Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 14 additions & 12 deletions monai/losses/focal_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@

class FocalLoss(_Loss):
"""
Reimplementation of the Focal Loss described in:
Reimplementation of the Focal Loss (with a build-in sigmoid activation) described in:

- "Focal Loss for Dense Object Detection", T. Lin et al., ICCV 2017
- "AnatomyNet: Deep learning for fast and fully automated whole‐volume segmentation of head and neck anatomy",
Expand Down Expand Up @@ -77,12 +77,12 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
"""
Args:
input: the shape should be BNH[WD], where N is the number of classes.
The input should be the original logits since it will be transferred by
`F.log_softmax` in the forward function.
The input should be the original logits since it will be transformed by
a sigmoid in the forward function.
target: the shape should be BNH[WD] or B1H[WD], where N is the number of classes.

Raises:
AssertionError: When input and target (after one hot transform if setted)
ValueError: When input and target (after one hot transform if set)
have different shapes.
ValueError: When ``self.reduction`` is not one of ["mean", "sum", "none"].
ValueError: When ``self.weight`` is a sequence and the length is not equal to the
Expand All @@ -107,7 +107,7 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
input = input[:, 1:]

if target.shape != input.shape:
raise AssertionError(f"ground truth has different shape ({target.shape}) from input ({input.shape})")
raise ValueError(f"ground truth has different shape ({target.shape}) from input ({input.shape})")

i = input
t = target
Expand All @@ -117,10 +117,10 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
i = i.reshape(b, n, -1)
t = t.reshape(b, n, -1)

# Compute the log proba.
logpt = F.log_softmax(i, dim=1)
# Get the proba
pt = torch.exp(logpt) # B,H*W or B,N,H*W
# computing binary cross entropy with logits
# see also https://github.com/pytorch/pytorch/blob/v1.9.0/aten/src/ATen/native/Loss.cpp#L231
max_val = (-i).clamp(min=0)
ce = i - i * t + max_val + ((-max_val).exp() + (-i - max_val).exp()).log()

if self.weight is not None:
class_weight: Optional[torch.Tensor] = None
Expand All @@ -142,11 +142,13 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
at = class_weight[None, :, None] # N => 1,N,1
at = at.expand((t.size(0), -1, t.size(2))) # 1,N,1 => B,N,H*W
# Multiply the log proba by their weights.
logpt = logpt * at
ce = ce * at

# Compute the loss mini-batch.
weight = torch.pow(-pt + 1.0, self.gamma)
loss = torch.mean(-weight * t * logpt, dim=-1)
# (1-p_t)^gamma * log(p_t) with reduced chance of overflow
p = F.logsigmoid(-i * (t * 2.0 - 1.0))
loss = torch.mean((p * self.gamma).exp() * ce, dim=-1)

if self.reduction == LossReduction.SUM.value:
return loss.sum()
if self.reduction == LossReduction.NONE.value:
Expand Down
73 changes: 55 additions & 18 deletions tests/test_focal_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,32 +22,32 @@

class TestFocalLoss(unittest.TestCase):
def test_consistency_with_cross_entropy_2d(self):
# For gamma=0 the focal loss reduces to the cross entropy loss
focal_loss = FocalLoss(to_onehot_y=True, gamma=0.0, reduction="mean", weight=1.0)
ce = nn.CrossEntropyLoss(reduction="mean")
"""For gamma=0 the focal loss reduces to the cross entropy loss"""
focal_loss = FocalLoss(to_onehot_y=False, gamma=0.0, reduction="mean", weight=1.0)
ce = nn.BCEWithLogitsLoss(reduction="mean")
max_error = 0
class_num = 10
batch_size = 128
for _ in range(100):
# Create a random tensor of shape (batch_size, class_num, 8, 4)
x = torch.rand(batch_size, class_num, 8, 4, requires_grad=True)
# Create a random batch of classes
l = torch.randint(low=0, high=class_num, size=(batch_size, 1, 8, 4))
l = torch.randint(low=0, high=2, size=(batch_size, class_num, 8, 4)).float()
if torch.cuda.is_available():
x = x.cuda()
l = l.cuda()
output0 = focal_loss(x, l)
output1 = ce(x, l[:, 0]) / class_num
output1 = ce(x, l)
a = float(output0.cpu().detach())
b = float(output1.cpu().detach())
if abs(a - b) > max_error:
max_error = abs(a - b)
self.assertAlmostEqual(max_error, 0.0, places=3)

