# Summer School on Biomedical Imaging with Deep Learning

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/albarqounilab/BILD-Summer-School/blob/main/notebooks/day1/cnn_classification_improved_explanations2.ipynb)

![alt_text](https://raw.githubusercontent.com/albarqounilab/BILD-Summer-School/refs/heads/main/images/helpers/notebook-banner.png)

BILD 2025 is organized under the umbrella of the [Strategic Arab-German Network for Affordable and Democratized AI in Healthcare (SANAD)](https://albarqouni.github.io/funded/sanad/), uniting academic excellence and technological innovation across borders. This year’s edition is organized by the [Albarqouni Lab](https://albarqouni.github.io/) at the [University Hospital Bonn](https://www.ukbonn.de/) and the [University of Bonn](https://www.uni-bonn.de/en). We are proud to partner with leading institutions in the region—Lebanese American University, University of Tunis El Manar, and Duhok Polytechnic University — to deliver a truly international learning experience. Over five intensive days in Tunis, you will explore cutting-edge deep-learning techniques for medical imaging through expert lectures, hands-on labs, and collaborative case studies. Engage with peers and faculty from Germany, Lebanon, Iraq, and Tunisia as you develop practical skills in building and deploying AI models for real-world healthcare challenges. We look forward to an inspiring week of interdisciplinary exchange and the shared commitment to advancing affordable, life-saving AI in medicine.


## Chest-X-Ray Classification [78 mins]

### Introduction

This notebook will guide you step-by-step through practical exercises on using Convolutional Neural Networks (CNNs) for classifying and detecting diseases in chest X-ray images. If you are new to machine learning, PyTorch, or Python, don't worry—each section will explain what is happening and why it is important.

We will use two real-world chest X-ray datasets to learn how to:
- Recognize (classify) if a chest X-ray shows signs of disease or is healthy.
- Find (detect) the location of disease in the image using bounding boxes.

**Why do we do this?**
- Medical images like X-rays are used by doctors to diagnose diseases. Automating this process with AI can help doctors make faster and more accurate decisions.

**The datasets:**
- **NIH ChestX-ray14:** Over 100,000 X-ray images labeled with 14 different diseases. Some images also have boxes showing where the disease is located.
- **RSNA Pneumonia Detection Challenge:** About 30,000 X-rays with expert-drawn boxes around pneumonia, perfect for learning detection.

**What will you learn?**
1. **How to work with medical image datasets.**
   - Loading images and reading their labels (what disease, if any, is present).
   - Understanding the structure of the data and why we split it into training, validation, and test sets.
2. **How to build and train a model to classify images.**
   - Using powerful pre-trained models and adapting them to our problem.
   - Measuring how well our model is doing and how to improve it.
3. **How to detect disease locations in images.**
   - Using models that can draw boxes around areas of interest.
   - Evaluating how accurate these detections are.
4. **How to report and interpret results.**
   - Understanding key metrics like accuracy and ROC-AUC.
   - Visualizing results to see what the model got right and wrong.

By the end of this notebook, you will have hands-on experience with the full process of building, training, and evaluating deep learning models for medical image analysis, even if you are just starting out!

## Dataset

The [NIH ChestX-ray-14](https://nihcc.app.box.com/v/ChestXray-NIHCC) dataset is a large collection of chest X-ray images. Each image comes with information about the patient and labels that tell us which diseases (if any) are present. This dataset is widely used in medical AI research because it helps us train and test models to recognize diseases from X-ray images.

**What does the dataset contain?**
1. Over 100,000 chest X-ray images, each in PNG format. These are pictures of the inside of the chest, showing the lungs and heart.
2. A metadata file (`Data_Entry_2017.csv`) that lists information about each image, such as:
   - Which diseases are present (if any)
   - Patient age and gender
   - How the image was taken
3. A file with bounding boxes (`BBox_List_2017.csv`) for about 1,000 images. These boxes show exactly where a disease is located in the image.
4. Files that split the data into training and test sets. This is important because we want to train our model on some images and test it on others to see how well it works on new data.

**Why do we use this dataset?**
- It is large and diverse, which helps our model learn better.
- It has real medical labels, making our project more realistic.
- It allows us to practice both classification (is there a disease?) and detection (where is the disease?).

In this notebook, we will use a smaller sample of this dataset and pre-trained models to make the exercises faster and easier to follow.

In [None]:
# Install / update the huggingface hub CLI
!pip install -q "huggingface_hub[cli]"

# Download the dataset files via CLI
!hf download albarqouni/bild-dataset --repo-type dataset \
    --include "Classification/csv.zip" \
    --local-dir ./

!hf download albarqouni/bild-dataset --repo-type dataset \
    --include "Classification/data_cxr8.zip" \
    --local-dir ./

# Unzip the downloaded files
!unzip -q ./Classification/csv.zip -d ./Classification
!unzip -q ./Classification/data_cxr8.zip -d ./Classification

print("Download and extraction complete.")


# Classification

## Import essentials

In [None]:
#@title import libraries
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import os
import warnings
warnings.filterwarnings('ignore')

!pip install pydicom -q
import pydicom

!pip install SimpleITK -q
import SimpleITK as sitk

from glob import glob
import time
import cv2
import re
import torch
import torch.nn as nn
from torch.utils.data import Dataset
from torchvision import tv_tensors
import torchvision
from torchvision import transforms
from torch.utils.data import DataLoader
import pydicom # Added import for pydicom

from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, f1_score, recall_score, precision_score
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_curve, auc
import matplotlib.pyplot as plt
from tqdm import tqdm
from sklearn.metrics import RocCurveDisplay
from sklearn.metrics import roc_auc_score

from torchvision import transforms, models
from PIL import Image
import random
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
from torch.optim.lr_scheduler import OneCycleLR

from sklearn.metrics import accuracy_score
!pip install torchmetrics -q
from torchmetrics.classification import BinaryAUROC
from huggingface_hub import hf_hub_download


### Download

Before we can work with the data, we need to download and unzip it. This means we are copying the files from the internet to our computer and making them ready to use.

**Why do we do this?**
- Machine learning models need data to learn from. Downloading the dataset gives us the images and labels we need for our project.
- Unzipping extracts the files from a compressed format so we can access them easily in our code.

**Instructions:**
- If you have not downloaded the dataset yet, run the following cells to download and unzip the files.
- If you already have the data, you can skip these steps by adding a `#` before the `!` in the code (this comments out the line so it won't run).
- You can also change the `DATA_PATH` variable if you want to store the data in a different folder.


In [None]:
# Install / update huggingface hub CLI
!pip install -q "huggingface_hub[cli]"

# Download pretrained model weights via CLI
!hf download albarqouni/bild-dataset --repo-type dataset \
    --include "Classification/densenet121-classification.pth" \
    --local-dir ./

!hf download albarqouni/bild-dataset --repo-type dataset \
    --include "Classification/efficientnet-classification.pth" \
    --local-dir ./

!hf download albarqouni/bild-dataset --repo-type dataset \
    --include "Classification/swintransformer-classification.pth" \
    --local-dir ./

print("Model weights downloaded.")


### Load patient splits

To train and evaluate our model properly, we need to split our data into different groups:
- **Training set:** Used to teach the model.
- **Validation set:** Used to check how well the model is learning during training.
- **Test set:** Used to see how well the model works on completely new data.

In this step, we load lists of which images belong to each group. This helps us make sure that the model is tested on images it has never seen before, which is important for getting a fair measure of its performance.

In [None]:
DATA_PATH="./Classification"
train_val_patients = pd.read_csv(f'{DATA_PATH}/train_val_list.txt', header=None, names=['patientId'])
test_patients = pd.read_csv(f'{DATA_PATH}/test_list.txt', header=None, names=['patientId'])

print(f"Number of patients in train/val set: {len(train_val_patients)}")
print(f"Number of patients in test set: {len(test_patients)}")

The `.txt` files contain lists of image names that belong to the training/validation or test sets. To use these splits, we need to match the image names in these files with the information in our main database (`metadata.csv`). This way, we know which images and labels go into each group for training and testing.

### Load dataframe metadata

A **dataframe** is a table of data, like a spreadsheet, that we can easily work with in Python using the pandas library. Here, we load the metadata for all our images. This metadata tells us important information about each image, such as which diseases are present, the patient ID, and more. Loading this information helps us organize and prepare our data for training and testing our model.

In [None]:
# Load and observe available data
metadata_df = pd.read_csv(f'{DATA_PATH}/metadata.csv')
metadata_df#.head() # Print the 5 fist rows of the dataframe

Now we need to make sure that the information in our dataframe matches the images we actually downloaded. This step filters out any entries in the metadata that do not have a corresponding image file, so we only work with images that are available on our computer.

In [None]:
imgs = glob(f'{DATA_PATH}/images/*')
imgs_basename = [os.path.basename(i) for i in imgs]

metadata_df = metadata_df.loc[metadata_df['Image Index'].isin(imgs_basename)]
metadata_df.shape

### Handle targets

In machine learning, a **target** is what we want the model to predict. For this project, the target is the disease label for each image. In this step, we prepare the target labels so that our model can learn to predict them. This may involve simplifying the labels or grouping them in a way that makes the problem easier to solve.

In the next step, we look at how many times each disease label appears in our data. Some diseases are very rare, which can make it hard for the model to learn about them. To keep things simple and make sure our model has enough examples to learn from, we will remove labels that appear less than 1,500 times.

In [None]:
label_counts = metadata_df['Finding Labels'].value_counts()
label_counts

We remove rare labels (diseases that appear in fewer than 1,500 images) so that our model has enough examples to learn from. This helps the model focus on the most common diseases and improves its ability to make accurate predictions.

After filtering out rare labels, we are left with the most common disease categories. The table below shows how many images belong to each label. This helps us understand the balance of our dataset and which diseases our model will learn to recognize.

First, we look at how many images there are for each disease label. This helps us see if some diseases are much more common than others, which can affect how well our model learns.

In [None]:
label_counts = metadata_df['Finding Labels'].value_counts()
rare_labels = label_counts[label_counts < 1500].index

Now we update our data table (DataFrame) to remove any images with rare disease labels. This makes sure our model only sees images with the most common labels, which helps it learn better.

In [None]:
metadata_df_filtered = metadata_df[~metadata_df['Finding Labels'].isin(rare_labels)].copy()

print(f"Original shape: {metadata_df.shape}")
print(f"Filtered shape: {metadata_df_filtered.shape}")

In [None]:
metadata_df_filtered['Finding Labels'].value_counts()

To make our task easier, we will turn the problem into a **binary classification** problem. This means the model will learn to answer a simple question: Is this X-ray healthy or does it show signs of disease?

- **Class 0 (Negative):** Images labeled as 'No Finding' (healthy)
- **Class 1 (Positive):** Images with any disease label (pathology present)

This approach is common in deep learning when starting out, because it is easier for the model to learn to distinguish between just two categories. The category we want the model to predict is called the **target class**. Here, you can also try focusing on a specific disease (like 'Effusion') or experiment with more classes to see how the model behaves.

In [None]:
keep = {
    'No Finding', 'Effusion',
}

# split each cell into a list, then keep rows where at least one element is in `keep`
df_filtered = metadata_df_filtered[
    metadata_df_filtered['Finding Labels']
      .str.split('|')                         # or .str.split(',') if comma‑separated
      .apply(lambda labels: any(lbl in keep for lbl in labels))
].copy()
df_filtered['Finding Labels'].value_counts()

Now we create a new column called `Binary Label` in our data. This column will have a value of 0 for healthy images and 1 for images with any disease. This process is called **label encoding** and is very common in deep learning, because models work best with numbers instead of text.

In [None]:
df_filtered['Binary Label'] = (df_filtered['Finding Labels'] != 'No Finding').astype(int)
df_filtered['Binary Label'].value_counts()

We can further clean our dataset by selecting only one **view acquisition** type for our classifier. 'View acquisition' refers to the way the X-ray image was taken (for example, from the front or the side). Using only one type (like 'PA' for posteroanterior) helps the model learn more consistently, because all images will look similar in terms of orientation.

In [None]:
df_filtered = df_filtered[df_filtered["View Position"] == 'PA']
df_filtered['View Position'].value_counts()

Now we use the lists of patient IDs to split our data into a **training set** (used to teach the model) and a **test set** (used to check how well the model works on new, unseen data). This is called a **train-test split** and is a key step in building reliable machine learning models.

In [None]:
# Split df_filtered based on patient IDs from the loaded lists
train_val_df = df_filtered[df_filtered['Image Index'].isin(train_val_patients['patientId'])].copy()
test_df = df_filtered[df_filtered['Image Index'].isin(test_patients['patientId'])].copy()

print(f"Train val shape: {train_val_df.shape}")
print(f"Test set shape: {test_df.shape}")

### Deep Learning Data Terminology

- **Batch:** A batch is a small group of samples processed together by the model before updating its parameters. Using batches makes training faster and more stable.
- **Epoch:** One epoch means the model has seen all the training data once. Training usually takes many epochs.
- **DataLoader:** In PyTorch, a DataLoader helps us load data in batches, shuffle it, and use multiple CPU cores to speed up the process. This is essential for efficient deep learning training.

In [None]:
df = train_val_df.copy()
pos = df[df['Binary Label'] == 1]
neg = df[df['Binary Label'] == 0]

# # sample up to x each
n_samples = 3000
pos = pos.sample(n=min(len(pos), n_samples), random_state=42)
neg = neg.sample(n=min(len(neg), n_samples), random_state=42)

subset = pd.concat([pos, neg]).reset_index(drop=True)
print("Subset size:", subset.shape)
print(subset['Binary Label'].value_counts())

# %%
train_validation_df, test_df = train_test_split(
    subset,
    test_size=0.2,
    stratify=subset['Binary Label'],
    random_state=42
)
train_df, val_df = train_test_split(
    train_validation_df,
    test_size=0.1,
    stratify=train_validation_df['Binary Label'],
    random_state=42
)
print("Train:", train_df.shape, "Validation:", val_df.shape)

### Download the pretrained model


This week, we delve into the power of deep models like CNNs, leveraging the PyTorch library as our framework. PyTorch provides the flexibility and tools necessary to explore and implement these complex architectures for challenging tasks.

### Datasets in PyTorch
Next we define our custom `ChestXrayDataset` using torch `Dataset` from `torch.utils.data`

In [None]:
class ChestXrayDataset(Dataset):
    def __init__(self, df, img_dir, transform=None):
        self.df = df
        self.img_dir = img_dir
        self.transform = transform

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        img_path = os.path.join(self.img_dir, row['Image Index'])
        image = Image.open(img_path).convert('RGB')
        if self.transform:
            image = self.transform(image)
        label = torch.tensor(row['Binary Label'], dtype=torch.float32)
        return image, label

**Transforms** are changes we make to images as we load them. This can include resizing, flipping, rotating, or normalizing the images. When we do these changes randomly during training, it is called **data augmentation**. Data augmentation helps the model learn to recognize patterns in different situations, making it more robust and less likely to memorize the training data (a problem called overfitting).

In [None]:
mean = [0.485, 0.456, 0.406]
std  = [0.229, 0.224, 0.225]

image_size_= 224

train_transforms = transforms.Compose([
    transforms.Resize((image_size_,image_size_)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(10),
    transforms.ToTensor(),
    transforms.Normalize(mean, std),
])


Now the `val_transforms`

In [None]:
val_transforms   = transforms.Compose([
    transforms.Resize((image_size_,image_size_)),
    transforms.ToTensor(),
    transforms.Normalize(mean, std),
])

### Dataloaders

A **DataLoader** is a tool in PyTorch that helps us load data in small groups called **mini-batches**. Instead of giving the model one image at a time, we give it a batch of images. This makes training faster and helps the model learn more stable patterns. Dataloaders also make it easy to shuffle the data and use multiple CPU cores for loading.

In [None]:
img_dir  = f'{DATA_PATH}/images'

train_ds = ChestXrayDataset(train_df, img_dir, transform=train_transforms)
train_loader = DataLoader(train_ds, batch_size=64, shuffle=True, num_workers=90, pin_memory=True)

We pass the `Dataset` as an argument to `DataLoader`. This wraps an iterable over our dataset, and supports automatic batching, sampling, shuffling and multiprocess data loading. Here we define a batch size of 32, *i.e.* each element in the dataloader iterable will return a batch of 32 features and labels.

In [None]:
for X, y in train_loader:
    print(f"Shape of X [N, C, H, W]: {X.shape}")
    print(f"Shape of y: {y.shape} {y.dtype}")
    break

In [None]:
IMAGENET_MEAN = np.array([0.485, 0.456, 0.406], dtype=np.float32)
IMAGENET_STD  = np.array([0.229, 0.224, 0.225], dtype=np.float32)

In [None]:
val_ds   = ChestXrayDataset(val_df,   img_dir, transform=val_transforms)
val_loader   = DataLoader(val_ds,   batch_size=32, num_workers=32, pin_memory=True)

## Loading pretrained models with PyTorch

A **pretrained model** is a model that has already been trained on a large dataset (like ImageNet) and has learned useful features. The structure of the model is called its **architecture** (for example, DenseNet, ResNet, EfficientNet). Using a pretrained model and adapting it to our own data is called **transfer learning**. This is very helpful because it allows us to get good results even with smaller datasets and less training time.
In 'torchvision.models' we can find many popular pretrained models and architectures.

In [None]:
torchvision.models.list_models()[::30]

#### Understanding Model Layers

When looking at a deep learning model, you will see several types of layers. Here is what to look for in each:

- **Convolutional layers:** These are the building blocks of most image models. They scan the input image with small filters (sliding windows) to detect patterns like edges, shapes, or textures. The first convolutional layer takes the raw image (with 1 channel for grayscale or 3 for RGB) and produces feature maps.
- **Normalization layers (BatchNorm):** These layers help the model train faster and more reliably by keeping the outputs of previous layers at a similar scale. Batch Normalization (BatchNorm) is the most common type. It makes training more stable and helps the model generalize better.
- **Pooling layers:** Pooling reduces the size of the feature maps, making the model faster and helping it focus on the most important features. The most common is Max Pooling, which keeps only the largest value in each region.
- **Activation functions:** After each convolution, the model uses an activation function (like ReLU) to introduce non-linearity. This helps the model learn complex patterns, not just straight lines.

- **First layer:** This is usually a convolutional layer that takes the input image. Check its input dimension (number of channels, usually 1 for grayscale or 3 for RGB images).
- **Second layer:** Often another convolutional, normalization, activation, or pooling layer, building on the features from the first.
- **Second to last layer:** This is typically a feature layer just before the classifier. Its output dimension shows the number of features passed to the final classifier.
- **Last layer:** This is the classifier or output layer. Its output dimension should match the number of classes (1 for binary classification).

By examining these layers, you can understand how the model processes the input and what features are used for the final prediction.

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = models.densenet121(pretrained=True)
model.classifier = nn.Linear(model.classifier.in_features, 1)
model = model.to(device)

# Print only the first and last two layer blocks
layers = list(model.children())
print('First layer block:')
print(layers[0])
print('\n---')
print('Second layer block:')
print(layers[1])
print('\n...')
print('Second to last layer block:')
print(layers[-2])
print('\n---')
print('Last layer block:')
print(layers[-1])

### Hyperparameters

**Hyperparameters** are settings that you choose before training your model. They control how the learning process works. Common hyperparameters include:
- **Number of epochs:** How many times the model sees the whole training set.
- **Batch size:** How many samples are in each batch.
- **Learning rate:** How big the steps are when updating the model's weights.

Tuning hyperparameters is important because it can make a big difference in how well your model learns.

In [None]:
learning_rate = 1e-3
batch_size = 64
epochs = 10

### Optimization loop

Training a deep learning model involves an **optimization loop**. Each time the model sees the whole training set, it completes one **epoch**. The process has two main parts:
- **Train loop:** The model learns from the training data and updates its parameters.
- **Validation loop:** The model is tested on validation data to see how well it is learning.

A **loss function** measures how far the model's predictions are from the true answers. The goal of training is to minimize this loss. The optimization loop repeats for many epochs until the model performs well.

Inside the training loop, the model learns by adjusting its parameters using **gradients**. Gradients show how much each parameter should change to reduce the loss. The process of calculating gradients and updating parameters is called **backpropagation**.

- **optimizer.zero_grad():** Resets the gradients to zero before each batch.
- **loss.backward():** Calculates the gradients using backpropagation.
- **optimizer.step():** Updates the model's parameters using the gradients.
- **Learning rate scheduler (like OneCycleLR):** Adjusts the learning rate during training to help the model learn better and faster.

In [None]:
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
scheduler = OneCycleLR(optimizer, max_lr=learning_rate, epochs=epochs, steps_per_epoch=len(train_loader), total_steps=epochs * len(train_loader))

In [None]:
pos_frac = train_df['Binary Label'].mean()
pos_weight = torch.tensor([(1 - pos_frac) / pos_frac]).to(device)

criterion  = nn.BCEWithLogitsLoss(pos_weight=pos_weight)

Next, we define our **training function** and **validation function**. The training function teaches the model using the training data, while the validation function checks how well the model is doing on data it hasn't seen before. Keeping these functions separate helps us monitor the model's progress and avoid overfitting (when the model memorizes the training data but doesn't generalize well to new data).

In [None]:
def train_loop(model, loader, criterion, optimizer, scheduler, device):
    model.train()
    running_loss = 0.0
    for imgs, labels in tqdm(loader, desc="  Training", leave=False):
        imgs, labels = imgs.to(device), labels.to(device)
        optimizer.zero_grad()
        logits = model(imgs).squeeze(1)
        loss = criterion(logits, labels)
        loss.backward()
        optimizer.step()
        try:
            scheduler.step()
        except ValueError:
            pass

        running_loss += loss.item() * imgs.size(0)

    avg_loss = running_loss / len(loader.dataset)
    return avg_loss


def val_loop(model, loader, criterion, auroc, device):
    model.eval()
    auroc.reset()
    running_preds = []
    running_labels = []

    with torch.no_grad():
        for imgs, labels in tqdm(loader, desc="  Validation", leave=False):
            imgs = imgs.to(device)
            logits = model(imgs).squeeze(1)
            probs = torch.sigmoid(logits)

            preds = (probs > 0.5).int().cpu().numpy()
            running_preds.extend(preds.tolist())
            running_labels.extend(labels.int().tolist())

            auroc.update(probs, labels.int().to(device))

    acc = accuracy_score(running_labels, running_preds)
    val_auroc = auroc.compute().item()
    return acc, val_auroc


In [None]:

img_dir  = f'{DATA_PATH}/images'

train_ds = ChestXrayDataset(train_df, img_dir, transform=train_transforms)
train_loader = DataLoader(train_ds, batch_size=64, shuffle=True, num_workers=90, pin_memory=True)

val_ds   = ChestXrayDataset(val_df,   img_dir, transform=val_transforms)
val_loader   = DataLoader(val_ds,   batch_size=32, num_workers=32, pin_memory=True)

auroc = BinaryAUROC().to(device)
model = model.to(device)

for epoch in range(1, epochs + 1):
    print(f"Epoch {epoch}/{epochs}")

    train_loss = train_loop(model, train_loader, criterion, optimizer, scheduler, device)
    val_acc, val_auroc = val_loop(model, val_loader, criterion, auroc, device)

    print(f"  Train Loss: {train_loss:.4f}  |  Val Acc: {val_acc:.4f}  |  Val AUROC: {val_auroc:.4f}")

In [None]:
import gc
import torch

del model   # delete model object
del optimizer
del scheduler
del train_loader, val_loader, train_ds, val_ds
torch.cuda.empty_cache()  # clears unused cached memory
gc.collect()  # run garbage collector


### Benchmarking model architectures

A **CNN architecture** is the specific design or structure of a convolutional neural network. Different architectures (like ResNet, DenseNet, EfficientNet, Swin Transformer) use different building blocks:
- **Skip connections:** Allow information to skip layers, helping very deep networks learn better (used in ResNet).
- **Dense connections:** Connect each layer to every other layer in a block, improving information flow (used in DenseNet).
- **Normalization layers:** Help stabilize and speed up training by keeping the data flowing through the network at a similar scale.

Trying different architectures is important because some may work better for your specific problem. In this section, you will train and compare several architectures to see which performs best on your data.

<div class="alert alert-block alert-info">
<b>Q3.</b> In deep learning, different **model architectures** can have a big impact on performance. Complete the following cells to train and compare these models:
    - EfficientNet
    - Swin Transformer
</div>

Comparing different models helps you understand which design works best for your specific task and data.

### EfficientNet

In [None]:
from torchvision.models import efficientnet_b0, EfficientNet_B0_Weights
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model_enb0 = ... # COMPLETE
model_enb0.classifier[1] = nn.Linear(1280, 1)

model = ... # COMPLETE to device

learning_rate = ... # COMPLETE
batch_size = ... # COMPLETE
epochs = ... # COMPLETE

pos_frac = train_df['Binary Label'].mean()
pos_weight = torch.tensor([(1 - pos_frac) / pos_frac]).to(device)
criterion  = ... # COMPLETE

optimizer = ... # COMPLETE
scheduler = OneCycleLR(optimizer, max_lr=learning_rate, epochs=epochs, steps_per_epoch=len(train_loader), total_steps=epochs * len(train_loader)) # Explicitly set total_steps

img_dir  = f'{DATA_PATH}/images'

train_ds = ChestXrayDataset(train_df, img_dir, transform=train_transforms)
train_loader = DataLoader(train_ds, batch_size=64, shuffle=True, num_workers=90, pin_memory=True)

val_ds   = ChestXrayDataset(val_df,   img_dir, transform=val_transforms)
val_loader   = DataLoader(val_ds,   batch_size=32, num_workers=32, pin_memory=True)

auroc = BinaryAUROC().to(device)

for epoch in range(1, epochs + 1):
    print(f"Epoch {epoch}/{epochs}")

    train_loss = train_loop(model, train_loader, criterion, optimizer, scheduler, device)
    val_acc, val_auroc = val_loop(model, val_loader, criterion, auroc, device)

    print(f"  Train Loss: {train_loss:.4f}  |  Val Acc: {val_acc:.4f}  |  Val AUROC: {val_auroc:.4f}")

In [None]:
import gc
import torch

# After finishing training a model
del model   # delete model object
del optimizer
del scheduler
del train_loader, val_loader, train_ds, val_ds
torch.cuda.empty_cache()  # clears unused cached memory
gc.collect()  # run garbage collector


### SwinTransformer

In [None]:
from torchvision.models import swin_t, Swin_T_Weights
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model_swin = ... # COMPLETE
model_swin.head = nn.Linear(in_features=768, out_features=1, bias=True)
model = ... # COMPLETE to device
print(model)

learning_rate = ... # COMPLETE
batch_size = ... # COMPLETE
epochs = ... # COMPLETE

pos_frac = train_df['Binary Label'].mean()
pos_weight = torch.tensor([(1 - pos_frac) / pos_frac]).to(device)
criterion  = ... # COMPLETE

optimizer = ... # COMPLETE
scheduler = OneCycleLR(optimizer, max_lr=learning_rate, epochs=epochs, steps_per_epoch=len(train_loader), total_steps=epochs * len(train_loader)) # Explicitly set total_steps

img_dir  = f'{DATA_PATH}/images'

train_ds = ChestXrayDataset(train_df, img_dir, transform=train_transforms)
train_loader = DataLoader(train_ds, batch_size=64, shuffle=True, num_workers=90, pin_memory=True)

val_ds   = ChestXrayDataset(val_df,   img_dir, transform=val_transforms)
val_loader   = DataLoader(val_ds,   batch_size=32, num_workers=32, pin_memory=True)

auroc = BinaryAUROC().to(device)

for epoch in range(1, epochs + 1):
    print(f"Epoch {epoch}/{epochs}")

    train_loss = train_loop(model, train_loader, criterion, optimizer, scheduler, device)
    val_acc, val_auroc = val_loop(model, val_loader, criterion, auroc, device)

    print(f"  Train Loss: {train_loss:.4f}  |  Val Acc: {val_acc:.4f}  |  Val AUROC: {val_auroc:.4f}")

In [None]:
import gc
import torch

# After finishing training a model
del model   # delete model object
del optimizer
del scheduler
del train_loader, val_loader, train_ds, val_ds
torch.cuda.empty_cache()  # clears unused cached memory
gc.collect()  # run garbage collector


#### Metrics

After training, we need to measure how well our models perform. In deep learning, we use different **metrics** to evaluate models:
- **Accuracy:** The percentage of correct predictions.
- **Precision:** Answers the question: "Of all the pixels the model labeled as tumor, what fraction were actually tumor?" High precision means the model makes few false positive errors. Clinically, this translates to not raising false alarms or suggesting unnecessary biopsies. (TP/(TP+FP))
- **Recall:** Answers the question: "Of all the pixels that were actually tumor, what fraction did the model correctly identify?" High recall means the model makes few false negative errors. This is often critically important in medicine, as it relates to not missing a diagnosis. (TP/(TP+FN))
- **F1-score:** Answers the question: “How well does the model balance precision and recall?” It is the harmonic mean of the two, ensuring that a model must perform well on both dimensions rather than excelling in only one.
- **ROC curve:** A plot that shows how well the model separates healthy from diseased images at different thresholds.

Using multiple metrics gives a more complete picture of model performance, especially when the data is imbalanced.

<div class="alert alert-block alert-info">
<b>Q4.</b> Compare the performance of the different models using these metrics:
    - Plot the **ROC curve**
    - Accuracy score
    - Precision and Recall
    - F1-score
</div>

Comparing models with these metrics helps you choose the best one for your task.

Load the saved models

In [None]:
model_densenet = ... # COMPLETE
model_densenet.classifier = ... # COMPLETE

model_enb0 = ... # COMPLETE
model_enb0.classifier[1] = ... # COMPLETE

model_swin = ... # COMPLETE
model_swin.head = ... # COMPLETE

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_densenet.load_state_dict(torch.load('...' # COMPLETE))
model_enb0.load_state_dict(torch.load('...' # COMPLETE))
model_swin.load_state_dict(torch.load('...' # COMPLETE))


Now lets compute the metrics and plot all models: (8mins)

---



In [None]:
test_ds     = ChestXrayDataset(test_df, img_dir, transform=val_transforms)
test_loader = DataLoader(
    test_ds,
    batch_size=32,
    num_workers=32,
    pin_memory=True,
    shuffle=False
)

model_list  = [model_densenet, model_enb0, model_swin]
model_names = ['DenseNet 121', 'EfficientNet B0', 'Swin Transformer']

all_preds_proba = {}
all_labels      = None

for model, name in zip(model_list, model_names):
    model.eval()
    model.to(device)

    probs_list, labels_list = [], []

    with torch.no_grad():
        for X, y in test_loader:
            X = X.to(device)
            out = model(X)
            probs = torch.sigmoid(out)
            probs = probs.squeeze(1)
            probs_list.extend(probs.cpu().numpy())
            labels_list.extend(y.numpy())

    preds_proba = np.array(probs_list)
    labels      = np.array(labels_list)

    all_preds_proba[name] = preds_proba
    if all_labels is None:
        all_labels = labels

    preds_binary = (preds_proba > 0.5).astype(int)

    acc     = ... # COMPLETE
    prec    = ... # COMPLETE
    rec     = ... # COMPLETE
    f1      = ... # COMPLETE
    roc_auc = ... # COMPLETE)

    print(f"Model: {name}")
    print(f"  Accuracy : {acc:.4f}")
    print(f"  Precision: {prec:.4f}")
    print(f"  Recall   : {rec:.4f}")
    print(f"  F1-score : {f1:.4f}")
    print(f"  ROC AUC  : {roc_auc:.4f}")
    print("-" * 30)

... # COMPLETE WITH ROC Curve display, and other metrics

### Summary: Key Deep Learning Terms

- **Dataset:** The collection of images and labels we use to train and test our model.
- **Label:** The answer we want the model to predict (e.g., healthy or diseased).
- **Model architecture:** The design or structure of the neural network.
- **Training:** Teaching the model using known data.
- **Validation:** Checking the model's performance during training.
- **Test set:** Data the model has never seen, used to measure final performance.
- **Batch/Epoch:** Groups of data and full passes through the dataset.
- **Loss function:** Measures how wrong the model's predictions are.
- **Optimizer:** Algorithm that updates the model's parameters to reduce loss.
- **Metric:** A way to measure how well the model is doing.

Reflect on these terms as you work through the notebook—they are the foundation of deep learning!

# Advanced Topic: Quality Control  

---




In medical AI, accuracy alone is not enough. A reliable system must also quantify its uncertainty. We will now explore Quality Control techniques to ensure predictions are interpretable and safe.

For chest-X-ray **classification**, we’ll use **Grad-CAM**, **Score-CAM**, **Grad-CAM++ and LIME** to turn model predictions into **heatmaps** that highlight the image regions driving each decision.

###  Grad-CAM (Gradient-weighted Class Activation Mapping)
- Uses the **gradients of the target class** flowing into the last convolutional layer.  
- Produces **coarse heatmaps** that highlight the most important regions influencing the prediction.  
- Helps verify whether the model is focusing on the lungs rather than irrelevant areas.  

---
###  Grad-CAM++
- An **improvement over Grad-CAM**.  
- Better at handling **multiple occurrences** of the same pathology (e.g., lesions in different lung areas).  

---
### Score-CAM
- Does **not rely on gradients**.  
- Uses the **model’s confidence scores** to generate explanations.  
- Avoids problems like noisy or vanishing gradients.  

---

### LIME (Local Interpretable Model-agnostic Explanations)

- Perturbs the input image (e.g., hides or alters superpixels) and observes how predictions change.  
- Produces **superpixel-based explanations**, showing which regions increase or decrease the prediction probability.  

---

In [None]:
!pip install git+https://github.com/jacobgil/pytorch-grad-cam.git
!pip install lime -q

In [None]:
from pytorch_grad_cam import GradCAM, ScoreCAM, GradCAMPlusPlus
from pytorch_grad_cam.utils.image import show_cam_on_image

In [None]:
def get_last_conv(model):
    last_name, last_conv = None, None
    for name, module in model.named_modules():
        if isinstance(module, nn.Conv2d):
            last_name, last_conv = name, module
    return last_name, last_conv

name, target_layer = get_last_conv(model_densenet)

Implement GradCAM, ScoreCAM, and GradCAM++ from the pytorch_grad_cam library on the test_loader using the loaded model and visualize the results.

In [None]:
test_ds     = ChestXrayDataset(test_df, img_dir, transform=val_transforms)
test_loader = DataLoader(
    test_ds,
    batch_size=32,
    num_workers=32,
    pin_memory=True,
    shuffle=False
)


In [None]:
def visualize_gradcam(
    original_images: torch.Tensor,  # (N, C, H, W)
    heatmaps:         np.ndarray,    # (N, H_cam, W_cam)
    true_labels:      np.ndarray,    # (N,)
    num_to_show:      int = 5,
    mean:             float|list|tuple = 0.5,
    std:              float|list|tuple = 0.2,

):
    """
    - original_images: Tensor(N, C, H, W), normalized via (x-mean)/std
    - heatmaps:        ndarray(N, H_cam, W_cam) in [0,1]
    - true_labels:     ndarray(N,)
    """
    N = min(num_to_show, original_images.shape[0])

    # prepare mean/std arrays for un-normalization
    if isinstance(mean, (list, tuple, np.ndarray)):
        mean_arr = np.array(mean)[:, None, None]
        std_arr  = np.array(std)[:,  None, None]
    else:
        mean_arr = mean
        std_arr  = std

    for i in range(N):
        # 1) pull & un-normalize the i-th image
        img = original_images[i].cpu().numpy()          # (C, H, W)
        img = img * std_arr + mean_arr                  # broadcast over C,H,W
        img = np.clip(img, 0, 1)

        # 2) convert to H×W×C for plotting
        img_hwc = np.transpose(img, (1, 2, 0))         # (H, W, C)
        H, W, _ = img_hwc.shape

        # 3) resize heatmap to match image size
        hm = heatmaps[i]                                # (H_cam, W_cam)
        hm_resized = cv2.resize(hm, (W, H), interpolation=cv2.INTER_LINEAR)

        # 4) overlay CAM
        cam_overlay = show_cam_on_image(img_hwc, hm_resized, use_rgb=True)

        # 5) plot side by side
        fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10, 5))
        ax1.imshow(img_hwc)
        ax1.set_title(f"Original (Label: {true_labels[i]})")
        ax1.axis('off')

        ax2.imshow(cam_overlay)
        ax2.set_title("GradCAM Overlay")
        ax2.axis('off')

        plt.tight_layout()
        plt.show()

### Grad-CAM

#### Deep Dive: How Grad-CAM Works

Grad-CAM creates its heatmap by combining two key pieces of information from the model:

- **Feature Maps from a Convolutional Layer**: Deep inside the network, convolutional layers produce feature maps that highlight abstract patterns like textures, edges, and shapes. The final convolutional layers capture the most high-level, class-specific information.

- **Gradients**: It calculates the gradient (the importance signal) of the model's final prediction score with respect to each feature map. A high gradient for a particular feature map means that map was very influential in the final decision.

Visualize the original images and the generated heatmaps from GradCAM, ScoreCAM, and GradCAM++ for a subset of the test data.

GradCAM computes heatmaps using gradients of the target class with respect to the feature maps of the last convolutional layer.

Produces a heatmap that highlights regions that most strongly influence the model’s prediction.

In [None]:
import gc
import torch

# Delete variables that occupy RAM
del orig_images, heatmaps, true_labels, hm_batch, inputs, labels
gc.collect()

# If using GPU, clear CUDA cache
torch.cuda.empty_cache()


### Score CAM [6mins]


#### Deep Dive: How Score-CAM Works

Score-CAM avoids using gradients and instead relies on forward passes to assess importance:

- **Perturbation-Based Activation**: Each activation map from the target convolutional layer is normalized and used as a mask on the input image.

- **Forward Pass Evaluation**: The masked image is passed through the model, and the change in the target class score determines the importance of that activation map.

- **Heatmap Construction**: The final heatmap is a weighted combination of activation maps, highlighting the regions that most influence the model’s prediction.

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model_densenet.to(device).eval()
cam = ScoreCAM(model=model_densenet, target_layers=[target_layer])

orig_images_sc = []
heatmaps_sc    = []
true_labels    = []

for inputs, labels in test_loader:
    inputs = inputs.to(device)

    hm_batch = cam(input_tensor=inputs)
    heatmaps_sc.extend(hm_batch)

    orig_images_sc.extend(inputs.cpu())
    true_labels.extend(labels.numpy())

orig_images_sc = torch.stack(orig_images_sc, dim=0)
heatmaps_sc    = np.stack(heatmaps_sc, axis=0)
true_labels    = np.array(true_labels, dtype=int)

N = 10
visualize_gradcam(
    original_images=orig_images_sc,
    heatmaps=heatmaps_sc,
    true_labels=true_labels,
    num_to_show=N,
    mean=[0.5],
    std =[0.2],
    title="ScoreCAM Overlay"
)

ScoreCAM masks the feature maps of the target layer one by one, runs them through the network, and measures the increase in predicted score for the target class.

In [None]:
import torch
import gc

# Delete large objects
del orig_images_sc
del heatmaps_sc
del true_labels
del inputs, hm_batch
# Empty the cache
torch.cuda.empty_cache()

# Run garbage collection to free up Python memory
gc.collect()


### GradCAM++

#### Deep Dive: How Grad-CAM++ Works

Grad-CAM++ builds on Grad-CAM but improves the localization of small or multiple objects:

- **Weighted Combination of Feature Maps**: Instead of a simple global average of gradients, Grad-CAM++ uses a weighted sum that accounts for pixel-wise contributions.

- **Better Fine-Grained Localization**: This allows the heatmap to more precisely highlight smaller regions that are critical to the prediction, especially when multiple objects or details matter.

In [None]:
from pytorch_grad_cam import GradCAMPlusPlus

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model_densenet.to(device).eval()
cam = GradCAMPlusPlus(model=model_densenet, target_layers=[target_layer])

heatmaps    = []
orig_images = []
true_labels = []

for inputs, labels in test_loader:
    inputs = inputs.to(device)

    hm_batch = cam(input_tensor=inputs)
    heatmaps.extend(hm_batch)

    orig_images.extend(inputs.cpu())
    true_labels.extend(labels.numpy())

orig_images = torch.stack(orig_images, dim=0)
heatmaps    = np.stack(heatmaps,    axis=0)
true_labels = np.array(true_labels, dtype=int)

visualize_gradcam(
    original_images=orig_images,
    heatmaps=heatmaps,
    true_labels=true_labels,
    num_to_show=N,
    mean=mean,
    std=std
)

GradCAMPlusPlus is an improved version of Grad-CAM.

It better captures multiple important regions in an image (not just the strongest gradient).

Especially useful when:

- An image has multiple objects of the same class.

- You want finer-grained localization than standard Grad-CAM.

In [None]:
import torch
import gc

# Delete large objects
del orig_images
del heatmaps
del true_labels
del inputs, hm_batch
# Empty the cache
torch.cuda.empty_cache()

# Run garbage collection to free up Python memory
gc.collect()


### LIME [5mins]

#### Deep Dive: How LIME Works

LIME provides explanations by approximating the model locally with interpretable models:

- **Superpixel Segmentation**: The input image is divided into superpixels—small contiguous regions of similar color or texture.

- **Perturbation and Prediction**: LIME creates many perturbed versions of the image by hiding or altering superpixels and records the model’s predictions on these variants.

- **Linear Approximation**: It fits a simple interpretable model on the perturbed dataset to determine which superpixels most strongly influence the prediction.

- **Explanation**: The highlighted superpixels are those that contributed most to the model’s decision for the target class.

In [None]:
#@title helpers
from lime import lime_image
from skimage.segmentation import mark_boundaries
from torchvision.transforms.functional import to_pil_image
import numpy as np
import torch
import matplotlib.pyplot as plt

# (a) model → probability bridge for LIME
def batch_predict(np_imgs: list[np.ndarray]) -> np.ndarray:
    """
    Accepts a list of H×W×C RGB images in [0,255] (dtype uint8).
    Returns an (N, 2) array of class probabilities for LIME.
    """
    model_densenet.eval()
    with torch.no_grad():
        batch = torch.stack(
            [
                val_transforms(                     # same transforms you used at eval time
                    to_pil_image(img.astype(np.uint8))
                )
                for img in np_imgs
            ],
            dim=0,
        ).to(device)

        logits = model_densenet(batch)              # (N, 1)
        probs_pos = torch.sigmoid(logits)  # (N, 1)  – P(class = 1)
        probs_neg = 1 - probs_pos          # (N, 1)  – P(class = 0)
        probs = torch.cat([probs_neg, probs_pos], dim=1)  # (N, 2)

    return probs.cpu().numpy()

# (b) tensor → uint8 numpy (unnormalised) for LIME visualisation
def tensor_to_uint8(img_tensor: torch.Tensor) -> np.ndarray:
    """
    img_tensor: (3, H, W) – normalised
    Returns  H×W×3 uint8 image in RGB.
    """
    img = img_tensor.cpu().clone().numpy()
    img = img * np.array(std)[:, None, None] + np.array(mean)[:, None, None]
    img = np.clip(img, 0, 1)
    img = (np.transpose(img, (1, 2, 0)) * 255).astype(np.uint8)
    return img

Instantiate LIME explainer LimeImageExplainer()

In [None]:
explainer = lime_image.LimeImageExplainer(random_state=42)

num_samples_to_show = 5      # how many images you want to inspect
lime_samples        = 1000   # neighbourhood size – higher = slower but smoother
top_labels          = [1]    # we care about the “positive” class (index 1)

for i in range(num_samples_to_show):
    img_tensor, true_label = test_ds[i]
    img_uint8              = tensor_to_uint8(img_tensor)

    # LIME explanation
    explanation = explainer.explain_instance(
        image=img_uint8,
        classifier_fn=batch_predict,
        labels=top_labels,
        hide_color=0,
        num_samples=lime_samples
    )

    lime_img, lime_mask = explanation.get_image_and_mask(
        label=1,
        positive_only=False,
        hide_rest=False,
        num_features=8,
        min_weight=0.0
    )

    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 6))

    ax1.imshow(img_uint8)
    ax1.set_title(f"Original  –  label = {int(true_label)}")
    ax1.axis("off")

    ax2.imshow(mark_boundaries(lime_img / 255.0, lime_mask))
    ax2.set_title("LIME explanation")
    ax2.axis("off")

    plt.tight_layout()
    plt.show()


LIME divides the image into superpixels (small contiguous regions of similar color/texture).

It then perturbs the image by hiding or altering these superpixels and observes how the model’s prediction changes.

The highlighted superpixels are those that most strongly influenced the model’s prediction for the class you’re explaining.

### Conclusion
By applying multiple quality control methods, we can better understand what drives the model’s predictions:

- **GradCAM / GradCAM++ / ScoreCAM**: Highlight the regions of the image that contributed most to the model’s decision. These methods provide a coarse-to-fine visualization of attention, showing where the model “looks” when predicting a class. GradCAM++ tends to emphasize smaller, more precise areas compared to GradCAM. ScoreCAM avoids gradient computation and often produces smoother, less noisy heatmaps.

- **LIME**: Highlights superpixels that most strongly influence the prediction by perturbing them and observing changes in output. Unlike CAM-based methods, LIME works in a model-agnostic way and can provide complementary explanations.