diff --git a/monai/losses/deform.py b/monai/losses/deform.py index fea56010c7..1a2d6349a5 100644 --- a/monai/losses/deform.py +++ b/monai/losses/deform.py @@ -52,9 +52,12 @@ class BendingEnergyLoss(_Loss): DeepReg (https://github.com/DeepRegNet/DeepReg) """ - def __init__(self, reduction: Union[LossReduction, str] = LossReduction.MEAN) -> None: + def __init__(self, normalize: bool = False, reduction: Union[LossReduction, str] = LossReduction.MEAN) -> None: """ Args: + normalize: + Whether to divide out spatial sizes in order to make the computation roughly + invariant to image scale (i.e. vector field sampling resolution). Defaults to False. reduction: {``"none"``, ``"mean"``, ``"sum"``} Specifies the reduction to apply to the output. Defaults to ``"mean"``. @@ -63,6 +66,7 @@ def __init__(self, reduction: Union[LossReduction, str] = LossReduction.MEAN) -> - ``"sum"``: the output will be summed. """ super().__init__(reduction=LossReduction(reduction).value) + self.normalize = normalize def forward(self, pred: torch.Tensor) -> torch.Tensor: """ @@ -74,20 +78,35 @@ def forward(self, pred: torch.Tensor) -> torch.Tensor: """ if pred.ndim not in [3, 4, 5]: - raise ValueError(f"expecting 3-d, 4-d or 5-d pred, instead got pred of shape {pred.shape}") + raise ValueError(f"Expecting 3-d, 4-d or 5-d pred, instead got pred of shape {pred.shape}") for i in range(pred.ndim - 2): if pred.shape[-i - 1] <= 4: - raise ValueError("all spatial dimensions must > 4, got pred of shape {pred.shape}") + raise ValueError(f"All spatial dimensions must be > 4, got spatial dimensions {pred.shape[2:]}") + if pred.shape[1] != pred.ndim - 2: + raise ValueError( + f"Number of vector components, {pred.shape[1]}, does not match number of spatial dimensions, {pred.ndim-2}" + ) # first order gradient first_order_gradient = [spatial_gradient(pred, dim) for dim in range(2, pred.ndim)] + # spatial dimensions in a shape suited for broadcasting below + if self.normalize: + spatial_dims = torch.tensor(pred.shape, device=pred.device)[2:].reshape((1, -1) + (pred.ndim - 2) * (1,)) + energy = torch.tensor(0) for dim_1, g in enumerate(first_order_gradient): dim_1 += 2 - energy = spatial_gradient(g, dim_1) ** 2 + energy + if self.normalize: + g *= pred.shape[dim_1] / spatial_dims + energy = energy + (spatial_gradient(g, dim_1) * pred.shape[dim_1]) ** 2 + else: + energy = energy + spatial_gradient(g, dim_1) ** 2 for dim_2 in range(dim_1 + 1, pred.ndim): - energy = 2 * spatial_gradient(g, dim_2) ** 2 + energy + if self.normalize: + energy = energy + 2 * (spatial_gradient(g, dim_2) * pred.shape[dim_2]) ** 2 + else: + energy = energy + 2 * spatial_gradient(g, dim_2) ** 2 if self.reduction == LossReduction.MEAN.value: energy = torch.mean(energy) # the batch and channel average diff --git a/tests/test_bending_energy.py b/tests/test_bending_energy.py index f254b9624c..77cd4b42c9 100644 --- a/tests/test_bending_energy.py +++ b/tests/test_bending_energy.py @@ -22,9 +22,28 @@ TEST_CASES = [ [{}, {"pred": torch.ones((1, 3, 5, 5, 5), device=device)}, 0.0], [{}, {"pred": torch.arange(0, 5, device=device)[None, None, None, None, :].expand(1, 3, 5, 5, 5)}, 0.0], - [{}, {"pred": torch.arange(0, 5, device=device)[None, None, None, None, :].expand(1, 3, 5, 5, 5) ** 2}, 4.0], - [{}, {"pred": torch.arange(0, 5, device=device)[None, None, None, :].expand(1, 3, 5, 5) ** 2}, 4.0], - [{}, {"pred": torch.arange(0, 5, device=device)[None, None, :].expand(1, 3, 5) ** 2}, 4.0], + [ + {"normalize": False}, + {"pred": torch.arange(0, 5, device=device)[None, None, None, None, :].expand(1, 3, 5, 5, 5) ** 2}, + 4.0, + ], + [ + {"normalize": False}, + {"pred": torch.arange(0, 5, device=device)[None, None, None, :].expand(1, 2, 5, 5) ** 2}, + 4.0, + ], + [{"normalize": False}, {"pred": torch.arange(0, 5, device=device)[None, None, :].expand(1, 1, 5) ** 2}, 4.0], + [ + {"normalize": True}, + {"pred": torch.arange(0, 5, device=device)[None, None, None, None, :].expand(1, 3, 5, 5, 5) ** 2}, + 100.0, + ], + [ + {"normalize": True}, + {"pred": torch.arange(0, 5, device=device)[None, None, None, :].expand(1, 2, 5, 5) ** 2}, + 100.0, + ], + [{"normalize": True}, {"pred": torch.arange(0, 5, device=device)[None, None, :].expand(1, 1, 5) ** 2}, 100.0], ] @@ -37,18 +56,24 @@ def test_shape(self, input_param, input_data, expected_val): def test_ill_shape(self): loss = BendingEnergyLoss() # not in 3-d, 4-d, 5-d - with self.assertRaisesRegex(ValueError, ""): + with self.assertRaisesRegex(ValueError, "Expecting 3-d, 4-d or 5-d"): loss.forward(torch.ones((1, 3), device=device)) - with self.assertRaisesRegex(ValueError, ""): - loss.forward(torch.ones((1, 3, 5, 5, 5, 5), device=device)) + with self.assertRaisesRegex(ValueError, "Expecting 3-d, 4-d or 5-d"): + loss.forward(torch.ones((1, 4, 5, 5, 5, 5), device=device)) # spatial_dim < 5 - with self.assertRaisesRegex(ValueError, ""): + with self.assertRaisesRegex(ValueError, "All spatial dimensions"): loss.forward(torch.ones((1, 3, 4, 5, 5), device=device)) - with self.assertRaisesRegex(ValueError, ""): + with self.assertRaisesRegex(ValueError, "All spatial dimensions"): loss.forward(torch.ones((1, 3, 5, 4, 5))) - with self.assertRaisesRegex(ValueError, ""): + with self.assertRaisesRegex(ValueError, "All spatial dimensions"): loss.forward(torch.ones((1, 3, 5, 5, 4))) + # number of vector components unequal to number of spatial dims + with self.assertRaisesRegex(ValueError, "Number of vector components"): + loss.forward(torch.ones((1, 2, 5, 5, 5))) + with self.assertRaisesRegex(ValueError, "Number of vector components"): + loss.forward(torch.ones((1, 2, 5, 5, 5))) + def test_ill_opts(self): pred = torch.rand(1, 3, 5, 5, 5).to(device=device) with self.assertRaisesRegex(ValueError, ""): diff --git a/tests/test_reg_loss_integration.py b/tests/test_reg_loss_integration.py index 2949ee1519..74fd1f7959 100644 --- a/tests/test_reg_loss_integration.py +++ b/tests/test_reg_loss_integration.py @@ -20,7 +20,7 @@ from tests.utils import SkipIfBeforePyTorchVersion TEST_CASES = [ - [BendingEnergyLoss, {}, ["pred"]], + [BendingEnergyLoss, {}, ["pred"], 3], [LocalNormalizedCrossCorrelationLoss, {"kernel_size": 7, "kernel_type": "rectangular"}, ["pred", "target"]], [LocalNormalizedCrossCorrelationLoss, {"kernel_size": 5, "kernel_type": "triangular"}, ["pred", "target"]], [LocalNormalizedCrossCorrelationLoss, {"kernel_size": 3, "kernel_type": "gaussian"}, ["pred", "target"]], @@ -42,7 +42,7 @@ def tearDown(self): @parameterized.expand(TEST_CASES) @SkipIfBeforePyTorchVersion((1, 9)) - def test_convergence(self, loss_type, loss_args, forward_args): + def test_convergence(self, loss_type, loss_args, forward_args, pred_channels=1): """ The goal of this test is to assess if the gradient of the loss function is correct by testing if we can train a one layer neural network @@ -64,7 +64,7 @@ def __init__(self): self.layer = nn.Sequential( nn.Conv3d(in_channels=1, out_channels=1, kernel_size=3, padding=1), nn.ReLU(), - nn.Conv3d(in_channels=1, out_channels=1, kernel_size=3, padding=1), + nn.Conv3d(in_channels=1, out_channels=pred_channels, kernel_size=3, padding=1), ) def forward(self, x):