In [1]:
from Resnet import ResNet, Block, create_pairs
import random
import pandas as pd
import numpy as np
import os
import torch
from torch.nn import TripletMarginLoss
from torchvision.io import decode_image, ImageReadMode
from torch.utils.data import Dataset, Sampler, DataLoader
from torchvision.transforms import v2
from torch.optim import Adam
from collections import defaultdict

In [2]:
"""
creates batches with a size of 15, ensuring we are using all samples provided at random.
"""

def create_batches(label_mapping, n_classes, n_samples):
    all_classes = list(label_mapping.keys())
    batches = []

    while True:
        available_classes = [c for c in all_classes if len(label_mapping[c]) > 0] # Keep only classes if they are still available.
        # Exit if we have less
        if len(available_classes) < n_classes: 
            break

        # Ensures that we will at least have 1 example of every label pair
        selected_classes = random.sample(available_classes, n_classes)
        batch = []

        for cls in selected_classes:
            indices = label_mapping[cls]
            # If there are not enough examples to sample from, we re-select a sample from the shortened list. Empty the list after
            if len(indices) >= n_samples:
                chosen = indices[:n_samples]
                label_mapping[cls] = indices[n_samples:]
            else:
                chosen = indices + random.choices(indices, k=(n_samples - len(indices)))
                label_mapping[cls] = []

            batch.extend(chosen)

        batches.append(batch)

    return batches


class BalancedBatchSampler(Sampler):
    def __init__(self, labels, n_classes, n_samples):
        self.labels = labels
        self.n_classes = n_classes 
        self.n_samples = n_samples  

        self.label_mapping = defaultdict(list)
        for idx, label in enumerate(labels):
            self.label_mapping[int(label)].append(idx)

    def __iter__(self):
        for cls in self.label_mapping:
            random.shuffle(self.label_mapping[cls])

        batches = create_batches(self.label_mapping, self.n_classes, self.n_samples)
        for batch in batches:
            yield batch

    def __len__(self):
        batches = create_batches(self.label_mapping, self.n_classes, self.n_samples)
        return len(batches)

In [3]:
class ImageDataset(Dataset):
    def __init__(self, annotation_file, img_dir, transform=None):
        self.img_label = pd.read_csv(annotation_file)
        self.img_dir = img_dir
        self.transform = transform

    def __len__(self):
        return len(self.img_label)

    def __getitem__(self, idx):
        img_path = os.path.join(self.img_dir, self.img_label.iloc[idx, 0])
        image = decode_image(img_path, mode=ImageReadMode.RGB)
        label = self.img_label.iloc[idx, 1]
        if self.transform:
            image = self.transform(image)

        return image, label

In [4]:
transform = v2.Compose([
    v2.Resize(256),
    v2.CenterCrop(224),
    v2.ToDtype(torch.float32, scale=True)])
train_data = ImageDataset('Data/train_data.csv', 'Data/train/', transform)
valid_data = ImageDataset('Data/valid_data.csv', 'Data/train/', transform)
test_data = ImageDataset('Data/test_data.csv', 'Data/train/', transform)

In [5]:
resnet_model = ResNet(Block, [3, 4, 6, 3], image_channels=3)

In [6]:
df = pd.read_csv('Data/train_data.csv')
label_tensor = torch.tensor(df['encoded_ground_truth'].values)

sampler = BalancedBatchSampler(label_tensor, 5, 3)

dataloader = DataLoader(train_data, batch_sampler=sampler)

In [7]:
learning_rate = 1e-3
epochs = 50
batch_size = 15

In [8]:
for batch in dataloader:
    x = batch[0]
    y = batch[1]
    pred = resnet_model(x)
    break

In [9]:
loss = TripletMarginLoss()

In [10]:
res = create_pairs(y, pred)

In [29]:
anchor_embedding = []
positive_embedding = []
negative_embedding = []
for anchor, positive, negative in res:
    anchor_embedding.append(anchor)
    positive_embedding.append(positive)
    negative_embedding.append(negative)

anchor_tensor = torch.stack(anchor_embedding)
positive_tensor = torch.stack(positive_embedding)
negative_tensor = torch.stack(negative_embedding)
loss(anchor_tensor, positive_tensor, negative_tensor)

tensor(0.8518, grad_fn=<MeanBackward0>)