<a href="https://colab.research.google.com/github/Rajcannotcode/CNN-Experimentation-Framework-CIFAR10/blob/main/CIFAR_10_experimentation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# CIFAR-10 Image Classification with CNNs

This project presents a modular experimentation framework for training and evaluating convolutional neural networks (CNNs) on the CIFAR-10 dataset using PyTorch.

Key components include:
- Data augmentation pipelines to reduce overfitting and improve generalization
- A centralized configuration block for easily adjusting hyperparameters
- A structured experiment runner that logs results to a pandas DataFrame for easy analysis

The notebook is designed for extensibility and reproducibility, enabling users to test and compare multiple training setups with minimal code changes.


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
import torch.optim as optim

import matplotlib as plt
import numpy as np
import pandas as pd

In [None]:
from google.colab import files
import os
import re
import shutil

In [None]:
from collections import OrderedDict
from collections import namedtuple
from itertools import product

from IPython.display import clear_output
import time

from sklearn.metrics import confusion_matrix
from sklearn.metrics import ConfusionMatrixDisplay

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

if torch.cuda.is_available():
  print("GPU name: ",torch.cuda.get_device_name(0))
else:
  print('No GPU available')

No GPU available


In [None]:
raw_trainset = torchvision.datasets.CIFAR10(
    root='./data/CIFAR10',
    train = True,
    download = True,
    transform= transforms.Compose([
        transforms.ToTensor()
    ])
)

raw_testset = torchvision.datasets.CIFAR10(
    root='./data/CIFAR10',
    train = False,
    download = True,
    transform= transforms.Compose([
        transforms.ToTensor()
    ])
)

100%|██████████| 170M/170M [00:10<00:00, 15.8MB/s]


In [None]:
normalizer_loader = torch.utils.data.DataLoader(raw_trainset,batch_size = len(raw_trainset), shuffle=False)
raw_train_data = next(iter(normalizer_loader))
train_std = raw_train_data[0].std()
train_mean = raw_train_data[0].mean()

## Data Augmentation and Normalization

In this section, we define transformations to prepare the CIFAR-10 dataset for training and evaluation.

- **Normalization** is performed using the dataset’s global mean and standard deviation to ensure all features have a consistent scale, improving convergence during training.
- **Data Augmentation** is applied to artificially increase the dataset diversity and reduce overfitting. We use:
  - `RandomHorizontalFlip` to simulate mirrored objects
  - `RandomCrop` with padding to introduce spatial jitter
  - `RandomRotation` and `RandomVerticalFlip` (in the heavily augmented variant) to expose the model to rotated and flipped perspectives

Two levels of augmentation are created:
- `augmented_trainset_transformations` for general-purpose training
- `heavily_augmented_trainset_transformations` for experimentation with stronger regularization


In [None]:
normalized_trainset= torchvision.datasets.CIFAR10(
    root='./data/CIFAR10',
    train =True,
    download =True,
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean=train_mean, std=train_std)
    ])
)

normalized_testset = torchvision.datasets.CIFAR10(
    root='./data/CIFAR10',
    train = False,
    download = True,
    transform= transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean=train_mean, std= train_std)
    ])
)

normalized_trainset.data.shape

(50000, 32, 32, 3)

In [None]:
heavily_augmented_trainset_transformations = transforms.Compose([
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomCrop(32, padding=4),
    transforms. RandomRotation(degrees=(30,120)),
    transforms.RandomVerticalFlip(p=0.25),
    transforms.ToTensor(),
    transforms.Normalize(mean=train_mean, std=train_std)
])

augmented_trainset_transformations = transforms.Compose([
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomCrop(32, padding=4),
    transforms.ToTensor(),
    transforms.Normalize(mean=train_mean, std=train_std)
])


In [None]:
heavily_augmented_trainset = torchvision.datasets.CIFAR10(
    root='./data/CIFAR10',
    train=True,
    download=True,
    transform = heavily_augmented_trainset_transformations
)

augmented_trainset = torchvision.datasets.CIFAR10(
    root='./data/CIFAR10',
    train=True,
    download=True,
    transform = augmented_trainset_transformations
)

In [None]:
test_prediction_loader = torch.utils.data.DataLoader(
      dataset = normalized_testset,
      batch_size = 10000,
      num_workers = 1,
      shuffle = True
      )

## Network Factory: Custom CNN Architectures

This section defines a modular CNN network factory used during experimentation.

- The architecture is customizable and designed to work with the CIFAR-10 image size (32×32).
- Convolutional layers are typically followed by Batch Normalization, ReLU activations, and Dropout for regularization.
- The final output layer is adapted for 10-class classification.

