In [None]:
import random
from lib2to3.pygram import pattern_symbols
from pathlib import Path
from tracemalloc import get_traceback_limit

import monai
import numpy as np
import pandas as pd
import toml
import torch
from torch.utils.data.dataset import Dataset

def get_train_tfms(seed=42):
    transforms = monai.transforms.Compose([
        monai.transforms.RandSpatialCropd(keys=['img'], roi_size=(120, 120, 120), random_size=True),
        monai.transforms.Resized(keys=['img'], spatial_size=(128, 128, 128), mode=['area', 'nearest']),
        monai.transforms.RandFlipd(keys=['img'], prob=0.5, spatial_axis=0),
        monai.transforms.RandFlipd(keys=['img'], prob=0.5, spatial_axis=1),
        monai.transforms.RandFlipd(keys=['img'], prob=0.5, spatial_axis=2),
        monai.transforms.RandAdjustContrastd(keys="img", prob=0.7, gamma=(0.5, 2.5)),
        monai.transforms.RandShiftIntensityd(keys="img", offsets=0.125, prob=0.7),
        monai.transforms.ToTensor(dtype=torch.float),
    ])
    transforms.set_random_state(seed)
    return transforms


def get_test_tfms(seed=42):
    transforms = monai.transforms.Compose([monai.transforms.ToTensor(dtype=torch.float)])
    transforms.set_random_state(seed)
    return transforms

def get_datasets(val_split=0.05, seed=42):
    # val_paths = []
    # brats_dir = Path('/dhc/home/youngbin.ko/brain_data/data')
    # for subfolder in (brats_dir / 'val').glob('BraTS20_Validation_*'):
    #   val_paths.extend(list(subfolder.glob('*.nii.gz')))
    #   print(val_paths)

    # val_dataset = BratsDataset(val_paths, get_test_tfms(seed=seed))

    train_paths = []
    brats_dir = Path('/dhc/home/youngbin.ko/brain_data/data')
    for subfolder in (brats_dir / 'train_for_guassian_120').glob('BraTS20_Training_*'):
      train_paths.extend(list(subfolder.glob('*.nii.gz')))
      #print(train_paths)

    train_dataset = BratsDataset(train_paths, get_test_tfms(seed=seed))

    #return train_dataset, valid_dataset
    return train_dataset


class BratsDataset(Dataset):
    def __init__(self, paths, tfms=None):
        self.tfms = tfms
        self.paths = paths

    def __len__(self):
        return len(self.paths)

    def __getitem__(self, idx):
        t1_path = self.paths[idx]
        t2_path = t1_path.parent / t1_path.name.replace('t1.nii.gz', 't2.nii.gz')
        t1ce_path = t1_path.parent / t1_path.name.replace('t1.nii.gz', 't1ce.nii.gz')
        flair_path = t1_path.parent / t1_path.name.replace('t1.nii.gz', 'flair.nii.gz')

        patient_id = t1_path.stem.split('_t1')[0]

        paths = [t1_path, t2_path, t1ce_path, flair_path]
        imgs = [load_and_scale(path) for path in paths]
        img = np.stack(imgs)

        if self.tfms:
            transformed = self.tfms({'img': img})
            img = transformed['img']

        return img, patient_id


def load_and_scale(path):
    img = np.load(path)
    img = (img - img.min()) / (img.max() - img.min())
    return img


if __name__ == '__main__':
    #train_dataset, valid_dataset, test_dataset = get_datasets()
    train_dataset = get_datasets()
    train_dataset[0]
    print('succes')


In [None]:
test_paths = []
brats_dir = Path('/dhc/home/youngbin.ko/brain_data/data/val')
for subfolder in brats_dir.glob('BraTS20_Validation_*'):
  print(subfolder)
  test_paths.extend(list(subfolder.glob('*.nii.gz')))
# test_paths = list((brats_dir / 'val').glob('*.nii.gz'))

#print(test_paths)