In [39]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms, models
import matplotlib.pyplot as plt
import numpy as np
from tqdm import tqdm
import time

In [40]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", device)

cuda


In [41]:
seed = 42
torch.manual_seed(seed)
np.random.seed(seed)

In [42]:
# --- CNN Transforms ---
transform_train_cnn = transforms.Compose([
    transforms.Resize(224),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))
])

transform_test_cnn = transforms.Compose([
    transforms.Resize(224),
    transforms.ToTensor(),
    transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))
])

# --- ViT Transforms ---
transform_train_vit = transforms.Compose([
    transforms.Resize(224),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))
])
transform_test_vit = transforms.Compose([
    transforms.Resize(224),
    transforms.ToTensor(),
    transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))
])

# CIFAR-10 Datasets
train_dataset_cnn = datasets.CIFAR10(root="./data", train=True, download=True, transform=transform_train_cnn)
test_dataset_cnn = datasets.CIFAR10(root="./data", train=False, download=True, transform=transform_test_cnn)

train_dataset_vit = datasets.CIFAR10(root="./data", train=True, download=True, transform=transform_train_vit)
test_dataset_vit = datasets.CIFAR10(root="./data", train=False, download=True, transform=transform_test_vit)

# DataLoaders
train_loader_cnn = DataLoader(train_dataset_cnn, batch_size=128, shuffle=True, num_workers=4, pin_memory=True)
test_loader_cnn = DataLoader(test_dataset_cnn, batch_size=128, shuffle=False, num_workers=4, pin_memory=True)

train_loader_vit = DataLoader(train_dataset_vit, batch_size=64, shuffle=True, num_workers=4, pin_memory=True)
test_loader_vit = DataLoader(test_dataset_vit, batch_size=64, shuffle=False, num_workers=4, pin_memory=True)

classes = train_dataset_cnn.classes

Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
Classes: ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']


In [43]:
# CNN: VGG16
def get_vgg16_model(num_classes=10, freeze_backbone=True):
    model = models.vgg16(weights="IMAGENET1K_V1")
    model.classifier[6] = nn.Linear(4096, num_classes)

    if freeze_backbone:
        for param in model.parameters():
            param.requires_grad = False
        for param in model.classifier[6].parameters():
            param.requires_grad = True

    return model

# ViT pretrained head
from torchvision.models import vit_b_16
def get_vit_model(num_classes=10, freeze_backbone=True):
    model_vit = vit_b_16(weights="IMAGENET1K_V1")
    model_vit.heads.head = nn.Linear(model_vit.heads.head.in_features, num_classes)
    if freeze_backbone:
        for param in model_vit.parameters():
            param.requires_grad = False
        for param in model_vit.heads.head.parameters():
            param.requires_grad = True
    return model_vit


In [44]:
from tqdm import tqdm

def run_epoch(model, loader, criterion, optimizer=None, device=None, desc=""):
    is_train = optimizer is not None
    model.train() if is_train else model.eval()

    running_loss = 0.0
    correct, total = 0, 0

    pbar = tqdm(loader, desc=desc, leave=False)
    for images, labels in pbar:
        images, labels = images.to(device), labels.to(device)

        if is_train:
            optimizer.zero_grad()

        with torch.set_grad_enabled(is_train):
            outputs = model(images)
            loss = criterion(outputs, labels)

            if is_train:
                loss.backward()
                optimizer.step()

        running_loss += loss.item()
        _, preds = torch.max(outputs, 1)
        total += labels.size(0)
        correct += (preds == labels).sum().item()

        pbar.set_postfix(
            loss=f"{loss.item():.4f}",
            acc=f"{100*correct/total:.2f}%"
        )

    avg_loss = running_loss / len(loader)
    avg_acc = correct / total
    return avg_loss, avg_acc

In [45]:
def train_model(model, train_loader, val_loader, epochs, lr, device, experiment_name):
    print(f"\n===== Starting experiment: {experiment_name} =====\n")

    model = model.to(device)
    criterion = nn.CrossEntropyLoss()

    if hasattr(model, "heads"):
        optimizer = torch.optim.Adam(model.heads.head.parameters(), lr=lr)
    else:
        optimizer = torch.optim.Adam(model.parameters(), lr=lr)

    history = {
        "train_loss": [],
        "val_loss": [],
        "train_acc": [],
        "val_acc": [],
        "epoch_time": []
    }

    for epoch in range(epochs):
        start_time = time.time()

        # --- Training ---
        train_loss, train_acc = run_epoch(
            model,
            train_loader,
            criterion,
            optimizer,
            device,
            desc=f"Training Epoch {epoch+1}/{epochs}"
        )

        # --- Validation ---
        val_loss, val_acc = run_epoch(
            model,
            val_loader,
            criterion,
            optimizer=None,
            device=device,
            desc=f"Validation Epoch {epoch+1}/{epochs}"
        )

        epoch_time = time.time() - start_time

        history["train_loss"].append(train_loss)
        history["val_loss"].append(val_loss)
        history["train_acc"].append(train_acc)
        history["val_acc"].append(val_acc)
        history["epoch_time"].append(epoch_time)

        print(
            f"Epoch {epoch+1}/{epochs} | "
            f"TrainLoss: {train_loss:.4f} | TrainAcc: {train_acc*100:.2f}% | "
            f"ValLoss: {val_loss:.4f} | ValAcc: {val_acc*100:.2f}% | "
            f"Time: {epoch_time:.1f}s"
        )

    print(f"\n===== Finished experiment: {experiment_name} =====\n")
    return history

In [46]:
results = {}
epochs_list = [3, 5, 10]

for ep in epochs_list:
    exp_name = f"VGG16_{ep}ep"
    print(f"\n===== Starting experiment: {exp_name} =====")
    vgg_model = get_vgg16_model()
    history = train_model(
        vgg_model,
        train_loader_cnn,
        test_loader_cnn,
        epochs=ep,
        lr=1e-3,
        device=device,
        experiment_name=exp_name
    )
    results[exp_name] = history

for ep in epochs_list:
    exp_name = f"ViT_{ep}ep"
    print(f"\n===== Starting experiment: {exp_name} =====")
    vit_model = get_vit_model()
    history = train_model(
        vit_model,
        train_loader_vit,
        test_loader_vit,
        epochs=ep,
        lr=3e-4,
        device=device,
        experiment_name=exp_name
    )
    results[exp_name] = history