# UNETR Inference Pipeline
This notebook demonstrates the inference process for UNETR-based models, showing how to load NIfTI files, preprocess data, apply transforms, and visualize results. The code is structured for clarity and reusability.

## 1. Setup

### 1.1 Import Libraries

We import necessary libraries for handling medical images, preprocessing, model inference, and visualization.

In [None]:
import os
import torch
import subprocess
import numpy as np
import nibabel as nib
import SimpleITK as sitk
import scipy.ndimage as ndi
import matplotlib.pyplot as plt
from skimage.measure import label

# MONAI imports
from monai.transforms import (
    Compose, LoadImaged, ScaleIntensityRanged,
    EnsureChannelFirstd, ResizeWithPadOrCropd
)
from monai.networks.nets import UNETR
from monai.inferers import sliding_window_inference
from monai.metrics import DiceMetric

# tqdm for progress bar
from tqdm.notebook import tqdm
from calculate_result_metrics import calculate_result_metrics

# Initialize the Dice metric
dice_score_metric = DiceMetric(include_background=True, reduction="mean")

### 1.2 Define Paths and Directories
Paths for the model, input MRIs, ground truth masks, and output predictions are set here.

In [None]:
# Define placeholders for file paths
model_path = "<MODEL_PATH>"                     # Path to the trained model, e.g., 'path/to/best_model.pth'
test_mri_dir = "<MRI_DIR>"                      # Path to the MRI test images, e.g., 'path/to/imagesTs'
ground_truth_dir = "<GROUND_TRUTH_DIR>"         # Path to the ground truth masks, e.g., 'path/to/labelsTs'
predicted_masks_dir = "<PREDICTED_MASKS_DIR>"   # Directory to save predicted masks, e.g., 'path/to/labelsTs_pred'

# Ensure output directories exist
os.makedirs(predicted_masks_dir, exist_ok=True)

## 2. Transformations

### 2.1.Preprocessing Transformations
Preprocessing transformations are applied to normalize intensity values, adjust channels, and resize images to the model's expected input size.

In [None]:
# Updated test_transforms with intensity scaling only applied to the image, not the label (mask)
test_transforms = Compose(
    [
        LoadImaged(keys=["image", "label"], allow_missing_keys=False),
        EnsureChannelFirstd(keys=["image", "label"]),
        # Apply intensity scaling only to the image, not the label
        ScaleIntensityRanged(
            keys=["image"],
            a_min=16, 
            a_max=1668, # change accordingly to dataset
            b_min=0.0,
            b_max=1.0,
            clip=True,
        ),
        # Ensure validation images also match the expected size
        ResizeWithPadOrCropd(keys=["image", "label"], spatial_size=(224, 224, 64)),
    ]
)

### 2.2.Postprocessing Transformations

In [None]:
def fill_3d_holes(mask_3d):
    """
    Fill internal holes in a 3D binary mask.

    Args:
        mask_3d (np.ndarray): 3D binary mask (1 for object, 0 for background).

    Returns:
        np.ndarray: Binary mask with internal holes filled.
    """
    filled_mask = ndi.binary_fill_holes(mask_3d).astype(np.uint8)
    return filled_mask

def keep_largest_connected_component_3d(mask_3d):
    """
    Keep only the largest connected component in a 3D binary mask.

    Args:
        mask_3d (np.ndarray): 3D binary mask (1 for object, 0 for background).

    Returns:
        np.ndarray: Binary mask with only the largest connected component in 3D.
    """
    # Label connected components in 3D
    labeled_mask = label(mask_3d, connectivity=3)
    if labeled_mask.max() == 0:  # If no components are found, return the original mask
        return mask_3d

    # Find the largest connected component
    largest_component = np.argmax(np.bincount(labeled_mask.flat)[1:]) + 1
    largest_component_mask = (labeled_mask == largest_component).astype(np.uint8)
    
    return largest_component_mask


def postprocess_predicted_mask(predicted_mask):
    """
    Apply 3D post-processing steps to the predicted mask:
    1. Fill internal holes in 3D.
    2. Keep only the largest connected component in 3D.

    Args:
        predicted_mask (np.ndarray): 3D binary mask (1 for object, 0 for background).

    Returns:
        np.ndarray: Post-processed binary mask.
    """
    # Step 1: Fill internal holes in 3D
    filled_mask = fill_3d_holes(predicted_mask)

    # Step 2: Keep the largest connected component in 3D
    processed_mask = keep_largest_connected_component_3d(filled_mask)

    return processed_mask


## 3. Load Model

Check for GPU availability and set the device accordingly for efficient computation. Then load the model.


