In [48]:
import os
import argparse
import logging
import functools
import itertools

import tqdm
import torchvision
from torch.utils.data import DataLoader
import zarr
import dask
import dask.array as da
import zarrdataset
import random
import numpy as np

import torch
import torch.nn as nn
import torch.optim as optim

## Hyper parameters

In [63]:
num_epochs = 1
batch_size = 4
patch_size = 518
num_workers = 0

num_classes = 1
bias = True

rng_seed = -1

log_dir = "/fastscratch/cervaf/logs/wsi_classifiers/"
log_identifier = "ViT_H_14"
print_log = True

trn_filenames_list = "/projects/researchit/cervaf/s3_bucket_data/tcga_kirc_train.txt"
val_filenames_list = "/projects/researchit/cervaf/s3_bucket_data/tcga_kirc_val.txt"

### Set a random number generator seed for reproducibility

In [36]:
if rng_seed < 0:
    rng_seed = np.random.randint(1, 100000)

torch.manual_seed(rng_seed)
np.random.seed(rng_seed + 1)
random.seed(rng_seed+1, 2)

### Helper function to load the dataset filenames

In [49]:
def parse_filenames_list(filenames_list, input_format):
    if (isinstance(filenames_list, str)
      and not (filenames_list.lower().endswith(input_format.lower())
             or filenames_list.lower().endswith(".txt"))):
        return []

    if (isinstance(filenames_list, str)
      and filenames_list.lower().endswith(input_format.lower())):
        return [filenames_list]

    if (isinstance(filenames_list, str)
      and filenames_list.lower().endswith(".txt")):
        with open(filenames_list, "r") as fp:
            filenames_list = [fn.strip("\n ") for fn in  fp.readlines()]

    if isinstance(filenames_list, list):
        filenames_list = functools.reduce(lambda l1, l2: l1 + l2,
                                          map(parse_filenames_list,
                                              filenames_list,
                                              itertools.repeat(input_format)),
                                          [])
    return filenames_list

In [64]:
logger = logging.getLogger('train_log')
logger.setLevel(logging.INFO)

formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')

fh = logging.FileHandler(os.path.join(log_dir, "train_vit_classifier%s.log" % log_identifier), mode='w')
fh.setFormatter(formatter)
logger.addHandler(fh)

if print_log:
    console = logging.StreamHandler()
    console.setFormatter(formatter)
    logger.addHandler(console)
    console.setLevel(logging.DEBUG)

# Load the pretrained Vision Transoformer (ViT)

### Setup a ViT model using the pre-trained weights provided by torchvision at
https://pytorch.org/vision/main/models/generated/torchvision.models.vit_h_14.html#vit-h-14

In [65]:
logger = logging.getLogger('train_log')

logger.info("Loading vision transformer from torchvision.models.vit_b_16")

vit_checkpoint = torchvision.models.ViT_H_14_Weights.IMAGENET1K_SWAG_E2E_V1
model = torchvision.models.vit_h_14(weights=vit_checkpoint, progress=False)

# Freeze the model parameters to perform fine tuning only on the last layer (classifier)
for par in model.parameters():
    par.requires_grad = False

model.heads = nn.Sequential(
    nn.Linear(in_features=1280, out_features=num_classes, bias=bias)
)

logger.debug("Model\n%s" % str(model))
extra_transforms = vit_checkpoint.transforms

logger.info("Preprocessing transforms\n%s" % extra_transforms)

2023-06-26 15:56:58,505 - INFO - Loading vision transformer from torchvision.models.vit_b_16
2023-06-26 15:56:58,505 - INFO - Loading vision transformer from torchvision.models.vit_b_16
2023-06-26 15:56:58,505 - INFO - Loading vision transformer from torchvision.models.vit_b_16
2023-06-26 15:57:07,468 - INFO - Preprocessing transforms
functools.partial(<class 'torchvision.transforms._presets.ImageClassification'>, crop_size=518, resize_size=518, interpolation=<InterpolationMode.BICUBIC: 'bicubic'>)
2023-06-26 15:57:07,468 - INFO - Preprocessing transforms
functools.partial(<class 'torchvision.transforms._presets.ImageClassification'>, crop_size=518, resize_size=518, interpolation=<InterpolationMode.BICUBIC: 'bicubic'>)
2023-06-26 15:57:07,468 - INFO - Preprocessing transforms
functools.partial(<class 'torchvision.transforms._presets.ImageClassification'>, crop_size=518, resize_size=518, interpolation=<InterpolationMode.BICUBIC: 'bicubic'>)


## Define the training and validation datasets

In [66]:
# This is an auxiliary class that extracts patches from the zarr files without having to save them separately
patch_sampler = zarrdataset.GridPatchSampler(patch_size=patch_size)

transforms_pipeline = [
    zarrdataset.SelectAxes("TCZYX", {"T":0, "Z":0}, "YXC"),
    zarrdataset.ZarrToArray(dtype=np.uint8),
    torchvision.transforms.ToTensor(),
]

if extra_transforms is not None:
    transforms_pipeline.append(extra_transforms())

input_transforms = torchvision.transforms.Compose(transforms_pipeline)

### Get the training and validation file names

In [67]:
trn_filenames = parse_filenames_list(trn_filenames_list, ".zarr")
val_filenames = parse_filenames_list(val_filenames_list, ".zarr")

In [68]:
logger.debug(f"Training files\n{trn_filenames}")
logger.debug(f"Validation files\n{val_filenames}")

