In [15]:
import torch
import torch.nn as nn
import numpy as np


class FocalLoss(nn.Module):
    def __init__(
        self,
        gamma=0.0,
        num_classes=6,
        alpha=None,
        max_batch_size=256,
        eps=1e-4,
        mode="class",
    ):
        super().__init__()
        self.gamma = gamma
        if alpha == None:
            alpha = torch.ones(num_classes)
        self.alpha = alpha.unsqueeze(0).repeat(max_batch_size, 1)
        self.kl_loss = nn.KLDivLoss(reduction="batchmean")
        self.eps = eps
        self.mode = mode

    def forward_log_alpha(self, x1, x2):
        x = x1 * torch.log(x2)
        return x

    def forward(self, pred, target):
        device = pred.get_device()
        pred = torch.nn.functional.softmax(pred, dim=1)
        pred = torch.clamp(pred, self.eps, 1 - self.eps)
        inv_pred = 1 - pred
        inv_pred = torch.clamp(inv_pred, self.eps, 1 - self.eps)
        target = torch.clamp(target, self.eps, 1 - self.eps)
        pred_log = self.forward_log_alpha(target, pred)
        target_log = self.forward_log_alpha(target, target)
        if device < 0:
            loss = -(inv_pred**self.gamma) * self.alpha[: pred.shape[0]]
        else:
            loss = -(inv_pred**self.gamma) * self.alpha[: pred.shape[0]].to(
                pred.get_device()
            )
        loss *= (pred_log - target_log) if self.mode == "kl" else pred_log
        loss = torch.mean(torch.sum(loss, dim=1))
        return loss

In [16]:
pred = torch.ones(2, 6) / 6
target = torch.ones(2, 6) / 6
pred[0, 0] = 1

# Validation of the focal loss
cross = nn.CrossEntropyLoss()
print(cross(pred, target))
focal_loss = FocalLoss(mode="class")
print(focal_loss(pred, target))

# Validation of the focal kl-divergence
kl_div = nn.KLDivLoss(reduction="batchmean")
print(kl_div(torch.nn.functional.log_softmax(pred, dim=1), target))
focal_loss = FocalLoss(mode="kl")
print(focal_loss(pred, target))

tensor(1.8204)
tensor(1.8204)
tensor(0.0287)
tensor(0.0287)
