In [1]:
# reference: https://github.com/facebookresearch/fvcore/blob/main/fvcore/nn/focal_loss.py

In [2]:
import numpy as np
import torch
from torch.nn import functional as F

In [3]:
batch_size = 32
class_num = 128

In [4]:
inputs = np.random.randn(batch_size,class_num)
inputs = torch.tensor(inputs,dtype=torch.float32)

In [5]:
inputs

tensor([[-1.9175, -1.3873, -0.1195,  ...,  2.2159,  0.8280, -0.5806],
        [-0.3103, -0.5549, -0.7568,  ...,  0.7188, -1.1780, -1.0346],
        [-2.5204,  0.1631, -1.3644,  ...,  0.7163,  1.5505,  1.0744],
        ...,
        [-1.3316,  0.0582,  1.5016,  ..., -0.4068,  1.8959,  0.0139],
        [-0.0294, -0.6537, -0.1888,  ..., -0.0366, -0.7971, -1.3894],
        [ 0.0738, -0.5909, -0.9811,  ..., -0.1257, -0.5891, -2.3516]])

In [6]:
labels = [np.random.randint(0,128) for _ in range(batch_size)]

In [7]:
targets = np.zeros(inputs.shape)

In [8]:
for bid in range(batch_size):
    targets[bid][labels[bid]] =1 

In [9]:
targets = torch.tensor(targets,dtype=torch.float32)

In [10]:
targets.shape

torch.Size([32, 128])

In [11]:
ce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none")

In [12]:
# echo cross binary loss for each class
ce_loss

tensor([[0.1371, 0.2229, 0.6352,  ..., 2.3194, 1.1905, 0.4444],
        [0.5500, 0.4537, 0.3847,  ..., 1.1158, 0.2684, 0.3041],
        [0.0774, 0.7780, 0.2276,  ..., 1.1141, 1.7429, 1.3682],
        ...,
        [0.2343, 0.7227, 1.7027,  ..., 0.5103, 2.0358, 0.7001],
        [0.6786, 0.4188, 0.6032,  ..., 0.6750, 0.3720, 1.6119],
        [0.7307, 0.4407, 0.3184,  ..., 0.6323, 0.4414, 0.0909]])

In [13]:
# probality to be true
probalities = torch.sigmoid(inputs)

In [14]:
probalities

tensor([[0.1281, 0.1998, 0.4702,  ..., 0.9017, 0.6959, 0.3588],
        [0.4230, 0.3647, 0.3193,  ..., 0.6723, 0.2354, 0.2622],
        [0.0744, 0.5407, 0.2035,  ..., 0.6718, 0.8250, 0.7454],
        ...,
        [0.2089, 0.5145, 0.8178,  ..., 0.3997, 0.8694, 0.5035],
        [0.4927, 0.3422, 0.4529,  ..., 0.4909, 0.3106, 0.1995],
        [0.5184, 0.3564, 0.2727,  ..., 0.4686, 0.3568, 0.0869]])

In [15]:
# probablities to be correct
probalities_correct = probalities * targets + (1 - probalities) * (1 - targets)

<img src="./../focal_loss.png" width=500 height=500>

In [16]:
gamma = 2

In [17]:
loss = ce_loss * ((1 - probalities_correct) ** gamma)

In [18]:
probalities

tensor([[0.1281, 0.1998, 0.4702,  ..., 0.9017, 0.6959, 0.3588],
        [0.4230, 0.3647, 0.3193,  ..., 0.6723, 0.2354, 0.2622],
        [0.0744, 0.5407, 0.2035,  ..., 0.6718, 0.8250, 0.7454],
        ...,
        [0.2089, 0.5145, 0.8178,  ..., 0.3997, 0.8694, 0.5035],
        [0.4927, 0.3422, 0.4529,  ..., 0.4909, 0.3106, 0.1995],
        [0.5184, 0.3564, 0.2727,  ..., 0.4686, 0.3568, 0.0869]])

In [19]:
ce_loss

tensor([[0.1371, 0.2229, 0.6352,  ..., 2.3194, 1.1905, 0.4444],
        [0.5500, 0.4537, 0.3847,  ..., 1.1158, 0.2684, 0.3041],
        [0.0774, 0.7780, 0.2276,  ..., 1.1141, 1.7429, 1.3682],
        ...,
        [0.2343, 0.7227, 1.7027,  ..., 0.5103, 2.0358, 0.7001],
        [0.6786, 0.4188, 0.6032,  ..., 0.6750, 0.3720, 1.6119],
        [0.7307, 0.4407, 0.3184,  ..., 0.6323, 0.4414, 0.0909]])

In [20]:
ce_loss.mean()

tensor(0.8169)

In [21]:
loss

tensor([[2.2517e-03, 8.9038e-03, 1.4042e-01,  ..., 1.8857e+00, 5.7659e-01,
         5.7210e-02],
        [9.8429e-02, 6.0354e-02, 3.9232e-02,  ..., 5.0440e-01, 1.4876e-02,
         2.0903e-02],
        [4.2868e-04, 2.2743e-01, 9.4257e-03,  ..., 5.0282e-01, 1.1862e+00,
         7.6028e-01],
        ...,
        [1.0225e-02, 1.9133e-01, 1.1388e+00,  ..., 8.1509e-02, 1.5388e+00,
         1.7748e-01],
        [1.6470e-01, 4.9026e-02, 1.2374e-01,  ..., 1.6264e-01, 3.5896e-02,
         1.0329e+00],
        [1.9639e-01, 5.5984e-02, 2.3669e-02,  ..., 1.3884e-01, 5.6204e-02,
         6.8737e-04]])

In [22]:
loss = loss.mean()

In [23]:
loss

tensor(0.3583)

In [None]:
def sigmoid_focal_loss(
    inputs: torch.Tensor,
    targets: torch.Tensor,
    alpha: float = -1,
    gamma: float = 2,
    reduction: str = "none",
) -> torch.Tensor:
    """
    Loss used in RetinaNet for dense detection: https://arxiv.org/abs/1708.02002.
    Args:
        inputs: A float tensor of arbitrary shape.
                The predictions for each example.
        targets: A float tensor with the same shape as inputs. Stores the binary
                 classification label for each element in inputs
                (0 for the negative class and 1 for the positive class).
        alpha: (optional) Weighting factor in range (0,1) to balance
                positive vs negative examples. Default = -1 (no weighting).
        gamma: Exponent of the modulating factor (1 - p_t) to
               balance easy vs hard examples.
        reduction: 'none' | 'mean' | 'sum'
                 'none': No reduction will be applied to the output.
                 'mean': The output will be averaged.
                 'sum': The output will be summed.
    Returns:
        Loss tensor with the reduction option applied.
    """
    inputs = inputs.float()
    targets = targets.float()
    p = torch.sigmoid(inputs)
    ce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none")
    p_t = p * targets + (1 - p) * (1 - targets)
    loss = ce_loss * ((1 - p_t) ** gamma)

    if alpha >= 0:
        alpha_t = alpha * targets + (1 - alpha) * (1 - targets)
        loss = alpha_t * loss

    if reduction == "mean":
        loss = loss.mean()
    elif reduction == "sum":
        loss = loss.sum()

    return loss
