In [17]:
import random
import os
import nibabel as nib
import numpy as np
import torch
from torch.utils.data.dataset import Dataset
import monai
from pathlib import Path
import glob
import matplotlib.pyplot as plt

In [33]:
class BratsDataset:
    def __init__(self, data_dir, transform=None, mask=False):
        self.data_dir = data_dir
        self.transform = transform
        self.filenames = os.listdir(data_dir)

    def __len__(self):
        return len(self.filenames)

    def __getitem__(self, idx):
        filename = self.filenames[idx]
        print(filename)
        nii_files = glob.glob(self.data_dir+'/*.nii.gz')
        print(nii_files)
        imgs = [nib.load(file_path).get_fdata() for file_path in nii_files]
        print(imgs)
        merged_img = np.concatenate(imgs, axis=2)
        if self.transform:
            transformed = self.transform(merged_img)
            img = transformed['img']
        return img

In [36]:
def get_train_tfms(seed=42):
    transforms = monai.transforms.Compose([
        monai.transforms.RandSpatialCropd(keys=['img', 'mask'], roi_size=(120, 120, 120), random_size=True),
        monai.transforms.Resized(keys=['img', 'mask'], spatial_size=(128, 128, 128), mode=['area', 'nearest']),
        monai.transforms.RandFlipd(keys=['img', 'mask'], prob=0.5, spatial_axis=0),
        monai.transforms.RandFlipd(keys=['img', 'mask'], prob=0.5, spatial_axis=1),
        monai.transforms.RandFlipd(keys=['img', 'mask'], prob=0.5, spatial_axis=2),
        monai.transforms.RandAdjustContrastd(keys="img", prob=0.7, gamma=(0.5, 2.5)),
        monai.transforms.RandShiftIntensityd(keys="img", offsets=0.125, prob=0.7),
        monai.transforms.ToTensor(dtype=torch.float),
    ])
    transforms.set_random_state(seed)
    return transforms

def get_test_tfms(seed=42):
    transforms = monai.transforms.Compose([monai.transforms.ToTensor(dtype=torch.float)])
    transforms.set_random_state(seed)
    return transforms

def get_datasets(brats_dir, val_split=0.05, seed=42):
    brats_dir = Path(brats_dir)
    test_paths = list((brats_dir / 'test_15')).glob('*t1.nii.gz')
    test_dataset = BratsDataset(test_paths, get_test_tfms(seed=seed), mask=False)

    # paths = list((Path(train_folder)).glob('*t1.nii.gz'))
    # random.shuffle(paths)
    # nb_examples = len(paths)
    # train_split = 1 - val_split
    # train_paths = paths[0: int(nb_examples*train_split)]
    # valid_paths = paths[int(nb_examples*train_split):]
    train_paths = list((brats_dir / 'train_for_nf_248')).glob('*t1.nii.gz')
    train_dataset = BratsDataset(train_paths, get_train_tfms(seed=seed), mask=True)
    valid_paths = list((brats_dir / 'val')).glob('*t1.nii.gz')
    valid_dataset = BratsDataset(valid_paths, get_test_tfms(seed=seed), mask=True)

    return train_dataset, valid_dataset, test_dataset