Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 7 additions & 11 deletions monai/apps/deepgrow/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import json
from typing import Callable, Dict, Optional, Sequence, Union
from typing import Callable, Dict, Hashable, List, Optional, Sequence, Union

import numpy as np
import torch
Expand All @@ -19,15 +19,7 @@
from monai.transforms import Resize, SpatialCrop
from monai.transforms.transform import MapTransform, Randomizable, Transform
from monai.transforms.utils import generate_spatial_bounding_box, is_positive
from monai.utils import (
InterpolateMode,
deprecated_arg,
ensure_tuple,
ensure_tuple_rep,
first,
min_version,
optional_import,
)
from monai.utils import InterpolateMode, deprecated_arg, ensure_tuple, ensure_tuple_rep, min_version, optional_import

measure, _ = optional_import("skimage.measure", "0.14.2", min_version)
distance_transform_cdt, _ = optional_import("scipy.ndimage.morphology", name="distance_transform_cdt")
Expand Down Expand Up @@ -652,8 +644,12 @@ def bounding_box(self, points, img_shape):

def __call__(self, data):
d: Dict = dict(data)
first_key: Union[Hashable, List] = self.first_key(d)
if first_key == []:
return d

guidance = d[self.guidance]
original_spatial_shape = d[first(self.key_iterator(d))].shape[1:]
original_spatial_shape = d[first_key].shape[1:] # type: ignore
box_start, box_end = self.bounding_box(np.array(guidance[0] + guidance[1]), original_spatial_shape)
center = list(np.mean([box_start, box_end], axis=0).astype(int))
spatial_size = self.spatial_size
Expand Down
20 changes: 16 additions & 4 deletions monai/transforms/croppad/dictionary.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@
weighted_patch_samples,
)
from monai.utils import ImageMetaKey as Key
from monai.utils import Method, NumpyPadMode, PytorchPadMode, ensure_tuple, ensure_tuple_rep, fall_back_tuple, first
from monai.utils import Method, NumpyPadMode, PytorchPadMode, ensure_tuple, ensure_tuple_rep, fall_back_tuple
from monai.utils.enums import TraceKeys

__all__ = [
Expand Down Expand Up @@ -480,8 +480,12 @@ def __init__(

def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]:
d = dict(data)
first_key: Union[Hashable, List] = self.first_key(d)
if first_key == []:
return d

# use the spatial size of first image to scale, expect all images have the same spatial size
img_size = d[first(self.key_iterator(d))].shape[1:]
img_size = d[first_key].shape[1:] # type: ignore
ndim = len(img_size)
roi_size = [ceil(r * s) for r, s in zip(ensure_tuple_rep(self.roi_scale, ndim), img_size)]
cropper = CenterSpatialCrop(roi_size)
Expand Down Expand Up @@ -575,7 +579,11 @@ def randomize(self, img_size: Sequence[int]) -> None:

def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]:
d = dict(data)
self.randomize(d[first(self.key_iterator(d))].shape[1:]) # image shape from the first data key
first_key: Union[Hashable, List] = self.first_key(d)
if first_key == []:
return d

self.randomize(d[first_key].shape[1:]) # type: ignore
if self._size is None:
raise RuntimeError("self._size not specified.")
for key in self.key_iterator(d):
Expand Down Expand Up @@ -669,7 +677,11 @@ def __init__(
self.max_roi_scale = max_roi_scale

def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]:
img_size = data[first(self.key_iterator(data))].shape[1:] # type: ignore
first_key: Union[Hashable, List] = self.first_key(data) # type: ignore
if first_key == []:
return data # type: ignore

