## 3D MRI Brain Tumor Segmentation with U-Net
This script implements a 3D U-Net model for brain tumor segmentation using the BraTS dataset.
It includes data preprocessing, model training, and an interactive Gradio interface.

### Libraries

In [1]:
pip install gradio

Note: you may need to restart the kernel to use updated packages.


In [76]:
# 3D MRI Brain Tumor Segmentation with U-Net
# This script implements a 3D U-Net model for brain tumor segmentation using the BraTS dataset.
# It includes data preprocessing, model training, and an interactive Gradio interface.

import numpy as np
import nibabel as nib
import tensorflow as tf
import gradio as gr
from tensorflow.keras.utils import to_categorical
import matplotlib.pyplot as plt
from sklearn.preprocessing import MinMaxScaler
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input, Conv3D, MaxPooling3D, UpSampling3D, concatenate, Dropout
from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping
import threading
import random

### Constants
These constants define key parameters for the segmentation task:
- `IMG_SIZE = (128, 128, 128)`: Resize all images to 128x128x128 to balance memory usage and detail retention. The BraTS dataset typically has 240x240x155 volumes, but downsizing reduces computational load while preserving sufficient anatomical information.
- `NUM_CLASSES = 4`: Represents the four classes in BraTS (0: background, 1: necrosis, 2: edema, 3: enhancing tumor). This matches the dataset's label structure after remapping label 4 to 3.
- `CHANNELS = 4`: Corresponds to the four MRI modalities (flair, t1, t1ce, t2) used as input channels, providing multi-modal information for better segmentation.

In [9]:
# Constants
IMG_SIZE = (128, 128, 128)  # Target size for resizing images
NUM_CLASSES = 4  # Number of classes (0: background, 1: necrosis, 2: edema, 3: enhancing tumor)
CHANNELS = 4  # Number of MRI modalities (flair, t1, t1ce, t2)

### DatasetHandler Class
This class manages loading and preprocessing of MRI data from the BraTS dataset:
- **Why MinMaxScaler?**: Normalizes pixel intensities to [0, 1], which stabilizes training by ensuring consistent input ranges across modalities.
- **Preprocessing Choices**: Resizing to `IMG_SIZE` ensures uniformity, and one-hot encoding of masks prepares them for categorical cross-entropy loss in the U-Net.
- **Multi-Modal Loading**: The `load_sample` method stacks four modalities into a 4D array, leveraging complementary information from each scan type.

In [12]:
class DatasetHandler:
    """Handles loading and preprocessing of MRI data."""
    def __init__(self, dataset_path):
        self.dataset_path = dataset_path
        self.scaler = MinMaxScaler()  # Normalizes data to [0, 1]
    
    def load_nifti(self, file_path):
        """Loads a NIfTI file and returns its data as a numpy array."""
        try:
            return nib.load(file_path).get_fdata()
        except Exception as e:
            raise ValueError(f"Failed to load NIfTI file {file_path}: {str(e)}")
    
    def preprocess_image(self, image):
        """Normalizes image data to [0, 1] range and resizes."""
        if image.size == 0:
            raise ValueError("Empty image provided for preprocessing.")
        image = self.scaler.fit_transform(image.reshape(-1, 1)).reshape(image.shape)
        return np.resize(image, IMG_SIZE)  # Resize to consistent dimensions
    
    def preprocess_mask(self, mask):
        """Preprocesses segmentation mask: converts to uint8, remaps labels, and one-hot encodes."""
        mask = mask.astype(np.uint8)
        mask[mask == 4] = 3  # Remap label 4 to 3 (BraTS convention)
        mask = np.resize(mask, IMG_SIZE)
        return to_categorical(mask, num_classes=NUM_CLASSES)  # One-hot encoding for multi-class
    
    def load_sample(self, sample_dir):
        """Loads and preprocesses a single BraTS sample with all modalities."""
        modalities = ['flair', 't1', 't1ce', 't2']
        images = []
        for modality in modalities:
            file_path = f"{sample_dir}/BraTS20_Training_{sample_dir.split('_')[-1]}_{modality}.nii"
            img = self.load_nifti(file_path)
            img = self.preprocess_image(img)
            images.append(img)
        image_stack = np.stack(images, axis=-1)  # Shape: (128, 128, 128, 4)
        
        mask_path = f"{sample_dir}/BraTS20_Training_{sample_dir.split('_')[-1]}_seg.nii"
        mask = self.load_nifti(mask_path)
        mask = self.preprocess_mask(mask)
        
        return image_stack, mask

### UNetModel Class
This class manages the 3D U-Net model for brain tumor segmentation:
- **Purpose**: Provides a reference implementation of the U-Net architecture and loads a pre-trained model for inference.
- **Implementation**: 
  - **`build_unet` Method**: Defines a 3D U-Net architecture as a reference (not used in this script). It includes an encoder-decoder structure with skip connections:
    - **Filter Sizes (32, 64, 128, 256)**: Increase with depth for hierarchical feature extraction, then decrease symmetrically in the decoder.
    - **Kernel Size (3,3,3)**: Small kernels reduce parameters while capturing local 3D patterns.
    - **Dropout (0.3)**: Added in the bottleneck to prevent overfitting.
    - **Pooling/Upsampling (2,2,2)**: Standard resolution adjustment for 3D U-Net.
    - **Softmax Output**: Produces probabilities for 4 classes per voxel.
  - This implementation is included as an example of how the model could be built if training from scratch, but it is not executed here.
