In [None]:
import os
import cv2
import numpy as np
import nibabel as nib
import matplotlib.pyplot as plt
from tensorflow.keras.models import load_model
# === Load trained model ===
model = load_model("../models/unet1.h5", compile=False)
# === Load and normalize NIfTI slices ===
def load_nifti_slices(nifti_path, target_shape=(128, 128)):
    nii = nib.load(nifti_path)
    volume = nii.get_fdata()
    volume = (volume - np.min(volume)) / (np.max(volume) - np.min(volume) + 1e-8)
    resized_slices = [cv2.resize(volume[:, :, i], target_shape) for i in range(volume.shape[2])]
    return np.array(resized_slices), volume.shape
# === Apply color mask ===
def apply_mask(mask, class_val, color):
    overlay = np.zeros((*mask.shape, 3), dtype=np.uint8)
    overlay[mask == class_val] = color
    return overlay
# === Display predictions with overlays ===
def predict_and_display(nifti_path):
    slices, orig_shape = load_nifti_slices(nifti_path)
    for i in range(min(151, slices.shape[0])):
        input_img = slices[i]
        input_tensor = input_img[np.newaxis, ..., np.newaxis].astype(np.float32)
        pred = model.predict(input_tensor)[0]         # Shape: (128, 128, 4)
        pred_mask = np.argmax(pred, axis=-1)          # Shape: (128, 128)
        input_rgb = np.stack([input_img] * 3, axis=-1)
        input_rgb = (input_rgb * 255).astype(np.uint8)
        # Individual masks
        enh_tumor = apply_mask(pred_mask, 3, [255, 0, 0])      # Red
        non_enh = apply_mask(pred_mask, 1, [0, 0, 255])        # Blue
        flair = apply_mask(pred_mask, 2, [255, 255, 0])        # Yellow
        # Plot everything
        fig, axs = plt.subplots(1, 4, figsize=(20, 5))
        axs[0].imshow(input_rgb)
        axs[0].set_title(f"Input Slice {i}")
        axs[1].imshow(enh_tumor)
        axs[1].set_title("Enhancing Tumor (Red)")
        axs[2].imshow(non_enh)
        axs[2].set_title("Non-Enhancing Tumor (Blue)")
        axs[3].imshow(flair)
        axs[3].set_title("FLAIR Hyperintensity (Yellow)")
        for ax in axs:
            ax.axis('off')
        plt.tight_layout()
        plt.show()
# === Run prediction display ===
predict_and_display("C:/Users/manju/OneDrive/Desktop/Brain Tumour/data/brats-men-train/BraTS-MEN-00010-000/BraTS-MEN-00010-000-t1c.nii")
