# Weakly-Supervised Deep Learning for Cancer Diagnosis in Computational Pathology

- Presenter:
   - Guillaume Jaume (gjaume@bwh.harvard.edu)
   - Postdoctoral Researcher at Harvard Medical School and Brigham and Women's Hospital
- Initally proposed and written by Richard J. Chen (richardchen@g.harvard.edu)

![](https://user-images.githubusercontent.com/10300839/232984533-c22822b8-df80-4b95-80e2-93dde2409bbf.png)

**Definitions:**

- *Computational pathology (CPath):* Computational methods based on the microscopic analysis of cells and tissues for the study of disease.

- *Digital pathology:* A set of tools and systems for the acquisition, management and diagnosis of pathology glass slides in a digital setting.

- *Whole slide image (WSI):* An image obtained by digitizing a glass slide at high-resolution using a scanner.

- *Hematoxylin and eosin (H&E) staining:* The reference stain for histological analysis of tissues for visualization of cell nuclei (in purple) with extracellular information and cytoplasm (in pink).


**Background**:

Computational Pathology aims to automate, assist, and augment the clinical practice of pathology using computational tools based on Artificial Intelligence.

Tissue phenotyping is a fundamental problem in computational pathology (CPATH) to characterize histopathologic features for cancer diagnosis, prognosis, and prediction of treatment response. Unlike natural images, whole-slide imaging is a challenging computer vision domain in which image resolutions can be as large as $150{,}000 \times 150{,}000$ pixels (>50 GB to load the entire image in RAM).

To address this computational and memory bottleneck, the majority of state-of-the-art methods use a three-stage, weakly-supervised pipeline based on multiple instance learning (MIL):
1. Tissue patching at a single magnification objective ("zoom"), e.g., 20x magnification

2. Patch-level feature extraction to construct a set of patch embeddings (compress patches by a factor 100~500)

3. Global pooling of embeddings to construct a slide-level representation for weak-supervision using slide-level labels (e.g., subtype, grade, stage, survival, origin).

**Notebook Objective**: The following tutorial aims to distinguish Lung Adenocarcinoma (LUAD, 40% of all lung cancer) vs. Lung Squamous Cell Carcinoma (LUSC, 30% of all lung cancer) (see [Lu et al., Data-efficient and weakly supervised computational pathology on whole-slide images  Nature BME 2021](https://www.nature.com/articles/s41551-020-00682-w) and codebase [CLAM](https://github.com/mahmoodlab/CLAM). Specifically, we will:
- Train and evaluate a "naive" MIL algorithm called `AverageMIL`, which takes the average of patch embeddings (as the global pooling operator).

- Implement a more sophisticated algorithm called Attention-Based Multiple Instance Learning (`ABMIL`), which learns attention weights for computing a weighted average of patch embeddings.

- Compare and contrast `AverageMIL` and `ABMIL`, discussing which algorithm performs better and potential limitations.

**About this notebook**:
- Model implementation and training is directly adapted from [CLAM](https://github.com/mahmoodlab/CLAM). CLAM includes many additional features (e.g. - letting users set up optimizers, model types, logging information, and other hyper-parameters) left out due to making this notebook as simple to run as possible for teaching purposes. To use all features, please see CLAM.

- Though this notebook is based off of CLAM, the method-of-interest that you will be implementing is not CLAM, but a different method called ABMIL from [Ilse et al. Attention-Based Multiple Instance Learning ICML 2018](https://arxiv.org/abs/1802.04712), which CLAM is derived from.

- Though pre-extracted features were generated using the CLAM codebase, the encoder was not a truncated ResNet-50 pretrained on ImageNet (dimension 1024) at 20 $\times$ resolution. Instead, we extracted features with a much smaller CNN encoder (dimension 320) at 10 $\times$ resolution, which shrinks the size of the dataset from ~11 GB to ~3.96 GB of storage (**download link for pre-extracted features in the cell below**). In addition, a torch.seed is set for reproducibility (all outputs should be deterministic).

### Colab Installation, Data Download, & Dependencies

- Gets clinical metadata csv for tcga-luad and tcga-lusc with predefined train/val/test splits
- Gets pre-extracted features for tcga-luad and tcga-lusc diagnostic WSIs (1043 WSIs total, ~3.96 GB in size, ~67 seconds to download)

Alternatively, you can download the data directly from Dropbox to your local computer, and run this Colab Notebook locally


In [1]:
use_drive = False

if use_drive:
  from google.colab import drive
  drive.mount('/content/drive')
  !mkdir -p "/content/drive/My Drive/ai4healthsummerschool/"

In [None]:
# either download in colab (data will be deleted when re-starting) or mount your labdrive (preferred, but requires 4GB of storage)
if use_drive:
  !wget https://www.dropbox.com/s/5wuvu791vwntg9o/tcga_lung_splits.csv -P "/content/drive/My Drive/ai4healthsummerschool"
  !wget https://www.dropbox.com/s/euepd2owxvuwr7v/feats_pt.zip
  !unzip -q feats_pt.zip
  !mv feats_pt "/content/drive/My Drive/ai4healthsummerschool"
else:
  !wget https://www.dropbox.com/s/5wuvu791vwntg9o/tcga_lung_splits.csv
  !wget https://www.dropbox.com/s/euepd2owxvuwr7v/feats_pt.zip
  !unzip -q feats_pt.zip


In [1]:
import os
import copy
import matplotlib.pyplot as plt
import seaborn
import numpy as np
import pandas as pd
import sklearn.metrics
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
print(torch.__version__)

2.0.1


### WSI data preprocessing for histology slides in the TCGA-Lung cohort


![](https://user-images.githubusercontent.com/10300839/232984010-7f8e3a6f-e0c5-4847-8d0f-747460055528.png)

To process WSIs, tools such as [CLAM](https://github.com/mahmoodlab/CLAM) are typically used for tissue patching and non-overlapping patch feature extraction. Though easy-to-use, using CLAM for feature processing would require downloading gigapixel WSIs (> 1000 WSIS in TCGA-LUAD and TCGA-LUSC), which exceeds over >100GB of storage space. To alleviate this issue, this problem set provides pre-extracted features (processed via CLAM, but using a much smaller vision encoder with $D = 320$). However, to still illustrate how CLAM preprocessing works, the cell below describes a high-level overview on how WSIs are formulated as `[M x D]`-dim bag of patch embeddings, where `M` is the number of tissue patches and `D` is the hidden dimension size of your encoder. Again, please use [CLAM](https://github.com/mahmoodlab/CLAM) if you are interested in re-generating these features.

**Note:** This cell doesn't need to be run to train the final models!

In [2]:
# Let's say we have a "set of M [256 x 256 x 3] image patches (M = 512), which are taken from non-overlapping patches in the WSI.
M = 2
X = torch.randn(M, 3, 256, 256) # Arranged in (Batch, Channel, Width, Height) format or (B, C, W, H) for short
print("WSI Shape:", X.shape)

# We would for instance use a CNN model (pretrained on ImageNet) as our vision encoder for pre-extracting "compressed" representations from each patch
cnn = torchvision.models.mobilenet_v3_small()
cnn.eval()

# Since this model was taken from torchvision and trained on ImageNet, the output of the model are the probability scores of the ImageNet classes (1000 classes total).
# To extract useful features from each patch, we have to use the penultimate layer(s) of the CNN, before feeding it into a linear layer.
print("Probability Scores for ImageNet:", cnn.forward(X[:1]).shape)

# To extract the penultimate features, we can define a new function that returns the features
# before giving it to internal classifier layer within the model.
# Again, we want to use the pretrained features on ImageNet, but don't want the classification scores for "ImageNet" classes!
# # See the below documentation for how the forward pass in MobileNetV3 works.
# https://pytorch.org/vision/main/_modules/torchvision/models/mobilenetv3.html#mobilenet_v3_small
encoder = lambda x: torch.flatten(cnn.avgpool(cnn.features(x)), 1)
print("Feature Embedding Shape:", encoder(X[:1]).shape)

# We can now use our encoder to extract features for each patch.
# Typically, the # of non-overlapping patches in a WSI is ~15,000. Thus, we often have to extract patch features in mini-batches.
batch_size = 32
H = []
for bag_idx in range(0, M, batch_size):
    H.append(encoder(X[bag_idx:(bag_idx+batch_size)]).cpu().detach().numpy())
print("Bag Shape", np.vstack(H).shape)

WSI Shape: torch.Size([2, 3, 256, 256])
Probability Scores for ImageNet: torch.Size([1, 1000])
Feature Embedding Shape: torch.Size([1, 576])
Bag Shape (2, 576)


### Data Exploration

**Note:** This cell doesn't need to be run to train the final models!

In [3]:
# where we downloaded the features and label csv to
use_drive = False
if use_drive:
  feats_dirpath, csv_fpath = '/content/drive/My Drive/ai4healthsummerschool/feats_pt/', '/content/drive/My Drive/ai4healthsummerschool/tcga_lung_splits.csv'
else:
  feats_dirpath, csv_fpath = './data/processed/feats_pt/', './data/processed/tcga_lung_splits.csv'

# label csv matches case_id (patient), slide_id (WSI image filename), and diagnosis (LUAD vs LUSC)
# as well as pre-defined splits (train / val / test)
df = pd.read_csv(csv_fpath)
display(df)
display(df[['split', 'OncoTreeCode']].value_counts())

# extracted feature filenames + slide_id column match
feats_pt_fnames = pd.Series(os.listdir(feats_dirpath))
print("Example filenames for extracted features:", list(feats_pt_fnames[:5]))
print("Overlap of extracted feature filenames + slide_id column:",
      len(set(df['slide_id']).intersection(set(feats_pt_fnames.str[:-3]))))

# statistics about the size of each bag
bag_sizes = []
for e in os.scandir(feats_dirpath):
    feats_pt = torch.load(e.path)    # [M x d]-dim tensor
    bag_sizes.append(feats_pt.shape[0])
print('Mean Bag Size:', np.mean(bag_sizes))
print('Std Bag Size:', np.std(bag_sizes))

Unnamed: 0,case_id,slide_id,tumor_type,OncoTreeSiteCode,main_cancer_type,sex,project_id,Diagnosis,OncoTreeCode,OncoTreeCode_Binarized,split
0,TCGA-73-4676,TCGA-73-4676-01Z-00-DX1.4d781bbc-a45e-4f9d-b6b...,Primary,LUNG,Non-Small Cell Lung Cancer,M,TCGA-LUAD,Lung Adenocarcinoma,LUAD,0,train
1,TCGA-MP-A4T6,TCGA-MP-A4T6-01Z-00-DX1.085C4F5A-DB1B-434A-9D6...,Primary,LUNG,Non-Small Cell Lung Cancer,F,TCGA-LUAD,Lung Adenocarcinoma,LUAD,0,train
2,TCGA-78-7167,TCGA-78-7167-01Z-00-DX1.f79e1a9b-a3eb-4c91-a1f...,Primary,LUNG,Non-Small Cell Lung Cancer,M,TCGA-LUAD,Lung Adenocarcinoma,LUAD,0,train
3,TCGA-L9-A444,TCGA-L9-A444-01Z-00-DX1.88CF6F01-0C1F-4572-81E...,Primary,LUNG,Non-Small Cell Lung Cancer,F,TCGA-LUAD,Lung Adenocarcinoma,LUAD,0,train
4,TCGA-55-8097,TCGA-55-8097-01Z-00-DX1.2f847b65-a5dc-41be-9dd...,Primary,LUNG,Non-Small Cell Lung Cancer,F,TCGA-LUAD,Lung Adenocarcinoma,LUAD,0,train
...,...,...,...,...,...,...,...,...,...,...,...
1038,TCGA-21-A5DI,TCGA-21-A5DI-01Z-00-DX1.E9123261-ADE7-468C-9E9...,Primary,LUNG,Non-Small Cell Lung Cancer,M,TCGA-LUSC,Lung Squamous Cell Carcinoma,LUSC,1,test
1039,TCGA-77-7465,TCGA-77-7465-01Z-00-DX1.25e4b0b4-4948-432f-801...,Primary,LUNG,Non-Small Cell Lung Cancer,M,TCGA-LUSC,Lung Squamous Cell Carcinoma,LUSC,1,test
1040,TCGA-34-8454,TCGA-34-8454-01Z-00-DX1.A2308ED3-E430-4448-853...,Primary,LUNG,Non-Small Cell Lung Cancer,F,TCGA-LUSC,Lung Squamous Cell Carcinoma,LUSC,1,test
1041,TCGA-77-7138,TCGA-77-7138-01Z-00-DX1.8c912762-0829-4692-92a...,Primary,LUNG,Non-Small Cell Lung Cancer,M,TCGA-LUSC,Lung Squamous Cell Carcinoma,LUSC,1,test


split  OncoTreeCode
train  LUAD            433
       LUSC            415
test   LUAD             49
       LUSC             49
val    LUAD             49
       LUSC             48
Name: count, dtype: int64

Example filenames for extracted features: ['TCGA-85-8052-01Z-00-DX1.26b66ae7-b73f-4263-85f7-e82dab5b657b.pt', 'TCGA-85-8353-01Z-00-DX1.2A333CFA-3D8A-41B2-9D08-8AAF431BFE54.pt', 'TCGA-55-6987-01Z-00-DX1.0c52b721-2209-4818-af8f-b22d37e6e81e.pt', 'TCGA-60-2703-01Z-00-DX1.13cdede5-0135-4e05-9478-3b728cad247e.pt', 'TCGA-55-7573-01Z-00-DX1.43a4bbd2-6a3b-4910-9356-2d750a736817.pt']
Overlap of extracted feature filenames + slide_id column: 1043
Mean Bag Size: 3259.9090038314175
Std Bag Size: 2133.97437395412


### Model 1:  AverageMIL

Implemented is a minimalistic training setup that performs weakly-supervised learning via `AverageMIL` on LUAD vs. LUSC subtyping using 1043 diagnostic H\&E tissue slides from the The Cancer Genome Atlas (features already pre-extracted and downloaded from installation, clinical metadata for all case and slide IDs also downloaded).

You can run the cells in the Google Colab Notebook and see how well this algorithm performs in 20 epochs.


In [4]:
class AverageMIL(nn.Module):
    def __init__(self, input_dim=320, hidden_dim=64, dropout=0.25, n_classes=2):
        r"""
        AverageMIL, a naive MIL algorithm that average pools all patch features.

        Args:
            input_dim (int): input feature dimension.
            hidden_dim (int): hidden layer dimension.
            dropout (float): Dropout probability.
            n_classes (int): Number of classes.
        """
        super(AverageMIL, self).__init__()
        self.inst_level_fc = nn.Sequential(*[nn.Linear(input_dim, hidden_dim), nn.ReLU(), nn.Dropout(dropout)]) # Fully-Connected Layer, applied "instance-wise" to each embedding
        self.bag_level_classifier = nn.Linear(hidden_dim, n_classes)                                            # Bag-Level Classifier

    def forward(self, H):
        r"""
        Takes as input a [M x D]-dim bag of patch features (representing a WSI), and outputs: 1) logits for classification, 2) un-normalized attention scores.

        Args:
            H (torch.Tensor): [M x D]-dim bag of patch features (representing a WSI)

        Returns:
            logits (torch.Tensor): [1 x n_classes]-dim tensor of un-normalized logits for classification task.
            None (no attention scores to return)
        """
        H = self.inst_level_fc(H)                   # 1. Preprocesses each "instance-level" embedding to be "hidden-dim"-dim size
        z = H.mean(dim=0).unsqueeze(dim=0)          # 2. Average of Patch Embeddings
        logits = self.bag_level_classifier(z)       # 3. Bag-Level Classifier
        return logits, None


class MILDataset(torch.utils.data.dataset.Dataset):
    r"""
    torch.utils.data.dataset.Dataset object that loads pre-extracted features per WSI from a CSV.

    Args:
        feats_dirpath (str): Path to pre-extracted patch features (assumes that these features are saved as a *.pt object with it's corresponding slide_id as the filename)
        csv_fpath (str): Path to CSV file which contains: 1) Case ID, 2) Slide ID, 3) split information (train / val / test), and 4) label columns of interest for classification.
        which_split (str): Split that is used for subsetting the CSV (choices: ['train', 'val', 'test'])
        n_classes (int): Number of classes (default == 2 for LUAD vs LUSC subtyping)
    """
    def __init__(self, feats_dirpath='./data/processed/', csv_fpath='./data/processed/tcga_lung_splits.csv', which_split='train', which_labelcol='OncoTreeCode_Binarized'):
        self.feats_dirpath, self.csv, self.which_labelcol = feats_dirpath, pd.read_csv(csv_fpath), which_labelcol
        self.csv_split = self.csv[self.csv['split']==which_split]

    def __getitem__(self, index):
        features = torch.load(os.path.join(self.feats_dirpath, self.csv_split.iloc[index]['slide_id']+'.pt'))
        label = self.csv_split.iloc[index][self.which_labelcol]
        return features, label

    def __len__(self):
        return self.csv_split.shape[0]


def traineval_epoch(epoch, model, loader, optimizer=None, loss_fn=nn.CrossEntropyLoss(), split='train', device=torch.device("cuda" if torch.cuda.is_available() else "cpu"), verbose=1, print_every=300):
    r"""
    Function that performs one epoch of training / evaluation with torch.nn model over torch.utils.data.DataLoader object.
    Typically, these functions are defined separately for training and validation, but to save line space, we have combined the two.

    Args:
        epoch (int): Current epoch of training / evaluation (used for logging).
        model (torch.nn): MIL model for processing bag of patch features.
        loader (torch.utils.data.DataLoader): Object for getting bag of patch features per WSI.
        loss_fn (torch.nn): Loss function.
        split (str): Which split, used for setting up model + calculating loss + calculating gradients.
        device (torch): Object representing the device on which a torch.Tensor will be allocated.
        verbose (int): Whether to print summary epoch results (verbose >=1) and iteration info (verbose >=2).
        print_every (int): How many batch iterations

    Returns:
        log_dict (dict): Dictionary for logging loss and performance for train / val / test split.
    """
    model.train() if (split == 'train') else model.eval()       # turning on whether model should be used for training or evaluation
    total_loss, Y_probs, labels = 0.0, [], []                   # tracking loss + logits/labels for performance metrics
    for batch_idx, (X_bag, label) in enumerate(loader):
        # Since we assume batch size == 1, we want to prevent torch from collating our bag of patch features as [1 x M x D] torch tensors.
        X_bag, label = X_bag[0].to(device), label.to(device)

        if (split == 'train'):
            logits, A_norm = model(X_bag)
            loss = loss_fn(logits, label)
            loss.backward(), optimizer.step(), optimizer.zero_grad()
        else:
            with torch.no_grad(): logits, A_norm = model(X_bag)
            loss = loss_fn(logits, label)

        # Track total loss, logits, and current progress
        total_loss += loss.item()
        Y_probs.append(torch.softmax(logits, dim=-1).cpu().detach().numpy())
        labels.append(label.cpu().detach().numpy())
        if ((batch_idx + 1) % print_every == 0) and (verbose >= 2):
            print(f'Epoch {epoch}:\t Batch {batch_idx}\t Avg Loss: {total_loss / (batch_idx+1):.04f}\t Label: {label.item()}\t Bag Size: {X_bag.shape[0]}')

    # Compute balanced accuracy and AUC-ROC from saved logits / labels
    Y_probs, labels = np.vstack(Y_probs), np.concatenate(labels)
    log_dict = {f'{split} loss': total_loss/len(loader),
                f'{split} acc': sklearn.metrics.balanced_accuracy_score(labels, Y_probs.argmax(axis=1)),
                f'{split} auc': sklearn.metrics.roc_auc_score(labels, Y_probs[:, 1])}

    # Print out end-of-epoch information
    if (verbose >= 1):
        print(f'### ({split.capitalize()} Summary) ###')
        print(f'Epoch {epoch}:\t' + f'\t'.join([f'{k.capitalize().rjust(10)}: {log_dict[k]:.04f}' for k,v in log_dict.items()]))
    return log_dict

In [5]:
# Sets the random seed (for reproducibility)
torch.manual_seed(2023)

# Get data loaders for train-val-test split evaluation
feats_dirpath, csv_fpath = './data/processed/feats_pt/', './data/processed/tcga_lung_splits.csv'
loader_kwargs = {
    'batch_size': 1,
    'num_workers': 0,
    'pin_memory': False
}
train_dataset, val_dataset, test_dataset = [MILDataset(feats_dirpath, csv_fpath, which_split=split) for split in ['train', 'val', 'test']]
train_loader = torch.utils.data.DataLoader(train_dataset, shuffle=True, **loader_kwargs)
val_loader = torch.utils.data.DataLoader(val_dataset, shuffle=False, **loader_kwargs)
test_loader = torch.utils.data.DataLoader(test_dataset, shuffle=False, **loader_kwargs)

# Get model, optimizer, and loss function
device = torch.device('cpu')
model = AverageMIL(input_dim=320, hidden_dim=64).to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-5)
loss_fn = nn.CrossEntropyLoss()

# Set-up train-validation loop and early stopping
num_epochs, min_early_stopping, patience, counter = 20, 10, 5, 0
lowest_val_loss, best_model = np.inf, None
all_train_logs, all_val_logs = [], []
for epoch in range(num_epochs):
    train_log = traineval_epoch(epoch, model, train_loader, optimizer=optimizer, split='train', device=device, verbose=2, print_every=200)
    val_log = traineval_epoch(epoch, model, val_loader, optimizer=None, split='val', device=device, verbose=1)
    val_loss = val_log['val loss']
    # Early stopping: If validation loss does not go down for <patience> epochs after <min_early_stopping> epochs, stop model training early
    if (epoch > min_early_stopping):
        if (val_loss < lowest_val_loss):
            print(f'Resetting early-stopping counter: {lowest_val_loss:.04f} -> {val_loss:.04f}...')
            lowest_val_loss, counter, best_model = val_loss, 0, copy.deepcopy(model)
        else:
            print(f'Early-stopping counter updating: {counter}/{patience} -> {counter+1}/{patience}...')
            counter += 1

    if counter >= patience: break
    print()

# Report best model (lowest validation loss) on test split
best_model = model if (best_model is None) else best_model
test_log = traineval_epoch(epoch, best_model, test_loader, optimizer=None, split='test', device=device, verbose=1)


Epoch 0:	 Batch 199	 Avg Loss: 0.7174	 Label: 1	 Bag Size: 957
Epoch 0:	 Batch 399	 Avg Loss: 0.7104	 Label: 1	 Bag Size: 1131
Epoch 0:	 Batch 599	 Avg Loss: 0.7104	 Label: 1	 Bag Size: 1342
Epoch 0:	 Batch 799	 Avg Loss: 0.7077	 Label: 1	 Bag Size: 2319
### (Train Summary) ###
Epoch 0:	Train loss: 0.7071	 Train acc: 0.5226	 Train auc: 0.5201
### (Val Summary) ###
Epoch 0:	  Val loss: 0.6809	   Val acc: 0.5731	   Val auc: 0.6832

Epoch 1:	 Batch 199	 Avg Loss: 0.6903	 Label: 0	 Bag Size: 3735
Epoch 1:	 Batch 399	 Avg Loss: 0.6977	 Label: 0	 Bag Size: 4140
Epoch 1:	 Batch 599	 Avg Loss: 0.6950	 Label: 1	 Bag Size: 4322
Epoch 1:	 Batch 799	 Avg Loss: 0.6915	 Label: 0	 Bag Size: 3104
### (Train Summary) ###
Epoch 1:	Train loss: 0.6906	 Train acc: 0.5377	 Train auc: 0.5630
### (Val Summary) ###
Epoch 1:	  Val loss: 0.6722	   Val acc: 0.6305	   Val auc: 0.7105

Epoch 2:	 Batch 199	 Avg Loss: 0.6762	 Label: 1	 Bag Size: 933
Epoch 2:	 Batch 399	 Avg Loss: 0.6767	 Label: 1	 Bag Size: 1869
Epoc

### Model 2. Implement Attention-Based Multiple Instance Learning (ABMIL)

Following your experimentation with `AverageMIL`, you are ready to implement a more sophisticated model for LUAD vs. LUSC subtyping. Formally, let $\mathbf{H}=\left\{\mathbf{h}_1, \ldots, \mathbf{h}_M\right\} \in \mathbb{R}^{M \times D}$  be a bag of $M$ patch embeddings, with each embedding having dimension size $D$. Ilse et al. 2018 proposed the following attention-based MIL pooling operation:

$$
\mathbf{z} =\sum_{i=1}^M a_i \mathbf{h}_i, \quad \text{where} \enspace a_i=\frac{\exp \left\{\mathbf{w}^{\top}\left(\tanh \left(\mathbf{V h}_{i} ^ { \top }\right) \odot \operatorname{sigm}\left(\mathbf{U h}_i^{\top}\right)\right)\right\}}{\sum_{j=1}^M \exp \left\{\mathbf{w}^{\top}\left(\tanh \left(\mathbf{V} \mathbf{h}_j^{\top}\right) \odot \operatorname{sigm}\left(\mathbf{U h}_j^{\top}\right)\right)\right\}}
$$

where $\mathbf{w} \in \mathbb{R}^{L \times 1}$, $\mathbf{V} \in \mathbb{R}^{L \times D}$, and $\mathbf{U} \in \mathbb{R}^{L \times D}$ are learnable neural network parameters (implemented as fully-connected layers), and $\mathbf{z} \in \mathbb{R}^{D}$ is the weighted average of all patch embeddings in $\mathbf{H}$. The hyperbolic tangent $\tanh (\cdot)$ element-wise non-linearity and sigmoid non-linearity are utilized for proper gradient flow.

Via PyTorch, the mathematical expression for computing $a_m$ is implemented as the `torch.nn` module `AttentionTanhSigmoidGating`, which we use as a layer in `ABMIL` for calculating the weighted average of patch embeddings.


In [6]:
class AttentionTanhSigmoidGating(nn.Module):
    def __init__(self, D=64, L=64, dropout=0.25):
        r"""
        Global attention pooling layer with tanh non-linearity and sigmoid gating (Ilse et al. 2018).

        Args:
            D (int): input feature dimension.
            L (int): hidden layer dimension. Notation changed from M from Ilse et al 2018, as M is overloaded to also describe # of patch embeddings in a WSI.
            dropout (float): Dropout probability.

        Returns:
            A_norm (torch.Tensor): [M x 1]-dim tensor of normalized attention scores (sum to 1)
        """
        super(AttentionTanhSigmoidGating, self).__init__()
        self.tanhV = nn.Sequential(*[nn.Linear(D, L), nn.Tanh(), nn.Dropout(dropout)])
        self.sigmU = nn.Sequential(*[nn.Linear(D, L), nn.Sigmoid(), nn.Dropout(dropout)])
        self.w = nn.Linear(L, 1)

    def forward(self, H):
        A_raw = self.w(self.tanhV(H).mul(self.sigmU(H))) # exponent term
        A_norm = F.softmax(A_raw, dim=0)                 # apply softmax to normalize weights to 1
        assert abs(A_norm.sum() - 1) < 1e-3              # Assert statement to check sum(A) ~= 1
        return A_norm


class ABMIL(nn.Module):
    def __init__(self, input_dim=320, hidden_dim=64, dropout=0.25, n_classes=2):
        r"""
        Attention-Based Multiple Instance Learning (Ilse et al. 2018).

        Args:
            input_dim (int): input feature dimension.
            hidden_dim (int): hidden layer dimension.
            dropout (float): Dropout probability.
            n_classes (int): Number of classes.
        """
        super(ABMIL, self).__init__()
        self.inst_level_fc = nn.Sequential(*[nn.Linear(input_dim, hidden_dim), nn.ReLU(), nn.Dropout(dropout)]) # Fully-Connected Layer, applied "instance-wise" to each embedding
        self.global_attn = AttentionTanhSigmoidGating(L=hidden_dim, D=hidden_dim)                              # Attention Function
        self.bag_level_classifier = nn.Linear(hidden_dim, n_classes)                                            # Bag-Level Classifier

    def forward(self, X: torch.randn(100, 320)):
        r"""
        Takes as input a [M x D]-dim bag of patch features (representing a WSI), and outputs: 1) logits for classification, 2) un-normalized attention scores.

        Args:
            X (torch.Tensor): [M x D]-dim bag of patch features (representing a WSI)

        Returns:
            logits (torch.Tensor): [1 x n_classes]-dim tensor of un-normalized logits for classification task.
            A_norm (torch.Tensor): [M,]- or [M x 1]-dim tensor of attention scores.
        """
        H_inst = self.inst_level_fc(X)         # 1. Process each feature embedding to be of size "hidden-dim"
        A_norm = self.global_attn(H_inst)      # 2. Get normalized attention scores for each embedding (s.t. sum(A_norm) ~= 1)
        z = torch.sum(A_norm * H_inst, dim=0)  # 3. Output of global attention pooling over the bag
        logits = self.bag_level_classifier(z).unsqueeze(dim=0)   # 4. Get un-normalized logits for classification task
        try:
            assert logits.shape == (1,2)
        except:
            print(f"Logit tensor shape is not formatted correctly. Should output [1 x 2] shape, but got {logits.shape} shape")
        return logits, A_norm


In [7]:

# Sets the random seed (for reproducibility)
torch.manual_seed(2023)

# Get data loaders for train-val-test split evaluation
if use_drive:
  feats_dirpath = '/content/drive/My Drive/ai4healthsummerschool/feats_pt/'
  csv_fpath = '/content/drive/My Drive/ai4healthsummerschool/tcga_lung_splits.csv'
else:
  feats_dirpath, csv_fpath = './data/processed/feats_pt/', './data/processed/tcga_lung_splits.csv'

display(pd.read_csv(csv_fpath).head(10)) # visualize data
loader_kwargs = {'batch_size': 1, 'num_workers': 0, 'pin_memory': False} # Batch size set to 1 due to variable bag sizes. Hard to collate.
train_dataset, val_dataset, test_dataset = [MILDataset(feats_dirpath, csv_fpath, which_split=split) for split in ['train', 'val', 'test']]
train_loader = torch.utils.data.DataLoader(train_dataset, shuffle=True, **loader_kwargs)
val_loader = torch.utils.data.DataLoader(val_dataset, shuffle=False, **loader_kwargs)
test_loader = torch.utils.data.DataLoader(test_dataset, shuffle=False, **loader_kwargs)

# Get model, optimizer, and loss function
device = torch.device('cpu')
model = ABMIL(input_dim=320, hidden_dim=64).to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-5)
loss_fn = nn.CrossEntropyLoss()

# Set-up train-validation loop and early stopping
num_epochs, min_early_stopping, patience, counter = 20, 10, 5, 0
lowest_val_loss, best_model = np.inf, None
all_train_logs, all_val_logs = [], [] # TODO: do something with train_log / val_log every epoch to help visualize performance curves?
for epoch in range(num_epochs):
    train_log = traineval_epoch(epoch, model, train_loader, optimizer=optimizer, split='train', device=device, verbose=2, print_every=200)
    val_log = traineval_epoch(epoch, model, val_loader, optimizer=None, split='val', device=device, verbose=1)
    val_loss = val_log['val loss']

    # Early stopping: If validation loss does not go down for <patience> epochs after <min_early_stopping> epochs, stop model training early
    if (epoch > min_early_stopping):
        if (val_loss < lowest_val_loss):
            print(f'Resetting early-stopping counter: {lowest_val_loss:.04f} -> {val_loss:.04f}...')
            lowest_val_loss, counter, best_model = val_loss, 0, copy.deepcopy(model)
        else:
            print(f'Early-stopping counter updating: {counter}/{patience} -> {counter+1}/{patience}...')
            counter += 1

    if counter >= patience: break
    print()

# Report best model (lowest validation loss) on test split
best_model = model if (best_model is None) else best_model
test_log = traineval_epoch(epoch, best_model, test_loader, optimizer=None, split='test', device=device, verbose=1)



Unnamed: 0,case_id,slide_id,tumor_type,OncoTreeSiteCode,main_cancer_type,sex,project_id,Diagnosis,OncoTreeCode,OncoTreeCode_Binarized,split
0,TCGA-73-4676,TCGA-73-4676-01Z-00-DX1.4d781bbc-a45e-4f9d-b6b...,Primary,LUNG,Non-Small Cell Lung Cancer,M,TCGA-LUAD,Lung Adenocarcinoma,LUAD,0,train
1,TCGA-MP-A4T6,TCGA-MP-A4T6-01Z-00-DX1.085C4F5A-DB1B-434A-9D6...,Primary,LUNG,Non-Small Cell Lung Cancer,F,TCGA-LUAD,Lung Adenocarcinoma,LUAD,0,train
2,TCGA-78-7167,TCGA-78-7167-01Z-00-DX1.f79e1a9b-a3eb-4c91-a1f...,Primary,LUNG,Non-Small Cell Lung Cancer,M,TCGA-LUAD,Lung Adenocarcinoma,LUAD,0,train
3,TCGA-L9-A444,TCGA-L9-A444-01Z-00-DX1.88CF6F01-0C1F-4572-81E...,Primary,LUNG,Non-Small Cell Lung Cancer,F,TCGA-LUAD,Lung Adenocarcinoma,LUAD,0,train
4,TCGA-55-8097,TCGA-55-8097-01Z-00-DX1.2f847b65-a5dc-41be-9dd...,Primary,LUNG,Non-Small Cell Lung Cancer,F,TCGA-LUAD,Lung Adenocarcinoma,LUAD,0,train
5,TCGA-44-8119,TCGA-44-8119-01Z-00-DX1.1EBEBFA7-22DB-4365-9DF...,Primary,LUNG,Non-Small Cell Lung Cancer,M,TCGA-LUAD,Lung Adenocarcinoma,LUAD,0,train
6,TCGA-49-AAR2,TCGA-49-AAR2-01Z-00-DX1.1F09F896-446E-4C55-8D0...,Primary,LUNG,Non-Small Cell Lung Cancer,M,TCGA-LUAD,Lung Adenocarcinoma,LUAD,0,train
7,TCGA-L9-A743,TCGA-L9-A743-01Z-00-DX1.27ED2955-E8B5-4A3C-ADA...,Primary,LUNG,Non-Small Cell Lung Cancer,M,TCGA-LUAD,Lung Adenocarcinoma,LUAD,0,train
8,TCGA-99-8032,TCGA-99-8032-01Z-00-DX1.7380b78f-ea25-43e0-ac9...,Primary,LUNG,Non-Small Cell Lung Cancer,M,TCGA-LUAD,Lung Adenocarcinoma,LUAD,0,train
9,TCGA-55-6972,TCGA-55-6972-01Z-00-DX1.0b441ad0-c30f-4f63-849...,Primary,LUNG,Non-Small Cell Lung Cancer,M,TCGA-LUAD,Lung Adenocarcinoma,LUAD,0,train


Epoch 0:	 Batch 199	 Avg Loss: 0.7337	 Label: 0	 Bag Size: 507
Epoch 0:	 Batch 399	 Avg Loss: 0.7257	 Label: 0	 Bag Size: 1163
Epoch 0:	 Batch 599	 Avg Loss: 0.7049	 Label: 1	 Bag Size: 4417
Epoch 0:	 Batch 799	 Avg Loss: 0.7070	 Label: 0	 Bag Size: 1598
### (Train Summary) ###
Epoch 0:	Train loss: 0.7061	 Train acc: 0.5028	 Train auc: 0.5272
### (Val Summary) ###
Epoch 0:	  Val loss: 0.6894	   Val acc: 0.4989	   Val auc: 0.6361

Epoch 1:	 Batch 199	 Avg Loss: 0.7111	 Label: 1	 Bag Size: 5048
Epoch 1:	 Batch 399	 Avg Loss: 0.7017	 Label: 0	 Bag Size: 3107
Epoch 1:	 Batch 599	 Avg Loss: 0.6974	 Label: 0	 Bag Size: 488
Epoch 1:	 Batch 799	 Avg Loss: 0.6935	 Label: 0	 Bag Size: 709
### (Train Summary) ###
Epoch 1:	Train loss: 0.6935	 Train acc: 0.5581	 Train auc: 0.5682
### (Val Summary) ###
Epoch 1:	  Val loss: 0.6819	   Val acc: 0.5793	   Val auc: 0.6658

Epoch 2:	 Batch 199	 Avg Loss: 0.6859	 Label: 0	 Bag Size: 3900
Epoch 2:	 Batch 399	 Avg Loss: 0.6848	 Label: 0	 Bag Size: 2445
Epoch

### Discussion. Compare and Contrast AverageMIL and ABMIL

Compare and contrast the **validation** and **test** performance of `AverageMIL` and `ABMIL`. In particular:

2. Which model performed better on overall AUC and balanced accuracy on the **test split**? Which class (LUAD or LUSC) was more prone to mis-classification by each model?
3. The following link at [http://clam.mahmoodlab.org](http://clam.mahmoodlab.org) visualizes high-attention heatmaps for LUAD vs LUSC subtyping via CLAM (similar to `ABMIL`) and confidence scores for each slides. If you were a clinical pathologist looking at these visualizations, what insights or concerns would you have in letting an AI algorithm assist you medical diagnoses?
4. The experimental setup in this problem set is limited to only evaluating on data from TCGA. List three techniques used in Lu et al. 2021 (or other relevant biomedical imaging $\times$ AI studies) that could be used in assessing 1) data efficiency, 2) generalization performance, and 3) concordance of attention-based interpretability of `ABMIL`.

In [8]:
# save best_model for next session
if use_drive:
  torch.save(best_model.state_dict(), '/content/drive/My Drive/ai4healthsummerschool/abmil.ckpt')
else:
  torch.save(best_model.state_dict(), './data/checkpoints/abmil.ckpt')
