# Split Data

In [None]:
imgs_dir='./BraTs'
all_dirs = glob (f'{imgs_dir}/*')
len (all_dirs)
all_dirs.sort()
len(all_dirs), all_dirs[:10]

In [None]:
def shuffle_split (all_dirs, val_pct = 0.15, seed = 99):
    """ shuffling dataset with random state and split to train and valid """
    n_val = int (len (all_dirs) * val_pct)
    np.random.seed (seed)
    idx = np.random.permutation (len (all_dirs))
    all_dirs = np.array (all_dirs) [idx]
    
    return all_dirs [n_val:], all_dirs [:n_val]

train_dirs, valid_dirs = shuffle_split (all_dirs, seed = 1)
len(valid_dirs), len(train_dirs)

# Data Class

In [None]:
class BratsDataset (Dataset):
    def __init__ (self, img_dirs, modality_types, transform = None):
        self.img_dirs = img_dirs
        self.transform = transform

    def __len__ (self):
        return len (self.img_dirs)

    def __getitem__ (self, index):
        imgs_path = self.img_dirs [index]
        image = self.concat_imgs (imgs_path)
        mask = np.array (Image.open (f'{imgs_path}/seg.jpg'))
        mask = (mask / 255 * 4).round ()
        mask = self.preprocess_mask_labels(mask)
        
        
        if self.transform is not None:
            augmented = self.transform(image = image, mask = mask)
            image = augmented ['image']
            mask = augmented ['mask']

        return image.astype(float), mask.astype(float)

    def concat_imgs (self, path: str):
        types = []
        for modality_type in modality_types:
            img = np.array (Image.open (f'{path}/{modality_type}.jpg'))
            img = self.normalize(img)
            types.append (img)
        return np.array(types)
    
    def preprocess_mask_labels(self, mask: np.ndarray):
        mask_WT = np.zeros(mask.shape)
        mask_WT[mask == 2] = 1
       
        mask_TC = np.zeros(mask.shape)
        mask_TC[mask == 1] = 1

        mask_ET = np.zeros(mask.shape)
        mask_ET[mask == 3] = 1
        
        mask_BG = np.zeros(mask.shape)
        mask_BG[mask == 0] = 1
        
        mask = np.stack([mask_WT, mask_TC, mask_ET, mask_BG])
        # mask = np.moveaxis(mask, (0, 1, 2), (0, 2, 1))
        return mask
    
    def normalize(self, data: np.ndarray):
        data_min = np.min(data)
        if np.max(data) == 0:
            return data
        if (np.max(data) - data_min) == 0:
            return data / data_min 
        
        return (data - data_min) / (np.max(data) - data_min)

# Data transformation

In [None]:
# Transformation
trn_tfms = A.Compose (
[
    A.Resize (height = 240, width = 240),
         A.Rotate (limit = 35, p = 1.0),
         A.HorizontalFlip (p = 0.5),
         A.VerticalFlip (p = 0.1),
         A.Normalize (mean=0.5, std=0.5, max_pixel_value = 255.0), 
             img = (img - mean * max_pixel_value) / (std * max_pixel_value)
         ToTensorV2 ()
])


val_tfms = A.Compose (
[
    A.Resize (height = 240, width = 240),
         A.Normalize (0.5, 0.5, max_pixel_value = 255.0),
         ToTensorV2 ()
])

# Generate Data

In [None]:
modality_types = ['flair', 't1', 't1ce', 't2']

batch_size = 64  # 총 배치 사이즈

train_ds = BratsDataset(train_dirs, modality_types)
valid_ds = BratsDataset(valid_dirs, modality_types)
train_dl = DataLoader(train_ds, batch_size = batch_size, shuffle = False, num_workers = 12, pin_memory = True)
valid_dl = DataLoader(valid_ds, batch_size = batch_size, shuffle = False, num_workers = 12, pin_memory = True)