How to use the TileLevelDataModule
----------------------------------

In [1]:
import torch
import time
import os
from torch import nn
import cv2
import albumentations as A
from torchvision import transforms as T
import numpy as np
import pandas as pd
import slide_tools

cv2.setNumThreads(1)  # Otherwise cv2 will use all threads with every worker -> bad bad bad

In [2]:
root = "/mnt/data/Lennard/gyn"
csv_train = "ago-tr1/csv/finetune_train_1.csv"
frame = pd.read_csv(os.path.join(root, csv_train))
frame

Unnamed: 0,SlideNr,HRD(BRCA1),label,slide,annotation
0,10,0,ago-tr1/labels/10.json,ago-tr1/slides/10.svs,ago-tr1/annotations/10.geojson
1,100,1,ago-tr1/labels/100.json,ago-tr1/slides/100.svs,ago-tr1/annotations/100.geojson
2,101,1,ago-tr1/labels/101.json,ago-tr1/slides/101.svs,ago-tr1/annotations/101.geojson
3,103,1,ago-tr1/labels/103.json,ago-tr1/slides/103.svs,ago-tr1/annotations/103.geojson
4,106,1,ago-tr1/labels/106.json,ago-tr1/slides/106.svs,ago-tr1/annotations/106.geojson
...,...,...,...,...,...
127,84,1,ago-tr1/labels/84.json,ago-tr1/slides/84.svs,ago-tr1/annotations/84.geojson
128,85,1,ago-tr1/labels/85.json,ago-tr1/slides/85.svs,ago-tr1/annotations/85.geojson
129,9,0,ago-tr1/labels/9.json,ago-tr1/slides/9.svs,ago-tr1/annotations/9.geojson
130,92,1,ago-tr1/labels/92.json,ago-tr1/slides/92.svs,ago-tr1/annotations/92.geojson


### Create a Pytorch Lightning DataModule from slides, annotations, tile labels and slide labels
Note that the arguments have prefixes to make their influence more visible.

In [3]:
hparams = { 
  'batch_size': 128,
  'column_annotation': 'annotation',
  'column_label': 'label',
  'column_slide': 'slide',
  'columns_global_label': ['HRD(BRCA1)'],
  'csv_test': 'ago-tr1/csv/test.csv',
  'csv_train': 'ago-tr1/csv/finetune_train_1.csv',
  'csv_valid': 'ago-tr1/csv/finetune_eval_1.csv',
  'epoch_balance_label_bins': 2,
  'epoch_balance_label_key': 'HRD(BRCA1)',
  'epoch_balance_size_by': 'median',
  'epoch_shuffle': True,
  'epoch_shuffle_chunk_size': 16,
  'image_size': 240,
  'norm_mean': [0.5, 0.5, 0.5],  # for tfms
  'norm_std': [0.5, 0.5, 0.5],  # for tfms
  'num_workers': 8,
  'pin_memory': False,
  'regions_annotation_align': False,
  'regions_centroid_in_annotation': True,
  'regions_level': 0,
  'regions_region_overlap': 0.0,
  'regions_return_index': False,
  'regions_return_labels': 'HRD(BRCA1)',
  'regions_size': None,
  'regions_unit': 'pixel',
  'regions_with_labels': True,
  'root': '/mnt/data/Lennard/gyn',
  'slide_interpolation': 'linear',
  'slide_linear_fill_value': 0.0,
  'slide_load_keys': None,
  'slide_simplify_tolerance': 100,
  'verbose': False
}

In [4]:
# Albumentations has the fastest agumentations acting on numpy arrays
# only batched transforms on GPU are faster.

class AlbumentationWrapper(nn.Module):
    """Wrapper around albumentations."""

    def __init__(self, tfms):
        super().__init__()
        self.tfms = A.Compose(tfms)

    @torch.no_grad()
    def forward(self, image):
        return self.tfms(image=image)["image"]

    def __repr__(self):
        return f"{self.__class__.__name__}({self.tfms.__repr__()})"