def test_consistency_with_cross_entropy_2d_onehot_label(self):
# For gamma=0 the focal loss reduces to the cross entropy loss
focal_loss = FocalLoss(to_onehot_y=False, gamma=0.0, reduction="mean")
ce = nn.CrossEntropyLoss(reduction="mean")
"""For gamma=0 the focal loss reduces to the cross entropy loss"""
focal_loss = FocalLoss(to_onehot_y=True, gamma=0.0, reduction="mean")
ce = nn.BCEWithLogitsLoss(reduction="mean")
max_error = 0
class_num = 10
batch_size = 128
Expand All @@ -59,18 +59,18 @@ def test_consistency_with_cross_entropy_2d_onehot_label(self):
if torch.cuda.is_available():
x = x.cuda()
l = l.cuda()
output0 = focal_loss(x, one_hot(l, num_classes=class_num))
output1 = ce(x, l[:, 0]) / class_num
output0 = focal_loss(x, l)
output1 = ce(x, one_hot(l, num_classes=class_num))
a = float(output0.cpu().detach())
b = float(output1.cpu().detach())
if abs(a - b) > max_error:
max_error = abs(a - b)
self.assertAlmostEqual(max_error, 0.0, places=3)

def test_consistency_with_cross_entropy_classification(self):
# for gamma=0 the focal loss reduces to the cross entropy loss
"""for gamma=0 the focal loss reduces to the cross entropy loss"""
focal_loss = FocalLoss(to_onehot_y=True, gamma=0.0, reduction="mean")
ce = nn.CrossEntropyLoss(reduction="mean")
ce = nn.BCEWithLogitsLoss(reduction="mean")
max_error = 0
class_num = 10
batch_size = 128
Expand All @@ -84,19 +84,43 @@ def test_consistency_with_cross_entropy_classification(self):
x = x.cuda()
l = l.cuda()
output0 = focal_loss(x, l)
output1 = ce(x, l[:, 0]) / class_num
output1 = ce(x, one_hot(l, num_classes=class_num))
a = float(output0.cpu().detach())
b = float(output1.cpu().detach())
if abs(a - b) > max_error:
max_error = abs(a - b)
self.assertAlmostEqual(max_error, 0.0, places=3)

def test_consistency_with_cross_entropy_classification_01(self):
# for gamma=0.1 the focal loss differs from the cross entropy loss
focal_loss = FocalLoss(to_onehot_y=True, gamma=0.1, reduction="mean")
ce = nn.BCEWithLogitsLoss(reduction="mean")
max_error = 0
class_num = 10
batch_size = 128
for _ in range(100):
# Create a random scores tensor of shape (batch_size, class_num)
x = torch.rand(batch_size, class_num, requires_grad=True)
# Create a random batch of classes
l = torch.randint(low=0, high=class_num, size=(batch_size, 1))
l = l.long()
if torch.cuda.is_available():
x = x.cuda()
l = l.cuda()
output0 = focal_loss(x, l)
output1 = ce(x, one_hot(l, num_classes=class_num))
a = float(output0.cpu().detach())
b = float(output1.cpu().detach())
if abs(a - b) > max_error:
max_error = abs(a - b)
self.assertNotAlmostEqual(max_error, 0.0, places=3)

def test_bin_seg_2d(self):
# define 2d examples
target = torch.tensor([[0, 0, 0, 0], [0, 1, 1, 0], [0, 1, 1, 0], [0, 0, 0, 0]])
# add another dimension corresponding to the batch (batch size = 1 here)
target = target.unsqueeze(0) # shape (1, H, W)
pred_very_good = 1000 * F.one_hot(target, num_classes=2).permute(0, 3, 1, 2).float()
pred_very_good = 100 * F.one_hot(target, num_classes=2).permute(0, 3, 1, 2).float() - 50.0

