# CIFAR-10 Model Comparison (Custom CNN vs ResNet50 vs VGG16 vs DenseNet121)
This notebook trains four models on **CIFAR-10**, stores them in memory, saves checkpoints, and **visualizes predictions** on the *same sample test images* for each model.

**What you get:**
- Trained models accessible via `trained_models[...]`
- Saved weights in `checkpoints/`
- Accuracy curves
- Prediction grids (GT vs Pred + confidence)


## 1. Imports & Setup

In [None]:
import os
import time
import copy
import numpy as np
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

import torchvision
import torchvision.transforms as transforms
from torchvision import models

SEED = 42
torch.manual_seed(SEED)
np.random.seed(SEED)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", device)

classes = ('plane', 'car', 'bird', 'cat', 'deer',
           'dog', 'frog', 'horse', 'ship', 'truck')

## 2. Dataset & DataLoaders

In [None]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5),
                         (0.5, 0.5, 0.5))
])

trainset = torchvision.datasets.CIFAR10(root="./data", train=True, download=True, transform=transform)
testset  = torchvision.datasets.CIFAR10(root="./data", train=False, download=True, transform=transform)

BATCH_SIZE = 64
trainloader = torch.utils.data.DataLoader(
    trainset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2, pin_memory=True
)
testloader = torch.utils.data.DataLoader(
    testset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2, pin_memory=True
)

print("Train size:", len(trainset), "| Test size:", len(testset))

## 3. Model Definitions

In [None]:
class CustomCNN(nn.Module):
    def __init__(self, num_classes=10):
        super().__init__()
        self.conv1 = nn.Conv2d(3,  6, 5)
        self.pool  = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)

        # CIFAR-10 32x32 -> after conv/pool:
        # 32 -> conv5 => 28 -> pool => 14
        # 14 -> conv5 => 10 -> pool => 5
        # channels=16 => 16*5*5 = 400
        self.fc1 = nn.Linear(16*5*5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, num_classes)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = torch.flatten(x, 1)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

def build_resnet50(num_classes=10):
    weights = models.ResNet50_Weights.DEFAULT
    model = models.resnet50(weights=weights)
    model.fc = nn.Linear(model.fc.in_features, num_classes)
    return model

def build_vgg16(num_classes=10):
    weights = models.VGG16_Weights.DEFAULT
    model = models.vgg16(weights=weights)
    model.classifier[6] = nn.Linear(model.classifier[6].in_features, num_classes)
    return model

def build_densenet121(num_classes=10):
    weights = models.DenseNet121_Weights.DEFAULT
    model = models.densenet121(weights=weights)
    model.classifier = nn.Linear(model.classifier.in_features, num_classes)
    return model

## 4. Training & Evaluation Utilities

In [None]:
def train_one_epoch(model, loader, criterion, optimizer, device):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0

    for inputs, labels in loader:
        inputs, labels = inputs.to(device), labels.to(device)

        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item() * inputs.size(0)
        _, preds = torch.max(outputs, 1)
        correct += (preds == labels).sum().item()
        total += labels.size(0)

    return running_loss / total, correct / total


@torch.no_grad()
def evaluate(model, loader, criterion, device):
    model.eval()
    running_loss = 0.0
    correct = 0
    total = 0

    for inputs, labels in loader:
        inputs, labels = inputs.to(device), labels.to(device)
        outputs = model(inputs)
        loss = criterion(outputs, labels)

        running_loss += loss.item() * inputs.size(0)
        _, preds = torch.max(outputs, 1)
        correct += (preds == labels).sum().item()
        total += labels.size(0)

    return running_loss / total, correct / total


def fit_model(model, model_name, trainloader, testloader, device, epochs=5, lr=0.001, momentum=0.9):
    model = model.to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(model.parameters(), lr=lr, momentum=momentum)

    history = {"train_loss": [], "train_acc": [], "test_loss": [], "test_acc": []}

    best_acc = -1.0
    best_state = None
    best_epoch = -1

    for epoch in range(1, epochs + 1):
        t0 = time.time()

        tr_loss, tr_acc = train_one_epoch(model, trainloader, criterion, optimizer, device)
        te_loss, te_acc = evaluate(model, testloader, criterion, device)

        history["train_loss"].append(tr_loss)
        history["train_acc"].append(tr_acc)
        history["test_loss"].append(te_loss)
        history["test_acc"].append(te_acc)

        if te_acc > best_acc:
            best_acc = te_acc
            best_state = copy.deepcopy(model.state_dict())
            best_epoch = epoch

        dt = time.time() - t0
        print(f"[{model_name}] Epoch {epoch:02d}/{epochs} | "
              f"Train: loss={tr_loss:.4f}, acc={tr_acc:.4f} | "
              f"Test: loss={te_loss:.4f}, acc={te_acc:.4f} | "
              f"{dt:.1f}s")

    model.load_state_dict(best_state)
    return model, history, best_acc, best_epoch

