In [None]:
import numpy as np
import nibabel as nib
import glob
from tensorflow.keras.utils import to_categorical
import matplotlib.pyplot as plt
from tifffile import imsave
from sklearn.preprocessing import MinMaxScaler
import os
import random

# Initialize the MinMaxScaler
scaler = MinMaxScaler()

# Define dataset paths
TRAIN_DATASET_PATH = './brats20-dataset-training-validation/BraTS2020_TrainingData/MICCAI_BraTS2020_TrainingData/'

# Handling the renamed segmented file
seg_file_path = TRAIN_DATASET_PATH + 'BraTS20_Training_355/W39_1998.09.19_Segm.nii'
if os.path.exists(seg_file_path):
    os.rename(seg_file_path, TRAIN_DATASET_PATH + 'BraTS20_Training_355/BraTS20_Training_355_seg.nii')

# Load and preprocess sample images and masks
# Load the FLAIR image
test_image_flair = nib.load(TRAIN_DATASET_PATH + 'BraTS20_Training_355/BraTS20_Training_355_flair.nii').get_fdata()
# Scale the FLAIR image to [0, 1]
test_image_flair = scaler.fit_transform(test_image_flair.reshape(-1, test_image_flair.shape[-1])).reshape(test_image_flair.shape)

# Load the T1 image
test_image_t1 = nib.load(TRAIN_DATASET_PATH + 'BraTS20_Training_355/BraTS20_Training_355_t1.nii').get_fdata()
# Scale the T1 image to [0, 1]
test_image_t1 = scaler.fit_transform(test_image_t1.reshape(-1, test_image_t1.shape[-1])).reshape(test_image_t1.shape)

# Load the T1ce image
test_image_t1ce = nib.load(TRAIN_DATASET_PATH + 'BraTS20_Training_355/BraTS20_Training_355_t1ce.nii').get_fdata()
# Scale the T1ce image to [0, 1]
test_image_t1ce = scaler.fit_transform(test_image_t1ce.reshape(-1, test_image_t1ce.shape[-1])).reshape(test_image_t1ce.shape)

# Load the T2 image
test_image_t2 = nib.load(TRAIN_DATASET_PATH + 'BraTS20_Training_355/BraTS20_Training_355_t2.nii').get_fdata()
# Scale the T2 image to [0, 1]
test_image_t2 = scaler.fit_transform(test_image_t2.reshape(-1, test_image_t2.shape[-1])).reshape(test_image_t2.shape)

# Load the segmentation mask
test_mask = nib.load(TRAIN_DATASET_PATH + 'BraTS20_Training_355/BraTS20_Training_355_seg.nii').get_fdata()
# Convert the mask to uint8
test_mask = test_mask.astype(np.uint8)
# Change mask pixel values from 4 to 3
test_mask[test_mask == 4] = 3

# Visualize the sample images and mask
n_slice = random.randint(0, test_mask.shape[2])

plt.figure(figsize=(12, 8))

plt.subplot(231)
plt.imshow(test_image_flair[:, :, n_slice], cmap='gray')
plt.title('Image flair')
plt.subplot(232)
plt.imshow(test_image_t1[:, :, n_slice], cmap='gray')
plt.title('Image t1')
plt.subplot(233)
plt.imshow(test_image_t1ce[:, :, n_slice], cmap='gray')
plt.title('Image t1ce')
plt.subplot(234)
plt.imshow(test_image_t2[:, :, n_slice], cmap='gray')
plt.title('Image t2')
plt.subplot(235)
plt.imshow(test_mask[:, :, n_slice])
plt.title('Mask')
plt.show()

# Combine images into a single multichannel image
combined_x = np.stack([test_image_flair, test_image_t1ce, test_image_t2], axis=3)
# Crop the images to a size divisible by 64
combined_x = combined_x[56:184, 56:184, 13:141]

# Crop the mask to match the image dimensions
test_mask = test_mask[56:184, 56:184, 13:141]

# Visualize the cropped images and mask
n_slice = random.randint(0, test_mask.shape[2])
plt.figure(figsize=(12, 8))

plt.subplot(221)
plt.imshow(combined_x[:, :, n_slice, 0], cmap='gray')
plt.title('Image flair')
plt.subplot(222)
plt.imshow(combined_x[:, :, n_slice, 1], cmap='gray')
plt.title('Image t1ce')
plt.subplot(223)
plt.imshow(combined_x[:, :, n_slice, 2], cmap='gray')
plt.title('Image t2')
plt.subplot(224)
plt.imshow(test_mask[:, :, n_slice])
plt.title('Mask')
plt.show()

# Save the combined images and mask
imsave('BraTS2020_TrainingData/combined255.tif', combined_x)
np.save('BraTS2020_TrainingData/combined255.npy', combined_x)

# Load and convert the mask to categorical
test_mask = to_categorical(test_mask, num_classes=4)

# Process all images in the dataset
t2_list = sorted(glob.glob(TRAIN_DATASET_PATH + '*/BraTS20_Training_*_t2.nii'))
t1ce_list = sorted(glob.glob(TRAIN_DATASET_PATH + '*/BraTS20_Training_*_t1ce.nii'))
flair_list = sorted(glob.glob(TRAIN_DATASET_PATH + '*/BraTS20_Training_*_flair.nii'))
mask_list = sorted(glob.glob(TRAIN_DATASET_PATH + '*/BraTS20_Training_*_seg.nii'))

