# Mid-Sagittal plane algorithm

### Import relevant packages

In [None]:
import os
import time
import gzip
from multiprocessing import Pool, cpu_count

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import imageio

import nibabel as nib
from nibabel.orientations import aff2axcodes

import joblib

from scipy.ndimage import center_of_mass, find_objects
from scipy.optimize import minimize
from scipy.interpolate import RegularGridInterpolator
from scipy.stats import norm

import ipywidgets as widgets
from matplotlib.lines import Line2D
from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas

In [293]:
def load_nifti_file(file_path):
    """
    Load a NIfTI file and return its data and voxel size, reoriented to (coronal, sagittal, axial).

    Parameters
    ----------
    file_path : str
        Path to the NIfTI file.

    Returns
    -------
    tuple
        data : np.ndarray
            The image data array, reoriented to (coronal, sagittal, axial).
        voxel_size : np.ndarray
            The voxel size for each axis, reordered to match the data orientation.
    """
    nifti_img = nib.load(file_path)
    data = nifti_img.get_fdata()
    voxel_size = np.array(nifti_img.header.get_zooms())

    # determine which data-axis is world-X, world-Y, world-Z
    codes = aff2axcodes(nifti_img.affine)  
    code2world = {'L':0,'R':0,'P':1,'A':1,'I':2,'S':2}
    world_of_axis = [code2world[c] for c in codes]  
    # find data-axis indices for each world axis
    idx_x = world_of_axis.index(0)  # sagittal
    idx_y = world_of_axis.index(1)  # coronal
    idx_z = world_of_axis.index(2)  # axial

    # we want new_data[ coronal, sagittal, axial ]
    new_order = [idx_y, idx_x, idx_z]
    data = np.transpose(data, new_order)
    voxel_size = np.array(tuple(voxel_size[new_order]))

    return data, voxel_size

def open_gzip_file(gzip_file_path):
    """
    Open a gzip-compressed file and return its content as bytes.

    Parameters
    ----------
    gzip_file_path : str
        The path to the gzip file.

    Returns
    -------
    bytes or None
        The content of the gzip file, or None if an error occurs.
    """
    try:
        with gzip.open(gzip_file_path, 'rb') as f_in:
            file_content = f_in.read()
        return file_content
    except Exception as e:
        print(f'Error opening {gzip_file_path}: {e}')
        return None

def get_image_and_voxel_size_from_gzip(gzip_file_path):
    """
    Extract the image array and voxel size from a gzipped NIfTI file.

    Parameters
    ----------
    gzip_file_path : str
        Path to the gzipped NIfTI file.

    Returns
    -------
    tuple
        img_data : np.ndarray or None
            The image data array, or None if loading fails.
        voxel_size : np.ndarray or None
            The voxel size for each axis, or None if loading fails.
    """
    file_content = open_gzip_file(gzip_file_path)
    if file_content is not None:
        with open('temp_nifti.nii', 'wb') as temp_file:
            temp_file.write(file_content)

        # 🔹 Debugging: Check if the file was actually written
        file_size = os.path.getsize('temp_nifti.nii')
        if file_size == 0:
            print(f"❌ Error: 'temp_nifti.nii' was written but is empty! ({gzip_file_path})")
            return None, None

        img_data, voxel_size = load_nifti_file('temp_nifti.nii')
        os.remove('temp_nifti.nii')  # Remove temp file after reading
        
        return img_data, voxel_size
    else:
        print(f"❌ Error: Failed to read file content from '{gzip_file_path}'")
        return None, None

def load_patient_structures(patient_folder: str, structure_names=None):
    """
    Load specified structures from a patient folder containing NIfTI or gzipped NIfTI files.

    Parameters
    ----------
    patient_folder : str
        Path to the folder containing patient structure files.
    structure_names : list of str, optional
        List of structure names to load (default: ['Image']).

    Returns
    -------
    dict
        Dictionary mapping structure names to (image, voxel_size) tuples.
    """
    if structure_names is None:
        structure_names = ['Image']

    struct_dict = {}
    for root, _, files in os.walk(patient_folder):
        for f in files:
            # Only NIfTI files
            if not (f.endswith('.nii.gz') or f.endswith('.nii')):
                continue

            name_part = f
            
            # Remove extension
            if name_part.endswith('.nii.gz'):
                base = name_part[:-7]
            else:
                base = name_part[:-4]

            # Compare with expected structure names (case-insensitive)
            for struct_name in structure_names:
                if base.lower() == struct_name.lower():
                    file_path = os.path.join(root, f)
                    try:
                        if f.endswith('.nii.gz'):
                            img, voxel_size = get_image_and_voxel_size_from_gzip(file_path)
                        else:
                            img, voxel_size = load_nifti_file(file_path)
                    except Exception as e:
                        logging.error(f"Error loading '{struct_name}' from {file_path}: {e}")
                        continue

                    if img is not None:
                        struct_dict[struct_name] = (img, voxel_size)
                        print(f"Loaded '{struct_name}' from {file_path}")
                    break
    return struct_dict

### Image Processing

In [294]:
def mask_via_threshold(ct_image, HU_range=(300, 1500)):
    """
    Create a binary mask for the CT image using a specified Hounsfield Unit (HU) range.

    Parameters
    ----------
    ct_image : np.ndarray
        The CT image array.
    HU_range : tuple of (int, int)
        Lower and upper HU bounds for thresholding.

    Returns
    -------
    np.ndarray
        Binary mask where voxels within the HU range are set to 1.
    """
    bone_mask = np.zeros_like(ct_image)
    lower_bound, upper_bound = HU_range
    bone_mask[(ct_image >= lower_bound) & (ct_image <= upper_bound)] = 1

    return bone_mask

def crop_patient_volumes(struct_dict, slice_axis=2, slice_range=None):
    """
    Crop all structures in the patient dictionary along a specified axis and range.

    Parameters
    ----------
    struct_dict : dict
        Dictionary of structures, each as (image, voxel_size) or image array.
    slice_axis : int
        Axis along which to crop (default: 2, axial).
    slice_range : tuple or None
        (start, end) indices for cropping. If None, uses GTVp mask or full range.

    Returns
    -------
    dict
        Dictionary with cropped structure arrays.
    """
    image = struct_dict['Image'][0]
    if 'GTVp' in struct_dict:
        gtvp = struct_dict['GTVp'][0]
        slc = find_objects(gtvp.astype(bool))[0]
        start, end = slc[slice_axis].start, slc[slice_axis].stop - 1
    elif slice_range is not None:
        start, end = slice_range
    else:
        start, end = 0, image.shape[slice_axis] - 1

    if 'Mandible' in struct_dict:
        mandible = struct_dict['Mandible'][0]
        mandible_slices = find_objects(mandible.astype(bool))
        if mandible_slices:
            mandible_slice = mandible_slices[0][slice_axis]
            mandible_start = mandible_slice.start
            if end <= mandible_start + 5:
                mandible_stop_lower = mandible_start + 5
                start = min(start, mandible_slice.start)
                end = max(end, mandible_stop_lower)

    idx = np.arange(start, end + 1)

    def _crop(arr, idx):
        return np.take(arr, idx, axis=slice_axis)

    for structure in struct_dict:
        struct_dict[structure] = _crop(struct_dict[structure][0], idx)

    idx_interpolation = np.arange(start - 10, end + 1 + 10)
    interpolation_image = _crop(image, idx_interpolation)
    
    return struct_dict, interpolation_image 

def preprocess_bone_image(struct_dict, HU_range):
    """
    Preprocess the CT image to extract bone regions using a HU threshold and optional body mask.

    Parameters
    ----------
    struct_dict : dict
        Dictionary of structures, must include 'Image' and optionally 'Body'.
    HU_range : tuple of (int, int)
        Lower and upper HU bounds for bone thresholding.

    Returns
    -------
    tuple
        proc_image : np.ndarray
            The processed CT image (optionally masked by 'Body').
        bone_ct : np.ndarray
            The CT image with only bone voxels retained.
    """
    proc_image = struct_dict['Image'].astype(np.int16)

    if 'Body' in struct_dict:
        body = struct_dict['Body']
        proc_image = np.where(body == 1, proc_image, np.min(proc_image))

    bone_mask = mask_via_threshold(proc_image, HU_range).astype(np.uint16)
    bone_ct = proc_image * bone_mask

    return proc_image, bone_ct

### Parametrization

In [295]:
def vector_to_angles(vector):
    """
    Convert a 3D vector to spherical coordinates (azimuthal, polar, radius).

    Parameters
    ----------
    vector : array-like of shape (3,)
        The (x, y, z) vector.

    Returns
    -------
    np.ndarray
        Array of [azimuthal, polar, R] in radians and length.
    """
    x, y, z = vector

    # Calculate the distance from the origin
    R = np.linalg.norm(vector)

    # Calculate the distance from the origin in xy plane
    r = np.sqrt(x**2 + y**2)

    # Calculate the angle in the xy plane with respect to the x-axis
    azimuthal = np.arctan2(y, x)

    # Calculate the angle in the xz plane with respect to the z-axis
    polar = np.arctan2(r, z)

    return np.array([azimuthal, polar, R])

def angles_to_vector(azimuthal, polar, R):
    """
    Convert spherical coordinates (azimuthal, polar, radius) to a 3D vector.

    Parameters
    ----------
    azimuthal : float
        Azimuthal angle in radians.
    polar : float
        Polar angle in radians.
    R : float
        Radius (length).

    Returns
    -------
    np.ndarray
        The (x, y, z) vector.
    """
    x = R * np.sin(polar) * np.cos(azimuthal)
    y = R * np.sin(polar) * np.sin(azimuthal)
    z = R * np.cos(polar)
    
    return np.array([x, y, z])

def generate_normal(theta: float, phi: float) -> np.ndarray:
    """
    Generate a 3D normal vector from spherical angles theta and phi.

    Parameters
    ----------
    theta : float or np.ndarray
        Azimuthal angle(s) in radians.
    phi : float or np.ndarray
        Polar angle(s) in radians.

    Returns
    -------
    np.ndarray
        The normal vector(s) as (x, y, z).
    """
    # When theta and phi are arrays, we need to stack along the last axis
    x = np.sin(phi) * np.cos(theta)
    y = np.sin(phi) * np.sin(theta)
    z = np.cos(phi)
    
    # Stack components along the last axis to get shape (..., 3)
    return np.stack([x, y, z], axis=-1)

### Interpolation

In [296]:
def get_cached_interpolator(output_dir, image, voxel_size,
                            filename='interpolator.joblib',
                            method='cubic',
                            pad_slices=10):
    """
    Load or build and cache a RegularGridInterpolator for a cropped+Padded volume,
    shifting the z‐coordinates so that z=0 aligns with the first slice of the original crop.

    Parameters
    ----------
    output_dir : str
        Directory to look for/save the cached interpolator.
    image : np.ndarray
        Cropped+Padded image of shape (Ny, Nx, Nz_pad).
    voxel_size : sequence of 3 floats
        Physical spacing in mm along (Y, X, Z) axes.
    filename : str, optional
        Name of the .joblib file to load/save under output_dir.
    method : str, optional
        Interpolation method: 'cubic', 'linear', etc.
    pad_slices : int, optional
        Number of padded slices at the beginning of the volume. These
        will be assigned negative z-coordinates.

    Returns
    -------
    RegularGridInterpolator
        Interpolator mapping (y_mm, x_mm, z_mm_shifted) to intensity,
        where z_mm_shifted = 0 at the first slice *after* the pad.
    """
    os.makedirs(output_dir, exist_ok=True)
    interpolator_path = os.path.join(output_dir, filename)

    Ny, Nx, Nz_pad = image.shape
    dy, dx, dz = voxel_size

    # build physical‐space sample grids
    grid_y = np.arange(Ny) * dy
    grid_x = np.arange(Nx) * dx
    # shift z so that first un‐padded slice is at z=0
    grid_z = (np.arange(Nz_pad) - pad_slices) * dz

    if os.path.exists(interpolator_path):
        interpolator = joblib.load(interpolator_path)
    else:
        start = time.time()
        interpolator = RegularGridInterpolator(
            (grid_y, grid_x, grid_z),
            image,
            method=method,
            bounds_error=False,
            fill_value=None
        )
        elapsed = time.time() - start
        print(f"Interpolator built in {elapsed:.2f}s; cached to {interpolator_path}")
        joblib.dump(interpolator, interpolator_path)

    return interpolator


### Objective Function

In [297]:
def compute_signed_distances(params_array, image, voxel_size):
    """
    Compute signed distances from all nonzero voxels in the image to a plane.

    Parameters
    ----------
    params_array : array-like of shape (3,)
        Plane parameters [azimuthal, polar, L].
    image : np.ndarray
        Binary or intensity image.
    voxel_size : array-like of shape (3,)
        Physical voxel size for each axis.

    Returns
    -------
    tuple
        d : np.ndarray
            Signed distances from each voxel to the plane.
        n : np.ndarray
            Normal vector of the plane.
        indices_coord_syst_phy : np.ndarray
            Physical coordinates of the voxels in the coordinate system.
        indices_image : np.ndarray
            Indices of nonzero voxels in the image.
    """
    azimuthal, polar, L = params_array

    indices_image = np.array(np.nonzero(image)).T

    indices_image_phy = indices_image * voxel_size

    indices_coord_syst_phy = np.stack([
        indices_image_phy[:,1],
        indices_image_phy[:,0],
        indices_image_phy[:,2]
    ], axis=1)

    n = generate_normal(azimuthal, polar)

    d = indices_coord_syst_phy.dot(n) - L

    return d, n, indices_coord_syst_phy, indices_image

def quadratic_loss_function(diff):
    """
    Compute the quadratic loss for an array of differences.

    Parameters
    ----------
    diff : np.ndarray
        Array of differences.

    Returns
    -------
    np.ndarray
        Quadratic loss values for each element in diff.
    """
    return 0.5 * diff**2

def huber_loss_function(diff, delta=300):
    """
    Compute the Huber loss for an array of differences.

    Parameters
    ----------
    diff : np.ndarray
        Array of differences.
    delta : float, optional
        Huber loss delta parameter.

    Returns
    -------
    np.ndarray
        Huber loss values for each element in diff.
    """
    h_loss = np.where(np.abs(diff) <= delta, 0.5 * diff**2,
                    delta * (np.abs(diff) - 0.5 * delta))
    return h_loss

def p_huber_loss(r, c=300.0, p=1.4):
    """Generalised Huber with polynomial tail."""
    a = np.abs(r)
    quad_mask = a <= c
    loss = np.empty_like(r, dtype=float)

    # quadratic core
    loss[quad_mask] = 0.5 * r[quad_mask] ** 2

    # tail: (2 c^{2-p} / (2-p)) |r|^p  - c^2 / (2-p)
    coeff = 2 * c ** (2 - p) / (2 - p)
    tail = coeff * (a[~quad_mask] ** p) - c ** 2 / (2 - p)
    loss[~quad_mask] = tail
    return loss


def welsch_loss(r, c=200):
    """ρc(r) for Welsch."""
    factor = (r / c) ** 2
    return 0.5 * c ** 2 * (1.0 - np.exp(-factor))

