## To analyse synthesis and uncertainty

### Set parameters

In [None]:
RUN_ID = 22 # set this to prevent overlapped saving of model and data
RANDOM_SEED = 0
ROOT_DIR = "/scratch1/sachinsa/cont_syn"
QR_REGRESSION = True

In [None]:
import os
import matplotlib.pyplot as plt

import pdb
import numpy as np
import pickle
from utils.logger import Logger

logger = Logger(log_level='DEBUG')

In [None]:
load_dir = os.path.join(ROOT_DIR, f"run_{RUN_ID}")
fig_save_dir = os.path.join("..", "figs", f"run_{RUN_ID}")
os.makedirs(fig_save_dir, exist_ok=True)

In [None]:
with open(os.path.join(load_dir, 'training_info.pkl'), 'rb') as f:
    training_info = pickle.load(f)
    epoch_loss_values = training_info['epoch_loss_values']
    metric_values = training_info['metric_values']

logger.info("PARAMETERS\n-----------------")
logger.info(f"RUN_ID: {RUN_ID}")
logger.info(f"QR_REGRESSION: {QR_REGRESSION}")
logger.info(f"ROOT_DIR: {ROOT_DIR}")
print("")

In [None]:
max_epochs = len(epoch_loss_values)
val_interval = len(epoch_loss_values)//len(metric_values)
print("Max epochs:", max_epochs)

In [None]:
MULTI_TRAINING_FIGURE = True

if not MULTI_TRAINING_FIGURE:
    plt.figure("train", (6, 4))
    x = [i + 1 for i in range(len(epoch_loss_values))]
    y = epoch_loss_values
    plt.xlabel("epoch")
    plt.ylabel("loss - log")
    plt.yscale('log')
    plt.plot(x, y, color="red")
    plt.title("Training: Gaussian Log Likelihood Loss", fontsize=25)
    plt.savefig(os.path.join(fig_save_dir, "train_plot.png"), facecolor='white')
    plt.show()
else:
    plt.figure("train", (18, 4))
    plt.subplot(1, 3, 1)
    x = [i + 1 for i in range(len(epoch_loss_values))]
    y = epoch_loss_values
    plt.xlabel("epoch", fontsize=15)
    plt.ylabel("loss", fontsize=15)
    # plt.ylabel("loss - log", fontsize=15)
    # plt.yscale('log')
    plt.plot(x, y, color="red")
    plt.suptitle("Training: GLL Loss", fontsize=20)

    k = 2
    for zoom in [20, 100]:
        if len(x) > zoom:
            plt.subplot(1, 3, k)
            x = [i + 1 for i in range(len(epoch_loss_values))]
            y = epoch_loss_values
            plt.ylabel("loss", fontsize=15)
            plt.xlabel(f"epoch (from ep. {zoom})", fontsize=15)
            
            plt.plot(x[zoom:], y[zoom:], color="red")
        k += 1
    plt.savefig(os.path.join(fig_save_dir, "train_plot.png"), facecolor='white')
    plt.show()

plt.figure("val", (6, 4))
plt.title("Validation: 1-MSE", fontsize=20)
x_val = [val_interval * (i + 1) for i in range(len(metric_values))]
y_val = metric_values
plt.xlabel("epoch", fontsize=15)
plt.tight_layout()
plt.plot(x_val, y_val, color="green")
plt.savefig(os.path.join(fig_save_dir, "val_plot.png"), facecolor='white')
plt.show()

## Inference on the model

In [None]:
import torch

from utils.model import create_UNet3D, inference
from utils.transforms import contr_syn_transform_3
from utils.dataset import BraTSDataset

In [None]:
device = torch.device("cuda:0")
out_channels = 12 if QR_REGRESSION else 8
model = create_UNet3D(out_channels, device)

In [None]:
val_dataset = BraTSDataset(
    version='2017',
    section = 'validation',
    seed = RANDOM_SEED,
    transform = contr_syn_transform_3['val']
)

In [None]:
checkpoint = torch.load(os.path.join(load_dir, 'best_checkpoint.pth'), weights_only=True)
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()
channels = ["FLAIR", "T1w", "T1Gd", "T2w"]
input_mask = [True, False, True, False]

with torch.no_grad():
    this_input = val_dataset.get_random()
    input_image = this_input["image"].unsqueeze(0).to(device)
    input_image_copy = input_image.clone()
    input_image_copy[:, input_mask, ...] = 0
    this_output = inference(input_image_copy, model)