# TrivialAugment uses only one augmentation per sample and is therefore very efficient
# convince yourself of the results here: https://arxiv.org/pdf/2103.10158.pdf
trivial_augment = A.OneOf([
    A.Flip(),
    A.GaussNoise(),
    A.ShiftScaleRotate(shift_limit=0.0625, scale_limit=0.2, rotate_limit=45),
    A.ElasticTransform(alpha=100, sigma=100 * 0.1, alpha_affine=100 * 0.03),
    A.Blur(p=0.5),
    A.ColorJitter(brightness=0.25, contrast=0.75, saturation=0.25, hue=0.5),
    A.RandomBrightnessContrast(),
    A.CLAHE(),
    A.Equalize(),
    A.Solarize(),
    A.Sharpen(),
    A.Posterize(),
    A.RandomGamma(),
    A.CoarseDropout(),
])

tfms_train = T.Compose([
    AlbumentationWrapper(
        tfms=[A.RandomResizedCrop(hparams["image_size"], hparams["image_size"], scale=(0.2, 1.0),
                                interpolation=cv2.INTER_CUBIC),
              trivial_augment,
              A.Normalize(hparams["norm_mean"], hparams["norm_std"])]
    ),
    T.ToTensor(),
])

tfms_valid = T.Compose([
    AlbumentationWrapper(
        tfms=[A.Resize(hparams["image_size"], hparams["image_size"], interpolation=cv2.INTER_CUBIC),
              A.Normalize(hparams["norm_mean"], hparams["norm_std"])]
    ),
    T.ToTensor(),
])

In [5]:
%%time
dm = slide_tools.tile_level.TileLevelDataModule(
    **hparams,
    tfms_train=tfms_train,
    tfms_valid=tfms_valid,
    tfms_test=tfms_valid,
)
dm.setup("fit")  # will prepare train_dataloader and valid_dataloader

# The runtime is mostly dominated by reading the json labels.
# Saving/loading of labels will likely change in the future to a more performant variant.

[Plugin: cucim.kit.cuslide] Loading the dynamic library from: /home/caduser/anaconda3/envs/tmmae/lib/python3.8/site-packages/cucim/clara/cucim.kit.cuslide@22.02.00.so
Initializing plugin: cucim.kit.cuslide (interfaces: [cucim::io::IImageFormat v0.1]) (impl: cucim.kit.cuslide)
[Plugin: cucim.kit.cumed] Loading the dynamic library from: /home/caduser/anaconda3/envs/tmmae/lib/python3.8/site-packages/cucim/clara/cucim.kit.cumed@22.02.00.so
Initializing plugin: cucim.kit.cumed (interfaces: [cucim::io::IImageFormat v0.1]) (impl: cucim.kit.cumed)


CPU times: user 5min 2s, sys: 3.71 s, total: 5min 6s
Wall time: 5min 6s


In [6]:
%%time
dm.setup("test") # will prepare test_dataloader

# CPU times: user 1min 8s, sys: 524 ms, total: 1min 9s
# Wall time: 1min 9s

CPU times: user 1min 8s, sys: 524 ms, total: 1min 9s
Wall time: 1min 9s


In [7]:
batch_size = 128
num_workers = 8
dl_train = dm.train_dataloader(batch_size=batch_size, num_workers=num_workers)

N = 100
for i, batch in enumerate(dl_train):
    if i == 20:
        t0 = time.time()
    if i == (N + 20):
        t = time.time() - t0
        break
print(f"{batch_size*N/t:.0f} tiles per second with {num_workers=} at train time (random tiles)")

# 538 tiles per second with num_workers=8 at train time (random tiles)

538 tiles per second with num_workers=8 at train time (random tiles)


In [8]:
batch_size = 128
num_workers = 8
dl_valid = dm.val_dataloader(batch_size=batch_size, num_workers=num_workers)

N = 100
for i, batch in enumerate(dl_valid):
    if i == 20:
        t0 = time.time()
    if i == (N + 20):
        t = time.time() - t0
        break
print(f"{batch_size*N/t:.0f} tiles per second with {num_workers=} at test time (sequential tiles)")

# 581 tiles per second with num_workers=8 at test time (sequential tiles)

581 tiles per second with num_workers=8 at test time (sequential tiles)
