### The CIFAR-10 dataset
The CIFAR-10 dataset consists of 60000 32x32 colour images in 10 classes, with 6000 images per class. There are 50000 training images and 10000 test images.

The dataset is divided into five training batches and one test batch, each with 10000 images. The test batch contains exactly 1000 randomly-selected images from each class. The training batches contain the remaining images in random order, but some training batches may contain more images from one class than another. Between them, the training batches contain exactly 5000 images from each class.

The classes are completely mutually exclusive. There is no overlap between automobiles and trucks. "Automobile" includes sedans, SUVs, things of that sort. "Truck" includes only big trucks. Neither includes pickup trucks.

In [None]:
# supress warnings
import warnings
warnings.filterwarnings("ignore", category=UserWarning, module=".*sklearn.*")

## EDA
can be further improved later on

In [None]:
# import torch
# import torchvision
# import torchvision.transforms as transforms
# import matplotlib.pyplot as plt
# import pandas as pd
# from sklearn.manifold import TSNE # Visual cluster structure—see how classes separate in 2D
# import numpy as np

# # 1. Load CIFAR-10
# transform = transforms.Compose([transforms.ToTensor()])
# trainset = torchvision.datasets.CIFAR10(
#     root='./data', train=True, download=True, transform=transform)
# classes = trainset.classes
# targets = torch.tensor(trainset.targets)

# # 2. Class counts => Ensures your data is balanced (it is, *5,000* images per class).
# class_counts = [(classes[i], int((targets == i).sum().item()))
#                 for i in range(len(classes))]
# df_counts = pd.DataFrame(class_counts, columns=['Class', 'Count'])
# print(df_counts)

# # 3. Compute per-channel mean & std => Useful for normalization before training.
# loader = torch.utils.data.DataLoader(trainset, batch_size=5000, shuffle=False, num_workers=2)
# mean = 0.
# std = 0.
# n_samples = 0
# for data, _ in loader:
#     bs = data.size(0)
#     data = data.view(bs, data.size(1), -1)
#     mean += data.mean(2).sum(0)
#     std  += data.std(2).sum(0)
#     n_samples += bs
# mean /= n_samples
# std  /= n_samples
# print(f"Mean per channel: {mean.tolist()}")
# print(f"Std  per channel: {std.tolist()}")

# # 4. Plot Class Distribution
# plt.figure(figsize=(8,4))
# plt.bar(df_counts['Class'], df_counts['Count'])
# plt.xticks(rotation=45)
# plt.title("CIFAR-10 Class Distribution")
# plt.tight_layout()
# plt.show()

# # 5. Pixel-value Histograms per Channel
# # Check for any channel biases or artifacts.
# # Stack a subset of images to save time
# subset = torch.stack([trainset[i][0] for i in range(10000)])  # 10k images
# for ch, col in enumerate(['Red', 'Green', 'Blue']):
#     plt.figure()
#     plt.hist(subset[:, ch, :, :].numpy().ravel(), bins=50)
#     plt.title(f"{col} Channel Histogram (10k samples)")
#     plt.xlabel("Pixel value")
#     plt.ylabel("Frequency")
#     plt.tight_layout()
#     plt.show()

# # 6. t-SNE Embedding (flattened, on 2000 random images) => t-distributed Stochastic Neighbor Embedding
# np.random.seed(42)
# idxs = np.random.choice(len(trainset), 2000, replace=False)
# data_flat = np.stack([trainset[i][0].numpy().ravel() for i in idxs])
# labels   = [trainset[i][1] for i in idxs]
# tsne = TSNE(n_components=2, perplexity=30, random_state=0)
# emb = tsne.fit_transform(data_flat)

# plt.figure(figsize=(6,6))
# scatter = plt.scatter(emb[:,0], emb[:,1], c=labels, alpha=0.6, cmap='tab10')
# plt.legend(handles=scatter.legend_elements()[0], labels=classes, bbox_to_anchor=(1.05,1))
# plt.title("t-SNE of CIFAR-10 (2k samples)")
# plt.tight_layout()
# plt.show()

# # 7. Sample Grid => A quick sanity check of raw images. 
# fig, axes = plt.subplots(4,8, figsize=(12,6))
# for ax in axes.flatten():
#     i = torch.randint(len(trainset), (1,)).item()
#     img, lbl = trainset[i]
#     img = img.permute(1,2,0).numpy()
#     ax.imshow(img)
#     ax.set_title(classes[lbl], fontsize=8)
#     ax.axis('off')
# plt.tight_layout()
# plt.show()

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np
import time
import copy
import os
from tqdm import tqdm

