Skip to content

Commit

Permalink
Add CID to mmpose (open-mmlab#1604)
Browse files Browse the repository at this point in the history
  • Loading branch information
kennethwdk committed Oct 11, 2022
1 parent 4339895 commit e79d12b
Show file tree
Hide file tree
Showing 15 changed files with 1,415 additions and 8 deletions.
41 changes: 41 additions & 0 deletions configs/body/2d_kpt_sview_rgb_img/cid/coco/cid_coco.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
<!-- [ALGORITHM] -->

<details>
<summary align="right"><a href="https://openaccess.thecvf.com/content/CVPR2022/html/Wang_Contextual_Instance_Decoupling_for_Robust_Multi-Person_Pose_Estimation_CVPR_2022_paper.html">CID (CVPR'2022)</a></summary>

```bibtex
@InProceedings{Wang_2022_CVPR,
author = {Wang, Dongkai and Zhang, Shiliang},
title = {Contextual Instance Decoupling for Robust Multi-Person Pose Estimation},
booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
month = {June},
year = {2022},
pages = {11060-11068}
}
```

</details>

<!-- [DATASET] -->

<details>
<summary align="right"><a href="https://link.springer.com/chapter/10.1007/978-3-319-10602-1_48">COCO (ECCV'2014)</a></summary>

```bibtex
@inproceedings{lin2014microsoft,
title={Microsoft coco: Common objects in context},
author={Lin, Tsung-Yi and Maire, Michael and Belongie, Serge and Hays, James and Perona, Pietro and Ramanan, Deva and Doll{\'a}r, Piotr and Zitnick, C Lawrence},
booktitle={European conference on computer vision},
pages={740--755},
year={2014},
organization={Springer}
}
```

</details>

Results on COCO val2017 without multi-scale test

| Arch | Input Size | AP | AP<sup>50</sup> | AP<sup>75</sup> | AR | AR<sup>50</sup> | ckpt | log |
| :-------------------------------------------- | :--------: | :---: | :-------------: | :-------------: | :---: | :-------------: | :-------------------------------------------: | :-------------------------------------------: |
| [CID](/configs/body/2d_kpt_sview_rgb_img/cid/coco/hrnet_w32_coco_512x512.py) | 512x512 | 0.702 | 0.887 | 0.768 | 0.755 | 0.926 | [ckpt](https://download.openmmlab.com/mmpose/bottom_up/cid/hrnet_w32_coco_512x512-867b9659_20220928.pth) | [log](https://download.openmmlab.com/mmpose/bottom_up/cid/hrnet_w32_coco_512x512-20220928.log.json) |
24 changes: 24 additions & 0 deletions configs/body/2d_kpt_sview_rgb_img/cid/coco/cid_coco.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
Collections:
- Name: CID
Paper:
Title: Contextual Instance Decoupling for Robust Multi-Person Pose Estimation
URL: https://openaccess.thecvf.com/content/CVPR2022/html/Wang_Contextual_Instance_Decoupling_for_Robust_Multi-Person_Pose_Estimation_CVPR_2022_paper.html
README: https://github.com/open-mmlab/mmpose/blob/master/docs/en/papers/algorithms/cid.md
Models:
- Config: configs/body/2d_kpt_sview_rgb_img/cid/coco/hrnet_w32_coco_512x512.py
In Collection: CID
Metadata:
Architecture:
- CID
Training Data: COCO
Name: cid_hrnet_w32_coco_512x512
Results:
- Dataset: COCO
Metrics:
AP: 0.702
AP@0.5: 0.887
AP@0.75: 0.768
AR: 0.755
AR@0.5: 0.926
Task: Body 2D Keypoint
Weights: https://download.openmmlab.com/mmpose/bottom_up/cid/hrnet_w32_coco_512x512-867b9659_20220928.pth
171 changes: 171 additions & 0 deletions configs/body/2d_kpt_sview_rgb_img/cid/coco/hrnet_w32_coco_512x512.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,171 @@
_base_ = [
'../../../../_base_/default_runtime.py',
'../../../../_base_/datasets/coco.py'
]
checkpoint_config = dict(interval=20)
evaluation = dict(interval=20, metric='mAP', save_best='AP')

optimizer = dict(
type='Adam',
lr=0.001,
)
optimizer_config = dict(grad_clip=None)
# learning policy
lr_config = dict(policy='step', step=[90, 120])
total_epochs = 140
channel_cfg = dict(
dataset_joints=17,
dataset_channel=[
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16],
],
inference_channel=[
0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16
])

data_cfg = dict(
image_size=512,
base_size=256,
base_sigma=2,
heatmap_size=[128],
num_joints=channel_cfg['dataset_joints'],
dataset_channel=channel_cfg['dataset_channel'],
inference_channel=channel_cfg['inference_channel'],
num_scales=1,
scale_aware_sigma=False,
with_bbox=True,
use_nms=True,
soft_nms=False,
oks_thr=0.8,
)

# model settings
model = dict(
type='CID',
pretrained='https://download.openmmlab.com/mmpose/'
'pretrain_models/hrnet_w32-36af842e.pth',
backbone=dict(
type='HRNet',
in_channels=3,
extra=dict(
stage1=dict(
num_modules=1,
num_branches=1,
block='BOTTLENECK',
num_blocks=(4, ),
num_channels=(64, )),
stage2=dict(
num_modules=1,
num_branches=2,
block='BASIC',
num_blocks=(4, 4),
num_channels=(32, 64)),
stage3=dict(
num_modules=4,
num_branches=3,
block='BASIC',
num_blocks=(4, 4, 4),
num_channels=(32, 64, 128)),
stage4=dict(
num_modules=3,
num_branches=4,
block='BASIC',
num_blocks=(4, 4, 4, 4),
num_channels=(32, 64, 128, 256),
multiscale_output=True)),
),
keypoint_head=dict(
type='CIDHead',
in_channels=480,
gfd_channels=32,
num_joints=17,
multi_hm_loss_factor=1.0,
single_hm_loss_factor=4.0,
contrastive_loss_factor=1.0,
max_train_instances=200,
prior_prob=0.01),
train_cfg=dict(),
test_cfg=dict(
num_joints=channel_cfg['dataset_joints'],
flip_test=True,
max_num_people=30,
detection_threshold=0.01,
center_pool_kernel=3))

train_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='BottomUpRandomAffine',
rot_factor=30,
scale_factor=[0.75, 1.5],
scale_type='short',
trans_factor=40),
dict(type='BottomUpRandomFlip', flip_prob=0.5),
dict(type='ToTensor'),
dict(
type='NormalizeTensor',
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]),
dict(
type='CIDGenerateTarget',
max_num_people=30,
),
dict(
type='Collect',
keys=[
'img', 'multi_heatmap', 'multi_mask', 'instance_coord',
'instance_heatmap', 'instance_mask', 'instance_valid'
],
meta_keys=[]),
]

