# Class-balanced-loss-pytorch

Pytorch implementation on [Github](https://github.com/vandit15/Class-balanced-loss-pytorch/tree/master) of the paper
[Class-Balanced Loss Based on Effective Number of Samples](https://arxiv.org/abs/1901.05555) presented at CVPR'19.

See also: [Medium Article](https://medium.com/@vandit_15/handling-class-imbalanced-data-using-a-loss-specifically-made-for-it-6e58fd65ffab?source=friends_link&sk=ac09ea6061990ead2a2f90e3767ae91f)

## Key Idea:

* To address the problem of long-tailed data distribution (i.e., a few classes account for most of the data, while most classes are under-represented).
* As the number of samples increases, the additional benefit
of a newly added data point will diminish.
* Measure data overlap by associating with each sample a small neighboring region rather than a single point.
* The effective number of samples is calculated as $(1−β^n)/(1−β)$,
where $n$ is the number of samples and $β \in [0, 1)$ is a hyperparameter.
* Design a re-weighting scheme that uses the effective number of samples for each class to re-balance the loss, thereby yielding
a class-balanced loss.

In [3]:
"""Pytorch implementation of Class-Balanced-Loss
   Reference: "Class-Balanced Loss Based on Effective Number of Samples"
   Authors: Yin Cui and
               Menglin Jia and
               Tsung Yi Lin and
               Yang Song and
               Serge J. Belongie
   https://arxiv.org/abs/1901.05555, CVPR'19.
"""

import numpy as np
import torch
import torch.nn.functional as F

def focal_loss(labels, logits, alpha, gamma):
    """Compute the focal loss between `logits` and the ground truth `labels`.

    Focal loss = -alpha_t * (1-pt)^gamma * log(pt)
    where pt is the probability of being classified to the true class.
    pt = p (if true class), otherwise pt = 1 - p. p = sigmoid(logit).

    Args:
      labels: A float tensor of size [batch, num_classes].
      logits: A float tensor of size [batch, num_classes].
      alpha: A float tensor of size [batch_size]
        specifying per-example weight for balanced cross entropy.
      gamma: A float scalar modulating loss from hard and easy examples.

    Returns:
      focal_loss: A float32 scalar representing normalized total loss.
    """
    BCLoss = F.binary_cross_entropy_with_logits(input = logits, target = labels,reduction = "none")

    if gamma == 0.0:
        modulator = 1.0
    else:
        modulator = torch.exp(-gamma * labels * logits - gamma * torch.log(1 +
            torch.exp(-1.0 * logits)))

    loss = modulator * BCLoss

    weighted_loss = alpha * loss
    focal_loss = torch.sum(weighted_loss)

    focal_loss /= torch.sum(labels)
    return focal_loss

def CB_loss(labels, logits, samples_per_cls, no_of_classes, loss_type, beta, gamma):
    """Compute the Class Balanced Loss between `logits` and the ground truth `labels`.

    Class Balanced Loss: ((1-beta)/(1-beta^n))*Loss(labels, logits)
    where Loss is one of the standard losses used for Neural Networks.

    Args:
      labels: A int tensor of size [batch].
      logits: A float tensor of size [batch, no_of_classes].
      samples_per_cls: A python list of size [no_of_classes].
      no_of_classes: total number of classes. int
      loss_type: string. One of "sigmoid", "focal", "softmax".
      beta: float. Hyperparameter for Class balanced loss.
      gamma: float. Hyperparameter for Focal loss.

    Returns:
      cb_loss: A float tensor representing class balanced loss
    """
    effective_num = 1.0 - np.power(beta, samples_per_cls)
    weights = (1.0 - beta) / np.array(effective_num)
    weights = weights / np.sum(weights) * no_of_classes

    labels_one_hot = F.one_hot(labels, no_of_classes).float()

    weights = torch.tensor(weights).float()
    weights = weights.unsqueeze(0)
    weights = weights.repeat(labels_one_hot.shape[0],1) * labels_one_hot
    weights = weights.sum(1)
    weights = weights.unsqueeze(1)
    weights = weights.repeat(1,no_of_classes)

    if loss_type == "focal":
        cb_loss = focal_loss(labels_one_hot, logits, weights, gamma)
    elif loss_type == "sigmoid":
        cb_loss = F.binary_cross_entropy_with_logits(input = logits,target = labels_one_hot, weights = weights)
    elif loss_type == "softmax":
        pred = logits.softmax(dim = 1)
        cb_loss = F.binary_cross_entropy(input = pred, target = labels_one_hot, weight = weights)
    return cb_loss

In [7]:
no_of_classes = 5
logits = torch.rand(10, no_of_classes).float()
labels = torch.randint(0, no_of_classes, size = (10,))
beta = 0.9999
gamma = 2.0
samples_per_cls = [5,3,1,2,2]
loss_type = "focal"
cb_loss = CB_loss(labels, logits, samples_per_cls, no_of_classes, loss_type, beta, gamma)
print(cb_loss)

tensor(1.8500)
