In [1]:
import os
import torch
from torch import nn
import torch.nn.functional as F
from torchvision import transforms
from torch.utils.data import DataLoader, Dataset, Subset
from torchvision.datasets import ImageFolder
import pytorch_lightning as pl
from torchvision.models import resnet50
from sklearn.model_selection import StratifiedShuffleSplit
from dataclasses import dataclass
from torch.multiprocessing import Pool, set_start_method
from pytorch_lightning.loggers import TensorBoardLogger
from torch.nn.functional import softmax
from sklearn.metrics import precision_recall_fscore_support, roc_auc_score
from torch.optim.lr_scheduler import ReduceLROnPlateau
import torchstain
from PIL import Image
import matplotlib.pyplot as plt 

In [2]:
import sys
sys.path.append("../")
import warnings
warnings.filterwarnings("ignore", category=UserWarning, module="torchstain")

In [3]:
set_start_method("spawn")

In [4]:
# %load_ext autoreload
# %autoreload 2
from src.utils import set_random_seed
from src.configs import ConfigsClass
from tumor_classifier import TileDataset, MacenkoNormalizerTransform

In [5]:
@dataclass
class Configs:
    DATA_ROOT = '../data/tumor_labeled_tiles'
    LOG_DIR = os.getcwd()
    COLOR_NORM_REF_IMG = ConfigsClass.COLOR_NORM_REF_IMG
    DEVICE = 'gpu'
    NUM_DEVICES = 1  # in jupyter multiple devices is not supported
    INIT_LR = 1e-4
    BATCH_SIZE = 16
    RANDOM_SEED = 123
    NUM_CLASSES = 3
    EXPERIMENT = 'tc_V1'
    TUMOR_CLASS = 'TUMSTU'
    NON_TUMOR_CLASSES = ['STRMUS', 'ADIMUC']
    TUMOR_IND = None
    NON_TUMOR_INDS = None

In [6]:
set_random_seed(Configs.RANDOM_SEED)

In [7]:
dataset = ImageFolder(Configs.DATA_ROOT)
Configs.TUMOR_IND = dataset.class_to_idx[Configs.TUMOR_CLASS]
Configs.NON_TUMOR_INDS = [dataset.class_to_idx[class_name] for class_name in Configs.NON_TUMOR_CLASSES]

In [8]:
train_transform = transforms.Compose([
        transforms.RandomHorizontalFlip(),  # reverse 50% of images
        transforms.RandomVerticalFlip(),  # reverse 50% of images  
        transforms.Resize(224),            
        transforms.CenterCrop(224),         
        transforms.ToTensor(),
        MacenkoNormalizerTransform(Configs.COLOR_NORM_REF_IMG),
        transforms.Normalize([0.485, 0.456, 0.406],
                             [0.229, 0.224, 0.225])
    ])

valid_transform = transforms.Compose([
        transforms.Resize(224),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        MacenkoNormalizerTransform(Configs.COLOR_NORM_REF_IMG),
        transforms.Normalize([0.485, 0.456, 0.406],
                             [0.229, 0.224, 0.225])
    ])

In [9]:
# train test split
train_test_split = StratifiedShuffleSplit(n_splits=1, test_size=0.2, random_state=Configs.RANDOM_SEED)
train_inds, test_inds = next(train_test_split.split(dataset, y=dataset.targets))
# train validation split
train_valid_split = StratifiedShuffleSplit(n_splits=1, test_size=0.1, random_state=Configs.RANDOM_SEED)
train_inds, valid_inds = next(train_valid_split.split(train_inds, y=[dataset.targets[i] for i in train_inds]))

In [10]:
train_dataset = TileDataset(Subset(dataset, train_inds), transform=train_transform)
valid_dataset = TileDataset(Subset(dataset, valid_inds), transform=valid_transform)
test_dataset = TileDataset(Subset(dataset, test_inds), transform=valid_transform)
len(train_dataset), len(valid_dataset), len(test_dataset)

(8622, 959, 2396)

In [11]:
train_loader = DataLoader(train_dataset, batch_size=Configs.BATCH_SIZE, shuffle=True, num_workers=16)
valid_loader = DataLoader(valid_dataset, batch_size=Configs.BATCH_SIZE, shuffle=False, num_workers=16)
test_loader = DataLoader(test_dataset, batch_size=Configs.BATCH_SIZE, shuffle=False, num_workers=16)

