# Pathology Deep Learning Hands-On

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/KatherLab/stamp_demo/blob/master/stamp_hands_on.ipynb)

Welcome to the 2025 Clinicum Digitale digital pathology hands-on session.
In this session we will have a look at what a typical machine learning workflow in our lab looks like.
We will predict the TP53 gene alteration in breast cancer from histopathologic whole-slide images.

# Prerequisites

## 1. Change runtime type
First, switch to a GPU-enabled Colab runtime: within Google Colab, go to *Runtime* $\to$ *Change runtime type*, and select **, as indicated in the screenshot below.

<img src="https://github.com/KatherLab/stamp_demo/blob/master/colab_runtime.png?raw=true" width=500 />

## 2. Install dependencies
Here, we will install [STAMP](https://github.com/KatherLab/STAMP), a pipeline for computational pathology developed in [our lab](https://jnkather.github.io/).

In [None]:
!pip install "stamp @ git+https://github.com/KatherLab/STAMP@feature/validation-config"

Due to a weird bug in Google Colab, please now restart the kernel (*Runtime* $\to$ *Restart session*), and continue from here.

In [None]:
import lightning  # if this throws an error, please restart the kernel

## 3. Download data
Let's download our dataset of extracted features. This will take a few minutes.

In [None]:
import requests
from tqdm.notebook import tqdm
from pathlib import Path
import hashlib

_DOWNLOAD_PARTS = {
    "TCGA_BRCA_10x_UNI_features.tar.gz.part_aa": "6ff1600f3dcdc6344d3a5c46eca481c4",
    "TCGA_BRCA_10x_UNI_features.tar.gz.part_ab": "7b4c7bb21ac365ee86be86e10f6e4efa",
}


def md5(fname: str, chunk_size=8192) -> str:
    hash_md5 = hashlib.md5()
    with open(fname, "rb") as f:
        while chunk := f.read(chunk_size):
            hash_md5.update(chunk)
    return hash_md5.hexdigest()


def download(url: str, output_file: Path, checksum: str, chunk_size=1024):
    if output_file.exists():
        if md5(output_file) == checksum:
            print(f"{output_file} already downloaded, skipping...")
            return
        else:
            output_file.unlink()

    resp = requests.get(url, stream=True)
    total = int(resp.headers.get("content-length", 0))
    with (
        output_file.open("wb") as f,
        tqdm(
            desc=str(output_file),
            total=total,
            unit="iB",
            unit_scale=True,
            unit_divisor=1024,
        ) as bar,
    ):
        for data in resp.iter_content(chunk_size=chunk_size):
            size = f.write(data)
            bar.update(size)


if not Path("TCGA_BRCA_10x_UNI_features.tar.gz").exists():
    for filename, checksum in _DOWNLOAD_PARTS.items():
        download(
            f"https://github.com/KatherLab/stamp_demo/releases/download/data-release/{filename}",
            Path(filename),
            checksum,
        )

    !cat TCGA_BRCA_10x_UNI_features.tar.gz.part_* > TCGA_BRCA_10x_UNI_features.tar.gz

Now, let's extract the tar archive.

In [None]:
!test -d TCGA_BRCA_10x_UNI_features || \
    (mkdir -p TCGA_BRCA_10x_UNI_features && \
     tar -xzf TCGA_BRCA_10x_UNI_features.tar.gz -C TCGA_BRCA_10x_UNI_features)

As a sanity check, ensure there are exactly `242` files.

In [None]:
!ls TCGA_BRCA_10x_UNI_features | wc -l

# Overview

## From whole slide image to classification output

Our goal is to classify whole slide images (WSIs).
In particular, we want to train a deep learning model that given a WSI, predicts the presence of a genetic mutation in TP53.
In other words, our model will map a whole slide image (in $\mathbb{R}^{H \times W \times 3}$) to a scalar prediction (in $[0,1]$).
To do this, we follow a multi-step workflow consisting of:
1. split whole slide image into tiles
2. extract features
3. aggregate features
4. classify

The workflow can be summarised as:
$$
\mathbb{R}^{H \times W \times 3}
\xrightarrow{\text{tiling}} \mathbb{R}^{n \times p \times p \times 3}
\xrightarrow{\text{extract features}} \mathbb{R}^{n \times d}
\xrightarrow{\text{aggregate}} \mathbb{R}^{1 \times d}
\xrightarrow{\text{classify}} [0, 1]
$$
where $H,W$ are the dimensions of the original whole slide image,
$p=224$ is the patch size, and
$d=1024$ is the dimensionality of the feature extractor.

<img src="https://github.com/georg-wolflein/good-features/blob/master/assets/overview.png?raw=true" width=800 />


For simplicity, we will only do the **downstream training** part in this notebook.


## The structure of our data

Let's first have a look at our data.
The dataset we are using today consists of three major components:

1. The clini table contains clinical data for each patient
2. The slide table maps each slide a patient
3. The slide features contain a condensed, machine-learning-ready representation of the slides

### Clini table

The clini table contains clinical information for each patient.
Each row of the clini table describes exactly one patient.

* The column `PATIENT` contains a patient ID in the form `TCGA-site-patient` (`site` tells us which hostpital the patient is from)
* The remaining columns contain other clinical information on the patient
  * Among these, the `TP53` column indicates if there is a mutation of TP53. We will try to predict this.

In [None]:
!test -f TCGA-BRCA-DX_CLINI.csv || wget https://raw.githubusercontent.com/KatherLab/stamp_demo/refs/heads/master/TCGA-BRCA-DX_CLINI.csv -q -O TCGA-BRCA-DX_CLINI.csv

import pandas as pd
clini_df = pd.read_csv("TCGA-BRCA-DX_CLINI.csv")
clini_df

### Slide table
We often have multiple slides per patient.
The slide table matches each slide to its patient.
If a patient has multiple slides, it will appear multiple times, once for each slide they have.

In [None]:
!test -f TCGA-BRCA-DX_SLIDE.csv || wget https://raw.githubusercontent.com/KatherLab/stamp_demo/refs/heads/master/TCGA-BRCA-DX_SLIDE.csv -q -O TCGA-BRCA-DX_SLIDE.csv
slide_df = pd.read_csv("TCGA-BRCA-DX_SLIDE.csv")
slide_df

In [None]:
for _, row in (
    slide_df.groupby("PATIENT").nunique().value_counts().reset_index().iterrows()
):
    print(f"{row['count']:3d} patients have {row['FILENAME']} slides")

### Features
Finally, we have the slide features themselves.
Since whole slide images are large, too large to do machine learning on them directly, we first reduce them to a more managable form with a feature extractor.
The feature extractor is itself a neural network.
It takes a whole slide image and reduces it to a more condensed form.
While the exact mechanism by which it does so is outside the scope of this course, it compresses the size of an input 10-fold, allowing us to use neural networks to analyze them.

Since this process does take quite some time, we have already extracted the features for today's dataset in advance.
Let's have a look at the features for one particular whole slide image.

Below is a thumbnail of this WSI. We removed background areas (shown in red).

![Slide Image](https://raw.githubusercontent.com/KatherLab/stamp_demo/refs/heads/master/TCGA-BH-A0HU-01Z-00-DX1.73B38904-E4F8-4F45-BD75-A27EC833B6DE.jpg)

In [None]:
import h5py

with h5py.File(
    "TCGA_BRCA_10x_UNI_features/TCGA-BH-A0HU-01Z-00-DX1.73B38904-E4F8-4F45-BD75-A27EC833B6DE.h5",
    "r",
) as f:
    feats = f["feats"][:]
    coords = f["coords"][:]
    print("Shape of features array:", feats.shape)
    print("Shape of coordinates array:", coords.shape)


As we can see, we have $n=5597$ feature vectors in this slide. This means that the WSI was split into 5597 patches. From each patch, we extracted a $d=1024$ dimensional feature vector.

#### Features
Let's have a look at one of the feature vectors. As we will see, it consists of 1024 floating point numbers.

In [None]:
import matplotlib.pyplot as plt

print(feats[0])
plt.figure(figsize=(10, 2))
plt.bar(range(1024), feats[0])
plt.xlabel("Feature dimension")
plt.ylabel("Feature value")
pass

#### Coordinates
Let's also visualize the coordinates of the patches. We will see the shape of the WSI. Compare this to the image of the WSI above.


In [None]:
plt.plot(coords[:, 0], coords[:, 1], "o", markersize=1)
plt.xlabel("x coordinate")
plt.ylabel("y coordinate")
plt.gca().invert_yaxis()
plt.axis("equal")
pass

## Splitting our data

We will split our dataset into two subsets: one part for training, one for testing.
Specifically, we will use patients from the largest site as the training set and the second largest site as the test set.

Often times, pathological slides contain artifacts like staining differences that make it possible to infer where slides originate from.
If certain hospitals have for example a higher rate of severe cases, the network may be base its prediction based on these artifacts instead of actually medically relevant features.
By ensuring that our testing set is from another site, we will be able to determine if our network is able to generalize to new sites.

**TASK:**
Split `clini_df` and `slide_df` into a train and test set. The train set should contain all patients from the site `"BH"`, and the test set should contain all patients from the site `"A2"`.

You should create four dataframes:
1. `train_clini_df`
2. `train_slide_df`
3. `test_clini_df`
4. `test_slide_df`

Hint: the Patient IDs are in the format `TCGA-site-ID`.

In [None]:
# Insert your code here...


train_clini_df = ...
train_slide_df = ...
test_clini_df = ...
test_slide_df = ...

Let's save the dataframes as CSV.

In [None]:
train_clini_df.to_csv("TCGA-BRCA-DX_CLINI_train.csv", index=False)
train_slide_df.to_csv("TCGA-BRCA-DX_SLIDE_train.csv", index=False)
test_clini_df.to_csv("TCGA-BRCA-DX_CLINI_test.csv", index=False)
test_slide_df.to_csv("TCGA-BRCA-DX_SLIDE_test.csv", index=False)

print(f"Training set: {len(train_clini_df)} patients")
print(f"Testing set: {len(test_clini_df)} patients")
train_clini_df

## Inspecting our data

Before starting training any models, it is often worth it to inspect the data to ensure that there are no glaring problems with it.
Let's look at the TP53 column.

In [None]:
train_clini_df["TP53"].value_counts()

As we can see, one class is less frequent than the other.
This can lead to problems while training our network.
Why becomes intuitively apparent if you consider a strongly imbalanced dataset with only two classes, one making up 90% of the dataset.
The network can trivially reach an accuracy of 90% by just always chosing the more frequent class.

One approach to combat this is to weigh the classes.
In STAMP, the classes are automatically weighed in such a way that each class has the overall same contribution.
For a nine-to-one imbalanced two-class dataset, each instance of the rare class would thus be weighted as having nine times the importance of a sample of the more common class.

This can of course still lead to instabilities in training, especially if one of the rare classes is one we don't particularly care about.
In that case the network may spend too much time learning how to correctly classify the unimportant class at the cost of more interesting classes.

However, we should be fine here because the imbalance isn't too severe.


# Training a model

We will now train our model. This will take a few minutes.

In [None]:
import os
from pathlib import Path
from stamp.modeling.train import train_categorical_model_
import torch

output_dir = Path("output")
output_dir.mkdir(exist_ok=True, parents=True)

train_categorical_model_(
    output_dir=output_dir,
    clini_table=Path("TCGA-BRCA-DX_CLINI_train.csv"),
    slide_table=Path("TCGA-BRCA-DX_SLIDE_train.csv"),
    feature_dir=Path("TCGA_BRCA_10x_UNI_features"),
    patient_label="PATIENT",
    ground_truth_label="TP53",
    filename_label="FILENAME",
    categories=["0", "1"],
    # Dataset and -loader parameters
    bag_size=512,
    val_bag_size=2048,
    num_workers=min(os.cpu_count() or 1, 16),
    # Training paramenters
    batch_size=64,
    max_epochs=64,
    patience=16,
    accelerator="gpu" if torch.cuda.is_available() else "cpu",
    # Experimental features
    use_vary_precision_transform=False,
    use_alibi=False,
)

Let's plot the training and validation loss.

In [None]:
import matplotlib.pyplot as plt

lightning_logs_dir = sorted(output_dir.joinpath("lightning_logs").glob("version_*"))[-1]

history = pd.read_csv(lightning_logs_dir / "metrics.csv")
history = history.groupby(["epoch", "step"]).first().reset_index()
history.plot(x="epoch", y=["training_loss", "validation_loss"])

As we can see, the training loss decreases a lot faster and further than the validation loss.
This is to be expected:
since the network is trained on the training set, it does not only learn to recognize features relevant for classifying the target, but also learns to recognize the training images themselves.

Many of the features the network learns will not generalize.
In general, the longer we train a network, the more likely it is that it will pick up small, non-generalizing details uniquely identifying a singular image from the training set.
This is why we have a validation set:
By checking how well the network performs on the validation set, we can determine whether the network is still learning generalizable features.

In [None]:
plt.plot(history.validation_auroc)
plt.title("Validation ROC AUC Score")
plt.xlabel("Epoch")

The same should be visible in the AUROC over the progress of the training:
initially, the ROC AUC score on the validation drops sharply while the network learns to recognize well-generalizing featues.
Then, as these easy-to-recognize features have been exhausted, improvement quickly becomes slower and stagnates.

If we train for too long, the performance on the validation set may even regress, as the only thing the network is doing during training is learning how to best classify the training set and one way of doing that is to just "memorize" all the specific training samples.

# Deploying our model on external data

**As soon as we use the testing set, we are not allowed to change the experimental setup any more**.

We will now _deploy_ our model on the testing set, that is, see how well it can predict never-before seen data.
This is different from our validation set in that, while the network was not _trained_ on the validation data, we did determine which epoch's model was the best based on the validation set.
Furthermore, parts of the validation set were sampled from the same cohorts as the training set.
We can thus expect the validation set to be more akin to the training set, and thus expect the network to perform better for the validation set than the testing set.

In [None]:
from stamp.modeling.deploy import deploy_categorical_model_

deploy_categorical_model_(
    output_dir=output_dir,
    checkpoint_paths=[output_dir / "model.ckpt"],
    clini_table=Path("TCGA-BRCA-DX_CLINI_test.csv"),
    slide_table=Path("TCGA-BRCA-DX_SLIDE_test.csv"),
    feature_dir=Path("TCGA_BRCA_10x_UNI_features"),
    patient_label="PATIENT",
    ground_truth_label="TP53",
    filename_label="FILENAME",
    num_workers=min(os.cpu_count() or 1, 16),
    accelerator="gpu" if torch.cuda.is_available() else "cpu",
    bag_size=2048
)

Let's have a look at the predictions that the model makes on the test set.

In [None]:
pred_df = pd.read_csv(output_dir / "patient-preds.csv")
pred_df = pred_df[~pred_df.TP53.isna()]  # remove rows with unknown groundtruth
pred_df = pred_df.rename(columns={"TP53_1": "TP53_pred"})[["PATIENT", "TP53", "TP53_pred"]]
pred_df

As you can see, the neural network actually doesn't give us a decision, but a probability for our classes.
Depending on what we use the network for, it may actually be useful to select a higher or lower threshold:
for a screening test for example we may use a very low threshold to ensure that we definitely include all patients that have a specific illness.

One tool that can help us qualify the quality of our classifier is the Receiver-Operator-Characteristic Curve, or ROC-Curve for short:

In [None]:
from stamp.statistics.roc import plot_single_decorated_roc_curve

plot_single_decorated_roc_curve(
    ax=plt.gca(),
    y_true=(pred_df.TP53 == 1.).values,
    y_score=pred_df.TP53_pred.values,
    n_bootstrap_samples=1000,
    threshold_cmap=None,
    title="TP53",
)

The ROC curve plots the true positive rate (also called specificity) against the false positive rate (1 - sensitivity).
We can evidently force our classifier to have perfect sensitivity by classifying _every_ sample as positive (i.e. setting the classification to 0).
Similarly, we can make its specificity perfect by classifying every sample as false.
Clearly, these classifiers are not particularly useful.
The ROC curve shows us, how sensitivity and specificity fluctuate for different cutoffs.
The area under that curve (AUC) is often used as a quick-and-easy way to compare classifiers' performance.
For a more detailed explanation, check out [this visual explanation of ROC curves](https://mlu-explain.github.io/roc-auc/).

While ROC curves are one of the major endpoints for judging the quality of a classifier, they have one problem to be aware of, especially if the question of applicability of machine learning in the real world comes up.
A ROC curve only contrasts the models sensitivity with specificity.
This means that, as long as the model is able to separate the classes in a dataset, it will maintain a relatively high ROC.
While this does show that the features learned by our network _are_ transferable, it can still pose problems when actually deploying our model.

Often times, artifacts introduced by the way histopathological slides are prepared in a hospital can consistently affect how the network classifies a sample.
A difference in staining for example may consistenly cause samples to be scored too highly.
This means that for example the threshold we selected to reach a certain sensitivity may not be transferable between our validation and training set.

# Bonus

If you have time, complete the two tasks below.

*Hint:* Use ChatGPT to help you code --- we use it in the lab all the time:)

## Challenge 1: computing statistics at a threshold

In [None]:
#@title { run: "auto" }
#@markdown Move the slider to see how the threshold (decision boundary) impacts accuracy.
threshold = 0.9 #@param {type:"slider", min:0, max:1, step:0.1}

probabilities = pred_df.TP53_pred.values
predictions = probabilities >= threshold
groundtruth = pred_df.TP53.values == 1.

accuracy_score = (predictions == groundtruth).mean()
print(f"Accuracy: {accuracy_score:.2f}")

Write code in the cell below to compute the following metrics on the test set:
- sensitivity (true positive rate)
- specificity (true negative rate)


*Hints*:
- Use the code for computing accuracy above for inspiration
- Inspect the contents of the `probabilities`, `predictions` and `groundtruth` variables.

In [None]:
true_positives = ...
true_negatives = ...

sensitivity = ...
specificity = ...

print(f"Sensitivity: {sensitivity}")
print(f"Specificity: {specificity}")

## Challenge 2: plot the precision recall curve

Use `matplotlib` to plot the precision recall curve.

In [None]:
# Insert code here...

plt.plot()

## Challenge 3: dummy classifier

Earlier we talked about the dangers of imbalanced datasets.
Imagine we had a model that always predicts `TP53 = 1` no matter its input.
Write code to compute how this "dummy classifier" would fare on the test set, regarding:
- accuracy
- balanced accuracy
- sensitivity
- specificity

In [None]:
# write code here...

This concludes our hands-on for deep learning in histopathology.
As you have seen, deep learning can be used to answer a variety of histopathological questions.
However, while current research is promising, there are still a lot of steps remaining to make it a reliable part of medical practice.