img_size = data[first_key].shape[1:] # type: ignore
ndim = len(img_size)
self.roi_size = [ceil(r * s) for r, s in zip(ensure_tuple_rep(self.roi_scale, ndim), img_size)]
if self.max_roi_scale is not None:
Expand Down
28 changes: 22 additions & 6 deletions monai/transforms/intensity/dictionary.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
Class names are ended with 'd' to denote dictionary-based transforms.
"""

from typing import Callable, Dict, Hashable, Mapping, Optional, Sequence, Tuple, Union
from typing import Callable, Dict, Hashable, List, Mapping, Optional, Sequence, Tuple, Union

import numpy as np

Expand Down Expand Up @@ -53,7 +53,7 @@
)
from monai.transforms.transform import MapTransform, RandomizableTransform
from monai.transforms.utils import is_positive
from monai.utils import ensure_tuple, ensure_tuple_rep, first
from monai.utils import ensure_tuple, ensure_tuple_rep
from monai.utils.deprecate_utils import deprecated_arg

__all__ = [
Expand Down Expand Up @@ -187,7 +187,11 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, N
return d

# all the keys share the same random noise
self.rand_gaussian_noise.randomize(d[first(self.key_iterator(d))])
first_key: Union[Hashable, List] = self.first_key(d)
if first_key == []:
return d

self.rand_gaussian_noise.randomize(d[first_key]) # type: ignore
for key in self.key_iterator(d):
d[key] = self.rand_gaussian_noise(img=d[key], randomize=False)
return d
Expand Down Expand Up @@ -621,7 +625,11 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, N
return d

# all the keys share the same random bias factor
self.rand_bias_field.randomize(img_size=d[first(self.key_iterator(d))].shape[1:])
first_key: Union[Hashable, List] = self.first_key(d)
if first_key == []:
return d

self.rand_bias_field.randomize(img_size=d[first_key].shape[1:]) # type: ignore
for key in self.key_iterator(d):
d[key] = self.rand_bias_field(d[key], randomize=False)
return d
Expand Down Expand Up @@ -1466,7 +1474,11 @@ def __call__(self, data):
return d

# expect all the specified keys have same spatial shape and share same random holes
self.dropper.randomize(d[first(self.key_iterator(d))].shape[1:])
first_key: Union[Hashable, List] = self.first_key(d)
if first_key == []:
return d

self.dropper.randomize(d[first_key].shape[1:]) # type: ignore
for key in self.key_iterator(d):
d[key] = self.dropper(img=d[key], randomize=False)

Expand Down Expand Up @@ -1531,7 +1543,11 @@ def __call__(self, data):
return d

# expect all the specified keys have same spatial shape and share same random holes
self.shuffle.randomize(d[first(self.key_iterator(d))].shape[1:])
first_key: Union[Hashable, List] = self.first_key(d)
if first_key == []:
return d

self.shuffle.randomize(d[first_key].shape[1:]) # type: ignore
for key in self.key_iterator(d):
d[key] = self.shuffle(img=d[key], randomize=False)

Expand Down
6 changes: 4 additions & 2 deletions monai/transforms/post/dictionary.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,11 +359,13 @@ def __init__(
def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]:
d = dict(data)
items: Union[List[NdarrayOrTensor], NdarrayOrTensor]
if len(self.keys) == 1:
if len(self.keys) == 1 and self.keys[0] in d:
items = d[self.keys[0]]
else:
items = [d[key] for key in self.key_iterator(d)]
d[self.output_key] = self.ensemble(items)

if len(items) > 0:
d[self.output_key] = self.ensemble(items)

return d

Expand Down
38 changes: 30 additions & 8 deletions monai/transforms/spatial/dictionary.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,6 @@
ensure_tuple,
ensure_tuple_rep,
fall_back_tuple,
first,
)
from monai.utils.deprecate_utils import deprecated_arg
from monai.utils.enums import TraceKeys
Expand Down Expand Up @@ -817,13 +816,16 @@ def set_random_state(

def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]:
d = dict(data)
first_key: Union[Hashable, List] = self.first_key(d)
if first_key == []:
return d

self.randomize(None)
# all the keys share the same random Affine factor
self.rand_affine.randomize()

device = self.rand_affine.resampler.device

spatial_size = d[first(self.key_iterator(d))].shape[1:]
spatial_size = d[first_key].shape[1:] # type: ignore
sp_size = fall_back_tuple(self.rand_affine.spatial_size, spatial_size)
# change image size or do random transform
do_resampling = self._do_transform or (sp_size != ensure_tuple(spatial_size))
Expand Down Expand Up @@ -982,9 +984,13 @@ def set_random_state(

def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]:
d = dict(data)
first_key: Union[Hashable, List] = self.first_key(d)
if first_key == []:
return d

self.randomize(None)

sp_size = fall_back_tuple(self.rand_2d_elastic.spatial_size, d[first(self.key_iterator(d))].shape[1:])
sp_size = fall_back_tuple(self.rand_2d_elastic.spatial_size, d[first_key].shape[1:]) # type: ignore
# all the keys share the same random elastic factor
self.rand_2d_elastic.randomize(sp_size)

Expand Down Expand Up @@ -1114,9 +1120,13 @@ def set_random_state(

def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]:
d = dict(data)
first_key: Union[Hashable, List] = self.first_key(d)
if first_key == []:
return d

self.randomize(None)

sp_size = fall_back_tuple(self.rand_3d_elastic.spatial_size, d[first(self.key_iterator(d))].shape[1:])
sp_size = fall_back_tuple(self.rand_3d_elastic.spatial_size, d[first_key].shape[1:]) # type: ignore
# all the keys share the same random elastic factor
self.rand_3d_elastic.randomize(sp_size)

Expand Down Expand Up @@ -1264,10 +1274,14 @@ def set_random_state(

def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]:
d = dict(data)
first_key: Union[Hashable, List] = self.first_key(d)
if first_key == []:
return d

self.randomize(None)

# all the keys share the same random selected axis
self.flipper.randomize(d[first(self.key_iterator(d))])
self.flipper.randomize(d[first_key]) # type: ignore
for key in self.key_iterator(d):
if self._do_transform:
d[key] = self.flipper(d[key], randomize=False)
Expand Down Expand Up @@ -1688,10 +1702,14 @@ def set_random_state(

def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]:
d = dict(data)
first_key: Union[Hashable, List] = self.first_key(d)
if first_key == []:
return d

self.randomize(None)

# all the keys share the same random zoom factor
self.rand_zoom.randomize(d[first(self.key_iterator(d))])
self.rand_zoom.randomize(d[first_key]) # type: ignore
for key, mode, padding_mode, align_corners in self.key_iterator(
d, self.mode, self.padding_mode, self.align_corners
):
Expand Down Expand Up @@ -1873,7 +1891,11 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, N
if not self._do_transform:
return d

self.rand_grid_distortion.randomize(d[first(self.key_iterator(d))].shape[1:])
first_key: Union[Hashable, List] = self.first_key(d)
if first_key == []:
return d

self.rand_grid_distortion.randomize(d[first_key].shape[1:]) # type: ignore
for key, mode, padding_mode in self.key_iterator(d, self.mode, self.padding_mode):
d[key] = self.rand_grid_distortion(d[key], mode=mode, padding_mode=padding_mode, randomize=False)
return d
Expand Down
13 changes: 12 additions & 1 deletion monai/transforms/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

from monai import transforms
from monai.config import KeysCollection
from monai.utils import MAX_SEED, ensure_tuple
from monai.utils import MAX_SEED, ensure_tuple, first
from monai.utils.enums import TransformBackends

__all__ = ["ThreadUnsafe", "apply_transform", "Randomizable", "RandomizableTransform", "Transform", "MapTransform"]
Expand Down Expand Up @@ -361,3 +361,14 @@ def key_iterator(self, data: Dict[Hashable, Any], *extra_iterables: Optional[Ite
yield (key,) + tuple(_ex_iters) if extra_iterables else key
elif not self.allow_missing_keys:
raise KeyError(f"Key was missing ({key}) and allow_missing_keys==False")

def first_key(self, data: Dict[Hashable, Any]):
"""
Get the first available key of `self.keys` in the input `data` dictionary.
If no available key, return an empty list `[]`.

