## **Run model training and validates performances on the validation data**
### **Author:** Alessandro Ulivi (ale.ulivi@gmail.com)
### **Start day (yyyy/mm/dd):** 2024/10/18
### **Description**
#### The notebook:
#### - loads train and validation data from the pip2_segmentation dataset (refer to README.txt).
#### - Sets up and applies augmentation transformations.
#### - Sets up the hyperparameters used to define and train the UNet model.
#### - Trains the model on the (augmented) train data while validating training performances on the validation data.
#### - While training, saves checkpoints.
#### - While training, logs in TensorBoard: the training loss function, the validation loss function and metric, input and output images for the train and validation data.
#### - Logs in TensorBoard the hypeparameters used for training.

### **Requirements**
#### The notebook runs on the pip2_segmentation environment and using the scripts of the pip2_segmentation project. Refer to https://github.com/AlessandroUlivi/pip2_segmentation.
#### In addition, a "runs" folder and a "checkpoints" folder are expected to store, respectively, TensorBoards summaries of individual runs, and checkpoints of model training.

In [5]:
# load tensorboard extension
%load_ext tensorboard

In [1]:
# import required modules
import matplotlib.pyplot as plt
from functools import partial
import datetime
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import RandomSampler, DataLoader, Subset
from dataprep.data_preparation import make_dataset, compose, random_flip, random_translation, random_gaussian_or_uniform_noise, add_channel, normalize, to_tensor
from utils.utils_funct import dict2mdtable
from models.unet import UNet
from modeltrain.train_model import run_training, run_training_no_val
from metrics.metric import DiceCoefficient, DiceLoss, DiceBCELoss, FocalLoss, BCE_EdgeDiceLoss

from torch.utils.tensorboard import SummaryWriter



In [2]:
train_input_data_dir = r"C:\Users\aless\OneDrive\Desktop\Ale\personal\projects\pip2_segmentation\data\train\raw"
train_label_data_dir = r"C:\Users\aless\OneDrive\Desktop\Ale\personal\projects\pip2_segmentation\data\train\label"
val_input_data_dir = r"C:\Users\aless\OneDrive\Desktop\Ale\personal\projects\pip2_segmentation\data\validation\raw"
val_label_data_dir = r"C:\Users\aless\OneDrive\Desktop\Ale\personal\projects\pip2_segmentation\data\validation\label"
test_input_data_dir = r"C:\Users\aless\OneDrive\Desktop\Ale\personal\projects\pip2_segmentation\data\test\raw"
test_label_data_dir = r"C:\Users\aless\OneDrive\Desktop\Ale\personal\projects\pip2_segmentation\data\test\label"

In [3]:
# Indicate trasformations
train_data_transformations_w_augmentation = [random_flip, random_translation, random_gaussian_or_uniform_noise, add_channel, to_tensor]
train_trafos = partial(compose, transforms=train_data_transformations_w_augmentation)

val_data_transformations = [add_channel, to_tensor] #NOTE: Data are not normalized as the normalization had been done at the moment of dataset creation and before chuking the images
val_trafos = partial(compose, transforms=val_data_transformations)

#create the train and validation datasets
train_dataset = make_dataset(train_input_data_dir, train_label_data_dir, transform=train_trafos, shuffle_data=True, stack_axis=0)
val_dataset = make_dataset(val_input_data_dir, val_label_data_dir, transform=val_trafos, shuffle_data=True, stack_axis=0)


In [6]:
# open tensorboard inside of our notebook
%tensorboard --logdir runs

Reusing TensorBoard on port 6008 (pid 31768), started 3:30:27 ago. (Use '!kill 31768' to kill it.)

In [7]:
#=========
# # pass data to DataLoader
batch_size = 4
# train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
# val_loader = DataLoader(val_dataset, batch_size=batch_size)

#only work on a small subset of the data, for the moment
num_train_samples = 12
train_sample_ds = Subset(train_dataset, np.arange(num_train_samples))
train_sample_sampler = RandomSampler(train_sample_ds)
train_loader = DataLoader(train_sample_ds, sampler=train_sample_sampler, batch_size=batch_size)

num_val_samples = 12
val_sample_ds = Subset(val_dataset, np.arange(num_val_samples))
val_sample_sampler = RandomSampler(val_sample_ds)
val_loader = DataLoader(val_sample_ds, sampler=val_sample_sampler, batch_size=batch_size)

#=========
# pass to device
# if torch.cuda.is_available:
#     print("using gpu")
#     device = torch.device("cuda")
# else:
#     print("using cpu")
#     device = torch.device("cpu")
device = torch.device("cpu")

#=========
# set model's parameters
final_activation="Sigmoid"
depth = 4
num_fmaps = 64
fmap_inc_factor = 4
downsample_factor = 2
kernel_size = 3
padding = "same"
upsample_mode = "nearest"
unet_model = UNet(depth=depth,
                  in_channels=1,
                  out_channels=1,
                  final_activation=final_activation,
                  num_fmaps=num_fmaps,
                  fmap_inc_factor=fmap_inc_factor,
                  downsample_factor=downsample_factor,
                  kernel_size=kernel_size,
                  padding=padding,
                  upsample_mode=upsample_mode).to(device)

#=========
# set loss function
# loss_function = nn.BCELoss() #second place for the BCELoss - for the moment it seems that it does not manage to get values increasing... they remain low and the Sigmoids then fails
# loss_function = DiceLoss() #Works for the very initial training but then quick leads to large "positive pixels" structures
# loss_function = DiceBCELoss() #for the moment it seems that this is the best
loss_function = BCE_EdgeDiceLoss()
bce_weight = 1
dice_weight = 1
#it might be worthed to test FocalLoss (https://arxiv.org/pdf/1708.02002, https://www.kaggle.com/code/bigironsphere/loss-function-library-keras-pytorch)

