In [1]:
import numpy as np
import os
import torch
import argparse
from glob import glob
from torch import nn
from torch.nn import functional as F
from torch.optim.lr_scheduler import StepLR, ReduceLROnPlateau
from torchvision import datasets, transforms
from torchvision.utils import make_grid
from CNNs.unet import UNet
from Utils.transforms import My_transforms
import Utils.datasets as my_datasets
from torch.optim import Adam, SGD

from Utils.loss import SoftDiceLoss
from Utils.loss import DiceLoss
from Utils.loss import DiceLoss_chavg
from Utils.loss import DiceLoss_weighs
from Utils.loss import CombinedLoss
from Utils.Metrics import DiceMetric_weighs

import time
import multiprocessing as mp

%load_ext tensorboard

import pytorch_lightning as pl
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint, Callback
# from pytorch_lightning.loggers import NeptuneLogger

print(f"Pytorch Lightning Version: {pl.__version__}")
print(f"Torch Version: {torch.__version__}")

Pytorch Lightning Version: 1.3.8
Torch Version: 1.8.1


# Step 1 - Definition of traning and model parameters

In [3]:
#Definitions of map combination

# input_channels=['evalue1', 'FA', 'RD', 'MD', 'T1']
# input_channels=['evalue1', 'FA', 'RD', 'MD']
# input_channels=['T1']
# input_channels=['FA', 'RD']
# input_channels=['FA', 'T1']
# input_channels=['FA', 'MD']
# input_channels=['FA', 'evalue1']
# input_channels=['RD', 'MD']
# input_channels=['FA']
# input_channels=['MD']
# input_channels=['RD']
# input_channels=['evalue1']
# input_channels=['MD', 'T1']
input_channels=['evalue1', 'T1']

output_channels=['staple_oh']


w = 1
lossweighs = [[0, w]]

loss_funcs = ['DiceLoss_weighs']
func_weights = [1]

metricweight = [0,1]
train_metric = 'DiceMetric_weighs'
val_metric = 'DiceMetric_weighs'

view = "all" # TRAINS ALL DATA VIEWS AUTOMATICALLY
# view = "sagittal"
# view = "coronal"
# view = "axial"

hyperparameters = {"experiment_name": "unet_single_label", #experiment params
                   "description": f"unet_single_label_psz064_{view}_{''.join(input_channels)}",
                   "data_view": view,  
                   "dataset_folder": './Data/Patches/single_label_psz064/',
                   "subjects_list": './training_subjects_randomized.txt',
           
                   "in_channels":input_channels,
                   "masks":output_channels,
# CNN architecture
                   "cnn_architecture": 'unet',
                   "input_size": 32,
                   "n_inchannels": len(input_channels),
                   "n_outchannels": len(output_channels)*2,
                   "init_features": 32,              
#augentation                                     
                   "taug_angle": 5,
                   "taug_flip_prob": 0.5,       
# train_params                 
                   "max_epochs": 2000,
                   "patience": 20,  # patience for early stop
                   "learning_threshold": 0.01, # default = 0.01
                   "batch_size": 256,  # Bigger is faster. limited by RAM or VRAM
#                    "batch_size": 512,
#                    "batch_size": 128,
#                    "batch_size": 64, 
#                    "batch_size": 16, 
                   "split_train_val": 0.2, 
                   
                   "opt_name": "Adam", 
#                    "opt_name": "SGD", 
                   "min_lr": 1e-06,
                   "eps": 1e-05,
                   "monitor": 'val_loss',
                   "lr": 1e-3,  
#                    "lr": 1e-5,  
                   "scheduling_patience_lrepochs": 3, # patience for lr deacay
                   "lr_decay_factor": 0.1, 
                   "lr_decay_policy": 'plateau', # plateau or step
                   "lr_decay_mode": 'min',
                   "lr_decay_threshold_mode" : 'abs', #rel, abs (plateau only)
                   
                   
                   "lossweighs": lossweighs,
                   "func_weights": func_weights,
                   "train_loss_funcs": loss_funcs,
                   "train_metric": train_metric,
                   "train_metricweighs": metricweight,
                   "val_loss_funcs": loss_funcs,
                   "val_metric": val_metric,
                   "val_metricweighs": metricweight,
                   "nworkers": 8,
                   "val_transform": {},
                   
                   
                  }
