Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Geometric transform -- Resize #7509

Draft
wants to merge 8 commits into
base: geometric
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions monai/transforms/inverse.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,13 @@ def track_transform_meta(
extra_info.pop(LazyAttr.AFFINE, None)
info[TraceKeys.EXTRA_INFO] = extra_info

# update refer meta
if isinstance(data_t, MetaTensor):
if data_t.meta.get("refer_meta", None) is not None:
data_t.meta["refer_meta"]["spatial_shape"] = (
sp_size if sp_size is not None else info.get(TraceKeys.ORIG_SIZE, [])
)

# push the transform info to the applied_operation or pending_operation stack
if lazy:
if sp_size is None:
Expand Down
48 changes: 27 additions & 21 deletions monai/transforms/spatial/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,8 @@
affine_func,
flip,
orientation,
resize,
resize_image,
resize_point,
rotate,
rotate90,
spatial_resample,
Expand All @@ -51,6 +52,7 @@
create_scale,
create_shear,
create_translate,
get_input_shape,
map_spatial_axes,
resolves_modes,
scale_affine,
Expand Down Expand Up @@ -764,8 +766,9 @@ def __init__(
self.anti_aliasing = anti_aliasing
self.anti_aliasing_sigma = anti_aliasing_sigma
self.dtype = dtype
self.operators = [resize_point, resize_image] # type: ignore

def __call__(
def __call__( # type: ignore[return]
self,
img: torch.Tensor,
mode: str | None = None,
Expand Down Expand Up @@ -806,21 +809,24 @@ def __call__(
anti_aliasing = self.anti_aliasing if anti_aliasing is None else anti_aliasing
anti_aliasing_sigma = self.anti_aliasing_sigma if anti_aliasing_sigma is None else anti_aliasing_sigma

input_ndim = img.ndim - 1 # spatial ndim
input_shape = get_input_shape(img) # spatial shape
input_ndim = len(input_shape) # spatial ndim
if self.size_mode == "all":
output_ndim = len(ensure_tuple(self.spatial_size))
if output_ndim > input_ndim:
# only works for pixel data
kind = img.meta.get("kind", "pixel") if isinstance(img, MetaTensor) else "pixel"
if output_ndim > input_ndim and kind == "pixel":
input_shape = ensure_tuple_size(img.shape, output_ndim + 1, 1)
img = img.reshape(input_shape)
elif output_ndim < input_ndim:
raise ValueError(
"len(spatial_size) must be greater or equal to img spatial dimensions, "
f"got spatial_size={output_ndim} img={input_ndim}."
)
_sp = img.peek_pending_shape() if isinstance(img, MetaTensor) else img.shape[1:]
_sp = get_input_shape(img)
sp_size = fall_back_tuple(self.spatial_size, _sp)
else: # for the "longest" mode
img_size = img.peek_pending_shape() if isinstance(img, MetaTensor) else img.shape[1:]
img_size = input_shape
if not isinstance(self.spatial_size, int):
raise ValueError("spatial_size must be an int number if size_mode is 'longest'.")
scale = self.spatial_size / max(img_size)
Expand All @@ -830,28 +836,28 @@ def __call__(
_align_corners = self.align_corners if align_corners is None else align_corners
_dtype = get_equivalent_dtype(dtype or self.dtype or img.dtype, torch.Tensor)
lazy_ = self.lazy if lazy is None else lazy
return resize( # type: ignore
img,
tuple(int(_s) for _s in sp_size),
_mode,
_align_corners,
_dtype,
input_ndim,
anti_aliasing,
anti_aliasing_sigma,
lazy_,
self.get_transform_info(),
)
kwargs = {
"mode": _mode,
"align_corners": _align_corners,
"anti_aliasing": anti_aliasing,
"anti_aliasing_sigma": anti_aliasing_sigma,
}
for operator in self.operators:
ret: torch.Tensor = operator( # type: ignore
img, tuple(int(_s) for _s in sp_size), _dtype, input_ndim, lazy_, self.get_transform_info(), **kwargs
)
if ret is not None:
return ret

def inverse(self, data: torch.Tensor) -> torch.Tensor:
transform = self.pop_transform(data)
return self.inverse_transform(data, transform)

def inverse_transform(self, data: torch.Tensor, transform) -> torch.Tensor:
orig_size = transform[TraceKeys.ORIG_SIZE]
mode = transform[TraceKeys.EXTRA_INFO]["mode"]
align_corners = transform[TraceKeys.EXTRA_INFO]["align_corners"]
dtype = transform[TraceKeys.EXTRA_INFO]["dtype"]
mode = transform[TraceKeys.EXTRA_INFO].get("mode", None)
align_corners = transform[TraceKeys.EXTRA_INFO].get("align_corners", None)
dtype = transform[TraceKeys.EXTRA_INFO].get("dtype", None)
xform = Resize(
spatial_size=orig_size,
mode=mode,
Expand Down
135 changes: 114 additions & 21 deletions monai/transforms/spatial/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,22 @@
import monai
from monai.config import USE_COMPILED
from monai.config.type_definitions import NdarrayOrTensor
from monai.data.box_utils import COMPUTE_DTYPE, get_spatial_dims
from monai.data.meta_obj import get_track_meta
from monai.data.meta_tensor import MetaTensor
from monai.data.utils import AFFINE_TOL, compute_shape_offset, to_affine_nd
from monai.networks.layers import AffineTransform
from monai.transforms.croppad.array import ResizeWithPadOrCrop
from monai.transforms.intensity.array import GaussianSmooth
from monai.transforms.inverse import TraceableTransform
from monai.transforms.utils import create_rotate, create_translate, resolves_modes, scale_affine
from monai.transforms.utils import (
convert_data_type,
create_rotate,
create_scale,
create_translate,
resolves_modes,
scale_affine,
)
from monai.transforms.utils_pytorch_numpy_unification import allclose
from monai.utils import (
LazyAttr,
Expand All @@ -50,7 +58,17 @@
cupy_ndi, _ = optional_import("cupyx.scipy.ndimage")
np_ndi, _ = optional_import("scipy.ndimage")

__all__ = ["spatial_resample", "orientation", "flip", "resize", "rotate", "zoom", "rotate90", "affine_func"]
__all__ = [
"spatial_resample",
"orientation",
"flip",
"resize_image",
"resize_point",
"rotate",
"zoom",
"rotate90",
"affine_func",
]


def _maybe_new_metatensor(img, dtype=None, device=None):
Expand Down Expand Up @@ -265,9 +283,7 @@ def flip(img, sp_axes, lazy, transform_info):
return out.copy_meta_from(meta_info) if isinstance(out, MetaTensor) else out


def resize(
img, out_size, mode, align_corners, dtype, input_ndim, anti_aliasing, anti_aliasing_sigma, lazy, transform_info
):
def resize_image(img, out_size, dtype, input_ndim, lazy, transform_info, **kwargs):
"""
Functional implementation of resize.
This function operates eagerly or lazily according to
Expand All @@ -292,23 +308,14 @@ def resize(
lazy: a flag that indicates whether the operation should be performed lazily or not
transform_info: a dictionary with the relevant information pertaining to an applied transform.
"""
img = convert_to_tensor(img, track_meta=get_track_meta())
# TODO
kind = img.meta.get("kind", "pixel") if isinstance(img, MetaTensor) else "pixel"
if kind != "pixel":
return None
anti_aliasing = kwargs.pop("anti_aliasing")
anti_aliasing_sigma = kwargs.pop("anti_aliasing_sigma")
orig_size = img.peek_pending_shape() if isinstance(img, MetaTensor) else img.shape[1:]
extra_info = {
"mode": mode,
"align_corners": align_corners if align_corners is not None else TraceKeys.NONE,
"dtype": str(dtype)[6:], # dtype as string; remove "torch": torch.float32 -> float32
"new_dim": len(orig_size) - input_ndim,
}
meta_info = TraceableTransform.track_transform_meta(
img,
sp_size=out_size,
affine=scale_affine(orig_size, out_size),
extra_info=extra_info,
orig_size=orig_size,
transform_info=transform_info,
lazy=lazy,
)
mode, align_corners, meta_info = resize_helper(img, orig_size, out_size, dtype, input_ndim, lazy, transform_info, **kwargs)
if lazy:
if anti_aliasing and lazy:
warnings.warn("anti-aliasing is not compatible with lazy evaluation.")
Expand Down Expand Up @@ -339,6 +346,92 @@ def resize(
return out.copy_meta_from(meta_info) if isinstance(out, MetaTensor) else out


def _apply_affine_to_points(points, affine, include_shift: bool = True) -> torch.Tensor:
"""
This internal function applies affine matrices to the point coordinate
Args:
points: point coordinates, Nx2 or Nx3 torch tensor or ndarray, representing [x, y] or [x, y, z]
affine: affine matrix to be applied to the point coordinates, sized (spatial_dims+1,spatial_dims+1)
include_shift: default True, whether the function apply translation (shift) in the affine transform
Returns:
transformed point coordinates, with same data type as ``points``, does not share memory with ``points``
"""
# convert numpy to tensor if needed
points_t, *_ = convert_data_type(points, torch.Tensor)
points_t = points_t.to(dtype=COMPUTE_DTYPE)
affine_t, *_ = convert_to_dst_type(src=affine, dst=points_t)
spatial_dims = get_spatial_dims(points=points_t)

# compute new points
if include_shift:
# append 1 to form Nx(spatial_dims+1) vector, then transpose
points_affine = torch.cat(
[points_t, torch.ones(points_t.shape[0], 1, device=points_t.device, dtype=points_t.dtype)], dim=1
).transpose(0, 1)
# apply affine
points_affine = torch.matmul(affine_t, points_affine)
# remove appended 1 and transpose back
points_affine = points_affine[:spatial_dims, :].transpose(0, 1)
else:
points_affine = points_t.transpose(0, 1)
points_affine = torch.matmul(affine_t[:spatial_dims, :spatial_dims], points_affine)
points_affine = points_affine.transpose(0, 1)

# convert tensor back to numpy if needed
points_affine, *_ = convert_to_dst_type(src=points_affine, dst=points)

return points_affine


def resize_point(points, out_size, dtype, input_ndim, lazy, transform_info, **kwargs):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How can we avoid having *_point(...) functions for each operation? At this level, shouldn't we be able to always just pass the correct transform? If we think about object space (raster space) and world space transforms then we shouldn't need to be performing these calculations based on spatial size here.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As far as I understand it, resize should have no effect on point data, as it is entirely a "raster" space operation that preserves the extents of an image or volume, no?

Copy link
Contributor

@atbenmurray atbenmurray Mar 8, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As discussed in the meeting, the alternative of taking the operation that defines the operation out of the function and putting it into its own implementation function works just as well as what I am suggesting.

def resize_point(data, lazy, ...):
    meta_info = get_resize(...)
    out = apply_to_geom(data, metainfo, lazy)
    return out

def resize_image(data, lazy, ...):
    meta_info = get_resize(...)
    if not lazy:
        out = torch.interpolate(data, meta_info.affine)
    else:
        out = out.copy_meta_from(meta_info)
    return out

# TODO
kind = points.meta.get("kind", "pixel") if isinstance(points, MetaTensor) else "pixel"
if kind != "point":
return None
if points.meta.get("refer_meta", None) is not None:
src_spatial_size = points.meta["refer_meta"].get("spatial_shape", None)
else:
raise ValueError("Resize cannot be applied to a point without a reference meta.")
*_, meta_info = resize_helper(points, src_spatial_size, out_size, dtype, input_ndim, lazy, transform_info, **kwargs)
out = _maybe_new_metatensor(points)
if lazy:
return out.copy_meta_from(meta_info) if isinstance(out, MetaTensor) else meta_info
if tuple(convert_to_numpy(src_spatial_size)) == out_size:
out = _maybe_new_metatensor(points, dtype=dtype)
return out.copy_meta_from(meta_info) if isinstance(out, MetaTensor) else out
spatial_dims = get_spatial_dims(points=points[0])
scaling_factor = [out_size[axis] / float(src_spatial_size[axis]) for axis in range(spatial_dims)]
affine = create_scale(spatial_dims=spatial_dims, scaling_factor=scaling_factor)
ret: torch.Tensor = _apply_affine_to_points(points[0], affine, include_shift=True)

out, *_ = convert_to_dst_type(src=ret.unsqueeze(0), dst=points, dtype=dtype)
return out.copy_meta_from(meta_info) if isinstance(out, MetaTensor) else out


def resize_helper(data, src_spatial_size, out_size, dtype, input_ndim, lazy, transform_info, **kwargs):
data = convert_to_tensor(data, track_meta=get_track_meta())
extra_info={
"dtype": str(dtype)[6:], # dtype as string; remove "torch": torch.float32 -> float32
"new_dim": len(src_spatial_size) - input_ndim,
}
mode = kwargs.pop("mode", None)
align_corners = kwargs.pop("align_corners", None)
if mode is not None:
extra_info["mode"] = mode
if align_corners is not None:
extra_info["align_corners"] = align_corners

meta_info = TraceableTransform.track_transform_meta(
data,
sp_size=out_size,
affine=scale_affine(src_spatial_size, out_size),
extra_info=extra_info,
orig_size=src_spatial_size,
transform_info=transform_info,
lazy=lazy,
)
return mode, align_corners, meta_info

def rotate(img, angle, output_shape, mode, padding_mode, align_corners, dtype, lazy, transform_info):
"""
Functional implementation of rotate.
Expand Down
14 changes: 14 additions & 0 deletions monai/transforms/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import monai
from monai.config import DtypeLike, IndexSelection
from monai.config.type_definitions import NdarrayOrTensor, NdarrayTensor
from monai.data.meta_tensor import MetaTensor
from monai.networks.layers import GaussianFilter
from monai.networks.utils import meshgrid_ij
from monai.transforms.compose import Compose
Expand Down Expand Up @@ -2221,5 +2222,18 @@ def distance_transform_edt(
return convert_data_type(r_vals[0] if len(r_vals) == 1 else r_vals, output_type=type(img), device=device)[0]


def get_input_shape(data):
if isinstance(data, MetaTensor):
if data.meta.get("refer_meta", None) is not None:
refer_shape = data.meta["refer_meta"].get("spatial_shape", None)
if refer_shape is not None:
input_shape = refer_shape
else:
input_shape = data.peek_pending_shape()
else:
input_shape = data.shape[1:]
return input_shape


if __name__ == "__main__":
print_transform_backends()
29 changes: 28 additions & 1 deletion tests/test_resize.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

from monai.data import MetaTensor, set_track_meta
from monai.transforms import Resize
from monai.utils import convert_to_dst_type
from tests.lazy_transforms_utils import test_resampler_lazy
from tests.utils import (
TEST_NDARRAYS_ALL,
Expand Down Expand Up @@ -93,7 +94,7 @@ def test_correct_results(self, spatial_size, mode, anti_aliasing):
]

expected = np.stack(expected).astype(np.float32)
for p in TEST_NDARRAYS_ALL:
for p in TEST_NDARRAYS_ALL[:1]:
im = p(self.imt[0])
call_param = {"img": im}
out = resize(**call_param)
Expand Down Expand Up @@ -136,6 +137,32 @@ def test_longest_infinite_decimals(self):
ret = resize(np.random.randint(0, 2, size=[1, 2544, 3032]))
self.assertTupleEqual(ret.shape, (1, 846, 1008))

@parameterized.expand(
[
((32, -1), "all", [[[12, 6], [18, 9], [24, 8]]]),
((32, 32, 32), "all", [[[12, 6, 32], [18, 9, 0], [24, 8, 18]]]),
((128, 64), "all", [[[12, 6], [18, 9], [24, 64]]]), # already in a good shape
(32, "longest", [[[12, 6], [18, 9], [24, 8]]]),
]
)
def test_point(self, spatial_size, size_mode, data):
init_param = {"spatial_size": spatial_size, "dtype": np.int64, "size_mode": size_mode}
resize = Resize(**init_param)
if spatial_size == (32, -1):
spatial_size = (32, 64)
elif spatial_size == 32:
spatial_size = (32, 16)
refer_shape = (128, 64) if len(spatial_size) == 2 else (128, 64, 64)
data = MetaTensor(data, meta={"kind": "point", "refer_meta": {"spatial_shape": refer_shape}})
expected = [data[0][..., i] * (spatial_size[i] / refer_shape[i]) for i in range(len(refer_shape))]
expected, *_ = convert_to_dst_type(torch.stack(expected, dim=1).unsqueeze(0), data)
out = resize(data)
im_inv = resize.inverse(out)
self.assertTrue(not im_inv.applied_operations)
assert_allclose(im_inv.shape, data.shape)
assert_allclose(out, expected, type_test="tensor")
assert_allclose(im_inv.affine, data.affine, atol=1e-3, rtol=1e-3)


if __name__ == "__main__":
unittest.main()
Loading
Loading