def piecewise_loss(r):
    """
    Custom piecewise loss function:
    - Quadratic for |r| <= 200
    - Linear for 200 < |r| < 1500
    - Quadratic for |r| >= 1500
    Supports scalar or NumPy array input.
    
    Parameters
    ----------
    r : float or np.ndarray
        Residual value(s)
    
    Returns
    -------
    loss : float or np.ndarray
        Computed loss value(s)
    """
    r = np.asarray(r)
    abs_r = np.abs(r)
    
    loss = np.where(
        abs_r <= 200,
        r**2,
        np.where(
            abs_r < 1500,
            1700 * abs_r - 300_000,
            r**2
        )
    )
    
    return loss



def piecewise_intensity_loss(I_orig, I_m, threshold=2800):
    """
    Compute a continuous piecewise loss between original and mirrored intensities.

    For each voxel:
      - If I_m < threshold, loss = 0.5 * (I_orig - I_m)^2
      - Else, loss = 0.1 * (I_orig - I_m) + offset, chosen so that loss is continuous at I_m = threshold.

    Parameters
    ----------
    I_orig : array-like
        Original intensities.
    I_m : array-like
        Mirrored/interpolated intensities.
    threshold : float, optional
        Intensity threshold for switching between squared and linear loss.

    Returns
    -------
    loss : np.ndarray
        Array of the same shape as inputs, containing the piecewise continuous loss.
    """
    I_orig = np.asarray(I_orig)
    I_m = np.asarray(I_m)
    diff = I_orig - I_m

    # Compute diff at the threshold boundary for each voxel
    diff_thresh = I_orig - threshold

    # Squared part below threshold
    loss_sq = 0.5 * diff**2

    # Linear part above threshold, with offset for continuity
    a = 0.1
    offset = 0.5 * diff_thresh**2 - a * diff_thresh
    loss_lin = a * diff + offset

    # Combine piecewise
    loss = np.where(I_m < threshold, loss_sq, loss_lin)
    return loss

def compute_delta(diff, dim, tol=1e-6, max_iter=100):
    """
    Solve for δ via the truncated-moment equation:
      (1/(n δ^2)) * sum(min(r_i^2, δ^2)) = (dim + z) / n,
    where z = log(n), using bisection.
    """
    n = diff.size
    r2 = diff**2
    z = np.log(n)
    
    def phi(delta):
        return (np.minimum(r2, delta**2).sum() / (n * delta**2)) - ((dim + z) / n)
    
    # Bounds
    low = np.finfo(float).eps
    high = np.sqrt(r2.sum() / (dim + z)) + tol
    
    # Bisection
    for _ in range(max_iter):
        mid = 0.5 * (low + high)
        if phi(mid) > 0:
            low = mid
        else:
            high = mid
        if (high - low) < tol * mid:
            break
    
    return 0.5 * (low + high)

def compute_objective(params_array, bone, interpolator_intensity, voxel_size, delta):
    """
    Compute the objective function value for plane symmetry optimization.

    Parameters
    ----------
    params_array : array-like of shape (3,)
        Plane parameters [azimuthal, polar, L].
    bone : np.ndarray
        Bone mask or intensity image.
    interpolator_intensity : RegularGridInterpolator
        Interpolator for the original image.
    voxel_size : array-like of shape (3,)
        Physical voxel size for each axis.
    delta : float
        Huber loss delta parameter.

    Returns
    -------
    float
        The mean Huber loss for the current plane parameters.
    """
    d, n, indices_coord_syst_phy, indices_image = compute_signed_distances(params_array, bone, voxel_size)
    x_m_coord_syst_phy = indices_coord_syst_phy - 2 * d[:, None] * n[None, :]
    x_m_image_phy = np.array([x_m_coord_syst_phy[:, 1], x_m_coord_syst_phy[:, 0], x_m_coord_syst_phy[:, 2]]).T
    I_m = interpolator_intensity(x_m_image_phy)
    I_orig = bone[indices_image[:, 0], indices_image[:, 1], indices_image[:, 2]]
    diff = I_orig - I_m
    
    #f = np.sum(piecewise_intensity_loss(I_orig, I_m)).mean()

    #f= np.sum(quadratic_loss_function(diff)).mean()


    # # 5) estimate δ via truncated-moment
    # dim = params_array.size
    # delta_adapt = compute_delta(diff, dim)


    f = np.sum(huber_loss_function(diff, delta=delta)).mean()
    
    #f = np.sum(welsch_loss(diff, c=delta)).mean()  # Using Welsch loss instead of Huber
    #f = np.sum(p_huber_loss(diff, c=delta, p=1)).mean()  # Using generalized Huber loss
    #f = np.sum(piecewise_loss(diff)).mean()  # Using custom piecewise loss
    return f



def compute_objective_initialization(params_array, bone, interpolator_intensity, voxel_size):
    """
    Compute the objective function value for plane symmetry optimization.

    Parameters
    ----------
    params_array : array-like of shape (3,)
        Plane parameters [azimuthal, polar, L].
    bone : np.ndarray
        Bone mask or intensity image.
    interpolator_intensity : RegularGridInterpolator
        Interpolator for the original image.
    voxel_size : array-like of shape (3,)
        Physical voxel size for each axis.
    delta : float
        Huber loss delta parameter.

    Returns
    -------
    float
        The mean Huber loss for the current plane parameters.
    """
    d, n, indices_coord_syst_phy, indices_image = compute_signed_distances(params_array, bone, voxel_size)
    x_m_coord_syst_phy = indices_coord_syst_phy - 2 * d[:, None] * n[None, :]
    x_m_image_phy = np.array([x_m_coord_syst_phy[:, 1], x_m_coord_syst_phy[:, 0], x_m_coord_syst_phy[:, 2]]).T
    I_m = interpolator_intensity(x_m_image_phy)
    I_orig = bone[indices_image[:, 0], indices_image[:, 1], indices_image[:, 2]]
    diff = I_orig - I_m

    diff_shifted = diff - np.mean(diff)
    mad = np.median(np.abs(diff_shifted))
    sigma_hat = 1.4826 * mad
    delta = 1.345 * sigma_hat

    def plot_shifted_residuals_hist(diff, delta, folderpath=r"/home/loriskeller/Documents/Master Project/Results/28.07.25/test/4927494"):
        """
        Plot the histogram of raw, median-shifted residuals and mark:
        • the Huber cutoff (δ̂) in magenta dashed lines,
        • the standard deviation (σ) of the fitted normal in blue dotted lines.
        
        Parameters
        ----------
        diff : array-like
            The raw residuals.
        delta : float
            The Huber loss threshold (δ̂).
        """
        # Center residuals
        shifted_diff = diff - np.median(diff)

        # Gaussian fit stats
        mu, std = shifted_diff.mean(), shifted_diff.std()

        # Histogram
        plt.figure(figsize=(7, 4))
        plt.hist(shifted_diff, bins=50, density=True, alpha=0.6, label='Shifted residuals')

        # Normal PDF
        x = np.linspace(shifted_diff.min(), shifted_diff.max(), 200)
        p = norm.pdf(x, mu, std)
        plt.plot(x, p, linewidth=2, label='Normal fit')

        # Huber threshold lines (magenta, single label)
        plt.axvline(delta, color='magenta', linestyle='--', label=f'δ̂ = {delta:.2f}')
        plt.axvline(-delta, color='magenta', linestyle='--')

        # Gaussian std lines (blue, single label)
        plt.axvline(std, color='blue', linestyle=':', label=f'σ = {std:.2f}')
        plt.axvline(-std, color='blue', linestyle=':')

        # Labels & legend
        plt.xlabel('Residuals r')
        plt.ylabel('Probability density of voxels')
        plt.legend()
        plt.tight_layout()
        plt.show()
        if folderpath is not None:
            path = os.path.join(folderpath, 'shifted_residuals_histogram.pdf')
            plt.savefig(path, bbox_inches='tight')

    #plot_shifted_residuals_hist(diff, delta)

    f = np.sum(quadratic_loss_function(diff)).mean()

    return f, delta

### Parameter Initialization

In [298]:
def parameter_initialization(
    image,
    bone,
    output_path,
    interpolator_intensity,
    voxel_size,
    azimuthal_deg_range=(0, 90),
    polar_deg_range=(90, 10),
    initialization_steps=10
):
    """
    Perform grid search initialization for plane parameters by evaluating the objective
    function over a range of azimuthal and polar angles.

    Returns the best plane parameters [azimuthal, polar, L].
    """
    start = time.time()

    com_vox = center_of_mass(bone)      # in voxel indices (z, y, x)
    com_phy = np.array([com_vox[1], com_vox[0], com_vox[2]]) * voxel_size

    # 2) Build angular search ranges (in radians)
    az_cent_deg, az_half_deg = azimuthal_deg_range
    pol_cent_deg, pol_half_deg = polar_deg_range
    az_cent_rad, az_half_rad = np.deg2rad([az_cent_deg, az_half_deg])
    pol_cent_rad, pol_half_rad = np.deg2rad([pol_cent_deg, pol_half_deg])

    azimuthal_angles = np.linspace(
        az_cent_rad - az_half_rad,
        az_cent_rad + az_half_rad,
        initialization_steps
    )
    polar_angles = np.linspace(
        pol_cent_rad - pol_half_rad,
        pol_cent_rad + pol_half_rad,
        initialization_steps
    )

    os.makedirs(output_path, exist_ok=True)
    mse_file = os.path.join(output_path, "Initialization_obj_fun.npy")
    params_file = os.path.join(output_path, "Initialization_plane_params.npy")
    delta_file = os.path.join(output_path, "Initialization_delta.npy")

    # 3) If not already computed, run grid search
    if not (os.path.exists(mse_file) and os.path.exists(params_file)):
        print("Starting parameter initialization...")

        # Create grid of angles
        theta_grid, phi_grid = np.meshgrid(
            azimuthal_angles,
            polar_angles,
            indexing="ij"
        )  # both shape (N, N)

        # Compute normals for each (θ, φ)
        n = generate_normal(theta_grid, phi_grid)  # shape (N, N, 3)

        # Compute offsets L_{ij} = n_{ij} · com_phy
        L_grid = np.tensordot(n, com_phy, axes=([-1], [0]))  # shape (N, N)

        # Flatten into list of plane parameters
        planes = np.stack([
            theta_grid.ravel(),
            phi_grid.ravel(),
            L_grid.ravel()
        ], axis=1)  # shape (N*N, 3)

        # Evaluate objective for each plane: returns list of (f, δ)
        results = [
            compute_objective_initialization(p, bone, interpolator_intensity, voxel_size)
            for p in planes
        ]
        mse_vals, delta_vals = zip(*results)
        mse_array = np.array(mse_vals)        # shape (N*N,)
        delta_array = np.array(delta_vals)    # shape (N*N,)

        # Find best index
        best_idx = np.argmin(mse_array)
        best_f = mse_array[best_idx]
        best_delta = delta_array[best_idx]
        print(f"Minimum objective f = {best_f:.4f} at index {best_idx}")
        print(f"Corresponding δ = {best_delta:.4f}")

        # Save for reuse
        np.save(mse_file, mse_array)
        np.save(params_file, planes)
        np.save(delta_file, delta_array)

        end = time.time()
        print(f"Time taken for initialization: {end - start:.2f} s")

    else:
        mse_array = np.load(mse_file)     # shape (N*N,)
        planes = np.load(params_file)     # shape (N*N, 3)
        delta_array = np.load(delta_file)  # shape (N*N,)
        N = initialization_steps
        theta_grid = np.linspace(
            az_cent_rad - az_half_rad,
            az_cent_rad + az_half_rad,
            N
        )
        phi_grid = np.linspace(
            pol_cent_rad - pol_half_rad,
            pol_cent_rad + pol_half_rad,
            N
        )

    # 4) Select best parameters from saved results
    best_idx = np.argmin(mse_array)
    best_params = planes[best_idx]  # [azimuthal, polar, L]
    best_delta = delta_array[best_idx]

    # 5) Optional: plot heatmap of MSE over (θ, φ)
    mse_grid = mse_array.reshape(initialization_steps, initialization_steps)
    plot_mse_heatmap(
        azimuthal_angles,
        polar_angles,
        mse_grid,
        title=(
            f"MSE Heatmap (θ, φ) – optimum at:\n"
            f"θ={np.rad2deg(best_params[0]):.2f}°, "
            f"φ={np.rad2deg(best_params[1]):.2f}°, "
            f"L={best_params[2]:.2f}\n"
            f"MSE={mse_grid.flat[best_idx]:.2f}"
        ),
        output_path=output_path
    )
    print(f"Best estimate of initial plane parameters: θ = {np.rad2deg(best_params[0]):.2f}°, φ = {np.rad2deg(best_params[1]):.2f}°, L = {best_params[2]:.2f}")




    # z_slice = int(round(com_vox[0]))  # axial index through the COM

    # # Convert planes (shape (N²,3)) to a list of tuples
    # all_planes = [tuple(p) for p in planes]

    # plot_and_save_axial_slice(
    #     struct_dict={'Image': bone},sss
    #     voxel_size=voxel_size,
    #     z_indices=6,
    #     save_path=os.path.join(output_path, "all_sampled_planes.pdf"),
    #     crop_size=200,
    #     plane_coeffs_list=all_planes,
    #     optimization_methods_list=None,   # omit legend if too many planes
    #     figsize_per=(6,6)
    # )

    return best_params, best_delta

In [None]:
def _optimize_single_initialization(args):
    """
    Worker function for multiprocessing optimization of a single initialization point.
    
    Parameters
    ----------
    args : tuple
        (index, plane_params, bone, interpolator_intensity, voxel_size, 
         optimal_plane_params, theta_range, phi_range, L_range)
         
    Returns
    -------
    dict
        Optimization result dictionary
    """
    (index, p, bone, interpolator_intensity, voxel_size, 
     optimal_plane_params, theta_range, phi_range, L_range) = args
    
    try:
        # Compute initial objective and delta
        f, delta = compute_objective_initialization(p, bone, interpolator_intensity, voxel_size)
        
        # Run optimization from this initialization
        res = optimize_plane(p, bone, interpolator_intensity, voxel_size, delta)
        
        optimized_parameters = res.x
        optimized_objective_value = res.fun
        
        # Compute normalized error metric relative to optimal parameters
        E = np.sqrt(
            (optimized_parameters[0] - optimal_plane_params[0])**2 / theta_range**2 +
            (optimized_parameters[1] - optimal_plane_params[1])**2 / phi_range**2 +
            (optimized_parameters[2] - optimal_plane_params[2])**2 / L_range**2
        )
        
        return {
            'index': index,
            'initial_params': p.copy(),
            'optimized_params': optimized_parameters.copy(),
            'initial_objective': f,
            'optimized_objective': optimized_objective_value,
            'delta': delta,
            'success': res.success,
            'nit': res.nit,
            'error_metric': E
        }
        
    except Exception as e:
        return {
            'index': index,
            'initial_params': p.copy(),
            'optimized_params': None,
            'initial_objective': f if 'f' in locals() else np.inf,
            'optimized_objective': np.inf,
            'delta': delta if 'delta' in locals() else np.inf,
            'success': False,
            'nit': 0,
            'error_metric': np.inf,
            'error': str(e)
        }


