# Main notebook
In this notebook we will show how to train an U-NET using MONAI and PyTorch.
We will use the functions written in the .py files


### Settings

#### Including Libraries

In [1]:
import os
os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"
from monai.networks.nets import UNet
from monai.networks.layers import Norm
from monai.losses import DiceLoss, DiceCELoss

import torch
from utilities import train, prepare

from monai.utils import first, set_determinism
from monai.transforms import(
    Compose,
    EnsureChannelFirstD,
    LoadImaged,
    Resized,
    ToTensord,
    Spacingd,
    Orientationd,
    ScaleIntensityRanged,
    CropForegroundd,
    Activations,
)

from monai.data import CacheDataset, DataLoader, Dataset

import matplotlib.pyplot as plt

from glob import glob
import numpy as np

from monai.inferers import sliding_window_inference

#### Preparation (Directories, Loss function, Optimizer and Data)

In [2]:

data_dir = 'E:\\Task03_Liver\\'
model_dir = '.\\Results' 


device = torch.device("cuda:0")
model = UNet(
    spatial_dims=3,
    in_channels=1,
    out_channels=2,
    channels=(16, 32, 64, 128, 256), 
    strides=(2, 2, 2, 2),
    num_res_units=2,
    norm=Norm.BATCH,
).to(device)


loss_function = DiceLoss(to_onehot_y=True, sigmoid=True, squared_pred=True)
optimizer = torch.optim.Adam(model.parameters(), 1e-5, weight_decay=1e-5, amsgrad=True)

data_in = prepare(data_dir, cache=True)
def debug_test_loader(test_loader):
    for idx,batch in enumerate(test_loader):
        try:
            print(f"Elaborazione batch {idx} completata con successo")
        except Exception as e:
            print(f"Errore durante l'elaborazione del batch {idx}")
            break
debug_test_loader(data_in[1])



### Training

#### Call to training model function

In [3]:
epoch=600
train(model, data_in, loss_function, optimizer, epoch, model_dir)

----------
epoch 1/600
--------------------


ZeroDivisionError: division by zero

#### Loading of informations on optimal UNET from files (output of train function)

In [4]:
train_loss = np.load(os.path.join(model_dir, 'loss_train.npy'))
train_metric = np.load(os.path.join(model_dir, 'metric_train.npy'))
test_loss = np.load(os.path.join(model_dir, 'loss_test.npy'))
test_metric = np.load(os.path.join(model_dir, 'metric_test.npy'))

#### Plot of Dice score and Loss function

In [None]:
plt.figure("Results", (12, 6))

# Train dice loss
plt.subplot(2, 2, 1)
plt.title("Train Dice Loss")
x = [i + 1 for i in range(len(train_loss))]
y = train_loss
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.plot(x, y, color='blue', linestyle='-', marker='o', label='Train Loss')
plt.grid(True)
plt.legend()

# Train metric DICE
plt.subplot(2, 2, 2)
plt.title("Train Metric DICE")
x = [i + 1 for i in range(len(train_metric))]
y = train_metric
plt.xlabel("Epoch")
plt.ylabel("DICE")
plt.plot(x, y, color='green', linestyle='-', marker='o', label='Train DICE')
plt.grid(True)
plt.legend()

# Test dice loss
plt.subplot(2, 2, 3)
plt.title("Test Dice Loss")
x = [i + 1 for i in range(len(test_loss))]
y = test_loss
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.plot(x, y, color='red', linestyle='-', marker='o', label='Test Loss')
plt.grid(True)
plt.legend()

# Test metric DICE
plt.subplot(2, 2, 4)
plt.title("Test Metric DICE")
x = [i + 1 for i in range(len(test_metric))]
y = test_metric
plt.xlabel("Epoch")
plt.ylabel("DICE")
plt.plot(x, y, color='purple', linestyle='-', marker='o', label='Test DICE')
plt.grid(True)
plt.legend()

plt.tight_layout()
plt.show()

### Testing

In [13]:
path_test_volumes = sorted(glob(os.path.join(data_dir, "TestVolumes_splitted", "*.nii.gz")))
path_test_segmentation = sorted(glob(os.path.join(data_dir, "TestSegmentation_splitted", "*.nii.gz")))
test_files = [{"vol": image_name, "seg": label_name} for image_name, label_name in zip(path_test_volumes, path_test_segmentation)]

In [None]:
test_transforms = Compose(
    [
        LoadImaged(keys=["vol", "seg"]),
        EnsureChannelFirstD(keys=["vol", "seg"]),
        Spacingd(keys=["vol", "seg"], pixdim=(1.5,1.5,1.0), mode=("bilinear", "nearest")),
        Orientationd(keys=["vol", "seg"], axcodes="RAS"),
        ScaleIntensityRanged(keys=["vol"], a_min=-200, a_max=200,b_min=0.0, b_max=1.0, clip=True), 
        CropForegroundd(keys=['vol', 'seg'], source_key='vol'),
        Resized(keys=["vol", "seg"], spatial_size=[128,128,64]),   
        ToTensord(keys=["vol", "seg"]),
    ]
)

In [15]:
test_ds = Dataset(data=test_files, transform=test_transforms)
test_loader = DataLoader(test_ds, batch_size=1)

In [16]:
device = torch.device("cuda:0")
model = UNet(
    spatial_dims=3,
    in_channels=1,
    out_channels=2,
    channels=(16, 32, 64, 128, 256), 
    strides=(2, 2, 2, 2),
    num_res_units=2,
    norm=Norm.BATCH,
).to(device)

In [None]:
model.load_state_dict(torch.load(
    os.path.join(model_dir, "best_metric_model.pth")))
model.eval()

In [None]:
sw_batch_size = 4
roi_size = (128, 128, 64)
with torch.no_grad():
    test_patient = first(test_loader)
    t_volume = test_patient['vol']
    
    test_outputs = sliding_window_inference(t_volume.to(device), roi_size, sw_batch_size, model)
    sigmoid_activation = Activations(sigmoid=True)
    test_outputs = sigmoid_activation(test_outputs)
    test_outputs = test_outputs > 0.53
        
    for i in range(64):
        plt.figure("check", (18, 6))
        plt.subplot(1, 3, 1)
        plt.title(f"image {i}")
        plt.imshow(test_patient["vol"][0, 0, :, :, i], cmap="gray")
        plt.subplot(1, 3, 2)
        plt.title(f"label {i}")
        plt.imshow(test_patient["seg"][0, 0, :, :, i] != 0)
        plt.subplot(1, 3, 3)
        plt.title(f"output {i}")
        plt.imshow(test_outputs.detach().cpu()[0, 1, :, :, i])
        plt.show()