# Focal Loss

In [0]:
import numpy as np
import torch
from torch import nn
import torch.nn.functional as F

## Cross Entropy Loss

For Understanding Focal Loss we need to understand Cross Entropy Loss

In pytorch Cross Entropy Loss = LogSoftmax + NLLLoss

NLLLoss is negative log likelihood which is used in multiclass classifier.

$$
\boldsymbol{\mathcal{L}}=-\frac{1}{n}\sum_{i=1}^{n}\log(\hat{y}^{(i)})
$$


#### Below is the code to understand NLL Loss https://isaacchanghau.github.io/post/loss_functions/

In [2]:
def NLLLoss(logs, targets):
    out = torch.zeros_like(targets, dtype=torch.float)
    for i in range(len(targets)):
        out[i] = logs[i][targets[i]]
    return -out.mean()

x = torch.randn(3, 5)
y = torch.LongTensor([4, 1, 2])
cross_entropy_loss = torch.nn.CrossEntropyLoss()
log_softmax = torch.nn.LogSoftmax(dim=1)
x_log = log_softmax(x)

nll_loss = torch.nn.NLLLoss()
print("Torch CrossEntropyLoss: ", cross_entropy_loss(x, y))
print("Torch NLL loss: ", nll_loss(x_log, y))
print("Custom NLL loss: ", NLLLoss(x_log, y))
# Torch CrossEntropyLoss:  tensor(1.8739)
# Torch NLL loss:  tensor(1.8739)
# Custom NLL loss:  tensor(1.

Torch CrossEntropyLoss:  tensor(1.8096)
Torch NLL loss:  tensor(1.8096)
Custom NLL loss:  tensor(1.8096)


### Binary Cross Entropy
 $$
 \boldsymbol{\mathcal{L}} = -y*log(p) - (1-y)*log(1-p)  
 $$
 p is the probability and y is true label

## [Focal Loss](https://arxiv.org/pdf/1708.02002.pdf)

#### Binary Focal Loss

In [0]:
def binary_focal_loss(y_pred , y_true,gamma=2.0 , alpha=0.25 ,reduction="mean",function=torch.sigmoid,**kwargs):
    """
    Binary Version of Focal Loss
    :args
    
    y_pred : prediction
    
    y_true : true target labels
    
    gamma: dampeing factor default value 2 works well according to reasearch paper
    
    alpha : postive to negative ratio default value 0.25 means 1 positive and 3 negative can be tuple ,list ,int and float
    
    reduction = mean,sum,none

    function = can be sigmoid or softmax or None
    
    **kwargs: parameters to pass in activation function like dim in softmax
    
    """
    if isinstance(alpha,(list,tuple)):
        pos_alpha = alpha[0] # postive sample ratio in the entire dataset
        neg_alpha = alpha[1] #(1-alpha) # negative ratio in the entire dataset
    elif isinstance(alpha ,(int,float)):
        pos_alpha = alpha
        neg_alpha = (1-alpha)
        
    # if else in function can be simplified be removing setting to default to sigmoid  for educational purpose
    if function is not None:
        y_pred = function(y_pred , **kwargs) #apply activation function
    else :
        assert ((y_pred <= 1) & (y_pred >= 0)).all().item() , "negative value in y_pred value should be in the range of 0 to 1 inclusive"
        
        
    pos_pt = torch.where(y_true==1 , y_pred , torch.ones_like(y_pred)) # positive pt (fill all the 0 place in y_true with 1 so (1-pt)=0 and log(pt)=0.0) where pt is 1
    neg_pt = torch.where(y_true==0 , y_pred , torch.zeros_like(y_pred)) # negative pt
    
    pos_modulating = (1-pos_pt)**gamma # compute postive modulating factor for correct classification the value approaches to zero
    neg_modulating = (neg_pt)**gamma # compute negative modulating factor
    
    
    pos = - pos_alpha*pos_modulating*torch.log(pos_pt) #pos part
    neg = - neg_alpha*neg_modulating*torch.log(1-neg_pt) # neg part
    
    loss = pos+neg  # this is final loss to be returned with some reduction
    
    # apply reduction
    if reduction =="mean":
        return loss.mean()
    elif reduction =="sum":
        return loss.sum()
    elif reduction =="none":
        return loss # reduction mean
    else:
        raise f"Wrong reduction {reduction} is choosen \n choose one among [mean,sum,none]  "
    

In [4]:
y_pred = torch.randn(32,10)
y_true = torch.empty(32, 10).random_(2)
F.binary_cross_entropy_with_logits(y_pred,y_true) , binary_focal_loss(y_pred,y_true,gamma=0,alpha=[1,1]) # to test the correctness of method

(tensor(0.8163), tensor(0.8163))

#### Categorical Focal Loss

This focal loss is written for only object detection is SSD and YOLO where we assume that 75% anchors are background and 25% of them are foreground

In [0]:
def categorical_focal_loss(y_pred , y_true,gamma=2.0 , alpha=[0.25 ,0.75] ,reduction="mean"):
    """
    Categorical Version of Focal Loss
    :args
    
    y_pred : prediction
    
    y_true : true target labels
    
    gamma: dampeing factor default value 2 works well according to reasearch paper
    
    alpha : foreground to background ratio can be list,tuple ,int or float 
    
    reduction = mean,sum,none
   
    """
    if isinstance(alpha,(list,tuple)):
        fore_alpha = alpha[0] # postive sample ratio in the entire dataset
        back_alpha = alpha[1] #(1-alpha) # negative ratio in the entire dataset
    elif isinstance(alpha ,(int,float)):
        fore_alpha = alpha
        back_alpha = (1-alpha)

    y_true = torch.eye(y_pred.shape[-1])[y_true] # generate one hot vector
    
    y_pred = F.softmax(y_pred , dim=1) #apply activation function

    loss = - y_true * torch.log(y_pred) # cross entropy
    loss =  loss * (1 - y_pred) ** gamma # focal loss
    
    # A common method for addressing class imbalance is to introduce a weighting factor α∈[0,1]for class1 and 1−α for class−1
    loss[loss[:,0] == 0.0] = loss[loss[:,0] == 0.0] * fore_alpha # foreground weightage 
    loss[loss[:,0] != 0.0] = loss[loss[:,0] != 0.0] * back_alpha # background weightage 
    
    loss = loss.sum(dim=1)
    # apply reduction
    if reduction =="mean":
        return loss.mean()
    elif reduction =="sum":
        return loss.sum()
    elif reduction =="none":
        return loss # reduction mean
    else:
        raise f"Wrong reduction {reduction} is choosen \n choose one among [mean,sum,none]  "
    

In [6]:
# Generating SSD data with random configuration
"""
confidence_prediction: (batch_size,num_anchors,num_classes) 
"""
confidence_prediction = torch.randn((32,8732,21),requires_grad=True)

"""
target: (batch_size,num_anchors,num_classes)
"""
target = torch.empty(32,8732).random_(21).to(torch.long)

## Reshaping 
confidence_prediction = confidence_prediction.view(-1,21)
print(f"confidence_prediction shape {confidence_prediction.shape}")
target = target.view(-1)
print(f"target shape {target.shape}")

target[:209568] = 0 # filling 75% of target with zeros 0.75*target.shape[0] = 209568

F.cross_entropy(confidence_prediction,target), categorical_focal_loss(confidence_prediction,target)

confidence_prediction shape torch.Size([279424, 21])
target shape torch.Size([279424])


(tensor(3.5091, grad_fn=<NllLossBackward>),
 tensor(2.0646, grad_fn=<MeanBackward1>))