def robustness_test(
    image,
    bone,
    output_path,
    interpolator_intensity,
    voxel_size,
    azimuthal_deg_range=(0, 90),
    polar_deg_range=(90, 10),
    initialization_steps=10,
    optimal_plane_params=None,
    save_results=True,
    force_regenerate=False
):
    """
    Perform grid search initialization for plane parameters by evaluating the objective
    function over a range of azimuthal and polar angles. Test robustness by comparing
    optimized results to known optimal parameters.
    
    Automatically checks for existing results and loads them if available, unless
    force_regenerate=True.

    Parameters
    ----------
    image : np.ndarray
        Input image data
    bone : np.ndarray
        Bone mask or segmentation
    output_path : str
        Directory to save results
    interpolator_intensity : RegularGridInterpolator
        Interpolator for the original image
    voxel_size : array-like
        Physical voxel size for each axis
    azimuthal_deg_range : tuple
        (center, half_range) for azimuthal angles in degrees
    polar_deg_range : tuple
        (center, half_range) for polar angles in degrees
    initialization_steps : int
        Number of steps in each angular direction
    optimal_plane_params : array-like, optional
        Known optimal plane parameters [azimuthal, polar, L] for comparison
    save_results : bool, optional
        Whether to save results to files for later analysis
    force_regenerate : bool, optional
        If True, regenerate results even if existing files are found

    Returns
    -------
    dict
        Dictionary containing optimization results, error metrics, and grid parameters
    """
    start_time = time.time()
    
    if optimal_plane_params is None:
        raise ValueError("optimal_plane_params must be provided for robustness testing")

    # Ensure output directory exists
    os.makedirs(output_path, exist_ok=True)
    
    # Define file paths for saved results
    results_file = os.path.join(output_path, "robustness_test_results.npz")
    metadata_file = os.path.join(output_path, "robustness_test_metadata.json")
    
    # Check if results already exist and match current parameters
    results_exist = os.path.exists(results_file) and os.path.exists(metadata_file)
    
    if results_exist and not force_regenerate:
        try:
            print("Checking for existing robustness test results...")
            
            # Load existing metadata to check if parameters match
            import json
            with open(metadata_file, 'r') as f:
                existing_metadata = json.load(f)
            
            # Check if parameters match (handle tuple/list conversion from JSON)
            current_params = {
                'azimuthal_deg_range': azimuthal_deg_range,
                'polar_deg_range': polar_deg_range,
                'initialization_steps': initialization_steps
            }
            
            existing_params = existing_metadata.get('parameters', {})
            
            # Convert both to lists for comparison (JSON converts tuples to lists)
            def normalize_param(param):
                if isinstance(param, (tuple, list)):
                    return list(param)
                return param
            
            params_match = (
                normalize_param(existing_params.get('azimuthal_deg_range')) == normalize_param(current_params['azimuthal_deg_range']) and
                normalize_param(existing_params.get('polar_deg_range')) == normalize_param(current_params['polar_deg_range']) and
                existing_params.get('initialization_steps') == current_params['initialization_steps']
            )
            
            if params_match:
                print("✓ Found existing results with matching parameters!")
                print(f"  - Results from: {existing_metadata.get('timestamp', 'unknown time')}")
                print(f"  - Grid shape: {existing_metadata.get('grid_shape', 'unknown')}")
                print(f"  - Total optimizations: {existing_metadata.get('total_optimizations', 'unknown')}")
                
                # Load existing results
                data = np.load(results_file)
                
                # Check if optimal_plane_params match (allowing for small numerical differences)
                existing_optimal = data['optimal_plane_params']
                if np.allclose(existing_optimal, optimal_plane_params, atol=1e-6):
                    print("✓ Optimal plane parameters match existing results")
                    
                    # Reconstruct results dictionary
                    optimization_results = []
                    for i in range(len(data['initial_params'])):
                        opt_params = data['optimized_params'][i]
                        if np.all(np.isnan(opt_params)):
                            opt_params = None
                        else:
                            opt_params = opt_params.copy()
                            
                        optimization_results.append({
                            'initial_params': data['initial_params'][i].copy(),
                            'optimized_params': opt_params,
                            'initial_objective': data['initial_objectives'][i],
                            'optimized_objective': data['optimized_objectives'][i],
                            'delta': data['deltas'][i],
                            'success': bool(data['success_flags'][i]),
                            'nit': int(data['nit_counts'][i]),
                            'error_metric': data['error_metrics'][i]
                        })
                    
                    # Compute ranges for compatibility
                    az_cent_deg, az_half_deg = azimuthal_deg_range
                    pol_cent_deg, pol_half_deg = polar_deg_range
                    az_half_rad = np.deg2rad(az_half_deg)
                    pol_half_rad = np.deg2rad(pol_half_deg)
                    theta_range = 2 * az_half_rad
                    phi_range = 2 * pol_half_rad
                    L_range = np.max(data['L_grid']) - np.min(data['L_grid'])
                    
                    results = {
                        'optimization_results': optimization_results,
                        'error_metrics': data['error_metrics'],
                        'azimuthal_angles': data['azimuthal_angles'],
                        'polar_angles': data['polar_angles'],
                        'theta_grid': data['theta_grid'],
                        'phi_grid': data['phi_grid'],
                        'L_grid': data['L_grid'],
                        'optimal_plane_params': data['optimal_plane_params'],
                        'statistics': existing_metadata['statistics'],
                        'ranges': {
                            'theta_range': theta_range,
                            'phi_range': phi_range,
                            'L_range': L_range
                        },
                        'parameters': existing_metadata['parameters']
                    }
                    
                    print("✓ Successfully loaded existing results!")
                    print(f"  - Success rate: {results['statistics']['success_rate']:.1%}")
                    print(f"  - Mean error: {results['statistics']['mean_error']:.4f}")
                    
                    # Generate heatmap from loaded results
                    print("Generating heatmap from existing results...")
                    plot_robustness_heatmap(
                        azimuthal_angles=results['azimuthal_angles'],
                        polar_angles=results['polar_angles'],
                        error_array=results['error_metrics'],
                        title="Robustness Test - Parameter Error vs Initialization (Loaded)",
                        output_path=output_path
                    )
                    
                    return results
                    
                else:
                    print("⚠ Optimal plane parameters don't match existing results")
                    print("  Regenerating with new optimal parameters...")
                    
            else:
                print("⚠ Parameters don't match existing results")
                print(f"  Current: {current_params}")
                print(f"  Existing: {existing_params}")
                print("  Regenerating with new parameters...")
                
        except Exception as e:
            print(f"⚠ Error loading existing results: {e}")
            print("  Regenerating results...")
    
    elif force_regenerate:
        print("🔄 Force regeneration requested - computing new results...")
    else:
        print("📊 No existing results found - computing new results...")

    # Generate new results (original computation logic)
    com_vox = center_of_mass(bone)      # in voxel indices (z, y, x)
    com_phy = np.array([com_vox[1], com_vox[0], com_vox[2]]) * voxel_size

    # Build angular search ranges (in radians)
    az_cent_deg, az_half_deg = azimuthal_deg_range
    pol_cent_deg, pol_half_deg = polar_deg_range
    az_cent_rad, az_half_rad = np.deg2rad([az_cent_deg, az_half_deg])
    pol_cent_rad, pol_half_rad = np.deg2rad([pol_cent_deg, pol_half_deg])

    azimuthal_angles = np.linspace(
        az_cent_rad - az_half_rad,
        az_cent_rad + az_half_rad,
        initialization_steps
    )
    polar_angles = np.linspace(
        pol_cent_rad - pol_half_rad,
        pol_cent_rad + pol_half_rad,
        initialization_steps
    )

    # Create grid of angles
    theta_grid, phi_grid = np.meshgrid(
        azimuthal_angles,
        polar_angles,
        indexing="ij"
    )  # both shape (N, N)

    # Compute normals for each (θ, φ)
    n = generate_normal(theta_grid, phi_grid)  # shape (N, N, 3)

    # Compute offsets L_{ij} = n_{ij} · com_phy
    L_grid = np.tensordot(n, com_phy, axes=([-1], [0]))  # shape (N, N)

    # Flatten into list of plane parameters
    planes = np.stack([
        theta_grid.ravel(),
        phi_grid.ravel(),
        L_grid.ravel()
    ], axis=1)  # shape (N*N, 3)

    # Define the ranges for normalization
    theta_range = 2 * az_half_rad
    phi_range = 2 * pol_half_rad
    L_range = np.max(L_grid) - np.min(L_grid)
    
    print(f"Starting robustness test with {len(planes)} initialization points...")
    print(f"Angular ranges: θ ∈ [{np.rad2deg(az_cent_rad - az_half_rad):.1f}°, {np.rad2deg(az_cent_rad + az_half_rad):.1f}°], "
          f"φ ∈ [{np.rad2deg(pol_cent_rad - pol_half_rad):.1f}°, {np.rad2deg(pol_cent_rad + pol_half_rad):.1f}°]")
    
    # Use multiprocessing for parallel optimization
    n_processes = min(cpu_count(), len(planes))  # Don't use more processes than initialization points
    print(f"Using {n_processes} parallel processes for optimization...")
    
    # Prepare arguments for multiprocessing
    task_args = []
    for index, p in enumerate(planes):
        task_args.append((
            index, p, bone, interpolator_intensity, voxel_size,
            optimal_plane_params, theta_range, phi_range, L_range
        ))
    
    # Run optimizations in parallel
    with Pool(processes=n_processes) as pool:
        optimization_results = pool.map(_optimize_single_initialization, task_args)
    
    # Sort results by index to maintain order
    optimization_results.sort(key=lambda x: x['index'])
    
    # Extract error metrics and print results
    error_metrics = []
    print("\nOptimization Results:")
    for result in optimization_results:
        error_metrics.append(result['error_metric'])
        
        # Print result summary
        idx = result['index']
        initial_p = result['initial_params']
        success = result['success']
        
        print(f"Init {idx+1}/{len(planes)}: θ={np.rad2deg(initial_p[0]):.1f}°, φ={np.rad2deg(initial_p[1]):.1f}°, "
              f"L={initial_p[2]:.2f}, Obj={result['initial_objective']:.2e}, Delta={result['delta']:.2e}")
        
        if success and result['optimized_params'] is not None:
            opt_p = result['optimized_params']
            print(f"Optimized {idx+1}/{len(planes)}: θ={np.rad2deg(opt_p[0]):.1f}°, "
                  f"φ={np.rad2deg(opt_p[1]):.1f}°, L={opt_p[2]:.2f}, "
                  f"Obj={result['optimized_objective']:.2e}, Error={result['error_metric']:.4f}")
        else:
            error_msg = result.get('error', 'optimization failed')
            print(f"Failed {idx+1}/{len(planes)}: {error_msg}")
    
    # Remove index key from results as it's no longer needed
    for result in optimization_results:
        del result['index']

    # Convert to numpy arrays for easier handling
    error_metrics = np.array(error_metrics)
    
    # Compute summary statistics
    finite_errors = error_metrics[np.isfinite(error_metrics)]
    if len(finite_errors) > 0:
        min_error = np.min(finite_errors)
        max_error = np.max(finite_errors)
        mean_error = np.mean(finite_errors)
        std_error = np.std(finite_errors)
        success_rate = len(finite_errors) / len(error_metrics)
    else:
        min_error = max_error = mean_error = std_error = np.inf
        success_rate = 0.0
    
    elapsed_time = time.time() - start_time
    
    print(f"Robustness test completed in {elapsed_time:.2f} seconds")
    print(f"Success rate: {success_rate:.1%}")
    print(f"Error statistics: min={min_error:.4f}, max={max_error:.4f}, "
          f"mean={mean_error:.4f}±{std_error:.4f}")
    
    # Prepare return dictionary
    results = {
        'optimization_results': optimization_results,
        'error_metrics': error_metrics,
        'azimuthal_angles': azimuthal_angles,
        'polar_angles': polar_angles,
        'theta_grid': theta_grid,
        'phi_grid': phi_grid,
        'L_grid': L_grid,
        'optimal_plane_params': optimal_plane_params,
        'statistics': {
            'min_error': min_error,
            'max_error': max_error,
            'mean_error': mean_error,
            'std_error': std_error,
            'success_rate': success_rate,
            'elapsed_time': elapsed_time
        },
        'ranges': {
            'theta_range': theta_range,
            'phi_range': phi_range,
            'L_range': L_range
        },
        'parameters': {
            'azimuthal_deg_range': azimuthal_deg_range,
            'polar_deg_range': polar_deg_range,
            'initialization_steps': initialization_steps
        }
    }
    
    # Save results to files for later analysis
    if save_results:
        print(f"Saving robustness test results...")
        
        # Save main results as compressed numpy file
        np.savez_compressed(
            results_file,
            error_metrics=error_metrics,
            azimuthal_angles=azimuthal_angles,
            polar_angles=polar_angles,
            theta_grid=theta_grid,
            phi_grid=phi_grid,
            L_grid=L_grid,
            optimal_plane_params=optimal_plane_params,
            # Convert optimization results to arrays for efficient storage
            initial_params=np.array([r['initial_params'] for r in optimization_results]),
            optimized_params=np.array([r['optimized_params'] if r['optimized_params'] is not None 
                                     else np.full(3, np.nan) for r in optimization_results]),
            initial_objectives=np.array([r['initial_objective'] for r in optimization_results]),
            optimized_objectives=np.array([r['optimized_objective'] for r in optimization_results]),
            deltas=np.array([r['delta'] for r in optimization_results]),
            success_flags=np.array([r['success'] for r in optimization_results]),
            nit_counts=np.array([r['nit'] for r in optimization_results])
        )
        
        # Save metadata and parameters as JSON for human readability
        import json
        metadata = {
            'parameters': results['parameters'],
            'statistics': results['statistics'],
            'ranges': results['ranges'],
            'grid_shape': [len(azimuthal_angles), len(polar_angles)],
            'total_optimizations': len(optimization_results),
            'timestamp': time.strftime('%Y-%m-%d %H:%M:%S')
        }
        
        with open(metadata_file, 'w') as f:
            json.dump(metadata, f, indent=2)
        
        print(f"Results saved to:")
        print(f"  - {results_file}")
        print(f"  - {metadata_file}")
    
    # Generate and save heatmap of error metrics
    print(f"Generating robustness heatmap...")
    plot_robustness_heatmap(
        azimuthal_angles=azimuthal_angles,
        polar_angles=polar_angles,
        error_array=error_metrics,
        title="Robustness Test - Parameter Error vs Initialization",
        output_path=output_path
    )
    
    return results

