In [1]:
import torchvision.datasets as datasets
import torch
import torch.nn as nn
import os
from torch.utils.data import WeightedRandomSampler, DataLoader
import torchvision.transforms as transforms

In [2]:
# # # methods for dealing with imbalanced datasets
# # 1. Oversampling
# # 2. Class Weighting

In [3]:
# # # class weighting
# # # 1. calculate the class weights
# # # loss_fn = nn.CrossEntropyLoss(weight=torch.tensor([1, 50], dtype=torch.float32))

In [4]:
def get_loader(root_dir:str, batch_size:int) -> any:
    my_transforms = transforms.Compose(
        [
            transforms.Resize((224, 224)),
            transforms.ToTensor()
        ]
    )
    # load the dataset
    dataset = datasets.ImageFolder(root=root_dir, transform=my_transforms)
    # create some class weight
    class_weights: list[int] = []
    for _, _, files in os.walk(root_dir):
        if len(files) > 0:
            class_weights.append(1/len(files))
        
    # create sample weight
    sample_weights: list[int] = [0] * len(dataset) # each element in the list is a sample starts with this sample weight
    for idx, (_, label) in enumerate(dataset):
        class_weight = class_weights[label]
        sample_weights[idx] = class_weight
    # create a sampler to take in the sample weights
    sampler = WeightedRandomSampler(
        sample_weights, 
        num_samples=len(sample_weights), 
        replacement=True # set to False to see this example only once (not ideal for oversampling)
    )
    # create a loader to take in the sampler
    loader = DataLoader(
        dataset, 
        batch_size=batch_size, 
        sampler=sampler
    )
    return loader


In [5]:
loader = get_loader(root_dir='../dataset/imbalance_data/', batch_size=8)


In [6]:
num_retrievers = []
num_elkhounds = []

for epoch in range(10):
    for data, target in loader:
        num_retrievers.append((target == 0).sum().item())
        num_elkhounds.append((target == 1).sum().item())
        
print(f"Num of Retrievers : {len(num_retrievers)} \nNum of Elkhounds : {len(num_elkhounds)}")

Num of Retrievers : 70 
Num of Elkhounds : 70
