In [13]:
import numpy as np
import nibabel as nib
import glob
from tensorflow.keras.utils import to_categorical
import matplotlib.pyplot as plt

from sklearn.preprocessing import MinMaxScaler

In [14]:
scaler = MinMaxScaler()

In [38]:
TRAIN_DATASET_PATH = ".\BraTS20\BraTS2020_TrainingData\MICCAI_BraTS2020_TrainingData"
TRAIN_OUTPUT_IMAGE_PATH = '.\BraTS20\BraTS2020_TrainingData\input_data_4channels\\train\images\\'
TRAIN_OUTPUT_MASK_PATH = '.\BraTS20\BraTS2020_TrainingData\input_data_4channels\\train\masks\\'
VAL_OUTPUT_IMAGE_PATH = '.\BraTS20\BraTS2020_TrainingData\input_data_4channels\\val\images\\'
VAL_OUTPUT_MASK_PATH = '.\BraTS20\BraTS2020_TrainingData\input_data_4channels\\val\masks\\'

In [39]:
import os
# Ensure output directories exist
os.makedirs(TRAIN_OUTPUT_IMAGE_PATH, exist_ok=True)
os.makedirs(TRAIN_OUTPUT_MASK_PATH, exist_ok=True)
os.makedirs(VAL_OUTPUT_IMAGE_PATH, exist_ok=True)
os.makedirs(VAL_OUTPUT_MASK_PATH, exist_ok=True)

In [40]:
# Define lists of files for each modality and the segmentation mask
t1_list = sorted(glob.glob(TRAIN_DATASET_PATH+'/*/*t1.nii'))
t2_list = sorted(glob.glob(TRAIN_DATASET_PATH+'/*/*t2.nii'))
t1ce_list = sorted(glob.glob(TRAIN_DATASET_PATH+'/*/*t1ce.nii'))
flair_list = sorted(glob.glob(TRAIN_DATASET_PATH+'/*/*flair.nii'))
mask_list = sorted(glob.glob(TRAIN_DATASET_PATH+'/*/*seg.nii'))  # remane the mask in  the training dataset 355 (Seg -> seg)


In [41]:
print(len(t1_list ),len(t1ce_list), len(t2_list ), len(flair_list ), len(mask_list ))

369 369 369 369 369


In [42]:
# Define crop boundaries
CROP_BOUNDARIES = (slice(56, 184), slice(56, 184), slice(13, 141))

**Why Crop?**
- **Focus on Relevant Areas:** By cropping, you ensure that the model processes only the most relevant portions of the image, which can improve training efficiency.
- **Memory and Computation:** 3D medical images are large and require significant computational resources. Cropping reduces the size of the data, making it more manageable for GPU memory and speeding up training.
- **Patch Size Compatibility:** Many deep learning architectures, especially those dealing with 3D data, require input sizes divisible by a specific number (like 64 in this code). Cropping ensures the volume fits this requirement.

**Drawbacks of Cropping:**
- **Risk of Losing Important Information:** If the cropping is too aggressive or poorly designed, important parts of the image (like a tumor) could be excluded, which could negatively impact the model's performance.
- **Assumption of Consistency:** Cropping assumes that the region of interest (ROI) is consistently located within the same region across all images. This might not hold true for all patients or images, potentially leading to information loss.


In [43]:
# Function to load and scale a NIfTI image
def load_and_scale_image(filepath):
    image = nib.load(filepath).get_fdata()
    scaled_image = scaler.fit_transform(image.reshape(-1, image.shape[-1])).reshape(image.shape)
    return scaled_image

In [45]:
import torch
from colorama import Fore, Style

def save_as_torch_tensor(tensor, path):
    torch.save(tensor, path)

