In [None]:
import os
import time
import matplotlib.pyplot as plt
import numpy as np
import pdb
import pandas as pd
import pickle

from monai.apps import DecathlonDataset
from monai.config import print_config
from monai.data import DataLoader, decollate_batch
from monai.handlers.utils import from_engine
from monai.losses import DiceLoss
from monai.inferers import sliding_window_inference
from monai.metrics import DiceMetric
from monai.networks.nets import SegResNet
from monai.transforms import (
    Activations,
    AsDiscrete,
    Compose,
)
from monai.utils import set_determinism
from tqdm import tqdm

import torch
from torch.utils.data import Subset

from utils.dataset import BraTSDataset
from utils.model import create_SegResNet, inference

# print_config()

In [None]:
from utils.logger import Logger
logger = Logger(log_level='DEBUG')

In [None]:
RANDOM_SEED = 0

In [None]:
set_determinism(seed=RANDOM_SEED)

In [None]:
from utils.transforms import tumor_seg_transform_2

In [None]:
# train_dataset = BraTSDataset(
#     version='2017',
#     processed = USE_PROCESSED,
#     section = 'training',
#     seed = RANDOM_SEED,
#     transform = tumor_seg_transform_2['train']
# )

val_dataset_orig = BraTSDataset(
    version='2017',
    processed = False,
    section = 'validation',
    seed = RANDOM_SEED,
    transform = tumor_seg_transform_2['val']
)

val_dataset_median = BraTSDataset(
    version='2017',
    processed = True,
    section = 'validation',
    seed = RANDOM_SEED,
    transform = tumor_seg_transform_2['val']
)

logger.debug("Data loaded")
logger.debug(f"Length of dataset: {len(val_dataset_orig)}, {len(val_dataset_median)}")


brain_slice = val_dataset_orig[0]['image']
print(brain_slice.shape)
h_index = brain_slice.shape[-1]//2

In [None]:
idx_ = 1

this_data = val_dataset_orig[idx_]
brain_slice = this_data['image'][1][...,h_index].T
print(brain_slice.shape)
plt.figure()
plt.title(f'Original: {this_data["id"]}')
plt.imshow(brain_slice, cmap='gray')
plt.colorbar()

brain_slice = this_data['label'][1][...,h_index].T
plt.figure()
plt.title(f'Label: {this_data["id"]}')
plt.imshow(brain_slice, cmap='gray')
plt.colorbar()

this_data = val_dataset_median[idx_]
brain_slice = this_data['image'][1][...,h_index].T
print(brain_slice.shape)
plt.figure()
plt.title(f'Median: {this_data["id"]}')
plt.imshow(brain_slice, cmap='gray')
plt.colorbar()

In [None]:
train_dataset = BraTSDataset(
    version='2017',
    processed = USE_PROCESSED,
    section = 'training',
    seed = RANDOM_SEED,
    transform = tumor_seg_transform['train']
)

brain_slice = train_dataset[0]['image']
print(brain_slice.shape)
h_index = brain_slice.shape[-1]//2

this_data = train_dataset[4]
brain_slice = this_data['image'][1][...,h_index].T
print(brain_slice.shape)
plt.imshow(brain_slice, cmap='gray')

In [None]:
brain_slice = val_dataset[0]['image']
print(brain_slice.shape)

In [None]:
import numpy as np
id_ = 75
ids_ = val_dataset.ids
this_data = val_dataset[np.where(ids_ == id_)[0][0]]
brain_slice = val_dataset[0]['image'][1][...,77].T
print(brain_slice.shape)
plt.imshow(brain_slice, cmap='gray')