In [300]:
def load_and_plot_robustness_results(results_path, title="Robustness Test Heatmap", vmax_percentile=95.0):
    """
    Load saved robustness test results and regenerate the heatmap.
    
    Parameters
    ----------
    results_path : str
        Path to the directory containing saved robustness test results
    title : str, optional
        Title for the regenerated heatmap
    vmax_percentile : float, optional
        Percentile for color scale maximum
        
    Returns
    -------
    dict
        Loaded results dictionary
    """
    import json
    
    # Load the compressed numpy results
    results_file = os.path.join(results_path, "robustness_test_results.npz")
    metadata_file = os.path.join(results_path, "robustness_test_metadata.json")
    
    if not os.path.exists(results_file):
        raise FileNotFoundError(f"Results file not found: {results_file}")
    
    if not os.path.exists(metadata_file):
        raise FileNotFoundError(f"Metadata file not found: {metadata_file}")
    
    print(f"Loading robustness test results from {results_path}")
    
    # Load numpy data
    data = np.load(results_file)
    
    # Load metadata
    with open(metadata_file, 'r') as f:
        metadata = json.load(f)
    
    # Reconstruct results dictionary
    results = {
        'error_metrics': data['error_metrics'],
        'azimuthal_angles': data['azimuthal_angles'],
        'polar_angles': data['polar_angles'],
        'theta_grid': data['theta_grid'],
        'phi_grid': data['phi_grid'],
        'L_grid': data['L_grid'],
        'optimal_plane_params': data['optimal_plane_params'],
        'initial_params': data['initial_params'],
        'optimized_params': data['optimized_params'],
        'initial_objectives': data['initial_objectives'],
        'optimized_objectives': data['optimized_objectives'],
        'deltas': data['deltas'],
        'success_flags': data['success_flags'],
        'nit_counts': data['nit_counts'],
        'metadata': metadata
    }
    
    
    # Regenerate the heatmap
    print("Regenerating robustness heatmap...")
    plot_robustness_heatmap(
        azimuthal_angles=results['azimuthal_angles'],
        polar_angles=results['polar_angles'],
        error_array=results['error_metrics'],
        title=title,
        output_path=results_path,
        vmax_percentile=vmax_percentile
    )
    
    return results




In [301]:
# Usage examples for enhanced robustness testing with automatic result checking

"""
# Example 1: Basic robustness test with automatic result checking (recommended)
results = robustness_test(
    image=image,
    bone=bone_ct,
    output_path=output_path_patient,
    interpolator_intensity=interpolator,
    voxel_size=voxel_size,
    azimuthal_deg_range=(0, 45),
    polar_deg_range=(90, 20),
    initialization_steps=8,
    optimal_plane_params=optimized_parameters
)

# This automatically:
# 1. Checks if results already exist with matching parameters
# 2. Loads existing results if found (much faster!)
# 3. Or computes new results if not found
# 4. Always generates the heatmap

# Example 2: Force regeneration of results (ignore existing files)
results = robustness_test(
    image=image,
    bone=bone_ct,
    output_path=output_path_patient,
    interpolator_intensity=interpolator,
    voxel_size=voxel_size,
    azimuthal_deg_range=(0, 45),
    polar_deg_range=(90, 20),
    initialization_steps=8,
    optimal_plane_params=optimized_parameters,
    force_regenerate=True  # This forces new computation
)

# Example 3: Different parameters will automatically trigger regeneration
results = robustness_test(
    image=image,
    bone=bone_ct,
    output_path=output_path_patient,
    interpolator_intensity=interpolator,
    voxel_size=voxel_size,
    azimuthal_deg_range=(0, 60),      # Different from previous run
    polar_deg_range=(90, 30),         # Different from previous run
    initialization_steps=10,          # Different from previous run
    optimal_plane_params=optimized_parameters
)

# Example 4: Manual loading and heatmap regeneration (for custom analysis)
results_loaded = load_and_plot_robustness_results(
    results_path=output_path_patient,
    title="Custom Analysis Title",
    vmax_percentile=90.0
)

# Example 5: Detailed convergence analysis from saved results
convergence_results = analyze_robustness_convergence(
    results_path=output_path_patient
)

# Example 6: Check what the function would do without running it
results_file = os.path.join(output_path_patient, "robustness_test_results.npz")
metadata_file = os.path.join(output_path_patient, "robustness_test_metadata.json")

if os.path.exists(results_file) and os.path.exists(metadata_file):
    print("✓ Existing results found - will load from cache")
    with open(metadata_file, 'r') as f:
        metadata = json.load(f)
    print(f"  - Generated: {metadata.get('timestamp', 'unknown')}")
    print(f"  - Parameters: {metadata.get('parameters', {})}")
else:
    print("⚠ No existing results - will compute from scratch")

# Benefits of the automatic checking:
# - Massive time savings on repeated runs
# - Consistent results across multiple analyses  
# - Easy parameter comparison (automatic regeneration when parameters change)
# - Robust caching with parameter validation
# - Always gets fresh heatmap even from cached data
"""

'\n# Example 1: Basic robustness test with automatic result checking (recommended)\nresults = robustness_test(\n    image=image,\n    bone=bone_ct,\n    output_path=output_path_patient,\n    interpolator_intensity=interpolator,\n    voxel_size=voxel_size,\n    azimuthal_deg_range=(0, 45),\n    polar_deg_range=(90, 20),\n    initialization_steps=8,\n    optimal_plane_params=optimized_parameters\n)\n\n# This automatically:\n# 1. Checks if results already exist with matching parameters\n# 2. Loads existing results if found (much faster!)\n# 3. Or computes new results if not found\n# 4. Always generates the heatmap\n\n# Example 2: Force regeneration of results (ignore existing files)\nresults = robustness_test(\n    image=image,\n    bone=bone_ct,\n    output_path=output_path_patient,\n    interpolator_intensity=interpolator,\n    voxel_size=voxel_size,\n    azimuthal_deg_range=(0, 45),\n    polar_deg_range=(90, 20),\n    initialization_steps=8,\n    optimal_plane_params=optimized_paramete

In [302]:
def plot_robustness_heatmap(
    azimuthal_angles: np.ndarray,
    polar_angles: np.ndarray,
    error_array: np.ndarray,
    title: str = "Robustness Test Heatmap",
    output_path: str = None,
    vmax_percentile: float = 95.0
):
    """
    Plot and save a heatmap of robustness test error metrics over a grid of initialization angles.

    Parameters
    ----------
    azimuthal_angles : np.ndarray
        Array of azimuthal angles (radians).
    polar_angles : np.ndarray
        Array of polar angles (radians).
    error_array : np.ndarray
        Flattened array of error metric values (length = P*T).
    title : str
        Title for the plot.
    output_path : str or None
        Directory to save the heatmap PDF. If None, displays interactively.
    vmax_percentile : float
        Percentile (0-100) of error values to use as the maximum of the color scale.
        Values above this are clipped to the highest color.
    """
    # Reshape error array to 2D grid
    P = len(polar_angles)
    T = len(azimuthal_angles)
    error_2d = error_array.reshape(P, T)
    
    # Handle infinite values by replacing with a large finite value
    finite_mask = np.isfinite(error_2d)
    if np.any(finite_mask):
        max_finite = np.max(error_2d[finite_mask])
        error_2d_plot = np.where(finite_mask, error_2d, max_finite * 2)
    else:
        error_2d_plot = np.ones_like(error_2d)
    
    # Convert angles to degrees for plotting
    theta_deg = np.rad2deg(azimuthal_angles)
    phi_deg = np.rad2deg(polar_angles)
    TH, PH = np.meshgrid(theta_deg, phi_deg)

    # Compute color scale limits
    if np.any(finite_mask):
        vmin = np.min(error_2d_plot[finite_mask])
        vmax = np.percentile(error_2d_plot[finite_mask], vmax_percentile)
    else:
        vmin, vmax = 0, 1
    
    # Create the plot
    fig, ax = plt.subplots(figsize=(10, 8))
    
    # Use a colormap that highlights low errors (good regions) in blue/green
    # and high errors (problematic regions) in red/yellow
    mesh = ax.pcolormesh(
        TH, PH, error_2d_plot,
        shading='auto',
        cmap='RdYlBu_r',  # Red-Yellow-Blue reversed (blue for low values)
        vmin=0,
        vmax=1
    )
    
    # Add colorbar with proper labeling
    cbar = fig.colorbar(mesh, ax=ax, shrink=0.8)
    cbar.set_label('E', fontsize=12)
    
    # Mark any failed optimizations (infinite errors)
    if np.any(~finite_mask):
        failed_theta = TH[~finite_mask]
        failed_phi = PH[~finite_mask]
        ax.scatter(failed_theta, failed_phi, c='white', s=50, 
                  marker='x', linewidths=2, label='Failed optimizations')
        ax.legend()
    
    # Formatting
    ax.set_xlabel('Azimuthal angle θ (°)', fontsize=12)
    ax.set_ylabel('Polar angle φ (°)', fontsize=12)
    #ax.set_title(title, fontsize=14, fontweight='bold')
    ax.grid(True, linestyle='--', alpha=0.3)
    
    # Save or show
    if output_path:
        os.makedirs(output_path, exist_ok=True)
        out_file = os.path.join(output_path, "robustness_heatmap.pdf")
        fig.savefig(out_file, format='pdf', bbox_inches='tight', dpi=300)
        print(f"Robustness heatmap saved to: {out_file}")
        plt.close(fig)
    else:
        plt.show()

In [303]:
# def plot_mse_heatmap(azimuthal_angles, polar_angles, mse_array, title="MSE Heatmap", output_path=None):
#     """
#     Plot and optionally save a heatmap of the objective function values over a grid of angles.

#     Parameters
#     ----------
#     azimuthal_angles : np.ndarray
#         Array of azimuthal angles (radians).
#     polar_angles : np.ndarray
#         Array of polar angles (radians).
#     mse_array : np.ndarray
#         Array of objective function values.
#     title : str, optional
#         Title for the plot.
#     output_path : str or None
#         Directory to save the heatmap image, if provided.
#     """
#     P = len(polar_angles)
#     T = len(azimuthal_angles)

#     mse_2d = mse_array.reshape(P, T)

#     theta_deg = np.rad2deg(azimuthal_angles)
#     phi_deg   = np.rad2deg(polar_angles)

#     TH, PH = np.meshgrid(theta_deg, phi_deg)
#     plt.figure(figsize=(8, 6))
#     plt.pcolormesh(TH, PH, mse_2d, shading='auto', cmap='viridis', vmin = np.min(mse_2d) , vmax = np.min(mse_2d) + 10**13)
#     plt.colorbar(label='MSE')
#     plt.xlabel('Azimuthal angle θ (°)')
#     plt.ylabel('Polar angle φ (°)')
#     #plt.title(title)
#     plt.grid(visible=True, linestyle='--', alpha=0.7)
#     if output_path is not None:
#         plt.savefig(os.path.join(output_path, "mse_heatmap.pdf"), bbox_inches='tight')



from matplotlib.colors import LogNorm

def plot_mse_heatmap(
    azimuthal_angles: np.ndarray,
    polar_angles: np.ndarray,
    mse_array: np.ndarray,
    title: str = "MSE Heatmap",
    output_path: str = None,
    vmax_percentile: float = 99.0
):
    """
    Plot and optionally save a heatmap of MSE values over a grid of angles,
    using LogNorm for logarithmic scaling of the colormap. Values above the
    percentile-based vmax are clipped to the same top color, while lower values
    are spread out for better differentiation.

    Parameters
    ----------
    azimuthal_angles : np.ndarray
        Array of azimuthal angles (radians).
    polar_angles : np.ndarray
        Array of polar angles (radians).
    mse_array : np.ndarray
        Flattened array of MSE values (length = P*T).
    title : str
        Title for the plot.
    output_path : str or None
        Directory to save the heatmap image (PDF). If None, displays interactively.
    vmax_percentile : float
        Percentile (0-100) of MSE to use as the maximum of the color scale.
        Values above this are clipped to the highest color.
    """
    # reshape and avoid zeros
    P = len(polar_angles)
    T = len(azimuthal_angles)
    mse_2d = mse_array.reshape(P, T)
    eps = np.finfo(float).eps
    mse_plot = mse_2d + eps

    # convert to degrees
    theta_deg = np.rad2deg(azimuthal_angles)
    phi_deg = np.rad2deg(polar_angles)
    TH, PH = np.meshgrid(theta_deg, phi_deg)

    # compute vmin and threshold-based vmax
    vmin = mse_plot.min()
    vmax = np.percentile(mse_plot, vmax_percentile)

    # setup LogNorm with clipping
    norm = LogNorm(vmin=vmin, vmax=vmax, clip=True)

    # plotting
    fig, ax = plt.subplots(figsize=(8, 6))
    mesh = ax.pcolormesh(
        TH, PH, mse_plot,
        shading='auto',
        cmap='viridis',
        norm=norm
    )
    cbar = fig.colorbar(mesh, ax=ax)
    cbar.set_label(f"MSE log scale")

    ax.set_xlabel('Azimuthal angle θ (°)')
    ax.set_ylabel('Polar angle φ (°)')
    #ax.set_title(title)
    ax.grid(True, linestyle='--', alpha=0.7)

    # save or show
    if output_path:
        os.makedirs(output_path, exist_ok=True)
        out_file = os.path.join(output_path, "mse_heatmap.pdf")
        fig.savefig(out_file, format='pdf', bbox_inches='tight')
        plt.close(fig)
    else:
        plt.show()



### Optimization

In [304]:
def optimize_plane(initial_params_array, image, interpolator_intensity, voxel_size, delta):
    """
    Optimize the plane parameters to minimize the objective function using BFGS.

    Parameters
    ----------
    initial_params_array : array-like of shape (3,)
        Initial guess for plane parameters [azimuthal, polar, L].
    image : np.ndarray
        The image or mask to optimize over.
    interpolator_intensity : RegularGridInterpolator
        Interpolator for the original image.
    voxel_size : array-like of shape (3,)
        Physical voxel size for each axis.
    delta : float
        Huber loss delta parameter.

    Returns
    -------
    OptimizeResult
        Result object from scipy.optimize.minimize, with additional attributes:
        - objective_value_list: list of objective values per iteration.
        - params_list: list of parameter arrays per iteration.
    """
    objective_value_list = []
    params_list = []

    def callback(xk):
        f_val = compute_objective(xk, image, interpolator_intensity, voxel_size, delta)
        objective_value_list.append(f_val)
        params_list.append(xk.copy())

    res = minimize(compute_objective, x0=initial_params_array, args=(image, interpolator_intensity, voxel_size, delta),
                   method='BFGS', jac=None, callback=callback,
                   )
    params_list.append(res.x.copy())
    objective_value_list.append(res.fun)

    res.objective_value_list = objective_value_list
    res.params_list = params_list

    return res

