In [1]:
from mmdet.apis import init_detector
import mmcv
from mmcv import Config


import copy
import os.path as osp

import numpy as np

from mmdet.datasets.builder import DATASETS
from mmdet.datasets.custom import CustomDataset

from mmdet.apis import set_random_seed


import json

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
import glob as _glob
import os
    
def glob(dir, pats, recursive=False):  # faster than match, python3 only
    pats = pats if isinstance(pats, (list, tuple)) else [pats]
    matches = []
    for pat in pats:
        matches += _glob.glob(os.path.join(dir, pat), recursive=recursive)
    return matches


In [3]:
#### load_annotations에서 뒤의 변수 받는거 custom dataset 에서는 이름을 바꿔도 되지만 아래에
#### configuration에서는 무조건 변수명을 ann_file로 받아야함
@DATASETS.register_module()
class Drive_dataset(CustomDataset):
    CLASSES=('car','bus','truck', 'special vehicle', 'motorcycle','bicycle','personal mobility','person','Traffic_light', 'Traffic_sign')


    def load_annotations(self, ann_fol):
        
        CLASSES_dict = {'car' : 0 , 'bus' : 1, 'truck' : 2, 'special vehicle' : 3, 'motorcycle' : 4,'bicycle' : 5 ,'personal mobility' : 6 
                        ,'person' : 7 ,'Traffic_light' : 8, 'Traffic_sign' : 9}
        
        cat2label = {k: i for i, k in enumerate(self.CLASSES)}
        
        data_infos = []
        
        ls = glob(ann_fol,'*',True)
        
        for idx,an in enumerate(ls):
            json_data = {}
            with open(an, "r") as json_file:
                json_data = json.load(json_file)
                
            ansplit = an.split('/')
            
            filename = ansplit[0] + '/' + ansplit[1] + '/' + 'images'+'/'+ json_data['image_name']
            
            width, height = json_data['image_size']

            data_info = dict(filename=filename, width=width, height=height)

            gt_bboxes = []
            gt_labels = []

            for ann_data in json_data['Annotation']:
                gt_labels.append(CLASSES_dict[ann_data['class_name']])
                gt_bboxes.append(ann_data['data'])


            data_anno = dict(
                    bboxes=np.array(gt_bboxes, dtype=np.float32).reshape(-1, 4),
                    labels=np.array(gt_labels, dtype=np.long))


            data_info.update(ann=data_anno)
            data_infos.append(data_info)
            
            if idx!=0 and idx%20000==0:
                print(str(idx)+'/'+str(len(ls))+' load annotations END!')
            
        
        
        return data_infos

In [4]:
## 추가수정 기존 받았던 pretrain과 매칭되는 config로 수정 
cfg = Config.fromfile('UniverseNet/configs/waymo_open/universenet50_2008_fp16_4x4_mstrain_640_1280_1x_waymo_open_f0.py') 

In [5]:
print(f'Config:\n{cfg.pretty_text}')

