Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 9 additions & 4 deletions monai/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -1031,7 +1031,7 @@ def __init__(
self,
npzfile: Union[str, IO],
keys: Dict[str, str],
transform: Optional[Callable] = None,
transform: Optional[Callable[..., Dict[str, Any]]] = None,
other_keys: Optional[Sequence[str]] = (),
):
self.npzfile: Union[str, IO] = npzfile if isinstance(npzfile, str) else "STREAM"
Expand All @@ -1058,10 +1058,15 @@ def __len__(self):
def _transform(self, index: int):
data = {k: v[index] for k, v in self.arrays.items()}

if self.transform is not None:
data = apply_transform(self.transform, data)
if not self.transform:
return data

return data
transformed_data = apply_transform(self.transform, data)

if not isinstance(transformed_data, dict):
raise AssertionError("With a dict supplied to apply_transform a single dict return is expected.")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @Spenhouet ,

Thanks very much for your quick PR!
I think this assert is not correct here, some MONAI sampling transforms return a list of dictionaries instead of a single dict.
I will submit a quick PR to fix it soon, CC @wyli @ericspod @rijobro .
Others of this PR look good to me.

Thanks.


return transformed_data


class CSVDataset(Dataset):
Expand Down
18 changes: 10 additions & 8 deletions monai/data/image_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,20 +111,22 @@ def __getitem__(self, index: int):
if self.transform is not None:
if isinstance(self.transform, Randomizable):
self.transform.set_random_state(seed=self._seed)
img = apply_transform(
self.transform, (img, meta_data) if self.transform_with_metadata else img, map_items=False
)

if self.transform_with_metadata:
img, meta_data = img
img, meta_data = apply_transform(self.transform, (img, meta_data), map_items=False, unpack_items=True)
else:
img = apply_transform(self.transform, img, map_items=False)

if self.seg_transform is not None:
if isinstance(self.seg_transform, Randomizable):
self.seg_transform.set_random_state(seed=self._seed)
seg = apply_transform(
self.seg_transform, (seg, seg_meta_data) if self.transform_with_metadata else seg, map_items=False
)

if self.transform_with_metadata:
seg, seg_meta_data = seg
seg, seg_meta_data = apply_transform(
self.seg_transform, (seg, seg_meta_data), map_items=False, unpack_items=True
)
else:
seg = apply_transform(self.seg_transform, seg, map_items=False)