### How to Use Your Own Model

To test a custom architecture:
1. Add an elif statement with what you want to name your model
2. Ensure the final output layer is: `nn.Linear(..., 10)`
3. Make sure your model accepts inputs of shape `[batch_size, 32, 32, 3]` (for RGB CIFAR-10 images)
4. Ensure you pass a '`torch.manual_seed()` before constructing your model, this will ensure the weights are reset if you want to test different hyperparameters with the same model


In [None]:
class NetworkFactory():
  @staticmethod
  def get_network(name):
    if(name == 'CNN_model_vanilla'):
      torch.manual_seed(50)
      return nn.Sequential(
          nn.Conv2d(in_channels=3, out_channels=6, kernel_size=5),
          nn.ReLU(),
          nn.Conv2d(in_channels=6, out_channels=12, kernel_size=5),
          nn.ReLU(),
          nn.Flatten(start_dim=1),
          nn.Linear(in_features=12*24*24, out_features=120),
          nn.ReLU(),
          nn.Linear(in_features=120, out_features=60),
          nn.ReLU(),
          nn.Linear(in_features=60, out_features=10)
      )

    elif(name == 'CNN_model_Dropout_BNorm_Mpool'):
        torch.manual_seed(50)
        return nn.Sequential(
            nn.Conv2d(in_channels=3, out_channels=6, kernel_size=5),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.BatchNorm2d(6),

            nn.Conv2d(in_channels=6, out_channels=12, kernel_size=5),
            nn.ReLU(),

            nn.Conv2d(in_channels=12, out_channels=24, kernel_size=5),
            nn.ReLU(),
            nn.Dropout(0.3),

            nn.Flatten(start_dim=1),
            nn.Linear(in_features=24*20*20, out_features=400),
            nn.ReLU(),
            nn.BatchNorm1d(400),

            nn.Linear(in_features=400, out_features=120),
            nn.ReLU(),
            nn.BatchNorm1d(120),

            nn.Linear(in_features=120, out_features=60),
            nn.ReLU(),

            nn.Linear(in_features=60, out_features=10)
      )
    elif(name == 'CNN_model_BNorm_Deep'):
        torch.manual_seed(50)
        return nn.Sequential(
            nn.Conv2d(in_channels=3, out_channels=6, kernel_size=5),
            nn.ReLU(),
            nn.BatchNorm2d(6),
            nn.Dropout(0.3),


            nn.Conv2d(in_channels=6, out_channels=12, kernel_size=5),
            nn.ReLU(),
            nn.Conv2d(in_channels=12, out_channels=24, kernel_size=5),
            nn.ReLU(),

            nn.Flatten(start_dim=1),
            nn.Linear(in_features=24*20*20, out_features=1200),
            nn.ReLU(),

            nn.Linear(in_features=1200, out_features=400),
            nn.ReLU(),
            nn.Linear(in_features=400, out_features=120),
            nn.ReLU(),

            nn.BatchNorm1d(120),
            nn.Linear(in_features=120, out_features=60),
            nn.ReLU(),
            nn.Linear(in_features=60, out_features=10)
      )

## Experiment Logging with RunManager

The `RunManager` class automates the training workflow and logs key metrics for each hyperparameter configuration.

- Each run is tracked with a unique ID and its parameter set (e.g., learning rate, batch size).
- During training, loss and accuracy metrics are recorded per epoch.
- After training, the results are compiled into a `pandas.DataFrame` for easy analysis, filtering, and visualization.

### Output and Reusability

- The final DataFrame includes hyperparameters, test accuracy, and optionally a saved state dict for the model.
- This enables reproducibility and comparison across experiments.
- Users can sort, export, or plot the DataFrame to evaluate which configuration performed best.



In [None]:
class Epochstats():
  def __init__(self):
    self.number = 0
    self.loss = 0
    self.num_correct = 0
    self.start_time = None
    self.duration = None

In [None]:
class Runstats():
  def __init__(self):
    self.params = None
    self.number = 0
    self.data = []
    self.start_time = None
    self.duration = None
    self.test_accuracy = None

