diff --git a/monai/transforms/utility/array.py b/monai/transforms/utility/array.py index 2ed036e9a6..ae2d56a11e 100644 --- a/monai/transforms/utility/array.py +++ b/monai/transforms/utility/array.py @@ -127,20 +127,22 @@ class AsChannelFirst(Transform): Args: channel_dim: which dimension of input image is the channel, default is the last dimension. + if channel_dim is a sequence, the transform will move the channel dimensions to the front. """ backend = [TransformBackends.TORCH, TransformBackends.NUMPY] - def __init__(self, channel_dim: int = -1) -> None: - if not (isinstance(channel_dim, int) and channel_dim >= -1): - raise AssertionError("invalid channel dimension.") - self.channel_dim = channel_dim + def __init__(self, channel_dim: Union[int, Sequence[int]] = -1) -> None: + self.channel_dim = ensure_tuple(channel_dim) + self.target_dim = tuple(range(len(self.channel_dim))) def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor: """ Apply the transform to `img`. """ - out: NdarrayOrTensor = convert_to_tensor(moveaxis(img, self.channel_dim, 0), track_meta=get_track_meta()) + out: NdarrayOrTensor = convert_to_tensor( + moveaxis(img, self.channel_dim, self.target_dim), track_meta=get_track_meta() + ) return out @@ -157,20 +159,22 @@ class AsChannelLast(Transform): Args: channel_dim: which dimension of input image is the channel, default is the first dimension. + if channel_dim is a sequence, the transform will move the channel dimensions to the back. """ backend = [TransformBackends.TORCH, TransformBackends.NUMPY] - def __init__(self, channel_dim: int = 0) -> None: - if not (isinstance(channel_dim, int) and channel_dim >= -1): - raise AssertionError("invalid channel dimension.") - self.channel_dim = channel_dim + def __init__(self, channel_dim: Union[int, Sequence[int]] = 0) -> None: + self.channel_dim = ensure_tuple(channel_dim) + self.target_dim = tuple(range(-len(self.channel_dim), 0)) def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor: """ Apply the transform to `img`. """ - out: NdarrayOrTensor = convert_to_tensor(moveaxis(img, self.channel_dim, -1), track_meta=get_track_meta()) + out: NdarrayOrTensor = convert_to_tensor( + moveaxis(img, self.channel_dim, self.target_dim), track_meta=get_track_meta() + ) return out diff --git a/monai/transforms/utility/dictionary.py b/monai/transforms/utility/dictionary.py index 92eaf95d27..c948f1f530 100644 --- a/monai/transforms/utility/dictionary.py +++ b/monai/transforms/utility/dictionary.py @@ -218,12 +218,15 @@ class AsChannelFirstd(MapTransform): backend = AsChannelFirst.backend - def __init__(self, keys: KeysCollection, channel_dim: int = -1, allow_missing_keys: bool = False) -> None: + def __init__( + self, keys: KeysCollection, channel_dim: Union[int, Sequence[int]] = -1, allow_missing_keys: bool = False + ) -> None: """ Args: keys: keys of the corresponding items to be transformed. See also: :py:class:`monai.transforms.compose.MapTransform` channel_dim: which dimension of input image is the channel, default is the last dimension. + if channel_dim is a sequence, the transform will move the channel dimensions to the front. allow_missing_keys: don't raise exception if key is missing. """ super().__init__(keys, allow_missing_keys) @@ -243,12 +246,15 @@ class AsChannelLastd(MapTransform): backend = AsChannelLast.backend - def __init__(self, keys: KeysCollection, channel_dim: int = 0, allow_missing_keys: bool = False) -> None: + def __init__( + self, keys: KeysCollection, channel_dim: Union[int, Sequence[int]] = 0, allow_missing_keys: bool = False + ) -> None: """ Args: keys: keys of the corresponding items to be transformed. See also: :py:class:`monai.transforms.compose.MapTransform` channel_dim: which dimension of input image is the channel, default is the first dimension. + if channel_dim is a sequence, the transform will move the channel dimensions to the back. allow_missing_keys: don't raise exception if key is missing. """ super().__init__(keys, allow_missing_keys) diff --git a/tests/test_as_channel_first.py b/tests/test_as_channel_first.py index 732c559a1a..f3a4c53de3 100644 --- a/tests/test_as_channel_first.py +++ b/tests/test_as_channel_first.py @@ -23,6 +23,7 @@ TESTS.append([p, {"channel_dim": -1}, (4, 1, 2, 3)]) TESTS.append([p, {"channel_dim": 3}, (4, 1, 2, 3)]) TESTS.append([p, {"channel_dim": 2}, (3, 1, 2, 4)]) + TESTS.append([p, {"channel_dim": (1, 2)}, (2, 3, 1, 4)]) class TestAsChannelFirst(unittest.TestCase): @@ -31,7 +32,10 @@ def test_value(self, in_type, input_param, expected_shape): test_data = in_type(np.random.randint(0, 2, size=[1, 2, 3, 4])) result = AsChannelFirst(**input_param)(test_data) self.assertTupleEqual(result.shape, expected_shape) - expected = moveaxis(test_data, input_param["channel_dim"], 0) + if isinstance(input_param["channel_dim"], int): + expected = moveaxis(test_data, input_param["channel_dim"], 0) + else: # sequence + expected = moveaxis(test_data, input_param["channel_dim"], tuple(range(len(input_param["channel_dim"])))) assert_allclose(result, expected, type_test="tensor") diff --git a/tests/test_as_channel_firstd.py b/tests/test_as_channel_firstd.py index 91086f9299..d725e3160e 100644 --- a/tests/test_as_channel_firstd.py +++ b/tests/test_as_channel_firstd.py @@ -22,6 +22,7 @@ TESTS.append([p, {"keys": ["image", "label", "extra"], "channel_dim": -1}, (4, 1, 2, 3)]) TESTS.append([p, {"keys": ["image", "label", "extra"], "channel_dim": 3}, (4, 1, 2, 3)]) TESTS.append([p, {"keys": ["image", "label", "extra"], "channel_dim": 2}, (3, 1, 2, 4)]) + TESTS.append([p, {"keys": ["image", "label", "extra"], "channel_dim": (2, 3)}, (3, 4, 1, 2)]) class TestAsChannelFirstd(unittest.TestCase): diff --git a/tests/test_as_channel_last.py b/tests/test_as_channel_last.py index e6446ab7a6..2c28b2259d 100644 --- a/tests/test_as_channel_last.py +++ b/tests/test_as_channel_last.py @@ -22,6 +22,7 @@ TESTS.append([p, {"channel_dim": 0}, (2, 3, 4, 1)]) TESTS.append([p, {"channel_dim": 1}, (1, 3, 4, 2)]) TESTS.append([p, {"channel_dim": 3}, (1, 2, 3, 4)]) + TESTS.append([p, {"channel_dim": (1, 2)}, (1, 4, 2, 3)]) class TestAsChannelLast(unittest.TestCase): diff --git a/tests/test_as_channel_lastd.py b/tests/test_as_channel_lastd.py index a6d94d216a..35ca9208ed 100644 --- a/tests/test_as_channel_lastd.py +++ b/tests/test_as_channel_lastd.py @@ -22,6 +22,7 @@ TESTS.append([p, {"keys": ["image", "label", "extra"], "channel_dim": 0}, (2, 3, 4, 1)]) TESTS.append([p, {"keys": ["image", "label", "extra"], "channel_dim": 1}, (1, 3, 4, 2)]) TESTS.append([p, {"keys": ["image", "label", "extra"], "channel_dim": 3}, (1, 2, 3, 4)]) + TESTS.append([p, {"keys": ["image", "label", "extra"], "channel_dim": (1, 0)}, (3, 4, 2, 1)]) class TestAsChannelLastd(unittest.TestCase):