In [7]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class FocalLoss(nn.Module):
    def __init__(self, alpha=1, gamma=2, reduction='mean'):
        super(FocalLoss, self).__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.reduction = reduction

    def forward(self, inputs, targets):
        ce_loss = F.cross_entropy(inputs, targets, reduction='none')
        pt = torch.exp(-ce_loss)
        focal_loss = self.alpha * (1 - pt) ** self.gamma * ce_loss

        if self.reduction == 'mean':
            return focal_loss.mean()
        elif self.reduction == 'sum':
            return focal_loss.sum()
        else:
            return focal_loss



In [8]:
pred = torch.randn(4, 2)  # batch, channel, x, y, z
y = torch.LongTensor([1, 1, 0, 1])


In [9]:
floss = FocalLoss()
celoss = torch.nn.CrossEntropyLoss()
floss(pred, y), celoss(pred, y)

(tensor(0.7236), tensor(0.7236))