In [None]:
# Load model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = UNETR(
    in_channels=1,
    out_channels=2,
    img_size=(224, 224, 64),
    feature_size=16,
    hidden_size=768,
    mlp_dim=3072,
    num_heads=12,
    proj_type="perceptron",
    norm_name="instance",
    res_block=True,
    dropout_rate=0.0,
).to(device)
model.load_state_dict(torch.load(model_path, map_location=device))
model.eval()

## 4. Model Inference

This section performs segmentation on the test dataset using a pre-trained UNETR model. The function processes MRI images, applies transformations, runs model inference with sliding window, computes Dice scores, and optionally saves the predicted segmentation masks.

In [None]:
# Function to process the test set
def process_test_set(test_mri_dir, predicted_masks_dir):
    mri_files = sorted([os.path.join(test_mri_dir, f) for f in os.listdir(test_mri_dir) if f.endswith('.nii.gz')])

    dice_scores = []

    for mri_file in tqdm(mri_files):
        gt_file = mri_file.replace("imagesTs", "labelsTs").replace("_0000", "")

        # Load and preprocess MRI image and GT label
        val_data = {"image": mri_file, "label": gt_file}
        val_data = test_transforms(val_data)  # Apply transformations
        mri_img = val_data["image"]
        gt_mask = val_data["label"]

        # Ensure the ground truth mask (GT) is binary
        gt_mask_binary = torch.unsqueeze(gt_mask, 1).to(device)
        gt_mask_binary = gt_mask_binary.squeeze(1)

        unique_vals = np.unique(gt_mask_binary.cpu().numpy())
        if len(unique_vals) > 1:
            small_val = np.min(unique_vals)
            gt_mask_binary = torch.where(gt_mask_binary == small_val, torch.tensor(0.0), torch.tensor(1.0))
        else:
            print(f"GT mask contains a single unique value: {unique_vals}. Ensure the data is correct.")

        # Prepare MRI input for the model
        mri_inputs = torch.unsqueeze(mri_img, 1).to(device)

        # Perform sliding window inference
        with torch.no_grad():
            pred_outputs = sliding_window_inference(mri_inputs, (224, 224, 64), 4, model)
        pred_outputs = torch.softmax(pred_outputs, dim=1)
        pred_mask_argmax = torch.argmax(pred_outputs, dim=1)

        # Post-process the predicted mask in 3D
        processed_pred_mask = postprocess_predicted_mask(pred_mask_argmax.cpu().numpy()[0])

        # Convert back to PyTorch tensor for Dice score calculation
        processed_pred_tensor = torch.tensor(processed_pred_mask).unsqueeze(0).to(device)

        # Calculate Dice score
        dice_score_metric(y_pred=processed_pred_tensor, y=gt_mask_binary)
        dice_score = dice_score_metric.aggregate().item()
        dice_scores.append(dice_score)
        dice_score_metric.reset()

        # Save predicted mask as NIfTI
        affine_matrix = val_data["image"].meta["affine"] if "affine" in val_data["image"].meta else np.eye(4)
        pred_nifti = nib.Nifti1Image(pred_mask_argmax.cpu().numpy()[0].astype(np.uint8), affine_matrix)

        pred_nifti_name = os.path.basename(mri_file).replace("_0000", "")
        pred_nifti_path = os.path.join(predicted_masks_dir, pred_nifti_name)
        nib.save(pred_nifti, pred_nifti_path)

    return dice_scores

# Process the dataset
print("Processing Test Set...")
dice_scores = process_test_set(test_mri_dir, predicted_masks_dir)
print(f"Inference complete. Mean Dice Score: {np.mean(dice_scores):.4f}")

## 4.Evaluate Model Performance

This section uses the `calculate_result_metrics.py` script to evaluate the model's segmentation performance. The script calculates metrics (DICE, Hausdorff, Hausdorff95, S2S, RVE) and groups tumors by size using specified thresholds.

In [None]:
thresholds = "[200,400]"  # Thresholds for grouping tumors by size:
                          # - <200: Small tumors
                          # - 200-400: Medium tumors
                          # - >400: Large tumors

# Calculate result metrics 
try:
    # Call the main function directly
    calculate_result_metrics(gt_folder=ground_truth_dir, pred_folder=predicted_masks_dir, thresholds=thresholds)
except Exception as e:
    print(f"An error occurred during metrics calculation: {e}")

## 5. Visualize Results

We visualize three key slices from the MRI:
1. First Quarter Slice,
2. Middle Slice,
3. Third Quarter Slice.

