# To save uncertainty results for downstream task

In [None]:
RUN_ID = 22
RANDOM_SEED = 0
ROOT_DIR = "/scratch1/sachinsa/cont_syn"

In [None]:
import os
import pandas as pd

import pdb
import numpy as np
from utils.logger import Logger

import torch
from torch.utils.data import DataLoader

from utils.model import create_UNet3D, inference
from utils.transforms import contr_syn_transform_3 as data_transform
from utils.dataset import BraTSDataset

logger = Logger(log_level='DEBUG')

In [None]:
load_dir = os.path.join(ROOT_DIR, f"run_{RUN_ID}")
save_dir = os.path.join('/scratch1/sachinsa/data/contr_generated', f"run_{RUN_ID}_mixed")
os.makedirs(save_dir, exist_ok=True)

In [None]:
device = torch.device("cuda:0")
model = create_UNet3D(out_channels=12, device=device)

In [None]:
# all_dataset = BraTSDataset(
#     version='2017',
#     section='all',
#     seed = RANDOM_SEED,
#     transform = contr_syn_transform_3['val']
# )
# all_loader = DataLoader(all_dataset, batch_size=1, shuffle=False, num_workers=8)

dataset_orig = BraTSDataset(
    version='2017',
    processed = False,
    section = 'all',
    seed = RANDOM_SEED,
    transform = data_transform['val']
)
loader_orig = DataLoader(dataset_orig, batch_size=1, shuffle=False, num_workers=8)

dataset_median = BraTSDataset(
    version='2017',
    processed = True,
    section = 'all',
    seed = RANDOM_SEED,
    transform = data_transform['basic']
)
loader_median = DataLoader(dataset_median, batch_size=1, shuffle=False, num_workers=8)

logger.debug("Data loaded")
logger.debug(f"Length of dataset: {len(dataset_orig)}, {len(dataset_median)}")

In [None]:
# Load masks
mask_root_dir = "/scratch1/sachinsa/data/masks/brats2017"
train_mask_df = pd.read_csv(os.path.join(mask_root_dir, "train_mask.csv"), index_col=0)
val_mask_df = pd.read_csv(os.path.join(mask_root_dir, "val_mask.csv"), index_col=0)
all_mask_df = pd.concat([train_mask_df, val_mask_df], axis=0)
all_mask_df.head(2)

In [None]:
# checkpoint = torch.load(os.path.join(load_dir, 'best_checkpoint.pth'), weights_only=True)
# model.load_state_dict(checkpoint['model_state_dict'])
# model.eval()
channels = ["FLAIR", "T1w", "T1Gd", "T2w"]

In [None]:
import nibabel as nib

In [None]:
# i = 0
with torch.no_grad():
    for this_data, median_data in zip(loader_orig,loader_median):
        # i+=1
        # if i>2:break
        this_inputs, this_ids = (
            this_data["image"].to(device),
            this_data["id"],
        )
        this_mask = torch.from_numpy(all_mask_df.loc[this_ids.tolist(), :].values).to(device)[:,:,None,None,None]
        this_saved_median = median_data["image"][:,:4,...].to(device)
        this_inputs = this_inputs*~this_mask
        this_saved_median = this_saved_median*this_mask
        this_mixed = this_inputs + this_saved_median
        
        mri_array = this_mixed[0].detach().permute(1, 2, 3, 0).cpu().numpy()
        nifti_img = nib.Nifti1Image(mri_array,affine=np.eye(4))
        output_filename = os.path.join(save_dir, f'BRATS_{this_ids[0]}.nii.gz')
        print(output_filename)
        nib.save(nifti_img, output_filename)

In [None]:
# print(this_mask.squeeze())
# this_id = this_data["id"].item()
# this_id

In [None]:
# import matplotlib.pyplot as plt
# print(this_mixed.shape)

# this_target = this_inputs

# h_index = 77
# c_index = 1 # channel
# channels = ["FLAIR", "T1w", "T1Gd", "T2w"]
# print(f"Channel: {channels[c_index]}")
# print(f"ID: {this_id}")
# brain_slice = this_target.detach().cpu().numpy()
# brain_slice = brain_slice[0,c_index,:,:,h_index].T
# plt.figure()
# plt.title(f'Original: {this_id}')
# plt.imshow(brain_slice, cmap='gray')
# plt.colorbar()

# brain_slice = this_saved_median.detach().cpu().numpy()
# brain_slice = brain_slice[0,c_index,:,:,h_index].T
# print(brain_slice.mean(), brain_slice.min(), brain_slice.max())
# plt.figure()
# plt.title(f'Saved Median: {this_id}')
# plt.imshow(brain_slice, cmap='gray')
# plt.colorbar()

In [None]:
# with torch.no_grad():
#     for this_data in all_loader:
#         this_inputs, this_ids = (
#             this_data["image"].to(device),
#             this_data["id"],
#         )
#         this_mask = torch.from_numpy(all_mask_df.loc[this_ids.tolist(), :].values).to(device)
#         this_target = this_inputs.clone()
#         this_inputs = this_inputs*~this_mask[:,:,None,None,None]
#         this_outputs = inference(this_inputs, model)
        
#         mri_array = this_outputs[0].detach().permute(1, 2, 3, 0).cpu().numpy()
#         nifti_img = nib.Nifti1Image(mri_array,affine=np.eye(4))
#         output_filename = os.path.join(save_dir, f'BRATS_{this_ids[0]}.nii.gz')
#         print(output_filename)
#         nib.save(nifti_img, output_filename)