val_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='BottomUpGetImgSize', test_scale_factor=[1]),
dict(
type='BottomUpResizeAlign',
transforms=[
dict(type='ToTensor'),
dict(
type='NormalizeTensor',
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]),
]),
dict(
type='Collect',
keys=['img'],
meta_keys=[
'image_file', 'aug_data', 'test_scale_factor', 'base_size',
'center', 'scale', 'flip_index'
]),
]

test_pipeline = val_pipeline

data_root = 'data/coco'
data = dict(
workers_per_gpu=2,
train_dataloader=dict(samples_per_gpu=20),
val_dataloader=dict(samples_per_gpu=1),
test_dataloader=dict(samples_per_gpu=1),
train=dict(
type='BottomUpCocoDataset',
ann_file=f'{data_root}/annotations/person_keypoints_train2017.json',
img_prefix=f'{data_root}/train2017/',
data_cfg=data_cfg,
pipeline=train_pipeline,
dataset_info={{_base_.dataset_info}}),
val=dict(
type='BottomUpCocoDataset',
ann_file=f'{data_root}/annotations/person_keypoints_val2017.json',
img_prefix=f'{data_root}/val2017/',
data_cfg=data_cfg,
pipeline=val_pipeline,
dataset_info={{_base_.dataset_info}}),
test=dict(
type='BottomUpCocoDataset',
ann_file=f'{data_root}/annotations/person_keypoints_val2017.json',
img_prefix=f'{data_root}/val2017/',
data_cfg=data_cfg,
pipeline=test_pipeline,
dataset_info={{_base_.dataset_info}}),
)
31 changes: 31 additions & 0 deletions docs/en/papers/algorithms/cid.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
# Contextual Instance Decoupling for Robust Multi-Person Pose Estimation

<!-- [ALGORITHM] -->

<details>
<summary align="right"><a href="https://openaccess.thecvf.com/content/CVPR2022/html/Wang_Contextual_Instance_Decoupling_for_Robust_Multi-Person_Pose_Estimation_CVPR_2022_paper.html">CID (CVPR'2022)</a></summary>

```bibtex
@InProceedings{Wang_2022_CVPR,
author = {Wang, Dongkai and Zhang, Shiliang},
title = {Contextual Instance Decoupling for Robust Multi-Person Pose Estimation},
booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
month = {June},
year = {2022},
pages = {11060-11068}
}
```

</details>

## Abstract

<!-- [ABSTRACT] -->

Crowded scenes make it challenging to differentiate persons and locate their pose keypoints. This paper proposes the Contextual Instance Decoupling (CID), which presents a new pipeline for multi-person pose estimation. Instead of relying on person bounding boxes to spatially differentiate persons, CID decouples persons in an image into multiple instance-aware feature maps. Each of those feature maps is hence adopted to infer keypoints for a specific person. Compared with bounding box detection, CID is differentiable and robust to detection errors. Decoupling persons into different feature maps allows to isolate distractions from other persons, and explore context cues at scales larger than the bounding box size. Experiments show that CID outperforms previous multi-person pose estimation pipelines on crowded scenes pose estimation benchmarks in both accuracy and efficiency. For instance, it achieves 71.3% AP on CrowdPose, outperforming the recent single-stage DEKR by 5.6%, the bottom-up CenterAttention by 3.7%, and the top-down JCSPPE by 5.3%. This advantage sustains on the commonly used COCO benchmark.

<!-- [IMAGE] -->

<div align=center>
<img src="https://github.com/kennethwdk/CID/raw/main/img/framework.png">
</div>
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ def __init__(self,
self.ann_info['inference_channel'] = data_cfg['inference_channel']
self.ann_info['dataset_channel'] = data_cfg['dataset_channel']

self.with_bbox = data_cfg.get('with_bbox', False)
self.use_nms = data_cfg.get('use_nms', False)
self.soft_nms = data_cfg.get('soft_nms', True)
self.oks_thr = data_cfg.get('oks_thr', 0.9)
Expand Down
16 changes: 16 additions & 0 deletions mmpose/datasets/datasets/bottom_up/bottom_up_coco.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,22 @@ def _get_single(self, idx):
db_rec['mask'] = mask_list
db_rec['joints'] = joints_list

if self.with_bbox:
# add bbox and area
num_people = len(anno)
areas = np.zeros((num_people, 1))
bboxes = np.zeros((num_people, 4, 2))
for i, obj in enumerate(anno):
areas[i, 0] = obj['bbox'][2] * obj['bbox'][3]
bboxes[i, :, 0], bboxes[i, :,
1] = obj['bbox'][0], obj['bbox'][1]
bboxes[i, 1, 0] += obj['bbox'][2]
bboxes[i, 2, 1] += obj['bbox'][3]
bboxes[i, 3, 0] += obj['bbox'][2]
bboxes[i, 3, 1] += obj['bbox'][3]
db_rec['bboxes'] = bboxes
db_rec['areas'] = areas

return db_rec

def _get_joints(self, anno):
Expand Down

0 comments on commit e79d12b

Please sign in to comment.