From aa646594757b82340f5cee9f4dced41e52aa6337 Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Fri, 12 Nov 2021 12:03:41 +0800 Subject: [PATCH 1/4] [DLMED] enhance no keys for allow_missing_keys Signed-off-by: Nic Ma --- monai/apps/deepgrow/transforms.py | 6 +++- monai/transforms/croppad/dictionary.py | 18 ++++++++++-- monai/transforms/intensity/dictionary.py | 24 ++++++++++++--- monai/transforms/post/dictionary.py | 6 ++-- monai/transforms/spatial/dictionary.py | 37 +++++++++++++++++++----- monai/transforms/utility/dictionary.py | 4 +++ tests/test_center_scale_cropd.py | 8 ++++- 7 files changed, 85 insertions(+), 18 deletions(-) diff --git a/monai/apps/deepgrow/transforms.py b/monai/apps/deepgrow/transforms.py index 7428173932..2357af9cbe 100644 --- a/monai/apps/deepgrow/transforms.py +++ b/monai/apps/deepgrow/transforms.py @@ -652,8 +652,12 @@ def bounding_box(self, points, img_shape): def __call__(self, data): d: Dict = dict(data) + image_key = first(self.key_iterator(d)) + if image_key is None: + return d + guidance = d[self.guidance] - original_spatial_shape = d[first(self.key_iterator(d))].shape[1:] + original_spatial_shape = d[image_key].shape[1:] box_start, box_end = self.bounding_box(np.array(guidance[0] + guidance[1]), original_spatial_shape) center = list(np.mean([box_start, box_end], axis=0).astype(int)) spatial_size = self.spatial_size diff --git a/monai/transforms/croppad/dictionary.py b/monai/transforms/croppad/dictionary.py index 58e40a3e3b..824a9b0305 100644 --- a/monai/transforms/croppad/dictionary.py +++ b/monai/transforms/croppad/dictionary.py @@ -480,8 +480,12 @@ def __init__( def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: d = dict(data) + image_key = first(self.key_iterator(d)) + if image_key is None: + return d + # use the spatial size of first image to scale, expect all images have the same spatial size - img_size = d[first(self.key_iterator(d))].shape[1:] + img_size = d[image_key].shape[1:] ndim = len(img_size) roi_size = [ceil(r * s) for r, s in zip(ensure_tuple_rep(self.roi_scale, ndim), img_size)] cropper = CenterSpatialCrop(roi_size) @@ -575,7 +579,11 @@ def randomize(self, img_size: Sequence[int]) -> None: def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: d = dict(data) - self.randomize(d[first(self.key_iterator(d))].shape[1:]) # image shape from the first data key + image_key = first(self.key_iterator(d)) + if image_key is None: + return d + + self.randomize(d[image_key].shape[1:]) # image shape from the first data key if self._size is None: raise RuntimeError("self._size not specified.") for key in self.key_iterator(d): @@ -669,7 +677,11 @@ def __init__( self.max_roi_scale = max_roi_scale def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: - img_size = data[first(self.key_iterator(data))].shape[1:] # type: ignore + image_key = first(self.key_iterator(data)) # type: ignore + if image_key is None: + return dict(data) + + img_size = data[image_key].shape[1:] ndim = len(img_size) self.roi_size = [ceil(r * s) for r, s in zip(ensure_tuple_rep(self.roi_scale, ndim), img_size)] if self.max_roi_scale is not None: diff --git a/monai/transforms/intensity/dictionary.py b/monai/transforms/intensity/dictionary.py index ca047890e8..6e7546c0bd 100644 --- a/monai/transforms/intensity/dictionary.py +++ b/monai/transforms/intensity/dictionary.py @@ -187,7 +187,11 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, N return d # all the keys share the same random noise - self.rand_gaussian_noise.randomize(d[first(self.key_iterator(d))]) + image_key = first(self.key_iterator(d)) + if image_key is None: + return d + + self.rand_gaussian_noise.randomize(d[image_key]) for key in self.key_iterator(d): d[key] = self.rand_gaussian_noise(img=d[key], randomize=False) return d @@ -621,7 +625,11 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, N return d # all the keys share the same random bias factor - self.rand_bias_field.randomize(img_size=d[first(self.key_iterator(d))].shape[1:]) + image_key = first(self.key_iterator(d)) + if image_key is None: + return d + + self.rand_bias_field.randomize(img_size=d[image_key].shape[1:]) for key in self.key_iterator(d): d[key] = self.rand_bias_field(d[key], randomize=False) return d @@ -1466,7 +1474,11 @@ def __call__(self, data): return d # expect all the specified keys have same spatial shape and share same random holes - self.dropper.randomize(d[first(self.key_iterator(d))].shape[1:]) + image_key = first(self.key_iterator(d)) + if image_key is None: + return d + + self.dropper.randomize(d[image_key].shape[1:]) for key in self.key_iterator(d): d[key] = self.dropper(img=d[key], randomize=False) @@ -1531,7 +1543,11 @@ def __call__(self, data): return d # expect all the specified keys have same spatial shape and share same random holes - self.shuffle.randomize(d[first(self.key_iterator(d))].shape[1:]) + image_key = first(self.key_iterator(d)) + if image_key is None: + return d + + self.shuffle.randomize(d[image_key].shape[1:]) for key in self.key_iterator(d): d[key] = self.shuffle(img=d[key], randomize=False) diff --git a/monai/transforms/post/dictionary.py b/monai/transforms/post/dictionary.py index 596b4b3a21..fc49120b02 100644 --- a/monai/transforms/post/dictionary.py +++ b/monai/transforms/post/dictionary.py @@ -361,11 +361,13 @@ def __init__( def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: d = dict(data) items: Union[List[NdarrayOrTensor], NdarrayOrTensor] - if len(self.keys) == 1: + if len(self.keys) == 1 and self.keys[0] in d: items = d[self.keys[0]] else: items = [d[key] for key in self.key_iterator(d)] - d[self.output_key] = self.ensemble(items) + + if len(items) > 0: + d[self.output_key] = self.ensemble(items) return d diff --git a/monai/transforms/spatial/dictionary.py b/monai/transforms/spatial/dictionary.py index 8bfdd6fd52..21a05d045b 100644 --- a/monai/transforms/spatial/dictionary.py +++ b/monai/transforms/spatial/dictionary.py @@ -812,13 +812,16 @@ def set_random_state( def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: d = dict(data) + image_key = first(self.key_iterator(d)) + if image_key is None: + return d + self.randomize(None) # all the keys share the same random Affine factor self.rand_affine.randomize() device = self.rand_affine.resampler.device - - spatial_size = d[first(self.key_iterator(d))].shape[1:] + spatial_size = d[image_key].shape[1:] sp_size = fall_back_tuple(self.rand_affine.spatial_size, spatial_size) # change image size or do random transform do_resampling = self._do_transform or (sp_size != ensure_tuple(spatial_size)) @@ -977,9 +980,13 @@ def set_random_state( def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: d = dict(data) + image_key = first(self.key_iterator(d)) + if image_key is None: + return d + self.randomize(None) - sp_size = fall_back_tuple(self.rand_2d_elastic.spatial_size, d[first(self.key_iterator(d))].shape[1:]) + sp_size = fall_back_tuple(self.rand_2d_elastic.spatial_size, d[image_key].shape[1:]) # all the keys share the same random elastic factor self.rand_2d_elastic.randomize(sp_size) @@ -1109,9 +1116,13 @@ def set_random_state( def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: d = dict(data) + image_key = first(self.key_iterator(d)) + if image_key is None: + return d + self.randomize(None) - sp_size = fall_back_tuple(self.rand_3d_elastic.spatial_size, d[first(self.key_iterator(d))].shape[1:]) + sp_size = fall_back_tuple(self.rand_3d_elastic.spatial_size, d[image_key].shape[1:]) # all the keys share the same random elastic factor self.rand_3d_elastic.randomize(sp_size) @@ -1259,10 +1270,14 @@ def set_random_state( def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: d = dict(data) + image_key = first(self.key_iterator(d)) + if image_key is None: + return d + self.randomize(None) # all the keys share the same random selected axis - self.flipper.randomize(d[first(self.key_iterator(d))]) + self.flipper.randomize(d[image_key]) for key in self.key_iterator(d): if self._do_transform: d[key] = self.flipper(d[key], randomize=False) @@ -1683,10 +1698,14 @@ def set_random_state( def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: d = dict(data) + image_key = first(self.key_iterator(d)) + if image_key is None: + return d + self.randomize(None) # all the keys share the same random zoom factor - self.rand_zoom.randomize(d[first(self.key_iterator(d))]) + self.rand_zoom.randomize(d[image_key]) for key, mode, padding_mode, align_corners in self.key_iterator( d, self.mode, self.padding_mode, self.align_corners ): @@ -1868,7 +1887,11 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, N if not self._do_transform: return d - self.rand_grid_distortion.randomize(d[first(self.key_iterator(d))].shape[1:]) + image_key = first(self.key_iterator(d)) + if image_key is None: + return d + + self.rand_grid_distortion.randomize(d[image_key].shape[1:]) for key, mode, padding_mode in self.key_iterator(d, self.mode, self.padding_mode): d[key] = self.rand_grid_distortion(d[key], mode=mode, padding_mode=padding_mode, randomize=False) return d diff --git a/monai/transforms/utility/dictionary.py b/monai/transforms/utility/dictionary.py index 1412790227..05a3c4b0f9 100644 --- a/monai/transforms/utility/dictionary.py +++ b/monai/transforms/utility/dictionary.py @@ -917,6 +917,10 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, N elif not isinstance(d[key], data_type): raise TypeError("All items in data must have the same type.") output.append(d[key]) + + if len(output) == 0: + return d + if data_type is np.ndarray: d[self.name] = np.concatenate(output, axis=self.dim) elif data_type is torch.Tensor: diff --git a/tests/test_center_scale_cropd.py b/tests/test_center_scale_cropd.py index 313e8e7f7e..a827d611e6 100644 --- a/tests/test_center_scale_cropd.py +++ b/tests/test_center_scale_cropd.py @@ -33,9 +33,15 @@ (3, 2, 2, 2), ] +TEST_CASE_4 = [ + {"keys": "test", "roi_scale": 0.6, "allow_missing_keys": True}, + np.random.randint(0, 2, size=[3, 3, 3, 3]), + (3, 3, 3, 3), +] + class TestCenterScaleCropd(unittest.TestCase): - @parameterized.expand([TEST_CASE_0, TEST_CASE_1, TEST_CASE_3]) + @parameterized.expand([TEST_CASE_0, TEST_CASE_1, TEST_CASE_3, TEST_CASE_4]) def test_shape(self, input_param, input_data, expected_shape): result = CenterScaleCropd(**input_param)({"img": input_data}) np.testing.assert_allclose(result["img"].shape, expected_shape) From 4eb4470fef32060a97cc485988d5def0d1a91b43 Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Fri, 12 Nov 2021 19:33:41 +0800 Subject: [PATCH 2/4] [DLMED] update according to comments Signed-off-by: Nic Ma --- monai/apps/deepgrow/transforms.py | 16 +++------- monai/transforms/croppad/dictionary.py | 22 +++++++------- monai/transforms/intensity/dictionary.py | 26 ++++++++--------- monai/transforms/spatial/dictionary.py | 37 ++++++++++++------------ 4 files changed, 46 insertions(+), 55 deletions(-) diff --git a/monai/apps/deepgrow/transforms.py b/monai/apps/deepgrow/transforms.py index 2357af9cbe..127aa6411d 100644 --- a/monai/apps/deepgrow/transforms.py +++ b/monai/apps/deepgrow/transforms.py @@ -19,15 +19,7 @@ from monai.transforms import Resize, SpatialCrop from monai.transforms.transform import MapTransform, Randomizable, Transform from monai.transforms.utils import generate_spatial_bounding_box, is_positive -from monai.utils import ( - InterpolateMode, - deprecated_arg, - ensure_tuple, - ensure_tuple_rep, - first, - min_version, - optional_import, -) +from monai.utils import InterpolateMode, deprecated_arg, ensure_tuple, ensure_tuple_rep, min_version, optional_import measure, _ = optional_import("skimage.measure", "0.14.2", min_version) distance_transform_cdt, _ = optional_import("scipy.ndimage.morphology", name="distance_transform_cdt") @@ -652,12 +644,12 @@ def bounding_box(self, points, img_shape): def __call__(self, data): d: Dict = dict(data) - image_key = first(self.key_iterator(d)) - if image_key is None: + exist_keys = list(self.key_iterator(d)) + if not exist_keys: return d guidance = d[self.guidance] - original_spatial_shape = d[image_key].shape[1:] + original_spatial_shape = d[exist_keys[0]].shape[1:] box_start, box_end = self.bounding_box(np.array(guidance[0] + guidance[1]), original_spatial_shape) center = list(np.mean([box_start, box_end], axis=0).astype(int)) spatial_size = self.spatial_size diff --git a/monai/transforms/croppad/dictionary.py b/monai/transforms/croppad/dictionary.py index 824a9b0305..df42d73ba1 100644 --- a/monai/transforms/croppad/dictionary.py +++ b/monai/transforms/croppad/dictionary.py @@ -51,7 +51,7 @@ weighted_patch_samples, ) from monai.utils import ImageMetaKey as Key -from monai.utils import Method, NumpyPadMode, PytorchPadMode, ensure_tuple, ensure_tuple_rep, fall_back_tuple, first +from monai.utils import Method, NumpyPadMode, PytorchPadMode, ensure_tuple, ensure_tuple_rep, fall_back_tuple from monai.utils.enums import InverseKeys __all__ = [ @@ -480,12 +480,12 @@ def __init__( def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: d = dict(data) - image_key = first(self.key_iterator(d)) - if image_key is None: + exist_keys = list(self.key_iterator(d)) + if not exist_keys: return d # use the spatial size of first image to scale, expect all images have the same spatial size - img_size = d[image_key].shape[1:] + img_size = d[exist_keys[0]].shape[1:] ndim = len(img_size) roi_size = [ceil(r * s) for r, s in zip(ensure_tuple_rep(self.roi_scale, ndim), img_size)] cropper = CenterSpatialCrop(roi_size) @@ -579,11 +579,11 @@ def randomize(self, img_size: Sequence[int]) -> None: def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: d = dict(data) - image_key = first(self.key_iterator(d)) - if image_key is None: + exist_keys = list(self.key_iterator(d)) + if not exist_keys: return d - self.randomize(d[image_key].shape[1:]) # image shape from the first data key + self.randomize(d[exist_keys[0]].shape[1:]) # image shape from the first data key if self._size is None: raise RuntimeError("self._size not specified.") for key in self.key_iterator(d): @@ -677,11 +677,11 @@ def __init__( self.max_roi_scale = max_roi_scale def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: - image_key = first(self.key_iterator(data)) # type: ignore - if image_key is None: - return dict(data) + exist_keys = list(self.key_iterator(data)) # type: ignore + if not exist_keys: + return data # type: ignore - img_size = data[image_key].shape[1:] + img_size = data[exist_keys[0]].shape[1:] ndim = len(img_size) self.roi_size = [ceil(r * s) for r, s in zip(ensure_tuple_rep(self.roi_scale, ndim), img_size)] if self.max_roi_scale is not None: diff --git a/monai/transforms/intensity/dictionary.py b/monai/transforms/intensity/dictionary.py index 6e7546c0bd..79ca3f8faa 100644 --- a/monai/transforms/intensity/dictionary.py +++ b/monai/transforms/intensity/dictionary.py @@ -53,7 +53,7 @@ ) from monai.transforms.transform import MapTransform, RandomizableTransform from monai.transforms.utils import is_positive -from monai.utils import ensure_tuple, ensure_tuple_rep, first +from monai.utils import ensure_tuple, ensure_tuple_rep from monai.utils.deprecate_utils import deprecated_arg __all__ = [ @@ -187,11 +187,11 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, N return d # all the keys share the same random noise - image_key = first(self.key_iterator(d)) - if image_key is None: + exist_keys = list(self.key_iterator(d)) + if not exist_keys: return d - self.rand_gaussian_noise.randomize(d[image_key]) + self.rand_gaussian_noise.randomize(d[exist_keys[0]]) for key in self.key_iterator(d): d[key] = self.rand_gaussian_noise(img=d[key], randomize=False) return d @@ -625,11 +625,11 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, N return d # all the keys share the same random bias factor - image_key = first(self.key_iterator(d)) - if image_key is None: + exist_keys = list(self.key_iterator(d)) + if not exist_keys: return d - self.rand_bias_field.randomize(img_size=d[image_key].shape[1:]) + self.rand_bias_field.randomize(img_size=d[exist_keys[0]].shape[1:]) for key in self.key_iterator(d): d[key] = self.rand_bias_field(d[key], randomize=False) return d @@ -1474,11 +1474,11 @@ def __call__(self, data): return d # expect all the specified keys have same spatial shape and share same random holes - image_key = first(self.key_iterator(d)) - if image_key is None: + exist_keys = list(self.key_iterator(d)) + if not exist_keys: return d - self.dropper.randomize(d[image_key].shape[1:]) + self.dropper.randomize(d[exist_keys[0]].shape[1:]) for key in self.key_iterator(d): d[key] = self.dropper(img=d[key], randomize=False) @@ -1543,11 +1543,11 @@ def __call__(self, data): return d # expect all the specified keys have same spatial shape and share same random holes - image_key = first(self.key_iterator(d)) - if image_key is None: + exist_keys = list(self.key_iterator(d)) + if not exist_keys: return d - self.shuffle.randomize(d[image_key].shape[1:]) + self.shuffle.randomize(d[exist_keys[0]].shape[1:]) for key in self.key_iterator(d): d[key] = self.shuffle(img=d[key], randomize=False) diff --git a/monai/transforms/spatial/dictionary.py b/monai/transforms/spatial/dictionary.py index 21a05d045b..aa921eade8 100644 --- a/monai/transforms/spatial/dictionary.py +++ b/monai/transforms/spatial/dictionary.py @@ -60,7 +60,6 @@ ensure_tuple, ensure_tuple_rep, fall_back_tuple, - first, ) from monai.utils.deprecate_utils import deprecated_arg from monai.utils.enums import InverseKeys @@ -812,8 +811,8 @@ def set_random_state( def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: d = dict(data) - image_key = first(self.key_iterator(d)) - if image_key is None: + exist_keys = list(self.key_iterator(d)) + if not exist_keys: return d self.randomize(None) @@ -821,7 +820,7 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, N self.rand_affine.randomize() device = self.rand_affine.resampler.device - spatial_size = d[image_key].shape[1:] + spatial_size = d[exist_keys[0]].shape[1:] sp_size = fall_back_tuple(self.rand_affine.spatial_size, spatial_size) # change image size or do random transform do_resampling = self._do_transform or (sp_size != ensure_tuple(spatial_size)) @@ -980,13 +979,13 @@ def set_random_state( def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: d = dict(data) - image_key = first(self.key_iterator(d)) - if image_key is None: + exist_keys = list(self.key_iterator(d)) + if not exist_keys: return d self.randomize(None) - sp_size = fall_back_tuple(self.rand_2d_elastic.spatial_size, d[image_key].shape[1:]) + sp_size = fall_back_tuple(self.rand_2d_elastic.spatial_size, d[exist_keys[0]].shape[1:]) # all the keys share the same random elastic factor self.rand_2d_elastic.randomize(sp_size) @@ -1116,13 +1115,13 @@ def set_random_state( def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: d = dict(data) - image_key = first(self.key_iterator(d)) - if image_key is None: + exist_keys = list(self.key_iterator(d)) + if not exist_keys: return d self.randomize(None) - sp_size = fall_back_tuple(self.rand_3d_elastic.spatial_size, d[image_key].shape[1:]) + sp_size = fall_back_tuple(self.rand_3d_elastic.spatial_size, d[exist_keys[0]].shape[1:]) # all the keys share the same random elastic factor self.rand_3d_elastic.randomize(sp_size) @@ -1270,14 +1269,14 @@ def set_random_state( def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: d = dict(data) - image_key = first(self.key_iterator(d)) - if image_key is None: + exist_keys = list(self.key_iterator(d)) + if not exist_keys: return d self.randomize(None) # all the keys share the same random selected axis - self.flipper.randomize(d[image_key]) + self.flipper.randomize(d[exist_keys[0]]) for key in self.key_iterator(d): if self._do_transform: d[key] = self.flipper(d[key], randomize=False) @@ -1698,14 +1697,14 @@ def set_random_state( def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: d = dict(data) - image_key = first(self.key_iterator(d)) - if image_key is None: + exist_keys = list(self.key_iterator(d)) + if not exist_keys: return d self.randomize(None) # all the keys share the same random zoom factor - self.rand_zoom.randomize(d[image_key]) + self.rand_zoom.randomize(d[exist_keys[0]]) for key, mode, padding_mode, align_corners in self.key_iterator( d, self.mode, self.padding_mode, self.align_corners ): @@ -1887,11 +1886,11 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, N if not self._do_transform: return d - image_key = first(self.key_iterator(d)) - if image_key is None: + exist_keys = list(self.key_iterator(d)) + if not exist_keys: return d - self.rand_grid_distortion.randomize(d[image_key].shape[1:]) + self.rand_grid_distortion.randomize(d[exist_keys[0]].shape[1:]) for key, mode, padding_mode in self.key_iterator(d, self.mode, self.padding_mode): d[key] = self.rand_grid_distortion(d[key], mode=mode, padding_mode=padding_mode, randomize=False) return d From 5be7210d7140679d596d04b8e9d93dea3bf19d95 Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Tue, 16 Nov 2021 08:16:56 +0800 Subject: [PATCH 3/4] [DLMED] update according to comments Signed-off-by: Nic Ma --- monai/apps/deepgrow/transforms.py | 8 +++--- monai/transforms/croppad/dictionary.py | 18 ++++++------ monai/transforms/intensity/dictionary.py | 26 ++++++++--------- monai/transforms/spatial/dictionary.py | 36 ++++++++++++------------ 4 files changed, 44 insertions(+), 44 deletions(-) diff --git a/monai/apps/deepgrow/transforms.py b/monai/apps/deepgrow/transforms.py index 127aa6411d..30da49dc8b 100644 --- a/monai/apps/deepgrow/transforms.py +++ b/monai/apps/deepgrow/transforms.py @@ -9,7 +9,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import json -from typing import Callable, Dict, Optional, Sequence, Union +from typing import Callable, Dict, Hashable, List, Optional, Sequence, Union import numpy as np import torch @@ -644,12 +644,12 @@ def bounding_box(self, points, img_shape): def __call__(self, data): d: Dict = dict(data) - exist_keys = list(self.key_iterator(d)) - if not exist_keys: + first_key: Union[Hashable, List] = next(self.key_iterator(d), []) + if first_key == []: return d guidance = d[self.guidance] - original_spatial_shape = d[exist_keys[0]].shape[1:] + original_spatial_shape = d[first_key].shape[1:] # type: ignore box_start, box_end = self.bounding_box(np.array(guidance[0] + guidance[1]), original_spatial_shape) center = list(np.mean([box_start, box_end], axis=0).astype(int)) spatial_size = self.spatial_size diff --git a/monai/transforms/croppad/dictionary.py b/monai/transforms/croppad/dictionary.py index 79d85e5ad8..461d0b7cd0 100644 --- a/monai/transforms/croppad/dictionary.py +++ b/monai/transforms/croppad/dictionary.py @@ -480,12 +480,12 @@ def __init__( def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: d = dict(data) - exist_keys = list(self.key_iterator(d)) - if not exist_keys: + first_key: Union[Hashable, List] = next(self.key_iterator(d), []) + if first_key == []: return d # use the spatial size of first image to scale, expect all images have the same spatial size - img_size = d[exist_keys[0]].shape[1:] + img_size = d[first_key].shape[1:] # type: ignore ndim = len(img_size) roi_size = [ceil(r * s) for r, s in zip(ensure_tuple_rep(self.roi_scale, ndim), img_size)] cropper = CenterSpatialCrop(roi_size) @@ -579,11 +579,11 @@ def randomize(self, img_size: Sequence[int]) -> None: def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: d = dict(data) - exist_keys = list(self.key_iterator(d)) - if not exist_keys: + first_key: Union[Hashable, List] = next(self.key_iterator(d), []) + if first_key == []: return d - self.randomize(d[exist_keys[0]].shape[1:]) # image shape from the first data key + self.randomize(d[first_key].shape[1:]) # type: ignore if self._size is None: raise RuntimeError("self._size not specified.") for key in self.key_iterator(d): @@ -677,11 +677,11 @@ def __init__( self.max_roi_scale = max_roi_scale def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: - exist_keys = list(self.key_iterator(data)) # type: ignore - if not exist_keys: + first_key: Union[Hashable, List] = next(self.key_iterator(data), []) # type: ignore + if first_key == []: return data # type: ignore - img_size = data[exist_keys[0]].shape[1:] + img_size = data[first_key].shape[1:] # type: ignore ndim = len(img_size) self.roi_size = [ceil(r * s) for r, s in zip(ensure_tuple_rep(self.roi_scale, ndim), img_size)] if self.max_roi_scale is not None: diff --git a/monai/transforms/intensity/dictionary.py b/monai/transforms/intensity/dictionary.py index 79ca3f8faa..82ac5fee1c 100644 --- a/monai/transforms/intensity/dictionary.py +++ b/monai/transforms/intensity/dictionary.py @@ -15,7 +15,7 @@ Class names are ended with 'd' to denote dictionary-based transforms. """ -from typing import Callable, Dict, Hashable, Mapping, Optional, Sequence, Tuple, Union +from typing import Callable, Dict, Hashable, List, Mapping, Optional, Sequence, Tuple, Union import numpy as np @@ -187,11 +187,11 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, N return d # all the keys share the same random noise - exist_keys = list(self.key_iterator(d)) - if not exist_keys: + first_key: Union[Hashable, List] = next(self.key_iterator(d), []) + if first_key == []: return d - self.rand_gaussian_noise.randomize(d[exist_keys[0]]) + self.rand_gaussian_noise.randomize(d[first_key]) # type: ignore for key in self.key_iterator(d): d[key] = self.rand_gaussian_noise(img=d[key], randomize=False) return d @@ -625,11 +625,11 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, N return d # all the keys share the same random bias factor - exist_keys = list(self.key_iterator(d)) - if not exist_keys: + first_key: Union[Hashable, List] = next(self.key_iterator(d), []) + if first_key == []: return d - self.rand_bias_field.randomize(img_size=d[exist_keys[0]].shape[1:]) + self.rand_bias_field.randomize(img_size=d[first_key].shape[1:]) # type: ignore for key in self.key_iterator(d): d[key] = self.rand_bias_field(d[key], randomize=False) return d @@ -1474,11 +1474,11 @@ def __call__(self, data): return d # expect all the specified keys have same spatial shape and share same random holes - exist_keys = list(self.key_iterator(d)) - if not exist_keys: + first_key: Union[Hashable, List] = next(self.key_iterator(d), []) + if first_key == []: return d - self.dropper.randomize(d[exist_keys[0]].shape[1:]) + self.dropper.randomize(d[first_key].shape[1:]) # type: ignore for key in self.key_iterator(d): d[key] = self.dropper(img=d[key], randomize=False) @@ -1543,11 +1543,11 @@ def __call__(self, data): return d # expect all the specified keys have same spatial shape and share same random holes - exist_keys = list(self.key_iterator(d)) - if not exist_keys: + first_key: Union[Hashable, List] = next(self.key_iterator(d), []) + if first_key == []: return d - self.shuffle.randomize(d[exist_keys[0]].shape[1:]) + self.shuffle.randomize(d[first_key].shape[1:]) # type: ignore for key in self.key_iterator(d): d[key] = self.shuffle(img=d[key], randomize=False) diff --git a/monai/transforms/spatial/dictionary.py b/monai/transforms/spatial/dictionary.py index 977b114274..b7048239a9 100644 --- a/monai/transforms/spatial/dictionary.py +++ b/monai/transforms/spatial/dictionary.py @@ -816,8 +816,8 @@ def set_random_state( def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: d = dict(data) - exist_keys = list(self.key_iterator(d)) - if not exist_keys: + first_key: Union[Hashable, List] = next(self.key_iterator(d), []) + if first_key == []: return d self.randomize(None) @@ -825,7 +825,7 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, N self.rand_affine.randomize() device = self.rand_affine.resampler.device - spatial_size = d[exist_keys[0]].shape[1:] + spatial_size = d[first_key].shape[1:] # type: ignore sp_size = fall_back_tuple(self.rand_affine.spatial_size, spatial_size) # change image size or do random transform do_resampling = self._do_transform or (sp_size != ensure_tuple(spatial_size)) @@ -984,13 +984,13 @@ def set_random_state( def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: d = dict(data) - exist_keys = list(self.key_iterator(d)) - if not exist_keys: + first_key: Union[Hashable, List] = next(self.key_iterator(d), []) + if first_key == []: return d self.randomize(None) - sp_size = fall_back_tuple(self.rand_2d_elastic.spatial_size, d[exist_keys[0]].shape[1:]) + sp_size = fall_back_tuple(self.rand_2d_elastic.spatial_size, d[first_key].shape[1:]) # type: ignore # all the keys share the same random elastic factor self.rand_2d_elastic.randomize(sp_size) @@ -1120,13 +1120,13 @@ def set_random_state( def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: d = dict(data) - exist_keys = list(self.key_iterator(d)) - if not exist_keys: + first_key: Union[Hashable, List] = next(self.key_iterator(d), []) + if first_key == []: return d self.randomize(None) - sp_size = fall_back_tuple(self.rand_3d_elastic.spatial_size, d[exist_keys[0]].shape[1:]) + sp_size = fall_back_tuple(self.rand_3d_elastic.spatial_size, d[first_key].shape[1:]) # type: ignore # all the keys share the same random elastic factor self.rand_3d_elastic.randomize(sp_size) @@ -1274,14 +1274,14 @@ def set_random_state( def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: d = dict(data) - exist_keys = list(self.key_iterator(d)) - if not exist_keys: + first_key: Union[Hashable, List] = next(self.key_iterator(d), []) + if first_key == []: return d self.randomize(None) # all the keys share the same random selected axis - self.flipper.randomize(d[exist_keys[0]]) + self.flipper.randomize(d[first_key]) # type: ignore for key in self.key_iterator(d): if self._do_transform: d[key] = self.flipper(d[key], randomize=False) @@ -1702,14 +1702,14 @@ def set_random_state( def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: d = dict(data) - exist_keys = list(self.key_iterator(d)) - if not exist_keys: + first_key: Union[Hashable, List] = next(self.key_iterator(d), []) + if first_key == []: return d self.randomize(None) # all the keys share the same random zoom factor - self.rand_zoom.randomize(d[exist_keys[0]]) + self.rand_zoom.randomize(d[first_key]) # type: ignore for key, mode, padding_mode, align_corners in self.key_iterator( d, self.mode, self.padding_mode, self.align_corners ): @@ -1891,11 +1891,11 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, N if not self._do_transform: return d - exist_keys = list(self.key_iterator(d)) - if not exist_keys: + first_key: Union[Hashable, List] = next(self.key_iterator(d), []) + if first_key == []: return d - self.rand_grid_distortion.randomize(d[exist_keys[0]].shape[1:]) + self.rand_grid_distortion.randomize(d[first_key].shape[1:]) # type: ignore for key, mode, padding_mode in self.key_iterator(d, self.mode, self.padding_mode): d[key] = self.rand_grid_distortion(d[key], mode=mode, padding_mode=padding_mode, randomize=False) return d From 05c256c22177077045d242982240d4c6e3d07464 Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Tue, 16 Nov 2021 23:42:50 +0800 Subject: [PATCH 4/4] [DLMED] update according to comments Signed-off-by: Nic Ma --- monai/apps/deepgrow/transforms.py | 2 +- monai/transforms/croppad/dictionary.py | 6 +++--- monai/transforms/intensity/dictionary.py | 8 ++++---- monai/transforms/spatial/dictionary.py | 12 ++++++------ monai/transforms/transform.py | 13 ++++++++++++- 5 files changed, 26 insertions(+), 15 deletions(-) diff --git a/monai/apps/deepgrow/transforms.py b/monai/apps/deepgrow/transforms.py index 30da49dc8b..1a3c035083 100644 --- a/monai/apps/deepgrow/transforms.py +++ b/monai/apps/deepgrow/transforms.py @@ -644,7 +644,7 @@ def bounding_box(self, points, img_shape): def __call__(self, data): d: Dict = dict(data) - first_key: Union[Hashable, List] = next(self.key_iterator(d), []) + first_key: Union[Hashable, List] = self.first_key(d) if first_key == []: return d diff --git a/monai/transforms/croppad/dictionary.py b/monai/transforms/croppad/dictionary.py index 461d0b7cd0..d5112f8fb6 100644 --- a/monai/transforms/croppad/dictionary.py +++ b/monai/transforms/croppad/dictionary.py @@ -480,7 +480,7 @@ def __init__( def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: d = dict(data) - first_key: Union[Hashable, List] = next(self.key_iterator(d), []) + first_key: Union[Hashable, List] = self.first_key(d) if first_key == []: return d @@ -579,7 +579,7 @@ def randomize(self, img_size: Sequence[int]) -> None: def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: d = dict(data) - first_key: Union[Hashable, List] = next(self.key_iterator(d), []) + first_key: Union[Hashable, List] = self.first_key(d) if first_key == []: return d @@ -677,7 +677,7 @@ def __init__( self.max_roi_scale = max_roi_scale def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: - first_key: Union[Hashable, List] = next(self.key_iterator(data), []) # type: ignore + first_key: Union[Hashable, List] = self.first_key(data) # type: ignore if first_key == []: return data # type: ignore diff --git a/monai/transforms/intensity/dictionary.py b/monai/transforms/intensity/dictionary.py index 82ac5fee1c..71809d7ea2 100644 --- a/monai/transforms/intensity/dictionary.py +++ b/monai/transforms/intensity/dictionary.py @@ -187,7 +187,7 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, N return d # all the keys share the same random noise - first_key: Union[Hashable, List] = next(self.key_iterator(d), []) + first_key: Union[Hashable, List] = self.first_key(d) if first_key == []: return d @@ -625,7 +625,7 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, N return d # all the keys share the same random bias factor - first_key: Union[Hashable, List] = next(self.key_iterator(d), []) + first_key: Union[Hashable, List] = self.first_key(d) if first_key == []: return d @@ -1474,7 +1474,7 @@ def __call__(self, data): return d # expect all the specified keys have same spatial shape and share same random holes - first_key: Union[Hashable, List] = next(self.key_iterator(d), []) + first_key: Union[Hashable, List] = self.first_key(d) if first_key == []: return d @@ -1543,7 +1543,7 @@ def __call__(self, data): return d # expect all the specified keys have same spatial shape and share same random holes - first_key: Union[Hashable, List] = next(self.key_iterator(d), []) + first_key: Union[Hashable, List] = self.first_key(d) if first_key == []: return d diff --git a/monai/transforms/spatial/dictionary.py b/monai/transforms/spatial/dictionary.py index b7048239a9..e4fc4cf3a9 100644 --- a/monai/transforms/spatial/dictionary.py +++ b/monai/transforms/spatial/dictionary.py @@ -816,7 +816,7 @@ def set_random_state( def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: d = dict(data) - first_key: Union[Hashable, List] = next(self.key_iterator(d), []) + first_key: Union[Hashable, List] = self.first_key(d) if first_key == []: return d @@ -984,7 +984,7 @@ def set_random_state( def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: d = dict(data) - first_key: Union[Hashable, List] = next(self.key_iterator(d), []) + first_key: Union[Hashable, List] = self.first_key(d) if first_key == []: return d @@ -1120,7 +1120,7 @@ def set_random_state( def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: d = dict(data) - first_key: Union[Hashable, List] = next(self.key_iterator(d), []) + first_key: Union[Hashable, List] = self.first_key(d) if first_key == []: return d @@ -1274,7 +1274,7 @@ def set_random_state( def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: d = dict(data) - first_key: Union[Hashable, List] = next(self.key_iterator(d), []) + first_key: Union[Hashable, List] = self.first_key(d) if first_key == []: return d @@ -1702,7 +1702,7 @@ def set_random_state( def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: d = dict(data) - first_key: Union[Hashable, List] = next(self.key_iterator(d), []) + first_key: Union[Hashable, List] = self.first_key(d) if first_key == []: return d @@ -1891,7 +1891,7 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, N if not self._do_transform: return d - first_key: Union[Hashable, List] = next(self.key_iterator(d), []) + first_key: Union[Hashable, List] = self.first_key(d) if first_key == []: return d diff --git a/monai/transforms/transform.py b/monai/transforms/transform.py index 61794308f4..58e8d29e43 100644 --- a/monai/transforms/transform.py +++ b/monai/transforms/transform.py @@ -21,7 +21,7 @@ from monai import transforms from monai.config import KeysCollection -from monai.utils import MAX_SEED, ensure_tuple +from monai.utils import MAX_SEED, ensure_tuple, first from monai.utils.enums import TransformBackends __all__ = ["ThreadUnsafe", "apply_transform", "Randomizable", "RandomizableTransform", "Transform", "MapTransform"] @@ -361,3 +361,14 @@ def key_iterator(self, data: Dict[Hashable, Any], *extra_iterables: Optional[Ite yield (key,) + tuple(_ex_iters) if extra_iterables else key elif not self.allow_missing_keys: raise KeyError(f"Key was missing ({key}) and allow_missing_keys==False") + + def first_key(self, data: Dict[Hashable, Any]): + """ + Get the first available key of `self.keys` in the input `data` dictionary. + If no available key, return an empty list `[]`. + + Args: + data: data that the transform will be applied to. + + """ + return first(self.key_iterator(data), [])