In [None]:
class RunManager(Epochstats, Runstats):
  def __init__(self):
    self.model = None
    self.loader= None
    self.epoch = Epochstats()
    self.run = Runstats()
    self.model_files = []

  def begin_run(self, run, model, loader):
    self.run.start_time = time.time()
    self.run.params = run
    self.run.number += 1

    self.run.test_accuracy = None

    self.model = model
    self.loader = loader


  def end_run(self,):
    self.epoch.number = 0

    model_filename = f"model_run_{self.run.number}.pth"
    torch.save(self.model.state_dict(), model_filename)
    self.model_files.append(model_filename)

  def begin_epoch(self):
    self.epoch.start_time= time.time()

    self.epoch.number += 1
    self.epoch.loss = 0
    self.epoch.num_correct = 0


  def end_epoch(self):
    self.epoch.duration = time.time() - self.epoch.start_time
    self.run.duration = time.time() - self.run.start_time

    loss = self.epoch.loss/len(self.loader.dataset)
    accuracy = self.epoch.num_correct/len(self.loader.dataset)

    results = OrderedDict()
    results["run"] = self.run.number
    results["epoch"] = self.epoch.number
    results["loss"] = loss
    results["accuracy"] = accuracy
    results["epoch_duration"] = self.epoch.duration
    results["run_duration"] = self.run.duration

    for key,value in self.run.params._asdict().items():
      results[key] = value

    results["test_acc"] = self.run.test_accuracy

    self.run.data.append(results)

    df = pd.DataFrame.from_dict(self.run.data, orient='columns')

    clear_output(wait = True)
    display(df)


  def track_loss(self, loss):
    self.epoch.loss += loss.item() * len(self.loader.dataset)


  def track_num_correct(self, preds, labels):
    self.epoch.num_correct += self.__get__num__correct(preds, labels)


  @torch.no_grad()
  def __get__num__correct(self, pred, labels):
    return pred.argmax(dim=1).eq(labels).sum().item()


  @torch.no_grad()
  def track_testing_accuracy(self, model, testloader):
    model.eval()
    correct = 0
    total = 0

    with torch.no_grad():
        for images, labels in testloader:
            images, labels = images.to(device), labels.to(device)
            preds = model(images)

            predictions = preds.argmax(dim=1)
            total += len(labels)
            correct += (predictions == labels).sum().item()

    accuracy = correct / total
    self.run.test_accuracy = accuracy

In [None]:
class RunBuilder():
  @staticmethod
  def get_runs(parameters):
    Run = namedtuple('Run', parameters.keys())

    runs = []

    for v in product(*parameters.values()):
      runs.append(Run(*v))

    return runs

## Experiment Configuration Block

This section provides a single, centralized place to configure and modify your experiment settings. Users can adjust the following:

- `model`: Choose the architecture name as a string (e.g., `'Resnet'`, `'CNNModel'`, or any registered model).
- `lr`: List of learning rates to try. You can provide multiple values for grid search (e.g., `[0.001, 0.0005]`).
- `trainset`: Select the data preprocessing variant (`'normalized'`, `'augmented'`, `'heavily_augmented'`, etc.)

Additional settings:
- `batch_size`: Number of samples per training batch
- `num_epochs`: Number of epochs to train each model
- `weight_decay`: L2 regularization strength
- `shuffle`: Whether to shuffle training data per epoch
- `num_workers`: Number of subprocesses used for data loading

## How to Use

To test a new configuration:
1. Modify values in the `parameters` dictionary.
2. Add or remove entries in the lists to control grid search.
3. Change `trainset` to control how data is preprocessed.
4. Adjust `num_epochs`, `batch_size`, or `weight_decay` as needed.
5. Run the notebook cells below to launch the new experiments.
6. Additionally you can also add new fields to `parameters` dictionary to test different values for a hyperparameter in one go

This design allows quick experimentation without needing to edit core logic or function signatures.


In [None]:
trainsets = {
    'normalized': normalized_trainset,
    'not_normalized': raw_trainset,
    'augmented': augmented_trainset,
    'heavily_augmented':heavily_augmented_trainset
}

parameters = dict(
    network = ['CNN_model_vanilla'],
    lr = [0.001, 0.002],
    trainset = ['augmented']
)
shuffle = True
num_workers = 1 #@param {type: "integer"}
batch_size = 1000 #@param {type: "integer"}
num_epochs = 1 #@param {type: "integer"}
weight_decay = 5e-2 #@param {type: "number"}

## Training and Experimentation

This section controls the core training loop and experiment execution. It uses a custom `RunManager` class to:

- Automate training over multiple hyperparameter combinations
- Track and log loss, training accuracy, and testing accuracy (to check for overfitting) per run
- Reset and reinitialize model weights for each configuration
- Output the results into a well-formatted pandas DataFrame

Each run corresponds to a unique combination on the basis of the fields entered in `parameters` in the configuration block

