## To analyse the training result

### Set parameters

In [None]:
run_id = 0 # set this to prevent overlapped saving of model and data

In [None]:
import os
import shutil
import tempfile
import time
import matplotlib.pyplot as plt
from tqdm import tqdm

import pdb
import os
import numpy as np
import json

In [None]:
os.environ["MONAI_DATA_DIRECTORY"] = "/scratch1/sachinsa/monai_data_1"
directory = os.environ.get("MONAI_DATA_DIRECTORY")
if directory is not None:
    os.makedirs(directory, exist_ok=True)
root_dir = tempfile.mkdtemp() if directory is None else directory
print(root_dir)

In [None]:
save_dir = os.path.join(root_dir, f"run_{run_id}")

In [None]:
epoch_loss_values = np.load(os.path.join(save_dir, 'epoch_loss_values.npy')).tolist()
metric_values = np.load(os.path.join(save_dir, 'metric_values.npy')).tolist()

print(epoch_loss_values)
print(metric_values)

In [None]:
max_epochs = len(epoch_loss_values)
val_interval = len(epoch_loss_values)//len(metric_values)

In [None]:
nan_ratio = np.sum(np.isnan(epoch_loss_values))/len(epoch_loss_values)
print(f"{100*nan_ratio:.1f}% of values are nan!!!")

In [None]:
plt.figure("train", (16, 5))
plt.subplot(1, 3, 1)
plt.title("Epoch Average Loss")
x = [i + 1 for i in range(len(epoch_loss_values))]
y = epoch_loss_values
plt.xlabel("epoch")
plt.ylabel("loss - log")
plt.plot(x, y, color="red")
plt.yscale('log')


k = 2
for zoom in [5,10]:
    if len(x) > zoom:
        plt.subplot(1, 3, k)
        plt.title("Epoch Average Loss (Zoomed in)")
        x = [i + 1 for i in range(len(epoch_loss_values))]
        y = epoch_loss_values
        plt.xlabel(f"epoch (from ep. {zoom})")
        plt.ylabel("loss")
        plt.plot(x[zoom:], y[zoom:], color="red")
    k += 1

plt.figure("val", (5,5))
plt.title("Val Mean MSE")
x_val = [val_interval * (i + 1) for i in range(len(metric_values))]
y_val = metric_values
plt.xlabel("epoch")
plt.plot(x_val, y_val, color="green")
plt.show()

## Inference on the model

In [None]:
import torch
from monai.networks.nets import UNet
from monai.transforms import (
    LoadImage,
    NormalizeIntensity,
    Orientation,
    RandFlip,
    RandScaleIntensity,
    RandShiftIntensity,
    RandSpatialCrop,
    Spacing,
    EnsureType,
    EnsureChannelFirst,
)
from monai.transforms import (
    Compose,
)
from torch.utils.data import Dataset

In [None]:
device = torch.device("cuda:0")
model = UNet(
    spatial_dims=3, # 3D
    in_channels=4,
    out_channels=8, # we will output estimated mean and estimated std dev for all 4 image channels
    channels=(4, 8, 16),
    strides=(2, 2),
    num_res_units=2
).to(device)

VAL_AMP = True

# define inference method
def inference(input):
    def _compute(input):
        output = model(input)
        return output

    if VAL_AMP:
        with torch.cuda.amp.autocast():
            return _compute(input)
    else:
        return _compute(input)

In [None]:
train_transform = Compose(
    [
        # load 4 Nifti images and stack them together
        LoadImage(),
        EnsureChannelFirst(),
        EnsureType(),
        Orientation(axcodes="RAS"),
        Spacing(
            pixdim=(1.0, 1.0, 1.0),
            mode=("bilinear", "nearest"),
        ),
        RandSpatialCrop(roi_size=[224, 224, 144], random_size=False),
        RandFlip(prob=0.5, spatial_axis=0),
        RandFlip(prob=0.5, spatial_axis=1),
        RandFlip(prob=0.5, spatial_axis=2),
        NormalizeIntensity(nonzero=True, channel_wise=True),
        RandScaleIntensity(factors=0.1, prob=1.0),
        RandShiftIntensity(offsets=0.1, prob=1.0),
    ]
)

In [None]:
class BrainMRIDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = os.path.join(root_dir, "Task01_BrainTumour")
        json_file_path = os.path.join(self.root_dir, "dataset.json")
        with open(json_file_path, 'r') as file:
            data_json = json.load(file)

        self.image_filenames = data_json['training']

        np.random.seed(0)
        self.seq_mask = np.random.rand(len(self.image_filenames), 4) < 0.2

        self.transform = transform

    def __len__(self):
        """Returns the total number of samples in the dataset."""
        return len(self.image_filenames)

    def __getitem__(self, idx):
        img_name = os.path.normpath(os.path.join(self.root_dir,self.image_filenames[idx]['image']))
        mask = self.seq_mask[idx]
        
        if self.transform:
            image = self.transform(img_name)

        mask = torch.from_numpy(mask)

        return {"image":image, "mask":mask}
    


sample_ds = BrainMRIDataset(
    root_dir=root_dir,
    transform=train_transform
)

In [None]:
model.load_state_dict(torch.load(os.path.join(save_dir, "best_metric_model.pth")))
model.eval()
with torch.no_grad():
    # select one image to evaluate and visualize the model output
    val_input = sample_ds[6]["image"].unsqueeze(0).to(device)
    mask_indices = [True, False, False, False]
    val_input[:, mask_indices, ...] = 0
    val_output = inference(val_input)
    # val_output = post_trans(val_output[0])
    plt.figure("image", (24, 6))
    for i in range(4):
        plt.subplot(1, 4, i + 1)
        plt.title(f"image channel {i}")
        plt.imshow(val_input[0, i, :, :, 70].detach().cpu(), cmap="gray")
    plt.show()
    # visualize the 3 channels model output corresponding to this image
    plt.figure("output mean", (24, 6))
    for i in range(4):
        plt.subplot(1, 4, i + 1)
        plt.title(f"output channel {i}")
        plt.imshow(val_output[0, i, :, :, 70].detach().cpu(), cmap="gray")

    plt.figure("output std", (24, 6))
    for i in range(4):
        plt.subplot(1, 4, i + 1)
        plt.title(f"output channel {i}")
        plt.imshow(val_output[0, i+4, :, :, 70].detach().cpu(), cmap="gray")
    plt.show()