## Final_Project.ipynb Contents

This notebook covers the complete workflow for brain tumor segmentation using deep learning (U-Net) on the BraTS dataset:

- **Data Loading & Preprocessing**
  - Loads multi-modal MRI data (FLAIR, T1, T1ce, T2) and segmentation masks for each subject.
  - Preprocesses each slice: normalization, resizing, and filtering out empty masks.

- **Model Architecture**
  - Defines a U-Net model with batch normalization and dropout for robust segmentation.

- **Loss Functions**
  - Implements Dice loss and a combined BCE + Dice loss for training.

- **Training**
  - Splits data into training and validation sets.
  - Trains the U-Net model and saves the trained weights.

- **Evaluation**
  - Predicts on the validation set.
  - Calculates Dice and IoU metrics.
  - Visualizes random samples of input, ground truth, and predictions.

- **Testing**
  - For testing the trained model on new data, use the provided `test_model.ipynb`.

> **Note:**  
> The MRI subject folders should contain the following files (BraTS format):  
> - `<subject_id>_flair.nii.gz`  
> - `<subject_id>_t1.nii.gz`  
> - `<subject_id>_t1ce.nii.gz`  
> - `<subject_id>_t2.nii.gz`  
> - `<subject_id>_seg.nii.gz`  
> All files must be named with the same `<subject_id>` prefix as the folder name.

In [None]:
import os
import numpy as np
import nibabel as nib
import tensorflow as tf
from tensorflow.keras import layers, models, backend as K

In [None]:
def load_modalities(subject_path, modalities=['flair', 't1', 't1ce', 't2']):
    vols = []
    for mod in modalities:
        nii = nib.load(os.path.join(subject_path, f"{os.path.basename(subject_path)}_{mod}.nii.gz"))
        vols.append(nii.get_fdata())
    return np.stack(vols, axis=-1)

def load_mask(subject_path):
    nii = nib.load(os.path.join(subject_path, f"{os.path.basename(subject_path)}_seg.nii.gz"))
    return nii.get_fdata()

def preprocess_volume(volume, mask, target_size=(128, 128)):
    X, y = [], []
    for i in range(volume.shape[2]):
        img_slice = volume[:, :, i, :] 
        mask_slice = mask[:, :, i]    
        if np.max(mask_slice) == 0:
            continue
        img_slice = (img_slice - np.min(img_slice, axis=(0,1))) / (np.ptp(img_slice, axis=(0,1)) + 1e-8)
        img_slice = tf.image.resize(img_slice, target_size).numpy()
        mask_slice = tf.image.resize(mask_slice[..., None], target_size, method='nearest').numpy().squeeze()
        X.append(img_slice)
        y.append(mask_slice)
    return np.array(X), np.array(y)

In [None]:
data_dir = "/home/jyotirya-agrawal/.cache/kagglehub/datasets/dschettler8845/brats-2021-task1/versions/1/BraTS2021_Training_Data"
subjects = [os.path.join(data_dir, d) for d in os.listdir(data_dir) if d.startswith("BraTS2021")]

X_all, y_all = [], []
for subj in subjects:  # Use a subset for demo/training speed
    vol = load_modalities(subj)
    mask = load_mask(subj)
    X, y = preprocess_volume(vol, mask)
    X_all.append(X)
    y_all.append(y)
X_all = np.concatenate(X_all, axis=0)
y_all = np.concatenate(y_all, axis=0)
y_all = (y_all > 0).astype(np.float32)  # Binary mask
print("X_all shape:", X_all.shape, "y_all shape:", y_all.shape)

In [None]:
def conv_block(x, filters, use_bn=True):
    x = layers.Conv2D(
        filters,
        kernel_size=3,
        padding='same',
        kernel_initializer='he_normal',
        use_bias=not use_bn
    )(x)
    if use_bn:
        x = layers.BatchNormalization()(x)
    x = layers.ReLU()(x)

    x = layers.Conv2D(
        filters,
        kernel_size=3,
        padding='same',
        kernel_initializer='he_normal',
        use_bias=not use_bn
    )(x)
    if use_bn:
        x = layers.BatchNormalization()(x)
    x = layers.ReLU()(x)

    return x


