Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 24 additions & 5 deletions monai/losses/deform.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,12 @@ class BendingEnergyLoss(_Loss):
DeepReg (https://github.com/DeepRegNet/DeepReg)
"""

def __init__(self, reduction: Union[LossReduction, str] = LossReduction.MEAN) -> None:
def __init__(self, normalize: bool = False, reduction: Union[LossReduction, str] = LossReduction.MEAN) -> None:
"""
Args:
normalize:
Whether to divide out spatial sizes in order to make the computation roughly
invariant to image scale (i.e. vector field sampling resolution). Defaults to False.
reduction: {``"none"``, ``"mean"``, ``"sum"``}
Specifies the reduction to apply to the output. Defaults to ``"mean"``.

Expand All @@ -63,6 +66,7 @@ def __init__(self, reduction: Union[LossReduction, str] = LossReduction.MEAN) ->
- ``"sum"``: the output will be summed.
"""
super().__init__(reduction=LossReduction(reduction).value)
self.normalize = normalize

def forward(self, pred: torch.Tensor) -> torch.Tensor:
"""
Expand All @@ -74,20 +78,35 @@ def forward(self, pred: torch.Tensor) -> torch.Tensor:

"""
if pred.ndim not in [3, 4, 5]:
raise ValueError(f"expecting 3-d, 4-d or 5-d pred, instead got pred of shape {pred.shape}")
raise ValueError(f"Expecting 3-d, 4-d or 5-d pred, instead got pred of shape {pred.shape}")
for i in range(pred.ndim - 2):
if pred.shape[-i - 1] <= 4:
raise ValueError("all spatial dimensions must > 4, got pred of shape {pred.shape}")
raise ValueError(f"All spatial dimensions must be > 4, got spatial dimensions {pred.shape[2:]}")
if pred.shape[1] != pred.ndim - 2:
raise ValueError(
f"Number of vector components, {pred.shape[1]}, does not match number of spatial dimensions, {pred.ndim-2}"
)

# first order gradient
first_order_gradient = [spatial_gradient(pred, dim) for dim in range(2, pred.ndim)]

# spatial dimensions in a shape suited for broadcasting below
if self.normalize:
spatial_dims = torch.tensor(pred.shape, device=pred.device)[2:].reshape((1, -1) + (pred.ndim - 2) * (1,))

energy = torch.tensor(0)
for dim_1, g in enumerate(first_order_gradient):
dim_1 += 2
energy = spatial_gradient(g, dim_1) ** 2 + energy
if self.normalize:
g *= pred.shape[dim_1] / spatial_dims
energy = energy + (spatial_gradient(g, dim_1) * pred.shape[dim_1]) ** 2
else:
energy = energy + spatial_gradient(g, dim_1) ** 2
for dim_2 in range(dim_1 + 1, pred.ndim):
energy = 2 * spatial_gradient(g, dim_2) ** 2 + energy
if self.normalize:
energy = energy + 2 * (spatial_gradient(g, dim_2) * pred.shape[dim_2]) ** 2
else:
energy = energy + 2 * spatial_gradient(g, dim_2) ** 2

if self.reduction == LossReduction.MEAN.value:
energy = torch.mean(energy) # the batch and channel average
Expand Down
43 changes: 34 additions & 9 deletions tests/test_bending_energy.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,28 @@
TEST_CASES = [
[{}, {"pred": torch.ones((1, 3, 5, 5, 5), device=device)}, 0.0],
[{}, {"pred": torch.arange(0, 5, device=device)[None, None, None, None, :].expand(1, 3, 5, 5, 5)}, 0.0],
[{}, {"pred": torch.arange(0, 5, device=device)[None, None, None, None, :].expand(1, 3, 5, 5, 5) ** 2}, 4.0],
[{}, {"pred": torch.arange(0, 5, device=device)[None, None, None, :].expand(1, 3, 5, 5) ** 2}, 4.0],
[{}, {"pred": torch.arange(0, 5, device=device)[None, None, :].expand(1, 3, 5) ** 2}, 4.0],
[
{"normalize": False},
{"pred": torch.arange(0, 5, device=device)[None, None, None, None, :].expand(1, 3, 5, 5, 5) ** 2},
4.0,
],
[
{"normalize": False},
{"pred": torch.arange(0, 5, device=device)[None, None, None, :].expand(1, 2, 5, 5) ** 2},
4.0,
],
[{"normalize": False}, {"pred": torch.arange(0, 5, device=device)[None, None, :].expand(1, 1, 5) ** 2}, 4.0],
[
{"normalize": True},
{"pred": torch.arange(0, 5, device=device)[None, None, None, None, :].expand(1, 3, 5, 5, 5) ** 2},
100.0,
],
[
{"normalize": True},
{"pred": torch.arange(0, 5, device=device)[None, None, None, :].expand(1, 2, 5, 5) ** 2},
100.0,
],
[{"normalize": True}, {"pred": torch.arange(0, 5, device=device)[None, None, :].expand(1, 1, 5) ** 2}, 100.0],
]


Expand All @@ -37,18 +56,24 @@ def test_shape(self, input_param, input_data, expected_val):
def test_ill_shape(self):
loss = BendingEnergyLoss()
# not in 3-d, 4-d, 5-d
with self.assertRaisesRegex(ValueError, ""):
with self.assertRaisesRegex(ValueError, "Expecting 3-d, 4-d or 5-d"):
loss.forward(torch.ones((1, 3), device=device))
with self.assertRaisesRegex(ValueError, ""):
loss.forward(torch.ones((1, 3, 5, 5, 5, 5), device=device))
with self.assertRaisesRegex(ValueError, "Expecting 3-d, 4-d or 5-d"):
loss.forward(torch.ones((1, 4, 5, 5, 5, 5), device=device))
# spatial_dim < 5
with self.assertRaisesRegex(ValueError, ""):
with self.assertRaisesRegex(ValueError, "All spatial dimensions"):
loss.forward(torch.ones((1, 3, 4, 5, 5), device=device))
with self.assertRaisesRegex(ValueError, ""):
with self.assertRaisesRegex(ValueError, "All spatial dimensions"):
loss.forward(torch.ones((1, 3, 5, 4, 5)))
with self.assertRaisesRegex(ValueError, ""):
with self.assertRaisesRegex(ValueError, "All spatial dimensions"):
loss.forward(torch.ones((1, 3, 5, 5, 4)))

# number of vector components unequal to number of spatial dims
with self.assertRaisesRegex(ValueError, "Number of vector components"):
loss.forward(torch.ones((1, 2, 5, 5, 5)))
with self.assertRaisesRegex(ValueError, "Number of vector components"):
loss.forward(torch.ones((1, 2, 5, 5, 5)))

def test_ill_opts(self):
pred = torch.rand(1, 3, 5, 5, 5).to(device=device)
with self.assertRaisesRegex(ValueError, ""):
Expand Down
6 changes: 3 additions & 3 deletions tests/test_reg_loss_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from tests.utils import SkipIfBeforePyTorchVersion

TEST_CASES = [
[BendingEnergyLoss, {}, ["pred"]],
[BendingEnergyLoss, {}, ["pred"], 3],
[LocalNormalizedCrossCorrelationLoss, {"kernel_size": 7, "kernel_type": "rectangular"}, ["pred", "target"]],
[LocalNormalizedCrossCorrelationLoss, {"kernel_size": 5, "kernel_type": "triangular"}, ["pred", "target"]],
[LocalNormalizedCrossCorrelationLoss, {"kernel_size": 3, "kernel_type": "gaussian"}, ["pred", "target"]],
Expand All @@ -42,7 +42,7 @@ def tearDown(self):

@parameterized.expand(TEST_CASES)
@SkipIfBeforePyTorchVersion((1, 9))
def test_convergence(self, loss_type, loss_args, forward_args):
def test_convergence(self, loss_type, loss_args, forward_args, pred_channels=1):
"""
The goal of this test is to assess if the gradient of the loss function
is correct by testing if we can train a one layer neural network
Expand All @@ -64,7 +64,7 @@ def __init__(self):
self.layer = nn.Sequential(
nn.Conv3d(in_channels=1, out_channels=1, kernel_size=3, padding=1),
nn.ReLU(),
nn.Conv3d(in_channels=1, out_channels=1, kernel_size=3, padding=1),
nn.Conv3d(in_channels=1, out_channels=pred_channels, kernel_size=3, padding=1),
)

def forward(self, x):
Expand Down