# Process and save each image-mask pair as Torch tensors
for img_idx in range(len(t1_list)):
    print(Fore.CYAN + f"Processing image and mask number: {img_idx}" + Style.RESET_ALL)
    # Loading images
    image_t1 = load_and_scale_image(t1_list[img_idx])
    image_t2 = load_and_scale_image(t2_list[img_idx])
    image_t1ce = load_and_scale_image(t1ce_list[img_idx])
    image_flair = load_and_scale_image(flair_list[img_idx])

    # Check if images are loaded correctly
    print(f"Loaded images with shapes: {image_t1.shape}, {image_t2.shape}, {image_t1ce.shape}, {image_flair.shape}")

    # Load mask
    mask = nib.load(mask_list[img_idx]).get_fdata().astype(np.uint8)
    print(f"Loaded mask with shape: {mask.shape}" )

    # Ensure correct reassignment in mask
    mask[mask == 4] = 3

    # Combine images
    combined_image = np.stack([image_t1, image_flair, image_t1ce, image_t2], axis=-1)
    print(f"Combined image shape: {combined_image.shape}")

    # Crop
    cropped_image = combined_image[CROP_BOUNDARIES]
    cropped_mask = mask[CROP_BOUNDARIES]
    print(f"Cropped image shape: {cropped_image.shape}, Cropped mask shape: {cropped_mask.shape}")

    # Check non-zero ratio
    _, counts = np.unique(cropped_mask, return_counts=True)
    non_zero_ratio = 1 - (counts[0] / counts.sum())
    print(f"Non-zero ratio: {non_zero_ratio}")

    if non_zero_ratio > 0.01:
        # One-hot encode mask
        categorical_mask = to_categorical(cropped_mask, num_classes=4)
        print(f"Categorical mask shape: {categorical_mask.shape}")

        # Convert to Torch tensors and permute to appropriate format 
        torch_image = torch.from_numpy(cropped_image).permute(3, 0, 1, 2).float()
        torch_mask = torch.from_numpy(categorical_mask).permute(3, 0, 1, 2).float()
        
        # Save the tensors
        if img_idx < len(t2_list) * 0.8:
            save_as_torch_tensor(torch_image, f'{TRAIN_OUTPUT_IMAGE_PATH}image_{img_idx}.pt')
            save_as_torch_tensor(torch_mask, f'{TRAIN_OUTPUT_MASK_PATH}mask_{img_idx}.pt')
        else:
            save_as_torch_tensor(torch_image, f'{VAL_OUTPUT_IMAGE_PATH}image_{img_idx}.pt')
            save_as_torch_tensor(torch_mask, f'{VAL_OUTPUT_MASK_PATH}mask_{img_idx}.pt')
        print(Fore.GREEN + f"Saved image and mask for index {img_idx}." + Style.RESET_ALL)
    else:
        print(Fore.RED + f"Skipped image {img_idx} due to low non-zero ratio." + Style.RESET_ALL)


[36mProcessing image and mask number: 0[0m
Loaded images with shapes: (240, 240, 155), (240, 240, 155), (240, 240, 155), (240, 240, 155)
Loaded mask with shape: (240, 240, 155)
Combined image shape: (240, 240, 155, 4)
Cropped image shape: (128, 128, 128, 4), Cropped mask shape: (128, 128, 128)
Non-zero ratio: 0.10062837600708008
Categorical mask shape: (128, 128, 128, 4)
[32mSaved image and mask for index 0.[0m
[36mProcessing image and mask number: 1[0m
Loaded images with shapes: (240, 240, 155), (240, 240, 155), (240, 240, 155), (240, 240, 155)
Loaded mask with shape: (240, 240, 155)
Combined image shape: (240, 240, 155, 4)
Cropped image shape: (128, 128, 128, 4), Cropped mask shape: (128, 128, 128)
Non-zero ratio: 0.031950950622558594
Categorical mask shape: (128, 128, 128, 4)
[32mSaved image and mask for index 1.[0m
[36mProcessing image and mask number: 2[0m
Loaded images with shapes: (240, 240, 155), (240, 240, 155), (240, 240, 155), (240, 240, 155)
Loaded mask with shape