In [None]:
import os
import torch
import random
import timm
import transformers
import pandas as pd
import pytorch_lightning as pl

from torch import nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from torchvision.io import read_image
from torchvision.transforms import Normalize
from torchmetrics import Accuracy, ConfusionMatrix 

from fastcore.foundation import L
from collections import OrderedDict
from tqdm.notebook import tqdm

from PIL import Image
from torchvision.transforms import ToTensor
import seaborn as sns

from torchvision import transforms

In [None]:
tile_column = 'tiles'
label_column = 'label'
gpu = '0,'
epochs = 100
learning_rate = 1e-05
batch_size = 700
exp_name = "TissueClassification_fine_tuning_DACHS_all_norm_CJ"

train_pkl = '/.../datasets/df_train.pkl' # The NCT-CRC-HE-100K 
val_pkl = '/.../df_val.pkl' # The NCT-CRC-HE-100K 
kather_int_test_pkl = '/.../datasets/df_test.pkl' #The NCT-CRC-HE-100K 
kather_ext = '/.../datasets/TissueClasses_CRC_VAL_7K.pkl'

# validation set 
finetune_val_pkl = '/.../datasets/df_finetune_val_dachs_ideal_norm.pkl' # samples of the DACHS patient cohort, annotated in-house
finetune_train_pkl = '/.../datasets/df_finetune_train_dachs_all_norm.pkl'# samples of the DACHS patient cohort, annotated in-house 

In [None]:
norm = Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))

train_tfms = transforms.Compose([ 
    transforms.RandomResizedCrop(224, scale=(0.8, 1.0)),
    transforms.RandomApply(torch.nn.ModuleList([transforms.ColorJitter(brightness=0.25, contrast=0.75, saturation=0.25, hue=0.5)]), p=0.9),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomVerticalFlip(p=0.5),
    transforms.RandomGrayscale(p=0.1),
    transforms.RandomApply(torch.nn.ModuleList([transforms.GaussianBlur(kernel_size=(5,5), sigma=(0.1,5))]), p=0.3),
    ToTensor(),
    norm,
])


test_tfms = transforms.Compose([
    ToTensor(),
    norm,
])


class SlideDataSet(Dataset):
    def __init__(self, dataframe,tile_column, label_column, tile_tfms):
        self.df = dataframe
        self.tiles = L(*self.df[tile_column])
        self.labels = torch.as_tensor(self.df[label_column].astype(int).values)
        self.tile_tfms = tile_tfms
        
    def __len__(self):
        return len(self.labels)

    def __getitem__(self, index):
        image = Image.open(self.tiles[index]) 
        label = self.labels[index]
        return self.tile_tfms(image), label 
    
class SlideDataModule(pl.LightningDataModule): 
    def __init__(self, train, val, test, tile_column, label_column, batch_size, train_tfms, test_tfms):
        super().__init__()
        self.train = train
        self.val = val
        self.test = test
        self.tile_column = tile_column
        self.label_column = label_column
        self.bs = batch_size 
        self.train_tfms = train_tfms
        self.test_tfms = test_tfms
       
        
    def prepare_data(self):
        pass
    
    def setup(self, stage): 
        if stage == "fit": 
            self.train_df = pd.read_pickle(self.train)
            self.valid_df =  pd.read_pickle(self.val)
            self.train_ds = SlideDataSet(self.train_df, self.tile_column, self.label_column, self.train_tfms) 
            self.valid_ds = SlideDataSet(self.valid_df, self.tile_column, self.label_column, self.test_tfms)
        if stage == "test":
            self.test_df = pd.read_pickle(self.test)
            self.test_ds = SlideDataSet(self.test_df, self.tile_column, self.label_column, self.test_tfms)
    
    def train_dataloader(self):
        return DataLoader(self.train_ds, shuffle=True, batch_size=self.bs, batch_sampler=None, num_workers=6)
    
    def val_dataloader(self):
        return DataLoader(self.valid_ds, shuffle=False, batch_size=self.bs, batch_sampler=None, num_workers=6)
    
    def test_dataloader(self):
        return DataLoader(self.test_ds, shuffle=False, batch_size=self.bs, batch_sampler=None, num_workers=6)
    

