<a href="https://colab.research.google.com/github/ahmrina/UNet-for-Brain-Tumor-Segmentation/blob/main/Preprocess_BraTS20_.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Preprocessing BraTS20 dataset

In [None]:
import os
from google.colab import drive
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torchvision.transforms import ToTensor
import matplotlib.pyplot as plt
import numpy as np
from nibabel.testing import data_path
import nibabel as nib
from tqdm.auto import tqdm

In [None]:
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
path = '/content/drive/My Drive/MRI_segmentation_UNet'
training_set = '/content/drive/My Drive/MRI_segmentation_UNet/BraTS2020_TrainingData/MICCAI_BraTS2020_TrainingData'
validating_set = '/content/drive/My Drive/MRI_segmentation_UNet/BraTS2020_ValidationData/MICCAI_BraTS2020_ValidationData'
os.makedirs(path, exist_ok=True)
# os.makedirs(dataset_path, exist_ok = True)

In [None]:
class load_BraTS20():
  def __init__(self):
    pass

  def load_nifti(self, file_path):
    img = nib.load(file_path)
    return img.get_fdata()

  def load_training(self, path, start, end):
    t2_train, t1ce_train, flair_train, mask_train = np.array([]), np.array([]), np.array([]), np.array([])
    patient_count = start

    patients = sorted(os.listdir(path))
    selected_patients = patients[start:end]
    # print(selected_patients)

    for patient in selected_patients:

        patient_dir = os.path.join(path, patient)
        if os.path.isdir(patient_dir):

            # t1_train = self.load_nifti(os.path.join(patient_dir, f'{patient}_t1.nii'))
            t1ce = self.load_nifti(os.path.join(patient_dir, f'{patient}_t1ce.nii'))
            t2 = self.load_nifti(os.path.join(patient_dir, f'{patient}_t2.nii'))
            flair = self.load_nifti(os.path.join(patient_dir, f'{patient}_flair.nii'))
            mask = self.load_nifti(os.path.join(patient_dir, f'{patient}_seg.nii'))
            print(patient_dir)

            if t2_train.size == 0:
                t2_train = t2
                t1ce_train = t1ce
                flair_train = flair
                mask_train = mask
            else:
              t1ce_train = np.concatenate((t1ce_train, t1ce), axis = 0)
              t2_train = np.concatenate((t2_train, t2), axis = 0)
              flair_train = np.concatenate((flair_train, flair), axis = 0)
              mask_train = np.concatenate((mask_train, mask), axis = 0)

    return t1ce_train, t2_train, flair_train, mask_train

  def load_validation(self, path, start, end):
    # no segmentation masks
    t2_val, t1ce_val, flair_val = np.array([]), np.array([]), np.array([])

    patients = sorted(os.listdir(path))
    selected_patients = patients[start:end]


    for patient in selected_patients:
        patient_dir = os.path.join(path, patient)
        if os.path.isdir(patient_dir):

            t1ce = self.load_nifti(os.path.join(patient_dir, f'{patient}_t1ce.nii'))
            t2 = self.load_nifti(os.path.join(patient_dir, f'{patient}_t2.nii'))
            flair = self.load_nifti(os.path.join(patient_dir, f'{patient}_flair.nii'))

            if t2_val.size == 0:
                t2_val = t2
                t1ce_val = t1ce
                flair_val = flair
            else:
              t1ce_val = np.concatenate((t1ce_val, t1ce))
              t2_val = np.concatenate((t2_val, t2))
              flair_val = np.concatenate((flair_val, flair))
        print(patient_dir)
    return t1ce_val, t2_val, flair_val


In [None]:
dataset = load_BraTS20()

In [None]:
t1ce_train, t2_train, flair_train, mask_train = dataset.load_training(training_set, start = 300, end = 369)

In [None]:
t1ce_val, t2_val, flair_val = dataset.load_validation(validating_set, start = 62, end = 125)

