# []

In [1]:
# inbuilt 
import os
import sys
import math

# most common
import numpy as np
import matplotlib.pyplot as plt

# pytorch
import torch as tt
import torch.nn as nn
import torch.functional as ff
import torch.distributions as dd
import torch.utils.data as ud

tt.manual_seed(281703975047300) # manually sets a seed for random sampling creation ops
print('Manual-Seed:', tt.initial_seed()) # current seed for default rng


Manual-Seed: 281703975047300


In [2]:
batch_size = 5
n_class = 2

In [3]:
y_pred = tt.softmax( tt.rand((batch_size, n_class)), dim=-1)
y_pred.shape, y_pred, tt.sum(y_pred, dim=-1)

(torch.Size([5, 2]),
 tensor([[0.5310, 0.4690],
         [0.4192, 0.5808],
         [0.4847, 0.5153],
         [0.5497, 0.4503],
         [0.5785, 0.4215]]),
 tensor([1.0000, 1.0000, 1.0000, 1.0000, 1.0000]))

In [4]:
y_labels = tt.randint(0, n_class, size=(batch_size,))
y_labels.shape, y_labels

(torch.Size([5]), tensor([0, 1, 0, 0, 0]))

In [5]:
# convert one-hot
y_true = tt.zeros_like(y_pred)
for i in range(batch_size):
    y_true[i, y_labels[i]]=1
y_true.shape, y_true

(torch.Size([5, 2]),
 tensor([[1., 0.],
         [0., 1.],
         [1., 0.],
         [1., 0.],
         [1., 0.]]))

In [6]:
common = [0]
rare = [1]

In [7]:
class AsymmetricFocalLoss(nn.Module):
    """For Imbalanced datasets
    Parameters
    ----------
    delta : float, optional
        controls weight given to false positive and false negatives, by default 0.25
    gamma : float, optional
        Focal Tversky loss' focal parameter controls degree of down-weighting of easy examples, by default 2.0
    epsilon : float, optional
        clip values to prevent division by zero error
    """
    def __init__(self, delta=0.7, gamma=2., epsilon=1e-07):
        super(AsymmetricFocalLoss, self).__init__()
        self.delta = delta
        self.gamma = gamma
        self.epsilon = epsilon

    def forward(self, y_pred, y_true):
        y_pred = tt.clamp(y_pred, self.epsilon, 1. - self.epsilon)
        cross_entropy = -y_true * tt.log(y_pred)

        print(f'{cross_entropy.shape=}\n{cross_entropy=}')

	# Calculate losses separately for each class, only suppressing background class
        back_ce = tt.pow(1 - y_pred[:,common], self.gamma) * cross_entropy[:,common]
        back_ce =  (1 - self.delta) * back_ce
        print(f'{back_ce.shape=}\n{back_ce=}')


        fore_ce = cross_entropy[:,rare]
        fore_ce = self.delta * fore_ce
        print(f'{fore_ce.shape=}\n{fore_ce=}')

        loss_stack = tt.stack([back_ce, fore_ce], axis=-1)
        print(f'{loss_stack.shape=}\n{loss_stack=}')

        loss_sum=tt.sum(loss_stack, axis=-1)
        print(f'{loss_sum.shape=}\n{loss_sum=}')

        loss = tt.mean(loss_sum)

        return loss

In [8]:
lossF = AsymmetricFocalLoss()
loss = lossF(y_pred, y_true)

loss.shape, loss

cross_entropy.shape=torch.Size([5, 2])
cross_entropy=tensor([[0.6330, 0.0000],
        [0.0000, 0.5433],
        [0.7242, 0.0000],
        [0.5983, 0.0000],
        [0.5473, 0.0000]])
back_ce.shape=torch.Size([5, 1])
back_ce=tensor([[0.0418],
        [0.0000],
        [0.0577],
        [0.0364],
        [0.0292]])
fore_ce.shape=torch.Size([5, 1])
fore_ce=tensor([[0.0000],
        [0.3803],
        [0.0000],
        [0.0000],
        [0.0000]])
loss_stack.shape=torch.Size([5, 1, 2])
loss_stack=tensor([[[0.0418, 0.0000]],

        [[0.0000, 0.3803]],

        [[0.0577, 0.0000]],

        [[0.0364, 0.0000]],

        [[0.0292, 0.0000]]])
loss_sum.shape=torch.Size([5, 1])
loss_sum=tensor([[0.0418],
        [0.3803],
        [0.0577],
        [0.0364],
        [0.0292]])


(torch.Size([]), tensor(0.1091))