Skip to content

Commit 0968da2

Browse files
CPBridgeericspodKumoLiu
authored
Improve Orientation transform to use the "space" (LPS vs RAS) of a metatensor by default (#8473)
Fix for #8467 ### Description As detailed in #8467, the `Orientation` transform currently always assumes a tensor's affine matrix is in RAS, regardless of the `meta["space"]` attribute, leading to incorrect performance for LPS metatensors unless the `labels` are explicitly defined by the user (and it is not at all clear that this needs to be done or how it should be done). The code in this PR checks whether the input tensor is a metatensor with its affine defined in LPS space. If so, it adjusts the `labels` passed to `nibabel.orientations.axcodes2ornt` to give the expected behaviour for LPS tensors. The default value of the `labels` parameter of the `Orientation` transform (and `OrientationD`) has changed from `(('L', 'R'), ('P', 'A'), ('I', 'S'))` to `None`. However, since the behaviour of `nibabel.orientations.axcodes2ornt` when passed `labels=None` is equivalent to when passing `labels=(('L', 'R'), ('P', 'A'), ('I', 'S'))`, I would not consider this a breaking change. ### Types of changes <!--- Put an `x` in all the boxes that apply, and remove the not applicable items --> - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [x] New tests added to cover the changes. - [x] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [x] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [x] In-line docstrings updated. - [x] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: Chris Bridge <chrisbridge44@googlemail.com> Signed-off-by: Eric Kerfoot <17726042+ericspod@users.noreply.github.com> Co-authored-by: Eric Kerfoot <17726042+ericspod@users.noreply.github.com> Co-authored-by: YunLiu <55491388+KumoLiu@users.noreply.github.com>
1 parent d4ba52e commit 0968da2

File tree

4 files changed

+206
-22
lines changed

4 files changed

+206
-22
lines changed

monai/transforms/spatial/array.py

Lines changed: 44 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@
6464
GridSamplePadMode,
6565
InterpolateMode,
6666
NumpyPadMode,
67+
SpaceKeys,
6768
convert_to_cupy,
6869
convert_to_dst_type,
6970
convert_to_numpy,
@@ -75,6 +76,7 @@
7576
issequenceiterable,
7677
optional_import,
7778
)
79+
from monai.utils.deprecate_utils import deprecated_arg_default
7880
from monai.utils.enums import GridPatchSort, PatchKeys, TraceKeys, TransformBackends
7981
from monai.utils.misc import ImageMetaKey as Key
8082
from monai.utils.module import look_up_option
@@ -556,11 +558,20 @@ class Orientation(InvertibleTransform, LazyTransform):
556558

557559
backend = [TransformBackends.NUMPY, TransformBackends.TORCH]
558560

