From 79de4e104d7cfe415ee1f3894b277fb786702ecf Mon Sep 17 00:00:00 2001 From: Ebrahim Ebrahim Date: Tue, 14 Dec 2021 22:40:04 -0500 Subject: [PATCH 1/2] make bending energy loss invariant to resolution fixes #3485 Signed-off-by: Ebrahim Ebrahim --- monai/losses/deform.py | 29 ++++++++++++++++---- tests/test_bending_energy.py | 43 +++++++++++++++++++++++------- tests/test_reg_loss_integration.py | 6 ++--- 3 files changed, 61 insertions(+), 17 deletions(-) diff --git a/monai/losses/deform.py b/monai/losses/deform.py index fea56010c7..4365e45568 100644 --- a/monai/losses/deform.py +++ b/monai/losses/deform.py @@ -52,9 +52,12 @@ class BendingEnergyLoss(_Loss): DeepReg (https://github.com/DeepRegNet/DeepReg) """ - def __init__(self, reduction: Union[LossReduction, str] = LossReduction.MEAN) -> None: + def __init__(self, normalize: bool = True, reduction: Union[LossReduction, str] = LossReduction.MEAN) -> None: """ Args: + normalize: + Whether to divide out spatial sizes in order to make the computation roughly + invariant to image scale (i.e. vector field sampling resolution). Defaults to True. reduction: {``"none"``, ``"mean"``, ``"sum"``} Specifies the reduction to apply to the output. Defaults to ``"mean"``. @@ -63,6 +66,7 @@ def __init__(self, reduction: Union[LossReduction, str] = LossReduction.MEAN) -> - ``"sum"``: the output will be summed. """ super().__init__(reduction=LossReduction(reduction).value) + self.normalize = normalize def forward(self, pred: torch.Tensor) -> torch.Tensor: """ @@ -74,20 +78,35 @@ def forward(self, pred: torch.Tensor) -> torch.Tensor: """ if pred.ndim not in [3, 4, 5]: - raise ValueError(f"expecting 3-d, 4-d or 5-d pred, instead got pred of shape {pred.shape}") + raise ValueError(f"Expecting 3-d, 4-d or 5-d pred, instead got pred of shape {pred.shape}") for i in range(pred.ndim - 2): if pred.shape[-i - 1] <= 4: - raise ValueError("all spatial dimensions must > 4, got pred of shape {pred.shape}") + raise ValueError(f"All spatial dimensions must be > 4, got spatial dimensions {pred.shape[2:]}") + if pred.shape[1] != pred.ndim - 2: + raise ValueError( + f"Number of vector components, {pred.shape[1]}, does not match number of spatial dimensions, {pred.ndim-2}" + ) # first order gradient first_order_gradient = [spatial_gradient(pred, dim) for dim in range(2, pred.ndim)] + # spatial dimensions in a shape suited for broadcasting below + if self.normalize: + spatial_dims = torch.tensor(pred.shape, device=pred.device)[2:].reshape((1, -1) + (pred.ndim - 2) * (1,)) + energy = torch.tensor(0) for dim_1, g in enumerate(first_order_gradient): dim_1 += 2 - energy = spatial_gradient(g, dim_1) ** 2 + energy + if self.normalize: + g *= pred.shape[dim_1] / spatial_dims + energy = energy + (spatial_gradient(g, dim_1) * pred.shape[dim_1]) ** 2 + else: + energy = energy + spatial_gradient(g, dim_1) ** 2 for dim_2 in range(dim_1 + 1, pred.ndim): - energy = 2 * spatial_gradient(g, dim_2) ** 2 + energy + if self.normalize: + energy = energy + 2 * (spatial_gradient(g, dim_2) * pred.shape[dim_2]) ** 2 + else: + energy = energy + 2 * spatial_gradient(g, dim_2) ** 2 if self.reduction == LossReduction.MEAN.value: energy = torch.mean(energy) # the batch and channel average diff --git a/tests/test_bending_energy.py b/tests/test_bending_energy.py index f254b9624c..77cd4b42c9 100644 --- a/tests/test_bending_energy.py +++ b/tests/test_bending_energy.py @@ -22,9 +22,28 @@ TEST_CASES = [ [{}, {"pred": torch.ones((1, 3, 5, 5, 5), device=device)}, 0.0], [{}, {"pred": torch.arange(0, 5, device=device)[None, None, None, None, :].expand(1, 3, 5, 5, 5)}, 0.0], - [{}, {"pred": torch.arange(0, 5, device=device)[None, None, None, None, :].expand(1, 3, 5, 5, 5) ** 2}, 4.0], - [{}, {"pred": torch.arange(0, 5, device=device)[None, None, None, :].expand(1, 3, 5, 5) ** 2}, 4.0], - [{}, {"pred": torch.arange(0, 5, device=device)[None, None, :].expand(1, 3, 5) ** 2}, 4.0], + [ + {"normalize": False}, + {"pred": torch.arange(0, 5, device=device)[None, None, None, None, :].expand(1, 3, 5, 5, 5) ** 2}, + 4.0, + ], + [ + {"normalize": False}, + {"pred": torch.arange(0, 5, device=device)[None, None, None, :].expand(1, 2, 5, 5) ** 2}, + 4.0, + ], + [{"normalize": False}, {"pred": torch.arange(0, 5, device=device)[None, None, :].expand(1, 1, 5) ** 2}, 4.0], + [ + {"normalize": True}, + {"pred": torch.arange(0, 5, device=device)[None, None, None, None, :].expand(1, 3, 5, 5, 5) ** 2}, + 100.0, + ], + [ + {"normalize": True}, + {"pred": torch.arange(0, 5, device=device)[None, None, None, :].expand(1, 2, 5, 5) ** 2}, + 100.0, + ], + [{"normalize": True}, {"pred": torch.arange(0, 5, device=device)[None, None, :].expand(1, 1, 5) ** 2}, 100.0], ] @@ -37,18 +56,24 @@ def test_shape(self, input_param, input_data, expected_val): def test_ill_shape(self): loss = BendingEnergyLoss() # not in 3-d, 4-d, 5-d - with self.assertRaisesRegex(ValueError, ""): + with self.assertRaisesRegex(ValueError, "Expecting 3-d, 4-d or 5-d"): loss.forward(torch.ones((1, 3), device=device)) - with self.assertRaisesRegex(ValueError, ""): - loss.forward(torch.ones((1, 3, 5, 5, 5, 5), device=device)) + with self.assertRaisesRegex(ValueError, "Expecting 3-d, 4-d or 5-d"): + loss.forward(torch.ones((1, 4, 5, 5, 5, 5), device=device)) # spatial_dim < 5 - with self.assertRaisesRegex(ValueError, ""): + with self.assertRaisesRegex(ValueError, "All spatial dimensions"): loss.forward(torch.ones((1, 3, 4, 5, 5), device=device)) - with self.assertRaisesRegex(ValueError, ""): + with self.assertRaisesRegex(ValueError, "All spatial dimensions"): loss.forward(torch.ones((1, 3, 5, 4, 5))) - with self.assertRaisesRegex(ValueError, ""): + with self.assertRaisesRegex(ValueError, "All spatial dimensions"): loss.forward(torch.ones((1, 3, 5, 5, 4))) + # number of vector components unequal to number of spatial dims + with self.assertRaisesRegex(ValueError, "Number of vector components"): + loss.forward(torch.ones((1, 2, 5, 5, 5))) + with self.assertRaisesRegex(ValueError, "Number of vector components"): + loss.forward(torch.ones((1, 2, 5, 5, 5))) + def test_ill_opts(self): pred = torch.rand(1, 3, 5, 5, 5).to(device=device) with self.assertRaisesRegex(ValueError, ""): diff --git a/tests/test_reg_loss_integration.py b/tests/test_reg_loss_integration.py index 2949ee1519..74fd1f7959 100644 --- a/tests/test_reg_loss_integration.py +++ b/tests/test_reg_loss_integration.py @@ -20,7 +20,7 @@ from tests.utils import SkipIfBeforePyTorchVersion TEST_CASES = [ - [BendingEnergyLoss, {}, ["pred"]], + [BendingEnergyLoss, {}, ["pred"], 3], [LocalNormalizedCrossCorrelationLoss, {"kernel_size": 7, "kernel_type": "rectangular"}, ["pred", "target"]], [LocalNormalizedCrossCorrelationLoss, {"kernel_size": 5, "kernel_type": "triangular"}, ["pred", "target"]], [LocalNormalizedCrossCorrelationLoss, {"kernel_size": 3, "kernel_type": "gaussian"}, ["pred", "target"]], @@ -42,7 +42,7 @@ def tearDown(self): @parameterized.expand(TEST_CASES) @SkipIfBeforePyTorchVersion((1, 9)) - def test_convergence(self, loss_type, loss_args, forward_args): + def test_convergence(self, loss_type, loss_args, forward_args, pred_channels=1): """ The goal of this test is to assess if the gradient of the loss function is correct by testing if we can train a one layer neural network @@ -64,7 +64,7 @@ def __init__(self): self.layer = nn.Sequential( nn.Conv3d(in_channels=1, out_channels=1, kernel_size=3, padding=1), nn.ReLU(), - nn.Conv3d(in_channels=1, out_channels=1, kernel_size=3, padding=1), + nn.Conv3d(in_channels=1, out_channels=pred_channels, kernel_size=3, padding=1), ) def forward(self, x): From bb2366fd92d2a9b22d211a0e1c83e314b292e4a6 Mon Sep 17 00:00:00 2001 From: Ebrahim Ebrahim Date: Wed, 15 Dec 2021 14:45:45 -0500 Subject: [PATCH 2/2] set BendingEnergyLoss default normalize to False Maybe it's more important that the default behavior match usage of the term "bending energy" elsewhere, rather than that it be the most convenient behavior. Signed-off-by: Ebrahim Ebrahim --- monai/losses/deform.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/monai/losses/deform.py b/monai/losses/deform.py index 4365e45568..1a2d6349a5 100644 --- a/monai/losses/deform.py +++ b/monai/losses/deform.py @@ -52,12 +52,12 @@ class BendingEnergyLoss(_Loss): DeepReg (https://github.com/DeepRegNet/DeepReg) """ - def __init__(self, normalize: bool = True, reduction: Union[LossReduction, str] = LossReduction.MEAN) -> None: + def __init__(self, normalize: bool = False, reduction: Union[LossReduction, str] = LossReduction.MEAN) -> None: """ Args: normalize: Whether to divide out spatial sizes in order to make the computation roughly - invariant to image scale (i.e. vector field sampling resolution). Defaults to True. + invariant to image scale (i.e. vector field sampling resolution). Defaults to False. reduction: {``"none"``, ``"mean"``, ``"sum"``} Specifies the reduction to apply to the output. Defaults to ``"mean"``.