From 8c25e0cb7377ead4fda46f7bd4a866aa6a0a4e81 Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Fri, 12 Nov 2021 10:10:21 +0800 Subject: [PATCH 1/2] [DLMED] add affine to dict transform Signed-off-by: Nic Ma --- monai/transforms/spatial/dictionary.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/monai/transforms/spatial/dictionary.py b/monai/transforms/spatial/dictionary.py index 8bfdd6fd52..0fc0b46cc1 100644 --- a/monai/transforms/spatial/dictionary.py +++ b/monai/transforms/spatial/dictionary.py @@ -605,6 +605,7 @@ def __init__( shear_params: Optional[Union[Sequence[float], float]] = None, translate_params: Optional[Union[Sequence[float], float]] = None, scale_params: Optional[Union[Sequence[float], float]] = None, + affine: Optional[NdarrayOrTensor] = None, spatial_size: Optional[Union[Sequence[int], int]] = None, mode: GridSampleModeSequence = GridSampleMode.BILINEAR, padding_mode: GridSamplePadModeSequence = GridSamplePadMode.REFLECTION, @@ -631,6 +632,9 @@ def __init__( pixel/voxel relative to the center of the input image. Defaults to no translation. scale_params: scale factor for every spatial dims. a tuple of 2 floats for 2D, a tuple of 3 floats for 3D. Defaults to `1.0`. + affine: if applied, ignore the params (`rotate_params`, etc.) and use the + supplied matrix. Should be square with each side = num of image spatial + dimensions + 1. spatial_size: output image spatial size. if `spatial_size` and `self.spatial_size` are not defined, or smaller than 1, the transform will use the spatial size of `img`. @@ -662,6 +666,7 @@ def __init__( shear_params=shear_params, translate_params=translate_params, scale_params=scale_params, + affine=affine, spatial_size=spatial_size, device=device, ) From 1f18eec5fe24a10552a6f5464f29eaf8a9387161 Mon Sep 17 00:00:00 2001 From: Nic Ma Date: Fri, 12 Nov 2021 19:02:41 +0800 Subject: [PATCH 2/2] [DLMED] add unit tests Signed-off-by: Nic Ma --- tests/test_affine.py | 11 +++++++++++ tests/test_affine_grid.py | 30 ++++++++++++++++++++++++++++++ tests/test_affined.py | 13 +++++++++++++ 3 files changed, 54 insertions(+) diff --git a/tests/test_affine.py b/tests/test_affine.py index bd89f1a436..e5d11e2a82 100644 --- a/tests/test_affine.py +++ b/tests/test_affine.py @@ -56,6 +56,17 @@ p(np.array([[[0.0, 0.0, 0.0, 0.0], [0.0, 2.0, 0.0, 0.0], [0.0, 3.0, 1.0, 0.0], [0.0, 0.0, 0.0, 0.0]]])), ] ) + TESTS.append( + [ + dict( + affine=p(torch.tensor([[0.0, -1.0, 0.0], [1.0, 0.0, 0.0], [0.0, 0.0, 1.0]])), + padding_mode="zeros", + device=device, + ), + {"img": p(np.arange(4).reshape((1, 2, 2))), "spatial_size": (4, 4)}, + p(np.array([[[0.0, 0.0, 0.0, 0.0], [0.0, 2.0, 0.0, 0.0], [0.0, 3.0, 1.0, 0.0], [0.0, 0.0, 0.0, 0.0]]])), + ] + ) TESTS.append( [ dict(padding_mode="zeros", device=device), diff --git a/tests/test_affine_grid.py b/tests/test_affine_grid.py index c12a395b47..27e049843c 100644 --- a/tests/test_affine_grid.py +++ b/tests/test_affine_grid.py @@ -54,6 +54,35 @@ ), ] ) + TESTS.append( + [ + { + "affine": p( + torch.tensor( + [[-10.8060, -8.4147, 0.0000], [-16.8294, 5.4030, 0.0000], [0.0000, 0.0000, 1.0000]] + ) + ) + }, + {"grid": p(torch.ones((3, 3, 3)))}, + p( + torch.tensor( + [ + [ + [-19.2208, -19.2208, -19.2208], + [-19.2208, -19.2208, -19.2208], + [-19.2208, -19.2208, -19.2208], + ], + [ + [-11.4264, -11.4264, -11.4264], + [-11.4264, -11.4264, -11.4264], + [-11.4264, -11.4264, -11.4264], + ], + [[1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 1.0]], + ] + ) + ), + ] + ) TESTS.append( [ {"rotate_params": (1.0, 1.0, 1.0), "scale_params": (-20, 10), "device": device}, @@ -99,6 +128,7 @@ ] ) + _rtol = 5e-2 if is_tf32_env() else 1e-4 diff --git a/tests/test_affined.py b/tests/test_affined.py index e9c468e755..a7c818fe65 100644 --- a/tests/test_affined.py +++ b/tests/test_affined.py @@ -49,6 +49,19 @@ p(np.array([[[0.0, 0.0, 0.0, 0.0], [0.0, 2.0, 0.0, 0.0], [0.0, 3.0, 1.0, 0.0], [0.0, 0.0, 0.0, 0.0]]])), ] ) + TESTS.append( + [ + dict( + keys="img", + affine=p(torch.tensor([[0.0, -1.0, 0.0], [1.0, 0.0, 0.0], [0.0, 0.0, 1.0]])), + padding_mode="zeros", + spatial_size=(4, 4), + device=device, + ), + {"img": p(np.arange(4).reshape((1, 2, 2)))}, + p(np.array([[[0.0, 0.0, 0.0, 0.0], [0.0, 2.0, 0.0, 0.0], [0.0, 3.0, 1.0, 0.0], [0.0, 0.0, 0.0, 0.0]]])), + ] + ) TESTS.append( [ dict(keys="img", padding_mode="zeros", spatial_size=(-1, 0, 0), device=device),