- **Pre-trained Model**: The `__init__` method loads `'brats_3d.hdf5'` using `load_model(compile=False)`, which is the actual model used for segmentation in the interactive interface. This pre-trained model overrides the `build_unet` architecture.
- **Prediction**: The `predict` method uses the loaded model to generate segmentation masks from input data.
- **Why This Structure?**: Including the `build_unet` method provides a clear reference for the U-Net architecture, while loading `'brats_3d.hdf5'` leverages an existing, optimized model for practical use.

In [78]:
class UNetModel:
    """Loads a pre-trained 3D U-Net model for brain tumor segmentation with reference implementation."""
    def __init__(self):
        self.model = self.load_pretrained_model()
    
    def load_pretrained_model(self):
        """Loads the pre-trained model from 'brats_3d.hdf5'."""
        try:
            return load_model('brats_3d.hdf5', compile=False)
        except Exception as e:
            raise ValueError(f"Failed to load pre-trained model 'brats_3d.hdf5': {str(e)}")
    
    def build_unet(self):
        """Reference implementation of the 3D U-Net architecture (not used in inference)."""
        # This is an example of how the U-Net could be implemented if not using a pre-trained model
        inputs = Input((*IMG_SIZE, CHANNELS))
        
        # Encoder
        c1 = Conv3D(32, (3, 3, 3), activation='relu', padding='same')(inputs)
        c1 = Conv3D(32, (3, 3, 3), activation='relu', padding='same')(c1)
        p1 = MaxPooling3D((2, 2, 2))(c1)
        
        c2 = Conv3D(64, (3, 3, 3), activation='relu', padding='same')(p1)
        c2 = Conv3D(64, (3, 3, 3), activation='relu', padding='same')(c2)
        p2 = MaxPooling3D((2, 2, 2))(c2)
        
        c3 = Conv3D(128, (3, 3, 3), activation='relu', padding='same')(p2)
        c3 = Conv3D(128, (3, 3, 3), activation='relu', padding='same')(c3)
        p3 = MaxPooling3D((2, 2, 2))(c3)
        
        # Bottleneck
        c4 = Conv3D(256, (3, 3, 3), activation='relu', padding='same')(p3)
        c4 = Conv3D(256, (3, 3, 3), activation='relu', padding='same')(c4)
        c4 = Dropout(0.3)(c4)
        
        # Decoder
        u5 = UpSampling3D((2, 2, 2))(c4)
        u5 = concatenate([u5, c3])
        c5 = Conv3D(128, (3, 3, 3), activation='relu', padding='same')(u5)
        c5 = Conv3D(128, (3, 3, 3), activation='relu', padding='same')(c5)
        
        u6 = UpSampling3D((2, 2, 2))(c5)
        u6 = concatenate([u6, c2])
        c6 = Conv3D(64, (3, 3, 3), activation='relu', padding='same')(u6)
        c6 = Conv3D(64, (3, 3, 3), activation='relu', padding='same')(c6)
        
        u7 = UpSampling3D((2, 2, 2))(c6)
        u7 = concatenate([u7, c1])
        c7 = Conv3D(32, (3, 3, 3), activation='relu', padding='same')(u7)
        c7 = Conv3D(32, (3, 3, 3), activation='relu', padding='same')(c7)
        
        outputs = Conv3D(NUM_CLASSES, (1, 1, 1), activation='softmax')(c7)
        return Model(inputs=inputs, outputs=outputs)
    
    def predict(self, image):
        """Predicts segmentation mask for a given image."""
        return self.model.predict(image)

### SegmentationApp Class
This class provides the application logic for segmentation:
- **Single-File Handling**: For simplicity, it processes a single `.nii` file and simulates 4 channels by duplicating it. In practice, you'd load all modalities.
- **Visualization**: Displays a random 2D slice from the 3D prediction using `jet` colormap for clear class distinction.
- **Error Handling**: Ensures robustness by catching and reporting issues like file loading failures.

In [80]:
class SegmentationApp:
    """Application for interactive brain tumor segmentation."""
    def __init__(self, model):
        self.model = model
    
    def segment_image(self, npy_data):
        """Segments a .npy image with 3 channels and returns the result."""
        try:
            test_img_input = np.expand_dims(npy_data, axis=0)  # Add batch dimension
            prediction = self.model.predict(test_img_input)
            pred_argmax = np.argmax(prediction, axis=4)[0, :, :, :]
            
            slice_num = random.randint(0, pred_argmax.shape[2] - 1)
            
            # Display the segmented image
            fig_segmented, ax_segmented = plt.subplots(figsize=(8, 8))
            ax_segmented.imshow(pred_argmax[:, :, slice_num])
            ax_segmented.set_title("Segmentation")
            ax_segmented.axis('off')
            
            return fig_segmented, pred_argmax, npy_data[:, :, slice_num, 0]  # Return plot, prediction, and input slice
        except Exception as e:
            return None, None, None, f"Error: {str(e)}"

