In [2]:
#https://medium.com/visionwizard/understanding-focal-loss-a-quick-read-b914422913e7
#https://discuss.pytorch.org/t/is-this-a-correct-implementation-for-focal-loss-in-pytorch/43327/6
#https://saturncloud.io/blog/how-to-use-class-weights-with-focal-loss-in-pytorch-for-imbalanced-multiclass-classification/
import torch
import numpy as np
import pandas as pd 
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset


In [5]:
class FocalLoss(nn.Module):

    """
        Class to represent the custom loss function, Focal Loss.

        Attributes: 
            alpha (float): Balancing factor.  
            gamma (float): Modulating factor to influence the impact of classifications. (#https://medium.com/geekculture/everything-about-focal-loss-f2d8ab294133)
    """
    
    def __init__(self, alpha, gamma): 

        """
            Initializes the focal loss. 
            
            Args: 
                alpha (float): Balancing factor.  
                gamma (float): Modulating factor to influence the impact of classifications. (#https://medium.com/geekculture/everything-about-focal-loss-f2d8ab294133).
        
        """    
    
        super().__init__()
        self.alpha = alpha
        self.gamma = gamma 

    def forward(self, y_pred_logits, y_true):

        """
            Derives the focal loss from Binary Cross Entropy (BCE).

            Returns: 
                loss function: The mean of the focal loss to be used as a loss function in neural networks. 
        """
        
        BCE_loss = nn.BCEWithLogitsLoss(reduction="none")

        loss = BCE_loss(y_pred_logits, y_true)

        pt = torch.exp(-loss)

        focal = -self.alpha*((1-pt)**self.gamma)*torch.log(pt)

        return focal.mean() 