Skip to content

Commit

Permalink
[Feature] Support K-fold cross-validation (open-mmlab#563)
Browse files Browse the repository at this point in the history
* Support to use `indices` to specify which samples to evaluate.

* Add KFoldDataset wrapper

* Rename 'K' to 'num_splits' accroding to sklearn

* Add `kfold-cross-valid.py`

* Add unit tests

* Add help doc and docstring
  • Loading branch information
mzr1996 authored and Ezra-Yu committed Feb 14, 2022
1 parent d79c830 commit 5beec41
Show file tree
Hide file tree
Showing 8 changed files with 641 additions and 5 deletions.
4 changes: 2 additions & 2 deletions mmcls/datasets/__init__.py
Expand Up @@ -4,7 +4,7 @@
build_dataset, build_sampler)
from .cifar import CIFAR10, CIFAR100
from .dataset_wrappers import (ClassBalancedDataset, ConcatDataset,
RepeatDataset)
KFoldDataset, RepeatDataset)
from .imagenet import ImageNet
from .imagenet21k import ImageNet21k
from .mnist import MNIST, FashionMNIST
Expand All @@ -17,5 +17,5 @@
'VOC', 'MultiLabelDataset', 'build_dataloader', 'build_dataset',
'DistributedSampler', 'ConcatDataset', 'RepeatDataset',
'ClassBalancedDataset', 'DATASETS', 'PIPELINES', 'ImageNet21k', 'SAMPLERS',
'build_sampler', 'RepeatAugSampler'
'build_sampler', 'RepeatAugSampler', 'KFoldDataset'
]
5 changes: 5 additions & 0 deletions mmcls/datasets/base_dataset.py
Expand Up @@ -118,6 +118,7 @@ def evaluate(self,
results,
metric='accuracy',
metric_options=None,
indices=None,
logger=None):
"""Evaluate the dataset.
Expand All @@ -128,6 +129,8 @@ def evaluate(self,
metric_options (dict, optional): Options for calculating metrics.
Allowed keys are 'topk', 'thrs' and 'average_mode'.
Defaults to None.
indices (list, optional): The indices of samples corresponding to
the results. Defaults to None.
logger (logging.Logger | str, optional): Logger used for printing
related information during evaluation. Defaults to None.
Returns:
Expand All @@ -145,6 +148,8 @@ def evaluate(self,
eval_results = {}
results = np.vstack(results)
gt_labels = self.get_gt_labels()
if indices is not None:
gt_labels = gt_labels[indices]
num_imgs = len(results)
assert len(gt_labels) == num_imgs, 'dataset testing results should '\
'be of the same length as gt_labels.'
Expand Down
10 changes: 9 additions & 1 deletion mmcls/datasets/builder.py
@@ -1,4 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved.
import copy
import platform
import random
from functools import partial
Expand All @@ -25,7 +26,7 @@

def build_dataset(cfg, default_args=None):
from .dataset_wrappers import (ConcatDataset, RepeatDataset,
ClassBalancedDataset)
ClassBalancedDataset, KFoldDataset)
if isinstance(cfg, (list, tuple)):
dataset = ConcatDataset([build_dataset(c, default_args) for c in cfg])
elif cfg['type'] == 'RepeatDataset':
Expand All @@ -34,6 +35,13 @@ def build_dataset(cfg, default_args=None):
elif cfg['type'] == 'ClassBalancedDataset':
dataset = ClassBalancedDataset(
build_dataset(cfg['dataset'], default_args), cfg['oversample_thr'])
elif cfg['type'] == 'KFoldDataset':
cp_cfg = copy.deepcopy(cfg)
if cp_cfg.get('test_mode', None) is None:
cp_cfg['test_mode'] = (default_args or {}).pop('test_mode', False)
cp_cfg['dataset'] = build_dataset(cp_cfg['dataset'], default_args)
cp_cfg.pop('type')
dataset = KFoldDataset(**cp_cfg)
else:
dataset = build_from_cfg(cfg, DATASETS, default_args)

Expand Down
53 changes: 53 additions & 0 deletions mmcls/datasets/dataset_wrappers.py
Expand Up @@ -170,3 +170,56 @@ def __getitem__(self, idx):

def __len__(self):
return len(self.repeat_indices)


@DATASETS.register_module()
class KFoldDataset:
"""A wrapper of dataset for K-Fold cross-validation.
K-Fold cross-validation divides all the samples in groups of samples,
called folds, of almost equal sizes. And we use k-1 of folds to do training
and use the fold left to do validation.
Args:
dataset (:obj:`CustomDataset`): The dataset to be divided.
fold (int): The fold used to do validation. Defaults to 0.
num_splits (int): The number of all folds. Defaults to 5.
test_mode (bool): Use the training dataset or validation dataset.
Defaults to False.
seed (int, optional): The seed to shuffle the dataset before splitting.
If None, not shuffle the dataset. Defaults to None.
"""

def __init__(self,
dataset,
fold=0,
num_splits=5,
test_mode=False,
seed=None):
self.dataset = dataset
self.CLASSES = dataset.CLASSES
self.test_mode = test_mode
self.num_splits = num_splits

length = len(dataset)
indices = list(range(length))
if isinstance(seed, int):
rng = np.random.default_rng(seed)
rng.shuffle(indices)

test_start = length * fold // num_splits
test_end = length * (fold + 1) // num_splits
if test_mode:
self.indices = indices[test_start:test_end]
else:
self.indices = indices[:test_start] + indices[test_end:]

def __getitem__(self, idx):
return self.dataset[self.indices[idx]]

def __len__(self):
return len(self.indices)

def evaluate(self, *args, **kwargs):
kwargs['indices'] = self.indices
return self.dataset.evaluate(*args, **kwargs)
3 changes: 3 additions & 0 deletions mmcls/datasets/multi_label.py
Expand Up @@ -28,6 +28,7 @@ def evaluate(self,
results,
metric='mAP',
metric_options=None,
indices=None,
logger=None,
**deprecated_kwargs):
"""Evaluate the dataset.
Expand Down Expand Up @@ -62,6 +63,8 @@ def evaluate(self,
eval_results = {}
results = np.vstack(results)
gt_labels = self.get_gt_labels()
if indices is not None:
gt_labels = gt_labels[indices]
num_imgs = len(results)
assert len(gt_labels) == num_imgs, 'dataset testing results should '\
'be of the same length as gt_labels.'
Expand Down
152 changes: 151 additions & 1 deletion tests/test_data/test_builder.py
@@ -1,9 +1,14 @@
import os.path as osp
from copy import deepcopy
from unittest.mock import patch

import torch
from mmcv.utils import digit_version

from mmcls.datasets import build_dataloader
from mmcls.datasets import ImageNet, build_dataloader, build_dataset
from mmcls.datasets.dataset_wrappers import (ClassBalancedDataset,
ConcatDataset, KFoldDataset,
RepeatDataset)


class TestDataloaderBuilder():
Expand Down Expand Up @@ -119,3 +124,148 @@ def test_distributed(self, _):
expect = torch.tensor(
[0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 5, 5, 5, 6, 6][1::2])
assert all(torch.cat(list(iter(dataloader))) == expect)


class TestDatasetBuilder():

@classmethod
def setup_class(cls):
data_prefix = osp.join(osp.dirname(__file__), '../data/dataset')
cls.dataset_cfg = dict(
type='ImageNet',
data_prefix=data_prefix,
ann_file=osp.join(data_prefix, 'ann.txt'),
pipeline=[],
test_mode=False,
)

def test_normal_dataset(self):
# Test build
dataset = build_dataset(self.dataset_cfg)
assert isinstance(dataset, ImageNet)
assert dataset.test_mode == self.dataset_cfg['test_mode']

# Test default_args
dataset = build_dataset(self.dataset_cfg, {'test_mode': True})
assert dataset.test_mode == self.dataset_cfg['test_mode']

cp_cfg = deepcopy(self.dataset_cfg)
cp_cfg.pop('test_mode')
dataset = build_dataset(cp_cfg, {'test_mode': True})
assert dataset.test_mode

def test_concat_dataset(self):
# Test build
dataset = build_dataset([self.dataset_cfg, self.dataset_cfg])
assert isinstance(dataset, ConcatDataset)
assert dataset.datasets[0].test_mode == self.dataset_cfg['test_mode']

# Test default_args
dataset = build_dataset([self.dataset_cfg, self.dataset_cfg],
{'test_mode': True})
assert dataset.datasets[0].test_mode == self.dataset_cfg['test_mode']

cp_cfg = deepcopy(self.dataset_cfg)
cp_cfg.pop('test_mode')
dataset = build_dataset([cp_cfg, cp_cfg], {'test_mode': True})
assert dataset.datasets[0].test_mode

def test_repeat_dataset(self):
# Test build
dataset = build_dataset(
dict(type='RepeatDataset', dataset=self.dataset_cfg, times=3))
assert isinstance(dataset, RepeatDataset)
assert dataset.dataset.test_mode == self.dataset_cfg['test_mode']

# Test default_args
dataset = build_dataset(
dict(type='RepeatDataset', dataset=self.dataset_cfg, times=3),
{'test_mode': True})
assert dataset.dataset.test_mode == self.dataset_cfg['test_mode']

cp_cfg = deepcopy(self.dataset_cfg)
cp_cfg.pop('test_mode')
dataset = build_dataset(
dict(type='RepeatDataset', dataset=cp_cfg, times=3),
{'test_mode': True})
assert dataset.dataset.test_mode

def test_class_balance_dataset(self):
# Test build
dataset = build_dataset(
dict(
type='ClassBalancedDataset',
dataset=self.dataset_cfg,
oversample_thr=1.,
))
assert isinstance(dataset, ClassBalancedDataset)
assert dataset.dataset.test_mode == self.dataset_cfg['test_mode']

# Test default_args
dataset = build_dataset(
dict(
type='ClassBalancedDataset',
dataset=self.dataset_cfg,
oversample_thr=1.,
), {'test_mode': True})
assert dataset.dataset.test_mode == self.dataset_cfg['test_mode']

cp_cfg = deepcopy(self.dataset_cfg)
cp_cfg.pop('test_mode')
dataset = build_dataset(
dict(
type='ClassBalancedDataset',
dataset=cp_cfg,
oversample_thr=1.,
), {'test_mode': True})
assert dataset.dataset.test_mode

def test_kfold_dataset(self):
# Test build
dataset = build_dataset(
dict(
type='KFoldDataset',
dataset=self.dataset_cfg,
fold=0,
num_splits=5,
test_mode=False,
))
assert isinstance(dataset, KFoldDataset)
assert not dataset.test_mode
assert dataset.dataset.test_mode == self.dataset_cfg['test_mode']

# Test default_args
dataset = build_dataset(
dict(
type='KFoldDataset',
dataset=self.dataset_cfg,
fold=0,
num_splits=5,
test_mode=False,
),
default_args={
'test_mode': True,
'classes': [1, 2, 3]
})
assert not dataset.test_mode
assert dataset.dataset.test_mode == self.dataset_cfg['test_mode']
assert dataset.dataset.CLASSES == [1, 2, 3]

cp_cfg = deepcopy(self.dataset_cfg)
cp_cfg.pop('test_mode')
dataset = build_dataset(
dict(
type='KFoldDataset',
dataset=self.dataset_cfg,
fold=0,
num_splits=5,
),
default_args={
'test_mode': True,
'classes': [1, 2, 3]
})
# The test_mode in default_args will be passed to KFoldDataset
assert dataset.test_mode
assert not dataset.dataset.test_mode
# Other default_args will be passed to child dataset.
assert dataset.dataset.CLASSES == [1, 2, 3]
64 changes: 63 additions & 1 deletion tests/test_data/test_datasets/test_dataset_wrapper.py
Expand Up @@ -8,7 +8,20 @@
import pytest

from mmcls.datasets import (BaseDataset, ClassBalancedDataset, ConcatDataset,
RepeatDataset)
KFoldDataset, RepeatDataset)