In [None]:
class SlideModel(pl.LightningModule):
    def __init__(self):
        super().__init__()
        self.criterion = nn.CrossEntropyLoss()
        
        # Metrics 
        self.train_acc = Accuracy()
        self.valid_acc = Accuracy()
        self.test_acc = Accuracy()

        
        model = timm.create_model('resnet18')
        model.fc = nn.Conv2d(512,1,1)
        model_dir = '/.../' # add path to nvidia-checkpoint
        ckpt = torch.load(os.path.join(model_dir, "nvidia-resnet18.pt"), map_location="cpu")
        model.load_state_dict(OrderedDict(zip(model.state_dict().keys(), ckpt.values())))
        model.reset_classifier(0)
        
        self.model = model 
        self.classifier = nn.Sequential(nn.Dropout(0.70), nn.Linear(512*1, 9))
        
        
    def forward(self, tiles, label): 
        rep = self.model(tiles)
        y_prob = self.classifier(rep)
        return y_prob, label   
    
    def training_step(self, batch, batch_idx):
        tiles, y = batch
        logits, y = self(tiles, y) 
        logits = logits.squeeze(1).float()
        loss = self.criterion(logits, y)
        self.log("train/acc", self.train_acc(logits.sigmoid(), y),on_step=False, on_epoch=True, logger=True)
        self.log("train/loss", loss, on_step=False, on_epoch=True, prog_bar=True, logger=True)
        return {
        "loss": loss, 
        "slide_score": logits.detach().sigmoid(), 
        "y": y.detach()
        }
    
    def training_epoch_end(self, outs):
        pass

        
    def validation_step(self, batch, batch_idx):
        tiles, y = batch
        logits, y = self(tiles, y)
        logits = logits.squeeze(1).float()
        #print(logits, y)
        loss = self.criterion(logits, y)
        self.log("valid/acc", self.valid_acc(logits.sigmoid(), y), on_epoch=True, logger=True)
        self.log("valid/loss", loss)
        return {
        "loss": loss, 
        "slide_score": logits.detach().sigmoid(), 
        "y": y.detach()
        }
    
    def validation_epoch_end(self, outs):
        pass
    
    def test_step(self, batch, batch_idx):
        return self.validation_step(batch, batch_idx)
    
    def test_epoch_end(self, outs):
        return self.validation_epoch_end(outs)
    
    def configure_optimizers(self):
        opt = torch.optim.AdamW(self.parameters(), lr = learning_rate, weight_decay=0.000001)
        return opt

