In [None]:
import os
import matplotlib.pyplot as plt

import pdb
import numpy as np
import pandas as pd
from utils.logger import Logger

from utils.dataset import BraTSDataset
from utils.transforms import tumor_seg_transform

In [None]:
mri_contrasts = ["FLAIR", "T1w", "T1Gd", "T2w"]
miss_prob_expected = [0.40, 0.12, 0.30, 0.15]
mask_root_dir = "/scratch1/sachinsa/data/masks/brats2017"
RANDOM_SEED = 0
fig_save_dir = os.path.join("..", "figs", f"mask")

logger = Logger(log_level='DEBUG')

### Load masks

In [None]:
from utils.mask import verify_mask_algo

for section in ['train', 'val']:
    logger.debug(section)
    if section == 'train':
        mask_df = pd.read_csv(os.path.join(mask_root_dir, "train_mask.csv"), index_col=0)
    else:
        mask_df = pd.read_csv(os.path.join(mask_root_dir, "val_mask.csv"), index_col=0)
    logger.debug(mask_df.shape)
    print(mask_df.head())
    verify_mask_algo(mask_df.values, miss_prob_expected)

In [None]:
mask_df = pd.read_csv(os.path.join(mask_root_dir, "train_mask.csv"), index_col=0)
mask_vals = mask_df.values
print(mask_vals.shape)

In [None]:
VERTICAL = False
colors = ['#4CAF50', '#F44336'] 
cmap = plt.matplotlib.colors.ListedColormap(colors)

scale = 2
fontsize = scale*15

if VERTICAL:
    plt.figure(figsize=(2, 12))
    plt.imshow(mask_vals, cmap=cmap, aspect='auto')
else:
    plt.figure(figsize=(scale*12, scale*2))
    plt.imshow(mask_vals.T[:,:200], cmap=cmap, aspect='auto')
# plt.title('Distribution of mask', fontsize=fontsize)
# ax = plt.gca()
# ax.yaxis.set_ticks_position('top') 
# ax.yaxis.set_label_position('top') 
plt.xticks(fontsize=fontsize)
plt.xlabel('Index', fontsize=fontsize)
plt.yticks(ticks=np.arange(len(mri_contrasts)), labels=mri_contrasts, fontsize=fontsize)


plt.savefig(os.path.join(fig_save_dir, "mask.png"), facecolor='white')
plt.show()

In [None]:
VERTICAL = True
colors = ['green', 'red'] 
cmap = plt.matplotlib.colors.ListedColormap(colors)

plt.figure(figsize=(2, 12))
if not VERTICAL:
    mask_vals = mask_vals.T
plt.imshow(mask_vals, cmap=cmap, aspect='auto')
plt.title('Training mask', fontsize=14)
ax = plt.gca()
ax.xaxis.set_ticks_position('top') 
ax.xaxis.set_label_position('top') 
plt.xticks(ticks=np.arange(len(mri_contrasts)), labels=mri_contrasts, fontsize=10)

plt.show()