Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 14 additions & 10 deletions monai/transforms/utility/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,20 +127,22 @@ class AsChannelFirst(Transform):

Args:
channel_dim: which dimension of input image is the channel, default is the last dimension.
if channel_dim is a sequence, the transform will move the channel dimensions to the front.
"""

backend = [TransformBackends.TORCH, TransformBackends.NUMPY]

def __init__(self, channel_dim: int = -1) -> None:
if not (isinstance(channel_dim, int) and channel_dim >= -1):
raise AssertionError("invalid channel dimension.")
self.channel_dim = channel_dim
def __init__(self, channel_dim: Union[int, Sequence[int]] = -1) -> None:
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this transform is deprecated, if we still want this feature, it should be in EnsureChannelFirst cc @ericspod

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was related to another task you were working on. If we haven't had the need to implement this I would close this PR and return the concept later, though it seems that Transpose has similar behaviour anyhow.

self.channel_dim = ensure_tuple(channel_dim)
self.target_dim = tuple(range(len(self.channel_dim)))

def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor:
"""
Apply the transform to `img`.
"""
out: NdarrayOrTensor = convert_to_tensor(moveaxis(img, self.channel_dim, 0), track_meta=get_track_meta())
out: NdarrayOrTensor = convert_to_tensor(
moveaxis(img, self.channel_dim, self.target_dim), track_meta=get_track_meta()
)
return out


Expand All @@ -157,20 +159,22 @@ class AsChannelLast(Transform):

Args:
channel_dim: which dimension of input image is the channel, default is the first dimension.
if channel_dim is a sequence, the transform will move the channel dimensions to the back.
"""

backend = [TransformBackends.TORCH, TransformBackends.NUMPY]

def __init__(self, channel_dim: int = 0) -> None:
if not (isinstance(channel_dim, int) and channel_dim >= -1):
raise AssertionError("invalid channel dimension.")
self.channel_dim = channel_dim
def __init__(self, channel_dim: Union[int, Sequence[int]] = 0) -> None:
self.channel_dim = ensure_tuple(channel_dim)
self.target_dim = tuple(range(-len(self.channel_dim), 0))

def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor:
"""
Apply the transform to `img`.
"""
out: NdarrayOrTensor = convert_to_tensor(moveaxis(img, self.channel_dim, -1), track_meta=get_track_meta())
out: NdarrayOrTensor = convert_to_tensor(
moveaxis(img, self.channel_dim, self.target_dim), track_meta=get_track_meta()
)
return out


Expand Down
10 changes: 8 additions & 2 deletions monai/transforms/utility/dictionary.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,12 +218,15 @@ class AsChannelFirstd(MapTransform):

backend = AsChannelFirst.backend

def __init__(self, keys: KeysCollection, channel_dim: int = -1, allow_missing_keys: bool = False) -> None:
def __init__(
self, keys: KeysCollection, channel_dim: Union[int, Sequence[int]] = -1, allow_missing_keys: bool = False
) -> None:
"""
Args:
keys: keys of the corresponding items to be transformed.
See also: :py:class:`monai.transforms.compose.MapTransform`
channel_dim: which dimension of input image is the channel, default is the last dimension.
if channel_dim is a sequence, the transform will move the channel dimensions to the front.
allow_missing_keys: don't raise exception if key is missing.
"""
super().__init__(keys, allow_missing_keys)
Expand All @@ -243,12 +246,15 @@ class AsChannelLastd(MapTransform):

backend = AsChannelLast.backend

def __init__(self, keys: KeysCollection, channel_dim: int = 0, allow_missing_keys: bool = False) -> None:
def __init__(
self, keys: KeysCollection, channel_dim: Union[int, Sequence[int]] = 0, allow_missing_keys: bool = False
) -> None:
"""
Args:
keys: keys of the corresponding items to be transformed.
See also: :py:class:`monai.transforms.compose.MapTransform`
channel_dim: which dimension of input image is the channel, default is the first dimension.
if channel_dim is a sequence, the transform will move the channel dimensions to the back.
allow_missing_keys: don't raise exception if key is missing.
"""
super().__init__(keys, allow_missing_keys)
Expand Down
6 changes: 5 additions & 1 deletion tests/test_as_channel_first.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
TESTS.append([p, {"channel_dim": -1}, (4, 1, 2, 3)])
TESTS.append([p, {"channel_dim": 3}, (4, 1, 2, 3)])
TESTS.append([p, {"channel_dim": 2}, (3, 1, 2, 4)])
TESTS.append([p, {"channel_dim": (1, 2)}, (2, 3, 1, 4)])


class TestAsChannelFirst(unittest.TestCase):
Expand All @@ -31,7 +32,10 @@ def test_value(self, in_type, input_param, expected_shape):
test_data = in_type(np.random.randint(0, 2, size=[1, 2, 3, 4]))
result = AsChannelFirst(**input_param)(test_data)
self.assertTupleEqual(result.shape, expected_shape)
expected = moveaxis(test_data, input_param["channel_dim"], 0)
if isinstance(input_param["channel_dim"], int):
expected = moveaxis(test_data, input_param["channel_dim"], 0)
else: # sequence
expected = moveaxis(test_data, input_param["channel_dim"], tuple(range(len(input_param["channel_dim"]))))
assert_allclose(result, expected, type_test="tensor")


Expand Down
1 change: 1 addition & 0 deletions tests/test_as_channel_firstd.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
TESTS.append([p, {"keys": ["image", "label", "extra"], "channel_dim": -1}, (4, 1, 2, 3)])
TESTS.append([p, {"keys": ["image", "label", "extra"], "channel_dim": 3}, (4, 1, 2, 3)])
TESTS.append([p, {"keys": ["image", "label", "extra"], "channel_dim": 2}, (3, 1, 2, 4)])
TESTS.append([p, {"keys": ["image", "label", "extra"], "channel_dim": (2, 3)}, (3, 4, 1, 2)])


class TestAsChannelFirstd(unittest.TestCase):
Expand Down
1 change: 1 addition & 0 deletions tests/test_as_channel_last.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
TESTS.append([p, {"channel_dim": 0}, (2, 3, 4, 1)])
TESTS.append([p, {"channel_dim": 1}, (1, 3, 4, 2)])
TESTS.append([p, {"channel_dim": 3}, (1, 2, 3, 4)])
TESTS.append([p, {"channel_dim": (1, 2)}, (1, 4, 2, 3)])


class TestAsChannelLast(unittest.TestCase):
Expand Down
1 change: 1 addition & 0 deletions tests/test_as_channel_lastd.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
TESTS.append([p, {"keys": ["image", "label", "extra"], "channel_dim": 0}, (2, 3, 4, 1)])
TESTS.append([p, {"keys": ["image", "label", "extra"], "channel_dim": 1}, (1, 3, 4, 2)])
TESTS.append([p, {"keys": ["image", "label", "extra"], "channel_dim": 3}, (1, 2, 3, 4)])
TESTS.append([p, {"keys": ["image", "label", "extra"], "channel_dim": (1, 0)}, (3, 4, 2, 1)])


class TestAsChannelLastd(unittest.TestCase):
Expand Down