In [69]:
trn_ds = zarrdataset.LabeledZarrDataset(
    trn_filenames[:2],
    data_group="0/0", data_axes="TCZYX",
    mask_data_group="masks/0/0", mask_data_axes="YX",
    labels_data_group="masks/1/1", labels_data_axes="C",
    transform=input_transforms,
    shuffle=True,
    patch_sampler=patch_sampler,
    draw_same_chunk=True,
    progress_bar=False,
    use_dask=False)

val_ds = zarrdataset.LabeledZarrDataset(
    val_filenames[:2],
    data_group="0/0", data_axes="TCZYX",
    mask_data_group="masks/0/0", mask_data_axes="YX",
    labels_data_group="masks/1/1", labels_data_axes="C",
    transform=input_transforms,
    shuffle=False,
    patch_sampler=patch_sampler,
    draw_same_chunk=True,
    progress_bar=False,
    use_dask=False)

trn_dl = DataLoader(
    trn_ds,
    batch_size=batch_size,
    num_workers=num_workers,
    worker_init_fn=zarrdataset.zarrdataset_worker_init,
    pin_memory=True,
    persistent_workers=num_workers > 0
)

val_dl = DataLoader(
    val_ds,
    batch_size=batch_size,
    num_workers=num_workers,
    worker_init_fn=zarrdataset.zarrdataset_worker_init,
    pin_memory=True,
    persistent_workers=num_workers > 0
)

## Define the optimizer and the criterion

In [70]:
if torch.cuda.is_available():
    model.cuda()

# Optimize only the classifier head of the model, since everything else is frozen
optimizer = optim.Adam(model.heads.parameters(), lr=1e-4)

# For one-hot encoding use the BCE with logits loss function
criterion = nn.BCEWithLogitsLoss()

logger.info("Optimizer: Adam, lr=1e-4")
logger.info("Criterion: BCE with logits")

2023-06-26 15:57:08,115 - INFO - Optimizer: Adam, lr=1e-4
2023-06-26 15:57:08,115 - INFO - Optimizer: Adam, lr=1e-4
2023-06-26 15:57:08,115 - INFO - Optimizer: Adam, lr=1e-4
2023-06-26 15:57:08,366 - INFO - Criterion: BCE with logits
2023-06-26 15:57:08,366 - INFO - Criterion: BCE with logits
2023-06-26 15:57:08,366 - INFO - Criterion: BCE with logits


## Define the training and validation steps

In [71]:
def train_step(trn_dl, model, criterion, optimizer):
    model.train()
    total_loss = 0
    total_samples = 0

    for i, (x, t) in enumerate(trn_dl):
        optimizer.zero_grad()

        y_hat = model(x.cuda())

        loss = criterion(y_hat, t.to(y_hat.device).float())
        loss.backward()

        optimizer.step()
        
        total_loss += loss.item()
        total_samples += x.size(0)

        if i % 10 == 0:
            logger.debug(f"Training step {i}, avg. training loss={total_loss / total_samples}")

    # Return the average training loss of ths epoch
    return total_loss / total_samples

In [72]:
def validation_step(val_dl, model, criterion):
    model.eval()
    total_loss = 0
    total_samples = 0

    with torch.no_grad():
        for i, (x, t) in enumerate(val_dl):
            y_hat = model(x.cuda())
            loss = criterion(y_hat, t.to(y_hat.device).float())

            total_loss += loss.item()
            total_samples += x.size(0)

            if i % 10 == 0:
                logger.debug(f"Validation step {i}, avg. validation loss={total_loss / total_samples}")

    # Return the average validation loss of this epoch
    return total_loss / total_samples

# The main training loop

In [73]:
best_val_loss = float('inf')
last_checkpoint_fn = os.path.join(log_dir, "last_vit_classifier%s.pth" % log_identifier)
best_checkpoint_fn = os.path.join(log_dir, "best_vit_classifier%s.pth" % log_identifier)

trn_loss_list = []
val_loss_list = []

for e in range(num_epochs):
    trn_loss = train_step(trn_dl, model, criterion, optimizer)
    val_loss = validation_step(val_dl, model, criterion)

    trn_loss_list.append(trn_loss)
    val_loss_list.append(val_loss)

    last_checkpoint = dict(
        model=model.state_dict(),
        epoch=e,
        trn_loss=trn_loss_list,
        val_loss=val_loss_list,
        best_val_loss=best_val_loss)
    torch.save(last_checkpoint, last_checkpoint_fn)

    if best_val_loss > val_loss:
        best_val_loss = val_loss
        best_checkpoint = dict(
            model=model.state_dict(),
            epoch=e,
            trn_loss=trn_loss_list,
            val_loss=val_loss_list,
            best_val_loss=best_val_loss)
        torch.save(best_checkpoint, best_checkpoint_fn)

    logger.info(f"Epoch {e + 1}, avg. training loss={trn_loss}, avg. validation loss={val_loss}")

2023-06-26 16:20:23,964 - INFO - Epoch 1, avg. training loss=0.002234800590497446, avg. validation loss=0.0002953684325151049
2023-06-26 16:20:23,964 - INFO - Epoch 1, avg. training loss=0.002234800590497446, avg. validation loss=0.0002953684325151049
2023-06-26 16:20:23,964 - INFO - Epoch 1, avg. training loss=0.002234800590497446, avg. validation loss=0.0002953684325151049
