In [22]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, Subset
from torchvision import models, transforms
from PIL import Image
import numpy as np
from sklearn.model_selection import StratifiedShuffleSplit
from collections import Counter
from tqdm import tqdm

# ====== Custom Dataset ======
class TaskDataset(Dataset):
    def __init__(self, transform=None):
        self.ids = []
        self.imgs = []
        self.labels = []
        self.transform = transform

    def __getitem__(self, index) -> Tuple[int, torch.Tensor, int]:
        id_ = self.ids[index]
        img = self.imgs[index]
        if not self.transform is None:
            img = self.transform(img)
        label = self.labels[index]
        return id_, img, label

    def __len__(self):
        return len(self.ids)

def to_rgb(image):
    return image.convert('RGB') if image.mode != 'RGB' else image

dataset = torch.load("Train.pt", weights_only=False)
dataset = [(idx, to_rgb(img), label) for idx, img, label in data]
labels = [label for _, _, label in dataset]

In [23]:
# ====== Stratified Split with Rare Tag Handling ======
labels_np = np.array(labels)
tag_counts = Counter(labels_np)
rare_indices = np.array([i for i, tag in enumerate(labels_np) if tag_counts[tag] < 2])
strat_eligible = np.array([i for i, tag in enumerate(labels_np) if tag_counts[tag] >= 2])
strat_tags = labels_np[strat_eligible]

sss = StratifiedShuffleSplit(n_splits=1, test_size=0.2, random_state=42)
train_idx_r, test_idx_r = next(sss.split(strat_eligible, strat_tags))

train_idx_r = strat_eligible[train_idx_r]
test_idx_r = strat_eligible[test_idx_r]
train_idx = np.concatenate([rare_indices, train_idx_r])
test_idx = test_idx_r

train_data = Subset(dataset, train_idx)
test_data = Subset(dataset, test_idx)

train_loader = DataLoader(train_data, batch_size=32, shuffle=True, num_workers=2)
test_loader = DataLoader(test_data, batch_size=32, shuffle=False, num_workers=2)

In [29]:
from collections import defaultdict
import numpy as np

def check_distribution_across_splits(train_idx, test_idx, data):
    # Extract labels for all data points by idx
    labels = {idx: label for idx, _, label in data}
    
    tag_counts = defaultdict(lambda: {"train": 0, "test": 0, "total": 0})

    for idx in train_idx:
        tag = labels.get(idx)
        if tag is not None:
            tag_counts[tag]["train"] += 1
            tag_counts[tag]["total"] += 1

    for idx in test_idx:
        tag = labels.get(idx)
        if tag is not None:
            tag_counts[tag]["test"] += 1
            tag_counts[tag]["total"] += 1

    # Print header
    print(f"{'Tag':<15} {'Train':>7} {'Test':>7} {'Total':>7} | {'Train %':>8} {'Test %':>8}")
    print("-" * 70)

    for tag in sorted(tag_counts.keys()):
        info = tag_counts[tag]
        total = info["total"]
        train_pct = 100 * info["train"] / total if total else 0
        test_pct = 100 * info["test"] / total if total else 0
        print(f"{tag:<15} {info['train']:>7}  {info['test']:>7} {total:>7} | {train_pct:7.2f}%  {test_pct:7.2f}%")


In [31]:
check_distribution_across_splits(train_idx, test_idx, dataset)

Tag               Train    Test   Total |  Train %   Test %
----------------------------------------------------------------------
0                   338       86     424 |   79.72%    20.28%
1                  6650     1682    8332 |   79.81%    20.19%
2                 17005     4214   21219 |   80.14%    19.86%
3                  2613      648    3261 |   80.13%    19.87%
4                  5105     1310    6415 |   79.58%    20.42%
5                 10829     2697   13526 |   80.06%    19.94%
6                  3380      861    4241 |   79.70%    20.30%
7                   329       80     409 |   80.44%    19.56%
8                 22994     5795   28789 |   79.87%    20.13%
9                  3414      854    4268 |   79.99%    20.01%