/content/drive/My Drive/MRI_segmentation_UNet/BraTS2020_ValidationData/MICCAI_BraTS2020_ValidationData/BraTS20_Validation_063
/content/drive/My Drive/MRI_segmentation_UNet/BraTS2020_ValidationData/MICCAI_BraTS2020_ValidationData/BraTS20_Validation_064
/content/drive/My Drive/MRI_segmentation_UNet/BraTS2020_ValidationData/MICCAI_BraTS2020_ValidationData/BraTS20_Validation_065
/content/drive/My Drive/MRI_segmentation_UNet/BraTS2020_ValidationData/MICCAI_BraTS2020_ValidationData/BraTS20_Validation_066
/content/drive/My Drive/MRI_segmentation_UNet/BraTS2020_ValidationData/MICCAI_BraTS2020_ValidationData/BraTS20_Validation_067
/content/drive/My Drive/MRI_segmentation_UNet/BraTS2020_ValidationData/MICCAI_BraTS2020_ValidationData/BraTS20_Validation_068
/content/drive/My Drive/MRI_segmentation_UNet/BraTS2020_ValidationData/MICCAI_BraTS2020_ValidationData/BraTS20_Validation_069
/content/drive/My Drive/MRI_segmentation_UNet/BraTS2020_ValidationData/MICCAI_BraTS2020_ValidationData/BraTS20_Validat

In [None]:
t1ce_train = t1ce_train.reshape(68, 240, 240, 155)
t2_train = t2_train.reshape(68, 240, 240, 155)
flair_train = flair_train.reshape(68, 240, 240, 155)
mask_train = mask_train.reshape(68, 240, 240, 155)

In [None]:
multi_channel_img = np.stack([t1ce_train, t2_train, flair_train], axis = 3).reshape(68, 240, 240, 155, 3)
# plt.imshow(multi_channel_img[:, :, 115, 2], cmap = 'gray')

In [None]:
multi_channel_img.shape, mask_train.shape

((68, 240, 240, 155, 3), (68, 240, 240, 155))

## Preprocess Training Portion

In [None]:
def normalize(img):
  """ returns normalized image with vals between 0 and 1"""
  x_p = (img - img.min()) / (img.max() - img.min())
  return x_p

def crop_image(img):
  """ returns cropped image (128, 128, 128)"""
  if img.ndim == 5:
    return img[:, 56:184, 56:184, 13:141, :]
  else:
    return img[:, 56:184, 56:184, 13:141]

def to_categorial(a):
  """returns numpy array one hot encoded"""
  num_classes = np.unique(a).shape[0]
  ohe = np.zeros((a.shape[0], a.shape[1], a.shape[2], num_classes))

  ohe[:, :, :, 0] = (a == 0)
  ohe[:, :, :, 1] = (a == 1)
  ohe[:, :, :, 2] = (a == 2)
  ohe[:, :, :, 3] = (a == 3)
  return ohe

In [None]:
mask_train = np.where(mask_train == 4, 3, mask_train) # turning labels from [0, 1, 2, 4] to [0, 1, 2, 3]
normalized_img = normalize(multi_channel_img)

In [None]:
cropped_multi_channel_img = crop_image(normalized_img)

In [None]:
cropped_mask = crop_image(mask_train)

In [None]:
cropped_multi_channel_img.shape, cropped_mask.shape

((68, 128, 128, 128, 3), (68, 128, 128, 128))

In [None]:
save_path_img = '/content/drive/My Drive/MRI_segmentation_UNet/cropped_multi_channel_img_7.npy'
save_path_mask = '/content/drive/My Drive/MRI_segmentation_UNet/cropped_mask_7.npy'

In [None]:
np.save(save_path_img, cropped_multi_channel_img)

In [None]:
np.save(save_path_mask, cropped_mask)

In [None]:
train_split = np.load(save_path_img)
mask_split = np.load(save_path_mask)

## Preprocess Validation Portion

In [None]:
multi_channel_img_val = np.stack([t1ce_val, t2_val, flair_val], axis = 3).reshape(63, 240, 240, 155, 3)

In [None]:
multi_channel_img_val = normalize(multi_channel_img_val)
cropped_multi_channel_img_val = crop_image(multi_channel_img_val)

In [None]:
save_path_img_val = '/content/drive/My Drive/MRI_segmentation_UNet/cropped_multi_channel_val_2.npy'

In [None]:
np.save(save_path_img_val, cropped_multi_channel_img_val)

In [None]:
valid_split = np.load(save_path_img_val)