<a href="https://colab.research.google.com/github/Debayan2004/BR-Tumor-Segmentation/blob/main/input_data_preprocessing.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [3]:
!pip install SimpleITK

Collecting SimpleITK
  Downloading SimpleITK-2.4.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (7.9 kB)
Downloading SimpleITK-2.4.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (52.4 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m52.4/52.4 MB[0m [31m11.9 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: SimpleITK
Successfully installed SimpleITK-2.4.0


In [4]:
import os
import numpy as np
import nibabel as nib
import SimpleITK as sitk
from sklearn.model_selection import train_test_split

In [5]:
def load_nii_file(file_path):
    """Load a .nii file."""
    nii_image = nib.load(file_path)
    return nii_image.get_fdata()  # Returns the data as a NumPy array


In [7]:
def preprocess_data(ir_images, t1_images, flair_images, masks, target_shape=(128, 128, 128)):
    """
    Preprocess the MRI images and masks by resizing and normalizing.
    """
    def resize_image(image, target_shape):
        """Resize image to the target shape."""
        sitk_image = sitk.GetImageFromArray(image)
        original_size = list(image.shape)
        target_size = list(target_shape)

        resampler = sitk.ResampleImageFilter()
        resampler.SetOutputSpacing([o_sz / t_sz for o_sz, t_sz in zip(original_size, target_size)])
        resampler.SetSize(target_size)
        resampler.SetOutputDirection(sitk_image.GetDirection())
        resampler.SetOutputOrigin(sitk_image.GetOrigin())
        resampler.SetTransform(sitk.Transform())
        resampler.SetDefaultPixelValue(0)
        resampler.SetInterpolator(sitk.sitkLinear)
        resized_image = sitk.GetArrayFromImage(resampler.Execute(sitk_image))

        return resized_image

    def normalize_image(image):
        """Normalize image data to [0, 1]."""
        return (image - np.min(image)) / (np.max(image) - np.min(image))

    ir_images_resized = np.array([resize_image(img, target_shape) for img in ir_images])
    t1_images_resized = np.array([resize_image(img, target_shape) for img in t1_images])
    flair_images_resized = np.array([resize_image(img, target_shape) for img in flair_images])

    ir_images_normalized = np.array([normalize_image(img) for img in ir_images_resized])
    t1_images_normalized = np.array([normalize_image(img) for img in t1_images_resized])
    flair_images_normalized = np.array([normalize_image(img) for img in flair_images_resized])

    masks_resized = np.array([resize_image(mask, target_shape) for mask in masks])

    return ir_images_normalized, t1_images_normalized, flair_images_normalized, masks_resized




In [8]:
# Extract patches
def extract_patches(images, patch_size=(16, 16, 16)):
    patches = []
    for image in images:
        z, h, w = image.shape
        dz, dh, dw = patch_size
        for z_start in range(0, z, dz):
            for h_start in range(0, h, dh):
                for w_start in range(0, w, dw):
                    patches.append(image[z_start:z_start+dz, h_start:h_start+dh, w_start:w_start+dw])
    return np.array(patches)

In [9]:
# Preprocess labels
def preprocess_labels(masks, num_classes=8):
    one_hot_masks = np.zeros((*masks.shape, num_classes), dtype=np.uint8)
    masks = np.where((masks == 0) | (masks == 9) | (masks == 10), 0, masks)
    for class_id in range(1, num_classes + 1):
        one_hot_masks[..., class_id - 1] = (masks == class_id).astype(np.uint8)
    return one_hot_masks

In [10]:
# Train/validation split
def split_train_valid(ir_patches, t1_patches, flair_patches, mask_patches, valid_size=0.2):
    input_images = np.stack([ir_patches, t1_patches, flair_patches], axis=-1)
    X_train, X_valid, y_train, y_valid = train_test_split(
        input_images, mask_patches, test_size=valid_size, random_state=42
    )
    return X_train, X_valid, y_train, y_valid

In [11]:
# Check normalization
def check_normalization(data, name):
    print(f"Checking {name}:")
    print(f"Min: {np.min(data):.5f}, Max: {np.max(data):.5f}, Mean: {np.mean(data):.5f}, Std: {np.std(data):.5f}")
