In [1]:
import os
import torch
from torch import nn
import torch.nn.functional as F
from torchvision import transforms
from torch.utils.data import DataLoader, Dataset, Subset
from torchvision.datasets import ImageFolder
import pytorch_lightning as pl
from torchvision.models import resnet50
from sklearn.model_selection import StratifiedShuffleSplit
from dataclasses import dataclass
from torch.multiprocessing import Pool, set_start_method
from pytorch_lightning.loggers import TensorBoardLogger
from torch.nn.functional import softmax
from sklearn.metrics import precision_recall_fscore_support, roc_auc_score
from torch.optim.lr_scheduler import ReduceLROnPlateau
import torchstain
from PIL import Image
import matplotlib.pyplot as plt 

In [2]:
import sys
sys.path.append("../")
import warnings
warnings.filterwarnings("ignore", category=UserWarning, module="torchstain")

In [3]:
# set_start_method("spawn")

In [4]:
# %load_ext autoreload
# %autoreload 2
from src.utils import set_random_seed
from src.configs import ConfigsClass
from tumor_classifier import TileDataset, MacenkoNormalizerTransform

In [5]:
@dataclass
class Configs:
    DATA_ROOT = '../data/tumor_labeled_tiles'
    LOG_DIR = os.path.join(os.getcwd(), 'lightning_logs')
    COLOR_NORM_REF_IMG = ConfigsClass.COLOR_NORM_REF_IMG
    TRAINED_MODEL_PATH = os.path.join(LOG_DIR, "tumor_classifier_resnet50_10_epochs.ckpt")
    DEVICE = 'gpu'
    NUM_DEVICES = 1  # in jupyter multiple devices is not supported
    INIT_LR = 1e-4
    BATCH_SIZE = 16
    NUM_WORKERS = 16
    RANDOM_SEED = 123
    NUM_CLASSES = 3
    EXPERIMENT = 'tc_V1'
    TUMOR_CLASS = 'TUMSTU'
    NON_TUMOR_CLASSES = ['STRMUS', 'ADIMUC']
    TUMOR_IND = None
    NON_TUMOR_INDS = None

In [6]:
set_random_seed(Configs.RANDOM_SEED)

In [7]:
dataset = ImageFolder(Configs.DATA_ROOT)
Configs.TUMOR_IND = dataset.class_to_idx[Configs.TUMOR_CLASS]
Configs.NON_TUMOR_INDS = [dataset.class_to_idx[class_name] for class_name in Configs.NON_TUMOR_CLASSES]

In [8]:
train_transform = transforms.Compose([
        transforms.RandomHorizontalFlip(),  # reverse 50% of images
        transforms.RandomVerticalFlip(),  # reverse 50% of images  
        transforms.Resize(224),            
        transforms.CenterCrop(224),         
        transforms.ToTensor(),
        MacenkoNormalizerTransform(Configs.COLOR_NORM_REF_IMG),
        transforms.Normalize([0.485, 0.456, 0.406],
                             [0.229, 0.224, 0.225])
    ])

valid_transform = transforms.Compose([
        transforms.Resize(224),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        MacenkoNormalizerTransform(Configs.COLOR_NORM_REF_IMG),
        transforms.Normalize([0.485, 0.456, 0.406],
                             [0.229, 0.224, 0.225])
    ])

In [9]:
# train test split
train_test_split = StratifiedShuffleSplit(n_splits=1, test_size=0.2, random_state=Configs.RANDOM_SEED)
train_inds, test_inds = next(train_test_split.split(dataset, y=dataset.targets))
# train validation split
train_valid_split = StratifiedShuffleSplit(n_splits=1, test_size=0.1, random_state=Configs.RANDOM_SEED)
train_inds, valid_inds = next(train_valid_split.split(train_inds, y=[dataset.targets[i] for i in train_inds]))

In [10]:
train_dataset = TileDataset(Subset(dataset, train_inds), transform=train_transform)
valid_dataset = TileDataset(Subset(dataset, valid_inds), transform=valid_transform)
test_dataset = TileDataset(Subset(dataset, test_inds), transform=valid_transform)
len(train_dataset), len(valid_dataset), len(test_dataset)

(8622, 959, 2396)

