In [1]:
from Resnet import ResNet, Block
import random
import pandas as pd
import numpy as np
import os
import torch
from torchvision.io import decode_image, ImageReadMode
from torch.utils.data import Dataset, Sampler, DataLoader
from torchvision.transforms import v2
from collections import defaultdict

In [2]:
def create_batches(label_mapping, n_classes, n_samples):
    all_classes = list(label_mapping.keys())
    batches = []

    label_mapping_copy = {c: indices for c, indices in label_mapping.items()}

    while True:
        available_classes = [c for c in all_classes if len(label_mapping_copy[c]) > 0]
        if len(available_classes) < n_classes:
            break

        selected_classes = random.sample(available_classes, n_classes)
        batch = []

        for cls in selected_classes:
            indices = label_mapping_copy[cls]

            if len(indices) >= n_samples:
                chosen = indices[:n_samples]
                label_mapping_copy[cls] = indices[n_samples:]
            else:
                chosen = indices + random.choices(indices, k=n_samples - len(indices))
                label_mapping_copy[cls] = []

            batch.extend(chosen)

        batches.append(batch)

    return batches


class BalancedBatchSampler(Sampler):
    def __init__(self, labels, n_classes, n_samples):
        self.labels = labels
        self.n_classes = n_classes 
        self.n_samples = n_samples  

        self.label_mapping = defaultdict(list)
        for idx, label in enumerate(labels):
            self.label_mapping[int(label)].append(idx)

    def __iter__(self):
        for cls in self.label_mapping:
            random.shuffle(self.label_mapping[cls])

        batches = create_batches(self.label_mapping, self.n_classes, self.n_samples)
        for batch in batches:
            yield batch

    def __len__(self):
        batches = create_batches(self.label_mapping, self.n_classes, self.n_samples)
        return len(batches)

In [3]:
class ImageDataset(Dataset):
    def __init__(self, annotation_file, img_dir, transform=None):
        self.img_label = pd.read_csv(annotation_file)
        self.img_dir = img_dir
        self.transform = transform

    def __len__(self):
        return len(self.img_label)

    def __getitem__(self, idx):
        img_path = os.path.join(self.img_dir, self.img_label.iloc[idx, 0])
        image = decode_image(img_path, mode=ImageReadMode.RGB)
        label = self.img_label.iloc[idx, 1]
        if self.transform:
            image = self.transform(image)

        return image, label

In [4]:
transform = v2.Compose([
    v2.Resize(256),
    v2.CenterCrop(224),
    v2.ToDtype(torch.float32, scale=True)])
train_data = ImageDataset('Data/train_data.csv', 'Data/train/', transform)
valid_data = ImageDataset('Data/valid_data.csv', 'Data/train/', transform)
test_data = ImageDataset('Data/test_data.csv', 'Data/train/', transform)

In [5]:
resnet_model = ResNet(Block, [3, 4, 6, 3], image_channels=3)

In [6]:
df = pd.read_csv('Data/train_data.csv')
label_tensor = torch.tensor(df['encoded_ground_truth'].values)

sampler = BalancedBatchSampler(label_tensor, 5, 3)

dataloader = DataLoader(train_data, batch_sampler=sampler)

In [23]:
for batch in dataloader:
    x = batch[0]
    y = batch[1]
    pred = resnet_model(x)
    break