# Inference Visualization

This notebook visualizes model predictions on test volumes.

## Objectives
- Load a test volume
- Load or generate a prediction mask
- Overlay volume + segmentation
- Display 3 standard views (axial, coronal, sagittal)
- Calculate Dice metric on a test case


In [None]:
import sys
from pathlib import Path
import numpy as np
import matplotlib.pyplot as plt
import torch
import nibabel as nib
from typing import Optional, Tuple

sys.path.insert(0, str(Path().absolute().parent))

from src.inference import predict
from src.inference import visualize
from src.data import preprocessing


## 1. Load Test Volume

Load a test volume from NIfTI file or DICOM directory.


In [None]:
# Configuration: specify path to your test volume
# Option 1: NIfTI file
volume_path = Path("../data/processed/volume_001.nii.gz")

# Option 2: DICOM directory (uncomment to use)
# volume_path = Path("../data/raw/patient_001")

# Load volume
print(f"Loading volume from: {volume_path}")

if volume_path.is_dir():
    # DICOM format
    volume, metadata = preprocessing.load_dicom_volume(volume_path)
    print("Loaded DICOM volume")
elif volume_path.suffix in [".nii", ".gz"]:
    # NIfTI format
    if volume_path.exists():
        volume, metadata = preprocessing.load_nifti_volume(volume_path)
        print("Loaded NIfTI volume")
    else:
        print(f"ERROR: Volume file not found at {volume_path}")
        print("Please specify correct path or run preprocessing first.")
        volume = None
        metadata = None
else:
    print(f"ERROR: Unsupported file format: {volume_path}")
    volume = None
    metadata = None

# Display volume information
if volume is not None:
    print(f"\nVolume shape: {volume.shape}")
    print(f"Volume spacing: {metadata.get('spacing', 'N/A')}")
    print(f"Volume intensity range: [{volume.min():.2f}, {volume.max():.2f}]")
    print(f"Volume dtype: {volume.dtype}")
else:
    print("Cannot proceed without volume.")


## 2. Load or Generate Prediction Mask

Load a saved prediction mask or generate one using a trained model.


In [None]:
# Option 1: Load a saved prediction mask
prediction_path = Path("../data/processed/prediction_001.nii.gz")

# Option 2: Generate prediction using a trained model (uncomment to use)
# checkpoint_path = Path("../checkpoints/best_model.pth")
# prediction_path = None

# Load or generate prediction
if prediction_path is not None and prediction_path.exists():
    # Load saved prediction
    print(f"Loading prediction from: {prediction_path}")
    nii_img = nib.load(str(prediction_path))
    segmentation = nii_img.get_fdata().astype(np.float32)
    print("Loaded saved prediction")
elif prediction_path is not None:
    # Try to generate prediction
    print(f"Prediction file not found at {prediction_path}")
    print("Attempting to generate prediction...")
    
    # Check if checkpoint is available
    checkpoint_path = Path("../checkpoints/best_model.pth")
    if checkpoint_path.exists() and volume is not None:
        print(f"Using checkpoint: {checkpoint_path}")
        segmentation, _ = predict.predict_from_file(
            checkpoint_path=checkpoint_path,
            volume_path=volume_path,
            device=torch.device("cuda" if torch.cuda.is_available() else "cpu"),
            threshold=0.5,
            apply_morphology=True,
        )
        print("Generated prediction")
    else:
        print("ERROR: Cannot generate prediction - checkpoint not found or volume not loaded")
        segmentation = None
else:
    # Generate prediction from checkpoint
    checkpoint_path = Path("../checkpoints/best_model.pth")
    if checkpoint_path.exists() and volume is not None:
        print(f"Generating prediction using checkpoint: {checkpoint_path}")
        segmentation, _ = predict.predict_from_file(
            checkpoint_path=checkpoint_path,
            volume_path=volume_path,
            device=torch.device("cuda" if torch.cuda.is_available() else "cpu"),
            threshold=0.5,
            apply_morphology=True,
        )
        print("Generated prediction")
    else:
        print("ERROR: Checkpoint not found or volume not loaded")
        segmentation = None

# Display prediction information
if segmentation is not None:
    print(f"\nSegmentation shape: {segmentation.shape}")
    print(f"Segmentation value range: [{segmentation.min():.2f}, {segmentation.max():.2f}]")
    print(f"Segmentation dtype: {segmentation.dtype}")
    print(f"Segmented voxels: {np.sum(segmentation > 0.5):,} / {segmentation.size:,} ({100 * np.sum(segmentation > 0.5) / segmentation.size:.2f}%)")
    
    # Check if shapes match
    if volume is not None and volume.shape != segmentation.shape:
        print(f"\nWARNING: Volume shape {volume.shape} != Segmentation shape {segmentation.shape}")
        print("Resampling segmentation to match volume...")
        from scipy.ndimage import zoom
        zoom_factors = [v / s for v, s in zip(volume.shape, segmentation.shape)]
        segmentation = zoom(segmentation, zoom_factors, order=1)
        print(f"Resampled segmentation shape: {segmentation.shape}")
