## To analyse the training result

### Set parameters

In [None]:
run_id = 8 # set this to prevent overlapped saving of model and data

In [None]:
import os
import shutil
import tempfile
import time
import matplotlib.pyplot as plt
from tqdm import tqdm

import pdb
import os
import numpy as np
import json

In [None]:
os.environ["MONAI_DATA_DIRECTORY"] = "/scratch1/sachinsa/monai_data_1"
directory = os.environ.get("MONAI_DATA_DIRECTORY")
if directory is not None:
    os.makedirs(directory, exist_ok=True)
root_dir = tempfile.mkdtemp() if directory is None else directory
print(root_dir)

In [None]:
load_dir = os.path.join(root_dir, f"run_{run_id}")
fig_save_dir = os.path.join("..", "figs", f"run_{run_id}")
os.makedirs(fig_save_dir, exist_ok=True)

In [None]:
epoch_loss_values = np.load(os.path.join(load_dir, 'epoch_loss_values.npy')).tolist()
metric_values = np.load(os.path.join(load_dir, 'metric_values.npy')).tolist()

In [None]:
max_epochs = len(epoch_loss_values)
val_interval = len(epoch_loss_values)//len(metric_values)

In [None]:
nan_ratio = np.sum(np.isnan(epoch_loss_values))/len(epoch_loss_values)
print(f"{100*nan_ratio:.1f}% of values are nan!!!")

In [None]:
MULTI_TRAINING_FIGURE = True

if not MULTI_TRAINING_FIGURE:
    plt.figure("train", (6, 4))
    x = [i + 1 for i in range(len(epoch_loss_values))]
    y = epoch_loss_values
    plt.xlabel("epoch")
    plt.ylabel("loss - log")
    plt.yscale('log')
    plt.plot(x, y, color="red")
    plt.title("Training: Gaussian Log Likelihood Loss", fontsize=25)
    plt.savefig(os.path.join(fig_save_dir, "train_plot.png"), facecolor='white')
    plt.show()
else:
    plt.figure("train", (18, 4))
    plt.subplot(1, 3, 1)
    x = [i + 1 for i in range(len(epoch_loss_values))]
    y = epoch_loss_values
    plt.xlabel("epoch", fontsize=15)
    plt.ylabel("loss - log", fontsize=15)
    plt.yscale('log')
    plt.plot(x, y, color="red")
    plt.suptitle("Training: Gaussian Log Likelihood Loss", fontsize=20)

    k = 2
    for zoom in [10,20]:
        if len(x) > zoom:
            plt.subplot(1, 3, k)
            # plt.title(f"Epoch Average Loss (from ep. {zoom})")
            x = [i + 1 for i in range(len(epoch_loss_values))]
            y = epoch_loss_values
            plt.ylabel("loss", fontsize=15)
            plt.xlabel(f"epoch (from ep. {zoom})", fontsize=15)
            
            plt.plot(x[zoom:], y[zoom:], color="red")
        k += 1
    plt.savefig(os.path.join(fig_save_dir, "train_plot.png"), facecolor='white')
    plt.show()

plt.figure("val", (6, 4))
plt.title("Validation: MSE", fontsize=20)
x_val = [val_interval * (i + 1) for i in range(len(metric_values))]
y_val = metric_values
plt.xlabel("epoch", fontsize=15)
plt.plot(x_val, y_val, color="green")
plt.savefig(os.path.join(fig_save_dir, "val_plot.png"), facecolor='white')
plt.show()

## Inference on the model

In [None]:
import torch
from monai.networks.nets import UNet
from monai.transforms import (
    LoadImage,
    NormalizeIntensity,
    Orientation,
    RandFlip,
    RandScaleIntensity,
    RandShiftIntensity,
    RandSpatialCrop,
    Spacing,
    EnsureType,
    EnsureChannelFirst,
)
from monai.transforms import (
    Compose,
)
from torch.utils.data import Dataset

In [None]:
device = torch.device("cuda:0")
model = UNet(
    spatial_dims=3, # 3D
    in_channels=4,
    out_channels=8, # we will output estimated mean and estimated std dev for all 4 image channels
    channels=(4, 8, 16),
    strides=(2, 2),
    num_res_units=2
).to(device)

VAL_AMP = True