In [None]:
class SlideModel_Finetune(pl.LightningModule):
    def __init__(self, model):
        super().__init__()
        self.criterion = nn.CrossEntropyLoss()
        
        # Metrics 
        self.train_acc = Accuracy()
        self.valid_acc = Accuracy()
        self.test_acc = Accuracy()
        self.cm = ConfusionMatrix(num_classes=9)

        self.model = model.model
        # split model into seperate blocks to decide which to fine-tune?
        self.model_part1 =  nn.Sequential(*list(self.model.children())[:-3])
        last_resblock = nn.Sequential(*list(self.model.children())[7])
        self.basicblock0 = nn.Sequential(list(last_resblock.children())[0])
        self.basicblock1 = nn.Sequential(list(last_resblock.children())[1])
        self.model_part2 =  nn.Sequential(*list(self.model.children())[8:])
        self.classifier = model.classifier
        
        self.model_part1.requires_grad_(True)
        self.basicblock0.requires_grad_(True) 
        self.basicblock1.requires_grad_(True)
        self.model_part2.requires_grad_(True)
        self.classifier.requires_grad_(True)
        
        
    def forward(self, tiles, label): 
        #rep = self.model(tiles)
        tiles = self.model_part1(tiles)
        tiles = self.basicblock0(tiles)
        tiles = self.basicblock1(tiles)
        rep = self.model_part2(tiles)
        y_prob = self.classifier(rep)
        return y_prob, label   
    
    def training_step(self, batch, batch_idx):
        tiles, y = batch
        logits, y = self(tiles, y) 
        logits = logits.squeeze(1).float()
        loss = self.criterion(logits, y)
        self.cm.update(logits.sigmoid().float(), y)
        self.log("train/acc", self.train_acc(logits.sigmoid(), y),on_step=False, on_epoch=True, logger=True)
        self.log("train/loss", loss, on_step=False, on_epoch=True, prog_bar=True, logger=True)
        return {
        "loss": loss, 
        "slide_score": logits.detach().sigmoid(), 
        "y": y.detach()
        }
    
    def training_epoch_end(self, outs):
        cm_results = self.cm.compute()
        pass

        
    def validation_step(self, batch, batch_idx):
        tiles, y = batch
        logits, y = self(tiles, y)
        logits = logits.squeeze(1).float()
        loss = self.criterion(logits, y)
        self.cm.update(logits.sigmoid().float(), y)
        self.log("valid/acc", self.valid_acc(logits.sigmoid(), y), on_epoch=True, logger=True)
        self.log("valid/loss", loss)
        return {
        "loss": loss, 
        "slide_score": logits.detach().sigmoid(), 
        "y": y.detach()
        }
    
    def validation_epoch_end(self, outs):
        cm_results = self.cm.compute()
        return {
            "cm_resuts": cm_results.detach()
        }
        
    
    def test_step(self, batch, batch_idx):
        return self.validation_step(batch, batch_idx)
    
    def test_epoch_end(self, outs):
        return self.validation_epoch_end(outs)
    
    def configure_optimizers(self):
        opt = torch.optim.AdamW(self.parameters(), lr = learning_rate, weight_decay=1e-07) # lr=0.00001, wd = 0.00001
        return opt

In [None]:
print('Experiment name', exp_name)
df_train = pd.read_pickle(train_pkl)
df_val = pd.read_pickle(val_pkl)
df_test = pd.read_pickle(kather_int_test_pkl)
print('Size of Training/Validation/Testset: {}/{}/{}'.format(len(df_train), len(df_val), len(df_test)))

# load the checkpoint after training on NCT-CRC-HE-100K 
ckpt = '/.../checkpoints/epoch=99-step=66699.ckpt'
m0 = SlideModel()
m0 = m0.load_from_checkpoint(ckpt, strict=False)
m0 = m0.eval() 
m0.freeze()
dm1 = SlideDataModule(train_pkl, val_pkl, kather_ext, tile_column, label_column, batch_size, train_tfms, test_tfms) 
dm1.setup('test')
trainer = pl.Trainer(gpus=gpu, callbacks=False)
metrics = trainer.test(m0, datamodule=dm1, verbose=True)
print(metrics)

In [None]:
dm = SlideDataModule(finetune_train_pkl, finetune_val_pkl, kather_int_test_pkl, tile_column, label_column, batch_size, train_tfms, test_tfms)
dm.setup('fit')

In [None]:
m = SlideModel_Finetune(m0)

logg_path = os.path.join(os.getcwd(), "logs", exp_name)
logger = pl.loggers.TensorBoardLogger(logg_path, name=None)
monitor = "valid/loss"
monitor_mode = "min"
lr_monitor = pl.callbacks.lr_monitor.LearningRateMonitor()
checkpoint_callback = pl.callbacks.ModelCheckpoint(save_last=True)
trainer = pl.Trainer(gpus=gpu,  callbacks=[checkpoint_callback, lr_monitor], logger=logger, num_sanity_val_steps=0, max_epochs=epochs)

In [None]:
trainer.fit(m, datamodule=dm)