for img in range(len(t2_list)):
    print("Now preparing image and masks number:", img)

    # Load and preprocess T2 image
    temp_image_t2 = nib.load(t2_list[img]).get_fdata()
    temp_image_t2 = scaler.fit_transform(temp_image_t2.reshape(-1, temp_image_t2.shape[-1])).reshape(temp_image_t2.shape)

    # Load and preprocess T1ce image
    temp_image_t1ce = nib.load(t1ce_list[img]).get_fdata()
    temp_image_t1ce = scaler.fit_transform(temp_image_t1ce.reshape(-1, temp_image_t1ce.shape[-1])).reshape(temp_image_t1ce.shape)

    # Load and preprocess FLAIR image
    temp_image_flair = nib.load(flair_list[img]).get_fdata()
    temp_image_flair = scaler.fit_transform(temp_image_flair.reshape(-1, temp_image_flair.shape[-1])).reshape(temp_image_flair.shape)

    # Load and preprocess mask
    temp_mask = nib.load(mask_list[img]).get_fdata()
    temp_mask = temp_mask.astype(np.uint8)
    temp_mask[temp_mask == 4] = 3

    # Combine images into a single multichannel image
    temp_combined_images = np.stack([temp_image_flair, temp_image_t1ce, temp_image_t2], axis=3)
    # Crop the images to a size divisible by 64
    temp_combined_images = temp_combined_images[56:184, 56:184, 13:141]
    # Crop the mask to match the image dimensions
    temp_mask = temp_mask[56:184, 56:184, 13:141]

    # Check if the volume has at least 1% useful labels
    val, counts = np.unique(temp_mask, return_counts=True)
    if (1 - (counts[0] / counts.sum())) > 0.01:
        print("Save Me")
        temp_mask = to_categorical(temp_mask, num_classes=4)
        np.save(TRAIN_DATASET_PATH + 'input_data_3channels/images/image_' + str(img) + '.npy', temp_combined_images)
        np.save(TRAIN_DATASET_PATH + 'input_data_3channels/masks/mask_' + str(img) + '.npy', temp_mask)
    else:
        print("I am useless")

# Split the dataset into training and validation sets
import splitfolders

input_folder = TRAIN_DATASET_PATH + 'input_data_3channels/'
output_folder = TRAIN_DATASET_PATH + 'input_data_128/'
splitfolders.ratio(input_folder, output=output_folder, seed=42, ratio=(.75, .25), group_prefix=None)


In [None]:
import os
import numpy as np
from matplotlib import pyplot as plt
import random

# Define function to load images from a directory
def load_img(img_dir, img_list):
    images = []
    for i, image_name in enumerate(img_list):
        if image_name.split('.')[1] == 'npy':
            image = np.load(os.path.join(img_dir, image_name))
            images.append(image)
    images = np.array(images)
    return images

# Define custom data generator
def imageLoader(img_dir, img_list, mask_dir, mask_list, batch_size):
    L = len(img_list)
    while True:  # Keras requires the generator to be infinite
        batch_start = 0
        batch_end = batch_size
        while batch_start < L:
            limit = min(batch_end, L)
            X = load_img(img_dir, img_list[batch_start:limit])
            Y = load_img(mask_dir, mask_list[batch_start:limit])
            yield (X, Y)  # Yield a tuple of two numpy arrays with batch_size samples
            batch_start += batch_size
            batch_end += batch_size

# Define dataset paths for training images and masks
train_img_dir = "./brats20-dataset-training-validation/BraTS2020_TrainingData/MICCAI_BraTS2020_TrainingData/input_data_128/train/images/"
train_mask_dir = "./brats20-dataset-training-validation/BraTS2020_TrainingData/MICCAI_BraTS2020_TrainingData/input_data_128/train/masks/"

# Get list of all image and mask files
train_img_list = os.listdir(train_img_dir)
train_mask_list = os.listdir(train_mask_dir)

# Define batch size
batch_size = 2

# Initialize the custom data generator
train_img_datagen = imageLoader(train_img_dir, train_img_list, train_mask_dir, train_mask_list, batch_size)

# Test the generator
img, msk = train_img_datagen.__next__()

# Randomly select an image and its corresponding mask from the batch
img_num = random.randint(0, img.shape[0] - 1)
test_img = img[img_num]
test_mask = msk[img_num]

# Convert mask from categorical to original labels
test_mask = np.argmax(test_mask, axis=3)

# Randomly select a slice to visualize
n_slice = random.randint(0, test_mask.shape[2])
plt.figure(figsize=(12, 8))

# Visualize the different channels of the selected image slice and its corresponding mask
plt.subplot(221)
plt.imshow(test_img[:, :, n_slice, 0], cmap='gray')
plt.title('Image flair')
plt.subplot(222)
plt.imshow(test_img[:, :, n_slice, 1], cmap='gray')
plt.title('Image t1ce')
plt.subplot(223)
plt.imshow(test_img[:, :, n_slice, 2], cmap='gray')
plt.title('Image t2')
plt.subplot(224)
plt.imshow(test_mask[:, :, n_slice])
plt.title('Mask')
plt.show()
