From 801f244ef2700e432d90c546077de3d40335169b Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Wed, 12 Jan 2022 14:24:40 +0000 Subject: [PATCH] update to support multiple dims Signed-off-by: Wenqi Li --- monai/transforms/utility/array.py | 20 ++++++++++---------- monai/transforms/utility/dictionary.py | 10 ++++++++-- tests/test_as_channel_first.py | 6 +++++- tests/test_as_channel_firstd.py | 1 + tests/test_as_channel_last.py | 1 + tests/test_as_channel_lastd.py | 1 + 6 files changed, 26 insertions(+), 13 deletions(-) diff --git a/monai/transforms/utility/array.py b/monai/transforms/utility/array.py index a107cf1cb1..8a5453f19a 100644 --- a/monai/transforms/utility/array.py +++ b/monai/transforms/utility/array.py @@ -118,20 +118,20 @@ class AsChannelFirst(Transform): Args: channel_dim: which dimension of input image is the channel, default is the last dimension. + if channel_dim is a sequence, the transform will move the channel dimensions to the front. """ backend = [TransformBackends.TORCH, TransformBackends.NUMPY] - def __init__(self, channel_dim: int = -1) -> None: - if not (isinstance(channel_dim, int) and channel_dim >= -1): - raise AssertionError("invalid channel dimension.") - self.channel_dim = channel_dim + def __init__(self, channel_dim: Union[int, Sequence[int]] = -1) -> None: + self.channel_dim = ensure_tuple(channel_dim) + self.target_dim = tuple(range(len(self.channel_dim))) def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor: """ Apply the transform to `img`. """ - return moveaxis(img, self.channel_dim, 0) + return moveaxis(img, self.channel_dim, self.target_dim) class AsChannelLast(Transform): @@ -147,20 +147,20 @@ class AsChannelLast(Transform): Args: channel_dim: which dimension of input image is the channel, default is the first dimension. + if channel_dim is a sequence, the transform will move the channel dimensions to the back. """ backend = [TransformBackends.TORCH, TransformBackends.NUMPY] - def __init__(self, channel_dim: int = 0) -> None: - if not (isinstance(channel_dim, int) and channel_dim >= -1): - raise AssertionError("invalid channel dimension.") - self.channel_dim = channel_dim + def __init__(self, channel_dim: Union[int, Sequence[int]] = 0) -> None: + self.channel_dim = ensure_tuple(channel_dim) + self.target_dim = tuple(range(-len(self.channel_dim), 0)) def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor: """ Apply the transform to `img`. """ - return moveaxis(img, self.channel_dim, -1) + return moveaxis(img, self.channel_dim, self.target_dim) class AddChannel(Transform): diff --git a/monai/transforms/utility/dictionary.py b/monai/transforms/utility/dictionary.py index b611d2ed30..c9c547fe41 100644 --- a/monai/transforms/utility/dictionary.py +++ b/monai/transforms/utility/dictionary.py @@ -213,12 +213,15 @@ class AsChannelFirstd(MapTransform): backend = AsChannelFirst.backend - def __init__(self, keys: KeysCollection, channel_dim: int = -1, allow_missing_keys: bool = False) -> None: + def __init__( + self, keys: KeysCollection, channel_dim: Union[int, Sequence[int]] = -1, allow_missing_keys: bool = False + ) -> None: """ Args: keys: keys of the corresponding items to be transformed. See also: :py:class:`monai.transforms.compose.MapTransform` channel_dim: which dimension of input image is the channel, default is the last dimension. + if channel_dim is a sequence, the transform will move the channel dimensions to the front. allow_missing_keys: don't raise exception if key is missing. """ super().__init__(keys, allow_missing_keys) @@ -238,12 +241,15 @@ class AsChannelLastd(MapTransform): backend = AsChannelLast.backend - def __init__(self, keys: KeysCollection, channel_dim: int = 0, allow_missing_keys: bool = False) -> None: + def __init__( + self, keys: KeysCollection, channel_dim: Union[int, Sequence[int]] = 0, allow_missing_keys: bool = False + ) -> None: """ Args: keys: keys of the corresponding items to be transformed. See also: :py:class:`monai.transforms.compose.MapTransform` channel_dim: which dimension of input image is the channel, default is the first dimension. + if channel_dim is a sequence, the transform will move the channel dimensions to the back. allow_missing_keys: don't raise exception if key is missing. """ super().__init__(keys, allow_missing_keys) diff --git a/tests/test_as_channel_first.py b/tests/test_as_channel_first.py index a2d56295b8..427489cc4b 100644 --- a/tests/test_as_channel_first.py +++ b/tests/test_as_channel_first.py @@ -23,6 +23,7 @@ TESTS.append([p, {"channel_dim": -1}, (4, 1, 2, 3)]) TESTS.append([p, {"channel_dim": 3}, (4, 1, 2, 3)]) TESTS.append([p, {"channel_dim": 2}, (3, 1, 2, 4)]) + TESTS.append([p, {"channel_dim": (1, 2)}, (2, 3, 1, 4)]) class TestAsChannelFirst(unittest.TestCase): @@ -33,7 +34,10 @@ def test_value(self, in_type, input_param, expected_shape): self.assertTupleEqual(result.shape, expected_shape) if isinstance(test_data, torch.Tensor): test_data = test_data.cpu().numpy() - expected = np.moveaxis(test_data, input_param["channel_dim"], 0) + if isinstance(input_param["channel_dim"], int): + expected = np.moveaxis(test_data, input_param["channel_dim"], 0) + else: # sequence + expected = np.moveaxis(test_data, input_param["channel_dim"], tuple(range(len(input_param["channel_dim"])))) assert_allclose(result, expected, type_test=False) diff --git a/tests/test_as_channel_firstd.py b/tests/test_as_channel_firstd.py index 91086f9299..d725e3160e 100644 --- a/tests/test_as_channel_firstd.py +++ b/tests/test_as_channel_firstd.py @@ -22,6 +22,7 @@ TESTS.append([p, {"keys": ["image", "label", "extra"], "channel_dim": -1}, (4, 1, 2, 3)]) TESTS.append([p, {"keys": ["image", "label", "extra"], "channel_dim": 3}, (4, 1, 2, 3)]) TESTS.append([p, {"keys": ["image", "label", "extra"], "channel_dim": 2}, (3, 1, 2, 4)]) + TESTS.append([p, {"keys": ["image", "label", "extra"], "channel_dim": (2, 3)}, (3, 4, 1, 2)]) class TestAsChannelFirstd(unittest.TestCase): diff --git a/tests/test_as_channel_last.py b/tests/test_as_channel_last.py index e6446ab7a6..2c28b2259d 100644 --- a/tests/test_as_channel_last.py +++ b/tests/test_as_channel_last.py @@ -22,6 +22,7 @@ TESTS.append([p, {"channel_dim": 0}, (2, 3, 4, 1)]) TESTS.append([p, {"channel_dim": 1}, (1, 3, 4, 2)]) TESTS.append([p, {"channel_dim": 3}, (1, 2, 3, 4)]) + TESTS.append([p, {"channel_dim": (1, 2)}, (1, 4, 2, 3)]) class TestAsChannelLast(unittest.TestCase): diff --git a/tests/test_as_channel_lastd.py b/tests/test_as_channel_lastd.py index a6d94d216a..35ca9208ed 100644 --- a/tests/test_as_channel_lastd.py +++ b/tests/test_as_channel_lastd.py @@ -22,6 +22,7 @@ TESTS.append([p, {"keys": ["image", "label", "extra"], "channel_dim": 0}, (2, 3, 4, 1)]) TESTS.append([p, {"keys": ["image", "label", "extra"], "channel_dim": 1}, (1, 3, 4, 2)]) TESTS.append([p, {"keys": ["image", "label", "extra"], "channel_dim": 3}, (1, 2, 3, 4)]) + TESTS.append([p, {"keys": ["image", "label", "extra"], "channel_dim": (1, 0)}, (3, 4, 2, 1)]) class TestAsChannelLastd(unittest.TestCase):