From 698b5eb8c4de3e44b64216b322c69200a75ac037 Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Wed, 2 Mar 2022 13:40:12 +0000 Subject: [PATCH 1/5] SplitDim Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/transforms/__init__.py | 4 ++ monai/transforms/utility/array.py | 57 +++++++++++++++------ monai/transforms/utility/dictionary.py | 68 ++++++++++++++++++++------ tests/test_splitdim.py | 50 +++++++++++++++++++ tests/test_splitdimd.py | 65 ++++++++++++++++++++++++ 5 files changed, 213 insertions(+), 31 deletions(-) create mode 100644 tests/test_splitdim.py create mode 100644 tests/test_splitdimd.py diff --git a/monai/transforms/__init__.py b/monai/transforms/__init__.py index a3dc439a51..8e6ccc8b94 100644 --- a/monai/transforms/__init__.py +++ b/monai/transforms/__init__.py @@ -412,6 +412,7 @@ RepeatChannel, SimulateDelay, SplitChannel, + SplitDim, SqueezeDim, ToCupy, ToDevice, @@ -509,6 +510,9 @@ SplitChanneld, SplitChannelD, SplitChannelDict, + SplitDimd, + SplitDimD, + SplitDimDict, SqueezeDimd, SqueezeDimD, SqueezeDimDict, diff --git a/monai/transforms/utility/array.py b/monai/transforms/utility/array.py index 0100c33719..d9c203c5bc 100644 --- a/monai/transforms/utility/array.py +++ b/monai/transforms/utility/array.py @@ -37,6 +37,7 @@ convert_to_cupy, convert_to_numpy, convert_to_tensor, + deprecated, deprecated_arg, ensure_tuple, look_up_option, @@ -62,6 +63,7 @@ "EnsureType", "RepeatChannel", "RemoveRepeatedChannel", + "SplitDim", "SplitChannel", "CastToType", "ToTensor", @@ -281,33 +283,56 @@ def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor: return img[:: self.repeats, :] -class SplitChannel(Transform): +class SplitDim(Transform): """ - Split Numpy array or PyTorch Tensor data according to the channel dim. - It can help applying different following transforms to different channels. + Given an image of size X along a certain dimension, return a list of length X containing + images. Useful for converting 3D images into a stack of 2D images, for example. - Args: - channel_dim: which dimension of input image is the channel, default to 0. + Note: `torch.split`/`np.split` is used, so the outputs are views of the input (shallow copy). + Args: + dim: dimension on which to split + keepdim: if `True`, output will have singleton in the split dimension. If `False`, this + dimension will be squeezed. """ backend = [TransformBackends.TORCH, TransformBackends.NUMPY] - def __init__(self, channel_dim: int = 0) -> None: - self.channel_dim = channel_dim + def __init__(self, dim: int = -1, keepdim: bool = True) -> None: + self.dim = dim + self.keepdim = keepdim def __call__(self, img: NdarrayOrTensor) -> List[NdarrayOrTensor]: - num_classes = img.shape[self.channel_dim] - if num_classes <= 1: - raise RuntimeError("input image does not contain multiple channels.") + """ + Apply the transform to `img`. + """ + n_out = img.shape[self.dim] + if n_out <= 1: + raise RuntimeError("Input image is singleton along dimension to be split.") + if isinstance(img, torch.Tensor): + outputs = list(torch.split(img, 1, self.dim)) + else: + outputs = np.split(img, n_out, self.dim) + if not self.keepdim: + outputs = [o.squeeze(self.dim) for o in outputs] + return outputs - outputs = [] - slices = [slice(None)] * len(img.shape) - for i in range(num_classes): - slices[self.channel_dim] = slice(i, i + 1) - outputs.append(img[tuple(slices)]) - return outputs +@deprecated(since="0.8", msg_suffix="please use `SplitDim` instead.") +class SplitChannel(SplitDim): + """ + Split Numpy array or PyTorch Tensor data according to the channel dim. + It can help applying different following transforms to different channels. + + Note: `torch.split`/`np.split` is used, so the outputs are views of the input (shallow copy). + + Args: + channel_dim: which dimension of input image is the channel, default to 0. + + """ + + def __init__(self, channel_dim: int = 0) -> None: + super().__init__(channel_dim) class CastToType(Transform): diff --git a/monai/transforms/utility/dictionary.py b/monai/transforms/utility/dictionary.py index ecf9aaffa4..9d859eb97e 100644 --- a/monai/transforms/utility/dictionary.py +++ b/monai/transforms/utility/dictionary.py @@ -49,7 +49,7 @@ RemoveRepeatedChannel, RepeatChannel, SimulateDelay, - SplitChannel, + SplitDim, SqueezeDim, ToCupy, ToDevice, @@ -61,7 +61,7 @@ ) from monai.transforms.utils import extreme_points_to_image, get_extreme_points from monai.transforms.utils_pytorch_numpy_unification import concatenate -from monai.utils import convert_to_numpy, deprecated_arg, ensure_tuple, ensure_tuple_rep +from monai.utils import convert_to_numpy, deprecated, deprecated_arg, ensure_tuple, ensure_tuple_rep from monai.utils.enums import PostFix, TraceKeys, TransformBackends from monai.utils.type_conversion import convert_to_dst_type @@ -150,6 +150,9 @@ "SplitChannelD", "SplitChannelDict", "SplitChanneld", + "SplitDimD", + "SplitDimDict", + "SplitDimd", "SqueezeDimD", "SqueezeDimDict", "SqueezeDimd", @@ -372,19 +375,14 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, N return d -class SplitChanneld(MapTransform): - """ - Dictionary-based wrapper of :py:class:`monai.transforms.SplitChannel`. - All the input specified by `keys` should be split into same count of data. - """ - - backend = SplitChannel.backend - +class SplitDimd(MapTransform): def __init__( self, keys: KeysCollection, output_postfixes: Optional[Sequence[str]] = None, - channel_dim: int = 0, + dim: int = 0, + keepdim: bool = True, + update_meta: bool = True, allow_missing_keys: bool = False, ) -> None: """ @@ -395,13 +393,17 @@ def __init__( for example: if the key of input data is `pred` and split 2 classes, the output data keys will be: pred_(output_postfixes[0]), pred_(output_postfixes[1]) if None, using the index number: `pred_0`, `pred_1`, ... `pred_N`. - channel_dim: which dimension of input image is the channel, default to 0. + dim: which dimension of input image is the channel, default to 0. + keepdim: if `True`, output will have singleton in the split dimension. If `False`, this + dimension will be squeezed. + update_meta: if `True`, copy `[key]_meta_dict` for each output and update affine to + reflect the cropped image allow_missing_keys: don't raise exception if key is missing. - """ - super().__init__(keys, allow_missing_keys) + MapTransform.__init__(self, keys, allow_missing_keys) self.output_postfixes = output_postfixes - self.splitter = SplitChannel(channel_dim=channel_dim) + self.splitter = SplitDim(dim, keepdim) + self.update_meta = update_meta def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: d = dict(data) @@ -415,9 +417,44 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, N if split_key in d: raise RuntimeError(f"input data already contains key {split_key}.") d[split_key] = r + + if self.update_meta: + orig_meta = d.get(PostFix.meta(key), None) + if orig_meta is not None: + split_meta_key = PostFix.meta(split_key) + d[split_meta_key] = deepcopy(orig_meta) + dim = self.splitter.dim + if dim > 0: # don't update affine if channel dim + shift = np.zeros_like(d[split_meta_key]["affine"]) + shift[dim - 1, -1] = i * d[split_meta_key]["pixdim"][dim] # type: ignore + d[split_meta_key]["affine"] += shift @ d[split_meta_key]["affine"] # type: ignore + return d +@deprecated(since="0.8", msg_suffix="please use `SplitDimd` instead.") +class SplitChanneld(SplitDimd): + """ + Dictionary-based wrapper of :py:class:`monai.transforms.SplitChannel`. + All the input specified by `keys` should be split into same count of data. + """ + + def __init__( + self, + keys: KeysCollection, + output_postfixes: Optional[Sequence[str]] = None, + channel_dim: int = 0, + allow_missing_keys: bool = False, + ) -> None: + super().__init__( + keys, + output_postfixes=output_postfixes, + dim=channel_dim, + update_meta=False, # for backwards compatibility + allow_missing_keys=allow_missing_keys, + ) + + class CastToTyped(MapTransform): """ Dictionary-based wrapper of :py:class:`monai.transforms.CastToType`. @@ -1637,6 +1674,7 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, N RemoveRepeatedChannelD = RemoveRepeatedChannelDict = RemoveRepeatedChanneld RepeatChannelD = RepeatChannelDict = RepeatChanneld SplitChannelD = SplitChannelDict = SplitChanneld +SplitDimD = SplitDimDict = SplitDimd CastToTypeD = CastToTypeDict = CastToTyped ToTensorD = ToTensorDict = ToTensord EnsureTypeD = EnsureTypeDict = EnsureTyped diff --git a/tests/test_splitdim.py b/tests/test_splitdim.py new file mode 100644 index 0000000000..623396a8fe --- /dev/null +++ b/tests/test_splitdim.py @@ -0,0 +1,50 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +from parameterized import parameterized + +from monai.transforms.utility.array import SplitDim +from tests.utils import TEST_NDARRAYS + +TESTS = [] +for p in TEST_NDARRAYS: + for keepdim in (True, False): + TESTS.append(((2, 10, 8, 7), keepdim, p)) + + +class TestSplitDim(unittest.TestCase): + @parameterized.expand(TESTS) + def test_correct_shape(self, shape, keepdim, im_type): + arr = im_type(np.random.rand(*shape)) + for dim in range(arr.ndim): + out = SplitDim(dim, keepdim)(arr) + self.assertIsInstance(out, (list, tuple)) + self.assertEqual(len(out), arr.shape[dim]) + expected_ndim = arr.ndim if keepdim else arr.ndim - 1 + self.assertEqual(out[0].ndim, expected_ndim) + # assert is a shallow copy + arr[0, 0, 0, 0] *= 2 + self.assertEqual(arr.flatten()[0], out[0].flatten()[0]) + + def test_error(self): + """Should fail because splitting along singleton dimension""" + shape = (2, 1, 8, 7) + for p in TEST_NDARRAYS: + arr = p(np.random.rand(*shape)) + with self.assertRaises(RuntimeError): + _ = SplitDim(dim=1)(arr) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_splitdimd.py b/tests/test_splitdimd.py new file mode 100644 index 0000000000..2979bc5141 --- /dev/null +++ b/tests/test_splitdimd.py @@ -0,0 +1,65 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +from copy import deepcopy + +import numpy as np +from parameterized import parameterized + +from monai.transforms import LoadImaged +from monai.transforms.utility.dictionary import SplitDimd +from tests.utils import TEST_NDARRAYS, make_nifti_image, make_rand_affine + +TESTS = [] +for p in TEST_NDARRAYS: + for keepdim in (True, False): + for update_meta in (True, False): + TESTS.append((keepdim, p, update_meta)) + + +class TestSplitDimd(unittest.TestCase): + @classmethod + def setUpClass(cls): + arr = np.random.rand(2, 10, 8, 7) + affine = make_rand_affine() + data = {"i": make_nifti_image(arr, affine)} + + cls.data = LoadImaged("i")(data) + + @parameterized.expand(TESTS) + def test_correct_shape(self, keepdim, im_type, update_meta): + data = deepcopy(self.data) + data["i"] = im_type(data["i"]) + arr = data["i"] + for dim in range(arr.ndim): + out = SplitDimd("i", dim=dim, keepdim=keepdim, update_meta=update_meta)(data) + self.assertIsInstance(out, dict) + num_new_keys = 2 if update_meta else 1 + self.assertEqual(len(out.keys()), len(data.keys()) + num_new_keys * arr.shape[dim]) + out = out["i_0"] + expected_ndim = arr.ndim if keepdim else arr.ndim - 1 + self.assertEqual(out.ndim, expected_ndim) + # assert is a shallow copy + arr[0, 0, 0, 0] *= 2 + self.assertEqual(arr.flatten()[0], out.flatten()[0]) + + def test_error(self): + """Should fail because splitting along singleton dimension""" + shape = (2, 1, 8, 7) + for p in TEST_NDARRAYS: + arr = p(np.random.rand(*shape)) + with self.assertRaises(RuntimeError): + _ = SplitDimd("i", dim=1)({"i": arr}) + + +if __name__ == "__main__": + unittest.main() From 452ffc423ec68f7c4d026626986f612ae0cebcf6 Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Wed, 2 Mar 2022 13:42:37 +0000 Subject: [PATCH 2/5] fix Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/transforms/utility/dictionary.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/monai/transforms/utility/dictionary.py b/monai/transforms/utility/dictionary.py index 9d859eb97e..d53fc8ba3c 100644 --- a/monai/transforms/utility/dictionary.py +++ b/monai/transforms/utility/dictionary.py @@ -426,7 +426,7 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, N dim = self.splitter.dim if dim > 0: # don't update affine if channel dim shift = np.zeros_like(d[split_meta_key]["affine"]) - shift[dim - 1, -1] = i * d[split_meta_key]["pixdim"][dim] # type: ignore + shift[dim - 1, -1] = i * d[split_meta_key]["pixdim"][dim] # type: ignore d[split_meta_key]["affine"] += shift @ d[split_meta_key]["affine"] # type: ignore return d From cfe392addcad4fc91d5d2745356c4561aa8f75ce Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Wed, 2 Mar 2022 15:49:16 +0000 Subject: [PATCH 3/5] fixes Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/transforms/utility/dictionary.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/monai/transforms/utility/dictionary.py b/monai/transforms/utility/dictionary.py index d53fc8ba3c..b5055723e9 100644 --- a/monai/transforms/utility/dictionary.py +++ b/monai/transforms/utility/dictionary.py @@ -400,7 +400,7 @@ def __init__( reflect the cropped image allow_missing_keys: don't raise exception if key is missing. """ - MapTransform.__init__(self, keys, allow_missing_keys) + super().__init__(keys, allow_missing_keys) self.output_postfixes = output_postfixes self.splitter = SplitDim(dim, keepdim) self.update_meta = update_meta @@ -425,7 +425,7 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, N d[split_meta_key] = deepcopy(orig_meta) dim = self.splitter.dim if dim > 0: # don't update affine if channel dim - shift = np.zeros_like(d[split_meta_key]["affine"]) + shift = np.zeros_like(d[split_meta_key]["affine"]) # type: ignore shift[dim - 1, -1] = i * d[split_meta_key]["pixdim"][dim] # type: ignore d[split_meta_key]["affine"] += shift @ d[split_meta_key]["affine"] # type: ignore From bd8e99f7cd4095e8e5b4e8559efe80aa83291833 Mon Sep 17 00:00:00 2001 From: Richard Brown <33289025+rijobro@users.noreply.github.com> Date: Thu, 7 Apr 2022 11:15:30 +0100 Subject: [PATCH 4/5] fix update meta Signed-off-by: Richard Brown <33289025+rijobro@users.noreply.github.com> --- monai/transforms/utility/dictionary.py | 6 +++--- tests/test_splitdimd.py | 17 +++++++++++++++-- 2 files changed, 18 insertions(+), 5 deletions(-) diff --git a/monai/transforms/utility/dictionary.py b/monai/transforms/utility/dictionary.py index b5055723e9..564b2993e7 100644 --- a/monai/transforms/utility/dictionary.py +++ b/monai/transforms/utility/dictionary.py @@ -425,9 +425,9 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, N d[split_meta_key] = deepcopy(orig_meta) dim = self.splitter.dim if dim > 0: # don't update affine if channel dim - shift = np.zeros_like(d[split_meta_key]["affine"]) # type: ignore - shift[dim - 1, -1] = i * d[split_meta_key]["pixdim"][dim] # type: ignore - d[split_meta_key]["affine"] += shift @ d[split_meta_key]["affine"] # type: ignore + shift = np.eye(len(d[split_meta_key]["affine"])) # type: ignore + shift[dim - 1, -1] = i # type: ignore + d[split_meta_key]["affine"] = d[split_meta_key]["affine"] @ shift # type: ignore return d diff --git a/tests/test_splitdimd.py b/tests/test_splitdimd.py index 2979bc5141..6b164a3cb8 100644 --- a/tests/test_splitdimd.py +++ b/tests/test_splitdimd.py @@ -17,7 +17,7 @@ from monai.transforms import LoadImaged from monai.transforms.utility.dictionary import SplitDimd -from tests.utils import TEST_NDARRAYS, make_nifti_image, make_rand_affine +from tests.utils import TEST_NDARRAYS, assert_allclose, make_nifti_image, make_rand_affine TESTS = [] for p in TEST_NDARRAYS: @@ -36,7 +36,7 @@ def setUpClass(cls): cls.data = LoadImaged("i")(data) @parameterized.expand(TESTS) - def test_correct_shape(self, keepdim, im_type, update_meta): + def test_correct(self, keepdim, im_type, update_meta): data = deepcopy(self.data) data["i"] = im_type(data["i"]) arr = data["i"] @@ -45,6 +45,19 @@ def test_correct_shape(self, keepdim, im_type, update_meta): self.assertIsInstance(out, dict) num_new_keys = 2 if update_meta else 1 self.assertEqual(len(out.keys()), len(data.keys()) + num_new_keys * arr.shape[dim]) + # if updating meta data, pick some random points and + # check same world coordinates between input and output + if update_meta: + for _ in range(10): + idx = [np.random.choice(i) for i in arr.shape] + split_im_idx = idx[dim] + split_idx = deepcopy(idx) + split_idx[dim] = 0 + # idx[1:] to remove channel and then add 1 for 4th element + real_world = data["i_meta_dict"]["affine"] @ (idx[1:] + [1]) + real_world2 = out[f"i_{split_im_idx}_meta_dict"]["affine"] @ (split_idx[1:] + [1]) + assert_allclose(real_world, real_world2) + out = out["i_0"] expected_ndim = arr.ndim if keepdim else arr.ndim - 1 self.assertEqual(out.ndim, expected_ndim) From 1d84d4637cd18cb4d8f63427d1468c122ba4c4b3 Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Thu, 7 Apr 2022 13:02:01 +0100 Subject: [PATCH 5/5] update docs Signed-off-by: Wenqi Li --- docs/source/transforms.rst | 12 ++++++++++++ monai/transforms/utility/array.py | 3 ++- tests/min_tests.py | 1 + 3 files changed, 15 insertions(+), 1 deletion(-) diff --git a/docs/source/transforms.rst b/docs/source/transforms.rst index 8fc832a253..78fb303093 100644 --- a/docs/source/transforms.rst +++ b/docs/source/transforms.rst @@ -803,6 +803,12 @@ Utility :members: :special-members: __call__ +`SplitDim` +"""""""""" +.. autoclass:: SplitDim + :members: + :special-members: __call__ + `SplitChannel` """""""""""""" .. autoclass:: SplitChannel @@ -1638,6 +1644,12 @@ Utility (Dict) :members: :special-members: __call__ +`SplitDimd` +""""""""""" +.. autoclass:: SplitDimd + :members: + :special-members: __call__ + `SplitChanneld` """"""""""""""" .. autoclass:: SplitChanneld diff --git a/monai/transforms/utility/array.py b/monai/transforms/utility/array.py index f3ac0d0e26..bc0c09e949 100644 --- a/monai/transforms/utility/array.py +++ b/monai/transforms/utility/array.py @@ -286,7 +286,8 @@ def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor: class SplitDim(Transform): """ Given an image of size X along a certain dimension, return a list of length X containing - images. Useful for converting 3D images into a stack of 2D images, for example. + images. Useful for converting 3D images into a stack of 2D images, splitting multichannel inputs into + single channels, for example. Note: `torch.split`/`np.split` is used, so the outputs are views of the input (shallow copy). diff --git a/tests/min_tests.py b/tests/min_tests.py index 9bf95f3f49..988a703e5a 100644 --- a/tests/min_tests.py +++ b/tests/min_tests.py @@ -143,6 +143,7 @@ def run_testsuit(): "test_smartcachedataset", "test_spacing", "test_spacingd", + "test_splitdimd", "test_surface_distance", "test_testtimeaugmentation", "test_torchvision",