In [13]:
import os
import time
import matplotlib.pyplot as plt
import numpy as np
import pdb
import pandas as pd
import pickle

from monai.data import DataLoader, decollate_batch
from monai.utils import set_determinism
from tqdm import tqdm

import torch
from torch.utils.data import Subset

from utils.dataset import BraTSDataset
from utils.model import create_SegResNet, inference
from utils.transforms import contr_syn_transform_scale as data_transform
from utils.plot import *

In [None]:
from utils.logger import Logger
logger = Logger(log_level='DEBUG')

In [15]:
RANDOM_SEED = 0
set_determinism(seed=RANDOM_SEED)

In [None]:
dataset_orig = BraTSDataset(
    version='2017',
    section = 'validation',
    seed = RANDOM_SEED,
    transform = data_transform['val']
)
loader_orig = DataLoader(dataset_orig, batch_size=1, shuffle=False, num_workers=8)

dataset_synth = BraTSDataset(
    version='2017',
    synth = True,
    section = 'validation',
    seed = RANDOM_SEED,
    transform = data_transform['basic']
)
loader_synth = DataLoader(dataset_synth, batch_size=1, shuffle=False, num_workers=8)

dataset_t1gd_mean = BraTSDataset( # dataset where t1gd is an avg of all t1gd: run_40
    version='2017',
    synth = True,
    processed_path = '/scratch1/sachinsa/data/contr_generated/run_40',
    section = 'validation',
    seed = RANDOM_SEED,
    transform = data_transform['basic']
)

logger.debug("Data loaded")
logger.debug(f"Length of dataset: {len(dataset_orig)}, {len(dataset_synth)}")

In [17]:
id_index = 0
index_list = dataset_orig.get_ids()

### Run from here!!

In [18]:
id_ = index_list[id_index]
id_index += 1

Find the slice (height index) at which Tumor Core is most present

In [None]:
label_orig = dataset_orig.get_with_id(id_)['label']
label_centroid =  find_centroid_3d(label_orig[0]) # centroid of TC (Tumor Core)
h_index=label_centroid[-1]
print(f"h_index: {h_index}")


In [None]:
print(f"ID: {id_}")
print("Original")
image_orig = dataset_orig.get_with_id(id_)['image']
plot_brainmri(image_orig, channels=["FLAIR", "T1w", "T1Gd", "T2w"], h_index=h_index, horiz=True, no_batch=True)

image_synth = dataset_synth.get_with_id(id_)['image'][2:3]
plot_brainmri(image_synth, channels=["T1Gd-synth"], h_index=h_index, horiz=True, no_batch=True)

image_mean = dataset_t1gd_mean.get_with_id(id_)['image'][2:3]
plot_brainmri(image_mean, channels=["T1Gd-mean"], h_index=h_index, horiz=True, no_batch=True)

In [None]:
brain_img = image_orig.detach().cpu()
print(f"{brain_img.mean().item():.3f} ± {brain_img.std().item():.3f} [{brain_img.min().item():.3f}, {brain_img.max().item():.3f}]")

In [None]:
plot_label(label_orig, h_index=h_index)

In [23]:
from monai.metrics import MSEMetric
from utils.loss import mse_loss

device = torch.device("cuda:0")
mse_metric = MSEMetric(reduction="mean")

In [None]:
run_idx = 0
mse_loss_combined = np.zeros(4)

with torch.no_grad():
    for orig_data, synth_data in zip(loader_orig, loader_synth):
        if run_idx > 5:break
        orig_img, orig_id = (
            orig_data["image"][0].to(device),
            orig_data["id"][0].item(),
        )
        synth_img, synth_id = (
            synth_data["image"][0].to(device),
            synth_data["id"][0].item(),
        )
        assert(orig_id == synth_id)
        synth_img_combined = torch.stack((
                synth_img[2],
                torch.zeros_like(orig_img[2]),
                orig_img[1],
                image_mean[0].to(device)
            )
        )
        for i in range(4):
            mse_loss_combined[i] += mse_loss(synth_img_combined[i], orig_img[2])
        run_idx+=1
    mse_loss_combined /= run_idx

In [None]:
for i, label in enumerate(["synth", "zero", "T1w", "T1Gd-mean"]):
    print(f"MSE {label}: {mse_loss_combined[i]:.3f}")
