In [None]:
import os
import numpy as np
import nibabel as nib
import tensorflow as tf
import matplotlib.pyplot as plt
from tensorflow.keras.models import load_model

# The subject_folder should contain the following files (BraTS format):
#   - <subject_id>_flair.nii.gz   : FLAIR modality MRI
#   - <subject_id>_t1.nii.gz      : T1 modality MRI
#   - <subject_id>_t1ce.nii.gz    : T1ce modality MRI
#   - <subject_id>_t2.nii.gz      : T2 modality MRI
#   - <subject_id>_seg.nii.gz     : Segmentation mask (labels typically {0,1,2,4})
# All files must be named with the same <subject_id> prefix as the folder name.

MODALITIES = ["flair", "t1", "t1ce", "t2"]
NUM_CLASSES = 4  # classes 0..3, with BraTS label 4 remapped to class 3

def load_modalities(subject_path, modalities=MODALITIES):
    vols = []
    sid = os.path.basename(subject_path)
    for mod in modalities:
        nii = nib.load(os.path.join(subject_path, f"{sid}_{mod}.nii.gz"))
        vols.append(nii.get_fdata())
    return np.stack(vols, axis=-1)  # (H, W, Slices, 4)

def load_mask(subject_path):
    sid = os.path.basename(subject_path)
    nii = nib.load(os.path.join(subject_path, f"{sid}_seg.nii.gz"))
    return nii.get_fdata()  # (H, W, Slices)

def preprocess_volume_multiclass(volume, mask, target_size=(128, 128)):
    """
    Returns:
      X: (N, H, W, 4) float32
      y: (N, H, W) uint8 with classes {0,1,2,3} where original label 4 -> 3
    """
    X, y = [], []
    for i in range(volume.shape[2]):
        img_slice = volume[:, :, i, :]   # (H, W, 4)
        mask_slice = mask[:, :, i]       # (H, W)

        if np.max(mask_slice) == 0:
            continue

        # normalize per-channel on the slice
        img_slice = (img_slice - np.min(img_slice, axis=(0, 1))) / (np.ptp(img_slice, axis=(0, 1)) + 1e-8)
        img_slice = tf.image.resize(img_slice, target_size).numpy().astype(np.float32)

        # resize mask with nearest, then keep integer labels
        mask_slice = tf.image.resize(mask_slice[..., None], target_size, method="nearest").numpy().squeeze()
        mask_slice = np.rint(mask_slice).astype(np.uint8)

        # BraTS commonly uses labels {0,1,2,4}; remap 4 -> 3 to make classes contiguous
        mask_slice[mask_slice == 4] = 3
        mask_slice = np.clip(mask_slice, 0, 3).astype(np.uint8)

        X.append(img_slice)
        y.append(mask_slice)

    return np.array(X, dtype=np.float32), np.array(y, dtype=np.uint8)

def dice_iou_per_class(y_true, y_pred, num_classes=NUM_CLASSES, smooth=1e-6):
    dices, ious = [], []
    for c in range(num_classes):
        yt = (y_true == c).astype(np.float32)
        yp = (y_pred == c).astype(np.float32)

        inter = np.sum(yt * yp)
        dice = (2.0 * inter + smooth) / (np.sum(yt) + np.sum(yp) + smooth)

        union = np.sum(yt) + np.sum(yp) - inter
        iou = (inter + smooth) / (union + smooth)

        dices.append(float(dice))
        ious.append(float(iou))
    return dices, ious

def plot_sample_multiclass(X, y_true, y_pred, idx):
    plt.figure(figsize=(12, 4))

    plt.subplot(1, 3, 1)
    plt.imshow(X[idx][..., 0], cmap="gray")
    plt.title("FLAIR")
    plt.axis("off")

    plt.subplot(1, 3, 2)
    plt.imshow(y_true[idx], cmap="tab10", vmin=0, vmax=NUM_CLASSES - 1)
    plt.title("Ground Truth (0..3)")
    plt.axis("off")

    plt.subplot(1, 3, 3)
    plt.imshow(y_pred[idx], cmap="tab10", vmin=0, vmax=NUM_CLASSES - 1)
    plt.title("Prediction (0..3)")
    plt.axis("off")

    plt.show()

def test_mri_data_multiclass(subject_folder, model_path="unet_brats_multiclass.h5", target_size=(128, 128), n_plots=10):
    # Load without compiling (avoids needing custom_objects)
    model = load_model(model_path, compile=False)

    mri = load_modalities(subject_folder)
    mask = load_mask(subject_folder)
    X_test, y_test = preprocess_volume_multiclass(mri, mask, target_size=target_size)

    if X_test.shape[0] == 0:
        raise ValueError("No non-empty tumor slices found in this subject (all mask slices were 0).")

    y_prob = model.predict(X_test, verbose=0)               # (N, H, W, C)
    y_pred = np.argmax(y_prob, axis=-1).astype(np.uint8)   # (N, H, W)

    acc = float(np.mean(y_pred == y_test))
    dices, ious = dice_iou_per_class(y_test, y_pred, num_classes=NUM_CLASSES)

    print(f"Accuracy: {acc:.4f}")
    for c in range(NUM_CLASSES):
        print(f"Class {c}: Dice={dices[c]:.4f}, IoU={ious[c]:.4f}")
    print(f"Mean Dice (tumor classes 1-3): {np.mean(dices[1:]):.4f}")

    for _ in range(n_plots):
        idx = np.random.randint(0, X_test.shape[0])
        plot_sample_multiclass(X_test, y_test, y_pred, idx)

In [None]:
# Example:
test_mri_data_multiclass(
    "/home/jyotirya-agrawal/.cache/kagglehub/datasets/dschettler8845/brats-2021-task1/versions/1/BraTS2021_00495",
    model_path="unet_brats_segmentation.h5",
    n_plots=10
)

# test_mri_data_multiclass("/path/to/your/mri/data", model_path="unet_brats_multiclass.h5", n_plots=10)