In [None]:
import os
import numpy as np
import nibabel as nib
import tensorflow as tf
import matplotlib.pyplot as plt
from tensorflow.keras.models import load_model

def load_modalities(subject_path, modalities=['flair', 't1', 't1ce', 't2']):
    vols = []
    for mod in modalities:
        nii = nib.load(os.path.join(subject_path, f"{os.path.basename(subject_path)}_{mod}.nii.gz"))
        vols.append(nii.get_fdata())
    return np.stack(vols, axis=-1)  # shape: (H, W, Slices, 4)

def load_mask(subject_path):
    nii = nib.load(os.path.join(subject_path, f"{os.path.basename(subject_path)}_seg.nii.gz"))
    return nii.get_fdata()  # shape: (H, W, Slices)

def preprocess_volume(volume, mask, target_size=(128, 128)):
    X, y = [], []
    for i in range(volume.shape[2]):
        img_slice = volume[:, :, i, :]  # (H, W, 4)
        mask_slice = mask[:, :, i]      # (H, W)
        if np.max(mask_slice) == 0:
            continue
        img_slice = (img_slice - np.min(img_slice, axis=(0,1))) / (np.ptp(img_slice, axis=(0,1)) + 1e-8)
        img_slice = tf.image.resize(img_slice, target_size).numpy()
        mask_slice = tf.image.resize(mask_slice[..., None], target_size, method='nearest').numpy().squeeze()
        X.append(img_slice)
        y.append(mask_slice)
    return np.array(X), np.array(y)

def dice_coef(y_true, y_pred, smooth=1e-6):
    y_true_f = y_true.flatten()
    y_pred_f = y_pred.flatten()
    intersection = np.sum(y_true_f * y_pred_f)
    return (2. * intersection + smooth) / (np.sum(y_true_f) + np.sum(y_pred_f) + smooth)

def iou_score(y_true, y_pred, smooth=1e-6):
    y_true_f = y_true.flatten()
    y_pred_f = y_pred.flatten()
    intersection = np.sum(y_true_f * y_pred_f)
    union = np.sum(y_true_f) + np.sum(y_pred_f) - intersection
    return (intersection + smooth) / (union + smooth)

def plot_sample(X, y_true, y_pred, idx):
    plt.figure(figsize=(12,4))
    plt.subplot(1,3,1)
    plt.imshow(X[idx][...,0], cmap='gray')
    plt.title('FLAIR (example)')
    plt.axis('off')
    plt.subplot(1,3,2)
    plt.imshow(y_true[idx], cmap='gray')
    plt.title('Ground Truth')
    plt.axis('off')
    plt.subplot(1,3,3)
    plt.imshow(y_pred[idx].squeeze(), cmap='gray')
    plt.title('Prediction')
    plt.axis('off')
    plt.show()

def test_mri_data(subject_folder, model_path="unet_brats_segmentation.h5"):
    def dice_loss(y_true, y_pred, smooth=1):
        y_true = tf.keras.backend.flatten(y_true)
        y_pred = tf.keras.backend.flatten(y_pred)
        intersection = tf.keras.backend.sum(y_true * y_pred)
        return 1 - (2. * intersection + smooth) / (tf.keras.backend.sum(y_true) + tf.keras.backend.sum(y_pred) + smooth)
    def bce_dice_loss(y_true, y_pred):
        return tf.keras.backend.binary_crossentropy(y_true, y_pred) + dice_loss(y_true, y_pred)
    model = load_model(model_path, custom_objects={'bce_dice_loss': bce_dice_loss})

    mri = load_modalities(subject_folder)
    mask = load_mask(subject_folder)
    X_test, Y_test = preprocess_volume(mri, mask)
    Y_test_bin = (Y_test > 0).astype(np.float32)

    Y_pred = model.predict(X_test)
    Y_bin_pred = np.squeeze((Y_pred > 0.5).astype(np.float32))

    dice = dice_coef(Y_test_bin, Y_bin_pred)
    iou = iou_score(Y_test_bin, Y_bin_pred)
    accuracy = np.mean(Y_test_bin == Y_bin_pred)
    print(f"Dice Score: {dice:.4f}, IoU: {iou:.4f}, Accuracy: {accuracy:.4f}")

    # Plot random 10 samples
    for i in range(10):
        index = np.random.randint(0, X_test.shape[0])
        plot_sample(X_test, Y_test_bin, Y_bin_pred, index)


In [None]:
test_mri_data("/path/to/your/mri/data")
# The subject_folder should contain the following files (BraTS format):
#   - <subject_id>_flair.nii.gz   : FLAIR modality MRI
#   - <subject_id>_t1.nii.gz      : T1 modality MRI
#   - <subject_id>_t1ce.nii.gz    : T1ce modality MRI
#   - <subject_id>_t2.nii.gz      : T2 modality MRI
#   - <subject_id>_seg.nii.gz     : Segmentation mask
# Example folder structure:
#   BraTS2021_00495/
#       BraTS2021_00495_flair.nii.gz
#       BraTS2021_00495_t1.nii.gz
#       BraTS2021_00495_t1ce.nii.gz
#       BraTS2021_00495_t2.nii.gz
#       BraTS2021_00495_seg.nii.gz

# Example usage:
# test_mri_data("/home/jyotirya-agrawal/.cache/kagglehub/datasets/dschettler8845/brats-2021-task1/versions/1/BraTS2021_00495")