# Using Sampler To Combat Class Imbalance

## Quick Answer

If you want equal split checkout the amazing repository https://github.com/ncullen93/torchsample by @ncullen93. Specifically the StratifiedSampler https://github.com/ncullen93/torchsample/blob/master/torchsample/samplers.py#L22. I got this link from the discussion here  https://discuss.pytorch.org/t/how-to-enable-the-dataloader-to-sample-from-each-class-with-equal-probability/911 

## What is Class Imbalance ?

Class imbalance is a common problem occuring in machine learning. It arises because often the data is very skewed, we have lots of data of one class and very little data of another class. For example, suppose there is a small object in every image and the task is to find the small object. A starting point would be to divide the image into small patches, for every patch see if the object is present or not. Let the two classes be foreground class and background class. So while training, when you divide the image into patches, the number of patches which correspond to the small object could be an order of magnitude smaller than the background class.

## Why Class Imbalance is a problem?

Given that we have a class imbalance, why should it be a problem. For first, we note that the neural network tries to optimize the loss function. So what happens when you use simple cross entropy loss? In any given batch (say batch of 10), the number of background classes are 9 and number of foreground class is 1. Thus if the network predicts all images as background class it gets a small loss and this might correspond to a local minima and the network might find it difficult to come out of it.

## How to combat Class Imbalance?

There are three easy ways. 

1. Undersample the background class. This is the easiest of all, simply throw away a large number of background class. 
2. Use class weights. In pytorch you can pass parameter called weights into the cross entropy function in inverse proportion of the distribution. This will penalize the network more if they get the foreground class incorrect.
3. Use a Sampler, which will pick sample from every class equally (or at a fixed ratio). This is the method this notebook focuses on.

## Understanding a few mechanics of Sampler

As usual we start with dogs and cats dataset

In [1]:
import matplotlib
matplotlib.use('Agg')

In [2]:
%reload_ext autoreload
%autoreload 2
%matplotlib inline

In [3]:
from fastai.imports import *
from fastai.transforms import *
from fastai.conv_learner import *
from fastai.model import *
from fastai.dataset import *
from fastai.sgdr import *
from fastai.plots import *

In [5]:
PATH = "../data/dogscats/"
sz=224

In [9]:
arch=resnet34
data = ImageClassifierData.from_paths(PATH, tfms=tfms_from_model(arch, sz), bs=16)

In [None]:
??torch.multinomial

Before moving into mechanics, lets see what is present in the weightedradomsampler (note this is the same as https://github.com/pytorch/pytorch/blob/master/torch/utils/data/sampler.py#L74)

The important thing to note here is that the `weights` do not correspond to the `class_weights` rather it is the weight for each index in the dataset.

In [6]:
class WeightedRandomSampler(object):
    r"""Samples elements from [0,..,len(weights)-1] with given probabilities (weights).

    Arguments:
        weights (list)   : a list of weights, not necessary summing up to one
        num_samples (int): number of samples to draw
        replacement (bool): if ``True``, samples are drawn with replacement.
            If not, they are drawn without replacement, which means that when a
            sample index is drawn for a row, it cannot be drawn again for that row.
    """

    def __init__(self, weights, num_samples, replacement=False):
        self.weights = T(weights)
        self.num_samples = num_samples
        self.replacement = replacement

    def __iter__(self):
        return iter(torch.multinomial(self.weights, self.num_samples, self.replacement))

    def __len__(self):
        return self.num_samples

A few convenience functions to be used later. `check_ws` simply gets distribution given the weighted sampler `ws` adn the dataset `ds`. `check_sampler` similarly does the same with only the dataloader `dl`

In [13]:
def check_ws(ws, ds):
    n0 = 0
    n1 = 0
    tot = 0

    for w in tqdm(ws):
        d = ds[w]
        d1 = d[1]
        if d1 == 0:
            n0 += 1
        elif d1 == 1:
            n1 += 1
        tot += 1
    assert tot == n0 + n1
    print(n0/tot)
    print(n1/tot)

In [11]:
def check_sampler(dl):
    c0 = 0
    c1 = 0
    tot = 0
    for x, y in tqdm(dl):
        ys = y.shape[0]
        num1 = torch.sum(y)
        c0 += ys - num1
        c1 += num1
        tot += ys
    assert tot == c0 + c1
    print(c0 / tot)
    print(c1 / tot)

In [12]:
check_sampler(data.trn_dl)

100%|██████████| 1438/1438 [00:19<00:00, 73.21it/s]
0.5
0.5


As expected the number of cats and dogs are the same.

In [14]:
ds = data.trn_dl.dataset

In [15]:
len(ds)

23000

We will now assign weights to each index

In [16]:
wl2i = []

l2p = [0.25, 0.75]

for ind, d in enumerate(tqdm(ds)):
    wl2i.append(l2p[d[1]])

100%|██████████| 23000/23000 [01:19<00:00, 287.83it/s]


Now we will make a weighted sampler with this list. `replacement` can be set to both `True` and `False` with the former incurring repitions of the undersampled class. The later is also a bit slower.

In [17]:
ws2 = WeightedRandomSampler(wl2i, len(ds)//2, replacement=True)

In [18]:
check_ws(ws2, ds)

100%|██████████| 11500/11500 [00:39<00:00, 288.25it/s]
0.2543478260869565
0.7456521739130435


More or less the exact distribution as expected.

In [19]:
ws3 = WeightedRandomSampler(wl2i, len(ds)//2, replacement=False)

In [20]:
check_ws(ws3, ds)

100%|██████████| 11500/11500 [00:48<00:00, 234.70it/s]
0.3183478260869565
0.6816521739130434


Slightly skewed distribution because of `replacement=False`

We can confirm the same with dataloader as well

In [21]:
new_dl = DataLoader(ds, batch_size=16, 
                    sampler=WeightedRandomSampler(wl2i, len(ds)//2, replacement=True),
                    num_workers=4)

In [22]:
check_sampler(new_dl)

100%|██████████| 719/719 [00:15<00:00, 47.65it/s]
0.24956521739130436
0.7504347826086957


In [23]:
new_dl_no_replacement = DataLoader(ds, batch_size=16, 
                        sampler=WeightedRandomSampler(wl2i, len(ds)//2, replacement=False),
                        num_workers=4)

In [24]:
check_sampler(new_dl_no_replacement)

100%|██████████| 719/719 [00:24<00:00, 29.29it/s]
0.31965217391304346
0.6803478260869565
