# Deep Learning and Its Applications to Signal and Image Processing and Analysis - Assignment 3


## Introduction
In this assignment, you will perform an image classification task on the CIFAR-10 dataset using two
model families: Convolutional Neural Networks (CNNs) and Vision Transformers (ViTs). The objectives
are to develop models, apply explainability tools (Grad-CAM and attention visualization), and evaluate
comparative performance using confusion matrices and other metrics. In this assignment, you will also
learn how to use the pytorch-lightning library. This library simplifies model building and training, and
it also supports automatic logging to Weights & Biases. There is a complementary notebook attached
to the assignment. A complementary notebook is provided with this assignment. It is intended for your
convenience, and you are free to modify it as needed.

### Imports and mount drive

In [None]:
!pip install pytorch-lightning
# 📦 Imports
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import random_split, DataLoader

import pytorch_lightning as pl
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.callbacks import ModelCheckpoint

import wandb

import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay

# 🧹 Set seeds and configs
pl.seed_everything(42)

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
import os
os.chdir('')

## 1. CNN Classification and Grad-CAM Explainability

In this section, you will implement a CNN from scratch and apply Grad-CAM to explain the model predictions.

###  1.1. Load and Preprocess CIFAR-10

a.

In [None]:
# Define a transform to normalize the data
transform = ... #raise NotImplementedError("TODO: Implement this part")

#⬇️ Load dataset
dataset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
test_dataset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)

def show_example_per_label(dataset):
    raise NotImplementedError("TODO: Implement this part")


b.

In [None]:
# Split the database into train, validation and test data set.
raise NotImplementedError("TODO: Implement this part")

c.

In [None]:
# Show histogram of the categorical split for train, validation and test.
raise NotImplementedError("TODO: Implement this part")

###  1.2. Define CNN in PyTorch Lightning

a.

In [None]:
class SimpleCNN(pl.LightningModule):
    def __init__(self, lr):
        raise NotImplementedError("TODO: Implement this part")

    def forward(self, x):
        raise NotImplementedError("TODO: Implement this part")

    def training_step(self, batch, batch_idx):
        raise NotImplementedError("TODO: Implement this part")

    def validation_step(self, batch, batch_idx):
        raise NotImplementedError("TODO: Implement this part")

    def test_step(self, batch, batch_idx):
        raise NotImplementedError("TODO: Implement this part")

    def configure_optimizers(self):
        raise NotImplementedError("TODO: Implement this part")

b. Your model should achieve an accuracy of at least 0.80 on the training set, and at least 0.70 on both the validation and test sets.

In [None]:
# 🪄 Init wandb logger
wandb_logger = WandbLogger(project="CNN-CIFAR10", log_model=True)

# ⚡ Instantiate model and trainer
model = SimpleCNN(lr=...)

# Define the checkpoint callback
checkpoint_callback = ModelCheckpoint(
    #raise NotImplementedError("TODO: Implement this part")

)

trainer = pl.Trainer(
    #raise NotImplementedError("TODO: Implement this part")
)

# 🏋️‍♂️ Train
trainer.fit(model, train_loader, val_loader)

c.

In [None]:
# 🔍 Evaluate
trainer.test(model, dataloaders=test_loader)

#Show F1 score and confusion matrix you can do it in def on_test_epoch_end(self): and log it to wandb

# 💾 Save the model in thr last epoch if saved by metric  
torch.save(model.state_dict(), "cnn_cifar10_checkpoint.ckpt")

## 1.3 Explainability with Grad-CAM

a.

In [None]:
!pip install grad-cam --quiet

In [None]:
from pytorch_grad_cam import GradCAM
from pytorch_grad_cam.utils.image import show_cam_on_image, preprocess_image
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget

# Define the target layer for Grad-CAM (adjust if needed)
target_layer = ... # raise NotImplementedError("TODO: Implement this part")

cam = GradCAM(model=model, target_layers=[target_layer])


generate Grad-CAM heatmaps for several test images

In [None]:
# generate Grad-CAM heatmaps for several test image

## 2. Vision Transformer (ViT) and Attention Visualization


In this section, you will implement a Vision Transformer (ViT) from scratch and compare it to the CNN
model developed in Section 1. Additionally, you will visualize attention maps to gain insight into the
model’s decision process.

### 2.1 Implementing the Vision Transformer

a.

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class PatchEmbedding(nn.Module):
    def __init__(self, in_channels=..., patch_size=..., emb_size=..., img_size=...):
      raise NotImplementedError("TODO: Implement this part")

    def forward(self, x):
      raise NotImplementedError("TODO: Implement this part")

class ViTWithAttention(nn.Module):
  def __init__(self, img_size=..., patch_size=..., in_channels=..., num_classes=...,):
    raise NotImplementedError("TODO: Implement this part")

  def forward(self, x):
    raise NotImplementedError("TODO: Implement this part")

class ViTLightningModule(pl.LightningModule):
  def __init__(self, lr=...):
    raise NotImplementedError("TODO: Implement this part")

  def forward(self, x):
    raise NotImplementedError("TODO: Implement this part")

  def training_step(self, batch, batch_idx):
    raise NotImplementedError("TODO: Implement this part")

  def validation_step(self, batch, batch_idx):
    raise NotImplementedError("TODO: Implement this part")
  
  def test_step(self, batch, batch_idx):
    raise NotImplementedError("TODO: Implement this part")

  def configure_optimizers(self):
    raise NotImplementedError("TODO: Implement this part")

b. Your model should achieve an accuracy of at least 0.70 on the training set, and at least 0.60 on both the validation and test sets.

In [None]:
# 🪄 Init wandb logger
wandb_logger = WandbLogger(project="ViT-CIFAR10", log_model=True)

# ⚡ Instantiate model and trainer
model = ViTLightningModule(lr=...)

# Define the checkpoint callback
checkpoint_callback = ModelCheckpoint(
    # raise NotImplementedError("TODO: Implement this part")
)

trainer = pl.Trainer(
    # raise NotImplementedError("TODO: Implement this part")
)

# 🏋️‍♂️ Train
trainer.fit(model, train_loader, val_loader)

c.

In [None]:
# 🔍 Evaluate
trainer.test(model, dataloaders=test_loader)

#Show F1 score and confusion matrix you can do it in def on_test_epoch_end(self): and log it to wandb

### 2.2 Visualizing Attention Maps

In [None]:
def visualize_attention(model, image_tensor, patch_size=4):
    raise NotImplementedError("TODO: Implement this part")
