From 0690d8da350b724eed6491a7d192def80d65a6af Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Fri, 2 Jul 2021 19:51:25 +0100 Subject: [PATCH 1/5] update focalloss to use sigmoid Signed-off-by: Wenqi Li --- monai/losses/focal_loss.py | 17 +++++---- tests/test_focal_loss.py | 71 +++++++++++++++++++++++++++++--------- tests/test_masked_loss.py | 2 +- 3 files changed, 63 insertions(+), 27 deletions(-) diff --git a/monai/losses/focal_loss.py b/monai/losses/focal_loss.py index 81a564148f..eba85634d8 100644 --- a/monai/losses/focal_loss.py +++ b/monai/losses/focal_loss.py @@ -22,7 +22,7 @@ class FocalLoss(_Loss): """ - Reimplementation of the Focal Loss described in: + Reimplementation of the Focal Loss (with a build-in sigmoid activation) described in: - "Focal Loss for Dense Object Detection", T. Lin et al., ICCV 2017 - "AnatomyNet: Deep learning for fast and fully automated whole‐volume segmentation of head and neck anatomy", @@ -78,7 +78,7 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: Args: input: the shape should be BNH[WD], where N is the number of classes. The input should be the original logits since it will be transferred by - `F.log_softmax` in the forward function. + a sigmoid in the forward function. target: the shape should be BNH[WD] or B1H[WD], where N is the number of classes. Raises: @@ -117,10 +117,8 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: i = i.reshape(b, n, -1) t = t.reshape(b, n, -1) - # Compute the log proba. - logpt = F.log_softmax(i, dim=1) - # Get the proba - pt = torch.exp(logpt) # B,H*W or B,N,H*W + max_val = (-i).clamp(min=0) + ce = i - i * t + max_val + ((-max_val).exp() + (-i - max_val).exp()).log() if self.weight is not None: class_weight: Optional[torch.Tensor] = None @@ -142,11 +140,12 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: at = class_weight[None, :, None] # N => 1,N,1 at = at.expand((t.size(0), -1, t.size(2))) # 1,N,1 => B,N,H*W # Multiply the log proba by their weights. - logpt = logpt * at + ce = ce * at # Compute the loss mini-batch. - weight = torch.pow(-pt + 1.0, self.gamma) - loss = torch.mean(-weight * t * logpt, dim=-1) + p = F.logsigmoid(-i * (t * 2.0 - 1.0)) + loss = torch.mean((p * self.gamma).exp() * ce, dim=-1) + if self.reduction == LossReduction.SUM.value: return loss.sum() if self.reduction == LossReduction.NONE.value: diff --git a/tests/test_focal_loss.py b/tests/test_focal_loss.py index 0c247702cb..1bf119d170 100644 --- a/tests/test_focal_loss.py +++ b/tests/test_focal_loss.py @@ -22,9 +22,9 @@ class TestFocalLoss(unittest.TestCase): def test_consistency_with_cross_entropy_2d(self): - # For gamma=0 the focal loss reduces to the cross entropy loss - focal_loss = FocalLoss(to_onehot_y=True, gamma=0.0, reduction="mean", weight=1.0) - ce = nn.CrossEntropyLoss(reduction="mean") + """For gamma=0 the focal loss reduces to the cross entropy loss""" + focal_loss = FocalLoss(to_onehot_y=False, gamma=0.0, reduction="mean", weight=1.0) + ce = nn.BCEWithLogitsLoss(reduction="mean") max_error = 0 class_num = 10 batch_size = 128 @@ -32,12 +32,12 @@ def test_consistency_with_cross_entropy_2d(self): # Create a random tensor of shape (batch_size, class_num, 8, 4) x = torch.rand(batch_size, class_num, 8, 4, requires_grad=True) # Create a random batch of classes - l = torch.randint(low=0, high=class_num, size=(batch_size, 1, 8, 4)) + l = torch.randint(low=0, high=2, size=(batch_size, class_num, 8, 4)).float() if torch.cuda.is_available(): x = x.cuda() l = l.cuda() output0 = focal_loss(x, l) - output1 = ce(x, l[:, 0]) / class_num + output1 = ce(x, l) a = float(output0.cpu().detach()) b = float(output1.cpu().detach()) if abs(a - b) > max_error: @@ -45,9 +45,9 @@ def test_consistency_with_cross_entropy_2d(self): self.assertAlmostEqual(max_error, 0.0, places=3) def test_consistency_with_cross_entropy_2d_onehot_label(self): - # For gamma=0 the focal loss reduces to the cross entropy loss - focal_loss = FocalLoss(to_onehot_y=False, gamma=0.0, reduction="mean") - ce = nn.CrossEntropyLoss(reduction="mean") + """For gamma=0 the focal loss reduces to the cross entropy loss""" + focal_loss = FocalLoss(to_onehot_y=True, gamma=0.0, reduction="mean") + ce = nn.BCEWithLogitsLoss(reduction="mean") max_error = 0 class_num = 10 batch_size = 128 @@ -59,8 +59,8 @@ def test_consistency_with_cross_entropy_2d_onehot_label(self): if torch.cuda.is_available(): x = x.cuda() l = l.cuda() - output0 = focal_loss(x, one_hot(l, num_classes=class_num)) - output1 = ce(x, l[:, 0]) / class_num + output0 = focal_loss(x, l) + output1 = ce(x, one_hot(l, num_classes=class_num)) a = float(output0.cpu().detach()) b = float(output1.cpu().detach()) if abs(a - b) > max_error: @@ -68,9 +68,9 @@ def test_consistency_with_cross_entropy_2d_onehot_label(self): self.assertAlmostEqual(max_error, 0.0, places=3) def test_consistency_with_cross_entropy_classification(self): - # for gamma=0 the focal loss reduces to the cross entropy loss + """for gamma=0 the focal loss reduces to the cross entropy loss""" focal_loss = FocalLoss(to_onehot_y=True, gamma=0.0, reduction="mean") - ce = nn.CrossEntropyLoss(reduction="mean") + ce = nn.BCEWithLogitsLoss(reduction="mean") max_error = 0 class_num = 10 batch_size = 128 @@ -84,19 +84,43 @@ def test_consistency_with_cross_entropy_classification(self): x = x.cuda() l = l.cuda() output0 = focal_loss(x, l) - output1 = ce(x, l[:, 0]) / class_num + output1 = ce(x, one_hot(l, num_classes=class_num)) a = float(output0.cpu().detach()) b = float(output1.cpu().detach()) if abs(a - b) > max_error: max_error = abs(a - b) self.assertAlmostEqual(max_error, 0.0, places=3) + def test_consistency_with_cross_entropy_classification_01(self): + # for gamma=0.1 the focal loss differs from the cross entropy loss + focal_loss = FocalLoss(to_onehot_y=True, gamma=0.1, reduction="mean") + ce = nn.BCEWithLogitsLoss(reduction="mean") + max_error = 0 + class_num = 10 + batch_size = 128 + for _ in range(100): + # Create a random scores tensor of shape (batch_size, class_num) + x = torch.rand(batch_size, class_num, requires_grad=True) + # Create a random batch of classes + l = torch.randint(low=0, high=class_num, size=(batch_size, 1)) + l = l.long() + if torch.cuda.is_available(): + x = x.cuda() + l = l.cuda() + output0 = focal_loss(x, l) + output1 = ce(x, one_hot(l, num_classes=class_num)) + a = float(output0.cpu().detach()) + b = float(output1.cpu().detach()) + if abs(a - b) > max_error: + max_error = abs(a - b) + self.assertNotAlmostEqual(max_error, 0.0, places=3) + def test_bin_seg_2d(self): # define 2d examples target = torch.tensor([[0, 0, 0, 0], [0, 1, 1, 0], [0, 1, 1, 0], [0, 0, 0, 0]]) # add another dimension corresponding to the batch (batch size = 1 here) target = target.unsqueeze(0) # shape (1, H, W) - pred_very_good = 1000 * F.one_hot(target, num_classes=2).permute(0, 3, 1, 2).float() + pred_very_good = 100 * F.one_hot(target, num_classes=2).permute(0, 3, 1, 2).float() - 50.0 # initialize the mean dice loss loss = FocalLoss(to_onehot_y=True) @@ -112,7 +136,7 @@ def test_empty_class_2d(self): target = torch.tensor([[0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]]) # add another dimension corresponding to the batch (batch size = 1 here) target = target.unsqueeze(0) # shape (1, H, W) - pred_very_good = 1000 * F.one_hot(target, num_classes=num_classes).permute(0, 3, 1, 2).float() + pred_very_good = 1000 * F.one_hot(target, num_classes=num_classes).permute(0, 3, 1, 2).float() - 500.0 # initialize the mean dice loss loss = FocalLoss(to_onehot_y=True) @@ -128,7 +152,7 @@ def test_multi_class_seg_2d(self): target = torch.tensor([[0, 0, 0, 0], [0, 1, 2, 0], [0, 3, 4, 0], [0, 0, 0, 0]]) # add another dimension corresponding to the batch (batch size = 1 here) target = target.unsqueeze(0) # shape (1, H, W) - pred_very_good = 1000 * F.one_hot(target, num_classes=num_classes).permute(0, 3, 1, 2).float() + pred_very_good = 1000 * F.one_hot(target, num_classes=num_classes).permute(0, 3, 1, 2).float() - 500.0 # initialize the mean dice loss loss = FocalLoss(to_onehot_y=True) loss_onehot = FocalLoss(to_onehot_y=False) @@ -159,7 +183,7 @@ def test_bin_seg_3d(self): # add another dimension corresponding to the batch (batch size = 1 here) target = target.unsqueeze(0) # shape (1, H, W, D) target_one_hot = F.one_hot(target, num_classes=num_classes).permute(0, 4, 1, 2, 3) # test one hot - pred_very_good = 1000 * F.one_hot(target, num_classes=num_classes).permute(0, 4, 1, 2, 3).float() + pred_very_good = 1000 * F.one_hot(target, num_classes=num_classes).permute(0, 4, 1, 2, 3).float() - 500.0 # initialize the mean dice loss loss = FocalLoss(to_onehot_y=True) @@ -173,6 +197,19 @@ def test_bin_seg_3d(self): focal_loss_good = float(loss_onehot(pred_very_good, target_one_hot).cpu()) self.assertAlmostEqual(focal_loss_good, 0.0, places=3) + def test_foreground(self): + background = torch.ones(1, 1, 5, 5) + foreground = torch.zeros(1, 1, 5, 5) + target = torch.cat((background, foreground), dim=1) + input = torch.cat((background, foreground), dim=1) + target[:, 0, 2, 2] = 0 + target[:, 1, 2, 2] = 1 + + fgbg = FocalLoss(to_onehot_y=False, include_background=True)(input, target) + fg = FocalLoss(to_onehot_y=False, include_background=False)(input, target) + self.assertAlmostEqual(float(fgbg.cpu()), 0.1116, places=3) + self.assertAlmostEqual(float(fg.cpu()), 0.1733, places=3) + def test_ill_opts(self): chn_input = torch.ones((1, 2, 3)) chn_target = torch.ones((1, 2, 3)) diff --git a/tests/test_masked_loss.py b/tests/test_masked_loss.py index 5a8b2bb68f..261b8131ad 100644 --- a/tests/test_masked_loss.py +++ b/tests/test_masked_loss.py @@ -32,7 +32,7 @@ "to_onehot_y": True, "reduction": "sum", }, - [(12.105497, 18.805185), (10.636354, 6.3138)], + [(14.538666, 13.17672), (13.17672, 6.3138)], ], ] From d74a5619ff7c0b9a4fa8b984a62fbf3e6037acc0 Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Fri, 2 Jul 2021 20:01:32 +0100 Subject: [PATCH 2/5] temp tests Signed-off-by: Wenqi Li --- tests/test_masked_loss.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/tests/test_masked_loss.py b/tests/test_masked_loss.py index 261b8131ad..7ef4fbab57 100644 --- a/tests/test_masked_loss.py +++ b/tests/test_masked_loss.py @@ -50,15 +50,16 @@ def test_shape(self, input_param, expected_val): label = torch.randint(low=0, high=2, size=size) label = torch.argmax(label, dim=1, keepdim=True) pred = torch.randn(size) - print(label[0, 0, 0]) result = MaskedLoss(**input_param)(pred, label, None) out = result.detach().cpu().numpy() - checked = np.allclose(out, expected_val[0][0]) or np.allclose(out, expected_val[0][1]) - self.assertTrue(checked) + print(out) + # checked = np.allclose(out, expected_val[0][0]) or np.allclose(out, expected_val[0][1]) + # self.assertTrue(checked) mask = torch.randint(low=0, high=2, size=label.shape) result = MaskedLoss(**input_param)(pred, label, mask) out = result.detach().cpu().numpy() + print(out) checked = np.allclose(out, expected_val[1][0]) or np.allclose(out, expected_val[1][1]) self.assertTrue(checked) From db64d08b8071738c14963f5932d877c1062d6b4c Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Fri, 2 Jul 2021 20:17:33 +0100 Subject: [PATCH 3/5] fixes tests Signed-off-by: Wenqi Li --- tests/test_masked_loss.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/tests/test_masked_loss.py b/tests/test_masked_loss.py index 7ef4fbab57..225e3d9668 100644 --- a/tests/test_masked_loss.py +++ b/tests/test_masked_loss.py @@ -32,7 +32,7 @@ "to_onehot_y": True, "reduction": "sum", }, - [(14.538666, 13.17672), (13.17672, 6.3138)], + [(14.538666, 20.191753), (13.17672, 8.251623)], ], ] @@ -52,14 +52,12 @@ def test_shape(self, input_param, expected_val): pred = torch.randn(size) result = MaskedLoss(**input_param)(pred, label, None) out = result.detach().cpu().numpy() - print(out) - # checked = np.allclose(out, expected_val[0][0]) or np.allclose(out, expected_val[0][1]) - # self.assertTrue(checked) + checked = np.allclose(out, expected_val[0][0]) or np.allclose(out, expected_val[0][1]) + self.assertTrue(checked) mask = torch.randint(low=0, high=2, size=label.shape) result = MaskedLoss(**input_param)(pred, label, mask) out = result.detach().cpu().numpy() - print(out) checked = np.allclose(out, expected_val[1][0]) or np.allclose(out, expected_val[1][1]) self.assertTrue(checked) From dfb331864234dcfdb0040223ba240de2e76f3c82 Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Mon, 5 Jul 2021 20:12:05 +0100 Subject: [PATCH 4/5] update based on comments Signed-off-by: Wenqi Li --- monai/losses/focal_loss.py | 6 +++--- tests/test_focal_loss.py | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/monai/losses/focal_loss.py b/monai/losses/focal_loss.py index eba85634d8..de98aa3635 100644 --- a/monai/losses/focal_loss.py +++ b/monai/losses/focal_loss.py @@ -77,12 +77,12 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: """ Args: input: the shape should be BNH[WD], where N is the number of classes. - The input should be the original logits since it will be transferred by + The input should be the original logits since it will be transformed by a sigmoid in the forward function. target: the shape should be BNH[WD] or B1H[WD], where N is the number of classes. Raises: - AssertionError: When input and target (after one hot transform if setted) + ValueError: When input and target (after one hot transform if set) have different shapes. ValueError: When ``self.reduction`` is not one of ["mean", "sum", "none"]. ValueError: When ``self.weight`` is a sequence and the length is not equal to the @@ -107,7 +107,7 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: input = input[:, 1:] if target.shape != input.shape: - raise AssertionError(f"ground truth has different shape ({target.shape}) from input ({input.shape})") + raise ValueError(f"ground truth has different shape ({target.shape}) from input ({input.shape})") i = input t = target diff --git a/tests/test_focal_loss.py b/tests/test_focal_loss.py index 1bf119d170..1314fe3841 100644 --- a/tests/test_focal_loss.py +++ b/tests/test_focal_loss.py @@ -219,7 +219,7 @@ def test_ill_opts(self): def test_ill_shape(self): chn_input = torch.ones((1, 2, 3)) chn_target = torch.ones((1, 3)) - with self.assertRaisesRegex(AssertionError, ""): + with self.assertRaisesRegex(ValueError, ""): FocalLoss(reduction="mean")(chn_input, chn_target) def test_ill_class_weight(self): From 7e066dc794f0efd409aef4017c980a05b05ad3d4 Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Wed, 7 Jul 2021 00:40:07 +0100 Subject: [PATCH 5/5] update based on comments Signed-off-by: Wenqi Li --- monai/losses/focal_loss.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/monai/losses/focal_loss.py b/monai/losses/focal_loss.py index de98aa3635..b4b3698e5b 100644 --- a/monai/losses/focal_loss.py +++ b/monai/losses/focal_loss.py @@ -117,6 +117,8 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: i = i.reshape(b, n, -1) t = t.reshape(b, n, -1) + # computing binary cross entropy with logits + # see also https://github.com/pytorch/pytorch/blob/v1.9.0/aten/src/ATen/native/Loss.cpp#L231 max_val = (-i).clamp(min=0) ce = i - i * t + max_val + ((-max_val).exp() + (-i - max_val).exp()).log() @@ -143,6 +145,7 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: ce = ce * at # Compute the loss mini-batch. + # (1-p_t)^gamma * log(p_t) with reduced chance of overflow p = F.logsigmoid(-i * (t * 2.0 - 1.0)) loss = torch.mean((p * self.gamma).exp() * ce, dim=-1)