Imports and Seedings

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Subset, ConcatDataset, WeightedRandomSampler
from torchvision import datasets, transforms
import numpy as np
import random
import os

def seed_all(seed=42):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)

seed_all(42)

Dataset Transforms (Equivalent to Transform enum)

In [None]:
def build_transform_list(transform_names):
    ops = []

    for t in transform_names:
        if t == "Translate":
            ops.append(transforms.RandomAffine(0, translate=(0.1, 0.1)))
        elif t == "Shear":
            ops.append(transforms.RandomAffine(0, shear=10))
        elif t == "Scale":
            ops.append(transforms.RandomAffine(0, scale=(0.9, 1.1)))
        elif t == "Rotation":
            ops.append(transforms.RandomRotation(10))

    return transforms.Compose(ops)


DatasetIdent Equivalent

In [None]:
class DatasetIdent:
    PLAIN = "Plain"
    ALL = "All"

    def __init__(self, transforms=None):
        self.transforms = transforms or []

    def prepare(self, base_dataset):
        if not self.transforms:
            return base_dataset

        tfm = build_transform_list(self.transforms)
        return datasets.MNIST(
            root=base_dataset.root,
            train=base_dataset.train,
            download=False,
            transform=tfm
        )

    def __str__(self):
        if not self.transforms:
            return "Plain"
        return " ".join(self.transforms)


Generate Dataset Idents (Exact Logic Port)

In [None]:
def generate_idents(num_samples_base=None):
    idents = []

    transforms = ["Shear", "Scale", "Rotation", "Translate"]

    for i in range(2 ** 4):
        current = [transforms[j] for j in range(4) if (i >> j) & 1]

        if len(current) == 4:
            ident = DatasetIdent(transforms)
        elif not current:
            ident = DatasetIdent()
        else:
            ident = DatasetIdent(current)

        size = num_samples_base * len(current) if num_samples_base else None
        idents.append((ident, size))

    return idents


Build Training & Validation Datasets

In [None]:
root = "./data"

base_train = datasets.MNIST(root, train=True, download=True, transform=transforms.ToTensor())
train_plain = Subset(base_train, range(0, 55_000))
valid_plain = Subset(base_train, range(55_000, 60_000))

train_idents = generate_idents(10_000)
valid_idents = generate_idents(None)

def compose_dataset(idents, base_subset):
    datasets_out = []

    for ident, size in idents:
        ds = ident.prepare(base_subset.dataset)
        ds = Subset(ds, base_subset.indices)

        if size:
            sampler = WeightedRandomSampler(
                weights=[1.0] * len(ds),
                num_samples=size,
                replacement=True
            )
            datasets_out.append(DataLoader(ds, sampler=sampler))
        else:
            datasets_out.append(ds)

    return ConcatDataset(datasets_out)

DataLoaders

In [None]:
batch_size = 256

train_dataset = compose_dataset(train_idents, train_plain)
valid_dataset = compose_dataset(valid_idents, valid_plain)

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=8)
valid_loader = DataLoader(valid_dataset, batch_size=batch_size, shuffle=False, num_workers=8)

Optimizer & LR Scheduler (Exact Burn Match)

In [None]:
optimizer = optim.AdamW(
    model.parameters(),
    lr=1.0,
    weight_decay=5e-5
)

scheduler = optim.lr_scheduler.SequentialLR(
    optimizer,
    schedulers=[
        optim.lr_scheduler.LinearLR(optimizer, start_factor=1e-8, end_factor=1.0, total_iters=2000),
        optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=2000),
        optim.lr_scheduler.LinearLR(optimizer, start_factor=1e-2, end_factor=1e-6, total_iters=10000),
    ],
    milestones=[2000, 4000]
)

Training Loop with Early Stopping

In [None]:
best_val_loss = float("inf")
patience = 5
no_improve = 0

for epoch in range(20):
    model.train()
    train_loss = 0

    for x, y in train_loader:
        optimizer.zero_grad()
        out = model(x.squeeze(1))
        loss = nn.CrossEntropyLoss()(out, y)
        loss.backward()
        optimizer.step()
        scheduler.step()
        train_loss += loss.item()

    model.eval()
    val_loss = 0
    correct = 0
    total = 0

    with torch.no_grad():
        for x, y in valid_loader:
            out = model(x.squeeze(1))
            loss = nn.CrossEntropyLoss()(out, y)
            val_loss += loss.item()
            correct += (out.argmax(1) == y).sum().item()
            total += y.size(0)

    val_loss /= len(valid_loader)
    acc = correct / total

    print(f"Epoch {epoch}: val_loss={val_loss:.4f}, acc={acc:.4f}")

    if val_loss < best_val_loss:
        best_val_loss = val_loss
        no_improve = 0
        torch.save(model.state_dict(), "model.pt")
    else:
        no_improve += 1
        if no_improve >= patience:
            print("Early stopping")
            break

Test Evaluation (Per DatasetIdent)

In [None]:
test_base = datasets.MNIST(root, train=False, download=True, transform=transforms.ToTensor())
test_idents = generate_idents(None)

model.load_state_dict(torch.load("model.pt"))
model.eval()

for ident, _ in test_idents:
    ds = ident.prepare(test_base)
    loader = DataLoader(ds, batch_size=batch_size)

    correct = 0
    total = 0
    with torch.no_grad():
        for x, y in loader:
            out = model(x.squeeze(1))
            correct += (out.argmax(1) == y).sum().item()
            total += y.size(0)

    print(f"{ident}: accuracy = {correct / total:.4f}")