The structured logging system allows users to:
- Identify top-performing configurations
- Compare test accuracy across models and settings
- Extend the framework by plugging in new optimizers, loss functions, or models


In [None]:
manager = RunManager()

for run in RunBuilder.get_runs(parameters):
  model = NetworkFactory.get_network(run.network).to(device)

  train_loader = torch.utils.data.DataLoader(
      dataset = trainsets[run.trainset],
      batch_size = batch_size,
      num_workers = num_workers
  )

  optimizer = optim.Adam(model.parameters(), lr=run.lr, weight_decay=weight_decay)

  manager.begin_run(run=run, model=model, loader=train_loader)

  for epoch in range(num_epochs):
    total_loss = 0
    total_correct = 0

    manager.begin_epoch()

    for batch in train_loader:
      images = batch[0].to(device)
      labels = batch[1].to(device)
      preds = model(images)

      loss = F.cross_entropy(preds, labels)
      optimizer.zero_grad()
      loss.backward()
      optimizer.step()

      manager.track_loss(loss)
      manager.track_num_correct(preds, labels)

    if epoch <= num_epochs-1:
          manager.track_testing_accuracy(model, test_prediction_loader)
    manager.end_epoch()
  manager.end_run()

for filepath in manager.model_files:
  print(f"Model saved: {filepath}")

Unnamed: 0,run,epoch,loss,accuracy,epoch_duration,run_duration,model,lr,trainset,test_acc
0,1,1,105.579627,0.22642,38.186993,38.187014,CNN_model_vanilla,0.001,augmented,0.3128
1,2,1,103.934432,0.23786,38.201097,38.201121,CNN_model_vanilla,0.002,augmented,0.3172


Model saved: model_run_1.pth
Model saved: model_run_2.pth


##  Model Download and Cleanup Interface

After training multiple models using the experimentation framework, this cell enables users to selectively download specific model checkpoints while managing Colab memory effectively.

###  Features:
- Input a run number (e.g., `1`) to fetch `model_run_1.pth`
- Provide a custom name (e.g., `CNN_vanilla`) to rename the download file to `CNN_vanilla.pth`
- File is downloaded and immediately deleted from the Colab disk after user confirmation
- Continue downloading more files, or exit anytime
- Automatically clears all original `.pth` files after process ends

###  How to Use:
1. **When prompted**:  
   `Enter model run to download (E to exit):`  
     Type the run number (e.g., `3`) to download `model_run_3.pth`

2. **When prompted**:  
   `Rename the file to (do not include the .pth extension):`  
     Enter your custom name (e.g., `my_cnn_model`)  
     Final file will be downloaded as `my_cnn_model.pth`

3. Once the download starts, confirm it's complete

4. To exit the loop, type `E` at the run prompt

###  Notes:
- Invalid characters or spaces in filenames are automatically replaced with `_`
- Files are copied (not renamed) to preserve originals until final cleanup
- This ensures Colab storage remains clean and prevents memory overflow



In [None]:
all_model_files = [f for f in os.listdir() if f.endswith('.pth')]
downloaded = set()
print("Available model files:")
print(all_model_files)

Available model files:
[]


In [None]:
while len(downloaded) < len(all_model_files):
  run_number = input("\nEnter model run to download(E to exit): ").strip()

  if run_number == "E" or "e":
    break

  fileaddress = f"model_run_{run_number}.pth"

  if fileaddress not in all_model_files:
    print(f"{fileaddress} does not exist or has already been downloaded")
    break

  new_name = input("Rename the file to (do not include the .pth extension): ").strip()
  new_name = re.sub(r'[^\w\-_.]', '_', new_name.strip())
  new_fileaddress = f"{new_name}.pth"
  shutil.copy(fileaddress, new_fileaddress)

  print(f"Downloading run {run_number} to local machine as {new_fileaddress}")

  files.download(new_fileaddress)
  input("Please Enter AFTER the file has downloaded to continue")

  if os.path.exists(new_fileaddress):
    os.remove(new_fileaddress)
    print(f"Deleted {new_fileaddress} from colab disc")

  downloaded.add(fileaddress)
  print(f"{fileaddress} saved to local device as {new_fileaddress}")

  if len(downloaded) == len(all_model_files):
    print("All files downloaded")

  next_action = input("Press A to add more files, or E to exit").strip()

  if next_action == "e" or "E":
    break

for f in all_model_files:
        if os.path.exists(f):
            os.remove(f)

print("\nAll original model files have been removed from Colab.")


All original model files have been removed from Colab.