561+
@deprecated_arg_default(
562+
name="labels",
563+
old_default=(("L", "R"), ("P", "A"), ("I", "S")),
564+
new_default=None,
565+
msg_suffix=(
566+
"Default value changed to None meaning that the transform now uses the 'space' of a "
567+
"meta-tensor, if applicable, to determine appropriate axis labels."
568+
),
569+
)
559570
def __init__(
560571
self,
561572
axcodes: str | None = None,
562573
as_closest_canonical: bool = False,
563-
labels: Sequence[tuple[str, str]] | None = (("L", "R"), ("P", "A"), ("I", "S")),
574+
labels: Sequence[tuple[str, str]] | None = None,
564575
lazy: bool = False,
565576
) -> None:
566577
"""
@@ -573,7 +584,14 @@ def __init__(
573584
as_closest_canonical: if True, load the image as closest to canonical axis format.
574585
labels: optional, None or sequence of (2,) sequences
575586
(2,) sequences are labels for (beginning, end) of output axis.
576-
Defaults to ``(('L', 'R'), ('P', 'A'), ('I', 'S'))``.
587+
If ``None``, an appropriate value is chosen depending on the
588+
value of the ``"space"`` metadata item of a metatensor: if
589+
``"space"`` is ``"LPS"``, the value used is ``(('R', 'L'),
590+
('A', 'P'), ('I', 'S'))``, if ``"space"`` is ``"RPS"`` or the
591+
input is not a meta-tensor or has no ``"space"`` item, the
592+
value ``(('L', 'R'), ('P', 'A'), ('I', 'S'))`` is used. If not
593+
``None``, the provided value is always used and the ``"space"``
594+
metadata item (if any) of the input is ignored.
577595
lazy: a flag to indicate whether this transform should execute lazily or not.
578596
Defaults to False
579597
@@ -619,9 +637,19 @@ def __call__(self, data_array: torch.Tensor, lazy: bool | None = None) -> torch.
619637
raise ValueError(f"data_array must have at least one spatial dimension, got {spatial_shape}.")
620638
affine_: np.ndarray
621639
affine_np: np.ndarray
640+
labels = self.labels
622641
if isinstance(data_array, MetaTensor):
623642
affine_np, *_ = convert_data_type(data_array.peek_pending_affine(), np.ndarray)
624643
affine_ = to_affine_nd(sr, affine_np)
644+
645+
# Set up "labels" such that LPS tensors are handled correctly by default
646+
if (
647+
self.labels is None
648+
and "space" in data_array.meta
649+
and SpaceKeys(data_array.meta["space"]) == SpaceKeys.LPS
650+
):
651+
labels = (("R", "L"), ("A", "P"), ("I", "S")) # value for LPS
652+
625653
else:
626654
warnings.warn("`data_array` is not of type `MetaTensor, assuming affine to be identity.")
627655
# default to identity
@@ -640,7 +668,7 @@ def __call__(self, data_array: torch.Tensor, lazy: bool | None = None) -> torch.
640668
f"{self.__class__.__name__}: spatial shape = {spatial_shape}, channels = {data_array.shape[0]},"
641669
"please make sure the input is in the channel-first format."
642670
)
643-
dst = nib.orientations.axcodes2ornt(self.axcodes[:sr], labels=self.labels)
671+
dst = nib.orientations.axcodes2ornt(self.axcodes[:sr], labels=labels)
644672
if len(dst) < sr:
645673
raise ValueError(
646674
f"axcodes must match data_array spatially, got axcodes={len(self.axcodes)}D data_array={sr}D"
@@ -653,8 +681,19 @@ def inverse(self, data: torch.Tensor) -> torch.Tensor:
653681
transform = self.pop_transform(data)
654682
# Create inverse transform
655683
orig_affine = transform[TraceKeys.EXTRA_INFO]["original_affine"]
656-
orig_axcodes = nib.orientations.aff2axcodes(orig_affine)
657-
inverse_transform = Orientation(axcodes=orig_axcodes, as_closest_canonical=False, labels=self.labels)
684+
labels = self.labels
685+
686+
# Set up "labels" such that LPS tensors are handled correctly by default
687+
if (
688+
isinstance(data, MetaTensor)
689+
and self.labels is None
690+
and "space" in data.meta
691+
and SpaceKeys(data.meta["space"]) == SpaceKeys.LPS
692+
):
693+
labels = (("R", "L"), ("A", "P"), ("I", "S")) # value for LPS
694+
695+
orig_axcodes = nib.orientations.aff2axcodes(orig_affine, labels=labels)
696+
inverse_transform = Orientation(axcodes=orig_axcodes, as_closest_canonical=False, labels=labels)
658697
# Apply inverse
659698
with inverse_transform.trace_transform(False):
660699
data = inverse_transform(data)

monai/transforms/spatial/dictionary.py

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@
7171
ensure_tuple_rep,
7272
fall_back_tuple,
7373
)
74+
from monai.utils.deprecate_utils import deprecated_arg_default
7475
from monai.utils.enums import TraceKeys
7576
from monai.utils.module import optional_import
7677

@@ -545,12 +546,21 @@ class Orientationd(MapTransform, InvertibleTransform, LazyTransform):
545546

546547
backend = Orientation.backend
547548

549+
@deprecated_arg_default(
550+
name="labels",
551+
old_default=(("L", "R"), ("P", "A"), ("I", "S")),
552+
new_default=None,
553+
msg_suffix=(
554+
"Default value changed to None meaning that the transform now uses the 'space' of a "
555+
"meta-tensor, if applicable, to determine appropriate axis labels."
556+
),
557+
)
548558
def __init__(
549559
self,
550560
keys: KeysCollection,
551561
axcodes: str | None = None,
552562
as_closest_canonical: bool = False,
553-
labels: Sequence[tuple[str, str]] | None = (("L", "R"), ("P", "A"), ("I", "S")),
563+
labels: Sequence[tuple[str, str]] | None = None,
554564
allow_missing_keys: bool = False,
555565
lazy: bool = False,
556566
) -> None:
@@ -564,7 +574,14 @@ def __init__(
564574
as_closest_canonical: if True, load the image as closest to canonical axis format.
565575
labels: optional, None or sequence of (2,) sequences
566576
(2,) sequences are labels for (beginning, end) of output axis.
567-
Defaults to ``(('L', 'R'), ('P', 'A'), ('I', 'S'))``.
577+
If ``None``, an appropriate value is chosen depending on the
578+
value of the ``"space"`` metadata item of a metatensor: if
579+
``"space"`` is ``"LPS"``, the value used is ``(('R', 'L'),
580+
('A', 'P'), ('I', 'S'))``, if ``"space"`` is ``"RPS"`` or the
581+
input is not a meta-tensor or has no ``"space"`` item, the
582+
value ``(('L', 'R'), ('P', 'A'), ('I', 'S'))`` is used. If not
583+
``None``, the provided value is always used and the ``"space"``
584+
metadata item (if any) of the input is ignored.
568585
allow_missing_keys: don't raise exception if key is missing.
569586
lazy: a flag to indicate whether this transform should execute lazily or not.
570587
Defaults to False

tests/transforms/test_orientation.py

Lines changed: 87 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from __future__ import annotations
1313

1414
import unittest
15+
from typing import cast
1516

1617
import nibabel as nib
1718
import numpy as np
@@ -21,6 +22,7 @@
2122
from monai.data.meta_obj import set_track_meta
2223
from monai.data.meta_tensor import MetaTensor
2324
from monai.transforms import Orientation, create_rotate, create_translate
25+
from monai.utils import SpaceKeys
2426
from tests.lazy_transforms_utils import test_resampler_lazy
2527
from tests.test_utils import TEST_DEVICES, assert_allclose
2628

@@ -33,6 +35,18 @@
3335
torch.eye(4),
3436
torch.arange(12).reshape((2, 1, 2, 3)),
3537
"RAS",
38+
False,
39+
*device,
40+
]
41+
)
42+
TESTS.append(
43+
[
44+
{"axcodes": "LPS"},
45+
torch.arange(12).reshape((2, 1, 2, 3)),
46+
torch.eye(4),
47+
torch.arange(12).reshape((2, 1, 2, 3)),
48+
"LPS",
49+
True,
3650
*device,
3751
]
3852
)
@@ -43,6 +57,18 @@
4357
torch.as_tensor(np.diag([-1, -1, 1, 1])),
4458
torch.tensor([[[[3, 4, 5]], [[0, 1, 2]]], [[[9, 10, 11]], [[6, 7, 8]]]]),
4559
"ALS",
60+
False,
61+
*device,
62+
]
63+
)
64+
TESTS.append(
65+
[
66+
{"axcodes": "PRS"},
67+
torch.arange(12).reshape((2, 1, 2, 3)),
68+
torch.as_tensor(np.diag([-1, -1, 1, 1])),
69+
torch.tensor([[[[3, 4, 5]], [[0, 1, 2]]], [[[9, 10, 11]], [[6, 7, 8]]]]),
70+
"PRS",
71+
True,
4672
*device,
4773
]
4874
)
@@ -53,6 +79,18 @@
5379
torch.as_tensor(np.diag([-1, -1, 1, 1])),
5480
torch.tensor([[[[3, 4, 5], [0, 1, 2]]], [[[9, 10, 11], [6, 7, 8]]]]),
5581
"RAS",
82+
False,
83+
*device,
84+
]
85+
)
86+
TESTS.append(
87+
[
88+
{"axcodes": "LPS"},
89+
torch.arange(12).reshape((2, 1, 2, 3)),
90+
torch.as_tensor(np.diag([-1, -1, 1, 1])),
91+
torch.tensor([[[[3, 4, 5], [0, 1, 2]]], [[[9, 10, 11], [6, 7, 8]]]]),
92+
"LPS",
93+
True,
5694
*device,
5795
]
5896
)
@@ -63,6 +101,18 @@
63101
torch.eye(3),
64102
torch.tensor([[[0], [1], [2]], [[3], [4], [5]]]),
65103
"AL",
104+
False,
105+
*device,
106+
]
107+
)
108+
TESTS.append(
109+
[
110+
{"axcodes": "PR"},
111+
torch.arange(6).reshape((2, 1, 3)),
112+
torch.eye(3),
113+
torch.tensor([[[0], [1], [2]], [[3], [4], [5]]]),
114+
"PR",
115+
True,
66116
*device,
67117
]
68118
)
@@ -73,6 +123,18 @@
73123
torch.eye(2),
74124
torch.tensor([[2, 1, 0], [5, 4, 3]]),
75125
"L",
126+
False,
127+
*device,
128+
]
129+
)
130+
TESTS.append(
131+
[
132+
{"axcodes": "R"},
133+
torch.arange(6).reshape((2, 3)),
134+
torch.eye(2),
135+
torch.tensor([[2, 1, 0], [5, 4, 3]]),
136+
"R",
137+
True,
76138
*device,
77139
]
78140
)
@@ -83,6 +145,7 @@
83145
torch.eye(2),
84146
torch.tensor([[2, 1, 0], [5, 4, 3]]),
85147
"L",
148+
False,
86149
*device,
87150
]
88151
)
@@ -93,6 +156,7 @@
93156
torch.as_tensor(np.diag([-1, 1])),
94157
torch.arange(6).reshape((2, 3)),
95158
"L",
159+
False,
96160
*device,
97161
]
98162
)
@@ -107,6 +171,7 @@
107171
),
108172
torch.tensor([[[[2, 5]], [[1, 4]], [[0, 3]]], [[[8, 11]], [[7, 10]], [[6, 9]]]]),
109173
"LPS",
174+
False,
110175
*device,
111176
]
112177
)
@@ -121,6 +186,7 @@
121186
),
122187
torch.tensor([[[[0, 3]], [[1, 4]], [[2, 5]]], [[[6, 9]], [[7, 10]], [[8, 11]]]]),
123188
"RAS",
189+
False,
124190
*device,
125191
]
126192
)
@@ -131,6 +197,7 @@
131197
torch.as_tensor(create_translate(2, (10, 20)) @ create_rotate(2, (np.pi / 3)) @ np.diag([-1, -0.2, 1])),
132198
torch.tensor([[[3, 0], [4, 1], [5, 2]]]),
133199
"RA",
200+
False,
134201
*device,
135202
]
136203
)
@@ -141,6 +208,7 @@
141208
torch.as_tensor(create_translate(2, (10, 20)) @ create_rotate(2, (np.pi / 3)) @ np.diag([-1, -0.2, 1])),
142209
torch.tensor([[[2, 5], [1, 4], [0, 3]]]),
143210
"LP",
211+
False,
144212
*device,
145213
]
146214
)
@@ -151,6 +219,7 @@
151219
torch.as_tensor(np.diag([-1, -0.2, -1, 1, 1])),
152220
torch.zeros((1, 2, 3, 4, 5)),
153221
"LPID",
222+
False,
154223
*device,
155224
]
156225
)
@@ -161,6 +230,7 @@
161230
torch.as_tensor(np.diag([-1, -0.2, -1, 1, 1])),
162231
torch.zeros((1, 2, 3, 4, 5)),
163232
"RASD",
233+
False,
164234
*device,
165235
]
166236
)
@@ -175,6 +245,11 @@
175245
[{"axcodes": "RA"}, torch.arange(12).reshape((2, 1, 2, 3)), torch.eye(4)]
176246
]
177247

248+
TESTS_INVERSE = []
249+
for device in TEST_DEVICES:
250+
TESTS_INVERSE.append([True, *device])
251+
TESTS_INVERSE.append([False, *device])
252+
178253

179254
class TestOrientationCase(unittest.TestCase):
180255
@parameterized.expand(TESTS)
@@ -185,17 +260,20 @@ def test_ornt_meta(
185260
affine: torch.Tensor,
186261
expected_data: torch.Tensor,
187262
expected_code: str,
263+
lps_convention: bool,
188264
device,
189265
):
190-
img = MetaTensor(img, affine=affine).to(device)
266+
meta = {"space": SpaceKeys.LPS} if lps_convention else None
267+
img = MetaTensor(img, affine=affine, meta=meta).to(device)
191268
ornt = Orientation(**init_param)
192269
call_param = {"data_array": img}
193270
res = ornt(**call_param) # type: ignore[arg-type]
194271
if img.ndim in (3, 4):
195272
test_resampler_lazy(ornt, res, init_param, call_param)
196273

197274
assert_allclose(res, expected_data.to(device))
198-
new_code = nib.orientations.aff2axcodes(res.affine.cpu(), labels=ornt.labels) # type: ignore
275+
labels = (("R", "L"), ("A", "P"), ("I", "S")) if lps_convention else ornt.labels
276+
new_code = nib.orientations.aff2axcodes(res.affine.cpu(), labels=labels) # type: ignore
199277
self.assertEqual("".join(new_code), expected_code)
200278

201279
@parameterized.expand(TESTS_TORCH)
@@ -224,23 +302,23 @@ def test_bad_params(self, init_param, img: torch.Tensor, affine: torch.Tensor):
224302
with self.assertRaises(ValueError):
225303
Orientation(**init_param)(img)
226304

227-
@parameterized.expand(TEST_DEVICES)
228-
def test_inverse(self, device):
305+
@parameterized.expand(TESTS_INVERSE)
306+
def test_inverse(self, lps_convention: bool, device):
229307
img_t = torch.rand((1, 10, 9, 8), dtype=torch.float32, device=device)
230308
affine = torch.tensor(
231309
[[0, 0, -1, 0], [1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 0, 1]], dtype=torch.float32, device="cpu"
232310
)
233-
meta = {"fname": "somewhere"}
311+
meta = {"fname": "somewhere", "space": SpaceKeys.LPS if lps_convention else SpaceKeys.RAS}
234312
img = MetaTensor(img_t, affine=affine, meta=meta)
235313
tr = Orientation("LPS")
236314
# check that image and affine have changed
237-
img = tr(img)
315+
img = cast(MetaTensor, tr(img))
238316
self.assertNotEqual(img.shape, img_t.shape)
239-
self.assertGreater((affine - img.affine).max(), 0.5)
317+
self.assertGreater(float((affine - img.affine).max()), 0.5)
240318
# check that with inverse, image affine are back to how they were
241-
img = tr.inverse(img)
319+
img = cast(MetaTensor, tr.inverse(img))
242320
self.assertEqual(img.shape, img_t.shape)
243-
self.assertLess((affine - img.affine).max(), 1e-2)
321+
self.assertLess(float((affine - img.affine).max()), 1e-2)
244322

245323

246324
if __name__ == "__main__":

0 commit comments

Comments
 (0)