Skip to content

Commit

Permalink
[Feature] Support Concat dataset (open-mmlab#1139)
Browse files Browse the repository at this point in the history
* concat dataset

* Add unit tests for ConcatDataset

Co-authored-by: canwang <wangcan@sensentime.com>
Co-authored-by: ly015 <liyining0712@gmail.com>
  • Loading branch information
3 people committed Jan 26, 2022
1 parent a121d4d commit 399ebda
Show file tree
Hide file tree
Showing 2 changed files with 111 additions and 2 deletions.
46 changes: 44 additions & 2 deletions mmpose/datasets/builder.py
@@ -1,13 +1,15 @@
# Copyright (c) OpenMMLab. All rights reserved.
import copy
import platform
import random
from functools import partial

import numpy as np
from mmcv.parallel import collate
from mmcv.runner import get_dist_info
from mmcv.utils import Registry, build_from_cfg
from mmcv.utils import Registry, build_from_cfg, is_seq_of
from mmcv.utils.parrots_wrapper import _get_dataloader
from torch.utils.data.dataset import ConcatDataset

from .samplers import DistributedSampler

Expand All @@ -24,6 +26,39 @@
PIPELINES = Registry('pipeline')


def _concat_dataset(cfg, default_args=None):
types = cfg['type']
ann_files = cfg['ann_file']
img_prefixes = cfg.get('img_prefix', None)
dataset_infos = cfg.get('dataset_info', None)

num_joints = cfg['data_cfg'].get('num_joints', None)
dataset_channel = cfg['data_cfg'].get('dataset_channel', None)

datasets = []
num_dset = len(ann_files)
for i in range(num_dset):
cfg_copy = copy.deepcopy(cfg)
cfg_copy['ann_file'] = ann_files[i]

if isinstance(types, (list, tuple)):
cfg_copy['type'] = types[i]
if isinstance(img_prefixes, (list, tuple)):
cfg_copy['img_prefix'] = img_prefixes[i]
if isinstance(dataset_infos, (list, tuple)):
cfg_copy['dataset_info'] = dataset_infos[i]

if isinstance(num_joints, (list, tuple)):
cfg_copy['data_cfg']['num_joints'] = num_joints[i]

if is_seq_of(dataset_channel, list):
cfg_copy['data_cfg']['dataset_channel'] = dataset_channel[i]

datasets.append(build_dataset(cfg_copy, default_args))

return ConcatDataset(datasets)


def build_dataset(cfg, default_args=None):
"""Build a dataset from config dict.
Expand All @@ -37,9 +72,16 @@ def build_dataset(cfg, default_args=None):
"""
from .dataset_wrappers import RepeatDataset

if cfg['type'] == 'RepeatDataset':
if isinstance(cfg, (list, tuple)):
dataset = ConcatDataset([build_dataset(c, default_args) for c in cfg])
elif cfg['type'] == 'ConcatDataset':
dataset = ConcatDataset(
[build_dataset(c, default_args) for c in cfg['datasets']])
elif cfg['type'] == 'RepeatDataset':
dataset = RepeatDataset(
build_dataset(cfg['dataset'], default_args), cfg['times'])
elif isinstance(cfg.get('ann_file'), (list, tuple)):
dataset = _concat_dataset(cfg, default_args)
else:
dataset = build_from_cfg(cfg, DATASETS, default_args)
return dataset
Expand Down
67 changes: 67 additions & 0 deletions tests/test_datasets/test_dataset_wrapper.py
@@ -0,0 +1,67 @@
# Copyright (c) OpenMMLab. All rights reserved.
from mmcv import Config

from mmpose.datasets.builder import build_dataset


def test_concat_dataset():
# build COCO-like dataset config
dataset_info = Config.fromfile(
'configs/_base_/datasets/coco.py').dataset_info

channel_cfg = dict(
num_output_channels=17,
dataset_joints=17,
dataset_channel=[
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16],
],
inference_channel=[
0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16
])

data_cfg = dict(
image_size=[192, 256],
heatmap_size=[48, 64],
num_output_channels=channel_cfg['num_output_channels'],
num_joints=channel_cfg['dataset_joints'],
dataset_channel=channel_cfg['dataset_channel'],
inference_channel=channel_cfg['inference_channel'],
soft_nms=False,
nms_thr=1.0,
oks_thr=0.9,
vis_thr=0.2,
use_gt_bbox=True,
det_bbox_thr=0.0,
bbox_file='tests/data/coco/test_coco_det_AP_H_56.json',
)

dataset_cfg = dict(
type='TopDownCocoDataset',
ann_file='tests/data/coco/test_coco.json',
img_prefix='tests/data/coco/',
data_cfg=data_cfg,
pipeline=[],
dataset_info=dataset_info)

dataset = build_dataset(dataset_cfg)

# Case 1: build ConcatDataset explicitly
concat_dataset_cfg = dict(
type='ConcatDataset', datasets=[dataset_cfg, dataset_cfg])
concat_dataset = build_dataset(concat_dataset_cfg)
assert len(concat_dataset) == 2 * len(dataset)

# Case 2: build ConcatDataset from cfg sequence
concat_dataset = build_dataset([dataset_cfg, dataset_cfg])
assert len(concat_dataset) == 2 * len(dataset)

# Case 3: build ConcatDataset from ann_file sequence
concat_dataset_cfg = dataset_cfg.copy()
for key in ['ann_file', 'type', 'img_prefix', 'dataset_info']:
val = concat_dataset_cfg[key]
concat_dataset_cfg[key] = [val] * 2
for key in ['num_joints', 'dataset_channel']:
val = concat_dataset_cfg['data_cfg'][key]
concat_dataset_cfg['data_cfg'][key] = [val] * 2
concat_dataset = build_dataset(concat_dataset_cfg)
assert len(concat_dataset) == 2 * len(dataset)

0 comments on commit 399ebda

Please sign in to comment.