From 1fccbb9cce01451249094f0c7f9b031c3d54063b Mon Sep 17 00:00:00 2001 From: Sebastian Penhouet Date: Thu, 24 Jun 2021 11:25:24 +0200 Subject: [PATCH 1/2] Add parameter unpacking for transforms Signed-off-by: Sebastian Penhouet --- monai/transforms/compose.py | 10 ++++--- monai/transforms/transform.py | 49 ++++++++++++++++++++++++++++++----- tests/test_compose.py | 20 ++++++++++++++ 3 files changed, 69 insertions(+), 10 deletions(-) diff --git a/monai/transforms/compose.py b/monai/transforms/compose.py index 085d5dcfc4..7fa8d6600b 100644 --- a/monai/transforms/compose.py +++ b/monai/transforms/compose.py @@ -102,12 +102,16 @@ class Compose(Randomizable, InvertibleTransform): """ def __init__( - self, transforms: Optional[Union[Sequence[Callable], Callable]] = None, map_items: bool = True + self, + transforms: Optional[Union[Sequence[Callable], Callable]] = None, + map_items: bool = True, + unpack_items: bool = False, ) -> None: if transforms is None: transforms = [] self.transforms = ensure_tuple(transforms) self.map_items = map_items + self.unpack_items = unpack_items self.set_random_state(seed=get_seed()) def set_random_state(self, seed: Optional[int] = None, state: Optional[np.random.RandomState] = None) -> "Compose": @@ -152,7 +156,7 @@ def __len__(self): def __call__(self, input_): for _transform in self.transforms: - input_ = apply_transform(_transform, input_, self.map_items) + input_ = apply_transform(_transform, input_, self.map_items, self.unpack_items) return input_ def inverse(self, data): @@ -162,5 +166,5 @@ def inverse(self, data): # loop backwards over transforms for t in reversed(invertible_transforms): - data = apply_transform(t.inverse, data, self.map_items) + data = apply_transform(t.inverse, data, self.map_items, self.unpack_items) return data diff --git a/monai/transforms/transform.py b/monai/transforms/transform.py index e5715ee702..da244b4a5d 100644 --- a/monai/transforms/transform.py +++ b/monai/transforms/transform.py @@ -14,7 +14,7 @@ import logging from abc import ABC, abstractmethod -from typing import Any, Callable, Dict, Generator, Hashable, Iterable, List, Optional, Tuple +from typing import Any, Callable, Dict, Generator, Hashable, Iterable, List, Optional, Tuple, TypeVar, Union import numpy as np import torch @@ -25,28 +25,63 @@ __all__ = ["ThreadUnsafe", "apply_transform", "Randomizable", "RandomizableTransform", "Transform", "MapTransform"] +ReturnType = TypeVar("ReturnType") -def apply_transform(transform: Callable, data, map_items: bool = True): + +def _apply_transform( + transform: Callable[..., ReturnType], parameters: Any, unpack_parameters: bool = False +) -> ReturnType: + """ + Perform transformation `transform` with the provided parameters `parameters`. + + If `parameters` is a tuple and `unpack_items` is True, each parameter of `parameters` is unpacked + as arguments to `transform`. + Otherwise `parameters` is considered as single argument to `transform`. + + Args: + transform (Callable[..., ReturnType]): a callable to be used to transform `data`. + parameters (Any): parameters for the `transform`. + unpack_parameters (bool, optional): whether to unpack parameters for `transform`. Defaults to False. + + Returns: + ReturnType: The return type of `transform`. + """ + if isinstance(parameters, tuple) and unpack_parameters: + return transform(*parameters) + + return transform(parameters) + + +def apply_transform( + transform: Callable[..., ReturnType], + data: Any, + map_items: bool = True, + unpack_items: bool = False, +) -> Union[List[ReturnType], ReturnType]: """ Transform `data` with `transform`. + If `data` is a list or tuple and `map_data` is True, each item of `data` will be transformed and this method returns a list of outcomes. otherwise transform will be applied once with `data` as the argument. Args: - transform: a callable to be used to transform `data` - data: an object to be transformed. - map_items: whether to apply transform to each item in `data`, + transform (Callable[..., ReturnType]): a callable to be used to transform `data`. + data (Any): an object to be transformed. + map_items (bool, optional): whether to apply transform to each item in `data`, if `data` is a list or tuple. Defaults to True. + unpack_items (bool, optional): [description]. Defaults to False. Raises: Exception: When ``transform`` raises an exception. + Returns: + Union[List[ReturnType], ReturnType]: The return type of `transform` or a list thereof. """ try: if isinstance(data, (list, tuple)) and map_items: - return [transform(item) for item in data] - return transform(data) + return [_apply_transform(transform, item, unpack_items) for item in data] + return _apply_transform(transform, data, unpack_items) except Exception as e: if not isinstance(transform, transforms.compose.Compose): diff --git a/tests/test_compose.py b/tests/test_compose.py index 77736a4c77..28783cad23 100644 --- a/tests/test_compose.py +++ b/tests/test_compose.py @@ -79,6 +79,26 @@ def c(d): # transform to handle dict data for item in value: self.assertDictEqual(item, {"a": 2, "b": 1, "c": 2}) + def test_non_dict_compose_with_unpack(self): + def a(i, i2): + return i + "a", i2 + "a2" + + def b(i, i2): + return i + "b", i2 + "b2" + + c = Compose([a, b, a, b], map_items=False, unpack_items=True) + self.assertEqual(c(("", "")), ("abab", "a2b2a2b2")) + + def test_list_non_dict_compose_with_unpack(self): + def a(i, i2): + return i + "a", i2 + "a2" + + def b(i, i2): + return i + "b", i2 + "b2" + + c = Compose([a, b, a, b], unpack_items=True) + self.assertEqual(c([("", ""), ("t", "t")]), [("abab", "a2b2a2b2"), ("tabab", "ta2b2a2b2")]) + def test_list_dict_compose_no_map(self): def a(d): # transform to handle dict data d = dict(d) From 0de158133bdc49281cc2c36dd82d78d5b10732de Mon Sep 17 00:00:00 2001 From: Sebastian Penhouet Date: Thu, 24 Jun 2021 15:07:26 +0200 Subject: [PATCH 2/2] Fix mypy errors and enable unpacking in image_dataset Signed-off-by: Sebastian Penhouet --- monai/data/dataset.py | 13 +++++++++---- monai/data/image_dataset.py | 18 ++++++++++-------- monai/engines/utils.py | 10 +++++++--- tests/test_image_dataset.py | 3 +-- 4 files changed, 27 insertions(+), 17 deletions(-) diff --git a/monai/data/dataset.py b/monai/data/dataset.py index 74b9726081..18540abab0 100644 --- a/monai/data/dataset.py +++ b/monai/data/dataset.py @@ -1031,7 +1031,7 @@ def __init__( self, npzfile: Union[str, IO], keys: Dict[str, str], - transform: Optional[Callable] = None, + transform: Optional[Callable[..., Dict[str, Any]]] = None, other_keys: Optional[Sequence[str]] = (), ): self.npzfile: Union[str, IO] = npzfile if isinstance(npzfile, str) else "STREAM" @@ -1058,10 +1058,15 @@ def __len__(self): def _transform(self, index: int): data = {k: v[index] for k, v in self.arrays.items()} - if self.transform is not None: - data = apply_transform(self.transform, data) + if not self.transform: + return data - return data + transformed_data = apply_transform(self.transform, data) + + if not isinstance(transformed_data, dict): + raise AssertionError("With a dict supplied to apply_transform a single dict return is expected.") + + return transformed_data class CSVDataset(Dataset): diff --git a/monai/data/image_dataset.py b/monai/data/image_dataset.py index fdde151b2f..874b9dc004 100644 --- a/monai/data/image_dataset.py +++ b/monai/data/image_dataset.py @@ -111,20 +111,22 @@ def __getitem__(self, index: int): if self.transform is not None: if isinstance(self.transform, Randomizable): self.transform.set_random_state(seed=self._seed) - img = apply_transform( - self.transform, (img, meta_data) if self.transform_with_metadata else img, map_items=False - ) + if self.transform_with_metadata: - img, meta_data = img + img, meta_data = apply_transform(self.transform, (img, meta_data), map_items=False, unpack_items=True) + else: + img = apply_transform(self.transform, img, map_items=False) if self.seg_transform is not None: if isinstance(self.seg_transform, Randomizable): self.seg_transform.set_random_state(seed=self._seed) - seg = apply_transform( - self.seg_transform, (seg, seg_meta_data) if self.transform_with_metadata else seg, map_items=False - ) + if self.transform_with_metadata: - seg, seg_meta_data = seg + seg, seg_meta_data = apply_transform( + self.seg_transform, (seg, seg_meta_data), map_items=False, unpack_items=True + ) + else: + seg = apply_transform(self.seg_transform, seg, map_items=False) if self.labels is not None: label = self.labels[index] diff --git a/monai/engines/utils.py b/monai/engines/utils.py index 34d37836a3..c234a46296 100644 --- a/monai/engines/utils.py +++ b/monai/engines/utils.py @@ -131,7 +131,7 @@ def default_make_latent( return torch.randn(num_latents, latent_size).to(device=device, non_blocking=non_blocking) -def engine_apply_transform(batch: Any, output: Any, transform: Callable): +def engine_apply_transform(batch: Any, output: Any, transform: Callable[..., Dict]): """ Apply transform for the engine.state.batch and engine.state.output. If `batch` and `output` are dictionaries, temporarily combine them for the transform, @@ -141,8 +141,12 @@ def engine_apply_transform(batch: Any, output: Any, transform: Callable): if isinstance(batch, dict) and isinstance(output, dict): data = dict(batch) data.update(output) - data = apply_transform(transform, data) - for k, v in data.items(): + transformed_data = apply_transform(transform, data) + + if not isinstance(transformed_data, dict): + raise AssertionError("With a dict supplied to apply_transform a single dict return is expected.") + + for k, v in transformed_data.items(): # split the output data of post transforms into `output` and `batch`, # `batch` should be read-only, so save the generated key-value into `output` if k in output or k not in batch: diff --git a/tests/test_image_dataset.py b/tests/test_image_dataset.py index 173d24f350..3b3c06c87c 100644 --- a/tests/test_image_dataset.py +++ b/tests/test_image_dataset.py @@ -36,8 +36,7 @@ def __call__(self, data): class _TestCompose(Compose): - def __call__(self, input_): - data, meta = input_ + def __call__(self, data, meta): data = self.transforms[0](data, meta) # ensure channel first data, _, meta["affine"] = self.transforms[1](data, meta["affine"]) # spacing if len(self.transforms) == 3: