# []

In [None]:
# inbuilt 
import os
import sys
import math

# most common
import numpy as np
import matplotlib.pyplot as plt

# pytorch
import torch as tt
import torch.nn as nn
import torch.functional as ff
import torch.distributions as dd
import torch.utils.data as ud

# manually sets a seed for random sampling creation ops
print('Manual-Seed:', tt.initial_seed()) # current seed for default rng


In [2]:
import torch
import torch.nn as nn
torch.manual_seed(281703975047300) 


class AsymmetricFocalLoss(nn.Module):
    """For Imbalanced datasets
    Parameters
    ----------
    delta : float, optional
        controls weight given to false positive and false negatives, by default 0.25
    gamma : float, optional
        Focal Tversky loss' focal parameter controls degree of down-weighting of easy examples, by default 2.0
    epsilon : float, optional
        clip values to prevent division by zero error
    common : list, required
        a list of common class indices
    rare : list, required
        a list of rare class indices
    """
    def __init__(self, common, rare, delta=0.7, gamma=2., epsilon=1e-07):
        super(AsymmetricFocalLoss, self).__init__()
        self.delta = delta
        self.gamma = gamma
        self.epsilon = epsilon
        self.common = common
        self.rare = rare

    def forward(self, y_pred, y_labels):
        # assume y_pred contain probabilities (batch_size_ n_class)
        # y_labels contain integer class lables (batch_size, )

        # convert one-hot
        y_true = torch.zeros_like(y_pred)
        for i,j in enumerate(y_labels):y_true[i, j]=1

        # clamp
        y_pred = torch.clamp(y_pred, self.epsilon, 1. - self.epsilon)
        cross_entropy = -y_true * torch.log(y_pred)
        #print(f'{cross_entropy.shape=}\n{cross_entropy=}')
    
	    # Calculate losses separately for each class, only suppressing background class
        all_ce=[]

        for c in self.common:
            back_ce = (1 - self.delta) * (torch.pow(1 - y_pred[:,c], self.gamma) * cross_entropy[:,c])
            all_ce.append(back_ce)

        for r in self.rare:
            fore_ce=self.delta * cross_entropy[:,r]
            all_ce.append(fore_ce)

        loss_stack = torch.stack(all_ce, axis=-1)
        #print(f'{loss_stack.shape=}\n{loss_stack=}')

        loss_sum=torch.sum(loss_stack, axis=-1)
        #print(f'{loss_sum.shape=}\n{loss_sum=}')

        loss = torch.mean(loss_sum)

        return loss

y_pred=tensor([[0.5261, 0.4739],
        [0.4097, 0.5903],
        [0.4087, 0.5913],
        [0.3588, 0.6412],
        [0.3381, 0.6619]])
y_labels=tensor([0, 0, 0, 0, 1])


In [4]:
lossF = AsymmetricFocalLoss(common = [0], rare = [1])
loss = lossF(y_pred, y_labels)

loss.shape, loss

(torch.Size([]), tensor(0.1091))

In [9]:
batch_size = 5
n_class = 7

y_pred = torch.softmax( torch.rand((batch_size, n_class)), dim=-1)
y_labels = torch.randint(0, n_class, size=(batch_size,))
print(f'{y_pred=}\n{y_labels=}')

lossF = AsymmetricFocalLoss(common = [0,2,4,6], rare = [1,3,5])
loss = lossF(y_pred, y_labels)

print(f'{loss=}')

y_pred=tensor([[0.1955, 0.1455, 0.0976, 0.1869, 0.1043, 0.1173, 0.1529],
        [0.1613, 0.1635, 0.1121, 0.1290, 0.1571, 0.0993, 0.1777],
        [0.0978, 0.1340, 0.1025, 0.1993, 0.2197, 0.1041, 0.1425],
        [0.1371, 0.1113, 0.1771, 0.1560, 0.0897, 0.1554, 0.1734],
        [0.1960, 0.1890, 0.1403, 0.1076, 0.1714, 0.1079, 0.0878]])
y_labels=tensor([0, 3, 2, 5, 3])
loss=tensor(1.0328)
