Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 23 additions & 5 deletions monai/losses/unified_focal_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,8 @@ def __init__(
gamma: float = 0.5,
delta: float = 0.7,
reduction: LossReduction | str = LossReduction.MEAN,
use_softmax: bool = False,
use_sigmoid: bool = False,
):
"""
Args:
Expand All @@ -170,8 +172,14 @@ def __init__(
weight : weight for each loss function. Defaults to 0.5.
gamma : value of the exponent gamma in the definition of the Focal loss. Defaults to 0.5.
delta : weight of the background. Defaults to 0.7.


reduction: {``"none"``, ``"mean"``, ``"sum"``}
Specifies the reduction to apply to the output. Defaults to ``"mean"``.
use_softmax: if True, use softmax to transform the input logits into probabilities.
Defaults to False. Mutually exclusive with ``use_sigmoid``.
use_sigmoid: if True, use sigmoid to transform the input logits into probabilities.
Defaults to False. Mutually exclusive with ``use_softmax``.
When both ``use_softmax`` and ``use_sigmoid`` are False, the input is assumed
to already be probabilities.

Example:
>>> import torch
Expand All @@ -182,22 +190,25 @@ def __init__(
>>> fl(pred, grnd)
"""
super().__init__(reduction=LossReduction(reduction).value)
if use_softmax and use_sigmoid:
raise ValueError("use_softmax and use_sigmoid are mutually exclusive.")
self.to_onehot_y = to_onehot_y
self.num_classes = num_classes
self.gamma = gamma
self.delta = delta
self.weight: float = weight
self.use_softmax = use_softmax
self.use_sigmoid = use_sigmoid
self.asy_focal_loss = AsymmetricFocalLoss(gamma=self.gamma, delta=self.delta)
self.asy_focal_tversky_loss = AsymmetricFocalTverskyLoss(gamma=self.gamma, delta=self.delta)

# TODO: Implement this function to support multiple classes segmentation
def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
"""
Args:
y_pred : the shape should be BNH[WD], where N is the number of classes.
It only supports binary segmentation.
The input should be the original logits since it will be transformed by
a sigmoid in the forward function.
The input can be raw logits or probabilities depending on ``use_softmax``
and ``use_sigmoid`` settings.
y_true : the shape should be BNH[WD], where N is the number of classes.
It only supports binary segmentation.

Expand All @@ -213,6 +224,13 @@ def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
if len(y_pred.shape) != 4 and len(y_pred.shape) != 5:
raise ValueError(f"input shape must be 4 or 5, but got {y_pred.shape}")

# Apply activation BEFORE one_hot encoding, since one_hot uses
# values as scatter indices and raw logits would cause index errors.
if self.use_softmax:
y_pred = torch.softmax(y_pred, dim=1)
elif self.use_sigmoid:
y_pred = torch.sigmoid(y_pred)

if y_pred.shape[1] == 1:
y_pred = one_hot(y_pred, num_classes=self.num_classes)
y_true = one_hot(y_true, num_classes=self.num_classes)
Expand Down
18 changes: 18 additions & 0 deletions tests/losses/test_unified_focal_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,24 @@ def test_with_cuda(self):
print(output)
np.testing.assert_allclose(output.detach().cpu().numpy(), 0.0, atol=1e-4, rtol=1e-4)

def test_use_sigmoid(self):
loss = AsymmetricUnifiedFocalLoss(use_sigmoid=True)
y_pred = torch.tensor([[[[10.0, -10], [-10, 10.0]]], [[[10.0, -10], [-10, 10.0]]]])
y_true = torch.tensor([[[[1.0, 0], [0, 1.0]]], [[[1.0, 0], [0, 1.0]]]])
result = loss(y_pred, y_true)
self.assertTrue(result.item() >= 0)
Comment thread
coderabbitai[bot] marked this conversation as resolved.

def test_use_softmax(self):
loss = AsymmetricUnifiedFocalLoss(use_softmax=True)
y_pred = torch.tensor([[[[10.0, -10], [-10, 10.0]]], [[[10.0, -10], [-10, 10.0]]]])
y_true = torch.tensor([[[[1.0, 0], [0, 1.0]]], [[[1.0, 0], [0, 1.0]]]])
result = loss(y_pred, y_true)
self.assertTrue(result.item() >= 0)

def test_mutually_exclusive(self):
with self.assertRaises(ValueError):
AsymmetricUnifiedFocalLoss(use_softmax=True, use_sigmoid=True)


if __name__ == "__main__":
unittest.main()
Loading