#=========
# set optimizer
lr = 1e-4
optimizer = torch.optim.Adam(unet_model.parameters(), lr=lr)

#=========
# set metrics
bin_threshold=0.5
metric = DiceCoefficient()

#=========
# indicate key
# runs_counter = get_var_value(filename="varstore.dat")
my_key  = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")

# set logger's parameters
logger = SummaryWriter(f"runs/{my_key}")
log_interval=1
log_image_interval=20


#unrelated comment, could be useful to look into this https://pytorch.org/docs/stable/generated/torch.nn.BatchNorm2d.html


In [8]:
#=========
# model's training with validation
n_epochs = 2
lr_scheduler_flag = True
lr_kwargs={"mode":"min", "factor": 0.1, "patience":2}
checkpoint_save_path = "checkpoints"
run_training(model=unet_model,
            optimizer=optimizer,
            metric=metric, 
            n_epochs=n_epochs,
            bce_weight=bce_weight,
            dice_weight=dice_weight,
            train_loader=train_loader,
            val_loader=val_loader,
            loss_function=loss_function,
            bin_threshold=bin_threshold,
            logger=logger,
            log_interval=log_interval,
            log_image_interval=log_image_interval,
            device=device,
            key=my_key,
            path=checkpoint_save_path,
            lr_scheduler_flag=lr_scheduler_flag,
            lr_kwargs=lr_kwargs,
            x_dim=[-2,-1],
            y_dim=[-2,-1],
            best_metric_init=0)

#=========
#log all hyperparameters as text in Tensorboard

#transform the kwards of the lr in a string
lr_kwargs_str = ""
for k in lr_kwargs:
    lr_kwargs_str = lr_kwargs_str + f"{k}:{lr_kwargs[k]},"

#form a dictionary to with all hyperparameters to be logged
hparam_dict = {"batch_size":str(batch_size),
                      "final_activation":final_activation,
                      "depth":str(depth),
                      "num_fmaps":str(num_fmaps),
                      "fmap_inc_factor":str(fmap_inc_factor),
                      "downsample_factor":str(downsample_factor),
                      "kernel_size":str(kernel_size),
                      "padding":padding,
                      "upsample_mode":upsample_mode,
                      "loss_function":str(loss_function),
                      "bce_weight":str(bce_weight),
                      "dice_weight":str(dice_weight),
                      "bin_threshold":str(bin_threshold),
                      "optimizer":str(optimizer),
                      "metric":str(metric),
                      "n_epochs":str(n_epochs),
                      "lr_scheduler_flag":str(lr_scheduler_flag),
                      "lr_kwargs":lr_kwargs_str}

#transform the dictionary in a table-like string object
hparam_table_like = dict2mdtable(hparam_dict, key='Name', val='Value', transform_2_string=False)

#log the text in Tensorboard summary of the run
logger.add_text('Hyperparams', hparam_table_like, 1)


Validate: Average loss: 0.9578, Average Metric: 0.0031


Validate: Average loss: 0.5149, Average Metric: 0.0000



In [8]:
# #=========
# # model's training without validation
# n_epochs = 8
# lr_scheduler_flag = True
# lr_kwargs={"mode":"min", "factor": 0.1, "patience":2}
# checkpoint_save_path = "checkpoints"
# run_training_no_val(model=unet_model,
#                     optimizer=optimizer,
#                     metric=metric,
#                     n_epochs=n_epochs,
#                     train_loader=train_loader,
#                     loss_function=loss_function,
#                     bin_threshold=bin_threshold,
#                     logger=logger,
#                     log_interval=log_interval,
#                     log_image_interval=log_image_interval,
#                     device=device,
#                     key=my_key,
#                     path=checkpoint_save_path,
#                     lr_scheduler_flag = lr_scheduler_flag,
#                     lr_kwargs=lr_kwargs,
#                     x_dim=[-2,-1],
#                     y_dim=[-2,-1],
#                     best_metric_init = -1)

# #=========
# #log all hyperparameters as text in Tensorboard

# #transform the kwards of the lr in a string
# lr_kwargs_str = ""
# for k in lr_kwargs:
#     lr_kwargs_str = lr_kwargs_str + f"{k}:{lr_kwargs[k]},"

# #form a dictionary to with all hyperparameters to be logged
# hparam_dict = {"batch_size":str(batch_size),
#                       "final_activation":final_activation,
#                       "depth":str(depth),
#                       "num_fmaps":str(num_fmaps),
#                       "fmap_inc_factor":str(fmap_inc_factor),
#                       "downsample_factor":str(downsample_factor),
#                       "kernel_size":str(kernel_size),
#                       "padding":padding,
#                       "upsample_mode":upsample_mode,
#                       "loss_function":str(loss_function),
#                       "bin_threshold":str(bin_threshold),
#                       "optimizer":str(optimizer),
#                       "metric":str(metric),
#                       "n_epochs":str(n_epochs),
#                       "lr_scheduler_flag":str(lr_scheduler_flag),
#                       "lr_kwargs":lr_kwargs_str}

# #transform the dictionary in a table-like string object
# hparam_table_like = dict2mdtable(hparam_dict, key='Name', val='Value', transform_2_string=False)

# #log the text in Tensorboard summary of the run
# logger.add_text('Hyperparams', hparam_table_like, 1)