For each slice, the following are shown:
1. Original MRI Image.
2. Ground Truth Overlay.
3. Predicted Mask Overlay.
4. Combined Ground Truth and Prediction Overlay.

The visualizations are presented in a grid layout for easy comparison.

In [None]:
# Function to load a NIfTI image
def load_nifti_image(filepath):
    image = sitk.ReadImage(filepath)
    return sitk.GetArrayFromImage(image)

# Function to apply the same validation transforms to both ground truth and predicted images
def apply_test_transforms(image_file, label_file, test_transforms):
    val_data = {"image": image_file, "label": label_file}
    transformed = test_transforms(val_data)
    return transformed["image"], transformed["label"]

def visualize_slices_grid(image, ground_truth, prediction):
    """
    Visualize slices in a grid:
    1st row: First quarter slice
    2nd row: Middle slice
    3rd row: Third quarter slice

    Each row contains:
    1. Original Image
    2. Ground Truth Overlay
    3. Predicted Mask Overlay
    4. Combined Overlay
    """
    # Calculate slice indices dynamically
    num_slices = image.shape[0]
    slice_indices = [num_slices // 4, num_slices // 2, 3 * num_slices // 4]

    plt.figure(figsize=(24, 18))  # Adjust figure size for three rows

    for row_idx, slice_idx in enumerate(slice_indices):
        # Extract slices
        if len(image.shape) == 3:  # 3D MRI
            image_slice = image[slice_idx, :, :]
            gt_slice = ground_truth[slice_idx, :, :]
            pred_slice = prediction[slice_idx, :, :]
     
        else:
            raise ValueError(f"Unexpected image shape: {image.shape}")

        # Original Image
        plt.subplot(3, 4, row_idx * 4 + 1)
        plt.imshow(image_slice, cmap='gray')
        plt.title(f'Original Image - Slice {slice_idx}')
        plt.axis('off')

        # Ground Truth Overlay
        plt.subplot(3, 4, row_idx * 4 + 2)
        plt.imshow(image_slice, cmap='gray', alpha=0.8)
        plt.imshow(gt_slice, cmap='Reds', alpha=0.2)
        plt.title(f'Ground Truth - Slice {slice_idx}')
        plt.axis('off')

        # Predicted Mask Overlay
        plt.subplot(3, 4, row_idx * 4 + 3)
        plt.imshow(image_slice, cmap='gray', alpha=0.8)
        plt.imshow(pred_slice, cmap='Blues', alpha=0.2)
        plt.title(f'Predicted Mask - Slice {slice_idx}')
        plt.axis('off')

        # Combined Overlay (Ground Truth + Predicted)
        plt.subplot(3, 4, row_idx * 4 + 4)
        plt.imshow(image_slice, cmap='gray', alpha=0.8)
        plt.imshow(gt_slice, cmap='Reds', alpha=0.2)
        plt.imshow(pred_slice, cmap='Blues', alpha=0.2)
        plt.title(f'Combined Overlay - Slice {slice_idx}')
        plt.axis('off')

    plt.tight_layout()
    plt.show()

In [None]:
basename = next((f for f in os.listdir(test_mri_dir) if f.endswith(".nii.gz")), None)
if not basename:
    raise FileNotFoundError("No NIfTI files found in the MRI directory.")

# Generate file paths
image_file_path = os.path.join(test_mri_dir, basename)
gt_file_path = os.path.join(ground_truth_dir, basename.replace("_0000", ""))
pred_file_path = os.path.join(predicted_masks_dir, basename.replace("_0000", ""))

# Load the images, ground truth, and predicted masks as numpy arrays
mri_nifti = load_nifti_image(image_file_path)
ground_truth_nifti = load_nifti_image(gt_file_path.replace('labelsTs', 'labelsTs_transformed'))
predicted_nifti = load_nifti_image(pred_file_path)

print(f"Image shape: {mri_nifti.shape}")
print(f"Ground Truth shape: {ground_truth_nifti.shape}")
print(f"Prediction shape: {predicted_nifti.shape}")

# Apply the validation transforms to both the ground truth and predicted data
img_transformed, gt_transformed = apply_test_transforms(image_file_path, gt_file_path, test_transforms)
_, pred_transformed = apply_test_transforms(image_file_path, pred_file_path, test_transforms)

# Ensure that the images and masks have the same shape
if gt_transformed.shape != pred_transformed.shape:
    print(f"Shape mismatch: GT shape: {gt_transformed.shape}, Pred shape: {pred_transformed.shape}")
else:
    # Plot the individual slices (original image, ground truth, predicted) with original image as background
    visualize_slices_grid(mri_nifti, ground_truth_nifti, predicted_nifti)