Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions docs/source/transforms.rst
Original file line number Diff line number Diff line change
Expand Up @@ -91,9 +91,16 @@ Generic Interfaces
Functionals
-----------

Crop and Pad (functional)
^^^^^^^^^^^^^^^^^^^^^^^^^
.. automodule:: monai.transforms.croppad.functional
:members:

Spatial (functional)
^^^^^^^^^^^^^^^^^^^^
.. automodule:: monai.transforms.spatial.functional
:members:

.. currentmodule:: monai.transforms

Vanilla Transforms
Expand Down
11 changes: 8 additions & 3 deletions monai/data/meta_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -461,7 +461,7 @@ def affine(self) -> torch.Tensor:
@affine.setter
def affine(self, d: NdarrayTensor) -> None:
"""Set the affine."""
self.meta[MetaKeys.AFFINE] = torch.as_tensor(d, device=torch.device("cpu"), dtype=torch.double)
self.meta[MetaKeys.AFFINE] = torch.as_tensor(d, device=torch.device("cpu"), dtype=torch.float64)

@property
def pixdim(self):
Expand All @@ -471,7 +471,10 @@ def pixdim(self):
return affine_to_spacing(self.affine)

def peek_pending_shape(self):
"""Get the currently expected spatial shape as if all the pending operations are executed."""
"""
Get the currently expected spatial shape as if all the pending operations are executed.
For tensors that have more than 3 spatial dimensions, only the shapes of the top 3 dimensions will be returned.
"""
res = None
if self.pending_operations:
res = self.pending_operations[-1].get(LazyAttr.SHAPE, None)
Expand All @@ -480,11 +483,13 @@ def peek_pending_shape(self):

def peek_pending_affine(self):
res = self.affine
r = len(res) - 1
for p in self.pending_operations:
next_matrix = convert_to_tensor(p.get(LazyAttr.AFFINE))
next_matrix = convert_to_tensor(p.get(LazyAttr.AFFINE), dtype=torch.float64)
if next_matrix is None:
continue
res = convert_to_dst_type(res, next_matrix)[0]
next_matrix = monai.data.utils.to_affine_nd(r, next_matrix)
res = monai.transforms.lazy.utils.combine_transforms(res, next_matrix)
return res

