From 399ebda3713112ad72734da701e30b2a608c61d2 Mon Sep 17 00:00:00 2001 From: CanWang Date: Wed, 26 Jan 2022 11:58:36 +0800 Subject: [PATCH] [Feature] Support Concat dataset (#1139) * concat dataset * Add unit tests for ConcatDataset Co-authored-by: canwang Co-authored-by: ly015 --- mmpose/datasets/builder.py | 46 +++++++++++++- tests/test_datasets/test_dataset_wrapper.py | 67 +++++++++++++++++++++ 2 files changed, 111 insertions(+), 2 deletions(-) create mode 100644 tests/test_datasets/test_dataset_wrapper.py diff --git a/mmpose/datasets/builder.py b/mmpose/datasets/builder.py index cdee6d1ebf..990ba859e0 100644 --- a/mmpose/datasets/builder.py +++ b/mmpose/datasets/builder.py @@ -1,4 +1,5 @@ # Copyright (c) OpenMMLab. All rights reserved. +import copy import platform import random from functools import partial @@ -6,8 +7,9 @@ import numpy as np from mmcv.parallel import collate from mmcv.runner import get_dist_info -from mmcv.utils import Registry, build_from_cfg +from mmcv.utils import Registry, build_from_cfg, is_seq_of from mmcv.utils.parrots_wrapper import _get_dataloader +from torch.utils.data.dataset import ConcatDataset from .samplers import DistributedSampler @@ -24,6 +26,39 @@ PIPELINES = Registry('pipeline') +def _concat_dataset(cfg, default_args=None): + types = cfg['type'] + ann_files = cfg['ann_file'] + img_prefixes = cfg.get('img_prefix', None) + dataset_infos = cfg.get('dataset_info', None) + + num_joints = cfg['data_cfg'].get('num_joints', None) + dataset_channel = cfg['data_cfg'].get('dataset_channel', None) + + datasets = [] + num_dset = len(ann_files) + for i in range(num_dset): + cfg_copy = copy.deepcopy(cfg) + cfg_copy['ann_file'] = ann_files[i] + + if isinstance(types, (list, tuple)): + cfg_copy['type'] = types[i] + if isinstance(img_prefixes, (list, tuple)): + cfg_copy['img_prefix'] = img_prefixes[i] + if isinstance(dataset_infos, (list, tuple)): + cfg_copy['dataset_info'] = dataset_infos[i] + + if isinstance(num_joints, (list, tuple)): + cfg_copy['data_cfg']['num_joints'] = num_joints[i] + + if is_seq_of(dataset_channel, list): + cfg_copy['data_cfg']['dataset_channel'] = dataset_channel[i] + + datasets.append(build_dataset(cfg_copy, default_args)) + + return ConcatDataset(datasets) + + def build_dataset(cfg, default_args=None): """Build a dataset from config dict. @@ -37,9 +72,16 @@ def build_dataset(cfg, default_args=None): """ from .dataset_wrappers import RepeatDataset - if cfg['type'] == 'RepeatDataset': + if isinstance(cfg, (list, tuple)): + dataset = ConcatDataset([build_dataset(c, default_args) for c in cfg]) + elif cfg['type'] == 'ConcatDataset': + dataset = ConcatDataset( + [build_dataset(c, default_args) for c in cfg['datasets']]) + elif cfg['type'] == 'RepeatDataset': dataset = RepeatDataset( build_dataset(cfg['dataset'], default_args), cfg['times']) + elif isinstance(cfg.get('ann_file'), (list, tuple)): + dataset = _concat_dataset(cfg, default_args) else: dataset = build_from_cfg(cfg, DATASETS, default_args) return dataset diff --git a/tests/test_datasets/test_dataset_wrapper.py b/tests/test_datasets/test_dataset_wrapper.py new file mode 100644 index 0000000000..f724d251d6 --- /dev/null +++ b/tests/test_datasets/test_dataset_wrapper.py @@ -0,0 +1,67 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mmcv import Config + +from mmpose.datasets.builder import build_dataset + + +def test_concat_dataset(): + # build COCO-like dataset config + dataset_info = Config.fromfile( + 'configs/_base_/datasets/coco.py').dataset_info + + channel_cfg = dict( + num_output_channels=17, + dataset_joints=17, + dataset_channel=[ + [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16], + ], + inference_channel=[ + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16 + ]) + + data_cfg = dict( + image_size=[192, 256], + heatmap_size=[48, 64], + num_output_channels=channel_cfg['num_output_channels'], + num_joints=channel_cfg['dataset_joints'], + dataset_channel=channel_cfg['dataset_channel'], + inference_channel=channel_cfg['inference_channel'], + soft_nms=False, + nms_thr=1.0, + oks_thr=0.9, + vis_thr=0.2, + use_gt_bbox=True, + det_bbox_thr=0.0, + bbox_file='tests/data/coco/test_coco_det_AP_H_56.json', + ) + + dataset_cfg = dict( + type='TopDownCocoDataset', + ann_file='tests/data/coco/test_coco.json', + img_prefix='tests/data/coco/', + data_cfg=data_cfg, + pipeline=[], + dataset_info=dataset_info) + + dataset = build_dataset(dataset_cfg) + + # Case 1: build ConcatDataset explicitly + concat_dataset_cfg = dict( + type='ConcatDataset', datasets=[dataset_cfg, dataset_cfg]) + concat_dataset = build_dataset(concat_dataset_cfg) + assert len(concat_dataset) == 2 * len(dataset) + + # Case 2: build ConcatDataset from cfg sequence + concat_dataset = build_dataset([dataset_cfg, dataset_cfg]) + assert len(concat_dataset) == 2 * len(dataset) + + # Case 3: build ConcatDataset from ann_file sequence + concat_dataset_cfg = dataset_cfg.copy() + for key in ['ann_file', 'type', 'img_prefix', 'dataset_info']: + val = concat_dataset_cfg[key] + concat_dataset_cfg[key] = [val] * 2 + for key in ['num_joints', 'dataset_channel']: + val = concat_dataset_cfg['data_cfg'][key] + concat_dataset_cfg['data_cfg'][key] = [val] * 2 + concat_dataset = build_dataset(concat_dataset_cfg) + assert len(concat_dataset) == 2 * len(dataset)