In [12]:
class TumorClassifier(pl.LightningModule):
    def __init__(self, num_classes):
        super().__init__()
        backbone = resnet50(weights="IMAGENET1K_V2")
        num_filters = backbone.fc.in_features
        layers = list(backbone.children())[:-1]
        layers.append(nn.Flatten())
        layers.append(nn.Linear(num_filters, num_classes))
        self.model = nn.Sequential(*layers)

    def forward(self, x):
        return self.model(x)
    
    def loss(self, scores, targets):
        return F.cross_entropy(scores, targets)
    
    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=Configs.INIT_LR)
        scheduler = ReduceLROnPlateau(optimizer, factor=0.1, patience=2)
        return {'optimizer': optimizer, 'lr_scheduler': scheduler, 'monitor': 'val_loss'}
    
    def general_loop(self, batch, batch_idx):
        x, y = batch
        scores = self.forward(x)
        loss = self.loss(scores, y)
        return loss, scores, y

    def training_step(self, batch, batch_idx):
        train_loss, scores, y =  self.general_loop(batch, batch_idx)
        self.log("train_loss", train_loss, on_step=True, on_epoch=True)
        return {"loss": train_loss}

    def validation_step(self, batch, batch_idx):
        val_loss, scores, y =  self.general_loop(batch, batch_idx)
        self.log("val_loss", val_loss, on_step=False, on_epoch=True)
        return {"scores": scores, "y": y}
    
    def test_step(self, batch, batch_idx):
        test_loss, scores, y =  self.general_loop(batch, batch_idx)
        self.log("test_loss", test_loss, on_step=False, on_epoch=True)
        return {"scores": scores, "y": y}
    
    def log_epoch_level_metrics(self, outputs, dataset_str):
        scores = torch.concat([out["scores"] for out in outputs])
        logits = softmax(scores, dim=1)
        tumor_prob = logits[:, Configs.TUMOR_IND].cpu().numpy()
        y_pred = (torch.argmax(scores, dim=1)==Configs.TUMOR_IND).int().cpu().numpy()
        y = (torch.concat([out["y"] for out in outputs])==Configs.TUMOR_IND).int().cpu().numpy()
        precision, recall, f1, _ = precision_recall_fscore_support(y, y_pred, average='binary')
        auc = roc_auc_score(y, tumor_prob)
        self.log(f"{dataset_str}_precision", precision)
        self.log(f"{dataset_str}_recall", recall)
        self.log(f"{dataset_str}_f1", f1)
        self.log(f"{dataset_str}_auc", auc)
        
    def validation_epoch_end(self, outputs):
        self.log_epoch_level_metrics(outputs, dataset_str='valid')
    
    def test_epoch_end(self, outputs):
        self.log_epoch_level_metrics(outputs, dataset_str='test')

In [13]:
model = TumorClassifier(Configs.NUM_CLASSES)
logger= TensorBoardLogger(f"logs/{Configs.EXPERIMENT}")
trainer = pl.Trainer(devices=Configs.NUM_DEVICES, accelerator=Configs.DEVICE, 
                     deterministic=True, 
                     check_val_every_n_epoch=1,
                     default_root_dir=Configs.LOG_DIR,
                     enable_checkpointing=True,
                     logger=True,
                     num_sanity_val_steps=2,
                     max_epochs=10)
                     
#                      limit_train_batches=10)
#                      limit_val_batches=5
#                      fast_dev_run=2)
#                      profiler=True)

  rank_zero_warn(
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [14]:
trainer.fit(model, train_loader, valid_loader,
            ckpt_path=None)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]

  | Name  | Type       | Params
-------------------------------------
0 | model | Sequential | 23.5 M
-------------------------------------
23.5 M    Trainable params
0         Non-trainable params
23.5 M    Total params
94.057    Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

`Trainer.fit` stopped: `max_epochs=10` reached.


In [None]:
model = TumorClassifier.load_from_checkpoint('lightning_logs/version_4/checkpoints/epoch=9-step=5390.ckpt',
                                            num_classes=Configs.NUM_CLASSES)

In [15]:
trainer.test(model, dataloaders=test_loader)

  rank_zero_warn(
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]


Testing: 0it [00:00, ?it/s]

────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       Test metric             DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
        test_auc            0.9996671365914787
         test_f1            0.9452867501647989
        test_loss           0.10564085096120834
     test_precision                 1.0
       test_recall                0.89625
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────


[{'test_loss': 0.10564085096120834,
  'test_precision': 1.0,
  'test_recall': 0.89625,
  'test_f1': 0.9452867501647989,
  'test_auc': 0.9996671365914787}]

In [15]:
# y_pred = trainer.pre(model, dataloaders=test_loader)

a