In [11]:
train_loader = DataLoader(train_dataset, batch_size=Configs.BATCH_SIZE, shuffle=True, num_workers=Configs.NUM_WORKERS)
valid_loader = DataLoader(valid_dataset, batch_size=Configs.BATCH_SIZE, shuffle=False, num_workers=Configs.NUM_WORKERS)
test_loader = DataLoader(test_dataset, batch_size=Configs.BATCH_SIZE, shuffle=False, num_workers=Configs.NUM_WORKERS)

In [12]:
class TumorClassifier(pl.LightningModule):
    def __init__(self, num_classes, tumor_class_ind, learning_rate):
        super().__init__()
        assert tumor_class_ind is not None
        self.tumor_class_ind = tumor_class_ind
        self.learning_rate = learning_rate
        backbone = resnet50(weights="IMAGENET1K_V2")
        num_filters = backbone.fc.in_features
        layers = list(backbone.children())[:-1]
        layers.append(nn.Flatten())
        layers.append(nn.Linear(num_filters, num_classes))
        self.model = nn.Sequential(*layers)
        self.save_hyperparameters()

    def forward(self, x):
        return self.model(x)
    
    def loss(self, scores, targets):
        return F.cross_entropy(scores, targets)
    
    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)
        scheduler = ReduceLROnPlateau(optimizer, factor=0.1, patience=2)
        return {'optimizer': optimizer, 'lr_scheduler': scheduler, 'monitor': 'val_loss'}
    
    def general_loop(self, batch, batch_idx):
        x, y = batch
        scores = self.forward(x)
        loss = self.loss(scores, y)
        return loss, scores, y

    def training_step(self, batch, batch_idx):
        train_loss, scores, y =  self.general_loop(batch, batch_idx)
        self.log("train_loss", train_loss, on_step=True, on_epoch=True)
        return {"loss": train_loss}

    def validation_step(self, batch, batch_idx):
        val_loss, scores, y =  self.general_loop(batch, batch_idx)
        self.log("val_loss", val_loss, on_step=False, on_epoch=True)
        return {"scores": scores, "y": y}
    
    def test_step(self, batch, batch_idx):
        test_loss, scores, y =  self.general_loop(batch, batch_idx)
        self.log("test_loss", test_loss, on_step=False, on_epoch=True)
        return {"scores": scores, "y": y}
    
    def log_epoch_level_metrics(self, outputs, dataset_str):
        scores = torch.concat([out["scores"] for out in outputs])
        logits = softmax(scores, dim=1)
        tumor_prob = logits[:, self.tumor_class_ind].cpu().numpy()
        y_pred = (torch.argmax(scores, dim=1)==self.tumor_class_ind).int().cpu().numpy()
        y = (torch.concat([out["y"] for out in outputs])==self.tumor_class_ind).int().cpu().numpy()
        precision, recall, f1, _ = precision_recall_fscore_support(y, y_pred, average='binary')
        auc = roc_auc_score(y, tumor_prob)
        self.log(f"{dataset_str}_precision", precision)
        self.log(f"{dataset_str}_recall", recall)
        self.log(f"{dataset_str}_f1", f1)
        self.log(f"{dataset_str}_auc", auc)
        
    def validation_epoch_end(self, outputs):
        self.log_epoch_level_metrics(outputs, dataset_str='valid')
    
    def test_epoch_end(self, outputs):
        self.log_epoch_level_metrics(outputs, dataset_str='test')
    
    def predict_step(self, batch, batch_idx):
        x, y = batch
        return self(x)

In [13]:
model = TumorClassifier(Configs.NUM_CLASSES, Configs.TUMOR_IND, Configs.INIT_LR)
logger= TensorBoardLogger(f"logs/{Configs.EXPERIMENT}")
trainer = pl.Trainer(devices=Configs.NUM_DEVICES, accelerator=Configs.DEVICE, 
                     deterministic=True, 
                     check_val_every_n_epoch=1,
                     default_root_dir=Configs.LOG_DIR,
                     enable_checkpointing=True,
                     logger=True,
                     num_sanity_val_steps=2,
                     max_epochs=10)
                     