In [None]:
# Create directories for saving results
os.makedirs("results", exist_ok=True)
os.makedirs("visualizations", exist_ok=True)

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

## ADJUSTABLE PARAMETERS
Try changing the following parameters to see how they affect the results:

In [None]:
# Training parameters
BATCH_SIZE = 128
NUM_EPOCHS = 5
LEARNING_RATE = 0.01
WEIGHT_DECAY = 5e-4
MOMENTUM = 0.9

# Model selection (set to True to train)
TRAIN_ALEXNET = True
TRAIN_VGG16 = True
TRAIN_VGG16_BN = True
TRAIN_VGG8 = True

# Architecture adjustments
ALEXNET_FC_SIZE = 4096  # Size of AlexNet's first fully connected layer according to the paper
VGG_FC_SIZE = 4096      # Size of VGG's first fully connected layer according to the paper
USE_DROPOUT = True      # Whether to use dropout in fully connected layers
DROPOUT_RATE = 0.5      # Dropout probability

# Data augmentation settings
USE_DATA_AUGMENTATION = True
HORIZONTAL_FLIP = True
RANDOM_CROP = True
NORMALIZE_DATA = True

# Visualization settings
SAVE_FILTERS = True
SAVE_FEATURE_MAPS = True
SAVE_TRAINING_CURVES = True

| Parameter                                                       | What it does                                                                                   | Typical effects of changing it                                                                                                                                            |
| --------------------------------------------------------------- | ---------------------------------------------------------------------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| **BATCH\_SIZE**                                                 | Number of samples processed before the model’s weights are updated.                            | • Larger batch → smoother gradient estimates, faster GPU utilization, but more memory.<br>• Smaller batch → noisier updates, can generalize better, but slower per‑epoch. |
| **NUM\_EPOCHS**                                                 | How many full passes over the training set.                                                    | • More epochs → model can learn more, but risk overfitting.<br>• Fewer epochs → faster run but may underfit.                                                              |
| **LEARNING\_RATE**                                              | Step size for the optimizer when updating weights (controls how “big” each update is).         | • High LR → faster initial learning but risk of divergence.<br>• Low LR → stable convergence but slower training.                                                         |
| **WEIGHT\_DECAY**                                               | L2‐regularization coefficient (penalizes large weights).                                       | • Higher → stronger regularization (helps prevent overfitting).<br>• Lower → less regularization (can fit data better, but risk overfit).                                 |
| **MOMENTUM**                                                    | In SGD, momentum keeps a fraction of the previous update to smooth and accelerate convergence. | • High (e.g. 0.9+) → faster convergence and can escape shallow minima.<br>• Low (e.g. 0.5) → more responsive to recent gradients but noisier.                             |
| **TRAIN\_ALEXNET, TRAIN\_VGG16, TRAIN\_VGG16\_BN, TRAIN\_VGG8** | Booleans to enable/disable training of each architecture.                                      | Turn each model on/off if you only want to run a subset of experiments.                                                                                                   |
| **ALEXNET\_FC\_SIZE**                                           | Number of neurons in AlexNet’s first fully‑connected (FC) layer.                               | • Larger → more capacity, but more parameters and risk of overfitting.<br>• Smaller → faster training, fewer parameters.                                                  |
| **VGG\_FC\_SIZE**                                               | Number of neurons in each of VGG’s FC layers.                                                  | Similar trade‑off as ALEXNET\_FC\_SIZE for VGG architectures.                                                                                                             |
| **USE\_DROPOUT**                                                | Whether to include dropout layers in the classifier head.                                      | • True → adds regularization by randomly zeroing connections each batch.<br>• False → no dropout (faster but less robust).                                                |
| **DROPOUT\_RATE**                                               | Probability that each neuron is “dropped” (set to zero) during training.                       | • Higher (e.g. 0.7) → stronger regularization.<br>• Lower (e.g. 0.3) → milder regularization.                                                                             |
| **USE\_DATA\_AUGMENTATION**                                     | Master switch for whether to apply random crops & flips to training images.                    | • True → generally improves generalization by exposing model to varied views.<br>• False → training on only original images.                                              |
| **HORIZONTAL\_FLIP**                                            | If augmenting, whether to randomly flip images left–right.                                     | • Often safe for natural images (e.g. cars, animals).                                                                                                                     |
| **RANDOM\_CROP**                                                | If augmenting, whether to randomly crop (with padding) around the image.                       | • Introduces translation invariance.                                                                                                                                      |
| **NORMALIZE\_DATA**                                             | Whether to subtract the dataset mean and divide by the std per channel.                        | • Almost always recommended for stable / faster convergence.<br>• Can be turned off when feeding raw pixels.                                                              |