else:
    print("Cannot proceed without segmentation.")


In [None]:
def plot_three_views(
    volume: np.ndarray,
    segmentation: np.ndarray,
    slice_indices: Optional[Tuple[int, int, int]] = None,
    save_path: Optional[Path] = None,
):
    """Plot 3 standard views (axial, coronal, sagittal) with overlay.
    
    Args:
        volume: Input volume array (D, H, W).
        segmentation: Segmentation mask array (D, H, W).
        slice_indices: Optional tuple of (axial_idx, coronal_idx, sagittal_idx).
                       If None, uses middle slices.
        save_path: Optional path to save the figure.
    """
    d, h, w = volume.shape
    
    if slice_indices is None:
        axial_idx = d // 2
        coronal_idx = h // 2
        sagittal_idx = w // 2
    else:
        axial_idx, coronal_idx, sagittal_idx = slice_indices
    
    fig, axes = plt.subplots(1, 3, figsize=(18, 6))
    fig.suptitle("Segmentation Visualization - 3 Standard Views", fontsize=16, y=1.02)
    
    # Axial view
    vol_slice = volume[axial_idx, :, :]
    seg_slice = segmentation[axial_idx, :, :]
    axes[0].imshow(vol_slice, cmap="gray", aspect="auto")
    axes[0].imshow(seg_slice, cmap="Reds", alpha=0.5, aspect="auto")
    axes[0].set_title(f"Axial View (Slice {axial_idx}/{d-1})", fontsize=14)
    axes[0].axis("off")
    
    # Coronal view
    vol_slice = volume[:, coronal_idx, :]
    seg_slice = segmentation[:, coronal_idx, :]
    axes[1].imshow(vol_slice, cmap="gray", aspect="auto")
    axes[1].imshow(seg_slice, cmap="Reds", alpha=0.5, aspect="auto")
    axes[1].set_title(f"Coronal View (Slice {coronal_idx}/{h-1})", fontsize=14)
    axes[1].axis("off")
    
    # Sagittal view
    vol_slice = volume[:, :, sagittal_idx]
    seg_slice = segmentation[:, :, sagittal_idx]
    axes[2].imshow(vol_slice, cmap="gray", aspect="auto")
    axes[2].imshow(seg_slice, cmap="Reds", alpha=0.5, aspect="auto")
    axes[2].set_title(f"Sagittal View (Slice {sagittal_idx}/{w-1})", fontsize=14)
    axes[2].axis("off")
    
    plt.tight_layout()
    
    if save_path:
        save_path = Path(save_path)
        save_path.parent.mkdir(parents=True, exist_ok=True)
        plt.savefig(save_path, dpi=150, bbox_inches="tight")
        print(f"Figure saved to {save_path}")
    else:
        plt.show()
    
    plt.close()

if volume is not None and segmentation is not None:
    plot_three_views(volume, segmentation)
else:
    print("Cannot plot views: volume or segmentation not loaded.")


In [None]:
def plot_slice_comparison(
    volume: np.ndarray,
    segmentation: np.ndarray,
    axial_idx: int,
    coronal_idx: int,
    sagittal_idx: int,
):
    """Plot specific slices for detailed inspection.
    
    Args:
        volume: Input volume array (D, H, W).
        segmentation: Segmentation mask array (D, H, W).
        axial_idx: Axial slice index.
        coronal_idx: Coronal slice index.
        sagittal_idx: Sagittal slice index.
    """
    d, h, w = volume.shape
    
    # Clamp indices to valid range
    axial_idx = max(0, min(axial_idx, d - 1))
    coronal_idx = max(0, min(coronal_idx, h - 1))
    sagittal_idx = max(0, min(sagittal_idx, w - 1))
    
    fig, axes = plt.subplots(2, 3, figsize=(18, 12))
    fig.suptitle("Detailed Slice Comparison", fontsize=16, y=0.995)
    
    views = [
        ("Axial", volume[axial_idx, :, :], segmentation[axial_idx, :, :], axial_idx, d),
        ("Coronal", volume[:, coronal_idx, :], segmentation[:, coronal_idx, :], coronal_idx, h),
        ("Sagittal", volume[:, :, sagittal_idx], segmentation[:, :, sagittal_idx], sagittal_idx, w),
    ]
    
    for col, (view_name, vol_slice, seg_slice, idx, max_idx) in enumerate(views):
        # Volume only
        axes[0, col].imshow(vol_slice, cmap="gray", aspect="auto")
        axes[0, col].set_title(f"{view_name} - Volume Only\n(Slice {idx}/{max_idx-1})", fontsize=12)
        axes[0, col].axis("off")
        
        # Volume + Segmentation
        axes[1, col].imshow(vol_slice, cmap="gray", aspect="auto")
        axes[1, col].imshow(seg_slice, cmap="Reds", alpha=0.6, aspect="auto")
        axes[1, col].set_title(f"{view_name} - Volume + Segmentation", fontsize=12)
        axes[1, col].axis("off")
    
    plt.tight_layout()
    plt.show()
    plt.close()