def run_or_load_optimization(output_path_patient,
                             image,
                             voxel_size_image,
                             delta,
                             initial_plane,
                             bone_ct,
                             interpolator,
                             optimized_parameter_list,
                             optimized_objective_value_list):
    """
    Run or load the optimization of plane parameters for a single patient.

    Parameters
    ----------
    output_path_patient : str
        Directory to save/load optimization results.
    image : np.ndarray
        The original image array.
    voxel_size_image : array-like of shape (3,)
        Physical voxel size for each axis.
    delta : float
        Huber loss delta parameter.
    initial_plane : array-like of shape (3,)
        Initial guess for plane parameters.
    bone_ct : np.ndarray
        Bone mask or intensity image.
    interpolator : RegularGridInterpolator
        Interpolator for the original image.
    optimized_parameter_list : list
        List to append optimized parameters to.
    optimized_objective_value_list : list
        List to append optimized objective values to.

    Returns
    -------
    tuple
        optimized_parameters : np.ndarray
            Optimized plane parameters [azimuthal, polar, L].
        optimized_objective_value : float
            Final objective function value.
    """
    param_path = os.path.join(output_path_patient, "parameter_array.npy")
    obj_path   = os.path.join(output_path_patient, "objective_value_array.npy")

    if os.path.exists(param_path) and os.path.exists(obj_path):
        # load the last saved result
        params_arr = np.load(param_path)
        obj_arr    = np.load(obj_path)
        optimized_parameters      = params_arr[-1]
        optimized_objective_value = obj_arr[-1]
        print(f"Optimized parameters: "
              f"{np.rad2deg(optimized_parameters[0]):.2f}°, "
              f"{np.rad2deg(optimized_parameters[1]):.2f}°, "
              f"{optimized_parameters[2]:.2f} mm "
              f"with MSE {optimized_objective_value:.2f}")
        
        make_plane_gif(image, voxel_size_image, params_arr, obj_arr, output_path_patient)
    else:
        start = time.time()
        res = optimize_plane(initial_plane,
                             bone_ct,
                             interpolator,
                             voxel_size_image,
                             delta)
        end = time.time()
        print(f"Optimization took {end - start:.2f} seconds.")

        optimized_parameters      = res.x
        optimized_objective_value = res.fun
        print(f"Optimized parameters: "
              f"θ = {np.rad2deg(optimized_parameters[0]):.2f}°, "
              f"φ = {np.rad2deg(optimized_parameters[1]):.2f}°, "
              f"L = {optimized_parameters[2]:.2f} mm "
              f"with MSE {optimized_objective_value:.2f}")
        
        make_plane_gif(image, voxel_size_image, np.array(res.params_list), np.array(res.objective_value_list), output_path_patient)
        np.save(param_path,   np.array(res.params_list))
        np.save(obj_path,     np.array(res.objective_value_list))

        optimized_parameter_list.append(optimized_parameters)
        optimized_objective_value_list.append(optimized_objective_value)

    return optimized_parameters, optimized_objective_value

### Verification Plots

In [305]:
def make_plane_gif(image_3d, voxel_size, plane_params, objective_values, output_path, duration=2):
    """
    Create a GIF showing the evolution of plane contours on the middle slice of a 3D image.

    Parameters
    ----------
    image_3d : np.ndarray
        The 3D image array (coronal, sagittal, axial).
    voxel_size : tuple of float
        Spacing in mm along (coronal=y, sagittal=x, axial=z).
    plane_params : np.ndarray
        Array of plane parameters for each frame (N, 3).
    objective_values : np.ndarray
        Array of objective values for each frame (N,).
    output_path : str
        Directory to save the resulting GIF.
    duration : float, optional
        Time in seconds each frame is shown.
    """
    H, W, D = image_3d.shape
    sy, sx, sz = voxel_size

    # middle slice and its physical z
    z0 = D // 2
    z0_mm = z0 * sz
    slice_img = image_3d[:, :, z0]

    # physical grid for contour
    y_mm = np.arange(H) * sy
    x_mm = np.arange(W) * sx
    X_mm, Y_mm = np.meshgrid(x_mm, y_mm)  # note: meshgrid(x,y) -> X shape (len(y), len(x))

    writer = imageio.get_writer(os.path.join(output_path, "plane_optimization.gif"), mode='I', duration=duration)

    for idx in range(plane_params.shape[0]):
        # extract parameters
        theta, phi, L_mm = plane_params[idx]
        objective_value = objective_values[idx]
        # compute plane normal
        n = np.array([
            np.sin(phi) * np.cos(theta),
            np.sin(phi) * np.sin(theta),
            np.cos(phi)
        ])
        nx, ny, nz = n

        # plane constant
        C = L_mm - nz * z0_mm

        # contour function in mm
        F = nx * X_mm + ny * Y_mm - C

        # plot
        fig = plt.figure(figsize=(6,6))
        canvas = FigureCanvas(fig)
        ax = fig.add_subplot(111)
        ax.imshow(slice_img,
                  cmap='gray',
                  origin='upper',
                  extent=[x_mm[0], x_mm[-1], y_mm[-1], y_mm[0]])
        ax.contour(X_mm, Y_mm, F, levels=[0], colors='red')
        ax.set_title(f'Plane {idx+1}: θ={np.rad2deg(theta):.1f}°, '
                     f'φ={np.rad2deg(phi):.1f}°, L={L_mm:.1f} mm\nMSE={objective_value:.2f}')
        ax.set_xlabel('x (mm)')
        ax.set_ylabel('y (mm)')
        #ax.axis('off')

        # render to image
        canvas.draw()
        buf, (w, h) = canvas.print_to_buffer()
        frame = np.frombuffer(buf, dtype='uint8').reshape(h, w, 4)[..., :3]
        writer.append_data(frame)

        plt.close(fig)

    writer.close()
    print(f"Saved GIF to {output_path}")

