diff --git a/monai/networks/blocks/warp.py b/monai/networks/blocks/warp.py index 8a570e42c4..5fd82c19f2 100644 --- a/monai/networks/blocks/warp.py +++ b/monai/networks/blocks/warp.py @@ -110,18 +110,29 @@ def get_reference_grid(self, ddf: torch.Tensor, jitter: bool = False, seed: int self.ref_grid.requires_grad = False return self.ref_grid - def forward(self, image: torch.Tensor, ddf: torch.Tensor): + def forward( + self, image: torch.Tensor, ddf: torch.Tensor, keypoints: torch.Tensor | None = None + ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: """ Args: image: Tensor in shape (batch, num_channels, H, W[, D]) ddf: Tensor in the same spatial size as image, in shape (batch, ``spatial_dims``, H, W[, D]) + keypoints: Tensor in shape (batch, N, ``spatial_dims``), optional Returns: warped_image in the same shape as image (batch, num_channels, H, W[, D]) + warped_keypoints in the same shape as keypoints (batch, N, ``spatial_dims``), if keypoints is not None """ + batch_size = image.shape[0] spatial_dims = len(image.shape) - 2 if spatial_dims not in (2, 3): raise NotImplementedError(f"got unsupported spatial_dims={spatial_dims}, currently support 2 or 3.") + if keypoints is not None: + if keypoints.shape[-1] != spatial_dims: + raise ValueError( + f"Given input {spatial_dims}-d image, the last dimension of the input keypoints must be {spatial_dims}, " + f"got {keypoints.shape}." + ) ddf_shape = (image.shape[0], spatial_dims) + tuple(image.shape[2:]) if ddf.shape != ddf_shape: raise ValueError( @@ -136,12 +147,33 @@ def forward(self, image: torch.Tensor, ddf: torch.Tensor): grid[..., i] = grid[..., i] * 2 / (dim - 1) - 1 index_ordering: list[int] = list(range(spatial_dims - 1, -1, -1)) grid = grid[..., index_ordering] # z, y, x -> x, y, z - return F.grid_sample( + warped_image = F.grid_sample( image, grid, mode=self._interp_mode, padding_mode=f"{self._padding_mode}", align_corners=True ) - - # using csrc resampling - return grid_pull(image, grid, bound=self._padding_mode, extrapolate=True, interpolation=self._interp_mode) + else: + # using csrc resampling + warped_image = grid_pull( + image, grid, bound=self._padding_mode, extrapolate=True, interpolation=self._interp_mode + ) + if keypoints is not None: + with torch.no_grad(): + offset = torch.as_tensor(image.shape[2:]).to(keypoints) / 2.0 + offset = offset.unsqueeze(0).unsqueeze(0) + normalized_keypoints = torch.flip((keypoints - offset) / offset, (-1,)) + ddf_keypoints = ( + F.grid_sample( + ddf, + normalized_keypoints.view(batch_size, -1, 1, 1, spatial_dims), + mode=self._interp_mode, + padding_mode=f"{self._padding_mode}", + align_corners=True, + ) + .view(batch_size, 3, -1) + .permute((0, 2, 1)) + ) + warped_keypoints = keypoints + ddf_keypoints + return warped_image, warped_keypoints + return warped_image class DVF2DDF(nn.Module): diff --git a/monai/networks/nets/voxelmorph.py b/monai/networks/nets/voxelmorph.py index 4923b6ad60..0a95564f2a 100644 --- a/monai/networks/nets/voxelmorph.py +++ b/monai/networks/nets/voxelmorph.py @@ -405,6 +405,12 @@ class serves as a wrapper that concatenates the input pair of moving and fixed i fixed = torch.randn(1, 1, 160, 192, 224) warped, ddf = net(moving, fixed) + # Example with optional moving_seg and fixed_keypoints + moving_seg = torch.randint(0, 4, (1, 1, 160, 192, 224)).float() + moving_seg = one_hot(moving_seg, num_classes=4) + fixed_keypoints = torch.tensor([[[80, 96, 112], [40, 48, 56]]]).float() + warped_img, warped_seg, warped_keypoints, ddf = net( moving, fixed, moving_seg=moving_seg, fixed_keypoints=fixed_keypoints ) + """ def __init__( @@ -440,13 +446,37 @@ def __init__( self.dvf2ddf = DVF2DDF(num_steps=self.integration_steps, mode="bilinear", padding_mode="zeros") self.warp = Warp(mode="bilinear", padding_mode="zeros") - def forward(self, moving: torch.Tensor, fixed: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + def forward( + self, + moving: torch.Tensor, + fixed: torch.Tensor, + moving_seg: torch.Tensor | None = None, + fixed_keypoints: torch.Tensor | None = None, + ) -> tuple[torch.Tensor, ...]: if moving.shape != fixed.shape: raise ValueError( "The spatial shape of the moving image should be the same as the spatial shape of the fixed image." f" Got {moving.shape} and {fixed.shape} instead." ) + if moving_seg is not None: + if moving_seg.shape[0] != moving.shape[0]: + raise ValueError( + f"Batch dimension mismatch: moving_seg={moving_seg.shape[0]}, moving={moving.shape[0]}" + ) + if moving_seg.shape[2:] != moving.shape[2:]: + raise ValueError( + "The spatial shape of the moving segmentation must match the spatial shape of the moving image. " + f"Got {moving_seg.shape[2:]} vs {moving.shape[2:]}." + ) + + if fixed_keypoints is not None: + if fixed_keypoints.shape[-1] != self.spatial_dims: + raise ValueError( + "The last dimension of the fixed keypoints should be equal to the number of spatial dimensions." + f" Got {fixed_keypoints.shape[-1]} and {self.spatial_dims} instead." + ) + x = self.backbone(torch.cat([moving, fixed], dim=1)) if x.shape[1] != self.spatial_dims: @@ -470,7 +500,14 @@ def forward(self, moving: torch.Tensor, fixed: torch.Tensor) -> tuple[torch.Tens if self.half_res: x = F.interpolate(x * 0.5, scale_factor=2.0, mode="trilinear", align_corners=True) - return self.warp(moving, x), x + if moving_seg is None and fixed_keypoints is None: + return self.warp(moving, x), x + elif moving_seg is None and fixed_keypoints is not None: + return *self.warp(moving, x, fixed_keypoints), x + elif moving_seg is not None and fixed_keypoints is None: + return self.warp(moving, x), self.warp(moving_seg, x), x + else: + return self.warp(moving, x), *self.warp(moving_seg, x, fixed_keypoints), x voxelmorph = VoxelMorph diff --git a/tests/networks/nets/test_voxelmorph.py b/tests/networks/nets/test_voxelmorph.py index 1a04bab568..340db771af 100644 --- a/tests/networks/nets/test_voxelmorph.py +++ b/tests/networks/nets/test_voxelmorph.py @@ -171,6 +171,18 @@ TEST_CASE_9, ] +TEST_CASE_SEG_0 = [ + {"spatial_dims": 3}, + (1, 1, 96, 96, 48), # moving image + (1, 1, 96, 96, 48), # fixed image + (1, 2, 96, 96, 48), # moving label + (1, 1, 96, 96, 48), # expected warped moving image + (1, 2, 96, 96, 48), # expected warped moving label + (1, 3, 96, 96, 48), # expected ddf +] + +CASES_SEG = [TEST_CASE_SEG_0] + ILL_CASE_0 = [ # spatial_dims = 1 { "spatial_dims": 1, @@ -243,6 +255,15 @@ ILL_CASES_IN_SHAPE = [ILL_CASES_IN_SHAPE_0, ILL_CASES_IN_SHAPE_1] +ILL_CASE_SEG_SHAPE_0 = [ # moving_seg and moving image shape not match + {"spatial_dims": 3}, + (1, 1, 96, 96, 48), + (1, 1, 96, 96, 48), + (1, 2, 80, 96, 48), +] + +ILL_CASES_SEG_SHAPE = [ILL_CASE_SEG_SHAPE_0] + class TestVOXELMORPH(unittest.TestCase): @parameterized.expand(CASES) @@ -252,6 +273,28 @@ def test_shape(self, input_param, input_shape, expected_shape): result = net.forward(torch.randn(input_shape).to(device)) self.assertEqual(result.shape, expected_shape) + @parameterized.expand(CASES_SEG) + def test_shape_seg( + self, + input_param, + moving_shape, + fixed_shape, + moving_seg_shape, + expected_warped_moving_shape, + expected_warped_moving_seg_shape, + expected_ddf_shape, + ): + net = VoxelMorph(**input_param).to(device) + with eval_mode(net): + warped_moving, warped_moving_seg, ddf = net.forward( + torch.randn(moving_shape).to(device), + torch.randn(fixed_shape).to(device), + torch.randn(moving_seg_shape).to(device), + ) + self.assertEqual(warped_moving.shape, expected_warped_moving_shape) + self.assertEqual(warped_moving_seg.shape, expected_warped_moving_seg_shape) + self.assertEqual(ddf.shape, expected_ddf_shape) + def test_script(self): net = VoxelMorphUNet( spatial_dims=2, @@ -275,6 +318,17 @@ def test_ill_input_shape(self, input_param, moving_shape, fixed_shape): with eval_mode(net): _ = net.forward(torch.randn(moving_shape).to(device), torch.randn(fixed_shape).to(device)) + @parameterized.expand(ILL_CASES_SEG_SHAPE) + def test_ill_input_seg_shape(self, input_param, moving_shape, fixed_shape, moving_seg_shape): + with self.assertRaises((ValueError, RuntimeError)): + net = VoxelMorph(**input_param).to(device) + with eval_mode(net): + _ = net.forward( + torch.randn(moving_shape).to(device), + torch.randn(fixed_shape).to(device), + torch.randn(moving_seg_shape).to(device), + ) + if __name__ == "__main__": unittest.main()