Config:
pretrained_ckpt = 'https://github.com/shinya7y/weights/releases/download/v1.0.2/res2net50_v1b_26w_4s-3cf99910_mmdetv2-92ed3313.pth'
model = dict(
    type='GFL',
    backbone=dict(
        type='Res2Net',
        depth=50,
        scales=4,
        base_width=26,
        num_stages=4,
        out_indices=(0, 1, 2, 3),
        frozen_stages=1,
        norm_cfg=dict(type='SyncBN', requires_grad=True),
        norm_eval=False,
        style='pytorch',
        dcn=dict(type='DCN', deform_groups=1, fallback_on_stride=False),
        stage_with_dcn=(False, False, False, True),
        init_cfg=dict(
            type='Pretrained',
            checkpoint=
            'https://github.com/shinya7y/weights/releases/download/v1.0.2/res2net50_v1b_26w_4s-3cf99910_mmdetv2-92ed3313.pth'
        )),
    neck=[
        dict(
            type='FPN',
            in_channels=[256, 512, 1024, 2048],
            out_channels=256,
            start_level=1,
            add_extra_convs='on_output',
   

In [6]:
## 추가 및 수정 ## 
cfg.dataset_type  = 'Drive_dataset'
cfg.data_root = ''

## single GPU 이기 때문에 syncBN 이 아닌 BN으로 수정)
cfg.model.backbone.norm_cfg=dict(type='BN', requires_grad=True)

## Validation pipeline에 train pipeline 적용하기 위해서 구성 
cfg.train_pipeline = [
    dict(type='LoadImageFromFile'),
    dict(type='LoadAnnotations', with_bbox=True),
    dict(
        type='Resize',
        img_scale=(1920, 1200),
        multiscale_mode='range',
        keep_ratio=True),
    dict(type='RandomFlip', flip_ratio=0.0),
    dict(
        type='Normalize',
        mean=[123.675, 116.28, 103.53],
        std=[58.395, 57.12, 57.375],
        to_rgb=True),
    dict(type='Pad', size_divisor=32),
    dict(type='DefaultFormatBundle'),
    dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels'])
    
]

### test pipeline 나중에 test진행에 사용할 거 실제 validation은 위의 pipeline 으로 진행
cfg.test_pipeline = [
    ### TSET때 사용할 test time augmentation용 pipeline
    dict(type='LoadImageFromFile'),
    dict(
                type='MultiScaleFlipAug',
                img_scale=(1920, 1200),
                flip=False,
                transforms=[
                    dict(type='Resize', keep_ratio=True),
                    dict(type='RandomFlip'),
                    dict(
                        type='Normalize',
                        mean=[123.675, 116.28, 103.53],
                        std=[58.395, 57.12, 57.375],
                        to_rgb=True),
                    dict(type='Pad', size_divisor=32),
                    dict(type='ImageToTensor', keys=['img']),
                      dict(type='Collect', keys=['img'])
                ])
]


cfg.data=dict(
    samples_per_gpu=10, # batchsize
    workers_per_gpu=12, # batch를 불러오기 위한 작업 thread 갯수 
    # train dataset 
    train=dict(
        type=cfg.dataset_type,
        ann_file='2DBB/training/labels/',
        pipeline=cfg.train_pipeline),
    # validation dataset  
    val=dict(
        type=cfg.dataset_type,
        ann_file='2DBB/validation/labels',
        pipeline=cfg.test_pipeline),
    test=None)

## class 갯수 
cfg.model.bbox_head.num_classes=10

## GPU 학습 진행을 위한 device 선언
cfg.device='cuda'

## weight 와 학습 log 저장 위치 
cfg.work_dir = 'checkpoints_Best_ver2'

## log interval
cfg.log_config.interval = 8000 #iteration 단위

cfg.seed = 2024

## seed 고정 진행
set_random_seed(cfg.seed, deterministic=False)

cfg.workflow = [('train', 1), ('val',1)]

cfg.evaluation = dict(interval=1, metric='mAP')

### coco dataset으로 pretrain된 weight load 할 path
cfg.load_from = 'universenet50_2008_fp16_4x4_mstrain_480_960_2x_coco_20200815_epoch_24-81356447.pth'

### epoch 선언
cfg.runner = dict(type='EpochBasedRunner', max_epochs=50)

### 사용할 GPU 선언
cfg.gpu_ids = range(1)

In [7]:
print(f'Config:\n{cfg.pretty_text}')

Config:
pretrained_ckpt = 'https://github.com/shinya7y/weights/releases/download/v1.0.2/res2net50_v1b_26w_4s-3cf99910_mmdetv2-92ed3313.pth'
model = dict(
    type='GFL',
    backbone=dict(
        type='Res2Net',
        depth=50,
        scales=4,
        base_width=26,
        num_stages=4,
        out_indices=(0, 1, 2, 3),
        frozen_stages=1,
        norm_cfg=dict(type='BN', requires_grad=True),
        norm_eval=False,
        style='pytorch',
        dcn=dict(type='DCN', deform_groups=1, fallback_on_stride=False),
        stage_with_dcn=(False, False, False, True),
        init_cfg=dict(
            type='Pretrained',
            checkpoint=
            'https://github.com/shinya7y/weights/releases/download/v1.0.2/res2net50_v1b_26w_4s-3cf99910_mmdetv2-92ed3313.pth'
        )),
    neck=[
        dict(
            type='FPN',
            in_channels=[256, 512, 1024, 2048],
            out_channels=256,
            start_level=1,
            add_extra_convs='on_output',
       

In [8]:
from mmdet.datasets import build_dataset
from mmdet.models import build_detector
from mmdet.apis import train_detector
import torch
import copy

In [9]:
# Build dataset
### validation ###
val_dataset=copy.deepcopy(cfg.data.val)
val_dataset.pipeline=cfg.data.train.pipeline
val_ds = build_dataset(val_dataset)
### validation  ###


## 실제 augmentation 포함 pipeline
cfg.train_pipeline = [
    # dict(type='LoadImageFromFile'),
    # dict(type='LoadAnnotations', with_bbox=True),
    dict(type='Mosaic', img_scale= (320, 480)  , pad_val=114.0),  # crop 과 동일하게 (height, width) 로 되있음
    dict(
        type='PhotoMetricDistortion',
        brightness_delta=32,
        contrast_range=(0.5, 1.5),
        saturation_range=(0.5, 1.5),
        hue_delta=18),
    dict(type='RandomFlip', flip_ratio=0.5),
    dict(
        type='Normalize',
        mean=[123.675, 116.28, 103.53],
        std=[58.395, 57.12, 57.375],
        to_rgb=True),
    dict(type='Pad', size_divisor=32),
    dict(type='FilterAnnotations', min_gt_bbox_wh=(1, 1), keep_empty=False),
    dict(type='DefaultFormatBundle'),
    dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels'])
    
]

# cfg.data.train.pipeline = cfg.train_pipeline

### class imbalance 를 해소하기 위해 oversampling 진행 + mosaic aug
### MultiImageMixDataset class를 사용하기 위해서는 load pipeline과 augmentation pipeline을 분리해야 해서 아래와 같이 수정하고 위의 train_pipeline도 수정
cfg.data.train = dict(
       type='MultiImageMixDataset',
                    dataset=dict(
                        type='ClassBalancedDataset',
                        oversample_thr=0.1,
                        dataset=dict(
                            type=cfg.dataset_type,
                            ann_file='2DBB/training/labels/',
                            pipeline=[
                                dict(type='LoadImageFromFile'),
                                dict(type='LoadAnnotations', with_bbox=True)
                            ],
                            filter_empty_gt=False)),
                     pipeline = cfg.train_pipeline)


datasets = [build_dataset(cfg.data.train), val_ds]

print(len(datasets[0]))
# Build the detector
model = build_detector(cfg.model)

# Add an attribute for visualization convenience
model.CLASSES = datasets[0].CLASSES

meta=dict()
meta['config'] = cfg.pretty_text
meta['seed'] = cfg.seed
# Create work_dir
mmcv.mkdir_or_exist(osp.abspath(cfg.work_dir))

## 학습 함수 
train_detector(model, datasets, cfg, distributed=False, validate=True, meta=meta)

Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  'CustomDataset does not support filtering empty gt images.')


20000/80000 load annotations END!
40000/80000 load annotations END!
60000/80000 load annotations END!


2024-08-14 12:36:24,471 - mmdet - INFO - image shape: height=320, width=480 in Mosaic.__init__


95202


2024-08-14 12:36:25,326 - mmdet - INFO - Automatic scaling of learning rate (LR) has been disabled.
2024-08-14 12:36:26,977 - mmdet - INFO - load checkpoint from local path: universenet50_2008_fp16_4x4_mstrain_480_960_2x_coco_20200815_epoch_24-81356447.pth

size mismatch for bbox_head.gfl_cls.weight: copying a param with shape torch.Size([80, 256, 3, 3]) from checkpoint, the shape in current model is torch.Size([10, 256, 3, 3]).
size mismatch for bbox_head.gfl_cls.bias: copying a param with shape torch.Size([80]) from checkpoint, the shape in current model is torch.Size([10]).
2024-08-14 12:36:27,050 - mmdet - INFO - Start running, host: root@de6c290dc761, work_dir: /root/checkpoints_Best_ver2
2024-08-14 12:36:27,051 - mmdet - INFO - Hooks will be executed in the following order:
before_run:
(VERY_HIGH   ) StepLrUpdaterHook                  
(ABOVE_NORMAL) Fp16OptimizerHook                  
(NORMAL      ) CheckpointHook                     
(LOW         ) EvalHook                     

[>>>>>>>>>>>>>>>>>>>>>>>>] 10000/10000, 15.4 task/s, elapsed: 648s, ETA:     0s
---------------iou_thr: 0.5---------------


2024-08-14 14:03:30,792 - mmdet - INFO - 
+-------------------+-------+--------+--------+-------+
| class             | gts   | dets   | recall | ap    |
+-------------------+-------+--------+--------+-------+
| car               | 32964 | 296290 | 0.884  | 0.696 |
| bus               | 979   | 34646  | 0.637  | 0.373 |
| truck             | 5874  | 147325 | 0.798  | 0.514 |
| special vehicle   | 202   | 13160  | 0.337  | 0.066 |
| motorcycle        | 235   | 5536   | 0.455  | 0.208 |
| bicycle           | 62    | 2073   | 0.387  | 0.203 |
| personal mobility | 43    | 6902   | 0.651  | 0.128 |
| person            | 7347  | 142852 | 0.797  | 0.559 |
| Traffic_light     | 3298  | 92718  | 0.870  | 0.465 |
| Traffic_sign      | 2889  | 162428 | 0.869  | 0.653 |
+-------------------+-------+--------+--------+-------+
| mAP               |       |        |        | 0.386 |
+-------------------+-------+--------+--------+-------+
2024-08-14 14:03:31,071 - mmdet - INFO - Epoch(val) [1][10000]

[>>>>>>>>>>>>>>>>>>>>>>>>] 10000/10000, 15.3 task/s, elapsed: 655s, ETA:     0s
---------------iou_thr: 0.5---------------


2024-08-14 15:42:01,970 - mmdet - INFO - 
+-------------------+-------+--------+--------+-------+
| class             | gts   | dets   | recall | ap    |
+-------------------+-------+--------+--------+-------+
| car               | 32964 | 376682 | 0.907  | 0.733 |
| bus               | 979   | 25865  | 0.623  | 0.392 |
| truck             | 5874  | 195182 | 0.811  | 0.522 |
| special vehicle   | 202   | 11232  | 0.272  | 0.078 |
| motorcycle        | 235   | 7639   | 0.468  | 0.209 |
| bicycle           | 62    | 1364   | 0.306  | 0.125 |
| personal mobility | 43    | 4626   | 0.558  | 0.179 |
| person            | 7347  | 93827  | 0.754  | 0.531 |
| Traffic_light     | 3298  | 46465  | 0.767  | 0.448 |
| Traffic_sign      | 2889  | 78465  | 0.804  | 0.587 |
+-------------------+-------+--------+--------+-------+
| mAP               |       |        |        | 0.380 |
+-------------------+-------+--------+--------+-------+
2024-08-14 15:42:02,288 - mmdet - INFO - Epoch(val) [2][10000]

[>>>>>>>>>>>>>>>>>>>>>>>>] 10000/10000, 15.1 task/s, elapsed: 663s, ETA:     0s
---------------iou_thr: 0.5---------------


2024-08-14 17:20:49,532 - mmdet - INFO - 
+-------------------+-------+--------+--------+-------+
| class             | gts   | dets   | recall | ap    |
+-------------------+-------+--------+--------+-------+
| car               | 32964 | 331445 | 0.858  | 0.620 |
| bus               | 979   | 23690  | 0.640  | 0.406 |
| truck             | 5874  | 160936 | 0.804  | 0.506 |
| special vehicle   | 202   | 12392  | 0.312  | 0.068 |
| motorcycle        | 235   | 6474   | 0.481  | 0.276 |
| bicycle           | 62    | 1939   | 0.387  | 0.200 |
| personal mobility | 43    | 4210   | 0.512  | 0.077 |
| person            | 7347  | 67612  | 0.719  | 0.527 |
| Traffic_light     | 3298  | 37435  | 0.775  | 0.451 |
| Traffic_sign      | 2889  | 47584  | 0.799  | 0.634 |
+-------------------+-------+--------+--------+-------+
| mAP               |       |        |        | 0.376 |
+-------------------+-------+--------+--------+-------+
2024-08-14 17:20:49,545 - mmdet - INFO - Epoch(val) [3][10000]

[>>>>>>>>>>>>>>>>>>>>>>>>] 10000/10000, 15.1 task/s, elapsed: 663s, ETA:     0s
---------------iou_thr: 0.5---------------


2024-08-14 18:59:45,345 - mmdet - INFO - 
+-------------------+-------+--------+--------+-------+
| class             | gts   | dets   | recall | ap    |
+-------------------+-------+--------+--------+-------+
| car               | 32964 | 456109 | 0.833  | 0.638 |
| bus               | 979   | 17287  | 0.605  | 0.350 |
| truck             | 5874  | 198194 | 0.784  | 0.484 |
| special vehicle   | 202   | 11621  | 0.213  | 0.030 |
| motorcycle        | 235   | 3999   | 0.302  | 0.177 |
| bicycle           | 62    | 1260   | 0.161  | 0.101 |
| personal mobility | 43    | 3757   | 0.372  | 0.051 |
| person            | 7347  | 43845  | 0.613  | 0.431 |
| Traffic_light     | 3298  | 48193  | 0.743  | 0.464 |
| Traffic_sign      | 2889  | 97578  | 0.798  | 0.606 |
+-------------------+-------+--------+--------+-------+
| mAP               |       |        |        | 0.333 |
+-------------------+-------+--------+--------+-------+
2024-08-14 18:59:45,682 - mmdet - INFO - Epoch(val) [4][10000]

[>>>>>>>>>>>>>>>>>>>>>>>>] 10000/10000, 14.9 task/s, elapsed: 673s, ETA:     0s
---------------iou_thr: 0.5---------------


2024-08-14 20:38:54,708 - mmdet - INFO - 
+-------------------+-------+--------+--------+-------+
| class             | gts   | dets   | recall | ap    |
+-------------------+-------+--------+--------+-------+
| car               | 32964 | 483135 | 0.811  | 0.609 |
| bus               | 979   | 25390  | 0.534  | 0.251 |
| truck             | 5874  | 142998 | 0.724  | 0.435 |
| special vehicle   | 202   | 15959  | 0.198  | 0.030 |
| motorcycle        | 235   | 3406   | 0.260  | 0.126 |
| bicycle           | 62    | 933    | 0.177  | 0.107 |
| personal mobility | 43    | 2195   | 0.302  | 0.036 |
| person            | 7347  | 112949 | 0.672  | 0.454 |
| Traffic_light     | 3298  | 35711  | 0.741  | 0.434 |
| Traffic_sign      | 2889  | 45379  | 0.739  | 0.587 |
+-------------------+-------+--------+--------+-------+
| mAP               |       |        |        | 0.307 |
+-------------------+-------+--------+--------+-------+
2024-08-14 20:38:54,721 - mmdet - INFO - Epoch(val) [5][10000]

[>>>>>>>>>>>>>>>>>>>>>>>>] 10000/10000, 14.9 task/s, elapsed: 671s, ETA:     0s
---------------iou_thr: 0.5---------------


2024-08-14 22:18:00,511 - mmdet - INFO - 
+-------------------+-------+--------+--------+-------+
| class             | gts   | dets   | recall | ap    |
+-------------------+-------+--------+--------+-------+
| car               | 32964 | 216749 | 0.763  | 0.580 |
| bus               | 979   | 27217  | 0.520  | 0.289 |
| truck             | 5874  | 176109 | 0.732  | 0.418 |
| special vehicle   | 202   | 8837   | 0.104  | 0.014 |
| motorcycle        | 235   | 3268   | 0.204  | 0.103 |
| bicycle           | 62    | 1493   | 0.210  | 0.066 |
| personal mobility | 43    | 3175   | 0.163  | 0.024 |
| person            | 7347  | 47067  | 0.592  | 0.384 |
| Traffic_light     | 3298  | 39011  | 0.700  | 0.384 |
| Traffic_sign      | 2889  | 83505  | 0.772  | 0.601 |
+-------------------+-------+--------+--------+-------+
| mAP               |       |        |        | 0.286 |
+-------------------+-------+--------+--------+-------+
2024-08-14 22:18:00,850 - mmdet - INFO - Epoch(val) [6][10000]

[>>>>>>>>>>>>>>>>>>>>>>>>] 10000/10000, 14.7 task/s, elapsed: 680s, ETA:     0s
---------------iou_thr: 0.5---------------


2024-08-14 23:57:13,795 - mmdet - INFO - 
+-------------------+-------+--------+--------+-------+
| class             | gts   | dets   | recall | ap    |
+-------------------+-------+--------+--------+-------+
| car               | 32964 | 346813 | 0.781  | 0.561 |
| bus               | 979   | 17418  | 0.464  | 0.233 |
| truck             | 5874  | 179425 | 0.697  | 0.396 |
| special vehicle   | 202   | 12626  | 0.099  | 0.006 |
| motorcycle        | 235   | 4970   | 0.200  | 0.055 |
| bicycle           | 62    | 1452   | 0.161  | 0.005 |
| personal mobility | 43    | 3409   | 0.233  | 0.025 |
| person            | 7347  | 97756  | 0.611  | 0.379 |
| Traffic_light     | 3298  | 39164  | 0.654  | 0.408 |
| Traffic_sign      | 2889  | 45849  | 0.716  | 0.535 |
+-------------------+-------+--------+--------+-------+
| mAP               |       |        |        | 0.260 |
+-------------------+-------+--------+--------+-------+
2024-08-14 23:57:13,801 - mmdet - INFO - Epoch(val) [7][10000]

[>>>>>>>>>>>>>>>>>>>>>>>>] 10000/10000, 14.8 task/s, elapsed: 676s, ETA:     0s
---------------iou_thr: 0.5---------------


2024-08-15 01:36:16,754 - mmdet - INFO - 
+-------------------+-------+--------+--------+-------+
| class             | gts   | dets   | recall | ap    |
+-------------------+-------+--------+--------+-------+
| car               | 32964 | 396139 | 0.756  | 0.562 |
| bus               | 979   | 16477  | 0.465  | 0.264 |
| truck             | 5874  | 129538 | 0.676  | 0.398 |
| special vehicle   | 202   | 20152  | 0.163  | 0.018 |
| motorcycle        | 235   | 5846   | 0.226  | 0.071 |
| bicycle           | 62    | 2004   | 0.194  | 0.033 |
| personal mobility | 43    | 5371   | 0.326  | 0.026 |
| person            | 7347  | 129396 | 0.615  | 0.362 |
| Traffic_light     | 3298  | 25343  | 0.672  | 0.411 |
| Traffic_sign      | 2889  | 59831  | 0.751  | 0.589 |
+-------------------+-------+--------+--------+-------+
| mAP               |       |        |        | 0.273 |
+-------------------+-------+--------+--------+-------+
2024-08-15 01:36:16,769 - mmdet - INFO - Epoch(val) [8][10000]

KeyboardInterrupt: 