original_hyperparameters = hyperparameters.copy()
hyperparameters

{'experiment_name': 'unet_single_label',
 'description': 'unet_single_label_psz064_all_evalue1T1',
 'data_view': 'all',
 'dataset_folder': './Data/Patches/single_label_psz064/',
 'subjects_list': './training_subjects_randomized.txt',
 'in_channels': ['evalue1', 'T1'],
 'masks': ['staple_oh'],
 'cnn_architecture': 'unet',
 'input_size': 32,
 'n_inchannels': 2,
 'n_outchannels': 2,
 'init_features': 32,
 'taug_angle': 5,
 'taug_flip_prob': 0.5,
 'max_epochs': 2000,
 'patience': 20,
 'learning_threshold': 0.01,
 'batch_size': 256,
 'split_train_val': 0.2,
 'opt_name': 'Adam',
 'min_lr': 1e-06,
 'eps': 1e-05,
 'monitor': 'val_loss',
 'lr': 0.001,
 'scheduling_patience_lrepochs': 3,
 'lr_decay_factor': 0.1,
 'lr_decay_policy': 'plateau',
 'lr_decay_mode': 'min',
 'lr_decay_threshold_mode': 'abs',
 'lossweighs': [[0, 1]],
 'func_weights': [1],
 'train_loss_funcs': ['DiceLoss_weighs'],
 'train_metric': 'DiceMetric_weighs',
 'train_metricweighs': [0, 1],
 'val_loss_funcs': ['DiceLoss_weighs'],

# Step 2 - Defining modules and classes (PyTorch Lightning Framework)

In [5]:
class MyDataModule(pl.LightningDataModule):
    def __init__(self, hparams):
        super().__init__()
        self.hparams = hparams

    def setup(self, stage=None):
        pass

    def train_dataloader(self):   
        train_dataset = my_datasets.PatchDataSet(subject_list=self.hparams['subjects_list'], root=self.hparams['dataset_folder'] + self.hparams['data_view'] + '/',
                                         channels=self.hparams['in_channels'],
                                         masks=self.hparams['masks'],
                                         transform=self.train_transforms,
                                         valid_split=self.hparams['split_train_val']
                                         )        
        train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
                                           batch_size=self.hparams['batch_size'],
                                           shuffle=True,
                                           num_workers=self.hparams['nworkers'], pin_memory=False)       
        
        return train_loader

    def val_dataloader(self):
        val_dataset = my_datasets.PatchDataSet(subject_list=self.hparams['subjects_list'], root=self.hparams['dataset_folder'] + self.hparams['data_view'] + '/',
                                         channels=self.hparams['in_channels'],
                                         masks=self.hparams['masks'],
                                         transform=self.val_transforms,
                                         valid_split=self.hparams['split_train_val'],
                                         validation=True
                                         )        
        val_loader = torch.utils.data.DataLoader(dataset=val_dataset,
                                           batch_size=self.hparams['batch_size'],
                                           shuffle=False,
                                           num_workers=self.hparams['nworkers'], pin_memory=False)   
        return val_loader

