Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 20 additions & 9 deletions monai/transforms/spatial/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ def __init__(
padding_mode: Union[GridSamplePadMode, str] = GridSamplePadMode.BORDER,
align_corners: bool = False,
dtype: DtypeLike = np.float64,
image_only: bool = False,
) -> None:
"""
Args:
Expand All @@ -114,13 +115,16 @@ def __init__(
dtype: data type for resampling computation. Defaults to ``np.float64`` for best precision.
If None, use the data type of input data. To be compatible with other modules,
the output data type is always ``np.float32``.
image_only: if True return only the image volume, otherwise return (image, original affine, new affine).

"""
self.pixdim = np.array(ensure_tuple(pixdim), dtype=np.float64)
self.diagonal = diagonal
self.mode: GridSampleMode = GridSampleMode(mode)
self.padding_mode: GridSamplePadMode = GridSamplePadMode(padding_mode)
self.align_corners = align_corners
self.dtype = dtype
self.image_only = image_only

def __call__(
self,
Expand All @@ -131,7 +135,7 @@ def __call__(
align_corners: Optional[bool] = None,
dtype: DtypeLike = None,
output_spatial_shape: Optional[np.ndarray] = None,
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
) -> Union[np.ndarray, Tuple[np.ndarray, np.ndarray, np.ndarray]]:
"""
Args:
data_array: in shape (num_channels, H[, W, ...]).
Expand Down Expand Up @@ -204,7 +208,8 @@ def __call__(
)
output_data = np.asarray(output_data.squeeze(0).detach().cpu().numpy(), dtype=np.float32) # type: ignore
new_affine = to_affine_nd(affine, new_affine)
return output_data, affine, new_affine

return output_data if self.image_only else (output_data, affine, new_affine)


class Orientation(Transform):
Expand All @@ -217,6 +222,7 @@ def __init__(
axcodes: Optional[str] = None,
as_closest_canonical: bool = False,
labels: Optional[Sequence[Tuple[str, str]]] = tuple(zip("LPI", "RAS")),
image_only: bool = False,
) -> None:
"""
Args:
Expand All @@ -229,6 +235,7 @@ def __init__(
labels: optional, None or sequence of (2,) sequences
(2,) sequences are labels for (beginning, end) of output axis.
Defaults to ``(('L', 'R'), ('P', 'A'), ('I', 'S'))``.
image_only: if True return only the image volume, otherwise return (image, original affine, new affine).

Raises:
ValueError: When ``axcodes=None`` and ``as_closest_canonical=True``. Incompatible values.
Expand All @@ -243,10 +250,11 @@ def __init__(
self.axcodes = axcodes
self.as_closest_canonical = as_closest_canonical
self.labels = labels
self.image_only = image_only

def __call__(
self, data_array: np.ndarray, affine: Optional[np.ndarray] = None
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
) -> Union[np.ndarray, Tuple[np.ndarray, np.ndarray, np.ndarray]]:
"""
original orientation of `data_array` is defined by `affine`.

Expand Down Expand Up @@ -289,7 +297,8 @@ def __call__(
data_array = np.ascontiguousarray(nib.orientations.apply_orientation(data_array, ornt))
new_affine = affine_ @ nib.orientations.inv_ornt_aff(spatial_ornt, shape)
new_affine = to_affine_nd(affine, new_affine)
return data_array, affine, new_affine

return data_array if self.image_only else (data_array, affine, new_affine)


class Flip(Transform):
Expand Down Expand Up @@ -1270,6 +1279,7 @@ def __init__(
padding_mode: Union[GridSamplePadMode, str] = GridSamplePadMode.REFLECTION,
as_tensor_output: bool = False,
device: Optional[torch.device] = None,
image_only: bool = False,
) -> None:
"""
The affine transformations are applied in rotate, shear, translate, scale order.
Expand All @@ -1296,6 +1306,7 @@ def __init__(
as_tensor_output: the computation is implemented using pytorch tensors, this option specifies
whether to convert it back to numpy arrays.
device: device on which the tensor will be allocated.
image_only: if True return only the image volume, otherwise return (image, affine).
"""
self.affine_grid = AffineGrid(
rotate_params=rotate_params,
Expand All @@ -1305,6 +1316,7 @@ def __init__(
as_tensor_output=True,
device=device,
)
self.image_only = image_only
self.resampler = Resample(as_tensor_output=as_tensor_output, device=device)
self.spatial_size = spatial_size
self.mode: GridSampleMode = GridSampleMode(mode)
Expand All @@ -1316,7 +1328,7 @@ def __call__(
spatial_size: Optional[Union[Sequence[int], int]] = None,
mode: Optional[Union[GridSampleMode, str]] = None,
padding_mode: Optional[Union[GridSamplePadMode, str]] = None,
) -> Tuple[Union[np.ndarray, torch.Tensor], Union[np.ndarray, torch.Tensor]]:
):
"""
Args:
img: shape must be (num_channels, H, W[, D]),
Expand All @@ -1334,10 +1346,9 @@ def __call__(
"""
sp_size = fall_back_tuple(spatial_size or self.spatial_size, img.shape[1:])
grid, affine = self.affine_grid(spatial_size=sp_size)
return (
self.resampler(img=img, grid=grid, mode=mode or self.mode, padding_mode=padding_mode or self.padding_mode),
affine,
)
ret = self.resampler(img, grid=grid, mode=mode or self.mode, padding_mode=padding_mode or self.padding_mode)

return ret if self.image_only else (ret, affine)


class RandAffine(RandomizableTransform):
Expand Down
9 changes: 8 additions & 1 deletion tests/test_affine.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,11 @@
{"img": np.arange(9).reshape((1, 3, 3)), "spatial_size": (-1, 0)},
np.arange(9).reshape(1, 3, 3),
],
[
dict(padding_mode="zeros", as_tensor_output=False, device=None, image_only=True),
{"img": np.arange(9).reshape((1, 3, 3)), "spatial_size": (-1, 0)},
np.arange(9).reshape(1, 3, 3),
],
[
dict(padding_mode="zeros", as_tensor_output=False, device=None),
{"img": np.arange(4).reshape((1, 2, 2))},
Expand Down Expand Up @@ -78,7 +83,9 @@ class TestAffine(unittest.TestCase):
@parameterized.expand(TEST_CASES)
def test_affine(self, input_param, input_data, expected_val):
g = Affine(**input_param)
result, _ = g(**input_data)
result = g(**input_data)
if isinstance(result, tuple):
result = result[0]
self.assertEqual(isinstance(result, torch.Tensor), isinstance(expected_val, torch.Tensor))
np.testing.assert_allclose(result, expected_val, rtol=1e-4, atol=1e-4)

Expand Down
10 changes: 10 additions & 0 deletions tests/test_orientation.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,13 @@
np.arange(12).reshape((2, 1, 2, 3)),
"RAS",
],
[
{"axcodes": "RAS", "image_only": True},
np.arange(12).reshape((2, 1, 2, 3)),
{"affine": np.eye(4)},
np.arange(12).reshape((2, 1, 2, 3)),
"RAS",
],
[
{"axcodes": "ALS"},
np.arange(12).reshape((2, 1, 2, 3)),
Expand Down Expand Up @@ -114,6 +121,9 @@ class TestOrientationCase(unittest.TestCase):
def test_ornt(self, init_param, img, data_param, expected_data, expected_code):
ornt = Orientation(**init_param)
res = ornt(img, **data_param)
if not isinstance(res, tuple):
np.testing.assert_allclose(res, expected_data)
return
np.testing.assert_allclose(res[0], expected_data)
original_affine = data_param["affine"]
np.testing.assert_allclose(original_affine, res[1])
Expand Down
9 changes: 9 additions & 0 deletions tests/test_spacing.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,12 @@
{"affine": np.eye(4)},
np.array([[[[1.0, 1.0]], [[1.0, 1.0]]]]),
],
[
{"pixdim": 1.0, "padding_mode": "zeros", "dtype": float, "image_only": True},
np.ones((1, 2, 1, 2)), # data
{"affine": np.eye(4)},
np.array([[[[1.0, 1.0]], [[1.0, 1.0]]]]),
],
[
{"pixdim": (1.0, 1.0, 1.0), "padding_mode": "zeros", "dtype": float},
np.ones((1, 2, 1, 2)), # data
Expand Down Expand Up @@ -145,6 +151,9 @@ class TestSpacingCase(unittest.TestCase):
@parameterized.expand(TEST_CASES)
def test_spacing(self, init_param, img, data_param, expected_output):
res = Spacing(**init_param)(img, **data_param)
if not isinstance(res, tuple):
np.testing.assert_allclose(res, expected_output, atol=1e-6)
return
np.testing.assert_allclose(res[0], expected_output, atol=1e-6)
sr = len(res[0].shape) - 1
if isinstance(init_param["pixdim"], float):
Expand Down