In [None]:
#%% [markdown]
# # Hepato-Synth: Exploratory Data Analysis & Visualization Notebook
# 
# **Objective:** This notebook serves as an interactive environment for:
# 1.  **Data Quality Control:** Inspecting raw, registered, and normalized images for a single case.
# 2.  **Model Input Verification:** Visualizing the physical parameter maps ($K^{trans}, v_e$).
# 3.  **Result Analysis:** Qualitatively evaluating the model's outputs (generated images, segmentations, uncertainty maps).
#
# ---
# **Instructions:**
# 1.  Make sure your Python environment is activated and has all dependencies from `requirements.txt` installed.
# 2.  Update the `CASE_ID` and path variables in the "Configuration" cell below.
# 3.  Run the cells sequentially to perform the analysis.

#%%
# ==============================================================================
# Cell 1: Setup - Imports and Configuration
# ==============================================================================
import SimpleITK as sitk
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path

# --- Configuration ---
# MODIFY THESE PATHS TO POINT TO YOUR DATA
# ------------------------------------------------------------------------------
# Select a case to inspect
PATIENT_ID = "sub-001"
SESSION_ID = "ses-20231101"

# Root directory of your BIDS-like preprocessed data
# This should point to the derivatives folder created by your pipeline
DERIVATIVES_ROOT = Path("../outputs/hepato-synth_derivatives/") # Assuming a relative path from notebooks/

# Path to the model's output directory for a specific run
# This would contain the final predictions for the selected case
INFERENCE_OUTPUT_ROOT = Path("../outputs/study_1_acceleration/2025-12-25_10-00-00/inference_results/")

# Define plot style
plt.style.use('grayscale')
plt.rcParams['figure.figsize'] = (18, 6)
plt.rcParams['axes.titlesize'] = 16
plt.rcParams['axes.labelsize'] = 14


# --- Helper Function ---
def plot_slices(images_dict: dict, title: str, slice_idx: int = -1):
    """Plots the central axial slice of multiple images."""
    num_images = len(images_dict)
    fig, axes = plt.subplots(1, num_images, figsize=(6 * num_images, 6))
    if num_images == 1:
        axes = [axes]
    
    for ax, (name, img_np) in zip(axes, images_dict.items()):
        if slice_idx == -1:
            idx = img_np.shape[0] // 2 # Central slice from Z-axis
        else:
            idx = slice_idx
        
        ax.imshow(img_np[idx, :, :], origin='lower')
        ax.set_title(name)
        ax.axis('off')
        
    fig.suptitle(title, fontsize=20)
    plt.tight_layout(rect=[0, 0.03, 1, 0.95])
    plt.show()

print("Setup complete. Please verify the paths above.")


#%%
# ==============================================================================
# Cell 2: Data Loading
# ==============================================================================
# Load all relevant images for the selected case.

# Define paths
normalized_dir = DERIVATIVES_ROOT / "normalization" / PATIENT_ID / SESSION_ID / "anat"
perfusion_dir = DERIVATIVES_ROOT / "perfusion_modeling" / PATIENT_ID / SESSION_ID / "anat"
inference_dir = INFERENCE_OUTPUT_ROOT / f"{PATIENT_ID}_{SESSION_ID}"

images = {}
try:
    # Load normalized input images
    images['T1w_precontrast'] = sitk.GetArrayFromImage(sitk.ReadImage(str(next(normalized_dir.glob("*T1w_precontrast*.nii.gz")))))
    images['T1w_arterial'] = sitk.GetArrayFromImage(sitk.ReadImage(str(next(normalized_dir.glob("*T1w_arterial*.nii.gz")))))
    images['T1w_portal'] = sitk.GetArrayFromImage(sitk.ReadImage(str(next(normalized_dir.glob("*T1w_portal*.nii.gz")))))

    # Load physical parameter maps
    images['Ktrans'] = sitk.GetArrayFromImage(sitk.ReadImage(str(next(perfusion_dir.glob("*phys_Ktrans.nii.gz")))))
    images['ve'] = sitk.GetArrayFromImage(sitk.ReadImage(str(next(perfusion_dir.glob("*phys_ve.nii.gz")))))

    # Load model prediction and (optional) ground truth
    # This part depends on which study's output you are inspecting
    images['Predicted_HBP'] = sitk.GetArrayFromImage(sitk.ReadImage(str(inference_dir / "predicted_hbp.nii.gz")))
    
    # Try to load ground truth if it exists (e.g., from Cohort-A for Study 1)
    gt_hbp_path = next(normalized_dir.glob("*T1w_hbp*.nii.gz"), None)
    if gt_hbp_path:
        images['GroundTruth_HBP'] = sitk.GetArrayFromImage(sitk.ReadImage(str(gt_hbp_path)))

    print(f"Successfully loaded {len(images)} images for case {PATIENT_ID}/{SESSION_ID}.")
    for name, arr in images.items():
        print(f"  - {name}: {arr.shape}")

except (StopIteration, FileNotFoundError) as e:
    print(f"Error: Could not load all required files for the case. Please check the paths.")
    print(e)

