# Summer School on Biomedical Imaging with Deep Learning

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/albarqounilab/BILD-Summer-School/blob/main/notebooks/day1/classification.ipynb)

![alt_text](https://raw.githubusercontent.com/albarqounilab/BILD-Summer-School/refs/heads/main/images/helpers/notebook-banner.png)

BILD 2025 is organized under the umbrella of the [Strategic Arab-German Network for Affordable and Democratized AI in Healthcare (SANAD)](https://albarqouni.github.io/funded/sanad/), uniting academic excellence and technological innovation across borders. This year’s edition is organized by the [Albarqouni Lab](https://albarqouni.github.io/) at the [University Hospital Bonn](https://www.ukbonn.de/) and the [University of Bonn](https://www.uni-bonn.de/en). We are proud to partner with leading institutions in the region—Lebanese American University, University of Tunis El Manar, and Duhok Polytechnic University — to deliver a truly international learning experience. Over five intensive days in Tunis, you will explore cutting-edge deep-learning techniques for medical imaging through expert lectures, hands-on labs, and collaborative case studies. Engage with peers and faculty from Germany, Lebanon, Iraq, and Tunisia as you develop practical skills in building and deploying AI models for real-world healthcare challenges. We look forward to an inspiring week of interdisciplinary exchange and the shared commitment to advancing affordable, life-saving AI in medicine.


## Chest-X-Ray Classification [60 mins]

### Today's Goals

This session is a practical journey into the world of medical image classification. By the end of this notebook, you will be able to:

- Prepare Medical Imaging Data: Load, preprocess, and normalize chest X-ray images for deep learning models.

- Understand Classification Data Pipelines: Create a custom PyTorch Dataset and DataLoader tailored for image classification tasks.

- Train a Classifier Model: Fine-tune state-of-the-art CNN architectures (e.g., DenseNet, EfficientNet) to classify chest X-rays into diagnostic categories.

- Master Classification Metrics: Use and interpret evaluation metrics such as Accuracy, Precision, Recall, F1-score, and ROC-AUC.

- Perform Model Explainability & Quality Control: Apply advanced interpretability techniques (Grad-CAM, Score-CAM, LIME) to visualize what regions of the image the model relies on, ensuring predictions are trustworthy.

### Objectives

You’ll see how AI can be trained to identify pathological findings in chest X-rays, a crucial step in computer-assisted diagnosis. You’ll also apply your classification skills to a challenging real-world problem in medical imaging, while learning how to interpret and validate model predictions beyond raw accuracy scores.

### Dataset

The [NIH ChestX-ray-14](https://nihcc.app.box.com/v/ChestXray-NIHCC) dataset is a large collection of chest X-ray images. Each image comes with information about the patient and labels that tell us which diseases (if any) are present. This dataset is widely used in medical AI research because it helps us train and test models to recognize diseases from X-ray images.

**What does the dataset contain?**
1. Over 100,000 chest X-ray images, each in PNG format. These are pictures of the inside of the chest, showing the lungs and heart.
2. A metadata file (`Data_Entry_2017.csv`) that lists information about each image, such as:
   - Which diseases are present (if any)
   - Patient age and gender
   - How the image was taken
3. A file with bounding boxes (`BBox_List_2017.csv`) for about 1,000 images. These boxes show exactly where a disease is located in the image.
4. Files that split the data into training and test sets. This is important because we want to train our model on some images and test it on others to see how well it works on new data.

**Why do we use this dataset?**
- It is large and diverse, which helps our model learn better.
- It has real medical labels, making our project more realistic.
- It allows us to practice both classification (is there a disease?) and detection (where is the disease?).

In this notebook, we will use a smaller sample of this dataset and pre-trained models to make the exercises faster and easier to follow.


## 1. Environment Setup

We install and import required libraries. Run this once per new environment.

> **Note:** The cell will install packages (internet required). If you're offline, skip installation and ensure the environment already has the packages.


In [None]:
#@title import libraries (2 minutes)
!pip install grad-cam
!pip install lime
!pip install -q huggingface_hub
from huggingface_hub import hf_hub_download
import os

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import os
import warnings
warnings.filterwarnings('ignore')

!pip install pydicom -q
import pydicom

!pip install SimpleITK -q
import SimpleITK as sitk

from glob import glob
import time
import cv2
import re
import torch
import torch.nn as nn
from torch.utils.data import Dataset
from torchvision import tv_tensors
import torchvision
from torchvision import transforms
from torch.utils.data import DataLoader
import pydicom # Added import for pydicom

from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, f1_score, recall_score, precision_score
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_curve, auc
import matplotlib.pyplot as plt
from tqdm import tqdm
from sklearn.metrics import RocCurveDisplay
from sklearn.metrics import roc_auc_score

from torchvision import transforms, models
from PIL import Image
import random
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
from torch.optim.lr_scheduler import OneCycleLR

from sklearn.metrics import accuracy_score
!pip install torchmetrics -q
from torchmetrics.classification import BinaryAUROC

\
## 2. The Dataset: ChestX-ray14 (NIH) ( 3 minutes)

We’ll use the NIH ChestX-ray14 dataset, which contains over 112,000 frontal chest X-ray images from 30,805 patients, labeled with 14 thoracic disease categories (including pneumonia, emphysema, fibrosis, hernia, and no finding).

- Size: ~42 GB

- Format: JPG images + accompanying CSV files with labels

- Source: NIH Clinical Center

### 2.1 Downloading the Data (3 minutes)
Before we can work with the data, we need to download and unzip it. This means we are copying the files from the internet to our computer and making them ready to use.

**Why do we do this?**
- Machine learning models need data to learn from. Downloading the dataset gives us the images and labels we need for our project.
- Unzipping extracts the files from a compressed format so we can access them easily in our code.

**Instructions:**
- If you have not downloaded the dataset yet, run the following cells to download and unzip the files.
- If you already have the data, you can skip these steps by adding a `#` before the `!` in the code (this comments out the line so it won't run).
- You can also change the `DATA_PATH` variable if you want to store the data in a different folder.

> **Tip:** Downloading large datasets can take a while, depending on your internet speed.


In [None]:
# Download data from Hugging Face
!pip install -q huggingface_hub
from huggingface_hub import hf_hub_download
import os

CWD = '.'
DATA_PATH = f"{CWD}/Classification"
REPO_ID = 'albarqouni/bild-dataset'
SUBFOLDER = 'Classification'
os.makedirs(CWD, exist_ok=True)

# Download csv.zip
csv_zip_path = hf_hub_download(
    repo_id=REPO_ID,
    repo_type="dataset",
    filename="csv.zip",
    subfolder=SUBFOLDER,
    local_dir=CWD
)

# Download data_cxr8.zip
data_zip_path = hf_hub_download(
    repo_id=REPO_ID,
    repo_type="dataset",
    filename="data_cxr8.zip",
    subfolder=SUBFOLDER,
    local_dir=CWD
)

print("Download complete.")

In [None]:
!unzip -q {DATA_PATH}/csv.zip -d {DATA_PATH}
!unzip -q {DATA_PATH}/data_cxr8.zip -d {DATA_PATH}

>  RSNA dataset download is not included in the Hugging Face Classification dataset.
If needed, add similar Hugging Face download logic here for RSNA or other datasets.

###2.2 Downloading Pretrained Model Weights
To save time and resources, we’ll use pretrained models (DenseNet121, EfficientNet, Swin Transformer).


These weights are downloaded from Hugging Face and will be used for transfer learning.

🔹 Why do we use pretrained models?

- Training a deep network from scratch on a huge dataset (like 42GB chest X-rays) would take days/weeks and require lots of GPUs.

- Instead, we use models that have already been pretrained on ImageNet (1M+ images).

- We then fine-tune them on our medical dataset (transfer learning).

- This makes training faster and often improves performance.

🔹 Common Pretrained Models Used

1. DenseNet121

    - Dense connections between layers → improves gradient flow.

    - Popular for medical imaging tasks (used in the NIH ChestX-ray paper itself).

    <img src="https://drive.google.com/uc?export=view&id=1Zr1-ni4pqjiZLHrFJNyXRGdwCFJp1Zrx" width="600">


2. EfficientNet

    - Balances model depth, width, and resolution efficiently.

    - Often achieves better accuracy with fewer parameters.
    <img src=" https://drive.google.com/uc?export=view&id=173iLsgRvocshQRSG1DICALVakaJm-XxM" width="600">


3. Swin Transformer

    - A Vision Transformer (ViT) variant.

    - Uses shifted windows for efficient self-attention.

    - Very strong performance on classification & detection.
   <img src="https://drive.google.com/uc?id=15Q75XaSSVRMHIg4GfUNNNAjAG61QWRJT" width="500">


In [None]:
# Download pretrained model weights from Hugging Face

# DenseNet121
hf_hub_download(
    repo_id=REPO_ID,
    repo_type="dataset",
    filename="densenet121-classification.pth",
    subfolder=SUBFOLDER,
    local_dir=CWD
)
# EfficientNet
hf_hub_download(
    repo_id=REPO_ID,
    repo_type="dataset",
    filename="efficientnet-classification.pth",
    subfolder=SUBFOLDER,
    local_dir=CWD
)
# Swin Transformer
hf_hub_download(
    repo_id=REPO_ID,
    repo_type="dataset",
    filename="swintransformer-classification.pth",
    subfolder=SUBFOLDER,
    local_dir=CWD
)
print("Model weights downloaded.")

###2.3 Data Exploration

Understanding the Data Structure

After downloading and extracting the dataset, we expect the following structure inside the Classification/ folder:
```
Classification/
  images/                              # Contains chest X-ray PNG images
    00000005_003.png
    00000005_006.png
    00000005_007.png
    ...
  metadata.csv                         # Full metadata: image IDs, labels, patient info
  metadata_filtered.csv                # Processed/filtered metadata
  train_df.csv                         # Training set split
  val_df.csv                           # Validation set split
  test_df.csv                          # Test set split
  train_val_list.txt                   # Combined train/val list
  test_list.txt                        # Test image list
  densenet121-classification.pth       # Pretrained DenseNet121 model weights
  efficientnet-classification.pth      # Pretrained EfficientNet model weights
  swintransformer-classification.pth   # Pretrained Swin Transformer model weights

```

### Load dataframe metadata

A **dataframe** is a table of data, like a spreadsheet, that we can easily work with in Python using the pandas library. Here, we load the metadata for all our images. This metadata tells us important information about each image, such as which diseases are present, the patient ID, and more. Loading this information helps us organize and prepare our data for training and testing our model.

In [None]:
# Load and observe available data
metadata_df = pd.read_csv(f'{DATA_PATH}/metadata.csv')
metadata_df#.head() # Print the 5 fist rows of the dataframe

In our dataset, we have two types of information:

- A metadata file (metadata.csv) that lists more than 112,000 entries, one for each expected X-ray.

- A folder of actual images, which contains only 24,502 files.

Now the question is: do all metadata entries have a corresponding image file? Let’s visualize this relationship.

In [None]:
import matplotlib.patches as patches
import matplotlib.pyplot as plt

fig, ax = plt.subplots(figsize=(9,5))

# Big box = Metadata
ax.add_patch(patches.Rectangle((0.1,0.2),0.6,0.6,fill=None,
                               edgecolor="blue",linewidth=2))
ax.text(0.4,0.8,"Metadata.csv\n(112,120 entries)",ha="center",fontsize=11,color="blue")

# Smaller box = Images
ax.add_patch(patches.Rectangle((0.5,0.3),0.35,0.4,fill=None,
                               edgecolor="orange",linewidth=2))
ax.text(0.675,0.7,"Images/\n(24,502)",ha="center",fontsize=11,color="orange")

# Overlap area
ax.add_patch(patches.Rectangle((0.5,0.3),0.2,0.4,fill=True,
                               facecolor="lightgreen",alpha=0.4,edgecolor="green"))
ax.text(0.6,0.5,"Overlap\n(24,502 usable)",ha="center",fontsize=11,color="green")

# Annotations
ax.annotate("Only in Metadata", xy=(0.2,0.5), xytext=(0.05,0.9),
            arrowprops=dict(arrowstyle="->",color="blue"), fontsize=10, color="blue")

ax.annotate("Only in Images", xy=(0.8,0.45), xytext=(0.9,0.2),
            arrowprops=dict(arrowstyle="->",color="orange"), fontsize=10, color="orange")

ax.annotate("Both Metadata + Images", xy=(0.58,0.45), xytext=(0.35,0.25),
            arrowprops=dict(arrowstyle="->",color="green"), fontsize=10, color="green")

ax.axis("off")
plt.title("Relationship between Metadata and Images", fontsize=13)
plt.show()


This view makes it even clearer:

- Most of the data exists only in the metadata file, but without images, we can’t use them.

- The overlap gives us the real working dataset we’ll use for training and evaluation: 24,502 chest X-rays.

- The orange area for ‘Only in Images’ is empty, meaning every image in the folder has a metadata entry — which is good for consistency.

Now we need to make sure that the information in our dataframe matches the images we actually downloaded. This step filters out any entries in the metadata that do not have a corresponding image file, so we only work with images that are available on our computer.

In [None]:
imgs = glob(f'{DATA_PATH}/images/*')
imgs_basename = [os.path.basename(i) for i in imgs]

metadata_df = metadata_df.loc[metadata_df['Image Index'].isin(imgs_basename)]
metadata_df.shape

In [None]:
#@title Plot Patient Images

import matplotlib.pyplot as plt
import random

# Show a few random patient images
def plot_random_images(metadata_df, data_path, n=6):
    """Plots n random images with their labels."""
    sample_df = metadata_df.sample(n)
    plt.figure(figsize=(15, 8))

    for i, (_, row) in enumerate(sample_df.iterrows()):
        img_path = os.path.join(data_path, "images", row["Image Index"])
        img = plt.imread(img_path)

        plt.subplot(2, n//2, i+1)
        plt.imshow(img, cmap="gray")
        plt.title(f"Patient {row['Patient ID']}\nLabel: {row['Finding Labels']}", fontsize=9)
        plt.axis("off")

    plt.tight_layout()
    plt.show()

# Call the function
plot_random_images(metadata_df, DATA_PATH, n=6)


##3. Data Splitting & Preparation

### 3.1 Load patient splits

To train and evaluate our model properly, we need to split our data into different groups:
- **Training set:** Used to teach the model.
- **Validation set:** Used to check how well the model is learning during training.
- **Test set:** Used to see how well the model works on completely new data.

In this step, we load lists of which images belong to each group. This helps us make sure that the model is tested on images it has never seen before, which is important for getting a fair measure of its performance.

In [None]:
train_val_patients = pd.read_csv(f'{DATA_PATH}/train_val_list.txt', header=None, names=['patientId'])
test_patients = pd.read_csv(f'{DATA_PATH}/test_list.txt', header=None, names=['patientId'])

print(f"Number of patients in train/val set: {len(train_val_patients)}")
print(f"Number of patients in test set: {len(test_patients)}")

The `.txt` files contain lists of image names that belong to the training/validation or test sets. To use these splits, we need to match the image names in these files with the information in our main database (`metadata.csv`). This way, we know which images and labels go into each group for training and testing.

### 3.2 Handle targets

In machine learning, a **target** is what we want the model to predict. For this project, the target is the disease label for each image. In this step, we prepare the target labels so that our model can learn to predict them. This may involve simplifying the labels or grouping them in a way that makes the problem easier to solve.

In the next step, we look at how many times each disease label appears in our data. Some diseases are very rare, which can make it hard for the model to learn about them. To keep things simple and make sure our model has enough examples to learn from, we will remove labels that appear less than 1,500 times.

In [None]:
label_counts = metadata_df['Finding Labels'].value_counts()
label_counts

We remove rare labels (diseases that appear in fewer than 1,500 images) so that our model has enough examples to learn from. This helps the model focus on the most common diseases and improves its ability to make accurate predictions.

After filtering out rare labels, we are left with the most common disease categories. The table below shows how many images belong to each label. This helps us understand the balance of our dataset and which diseases our model will learn to recognize.

First, we look at how many images there are for each disease label. This helps us see if some diseases are much more common than others, which can affect how well our model learns.

In [None]:
label_counts = metadata_df['Finding Labels'].value_counts()
rare_labels = label_counts[label_counts < 1500].index

Now we update our data table (DataFrame) to remove any images with rare disease labels. This makes sure our model only sees images with the most common labels, which helps it learn better.

In [None]:
metadata_df_filtered = metadata_df[~metadata_df['Finding Labels'].isin(rare_labels)].copy()

print(f"Original shape: {metadata_df.shape}")
print(f"Filtered shape: {metadata_df_filtered.shape}")

In [None]:
metadata_df_filtered['Finding Labels'].value_counts()

To make our task easier, we will turn the problem into a **binary classification** problem. This means the model will learn to answer a simple question: Is this X-ray healthy or does it show signs of disease?

- **Class 0 (Negative):** Images labeled as 'No Finding' (healthy)
- **Class 1 (Positive):** Images with any disease label (pathology present)

This approach is common in deep learning when starting out, because it is easier for the model to learn to distinguish between just two categories. The category we want the model to predict is called the **target class**. Here, you can also try focusing on a specific disease (like 'Effusion') or experiment with more classes to see how the model behaves.

In [None]:
keep = {
    'No Finding', 'Effusion',
}

# split each cell into a list, then keep rows where at least one element is in `keep`
df_filtered = metadata_df_filtered[
    metadata_df_filtered['Finding Labels']
      .str.split('|')                         # or .str.split(',') if comma‑separated
      .apply(lambda labels: any(lbl in keep for lbl in labels))
].copy()
df_filtered['Finding Labels'].value_counts()

Now we create a new column called `Binary Label` in our data. This column will have a value of 0 for healthy images and 1 for images with any disease. This process is called **label encoding** and is very common in deep learning, because models work best with numbers instead of text.

In [None]:
df_filtered['Binary Label'] = (df_filtered['Finding Labels'] != 'No Finding').astype(int)
df_filtered['Binary Label'].value_counts()

We can further clean our dataset by selecting only one **view acquisition** type for our classifier. 'View acquisition' refers to the way the X-ray image was taken (for example, from the front or the side). Using only one type (like 'PA' for posteroanterior) helps the model learn more consistently, because all images will look similar in terms of orientation.

In [None]:
df_filtered = df_filtered[df_filtered["View Position"] == 'PA']
df_filtered['View Position'].value_counts()

### 3.3 Train / Test Split
Now we use the lists of patient IDs to split our data into a **training set** (used to teach the model) and a **test set** (used to check how well the model works on new, unseen data). This is called a **train-test split** and is a key step in building reliable machine learning models.

In [None]:
# Split df_filtered based on patient IDs from the loaded lists
train_val_df = df_filtered[df_filtered['Image Index'].isin(train_val_patients['patientId'])].copy()
test_df = df_filtered[df_filtered['Image Index'].isin(test_patients['patientId'])].copy()

print(f"Train val shape: {train_val_df.shape}")
print(f"Test set shape: {test_df.shape}")

### 3.4 Balance Classes & Prepare Subset

Now we balance the dataset by limiting the number of examples per class to 3,000. This ensures that the model sees a similar number of positive (disease) and negative (healthy) examples during training. Balancing the classes is important because an unbalanced dataset can cause the model to become biased toward the majority class, reducing its ability to correctly predict the minority class.

In [None]:
df = train_val_df.copy()
pos = df[df['Binary Label'] == 1]
neg = df[df['Binary Label'] == 0]

# # sample up to x each
n_samples = 3000
pos = pos.sample(n=min(len(pos), n_samples), random_state=42)
neg = neg.sample(n=min(len(neg), n_samples), random_state=42)

subset = pd.concat([pos, neg]).reset_index(drop=True)
print("Subset size:", subset.shape)
print(subset['Binary Label'].value_counts())

### 3.5 Final Train/Validation Split

Now we split the balanced subset into separate training and validation sets. The training set is used to teach the model, while the validation set is used to monitor the model’s learning during training and to tune hyperparameters. We use the stratify parameter to maintain class balance in both splits, ensuring that both sets contain similar proportions of positive and negative examples. This train-validation split is essential for building a reliable model and avoiding overfitting.

In [None]:
# %%
train_validation_df, test_df = train_test_split(
    subset,
    test_size=0.2,
    stratify=subset['Binary Label'],
    random_state=42
)
train_df, val_df = train_test_split(
    train_validation_df,
    test_size=0.1,
    stratify=train_validation_df['Binary Label'],
    random_state=42
)
print("Train:", train_df.shape, "Validation:", val_df.shape)

### 3.6 Deep Learning Data Terminology

Before we start building datasets and data loaders in PyTorch, it is important to understand these key concepts :

- **Batch:** A batch is a small group of samples processed together by the model before updating its parameters. Using batches makes training faster and more stable.
- **Epoch:** One epoch means the model has seen all the training data once. Training usually takes many epochs.
- **DataLoader:** In PyTorch, a DataLoader helps us load data in batches, shuffle it, and use multiple CPU cores to speed up the process. This is essential for efficient deep learning training.

## 4. Download Pretrained Model & Prepare Dataset


This week, we delve into the power of deep models like CNNs, leveraging the PyTorch library as our framework. PyTorch provides the flexibility and tools necessary to explore and implement these complex architectures for challenging tasks.

### 4.1 Define Custom Dataset in PyTorch
Next we define our custom `ChestXrayDataset` using torch `Dataset` from `torch.utils.data`


In PyTorch, datasets are represented as classes inheriting from torch.utils.data.Dataset. Here, we define a ChestXrayDataset class to handle image loading and preprocessing.


This dataset class does three main things:

-  Loads the X-ray images from a directory.

-  Applies any preprocessing or transformations (resizing, normalization, augmentation) specified by transform.

-  Returns the image and its corresponding label as a PyTorch tensor.

In [None]:
class ChestXrayDataset(Dataset):
    def __init__(self, df, img_dir, transform=None):
        self.df = df
        self.img_dir = img_dir
        self.transform = transform

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        img_path = os.path.join(self.img_dir, row['Image Index'])
        image = Image.open(img_path).convert('RGB')
        if self.transform:
            image = self.transform(image)
        label = torch.tensor(row['Binary Label'], dtype=torch.float32)
        return image, label

### 4.2 Image Transforms

**Transforms** are changes we make to images as we load them. This can include resizing, flipping, rotating, or normalizing the images. When we do these changes randomly during training, it is called **data augmentation**. Data augmentation helps the model learn to recognize patterns in different situations, making it more robust and less likely to memorize the training data (a problem called overfitting).

In [None]:
mean = [0.485, 0.456, 0.406]
std  = [0.229, 0.224, 0.225]

image_size_= 224

train_transforms = transforms.Compose([
    transforms.Resize((image_size_,image_size_)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(10),
    transforms.ToTensor(),
    transforms.Normalize(mean, std),
])


Now the `val_transforms`

In [None]:
val_transforms   = transforms.Compose([
    transforms.Resize((image_size_,image_size_)),
    transforms.ToTensor(),
    transforms.Normalize(mean, std),
])

### 4.3 Dataloaders

A **DataLoader** is a tool in PyTorch that helps us load data in small groups called **mini-batches**. Instead of giving the model one image at a time, we give it a batch of images. This makes training faster and helps the model learn more stable patterns. Dataloaders also make it easy to shuffle the data and use multiple CPU cores for loading.

In [None]:
img_dir  = f'{DATA_PATH}/images'

train_ds = ChestXrayDataset(train_df, img_dir, transform=train_transforms)
train_loader = DataLoader(train_ds, batch_size=64, shuffle=True, num_workers=90, pin_memory=True)

We pass the `Dataset` as an argument to `DataLoader`. This wraps an iterable over our dataset, and supports automatic batching, sampling, shuffling and multiprocess data loading. Here we define a batch size of 32, *i.e.* each element in the dataloader iterable will return a batch of 32 features and labels.

In [None]:
for X, y in train_loader:
    print(f"Shape of X [N, C, H, W]: {X.shape}")
    print(f"Shape of y: {y.shape} {y.dtype}")
    break

In [None]:
IMAGENET_MEAN = np.array([0.485, 0.456, 0.406], dtype=np.float32)
IMAGENET_STD  = np.array([0.229, 0.224, 0.225], dtype=np.float32)

In [None]:
val_ds   = ChestXrayDataset(val_df,   img_dir, transform=val_transforms)
val_loader   = DataLoader(val_ds,   batch_size=32, num_workers=32, pin_memory=True)

## 4.4 Pretrained Models & Transfer Learning

A **pretrained model** is a model that has already been trained on a large dataset (like ImageNet) and has learned useful features. The structure of the model is called its **architecture** (for example, DenseNet, ResNet, EfficientNet). Using a pretrained model and adapting it to our own data is called **transfer learning**. This is very helpful because it allows us to get good results even with smaller datasets and less training time.
In 'torchvision.models' we can find many popular pretrained models and architectures.

In [None]:
import torchvision
print(torchvision.__version__)

In [None]:
torchvision.models.list_models()[::30]

### 4.5  Understanding Model Layers

When looking at a deep learning model, you will see several types of layers. Here is what to look for in each:

- **Convolutional layers:** These are the building blocks of most image models. They scan the input image with small filters (sliding windows) to detect patterns like edges, shapes, or textures. The first convolutional layer takes the raw image (with 1 channel for grayscale or 3 for RGB) and produces feature maps.
- **Normalization layers (BatchNorm):** These layers help the model train faster and more reliably by keeping the outputs of previous layers at a similar scale. Batch Normalization (BatchNorm) is the most common type. It makes training more stable and helps the model generalize better.
- **Pooling layers:** Pooling reduces the size of the feature maps, making the model faster and helping it focus on the most important features. The most common is Max Pooling, which keeps only the largest value in each region.
- **Activation functions:** After each convolution, the model uses an activation function (like ReLU) to introduce non-linearity. This helps the model learn complex patterns, not just straight lines.

- **First layer:** This is usually a convolutional layer that takes the input image. Check its input dimension (number of channels, usually 1 for grayscale or 3 for RGB images).
- **Second layer:** Often another convolutional, normalization, activation, or pooling layer, building on the features from the first.
- **Second to last layer:** This is typically a feature layer just before the classifier. Its output dimension shows the number of features passed to the final classifier.
- **Last layer:** This is the classifier or output layer. Its output dimension should match the number of classes (1 for binary classification).

By examining these layers, you can understand how the model processes the input and what features are used for the final prediction.

### 4.6 Load Pretrained Model

- We load DenseNet121 pretrained on ImageNet.

- Replace the classifier with a single output unit for binary classification.

- Printing first, second, second-to-last, and last layers gives insight into the model structure.

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = models.densenet121(pretrained=True)
model.classifier = nn.Linear(model.classifier.in_features, 1)
model = model.to(device)

# Print only the first and last two layer blocks
layers = list(model.children())
print('First layer block:')
print(layers[0])
print('\n---')
print('Second layer block:')
print(layers[1])
print('\n...')
print('Second to last layer block:')
print(layers[-2])
print('\n---')
print('Last layer block:')
print(layers[-1])

## 5. Train the Model (10 minutes)


### 5.1 Hyperparameters

**Hyperparameters** are settings that you choose before training your model. They control how the learning process works. Common hyperparameters include:
- **Number of epochs:** How many times the model sees the whole training set.
- **Batch size:** How many samples are in each batch.
- **Learning rate:** How big the steps are when updating the model's weights.

Tuning hyperparameters is important because it can make a big difference in how well your model learns.

In [None]:
learning_rate = 1e-3
batch_size = 64
epochs = 10

### 5.2 Optimization loop

Training a deep learning model involves an **optimization loop**. Each time the model sees the whole training set, it completes one **epoch**. The process has two main parts:
- **Train loop:** The model learns from the training data and updates its parameters.
- **Validation loop:** The model is tested on validation data to see how well it is learning.

A **loss function** measures how far the model's predictions are from the true answers. The goal of training is to minimize this loss. The optimization loop repeats for many epochs until the model performs well.

Inside the training loop, the model learns by adjusting its parameters using **gradients**. Gradients show how much each parameter should change to reduce the loss. The process of calculating gradients and updating parameters is called **backpropagation**.

- **optimizer.zero_grad():** Resets the gradients to zero before each batch.
- **loss.backward():** Calculates the gradients using backpropagation.
- **optimizer.step():** Updates the model's parameters using the gradients.
- **Learning rate scheduler (like OneCycleLR):** Adjusts the learning rate during training to help the model learn better and faster.

In [None]:
# Re-initialize the model after cleanup
model = models.densenet121(pretrained=True)
model.classifier = nn.Linear(model.classifier.in_features, 1)
model = model.to(device)


optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
scheduler = OneCycleLR(optimizer, max_lr=learning_rate, epochs=epochs, steps_per_epoch=len(train_loader), total_steps=epochs * len(train_loader))

In [None]:
pos_frac = train_df['Binary Label'].mean()
pos_weight = torch.tensor([(1 - pos_frac) / pos_frac]).to(device)

criterion  = nn.BCEWithLogitsLoss(pos_weight=pos_weight)

### 5.3 Define Training & Validation Functions

Next, we define our **training function** and **validation function**. The training function teaches the model using the training data, while the validation function checks how well the model is doing on data it hasn't seen before. Keeping these functions separate helps us monitor the model's progress and avoid overfitting (when the model memorizes the training data but doesn't generalize well to new data).

In [None]:
def train_loop(model, loader, criterion, optimizer, scheduler, device):
    model.train()
    running_loss = 0.0
    for imgs, labels in tqdm(loader, desc="  Training", leave=False):
        imgs, labels = imgs.to(device), labels.to(device)
        optimizer.zero_grad()
        logits = model(imgs).squeeze(1)
        loss = criterion(logits, labels)
        loss.backward()
        optimizer.step()
        try:
            scheduler.step()
        except ValueError:
            pass

        running_loss += loss.item() * imgs.size(0)

    avg_loss = running_loss / len(loader.dataset)
    return avg_loss


def val_loop(model, loader, criterion, auroc, device):
    model.eval()
    auroc.reset()
    running_preds = []
    running_labels = []

    with torch.no_grad():
        for imgs, labels in tqdm(loader, desc="  Validation", leave=False):
            imgs = imgs.to(device)
            logits = model(imgs).squeeze(1)
            probs = torch.sigmoid(logits)

            preds = (probs > 0.5).int().cpu().numpy()
            running_preds.extend(preds.tolist())
            running_labels.extend(labels.int().tolist())

            auroc.update(probs, labels.int().to(device))

    acc = accuracy_score(running_labels, running_preds)
    val_auroc = auroc.compute().item()
    return acc, val_auroc


### 5.4 Load Dataset

Define the train and validation datasets and dataloaders.

In [None]:

img_dir  = f'{DATA_PATH}/images'

train_ds = ChestXrayDataset(train_df, img_dir, transform=train_transforms)
train_loader = DataLoader(train_ds, batch_size=64, shuffle=True, num_workers=90, pin_memory=True)

val_ds   = ChestXrayDataset(val_df,   img_dir, transform=val_transforms)
val_loader   = DataLoader(val_ds,   batch_size=32, num_workers=32, pin_memory=True)

auroc = BinaryAUROC().to(device)

### 5.5 Train the DenseNet121 Model

We now start the actual training loop for DenseNet121.

In [None]:
for epoch in range(1, epochs + 1):
    print(f"Epoch {epoch}/{epochs}")

    train_loss = train_loop(model, train_loader, criterion, optimizer, scheduler, device)
    val_acc, val_auroc = val_loop(model, val_loader, criterion, auroc, device)

    print(f"  Train Loss: {train_loss:.4f}  |  Val Acc: {val_acc:.4f}  |  Val AUROC: {val_auroc:.4f}")
# ---- Save only the final model ----
torch.save(model.state_dict(), "DenseNet121_final.pth")

## 6. Benchmarking model architectures (20 minutes)

A **CNN architecture** is the specific design or structure of a convolutional neural network. Different architectures (like ResNet, DenseNet, EfficientNet, Swin Transformer) use different building blocks:
- **Skip connections:** Allow information to skip layers, helping very deep networks learn better (used in ResNet).
- **Dense connections:** Connect each layer to every other layer in a block, improving information flow (used in DenseNet).
- **Normalization layers:** Help stabilize and speed up training by keeping the data flowing through the network at a similar scale.

Trying different architectures is important because some may work better for your specific problem. In this section, you will train and compare several architectures to see which performs best on your data.

<div class="alert alert-block alert-info">
<b>Q1.</b> In deep learning, different **model architectures** can have a big impact on performance. Complete the following cells to train and compare these models:
    - EfficientNet
    - Swin Transformer
</div>

Comparing different models helps you understand which design works best for your specific task and data.

### EfficientNet

In [None]:
from torchvision.models import efficientnet_b0, EfficientNet_B0_Weights
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model_enb0 = ... # COMPLETE
model_enb0.classifier[1] = nn.Linear(1280, 1)

model = ... # COMPLETE to device
print(model)

learning_rate = ... # COMPLETE
batch_size = ... # COMPLETE
epochs = ... # COMPLETE

pos_frac = train_df['Binary Label'].mean()
pos_weight = torch.tensor([(1 - pos_frac) / pos_frac]).to(device)
criterion  = ... # COMPLETE

optimizer = ... # COMPLETE
scheduler = OneCycleLR(optimizer, max_lr=learning_rate, epochs=epochs, steps_per_epoch=len(train_loader), total_steps=epochs * len(train_loader)) # Explicitly set total_steps

img_dir  = f'{DATA_PATH}/images'

train_ds = ChestXrayDataset(train_df, img_dir, transform=train_transforms)
train_loader = DataLoader(train_ds, batch_size=64, shuffle=True, num_workers=90, pin_memory=True)

val_ds   = ChestXrayDataset(val_df,   img_dir, transform=val_transforms)
val_loader   = DataLoader(val_ds,   batch_size=32, num_workers=32, pin_memory=True)

auroc = BinaryAUROC().to(device)

for epoch in range(1, epochs + 1):
    print(f"Epoch {epoch}/{epochs}")

    train_loss = train_loop(model, train_loader, criterion, optimizer, scheduler, device)
    val_acc, val_auroc = val_loop(model, val_loader, criterion, auroc, device)

    print(f"  Train Loss: {train_loss:.4f}  |  Val Acc: {val_acc:.4f}  |  Val AUROC: {val_auroc:.4f}")

Save the trained model

In [None]:
torch.save(model.state_dict(), './efficientnet-classification.pth')
print("Model saved successfully!")

### SwinTransformer

In [None]:
from torchvision.models import swin_t, Swin_T_Weights
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model_swin = ... # COMPLETE
model_swin.head = nn.Linear(in_features=768, out_features=1, bias=True)
model = ... # COMPLETE to device
print(model)

learning_rate = ... # COMPLETE
batch_size = ... # COMPLETE
epochs = ... # COMPLETE

pos_frac = train_df['Binary Label'].mean()
pos_weight = torch.tensor([(1 - pos_frac) / pos_frac]).to(device)
criterion  = ... # COMPLETE

optimizer = ... # COMPLETE
scheduler = OneCycleLR(optimizer, max_lr=learning_rate, epochs=epochs, steps_per_epoch=len(train_loader), total_steps=epochs * len(train_loader)) # Explicitly set total_steps

img_dir  = f'{DATA_PATH}/CXR8/images'

train_ds = ChestXrayDataset(train_df, img_dir, transform=train_transforms)
train_loader = DataLoader(train_ds, batch_size=64, shuffle=True, num_workers=90, pin_memory=True)

val_ds   = ChestXrayDataset(val_df,   img_dir, transform=val_transforms)
val_loader   = DataLoader(val_ds,   batch_size=32, num_workers=32, pin_memory=True)

auroc = BinaryAUROC().to(device)

for epoch in range(1, epochs + 1):
    print(f"Epoch {epoch}/{epochs}")

    train_loss = train_loop(model, train_loader, criterion, optimizer, scheduler, device)
    val_acc, val_auroc = val_loop(model, val_loader, criterion, auroc, device)

    print(f"  Train Loss: {train_loss:.4f}  |  Val Acc: {val_acc:.4f}  |  Val AUROC: {val_auroc:.4f}")

The Swin Transformer initially learned useful features (Val Acc ~0.83, AUROC ~0.86) but quickly collapsed. From epoch 3 onward, AUROC dropped to ~0.5 while accuracy reflected the majority class, indicating the model was no longer distinguishing classes. This behavior is caused by a high learning rate destabilizing pretrained weights and class imbalance in the dataset. AUROC is the better metric here, showing the model is effectively guessing randomly.

In [None]:
print(train_df['Binary Label'].value_counts(normalize=True))


This shows a class imbalance in the dataset, with the majority class (0) being about twice the size of the minority class (1).

We fix this by:

- Using a WeightedRandomSampler to create balanced batches during training.

- Applying a class-weighted loss (BCEWithLogitsLoss with pos_weight) to give more importance to the minority class.

- Adjusting learning rate and batch size for more stable training.

In [None]:
from torchvision.models import swin_t, Swin_T_Weights
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model_swin = ... # COMPLETE
model_swin.head = nn.Linear(in_features=768, out_features=1, bias=True)
model = ... # COMPLETE to device
print(model)

learning_rate = ... # COMPLETE
batch_size = ... # COMPLETE
epochs = ... # COMPLETE

pos_frac = train_df['Binary Label'].mean()
pos_weight = torch.tensor([(1 - pos_frac) / pos_frac]).to(device)
criterion  = ... # COMPLETE

optimizer = ... # COMPLETE
scheduler = OneCycleLR(optimizer, max_lr=learning_rate, epochs=epochs, steps_per_epoch=len(train_loader), total_steps=epochs * len(train_loader)) # Explicitly set total_steps

img_dir  = f'{DATA_PATH}/CXR8/images'

train_ds = ChestXrayDataset(train_df, img_dir, transform=train_transforms)
train_loader = DataLoader(train_ds, batch_size=64, shuffle=True, num_workers=90, pin_memory=True)

val_ds   = ChestXrayDataset(val_df,   img_dir, transform=val_transforms)
val_loader   = DataLoader(val_ds,   batch_size=32, num_workers=32, pin_memory=True)

auroc = BinaryAUROC().to(device)

for epoch in range(1, epochs + 1):
    print(f"Epoch {epoch}/{epochs}")

    train_loss = train_loop(model, train_loader, criterion, optimizer, scheduler, device)
    val_acc, val_auroc = val_loop(model, val_loader, criterion, auroc, device)

    print(f"  Train Loss: {train_loss:.4f}  |  Val Acc: {val_acc:.4f}  |  Val AUROC: {val_auroc:.4f}")

After these changes, the model achieves a high Val AUROC (~0.91), showing it can discriminate between classes effectively despite the imbalance.

Save the trained model

In [None]:
torch.save(model.state_dict(), './swintransformer-classification.pth')
print("Model saved successfully!")

### Model Benchmarking Summary

We trained three different architectures on our chest X-ray dataset: **DenseNet121**, **EfficientNet-B0**, and **Swin Transformer**. Key results:

| Model            | Train Loss (final) | Val Acc (final) | Val AUROC (final) |
|-----------------|-----------------|----------------|------------------|
| DenseNet121      | 0.260           | 0.875          | 0.911            |
| EfficientNet-B0  | 0.198           | 0.853          | 0.904            |
| Swin Transformer | 0.435           | 0.834          | 0.915            |

**Observations:**
- DenseNet121 has high validation accuracy and AUROC.
- EfficientNet is efficient with low train loss.
- Swin Transformer captures global features but requires careful tuning.



After completing the experiments, summarize the results:

- Which loss function achieved the highest validation performance (e.g., best F1-score or AUROC) on the triage classification task?

- Why do you think that loss function worked better than the others for this problem?

- Did combining different losses (e.g., CrossEntropy + Focal) provide a “best of both worlds” effect, or did a simpler loss function perform just as well or even better?


## 7. Final Evaluation

We have trained our models (DenseNet121, EfficientNet-B0, Swin Transformer) and used the validation set to guide hyperparameter tuning. Now, we evaluate them on the held-out test set, which the models have never seen before. This gives the most honest estimate of performance on new, unseen patients — the final verdict on model capability.


<div class="alert alert-block alert-info">
<b>Q2.</b> Compare the performance of the four models using appropriate metrics: <br>
    - Plot the ROC curve <br>
    - Accuracy score <br>
    - Precision and Recall <br>
    - F1-score <br>
</div>


Load the saved models

In [None]:
model_densenet = ... # COMPLETE
model_densenet.classifier = ... # COMPLETE

model_enb0 = ... # COMPLETE
model_enb0.classifier[1] = ... # COMPLETE

model_swin = ... # COMPLETE
model_swin.head = ... # COMPLETE

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_densenet.load_state_dict(torch.load('... # COMPLETE))
model_enb0.load_state_dict(torch.load('.... # COMPLETE))
model_swin.load_state_dict(torch.load('.... # COMPLETE))




Evaluate your models on the test set using the following metrics:

- Accuracy: Fraction of correctly classified cases. Gives a quick overview but can be misleading with imbalanced classes.

- AUROC (Area Under the ROC Curve): Measures the model’s ability to discriminate positive vs. negative classes; robust to class imbalance.

- Precision: Of all predicted positives, how many are truly positive? High precision avoids unnecessary alarm.

- Recall (Sensitivity): Of all true positives, how many were detected? High recall avoids missing urgent cases.

- F1-score: Harmonic mean of precision and recall, balancing false positives and negatives.

>Tip: For imbalanced medical datasets, AUROC and F1-score are usually more reliable than plain accuracy.

In [None]:
test_ds     = ChestXrayDataset(test_df, img_dir, transform=val_transforms)
test_loader = DataLoader(
    test_ds,
    batch_size=32,
    num_workers=32,
    pin_memory=True,
    shuffle=False
)

model_list  = [model_densenet, model_enb0, model_swin]
model_names = ['DenseNet 121', 'EfficientNet B0', 'Swin Transformer']

all_preds_proba = {}
all_labels      = None

for model, name in zip(model_list, model_names):
    model.eval()
    model.to(device)

    probs_list, labels_list = [], []

    with torch.no_grad():
        for X, y in test_loader:
            X = X.to(device)
            out = model(X)
            probs = torch.sigmoid(out)
            probs = probs.squeeze(1)
            probs_list.extend(probs.cpu().numpy())
            labels_list.extend(y.numpy())

    preds_proba = np.array(probs_list)
    labels      = np.array(labels_list)

    all_preds_proba[name] = preds_proba
    if all_labels is None:
        all_labels = labels

    preds_binary = (preds_proba > 0.5).astype(int)

    acc     = ... # COMPLETE
    prec    = ... # COMPLETE
    rec     = ... # COMPLETE
    f1      = ... # COMPLETE
    roc_auc = ... # COMPLETE)

... # COMPLETE WITH ROC Curve display, and other metrics


Conclude: Which model gives the best performance and should be selected?

### Part II: Quality Control

Once a deep learning model is trained, it is crucial not only to evaluate its performance using metrics like accuracy, precision, or ROC-AUC, but also to understand why the model makes certain predictions. This is especially important in medical imaging, where trust and interpretability are critical.

To achieve this, we use visual attribution methods that highlight the regions of the input image that most influenced the model's decision. In this project, we apply the following explainable AI techniques using the pytorch_grad_cam library:

- Grad-CAM (Gradient-weighted Class Activation Mapping):
Generates a heatmap that highlights the important regions of the image by computing the gradients of the target class with respect to the feature maps of the last convolutional layer.

- Score-CAM:
Unlike Grad-CAM, it does not rely on gradients. Instead, it evaluates the importance of each activation map by measuring how much it increases the score for the target class, making it more stable and sometimes more visually interpretable.

- Grad-CAM++:
An improved version of Grad-CAM that can better handle multiple occurrences of the target class in the same image and generally produces sharper localization maps.

- LIME (Local Interpretable Model-agnostic Explanations):
A model-agnostic method that perturbs the input image and observes the change in prediction, creating a weighted map of important superpixels.

Importance of this step:

- Provides interpretability, helping clinicians understand which regions of the image are influencing predictions.

- Detects potential biases in the model.

- Builds trust in AI-assisted diagnosis by making the model’s decision process transparent.

- Supports error analysis, helping improve model design and dataset quality.

By applying Grad-CAM, Score-CAM, Grad-CAM++, and LIME on our test set, we generate visual explanations for the model predictions and identify the most relevant regions in the X-ray images that contributed to each prediction. This step bridges the gap between model performance metrics and actionable insights for real-world medical applications.

In [None]:
#@title Utility functions
import numpy as np
import matplotlib.pyplot as plt
import cv2
from pytorch_grad_cam.utils.image import show_cam_on_image

def visualize_gradcam(
    original_images: torch.Tensor,  # (N, C, H, W)
    heatmaps:         np.ndarray,    # (N, H_cam, W_cam)
    true_labels:      np.ndarray,    # (N,)
    num_to_show:      int = 5,
    mean:             float|list|tuple = 0.5,
    std:              float|list|tuple = 0.2
):
    """
    - original_images: Tensor(N, C, H, W), normalized via (x-mean)/std
    - heatmaps:        ndarray(N, H_cam, W_cam) in [0,1]
    - true_labels:     ndarray(N,)
    """
    N = min(num_to_show, original_images.shape[0])

    # prepare mean/std arrays for un-normalization
    if isinstance(mean, (list, tuple, np.ndarray)):
        mean_arr = np.array(mean)[:, None, None]
        std_arr  = np.array(std)[:,  None, None]
    else:
        mean_arr = mean
        std_arr  = std

    for i in range(N):
        # 1) pull & un-normalize the i-th image
        img = original_images[i].cpu().numpy()          # (C, H, W)
        img = img * std_arr + mean_arr                  # broadcast over C,H,W
        img = np.clip(img, 0, 1)

        # 2) convert to H×W×C for plotting
        img_hwc = np.transpose(img, (1, 2, 0))         # (H, W, C)
        H, W, _ = img_hwc.shape

        # 3) resize heatmap to match image size
        hm = heatmaps[i]                                # (H_cam, W_cam)
        hm_resized = cv2.resize(hm, (W, H), interpolation=cv2.INTER_LINEAR)

        # 4) overlay CAM
        cam_overlay = show_cam_on_image(img_hwc, hm_resized, use_rgb=True)

        # 5) plot side by side
        fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10, 5))
        ax1.imshow(img_hwc)
        ax1.set_title(f"Original (Label: {true_labels[i]})")
        ax1.axis('off')

        ax2.imshow(cam_overlay)
        ax2.set_title("GradCAM Overlay")
        ax2.axis('off')

        plt.tight_layout()
        plt.show()

# Grad-CAM


In [None]:
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
from pytorch_grad_cam import GradCAM
from pytorch_grad_cam.utils.image import show_cam_on_image
from torch.utils.data import DataLoader

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# ----- 1. List of models -----
model_list  = [model_densenet, model_enb0, model_swin]
model_names = ['DenseNet 121', 'EfficientNet B0']

# ----- 2. Function to get last conv layer -----
def get_last_conv(model):
    last_name, last_conv = None, None
    for name, module in model.named_modules():
        if isinstance(module, nn.Conv2d):
            last_name, last_conv = name, module
    return last_name, last_conv

# ----- 3. Create test loader (already prepared) -----
test_ds = ChestXrayDataset(test_df, img_dir, transform=val_transforms)
test_loader = DataLoader(
    test_ds,
    batch_size=16,  # smaller batch for Grad-CAM
    shuffle=False,
    num_workers=4,
    pin_memory=True
)

# ----- 4. Loop through models -----
for model, name in zip(model_list, model_names):
    model.to(device).eval()

    # Get last conv layer
    layer_name, target_layer = get_last_conv(model)
    print(f"Grad-CAM for {name}")
    print(f"Target layer: {layer_name}")

    # Initialize Grad-CAM
    cam = GradCAM(model=model, target_layers=[target_layer])


    heatmaps, orig_images, true_labels = [], [], []

    # ----- 5. Loop through test set -----
    for X, y in test_loader:
        X = X.to(device)
        y = y.numpy()

        # Forward pass and predictions
        with torch.no_grad():
            out = model(X)
            if out.shape[1] == 2:
                probs = torch.softmax(out, dim=1)[:, 1]
            else:
                probs = torch.sigmoid(out).squeeze(1)

        preds_binary = (probs > 0.5).cpu().numpy()

        # Keep only correctly predicted samples
        correct_mask = (preds_binary == y)
        if correct_mask.sum() == 0:
            continue  # skip if no correct predictions in batch

        X_correct = X[correct_mask]
        y_correct = y[correct_mask]

        # Compute Grad-CAM
        hm_batch = cam(input_tensor=X_correct)
        heatmaps.extend(hm_batch)
        orig_images.extend(X_correct.cpu())
        true_labels.extend(y_correct)

    # ----- 6. Convert to arrays -----
    orig_images = torch.stack(orig_images).detach().cpu().numpy()
    heatmaps    = np.stack(heatmaps, axis=0)
    true_labels = np.array(true_labels, dtype=int)

    # ----- 7. Visualize -----
    num_to_show = min(5, len(orig_images))
    mean = np.array([0.485, 0.456, 0.406])
    std  = np.array([0.229, 0.224, 0.225])

    for i in range(num_to_show):
        orig = orig_images[i].transpose(1, 2, 0)  # CHW -> HWC
        orig = std * orig + mean
        orig = np.clip(orig, 0, 1)

        cam_image = show_cam_on_image(orig, heatmaps[i], use_rgb=True)

        plt.figure(figsize=(8,4))
        plt.subplot(1,2,1)
        plt.imshow(orig)
        plt.title(f"Original - Label: {true_labels[i]}")
        plt.axis('off')

        plt.subplot(1,2,2)
        plt.imshow(cam_image)
        plt.title("Grad-CAM")
        plt.axis('off')
        plt.show()


<div class="alert alert-block alert-info">
<b>Q3.</b> In the previous cell, we created the <code>GradCAM heatmaps</code>. Observe the code and the plots, and interpret how well the model is performing

# Score CAM

In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt
from pytorch_grad_cam import ScoreCAM
from pytorch_grad_cam.utils.image import show_cam_on_image
import torch.nn as nn


# --- 1. Prepare test dataset and loader ---
test_ds = ChestXrayDataset(test_df, img_dir, transform=val_transforms)
test_loader = DataLoader(
    test_ds,
    batch_size=32,
    num_workers=4,
    pin_memory=True,
    shuffle=False
)

# --- 1. Set device ---
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# --- 2. Prepare model list ---
model_list  = [model_densenet, model_enb0, model_swin]
model_names = ['DenseNet 121', 'EfficientNet B0']

# --- 3. Function to get last conv layer ---
def get_last_conv(model):
    last_name, last_conv = None, None
    for name, module in model.named_modules():
        if isinstance(module, nn.Conv2d):
            last_name, last_conv = name, module
    return last_name, last_conv

# --- 4. Loop over models ---
for model, name in zip(model_list, model_names):
    model.to(device).eval()

    # Get target layer
    _, target_layer = get_last_conv(model)
    print(f"Using Score-CAM on {name} layer: {target_layer}")

    cam = ScoreCAM(model=model, target_layers=[target_layer])

    # Process a few images to save memory
    inputs, labels = next(iter(test_loader))

    for i in range(min(5, inputs.size(0))):  # one image at a time
        input_img = inputs[i:i+1].to(device)  # batch size 1

        with torch.no_grad():
            heatmap = cam(input_tensor=input_img)[0]  # get single heatmap

        # Original image for display
        orig_img = input_img.cpu()[0].permute(1,2,0).numpy()
        orig_img = (orig_img - orig_img.min()) / (orig_img.max() - orig_img.min())  # normalize 0-1

        cam_img = show_cam_on_image(orig_img, heatmap, use_rgb=True)

        plt.figure(figsize=(8,4))
        plt.subplot(1,2,1)
        plt.imshow(orig_img)
        plt.title('Original')
        plt.axis('off')

        plt.subplot(1,2,2)
        plt.imshow(cam_img)
        plt.title(f'{name} Score-CAM')
        plt.axis('off')
        plt.show()

        # Free memory
        del input_img, heatmap, cam_img
        torch.cuda.empty_cache()


<div class="alert alert-block alert-info">
<b>Q4.</b> In the previous cell, we created the <code>ScoreCAM heatmaps</code>. What remarks can you identify with respect to the <code>GradCAM</code>

---



## GradCAM++

In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt
from pytorch_grad_cam import GradCAMPlusPlus
from pytorch_grad_cam.utils.image import show_cam_on_image
import torch.nn as nn

# --- 1. Set device ---
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# --- 2. Move model to device and set eval mode ---
model.to(device)
model.eval()

# --- 3. Function to get last convolutional layer ---
def get_last_conv(model):
    last_name, last_conv = None, None
    for name, module in model.named_modules():
        if isinstance(module, nn.Conv2d):
            last_name, last_conv = name, module
    return last_name, last_conv

# --- 4. Get target layer ---
_, target_layer = get_last_conv(model)
print(f"Using Grad-CAM++ on layer: {target_layer}")

# --- 5. Initialize GradCAM++ (no use_cuda argument) ---
cam = GradCAMPlusPlus(model=model, target_layers=[target_layer])

# --- 6. Parameters ---
N = 10  # number of images to visualize
mean = [0.5]  # modify according to your normalization
std  = [0.2]

heatmaps, orig_images, true_labels = [], [], []

# --- 7. Process images individually ---
processed = 0
for inputs, labels in test_loader:
    for i in range(inputs.size(0)):
        if processed >= N:
            break

        input_img = inputs[i:i+1].to(device)
        label = labels[i].item()

        # GradCAM++ requires gradients for input
        input_img.requires_grad_()

        # Compute GradCAM++
        hm = cam(input_tensor=input_img)

        # Store results
        heatmaps.append(hm[0])
        orig_images.append(input_img.detach().cpu()[0])
        true_labels.append(label)

        del input_img, hm
        torch.cuda.empty_cache()
        processed += 1
    if processed >= N:
        break

# --- 8. Convert lists to arrays/tensors ---
orig_images = torch.stack(orig_images, dim=0)  # (N, C, H, W)
heatmaps    = np.stack(heatmaps, axis=0)
true_labels = np.array(true_labels, dtype=int)

# --- 9. Visualization ---
def visualize_gradcam(original_images, heatmaps, true_labels, num_to_show=10, mean=[0.5], std=[0.2]):
    for i in range(num_to_show):
        # Detach & permute
        orig_img = original_images[i].permute(1,2,0).numpy()
        orig_img = (orig_img * std[0]) + mean[0]
        orig_img = np.clip(orig_img, 0, 1)

        heatmap = heatmaps[i]

        plt.figure(figsize=(8,4))
        plt.subplot(1,2,1)
        plt.imshow(orig_img)
        plt.title(f'Original - Label: {true_labels[i]}')
        plt.axis('off')

        plt.subplot(1,2,2)
        cam_img = show_cam_on_image(orig_img, heatmap, use_rgb=True)
        plt.imshow(cam_img)
        plt.title('Grad-CAM++')
        plt.axis('off')
        plt.show()

# --- 10. Visualize ---
visualize_gradcam(
    original_images=orig_images,
    heatmaps=heatmaps,
    true_labels=true_labels,
    num_to_show=N,
    mean=mean,
    std=std
)


<div class="alert alert-block alert-info">
<b>Q5.</b> In the previous cell, we created the <code>GradCAM++ heatmaps</code>. Can you spot the lesions in the image under the predicted class activation maps? Where do the difficulties arise?

# Conclusion and Final Thoughts

Congratulations on completing this comprehensive, hands-on journey through medical image classification!

Over the course of this notebook, you have successfully built, trained, and evaluated deep learning models from start to finish. You have gone beyond mere training—you have engaged in critical practices that transform a simple experiment into a robust, interpretable AI system.

## Key Achievements

- You successfully handled and prepared a real-world medical imaging dataset, including preprocessing and normalization.
- You built custom PyTorch `Dataset` and `DataLoader` pipelines tailored for classification tasks.
- You trained state-of-the-art models (DenseNet, EfficientNet, Swin Transformer) and evaluated them on unseen test data.
- You implemented advanced evaluation metrics beyond simple accuracy, including precision, recall, F1-score, and AUROC.
- You applied explainable AI techniques such as Grad-CAM, Grad-CAM++, and Score-CAM to interpret model predictions and identify the regions of the X-rays that influenced each decision.

## Key Takeaways

- **Data is foundational:** A thorough understanding of your dataset, including class balance and preprocessing, is essential for successful model training.  
- **Evaluation is multi-dimensional:** Relying on a single metric is insufficient. Metrics like AUROC, F1-score, and visual inspection of correctly and incorrectly classified images give a complete picture of model performance.  
- **Interpretability builds trust:** In high-stakes domains like medicine, it’s crucial to understand why models make their decisions. Techniques like Grad-CAM help clinicians trust AI-assisted predictions.

## Next Steps

This notebook provides a strong foundation for further exploration. Here are some directions to extend your work:

1. **Include Negative Samples:** Incorporate non-diseased cases to create a more comprehensive diagnostic model capable of both detection and classification.  
2. **Experiment with Architectures:** Try alternative architectures or ensemble approaches to improve classification performance.  
3. **Hyperparameter Optimization:** Systematically tune learning rates, batch sizes, and data augmentation strategies to improve metrics like AUROC and F1-score.  
4. **3D Medical Imaging:** Extend these methods to 3D datasets (CT or MRI scans) to handle more complex structures and richer diagnostic information.  
5. **Model Explainability:** Explore LIME or integrated gradients alongside Grad-CAM to provide multi-faceted interpretability for clinical decision support.

By following these practices, you have not only trained accurate models but also built a framework for safe, interpretable, and clinically useful AI in medical imaging.


> 60 minutes