Skip to content

Commit

Permalink
Adding a placeholder for resampling of point clouds
Browse files Browse the repository at this point in the history
Signed-off-by: Ben Murray <ben.murray@gmail.com>
  • Loading branch information
atbenmurray committed Jan 18, 2024
1 parent 0a1df5f commit 9de7dbf
Show file tree
Hide file tree
Showing 4 changed files with 19 additions and 10 deletions.
2 changes: 1 addition & 1 deletion monai/transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,7 @@
from .lazy.array import ApplyPending
from .lazy.dictionary import ApplyPendingd, ApplyPendingD, ApplyPendingDict
from .lazy.functional import apply_pending
from .lazy.utils import combine_transforms, resample
from .lazy.utils import combine_transforms, resample_image
from .meta_utility.dictionary import (
FromMetaTensord,
FromMetaTensorD,
Expand Down
10 changes: 7 additions & 3 deletions monai/transforms/lazy/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@
combine_transforms,
is_compatible_apply_kwargs,
kwargs_from_pending,
resample,
resample_image,
resample_points,
)
from monai.transforms.traits import LazyTrait
from monai.transforms.transform import MapTransform
Expand Down Expand Up @@ -336,7 +337,7 @@ def apply_pending(data: torch.Tensor | MetaTensor, pending: list | None = None,
# carry out an intermediate resample here due to incompatibility between arguments
_cur_kwargs = cur_kwargs.copy()
_cur_kwargs.update(override_kwargs)
data = resample(data.to(device), cumulative_xform, _cur_kwargs)
data = resample_image(data.to(device), cumulative_xform, _cur_kwargs)

next_matrix = affine_from_pending(p)
if next_matrix.shape[0] == 3:
Expand All @@ -345,7 +346,10 @@ def apply_pending(data: torch.Tensor | MetaTensor, pending: list | None = None,
cumulative_xform = combine_transforms(cumulative_xform, next_matrix)
cur_kwargs.update(new_kwargs)
cur_kwargs.update(override_kwargs)
data = resample(data.to(device), cumulative_xform, cur_kwargs)
if data.kind() == 'pixel':
data = resample_image(data.to(device), cumulative_xform, cur_kwargs)
elif data.kind() == 'point':
data = resample_points(data.to(device), cumulative_xform, cur_kwargs)
if isinstance(data, MetaTensor):
for p in pending:
data.push_applied_operation(p)
Expand Down
9 changes: 7 additions & 2 deletions monai/transforms/lazy/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from monai.transforms.utils_pytorch_numpy_unification import allclose
from monai.utils import LazyAttr, convert_to_numpy, convert_to_tensor, look_up_option

__all__ = ["resample", "combine_transforms"]
__all__ = ["resample_image", "combine_transforms"]


def affine_from_pending(pending_item):
Expand Down Expand Up @@ -91,7 +91,7 @@ def requires_interp(matrix, atol=AFFINE_TOL):
__override_lazy_keywords = {*list(LazyAttr), "atol"}


def resample(data: torch.Tensor, matrix: NdarrayOrTensor, kwargs: dict | None = None):
def resample_image(data: torch.Tensor, matrix: NdarrayOrTensor, kwargs: dict | None = None):
"""
Resample `data` using the affine transformation defined by ``matrix``.
Expand Down Expand Up @@ -173,3 +173,8 @@ def resample(data: torch.Tensor, matrix: NdarrayOrTensor, kwargs: dict | None =
resampler.lazy = False # resampler is a lazytransform
with resampler.trace_transform(False): # don't track this transform in `img`
return resampler(img=img, **call_kwargs)


def resample_points(data: torch.Tensor, matrix: NdarrayOrTensor, kwargs: dict | None = None):
# Handle all point resampling here
raise NotImplementedError()
8 changes: 4 additions & 4 deletions tests/test_resample.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import torch
from parameterized import parameterized

from monai.transforms.lazy.functional import resample
from monai.transforms.lazy.functional import resample_image
from monai.utils import convert_to_tensor
from tests.utils import assert_allclose, get_arange_img

Expand All @@ -37,12 +37,12 @@ def rotate_90_2d():
class TestResampleFunction(unittest.TestCase):
@parameterized.expand(RESAMPLE_FUNCTION_CASES)
def test_resample_function_impl(self, img, matrix, expected):
out = resample(convert_to_tensor(img), matrix, {"lazy_shape": img.shape[1:], "lazy_padding_mode": "border"})
out = resample_image(convert_to_tensor(img), matrix, {"lazy_shape": img.shape[1:], "lazy_padding_mode": "border"})
assert_allclose(out[0], expected, type_test=False)

img = convert_to_tensor(img, dtype=torch.uint8)
out = resample(img, matrix, {"lazy_resample_mode": "auto", "lazy_dtype": torch.float})
out_1 = resample(img, matrix, {"lazy_resample_mode": "other value", "lazy_dtype": torch.float})
out = resample_image(img, matrix, {"lazy_resample_mode": "auto", "lazy_dtype": torch.float})
out_1 = resample_image(img, matrix, {"lazy_resample_mode": "other value", "lazy_dtype": torch.float})
self.assertIs(out.dtype, out_1.dtype) # testing dtype in different lazy_resample_mode


Expand Down

0 comments on commit 9de7dbf

Please sign in to comment.