## 5. Visualization Utilities (Predictions on Sample Images)

In [None]:
def _denorm(img_tensor):
    # inverse Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))
    return img_tensor * 0.5 + 0.5

@torch.no_grad()
def visualize_model_predictions(model, model_name, testloader, device, n_images=8, fixed_batch=None):
    model.eval()

    if fixed_batch is None:
        images, labels = next(iter(testloader))
    else:
        images, labels = fixed_batch

    images, labels = images.to(device), labels.to(device)

    outputs = model(images)
    probs = F.softmax(outputs, dim=1)
    confs, preds = torch.max(probs, 1)

    images_vis = _denorm(images.detach().cpu()).clamp(0, 1).numpy()

    cols = 4
    rows = int(np.ceil(n_images / cols))
    plt.figure(figsize=(cols * 4, rows * 4))

    for i in range(min(n_images, images_vis.shape[0])):
        plt.subplot(rows, cols, i + 1)
        plt.imshow(np.transpose(images_vis[i], (1, 2, 0)))
        gt = classes[labels[i].item()]
        pr = classes[preds[i].item()]
        cf = confs[i].item()
        plt.title(f"GT: {gt}\nPred: {pr} ({cf:.2f})", fontsize=10)
        plt.axis("off")

    plt.suptitle(model_name, fontsize=14)
    plt.tight_layout()
    plt.show()

## 6. Train All Models (Store + Save + Visualize)
This section:
- Trains each model
- Stores it in `trained_models`
- Saves weights in `checkpoints/`
- Visualizes predictions on the **same** sample batch for fair comparison


In [None]:
EPOCHS = 5
LR = 0.001
MOMENTUM = 0.9

os.makedirs("checkpoints", exist_ok=True)

trained_models = {}
histories = {}
summary = []

# Use the SAME images for all models in visualization (fair comparison)
fixed_images, fixed_labels = next(iter(testloader))
fixed_batch = (fixed_images, fixed_labels)

### 6.1 Train Custom CNN

In [None]:
model_name = "Custom CNN"
model = CustomCNN(num_classes=10)

model, history, best_acc, best_epoch = fit_model(
    model, model_name, trainloader, testloader, device,
    epochs=EPOCHS, lr=LR, momentum=MOMENTUM
)

trained_models[model_name] = model
histories[model_name] = history
torch.save(model.state_dict(), "checkpoints/custom_cnn_cifar10.pth")

summary.append([model_name, best_acc, best_epoch])
visualize_model_predictions(model, model_name, testloader, device, n_images=8, fixed_batch=fixed_batch)

### 6.2 Train ResNet50

In [None]:
model_name = "ResNet50"
model = build_resnet50(num_classes=10)

model, history, best_acc, best_epoch = fit_model(
    model, model_name, trainloader, testloader, device,
    epochs=EPOCHS, lr=LR, momentum=MOMENTUM
)

trained_models[model_name] = model
histories[model_name] = history
torch.save(model.state_dict(), "checkpoints/resnet50_cifar10.pth")

summary.append([model_name, best_acc, best_epoch])
visualize_model_predictions(model, model_name, testloader, device, n_images=8, fixed_batch=fixed_batch)

### 6.3 Train VGG16

In [None]:
model_name = "VGG16"
model = build_vgg16(num_classes=10)

model, history, best_acc, best_epoch = fit_model(
    model, model_name, trainloader, testloader, device,
    epochs=EPOCHS, lr=LR, momentum=MOMENTUM
)

trained_models[model_name] = model
histories[model_name] = history
torch.save(model.state_dict(), "checkpoints/vgg16_cifar10.pth")

summary.append([model_name, best_acc, best_epoch])
visualize_model_predictions(model, model_name, testloader, device, n_images=8, fixed_batch=fixed_batch)

### 6.4 Train DenseNet121

In [None]:
model_name = "DenseNet121"
model = build_densenet121(num_classes=10)

model, history, best_acc, best_epoch = fit_model(
    model, model_name, trainloader, testloader, device,
    epochs=EPOCHS, lr=LR, momentum=MOMENTUM
)

trained_models[model_name] = model
histories[model_name] = history
torch.save(model.state_dict(), "checkpoints/densenet121_cifar10.pth")

summary.append([model_name, best_acc, best_epoch])
visualize_model_predictions(model, model_name, testloader, device, n_images=8, fixed_batch=fixed_batch)

## 7. Plot Accuracy Curves

In [None]:
plt.figure(figsize=(10, 6))
for name, h in histories.items():
    plt.plot(range(1, EPOCHS + 1), h["test_acc"], marker="o", label=name)

plt.xlabel("Epoch")
plt.ylabel("Test Accuracy")
plt.title("Model Comparison on CIFAR-10 (Test Accuracy)")
plt.legend()
plt.grid(True)
plt.show()

## 8. Summary Table

In [None]:
import pandas as pd
df = pd.DataFrame(summary, columns=["Model", "Best Test Acc", "Best Epoch"])
df