## Verify the masking algorithms

In [None]:
from mask import Mask1

# Mask-Algo 1 verification
miss_prob_expected = [0.40, 0.12, 0.30, 0.15]
miss_info = Mask1(num_samples=10000, miss_prob_expected=miss_prob_expected, seed=0)
miss_info.verify()
miss_info.miss_info

In [None]:
from mask import Mask0

# Mask-Algo 0 verification
miss_info = Mask0(num_samples=10000, num_contrasts=4, seed=0)
miss_info.verify()
miss_info.miss_info

## Create mask info and save it

In [None]:
RANDOM_SEED = 0
ROOT_DIR = "/scratch1/sachinsa/cont_syn"

In [None]:
import os
import matplotlib.pyplot as plt

import pdb
import numpy as np
import pickle
from logger import Logger

import torch
from monai.networks.nets import UNet
from transforms import contr_syn_transform_2 as data_transform
# from dataset import BraTSDataset

In [None]:
logger = Logger(log_level='DEBUG')
# load_dir = os.path.join(ROOT_DIR, f"run_{RUN_ID}")

In [None]:
import os
import numpy as np
import json
from torch.utils.data import Dataset, DataLoader, random_split
import torch

In [None]:
def loadBRATS2017(root_dir):
    json_file_path = os.path.join(root_dir, "dataset.json")
    with open(json_file_path, 'r') as file:
        properties = json.load(file)
    return properties

DATA_PATH = {
    'BraTS_2017': '/scratch1/sachinsa/data/Task01_BrainTumour'
}

# TODO: Move out the mask logic from inside this class
# TODO: Study MONAI DecathloanDataset and CacheDataset class to improve this class
class BraTSDataset(Dataset):
    def __init__(self, version, section, transform=None, seed=0, has_mask=True, has_label=True):
        self.root_dir = DATA_PATH[f'BraTS_{version}']

        if version == '2017':
            self.properties = loadBRATS2017(self.root_dir)
            self.image_filenames = self.properties['training']
            
            # removing BRATS_065 dataset as it disrupts training
            self.image_filenames = [item for item in self.image_filenames if 'BRATS_065' not in item['image']]
            
            self.indices = np.array([int(filepath['image'][17:-7]) for filepath in self.image_filenames])
            
        self.has_mask = has_mask
        self.mask = Mask0(num_samples=len(self.image_filenames), seed=seed)
        
        # TODO: if transform == None, use a default transform that simply loads the data
        self.transform = transform

    def __len__(self):
        return len(self.image_filenames)

    def __getitem__(self, idx):
        img_name = os.path.normpath(os.path.join(self.root_dir,self.image_filenames[idx]['image']))
        mask = self.mask.get(idx)
        image = self.transform(img_name)
        mask = torch.from_numpy(mask)

        return {"id": self.image_filenames[idx]['image'][11:20], "image":image, "mask":mask}
    
    def get_with_id(self, id):
        idx = self.indices.index(id)
        return self.__getitem__(idx)
    
    def get_indices(self) -> np.ndarray:
        """
        Get the indices of datalist used in this dataset.

        """
        return self.indices
    
    def get_properties(self):
        return self.properties

In [None]:
all_dataset = BraTSDataset(
    version='2017',
    section = 'all',
    seed = RANDOM_SEED,
    transform = data_transform['test']
)

indices = all_dataset.get_indices()
print("len: ", len(indices))
print("max: ", max(indices))
print(indices)

# properties = all_dataset.get_properties()
# properties

In [None]:
from monai.apps import DecathlonDataset
from transforms import tumor_seg_transform

train_ds = DecathlonDataset(
    root_dir='/scratch1/sachinsa/data',
    task="Task01_BrainTumour",
    transform=tumor_seg_transform['train'],
    section="training",
    download=True,
    cache_rate=0.0,
    num_workers=4,
)

In [None]:
# indices = train_ds.get_indices()
# print(len(indices))
# print(indices)

properties = train_ds.get_properties()
properties