In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import sys
sys.path.append("../")

In [4]:
from pathlib import Path
from functools import partial


from ml_core.preprocessing.dataset import create_dataloader
from ml_core.modeling.unet import UNet


import torch
from torch.optim import Adam
from torch.utils.data import DataLoader, ConcatDataset
import albumentations as A
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping

from matplotlib import pyplot as plt
import numpy as np

## Define dataloaders

In [5]:
patch_size = 256
class_name, dataset = "Glomerulus", "MultiStain"
# class_name, dataset = "Glomerulus", "Collage"
# class_name, dataset = "Artery", "MultiStain"
# class_name, dataset = "Artery", "Collage"
# class_name, dataset = "Tubules", "MultiStain"
# class_name, dataset = "Tubules", "Collage"

In [6]:
aug_transform = A.Compose([
    A.VerticalFlip(p=0.5),
    A.HorizontalFlip(p=0.5),
    A.Rotate(),
    A.OneOf([
        A.HueSaturationValue(),
        A.ToGray()
    ]),
    A.ElasticTransform(p=1, alpha=120, sigma=120 * 0.05, alpha_affine=120 * 0.03)
])

In [7]:
data_root = Path(f"../../data/{dataset}/hdf5_data/patch_{patch_size}/")
train_fname = data_root / f"{class_name}_train.h5"
val_fname = data_root / f"{class_name}_val.h5"

test_data_root = Path(f"../../data/AAPI/hdf5_data/patch_{patch_size}/")
test_fname = []
test_fname.append(test_data_root / f"{class_name}_test.h5")

In [8]:
train_dataloader, train_dataset = create_dataloader(train_fname, transform=aug_transform, return_dataset=True)
val_dataloader, val_dataset = create_dataloader(val_fname, batch_size=64, shuffle=False, return_dataset=True)
test_dataloader, test_dataset = create_dataloader(test_fname, batch_size=64, shuffle=False, return_dataset=True)

## Training starts here

In [9]:
checkpoint_callback = ModelCheckpoint(save_top_k=1, monitor="val/f1_score", mode="max")
early_stopping = EarlyStopping(monitor="val/f1_score", mode="max", patience=10, min_delta=0)

In [10]:
trainer = Trainer(gpus=1, 
                  default_root_dir=Path(f"../../output/{dataset}/{class_name}"),
                  max_epochs=10,
                  callbacks=[early_stopping],
                  checkpoint_callback=checkpoint_callback,
                  log_every_n_steps=1,
                  flush_logs_every_n_steps=50)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


In [11]:
unet = UNet(in_channels=3,
            n_classes=2,
            depth = 5,# this depth may need to be modified
            wf = 3,
            padding = True,
            batch_norm = True,
            up_mode = "upconv",
            optimizer = [partial(Adam, lr=1e-3)],
            edge_weight = 1.2)

In [12]:
trainer.fit(unet, train_dataloader, val_dataloader)


  | Name      | Type             | Params
-----------------------------------------------
0 | down_path | ModuleList       | 296 K 
1 | up_path   | ModuleList       | 191 K 
2 | last      | Conv2d           | 18    
3 | criterion | CrossEntropyLoss | 0     


HBox(children=(HTML(value='Validation sanity check'), FloatProgress(value=1.0, bar_style='info', layout=Layout…

HBox(children=(HTML(value='Training'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…




1

In [13]:
def fetch_best_model_path(trainer):
    chpt_paths = (Path(trainer.logger.log_dir) / Path("checkpoints/")).glob("*.ckpt")
    best_chpt = next(chpt_paths)
    return str(best_chpt)

In [14]:
best_model = UNet.load_from_checkpoint(fetch_best_model_path(trainer))

In [15]:
trainer.test(best_model, test_dataloader)

HBox(children=(HTML(value='Testing'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max=…

--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'test/f1_score': tensor(0.2861, device='cuda:0'),
 'test/loss': tensor(1.4463, device='cuda:0'),
 'test/precision': tensor(0.1745, device='cuda:0'),
 'test/recall': tensor(0.8863, device='cuda:0'),
 'train/f1_score': tensor(0.7494, device='cuda:0'),
 'train/loss': tensor(0.3391, device='cuda:0'),
 'train/precision': tensor(0.7325, device='cuda:0'),
 'train/recall': tensor(0.7670, device='cuda:0'),
 'val/f1_score': tensor(0.6242, device='cuda:0'),
 'val/loss': tensor(0.5620, device='cuda:0'),
 'val/precision': tensor(0.5116, device='cuda:0'),
 'val/recall': tensor(0.8592, device='cuda:0')}
--------------------------------------------------------------------------------



[{'train/loss': 0.339130699634552,
  'train/f1_score': 0.7493833303451538,
  'train/precision': 0.7325277924537659,
  'train/recall': 0.7670329809188843,
  'val/loss': 0.5620375275611877,
  'val/f1_score': 0.6241845488548279,
  'val/precision': 0.5115647912025452,
  'val/recall': 0.8592203259468079,
  'test/loss': 1.4463034868240356,
  'test/f1_score': 0.2861465513706207,
  'test/precision': 0.17450568079948425,
  'test/recall': 0.8863489627838135}]

In [17]:
print(f"Dataset: {dataset}; Class: {class_name}.")
trainer.save_checkpoint(f"../../output/{dataset}/{class_name}_best_model.ckpt")

Dataset: MultiStain; Class: Glomerulus.