```
Parameter                | Example values to try
BATCH_SIZE               | 64, 128, 256         
NUM_EPOCHS               | 40, 60, 80           
LEARNING_RATE            | 0.1, 0.01, 0.001     
WEIGHT_DECAY             | 1e-3, 5e-4, 1e-4     
MOMENTUM                 | 0.9, 0.95, 0.99      
ALEXNET_FC_SIZE          | 4096, 2048, 1024     
VGG_FC_SIZE              | 4096, 2048, 1024     
DROPOUT_RATE             | 0.3, 0.5, 0.7        
USE_DROPOUT              | True vs. False       
USE_DATA_AUGMENTATION    | True vs. False       
HORIZONTAL_FLIP          | True vs. False       
RANDOM_CROP              | True vs. False       
NORMALIZE_DATA           | True vs. False       
```

## DATA LOADING AND PREPROCESSING

In [None]:
# Define transformations
if USE_DATA_AUGMENTATION:
    transform_train = transforms.Compose([
        transforms.RandomCrop(32, padding=4) if RANDOM_CROP else transforms.Lambda(lambda x: x),
        transforms.RandomHorizontalFlip() if HORIZONTAL_FLIP else transforms.Lambda(lambda x: x),
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)) if NORMALIZE_DATA else transforms.Lambda(lambda x: x),
    ])
else:
    transform_train = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)) if NORMALIZE_DATA else transforms.Lambda(lambda x: x),
    ])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)) if NORMALIZE_DATA else transforms.Lambda(lambda x: x),
])

# Load CIFAR-10 dataset
trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
trainloader = DataLoader(trainset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2)

testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)
testloader = DataLoader(testset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2)

# Classes in CIFAR-10
classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

## MODEL DEFINITIONS

In [None]:
# AlexNet architecture adapted for CIFAR-10
class AlexNetCIFAR(nn.Module):
    def __init__(self, num_classes=10):
        super(AlexNetCIFAR, self).__init__()
        self.features = nn.Sequential(
            # Conv1
            nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),

            # Conv2
            nn.Conv2d(64, 192, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),

            # Conv3
            nn.Conv2d(192, 384, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),

            # Conv4
            nn.Conv2d(384, 256, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),

            # Conv5
            nn.Conv2d(256, 256, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),
        )

        if USE_DROPOUT:
            self.classifier = nn.Sequential(
                nn.Dropout(DROPOUT_RATE),
                nn.Linear(256 * 4 * 4, ALEXNET_FC_SIZE),
                nn.ReLU(inplace=True),
                nn.Dropout(DROPOUT_RATE),
                nn.Linear(ALEXNET_FC_SIZE, 1024),
                nn.ReLU(inplace=True),
                nn.Linear(1024, num_classes),
            )
        else:
            self.classifier = nn.Sequential(
                nn.Linear(256 * 4 * 4, ALEXNET_FC_SIZE),
                nn.ReLU(inplace=True),
                nn.Linear(ALEXNET_FC_SIZE, 1024),
                nn.ReLU(inplace=True),
                nn.Linear(1024, num_classes),
            )

    def forward(self, x):
        x = self.features(x)
        x = torch.flatten(x, 1)
        x = self.classifier(x)
        return x

