## To analyse the brain segmentation results

In [1]:
RUN_ID = 810
T1GD_SYNTH = True
MASK_CODE = 0
if T1GD_SYNTH: MASK_CODE = 0
RANDOM_SEED = 0
ROOT_DIR = "/scratch1/sachinsa/brats_seg"

In [None]:
import os
import matplotlib.pyplot as plt

import pdb
import numpy as np
import pickle
from utils.logger import Logger
from utils.plot import *
from itertools import chain, combinations

logger = Logger(log_level='DEBUG')

In [3]:
load_dir = os.path.join(ROOT_DIR, f"run_{RUN_ID:03d}")
fig_save_dir = os.path.join("..", "figs", f"run_{RUN_ID:03d}")
os.makedirs(fig_save_dir, exist_ok=True)

In [None]:
with open(os.path.join(load_dir, 'training_info.pkl'), 'rb') as f:
    training_info = pickle.load(f)
    epoch_loss_values = training_info['epoch_loss_values']
    metric_values = training_info['metric_values']
    metric_values_tc = training_info['metric_values_tc']
    metric_values_wt = training_info['metric_values_wt']
    metric_values_et = training_info['metric_values_et']

In [None]:
max_epochs = len(epoch_loss_values)
val_interval = len(epoch_loss_values)//len(metric_values)
logger.info(f"Total epochs: {max_epochs}")

### Plot the loss and metric

In [None]:
plot_training_tumor_seg(epoch_loss_values, metric_values, metric_values_tc, metric_values_wt, metric_values_et, val_interval)

In [None]:
metric = np.max(metric_values)
arg_max = np.argmax(metric_values)
metric_tc = metric_values_tc[arg_max]
metric_wt = metric_values_wt[arg_max]
metric_et = metric_values_et[arg_max]

# print(f"Masked contrasts: {[channels[i] for i in mask_indices]}")
print(f"Epochs  Total	TC	WT	ET")
print(f"{len(epoch_loss_values)}	{100*metric:.1f}	{100*metric_tc:.1f}	{100*metric_wt:.1f}	{100*metric_et:.1f}")

fig, axs = plt.subplots(1,4, figsize=(10, 3),gridspec_kw={'wspace': 0, 'hspace': 0})
plot_donut(metric, "Total", "green", axs[0])
plot_donut(metric_tc, "TC", "blue", axs[1])
plot_donut(metric_wt, "WT", "brown", axs[2])
plot_donut(metric_et, "ET", "purple", axs[3])
plt.show()


## Inference on the model

In [63]:
import torch
from monai.transforms import (
    Compose,
)
from monai.config import print_config
from monai.handlers.utils import from_engine
from monai.losses import DiceLoss
from monai.inferers import sliding_window_inference
from monai.metrics import DiceMetric
from monai.networks.nets import SegResNet
from monai.transforms import (
    Activations,
    AsDiscrete,
    Compose,
)
from monai.utils import set_determinism

from tqdm import tqdm
from utils.transforms import tumor_seg_transform_default as data_transform
from utils.model import create_SegResNet, inference
from utils.dataset import BraTSDataset

In [64]:
def all_subsets(arr):
    subsets = list(chain.from_iterable(combinations(arr, r) for r in range(0, len(arr))))
    return [list(subset) for subset in subsets]

mask_indices = all_subsets([0, 1, 2, 3])[MASK_CODE]
show_indices = [x for x in [0, 1, 2, 3] if x not in mask_indices]
channels = ["FLAIR", "T1w", "T1Gd", "T2w"]
label_list = ["TC", "WT", "ET"]

In [65]:
device = torch.device("cuda:0")
in_channels = len(show_indices)
model = create_SegResNet(in_channels, device)

In [66]:
dice_metric = DiceMetric(include_background=True, reduction="mean")
dice_metric_batch = DiceMetric(include_background=True, reduction="mean_batch")

post_trans = Compose([Activations(sigmoid=True), AsDiscrete(threshold=0.5)])

In [67]:
val_dataset = BraTSDataset(
    version='2017',
    section = 'all',
    seed = RANDOM_SEED,
    transform = data_transform['val']
)

In [68]:
checkpoint = torch.load(os.path.join(load_dir, 'best_checkpoint.pth'), weights_only=True)
model.load_state_dict(checkpoint['model_state_dict'])
model.eval();

In [69]:
this_index = 0

In [None]:
id_ = 449 # np.sort(val_dataset.get_ids())[this_index]
this_index += 1
print(id_)

with torch.no_grad():
    this_input = val_dataset.get_with_id(id_)
    input_image = this_input["image"].unsqueeze(0).to(device)
    if not T1GD_SYNTH:
        input_image = input_image[:, show_indices, ...]
    input_label = this_input["label"]
    this_output = inference(input_image, model)
    this_output = post_trans(this_output[0])

In [None]:
label_centroid =  find_centroid_3d(input_label[0]) # centroid of TC (Tumor Core)
print(label_centroid)

In [None]:
_, _, im_length, im_width, im_height = input_image.shape
h_index = label_centroid[-1]
label_list = ["TC", "WT", "ET", "Combined"]
channels = ["FLAIR", "T1w", "T1gd", "T2w"]

def plot_label(index, label):
    if label == "ground_truth":
        start_index = 1*len(channels)
        brain_slice = input_label
        title = "Ground Truth"
    elif label == "prediction":
        start_index = 2*len(channels)
        brain_slice = this_output
        title = "Prediction"

    brain_slice = brain_slice[..., h_index].detach().cpu()
    if index < 3:
        brain_slice = brain_slice[index, ...]
    else:
        brain_slice = brain_slice.sum(axis=0)
    brain_slice = brain_slice.T
    plt.subplot(3, 4, start_index + index + 1)
    plt.title(label_list[index], fontsize=30)
    if index == 0:
        plt.ylabel(title, fontsize=30)
    plt.xticks([0, im_width - 1], [0, im_width - 1], fontsize=15)
    plt.yticks([0, im_length - 1], [0, im_length - 1], fontsize=15)
    cmap = "gray" if index < 3 else "magma"
    plt.imshow(brain_slice, cmap=cmap)


plt.figure("image", (24, 18))
for i in range(len(show_indices)):
    plt.subplot(3, 4, i + 1)
    plt.title(channels[show_indices[i]], fontsize=30)
    if i == 0:
        plt.ylabel("Input", fontsize=30)
    plt.xticks([0, im_width - 1], [0, im_width - 1], fontsize=15)
    plt.yticks([0, im_length - 1], [0, im_length - 1], fontsize=15)
    brain_slice = input_image[0, i, :, :, h_index].detach().cpu().T
    plt.imshow(brain_slice, cmap="gray", vmin=-3, vmax=4)
    plt.colorbar(shrink=0.7)
# plt.suptitle(f"BRATS_{this_input['id']} (h={h_index}/{im_height})", fontsize=20)
# plt.show()
    
# plt.figure("label", (18, 12))
for i in range(len(label_list)):
    plot_label(i, "ground_truth")
for i in range(len(label_list)):
    plot_label(i, "prediction")
plt.suptitle(f"BRATS_{this_input['id']} (h={h_index}/{im_height})", y=0.9, fontsize=20)
plt.show()