class Segmentor(pl.LightningModule):
    def __init__(self, hparams: argparse.Namespace):
        super().__init__()

        self.save_hyperparameters(hparams)        

        if "unet" in self.hparams.cnn_architecture:
            architecture = UNet(nin_channels=self.hparams.n_inchannels, 
                                nout_channels=self.hparams.n_outchannels, 
                                init_features=self.hparams.init_features)
        elif self.hparams.cnn_architecture == "coedet":
            architecture = CoEDET(nin=self.hparams.n_inchannels, nout=self.hparams.n_outchannels, 
                                  apply_sigmoid=self.hparams.apply_sigmoid)
        else:
            raise ValueError(f"Unsupported cnn_architecture {self.hparams.cnn_architecture}")

        self.model = architecture
    
        
        ttransform_scale=None
        ttransform_angle=None
        ttransform_flip_prob=None
        ttransform_sigma=None
        ttransform_ens_treshold=None
        if "taug_scale" in self.hparams:
            ttransform_scale = self.hparams.taug_scale
        if "taug_angle" in self.hparams:
            ttransform_angle = self.hparams.taug_angle
        if "taug_flip_prob" in self.hparams:
            ttransform_flip_prob = self.hparams.taug_flip_prob
        if "taug_sigma" in self.hparams:
            ttransform_sigma = self.hparams.taug_sigma
        if "taug_ens_treshold" in self.hparams:
            ttransform_ens_treshold = self.hparams.aug_ens_treshold
        self.train_transforms = My_transforms(scale=ttransform_scale,
                                         angle=ttransform_angle,
                                         flip_prob=ttransform_flip_prob,
                                         sigma=ttransform_sigma,
                                         ens_treshold=ttransform_ens_treshold
                                        )
        vtransform_scale=None
        vtransform_angle=None
        vtransform_flip_prob=None
        vtransform_sigma=None
        vtransform_ens_treshold=None
        if "vaug_scale" in self.hparams:
            ttransform_scale = self.hparams.vaug_scale
        if "vaug_angle" in self.hparams:
            ttransform_angle = self.hparams.vaug_angle
        if "vaug_flip_prob" in self.hparams:
            ttransform_flip_prob = self.hparams.vaug_flip_prob
        if "vaug_sigma" in self.hparams:
            ttransform_sigma = self.hparams.vaug_sigma
        if "vaug_ens_treshold" in self.hparams:
            ttransform_ens_treshold = self.hparams.vaug_ens_treshold
        self.val_transforms = My_transforms(scale=vtransform_scale,
                                         angle=vtransform_angle,
                                         flip_prob=vtransform_flip_prob,
                                         sigma=vtransform_sigma,
                                         ens_treshold=vtransform_ens_treshold
                                        )
        
  

    def forward(self, x):
        return self.model(x)

    def training_step(self, train_batch, batch_idx):
        loss = None

        x, y = train_batch
        logits = self.forward(x)
#         print('x.shape = ', x.shape)
#         print('y.shape = ', y.shape)
#         print('logits.shape = ', logits.shape)

        loss = CombinedLoss(logits, y, 
                            self.hparams.train_loss_funcs, 
                            self.hparams.lossweighs,
                            func_weights=self.hparams.func_weights)

        if self.hparams.train_metric == 'DiceMetric_weighs':
            train_metric = DiceMetric_weighs(y_pred=logits, y_true=y,
                                             weights=self.hparams.train_metricweighs, treshold=0.5)
        else:
            raise ValueError(f"Unsupported metric {self.hparams.train_metric}")

        self.log("loss", loss, on_epoch=True, on_step=True)
        self.log("train_metric", train_metric, on_epoch=True, on_step=False)

        return loss

    def validation_step(self, val_batch, batch_idx):
        logits = None

        x, y = val_batch
        logits = self.forward(x)
