Skip to content

Commit

Permalink
[Feature] Release dyhead big model. (open-mmlab#7733)
Browse files Browse the repository at this point in the history
* Release dyhead with big model

* Update new config

* Update config

* Fix lint

* Update

* Update
  • Loading branch information
jbwang1997 authored and ZwwWayne committed Jul 19, 2022
1 parent f261664 commit 96be838
Show file tree
Hide file tree
Showing 3 changed files with 183 additions and 0 deletions.
6 changes: 6 additions & 0 deletions configs/dyhead/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,12 @@ The complex nature of combining localization and classification in object detect
We have not conduct ablation study between the two settings.
`dict(type='Pad', size_divisor=128)` may further improve AP by prefer spatial alignment across pyramid levels, although large padding reduces efficiency.

We also trained the model with Swin-L backbone. Results are as below.

| Method | Backbone | Style | Setting | Lr schd | mstrain | box AP | Config | Download |
|:------:|:--------:|:-------:|:------------:|:-------:|:-------:|:------:|:------:|:--------:|
| ATSS | Swin-L | caffe | reproduction | 2x | 480~1200| 56.2 | [config](./atss_swin-l-p4-w12_fpn_dyhead_mstrain_2x_coco.py) | [model](https://download.openmmlab.com/mmdetection/v2.0/dyhead/atss_swin-l-p4-w12_fpn_dyhead_mstrain_2x_coco/atss_swin-l-p4-w12_fpn_dyhead_mstrain_2x_coco_20220509_100315-bc5b6516.pth) | [log](https://download.openmmlab.com/mmdetection/v2.0/dyhead/atss_swin-l-p4-w12_fpn_dyhead_mstrain_2x_coco/atss_swin-l-p4-w12_fpn_dyhead_mstrain_2x_coco_20220509_100315.log.json) |

## Relation to Other Methods

- DyHead can be regarded as an improved [SEPC](https://arxiv.org/abs/2005.03101) with [DyReLU modules](https://arxiv.org/abs/2003.10027) and simplified [SE blocks](https://arxiv.org/abs/1709.01507).
Expand Down
164 changes: 164 additions & 0 deletions configs/dyhead/atss_swin-l-p4-w12_fpn_dyhead_mstrain_2x_coco.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,164 @@
_base_ = '../_base_/default_runtime.py'

pretrained = 'https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_large_patch4_window12_384_22k.pth' # noqa
model = dict(
type='ATSS',
backbone=dict(
type='SwinTransformer',
pretrain_img_size=384,
embed_dims=192,
depths=[2, 2, 18, 2],
num_heads=[6, 12, 24, 48],
window_size=12,
mlp_ratio=4,
qkv_bias=True,
qk_scale=None,
drop_rate=0.,
attn_drop_rate=0.,
drop_path_rate=0.2,
patch_norm=True,
out_indices=(1, 2, 3),
# Please only add indices that would be used
# in FPN, otherwise some parameter will not be used
with_cp=False,
convert_weights=True,
init_cfg=dict(type='Pretrained', checkpoint=pretrained)),
neck=[
dict(
type='FPN',
in_channels=[384, 768, 1536],
out_channels=256,
start_level=0,
add_extra_convs='on_output',
num_outs=5),
dict(
type='DyHead',
in_channels=256,
out_channels=256,
num_blocks=6,
# disable zero_init_offset to follow official implementation
zero_init_offset=False)
],
bbox_head=dict(
type='ATSSHead',
num_classes=80,
in_channels=256,
pred_kernel_size=1, # follow DyHead official implementation
stacked_convs=0,
feat_channels=256,
anchor_generator=dict(
type='AnchorGenerator',
ratios=[1.0],
octave_base_scale=8,
scales_per_octave=1,
strides=[8, 16, 32, 64, 128],
center_offset=0.5), # follow DyHead official implementation
bbox_coder=dict(
type='DeltaXYWHBBoxCoder',
target_means=[.0, .0, .0, .0],
target_stds=[0.1, 0.1, 0.2, 0.2]),
loss_cls=dict(
type='FocalLoss',
use_sigmoid=True,
gamma=2.0,
alpha=0.25,
loss_weight=1.0),
loss_bbox=dict(type='GIoULoss', loss_weight=2.0),
loss_centerness=dict(
type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0)),
# training and testing settings
train_cfg=dict(
assigner=dict(type='ATSSAssigner', topk=9),
allowed_border=-1,
pos_weight=-1,
debug=False),
test_cfg=dict(
nms_pre=1000,
min_bbox_size=0,
score_thr=0.05,
nms=dict(type='nms', iou_threshold=0.6),
max_per_img=100))

# dataset settings
dataset_type = 'CocoDataset'
data_root = 'data/coco/'
img_norm_cfg = dict(
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)

train_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='LoadAnnotations', with_bbox=True),
dict(
type='Resize',
img_scale=[(2000, 480), (2000, 1200)],
multiscale_mode='range',
keep_ratio=True,
backend='pillow'),
dict(type='RandomFlip', flip_ratio=0.5),
dict(type='Normalize', **img_norm_cfg),
dict(type='Pad', size_divisor=128),
dict(type='DefaultFormatBundle'),
dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels']),
]
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='MultiScaleFlipAug',
img_scale=(2000, 1200),
flip=False,
transforms=[
dict(type='Resize', keep_ratio=True, backend='pillow'),
dict(type='RandomFlip'),
dict(type='Normalize', **img_norm_cfg),
dict(type='Pad', size_divisor=128),
dict(type='ImageToTensor', keys=['img']),
dict(type='Collect', keys=['img']),
])
]

