In [None]:


loss_fn = torch.nn.CrossEntropyLoss(reduction='none')
loss_fn(y_pred, y)

tensor([0.3992, 0.3992])

In [8]:
soft = torch.softmax(y_pred, dim=1)
soft

tensor([[0.0479, 0.6709, 0.2233, 0.0377, 0.0203]])

In [10]:
log = torch.log(soft)
-log

tensor([[3.0392, 0.3992, 1.4992, 3.2792, 3.8992]])

In [11]:
import torch
import torch.nn.functional as F

def focal_loss_multiclass(inputs, targets, alpha=1, gamma=2):
    """
    Multi-class focal loss implementation
    - inputs: raw logits from the model
    - targets: true class labels (as integer indices, not one-hot encoded)
    """
    # Convert logits to log probabilities
    log_prob = F.log_softmax(inputs, dim=-1)
    prob = torch.exp(log_prob)  # Calculate probabilities from log probabilities

    # Gather the probabilities corresponding to the correct classes
    targets_one_hot = F.one_hot(targets, num_classes=inputs.shape[-1])
    pt = torch.sum(prob * targets_one_hot, dim=-1)

    # Apply focal adjustment
    focal_loss = -alpha * (1 - pt) ** gamma * torch.sum(log_prob * targets_one_hot, dim=-1)
    
    return focal_loss.mean()

In [12]:
focal_loss_multiclass(y_pred, y)

tensor(0.0432)

In [179]:
import torch

alpha = torch.tensor([0.1, 1, 0.5, 0.7, 0.7])

y_pred = torch.tensor([[0.56, 3.2, 2.1, 0.32, -0.3], [0.56, 3.2, 2.1, 0.32, -0.3]])
y_class = torch.tensor([1, 1])
y_ord = torch.tensor([[1, 1, 1, 1, 0], [1, 1, 1, 1, 1]]).float()

In [180]:
print(y_ord.to(dtype=torch.int64))
print(torch.sum(y_ord.to(dtype=torch.int64), dim=1) - 1)

tensor([[1, 1, 1, 1, 0],
        [1, 1, 1, 1, 1]])
tensor([3, 4])


In [189]:
class FocalLoss(torch.nn.Module):
    def __init__(self, alpha, gamma, headType=None):
        super(FocalLoss, self).__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.headType = headType
    def forward(self, inputs, targets):
        
        if self.headType == 'classification':
            ce_loss = F.cross_entropy(inputs, targets, reduction='none')
            # apply class weights
            # alpha is for the weight class, weights is the correct weight for each class and looks like this: tensor([a, b, c, d, a, ...])
            weights = self.alpha.gather(0, targets)

        if self.headType == 'ordinal':
            ce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction='none')
            # apply class weights
            # alpha is for the weight class, weights is the correct weight for each class and looks like this: tensor([a, b, c, d, a, ...])
            weights = self.alpha

        pt = torch.exp(-ce_loss)
        print('pt is: ', pt)
        print('ce_loss', ce_loss)
        loss = (weights * ((1 - pt) ** self.gamma) * ce_loss)
        return loss.mean()

In [190]:
loss_fn = FocalLoss(headType='classification', alpha=alpha, gamma=1)
loss_fn(y_pred, y_class)

pt is:  tensor([0.6709, 0.6709])
ce_loss tensor([0.3992, 0.3992])


tensor(0.1314)

In [191]:
loss_fn = FocalLoss(headType='ordinal', alpha=alpha, gamma=0)
loss_fn(y_pred, y_ord)

pt is:  tensor([[0.6365, 0.9608, 0.8909, 0.5793, 0.5744],
        [0.6365, 0.9608, 0.8909, 0.5793, 0.4256]])
ce_loss tensor([[0.4518, 0.0400, 0.1155, 0.5459, 0.5544],
        [0.4518, 0.0400, 0.1155, 0.5459, 0.8544]])


tensor(0.2036)

In [201]:
y_logits = torch.tensor([[0.56, 3.2, 2.1, 0.32, -0.3], [0.56, 3.2, 2.1, 0.32, -0.3], [2, 2, 3, 1, -5]])
y_class = torch.tensor([1, 1])
y_ord = torch.tensor([[1, 1, 1, 1, 0], [1, 1, 1, 1, 1], [1, 1, 1, 1, 0]]).float()

print(F.binary_cross_entropy_with_logits(y_logits, y_ord, reduction='none'))
print(F.binary_cross_entropy_with_logits(y_logits, y_ord, reduction='mean'))

tensor([[0.4518, 0.0400, 0.1155, 0.5459, 0.5544],
        [0.4518, 0.0400, 0.1155, 0.5459, 0.8544],
        [0.1269, 0.1269, 0.0486, 0.3133, 0.0067]])
tensor(0.2892)


In [207]:
from torch import tensor
from torchmetrics.classification import MulticlassCohenKappa
target = tensor([2, 1, 0, 1])
preds = tensor([2, 1, 0, 0])
metric = MulticlassCohenKappa(num_classes=3)
metric(preds, target)

tensor(0.6364)