if volume is not None and segmentation is not None:
    d, h, w = volume.shape
    
    # Example: plot middle slices
    print("Plotting middle slices...")
    plot_slice_comparison(
        volume,
        segmentation,
        axial_idx=d // 2,
        coronal_idx=h // 2,
        sagittal_idx=w // 2,
    )
    
    # You can modify these indices to explore different slices
    # plot_slice_comparison(volume, segmentation, axial_idx=50, coronal_idx=100, sagittal_idx=80)
else:
    print("Cannot plot slice comparison: volume or segmentation not loaded.")


## 5. Calculate Dice Metric

Calculate Dice score by comparing prediction with ground truth mask (if available).


In [None]:
def calculate_dice_numpy(
    pred: np.ndarray,
    target: np.ndarray,
    threshold: float = 0.5,
    smooth: float = 1e-5,
) -> float:
    """Calculate Dice score from numpy arrays.
    
    Args:
        pred: Predicted mask array (D, H, W) or probabilities.
        target: Ground truth mask array (D, H, W).
        threshold: Threshold for binarization (default: 0.5).
        smooth: Smoothing factor to avoid division by zero (default: 1e-5).
        
    Returns:
        Dice score (float).
    """
    # Binarize prediction
    if pred.max() > 1.0 or pred.min() < 0.0:
        # Assume logits, apply sigmoid
        pred_binary = (1 / (1 + np.exp(-pred)) > threshold).astype(np.float32)
    else:
        pred_binary = (pred > threshold).astype(np.float32)
    
    # Ensure target is binary
    target_binary = (target > 0.5).astype(np.float32)
    
    # Flatten arrays
    pred_flat = pred_binary.flatten()
    target_flat = target_binary.flatten()
    
    # Calculate Dice
    intersection = np.sum(pred_flat * target_flat)
    union = np.sum(pred_flat) + np.sum(target_flat)
    
    dice = (2.0 * intersection + smooth) / (union + smooth)
    
    return float(dice)


def calculate_iou_numpy(
    pred: np.ndarray,
    target: np.ndarray,
    threshold: float = 0.5,
    smooth: float = 1e-5,
) -> float:
    """Calculate IoU score from numpy arrays.
    
    Args:
        pred: Predicted mask array (D, H, W) or probabilities.
        target: Ground truth mask array (D, H, W).
        threshold: Threshold for binarization (default: 0.5).
        smooth: Smoothing factor to avoid division by zero (default: 1e-5).
        
    Returns:
        IoU score (float).
    """
    # Binarize prediction
    if pred.max() > 1.0 or pred.min() < 0.0:
        pred_binary = (1 / (1 + np.exp(-pred)) > threshold).astype(np.float32)
    else:
        pred_binary = (pred > threshold).astype(np.float32)
    
    # Ensure target is binary
    target_binary = (target > 0.5).astype(np.float32)
    
    # Flatten arrays
    pred_flat = pred_binary.flatten()
    target_flat = target_binary.flatten()
    
    # Calculate IoU
    intersection = np.sum(pred_flat * target_flat)
    union = np.sum(pred_flat) + np.sum(target_flat) - intersection
    
    iou = (intersection + smooth) / (union + smooth)
    
    return float(iou)


# Load ground truth mask (if available)
ground_truth_path = Path("../data/processed/mask_001.nii.gz")

