In [2]:
import torch
import torch.nn.functional as F
import torch.nn as nn

class FocalLoss(nn.Module):
    def __init__(self, alpha=None, gamma=2, reduction='mean'):
        super(FocalLoss, self).__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.reduction = reduction

    def forward(self, inputs, targets):
        ce_loss = F.cross_entropy(inputs, targets, reduction='none')
        pt = torch.exp(-ce_loss)
        focal_loss = (1 - pt) ** self.gamma * ce_loss

        if self.alpha is not None:
            alpha_t = self.alpha[targets]
            focal_loss = alpha_t * focal_loss

        if self.reduction == 'mean':
            return focal_loss.mean()
        elif self.reduction == 'sum':
            return focal_loss.sum()
        else:
            return focal_loss

In [3]:
# Set random seed for reproducibility
torch.manual_seed(42)

# Define the number of classes and batch size
num_classes = 5
batch_size = 3

# Generate random input and target tensors
inputs = torch.randn(batch_size, num_classes, requires_grad=True)
targets = torch.randint(0, num_classes, (batch_size,), dtype=torch.long)

# Create an instance of the FocalLoss class
focal_loss = FocalLoss(alpha=torch.tensor([0.2, 0.3, 0.1, 0.4, 0.5]), gamma=2, reduction='mean')

# Calculate the loss
loss = focal_loss(inputs, targets)

# Print the results
print("Random Input Tensor:")
print(inputs)
print("\nRandom Target Tensor:")
print(targets)
print("\nFocal Loss:")
print(loss.item())

Random Input Tensor:
tensor([[ 0.3367,  0.1288,  0.2345,  0.2303, -1.1229],
        [-0.1863,  2.2082, -0.6380,  0.4617,  0.2674],
        [ 0.5349,  0.8094,  1.1103, -1.6898, -0.9890]], requires_grad=True)

Random Target Tensor:
tensor([0, 4, 3])

Focal Loss:
0.8353108763694763