def mock_evaluate(results,
metric='accuracy',
metric_options=None,
indices=None,
logger=None):
return dict(
results=results,
metric=metric,
metric_options=metric_options,
indices=indices,
logger=logger)


@patch.multiple(BaseDataset, __abstractmethods__=set())
Expand All @@ -23,6 +36,8 @@ def construct_toy_multi_label_dataset(length):
dataset.data_infos = MagicMock()
dataset.data_infos.__len__.return_value = length
dataset.get_cat_ids = MagicMock(side_effect=lambda idx: cat_ids_list[idx])

dataset.evaluate = MagicMock(side_effect=mock_evaluate)
return dataset, cat_ids_list


Expand All @@ -35,6 +50,7 @@ def construct_toy_single_label_dataset(length):
dataset.data_infos = MagicMock()
dataset.data_infos.__len__.return_value = length
dataset.get_cat_ids = MagicMock(side_effect=lambda idx: cat_ids_list[idx])
dataset.evaluate = MagicMock(side_effect=mock_evaluate)
return dataset, cat_ids_list


Expand Down Expand Up @@ -107,3 +123,49 @@ def test_class_balanced_dataset(construct_dataset):
for idx in np.random.randint(0, len(repeat_factor_dataset), 3):
assert repeat_factor_dataset[idx] == bisect.bisect_right(
repeat_factors_cumsum, idx)


