In [1]:
import matplotlib.pyplot as plt
from einops import rearrange

import torch
import torch.nn.functional as F


from monai.transforms import ScaleIntensity
from monai.config import print_config

from scripts.train import LitUnet
from scripts.data import CentreDataModule
import pytorch_lightning as pl
from pytorch_lightning.callbacks import EarlyStopping
from pytorch_lightning import seed_everything

%load_ext autoreload
%autoreload 2

from scripts.utils import *

#print_config()


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
seed_everything(42, workers=True)

model = LitUnet(lr=1e-3)

Global seed set to 42


In [3]:
transform =  ScaleIntensity(minv=0.0, maxv=1.0, channel_wise=True)
dm = CentreDataModule("A", split_ratio=0.7, load_transform=transform, batch_size=8)
# Training and Validation
trainer = pl.Trainer(max_epochs=50, 
                     deterministic=True, 
                     logger=True,
                     log_every_n_steps=1, 
                     enable_model_summary=False,
                     callbacks=[EarlyStopping('val_loss', patience=2)], 
                     fast_dev_run=False)
trainer.fit(model, datamodule=dm)
trainer.save_checkpoint("checkpoints/benchmark_deterministic.ckpt")

  rank_zero_warn(
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
  rank_zero_warn(


total number of samples: 95, train samples: 66, Validation: 29


Loading Images/Labels:: 100%|██████████| 66/66 [00:02<00:00, 32.02File/s]
Loading Images/Labels:: 100%|██████████| 66/66 [00:00<00:00, 69.18File/s]
Loading Images/Labels:: 100%|██████████| 29/29 [00:00<00:00, 33.69File/s]
Loading Images/Labels:: 100%|██████████| 29/29 [00:00<00:00, 70.76File/s]
  rank_zero_warn(f"Checkpoint directory {dirpath} exists and is not empty.")
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Epoch 10: 100%|██████████| 195/195 [00:06<00:00, 29.40it/s, v_num=39260, val_loss=0.0591, train_loss=0.0395]


In [4]:
model = LitUnet.load_from_checkpoint("checkpoints/benchmark_deterministic.ckpt")

In [32]:
trainer.test(ckpt_path="best", datamodule=dm)

Restoring states from the checkpoint path at /home/ids/mahdi-22/M-M/lightning_logs/version_39260/checkpoints/epoch=10-step=2145-v1.ckpt
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Loaded model weights from the checkpoint at /home/ids/mahdi-22/M-M/lightning_logs/version_39260/checkpoints/epoch=10-step=2145-v1.ckpt


Testing DataLoader 5: 100%|██████████| 50/50 [00:18<00:00,  2.64it/s]


[{'test_loss/dataloader_idx_0': 0.03938731258060311},
 {'test_loss/dataloader_idx_1': 0.05907154347426737},
 {'test_loss/dataloader_idx_2': 0.07059815582702761},
 {'test_loss/dataloader_idx_3': 0.05236804174457443},
 {'test_loss/dataloader_idx_4': 0.06769312220450134},
 {'test_loss/dataloader_idx_5': 0.09842617480538117}]

In [26]:
results = trainer.model.results

In [27]:
results[["Centre", "batch_idx"]] = results.iloc[:, :2].astype(int)

In [28]:
results

Unnamed: 0,Centre,batch_idx,Dice_BG_ED,Dice_LV_ED,Dice_MYO_ED,Dice_RV_ED,Dice_BG_ES,Dice_LV_ES,Dice_MYO_ES,Dice_RV_ES,IoU_BG_ED,IoU_LV_ED,IoU_MYO_ED,IoU_RV_ED,IoU_BG_ES,IoU_LV_ES,IoU_MYO_ES,IoU_RV_ES
0,0,0,0.988747,0.831898,0.799259,0.905956,0.977744,0.712179,0.665638,0.828080,0.990050,0.771506,0.790560,0.863945,0.980297,0.628009,0.653658,0.760477
1,0,1,0.996239,0.966309,0.876440,0.948900,0.992506,0.934813,0.780056,0.902769,0.994879,0.828447,0.837788,0.933456,0.989810,0.707136,0.720856,0.875216
2,0,2,0.994562,0.944106,0.875126,0.913029,0.989184,0.894130,0.777977,0.839975,0.993235,0.875850,0.797721,0.884424,0.986561,0.779122,0.663508,0.792795
3,0,3,0.991603,0.958498,0.851777,0.895070,0.983346,0.920303,0.741821,0.810069,0.993229,0.966327,0.838127,0.836163,0.986549,0.934847,0.721359,0.718454
4,0,4,0.992211,0.953330,0.804658,0.912997,0.984543,0.910821,0.673161,0.839922,0.996332,0.874359,0.842740,0.897663,0.992691,0.776766,0.728220,0.814327
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
312,5,45,0.985826,0.842664,0.710956,0.876743,0.972048,0.728107,0.551537,0.780536,0.981116,0.727385,0.515544,0.670845,0.962932,0.571567,0.347295,0.504716
313,5,46,0.978904,0.844623,0.760518,0.784623,0.958679,0.731036,0.613577,0.645581,0.981182,0.721175,0.646057,0.634754,0.963059,0.563936,0.477167,0.464938
314,5,47,0.989500,0.933935,0.756554,0.916756,0.979219,0.876058,0.608433,0.846307,0.991364,0.917394,0.768554,0.847928,0.982876,0.847395,0.624107,0.736002
315,5,48,0.992171,0.943271,0.752622,0.911907,0.984463,0.892633,0.603363,0.838078,0.988841,0.894372,0.752053,0.757256,0.977928,0.808926,0.602632,0.609342


In [29]:
results.to_csv("results/results.csv", index=False)

In [30]:
results = pd.read_csv("results/results.csv")