#         loss = self.lossfunc(logits, y)

        loss = CombinedLoss(logits, y, 
                            self.hparams.val_loss_funcs, 
                            self.hparams.lossweighs,
                            func_weights=self.hparams.func_weights)
    
        if self.hparams.val_metric == 'DiceMetric_weighs':
            val_metric = DiceMetric_weighs(y_pred=logits, y_true=y,
                                             weights=self.hparams.val_metricweighs, treshold=0.5)
        else:
            raise ValueError(f"Unsupported metric {self.hparams.val_metric}")

        self.log("val_loss", loss, on_epoch=True, on_step=False, prog_bar=True)
        self.log("val_metric", val_metric, on_epoch=True, on_step=False, prog_bar=True)        
        self.log("learning_rate_test", self.optimizer.param_groups[0]['lr'], on_epoch=True, on_step=False, prog_bar=False)

    def get_optimizer_by_name(self, name, lr):
        if name == "Adam":
            return Adam(self.model.parameters(), lr=lr)
        elif name == "SGD":
            return SGD(self.model.parameters(), lr=lr)
        else:
            raise ValueError(f"Unsupported optimizer: {name}")
            

    def configure_optimizers(self):
        '''
        Select optimizer and scheduling strategy according to hparams.
        '''
        optimizer = self.get_optimizer_by_name(self.hparams.opt_name, 
                                               self.hparams.lr)

        if self.hparams.lr_decay_policy == 'step':
            scheduler = StepLR(optimizer, self.hparams.scheduling_patience_lrepochs, self.hparams.lr_decay_factor, verbose=True)
            print('STEP - scheduling_patience_lrepochs = ', self.hparams.scheduling_patience_lrepochs, ' lr_decay_factor = ', self.hparams.lr_decay_factor)
        elif self.hparams.lr_decay_policy == 'plateau':
            print('PLATEAU - scheduling_patience_lrepochs = ', self.hparams.scheduling_patience_lrepochs, ' lr_decay_factor = ', self.hparams.lr_decay_factor)
            
            self.optimizer = optimizer
            lr_scheduler =  {
                           'scheduler': ReduceLROnPlateau(optimizer),
                           'mode': self.hparams.lr_decay_mode,
                           'factor': self.hparams.lr_decay_factor,
                           'patience': self.hparams.scheduling_patience_lrepochs,
                           'threshold': self.hparams.learning_threshold,
                           'threshold_mode': self.hparams.lr_decay_threshold_mode,
                           'cooldown': 0,
                           'min_lr': self.hparams.min_lr,
                           'eps': self.hparams.eps,
                           'monitor': self.hparams.monitor,
                           'verbose': True
                           }
            return {"optimizer": optimizer, "lr_scheduler": lr_scheduler}
        
         
            
        else:
            raise ValueError(f"Unsupported lr_decay_policy {self.hparams.lr_decay_policy}")
            

        return [optimizer], [scheduler]
     

# Step 3 - Test training hyperperameters and framework

In [6]:
%%time

original_desc = original_hyperparameters["description"]
original_view = original_hyperparameters["data_view"]
if original_view == "all":
    views = ["coronal", "sagittal", "axial"]
else:
    views = [original_view]
    
for view in views:
    new_desc = original_desc.replace("all", view)
    hyperparameters = original_hyperparameters.copy()
    
    hyperparameters["data_view"] = view
    hyperparameters["description"] = new_desc
    if view == "sagittal":
        print("Disabling sagital flip aug")
        hyperparameters["taug_flip_prob"] = 0  # disable flip on sagittal
    print(f"Starting training for view {view}, description {new_desc}")
    
    model = Segmentor(hparams=hyperparameters)
    data = MyDataModule(hparams=hyperparameters)

    # Setting the trainer for fast_dev_run
    trainer_just_1batch = pl.Trainer(fast_dev_run=True, 
                                     profiler=None,
                                     gpus=1,  # GPU number
                                     precision=32,  
                                     logger=False,  
                                     callbacks=None,  #
                                     checkpoint_callback=False,
                                     )

    trainer_just_1batch.fit(model, data)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
Running in fast_dev_run mode: will run a full train, val and test loop using 1 batch(es).
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Starting training for view coronal, description unet_single_label_psz064_coronal_evalue1T1



  | Name  | Type | Params
-------------------------------
0 | model | UNet | 7.8 M 
-------------------------------
7.8 M     Trainable params
0         Non-trainable params
7.8 M     Total params
31.051    Total estimated model params size (MB)


PLATEAU - scheduling_patience_lrepochs =  3  lr_decay_factor =  0.1