@pytest.mark.parametrize('construct_dataset', [
'construct_toy_multi_label_dataset', 'construct_toy_single_label_dataset'
])
def test_kfold_dataset(construct_dataset):
construct_toy_dataset = eval(construct_dataset)
dataset, _ = construct_toy_dataset(10)

# test without random seed
train_datasets = [
KFoldDataset(dataset, fold=i, num_splits=3, test_mode=False)
for i in range(5)
]
test_datasets = [
KFoldDataset(dataset, fold=i, num_splits=3, test_mode=True)
for i in range(5)
]

assert sum([i.indices for i in test_datasets], []) == list(range(10))
for train_set, test_set in zip(train_datasets, test_datasets):
train_samples = [train_set[i] for i in range(len(train_set))]
test_samples = [test_set[i] for i in range(len(test_set))]
assert set(train_samples + test_samples) == set(range(10))

# test with random seed
train_datasets = [
KFoldDataset(dataset, fold=i, num_splits=3, test_mode=False, seed=1)
for i in range(5)
]
test_datasets = [
KFoldDataset(dataset, fold=i, num_splits=3, test_mode=True, seed=1)
for i in range(5)
]

assert sum([i.indices for i in test_datasets], []) != list(range(10))
assert set(sum([i.indices for i in test_datasets], [])) == set(range(10))
for train_set, test_set in zip(train_datasets, test_datasets):
train_samples = [train_set[i] for i in range(len(train_set))]
test_samples = [test_set[i] for i in range(len(test_set))]
assert set(train_samples + test_samples) == set(range(10))

# test evaluate
for test_set in test_datasets:
eval_inputs = test_set.evaluate(None)
assert eval_inputs['indices'] == test_set.indices

0 comments on commit 5beec41

Please sign in to comment.