forked from open-mmlab/mmpose
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Feature] Support Concat dataset (open-mmlab#1139)
* concat dataset * Add unit tests for ConcatDataset Co-authored-by: canwang <wangcan@sensentime.com> Co-authored-by: ly015 <liyining0712@gmail.com>
- Loading branch information
1 parent
a121d4d
commit 399ebda
Showing
2 changed files
with
111 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,67 @@ | ||
# Copyright (c) OpenMMLab. All rights reserved. | ||
from mmcv import Config | ||
|
||
from mmpose.datasets.builder import build_dataset | ||
|
||
|
||
def test_concat_dataset(): | ||
# build COCO-like dataset config | ||
dataset_info = Config.fromfile( | ||
'configs/_base_/datasets/coco.py').dataset_info | ||
|
||
channel_cfg = dict( | ||
num_output_channels=17, | ||
dataset_joints=17, | ||
dataset_channel=[ | ||
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16], | ||
], | ||
inference_channel=[ | ||
0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16 | ||
]) | ||
|
||
data_cfg = dict( | ||
image_size=[192, 256], | ||
heatmap_size=[48, 64], | ||
num_output_channels=channel_cfg['num_output_channels'], | ||
num_joints=channel_cfg['dataset_joints'], | ||
dataset_channel=channel_cfg['dataset_channel'], | ||
inference_channel=channel_cfg['inference_channel'], | ||
soft_nms=False, | ||
nms_thr=1.0, | ||
oks_thr=0.9, | ||
vis_thr=0.2, | ||
use_gt_bbox=True, | ||
det_bbox_thr=0.0, | ||
bbox_file='tests/data/coco/test_coco_det_AP_H_56.json', | ||
) | ||
|
||
dataset_cfg = dict( | ||
type='TopDownCocoDataset', | ||
ann_file='tests/data/coco/test_coco.json', | ||
img_prefix='tests/data/coco/', | ||
data_cfg=data_cfg, | ||
pipeline=[], | ||
dataset_info=dataset_info) | ||
|
||
dataset = build_dataset(dataset_cfg) | ||
|
||
# Case 1: build ConcatDataset explicitly | ||
concat_dataset_cfg = dict( | ||
type='ConcatDataset', datasets=[dataset_cfg, dataset_cfg]) | ||
concat_dataset = build_dataset(concat_dataset_cfg) | ||
assert len(concat_dataset) == 2 * len(dataset) | ||
|
||
# Case 2: build ConcatDataset from cfg sequence | ||
concat_dataset = build_dataset([dataset_cfg, dataset_cfg]) | ||
assert len(concat_dataset) == 2 * len(dataset) | ||
|
||
# Case 3: build ConcatDataset from ann_file sequence | ||
concat_dataset_cfg = dataset_cfg.copy() | ||
for key in ['ann_file', 'type', 'img_prefix', 'dataset_info']: | ||
val = concat_dataset_cfg[key] | ||
concat_dataset_cfg[key] = [val] * 2 | ||
for key in ['num_joints', 'dataset_channel']: | ||
val = concat_dataset_cfg['data_cfg'][key] | ||
concat_dataset_cfg['data_cfg'][key] = [val] * 2 | ||
concat_dataset = build_dataset(concat_dataset_cfg) | ||
assert len(concat_dataset) == 2 * len(dataset) |