## Verify the masking algorithms

In [None]:
from mask import Mask1

# Mask-Algo 1 verification
miss_prob_expected = [0.40, 0.12, 0.30, 0.15]
mask_obj = Mask1(num_samples=10000, miss_prob_expected=miss_prob_expected, seed=0)
mask_obj.verify()

In [None]:
from mask import Mask0

# Mask-Algo 0 verification
mask_obj = Mask0(num_samples=10000, num_contrasts=4, seed=0)
mask_obj.verify()

## Create mask info and save it

In [None]:
import os
import matplotlib.pyplot as plt

import pdb
import numpy as np
import pandas as pd
from logger import Logger

from dataset import BraTSDataset
from transforms import tumor_seg_transform

In [None]:
mri_contrasts = ["FLAIR", "T1w", "T1Gd", "T2w"]
miss_prob_expected = [0.40, 0.12, 0.30, 0.15]
mask_root_dir = "/scratch1/sachinsa/data/masks/brats2017"
RANDOM_SEED = 0

logger = Logger(log_level='DEBUG')

## Generate and save masks

In [None]:
for section in ['train', 'val']:
    logger.debug(section)
    if section == 'train':
        dataset = BraTSDataset(
            version='2017',
            section = 'training',
            seed = RANDOM_SEED,
            transform = tumor_seg_transform['train']
        )
    else:
        dataset = BraTSDataset(
            version='2017',
            section = 'validation',
            seed = RANDOM_SEED,
            transform = tumor_seg_transform['val']
        )
        
    ids = dataset.get_ids()
    num_samples=len(ids)
    mask_obj = Mask1(num_samples=num_samples, miss_prob_expected=miss_prob_expected, seed=0)
    miss_info = mask_obj.miss_info

    mask_df = pd.DataFrame(miss_info, index=ids, columns=mri_contrasts)
    logger.debug(mask_df.shape)
    print(mask_df.head())

    # save masking information
    if section == 'train':
        mask_df.to_csv(os.path.join(mask_root_dir, "train_mask.csv"), index=True)
    else:
        mask_df.to_csv(os.path.join(mask_root_dir, "val_mask.csv"), index=True)

## Load masks

In [None]:
from mask import verify_mask_algo

for section in ['train', 'val']:
    logger.debug(section)
    if section == 'train':
        mask_df = pd.read_csv(os.path.join(mask_root_dir, "train_mask.csv"), index_col=0)
    else:
        mask_df = pd.read_csv(os.path.join(mask_root_dir, "val_mask.csv"), index_col=0)
    logger.debug(mask_df.shape)
    print(mask_df.head())
    verify_mask_algo(mask_df.values, miss_prob_expected)