### Training TinyImageNet on Inception

Problem faced: Ran out of memory and RAM crashed on google colab

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision.datasets import ImageFolder
from torchvision import transforms
import matplotlib.pyplot as plt
import os

ref: https://towardsdatascience.com/pytorch-ignite-classifying-tiny-imagenet-with-efficientnet-e5b1768e5e8f/

In [None]:
class CustomBatchNorm2d(nn.Module):
    def __init__(self, num_features, eps=1e-5, momentum=0.1):
        super().__init__()
        self.eps = eps
        self.momentum = momentum

        self.gamma = nn.Parameter(torch.ones(num_features))
        self.beta = nn.Parameter(torch.zeros(num_features))

        self.register_buffer("running_mean", torch.zeros(num_features))
        self.register_buffer("running_var", torch.ones(num_features))

    def forward(self, x):
        if self.training:
            mean = x.mean(dim=(0, 2, 3))
            var = x.var(dim=(0, 2, 3), unbiased=False)

            self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * mean
            self.running_var = (1 - self.momentum) * self.running_var + self.momentum * var
        else:
            mean = self.running_mean
            var = self.running_var

        x_hat = (x - mean[None, :, None, None]) / torch.sqrt(var[None, :, None, None] + self.eps)
        out = self.gamma[None, :, None, None] * x_hat + self.beta[None, :, None, None]
        return out


In [None]:
class ConvBNAct(nn.Module):
    def __init__(self, in_c, out_c, k, s=1, p=0, use_bn=False):
        super().__init__()
        layers = [nn.Conv2d(in_c, out_c, k, s, p, bias=not use_bn)]
        if use_bn:
            layers.append(CustomBatchNorm2d(out_c))
        layers.append(nn.ReLU())
        self.net = nn.Sequential(*layers)

    def forward(self, x):
        return self.net(x)


class InceptionBlock(nn.Module):
    def __init__(self, in_c, use_bn=False):
        super().__init__()

        self.b1 = ConvBNAct(in_c, 16, 1, use_bn=use_bn)

        self.b2 = nn.Sequential(
            ConvBNAct(in_c, 16, 1, use_bn=use_bn),
            ConvBNAct(16, 24, 3, p=1, use_bn=use_bn)
        )

        self.b3 = nn.Sequential(
            ConvBNAct(in_c, 16, 1, use_bn=use_bn),
            ConvBNAct(16, 24, 3, p=1, use_bn=use_bn),
            ConvBNAct(24, 24, 3, p=1, use_bn=use_bn)
        )

        self.b4 = nn.Sequential(
            nn.MaxPool2d(3, stride=1, padding=1),
            ConvBNAct(in_c, 16, 1, use_bn=use_bn)
        )

    def forward(self, x):
        return torch.cat([self.b1(x), self.b2(x), self.b3(x), self.b4(x)], dim=1)


In [None]:
class TinyInception(nn.Module):
    def __init__(self, num_classes=200, use_bn=False):
        super().__init__()
        self.stem = ConvBNAct(3, 32, 3, p=1, use_bn=use_bn) # A stem convolution first stabilizes the representation and denoises.
        self.inc1 = InceptionBlock(32, use_bn=use_bn)
        self.inc2 = InceptionBlock(80, use_bn=use_bn)
        self.pool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Linear(80, num_classes)

    def forward(self, x):
        x = self.stem(x)
        x = self.inc1(x)
        x = self.inc2(x)
        x = self.pool(x)
        x = x.view(x.size(0), -1)
        return self.fc(x)


In [None]:
# Retrieve data directly from Stanford data source
!wget http://cs231n.stanford.edu/tiny-imagenet-200.zip

# Unzip raw zip file
!unzip -qq 'tiny-imagenet-200.zip'

# Define main data directory

DATA_DIR = 'tiny-imagenet-200'

--2025-11-22 03:37:40--  http://cs231n.stanford.edu/tiny-imagenet-200.zip
Resolving cs231n.stanford.edu (cs231n.stanford.edu)... 171.64.64.64
Connecting to cs231n.stanford.edu (cs231n.stanford.edu)|171.64.64.64|:80... connected.
HTTP request sent, awaiting response... 301 Moved Permanently
Location: https://cs231n.stanford.edu/tiny-imagenet-200.zip [following]
--2025-11-22 03:37:40--  https://cs231n.stanford.edu/tiny-imagenet-200.zip
Connecting to cs231n.stanford.edu (cs231n.stanford.edu)|171.64.64.64|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 248100043 (237M) [application/zip]
Saving to: ‘tiny-imagenet-200.zip’


2025-11-22 03:37:43 (85.3 MB/s) - ‘tiny-imagenet-200.zip’ saved [248100043/248100043]



In [None]:
from torch.utils.data import DataLoader, Dataset
from PIL import Image
import os

class TinyImageNetVal(Dataset):
    def __init__(self, root, class_to_idx, transform=None):
        self.transform = transform
        self.img_dir = os.path.join(root, "images")

        ann_file = os.path.join(root, "val_annotations.txt")
        self.samples = []

        with open(ann_file, "r") as f:
            for line in f:
                img, label, *_ = line.split()
                img_path = os.path.join(self.img_dir, img)
                # map wnid -> class index
                self.samples.append((img_path, class_to_idx[label]))

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        img_path, label = self.samples[idx]
        img = Image.open(img_path).convert("RGB")

        if self.transform:
            img = self.transform(img)

        return img, label


