In [1]:
### class 
### car','pedestrian','traffic sign', 'motorcycle', 'bus','truck','bicycle','traffic light','special vehicle', 'non' 총 10개 

In [2]:
from mmdet.apis import init_detector, inference_detector
import mmcv
from mmcv import Config


import copy
import os.path as osp

import numpy as np

from mmdet.datasets.builder import DATASETS
from mmdet.datasets.custom import CustomDataset

from mmdet.apis import set_random_seed


import json

In [3]:
@DATASETS.register_module()
class Auto_drive_dataset(CustomDataset):
    CLASSES=('car','pedestrian','traffic sign', 'motorcycle', 'bus','truck','bicycle','traffic light','special vehicle', 'non')


    def load_annotations(self, ann_file):
        
        CLASSES_dict = {'car' : 0 , 'pedestrian' : 1, 'traffic sign' : 2, 'motorcycle' : 3, 'bus' : 4,'truck' : 5 ,'bicycle' : 6 ,'traffic light' : 7 ,'special vehicle' : 8, 'non' : 9}
        
        cat2label = {k: i for i, k in enumerate(self.CLASSES)}
        # load image list from file
        image_list = mmcv.list_from_file(self.ann_file)
        
        data_infos = []
        
        for idx,img in enumerate(image_list):
            json_data = {}
            with open(img, "r") as json_file:
                json_data = json.load(json_file)
            # 수정 prefix 제대로 적용안되서 문자열 그대로 입력 
            filename = '2D_BB'+'/'+json_data['Source_Image_Info']['Img_path'][0:9]+'/'+json_data['Source_Image_Info']['Img_path']+'/'+json_data['Source_Image_Info']['Img_name']

            height, width = json_data['Source_Image_Info']['Resolution']

            data_info = dict(filename=filename, width=width, height=height)

            gt_bboxes = []
            gt_labels = []

            for a_idx in range(len(json_data['Annotation'])):
                gt_labels.append(CLASSES_dict[json_data['Annotation'][a_idx]['Label']])
                gt_bboxes.append(json_data['Annotation'][a_idx]['Coordinate'])


            data_anno = dict(
                    bboxes=np.array(gt_bboxes, dtype=np.float32).reshape(-1, 4),
                    labels=np.array(gt_labels, dtype=np.long))


            data_info.update(ann=data_anno)
            data_infos.append(data_info)
            
            if idx!=0 and idx%20000==0:
                print(str(idx)+'/'+str(image_list)+' load annotations END!')
            
        
        
        return data_infos
    

In [4]:
## 추가수정 기존 받았던 pretrain과 매칭되는 config로 수정 
cfg = Config.fromfile('mmdetection/configs/yolof/yolof_r50_c5_8x8_1x_coco.py') 

In [5]:
cfg.dataset_type  = 'Auto_drive_dataset'
cfg.data_root = ''

cfg.img_scale = (384, 384)
## 추가 ##
cfg.train_pipeline = [
    dict(type='LoadImageFromFile',to_float32=True),
    dict(type='LoadAnnotations'),
    dict(
        type='PhotoMetricDistortion',
        brightness_delta=32,
        contrast_range=(0.5, 1.5),
        saturation_range=(0.5, 1.5),
        hue_delta=18),
    dict(type='RandomFlip', flip_ratio=0.5),
    dict(type='Resize', keep_ratio=True, img_scale = (384, 384)),
    dict(type='Pad', pad_to_square=True, pad_val=114.0),
    dict(type='Normalize', **cfg.img_norm_cfg),
    dict(type='DefaultFormatBundle'),
    dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels'])
]

cfg.test_pipeline = [
    dict(type='LoadImageFromFile'),
    dict(
        type='MultiScaleFlipAug',
        img_scale = (384, 384),
        flip = False,
        transforms = [
            dict(type='Resize', keep_ratio=True),
            dict(type='RandomFlip'),
            dict(type='Pad', size=cfg.img_scale, pad_val=114.0),
            dict(type='Normalize', **cfg.img_norm_cfg),
            dict(type='DefaultFormatBundle'),
            dict(type='Collect', keys=['img'])
        ])
]


cfg.data = dict(
    samples_per_gpu=48,
    workers_per_gpu=2,
    train=dict(
        type=cfg.dataset_type,
        ann_file='splits/train.txt',
        pipeline=cfg.train_pipeline),
    val=dict(
        type=cfg.dataset_type,
        ann_file='splits/val.txt',
        pipeline=cfg.test_pipeline))


## 추가 ##
cfg.load_from = 'pretrain/yolof_r50_c5_8x8_1x_coco.pth'

