Skip to content

Commit

Permalink
[Refactor] Refactor data flow to make the interface more natural (ope…
Browse files Browse the repository at this point in the history
…n-mmlab#468)

* [Refactor]: modify interface of Visualizer.add_datasample (open-mmlab#365)

* [Refactor] Refactor data flow: refine `data_preprocessor`. (open-mmlab#359)

* refine data_preprocessor

* remove unused BATCH_DATA alias

* Fix type hints

* rename move_data to cast_data

* [Refactor] Refactor data flow: collate data in `collate_fn` of `DataLoader`  (open-mmlab#323)

* acollate data in dataloader

* fix docstring

* refine comment

* fix as comment

* refactor default collate and psedo collate

* foramt test file

* fix docstring

* fix as comment

* rename elem to data_item

* minor fix

* fix as comment

* [Refactor] Refactor data flow: `data_batch` argument of `Evaluator.process is a `dict` (open-mmlab#360)

* refine evaluator and metric

* compatible with new default collate

* replace default collate with pseudo

* Handle data_batch in metric

* fix unit test

* fix unit test

* fix unit test

* minor refine

* make data_batch optional

make data_batch optional

* rename outputs to predictions

* fix ut

* rename predictions to outputs

* fix docstring

* fix docstring

* fix unit test

* make outputs and data_batch to kwargs

* fix unit test

* keep signature of metric

* fix ut

* rename pred_sample arguments to data_sample(Visualizer)

* fix loop and ut

* [refactor]: Refactor model dataflow (open-mmlab#398)

* [Refactor] Refactor data flow: refine `data_preprocessor`. (open-mmlab#359)

* refine data_preprocessor

* remove unused BATCH_DATA alias

* Fix type hints

* rename move_data to cast_data

* refactor model data flow

tmp_commt

tmp commit

* make val_cfg and test_cfg optional

* roll back runner

* pass test mmdet

* fix as comment

fix as comment

fix ci in DataPreprocessor

* fix ut

* fix ut

* fix rebase main

* [Fix]: Fix test val ddp (open-mmlab#462)

* [Fix] Fix docstring and type hint of data flow (open-mmlab#463)

* Fix docstring of data flow

* change signature of hook

* fix unit test

* resolve conflicts

* fix lint
  • Loading branch information
HAOCHENYE committed Aug 24, 2022
1 parent 7e1d7af commit 8770c6c
Show file tree
Hide file tree
Showing 30 changed files with 843 additions and 449 deletions.
4 changes: 1 addition & 3 deletions docs/zh_cn/tutorials/hook.md
Expand Up @@ -228,10 +228,8 @@ class CheckInvalidLossHook(Hook):
Args:
runner (Runner): The runner of the training process.
batch_idx (int): The index of the current batch in the train loop.
data_batch (Sequence[dict], optional): Data from dataloader.
Defaults to None.
data_batch (dict or tuple or list, optional): Data from dataloader.
outputs (dict, optional): Outputs from model.
Defaults to None.
"""
if self.every_n_train_iters(runner, self.interval):
assert torch.isfinite(outputs['loss']),\
Expand Down
5 changes: 3 additions & 2 deletions mmengine/dataset/__init__.py
Expand Up @@ -2,10 +2,11 @@
from .base_dataset import BaseDataset, Compose, force_full_init
from .dataset_wrapper import ClassBalancedDataset, ConcatDataset, RepeatDataset
from .sampler import DefaultSampler, InfiniteSampler
from .utils import pseudo_collate, worker_init_fn
from .utils import (COLLATE_FUNCTIONS, default_collate, pseudo_collate,
worker_init_fn)

__all__ = [
'BaseDataset', 'Compose', 'force_full_init', 'ClassBalancedDataset',
'ConcatDataset', 'RepeatDataset', 'DefaultSampler', 'InfiniteSampler',
'worker_init_fn', 'pseudo_collate'
'worker_init_fn', 'pseudo_collate', 'COLLATE_FUNCTIONS', 'default_collate'
]
133 changes: 123 additions & 10 deletions mmengine/dataset/utils.py
@@ -1,11 +1,16 @@
# Copyright (c) OpenMMLab. All rights reserved.
import random
from typing import Sequence
from typing import Any, Mapping, Sequence

import numpy as np
import torch
from torch.utils.data._utils.collate import \
default_collate as torch_default_collate

DATA_BATCH = Sequence[dict]
from mmengine.registry import Registry
from mmengine.structures import BaseDataElement

COLLATE_FUNCTIONS = Registry('Collate Functions')


def worker_init_fn(worker_id: int, num_workers: int, rank: int,
Expand All @@ -28,16 +33,124 @@ def worker_init_fn(worker_id: int, num_workers: int, rank: int,
torch.manual_seed(worker_seed)


def pseudo_collate(data_batch: DATA_BATCH) -> DATA_BATCH:
"""The default behavior of dataloader is to merge a list of samples to form
a mini-batch of Tensor(s). However, in MMEngine, ``pseudo_collate`` does
nothing just returns ``data_batch``.
@COLLATE_FUNCTIONS.register_module()
def pseudo_collate(data_batch: Sequence) -> Any:
"""Convert list of data sampled from dataset into a batch of data, of which
type consistent with the type of each data_itement in ``data_batch``.
The default behavior of dataloader is to merge a list of samples to form
a mini-batch of Tensor(s). However, in MMEngine, ``pseudo_collate``
will not stack tensors to batch tensors, and convert int, float, ndarray to
tensors.
This code is referenced from:
`Pytorch default_collate <https://github.com/pytorch/pytorch/blob/master/torch/utils/data/_utils/collate.py>`_. # noqa: E501
Args:
data_batch (Sequence): Batch of data from dataloader.
Returns:
Any: Transversed Data in the same format as the data_itement of
``data_batch``.
"""
data_item = data_batch[0]
data_item_type = type(data_item)
if isinstance(data_item, (str, bytes)):
return data_batch
elif isinstance(data_item, tuple) and hasattr(data_item, '_fields'):
# named tuple
return data_item_type(*(pseudo_collate(samples)
for samples in zip(*data_batch)))
elif isinstance(data_item, Sequence):
# check to make sure that the data_itements in batch have
# consistent size
it = iter(data_batch)
data_item_size = len(next(it))
if not all(len(data_item) == data_item_size for data_item in it):
raise RuntimeError(
'each data_itement in list of batch should be of equal size')
transposed = list(zip(*data_batch))

if isinstance(data_item, tuple):
return [pseudo_collate(samples)
for samples in transposed] # Compat with Pytorch.
else:
try:
return data_item_type(
[pseudo_collate(samples) for samples in transposed])
except TypeError:
# The sequence type may not support `__init__(iterable)`
# (e.g., `range`).
return [pseudo_collate(samples) for samples in transposed]
elif isinstance(data_item, Mapping):
return data_item_type({
key: pseudo_collate([d[key] for d in data_batch])
for key in data_item
})
else:
return data_batch


@COLLATE_FUNCTIONS.register_module()
def default_collate(data_batch: Sequence) -> Any:
"""Convert list of data sampled from dataset into a batch of data, of which
type consistent with the type of each data_itement in ``data_batch``.
Different from :func:`pseudo_collate`, ``default_collate`` will stack
tensor contained in ``data_batch`` into a batched tensor with the
first dimension batch size, and then move input tensor to the target
device.
Different from ``default_collate`` in pytorch, ``default_collate`` will
not process ``BaseDataElement``.
This code is referenced from:
`Pytorch default_collate <https://github.com/pytorch/pytorch/blob/master/torch/utils/data/_utils/collate.py>`_. # noqa: E501
Note:
``default_collate`` only accept input tensor with the same shape.
Args:
data_batch (Sequence[dict]): Batch of data from
dataloader.
data_batch (Sequence): Data sampled from dataset.
Returns:
Sequence[dict]: Return input ``data_batch``.
Any: Data in the same format as the data_itement of ``data_batch``, of which
tensors have been stacked, and ndarray, int, float have been
converted to tensors.
"""
return data_batch
data_item = data_batch[0]
data_item_type = type(data_item)

if isinstance(data_item, (BaseDataElement, str, bytes)):
return data_batch
elif isinstance(data_item, tuple) and hasattr(data_item, '_fields'):
# named_tuple
return data_item_type(*(default_collate(samples)
for samples in zip(*data_batch)))
elif isinstance(data_item, Sequence):
# check to make sure that the data_itements in batch have
# consistent size
it = iter(data_batch)
data_item_size = len(next(it))
if not all(len(data_item) == data_item_size for data_item in it):
raise RuntimeError(
'each data_itement in list of batch should be of equal size')
transposed = list(zip(*data_batch))

if isinstance(data_item, tuple):
return [default_collate(samples)
for samples in transposed] # Compat with Pytorch.
else:
try:
return data_item_type(
[default_collate(samples) for samples in transposed])
except TypeError:
# The sequence type may not support `__init__(iterable)`
# (e.g., `range`).
return [default_collate(samples) for samples in transposed]
elif isinstance(data_item, Mapping):
return data_item_type({
key: default_collate([d[key] for d in data_batch])
for key in data_item
})
else:
return torch_default_collate(data_batch)
63 changes: 32 additions & 31 deletions mmengine/evaluator/evaluator.py
@@ -1,6 +1,7 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Iterator, List, Optional, Sequence, Union
from typing import Any, Iterator, List, Optional, Sequence, Union

from mmengine.dataset import pseudo_collate
from mmengine.registry import EVALUATOR, METRICS
from mmengine.structures import BaseDataElement
from .metric import BaseMetric
Expand Down Expand Up @@ -37,34 +38,26 @@ def dataset_meta(self, dataset_meta: dict) -> None:
for metric in self.metrics:
metric.dataset_meta = dataset_meta

def process(self, data_batch: Sequence[dict],
predictions: Sequence[BaseDataElement]):
def process(self,
data_samples: Sequence[BaseDataElement],
data_batch: Optional[Any] = None):
"""Convert ``BaseDataSample`` to dict and invoke process method of each
metric.
Args:
data_batch (Sequence[dict]): A batch of data from the dataloader.
predictions (Sequence[BaseDataElement]): A batch of outputs from
the model.
data_samples (Sequence[BaseDataElement]): predictions of the model,
and the ground truth of the validation set.
data_batch (Any, optional): A batch of data from the dataloader.
"""
_data_batch = []
for data in data_batch:
if isinstance(data['data_sample'], BaseDataElement):
_data_batch.append(
dict(
inputs=data['inputs'],
data_sample=data['data_sample'].to_dict()))
_data_samples = []
for data_sample in data_samples:
if isinstance(data_sample, BaseDataElement):
_data_samples.append(data_sample.to_dict())
else:
_data_batch.append(data)
_predictions = []
for pred in predictions:
if isinstance(pred, BaseDataElement):
_predictions.append(pred.to_dict())
else:
_predictions.append(pred)
_data_samples.append(data_sample)

for metric in self.metrics:
metric.process(_data_batch, _predictions)
metric.process(data_batch, _data_samples)

def evaluate(self, size: int) -> dict:
"""Invoke ``evaluate`` method of each metric and collect the metrics
Expand Down Expand Up @@ -97,20 +90,26 @@ def evaluate(self, size: int) -> dict:
return metrics

def offline_evaluate(self,
data: Sequence,
predictions: Sequence,
data_samples: Sequence,
data: Optional[Sequence] = None,
chunk_size: int = 1):
"""Offline evaluate the dumped predictions on the given data .
Args:
data (Sequence): All data of the validation set.
predictions (Sequence): All predictions of the model on the
validation set.
data_samples (Sequence): All predictions and ground truth of the
model and the validation set.
data (Sequence, optional): All data of the validation set.
chunk_size (int): The number of data samples and predictions to be
processed in a batch.
"""

# support chunking iterable objects
if data is not None:
assert len(data_samples) == len(data), (
'outputs and data should have the same length, but got '
f'outputs length: {len(data_samples)} '
f'data length: {len(data)}')

def get_chunks(seq: Iterator, chunk_size=1):
stop = False
while not stop:
Expand All @@ -125,9 +124,11 @@ def get_chunks(seq: Iterator, chunk_size=1):
yield chunk

size = 0
for data_chunk, pred_chunk in zip(
get_chunks(iter(data), chunk_size),
get_chunks(iter(predictions), chunk_size)):
size += len(data_chunk)
self.process(data_chunk, pred_chunk)
for output_chunk in get_chunks(iter(data_samples), chunk_size):
if data is not None:
data_chunk = pseudo_collate(data[size:size + chunk_size])
else:
data_chunk = None
size += len(output_chunk)
self.process(output_chunk, data_chunk)
return self.evaluate(size)
10 changes: 4 additions & 6 deletions mmengine/evaluator/metric.py
Expand Up @@ -58,15 +58,14 @@ def dataset_meta(self, dataset_meta: dict) -> None:
self._dataset_meta = dataset_meta

@abstractmethod
def process(self, data_batch: Sequence[dict],
predictions: Sequence[dict]) -> None:
def process(self, data_batch: Any, data_samples: Sequence[dict]) -> None:
"""Process one batch of data samples and predictions. The processed
results should be stored in ``self.results``, which will be used to
compute the metrics when all batches have been processed.
Args:
data_batch (Sequence[dict]): A batch of data from the dataloader.
predictions (Sequence[dict]): A batch of outputs from
data_batch (Any): A batch of data from the dataloader.
data_samples (Sequence[dict]): A batch of outputs from
the model.
"""

Expand Down Expand Up @@ -146,8 +145,7 @@ def __init__(self,
raise ValueError('The output file must be a pkl file.')
self.out_file_path = out_file_path

def process(self, data_batch: Sequence[dict],
predictions: Sequence[dict]) -> None:
def process(self, data_batch: Any, predictions: Sequence[dict]) -> None:
"""transfer tensors in predictions to CPU."""
self.results.extend(_to_cpu(predictions))

Expand Down
7 changes: 3 additions & 4 deletions mmengine/hooks/checkpoint_hook.py
Expand Up @@ -12,7 +12,7 @@
from mmengine.utils import is_list_of, is_seq_of
from .hook import Hook

DATA_BATCH = Optional[Sequence[dict]]
DATA_BATCH = Optional[Union[dict, tuple, list]]


@HOOKS.register_module()
Expand Down Expand Up @@ -470,9 +470,8 @@ def after_train_iter(self,
Args:
runner (Runner): The runner of the training process.
batch_idx (int): The index of the current batch in the train loop.
data_batch (Sequence[dict], optional): Data from dataloader.
Defaults to None.
outputs (dict, optional): Outputs from model. Defaults to None.
data_batch (dict or tuple or list, optional): Data from dataloader.
outputs (dict, optional): Outputs from model.
"""
if self.by_epoch:
return
Expand Down
10 changes: 3 additions & 7 deletions mmengine/hooks/empty_cache_hook.py
Expand Up @@ -4,10 +4,9 @@
import torch

from mmengine.registry import HOOKS
from mmengine.structures import BaseDataElement
from .hook import Hook

DATA_BATCH = Optional[Sequence[dict]]
DATA_BATCH = Optional[Union[dict, tuple, list]]


@HOOKS.register_module()
Expand Down Expand Up @@ -38,18 +37,15 @@ def _after_iter(self,
runner,
batch_idx: int,
data_batch: DATA_BATCH = None,
outputs: Optional[Union[dict,
Sequence[BaseDataElement]]] = None,
outputs: Optional[Union[dict, Sequence]] = None,
mode: str = 'train') -> None:
"""Empty cache after an iteration.
Args:
runner (Runner): The runner of the training process.
batch_idx (int): The index of the current batch in the loop.
data_batch (Sequence[dict], optional): Data from dataloader.
Defaults to None.
data_batch (dict or tuple or list, optional): Data from dataloader.
outputs (dict or sequence, optional): Outputs from model.
Defaults to None.
mode (str): Current mode of runner. Defaults to 'train'.
"""
if self._do_after_iter:
Expand Down

0 comments on commit 8770c6c

Please sign in to comment.