def build_unet(input_shape=(128, 128, 4), base_filters=16):
    inputs = layers.Input(shape=input_shape)

    # ---------- Encoder ----------
    c1 = conv_block(inputs, base_filters)
    p1 = layers.MaxPooling2D()(c1)

    c2 = conv_block(p1, base_filters * 2)
    p2 = layers.MaxPooling2D()(c2)

    c3 = conv_block(p2, base_filters * 4)
    p3 = layers.MaxPooling2D()(c3)

    c4 = conv_block(p3, base_filters * 8)
    p4 = layers.MaxPooling2D()(c4)

    # ---------- Bottleneck ----------
    bn = conv_block(p4, base_filters * 16)
    bn = layers.Dropout(0.3)(bn)

    # ---------- Decoder ----------
    u1 = layers.Conv2DTranspose(base_filters * 8, 2, strides=2, padding='same')(bn)
    u1 = layers.Concatenate()([u1, c4])
    c5 = conv_block(u1, base_filters * 8)

    u2 = layers.Conv2DTranspose(base_filters * 4, 2, strides=2, padding='same')(c5)
    u2 = layers.Concatenate()([u2, c3])
    c6 = conv_block(u2, base_filters * 4)

    u3 = layers.Conv2DTranspose(base_filters * 2, 2, strides=2, padding='same')(c6)
    u3 = layers.Concatenate()([u3, c2])
    c7 = conv_block(u3, base_filters * 2)

    u4 = layers.Conv2DTranspose(base_filters, 2, strides=2, padding='same')(c7)
    u4 = layers.Concatenate()([u4, c1])
    c8 = conv_block(u4, base_filters)

    outputs = layers.Conv2D(1, kernel_size=1, activation='sigmoid')(c8)

    return models.Model(inputs, outputs)


In [None]:
def dice_loss(y_true, y_pred, smooth=1):
    y_true = K.flatten(y_true)
    y_pred = K.flatten(y_pred)
    intersection = K.sum(y_true * y_pred)
    return 1 - (2. * intersection + smooth) / (
        K.sum(y_true) + K.sum(y_pred) + smooth
    )


def bce_dice_loss(y_true, y_pred):
    return (
        K.binary_crossentropy(y_true, y_pred)
        + dice_loss(y_true, y_pred)
    )


In [None]:
unet = build_unet()
unet.compile(
    optimizer='adam',
    loss=bce_dice_loss,
    metrics=['accuracy']
)

unet.summary()


In [None]:
from sklearn.model_selection import train_test_split
X_train, X_val, y_train, y_val = train_test_split(X_all, y_all, test_size=0.2, random_state=42)
history = unet.fit(
    X_train, y_train[..., None],
    validation_data=(X_val, y_val[..., None]),
    epochs=10,
    batch_size=8
)

In [None]:
unet.save("unet_brats_segmentation.h5")

In [None]:
def dice_coef(y_true, y_pred, smooth=1e-6):
    y_true_f = y_true.flatten()
    y_pred_f = y_pred.flatten()
    intersection = np.sum(y_true_f * y_pred_f)
    return (2. * intersection + smooth) / (np.sum(y_true_f) + np.sum(y_pred_f) + smooth)

def iou_score(y_true, y_pred, smooth=1e-6):
    y_true_f = y_true.flatten()
    y_pred_f = y_pred.flatten()
    intersection = np.sum(y_true_f * y_pred_f)
    union = np.sum(y_true_f) + np.sum(y_pred_f) - intersection
    return (intersection + smooth) / (union + smooth)

# Predict on validation set
y_pred = unet.predict(X_val)
y_pred_bin = (y_pred > 0.5).astype(np.float32)

dice = dice_coef(y_val, y_pred_bin)
iou = iou_score(y_val, y_pred_bin)
print(f"Dice Score: {dice:.4f}, IoU: {iou:.4f}")

In [None]:
import matplotlib.pyplot as plt

def plot_sample(X, y_true, y_pred, idx):
    plt.figure(figsize=(12,4))
    plt.subplot(1,3,1)
    plt.imshow(X[idx][...,0], cmap='gray')
    plt.title('FLAIR (example)')
    plt.axis('off')
    plt.subplot(1,3,2)
    plt.imshow(y_true[idx], cmap='gray')
    plt.title('Ground Truth')
    plt.axis('off')
    plt.subplot(1,3,3)
    plt.imshow(y_pred[idx].squeeze(), cmap='gray')
    plt.title('Prediction')
    plt.axis('off')
    plt.show()
for i in range(10):
    index = np.random.randint(0, X_val.shape[0])
    plot_sample(X_val, y_val, y_pred_bin,index)