HBox(children=(HTML(value='Training'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
Running in fast_dev_run mode: will run a full train, val and test loop using 1 batch(es).
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name  | Type | Params
-------------------------------
0 | model | UNet | 7.8 M 
-------------------------------
7.8 M     Trainable params
0         Non-trainable params
7.8 M     Total params
31.051    Total estimated model params size (MB)



Disabling sagital flip aug
Starting training for view sagittal, description unet_single_label_psz064_sagittal_evalue1T1
PLATEAU - scheduling_patience_lrepochs =  3  lr_decay_factor =  0.1


HBox(children=(HTML(value='Training'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
Running in fast_dev_run mode: will run a full train, val and test loop using 1 batch(es).
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name  | Type | Params
-------------------------------
0 | model | UNet | 7.8 M 
-------------------------------
7.8 M     Trainable params
0         Non-trainable params
7.8 M     Total params
31.051    Total estimated model params size (MB)



Starting training for view axial, description unet_single_label_psz064_axial_evalue1T1
PLATEAU - scheduling_patience_lrepochs =  3  lr_decay_factor =  0.1


HBox(children=(HTML(value='Training'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…


CPU times: user 12.8 s, sys: 2.28 s, total: 15.1 s
Wall time: 14.4 s


# Step 3 - Start Training

In [7]:
%%time
       
original_desc = original_hyperparameters["description"]
original_view = original_hyperparameters["data_view"]
if original_view == "all":
    views = ["coronal", "sagittal", "axial"]
else:
    views = [original_view]
    
for view in views:
    new_desc = original_desc.replace("all", view)
    hyperparameters = original_hyperparameters.copy()
    
    hyperparameters["data_view"] = view
    hyperparameters["description"] = new_desc
    if view == "sagittal":
        print("Disabling sagital flip aug")
        hyperparameters["taug_flip_prob"] = 0  # disable flip on sagittal view
    print(f"Starting training for view {view}, description {new_desc}")
    model = Segmentor(hparams=hyperparameters)
    data = MyDataModule(hparams=hyperparameters)
 
    #callbacks configuration
    prefix = hyperparameters["description"] + '_' + time.strftime("%d-%m-%Y_%H-%M")
    ckpt_path = os.path.join("checkpoints", hyperparameters["experiment_name"])
    print(ckpt_path)
    callbacks = [EarlyStopping(monitor="val_loss",  # logging variable
                               patience=hyperparameters['patience'], 
                               verbose=True, 
                               mode=hyperparameters['lr_decay_mode'] 
                               ),
                 ModelCheckpoint(dirpath=ckpt_path, 
                                 filename=prefix + '-{epoch:02d}-{val_loss:.2f}',
                                 verbose=True,
                                 monitor="val_loss", 
                                 save_top_k=1,
                                 mode="min") ]
    trainer_normal = pl.Trainer(max_epochs=hyperparameters["max_epochs"],
                                gpus=1,
                                precision=32,  
                                callbacks=callbacks,  
                                checkpoint_callback=True,  
                                accumulate_grad_batches=2,  
                                resume_from_checkpoint=None,  
                                progress_bar_refresh_rate=50  
                                                              
                                )

    print("Hyperparameters:\n")
    for k, v in hyperparameters.items():
        print(f'{k}: {v}')

    trainer_normal.fit(model, data)

print('DONE!')

  rank_zero_warn(f"Checkpoint directory {dirpath} exists and is not empty.")
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Starting training for view coronal, description unet_single_label_psz064_coronal_evalue1T1
checkpoints/unet_single_label
Hyperparameters:

experiment_name: unet_single_label
description: unet_single_label_psz064_coronal_evalue1T1
data_view: coronal
dataset_folder: ./Data/Patches/single_label_psz064/
subjects_list: ./training_subjects_randomized.txt
in_channels: ['evalue1', 'T1']
masks: ['staple_oh']
cnn_architecture: unet
input_size: 32
n_inchannels: 2
n_outchannels: 2
init_features: 32
taug_angle: 5
taug_flip_prob: 0.5
max_epochs: 2000
patience: 20
learning_threshold: 0.01
batch_size: 256
split_train_val: 0.2
opt_name: Adam
min_lr: 1e-06
eps: 1e-05
monitor: val_loss
lr: 0.001
scheduling_patience_lrepochs: 3
lr_decay_factor: 0.1
lr_decay_policy: plateau
lr_decay_mode: min
lr_decay_threshold_mode: abs
lossweighs: [[0, 1]]
func_weights: [1]
train_loss_funcs: ['DiceLoss_weighs']
train_metric: DiceMetric_weighs
train_metricweighs: [0, 1]
val_loss_funcs: ['DiceLoss_weighs']
val_metric: Dice


  | Name  | Type | Params
-------------------------------
0 | model | UNet | 7.8 M 
-------------------------------
7.8 M     Trainable params
0         Non-trainable params
7.8 M     Total params
31.051    Total estimated model params size (MB)


HBox(children=(HTML(value='Validation sanity check'), FloatProgress(value=1.0, bar_style='info', layout=Layout…

HBox(children=(HTML(value='Training'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Metric val_loss improved. New best score: 0.252
Epoch 0, global step 90: val_loss reached 0.25241 (best 0.25241), saving model to "/home/miclab/Python_codes/tahalmus_benchmark_diffusion_dev-main/code/checkpoints/unet_single_label/unet_single_label_psz064_coronal_evalue1T1_17-12-2021_15-24-epoch=00-val_loss=0.25.ckpt" as top 1


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Metric val_loss improved by 0.165 >= min_delta = 0.0. New best score: 0.087
Epoch 1, global step 181: val_loss reached 0.08712 (best 0.08712), saving model to "/home/miclab/Python_codes/tahalmus_benchmark_diffusion_dev-main/code/checkpoints/unet_single_label/unet_single_label_psz064_coronal_evalue1T1_17-12-2021_15-24-epoch=01-val_loss=0.09.ckpt" as top 1


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Metric val_loss improved by 0.030 >= min_delta = 0.0. New best score: 0.057
Epoch 2, global step 272: val_loss reached 0.05714 (best 0.05714), saving model to "/home/miclab/Python_codes/tahalmus_benchmark_diffusion_dev-main/code/checkpoints/unet_single_label/unet_single_label_psz064_coronal_evalue1T1_17-12-2021_15-24-epoch=02-val_loss=0.06.ckpt" as top 1


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Metric val_loss improved by 0.007 >= min_delta = 0.0. New best score: 0.050
Epoch 3, global step 363: val_loss reached 0.05038 (best 0.05038), saving model to "/home/miclab/Python_codes/tahalmus_benchmark_diffusion_dev-main/code/checkpoints/unet_single_label/unet_single_label_psz064_coronal_evalue1T1_17-12-2021_15-24-epoch=03-val_loss=0.05.ckpt" as top 1


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Metric val_loss improved by 0.004 >= min_delta = 0.0. New best score: 0.046
Epoch 4, global step 454: val_loss reached 0.04598 (best 0.04598), saving model to "/home/miclab/Python_codes/tahalmus_benchmark_diffusion_dev-main/code/checkpoints/unet_single_label/unet_single_label_psz064_coronal_evalue1T1_17-12-2021_15-24-epoch=04-val_loss=0.05.ckpt" as top 1


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Metric val_loss improved by 0.000 >= min_delta = 0.0. New best score: 0.046
Epoch 5, global step 545: val_loss reached 0.04587 (best 0.04587), saving model to "/home/miclab/Python_codes/tahalmus_benchmark_diffusion_dev-main/code/checkpoints/unet_single_label/unet_single_label_psz064_coronal_evalue1T1_17-12-2021_15-24-epoch=05-val_loss=0.05.ckpt" as top 1


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Metric val_loss improved by 0.004 >= min_delta = 0.0. New best score: 0.041
Epoch 6, global step 636: val_loss reached 0.04142 (best 0.04142), saving model to "/home/miclab/Python_codes/tahalmus_benchmark_diffusion_dev-main/code/checkpoints/unet_single_label/unet_single_label_psz064_coronal_evalue1T1_17-12-2021_15-24-epoch=06-val_loss=0.04.ckpt" as top 1
  rank_zero_warn('Detected KeyboardInterrupt, attempting graceful shutdown...')
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name  | Type | Params
-------------------------------
0 | model | UNet | 7.8 M 
-------------------------------
7.8 M     Trainable params
0         Non-trainable params
7.8 M     Total params
31.051    Total estimated model params size (MB)


Disabling sagital flip aug
Starting training for view sagittal, description unet_single_label_psz064_sagittal_evalue1T1
checkpoints/unet_single_label
Hyperparameters:

experiment_name: unet_single_label
description: unet_single_label_psz064_sagittal_evalue1T1
data_view: sagittal
dataset_folder: ./Data/Patches/single_label_psz064/
subjects_list: ./training_subjects_randomized.txt
in_channels: ['evalue1', 'T1']
masks: ['staple_oh']
cnn_architecture: unet
input_size: 32
n_inchannels: 2
n_outchannels: 2
init_features: 32
taug_angle: 5
taug_flip_prob: 0
max_epochs: 2000
patience: 20
learning_threshold: 0.01
batch_size: 256
split_train_val: 0.2
opt_name: Adam
min_lr: 1e-06
eps: 1e-05
monitor: val_loss
lr: 0.001
scheduling_patience_lrepochs: 3
lr_decay_factor: 0.1
lr_decay_policy: plateau
lr_decay_mode: min
lr_decay_threshold_mode: abs
lossweighs: [[0, 1]]
func_weights: [1]
train_loss_funcs: ['DiceLoss_weighs']
train_metric: DiceMetric_weighs
train_metricweighs: [0, 1]
val_loss_funcs: ['DiceL

HBox(children=(HTML(value='Validation sanity check'), FloatProgress(value=1.0, bar_style='info', layout=Layout…

HBox(children=(HTML(value='Training'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max…

HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Metric val_loss improved. New best score: 0.208
Epoch 0, global step 90: val_loss reached 0.20770 (best 0.20770), saving model to "/home/miclab/Python_codes/tahalmus_benchmark_diffusion_dev-main/code/checkpoints/unet_single_label/unet_single_label_psz064_sagittal_evalue1T1_17-12-2021_15-31-epoch=00-val_loss=0.21.ckpt" as top 1


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Metric val_loss improved by 0.135 >= min_delta = 0.0. New best score: 0.073
Epoch 1, global step 181: val_loss reached 0.07262 (best 0.07262), saving model to "/home/miclab/Python_codes/tahalmus_benchmark_diffusion_dev-main/code/checkpoints/unet_single_label/unet_single_label_psz064_sagittal_evalue1T1_17-12-2021_15-31-epoch=01-val_loss=0.07.ckpt" as top 1


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Metric val_loss improved by 0.016 >= min_delta = 0.0. New best score: 0.057
Epoch 2, global step 272: val_loss reached 0.05703 (best 0.05703), saving model to "/home/miclab/Python_codes/tahalmus_benchmark_diffusion_dev-main/code/checkpoints/unet_single_label/unet_single_label_psz064_sagittal_evalue1T1_17-12-2021_15-31-epoch=02-val_loss=0.06.ckpt" as top 1


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Metric val_loss improved by 0.007 >= min_delta = 0.0. New best score: 0.050
Epoch 3, global step 363: val_loss reached 0.05022 (best 0.05022), saving model to "/home/miclab/Python_codes/tahalmus_benchmark_diffusion_dev-main/code/checkpoints/unet_single_label/unet_single_label_psz064_sagittal_evalue1T1_17-12-2021_15-31-epoch=03-val_loss=0.05.ckpt" as top 1


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Metric val_loss improved by 0.005 >= min_delta = 0.0. New best score: 0.045
Epoch 4, global step 454: val_loss reached 0.04539 (best 0.04539), saving model to "/home/miclab/Python_codes/tahalmus_benchmark_diffusion_dev-main/code/checkpoints/unet_single_label/unet_single_label_psz064_sagittal_evalue1T1_17-12-2021_15-31-epoch=04-val_loss=0.05.ckpt" as top 1


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 5, global step 545: val_loss was not in top 1


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Metric val_loss improved by 0.003 >= min_delta = 0.0. New best score: 0.043
Epoch 6, global step 636: val_loss reached 0.04274 (best 0.04274), saving model to "/home/miclab/Python_codes/tahalmus_benchmark_diffusion_dev-main/code/checkpoints/unet_single_label/unet_single_label_psz064_sagittal_evalue1T1_17-12-2021_15-31-epoch=06-val_loss=0.04.ckpt" as top 1


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Metric val_loss improved by 0.001 >= min_delta = 0.0. New best score: 0.042
Epoch 7, global step 727: val_loss reached 0.04164 (best 0.04164), saving model to "/home/miclab/Python_codes/tahalmus_benchmark_diffusion_dev-main/code/checkpoints/unet_single_label/unet_single_label_psz064_sagittal_evalue1T1_17-12-2021_15-31-epoch=07-val_loss=0.04.ckpt" as top 1


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Metric val_loss improved by 0.001 >= min_delta = 0.0. New best score: 0.041
Epoch 8, global step 818: val_loss reached 0.04057 (best 0.04057), saving model to "/home/miclab/Python_codes/tahalmus_benchmark_diffusion_dev-main/code/checkpoints/unet_single_label/unet_single_label_psz064_sagittal_evalue1T1_17-12-2021_15-31-epoch=08-val_loss=0.04.ckpt" as top 1


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 9, global step 909: val_loss was not in top 1


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Metric val_loss improved by 0.001 >= min_delta = 0.0. New best score: 0.039
Epoch 10, global step 1000: val_loss reached 0.03930 (best 0.03930), saving model to "/home/miclab/Python_codes/tahalmus_benchmark_diffusion_dev-main/code/checkpoints/unet_single_label/unet_single_label_psz064_sagittal_evalue1T1_17-12-2021_15-31-epoch=10-val_loss=0.04.ckpt" as top 1


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 11, global step 1091: val_loss was not in top 1


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 12, global step 1182: val_loss was not in top 1


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 13, global step 1273: val_loss was not in top 1


HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…

Epoch 14, global step 1364: val_loss was not in top 1
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name  | Type | Params
-------------------------------
0 | model | UNet | 7.8 M 
-------------------------------
7.8 M     Trainable params
0         Non-trainable params
7.8 M     Total params
31.051    Total estimated model params size (MB)


Starting training for view axial, description unet_single_label_psz064_axial_evalue1T1
checkpoints/unet_single_label
Hyperparameters:

experiment_name: unet_single_label
description: unet_single_label_psz064_axial_evalue1T1
data_view: axial
dataset_folder: ./Data/Patches/single_label_psz064/
subjects_list: ./training_subjects_randomized.txt
in_channels: ['evalue1', 'T1']
masks: ['staple_oh']
cnn_architecture: unet
input_size: 32
n_inchannels: 2
n_outchannels: 2
init_features: 32
taug_angle: 5
taug_flip_prob: 0.5
max_epochs: 2000
patience: 20
learning_threshold: 0.01
batch_size: 256
split_train_val: 0.2
opt_name: Adam
min_lr: 1e-06
eps: 1e-05
monitor: val_loss
lr: 0.001
scheduling_patience_lrepochs: 3
lr_decay_factor: 0.1
lr_decay_policy: plateau
lr_decay_mode: min
lr_decay_threshold_mode: abs
lossweighs: [[0, 1]]
func_weights: [1]
train_loss_funcs: ['DiceLoss_weighs']
train_metric: DiceMetric_weighs
train_metricweighs: [0, 1]
val_loss_funcs: ['DiceLoss_weighs']
val_metric: DiceMetric_w

HBox(children=(HTML(value='Validation sanity check'), FloatProgress(value=1.0, bar_style='info', layout=Layout…











HBox(children=(HTML(value='Training'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max…









DONE!
CPU times: user 22min 10s, sys: 26.8 s, total: 22min 37s
Wall time: 22min 27s


In [None]:
# Open tensor board monitor
%tensorboard --logdir lightning_logs