### Interactive Segmentation Interface
This section defines the Gradio interface integrated with the defined classes:
- **Why Gradio?**: Provides an easy-to-use web interface for interactive segmentation, accessible to both technical and non-technical users.
- **Training Input**: Uses `.nii` files from the BraTS dataset with 4 modalities (FLAIR, T1, T1ce, T2) processed by `DatasetHandler` for training reference (not executed here as the model is pre-trained).
- **Interactive Input**: Accepts a single `.npy` file with 3 channels (FLAIR, T1ce, T2), directly processed without additional preprocessing to match the pre-trained model’s expected input.
- **Output**: Uses `gr.Plot` to display two Matplotlib figures: one for the input image (FLAIR) and one for the segmented result, with a `gr.Textbox` for the diagnosis message.
- **Class Integration**: 
  - `DatasetHandler`: Retained for training reference but not used in inference (direct `.npy` loading as per original logic).
  - `UNetModel`: Loads the pre-trained model from `'brats_3d.hdf5'` using `load_model(compile=False)`; includes `build_unet` as a reference implementation (see UNetModel Class).
  - `SegmentationApp`: Manages the segmentation process and visualization for the `.npy` input, mirroring the original logic.
- **Pre-trained Model**: Loads `'brats_3d.hdf5'` directly. If loading fails, an error message is returned, prompting the user to ensure the model file is present.
- **Interactive Control**: Includes a "Stop Processing" button to allow users to interrupt the processing and avoid unnecessary computation.


In [82]:
# Variable to stop the processing
stop_flag = threading.Event()

def interactive_segmentation(npy_file, stop_button):
    """Interactive segmentation with a 3-channel .npy file using defined classes."""
    global stop_flag
    stop_flag.clear()  # Reset the flag at the start of each execution

    if npy_file is None:
        return None, None, "No file uploaded."

    model = UNetModel()  # Loads 'brats_3d.hdf5'
    app = SegmentationApp(model)

    try:
        # Load the .npy image
        test_img = np.load(npy_file.name)
        if test_img.ndim != 4 or test_img.shape[-1] != 3:
            return None, None, "Invalid .npy file. It must have shape (H, W, D, 3)."

        # Add a batch dimension for the model
        test_img_input = np.expand_dims(test_img, axis=0)

        # Make a prediction using the model
        test_prediction = model.predict(test_img_input)
        test_prediction_argmax = np.argmax(test_prediction, axis=4)[0, :, :, :]

        # Select a random slice
        n_slice = random.randint(0, test_prediction_argmax.shape[2] - 1)

        # Check if the user clicked "Stop Processing"
        if stop_flag.is_set():
            return None, None, "Processing stopped."

        # Check if there is a tumor by looking for classes 1 or 3
        tumor_found = np.any(np.isin(test_prediction_argmax[:, :, n_slice], [1, 3]))

        # Display the input image (FLAIR)
        fig_input, ax_input = plt.subplots(figsize=(8, 8))
        ax_input.imshow(test_img[:, :, n_slice, 0], cmap='gray')
        ax_input.set_title("Input Image (FLAIR)")
        ax_input.axis('off')  # Hide the axes

        # Display the segmented image
        fig_segmented, ax_segmented = plt.subplots(figsize=(8, 8))
        ax_segmented.imshow(test_prediction_argmax[:, :, n_slice])
        ax_segmented.set_title("Segmentation")
        ax_segmented.axis('off')  # Hide the axes

        # Message based on the presence of the tumor
        message = "All good, no problem!" if not tumor_found else "Problem detected, a tumor has been identified!"

        # Return the figures and the message
        return (fig_input, fig_segmented, message)

    except Exception as e:
        return None, None, f"Error: {str(e)}"

def stop_processing():
    """Function to stop the processing."""
    global stop_flag
    stop_flag.set()
    return "Processing stopped."

# Gradio interface with a "Stop Processing" button
iface = gr.Interface(
    fn=interactive_segmentation,
    inputs=[
        gr.File(label="Upload MRI Scan (.npy)", file_types=[".npy"]),
        gr.Button("Stop Processing", variant="secondary", elem_id="stop-btn")
    ],
    outputs=[
        gr.Plot(label="Input Image (FLAIR)"),
        gr.Plot(label="Segmentation Result"),
        gr.Textbox(label="Diagnosis Message")
    ],
    title="Brain Tumor Segmentation",
    description="Upload an MRI scan in .npy format (3 channels: FLAIR, T1ce, T2) to segment the brain tumor using a pre-trained 3D U-Net model from 'brats_3d.hdf5'.",
    live=False  # Set to False to avoid continuous execution
)

# Launch the Gradio interface
iface.launch()

* Running on local URL:  http://127.0.0.1:7882

To create a public link, set `share=True` in `launch()`.