In [None]:
def load_tiny_imagenet(path="./tiny-imagenet-200"):
    transform = transforms.Compose([
        transforms.Resize((64, 64)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                            std=[0.229, 0.224, 0.225])
    ])


    trainset = ImageFolder(os.path.join(path, "train"), transform=transform)
    # valset   = ImageFolder(os.path.join(path, "val"),   transform=transform)

    # Class mapping: WNID → index
    class_to_idx = trainset.class_to_idx

    val_dir = os.path.join(path, "val")
    valset = TinyImageNetVal(val_dir, class_to_idx, transform=transform)


    trainloader = DataLoader(trainset, batch_size=32, shuffle=True,
                            num_workers=2, pin_memory=True)
    valloader   = DataLoader(valset, batch_size=32, shuffle=False,
                            num_workers=2, pin_memory=True)

    return trainloader, valloader


In [None]:
def save_checkpoint(model, optimizer, scaler, epoch, best_val_acc, path="checkpoint.pth"):
    checkpoint = {
        "model_state": model.state_dict(),
        "optimizer_state": optimizer.state_dict(),
        "scaler_state": scaler.state_dict() if scaler else None,
        "epoch": epoch,
        "best_val_acc": best_val_acc,
    }
    torch.save(checkpoint, path)
    print(f"Checkpoint saved at epoch {epoch}")

def load_checkpoint(model, optimizer, scaler, path="checkpoint.pth"):
    checkpoint = torch.load(path, map_location="cuda")

    model.load_state_dict(checkpoint["model_state"])
    optimizer.load_state_dict(checkpoint["optimizer_state"])

    if scaler and checkpoint["scaler_state"]:
        scaler.load_state_dict(checkpoint["scaler_state"])

    start_epoch = checkpoint["epoch"] + 1
    best_val_acc = checkpoint["best_val_acc"]

    print(f"Resuming from epoch {start_epoch}")
    return start_epoch, best_val_acc


In [None]:
def train_model(model, trainloader, valloader, epochs=100, lr=0.015,
                resume=True, ckpt_path="checkpoint.pth"):

    device = "cuda" if torch.cuda.is_available() else "cpu"
    model = model.to(device)

    opt = optim.SGD(model.parameters(), lr=lr, momentum=0.9)
    criterion = nn.CrossEntropyLoss()
    scaler = torch.cuda.amp.GradScaler()

    train_acc_list = []
    val_acc_list = []
    epos = []

    # ----------- LOAD CHECKPOINT -----------
    start_epoch = 0
    if resume and os.path.exists(ckpt_path):
        start_epoch = load_checkpoint(model, opt, scaler, ckpt_path)
        print(f"Resuming from epoch {start_epoch}")
    # ---------------------------------------

    for epoch in range(start_epoch, epochs):
        model.train()
        correct = total = 0

        for x, y in trainloader:
            x, y = x.to(device), y.to(device)
            opt.zero_grad()

            with torch.cuda.amp.autocast():
                out = model(x)
                loss = criterion(out, y)

            scaler.scale(loss).backward()
            scaler.step(opt)
            scaler.update()

            _, pred = out.max(1)
            correct += pred.eq(y).sum().item()
            total += y.size(0)

        train_acc = correct / total
        train_acc_list.append(train_acc)
        epos.append(epoch)

        # Validation
        model.eval()
        correct = total = 0
        with torch.no_grad():
            for x, y in valloader:
                x, y = x.to(device), y.to(device)
                out = model(x)
                _, pred = out.max(1)
                correct += pred.eq(y).sum().item()
                total += y.size(0)

        val_acc = correct / total
        val_acc_list.append(val_acc)

        print(f"Epoch {epoch+1}: Train Acc={train_acc:.2f} | Val Acc={val_acc:.2f}")

        # ----------- SAVE CHECKPOINT EVERY EPOCH -----------
        save_checkpoint(model, opt, scaler, epoch, ckpt_path)
        # ---------------------------------------------------

    return train_acc_list, val_acc_list, epos


In [None]:
trainloader, valloader = load_tiny_imagenet()

In [None]:
model_bn = TinyInception(use_bn=True)

acc_bn_train, acc_bn_val, ep = train_model(model_bn, trainloader, valloader,epochs=30,lr=0.015)

  scaler = torch.cuda.amp.GradScaler()
  with torch.cuda.amp.autocast():


Epoch 1: Train Acc=0.07 | Val Acc=0.11
Checkpoint saved at epoch 0
Epoch 2: Train Acc=0.13 | Val Acc=0.13
Checkpoint saved at epoch 1
Epoch 3: Train Acc=0.16 | Val Acc=0.19
Checkpoint saved at epoch 2
Epoch 4: Train Acc=0.19 | Val Acc=0.18
Checkpoint saved at epoch 3
Epoch 5: Train Acc=0.21 | Val Acc=0.23
Checkpoint saved at epoch 4
Epoch 6: Train Acc=0.22 | Val Acc=0.23
Checkpoint saved at epoch 5
Epoch 7: Train Acc=0.23 | Val Acc=0.23
Checkpoint saved at epoch 6
Epoch 8: Train Acc=0.25 | Val Acc=0.24
Checkpoint saved at epoch 7
Epoch 9: Train Acc=0.25 | Val Acc=0.26
Checkpoint saved at epoch 8
Epoch 10: Train Acc=0.26 | Val Acc=0.26
Checkpoint saved at epoch 9
Epoch 11: Train Acc=0.27 | Val Acc=0.26
Checkpoint saved at epoch 10
Epoch 12: Train Acc=0.28 | Val Acc=0.26
Checkpoint saved at epoch 11


In [None]:
plt.figure(figsize=(10,5))
plt.plot([i+1 for i in ep], acc_bn_val, label="BN")
plt.title("Validation Accuracy vs Epoch (Tiny-ImageNet)")
plt.xlabel("Epoch")
plt.ylabel("Accuracy")
plt.legend()
plt.grid()
plt.savefig("plot.png")
plt.show()