if ground_truth_path.exists():
    print(f"Loading ground truth from: {ground_truth_path}")
    nii_img = nib.load(str(ground_truth_path))
    ground_truth = nii_img.get_fdata().astype(np.float32)
    print("Loaded ground truth mask")
    
    # Resample if needed
    if ground_truth.shape != segmentation.shape:
        print(f"Resampling ground truth from {ground_truth.shape} to {segmentation.shape}")
        from scipy.ndimage import zoom
        zoom_factors = [s / g for s, g in zip(segmentation.shape, ground_truth.shape)]
        ground_truth = zoom(ground_truth, zoom_factors, order=1)
    
    # Calculate metrics
    dice_score = calculate_dice_numpy(segmentation, ground_truth, threshold=0.5)
    iou_score = calculate_iou_numpy(segmentation, ground_truth, threshold=0.5)
    
    print("\n" + "=" * 60)
    print("SEGMENTATION METRICS")
    print("=" * 60)
    print(f"Dice Score: {dice_score:.4f}")
    print(f"IoU Score: {iou_score:.4f}")
    print("=" * 60)
    
    # Visualize comparison
    if volume is not None:
        print("\nVisualizing prediction vs ground truth...")
        visualize.visualize_2d_slices(
            volume,
            segmentation,
            ground_truth=ground_truth,
        )
else:
    print(f"Ground truth not found at {ground_truth_path}")
    print("Skipping Dice calculation. To calculate Dice, provide a ground truth mask.")
    ground_truth = None


In [None]:
def print_segmentation_statistics(
    volume: np.ndarray,
    segmentation: np.ndarray,
    ground_truth: Optional[np.ndarray] = None,
):
    """Print summary statistics about the segmentation.
    
    Args:
        volume: Input volume array (D, H, W).
        segmentation: Segmentation mask array (D, H, W).
        ground_truth: Optional ground truth mask (D, H, W).
    """
    print("=" * 60)
    print("SEGMENTATION STATISTICS")
    print("=" * 60)
    
    # Volume statistics
    print(f"\nVolume:")
    print(f"  Shape: {volume.shape}")
    print(f"  Intensity range: [{volume.min():.2f}, {volume.max():.2f}]")
    print(f"  Mean intensity: {volume.mean():.2f}")
    
    # Segmentation statistics
    binary_mask = (segmentation > 0.5).astype(np.float32)
    num_voxels = segmentation.size
    num_segmented = np.sum(binary_mask)
    volume_ratio = num_segmented / num_voxels * 100
    
    print(f"\nSegmentation:")
    print(f"  Shape: {segmentation.shape}")
    print(f"  Value range: [{segmentation.min():.2f}, {segmentation.max():.2f}]")
    print(f"  Segmented voxels: {num_segmented:,} / {num_voxels:,} ({volume_ratio:.2f}%)")
    
    if num_segmented > 0:
        # Calculate bounding box
        coords = np.argwhere(binary_mask)
        bbox_min = coords.min(axis=0)
        bbox_max = coords.max(axis=0)
        bbox_size = bbox_max - bbox_min + 1
        
        print(f"  Bounding box: {bbox_min} to {bbox_max}")
        print(f"  Bounding box size: {bbox_size}")
    
    # Ground truth comparison
    if ground_truth is not None:
        gt_binary = (ground_truth > 0.5).astype(np.float32)
        num_gt = np.sum(gt_binary)
        gt_ratio = num_gt / num_voxels * 100
        
        print(f"\nGround Truth:")
        print(f"  Segmented voxels: {num_gt:,} / {num_voxels:,} ({gt_ratio:.2f}%)")
        
        # Overlap statistics
        overlap = binary_mask * gt_binary
        num_overlap = np.sum(overlap)
        
        print(f"\nOverlap:")
        print(f"  Overlapping voxels: {num_overlap:,}")
        if num_segmented > 0:
            precision = num_overlap / num_segmented * 100
            print(f"  Precision: {precision:.2f}%")
        if num_gt > 0:
            recall = num_overlap / num_gt * 100
            print(f"  Recall: {recall:.2f}%")
    
    print("=" * 60)

if volume is not None and segmentation is not None:
    print_segmentation_statistics(volume, segmentation, ground_truth)
else:
    print("Cannot print statistics: volume or segmentation not loaded.")


## Summary

This notebook visualized model predictions by:

1. **Loading test volume**: Supported NIfTI and DICOM formats
2. **Loading/generating predictions**: From saved file or model inference
3. **3-view visualization**: Axial, coronal, and sagittal views with overlay
4. **Interactive navigation**: Explore different slices
5. **Dice calculation**: Compare prediction with ground truth (if available)
6. **Statistics**: Summary of segmentation properties

### Key Features:
- Visual overlay of volume and segmentation in 3 standard planes
- Dice and IoU metrics for quantitative evaluation
- Detailed statistics including bounding box and overlap analysis
- Support for both saved predictions and on-the-fly inference
