## To analyse synthesis and uncertainty

### Set parameters

In [None]:
RUN_ID = 24 # set this to prevent overlapped saving of model and data
RANDOM_SEED = 0
ROOT_DIR = "/scratch1/sachinsa/cont_syn"
QR_REGRESSION = True

In [None]:
import os
import matplotlib.pyplot as plt

import pdb
import numpy as np
import pickle
from utils.logger import Logger

logger = Logger(log_level='DEBUG')

In [None]:
load_dir = os.path.join(ROOT_DIR, f"run_{RUN_ID}")
fig_save_dir = os.path.join("..", "figs", "runs", f"run_{RUN_ID}")
os.makedirs(fig_save_dir, exist_ok=True)

In [None]:
with open(os.path.join(load_dir, 'training_info.pkl'), 'rb') as f:
    training_info = pickle.load(f)
    epoch_loss_values = training_info['epoch_loss_values']
    metric_values = training_info['metric_values']

logger.info("PARAMETERS\n-----------------")
logger.info(f"RUN_ID: {RUN_ID}")
logger.info(f"QR_REGRESSION: {QR_REGRESSION}")
logger.info(f"ROOT_DIR: {ROOT_DIR}")
print("")

In [None]:
max_epochs = len(epoch_loss_values)
val_interval = len(epoch_loss_values)//len(metric_values)
print("Max epochs:", max_epochs)

In [None]:
MULTI_TRAINING_FIGURE = True

if not MULTI_TRAINING_FIGURE:
    plt.figure("train", (6, 4))
    x = [i + 1 for i in range(len(epoch_loss_values))]
    y = epoch_loss_values
    plt.xlabel("epoch")
    plt.ylabel("loss - log")
    plt.yscale('log')
    plt.plot(x, y, color="red")
    plt.title("Training: Gaussian Log Likelihood Loss", fontsize=25)
    plt.savefig(os.path.join(fig_save_dir, "train_plot.png"), facecolor='white')
    plt.show()
else:
    plt.figure("train", (18, 4))
    plt.subplot(1, 3, 1)
    x = [i + 1 for i in range(len(epoch_loss_values))]
    y = epoch_loss_values
    plt.xlabel("epoch", fontsize=15)
    plt.ylabel("loss", fontsize=15)
    # plt.ylabel("loss - log", fontsize=15)
    # plt.yscale('log')
    plt.plot(x, y, color="red")
    plt.suptitle("Training: Loss", fontsize=20)

    k = 2
    for zoom in [20, 100]:
        if len(x) > zoom:
            plt.subplot(1, 3, k)
            x = [i + 1 for i in range(len(epoch_loss_values))]
            y = epoch_loss_values
            plt.ylabel("loss", fontsize=15)
            plt.xlabel(f"epoch (from ep. {zoom})", fontsize=15)
            
            plt.plot(x[zoom:], y[zoom:], color="red")
        k += 1
    plt.savefig(os.path.join(fig_save_dir, "train_plot.png"), facecolor='white')
    plt.show()

plt.figure("val", (6, 4))
plt.title("Validation: 1-MSE", fontsize=20)
x_val = [val_interval * (i + 1) for i in range(len(metric_values))]
y_val = metric_values
plt.xlabel("epoch", fontsize=15)
plt.tight_layout()
plt.plot(x_val, y_val, color="green")
plt.savefig(os.path.join(fig_save_dir, "val_plot.png"), facecolor='white')
# plt.ylim([0, 1])
plt.show()

## Inference on the model

In [None]:
import torch

from utils.model import create_UNet3D, inference
from utils.transforms import contr_syn_transform_2 as data_transform
from utils.dataset import BraTSDataset

In [None]:
device = torch.device("cuda:0")
out_channels = 12 if QR_REGRESSION else 8
model = create_UNet3D(out_channels, device)

In [None]:
val_dataset = BraTSDataset(
    version='2017',
    section = 'validation',
    seed = RANDOM_SEED,
    transform = data_transform['val']
)

In [None]:
checkpoint = torch.load(os.path.join(load_dir, 'best_checkpoint.pth'), weights_only=True)
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()
channels = ["FLAIR", "T1w", "T1Gd", "T2w"]
label_list = ["TC", "WT", "ET"]
# input_mask = [True, False, True, False]
input_mask = [True, False, False, True]

# interesting ids with tummor: 328, 448

with torch.no_grad():
    # this_input = val_dataset.get_random()
    this_input = val_dataset.get_with_id(328)
    input_image = this_input["image"].unsqueeze(0).to(device)
    input_label = this_input["label"]
    input_image_copy = input_image.clone()
    input_image_copy[:, input_mask, ...] = 0
    this_output = inference(input_image_copy, model)

In [None]:
from utils.plot import BrainPlot

In [None]:
plt.figure("image", (24, 18))
b_plot = BrainPlot(input_image, this_output, input_mask)
for i in range(4):
    b_plot.plot("input")
for i in range(4):
    b_plot.plot("q0")
for i in range(4):
    b_plot.plot("diff")