In [None]:
_, _, im_length, im_width, im_height = input_image.shape
h_index = im_height//2
extremes = { # to map vmin, vmax
    "FLAIR": {
        "input": [3, -3],
        "q0": [3, -3],
        "q1": [2, 0],
        "q2": [3, -3],
        "q3": [1, 0],
    },
    "T1w": {
        "input": [3, -3],
        "q0": [3, -3],
        "q1": [3, 0],
        "q2": [5, -5],
        "q3": [1, 0],
    },
    "T1Gd": {
        "input": [3, -3],
        "q0": [3, -3],
        "q1": [2, 0],
        "q2": [4, -4],
        "q3": [1, 0],
    },
    "T2w": {
        "input": [4, -3],
        "q0": [4, -3],
        "q1": [3, 0],
        "q2": [4, -4],
        "q3": [1, 0],
    }
}

def plot_brain(index, label):
    start_index = None
    row_title = ""
    this_input_sub = input_image[0, :, :, :, h_index]
    this_output_sub = this_output[0, :, :, :, h_index]
    nc = len(channels)

    if label == "input":
        start_index = 0*nc
        row_title = "Input"
        brain_slice = this_input_sub[index]
    elif label == "mean":
        start_index = 1*nc
        row_title = "Output: " + r"$\mu$"
        brain_slice = this_output_sub[index]
    elif label == "var":
        start_index = 2*nc
        row_title = "Output: " + r"$\sigma$" + ""
        brain_slice = this_output_sub[index+4]
        brain_slice = torch.exp(brain_slice)
        var_threshold = torch.quantile(brain_slice.float(), 0.95).item()
        brain_slice[brain_slice >= var_threshold] = var_threshold
    elif label == "q0":
        start_index = 1*nc
        row_title = "Output: " + "q0"
        brain_slice = this_output_sub[index]
    elif label == "q1":
        start_index = 2*nc
        row_title = "Output: " + r"qH-qL" + ""
        brain_slice = 0.5*(this_output_sub[index+2*nc] - this_output_sub[index+1*nc])
    elif label == "q2":
        start_index = 3*nc
        row_title = "Output: " + r"qH" + ""
        brain_slice = this_output_sub[index+2*nc]
    elif label == "q3":
        start_index = 4*nc
        row_title = "Outlier"
        lower_slice = this_input_sub[index] < this_output_sub[index+nc]
        upper_slice = this_input_sub[index] > this_output_sub[index+2*nc]
        brain_slice = torch.logical_or(lower_slice, upper_slice)
    num_rows = 5 if QR_REGRESSION else 3
    
    plt.subplot(num_rows, 4, start_index + index + 1)

    col_title = channels[index]
    col_title = f"{col_title} (X)" if input_mask[index] else col_title
    if label == "input":
        plt.title(col_title, fontsize=30)

    brain_slice = brain_slice.detach().cpu().T

    CLAMP_VIS = False # Clamp visualization
    if CLAMP_VIS:
        PERCENTILE = True
        if not PERCENTILE:
            extr = extremes[channels[index]][label]
        else:
            extr = [np.percentile(brain_slice, 1), np.percentile(brain_slice, 99)]
        plt.imshow(brain_slice, cmap="gray",vmin=extr[0],vmax=extr[1])
    else:
        plt.imshow(brain_slice, cmap="gray")

    plt.xlabel('')
    if index == 0:
        plt.ylabel(row_title, fontsize=30)

    plt.xticks([im_width - 1], [im_width], fontsize=15)
    plt.yticks([im_length - 1], [im_length], fontsize=15)
    cbar = plt.colorbar(shrink=0.7)
    cbar.ax.tick_params(labelsize=20)

plt.figure("image", (24, 18))
for i in range(4):
    plot_brain(i, "input")
if QR_REGRESSION:
    for i in range(4):
        plot_brain(i, "q0")
    for i in range(4):
        plot_brain(i, "q1")
    # for i in range(4):
    #     plot_brain(i, "q2")
    # for i in range(4):
    #     plot_brain(i, "q3")
else:
    for i in range(4):
        plot_brain(i, "mean")
    for i in range(4):
        plot_brain(i, "var")

plt.suptitle(f"BRATS_{this_input['id']} (h={h_index}/{im_height})", fontsize=20)
plt.tight_layout()
plt.savefig(os.path.join(fig_save_dir, "model_inference.png"), facecolor='white')
plt.show()