## **Run model training and validates performances on the validation data**
### **Author:** Alessandro Ulivi (ale.ulivi@gmail.com)
### **Start day (yyyy/mm/dd):** 2024/10/18
### **Description**
#### The notebook:
#### - loads train and validation data from the pip2_segmentation dataset (refer to README.txt).
#### - Sets up and applies augmentation transformations.
#### - Sets up the hyperparameters used to define and train the UNet model.
#### - Trains the model on the (augmented) train data while validating training performances on the validation data.
#### - While training, saves checkpoints.
#### - While training, logs in TensorBoard: the training loss function, the validation loss function and metric, input and output images for the train and validation data.
#### - Logs in TensorBoard the hypeparameters used for training.

### **Requirements**
#### The notebook runs on the pip2_segmentation environment and using the scripts of the pip2_segmentation project. Refer to https://github.com/AlessandroUlivi/pip2_segmentation.
#### In addition, a "runs" folder and a "checkpoints" folder are expected to store, respectively, TensorBoards summaries of individual runs, and checkpoints of model training.

In [1]:
# load tensorboard extension
%load_ext tensorboard

In [2]:
# import required modules
# import matplotlib.pyplot as plt
from functools import partial
import datetime
# import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import RandomSampler, DataLoader, Subset
from data_preparation import make_dataset, compose, random_flip, random_translation, random_gaussian_or_uniform_noise, add_channel, normalize, to_tensor
from utils import dict2mdtable
from unet import UNet
from train_model import run_training
from metric import DiceCoefficient, DiceLoss, DiceBCELoss

from torch.utils.tensorboard import SummaryWriter
# import torchvision.transforms.v2 as transforms_v2



In [3]:
train_input_data_dir = ""
train_label_data_dir = ""
val_input_data_dir = ""
val_label_data_dir = ""
test_input_data_dir = ""
test_label_data_dir = ""

In [4]:
# Indicate trasformations
data_augmentation_transformations = [random_flip, random_translation, random_gaussian_or_uniform_noise, add_channel, normalize, to_tensor]
trafos = trafos = partial(compose, transforms=data_augmentation_transformations)

#create the train and validation datasets
train_dataset = make_dataset(train_input_data_dir, train_label_data_dir, transform=trafos, shuffle_data=True, stack_axis=0)
val_dataset = make_dataset(val_input_data_dir, val_label_data_dir, transform=None, shuffle_data=True, stack_axis=0)

In [5]:
# open tensorboard inside of our notebook
%tensorboard --logdir runs

Reusing TensorBoard on port 6007 (pid 18120), started 8 days, 20:20:34 ago. (Use '!kill 18120' to kill it.)

In [7]:
#=========
# # pass data to DataLoader
batch_size = 4
# train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
# val_loader = DataLoader(val_dataset, batch_size=batch_size)

#only work on a small subset of the data, for the moment
num_train_samples = 4
train_sample_ds = Subset(train_dataset, np.arange(num_train_samples))
train_sample_sampler = RandomSampler(train_sample_ds)
train_loader = DataLoader(train_sample_ds, sampler=train_sample_sampler, batch_size=batch_size)

num_val_samples = 4
val_sample_ds = Subset(val_dataset, np.arange(num_val_samples))
val_sample_sampler = RandomSampler(val_sample_ds)
val_loader = DataLoader(val_sample_ds, sampler=val_sample_sampler, batch_size=batch_size)

#=========
# pass to device
# if torch.cuda.is_available:
#     print("using gpu")
#     device = torch.device("cuda")
# else:
#     print("using cpu")
#     device = torch.device("cpu")
device = torch.device("cpu")

#=========
# set model's parameters
final_activation="Sigmoid"
depth = 3
num_fmaps = 64
fmap_inc_factor = 4
downsample_factor = 2
kernel_size = 3
padding = "valid"
upsample_mode = "nearest"
unet_model = UNet(depth=depth,
                  in_channels=1,
                  out_channels=1,
                  final_activation=final_activation,
                  num_fmaps=num_fmaps,
                  fmap_inc_factor=fmap_inc_factor,
                  downsample_factor=downsample_factor,
                  kernel_size=kernel_size,
                  padding=padding,
                  upsample_mode=upsample_mode).to(device)

#=========
# set loss function
# loss_function = nn.BCELoss()
# loss_function = DiceLoss()
loss_function = DiceBCELoss() #for the moment it seems that this is the best

#=========
# set optimizer
lr = 1e-4
optimizer = torch.optim.Adam(unet_model.parameters(), lr=lr)

#=========
# set metrics
metric = DiceCoefficient()

#=========
# indicate key
my_key  = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")

# set logger's parameters
logger = SummaryWriter(f"runs/{my_key}")
log_interval=1

#=========
# model's training
n_epochs = 2
lr_scheduler_flag = True
lr_kwargs={"mode":"min", "factor": 0.1, "patience":2}
checkpoint_save_path = "checkpoints"
run_training(model=unet_model,
            optimizer=optimizer,
            metric=metric, 
            n_epochs=n_epochs,
            train_loader=train_loader,
            val_loader=val_loader,
            loss_function=loss_function,
            logger=logger,
            log_interval=log_interval,
            device=device,
            key=my_key,
            path=checkpoint_save_path,
            lr_scheduler_flag=lr_scheduler_flag,
            lr_kwargs=lr_kwargs,
            x_dim=[-2,-1],
            y_dim=[-2,-1],
            best_metric_init=0)

#=========
#log all hyperparameters as text in Tensorboard

#transform the kwards of the lr in a string
lr_kwargs_str = ""
for k in lr_kwargs:
    lr_kwargs_str = lr_kwargs_str + f"{k}:{lr_kwargs[k]},"

#form a dictionary to with all hyperparameters to be logged
hparam_dict = {"batch_size":str(batch_size),
                      "final_activation":final_activation,
                      "depth":str(depth),
                      "num_fmaps":str(num_fmaps),
                      "fmap_inc_factor":str(fmap_inc_factor),
                      "downsample_factor":str(downsample_factor),
                      "kernel_size":str(kernel_size),
                      "padding":padding,
                      "upsample_mode":upsample_mode,
                      "loss_function":str(loss_function),
                      "optimizer":str(optimizer),
                      "metric":str(metric),
                      "n_epochs":str(n_epochs),
                      "lr_scheduler_flag":str(lr_scheduler_flag),
                      "lr_kwargs":lr_kwargs_str}

#transform the dictionary in a table-like string object
hparam_table_like = dict2mdtable(hparam_dict, key='Name', val='Value', transform_2_string=False)

#log the text in Tensorboard summary of the run
logger.add_text('Hyperparams', hparam_table_like, 1)




Validate: Average loss: 1.5618, Average Metric: 0.0000


Validate: Average loss: 1.4866, Average Metric: 0.0000

