In [2]:
import torch
import torchvision.datasets as datasets
import os
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, WeightedRandomSampler
import torch.nn as nn

# Methods for dealing with imbalanced datasets:
# 1. Over-sampling: Add more copies of the minority class
# 2. Class weighting: Increase the loss for the minority class

# Over-sampling
# 1. Random over-sampling: Add random copies of the minority class
# 2. Synthetic over-sampling: Generate synthetic samples of the minority class using techniques like SMOTE

# Class weighting
# 1. Assign a weight to each class in the loss function
# 2. The weight of the minority class is higher than the weight of the majority class

In [3]:
# 2. Class weighting
# Weight of the minority class is 50
loss_fn = nn.CrossEntropyLoss(weight=torch.tensor([1, 50]))

In [4]:
# 1. Random over-sampling (using WeightedRandomSampler)
def get_loader(root_dir, batch_size):
    my_transforms = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor()
    ])

    dataset = datasets.ImageFolder(root=root_dir, transform=my_transforms)
    
    # get the inverse of the number of samples in each class
    # and use them as class weights
    class_weights = []
    for root, subdir, files in os.walk(root_dir):
        if len(files) > 0: # skip empty folders
            class_weights.append(1/len(files))

    sample_weights = [0] * len(dataset)
    # create the sample_weights
    # (weight for each sample based on its class)
    for idx, (data, label) in enumerate(dataset):
        class_weight = class_weights[label]
        sample_weights[idx] = class_weight

    # we can use a WeightedRandomSampler to oversample the minority class
    # replacement=True to make sure we can see examples more than once
    # otherwise it won't over-sample
    sampler = WeightedRandomSampler(sample_weights, 
                                    num_samples=len(sample_weights), 
                                    replacement=True)
    
    # make our loader and specify the sampler
    loader = DataLoader(dataset, batch_size=batch_size, sampler=sampler)

    return loader

In [6]:
loader = get_loader(root_dir='dataset/PetImages', batch_size=8)
for data, label in loader:
    print(label)
    break

UnidentifiedImageError: cannot identify image file <_io.BufferedReader name='dataset/PetImages/Cat/666.jpg'>