#                      limit_train_batches=10)
#                      limit_val_batches=5
#                      fast_dev_run=2)
#                      profiler=True)

  rank_zero_warn(
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [14]:
# trainer.fit(model, train_loader, valid_loader,
#             ckpt_path=None)

In [15]:
# trainer.save_checkpoint(Configs.TRAINED_MODEL_PATH)

In [None]:
model = TumorClassifier.load_from_checkpoint(Configs.TRAINED_MODEL_PATH,
                                             num_classes=Configs.NUM_CLASSES,
                                             tumor_class_ind=Configs.TUMOR_IND,
                                             learning_rate=Configs.INIT_LR)

In [22]:
trainer.save_checkpoint("/home/sharonpe/microsatellite-instability-classification/models/tumor_classifier_resnet50_10_epochs_V1.ckpt")

In [23]:
TumorClassifier.load_from_checkpoint("/home/sharonpe/microsatellite-instability-classification/models/tumor_classifier_resnet50_10_epochs_V1.ckpt")

TumorClassifier(
  (model): Sequential(
    (0): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (4): Sequential(
      (0): Bottleneck(
        (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (downsample): Sequential(
          (0): Conv2

In [32]:
# trainer.test(model, dataloaders=test_loader)

In [21]:
trainer.predict(model, valid_loader)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]


Predicting: 0it [00:00, ?it/s]

[tensor([[ -5.1957,   8.9623,  -9.5944],
         [  7.6110,  -5.2342,  -5.7597],
         [ -5.0368,   7.2854,  -5.6657],
         [  7.2404,  -4.7327,  -5.1989],
         [  7.1492,  -4.5806,  -5.5456],
         [ -1.0326,   2.7279,  -4.7821],
         [ -3.3674,   6.1981,  -5.9174],
         [ -5.4135,   5.8016,  -3.4038],
         [ -6.0110,  -4.2572,   6.9915],
         [-10.4216,   8.1384,  -3.3810],
         [ -6.1634,   4.7349,  -1.8041],
         [ 11.4824,  -8.3061,  -7.7308],
         [ -4.1327,   6.4059,  -5.7957],
         [ -8.3006,   7.3237,  -3.7305],
         [ -6.9310,   7.7202,  -4.8088],
         [ -7.9247,   6.6957,  -3.6528]]),
 tensor([[ 11.2224,  -8.1687,  -7.3587],
         [ -7.7341,   0.2064,   3.4276],
         [ -6.7837,  -3.8305,   6.9998],
         [ -1.7673,   4.4989,  -5.2018],
         [  5.9979,  -1.8032,  -6.8798],
         [  7.9663,  -5.5291,  -5.2802],
         [  5.1906,  -0.3227,  -8.0683],
         [  7.7734,  -3.4486,  -7.4926],
         [ 15.

In [36]:
from pytorch_lightning.callbacks import BasePredictionWriter


class CustomWriter(BasePredictionWriter):

    def __init__(self, output_dir, write_interval):
        super().__init__(write_interval)
        self.output_dir = output_dir

    def write_on_epoch_end(self, trainer, pl_module, predictions, batch_indices):
        # this will create N (num processes) files in `output_dir` each containing
        # the predictions of it's respective rank
#         torch.save(predictions, os.path.join(self.output_dir, f"predictions_{trainer.global_rank}.pt"))

#         # optionally, you can also save `batch_indices` to get the information about the data index
#         # from your prediction data
#         torch.save(batch_indices, os.path.join(self.output_dir, f"batch_indices_{trainer.global_rank}.pt"))
        print(batch_indices)


# or you can set `writer_interval="batch"` and override `write_on_batch_end` to save
# predictions at batch level
pred_writer = CustomWriter(output_dir="pred_path", write_interval="epoch")
trainer = pl.Trainer(accelerator="gpu", devices=1, callbacks=[pred_writer])
trainer.predict(model, valid_loader, return_predictions=False)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]


Predicting: 0it [00:00, ?it/s]

[[[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15], [16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31], [32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47], [48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63], [64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79], [80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95], [96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111], [112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127], [128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143], [144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159], [160, 161, 162, 163, 164, 165, 166, 167, 168, 169, 170, 171, 172, 173, 174, 175], [176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, 191], [192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207], [208, 209, 210, 211, 212, 213, 214, 215, 2