Skip to content

Commit

Permalink
Refactor YOLOX (open-mmlab#6443)
Browse files Browse the repository at this point in the history
* Fix aug test error when the number of prediction bboxes is 0 (open-mmlab#6398)

* Fix aug test error when the number of prediction bboxes is 0

* test

* test

* fix lint

* Support custom pin_memory and persistent_workers

* [Docs] Chinese version of robustness_benchmarking.md (open-mmlab#6375)

* Chinese version of robustness_benchmarking.md

* Update docs_zh-CN/robustness_benchmarking.md

Co-authored-by: RangiLyu <lyuchqi@gmail.com>

* Update docs_zh-CN/robustness_benchmarking.md

Co-authored-by: RangiLyu <lyuchqi@gmail.com>

* Update docs_zh-CN/robustness_benchmarking.md

Co-authored-by: RangiLyu <lyuchqi@gmail.com>

* Update docs_zh-CN/robustness_benchmarking.md

Co-authored-by: RangiLyu <lyuchqi@gmail.com>

* Update docs_zh-CN/robustness_benchmarking.md

Co-authored-by: RangiLyu <lyuchqi@gmail.com>

* Update docs_zh-CN/robustness_benchmarking.md

Co-authored-by: RangiLyu <lyuchqi@gmail.com>

* Update robustness_benchmarking.md

* Update robustness_benchmarking.md

* Update robustness_benchmarking.md

* Update robustness_benchmarking.md

* Update robustness_benchmarking.md

* Update robustness_benchmarking.md

Co-authored-by: RangiLyu <lyuchqi@gmail.com>

* update yolox_s

* update yolox_s

* support dynamic eval interval

* fix some error

* support ceph

* fix none error

* fix batch error

* replace resize

* fix comment

* fix docstr

* Update the link of checkpoints (open-mmlab#6460)

* [Feature]: Support plot confusion matrix. (open-mmlab#6344)

* remove pin_memory

* update

* fix unittest

* update cfg

* fix error

* add unittest

* [Fix] Fix SpatialReductionAttention in PVT. (open-mmlab#6488)

* [Fix] Fix SpatialReductionAttention in PVT

* Add warning

* Save coco summarize print information to logger (open-mmlab#6505)

* Fix type error in 2_new_data_mode (open-mmlab#6469)

* Always map location to cpu when load checkpoint (open-mmlab#6405)

* configs: update groie README (open-mmlab#6401)

Signed-off-by: Leonardo Rossi <leonardo.rossi@unipr.it>

* [Fix] fix config path in docs (open-mmlab#6396)

* [Enchance] Set a random seed when the user does not set a seed. (open-mmlab#6457)

* fix random seed bug

* add comment

* enchance random seed

* rename

Co-authored-by: Haobo Yuan <yuanhaobo@whu.edu.cn>

* [BugFixed] fix wrong trunc_normal_init use (open-mmlab#6432)

* fix wrong trunc_normal_init use

* fix wrong trunc_normal_init use

* fix open-mmlab#6446

Co-authored-by: Uno Wu <st9007a@gmail.com>
Co-authored-by: Leonardo Rossi <leonardo.rossi@unipr.it>
Co-authored-by: BigDong <yudongwang@tju.edu.cn>
Co-authored-by: Haian Huang(深度眸) <1286304229@qq.com>
Co-authored-by: Haobo Yuan <yuanhaobo@whu.edu.cn>
Co-authored-by: Shusheng Yang <shusheng.yang@qq.com>

* bump version to v2.18.1 (open-mmlab#6510)

* bump version to v2.18.1

* Update changelog.md

* add some comment

* fix some comment

* update readme

* fix lint

* add reduce mean

* update

* update readme

* update params

Co-authored-by: Cedric Luo <luochunhua1996@outlook.com>
Co-authored-by: RangiLyu <lyuchqi@gmail.com>
Co-authored-by: Guangchen Lin <347630870@qq.com>
Co-authored-by: Andrea Panizza <8233615+AndreaPi@users.noreply.github.com>
Co-authored-by: Uno Wu <st9007a@gmail.com>
Co-authored-by: Leonardo Rossi <leonardo.rossi@unipr.it>
Co-authored-by: BigDong <yudongwang@tju.edu.cn>
Co-authored-by: Haobo Yuan <yuanhaobo@whu.edu.cn>
Co-authored-by: Shusheng Yang <shusheng.yang@qq.com>
  • Loading branch information
10 people authored and ZwwWayne committed Jul 18, 2022
1 parent 495732d commit fa5dde9
Show file tree
Hide file tree
Showing 20 changed files with 504 additions and 185 deletions.
4 changes: 2 additions & 2 deletions configs/yolox/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@

| Backbone | size | Mem (GB) | box AP | Config | Download |
|:---------:|:-------:|:-------:|:-------:|:--------:|:------:|
| YOLOX-Tiny | 416 | 3.6 | 31.6 | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/yolox/yolox_tiny_8x8_300e_coco.py) |[model](https://download.openmmlab.com/mmdetection/v2.0/yolox/yolox_tiny_8x8_300e_coco/yolox_tiny_8x8_300e_coco_20210806_234250-4ff3b67e.pth) &#124; [log](https://download.openmmlab.com/mmdetection/v2.0/yolox/yolox_tiny_8x8_300e_coco/yolox_tiny_8x8_300e_coco_20210806_234250.log.json) |
| YOLOX-s | 640 | 7.6 | 40.5 | [config](https://github.com/open-mmlab/mmdetection/tree/master/configs/yolox/yolox_s_8x8_300e_coco.py) |[model](https://download.openmmlab.com/mmdetection/v2.0/yolox/yolox_s_8x8_300e_coco/yolox_s_8x8_300e_coco_20211121_095711-4592a793.pth) &#124; [log](https://download.openmmlab.com/mmdetection/v2.0/yolox/yolox_s_8x8_300e_coco/yolox_s_8x8_300e_coco_20211121_095711.log.json) |


**Note**:

1. The test score threshold is 0.001.
2. We find that the performance is unstable and may fluctuate by about 0.7 mAP. We will continue to investigate and improve it.
11 changes: 6 additions & 5 deletions configs/yolox/metafile.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,16 +18,17 @@ Collections:
URL: https://github.com/open-mmlab/mmdetection/blob/v2.15.1/mmdet/models/detectors/yolox.py#L6
Version: v2.15.1


Models:
- Name: yolox_tiny_8x8_300e_coco
- Name: yolox_s_8x8_300e_coco
In Collection: YOLOX
Config: configs/yolox/yolox_tiny_8x8_300e_coco.py
Config: configs/yolox/yolox_s_8x8_300e_coco.py
Metadata:
Training Memory (GB): 3.6
Training Memory (GB): 7.6
Epochs: 300
Results:
- Task: Object Detection
Dataset: COCO
Metrics:
box AP: 31.6
Weights: https://download.openmmlab.com/mmdetection/v2.0/yolox/yolox_tiny_8x8_300e_coco/yolox_tiny_8x8_300e_coco_20210806_234250-4ff3b67e.pth
box AP: 40.5
Weights: https://download.openmmlab.com/mmdetection/v2.0/yolox/yolox_s_8x8_300e_coco/yolox_s_8x8_300e_coco_20211121_095711-4592a793.pth
80 changes: 49 additions & 31 deletions configs/yolox/yolox_s_8x8_300e_coco.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,13 @@
_base_ = ['../_base_/schedules/schedule_1x.py', '../_base_/default_runtime.py']

img_scale = (640, 640)

# model settings
model = dict(
type='YOLOX',
input_size=img_scale,
random_size_range=(15, 25),
random_size_interval=10,
backbone=dict(type='CSPDarknet', deepen_factor=0.33, widen_factor=0.5),
neck=dict(
type='YOLOXPAFPN',
Expand All @@ -20,11 +25,6 @@
data_root = 'data/coco/'
dataset_type = 'CocoDataset'

img_norm_cfg = dict(
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)

img_scale = (640, 640)

train_pipeline = [
dict(type='Mosaic', img_scale=img_scale, pad_val=114.0),
dict(
Expand All @@ -36,16 +36,19 @@
img_scale=img_scale,
ratio_range=(0.8, 1.6),
pad_val=114.0),
dict(
type='PhotoMetricDistortion',
brightness_delta=32,
contrast_range=(0.5, 1.5),
saturation_range=(0.5, 1.5),
hue_delta=18),
dict(type='YOLOXHSVRandomAug'),
dict(type='RandomFlip', flip_ratio=0.5),
dict(type='Resize', keep_ratio=True),
dict(type='Pad', pad_to_square=True, pad_val=114.0),
dict(type='Normalize', **img_norm_cfg),
# According to the official implementation, multi-scale
# training is not considered here but in the
# 'mmdet/models/detectors/yolox.py'.
dict(type='Resize', img_scale=img_scale, keep_ratio=True),
dict(
type='Pad',
pad_to_square=True,
# If the image is three-channel, the pad value needs
# to be set separately for each channel.
pad_val=dict(img=(114.0, 114.0, 114.0))),
dict(type='FilterAnnotations', min_gt_bbox_wh=(1, 1), keep_empty=False),
dict(type='DefaultFormatBundle'),
dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels'])
]
Expand All @@ -57,13 +60,12 @@
ann_file=data_root + 'annotations/instances_train2017.json',
img_prefix=data_root + 'train2017/',
pipeline=[
dict(type='LoadImageFromFile', to_float32=True),
dict(type='LoadImageFromFile'),
dict(type='LoadAnnotations', with_bbox=True)
],
filter_empty_gt=False,
),
pipeline=train_pipeline,
dynamic_scale=img_scale)
pipeline=train_pipeline)

test_pipeline = [
dict(type='LoadImageFromFile'),
Expand All @@ -74,16 +76,19 @@
transforms=[
dict(type='Resize', keep_ratio=True),
dict(type='RandomFlip'),
dict(type='Pad', size=img_scale, pad_val=114.0),
dict(type='Normalize', **img_norm_cfg),
dict(
type='Pad',
pad_to_square=True,
pad_val=dict(img=(114.0, 114.0, 114.0))),
dict(type='DefaultFormatBundle'),
dict(type='Collect', keys=['img'])
])
]

data = dict(
samples_per_gpu=8,
workers_per_gpu=2,
workers_per_gpu=4,
persistent_workers=True,
train=train_dataset,
val=dict(
type=dataset_type,
Expand All @@ -107,6 +112,11 @@
paramwise_cfg=dict(norm_decay_mult=0., bias_decay_mult=0.))
optimizer_config = dict(grad_clip=None)

max_epochs = 300
num_last_epochs = 15
resume_from = None
interval = 10

# learning policy
lr_config = dict(
_delete_=True,
Expand All @@ -116,27 +126,35 @@
warmup_by_epoch=True,
warmup_ratio=1,
warmup_iters=5, # 5 epoch
num_last_epochs=15,
num_last_epochs=num_last_epochs,
min_lr_ratio=0.05)
runner = dict(type='EpochBasedRunner', max_epochs=300)

resume_from = None
interval = 10
runner = dict(type='EpochBasedRunner', max_epochs=max_epochs)

custom_hooks = [
dict(type='YOLOXModeSwitchHook', num_last_epochs=15, priority=48),
dict(
type='SyncRandomSizeHook',
ratio_range=(14, 26),
img_scale=img_scale,
type='YOLOXModeSwitchHook',
num_last_epochs=num_last_epochs,
priority=48),
dict(
type='SyncNormHook',
num_last_epochs=15,
num_last_epochs=num_last_epochs,
interval=interval,
priority=48),
dict(type='ExpMomentumEMAHook', resume_from=resume_from, priority=49)
dict(
type='ExpMomentumEMAHook',
resume_from=resume_from,
momentum=0.0001,
priority=49)
]
checkpoint_config = dict(interval=interval)
evaluation = dict(interval=interval, metric='bbox')
evaluation = dict(
save_best='auto',
# The evaluation interval is 'interval' when running epoch is
# less than ‘max_epochs - num_last_epochs’.
# The evaluation interval is 1 when running epoch is greater than
# or equal to ‘max_epochs - num_last_epochs’.
interval=interval,
dynamic_intervals=[(max_epochs - num_last_epochs, 1)],
metric='bbox')
log_config = dict(interval=50)
49 changes: 12 additions & 37 deletions configs/yolox/yolox_tiny_8x8_300e_coco.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,11 @@

# model settings
model = dict(
random_size_range=(10, 20),
backbone=dict(deepen_factor=0.33, widen_factor=0.375),
neck=dict(in_channels=[96, 192, 384], out_channels=96),
bbox_head=dict(in_channels=96, feat_channels=96))

# dataset settings
img_norm_cfg = dict(
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)

img_scale = (640, 640)

train_pipeline = [
Expand All @@ -18,16 +15,14 @@
type='RandomAffine',
scaling_ratio_range=(0.5, 1.5),
border=(-img_scale[0] // 2, -img_scale[1] // 2)),
dict(
type='PhotoMetricDistortion',
brightness_delta=32,
contrast_range=(0.5, 1.5),
saturation_range=(0.5, 1.5),
hue_delta=18),
dict(type='YOLOXHSVRandomAug'),
dict(type='RandomFlip', flip_ratio=0.5),
dict(type='Resize', keep_ratio=True),
dict(type='Pad', pad_to_square=True, pad_val=114.0),
dict(type='Normalize', **img_norm_cfg),
dict(type='Resize', img_scale=img_scale, keep_ratio=True),
dict(
type='Pad',
pad_to_square=True,
pad_val=dict(img=(114.0, 114.0, 114.0))),
dict(type='FilterAnnotations', min_gt_bbox_wh=(1, 1), keep_empty=False),
dict(type='DefaultFormatBundle'),
dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels'])
]
Expand All @@ -41,8 +36,10 @@
transforms=[
dict(type='Resize', keep_ratio=True),
dict(type='RandomFlip'),
dict(type='Pad', size=(416, 416), pad_val=114.0),
dict(type='Normalize', **img_norm_cfg),
dict(
type='Pad',
pad_to_square=True,
pad_val=dict(img=(114.0, 114.0, 114.0))),
dict(type='DefaultFormatBundle'),
dict(type='Collect', keys=['img'])
])
Expand All @@ -54,25 +51,3 @@
train=train_dataset,
val=dict(pipeline=test_pipeline),
test=dict(pipeline=test_pipeline))

resume_from = None
interval = 10

# Execute in the order of insertion when the priority is the same.
# The smaller the value, the higher the priority
custom_hooks = [
dict(type='YOLOXModeSwitchHook', num_last_epochs=15, priority=48),
dict(
type='SyncRandomSizeHook',
ratio_range=(10, 20),
img_scale=img_scale,
priority=48),
dict(
type='SyncNormHook',
num_last_epochs=15,
interval=interval,
priority=48),
dict(type='ExpMomentumEMAHook', resume_from=resume_from, priority=49)
]
checkpoint_config = dict(interval=interval)
evaluation = dict(interval=interval, metric='bbox')
3 changes: 2 additions & 1 deletion mmdet/core/bbox/assigners/sim_ota_assigner.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,8 @@ def get_in_gt_and_in_center_info(self, priors, gt_bboxes):
def dynamic_k_matching(self, cost, pairwise_ious, num_gt, valid_mask):
matching_matrix = torch.zeros_like(cost)
# select candidate topk ious for dynamic-k calculation
topk_ious, _ = torch.topk(pairwise_ious, self.candidate_topk, dim=0)
candidate_topk = min(self.candidate_topk, pairwise_ious.size(0))
topk_ious, _ = torch.topk(pairwise_ious, candidate_topk, dim=0)
# calculate dynamic k for each gt
dynamic_ks = torch.clamp(topk_ious.sum(0).int(), min=1)
for gt_idx in range(num_gt):
Expand Down
65 changes: 65 additions & 0 deletions mmdet/core/evaluation/eval_hooks.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,52 @@
# Copyright (c) OpenMMLab. All rights reserved.
import bisect
import os.path as osp

import mmcv
import torch.distributed as dist
from mmcv.runner import DistEvalHook as BaseDistEvalHook
from mmcv.runner import EvalHook as BaseEvalHook
from torch.nn.modules.batchnorm import _BatchNorm


def _calc_dynamic_intervals(start_interval, dynamic_interval_list):
assert mmcv.is_list_of(dynamic_interval_list, tuple)

dynamic_milestones = [0]
dynamic_milestones.extend(
[dynamic_interval[0] for dynamic_interval in dynamic_interval_list])
dynamic_intervals = [start_interval]
dynamic_intervals.extend(
[dynamic_interval[1] for dynamic_interval in dynamic_interval_list])
return dynamic_milestones, dynamic_intervals


class EvalHook(BaseEvalHook):

def __init__(self, *args, dynamic_intervals=None, **kwargs):
super(EvalHook, self).__init__(*args, **kwargs)

self.use_dynamic_intervals = dynamic_intervals is not None
if self.use_dynamic_intervals:
self.dynamic_milestones, self.dynamic_intervals = \
_calc_dynamic_intervals(self.interval, dynamic_intervals)

def _decide_interval(self, runner):
if self.use_dynamic_intervals:
progress = runner.epoch if self.by_epoch else runner.iter
step = bisect.bisect(self.dynamic_milestones, (progress + 1))
# Dynamically modify the evaluation interval
self.interval = self.dynamic_intervals[step - 1]

def before_train_epoch(self, runner):
"""Evaluate the model only at the start of training by epoch."""
self._decide_interval(runner)
super().before_train_epoch(runner)

def before_train_iter(self, runner):
self._decide_interval(runner)
super().before_train_iter(runner)

def _do_evaluate(self, runner):
"""perform evaluation and save ckpt."""
if not self._should_evaluate(runner):
Expand All @@ -22,8 +60,35 @@ def _do_evaluate(self, runner):
self._save_ckpt(runner, key_score)


# Note: Considering that MMCV's EvalHook updated its interface in V1.3.16,
# in order to avoid strong version dependency, we did not directly
# inherit EvalHook but BaseDistEvalHook.
class DistEvalHook(BaseDistEvalHook):

def __init__(self, *args, dynamic_intervals=None, **kwargs):
super(DistEvalHook, self).__init__(*args, **kwargs)

self.use_dynamic_intervals = dynamic_intervals is not None
if self.use_dynamic_intervals:
self.dynamic_milestones, self.dynamic_intervals = \
_calc_dynamic_intervals(self.interval, dynamic_intervals)

def _decide_interval(self, runner):
if self.use_dynamic_intervals:
progress = runner.epoch if self.by_epoch else runner.iter
step = bisect.bisect(self.dynamic_milestones, (progress + 1))
# Dynamically modify the evaluation interval
self.interval = self.dynamic_intervals[step - 1]

def before_train_epoch(self, runner):
"""Evaluate the model only at the start of training by epoch."""
self._decide_interval(runner)
super().before_train_epoch(runner)

def before_train_iter(self, runner):
self._decide_interval(runner)
super().before_train_iter(runner)

def _do_evaluate(self, runner):
"""perform evaluation and save ckpt."""
# Synchronization of BatchNorm's buffer (running_mean
Expand Down
14 changes: 14 additions & 0 deletions mmdet/core/hook/yolox_mode_switch_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ def __init__(self,
skip_type_keys=('Mosaic', 'RandomAffine', 'MixUp')):
self.num_last_epochs = num_last_epochs
self.skip_type_keys = skip_type_keys
self._restart_dataloader = False

def before_train_epoch(self, runner):
"""Close mosaic and mixup augmentation and switches to use L1 loss."""
Expand All @@ -33,6 +34,19 @@ def before_train_epoch(self, runner):
model = model.module
if (epoch + 1) == runner.max_epochs - self.num_last_epochs:
runner.logger.info('No mosaic and mixup aug now!')
# The dataset pipeline cannot be updated when persistent_workers
# is True, so we need to force the dataloader's multi-process
# restart. This is a very hacky approach.
train_loader.dataset.update_skip_type_keys(self.skip_type_keys)
if hasattr(train_loader, 'persistent_workers'
) and train_loader.persistent_workers is True:
train_loader._DataLoader__initialized = False
train_loader._iterator = None
self._restart_dataloader = True
runner.logger.info('Add additional L1 loss now!')
model.bbox_head.use_l1 = True
else:
# Once the restart is complete, we need to restore
# the initialization flag.
if self._restart_dataloader:
train_loader._DataLoader__initialized = True
Loading

0 comments on commit fa5dde9

Please sign in to comment.