cfg.work_dir = 'checkpoints'

cfg.evaluation.metric = 'mAP'

#추가수정 ( default class 80개라서 에러발생)
cfg.model.bbox_head.num_classes=10


cfg.log_config.interval = 1875 #iteration 단위
cfg.evaluation.interval = 1 #epoch 단위
cfg.checkpoint_config.interval = 1 #epoch 단위

cfg.seed = 0
set_random_seed(0, deterministic=False)
cfg.gpu_ids = range(1)

print(f'Config:\n{cfg.pretty_text}')

Config:
dataset_type = 'Auto_drive_dataset'
data_root = ''
img_norm_cfg = dict(
    mean=[103.53, 116.28, 123.675], std=[1.0, 1.0, 1.0], to_rgb=False)
train_pipeline = [
    dict(type='LoadImageFromFile', to_float32=True),
    dict(type='LoadAnnotations'),
    dict(
        type='PhotoMetricDistortion',
        brightness_delta=32,
        contrast_range=(0.5, 1.5),
        saturation_range=(0.5, 1.5),
        hue_delta=18),
    dict(type='RandomFlip', flip_ratio=0.5),
    dict(type='Resize', keep_ratio=True, img_scale=(384, 384)),
    dict(type='Pad', pad_to_square=True, pad_val=114.0),
    dict(
        type='Normalize',
        mean=[103.53, 116.28, 123.675],
        std=[1.0, 1.0, 1.0],
        to_rgb=False),
    dict(type='DefaultFormatBundle'),
    dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels'])
]
test_pipeline = [
    dict(type='LoadImageFromFile'),
    dict(
        type='MultiScaleFlipAug',
        img_scale=(384, 384),
        flip=False,
        transforms=[
   

In [6]:
from mmdet.datasets import build_dataset
from mmdet.models import build_detector
from mmdet.apis import train_detector
import torch

In [7]:
# Build dataset
datasets = [build_dataset(cfg.data.train)]

# Build the detector
model = build_detector(
    cfg.model, train_cfg=cfg.get('train_cfg'), test_cfg=cfg.get('test_cfg'))
# Add an attribute for visualization convenience
model.CLASSES = datasets[0].CLASSES

# Create work_dir
mmcv.mkdir_or_exist(osp.abspath(cfg.work_dir))
### 시간 부족으로 4epoch만 하고 멈춤..
train_detector(model, datasets, cfg, distributed=False, validate=True)

Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  'CustomDataset does not support filtering empty gt images.')
2021-11-19 03:42:15,625 - mmdet - INFO - load checkpoint from local path: pretrain/yolof_r50_c5_8x8_1x_coco.pth

size mismatch for bbox_head.cls_score.weight: copying a param with shape torch.Size([400, 512, 3, 3]) from checkpoint, the shape in current model is torch.Size([50, 512, 3, 3]).
size mismatch for bbox_head.cls_score.bias: copying a param with shape torch.Size([400]) from checkpoint, the shape in current model is torch.Size([50]).
2021-11-19 03:42:15,738 - mmdet - INFO - Start running, host: root@ae169e2ce13c, work_dir: /root/checkpoints
2021-11-19 03:42:15,739 - mmdet - INFO - Hooks will be executed in the following order:
before_run:
(VERY_HIGH   ) StepLrUpdaterHook                  
(NORMAL      ) CheckpointHook                     
(LOW         ) EvalHook                           
(VERY_LO

[>>>>>>>>>>>>>>>>>>>>>>>>>>] 9519/9519, 58.0 task/s, elapsed: 164s, ETA:     0s
---------------iou_thr: 0.5---------------


2021-11-19 04:30:58,489 - mmdet - INFO - 
+-----------------+-------+--------+--------+-------+
| class           | gts   | dets   | recall | ap    |
+-----------------+-------+--------+--------+-------+
| car             | 41250 | 109660 | 0.336  | 0.267 |
| pedestrian      | 2939  | 40823  | 0.145  | 0.112 |
| traffic sign    | 16540 | 76471  | 0.093  | 0.058 |
| motorcycle      | 501   | 24115  | 0.126  | 0.019 |
| bus             | 1604  | 86016  | 0.479  | 0.250 |
| truck           | 8922  | 91644  | 0.424  | 0.279 |
| bicycle         | 202   | 7749   | 0.248  | 0.070 |
| traffic light   | 7416  | 8752   | 0.019  | 0.004 |
| special vehicle | 596   | 28738  | 0.460  | 0.242 |
| non             | 4905  | 30665  | 0.013  | 0.002 |
+-----------------+-------+--------+--------+-------+
| mAP             |       |        |        | 0.130 |
+-----------------+-------+--------+--------+-------+
2021-11-19 04:30:58,519 - mmdet - INFO - Epoch(val) [1][9519]	AP50: 0.1300, mAP: 0.1303
2021-1

[>>>>>>>>>>>>>>>>>>>>>>>>>>] 9519/9519, 58.3 task/s, elapsed: 163s, ETA:     0s
---------------iou_thr: 0.5---------------


2021-11-19 05:17:11,528 - mmdet - INFO - 
+-----------------+-------+--------+--------+-------+
| class           | gts   | dets   | recall | ap    |
+-----------------+-------+--------+--------+-------+
| car             | 41250 | 117448 | 0.357  | 0.300 |
| pedestrian      | 2939  | 25346  | 0.174  | 0.121 |
| traffic sign    | 16540 | 51887  | 0.094  | 0.058 |
| motorcycle      | 501   | 16327  | 0.188  | 0.050 |
| bus             | 1604  | 39940  | 0.448  | 0.307 |
| truck           | 8922  | 102235 | 0.431  | 0.317 |
| bicycle         | 202   | 25110  | 0.287  | 0.106 |
| traffic light   | 7416  | 18986  | 0.029  | 0.012 |
| special vehicle | 596   | 27026  | 0.446  | 0.256 |
| non             | 4905  | 32851  | 0.016  | 0.003 |
+-----------------+-------+--------+--------+-------+
| mAP             |       |        |        | 0.153 |
+-----------------+-------+--------+--------+-------+
2021-11-19 05:17:11,561 - mmdet - INFO - Epoch(val) [2][9519]	AP50: 0.1530, mAP: 0.1531
2021-1

[>>>>>>>>>>>>>>>>>>>>>>>>>>] 9519/9519, 58.1 task/s, elapsed: 164s, ETA:     0s
---------------iou_thr: 0.5---------------


2021-11-19 06:00:57,856 - mmdet - INFO - 
+-----------------+-------+--------+--------+-------+
| class           | gts   | dets   | recall | ap    |
+-----------------+-------+--------+--------+-------+
| car             | 41250 | 100185 | 0.359  | 0.316 |
| pedestrian      | 2939  | 23850  | 0.158  | 0.129 |
| traffic sign    | 16540 | 49789  | 0.093  | 0.060 |
| motorcycle      | 501   | 19172  | 0.144  | 0.047 |
| bus             | 1604  | 42285  | 0.472  | 0.357 |
| truck           | 8922  | 102677 | 0.436  | 0.335 |
| bicycle         | 202   | 9851   | 0.277  | 0.132 |
| traffic light   | 7416  | 15707  | 0.019  | 0.004 |
| special vehicle | 596   | 35996  | 0.508  | 0.345 |
| non             | 4905  | 37607  | 0.018  | 0.003 |
+-----------------+-------+--------+--------+-------+
| mAP             |       |        |        | 0.173 |
+-----------------+-------+--------+--------+-------+
2021-11-19 06:00:57,886 - mmdet - INFO - Epoch(val) [3][9519]	AP50: 0.1730, mAP: 0.1727
2021-1

[>>>>>>>>>>>>>>>>>>>>>>>>>>] 9519/9519, 57.9 task/s, elapsed: 164s, ETA:     0s
---------------iou_thr: 0.5---------------


2021-11-19 06:44:51,207 - mmdet - INFO - 
+-----------------+-------+--------+--------+-------+
| class           | gts   | dets   | recall | ap    |
+-----------------+-------+--------+--------+-------+
| car             | 41250 | 134476 | 0.391  | 0.342 |
| pedestrian      | 2939  | 10928  | 0.174  | 0.111 |
| traffic sign    | 16540 | 54939  | 0.123  | 0.086 |
| motorcycle      | 501   | 8161   | 0.174  | 0.047 |
| bus             | 1604  | 45891  | 0.513  | 0.382 |
| truck           | 8922  | 120862 | 0.467  | 0.361 |
| bicycle         | 202   | 11033  | 0.317  | 0.209 |
| traffic light   | 7416  | 17112  | 0.033  | 0.011 |
| special vehicle | 596   | 30719  | 0.537  | 0.388 |
| non             | 4905  | 20379  | 0.023  | 0.010 |
+-----------------+-------+--------+--------+-------+
| mAP             |       |        |        | 0.195 |
+-----------------+-------+--------+--------+-------+
2021-11-19 06:44:51,237 - mmdet - INFO - Epoch(val) [4][9519]	AP50: 0.1950, mAP: 0.1946


KeyboardInterrupt: 