# initialize the mean dice loss
loss = FocalLoss(to_onehot_y=True)
Expand All @@ -112,7 +136,7 @@ def test_empty_class_2d(self):
target = torch.tensor([[0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]])
# add another dimension corresponding to the batch (batch size = 1 here)
target = target.unsqueeze(0) # shape (1, H, W)
pred_very_good = 1000 * F.one_hot(target, num_classes=num_classes).permute(0, 3, 1, 2).float()
pred_very_good = 1000 * F.one_hot(target, num_classes=num_classes).permute(0, 3, 1, 2).float() - 500.0

# initialize the mean dice loss
loss = FocalLoss(to_onehot_y=True)
Expand All @@ -128,7 +152,7 @@ def test_multi_class_seg_2d(self):
target = torch.tensor([[0, 0, 0, 0], [0, 1, 2, 0], [0, 3, 4, 0], [0, 0, 0, 0]])
# add another dimension corresponding to the batch (batch size = 1 here)
target = target.unsqueeze(0) # shape (1, H, W)
pred_very_good = 1000 * F.one_hot(target, num_classes=num_classes).permute(0, 3, 1, 2).float()
pred_very_good = 1000 * F.one_hot(target, num_classes=num_classes).permute(0, 3, 1, 2).float() - 500.0
# initialize the mean dice loss
loss = FocalLoss(to_onehot_y=True)
loss_onehot = FocalLoss(to_onehot_y=False)
Expand Down Expand Up @@ -159,7 +183,7 @@ def test_bin_seg_3d(self):
# add another dimension corresponding to the batch (batch size = 1 here)
target = target.unsqueeze(0) # shape (1, H, W, D)
target_one_hot = F.one_hot(target, num_classes=num_classes).permute(0, 4, 1, 2, 3) # test one hot
pred_very_good = 1000 * F.one_hot(target, num_classes=num_classes).permute(0, 4, 1, 2, 3).float()
pred_very_good = 1000 * F.one_hot(target, num_classes=num_classes).permute(0, 4, 1, 2, 3).float() - 500.0

# initialize the mean dice loss
loss = FocalLoss(to_onehot_y=True)
Expand All @@ -173,6 +197,19 @@ def test_bin_seg_3d(self):
focal_loss_good = float(loss_onehot(pred_very_good, target_one_hot).cpu())
self.assertAlmostEqual(focal_loss_good, 0.0, places=3)

def test_foreground(self):
background = torch.ones(1, 1, 5, 5)
foreground = torch.zeros(1, 1, 5, 5)
target = torch.cat((background, foreground), dim=1)
input = torch.cat((background, foreground), dim=1)
target[:, 0, 2, 2] = 0
target[:, 1, 2, 2] = 1

fgbg = FocalLoss(to_onehot_y=False, include_background=True)(input, target)
fg = FocalLoss(to_onehot_y=False, include_background=False)(input, target)
self.assertAlmostEqual(float(fgbg.cpu()), 0.1116, places=3)
self.assertAlmostEqual(float(fg.cpu()), 0.1733, places=3)

def test_ill_opts(self):
chn_input = torch.ones((1, 2, 3))
chn_target = torch.ones((1, 2, 3))
Expand All @@ -182,7 +219,7 @@ def test_ill_opts(self):
def test_ill_shape(self):
chn_input = torch.ones((1, 2, 3))
chn_target = torch.ones((1, 3))
with self.assertRaisesRegex(AssertionError, ""):
with self.assertRaisesRegex(ValueError, ""):
FocalLoss(reduction="mean")(chn_input, chn_target)

def test_ill_class_weight(self):
Expand Down
3 changes: 1 addition & 2 deletions tests/test_masked_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
"to_onehot_y": True,
"reduction": "sum",
},
[(12.105497, 18.805185), (10.636354, 6.3138)],
[(14.538666, 20.191753), (13.17672, 8.251623)],
],
]

Expand All @@ -50,7 +50,6 @@ def test_shape(self, input_param, expected_val):
label = torch.randint(low=0, high=2, size=size)
label = torch.argmax(label, dim=1, keepdim=True)
pred = torch.randn(size)
print(label[0, 0, 0])
result = MaskedLoss(**input_param)(pred, label, None)
out = result.detach().cpu().numpy()
checked = np.allclose(out, expected_val[0][0]) or np.allclose(out, expected_val[0][1])
Expand Down