# define inference method
def inference(input):
    def _compute(input):
        output = model(input)
        return output

    if VAL_AMP:
        with torch.cuda.amp.autocast():
            return _compute(input)
    else:
        return _compute(input)

In [None]:
train_transform = Compose(
    [
        # load 4 Nifti images and stack them together
        LoadImage(),
        EnsureChannelFirst(),
        EnsureType(),
        Orientation(axcodes="RAS"),
        Spacing(
            pixdim=(1.0, 1.0, 1.0),
            mode=("bilinear", "nearest"),
        ),
        RandSpatialCrop(roi_size=[224, 224, 144], random_size=False),
        RandFlip(prob=0.5, spatial_axis=0),
        RandFlip(prob=0.5, spatial_axis=1),
        RandFlip(prob=0.5, spatial_axis=2),
        NormalizeIntensity(nonzero=True, channel_wise=True),
        RandScaleIntensity(factors=0.1, prob=1.0),
        RandShiftIntensity(offsets=0.1, prob=1.0),
    ]
)

In [None]:
class BrainMRIDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = os.path.join(root_dir, "Task01_BrainTumour")
        json_file_path = os.path.join(self.root_dir, "dataset.json")
        with open(json_file_path, 'r') as file:
            data_json = json.load(file)

        self.image_filenames = data_json['training']

        np.random.seed(0)
        self.seq_mask = np.random.rand(len(self.image_filenames), 4) < 0.2

        self.transform = transform

    def __len__(self):
        """Returns the total number of samples in the dataset."""
        return len(self.image_filenames)

    def __getitem__(self, idx):
        img_name = os.path.normpath(os.path.join(self.root_dir,self.image_filenames[idx]['image']))
        mask = self.seq_mask[idx]
        
        if self.transform:
            image = self.transform(img_name)

        mask = torch.from_numpy(mask)

        return {"image":image, "mask":mask}
    


sample_ds = BrainMRIDataset(
    root_dir=root_dir,
    transform=train_transform
)

In [None]:
def GaussianLikelihood(expected_img, output_img):
    output_img_mean = output_img[:, :4, ...]
    output_img_log_std = output_img[:, 4:, ...]

    cost1 = (expected_img - output_img_mean)**2 / (2*torch.exp(2*output_img_log_std))

    cost2 = output_img_log_std

    return torch.mean(cost1 + cost2)

In [None]:
model.load_state_dict(torch.load(os.path.join(load_dir, "best_metric_model.pth")))
model.eval()
channels = ["FLAIR", "T1w", "T1gd", "T2w"]

def plot_brain(index, label):
    start_index = None
    title = ""
    _, _, im_length, im_width, im_height = val_input.shape
    h_index = im_height//2

    if label == "input":
        start_index = 0
        title = "Input"
        brain_slice = val_input[0, index, :, :, h_index]
    elif label == "mean":
        start_index = 4
        title = "Output: " + r"$\mu$"
        brain_slice = val_output[0, index, :, :, h_index]
    elif label == "logstd":
        start_index = 8
        title = "Output: log(" + r"$\sigma$" + ")"
        brain_slice = val_output[0, index+4, :, :, h_index]
    plt.subplot(3, 4, start_index + index + 1)
    if label == "input":
        plt.title(channels[i], fontsize=30)
    brain_slice = brain_slice.detach().cpu().T
    plt.imshow(brain_slice, cmap="gray")
    plt.xlabel('')
    if i == 0:
        plt.ylabel(title, fontsize=30)
    plt.xticks([0, im_width - 1], [0, im_width - 1], fontsize=15)
    plt.yticks([0, im_length - 1], [0, im_length - 1], fontsize=15)
    cbar = plt.colorbar()
    cbar.ax.tick_params(labelsize=20)


with torch.no_grad():
    val_input = sample_ds[6]["image"].unsqueeze(0).to(device)
    mask_indices = [False, False, False, False]
    val_input[:, mask_indices, ...] = 0
    val_output = inference(val_input)


    plt.figure("image", (24, 18))
    for i in range(4):
        plot_brain(i, "input")
    for i in range(4):
        plot_brain(i, "mean")
    for i in range(4):
        plot_brain(i, "logstd")

plt.suptitle("")
plt.tight_layout()
plt.savefig(os.path.join(fig_save_dir, "model_inference.png"), facecolor='white')
plt.show()