Args:
data: data that the transform will be applied to.

"""
return first(self.key_iterator(data), [])
4 changes: 4 additions & 0 deletions monai/transforms/utility/dictionary.py
Original file line number Diff line number Diff line change
Expand Up @@ -917,6 +917,10 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, N
elif not isinstance(d[key], data_type):
raise TypeError("All items in data must have the same type.")
output.append(d[key])

if len(output) == 0:
return d

if data_type is np.ndarray:
d[self.name] = np.concatenate(output, axis=self.dim)
elif data_type is torch.Tensor:
Expand Down
8 changes: 7 additions & 1 deletion tests/test_center_scale_cropd.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,15 @@
(3, 2, 2, 2),
]

TEST_CASE_4 = [
{"keys": "test", "roi_scale": 0.6, "allow_missing_keys": True},
np.random.randint(0, 2, size=[3, 3, 3, 3]),
(3, 3, 3, 3),
]


class TestCenterScaleCropd(unittest.TestCase):
@parameterized.expand([TEST_CASE_0, TEST_CASE_1, TEST_CASE_3])
@parameterized.expand([TEST_CASE_0, TEST_CASE_1, TEST_CASE_3, TEST_CASE_4])
def test_shape(self, input_param, input_data, expected_shape):
result = CenterScaleCropd(**input_param)({"img": input_data})
np.testing.assert_allclose(result["img"].shape, expected_shape)
Expand Down