# []

In [1]:
# inbuilt 
import os
import sys
import math

# most common
import numpy as np
import matplotlib.pyplot as plt

# pytorch
import torch as tt
import torch.nn as nn
import torch.functional as ff
import torch.distributions as dd
import torch.utils.data as ud

tt.manual_seed(281703975047300) # manually sets a seed for random sampling creation ops
print('Manual-Seed:', tt.initial_seed()) # current seed for default rng


Manual-Seed: 281703975047300


In [2]:
batch_size = 5
n_class = 7

In [3]:
y_pred = tt.softmax( tt.rand((batch_size, n_class)), dim=-1)
y_pred.shape, y_pred, tt.sum(y_pred, dim=-1)

(torch.Size([5, 7]),
 tensor([[0.1289, 0.1138, 0.1313, 0.1819, 0.1729, 0.1838, 0.0875],
         [0.0941, 0.1342, 0.0978, 0.1929, 0.2000, 0.1198, 0.1612],
         [0.1829, 0.1168, 0.1837, 0.1428, 0.0957, 0.1177, 0.1604],
         [0.1237, 0.1257, 0.1511, 0.1112, 0.1141, 0.2515, 0.1226],
         [0.0992, 0.1894, 0.1155, 0.0908, 0.1005, 0.1942, 0.2103]]),
 tensor([1.0000, 1.0000, 1.0000, 1.0000, 1.0000]))

In [4]:
y_labels = tt.randint(0, n_class, size=(batch_size,))
y_labels.shape, y_labels

(torch.Size([5]), tensor([4, 4, 5, 2, 4]))

In [5]:
# convert one-hot
y_true = tt.zeros_like(y_pred)
for i in range(batch_size):
    y_true[i, y_labels[i]]=1
y_true.shape, y_true

(torch.Size([5, 7]),
 tensor([[0., 0., 0., 0., 1., 0., 0.],
         [0., 0., 0., 0., 1., 0., 0.],
         [0., 0., 0., 0., 0., 1., 0.],
         [0., 0., 1., 0., 0., 0., 0.],
         [0., 0., 0., 0., 1., 0., 0.]]))

In [6]:
common = [0,2,4,6]
rare = [1,3,5]

In [7]:
class AsymmetricFocalLoss(nn.Module):
    """For Imbalanced datasets
    Parameters
    ----------
    delta : float, optional
        controls weight given to false positive and false negatives, by default 0.25
    gamma : float, optional
        Focal Tversky loss' focal parameter controls degree of down-weighting of easy examples, by default 2.0
    epsilon : float, optional
        clip values to prevent division by zero error
    """
    def __init__(self, delta=0.7, gamma=2., epsilon=1e-07):
        super(AsymmetricFocalLoss, self).__init__()
        self.delta = delta
        self.gamma = gamma
        self.epsilon = epsilon

    def forward(self, y_pred, y_true):
        y_pred = tt.clamp(y_pred, self.epsilon, 1. - self.epsilon)
        cross_entropy = -y_true * tt.log(y_pred)

        print(f'{cross_entropy.shape=}\n{cross_entropy=}')
    
	# Calculate losses separately for each class, only suppressing background class
        back_ce = tt.pow(1 - y_pred[:,common], self.gamma) * cross_entropy[:,common]
        back_ce =  (1 - self.delta) * back_ce
        print(f'{back_ce.shape=}\n{back_ce=}')


        fore_ce = cross_entropy[:,rare]
        fore_ce = self.delta * fore_ce
        print(f'{fore_ce.shape=}\n{fore_ce=}')

        loss_stack = tt.stack([back_ce, fore_ce], axis=-1)
        print(f'{loss_stack.shape=}\n{loss_stack=}')

        loss_sum=tt.sum(loss_stack, axis=-1)
        print(f'{loss_sum.shape=}\n{loss_sum=}')

        loss = tt.mean(loss_sum)

        return loss

In [8]:
lossF = AsymmetricFocalLoss()
loss = lossF(y_pred, y_true)

loss.shape, loss

cross_entropy.shape=torch.Size([5, 7])
cross_entropy=tensor([[0.0000, 0.0000, 0.0000, 0.0000, 1.7552, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 1.6096, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 2.1396, 0.0000],
        [0.0000, 0.0000, 1.8897, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 2.2972, 0.0000, 0.0000]])
back_ce.shape=torch.Size([5, 4])
back_ce=tensor([[0.0000, 0.0000, 0.3602, 0.0000],
        [0.0000, 0.0000, 0.3091, 0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.4085, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.5576, 0.0000]])
fore_ce.shape=torch.Size([5, 3])
fore_ce=tensor([[0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 1.4977],
        [0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000]])


RuntimeError: stack expects each tensor to be equal size, but got [5, 4] at entry 0 and [5, 3] at entry 1