def display_scrollable_views(
        struct_dict: dict,
        voxel_size,
        plane_coeffs_list=None,
        optimization_methods_list=None,
):
    """
    Display interactive axial and coronal views of the image and structures, with optional plane overlays.

    Parameters
    ----------
    struct_dict : dict
        Dictionary of structures, must include 'Image'.
    voxel_size : tuple of float
        Spacing in mm along each axis.
    plane_coeffs_list : list of array-like, optional
        List of plane parameter arrays to overlay.
    optimization_methods_list : list of str, optional
        List of method names for legend.
    """
    if plane_coeffs_list is None:
        plane_coeffs_list = []
    if optimization_methods_list is None:
        optimization_methods_list = []

    def _array(obj):
        return obj[0] if isinstance(obj, (tuple, list)) else obj

    image = _array(struct_dict['Image'])

    H, W, D = image.shape
    sy, sx, sz = voxel_size

    # Calculate center of mass using the Body structure for cropping
    # Use Body structure if available, otherwise fallback to geometric center
    if 'Body' in struct_dict:
        try:
            body_mask = _array(struct_dict['Body'])
            if body_mask.any():
                # Get indices where Body structure is present
                coords = np.where(body_mask > 0)
                center_y = int(np.mean(coords[0]))
                center_x = int(np.mean(coords[1]))
                print(f"Using Body structure center of mass: Y={center_y}, X={center_x}")
            else:
                # Body structure exists but is empty
                center_y = H // 2
                center_x = W // 2
                print(f"Body structure is empty, using geometric center: Y={center_y}, X={center_x}")
        except Exception as e:
            # Error accessing Body structure
            center_y = H // 2
            center_x = W // 2
            print(f"Error accessing Body structure ({e}), using geometric center: Y={center_y}, X={center_x}")
    else:
        # No Body structure available, use geometric center
        center_y = H // 2
        center_x = W // 2
        print(f"No Body structure found, using geometric center: Y={center_y}, X={center_x}")
    
    # Define cropping around center of mass (150 voxels in each direction)
    crop_radius = 150
    
    # Calculate crop boundaries
    y_start = max(0, center_y - crop_radius)
    y_end = min(H, center_y + crop_radius)
    x_start = max(0, center_x - crop_radius)
    x_end = min(W, center_x + crop_radius)
    z_start = 0  # Keep full Z range
    z_end = D
    
    # Ensure we have at least the requested size if image is large enough
    if y_end - y_start < 2 * crop_radius and H >= 2 * crop_radius:
        if y_start == 0:
            y_end = min(H, 2 * crop_radius)
        elif y_end == H:
            y_start = max(0, H - 2 * crop_radius)
    
    if x_end - x_start < 2 * crop_radius and W >= 2 * crop_radius:
        if x_start == 0:
            x_end = min(W, 2 * crop_radius)
        elif x_end == W:
            x_start = max(0, W - 2 * crop_radius)
    
    # Ensure minimum dimensions for very small images
    min_size = 50
    if y_end - y_start < min_size:
        center_y = H // 2
        y_start = max(0, center_y - min_size // 2)
        y_end = min(H, y_start + min_size)
    if x_end - x_start < min_size:
        center_x = W // 2
        x_start = max(0, center_x - min_size // 2)
        x_end = min(W, x_start + min_size)
    
    print(f"Cropping around center of mass: Y=[{center_y}], X=[{center_x}]")
    print(f"Crop bounds: Y=[{y_start}:{y_end}], X=[{x_start}:{x_end}], Z=[{z_start}:{z_end}]")
    print(f"Cropped size: {y_end-y_start} x {x_end-x_start} x {z_end-z_start} voxels")
    
    # Crop the image
    image_cropped = image[y_start:y_end, x_start:x_end, z_start:z_end]
    H_crop, W_crop, D_crop = image_cropped.shape

    # physical axes for cropped image
    y_mm = np.arange(y_start, y_end) * sy
    x_mm = np.arange(x_start, x_end) * sx
    z_mm = np.arange(z_start, z_end) * sz  # Full Z range since not cropped

    # precompute meshgrids
    X_ax, Y_ax = np.meshgrid(x_mm, y_mm)   # for axial view
    X_cor, Z_cor = np.meshgrid(x_mm, z_mm) # for coronal view

    # colors
    plane_colors = ['red','purple','cyan','lime','magenta']
    struct_names = [s for s in struct_dict if s != 'Image']
    cmap10 = plt.cm.get_cmap('tab10').colors
    color_cycle = {n: cmap10[i] for i,n in enumerate(struct_names)}

    def view_slice_axial(z_index:int, y_line:int):
        fig, ax = plt.subplots(1,2, figsize=(12,6), constrained_layout=True)
        # --- Axial ---
        ax0 = ax[0]
        ax0.imshow(image_cropped[:,:,z_index], cmap='gray',
                   extent=[x_mm[0],x_mm[-1],y_mm[-1],y_mm[0]])
        ax0.set_title(f'Axial z={z_index} (z={z_index*sz:.1f} mm)')
        ax0.set_xlabel('X (sagittal) mm'); ax0.set_ylabel('Y (coronal) mm')
        # horizontal guide at Y = y_line
        y0 = y_mm[y_line]
        ax0.axhline(y=y0, color='yellow', linestyle='--')
        # structure contours
        for name in struct_names:
            mask = _array(struct_dict[name])
            # Crop the structure mask to match the image cropping
            mask_cropped = mask[y_start:y_end, x_start:x_end, z_start:z_end]
            sl = mask_cropped[:,:,z_index]
            if sl.any():
                ax0.contour(X_ax, Y_ax, sl, colors=[color_cycle[name]], linewidths=1)
        # planes
        for i,(theta,phi,L) in enumerate(plane_coeffs_list):
            nx = np.sin(phi)*np.cos(theta)
            ny = np.sin(phi)*np.sin(theta)
            nz = np.cos(phi)
            C = L - nz*(z_index*sz)
            F = nx*X_ax + ny*Y_ax - C
            ax0.contour(X_ax, Y_ax, F, levels=[0],
                        colors=[plane_colors[i%len(plane_colors)]], linewidths=1)
        # legend
        handles=[]
        for n in struct_names:
            handles.append(Line2D([0],[0],color=color_cycle[n],lw=2,label=n))
        for i,method in enumerate(optimization_methods_list):
            handles.append(Line2D([0],[0],color=plane_colors[i%len(plane_colors)],
                                  lw=2,label=method))
        ax0.legend(handles=handles, loc='upper right')

        # --- Coronal View ---
        ax1 = ax[1]
        
        # Extract coronal slice (using same cropping as axial view)
        sl_full = image_cropped[y_line, :, :].T  # Transpose for proper orientation
        
        # Display the coronal slice
        ax1.imshow(
            sl_full,
            cmap='gray',
            origin='lower',
            extent=[x_mm[0], x_mm[-1], z_mm[0], z_mm[-1]]
        )
        ax1.set_title(
            f'Coronal y={y_line + y_start} '
            f'(y={(y_line + y_start)*sy:.1f} mm)'
        )
        ax1.set_xlabel('X (sagittal) mm')
        ax1.set_ylabel('Z (axial) mm')

        # Guide line at current Z slice
        z0 = z_mm[z_index]
        ax1.axhline(y=z0, color='yellow', linestyle='--')

        # Overlay structure contours (using same coordinate system as axial)
        for name in struct_names:
            mask = _array(struct_dict[name])
            mask_cropped = mask[y_start:y_end, x_start:x_end, z_start:z_end]
            slm_full = mask_cropped[y_line, :, :].T  # Transpose for proper orientation
            if slm_full.any():
                ax1.contour(
                    X_cor,
                    Z_cor,
                    slm_full,
                    colors=[color_cycle[name]],
                    linewidths=1
                )

        # Overlay plane intersections (using same coordinate system as axial)
        y_phys = (y_line + y_start) * sy
        for i, (theta, phi, L) in enumerate(plane_coeffs_list):
            nx = np.sin(phi) * np.cos(theta)
            ny = np.sin(phi) * np.sin(theta)
            nz = np.cos(phi)
            F = nx * X_cor + ny * y_phys + nz * Z_cor - L
            ax1.contour(
                X_cor,
                Z_cor,
                F,
                levels=[0],
                colors=[plane_colors[i % len(plane_colors)]],
                linewidths=1
            )

        plt.show()

    # interactive sliders for Z (axial) and Y (coronal) - adjusted for cropped dimensions
    widgets.interact(
        view_slice_axial,
        z_index = widgets.IntSlider(min=0, max=D_crop-1, step=1, value=D_crop//2,
                                    description='Z slice', continuous_update=False),
        y_line  = widgets.IntSlider(min=0, max=H_crop-1, step=1, value=H_crop//2,
                                    description='Y slice', continuous_update=False),
    )


def plot_axial_coronal(
    struct_dict: dict,
    voxel_size: tuple,
    z_index: int,
    y_line: int,
    save_path: str,
    plane_coeffs_list: list = None,
    optimization_methods_list: list = None,
    axial_crop: tuple = None,
    coronal_crop: tuple = None,
    crop_size: int = None,
    figsize: tuple = (12, 6)
):
    """
    Plot axial and coronal views side by side and save as PDF, with adjustable cropping.

    Parameters
    ----------
    struct_dict : dict
        Must contain 'Image' (ny,nx,nz array) and any mask arrays.
    voxel_size : tuple
        (sy, sx, sz) spacing in mm along Y, X, Z axes.
    z_index : int
        Axial slice index.
    y_line : int
        Y-index for coronal view.
    save_path : str
        Path to output PDF.
    plane_coeffs_list : list of (theta, phi, L), optional
        Planes to overlay.
    optimization_methods_list : list of str, optional
        Labels for plane overlays.
    axial_crop : (x_min, x_max, y_min, y_max), optional
        Pixel indices to crop axial view: X lower/upper, Y lower/upper.
        If None and crop_size is provided, uses center of mass based cropping.
    coronal_crop : (x_min, x_max, z_min, z_max), optional
        Pixel indices to crop coronal view: X lower/upper, Z lower/upper.
        If None and crop_size is provided, uses center of mass based cropping.
    crop_size : int, optional
        Half-width in pixels around center of mass of Body mask for automatic cropping.
        Only used if axial_crop/coronal_crop are None and 'Body' is in struct_dict.
    figsize : (w,h)
        Figure size in inches.
    """
    if plane_coeffs_list is None:
        plane_coeffs_list = []
    if optimization_methods_list is None:
        optimization_methods_list = []

    def _arr(obj): return obj[0] if isinstance(obj, (tuple,list)) else obj

    # Load image
    image = _arr(struct_dict['Image'])
    H, W, D = image.shape
    sy, sx, sz = voxel_size

    # Center of mass based cropping if crop_size is provided and Body mask exists
    if crop_size and 'Body' in struct_dict and (axial_crop is None or coronal_crop is None):
        body = _arr(struct_dict['Body'])
        # Compute center of mass for body mask
        sum_y = body.sum(axis=(1,2)); sum_x = body.sum(axis=(0,2)); sum_z = body.sum(axis=(0,1))
        cy = int(round((np.arange(H)*sum_y).sum()/(sum_y.sum() or 1)))
        cx = int(round((np.arange(W)*sum_x).sum()/(sum_x.sum() or 1)))
        cz = int(round((np.arange(D)*sum_z).sum()/(sum_z.sum() or 1)))
        
        # Auto-generate crop bounds around center of mass
        auto_axial_crop = (
            max(0, cx-crop_size), min(W, cx+crop_size),
            max(0, cy-crop_size), min(H, cy+crop_size)
        )
        auto_coronal_crop = (
            max(0, cx-crop_size), min(W, cx+crop_size),
            max(0, cz-crop_size), min(D, cz+crop_size)
        )
    else:
        auto_axial_crop = None
        auto_coronal_crop = None

    # Default full extents or use provided/auto-generated crops
    ax_x0, ax_x1, ax_y0, ax_y1 = 0, W, 0, H
    if axial_crop:
        ax_x0, ax_x1, ax_y0, ax_y1 = axial_crop
        ax_x0, ax_x1 = max(0,ax_x0), min(W,ax_x1)
        ax_y0, ax_y1 = max(0,ax_y0), min(H,ax_y1)
    elif auto_axial_crop:
        ax_x0, ax_x1, ax_y0, ax_y1 = auto_axial_crop
        ax_x0, ax_x1 = max(0,ax_x0), min(W,ax_x1)
        ax_y0, ax_y1 = max(0,ax_y0), min(H,ax_y1)

    cor_x0, cor_x1, cor_z0, cor_z1 = 0, W, 0, D
    if coronal_crop:
        cor_x0, cor_x1, cor_z0, cor_z1 = coronal_crop
        cor_x0, cor_x1 = max(0,cor_x0), min(W,cor_x1)
        cor_z0, cor_z1 = max(0,cor_z0), min(D,cor_z1)
    elif auto_coronal_crop:
        cor_x0, cor_x1, cor_z0, cor_z1 = auto_coronal_crop
        cor_x0, cor_x1 = max(0,cor_x0), min(W,cor_x1)
        cor_z0, cor_z1 = max(0,cor_z0), min(D,cor_z1)

    # Crop arrays
    ax_img = image[ax_y0:ax_y1, ax_x0:ax_x1, :]
    cor_img = image[:, cor_x0:cor_x1, cor_z0:cor_z1]

    # Physical coords
    x_ax = np.arange(ax_x0, ax_x1)*sx
    y_ax = np.arange(ax_y0, ax_y1)*sy
    z_ax = np.arange(0, D)*sz
    X_ax, Y_ax = np.meshgrid(x_ax, y_ax)

    x_cor = np.arange(cor_x0, cor_x1)*sx
    z_cor = np.arange(cor_z0, cor_z1)*sz
    X_cor, Z_cor = np.meshgrid(x_cor, z_cor)

    # Setup figure
    fig, (ax0, ax1) = plt.subplots(1,2,figsize=figsize,constrained_layout=True)

    # ----- Axial -----
    ax0.imshow(
        ax_img[:,:,z_index], cmap='gray',
        extent=[x_ax[0], x_ax[-1], y_ax[-1], y_ax[0]]
    )
    #ax0.set_title(f'Axial z={z_index}')
    ax0.set_xlabel('X (mm)'); ax0.set_ylabel('Y (mm)')
    # horizontal guide at Y = y_line
    y0 = y_ax[y_line - ax_y0]  # adjust for cropping
    ax0.axhline(y=y0, color='yellow', linestyle='--')

    # Overlay contours
    struct_names = [k for k in struct_dict if k!='Image']
    cmap10 = plt.cm.get_cmap('tab10').colors
    col = {n:cmap10[i%10] for i,n in enumerate(struct_names)}
    for name in struct_names:
        mask = _arr(struct_dict[name])[ax_y0:ax_y1, ax_x0:ax_x1, :]
        sl = mask[:,:,z_index]
        if sl.any(): ax0.contour(X_ax, Y_ax, sl,levels=[0.5],colors=[col[name]],linewidths=1)

    # Planes
    plane_cols = ['red','purple','cyan']
    for i,(theta,phi,L) in enumerate(plane_coeffs_list):
        nx=np.sin(phi)*np.cos(theta); ny=np.sin(phi)*np.sin(theta); nz=np.cos(phi)
        C=L-nz*(z_index*sz)
        F=nx*X_ax+ny*Y_ax-C
        ax0.contour(X_ax,Y_ax,F,levels=[0],colors=[plane_cols[i%len(plane_cols)]],linewidths=1)

    # ----- Coronal -----
    cor_slice = cor_img[x_slice_index if False else y_line, :, :]  # ignore bug
    cor_slice = cor_img[y_line, :, :].T
    ax1.imshow(
        cor_slice,
        cmap='gray', origin='lower',
        extent=[x_cor[0], x_cor[-1], z_cor[0], z_cor[-1]]
    )
    #ax1.set_title(f'Coronal y={y_line}')
    ax1.set_xlabel('X (mm)'); ax1.set_ylabel('Z (mm)')
    # Guide line at current Z slice
    z0 = z_ax[z_index - cor_z0]  # adjust for cropping
    ax1.axhline(y=z0, color='yellow', linestyle='--')

    for name in struct_names:
        mask = _arr(struct_dict[name])
        m2 = mask[y_line, cor_x0:cor_x1, cor_z0:cor_z1].T
        if m2.any(): ax1.contour(X_cor,Z_cor,m2,levels=[0.5],colors=[col[name]],linewidths=1)

    for i,(theta,phi,L) in enumerate(plane_coeffs_list):
        nx=np.sin(phi)*np.cos(theta); ny=np.sin(phi)*np.sin(theta); nz=np.cos(phi)
        y_phys=y_line*sy
        C=L
        F2=nx*X_cor+ny*y_phys+nz*Z_cor-L
        ax1.contour(X_cor,Z_cor,F2,levels=[0],colors=[plane_cols[i%len(plane_cols)]],linewidths=1)

    # Legend
    hds = [Line2D([0],[0],color=col[n],lw=2,label=n) for n in struct_names]
    hds += [Line2D([0],[0],color=plane_cols[i],lw=2,label=optimization_methods_list[i] if i<len(optimization_methods_list) else f'Plane{i}') for i in range(len(plane_coeffs_list))]
    ax0.legend(handles=hds,loc='upper right',fontsize=11,frameon=True)

    plt.savefig(save_path,format='pdf',bbox_inches='tight',pad_inches=0)
    plt.close(fig)



def save_3d_slice_as_vector(array3d, save_dir, index, factor, filename):
    """
    Save a 2D slice of a 3D array as a grayscale vector graphic (PDF).

    Parameters
    ----------
    array3d : numpy.ndarray
        3D array of shape (M, N, P).
    save_dir : str
        Directory where the vector graphic will be saved.
    index : int
        Index along the third dimension specifying which slice to plot.
    factor : int
        Number of pixels to crop from each border (both top/bottom and left/right).
    """
    # Ensure output directory exists
    os.makedirs(save_dir, exist_ok=True)

    # Validate index
    if not (0 <= index < array3d.shape[2]):
        raise ValueError(f"Index {index} is out of bounds for array with depth {array3d.shape[2]}.")

    # Extract the specified slice
    slice_2d = array3d[:, :, index]

    # Crop borders by factor if needed
    if factor > 0:
        slice_2d = slice_2d[factor:-factor, factor:-factor]

    # Create the figure and axes
    fig, ax = plt.subplots()
    ax.imshow(slice_2d, cmap='gray', interpolation='none', aspect='equal')
    ax.axis('off')  # Hide axes for a clean image

    # Adjust layout to remove padding/margins
    fig.tight_layout(pad=0)

    # Define output filepath
    filepath = os.path.join(save_dir, f"{filename}.pdf")

    # Save as vector graphic (PDF)
    fig.savefig(filepath, format='pdf', bbox_inches='tight')
    plt.close(fig)

    print(f"Saved vector graphic {filename} to {filepath}")




# ─── Master mapping: every structure → specific color ─────────────────────────
STRUCTURE_COLOR_MAP = {
    'GTV-T':      'tab:orange',
    'Mandible':   'tab:pink',
    'Body':       'tab:blue',
    'Spinal Cord':'tab:green'
}
DEFAULT_STRUCT_COLOR = 'purple'


def plot_and_save_axial_slice(
    struct_dict: dict,
    voxel_size: tuple,
    z_indices,
    save_path: str,
    crop_size: int = 50,
    plane_coeffs_list: list = None,
    optimization_methods_list: list = None,
    figsize_per: tuple = (6, 6)
):
    """
    Plot one or multiple axial slices. If `z_indices` is a list or tuple, a single figure
    with side-by-side subplots is created. If an integer, only that slice is plotted.

    Parameters
    ----------
    struct_dict : dict
        Must include 'Image' (ny,nx,nz array) and any structure masks keyed by name.
    voxel_size : tuple
        (sy, sx, sz) spacing in mm for Y, X, Z axes.
    z_indices : int or sequence of int
        One or multiple axial slice indices.
    save_path : str
        Output file path (including .pdf).
    crop_size : int, optional
        Half-window in voxels around the body center-of-mass for cropping. Default 50.
    plane_coeffs_list : list of (theta, phi, L), optional
        Planes to overlay on each slice.
    optimization_methods_list : list of str, optional
        Legend labels for the plane overlays.
    figsize_per : (width, height), optional
        Size in inches for each subplot.
    """
    # Normalize z_indices
    if isinstance(z_indices, (int, np.integer)):
        z_list = [z_indices]
    else:
        z_list = list(z_indices)
    n = len(z_list)

    # Defaults
    if plane_coeffs_list is None:
        plane_coeffs_list = []
    if optimization_methods_list is None:
        optimization_methods_list = []

    def _arr(obj):
        return obj[0] if isinstance(obj, (tuple, list)) else obj

    # Load image and optional body
    image = _arr(struct_dict['Image'])
    ny, nx, nz = image.shape
    sy, sx, sz = voxel_size
    body = _arr(struct_dict.get('Body', image > 0))

    # Compute crop bounds
    if crop_size and crop_size > 0:
        sum_y = body.sum(axis=(1,2)); sum_x = body.sum(axis=(0,2))
        cy = int(round((np.arange(ny)*sum_y).sum()/(sum_y.sum() or 1)))
        cx = int(round((np.arange(nx)*sum_x).sum()/(sum_x.sum() or 1)))
        y0, y1 = max(0, cy-crop_size), min(ny, cy+crop_size)
        x0, x1 = max(0, cx-crop_size), min(nx, cx+crop_size)
    else:
        y0, y1, x0, x1 = 0, ny, 0, nx

    # Crop data
    img_c = image[y0:y1, x0:x1, :]

    # Physical coordinates for plotting
    y_mm = np.arange(y0, y1) * sy
    x_mm = np.arange(x0, x1) * sx
    Xg, Yg = np.meshgrid(x_mm, y_mm)

    # Prepare figure and axes
    fig_w, fig_h = figsize_per[0]*n, figsize_per[1]
    fig, axes = plt.subplots(1, n, figsize=(fig_w, fig_h), constrained_layout=True)
    if n == 1:
        axes = [axes]

    # Precompute structure color mapping
    struct_names = [name for name in struct_dict if name != 'Image']
    cmap10 = plt.cm.get_cmap('tab10').colors
    struct_colors = {name: cmap10[i % len(cmap10)] for i,name in enumerate(struct_names)}

    # Loop through each slice index
    for ax, z in zip(axes, z_list):
        # Show image
        ax.imshow(
            img_c[:, :, z], cmap='gray',
            extent=[x_mm[0], x_mm[-1], y_mm[-1], y_mm[0]]
        )
        # Overlay structure contours
        for name in struct_names:
            mask = _arr(struct_dict[name])
            mc = mask[y0:y1, x0:x1, :]
            sl = mc[:, :, z]
            if sl.any():
                ax.contour(
                    Xg, Yg, sl,
                    levels=[0.5],
                    colors=[struct_colors[name]],
                    linewidths=1.5
                )
        # Overlay plane contours
        #plane_colors = ['red','purple','cyan','lime','magenta']
        plane_colors = ['red']
        for i,(theta, phi, L) in enumerate(plane_coeffs_list):
            nx_f = np.sin(phi)*np.cos(theta)
            ny_f = np.sin(phi)*np.sin(theta)
            nz_f = np.cos(phi)
            C = L - nz_f*(z*sz)
            F = nx_f*Xg + ny_f*Yg - C
            ax.contour(
                Xg, Yg, F,
                levels=[0],
                colors=[plane_colors[i % len(plane_colors)]],
                linewidths=1.5
            )
        # Fix aspect and remove axes
        ax.set_xlim(x_mm[0], x_mm[-1])
        ax.set_ylim(y_mm[-1], y_mm[0])
        ax.set_aspect('equal')
        ax.axis('off')
        #ax.set_title(f'Slice {z}', fontsize=12)

    # Build legend on first axis
    handles = [Line2D([0],[0], color=struct_colors[n], lw=2, label=n)
               for n in struct_names]
    for i,method in enumerate(optimization_methods_list):
        c = plane_colors[i % len(plane_colors)]
        handles.append(Line2D([0],[0],color=c,lw=2,label=method))
    axes[0].legend(handles=handles, loc='upper right', fontsize=11, frameon=True)

    # Save figure
    fig.savefig(save_path, format='pdf', bbox_inches='tight', pad_inches=0)
    plt.close(fig)



import numpy as np
import matplotlib.pyplot as plt

# Color for GTV-T contour positions
POSITION_COLORS = ['tab:blue', 'tab:orange', 'tab:green']  # for shifts [max, mid, 0]
PLANE_COLOR = 'red'  # explicit plane color


def plot_gtvt_positions(
    struct_dict: dict,
    voxel_size: tuple,
    z_index: int,
    plane_coeffs: tuple,
    save_path: str,
    shifts: list = None,
    crop_size: int = None,
    figsize_per: tuple = (5, 5)
):
    """
    Plot three side-by-side axial slices at slice index `z_index`, each showing the
    GTV-T contour shifted along the negative XY-plane projection of the provided
    plane normal. The leftmost is the largest shift, the rightmost is no shift.
    Also overlays the plane intersection line (midline) on each subplot, clipped to
    the cropped image bounds.

    Parameters
    ----------
    struct_dict : dict
        Must contain 'Image' (ny,nx,nz array) and 'GTV-T' mask.
        Optionally 'Body' for cropping around center of mass.
    voxel_size : tuple
        (sy, sx, sz) voxel spacing in mm.
    z_index : int
        Axial slice index to display.
    plane_coeffs : tuple (theta, phi, L)
        Plane normal spherical coords (radians) and offset L in mm.
    save_path : str
        Output PDF file path.
    shifts : list of int, optional
        Voxel offsets along negative normal for the three positions. Default [20,10,0].
    crop_size : int, optional
        Half-window (voxels) around body COM to crop slice. If None, no crop.
    figsize_per : (w,h), optional
        Figure size (inches) per subplot.
    """
    def _arr(x):
        return x[0] if isinstance(x, (tuple, list)) else x

    # Default shifts
    if shifts is None:
        shifts = [20, 10, 0]
    if len(shifts) != 3:
        raise ValueError("shifts must be a list of three values")

    # Load arrays
    img = _arr(struct_dict['Image'])
    gtv = _arr(struct_dict['GTV-T'])
    ny, nx, nz = img.shape
    sy, sx, sz = voxel_size

    # Determine crop bounds
    if crop_size and 'Body' in struct_dict:
        body = _arr(struct_dict['Body'])
        sum_y = body.sum(axis=(1,2)); sum_x = body.sum(axis=(0,2))
        cy = int(round((np.arange(ny)*sum_y).sum()/(sum_y.sum() or 1)))
        cx = int(round((np.arange(nx)*sum_x).sum()/(sum_x.sum() or 1)))
        y0, y1 = max(0, cy-crop_size), min(ny, cy+crop_size)
        x0, x1 = max(0, cx-crop_size), min(nx, cx+crop_size)
    else:
        y0, y1, x0, x1 = 0, ny, 0, nx

    # Crop data
    img_c = img[y0:y1, x0:x1, :]
    gtv_c = gtv[y0:y1, x0:x1, :]

    # Physical coordinates
    y_mm = np.arange(y0, y1)*sy
    x_mm = np.arange(x0, x1)*sx
    Xg, Yg = np.meshgrid(x_mm, y_mm)

    # Plane normal components
    theta, phi, L = plane_coeffs
    nx_f = np.sin(phi)*np.cos(theta)
    ny_f = np.sin(phi)*np.sin(theta)
    nz_f = np.cos(phi)
    # XY unit vector along plane normal projection
    norm_xy = np.hypot(nx_f, ny_f)
    ux = nx_f/norm_xy if norm_xy>0 else 0
    uy = ny_f/norm_xy if norm_xy>0 else 0

    # Midline constant
    z_mm = z_index * sz
    C = L - nz_f * z_mm

    # Clipping bounds in mm
    x_min, x_max = x_mm.min(), x_mm.max()
    y_min, y_max = y_mm.min(), y_mm.max()

    # Subplots
    n = 3
    fig_w, fig_h = figsize_per[0]*n, figsize_per[1]
    fig, axes = plt.subplots(1, n, figsize=(fig_w, fig_h), constrained_layout=True)

    for ax, shift, color in zip(axes, shifts, POSITION_COLORS):
        # show image
        ax.imshow(
            img_c[:, :, z_index], cmap='gray',
            extent=[x_min, x_max, y_max, y_min]
        )
        # draw GTV-T contour
        dx = -ux * shift * sx
        dy = -uy * shift * sy
        sl = gtv_c[:, :, z_index]
        if sl.any():
            ax.contour(
                Xg + dx, Yg + dy, sl,
                levels=[0.5], colors=[color], linewidths=2
            )
        # plot plane intersection line
        if abs(ny_f) > 1e-6:
            xs = np.array([x_min, x_max])
            ys = (C - nx_f * xs) / ny_f
        else:
            x_c = C / nx_f if abs(nx_f) > 1e-6 else 0
            xs = np.array([x_c, x_c])
            ys = np.array([y_min, y_max])
        ax.plot(xs, ys, '-', color=PLANE_COLOR, lw=2)
        # enforce equal axis limits and aspect
        ax.set_xlim(x_min, x_max)
        ax.set_ylim(y_max, y_min)  # match extent reverse
        ax.set_aspect('equal')
        ax.axis('off')
        # ax.set_title(f'Shift={shift}' if shift else 'Central')
        # ax.set_title(f'Shift={shift}' if shift else 'Central')

    # save
    fig.savefig(save_path, format='pdf', bbox_inches='tight', pad_inches=0)
    plt.close(fig)




def plot_huber_loss(deltas, save_path, residual_range=(-5.0, 5.0), num_points=500):
    """
    Plot the Huber loss function for different delta values and save as a PDF.

    Parameters
    ----------
    deltas : array-like of float
        Sequence of delta thresholds for the Huber loss.
    save_path : str
        Filesystem path (including .pdf extension) where the figure will be saved.
    residual_range : tuple of float, optional
        (min, max) range of residual values over which to evaluate the loss.
        Default is (-5.0, 5.0).
    num_points : int, optional
        Number of points in the residual grid. Default is 500.

    The Huber loss Lδ(r) is defined as:
        Lδ(r) = { 0.5 * r²                    if |r| ≤ δ
               { δ * (|r| − 0.5 * δ)         if |r| > δ
    """
    # Create grid of residuals
    r = np.linspace(residual_range[0], residual_range[1], num_points)

    plt.figure()
    for δ in deltas:
        # Vectorised computation of Huber loss
        abs_r = np.abs(r)
        quadratic = 0.5 * r**2
        linear = δ * (abs_r - 0.5 * δ)
        huber = np.where(abs_r <= δ, quadratic, linear)

        plt.plot(r, huber, label=f'δ = {δ}')

    plt.title('Huber Loss for Various δ')
    plt.xlabel('Residual (r)')
    plt.ylabel('Huber Loss Lδ(r)')
    plt.legend()
    plt.tight_layout()

    # Save figure as PDF
    plt.savefig(save_path, format='pdf')
    plt.close()
    


## Plane Parameter Tweaking


In [306]:
# Simple function to get current widget parameters - call this directly!
def get_params():
    """
    Simple function to get current widget parameters.
    Just run: get_params()
    """
    try:
        if hasattr(display_interactive_plane_widget, 'sliders'):
            sliders = display_interactive_plane_widget.sliders
            
            # Get values
            theta_deg = sliders['theta'].value
            phi_deg = sliders['phi'].value  
            L_mm = sliders['L'].value
            z_idx = sliders['z'].value
            y_idx = sliders['y'].value
            
            # Convert to radians
            theta_rad = np.deg2rad(theta_deg)
            phi_rad = np.deg2rad(phi_deg)
            
            # Print formatted output
            print("=" * 60)
            print("CURRENT WIDGET PARAMETERS")
            print("=" * 60)
            print(f"θ (Theta): {theta_deg:7.2f}° = {theta_rad:8.4f} rad")
            print(f"φ (Phi):   {phi_deg:7.2f}° = {phi_rad:8.4f} rad") 
            print(f"L (Dist):  {L_mm:7.2f} mm")
            print(f"Z slice:   {z_idx}")
            print(f"Y line:    {y_idx}")
            print("-" * 60)
            print("COPY-READY FORMATS:")
            print("-" * 60)
            param_tuple_rad = f"({theta_rad:.4f}, {phi_rad:.4f}, {L_mm:.2f})"
            param_tuple_deg = f"({theta_deg:.2f}, {phi_deg:.2f}, {L_mm:.2f})"
            print(f"Radians: {param_tuple_rad}")
            print(f"Degrees: {param_tuple_deg}")
            print("=" * 60)
            
            return {
                'theta_rad': theta_rad, 'phi_rad': phi_rad, 'L_mm': L_mm,
                'theta_deg': theta_deg, 'phi_deg': phi_deg,
                'z_index': z_idx, 'y_line': y_idx,
                'tuple_rad': (theta_rad, phi_rad, L_mm),
                'tuple_deg': (theta_deg, phi_deg, L_mm)
            }
        else:
            print("❌ Widget not found!")
            print("Make sure you have:")
            print("1. Executed the display_interactive_plane_widget function")
            print("2. The widget is currently active")
            return None
            
    except Exception as e:
        print(f"❌ Error: {e}")
        return None


def display_interactive_plane_widget(
    struct_dict: dict,
    voxel_size: tuple,
    initial_plane_params: tuple,
    theta_range: tuple = (-180, 180),
    phi_range: tuple = (0, 180), 
    L_range: tuple = None
):
    """
    Display interactive axial and coronal views with sliders to adjust plane parameters in real-time.
    
    Parameters
    ----------
    struct_dict : dict
        Dictionary of structures, must include 'Image'.
    voxel_size : tuple of float
        Spacing in mm along each axis (sy, sx, sz).
    initial_plane_params : tuple
        Initial plane parameters (theta_rad, phi_rad, L_mm).
    theta_range : tuple
        Range for theta slider in degrees (min, max). Default (-180, 180).
    phi_range : tuple  
        Range for phi slider in degrees (min, max). Default (0, 180).
    L_range : tuple, optional
        Range for L slider in mm (min, max). If None, auto-computed from image bounds.
    """
    import ipywidgets as widgets
    from IPython.display import display
    
    def _array(obj):
        return obj[0] if isinstance(obj, (tuple, list)) else obj

    # Load image and get dimensions
    image = _array(struct_dict['Image'])
    H, W, D = image.shape
    sy, sx, sz = voxel_size
    
    # Convert initial parameters to degrees for display
    theta_init_deg = np.rad2deg(initial_plane_params[0])
    phi_init_deg = np.rad2deg(initial_plane_params[1])
    L_init = initial_plane_params[2]
    
    # Auto-compute L range if not provided
    if L_range is None:
        # Estimate reasonable L range based on image physical dimensions
        max_dim = max(H * sy, W * sx, D * sz)
        L_range = (-max_dim, max_dim)
    
    # Calculate center of mass using the Body structure for cropping
    # Use Body structure if available, otherwise fallback to geometric center
    if 'Body' in struct_dict:
        try:
            body_mask = _array(struct_dict['Body'])
            if body_mask.any():
                # Get indices where Body structure is present
                coords = np.where(body_mask > 0)
                center_y = int(np.mean(coords[0]))
                center_x = int(np.mean(coords[1]))
                print(f"Using Body structure center of mass: Y={center_y}, X={center_x}")
            else:
                # Body structure exists but is empty
                center_y = H // 2
                center_x = W // 2
                print(f"Body structure is empty, using geometric center: Y={center_y}, X={center_x}")
        except Exception as e:
            # Error accessing Body structure
            center_y = H // 2
            center_x = W // 2
            print(f"Error accessing Body structure ({e}), using geometric center: Y={center_y}, X={center_x}")
    else:
        # No Body structure available, use geometric center
        center_y = H // 2
        center_x = W // 2
        print(f"No Body structure found, using geometric center: Y={center_y}, X={center_x}")
    
    # Define cropping around center of mass (150 voxels in each direction)
    crop_radius = 150
    
    # Calculate crop boundaries
    y_start = max(0, center_y - crop_radius)
    y_end = min(H, center_y + crop_radius)
    x_start = max(0, center_x - crop_radius)
    x_end = min(W, center_x + crop_radius)
    z_start = 0  # Keep full Z range
    z_end = D
    
    # Ensure we have at least the requested size if image is large enough
    if y_end - y_start < 2 * crop_radius and H >= 2 * crop_radius:
        if y_start == 0:
            y_end = min(H, 2 * crop_radius)
        elif y_end == H:
            y_start = max(0, H - 2 * crop_radius)
    
    if x_end - x_start < 2 * crop_radius and W >= 2 * crop_radius:
        if x_start == 0:
            x_end = min(W, 2 * crop_radius)
        elif x_end == W:
            x_start = max(0, W - 2 * crop_radius)
    
    # Ensure minimum dimensions for very small images
    min_size = 50
    if y_end - y_start < min_size:
        center_y = H // 2
        y_start = max(0, center_y - min_size // 2)
        y_end = min(H, y_start + min_size)
    if x_end - x_start < min_size:
        center_x = W // 2
        x_start = max(0, center_x - min_size // 2)
        x_end = min(W, x_start + min_size)
    
    print(f"Cropping around center of mass: Y=[{center_y}], X=[{center_x}]")
    print(f"Crop bounds: Y=[{y_start}:{y_end}], X=[{x_start}:{x_end}], Z=[{z_start}:{z_end}]")
    print(f"Cropped size: {y_end-y_start} x {x_end-x_start} x {z_end-z_start} voxels")
    
    # Crop the image
    image_cropped = image[y_start:y_end, x_start:x_end, z_start:z_end]
    H_crop, W_crop, D_crop = image_cropped.shape
    
    # Physical coordinates for cropped image
    y_mm = np.arange(y_start, y_end) * sy
    x_mm = np.arange(x_start, x_end) * sx
    z_mm = np.arange(z_start, z_end) * sz
    
    # Meshgrids for plotting
    X_ax, Y_ax = np.meshgrid(x_mm, y_mm)   # for axial view
    X_cor, Z_cor = np.meshgrid(x_mm, z_mm) # for coronal view
    
    # Colors for structures
    struct_names = [s for s in struct_dict if s != 'Image']
    cmap10 = plt.colormaps['tab10'].colors
    color_cycle = {n: cmap10[i % len(cmap10)] for i, n in enumerate(struct_names)}
    
    def update_plot(theta_deg, phi_deg, L_mm, z_index, y_line):
        """Update the plot with new plane parameters"""
        # Convert degrees to radians
        theta_rad = np.deg2rad(theta_deg)
        phi_rad = np.deg2rad(phi_deg)
        
        # Create figure
        fig, ax = plt.subplots(1, 2, figsize=(15, 7), constrained_layout=True)
        
        # --- AXIAL VIEW ---
        ax0 = ax[0]
        
        # Check bounds for z_index
        if z_index >= D_crop:
            z_index = D_crop // 2
            
        ax0.imshow(image_cropped[:, :, z_index], cmap='gray',
                   extent=[x_mm[0], x_mm[-1], y_mm[-1], y_mm[0]])
        ax0.set_title(f'Axial z={z_index} (z={z_index*sz:.1f} mm)  ' +
                     f'θ={theta_deg:.1f}°, φ={phi_deg:.1f}°, L={L_mm:.1f}mm')
        ax0.set_xlabel('X (sagittal) mm')
        ax0.set_ylabel('Y (coronal) mm')
        
        # Horizontal guide line at current Y slice
        if y_line < len(y_mm):
            y0 = y_mm[y_line]
            ax0.axhline(y=y0, color='yellow', linestyle='--', alpha=0.7)
        
        # Structure contours on axial view
        for name in struct_names:
            mask = _array(struct_dict[name])
            mask_cropped = mask[y_start:y_end, x_start:x_end, z_start:z_end]
            if z_index < mask_cropped.shape[2]:
                sl = mask_cropped[:, :, z_index]
                if sl.any():
                    ax0.contour(X_ax, Y_ax, sl, colors=[color_cycle[name]], 
                               linewidths=1.5, alpha=0.8)
        
        # Plane intersection on axial view
        nx = np.sin(phi_rad) * np.cos(theta_rad)
        ny = np.sin(phi_rad) * np.sin(theta_rad)
        nz = np.cos(phi_rad)
        C = L_mm - nz * (z_index * sz)
        F_axial = nx * X_ax + ny * Y_ax - C
        ax0.contour(X_ax, Y_ax, F_axial, levels=[0], colors=['red'], linewidths=2)
        
        # --- CORONAL VIEW ---
        ax1 = ax[1]
        
        # Check bounds for y_line
        if y_line >= H_crop:
            y_line = H_crop // 2
            
        # Extract coronal slice (using same cropping as axial view)
        sl_full = image_cropped[y_line, :, :].T  # Transpose for proper orientation
        
        ax1.imshow(sl_full, cmap='gray', origin='lower',
                   extent=[x_mm[0], x_mm[-1], z_mm[0], z_mm[-1]])
        ax1.set_title(f'Coronal y={y_line + y_start} (y={(y_line + y_start)*sy:.1f} mm)')
        ax1.set_xlabel('X (sagittal) mm')
        ax1.set_ylabel('Z (axial) mm')
        
        # Vertical guide line at current Z slice
        z0 = z_mm[z_index] if z_index < len(z_mm) else z_mm[len(z_mm)//2]
        ax1.axhline(y=z0, color='yellow', linestyle='--', alpha=0.7)
        
        # Structure contours on coronal view
        for name in struct_names:
            mask = _array(struct_dict[name])
            mask_cropped = mask[y_start:y_end, x_start:x_end, z_start:z_end]
            if y_line < mask_cropped.shape[0]:
                slm_full = mask_cropped[y_line, :, :].T  # Transpose for proper orientation
                if slm_full.any():
                    ax1.contour(X_cor, Z_cor, slm_full, colors=[color_cycle[name]], 
                               linewidths=1.5, alpha=0.8)
        
        # Plane intersection on coronal view
        y_phys = (y_line + y_start) * sy
        F_coronal = nx * X_cor + ny * y_phys + nz * Z_cor - L_mm
        ax1.contour(X_cor, Z_cor, F_coronal, levels=[0], colors=['red'], linewidths=2)
        
        # Add legend
        handles = []
        for name in struct_names:
            handles.append(plt.Line2D([0], [0], color=color_cycle[name], lw=2, label=name))
        handles.append(plt.Line2D([0], [0], color='red', lw=2, label='MSP'))
        ax0.legend(handles=handles, loc='upper right', fontsize=10)
        
        plt.show()
    
    # Create interactive widget with sliders
    theta_slider = widgets.FloatSlider(
        value=theta_init_deg, min=theta_range[0], max=theta_range[1], step=0.1,
        description='θ (deg):', continuous_update=False, style={'description_width': 'initial'}
    )
    
    phi_slider = widgets.FloatSlider(
        value=phi_init_deg, min=phi_range[0], max=phi_range[1], step=0.1,
        description='φ (deg):', continuous_update=False, style={'description_width': 'initial'}
    )
    
    L_slider = widgets.FloatSlider(
        value=L_init, min=L_range[0], max=L_range[1], step=1.0,
        description='L (mm):', continuous_update=False, style={'description_width': 'initial'}
    )
    
    z_slider = widgets.IntSlider(
        value=D_crop//2, min=0, max=D_crop-1, step=1,
        description='Z slice:', continuous_update=False, style={'description_width': 'initial'}
    )
    
    y_slider = widgets.IntSlider(
        value=H_crop//2, min=0, max=H_crop-1, step=1,
        description='Y slice:', continuous_update=False, style={'description_width': 'initial'}
    )
    
    # Create interactive widget
    widgets.interact(update_plot,
                    theta_deg=theta_slider,
                    phi_deg=phi_slider, 
                    L_mm=L_slider,
                    z_index=z_slider,
                    y_line=y_slider)
    
    # Display current parameters
    def get_current_params():
        """Get current plane parameters in radians"""
        return (np.deg2rad(theta_slider.value), 
                np.deg2rad(phi_slider.value), 
                L_slider.value)
    
    # Store reference to sliders for external access
    display_interactive_plane_widget.sliders = {
        'theta': theta_slider,
        'phi': phi_slider, 
        'L': L_slider,
        'z': z_slider,
        'y': y_slider
    }
    display_interactive_plane_widget.get_current_params = get_current_params

### Mid-Sagittal Plane Detection Pipeline

In [307]:
def MSP_pipeline(base_path, 
                      output_path,
                      structure_names = ["Image"],
                      slice_axis = 2,
                      HU_range=[300, 1500],
                      slice_range=None,
                      azimuthal=(0, 90),
                      polar=(90, 45),
                      initialization_steps=10,
                      verification_widget = False,
                      parameter_tweaking = False,
                      patient = None
                      ):
    """
    Complete pipeline for mid-sagittal plane detection across multiple patients.

    Parameters
    ----------
    base_path : str
        Path to the base directory containing patient folders.
    output_path : str
        Path to the directory where results will be saved.
    structure_names : list of str, optional
        List of structure names to load for each patient.
    slice_axis : int, optional
        Axis along which to crop (default: 2, axial).
    HU_range : list or tuple of (int, int), optional
        HU range for bone thresholding.
    slice_range : tuple or None, optional
        (start, end) indices for cropping. If None, uses GTVp mask or full range.
    azimuthal : tuple of (float, float), optional
        Center and half-width of azimuthal angle range in degrees.
    polar : tuple of (float, float), optional
        Center and half-width of polar angle range in degrees.
    initialization_steps : int, optional
        Number of steps for each angle in initialization.
    delta : float, optional
        Huber loss delta parameter.
    widget : bool, optional
        If True, display interactive widgets for visualization.

    Returns
    -------
    None
    """
    start_pipeline = time.time()
    optimized_parameter_list = []
    optimized_objective_value_list = []
    folders = sorted([d for d in os.listdir(base_path)], key=int)
    for folder in folders:
        if patient is not None:
            folder = str(patient)
        folder_path = os.path.join(base_path, folder)
        print(f"Processing folder: {folder_path}")
        output_path_patient = os.path.join(output_path, folder)
        os.makedirs(output_path_patient, exist_ok=True)
        struct_dict = load_patient_structures(folder_path, structure_names)
        voxel_size = struct_dict["Image"][1]
        
        if struct_dict is None:
            print(f"No matching files found in {folder_path}. Skipping...")
            continue
        
        struct_dict_cropped, interpolation_image = crop_patient_volumes(struct_dict, slice_axis=slice_axis, slice_range=slice_range)

        image, bone_ct = preprocess_bone_image(struct_dict_cropped, HU_range=HU_range)

         
        interpolator = get_cached_interpolator(output_path_patient, interpolation_image, voxel_size)

        
        initial_plane, delta = parameter_initialization(image, bone_ct, output_path_patient, interpolator, voxel_size,
                                                    azimuthal_deg_range=azimuthal, polar_deg_range=polar,
                                                    initialization_steps=initialization_steps)
        
        #f = compute_objective_initialization(initial_plane, bone_ct, interpolator, voxel_size)
        
        #plot_huber_loss([1000, 500, 300], os.path.join(output_path_patient, "huber_loss.pdf"),residual_range=(-3000, 3000), num_points=500)
                        
        
        optimized_parameters, optimized_objective_values = run_or_load_optimization(output_path_patient,
                                                                                     image,
                                                                                     voxel_size,
                                                                                     delta,
                                                                                     initial_plane,
                                                                                     bone_ct,
                                                                                     interpolator,
                                                                                     optimized_parameter_list,
                                                                                     optimized_objective_value_list)
        
        # optimized_parameters_1, optimized_objective_values_1 = run_or_load_optimization(r"/home/loriskeller/Documents/Master Project/Results/03.08.25/Pure qudratic objective/azimuthal/-20°/4927494",
        #                                                                              image,
        #                                                                              voxel_size,
        #                                                                              delta,
        #                                                                              initial_plane,
        #                                                                              bone_ct,
        #                                                                              interpolator,
        #                                                                              optimized_parameter_list,
        #                                                                              optimized_objective_value_list)


        robustness_test(
            image=image,
            bone=bone_ct,
            output_path=output_path_patient,
            interpolator_intensity=interpolator,
            voxel_size=voxel_size,
            azimuthal_deg_range=(0, 45),  # Uses the azimuthal parameter from MSP_pipeline
            polar_deg_range=(90, 45),         # Uses the polar parameter from MSP_pipeline
            initialization_steps=20,  # Uses the initialization_steps parameter
            optimal_plane_params=optimized_parameters   # Uses the optimized parameters as ground truth
        )

        # save_3d_slice_as_vector(image, output_path_patient, index=6, factor=100, filename = f"image")
        # save_3d_slice_as_vector(bone_ct, output_path_patient, index=6, factor=100, filename = f"bone")
        # assume struct_dict originally has a key "GTVp"
        struct_dict['GTV-T'] = struct_dict.pop('GTVp')
        # plot_and_save_axial_slice(
        #     struct_dict_cropped,
        #     voxel_size,
        #     z_indices=[2,7],
        #     save_path=os.path.join(r"/home/loriskeller/Documents/Master Project/Results/Images for Report", "no_mandible_vs_mandible_results.pdf"),
        #     crop_size=200,
        #     plane_coeffs_list=[optimized_parameters, optimized_parameters_1],
        #     optimization_methods_list=["MSP A", "MSP B"]
        # )

        # plot_gtvt_positions(
        #     struct_dict_cropped,
        #     voxel_size,
        #     z_index=12,
        #     plane_coeffs=optimized_parameters,
        #     save_path=os.path.join(r"/home/loriskeller/Documents/Master Project/Results/Images for Report", "gtvt_positions.pdf"),
        #     shifts=[25, 10, 0],
        #     crop_size=150,
        #     figsize_per=(5, 5)
        # )

        # optimized_parameters_manual_10587029 = (-0.06283185307179585, 1.6244956418729801, 262)
        # optimized_parameters_manual_10774767 = (-0.0017453292519943233, 1.594660423674561, 240)
        # optimized_parameters_manual_10780163 = (-0.03490658503988659, 1.5982842763770915, 269)

        # plot_axial_coronal(
        #     struct_dict_cropped,
        #     voxel_size,
        #     z_index=11,
        #     y_line=257,
        #     save_path=os.path.join(r"/home/loriskeller/Documents/Master Project/Results/Images for Report", "plot.pdf"),
        #     plane_coeffs_list=[optimized_parameters],
        #     optimization_methods_list=["MSP optimized"],
        #     crop_size=150,
        #     coronal_crop=(200, 360, 0, 512),
        #     figsize=(12, 6)
        # )

        if verification_widget:
            %matplotlib widget
            display_scrollable_views(struct_dict_cropped, voxel_size,
                                    plane_coeffs_list=[optimized_parameters],
                                    optimization_methods_list=["MSP optimized"])
        
        if parameter_tweaking:
            %matplotlib widget
            # Interactive widget with sliders to adjust plane parameters
            display_interactive_plane_widget(
                struct_dict=struct_dict_cropped,
                voxel_size=voxel_size,
                initial_plane_params=optimized_parameters,
                theta_range=(-10, 10),
                phi_range=(80, 100),
                L_range=(200, 300)
            )
        
        if patient is not None:
            break
                                                
    if patient is None:
        np.save(os.path.join(output_path, "optimized_parameters.npy"), np.array(optimized_parameter_list))
        np.save(os.path.join(output_path, "objective_values.npy"), np.array(optimized_objective_value_list))

    end_pipeline = time.time()
    if patient is None:
        print(f"Pipeline completed for {len(folders)} patients in {end_pipeline - start_pipeline:.2f} seconds.")
    else:
        print(f"Pipeline completed for patient {patient} in {end_pipeline - start_pipeline:.2f} seconds.")

## Playground

In [308]:
MSP_pipeline(
    base_path =r"/home/loriskeller/Documents/Master Project/Patient data/patient_data_complete/Patient_structures_clean",
    output_path = r"/home/loriskeller/Documents/Master Project/Results/03.08.25/test robustness/heatmaps/huber loss/20 steps",
    structure_names = ["Image", "Body", "GTVp", "Spinal Cord", "Mandible"],
    slice_axis = 2,
    HU_range=[300, 2800],
    slice_range=None,
    azimuthal=(0, 20),
    polar=(90, 20),
    initialization_steps=10,
    verification_widget=False,
    parameter_tweaking = False,
    patient = 4927494
)

Processing folder: /home/loriskeller/Documents/Master Project/Patient data/patient_data_complete/Patient_structures_clean/4927494
Loaded 'Body' from /home/loriskeller/Documents/Master Project/Patient data/patient_data_complete/Patient_structures_clean/4927494/Body.nii.gz
Loaded 'Body' from /home/loriskeller/Documents/Master Project/Patient data/patient_data_complete/Patient_structures_clean/4927494/Body.nii.gz
Loaded 'Mandible' from /home/loriskeller/Documents/Master Project/Patient data/patient_data_complete/Patient_structures_clean/4927494/Mandible.nii.gz
Loaded 'Mandible' from /home/loriskeller/Documents/Master Project/Patient data/patient_data_complete/Patient_structures_clean/4927494/Mandible.nii.gz
Loaded 'Image' from /home/loriskeller/Documents/Master Project/Patient data/patient_data_complete/Patient_structures_clean/4927494/image.nii.gz
Loaded 'Image' from /home/loriskeller/Documents/Master Project/Patient data/patient_data_complete/Patient_structures_clean/4927494/image.nii.g

KeyboardInterrupt: 

In [223]:
get_params()

CURRENT WIDGET PARAMETERS
θ (Theta):   -0.10° =  -0.0017 rad
φ (Phi):     91.37° =   1.5947 rad
L (Dist):   240.00 mm
Z slice:   13
Y line:    150
------------------------------------------------------------
COPY-READY FORMATS:
------------------------------------------------------------
Radians: (-0.0017, 1.5947, 240.00)
Degrees: (-0.10, 91.37, 240.00)


{'theta_rad': np.float64(-0.0017453292519943233),
 'phi_rad': np.float64(1.594660423674561),
 'L_mm': 240.0,
 'theta_deg': -0.09999999999999964,
 'phi_deg': 91.36731203309608,
 'z_index': 13,
 'y_line': 150,
 'tuple_rad': (np.float64(-0.0017453292519943233),
  np.float64(1.594660423674561),
  240.0),
 'tuple_deg': (-0.09999999999999964, 91.36731203309608, 240.0)}