In [None]:
import os
import time
import matplotlib.pyplot as plt
import numpy as np
import pdb
import pandas as pd
import pickle

from monai.apps import DecathlonDataset
from monai.config import print_config
from monai.data import DataLoader, decollate_batch
from monai.handlers.utils import from_engine
from monai.losses import DiceLoss
from monai.inferers import sliding_window_inference
from monai.metrics import DiceMetric
from monai.networks.nets import SegResNet
from monai.transforms import (
    Activations,
    AsDiscrete,
    Compose,
)
from monai.metrics import MSEMetric
from monai.utils import set_determinism
from tqdm import tqdm

import torch
from torch.utils.data import Subset

from utils.dataset import BraTSDataset
from utils.model import create_SegResNet, inference

# print_config()

In [None]:
from utils.logger import Logger
logger = Logger(log_level='DEBUG')

In [None]:
RANDOM_SEED = 0

In [None]:
set_determinism(seed=RANDOM_SEED)

In [None]:
from utils.transforms import contr_syn_transform_3 as data_transform

In [None]:
dataset_orig = BraTSDataset(
    version='2017',
    processed = False,
    section = 'validation',
    seed = RANDOM_SEED,
    transform = data_transform['val']
)
loader_orig = DataLoader(dataset_orig, batch_size=1, shuffle=False, num_workers=8)

dataset_median = BraTSDataset(
    version='2017',
    processed = True,
    section = 'validation',
    seed = RANDOM_SEED,
    transform = data_transform['basic']
)
loader_median = DataLoader(dataset_median, batch_size=1, shuffle=False, num_workers=8)

logger.debug("Data loaded")
logger.debug(f"Length of dataset: {len(dataset_orig)}, {len(dataset_median)}")

In [None]:
# Load masks
mask_root_dir = "/scratch1/sachinsa/data/masks/brats2017"
train_mask_df = pd.read_csv(os.path.join(mask_root_dir, "train_mask.csv"), index_col=0)
val_mask_df = pd.read_csv(os.path.join(mask_root_dir, "val_mask.csv"), index_col=0)
all_mask_df = pd.concat([train_mask_df, val_mask_df], axis=0)
all_mask_df.head(2)

In [None]:
device = torch.device("cuda:0")

In [None]:
# from utils.model import create_UNet3D, inference

# RUN_ID = 22
# RANDOM_SEED = 0
# ROOT_DIR = "/scratch1/sachinsa/cont_syn"
# load_dir = os.path.join(ROOT_DIR, f"run_{RUN_ID}")

# model = create_UNet3D(out_channels=12, device=device)
# checkpoint = torch.load(os.path.join(load_dir, 'best_checkpoint.pth'), weights_only=True)
# model.load_state_dict(checkpoint['model_state_dict'])

# mse_metric = MSEMetric(reduction="mean")

In [None]:
id_ = 1
this_data = dataset_orig[id_]
median_data = dataset_median[id_]

In [None]:
dataset_orig.get_with_id(328);
dataset_median.get_with_id(328);

In [None]:
id_ = 1
this_data = dataset_orig[id_]
this_inputs = this_data["image"].unsqueeze(0).to(device)
this_id = this_data["id"]
this_mask = torch.from_numpy(all_mask_df.loc[[this_id], :].values).to(device)

median_data = dataset_median[id_]
this_saved_median = median_data["image"].unsqueeze(0)[:,:4,...].to(device)

In [None]:
# with torch.no_grad():
#     this_target = this_inputs.clone()
#     this_inputs = this_inputs*~this_mask[:,:,None,None,None]
#     this_outputs = inference(this_inputs, model)

# this_output_median = this_outputs[:,:4,...]
# mse_metric(y_pred=this_output_median, y=this_saved_median)

# metric = mse_metric.aggregate().item()
# mse_metric.reset()
# print(f"mse error: {metric}")

In [None]:
h_index = 77
c_index = 1 # channel
channels = ["FLAIR", "T1w", "T1Gd", "T2w"]
print(f"Channel: {channels[c_index]}")
print(f"ID: {this_id}")
brain_slice = this_inputs.detach().cpu().numpy()
brain_slice = brain_slice[0,c_index,:,:,h_index].T
plt.figure()
plt.title(f'Original: {this_id}')
plt.imshow(brain_slice, cmap='gray')
plt.colorbar()

brain_slice = this_saved_median.detach().cpu().numpy()
brain_slice = brain_slice[0,c_index,:,:,h_index].T
print(brain_slice.mean(), brain_slice.min(), brain_slice.max())
plt.figure()
plt.title(f'Saved Median: {this_id}')
plt.imshow(brain_slice, cmap='gray')
plt.colorbar()

# brain_slice = this_output_median.detach().cpu().numpy()
# brain_slice = brain_slice[0,c_index,:,:,h_index].T
# print(brain_slice.mean(), brain_slice.min(), brain_slice.max())
# plt.figure()
# plt.title(f'Output Median: {this_id}')
# plt.imshow(brain_slice, cmap='gray')
# plt.colorbar()

MSE: (over no-mask contrasts)

In [None]:
# model.eval()
# i = 0
# with torch.no_grad():
#     for this_data, median_data in zip(loader_orig,loader_median):
#         i+=1
#         if i>1:break
#         this_inputs, this_ids = (
#             this_data["image"].to(device),
#             this_data["id"],
#         )
#         this_mask = torch.from_numpy(all_mask_df.loc[this_ids.tolist(), :].values).to(device)[:,:,None,None,None]
#         this_saved_median = median_data["image"][:,:4,...].to(device)
#         this_inputs = this_inputs*~this_mask
#         this_saved_median = this_saved_median*~this_mask
#         mse_metric(y_pred=this_inputs, y=this_saved_median)

#     metric = mse_metric.aggregate().item()
#     mse_metric.reset()
# print(f"mse error: {metric}")

MSE: (over masked contrasts)

In [None]:
# model.eval()
# i = 0
# with torch.no_grad():
#     for this_data, median_data in zip(loader_orig,loader_median):
#         i+=1
#         if i>1:break
#         this_inputs, this_ids = (
#             this_data["image"].to(device),
#             this_data["id"],
#         )
#         this_mask = torch.from_numpy(all_mask_df.loc[this_ids.tolist(), :].values).to(device)[:,:,None,None,None]
#         this_saved_median = median_data["image"][:,:4,...].to(device)
#         this_inputs = this_inputs*this_mask
#         this_saved_median = this_saved_median*this_mask
#         mse_metric(y_pred=this_inputs, y=this_saved_median)

#     metric = mse_metric.aggregate().item()
#     mse_metric.reset()
# print(f"mse error: {metric}")

MSE (over everything)

In [None]:
# model.eval()
# i = 0
# with torch.no_grad():
#     for this_data, median_data in zip(loader_orig,loader_median):
#         i+=1
#         if i>1:break
#         this_inputs, this_ids = (
#             this_data["image"].to(device),
#             this_data["id"],
#         )
#         # this_mask = torch.from_numpy(all_mask_df.loc[this_ids.tolist(), :].values).to(device)[:,:,None,None,None]
#         this_saved_median = median_data["image"][:,:4,...].to(device)
#         # this_inputs = this_inputs*this_mask
#         # this_saved_median = this_inputs*this_mask
#         mse_metric(y_pred=this_inputs, y=this_saved_median)

#     metric = mse_metric.aggregate().item()
#     mse_metric.reset()
# print(f"mse error: {metric}")