if self.labels is not None:
label = self.labels[index]
Expand Down
10 changes: 7 additions & 3 deletions monai/engines/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ def default_make_latent(
return torch.randn(num_latents, latent_size).to(device=device, non_blocking=non_blocking)


def engine_apply_transform(batch: Any, output: Any, transform: Callable):
def engine_apply_transform(batch: Any, output: Any, transform: Callable[..., Dict]):
"""
Apply transform for the engine.state.batch and engine.state.output.
If `batch` and `output` are dictionaries, temporarily combine them for the transform,
Expand All @@ -141,8 +141,12 @@ def engine_apply_transform(batch: Any, output: Any, transform: Callable):
if isinstance(batch, dict) and isinstance(output, dict):
data = dict(batch)
data.update(output)
data = apply_transform(transform, data)
for k, v in data.items():
transformed_data = apply_transform(transform, data)

if not isinstance(transformed_data, dict):
raise AssertionError("With a dict supplied to apply_transform a single dict return is expected.")

for k, v in transformed_data.items():
# split the output data of post transforms into `output` and `batch`,
# `batch` should be read-only, so save the generated key-value into `output`
if k in output or k not in batch:
Expand Down
10 changes: 7 additions & 3 deletions monai/transforms/compose.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,12 +102,16 @@ class Compose(Randomizable, InvertibleTransform):
"""

def __init__(
self, transforms: Optional[Union[Sequence[Callable], Callable]] = None, map_items: bool = True
self,
transforms: Optional[Union[Sequence[Callable], Callable]] = None,
map_items: bool = True,
unpack_items: bool = False,
) -> None:
if transforms is None:
transforms = []
self.transforms = ensure_tuple(transforms)
self.map_items = map_items
self.unpack_items = unpack_items
self.set_random_state(seed=get_seed())

def set_random_state(self, seed: Optional[int] = None, state: Optional[np.random.RandomState] = None) -> "Compose":
Expand Down Expand Up @@ -152,7 +156,7 @@ def __len__(self):

def __call__(self, input_):
for _transform in self.transforms:
input_ = apply_transform(_transform, input_, self.map_items)
input_ = apply_transform(_transform, input_, self.map_items, self.unpack_items)
return input_

def inverse(self, data):
Expand All @@ -162,5 +166,5 @@ def inverse(self, data):

# loop backwards over transforms
for t in reversed(invertible_transforms):
data = apply_transform(t.inverse, data, self.map_items)
data = apply_transform(t.inverse, data, self.map_items, self.unpack_items)
return data
49 changes: 42 additions & 7 deletions monai/transforms/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

import logging
from abc import ABC, abstractmethod
from typing import Any, Callable, Dict, Generator, Hashable, Iterable, List, Optional, Tuple
from typing import Any, Callable, Dict, Generator, Hashable, Iterable, List, Optional, Tuple, TypeVar, Union

import numpy as np
import torch
Expand All @@ -25,28 +25,63 @@

__all__ = ["ThreadUnsafe", "apply_transform", "Randomizable", "RandomizableTransform", "Transform", "MapTransform"]

ReturnType = TypeVar("ReturnType")

def apply_transform(transform: Callable, data, map_items: bool = True):

def _apply_transform(
transform: Callable[..., ReturnType], parameters: Any, unpack_parameters: bool = False
) -> ReturnType:
"""
Perform transformation `transform` with the provided parameters `parameters`.

If `parameters` is a tuple and `unpack_items` is True, each parameter of `parameters` is unpacked
as arguments to `transform`.
Otherwise `parameters` is considered as single argument to `transform`.

Args:
transform (Callable[..., ReturnType]): a callable to be used to transform `data`.
parameters (Any): parameters for the `transform`.
unpack_parameters (bool, optional): whether to unpack parameters for `transform`. Defaults to False.

Returns:
ReturnType: The return type of `transform`.
"""
if isinstance(parameters, tuple) and unpack_parameters:
return transform(*parameters)

return transform(parameters)


def apply_transform(
transform: Callable[..., ReturnType],
data: Any,
map_items: bool = True,
unpack_items: bool = False,
) -> Union[List[ReturnType], ReturnType]:
"""
Transform `data` with `transform`.

If `data` is a list or tuple and `map_data` is True, each item of `data` will be transformed
and this method returns a list of outcomes.
otherwise transform will be applied once with `data` as the argument.

Args:
transform: a callable to be used to transform `data`
data: an object to be transformed.
map_items: whether to apply transform to each item in `data`,
transform (Callable[..., ReturnType]): a callable to be used to transform `data`.
data (Any): an object to be transformed.
map_items (bool, optional): whether to apply transform to each item in `data`,
if `data` is a list or tuple. Defaults to True.
unpack_items (bool, optional): [description]. Defaults to False.

Raises:
Exception: When ``transform`` raises an exception.

Returns:
Union[List[ReturnType], ReturnType]: The return type of `transform` or a list thereof.
"""
try:
if isinstance(data, (list, tuple)) and map_items:
return [transform(item) for item in data]
return transform(data)
return [_apply_transform(transform, item, unpack_items) for item in data]
return _apply_transform(transform, data, unpack_items)
except Exception as e:

if not isinstance(transform, transforms.compose.Compose):
Expand Down
20 changes: 20 additions & 0 deletions tests/test_compose.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,26 @@ def c(d): # transform to handle dict data
for item in value:
self.assertDictEqual(item, {"a": 2, "b": 1, "c": 2})

def test_non_dict_compose_with_unpack(self):
def a(i, i2):
return i + "a", i2 + "a2"

def b(i, i2):
return i + "b", i2 + "b2"

c = Compose([a, b, a, b], map_items=False, unpack_items=True)
self.assertEqual(c(("", "")), ("abab", "a2b2a2b2"))

def test_list_non_dict_compose_with_unpack(self):
def a(i, i2):
return i + "a", i2 + "a2"

def b(i, i2):
return i + "b", i2 + "b2"

c = Compose([a, b, a, b], unpack_items=True)
self.assertEqual(c([("", ""), ("t", "t")]), [("abab", "a2b2a2b2"), ("tabab", "ta2b2a2b2")])

def test_list_dict_compose_no_map(self):
def a(d): # transform to handle dict data
d = dict(d)
Expand Down
3 changes: 1 addition & 2 deletions tests/test_image_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,7 @@ def __call__(self, data):


class _TestCompose(Compose):
def __call__(self, input_):
data, meta = input_
def __call__(self, data, meta):
data = self.transforms[0](data, meta) # ensure channel first
data, _, meta["affine"] = self.transforms[1](data, meta["affine"]) # spacing
if len(self.transforms) == 3:
Expand Down