Skip to content

Commit

Permalink
[Feature] support DEKR (open-mmlab#1693)
Browse files Browse the repository at this point in the history
  • Loading branch information
JinluZhang1126 committed Oct 14, 2022
1 parent e79d12b commit a5ce55b
Show file tree
Hide file tree
Showing 36 changed files with 3,053 additions and 101 deletions.
22 changes: 22 additions & 0 deletions configs/body/2d_kpt_sview_rgb_img/dekr/README.md
@@ -0,0 +1,22 @@
# Bottom-up Human Pose Estimation via Disentangled Keypoint Regression (DEKR)

<!-- [ALGORITHM] -->

<details>
<summary align="right"><a href="https://arxiv.org/abs/2104.02300">DEKR (CVPR'2021)</a></summary>

```bibtex
@inproceedings{geng2021bottom,
title={Bottom-up human pose estimation via disentangled keypoint regression},
author={Geng, Zigang and Sun, Ke and Xiao, Bin and Zhang, Zhaoxiang and Wang, Jingdong},
booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},
pages={14676--14686},
year={2021}
}
```

</details>

DEKR is a popular 2D bottom-up pose estimation approach that simultaneously detects all the instances and regresses the offsets from the instance centers to joints.

In order to predict the offsets more accurately, the offsets of different joints are regressed using separated branches with deformable convolutional layers. Thus convolution kernels with different shapes are adopted to extract features for the corresponding joint.
78 changes: 78 additions & 0 deletions configs/body/2d_kpt_sview_rgb_img/dekr/coco/hrnet_coco.md
@@ -0,0 +1,78 @@
<!-- [ALGORITHM] -->

<details>
<summary align="right"><a href="https://arxiv.org/abs/2104.02300">DEKR (CVPR'2021)</a></summary>

```bibtex
@inproceedings{geng2021bottom,
title={Bottom-up human pose estimation via disentangled keypoint regression},
author={Geng, Zigang and Sun, Ke and Xiao, Bin and Zhang, Zhaoxiang and Wang, Jingdong},
booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},
pages={14676--14686},
year={2021}
}
```

</details>

<!-- [ALGORITHM] -->

<details>
<summary align="right"><a href="http://openaccess.thecvf.com/content_CVPR_2019/html/Sun_Deep_High-Resolution_Representation_Learning_for_Human_Pose_Estimation_CVPR_2019_paper.html">HRNet (CVPR'2019)</a></summary>

```bibtex
@inproceedings{sun2019deep,
title={Deep high-resolution representation learning for human pose estimation},
author={Sun, Ke and Xiao, Bin and Liu, Dong and Wang, Jingdong},
booktitle={Proceedings of the IEEE conference on computer vision and pattern recognition},
pages={5693--5703},
year={2019}
}
```

</details>

<!-- [DATASET] -->

<details>
<summary align="right"><a href="https://link.springer.com/chapter/10.1007/978-3-319-10602-1_48">COCO (ECCV'2014)</a></summary>

```bibtex
@inproceedings{lin2014microsoft,
title={Microsoft coco: Common objects in context},
author={Lin, Tsung-Yi and Maire, Michael and Belongie, Serge and Hays, James and Perona, Pietro and Ramanan, Deva and Doll{\'a}r, Piotr and Zitnick, C Lawrence},
booktitle={European conference on computer vision},
pages={740--755},
year={2014},
organization={Springer}
}
```

</details>

Results on COCO val2017 without multi-scale test

| Arch | Input Size | AP | AP<sup>50</sup> | AP<sup>75</sup> | AR | AR<sup>50</sup> | ckpt | log |
| :-------------------------------------------- | :--------: | :---: | :-------------: | :-------------: | :---: | :-------------: | :-------------------------------------------: | :-------------------------------------------: |
| [HRNet-w32](/configs/body/2d_kpt_sview_rgb_img/disentangled_keypoint_regression/coco/hrnet_w32_coco_512x512.py) | 512x512 | 0.680 | 0.868 | 0.745 | 0.728 | 0.897 | [ckpt](https://download.openmmlab.com/mmpose/bottom_up/dekr/hrnet_w32_coco_512x512-2a3056de_20220928.pth) | [log](https://download.openmmlab.com/mmpose/bottom_up/dekr/hrnet_w32_coco_512x512-20220928.log.json) |
| [HRNet-w48](/configs/body/2d_kpt_sview_rgb_img/disentangled_keypoint_regression/coco/hrnet_w48_coco_640x640.py) | 640x640 | 0.709 | 0.876 | 0.773 | 0.758 | 0.909 | [ckpt](https://download.openmmlab.com/mmpose/bottom_up/dekr/hrnet_w48_coco_640x640-8854b2f1_20220930.pth) | [log](https://download.openmmlab.com/mmpose/bottom_up/dekr/hrnet_w48_coco_640x640-20220930.log.json) |

Results on COCO val2017 with multi-scale test. 3 default scales (\[2, 1, 0.5\]) are used

| Arch | Input Size | AP | AP<sup>50</sup> | AP<sup>75</sup> | AR | AR<sup>50</sup> | ckpt |
| :------------------------------------------------------------------ | :--------: | :---: | :-------------: | :-------------: | :---: | :-------------: | :------------------------------------------------------------------: |
| [HRNet-w32](/configs/body/2d_kpt_sview_rgb_img/disentangled_keypoint_regression/coco/hrnet_w32_coco_512x512_multiscale.py)\* | 512x512 | 0.705 | 0.878 | 0.767 | 0.759 | 0.921 | [ckpt](https://download.openmmlab.com/mmpose/bottom_up/dekr/hrnet_w32_coco_512x512-2a3056de_20220928.pth) |
| [HRNet-w48](/configs/body/2d_kpt_sview_rgb_img/disentangled_keypoint_regression/coco/hrnet_w48_coco_640x640_multiscale.py)\* | 640x640 | 0.722 | 0.882 | 0.785 | 0.778 | 0.928 | [ckpt](https://download.openmmlab.com/mmpose/bottom_up/dekr/hrnet_w48_coco_640x640-8854b2f1_20220930.pth) |

\* these configs are generally used for evaluation. The training settings are identical to their single-scale counterparts.

The results of models provided by the authors on COCO val2017 using the same evaluation protocol

| Arch | Input Size | Setting | AP | AP<sup>50</sup> | AP<sup>75</sup> | AR | AR<sup>50</sup> | ckpt |
| :-------- | :--------: | :----------: | :---: | :-------------: | :-------------: | :---: | :-------------: | :----------------------------------------------------------: |
| HRNet-w32 | 512x512 | single-scale | 0.678 | 0.868 | 0.744 | 0.728 | 0.897 | see [official implementation](https://github.com/HRNet/DEKR) |
| HRNet-w48 | 640x640 | single-scale | 0.707 | 0.876 | 0.773 | 0.757 | 0.909 | see [official implementation](https://github.com/HRNet/DEKR) |
| HRNet-w32 | 512x512 | multi-scale | 0.708 | 0.880 | 0.773 | 0.763 | 0.921 | see [official implementation](https://github.com/HRNet/DEKR) |
| HRNet-w48 | 640x640 | multi-scale | 0.721 | 0.881 | 0.786 | 0.779 | 0.927 | see [official implementation](https://github.com/HRNet/DEKR) |

The discrepancy between these results and that shown in paper is attributed to the differences in implementation details in evaluation process.
73 changes: 73 additions & 0 deletions configs/body/2d_kpt_sview_rgb_img/dekr/coco/hrnet_coco.yml
@@ -0,0 +1,73 @@
Collections:
- Name: DEKR
Paper:
Title: Bottom-up human pose estimation via disentangled keypoint regression
URL: https://arxiv.org/abs/2104.02300
README: https://github.com/open-mmlab/mmpose/blob/master/docs/en/papers/algorithms/dekr.md
Models:
- Config: configs/body/2d_kpt_sview_rgb_img/disentangled_keypoint_regression/coco/hrnet_w32_coco_512x512.py
In Collection: DEKR
Metadata:
Architecture: &id001
- DEKR
- HRNet
Training Data: COCO
Name: disentangled_keypoint_regression_hrnet_w32_coco_512x512
Results:
- Dataset: COCO
Metrics:
AP: 0.68
AP@0.5: 0.868
AP@0.75: 0.745
AR: 0.728
AR@0.5: 0.897
Task: Body 2D Keypoint
Weights: https://download.openmmlab.com/mmpose/bottom_up/dekr/hrnet_w32_coco_512x512-2a3056de_20220928.pth
- Config: configs/body/2d_kpt_sview_rgb_img/disentangled_keypoint_regression/coco/hrnet_w48_coco_640x640.py
In Collection: DEKR
Metadata:
Architecture: *id001
Training Data: COCO
Name: disentangled_keypoint_regression_hrnet_w48_coco_640x640
Results:
- Dataset: COCO
Metrics:
AP: 0.709
AP@0.5: 0.876
AP@0.75: 0.773
AR: 0.758
AR@0.5: 0.909
Task: Body 2D Keypoint
Weights: https://download.openmmlab.com/mmpose/bottom_up/dekr/hrnet_w48_coco_640x640-8854b2f1_20220930.pth
- Config: configs/body/2d_kpt_sview_rgb_img/disentangled_keypoint_regression/coco/hrnet_w32_coco_512x512_multiscale.py
In Collection: DEKR
Metadata:
Architecture: *id001
Training Data: COCO
Name: disentangled_keypoint_regression_hrnet_w32_coco_512x512_multiscale
Results:
- Dataset: COCO
Metrics:
AP: 0.705
AP@0.5: 0.878
AP@0.75: 0.767
AR: 0.759
AR@0.5: 0.921
Task: Body 2D Keypoint
Weights: https://download.openmmlab.com/mmpose/bottom_up/dekr/hrnet_w32_coco_512x512-2a3056de_20220928.pth
- Config: configs/body/2d_kpt_sview_rgb_img/disentangled_keypoint_regression/coco/hrnet_w48_coco_640x640_multiscale.py
In Collection: DEKR
Metadata:
Architecture: *id001
Training Data: COCO
Name: disentangled_keypoint_regression_hrnet_w48_coco_640x640_multiscale
Results:
- Dataset: COCO
Metrics:
AP: 0.722
AP@0.5: 0.882
AP@0.75: 0.785
AR: 0.778
AR@0.5: 0.928
Task: Body 2D Keypoint
Weights: https://download.openmmlab.com/mmpose/bottom_up/dekr/hrnet_w48_coco_640x640-8854b2f1_20220930.pth
196 changes: 196 additions & 0 deletions configs/body/2d_kpt_sview_rgb_img/dekr/coco/hrnet_w32_coco_512x512.py
@@ -0,0 +1,196 @@
_base_ = [
'../../../../_base_/default_runtime.py',
'../../../../_base_/datasets/coco.py'
]
checkpoint_config = dict(interval=20)
evaluation = dict(interval=20, metric='mAP', save_best='AP')

optimizer = dict(
type='Adam',
lr=0.001,
)
optimizer_config = dict(grad_clip=None)
# learning policy
lr_config = dict(
policy='step',
warmup='linear',
warmup_iters=500,
warmup_ratio=0.001,
step=[90, 120])
total_epochs = 140
channel_cfg = dict(
dataset_joints=17,
dataset_channel=[
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16],
],
inference_channel=[
0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16
])

data_cfg = dict(
image_size=512,
base_size=256,
base_sigma=2,
heatmap_size=[128],
num_joints=channel_cfg['dataset_joints'],
dataset_channel=channel_cfg['dataset_channel'],
inference_channel=channel_cfg['inference_channel'],
num_scales=1,
scale_aware_sigma=False,
)

# model settings
model = dict(
type='DisentangledKeypointRegressor',
pretrained='https://download.openmmlab.com/mmpose/'
'pretrain_models/hrnet_w32-36af842e.pth',
backbone=dict(
type='HRNet',
in_channels=3,
extra=dict(
stage1=dict(
num_modules=1,
num_branches=1,
block='BOTTLENECK',
num_blocks=(4, ),
num_channels=(64, )),
stage2=dict(
num_modules=1,
num_branches=2,
block='BASIC',
num_blocks=(4, 4),
num_channels=(32, 64)),
stage3=dict(
num_modules=4,
num_branches=3,
block='BASIC',
num_blocks=(4, 4, 4),
num_channels=(32, 64, 128)),
stage4=dict(
num_modules=3,
num_branches=4,
block='BASIC',
num_blocks=(4, 4, 4, 4),
num_channels=(32, 64, 128, 256),
multiscale_output=True)),
),
keypoint_head=dict(
type='DEKRHead',
in_channels=(32, 64, 128, 256),
in_index=(0, 1, 2, 3),
num_heatmap_filters=32,
num_joints=channel_cfg['dataset_joints'],
input_transform='resize_concat',
heatmap_loss=dict(
type='JointsMSELoss',
use_target_weight=True,
loss_weight=1.0,
),
offset_loss=dict(
type='SoftWeightSmoothL1Loss',
use_target_weight=True,
supervise_empty=False,
loss_weight=0.002,
beta=1 / 9.0,
)),
train_cfg=dict(),
test_cfg=dict(
num_joints=channel_cfg['dataset_joints'],
max_num_people=30,
project2image=False,
align_corners=False,
max_pool_kernel=5,
use_nms=True,
nms_dist_thr=0.05,
nms_joints_thr=8,
keypoint_threshold=0.01,
rescore_cfg=dict(
in_channels=74,
norm_indexes=(5, 6),
pretrained='https://download.openmmlab.com/mmpose/'
'pretrain_models/kpt_rescore_coco-33d58c5c.pth'),
flip_test=True))

train_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='BottomUpRandomAffine',
rot_factor=30,
scale_factor=[0.75, 1.5],
scale_type='short',
trans_factor=40),
dict(type='BottomUpRandomFlip', flip_prob=0.5),
dict(type='ToTensor'),
dict(
type='NormalizeTensor',
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]),
dict(type='GetKeypointCenterArea'),
dict(
type='BottomUpGenerateHeatmapTarget',
sigma=(2, 4),
gen_center_heatmap=True,
bg_weight=0.1,
),
dict(
type='BottomUpGenerateOffsetTarget',
radius=4,
),
dict(
type='Collect',
keys=['img', 'heatmaps', 'masks', 'offsets', 'offset_weights'],
meta_keys=[]),
]

val_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='BottomUpGetImgSize', test_scale_factor=[1]),
dict(
type='BottomUpResizeAlign',
transforms=[
dict(type='ToTensor'),
dict(
type='NormalizeTensor',
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]),
]),
dict(
type='Collect',
keys=['img'],
meta_keys=[
'image_file', 'aug_data', 'test_scale_factor', 'base_size',
'center', 'scale', 'flip_index', 'num_joints', 'skeleton',
'image_size', 'heatmap_size'
]),
]

test_pipeline = val_pipeline

data_root = 'data/coco'
data = dict(
workers_per_gpu=4,
train_dataloader=dict(samples_per_gpu=10),
val_dataloader=dict(samples_per_gpu=1),
test_dataloader=dict(samples_per_gpu=1),
train=dict(
type='BottomUpCocoDataset',
ann_file=f'{data_root}/annotations/person_keypoints_train2017.json',
img_prefix=f'{data_root}/train2017/',
data_cfg=data_cfg,
pipeline=train_pipeline,
dataset_info={{_base_.dataset_info}}),
val=dict(
type='BottomUpCocoDataset',
ann_file=f'{data_root}/annotations/person_keypoints_val2017.json',
img_prefix=f'{data_root}/val2017/',
data_cfg=data_cfg,
pipeline=val_pipeline,
dataset_info={{_base_.dataset_info}}),
test=dict(
type='BottomUpCocoDataset',
ann_file=f'{data_root}/annotations/person_keypoints_val2017.json',
img_prefix=f'{data_root}/val2017/',
data_cfg=data_cfg,
pipeline=test_pipeline,
dataset_info={{_base_.dataset_info}}),
)

0 comments on commit a5ce55b

Please sign in to comment.