# VGG-16 architecture for CIFAR-10
class VGG16CIFAR(nn.Module):
    def __init__(self, num_classes=10):
        super(VGG16CIFAR, self).__init__()
        self.features = nn.Sequential(
            # Block 1
            nn.Conv2d(3, 64, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 64, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),

            # Block 2
            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(128, 128, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),

            # Block 3
            nn.Conv2d(128, 256, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, 256, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, 256, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),

            # Block 4
            nn.Conv2d(256, 512, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(512, 512, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(512, 512, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),

            # Block 5
            nn.Conv2d(512, 512, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(512, 512, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(512, 512, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),
        )

        if USE_DROPOUT:
            self.classifier = nn.Sequential(
                nn.Linear(512, VGG_FC_SIZE),
                nn.ReLU(inplace=True),
                nn.Dropout(DROPOUT_RATE),
                nn.Linear(VGG_FC_SIZE, VGG_FC_SIZE),
                nn.ReLU(inplace=True),
                nn.Dropout(DROPOUT_RATE),
                nn.Linear(VGG_FC_SIZE, num_classes),
            )
        else:
            self.classifier = nn.Sequential(
                nn.Linear(512, VGG_FC_SIZE),
                nn.ReLU(inplace=True),
                nn.Linear(VGG_FC_SIZE, VGG_FC_SIZE),
                nn.ReLU(inplace=True),
                nn.Linear(VGG_FC_SIZE, num_classes),
            )

    def forward(self, x):
        x = self.features(x)
        x = torch.flatten(x, 1)
        x = self.classifier(x)
        return x

# VGG-16 with Batch Normalization
class VGG16BN(nn.Module):
    def __init__(self, num_classes=10):
        super(VGG16BN, self).__init__()
        self.features = nn.Sequential(
            # Block 1
            nn.Conv2d(3, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),

            # Block 2
            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.Conv2d(128, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),

            # Block 3
            nn.Conv2d(128, 256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, 256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, 256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),

            # Block 4
            nn.Conv2d(256, 512, kernel_size=3, padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(inplace=True),
            nn.Conv2d(512, 512, kernel_size=3, padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(inplace=True),
            nn.Conv2d(512, 512, kernel_size=3, padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),

            # Block 5
            nn.Conv2d(512, 512, kernel_size=3, padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(inplace=True),
            nn.Conv2d(512, 512, kernel_size=3, padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(inplace=True),
            nn.Conv2d(512, 512, kernel_size=3, padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),
        )

        if USE_DROPOUT:
            self.classifier = nn.Sequential(
                nn.Linear(512, VGG_FC_SIZE),
                nn.ReLU(inplace=True),
                nn.Dropout(DROPOUT_RATE),
                nn.Linear(VGG_FC_SIZE, VGG_FC_SIZE),
                nn.ReLU(inplace=True),
                nn.Dropout(DROPOUT_RATE),
                nn.Linear(VGG_FC_SIZE, num_classes),
            )
        else:
            self.classifier = nn.Sequential(
                nn.Linear(512, VGG_FC_SIZE),
                nn.ReLU(inplace=True),
                nn.Linear(VGG_FC_SIZE, VGG_FC_SIZE),
                nn.ReLU(inplace=True),
                nn.Linear(VGG_FC_SIZE, num_classes),
            )

    def forward(self, x):
        x = self.features(x)
        x = torch.flatten(x, 1)
        x = self.classifier(x)
        return x

# VGG-8 (Reduced depth)
class VGG8(nn.Module):
    def __init__(self, num_classes=10):
        super(VGG8, self).__init__()
        self.features = nn.Sequential(
            # Block 1
            nn.Conv2d(3, 64, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 64, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),

            # Block 2
            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(128, 128, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),

            # Block 3
            nn.Conv2d(128, 256, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(256, 256, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),

            # Block 4
            nn.Conv2d(256, 512, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(512, 512, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=2, stride=2),
        )

        if USE_DROPOUT:
            self.classifier = nn.Sequential(
                nn.Linear(512 * 2 * 2, VGG_FC_SIZE),
                nn.ReLU(inplace=True),
                nn.Dropout(DROPOUT_RATE),
                nn.Linear(VGG_FC_SIZE, num_classes),
            )
        else:
            self.classifier = nn.Sequential(
                nn.Linear(512 * 2 * 2, VGG_FC_SIZE),
                nn.ReLU(inplace=True),
                nn.Linear(VGG_FC_SIZE, num_classes),
            )

    def forward(self, x):
        x = self.features(x)
        x = torch.flatten(x, 1)
        x = self.classifier(x)
        return x

## TRAINING AND EVALUATION FUNCTIONS

In [None]:
# Function to count parameters
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

# Training function
def train_model(model, criterion, optimizer, scheduler, num_epochs=NUM_EPOCHS):
    model = model.to(device)
    train_losses = []
    val_losses = []
    train_accs = []
    val_accs = []
    times_per_epoch = []

    best_acc = 0.0
    best_model_wts = copy.deepcopy(model.state_dict())

    for epoch in range(num_epochs):
        start_time = time.time()

        # Training phase
        model.train()
        running_loss = 0.0
        running_corrects = 0
        total = 0

        pbar = tqdm(trainloader, desc=f"Epoch {epoch+1}/{num_epochs} [Train]")
        for inputs, labels in pbar:
            inputs, labels = inputs.to(device), labels.to(device)

            optimizer.zero_grad()

            outputs = model(inputs)
            loss = criterion(outputs, labels)

            loss.backward()
            optimizer.step()

            _, preds = torch.max(outputs, 1)
            running_loss += loss.item() * inputs.size(0)
            running_corrects += torch.sum(preds == labels.data)
            total += labels.size(0)

            # Update progress bar
            pbar.set_postfix({'loss': loss.item(), 'acc': f"{100 * running_corrects.double() / total:.2f}%"})

        if scheduler:
            scheduler.step()

        epoch_loss = running_loss / total
        epoch_acc = running_corrects.double() / total

        train_losses.append(epoch_loss)
        train_accs.append(epoch_acc.item())

        # Validation phase
        model.eval()
        running_loss = 0.0
        running_corrects = 0
        total = 0

        with torch.no_grad():
            pbar = tqdm(testloader, desc=f"Epoch {epoch+1}/{num_epochs} [Val]")
            for inputs, labels in pbar:
                inputs, labels = inputs.to(device), labels.to(device)

                outputs = model(inputs)
                loss = criterion(outputs, labels)

                _, preds = torch.max(outputs, 1)
                running_loss += loss.item() * inputs.size(0)
                running_corrects += torch.sum(preds == labels.data)
                total += labels.size(0)

                # Update progress bar
                pbar.set_postfix({'loss': loss.item(), 'acc': f"{100 * running_corrects.double() / total:.2f}%"})

        epoch_loss = running_loss / total
        epoch_acc = running_corrects.double() / total

        val_losses.append(epoch_loss)
        val_accs.append(epoch_acc.item())

        # Save the best model
        if epoch_acc > best_acc:
            best_acc = epoch_acc
            best_model_wts = copy.deepcopy(model.state_dict())

        epoch_time = time.time() - start_time
        times_per_epoch.append(epoch_time)

        print(f'Epoch {epoch+1}/{num_epochs}')
        print(f'Train Loss: {train_losses[-1]:.4f}, Train Acc: {train_accs[-1]:.4f}')
        print(f'Val Loss: {val_losses[-1]:.4f}, Val Acc: {val_accs[-1]:.4f}')
        print(f'Time: {epoch_time:.2f}s')
        print('-' * 50)

    model.load_state_dict(best_model_wts)
    return model, train_losses, val_losses, train_accs, val_accs, times_per_epoch

# Function to evaluate model on test set
def evaluate_model(model, test_loader):
    model.eval()
    correct = 0
    total = 0

    with torch.no_grad():
        for inputs, labels in test_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    accuracy = 100 * correct / total
    return accuracy

## VISUALIZATION FUNCTIONS

In [None]:
# Function to visualize first layer filters
def visualize_filters(model, title):
    if not SAVE_FILTERS:
        return

    # Get the first convolutional layer
    first_layer = None
    for module in model.modules():
        if isinstance(module, nn.Conv2d):
            first_layer = module
            break

    if first_layer is None:
        print("No convolutional layer found")
        return

    # Get the weights of the first layer
    weights = first_layer.weight.data.cpu().numpy()

    # Plotting
    fig, axes = plt.subplots(4, 8, figsize=(12, 6))
    fig.subplots_adjust(hspace=0.4, wspace=0.4)
    fig.suptitle(f'First Layer Filters - {title}', fontsize=16)

    for i, ax in enumerate(axes.flat):
        if i < weights.shape[0]:  # Only plot if filter exists
            # Normalize filter for better visualization
            img = weights[i, 0, :, :]  # Get the first channel (R)
            img = (img - img.min()) / (img.max() - img.min() + 1e-8)
            ax.imshow(img, cmap='viridis')
            ax.set_title(f'Filter {i+1}')
            ax.axis('off')
        else:
            ax.axis('off')

    plt.savefig(f'visualizations/{title}_filters.png')
    plt.close()

# Function to visualize feature maps
def visualize_feature_maps(model, title):
    if not SAVE_FEATURE_MAPS:
        return

    # Register a hook to get feature maps
    feature_maps = {}

    def hook_fn(module, input, output):
        feature_maps['features'] = output.detach().cpu()

    # Get the first convolutional layer
    first_conv = None
    for name, module in model.named_modules():
        if isinstance(module, nn.Conv2d):
            first_conv = module
            break

    if first_conv is None:
        print("No convolutional layer found")
        return

    hook = first_conv.register_forward_hook(hook_fn)

    # Get a batch of images
    dataiter = iter(testloader)
    images, _ = next(dataiter)

    # Forward pass
    model.eval()
    with torch.no_grad():
        _ = model(images[0:1].to(device))

    # Remove the hook
    hook.remove()

    # Get feature maps
    if 'features' in feature_maps:
        feature_map = feature_maps['features'][0]

        # Plot feature maps
        fig, axes = plt.subplots(4, 8, figsize=(12, 6))
        fig.subplots_adjust(hspace=0.4, wspace=0.4)
        fig.suptitle(f'Feature Maps - {title}', fontsize=16)

        for i, ax in enumerate(axes.flat):
            if i < feature_map.shape[0]:  # Only plot if feature map exists
                img = feature_map[i].numpy()
                img = (img - img.min()) / (img.max() - img.min() + 1e-8)
                ax.imshow(img, cmap='viridis')
                ax.set_title(f'Map {i+1}')
                ax.axis('off')
            else:
                ax.axis('off')

        plt.savefig(f'visualizations/{title}_feature_maps.png')
        plt.close()

# Function to plot training history
def plot_history(histories, model_names):
    if not SAVE_TRAINING_CURVES:
        return

    fig, axes = plt.subplots(2, 2, figsize=(15, 10))

    # Plot training and validation accuracy
    for i, (name, history) in enumerate(zip(model_names, histories)):
        train_accs, val_accs, train_losses, val_losses, times = history
        epochs = range(1, len(train_accs) + 1)

        axes[0, 0].plot(epochs, train_accs, marker='o', linestyle='-', label=f'{name} Train')
        axes[0, 0].set_title('Training Accuracy')
        axes[0, 0].set_xlabel('Epochs')
        axes[0, 0].set_ylabel('Accuracy')
        axes[0, 0].grid(True)
        axes[0, 0].legend()

        axes[0, 1].plot(epochs, val_accs, marker='o', linestyle='-', label=f'{name} Val')
        axes[0, 1].set_title('Validation Accuracy')
        axes[0, 1].set_xlabel('Epochs')
        axes[0, 1].set_ylabel('Accuracy')
        axes[0, 1].grid(True)
        axes[0, 1].legend()

        axes[1, 0].plot(epochs, train_losses, marker='o', linestyle='-', label=f'{name} Train')
        axes[1, 0].set_title('Training Loss')
        axes[1, 0].set_xlabel('Epochs')
        axes[1, 0].set_ylabel('Loss')
        axes[1, 0].grid(True)
        axes[1, 0].legend()

        axes[1, 1].plot(epochs, val_losses, marker='o', linestyle='-', label=f'{name} Val')
        axes[1, 1].set_title('Validation Loss')
        axes[1, 1].set_xlabel('Epochs')
        axes[1, 1].set_ylabel('Loss')
        axes[1, 1].grid(True)
        axes[1, 1].legend()

    plt.tight_layout()
    plt.savefig('visualizations/training_history.png')
    plt.close()

    # Plot time per epoch
    plt.figure(figsize=(10, 6))
    for i, (name, history) in enumerate(zip(model_names, histories)):
        _, _, _, _, times = history
        epochs = range(1, len(times) + 1)
        plt.plot(epochs, times, marker='o', linestyle='-', label=name)

    plt.title('Time per Epoch')
    plt.xlabel('Epochs')
    plt.ylabel('Time (seconds)')
    plt.grid(True)
    plt.legend()
    plt.tight_layout()
    plt.savefig('visualizations/time_per_epoch.png')
    plt.close()

## MAIN EXECUTION

In [None]:
def run_experiments():
    histories = []
    model_names = []
    results = {}

    # Initialize criterion
    criterion = nn.CrossEntropyLoss()

    # Train and evaluate AlexNet
    if TRAIN_ALEXNET:
        print("\n" + "="*50)
        print("Training AlexNet")
        print("="*50)

        alexnet_model = AlexNetCIFAR()
        print(f"AlexNet parameters: {count_parameters(alexnet_model):,}")

        alexnet_optimizer = optim.SGD(alexnet_model.parameters(),
                                     lr=LEARNING_RATE,
                                     momentum=MOMENTUM,
                                     weight_decay=WEIGHT_DECAY)

        alexnet_scheduler = optim.lr_scheduler.CosineAnnealingLR(alexnet_optimizer, T_max=NUM_EPOCHS)

        alexnet_model, alexnet_train_losses, alexnet_val_losses, alexnet_train_accs, alexnet_val_accs, alexnet_times = train_model(
            alexnet_model, criterion, alexnet_optimizer, alexnet_scheduler, num_epochs=NUM_EPOCHS
        )

        # Save model
        torch.save(alexnet_model.state_dict(), 'results/alexnet_cifar10.pth')

        # Evaluate final test accuracy
        alexnet_acc = evaluate_model(alexnet_model, testloader)
        print(f"AlexNet Test Accuracy: {alexnet_acc:.2f}%")

        # Visualize filters and feature maps
        visualize_filters(alexnet_model, "AlexNet")
        visualize_feature_maps(alexnet_model, "AlexNet")

        # Save results
        histories.append((alexnet_train_accs, alexnet_val_accs, alexnet_train_losses, alexnet_val_losses, alexnet_times))
        model_names.append("AlexNet")

        results['AlexNet'] = {
            'parameters': count_parameters(alexnet_model),
            'test_accuracy': alexnet_acc,
            'train_accuracy': alexnet_train_accs[-1],
            'val_accuracy': alexnet_val_accs[-1],
            'time_per_epoch': sum(alexnet_times) / len(alexnet_times)
        }

    # Train and evaluate VGG-16
    if TRAIN_VGG16:
        print("\n" + "="*50)
        print("Training VGG-16")
        print("="*50)

        vgg16_model = VGG16CIFAR()
        print(f"VGG-16 parameters: {count_parameters(vgg16_model):,}")

        vgg16_optimizer = optim.SGD(vgg16_model.parameters(),
                                    lr=LEARNING_RATE,
                                    momentum=MOMENTUM,
                                    weight_decay=WEIGHT_DECAY)

        vgg16_scheduler = optim.lr_scheduler.CosineAnnealingLR(vgg16_optimizer, T_max=NUM_EPOCHS)

        vgg16_model, vgg16_train_losses, vgg16_val_losses, vgg16_train_accs, vgg16_val_accs, vgg16_times = train_model(
            vgg16_model, criterion, vgg16_optimizer, vgg16_scheduler, num_epochs=NUM_EPOCHS
        )

        # Save model
        torch.save(vgg16_model.state_dict(), 'results/vgg16_cifar10.pth')

        # Evaluate final test accuracy
        vgg16_acc = evaluate_model(vgg16_model, testloader)
        print(f"VGG-16 Test Accuracy: {vgg16_acc:.2f}%")

        # Visualize filters and feature maps
        visualize_filters(vgg16_model, "VGG-16")
        visualize_feature_maps(vgg16_model, "VGG-16")

        # Save results
        histories.append((vgg16_train_accs, vgg16_val_accs, vgg16_train_losses, vgg16_val_losses, vgg16_times))
        model_names.append("VGG-16")

        results['VGG-16'] = {
            'parameters': count_parameters(vgg16_model),
            'test_accuracy': vgg16_acc,
            'train_accuracy': vgg16_train_accs[-1],
            'val_accuracy': vgg16_val_accs[-1],
            'time_per_epoch': sum(vgg16_times) / len(vgg16_times)
        }

    # Train and evaluate VGG-16 with Batch Normalization
    if TRAIN_VGG16_BN:
        print("\n" + "="*50)
        print("Training VGG-16 with Batch Normalization")
        print("="*50)

        vgg16bn_model = VGG16BN()
        print(f"VGG-16 BN parameters: {count_parameters(vgg16bn_model):,}")

        vgg16bn_optimizer = optim.SGD(vgg16bn_model.parameters(),
                                     lr=LEARNING_RATE,
                                     momentum=MOMENTUM,
                                     weight_decay=WEIGHT_DECAY)

        vgg16bn_scheduler = optim.lr_scheduler.CosineAnnealingLR(vgg16bn_optimizer, T_max=NUM_EPOCHS)

        vgg16bn_model, vgg16bn_train_losses, vgg16bn_val_losses, vgg16bn_train_accs, vgg16bn_val_accs, vgg16bn_times = train_model(
            vgg16bn_model, criterion, vgg16bn_optimizer, vgg16bn_scheduler, num_epochs=NUM_EPOCHS
        )

        # Save model
        torch.save(vgg16bn_model.state_dict(), 'results/vgg16bn_cifar10.pth')

        # Evaluate final test accuracy
        vgg16bn_acc = evaluate_model(vgg16bn_model, testloader)
        print(f"VGG-16 BN Test Accuracy: {vgg16bn_acc:.2f}%")

        # Visualize filters and feature maps
        visualize_filters(vgg16bn_model, "VGG-16-BN")
        visualize_feature_maps(vgg16bn_model, "VGG-16-BN")

        # Save results
        histories.append((vgg16bn_train_accs, vgg16bn_val_accs, vgg16bn_train_losses, vgg16bn_val_losses, vgg16bn_times))
        model_names.append("VGG-16 BN")

        results['VGG-16 BN'] = {
            'parameters': count_parameters(vgg16bn_model),
            'test_accuracy': vgg16bn_acc,
            'train_accuracy': vgg16bn_train_accs[-1],
            'val_accuracy': vgg16bn_val_accs[-1],
            'time_per_epoch': sum(vgg16bn_times) / len(vgg16bn_times)
        }

    # Train and evaluate VGG-8
    if TRAIN_VGG8:
        print("\n" + "="*50)
        print("Training VGG-8 (Reduced Depth)")
        print("="*50)

        vgg8_model = VGG8()
        print(f"VGG-8 parameters: {count_parameters(vgg8_model):,}")

        vgg8_optimizer = optim.SGD(vgg8_model.parameters(),
                                  lr=LEARNING_RATE,
                                  momentum=MOMENTUM,
                                  weight_decay=WEIGHT_DECAY)

        vgg8_scheduler = optim.lr_scheduler.CosineAnnealingLR(vgg8_optimizer, T_max=NUM_EPOCHS)

        vgg8_model, vgg8_train_losses, vgg8_val_losses, vgg8_train_accs, vgg8_val_accs, vgg8_times = train_model(
            vgg8_model, criterion, vgg8_optimizer, vgg8_scheduler, num_epochs=NUM_EPOCHS
        )

        # Save model
        torch.save(vgg8_model.state_dict(), 'results/vgg8_cifar10.pth')

        # Evaluate final test accuracy
        vgg8_acc = evaluate_model(vgg8_model, testloader)
        print(f"VGG-8 Test Accuracy: {vgg8_acc:.2f}%")

        # Visualize filters and feature maps
        visualize_filters(vgg8_model, "VGG-8")
        visualize_feature_maps(vgg8_model, "VGG-8")

        # Save results
        histories.append((vgg8_train_accs, vgg8_val_accs, vgg8_train_losses, vgg8_val_losses, vgg8_times))
        model_names.append("VGG-8")

        results['VGG-8'] = {
            'parameters': count_parameters(vgg8_model),
            'test_accuracy': vgg8_acc,
            'train_accuracy': vgg8_train_accs[-1],
            'val_accuracy': vgg8_val_accs[-1],
            'time_per_epoch': sum(vgg8_times) / len(vgg8_times)
        }

    # Plot training histories
    if len(histories) > 0:
        plot_history(histories, model_names)

    # Print summary
    print("\n" + "="*50)
    print("SUMMARY OF RESULTS")
    print("="*50)

    for name, result in results.items():
        print(f"{name}:")
        print(f"  - Parameters: {result['parameters']:,}")
        print(f"  - Test Accuracy: {result['test_accuracy']:.2f}%")
        print(f"  - Train Accuracy: {result['train_accuracy']:.2f}%")
        print(f"  - Validation Accuracy: {result['val_accuracy']:.2f}%")
        print(f"  - Avg. Time per Epoch: {result['time_per_epoch']:.2f}s")
        print()

    # Save results to file
    import json
    with open('results/experiment_results.json', 'w') as f:
        # Convert float32 to float for JSON serialization
        serializable_results = {}
        for model, metrics in results.items():
            serializable_results[model] = {k: float(v) if isinstance(v, torch.Tensor) else v for k, v in metrics.items()}

        json.dump(serializable_results, f, indent=4)

    print("Results saved to 'results/experiment_results.json'")

    return results

## RUN EXPERIMENTS

In [None]:
# Run all experiments
if __name__ == "__main__":
    results = run_experiments()