# Use RepeatDataset to speed up training
data = dict(
samples_per_gpu=2,
workers_per_gpu=2,
train=dict(
type='RepeatDataset',
times=2,
dataset=dict(
type=dataset_type,
ann_file=data_root + 'annotations/instances_train2017.json',
img_prefix=data_root + 'train2017/',
pipeline=train_pipeline)),
val=dict(
type=dataset_type,
ann_file=data_root + 'annotations/instances_val2017.json',
img_prefix=data_root + 'val2017/',
pipeline=test_pipeline),
test=dict(
type=dataset_type,
ann_file=data_root + 'annotations/instances_val2017.json',
img_prefix=data_root + 'val2017/',
pipeline=test_pipeline))
evaluation = dict(interval=1, metric='bbox')

# optimizer
optimizer_config = dict(grad_clip=None)
optimizer = dict(
type='AdamW',
lr=0.00005,
betas=(0.9, 0.999),
weight_decay=0.05,
paramwise_cfg=dict(
custom_keys={
'absolute_pos_embed': dict(decay_mult=0.),
'relative_position_bias_table': dict(decay_mult=0.),
'norm': dict(decay_mult=0.)
}))

# learning policy
lr_config = dict(
policy='step',
warmup='linear',
warmup_iters=500,
warmup_ratio=0.001,
step=[8, 11])
runner = dict(type='EpochBasedRunner', max_epochs=12)
13 changes: 13 additions & 0 deletions configs/dyhead/metafile.yml
Original file line number Diff line number Diff line change
Expand Up @@ -61,3 +61,16 @@ Models:
Metrics:
box AP: 43.3
Weights: https://download.openmmlab.com/mmdetection/v2.0/dyhead/atss_r50_fpn_dyhead_4x4_1x_coco/atss_r50_fpn_dyhead_4x4_1x_coco_20211219_023314-eaa620c6.pth

- Name: atss_swin-l-p4-w12_fpn_dyhead_mstrain_2x_coco
In Collection: DyHead
Config: configs/dyhead/atss_swin-l-p4-w12_fpn_dyhead_mstrain_2x_coco.py
Metadata:
Training Memory (GB): 58.4
Epochs: 24
Results:
- Task: Object Detection
Dataset: COCO
Metrics:
box AP: 56.2
Weights: https://download.openmmlab.com/mmdetection/v2.0/dyhead/atss_swin-l-p4-w12_fpn_dyhead_mstrain_2x_coco/atss_swin-l-p4-w12_fpn_dyhead_mstrain_2x_coco_20220509_100315-bc5b6516.pth

0 comments on commit 96be838

Please sign in to comment.