## **Tests a trained model**
### **Author:** Alessandro Ulivi (ale.ulivi@gmail.com)
### **Start day (yyyy/mm/dd):** 2024/10/21
### **Description**
#### The notebook loads a model and tests it on a test set.

### **Requirements**
#### The notebook expects a folder named "tests" in the work directory, to save summary writer's outputs.
#### The notebook runs on the pip2_segmentation environment and using the scripts of the pip2_segmentation project. Refer to https://github.com/AlessandroUlivi/pip2_segmentation.
In addition, a "runs" folder and a "checkpoints" folder are expected to store, respectively, TensorBoards summaries of individual runs, and checkpoints of model training.

In [5]:
# load tensorboard extension
%load_ext tensorboard

In [None]:
#Import required modules
import datetime
from functools import partial
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import RandomSampler, DataLoader, Subset
from data_preparation import make_dataset, add_channel, to_tensor, compose
from unet import UNet
from utils import dict2mdtable, load_checkpoint
from test_model import test_model
from metric import DiceCoefficient, DiceLoss, DiceBCELoss

from torch.utils.tensorboard import SummaryWriter
# import torchvision.transforms.v2 as transforms_v2

In [2]:
test_input_data_dir = r"C:\Users\aless\OneDrive\Desktop\Ale\personal\projects\pip2_segmentation\data\test\raw"
test_label_data_dir = r"C:\Users\aless\OneDrive\Desktop\Ale\personal\projects\pip2_segmentation\data\test\label"

In [3]:
# Indicate trasformations - #NOTE: Data are not normalized as the normalization had been done at the moment of dataset creation and before chunking the images
test_data_transformations = [add_channel, to_tensor]
test_trafos = trafos = partial(compose, transforms=test_data_transformations)

#create the test dataset
test_dataset = make_dataset(test_input_data_dir, test_label_data_dir, transform=test_trafos, shuffle_data=True, stack_axis=0)


In [6]:
# open tensorboard inside of our notebook
%tensorboard --logdir tests

In [7]:
#=========
# # pass data to DataLoader
batch_size=1

#only work on a small subset of the data, for the moment
num_test_samples = 4
test_sample_ds = Subset(test_dataset, np.arange(num_test_samples))
test_sample_sampler = RandomSampler(test_sample_ds)
test_loader = DataLoader(test_sample_ds, sampler=test_sample_sampler, batch_size=batch_size)

#=========
# pass to device
# if torch.cuda.is_available:
#     print("using gpu")
#     device = torch.device("cuda")
# else:
#     print("using cpu")
#     device = torch.device("cpu")
device = torch.device("cpu")

#=========
# set model's parameters
final_activation="Sigmoid"
depth = 3
num_fmaps = 64
fmap_inc_factor = 4
downsample_factor = 2
kernel_size = 3
padding = "valid"
upsample_mode = "nearest"
unet_model = UNet(depth=depth,
                  in_channels=1,
                  out_channels=1,
                  final_activation=final_activation,
                  num_fmaps=num_fmaps,
                  fmap_inc_factor=fmap_inc_factor,
                  downsample_factor=downsample_factor,
                  kernel_size=kernel_size,
                  padding=padding,
                  upsample_mode=upsample_mode).to(device)

#=========
# set loss function
# loss_function = nn.BCELoss() #second place for the BCELoss - for the moment it seems that it does not manage to get values increasing... they remain low and the Sigmoids then fails
# loss_function = DiceLoss() #Works for the very initial training but then quick leads to large "positive pixels" structures
loss_function = DiceBCELoss() #for the moment it seems that this is the best
#it might be worthed to test FocalLoss (https://arxiv.org/pdf/1708.02002, https://www.kaggle.com/code/bigironsphere/loss-function-library-keras-pytorch)

#=========
# set optimizer
lr = 1e-4
optimizer = torch.optim.Adam(unet_model.parameters(), lr=lr)

#=========
# set metrics
bin_threshold=0.5
metric = DiceCoefficient()

#=========
# indicate key
# runs_counter = get_var_value(filename="varstore.dat")
my_key  = "test_"+datetime.datetime.now().strftime("%Y%m%d-%H%M%S")

# set logger's parameters
logger = SummaryWriter(f"tests/{my_key}")



In [None]:
#=========
#=========
#=========
#THIS PART OF THE CODE IS NOT PROPERLY TESTED. IN PARTICULAR, NOTHING IS PROVITED TO step IN test_model
#=========
#=========
#=========





#=========
#load checkpoint of the model to test
checkpoint_save_path = "checkpoints"
checkpoint_key = "20241113-2343746"
model, optimizer, epoch = load_checkpoint(model=unet_model,
                                          path=checkpoint_save_path,
                                          optimizer=optimizer,
                                          key=checkpoint_key)


#=========
# model's training without validation
avg_loss_val, avg_metric_val = test_model(model=model,
                                          loader=test_loader,
                                          loss_function=loss_function,
                                          metric=metric,
                                          bin_threshold=bin_threshold,
                                          step=None,
                                          tb_logger=None,
                                          device=None,
                                          x_dim=[-2,-1],
                                          y_dim=[-2,-1])

#log avg_loss_val, avg_metric_val to tensorboard


#=========
#log all hyperparameters as text in Tensorboard
#form a dictionary to with all hyperparameters to be logged
hparam_dict = {"train_checkpoint_key": checkpoint_key,
               "batch_size":str(batch_size),
                "final_activation":final_activation,
                "depth":str(depth),
                "num_fmaps":str(num_fmaps),
                "fmap_inc_factor":str(fmap_inc_factor),
                "downsample_factor":str(downsample_factor),
                "kernel_size":str(kernel_size),
                "padding":padding,
                "upsample_mode":upsample_mode,
                "loss_function":str(loss_function),
                "bin_threshold":str(bin_threshold),
                "optimizer":str(optimizer),
                "metric":str(metric),
                "n_epochs":str(n_epochs)}

#transform the dictionary in a table-like string object
hparam_table_like = dict2mdtable(hparam_dict, key='Name', val='Value', transform_2_string=False)

#log the text in Tensorboard summary of the run
logger.add_text('Hyperparams', hparam_table_like, 1)