Expand Down
2 changes: 2 additions & 0 deletions monai/transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@
SpatialPadD,
SpatialPadDict,
)
from .croppad.functional import pad_func, pad_nd
from .intensity.array import (
AdjustContrast,
ComputeHoVerMaps,
Expand Down Expand Up @@ -453,6 +454,7 @@
ZoomD,
ZoomDict,
)
from .spatial.functional import spatial_resample
from .traits import LazyTrait, MultiSampleTrait, RandomizableTrait, ThreadUnsafe
from .transform import LazyTransform, MapTransform, Randomizable, RandomizableTransform, Transform, apply_transform
from .utility.array import (
Expand Down
11 changes: 7 additions & 4 deletions monai/transforms/inverse.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,8 +121,9 @@ def push_transform(self, data, *args, **kwargs):
return data.copy_meta_from(meta_obj)
if do_transform:
xform = data.pending_operations.pop()
extra = xform.copy()
xform.update(transform_info)
meta_obj = self.push_transform(data, transform_info=xform, lazy_evaluation=lazy_eval)
meta_obj = self.push_transform(data, transform_info=xform, lazy_evaluation=lazy_eval, extra_info=extra)
return data.copy_meta_from(meta_obj)
return data
kwargs["lazy_evaluation"] = lazy_eval
Expand Down Expand Up @@ -177,9 +178,9 @@ def track_transform_meta(
if not lazy_evaluation and affine is not None and isinstance(data_t, MetaTensor):
# not lazy evaluation, directly update the metatensor affine (don't push to the stack)
orig_affine = data_t.peek_pending_affine()
orig_affine = convert_to_dst_type(orig_affine, affine)[0]
affine = orig_affine @ to_affine_nd(len(orig_affine) - 1, affine, dtype=affine.dtype)
out_obj.meta[MetaKeys.AFFINE] = convert_to_tensor(affine, device=torch.device("cpu"))
orig_affine = convert_to_dst_type(orig_affine, affine, dtype=torch.float64)[0]
affine = orig_affine @ to_affine_nd(len(orig_affine) - 1, affine, dtype=torch.float64)
out_obj.meta[MetaKeys.AFFINE] = convert_to_tensor(affine, device=torch.device("cpu"), dtype=torch.float64)

if not (get_track_meta() and transform_info and transform_info.get(TraceKeys.TRACING)):
if isinstance(data, Mapping):
Expand All @@ -199,6 +200,8 @@ def track_transform_meta(
info[TraceKeys.ORIG_SIZE] = data_t.shape[1:]
# include extra_info
if extra_info is not None:
extra_info.pop(LazyAttr.SHAPE, None)
extra_info.pop(LazyAttr.AFFINE, None)
info[TraceKeys.EXTRA_INFO] = extra_info

# push the transform info to the applied_operation or pending_operation stack
Expand Down
2 changes: 1 addition & 1 deletion monai/transforms/lazy/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,6 @@ def resample(data: torch.Tensor, matrix: NdarrayOrTensor, spatial_size, kwargs:
"padding_mode": kwargs.pop(LazyAttr.PADDING_MODE, None),
}
resampler = monai.transforms.SpatialResample(**init_kwargs)
# resampler.lazy_evaluation = False # resampler is a lazytransform
resampler.lazy_evaluation = False # resampler is a lazytransform
with resampler.trace_transform(False): # don't track this transform in `img`
return resampler(img=img, **call_kwargs)
2 changes: 1 addition & 1 deletion monai/transforms/spatial/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ def __call__(
"""
# get dtype as torch (e.g., torch.float64)
dtype_pt = get_equivalent_dtype(dtype or self.dtype or img.dtype, torch.Tensor)
align_corners = self.align_corners if align_corners is None else align_corners
align_corners = align_corners if align_corners is not None else self.align_corners
mode = mode if mode is not None else self.mode
padding_mode = padding_mode if padding_mode is not None else self.padding_mode
return spatial_resample(
Expand Down
9 changes: 7 additions & 2 deletions monai/transforms/spatial/dictionary.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@
Zoom,
)
from monai.transforms.traits import MultiSampleTrait
from monai.transforms.transform import MapTransform, RandomizableTransform
from monai.transforms.transform import LazyTransform, MapTransform, RandomizableTransform
from monai.transforms.utils import create_grid
from monai.utils import (
GridSampleMode,
Expand Down Expand Up @@ -142,7 +142,7 @@
]


class SpatialResampled(MapTransform, InvertibleTransform):
class SpatialResampled(MapTransform, InvertibleTransform, LazyTransform):
"""
Dictionary-based wrapper of :py:class:`monai.transforms.SpatialResample`.

Expand Down Expand Up @@ -204,6 +204,11 @@ def __init__(
self.dtype = ensure_tuple_rep(dtype, len(self.keys))
self.dst_keys = ensure_tuple_rep(dst_keys, len(self.keys))

@LazyTransform.lazy_evaluation.setter # type: ignore
def lazy_evaluation(self, val: bool) -> None:
self._lazy_evaluation = val
self.sp_transform.lazy_evaluation = val

def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> dict[Hashable, torch.Tensor]:
d: dict = dict(data)
for key, mode, padding_mode, align_corners, dtype, dst_key in self.key_iterator(
Expand Down
30 changes: 29 additions & 1 deletion monai/transforms/spatial/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,37 @@
def spatial_resample(
img, dst_affine, spatial_size, mode, padding_mode, align_corners, dtype_pt, transform_info
) -> torch.Tensor:
"""
Functional implementation of resampling the input image to the specified ``dst_affine`` matrix and ``spatial_size``.
This function operates eagerly or lazily according to
``transform_info[TraceKeys.LAZY_EVALUATION]`` (default ``False``).

Args:
img: data to be resampled, assuming `img` is channel-first.
dst_affine: target affine matrix, if None, use the input affine matrix, effectively no resampling.
spatial_size: output spatial size, if the component is ``-1``, use the corresponding input spatial size.
mode: {``"bilinear"``, ``"nearest"``} or spline interpolation order 0-5 (integers).
Interpolation mode to calculate output values. Defaults to ``"bilinear"``.
See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html
When it's an integer, the numpy (cpu tensor)/cupy (cuda tensor) backends will be used
and the value represents the order of the spline interpolation.
See also: https://docs.scipy.org/doc/scipy/reference/generated/scipy.ndimage.map_coordinates.html
padding_mode: {``"zeros"``, ``"border"``, ``"reflection"``}
Padding mode for outside grid values. Defaults to ``"border"``.
See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html
When `mode` is an integer, using numpy/cupy backends, this argument accepts
{'reflect', 'grid-mirror', 'constant', 'grid-constant', 'nearest', 'mirror', 'grid-wrap', 'wrap'}.
See also: https://docs.scipy.org/doc/scipy/reference/generated/scipy.ndimage.map_coordinates.html
align_corners: Geometrically, we consider the pixels of the input as squares rather than points.
See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.grid_sample.html
Defaults to ``None``, effectively using the value of `self.align_corners`.
dtype_pt: data `dtype` for resampling computation.
transform_info: a dictionary with the relevant information pertaining to an applied transform.
"""
original_spatial_shape = img.peek_pending_shape() if isinstance(img, MetaTensor) else img.shape[1:]
src_affine: torch.Tensor = img.peek_pending_affine() if isinstance(img, MetaTensor) else torch.eye(4)
img = convert_to_tensor(data=img, track_meta=get_track_meta())
# ensure spatial rank is <= 3
spatial_rank = min(len(img.shape) - 1, src_affine.shape[0] - 1, 3)
if (not isinstance(spatial_size, int) or spatial_size != -1) and spatial_size is not None:
spatial_rank = min(len(ensure_tuple(spatial_size)), 3) # infer spatial rank based on spatial_size
Expand Down Expand Up @@ -101,7 +129,7 @@ def spatial_resample(
# no significant change or lazy change, return original image
out = convert_to_tensor(img, track_meta=get_track_meta())
return out.copy_meta_from(meta_info) if isinstance(out, MetaTensor) else meta_info # type: ignore
im_size = torch.tensor(img.shape).tolist()
im_size = list(img.shape)
chns, in_sp_size, additional_dims = im_size[0], im_size[1 : spatial_rank + 1], im_size[spatial_rank + 1 :]

if additional_dims:
Expand Down
48 changes: 43 additions & 5 deletions tests/test_spatial_resample.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from monai.data.meta_tensor import MetaTensor
from monai.data.utils import to_affine_nd
from monai.transforms import SpatialResample
from monai.transforms.lazy.functional import apply_transforms
from monai.utils import optional_import
from tests.utils import TEST_DEVICES, TEST_NDARRAYS_ALL, assert_allclose

Expand Down Expand Up @@ -131,6 +132,28 @@
TEST_TORCH_INPUT.append(t + [track_meta])


def get_apply_param(init_param=None, call_param=None):
apply_param = {}
for key in ["pending", "mode", "padding_mode", "dtype", "align_corners"]:
if init_param:
if key in init_param.keys():
apply_param[key] = init_param[key]
if call_param:
if key in call_param.keys():
apply_param[key] = call_param[key]
return apply_param


def test_resampler_lazy(resampler, non_lazy_out, init_param=None, call_param=None):
resampler.lazy_evaluation = True
pending_out = resampler(**call_param)
assert_allclose(pending_out.peek_pending_affine(), non_lazy_out.affine)
assert_allclose(pending_out.peek_pending_shape(), non_lazy_out.shape[1:4])
apply_param = get_apply_param(init_param, call_param)
lazy_out = apply_transforms(pending_out, **apply_param)[0]
assert_allclose(lazy_out, non_lazy_out, rtol=1e-5)


class TestSpatialResample(unittest.TestCase):
@parameterized.expand(TESTS)
def test_flips(self, img, device, data_param, expected_output):
Expand All @@ -140,9 +163,14 @@ def test_flips(self, img, device, data_param, expected_output):
img.affine = torch.eye(4)
if hasattr(img, "to"):
img = img.to(device)
out = SpatialResample()(img=img, **data_param)
resampler = SpatialResample()
call_param = data_param.copy()
call_param["img"] = img
out = resampler(**call_param)
assert_allclose(out, expected_output, rtol=1e-2, atol=1e-2)
assert_allclose(to_affine_nd(len(out.shape) - 1, out.affine), data_param["dst_affine"])
assert_allclose(to_affine_nd(len(out.shape) - 1, out.affine), call_param["dst_affine"])

test_resampler_lazy(resampler, out, init_param=None, call_param=call_param)

@parameterized.expand(TEST_4_5_D)
def test_4d_5d(self, new_shape, tile, device, dtype, expected_data):
Expand All @@ -152,10 +180,15 @@ def test_4d_5d(self, new_shape, tile, device, dtype, expected_data):

dst = torch.tensor([[1.0, 0.0, 0.0, 0.0], [0.0, 1.0, 0.0, 0.0], [0.0, 0.0, -1.0, 1.5], [0.0, 0.0, 0.0, 1.0]])
dst = dst.to(dtype)
out = SpatialResample(dtype=dtype, align_corners=True)(img=img, dst_affine=dst, align_corners=False)
init_param = {"dtype": dtype, "align_corners": True}
call_param = {"img": img, "dst_affine": dst, "align_corners": False}
resampler = SpatialResample(**init_param)
out = resampler(**call_param)
assert_allclose(out, expected_data[None], rtol=1e-2, atol=1e-2)
assert_allclose(out.affine, dst.to(torch.float32), rtol=1e-2, atol=1e-2)

test_resampler_lazy(resampler, out, init_param, call_param)

@parameterized.expand(TEST_DEVICES)
def test_ill_affine(self, device):
img = MetaTensor(torch.arange(12).reshape(1, 2, 2, 3)).to(device)
Expand All @@ -182,9 +215,14 @@ def test_input_torch(self, new_shape, tile, device, dtype, expected_data, track_
img = torch.as_tensor(np.tile(img, tile)).to(device)
dst = torch.tensor([[1.0, 0.0, 0.0, 0.0], [0.0, 1.0, 0.0, 0.0], [0.0, 0.0, -1.0, 1.5], [0.0, 0.0, 0.0, 1.0]])
dst = dst.to(dtype).to(device)

out = SpatialResample(dtype=dtype)(img=img, dst_affine=dst)
init_param = {"dtype": dtype}
call_param = {"img": img, "dst_affine": dst}
resampler = SpatialResample(**init_param)
out = resampler(**call_param)
assert_allclose(out, expected_data[None], rtol=1e-2, atol=1e-2)

test_resampler_lazy(resampler, out, init_param, call_param)

if track_meta:
self.assertIsInstance(out, MetaTensor)
assert_allclose(out.affine, dst.to(torch.float32), rtol=1e-2, atol=1e-2)
Expand Down
33 changes: 30 additions & 3 deletions tests/test_spatial_resampled.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

from monai.data.meta_tensor import MetaTensor
from monai.data.utils import to_affine_nd
from monai.transforms.lazy.functional import apply_transforms
from monai.transforms.spatial.dictionary import SpatialResampled
from tests.utils import TEST_DEVICES, assert_allclose

Expand Down Expand Up @@ -85,19 +86,45 @@
)


def get_apply_param(init_param=None, call_param=None):
apply_param = {}
for key in ["pending", "mode", "padding_mode", "dtype", "align_corners"]:
if init_param:
if key in init_param.keys():
apply_param[key] = init_param[key]
if call_param:
if key in call_param.keys():
apply_param[key] = call_param[key]
return apply_param


class TestSpatialResample(unittest.TestCase):
@parameterized.expand(TESTS)
def test_flips_inverse(self, img, device, dst_affine, kwargs, expected_output):
img = MetaTensor(img, affine=torch.eye(4)).to(device)
data = {"img": img, "dst_affine": dst_affine}

xform = SpatialResampled(keys="img", **kwargs)
output_data = xform(data)
init_param = kwargs.copy()
init_param["keys"] = "img"
call_param = {"data": data}
xform = SpatialResampled(**init_param)
output_data = xform(**call_param)
out = output_data["img"]

assert_allclose(out, expected_output, rtol=1e-2, atol=1e-2)
assert_allclose(to_affine_nd(len(out.shape) - 1, out.affine), dst_affine, rtol=1e-2, atol=1e-2)

# check lazy
lazy_xform = SpatialResampled(**init_param)
lazy_xform.lazy_evaluation = True
pending_output_data = lazy_xform(**call_param)
pending_out = pending_output_data["img"]
assert_allclose(pending_out.peek_pending_affine(), out.affine)
assert_allclose(pending_out.peek_pending_shape(), out.shape[1:4])
apply_param = get_apply_param(init_param=init_param, call_param=call_param)
lazy_out = apply_transforms(pending_out, **apply_param)[0]
assert_allclose(lazy_out, out, rtol=1e-5)

# check inverse
inverted = xform.inverse(output_data)["img"]
self.assertEqual(inverted.applied_operations, []) # no further invert after inverting
expected_affine = to_affine_nd(len(out.affine) - 1, torch.eye(4))
Expand Down