### Dealing with imbalanced dataset

**Date:** 14/11/2021  
**Author:** Murad Popattia

Methods for dealing with imbalanced datasets
- Oversampling
- Class weighting (we give higher priority to minority class in the loss function i.e mult loss by some number)

In [12]:
import torch
import torchvision.datasets as datasets
import os
from torch.utils.data import WeightedRandomSampler, DataLoader
import torchvision.transforms as transforms
import torch.nn as nn

In [8]:
# Class weighting: 
# For instance we have 50 examples for class 1 and 1 example for class 2 so,
#     class 1 -> x1
#     class 2 -> x50 ...

# loss_fn = nn.CrossEntropyLoss(weight=torch.tensor([1,50])) etc..

In [41]:
def get_loader(root_dir, batch_size):
    my_transforms = transforms.Compose([
        transforms.Resize((224,224)),
        transforms.ToTensor()
    ])

    # making the dataset using ImageFolder class in torch
    dataset = datasets.ImageFolder(root_dir, transform=my_transforms)
#     class_weights = [1,50] or class_weights = [1/50,1] ... both mean the same
    class_weights = []
    # generalizing the class weights for each class as we an have x different classes
    
    for root, subdir, files in os.walk(root_dir):
        if len(files) > 0:
            class_weights.append(1/len(files)) # so in case of 1 file, we assign it 1 else 1 / num_files
    
    sample_weights = [0] * len(dataset) # first initialize the sample weight 0 for all samples
    
    # setting the class_weight for each sample
    for idx, (data, label) in enumerate(dataset):
        class_weight = class_weights[label] # fetch the class_weight
        sample_weights[idx] = class_weight
        
    # for WeightedRandomSampler we need to specify weight for each sample in the dataset
    sampler = WeightedRandomSampler(sample_weights, num_samples = len(sample_weights), replacement=True)
    loader = DataLoader(dataset, batch_size=batch_size, sampler=sampler)
    
    return loader

In [42]:
loader = get_loader("../datasets/imbalance_dataset", 8)

##### With replacement = False

In [24]:
for data, labels in loader:
    print(labels)

tensor([0, 0, 1, 0, 0, 0, 0, 0])
tensor([0, 0, 0, 0, 0, 0, 0, 0])
tensor([0, 0, 0, 0, 0, 0, 0, 0])
tensor([0, 0, 0, 0, 0, 0, 0, 0])
tensor([0, 0, 0, 0, 0, 0, 0, 0])
tensor([0, 0, 0, 0, 0, 0, 0, 0])
tensor([0, 0, 0])


##### With replacement = True

In [40]:
for data, labels in loader:
    print(labels)

tensor([1, 1, 1, 0, 0, 1, 1, 1])
tensor([1, 1, 0, 1, 1, 0, 1, 1])
tensor([0, 1, 1, 0, 1, 1, 0, 0])
tensor([1, 0, 0, 0, 0, 1, 0, 0])
tensor([1, 1, 1, 1, 0, 0, 1, 0])
tensor([1, 0, 0, 1, 0, 1, 1, 0])
tensor([1, 0, 1])


##### Counting the number of samples we get

In [43]:
num_retrievers = 0
num_hounds = 0

for epochs in range(10):
    for data, labels in loader:
        num_retrievers += torch.sum(labels==0)
        num_hounds += torch.sum(labels==1)

print(f"Rerievers: {num_retrievers//10}, Hounds: {num_hounds//10}")

Rerievers: 26, Hounds: 24


So we can see that the dataset is now balanced