#%%
# ==============================================================================
# Cell 3: Preprocessing Quality Check - Registration Visualization
# ==============================================================================
# Let's check how well the arterial phase was registered to the portal phase.
# We will overlay their edges.

# Load the UNREGISTERED arterial phase for comparison
unregistered_dir = DERIVATIVES_ROOT.parent / "sub-001/ses-20231101/anat" # Path to original BIDS
unregistered_arterial = sitk.GetArrayFromImage(sitk.ReadImage(str(next(unregistered_dir.glob("*T1w_arterial.nii.gz")))))

portal_phase = images['T1w_portal']
registered_arterial = images['T1w_arterial']
z_slice = portal_phase.shape[0] // 2

fig, axes = plt.subplots(1, 2, figsize=(12, 6))

# Before Registration
axes[0].imshow(portal_phase[z_slice, :, :], origin='lower')
axes[0].contour(unregistered_arterial[z_slice, :, :], levels=[np.percentile(unregistered_arterial, 95)], colors='r', linewidths=0.8)
axes[0].set_title("Before Registration (Red=Arterial Edge)")
axes[0].axis('off')

# After Registration
axes[1].imshow(portal_phase[z_slice, :, :], origin='lower')
axes[1].contour(registered_arterial[z_slice, :, :], levels=[np.percentile(registered_arterial, 95)], colors='g', linewidths=0.8)
axes[1].set_title("After Registration (Green=Arterial Edge)")
axes[1].axis('off')

plt.suptitle("Registration Quality Check", fontsize=20)
plt.show()

#%%
# ==============================================================================
# Cell 4: Perfusion Map Visualization
# ==============================================================================
# Visualize the physical parameter maps generated by the perfusion modeling pipeline.

plot_slices({
    "Portal Phase": images['T1w_portal'],
    "Ktrans Map": images['Ktrans'],
    "ve Map": images['ve']
}, title="Physical Parameter Maps Visualization")

# Plot histograms to check value distributions
fig, axes = plt.subplots(1, 2, figsize=(12, 5))
axes[0].hist(images['Ktrans'][images['Ktrans'] > 0].flatten(), bins=100, log=True)
axes[0].set_title("Ktrans Distribution (log scale)")
axes[1].hist(images['ve'][images['ve'] > 0].flatten(), bins=100, log=True)
axes[1].set_title("ve Distribution (log scale)")
plt.show()

#%%
# ==============================================================================
# Cell 5: Model Output Visualization (for Study 1/2)
# ==============================================================================
# Compare the model's generated HBP with the input and ground truth.

if 'GroundTruth_HBP' in images:
    plot_slices({
        "Input (Portal)": images['T1w_portal'],
        "Ground Truth HBP": images['GroundTruth_HBP'],
        "Predicted HBP": images['Predicted_HBP']
    }, title="Generation Result vs. Ground Truth")
else:
    plot_slices({
        "Input (Portal)": images['T1w_portal'],
        "Predicted Virtual HBP": images['Predicted_HBP']
    }, title="Virtual HBP Generation Result")

#%%
# ==============================================================================
# Cell 6: Trustworthy AI Visualization (for Study 3)
# ==============================================================================
# This cell would load and visualize the outputs of the diagnostic system.

try:
    # Load segmentation and uncertainty map
    seg_mask = sitk.GetArrayFromImage(sitk.ReadImage(str(inference_dir / "predicted_segmentation.nii.gz")))
    uncertainty_map = sitk.GetArrayFromImage(sitk.ReadImage(str(inference_dir / "prediction_uncertainty.nii.gz")))
    
    z_slice = seg_mask.shape[0] // 2
    
    fig, axes = plt.subplots(1, 3, figsize=(18, 6))

    # Plot Prediction + Segmentation Contour
    axes[0].imshow(images['Predicted_HBP'][z_slice, :, :], origin='lower')
    axes[0].contour(seg_mask[z_slice, :, :], levels=[0.5], colors='lime', linewidths=1.0)
    axes[0].set_title("Prediction with Segmentation")
    axes[0].axis('off')

    # Plot Uncertainty Map
    im = axes[1].imshow(uncertainty_map[z_slice, :, :], cmap='magma', origin='lower')
    axes[1].set_title("Uncertainty Map")
    axes[1].axis('off')
    fig.colorbar(im, ax=axes[1], fraction=0.046, pad=0.04)

    # Plot Prediction with Uncertainty Overlay
    axes[2].imshow(images['Predicted_HBP'][z_slice, :, :], origin='lower')
    # Overlay the uncertainty map with transparency
    uncertainty_overlay = np.ma.masked_where(uncertainty_map[z_slice, :, :] < 0.05, uncertainty_map[z_slice, :, :])
    axes[2].imshow(uncertainty_overlay, cmap='hot', alpha=0.6, origin='lower')
    axes[2].set_title("Prediction with Uncertainty Overlay")
    axes[2].axis('off')

    plt.suptitle("Trustworthy AI Diagnostic Visualization", fontsize=20)
    plt.show()

except (StopIteration, FileNotFoundError):
    print("Could not find diagnostic system outputs (segmentation/uncertainty maps).")
    print("Please ensure you have run inference for a Study 3 model.")