# ### TODO: Outline image here###

# Table of contents

1. [Libraries & Environment](#Libraries-&-Environment)
1. [Data Preprocessing](#Data-Preprocessing)
    1. Tiling
    1. Filtering out background tiles
    1. Macenko normalization
    1. Tumor detection
1. [Training Deep Learning Models](#Training-Deep-Learning-Models)
    1. Data splitting
    1. Model and data loading
    1. Common hardware bottlenecks
    1. Real-time performance monitoring
    1. Misc.
1. [Evaluating Performance](#Evaluating-Performance)
    1. Patient-level vs. tile-level evaluation
    1. AUROC vs. accuracy
    1. On improving performance
1. [Visualizing Results](#Visualizing-Results)
    1. TODO: outline

# Libraries & Environment

The base environment that I use can be installed using the create_conda_env.sh bash script.

NB: As of June 2021, when installing OpenSlide on Linux, it will not work correctly with some image types due to a broken dependency. (I've noticed this problem for .mrxs images in particular) In order to repair this issue, you can install version 0.40.0 of the pixman library. (Installed automatically in the create_conda_env.sh script) If you notice the slide images look like like the image below, or throw an error when you view them, try this solution.

TODO: insert image

In [18]:
import numpy as np
from openslide import OpenSlide, OpenSlideError
import pandas as pd
from pathlib import Path
from PIL import Image
import re
from scipy import ndimage
from sklearn.model_selection import train_test_split
import shutil
import time
import tqdm
import traceback
import warnings

# Pytorch imports
import torch
from torch.utils import data
from torchvision import datasets, models, transforms

# Custom imports
from library.MacenkoNormalizer import MacenkoNormalizer
from library.model_utils import load_saved_model_for_inference, load_model_arch

# DEVICE determines which GPU (or CPU) the deep learning code will be run on
# DEVICE = torch.device('cpu')
DEVICE = torch.device('cuda:0')

# Data Preprocessing

In order to prepare the WSI images for deep learning training and inference, a number of preprocessing steps must be applied:

1. Images are broken into many small tiles (usually 256x256 microns)
1. Tiles are filtered to exclude non-tissue background regions
1. Tiles are Macenko-normalized
1. Tiles are filtered to exclude non-tumorous tissue regions

These steps are laid out in example code below. However, when applying this pipeline at scale, the implementation should include multiprocessing and/or CuPy (for Macenko normalization) as these additions provide enormous speedups.

In [2]:
# MICRONS_PER_TILE defines the tile edge length used when breaking WSIs into smaller images
MICRONS_PER_TILE = 256.

# Initialize the Macenko Normalizer
reference_img = np.array(Image.open('library/macenko_reference_img.png').convert('RGB'))
normalizer = MacenkoNormalizer()
normalizer.fit(reference_img)

# Find all WSIs and check for errors opening the file or finding the microns-per-pixel values 
base_path = Path('WSIs')
base_save_path = Path('tiled_WSIs')
wsi_paths = base_path.rglob('*.svs')
save_paths = []
wsi_paths_to_normalize = []
total_num_tiles = 0
for wsi_path in wsi_paths:
    try:
        with OpenSlide(str(wsi_path)) as wsi:
            sub_path = Path(str(wsi_path)[len(str(base_path)) + 1:-len(wsi_path.suffix)])
            save_path = base_save_path / sub_path

            if (save_path / 'Finished.txt').exists():
                print('Ignoring {}, as it has already been processed.'.format(wsi_path))
            else:
                pixels_per_tile_x = int(MICRONS_PER_TILE / float(wsi.properties['openslide.mpp-x']))
                pixels_per_tile_y = int(MICRONS_PER_TILE / float(wsi.properties['openslide.mpp-y']))
                wsi_paths_to_normalize.append(wsi_path)
                save_paths.append(save_path)
                save_path.mkdir(parents=True, exist_ok=True)
                total_num_tiles += (
                        len(range(pixels_per_tile_x, wsi.dimensions[0] - pixels_per_tile_x, pixels_per_tile_x)) *
                        len(range(pixels_per_tile_y, wsi.dimensions[1] - pixels_per_tile_y, pixels_per_tile_y)))
    except OpenSlideError:
        print('Ignoring {}, as it cannot be opened by OpenSlide.'.format(wsi_path))
    except KeyError:
        print('Ignoring {}, as it does not have a defined microns-per-pixel value'.format(wsi_path))

print(f'Masking and normalizing {total_num_tiles} tiles from {len(wsi_paths_to_normalize)} whole slide images.')

Ignoring WSIs/MSS/TCGA-4N-A93T-01Z-00-DX1.82E240B1-22C3-46E3-891F-0DCE35C43F8B.svs, as it has already been processed.
Ignoring WSIs/MSS/TCGA-3L-AA1B-01Z-00-DX1.8923A151-A690-40B7-9E5A-FCBEDFC2394F.svs, as it has already been processed.
Ignoring WSIs/MSI-H/TCGA-5M-AATE-01Z-00-DX1.483FFD2F-61A1-477E-8F94-157383803FC7.svs, as it has already been processed.
Ignoring WSIs/MSI-H/TCGA-5M-AAT6-01Z-00-DX1.8834C952-14E3-4491-8156-52FC917BB014.svs, as it has already been processed.
Masking and normalizing 0 tiles from 0 whole slide images.


This function, given a whole slide image path and target save path, masks and normalizes all tissue tiles and then saves them into pngs.

In [3]:
def mask_and_normalize_wsi(wsi_path, save_path, pbar):
    num_tiles_kept = 0
    try:
        with OpenSlide(str(wsi_path)) as wsi:
            pptx = int(MICRONS_PER_TILE / float(wsi.properties['openslide.mpp-x']))
            ppty = int(MICRONS_PER_TILE / float(wsi.properties['openslide.mpp-y']))
            # Leave out border of image
            for x in range(pptx, wsi.dimensions[0] - pptx, pptx):
                for y in range(ppty, wsi.dimensions[1] - ppty, ppty):
                    tile = wsi.read_region((x, y), level=0, size=(pptx, ppty)).convert('RGB')
                    # Mask away all-white and all-black background regions
                    mask = tile.convert(mode='L').point(lut=lambda p: 220 > p > 10, mode='1')
                    mask = ndimage.binary_fill_holes(mask)
                    if np.sum(mask).astype(float) / mask.size > 0.5:
                        with warnings.catch_warnings():
                            warnings.simplefilter('ignore')
                            try:
                                # Normalize the tile
                                tile = normalizer.transform(np.array(tile))
                                tile = Image.fromarray(tile)
                                # Resize the image to 224x224
                                tile = tile.resize((224, 224), Image.LANCZOS)
                                num_tiles_kept += 1
                                filename = f'{wsi_path.stem}__x{x}_y{y}_dx{pptx}_dy{ppty}.png'
                                tile.save(save_path / filename, format='PNG')
                            except np.linalg.LinAlgError:
                                pass
                    pbar.update()
    except OpenSlideError as ex:
        print('\nUnable to process {}:'.format(wsi_path))
        print(''.join(traceback.format_exception(etype=type(ex), value=ex, tb=ex.__traceback__)))
        shutil.rmtree(save_path)
        return 0

    with open(save_path / 'Finished.txt', 'w+') as file:
        file.write('Kept and processed {} tiles.'.format(num_tiles_kept))
    return num_tiles_kept

In [4]:
assert len(wsi_paths_to_normalize) == len(save_paths)
with tqdm.tqdm(total=total_num_tiles) as pbar:
    for wsi_path, save_path in zip(wsi_paths_to_normalize, save_paths):
        mask_and_normalize_wsi(wsi_path, save_path, pbar)
# Wait a moment for pbar to close
time.sleep(0.25)

all_save_paths = [p for p in base_save_path.glob('*/*') if p.is_dir()]
total_tiles_kept = 0
for save_path in all_save_paths:
    with open(save_path / 'Finished.txt', 'r') as f:
        info = f.readline()
    num_tiles_kept = int(re.search('processed ([0-9]+?) tiles', info).group(1))
    total_tiles_kept += num_tiles_kept
    print(f'{num_tiles_kept} tiles from patient {save_path.stem} saved to {save_path}')
print(f'{total_tiles_kept} tiles were saved and normalized')

0it [00:00, ?it/s]


2051 tiles from patient TCGA-4N-A93T-01Z-00-DX1 saved to tiled_WSIs/MSS/TCGA-4N-A93T-01Z-00-DX1.82E240B1-22C3-46E3-891F-0DCE35C43F8B
2960 tiles from patient TCGA-3L-AA1B-01Z-00-DX1 saved to tiled_WSIs/MSS/TCGA-3L-AA1B-01Z-00-DX1.8923A151-A690-40B7-9E5A-FCBEDFC2394F
4172 tiles from patient TCGA-5M-AAT6-01Z-00-DX1 saved to tiled_WSIs/MSI-H/TCGA-5M-AAT6-01Z-00-DX1.8834C952-14E3-4491-8156-52FC917BB014
6951 tiles from patient TCGA-5M-AATE-01Z-00-DX1 saved to tiled_WSIs/MSI-H/TCGA-5M-AATE-01Z-00-DX1.483FFD2F-61A1-477E-8F94-157383803FC7
16134 tiles were saved and normalized


Now that WSIs have been broken into normalized tiles, we load these images for tumor detection.

NB: Make sure to use `with torch.no_grad():` at inference time or there may be memory overflow.

In [5]:
print('Loading images for tumor detection...')
img_dataset = datasets.ImageFolder(
    base_save_path,
    transforms.Compose([
        # Images must be of size 224x224 to be passed to most deep learning vision models
        transforms.Resize(224),
        transforms.ToTensor()
    ])
)
img_dataloader = data.DataLoader(
    img_dataset,
    batch_size=128,
    num_workers=8,
    shuffle=False,
    pin_memory=True
)
tumor_detection_model = load_saved_model_for_inference(
    'saved_models/resnet18_tumor_detection_exp9.pt',
    num_classes=2,
).to(DEVICE)

print(f'Getting tumor predictions for {len(img_dataset)} tiles in {len(img_dataloader)} batches.')
time.sleep(0.25)
all_preds = []

with torch.no_grad():
    for inputs, _ in tqdm.tqdm(img_dataloader):
        inputs = inputs.to(DEVICE, non_blocking=True)
        outputs = tumor_detection_model(inputs).cpu()
        all_preds.append(outputs)
all_preds = torch.cat(all_preds, dim=0)
time.sleep(0.25)

tumorous_tiles = all_preds.argmax(dim=1).flatten()
print(f'{tumorous_tiles.sum()}/{len(img_dataset)} tiles contain tumorous tissue')

Loading images for tumor detection...
Getting tumor predictions for 16134 tiles in 127 batches.


100%|██████████| 127/127 [00:08<00:00, 14.73it/s]


9806/16134 tiles contain tumorous tissue


In [6]:
tile_ids = [Path(s[0]).name for s in img_dataset.samples]
patient_ids = [t.split('__')[0] for t in tile_ids]
msi_status = [Path(s[0]).parents[1].name for s in img_dataset.samples]
tile_info_df = pd.DataFrame({
    'tile_id': tile_ids,
    'patient_id': patient_ids,
    'tumor_pred_val': all_preds[:, 1].numpy(),
    'tumor_pred_class': tumorous_tiles.numpy(),
    'MSI_status': msi_status,
})
tile_info_df.set_index('tile_id', inplace=True)
tile_df_save_path = base_save_path / 'tile_info.csv'
tile_info_df.to_csv(tile_df_save_path)
print(f'Saved tile_info_df to "{tile_df_save_path}"')

patient_info_df = pd.DataFrame({
    'patient_id': tile_info_df['patient_id'].unique(),
    'MSI_status': '',
})
patient_info_df.set_index('patient_id', inplace=True)
for patient_id, msi_status in tile_info_df.groupby('patient_id').MSI_status.unique().iteritems():
    # Make sure that MSI status is the same for all tiles within a patient
    msi_status, = msi_status
    patient_info_df.loc[patient_id] = msi_status
patient_df_save_path = base_save_path / 'patient_info.csv'
patient_info_df.to_csv(patient_df_save_path)
print(f'Saved patient_info_df to "{patient_df_save_path}"')

display(patient_info_df.style.set_caption('Patient info dataframe'))
tile_info_df.groupby('patient_id').head(2).style.set_caption('Tile info dataframe example rows')

Saved tile_info_df to "tiled_WSIs/tile_info.csv"
Saved patient_info_df to "tiled_WSIs/patient_info.csv"


Unnamed: 0_level_0,MSI_status
patient_id,Unnamed: 1_level_1
TCGA-5M-AAT6-01Z-00-DX1.8834C952-14E3-4491-8156-52FC917BB014,MSI-H
TCGA-5M-AATE-01Z-00-DX1.483FFD2F-61A1-477E-8F94-157383803FC7,MSI-H
TCGA-3L-AA1B-01Z-00-DX1.8923A151-A690-40B7-9E5A-FCBEDFC2394F,MSS
TCGA-4N-A93T-01Z-00-DX1.82E240B1-22C3-46E3-891F-0DCE35C43F8B,MSS


Unnamed: 0_level_0,patient_id,tumor_pred_val,tumor_pred_class,MSI_status
tile_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1
TCGA-5M-AAT6-01Z-00-DX1.8834C952-14E3-4491-8156-52FC917BB014__x100287_y12156_dx1013_dy1013.png,TCGA-5M-AAT6-01Z-00-DX1.8834C952-14E3-4491-8156-52FC917BB014,0.010808,0,MSI-H
TCGA-5M-AAT6-01Z-00-DX1.8834C952-14E3-4491-8156-52FC917BB014__x100287_y13169_dx1013_dy1013.png,TCGA-5M-AAT6-01Z-00-DX1.8834C952-14E3-4491-8156-52FC917BB014,0.045893,0,MSI-H
TCGA-5M-AATE-01Z-00-DX1.483FFD2F-61A1-477E-8F94-157383803FC7__x100287_y10130_dx1013_dy1013.png,TCGA-5M-AATE-01Z-00-DX1.483FFD2F-61A1-477E-8F94-157383803FC7,0.999991,1,MSI-H
TCGA-5M-AATE-01Z-00-DX1.483FFD2F-61A1-477E-8F94-157383803FC7__x100287_y11143_dx1013_dy1013.png,TCGA-5M-AATE-01Z-00-DX1.483FFD2F-61A1-477E-8F94-157383803FC7,0.999997,1,MSI-H
TCGA-3L-AA1B-01Z-00-DX1.8923A151-A690-40B7-9E5A-FCBEDFC2394F__x10130_y33429_dx1013_dy1013.png,TCGA-3L-AA1B-01Z-00-DX1.8923A151-A690-40B7-9E5A-FCBEDFC2394F,1.0,1,MSS
TCGA-3L-AA1B-01Z-00-DX1.8923A151-A690-40B7-9E5A-FCBEDFC2394F__x10130_y34442_dx1013_dy1013.png,TCGA-3L-AA1B-01Z-00-DX1.8923A151-A690-40B7-9E5A-FCBEDFC2394F,1.0,1,MSS
TCGA-4N-A93T-01Z-00-DX1.82E240B1-22C3-46E3-891F-0DCE35C43F8B__x10130_y31403_dx1013_dy1013.png,TCGA-4N-A93T-01Z-00-DX1.82E240B1-22C3-46E3-891F-0DCE35C43F8B,0.010911,0,MSS
TCGA-4N-A93T-01Z-00-DX1.82E240B1-22C3-46E3-891F-0DCE35C43F8B__x10130_y32416_dx1013_dy1013.png,TCGA-4N-A93T-01Z-00-DX1.82E240B1-22C3-46E3-891F-0DCE35C43F8B,0.668856,1,MSS


# Training Deep Learning Models

The deep learning model pipeline consists of three main steps:
1. Data splitting
1. Model and data loading
1. Training loop
    1. Performing inference
    1. Calculating loss
    1. Backpropagating loss
    1. Updating parameters
    1. Logging results

This notebook will also cover a few other important things to consider:
1. Common hardware bottlenecks
1. Real-time performance monitoring
1. Misc.

First, we split the patients into a train/validation set and a test set.

Normally, 10-20% of patients would be assigned to the test set, but since we only have 4 patients in our example dataset, we will perform a 50/50 split.

In [7]:
train_val_set, test_set = train_test_split(
    patient_info_df.index.values,
    test_size=0.5,
    stratify=patient_info_df['MSI_status'].values
)
patient_info_df.loc[train_val_set, 'data_subset'] = 'train/validation'
patient_info_df.loc[test_set, 'data_subset'] = 'test'
tile_info_df = tile_info_df.drop(
    columns='data_subset',
    errors='ignore'
).join(
    patient_info_df['data_subset'],
    on='patient_id'
)

Then, we split the tiles from the train/validation patients into a train set and a validation set.

In [8]:
train_val_mask = tile_info_df['data_subset'] != 'test'
train_set, val_set = train_test_split(
    tile_info_df.index.values[train_val_mask],
    train_size=0.9,
    stratify=tile_info_df['patient_id'].values[train_val_mask]
)
tile_info_df.loc[train_set, 'data_subset'] = 'train'
tile_info_df.loc[val_set, 'data_subset'] = 'validation'

tile_info_df.to_csv(tile_df_save_path)
print(f'Saved updated tile_info_df to "{tile_df_save_path}"')
patient_info_df.to_csv(patient_df_save_path)
print(f'Saved updated patient_info_df to "{patient_df_save_path}"')

Saved updated tile_info_df to "tiled_WSIs/tile_info.csv"
Saved updated patient_info_df to "tiled_WSIs/patient_info.csv"


Verify that there are no patients with tiles in both the train and test set.

In [9]:
tile_info_df.groupby('patient_id')['data_subset'].unique()

patient_id
TCGA-3L-AA1B-01Z-00-DX1.8923A151-A690-40B7-9E5A-FCBEDFC2394F                 [test]
TCGA-4N-A93T-01Z-00-DX1.82E240B1-22C3-46E3-891F-0DCE35C43F8B    [train, validation]
TCGA-5M-AAT6-01Z-00-DX1.8834C952-14E3-4491-8156-52FC917BB014    [train, validation]
TCGA-5M-AATE-01Z-00-DX1.483FFD2F-61A1-477E-8F94-157383803FC7                 [test]
Name: data_subset, dtype: object

Here we load the data and model, and move the model to the correct device.

In order to split the train and test sets, and only include tumor images, we define a function to check that a given image path is tumorous and in the correct data subset.

The model architectures I've tried and had the most success with are, in no particular order:
1. densenet201
1. resnet18
1. shufflenet_v2_x1_0
1. squeezenet1_1

However, this is a decision that depends on the amount of data and compute available. For a list of all ImageNet pretrained models available through PyTorch, see https://pytorch.org/vision/stable/models.html.

Here we'll use SqueezeNet since it is the smallest and fastest.

In [44]:
def get_subset_func(data_subset):
    def is_valid_img(path):
        if not Path(path).suffix == '.png':
            return False
        row = tile_info_df.loc[Path(path).name]
        return row['tumor_pred_class'] == 1 and row['data_subset'] == data_subset
    return is_valid_img

PHASES = ['train', 'validation', 'test']
data_transforms = {
    'train': transforms.Compose([
        transforms.Resize(224),
        # We perform a number of random operations to the images in order to augment the training data
        transforms.RandomAffine(180, translate=(0.1, 0.1)),
        transforms.RandomHorizontalFlip(),
        transforms.RandomVerticalFlip(),
        transforms.ToTensor()
    ]),
    'validation': transforms.Compose([
        transforms.Resize(224),
        transforms.ToTensor()
    ]),
    'test': transforms.Compose([
        transforms.Resize(224),
        transforms.ToTensor()
    ]),
}

print('Loading images for training...')
img_datasets = {
    phase: datasets.ImageFolder(
        base_save_path,
        transform=data_transforms[phase],
        is_valid_file=get_subset_func(phase)
    ) for phase in PHASES
}
img_dataloaders = {
    phase: data.DataLoader(
        img_datasets[phase],
        batch_size=32,
        num_workers=8,
        shuffle=True,
        pin_memory=True
    ) for phase in PHASES
}
for phase in PHASES:
    print(f'Loaded {len(img_datasets[phase])} {phase} images.')

MODEL_ARCHITECTURE = models.squeezenet1_1
model = load_model_arch(MODEL_ARCHITECTURE, pretrained=True, num_classes=2).to(DEVICE)
n_train_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f'Loaded model "{MODEL_ARCHITECTURE.__name__}" with {n_train_params:,d} trainable parameters.')

Loading images for training...
Loaded 3358 train images.
Loaded 373 validation images.
Loaded 6075 test images.
Loaded model "squeezenet1_1" with 723,522 trainable parameters.


# Evaluating Performance

In [None]:
print('Testing')