# Mid-Sagittal plane algorithm

### Import relevant packages

In [None]:
import joblib
import nibabel as nib
import matplotlib.pyplot as plt
import numpy as np
from scipy.ndimage import center_of_mass
from scipy.optimize import minimize
import pandas as pd
from scipy.optimize import curve_fit
from scipy.interpolate import RegularGridInterpolator
from joblib import Parallel, delayed, dump, load
from matplotlib.widgets import Slider
import ipywidgets as widgets
from matplotlib.lines import Line2D
import time
from scipy import ndimage
import json
import skimage
from scipy.spatial import ConvexHull, distance
import seaborn as sns
import pickle

## Patterns of structures

In [2]:
patterns = {

    # Improved "gtvp" pattern
    "image": r"image",

    "gtvp": r"\b(?:klin|vorschlag|pr[ae]?eop|v1)?[\s._-]*gtv[\s._-]*p[\s._-]*t?[\s._-]*\d*(?:new|rimary|pr[ae]?eop|v1|xxgy|74\.4|70|ptv1)?[\s._-]*(gy)?[\s._-]*(1a)?[\s._-]*(1b)?\b",

    # Exact match for "body" with optional numbers or symbols following it
    "body": r"(?:^body[\s._-]*\d?$|^skin[\s._-]*\d?$)",

    # Improved "spinal cord" pattern for clearer boundary matching, including "myelon" as an alternative
    "spinal cord": r"(?:spinal[\s._-]*cord$|^myelon$|myelon[\s+]*5mm|spinal[\s._-]*canal)",

    # Matches strings starting with "mandib"
    "mandibula": r"^mandib",

}

## Data Reading

In [3]:
def pad_to_eight_characters(input_list):
    """
    Check if each string in the list has 8 characters, and if not, add zeros at the beginning to make it 8 characters long.

    Parameters:
    input_list (list): A list of strings.

    Returns:
    list: A list of strings with each string padded to 8 characters.
    """
    return [s.zfill(8) for s in input_list]

def load_nifti_file(file_path):
    """
    Load a NIfTI file and return the image data.
    
    Parameters:
    file_path (str): Path to the NIfTI file.
    
    Returns:
    numpy.ndarray: The image data.
    """
    import nibabel as nib

    nifti_img = nib.load(file_path)
    img_data = nifti_img.get_fdata()
    voxel_size = nifti_img.header.get_zooms()
    
    return img_data, voxel_size

def open_gzip_file(gzip_file_path):
    """
    Open a gzip file and return its content.

    Parameters:
    gzip_file_path (str): The path to the gzip file.

    Returns:
    bytes: The content of the gzip file.
    """
    try:
        with gzip.open(gzip_file_path, 'rb') as f_in:
            file_content = f_in.read()
        return file_content
    except Exception as e:
        print(f'Error opening {gzip_file_path}: {e}')
        return None
    
def get_image_and_voxel_size_from_gzip(gzip_file_path):
    """
    Get the image array and voxel size from a gzipped NIfTI file.

    Parameters:
    gzip_file_path (str): Path to the gzipped NIfTI file.

    Returns:
    tuple: The image array and voxel size.
    """

    file_content = open_gzip_file(gzip_file_path)
    if file_content is not None:
        with open('temp_nifti.nii', 'wb') as temp_file:
            temp_file.write(file_content)

        # 🔹 Debugging: Check if the file was actually written
        file_size = os.path.getsize('temp_nifti.nii')
        if file_size == 0:
            print(f"❌ Error: 'temp_nifti.nii' was written but is empty! ({gzip_file_path})")
            return None, None
        else:
            print(f"✅ Success: 'temp_nifti.nii' was written successfully ({file_size} bytes).")

        img_data, voxel_size = load_nifti_file('temp_nifti.nii')
        os.remove('temp_nifti.nii')  # Remove temp file after reading
        
        return img_data, voxel_size
    else:
        print(f"❌ Error: Failed to read file content from '{gzip_file_path}'")
        return None, None
    
def load_patient_structures(patient_folder, patterns):
    """
    Searches the patient folder (recursively) for files matching each regex in patterns,
    loads the corresponding NIfTI file, transposes the array using np.transpose(array, (1,0,2)),
    and returns a dictionary mapping each pattern key to a tuple (image_array, voxel_size).

    Parameters
    ----------
    patient_folder : str
        The folder path for the patient.
    patterns : dict
        A dictionary where keys are structure names (e.g. "image", "gtvp", etc.)
        and values are regex patterns (as raw strings) used to match file names.

    Returns
    -------
    dict
        Dictionary where each key (from patterns) is mapped to a tuple (img_array, voxel_size).
        Only the first matching file is loaded for each key.
    """
    struct_dict = {}
    
    # Walk recursively in the patient folder.
    for root, dirs, files in os.walk(patient_folder):
        for f in files:
            # Only consider NIfTI files.
            if not (f.endswith(".nii.gz") or f.endswith(".nii")):
                continue
            
            # Build a search string from the filename.
            # If the file starts with "mask_", remove that prefix.
            if f.startswith("mask_"):
                if f.endswith(".nii.gz"):
                    f_search = f[len("mask_"):-len(".nii.gz")]
                else:  # endswith(".nii")
                    f_search = f[len("mask_"):-len(".nii")]
            else:
                if f.endswith(".nii.gz"):
                    f_search = f[:-len(".nii.gz")]
                else:  # endswith(".nii")
                    f_search = f[:-len(".nii")]
            
            # Loop over each pattern.
            for key, pat in patterns.items():
                # Skip this pattern if we have already loaded a file for it.
                if key in struct_dict:
                    continue
                # Use re.search (case-insensitive) on the f_search string.
                if re.search(pat, f_search, flags=re.IGNORECASE):
                    file_path = os.path.join(root, f)
                    # Load the file using the appropriate function.
                    if f.endswith(".nii.gz"):
                        img, voxel_size = get_image_and_voxel_size_from_gzip(file_path)
                    else:
                        img, voxel_size = load_nifti_file(file_path)
                    if img is not None:
                        # Transpose the image as required.
                        img = np.transpose(img, (1, 0, 2))
                        struct_dict[key] = (img, voxel_size)
                        print(f"Loaded '{key}' from {file_path}")
                    # Stop checking other patterns for this file once a match is found.
                    break
    return struct_dict

def load_patient_data_from_csv(csv_file, root_folder, pat_id = None):
    """
    Loads patient image data and structure masks from a CSV file and a root folder.
    
    The CSV file is expected to have headers:
      Patient ID, Extention, Position, GTVp, Body, Mandible, Spinal Cord
    with one row per patient.
    
    For each patient:
      - The patient folder is assumed to be in root_folder and named using the Patient ID
        (padded to 8 digits, with any decimals removed).
      - The main image is loaded from "image.nii.gz" (searched recursively in the patient folder).
      - For each structure column (GTVp, Body, Mandible, Spinal Cord):
            If the cell is non-empty, the function searches recursively in the patient folder
            for a file whose name exactly matches the cell value.
      - Each loaded NIfTI file is read (using get_image_and_voxel_size_from_gzip if gzipped,
        or load_nifti_file if not), and its array is transposed using np.transpose(array, (1,0,2)).
    
    The results are stored in a dictionary for each patient with keys:
      "Image", "GTVp", "Body", "Mandibula", "Spinal Cord"
    (Note: you can adjust the keys if needed.)
    
    Parameters
    ----------
    csv_file : str
        Path to the CSV file.
    root_folder : str
        Root folder that contains patient folders (each named after a patient ID).
        
    Returns
    -------
    dict
        A dictionary mapping each patient ID (string) to another dictionary with keys as above,
        and values equal to tuples (image_array, voxel_size).
    """
    #patient_data = {}
    
    # Read CSV file.
    df = pd.read_csv(csv_file)
    
    # Define the structure columns (excluding Patient ID, Extention, and Position)
    structure_columns = ["GTVp", "Body", "Mandible", "Spinal Cord"]

    if pat_id is not None:
        df = df[df["Patient ID"] == pat_id]

    for idx, row in df.iterrows():
        # Get patient id, remove decimals and pad to 8 digits.
        print(f"Processing row {idx} out of {len(df)} ...")
        try:
            pid_num = float(row["Patient ID"])
            patient_id = str(int(pid_num)).zfill(8)
        except Exception as e:
            print(f"Invalid Patient ID at row {idx}: {row['Patient ID']}. Skipping.")
            continue
        if np.isnan(pid_num):
            print(f"Invalid Patient ID at row {idx}: {row['Patient ID']}. Skipping.")
            continue
        print(f"Processing patient {patient_id} ...")
        if pd.isna(row["GTVp"]):
            print(f"Patient {patient_id} has no GTVp. Skipping.")
            continue
        # Patient folder is assumed to be root_folder/patient_id
        patient_folder = os.path.join(root_folder, patient_id)
        if not os.path.isdir(patient_folder):
            print(f"Patient folder {patient_folder} not found. Skipping patient {patient_id}.")
            continue
        
        
        pdata = {}
        
        # 1. Load the main image ("image.nii.gz")
        main_image_path = None
        # Search for "image.nii.gz" in patient_folder recursively.
        for root, dirs, files in os.walk(patient_folder):
            if "image.nii.gz" in files:
                main_image_path = os.path.join(root, "image.nii.gz")
                break
        if main_image_path is None:
            print(f"Main image 'image.nii.gz' not found for patient {patient_id}. Skipping.")
            continue
        # Load main image.
        img_data, voxel_size = get_image_and_voxel_size_from_gzip(main_image_path)
        if img_data is None:
            print(f"Failed to load main image for patient {patient_id}.")
            continue
        img_data = np.transpose(img_data, (1, 0, 2))
        print(f"Hu value range: {np.min(img_data)}, {np.max(img_data)}")
        pdata["Image"] = (img_data, voxel_size)
        
        # 2. For each structure, read the corresponding file.
        for col in structure_columns:
            file_name = str(row[col]).strip()
            if not file_name or file_name.lower() in ['nan', '']:
                # Skip if cell is empty.
                continue
            
            structure_file_path = None
            # Search recursively for an exact filename match.
            for root, dirs, files in os.walk(patient_folder):
                # Compare case-insensitively.
                for f in files:
                    if f.lower() == file_name.lower():
                        structure_file_path = os.path.join(root, f)
                        break
                if structure_file_path is not None:
                    break
            
            if structure_file_path is None:
                print(f"File for structure '{col}' with name '{file_name}' not found for patient {patient_id}.")
                continue
            
            # Load the structure file.
            if structure_file_path.endswith(".nii.gz"):
                struct_img, struct_voxel_size = get_image_and_voxel_size_from_gzip(structure_file_path)
                
            else:
                struct_img, struct_voxel_size = load_nifti_file(structure_file_path)
            if struct_img is None:
                print(f"Failed to load structure '{col}' for patient {patient_id} from file {structure_file_path}.")
                continue
            struct_img = np.transpose(struct_img, (1, 0, 2))
            if struct_img.size == 0:
                print(f"Structure '{col}' for patient {patient_id} is empty. Skipping.")
                continue
            pdata[col] = (struct_img, struct_voxel_size)
            print(f"Loaded {col} from {structure_file_path}")
            
            # if "GTVp" in pdata:
            #     gtvp_array = pdata["GTVp"][0]
            #     nonzero_count = np.count_nonzero(gtvp_array)
            #     print(f"Number of nonzero elements in 'gtvp': {nonzero_count}")
        # Store the patient data.
        #patient_data[patient_id] = pdata
        
        # Optionally, you could display an interactive widget here for this patient:
        display_patient_overlay_structures(pdata, title=f"Patient {patient_id} Overlay")
        
    
    #return patient_data

def load_patient_data_from_row(row, root_folder):
    """
    Given one row from the CSV and the root folder containing patient folders,
    load the patient's main image and structure files as specified by the row.
    
    The CSV is expected to have headers:
      Patient ID, Extention, Position, GTVp, Body, Mandible, Spinal Cord.
      
    For the patient:
      - The patient folder is assumed to be located at root_folder/patient_id,
        where patient_id is the Patient ID padded to 8 digits (with any decimals removed).
      - The main image is loaded from "image.nii.gz" (searched recursively).
      - For each structure column (GTVp, Body, Mandible, Spinal Cord), if the CSV cell is non-empty,
        the function searches recursively for a file whose name exactly matches the cell value.
      - Each loaded NIfTI file is transposed using np.transpose(array, (1, 0, 2)).
      
    Parameters
    ----------
    row : pandas.Series
        One row from the CSV file.
    root_folder : str
        The root folder that contains patient folders (named by padded patient IDs).
    
    Returns
    -------
    tuple or (None, patient_id)
        If successful, returns a tuple (patient_data, patient_id) where patient_data is a dictionary
        with keys such as "Image", "GTVp", "Body", "Mandible", and "Spinal Cord" mapping to a tuple
        (image_array, voxel_size). If the main image cannot be loaded or the patient folder is missing,
        returns (None, patient_id).
    """
    # Convert Patient ID to a float, then to an int to remove decimals, and pad to 8 digits.
    try:
        pid_num = float(row["Patient ID"])
        patient_id = str(int(pid_num)).zfill(8)
    except Exception as e:
        print(f"Invalid Patient ID '{row['Patient ID']}' in row, skipping.")
        return None, None

    # Build the patient folder path.
    patient_folder = os.path.join(root_folder, patient_id)
    if not os.path.isdir(patient_folder):
        print(f"Patient folder {patient_folder} not found for patient {patient_id}.")
        return None, patient_id

    pdata = {}

    # 1. Load the main image ("image.nii.gz")
    main_image_path = None
    for root_dir, dirs, files in os.walk(patient_folder):
        if "image.nii.gz" in files:
            main_image_path = os.path.join(root_dir, "image.nii.gz")
            break
    if main_image_path is None:
        print(f"Main image 'image.nii.gz' not found for patient {patient_id}.")
        return None, patient_id

    img_data, voxel_size = get_image_and_voxel_size_from_gzip(main_image_path)
    if img_data is None:
        print(f"Failed to load main image for patient {patient_id}.")
        return None, patient_id
    # Transpose the array as required.
    img_data = np.transpose(img_data, (1, 0, 2))
    pdata["Image"] = (img_data, voxel_size)

    # 2. Process structure columns.
    structure_columns = ["GTVp", "Body", "Mandible", "Spinal Cord"]
    for col in structure_columns:
        file_name = str(row[col]).strip()
        if not file_name or file_name.lower() in ['nan', '']:
            continue  # Skip empty entries.
        
        structure_file_path = None
        # Search recursively in the patient folder for an exact filename match (case-insensitive).
        for root_dir, dirs, files in os.walk(patient_folder):
            for f in files:
                if f.lower() == file_name.lower():
                    structure_file_path = os.path.join(root_dir, f)
                    break
            if structure_file_path is not None:
                break
        
        if structure_file_path is None:
            print(f"File for structure '{col}' with name '{file_name}' not found for patient {patient_id}.")
            continue
        
        # Load the structure file.
        if structure_file_path.endswith(".nii.gz"):
            struct_img, struct_voxel_size = get_image_and_voxel_size_from_gzip(structure_file_path)
            
        else:
            struct_img, struct_voxel_size = load_nifti_file(structure_file_path)
        if struct_img is None:
            print(f"Failed to load structure '{col}' for patient {patient_id} from {structure_file_path}.")
            continue
        
        # Transpose the array.
        struct_img = np.transpose(struct_img, (1, 0, 2))
        pdata[col] = (struct_img, struct_voxel_size)
        print(f"Loaded {col} from {structure_file_path}")
    
    return pdata, patient_id



## Image processing

In [4]:
def mask_via_threshold(ct_image, HU_range=(900, 2500)):
    """
    Generate a bone mask from a CT image by thresholding within a specified HU range.

    Parameters:
    ct_image (numpy.ndarray): The 3D CT image data.
    HU_range (tuple): The range of Hounsfield Units to identify bone. Default is (700, 2000).

    Returns:
    numpy.ndarray: The 3D bone mask with bone as 1 and all other as 0.
    """
    bone_mask = np.zeros_like(ct_image)
    lower_bound, upper_bound = HU_range
    #print(f"Applying HU range: {lower_bound} to {upper_bound}")
    #print(f"CT image min value: {np.min(ct_image)}, max value: {np.max(ct_image)}")
    bone_mask[(ct_image >= lower_bound) & (ct_image <= upper_bound)] = 1
    #print(f"Bone mask generated with shape: {bone_mask.shape}, number of bone voxels: {np.sum(bone_mask)}")
    return bone_mask

def get_nonzero_slice_range(image_data, slice_dir_indx=2):
    """
    Get the range of slices that contain non-zero values in a 3D image.
    
    This function computes which slices along the specified dimension contain 
    any non-zero elements by collapsing the other two dimensions.
    
    Parameters
    ----------
    image_data : numpy.ndarray
        The 3D image data.
    slice_dir_indx : int, optional
        The index of the slice direction (0, 1, or 2). Default is 2.
    
    Returns
    -------
    tuple
        A tuple (start, end) where start is the first slice index and end is the last 
        slice index that contain non-zero values.
    
    Raises
    ------
    ValueError
        If no non-zero slices are found or if slice_dir_indx is not 0, 1, or 2.
    """
    if image_data.ndim != 3:
        raise ValueError("image_data must be a 3D array.")
    if slice_dir_indx not in (0, 1, 2):
        raise ValueError("slice_dir_indx must be 0, 1, or 2.")
    
    # Determine which axes to collapse for the nonzero test.
    if slice_dir_indx == 0:
        collapsed = np.any(image_data, axis=(1, 2))
    elif slice_dir_indx == 1:
        collapsed = np.any(image_data, axis=(0, 2))
    else:  # slice_dir_indx == 2
        collapsed = np.any(image_data, axis=(0, 1))
    
    nonzero_indices = np.nonzero(collapsed)[0]
    if nonzero_indices.size == 0:
        raise ValueError("No non-zero slices found in the specified direction.")
    
    start = int(nonzero_indices[0])
    end = int(nonzero_indices[-1])
    
    return start, end

from skimage.filters import threshold_otsu

def estimate_and_plot_bone_range(image, mandible_mask=None, spinal_cord_mask=None, plot_hist=True):
    """
    Estimate the range of bone Hounsfield Unit (HU) values from a CT image using Otsu's thresholding
    applied to optional structure masks for the mandible and spinal cord.
    
    The function operates as follows:
      - If both mandible_mask and spinal_cord_mask are provided:
          * Compute Otsu's threshold for the mandible mask and for the spinal cord mask.
          * For the spinal cord, use the minimum intensity among voxels above threshold as the lower bound.
          * For the mandible, use the mean intensity of voxels above threshold as the upper bound.
      - If only one mask is provided, compute Otsu's threshold on that mask and use the minimum and maximum 
        intensities (above threshold) as the bone range.
      - If neither mask is provided, return None.
    
    Additionally, if plot_hist is True, the function plots the intensity histograms for each available mask,
    with a vertical dashed line indicating the Otsu threshold.
    
    Parameters
    ----------
    image : numpy.ndarray
        3D CT image array.
    mandible_mask : numpy.ndarray, optional
        3D binary mask for the mandible.
    spinal_cord_mask : numpy.ndarray, optional
        3D binary mask for the spinal cord.
    plot_hist : bool, optional
        If True, display intensity histograms with the computed Otsu threshold. Default is True.
    
    Returns
    -------
    tuple or None
        If at least one mask is provided and valid voxels are found, returns a tuple (min_HU, max_HU)
        representing the estimated bone HU range; otherwise, returns None.
    """
    def get_voxels(mask):
        return image[mask.astype(bool)]
    
    # Process when both masks are available.
    if mandible_mask is not None and spinal_cord_mask is not None:
        mandible_voxels = get_voxels(mandible_mask)
        spinal_voxels = get_voxels(spinal_cord_mask)
        if mandible_voxels.size == 0 or spinal_voxels.size == 0:
            return None
        thresh_mandib = threshold_otsu(mandible_voxels)
        thresh_spinal = threshold_otsu(spinal_voxels)
        
        mandible_above = mandible_voxels[mandible_voxels > thresh_mandib]
        spinal_above = spinal_voxels[spinal_voxels > thresh_spinal]
        if mandible_above.size == 0 or spinal_above.size == 0:
            return None
        
        min_HU = np.min(spinal_above)   # Lower bound from spinal cord
        max_HU = np.mean(mandible_above)  # Upper bound from mandible
        
        if plot_hist:
            fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12,5))
            ax1.hist(mandible_voxels, bins=50, color='blue', alpha=0.7)
            ax1.axvline(thresh_mandib, color='black', linestyle='--', 
                        label=f'Otsu thresh = {thresh_mandib:.1f}')
            ax1.set_title('Mandible Intensity Histogram')
            ax1.legend()
            
            ax2.hist(spinal_voxels, bins=50, color='green', alpha=0.7)
            ax2.axvline(thresh_spinal, color='black', linestyle='--', 
                        label=f'Otsu thresh = {thresh_spinal:.1f}')
            ax2.set_title('Spinal Cord Intensity Histogram')
            ax2.legend()
            plt.show()
        
        return (min_HU, max_HU)
    
    # Process if only the mandible mask is provided.
    if mandible_mask is not None and spinal_cord_mask is None:
        mandible_voxels = get_voxels(mandible_mask)
        if mandible_voxels.size == 0:
            return None
        thresh = threshold_otsu(mandible_voxels)
        voxels_above = mandible_voxels[mandible_voxels > thresh]
        if voxels_above.size == 0:
            return None
        if plot_hist:
            plt.figure(figsize=(6,5))
            plt.hist(mandible_voxels, bins=50, color='blue', alpha=0.7)
            plt.axvline(thresh, color='black', linestyle='--', label=f'Otsu thresh = {thresh:.1f}')
            plt.title('Mandible Intensity Histogram')
            plt.legend()
            plt.show()
        return (np.min(voxels_above), np.max(voxels_above))
    
    # Process if only the spinal cord mask is provided.
    if spinal_cord_mask is not None and mandible_mask is None:
        spinal_voxels = get_voxels(spinal_cord_mask)
        if spinal_voxels.size == 0:
            return None
        thresh = threshold_otsu(spinal_voxels)
        voxels_above = spinal_voxels[spinal_voxels > thresh]
        if voxels_above.size == 0:
            return None
        if plot_hist:
            plt.figure(figsize=(6,5))
            plt.hist(spinal_voxels, bins=50, color='green', alpha=0.7)
            plt.axvline(thresh, color='black', linestyle='--', label=f'Otsu thresh = {thresh:.1f}')
            plt.title('Spinal Cord Intensity Histogram')
            plt.legend()
            plt.show()
        return (np.min(voxels_above), np.max(voxels_above))
    
    return None


## Parametrization

In [5]:
def vector_to_angles(vector):
    """
    Calculate the distance from the origin and the angles in the xy and xz planes with respect to the x-axis.

    Parameters:
    vector (tuple): A tuple representing the vector (x, y, z).

    Returns:
    tuple: The distance from the origin, the angle in the xy plane, and the angle in the xz plane.
    """
    x, y, z = vector

    # Calculate the distance from the origin
    distance = np.sqrt(x**2 + y**2 + z**2)

    # Calculate the angle in the xy plane with respect to the x-axis
    angle_xy = np.arctan2(y, x)
    angle_xy_deg = np.degrees(angle_xy)

    # Calculate the angle in the xz plane with respect to the x-axis
    angle_xz = np.arctan2(z, x)
    angle_xz_deg = np.degrees(angle_xz)

    return angle_xy, angle_xz, distance

def angles_to_vector(angle_xy, angle_xz, distance):
    """
    Calculate the vector components given the distance from the origin and the angles in the xy and xz planes.

    Parameters:
    distance (float): The distance from the origin.
    angle_xy (float): The angle in the xy plane with respect to the x-axis.
    angle_xz (float): The angle in the xz plane with respect to the x-axis.

    Returns:
    tuple: The vector components (x, y, z).
    """
    x = distance * np.cos(angle_xy) * np.cos(angle_xz)
    y = distance * np.sin(angle_xy) * np.cos(angle_xz)
    z = distance * np.sin(angle_xz)
    
    return x, y, z

## Objective function


In [6]:
def assign_intensity_to_mirror_voxels(I_orig, x, x_m):
    """
    Assigns intensities from I_orig to a mirror image based on mirror voxel positions.
    
    For each voxel index in x (of shape (N, 3)), the corresponding mirror position
    in x_m (which is computed in continuous (float) coordinates) is rounded to the nearest
    integer indices. Then, the intensity at the original voxel (from I_orig) is assigned to
    the mirror voxel location in a new image.
    
    Parameters
    ----------
    I_orig : numpy.ndarray
        The original 3D image (e.g., a CT scan) with shape (X, Y, Z).
    x : numpy.ndarray
        Array of voxel coordinates (shape (N, 3)) corresponding to the original image points.
    x_m : numpy.ndarray
        Array of mirror voxel positions (shape (N, 3)) computed in continuous (float) coordinates.
    
    Returns
    -------
    I_m : numpy.ndarray
        A new 3D image of the same shape as I_orig, where each rounded mirror voxel position 
        is assigned the intensity from I_orig corresponding to the original voxel.
    """
    # Initialize the mirror image (I_m) with zeros (or choose a background value if needed).
    I_m = np.zeros_like(I_orig)
    
    # Round the mirror voxel positions to the nearest integers.
    x_m_rounded = np.rint(x_m).astype(int)
    
    # Clip the indices to ensure they are within the valid range of the image dimensions.
    x_m_rounded[:, 0] = np.clip(x_m_rounded[:, 0], 0, I_orig.shape[0] - 1)
    x_m_rounded[:, 1] = np.clip(x_m_rounded[:, 1], 0, I_orig.shape[1] - 1)
    x_m_rounded[:, 2] = np.clip(x_m_rounded[:, 2], 0, I_orig.shape[2] - 1)
    
    # Assign the intensity from I_orig (using voxel indices in x) to I_m at the corresponding rounded mirror indices.
    I_m[x_m_rounded[:, 0], x_m_rounded[:, 1], x_m_rounded[:, 2]] = I_orig[x[:, 0], x[:, 1], x[:, 2]]
    
    return I_m

# def compute_objective(params_array, image, interpolator_intensity, interpolators_gradient):
#     """
#     Compute the objective function:
#         f(theta, phi, L) = (1/N) * sum_i [ I(x_i) - I(x_m,i) ]^2,
#     where the mirror voxel is defined as:
#         x_m = x - 2 * alpha * n,
#     with
#         n = [cos(phi)*cos(theta), cos(phi)*sin(theta), sin(phi)]
#         alpha = cos(phi)*cos(theta)*x + cos(phi)*sin(theta)*y + sin(phi)*z - L.
    
#     The voxel indices x are taken from the nonzero elements of the image.
    
#     Parameters:
#       theta : float
#           Rotation angle around the z-axis.
#       phi : float
#           Rotation angle around the y-axis.
#       L : float
#           Offset along the normal.
#       image : 3D numpy array
#           The intensity image.
#       interpolator_intensity : RegularGridInterpolator
#           Interpolator to get the intensity I at any (x,y,z) location.
#       interpolators_gradient : dict
#           Dictionary with keys 'x', 'y', 'z' containing RegularGridInterpolator
#           objects for the gradient components (not used in this function).
    
#     Returns:
#       f : float
#           The value of the objective function.
#     """
#     # Extract parameters
#     theta, phi, L = params_array[0], params_array[1], params_array[2]
#     # Extract voxel indices (nonzero elements) as an array of shape (N,3)
#     indices_image = np.array(np.nonzero(image)).T  # each row is [x, y, z]
#     indices_coord_syst = np.array([indices_image[:, 1], indices_image[:, 0], indices_image[:, 2]]).T
#     N = indices_image.shape[0]
    
#     # Define the unit normal vector n based on theta and phi.
#     n = np.array([np.cos(phi)*np.cos(theta),
#                   np.cos(phi)*np.sin(theta),
#                   np.sin(phi)])
    
#     # Compute d for each voxel.
#     d = (np.cos(phi)*np.cos(theta)*indices_coord_syst[:, 0] +
#              np.cos(phi)*np.sin(theta)*indices_coord_syst[:, 1] +
#              np.sin(phi)*indices_coord_syst[:, 2] - L)  # shape (N,)
    
#     # Compute mirror voxel coordinates vectorized.
#     x_m_coord_syst = indices_coord_syst - 2 * d[:, None] * n[None, :]
#     x_m_image = np.array([x_m_coord_syst[:, 1], x_m_coord_syst[:, 0], x_m_coord_syst[:, 2]]).T
#     # Evaluate intensity at mirror voxel positions.
#     I_m = interpolator_intensity(x_m_image)
    
    
#     # Get original intensity values from the image.
#     I_orig = image[indices_image[:, 0], indices_image[:, 1], indices_image[:, 2]]
    
#     # Compute the mean square error.
#     diff = I_orig - I_m
#     f = (1.0 / N) * np.sum(diff ** 2)

#     # # Visualize the mirror image on middle slice
#     # I_mirror = assign_intensity_to_mirror_voxels(image, indices_image, x_m_image)

#     # plt.imshow(I_mirror[:, :, image.shape[2] // 2] + image[:, :, image.shape[2] // 2], cmap='gray')
#     # plt.show()


#     return f

def compute_signed_distances(params_array, image):
    """
    Compute the signed distances of nonzero voxels in a 3D image from a plane,
    using the parameterization [theta, phi, L].
    
    The image is assumed to have coordinate zero at the top left corner.
    The plane is defined via:
      n = [cos(phi)*cos(theta), cos(phi)*sin(theta), sin(phi)]
      P0 = L * n
    and the signed distance for a voxel x (in the modified coordinate system) is:
      d = cos(phi)*cos(theta)*x' + cos(phi)*sin(theta)*y' + sin(phi)*z - L,
    where the voxel coordinates are obtained by first extracting the nonzero indices
    (each row as [x, y, z]) and then swapping the first two coordinates to obtain
    the coordinate system used in the computation.
    
    Parameters
    ----------
    params_array : array-like, shape (3,)
        The parameters [theta, phi, L], with theta and phi in radians and L in voxel units.
    image : numpy.ndarray
        The 3D image (e.g., a CT image), where coordinate zero is in the top left corner.
    
    Returns
    -------
    d : numpy.ndarray
        A 1D array of signed distances for each nonzero voxel.
    indices_coord_syst : numpy.ndarray
        The modified voxel coordinate array corresponding to the computed distances.
    """
    # Extract parameters.
    theta, phi, L = params_array[0], params_array[1], params_array[2]
    
    # Extract voxel indices (nonzero elements) as an array of shape (N,3).
    indices_image = np.array(np.nonzero(image)).T  # each row is [x, y, z]
    
    # Convert indices to the coordinate system used in the computation (swap first two dimensions).
    indices_coord_syst = np.array([indices_image[:, 1], indices_image[:, 0], indices_image[:, 2]]).T
    N = indices_image.shape[0]
    
    # Define the unit normal vector n based on theta and phi.
    n = np.array([np.cos(phi)*np.cos(theta),
                  np.cos(phi)*np.sin(theta),
                  np.sin(phi)])
    
    # Compute signed distance d for each voxel.
    d = (np.cos(phi)*np.cos(theta)*indices_coord_syst[:, 0] +
         np.cos(phi)*np.sin(theta)*indices_coord_syst[:, 1] +
         np.sin(phi)*indices_coord_syst[:, 2] - L)
    
    return d, n, indices_coord_syst, indices_image


def huber_loss(diff, delta=100):
    """
    Compute the Huber loss for the given differences.
    
    Parameters
    ----------
    diff : numpy.ndarray
        The difference between original and mirrored intensities.
    delta : float, optional
        The threshold at which to switch between quadratic and linear loss (default 100).
        
    Returns
    -------
    numpy.ndarray
        The Huber loss computed elementwise.
    """
    # For errors smaller than delta, use quadratic loss.
    # For errors larger than delta, use linear loss.
    loss = np.where(np.abs(diff) <= delta, 0.5 * diff**2,
                    delta * (np.abs(diff) - 0.5 * delta))
    return loss

def compute_objective(params_array, image, interpolator_intensity, interpolators_gradient):
    """
    Compute the objective function using Huber loss:
    
        f(theta, phi, L) = (1/N) * sum_i huber_loss( I(x_i) - I(x_m,i) )
        
    where the mirror voxel is defined as:
        x_m = x - 2 * alpha * n,
    with
        n = [cos(phi)*cos(theta), cos(phi)*sin(theta), sin(phi)]
        alpha = cos(phi)*cos(theta)*x + cos(phi)*sin(theta)*y + sin(phi)*z - L.
    
    The voxel indices x are taken from the nonzero elements of the image.

    Parameters
    ----------
    params_array : array-like, shape (3,)
        Array containing the parameters [theta, phi, L].
    image : numpy.ndarray
        The 3D intensity image.
    interpolator_intensity : RegularGridInterpolator
        Interpolator to get the intensity I at any (x, y, z) position.
    interpolators_gradient : dict
        Dictionary with keys 'x', 'y', and 'z' containing RegularGridInterpolator
        objects for the gradient components (not used in this function).
    delta : float, optional
        Delta parameter for the Huber loss (default is 100).

    Returns
    -------
    f : float
        The value of the objective function.
    """
    d, n, indices_coord_syst, indices_image = compute_signed_distances(params_array, image)
    N = len(d)
    
    # Compute mirror voxel coordinates vectorized.
    x_m_coord_syst = indices_coord_syst - 2 * d[:, None] * n[None, :]
    # Convert back to image coordinate system by swapping axes.
    x_m_image = np.array([x_m_coord_syst[:, 1], x_m_coord_syst[:, 0], x_m_coord_syst[:, 2]]).T
    
    # Evaluate intensity at mirror voxel positions.
    I_m = interpolator_intensity(x_m_image)
    
    # Get original intensity values from the image.
    I_orig = image[indices_image[:, 0], indices_image[:, 1], indices_image[:, 2]]
    
    # Compute the difference.
    diff = I_orig - I_m
    # Compute the objective using the Huber loss.
    #f = (1.0 / N) * np.sum(huber_loss(diff, delta=300))
    f = (1.0 / N) * np.sum(diff **2)
    return f


## Gradient

In [7]:
def compute_gradient(param_array, image, interpolator_intensity, interpolators_gradient):
    """
    Compute the gradient of the objective function
        f(theta, phi, L) = (1/N)*sum_{i}[ I(x_i) - I(x_m,i) ]^2,
    with respect to the parameters theta, phi, and L.
    
    The mirror voxel for each voxel x is defined as:
        x_m = x - 2 * alpha * n,
    where
        n = [cos(phi)*cos(theta), cos(phi)*sin(theta), sin(phi)]
        alpha = cos(phi)*cos(theta)*x + cos(phi)*sin(theta)*y + sin(phi)*z - L.
    
    The function extracts the indices of the nonzero elements in the image as x.
    
    Parameters:
      theta : float
          Rotation angle around the z-axis.
      phi : float
          Rotation angle around the y-axis.
      L : float
          Offset along the normal.
      image : 3D numpy array
          The intensity image.
      interpolator_intensity : RegularGridInterpolator
          Interpolator to get the intensity I at any (x,y,z) location.
      interpolators_gradient : dict
          Dictionary with keys 'x', 'y', and 'z' containing RegularGridInterpolator
          objects for the gradient components of the image.
    
    Returns:
      grad : numpy array of shape (3,)
          The gradient [df/dtheta, df/dphi, df/dL].
    """
    # Get voxel indices (x) as an array of shape (N, 3) from nonzero elements.
    theta, phi, L = param_array[0], param_array[1], param_array[2]
    indices = np.array(np.nonzero(image)).T  # each row is [x, y, z]
    N = indices.shape[0]
    
    # Define the unit normal vector n.
    n = np.array([np.cos(phi)*np.cos(theta),
                  np.cos(phi)*np.sin(theta),
                  np.sin(phi)])
    
    # Compute alpha for each voxel.
    # indices[:,0] -> x, indices[:,1] -> y, indices[:,2] -> z.
    alpha = (np.cos(phi)*np.cos(theta)*indices[:,0] +
             np.cos(phi)*np.sin(theta)*indices[:,1] +
             np.sin(phi)*indices[:,2] - L)  # shape (N,)
    
    # Compute mirror voxels x_m vectorized.
    # x_m = x - 2*alpha*n, broadcasting n (shape (3,)) and alpha (shape (N,))
    x_m = indices - 2 * alpha[:, None] * n[None, :]
    
    # Evaluate the intensity at the mirror voxel positions.
    I_m = interpolator_intensity(x_m)
    # Get the intensity at the original voxel positions from the image.
    I_orig = image[indices[:,0], indices[:,1], indices[:,2]]
    
    # Compute the difference d = I(x) - I(x_m) for each voxel.
    d = I_orig - I_m  # shape (N,)
    
    # Evaluate the gradient of the intensity at mirror positions.
    grad_x = interpolators_gradient['x'](x_m)
    grad_y = interpolators_gradient['y'](x_m)
    grad_z = interpolators_gradient['z'](x_m)
    grad_I = np.stack([grad_x, grad_y, grad_z], axis=1)  # shape (N, 3)
    
    # -------------------------------
    # Compute the derivatives of x_m with respect to each parameter.
    # -------------------------------
    # For L:
    d_xm_dL = 2 * n  # shape (3,)
    
    # For theta:
    # term_theta = cos(phi)*sin(theta)*x - cos(phi)*cos(theta)*y
    term_theta = np.cos(phi)*np.sin(theta)*indices[:,0] - np.cos(phi)*np.cos(theta)*indices[:,1]
    d_xm_dtheta = 2 * ( term_theta[:, None] * n[None, :] +
                        alpha[:, None] * np.array([np.cos(phi)*np.sin(theta),
                                                   -np.cos(phi)*np.cos(theta),
                                                    0]) )
    
    # For phi:
    # term_phi = sin(phi)*cos(theta)*x + sin(phi)*sin(theta)*y - cos(phi)*z
    term_phi = (np.sin(phi)*np.cos(theta)*indices[:,0] +
                np.sin(phi)*np.sin(theta)*indices[:,1] -
                np.cos(phi)*indices[:,2])
    d_xm_dphi = 2 * ( term_phi[:, None] * n[None, :] +
                      alpha[:, None] * np.array([np.sin(phi)*np.cos(theta),
                                                 np.sin(phi)*np.sin(theta),
                                                -np.cos(phi)]) )
    
    # -------------------------------
    # Compute dot products between grad_I and the derivatives of x_m.
    dot_theta = np.sum(grad_I * d_xm_dtheta, axis=1)  # shape (N,)
    dot_phi   = np.sum(grad_I * d_xm_dphi, axis=1)
    dot_L     = np.sum(grad_I * d_xm_dL[None, :], axis=1)
    
    # Compute the gradient of the objective function with respect to each parameter.
    # df/dp = - (2/N) * sum_i [ d_i * (grad_I dot (dx_m/dp)) ]
    grad_theta = - (2.0 / N) * np.sum(d * dot_theta)
    grad_phi   = - (2.0 / N) * np.sum(d * dot_phi)
    grad_L     = - (2.0 / N) * np.sum(d * dot_L)
    
    return np.array([grad_theta, grad_phi, grad_L])

## Parameter Initialization

In [8]:
def param_initialization_2d(bone: np.ndarray,
                            image: np.ndarray,
                            theta_deg: float,
                            phi_deg: float,
                            output_path: str,
                            pat: str,
                            interpolator_intensity,  # RegularGridInterpolator for intensity
                            interpolators_gradient: dict,
                            plot: bool = True) -> tuple:
    """
    Initializes parameters for plane fitting on a given 3D image using a grid search.
    
    This function calculates the center of mass (COM) of the input image (or body, if provided)
    and uses it to initialize candidate plane parameters. For each candidate plane, it computes
    the mean squared error (MSE) between the image and its mirror (using the compute_objective function)
    and selects the plane with the lowest MSE.
    
    Parameters
    ----------
    bone : np.ndarray
        The thresholded 3D intensity image for bone.
    image_plot : np.ndarray
        A 3D image used to compute the center of mass.
    output_path : str
        Directory in which the MSE heatmap and candidate plane parameters are saved.
    pat : str
        Patient identifier used for naming output files.
    interpolator_intensity : RegularGridInterpolator
        Interpolator to evaluate the image intensity at arbitrary (x,y,z) coordinates.
    interpolators_gradient : dict
        Dictionary with keys 'x', 'y', and 'z' containing RegularGridInterpolator objects
        for the corresponding image gradient components.
    plot : bool, optional
        If True, plots the MSE heatmap. Default is True.
    
    Returns
    -------
    best_plane_params : tuple
        The best plane parameters in the form (a, b, c, D) of the plane equation.
    """
    print("Starting initialization...")
    start_initialization = time.time()

    # Compute the center of mass (COM) using image_plot.
    com = center_of_mass(image)
    com = np.array([com[1], com[0], com[2]])  # Adjust COM to match image axes.
    
    # Define angular search ranges (in radians)
    angle_rad_theta = np.deg2rad(theta_deg)
    angle_rad_phi = np.deg2rad(phi_deg)
    thetas = np.linspace(-angle_rad_theta, angle_rad_theta, 10)  # Candidate polar angles.
    phis = np.linspace(-angle_rad_phi, angle_rad_phi, 10)         # Candidate azimuthal angles.
    
    mse_list = []
    mse_data = []
    plane_params_list = []

    # Use a reference vector (along the x-axis, using image center) to generate candidate normals.
    middle_x = image.shape[0] // 2
    param_vec = np.array([1, 0, 0])
    
    # Ensure the output directory exists.
    os.makedirs(output_path, exist_ok=True)
    mse_array_file = os.path.join(output_path, f"Initialization_obj_fun.npy")
    plane_params_file = os.path.join(output_path, f"Initialization_plane_params.npy")
    
    # If the files do not exist, compute the grid search.
    if not (os.path.exists(mse_array_file) and os.path.exists(plane_params_file)):
        for phi_val in phis:
            for theta_val in thetas:
                # Build the rotation matrix for given theta and phi.
                rotation_matrix = np.array([
                    [np.cos(theta_val)*np.cos(phi_val), -np.sin(theta_val), np.cos(theta_val)*np.sin(phi_val)],
                    [np.sin(theta_val)*np.cos(phi_val),  np.cos(theta_val), np.sin(theta_val)*np.sin(phi_val)],
                    [-np.sin(phi_val),                   0,                np.cos(phi_val)]
                ])
                
                # Rotate the reference vector to obtain the candidate normal vector.
                rotated_normal_vector = rotation_matrix.dot(param_vec)
                D = -np.dot(rotated_normal_vector, com)
                rotated_normal_vector = np.abs(D) * rotated_normal_vector
                theta_sph, phi_sph, l = vector_to_angles(rotated_normal_vector)

                # Save plane parameters (a, b, c, D).
                plane_params = np.array([theta_sph, phi_sph, l])
                plane_params_list.append(plane_params)
                
                # Convert the rotated normal vector to spherical coordinates.
                # vector_to_angles should return (theta, phi, L) such that L corresponds to the offset.
                
                
                # Compute the MSE for this candidate plane using compute_objective.
                mse = compute_objective(plane_params, bone, interpolator_intensity, interpolators_gradient)
                mse_list.append(mse)
        
        mse_array = np.array(mse_list).reshape(len(phis), len(thetas))
        plane_params_array = np.array(plane_params_list)
        np.save(mse_array_file, mse_array)
        np.save(plane_params_file, plane_params_array)
    else:
        mse_array = np.load(mse_array_file)
        mse_list = mse_array.flatten().tolist()
        plane_params_array = np.load(plane_params_file)
        plane_params_list = plane_params_array.tolist()
    
    # Plot the MSE heatmap if requested.
    if plot:
        vmin = np.min(mse_array)
        vmax_initial = vmin + 1e7  # Adjust this value if needed.
        min_index = np.unravel_index(np.argmin(mse_array), mse_array.shape)
        min_phi = np.rad2deg(phis[min_index[0]])
        min_theta = np.rad2deg(thetas[min_index[1]])
        
        plt.figure()
        plt.imshow(mse_array, extent=[np.rad2deg(thetas[0]), np.rad2deg(thetas[-1]),
                                      np.rad2deg(phis[0]), np.rad2deg(phis[-1])],
                   aspect='auto', origin='lower', cmap='viridis', vmin=vmin, vmax=vmax_initial)
        plt.colorbar(label='Mean Squared Error')
        plt.title('MSE vs. Polar and Azimuthal Angles')
        plt.xlabel('Polar Angle (θ)°')
        plt.ylabel('Azimuthal Angle (φ)°')
        plt.scatter(min_theta, min_phi, color='red', marker='x', s=100, label='Min MSE')
        plt.legend()
        plt.show()
    
    #plot_middle_slice_with_planes(image, plane_params_list, title = 'Candidate Planes', output_path = output_path, filename = f'candidate_planes_patient_{pat}.svg')

    best_plane_index = np.argmin(mse_list)
    best_plane_params = plane_params_list[best_plane_index]
    best_plane_params_deg = np.array(np.rad2deg(best_plane_params))
    # print(f"Best plane parameters: {best_plane_params_deg}")
    # print(f"Best MSE: {mse_list[best_plane_index]}")
    
    end_initialization = time.time()
    print(f"Time taken for initialization: {end_initialization - start_initialization:.2f} seconds")
    
    return best_plane_params

## Optimization

In [9]:
def optimize_plane(initial_params_array, image, interpolator_intensity, interpolators_gradient):
    """
    Optimize the plane parameters [theta, phi, L] using the BFGS method.

    The objective function is:
      f(theta, phi, L) = (1/N) * sum_i [ I(x_i) - I(x_m,i) ]^2,
    where x_m is computed as:
      x_m = x - 2 * alpha * n,
    with
      n = [cos(phi)*cos(theta), cos(phi)*sin(theta), sin(phi)]
      alpha = cos(phi)*cos(theta)*x + cos(phi)*sin(theta)*y + sin(phi)*z - L.
    The voxel coordinates x are taken as the nonzero indices of the image.

    Parameters
    ----------
    initial_params_array : array-like, shape (3,)
        Initial guess for the parameters [theta, phi, L].
    image : 3D numpy array
        The intensity image.
    interpolator_intensity : RegularGridInterpolator
        Interpolator to evaluate the intensity at any (x, y, z) position.
    interpolators_gradient : dict
        Dictionary with keys 'x', 'y', and 'z' containing RegularGridInterpolator
        objects for the gradient components of the image.

    Returns
    -------
    res : OptimizeResult
        The optimization result returned by scipy.optimize.minimize.
        The lists `res.objective_value_list` and `res.params_list` contain the
        objective function values and parameter vectors encountered during the optimization.
    """

    # Lists to store the objective values and parameters at each iteration.
    objective_value_list = []
    params_list = []

    # Callback function to save current parameter vector and objective value.
    def callback(xk):
        # Compute current objective value using our objective function.
        f_val = compute_objective(xk, image, interpolator_intensity, interpolators_gradient)
        objective_value_list.append(f_val)
        params_list.append(xk.copy())

    
    # Run the optimizer using BFGS.
    res = minimize(compute_objective, x0=initial_params_array, args=(image, interpolator_intensity, interpolators_gradient),
                   method='BFGS', jac=None, callback=callback)
    params_list.append(res.x.copy())
    objective_value_list.append(res.fun)

    # Attach the objective and parameter histories to the result.
    res.objective_value_list = objective_value_list
    res.params_list = params_list

    return res

## Verification Plots

In [10]:
def plot_middle_slice_with_planes(image_data, plane_params_list, title='Middle Slice with Plane Projections', com=None, output_path=None, filename="middle_slice_with_planes.svg"):
    """
    Plot the middle axial slice of a 3D image with projections of multiple planes.

    Parameters
    ----------
    image_data : numpy.ndarray
        The 3D image data.
    plane_params_list : list
        A list of plane parameters, where each element is a tuple or array of [theta, phi, L].
    title : str, optional
        The title of the plot.
    com : array-like, optional
        The center of mass (as [x, y]) to be plotted.
    output_path : str, optional
        Directory in which to save the figure. If None, the figure is not saved.
    filename : str, optional
        Filename for the saved figure.
    """
    # Compute the index for the middle axial slice.
    middle_slice_index = image_data.shape[2] // 2
    slice_array = image_data[:, :, middle_slice_index]
    
    # Create the figure and axis.
    fig, ax = plt.subplots(figsize=(10, 10))
    ax.imshow(slice_array, cmap='gray')
    ax.set_title(title)
    ax.set_xlabel("x (pixels)")
    ax.set_ylabel("y (pixels)")
    
    # Plot the center of mass if provided.
    if com is not None:
        ax.scatter(com[0], com[1], color='blue', marker='x', s=100, label='Center of Mass')
    
    # Create a grid for contour plotting.
    ny, nx = slice_array.shape
    x = np.linspace(0, nx, 100)
    y = np.linspace(0, ny, 100)
    X, Y = np.meshgrid(x, y)
    
    # For each plane, compute its intersection with the middle slice.
    for plane_coeffs in plane_params_list:
        theta, phi, L = plane_coeffs[0], plane_coeffs[1], plane_coeffs[2]
        A, B, C = angles_to_vector(theta, phi, L)
        D = - np.dot([A, B, C], [A, B, C])
        # Normalize the plane coefficients (make a copy so as not to modify the input)

        if np.abs(C) < 1e-8:
            C = 1e-8
        # Solve for z in the plane equation: z = (-A*x - B*y - D) / C.
        # Find the contour where z equals the middle slice index.
        contour = (-A * X - B * Y - D) / C
        ax.contour(X, Y, contour, levels=[middle_slice_index], colors='red')
    
    if com is not None:
        ax.legend()

    # Save the figure if an output path is provided.
    if output_path is not None:
        save_path = os.path.join(output_path, filename)
        plt.savefig(save_path)
        print(f"Figure saved to {save_path}")
    
    plt.show()

## Scrollable Widget

In [11]:
def sample_random_points(coords, distances, n_points):
    """
    Randomly sample a subset of points from the coordinate array and their corresponding distances.
    
    Parameters
    ----------
    coords : numpy.ndarray
        A 2D array of shape (N, 3) (or any shape where the first dimension is the number of points)
        containing the positions of the voxels.
    distances : numpy.ndarray
        A 1D array of length N containing the corresponding signed distances.
    n_points : int
        The number of random points to sample.
        
    Returns
    -------
    sampled_coords : numpy.ndarray
        A 2D array containing the randomly sampled coordinates.
    sampled_distances : numpy.ndarray
        A 1D array containing the distances corresponding to the sampled coordinates.
    """
    if n_points > len(distances):
        n_points = len(distances)
    # Randomly select n_points indices from the available points.
    indices = np.random.choice(len(distances), size=n_points, replace=False)
    sampled_coords = coords[indices]
    sampled_distances = distances[indices]
    return sampled_coords, sampled_distances

def display_scrollable_slices_with_plane(image, gtv_mask, body_mask=None, mandible_mask=None, spinal_cord_mask=None,
                              plane_coeffs_list=[], optimization_methods_list=[], points = None, distances = None):
    """
    Display an interactive widget to scroll through slices of the 3D image (in two views) with
    overlays of masks and plane contours.
    
    Parameters
    ----------
    image : numpy.ndarray
        The 3D image data.
    gtv_mask : numpy.ndarray
        The 3D primary GTV mask (mandatory).
    body_mask : numpy.ndarray, optional
        The 3D body mask.
    mandible_mask : numpy.ndarray, optional
        The 3D mandible mask.
    spinal_cord_mask : numpy.ndarray, optional
        The 3D spinal cord mask.
    plane_coeffs_list : list
        List of plane coefficients, where each is a tuple/list (theta, phi, L).
    optimization_methods_list : list
        List of method names corresponding to each plane, used for the legend.
    """
    num_slices = image.shape[2]
    plane_colors = ['red', 'purple', 'cyan']
    
    # Calculate center of mass from body_mask if provided.
    if body_mask is not None:
        com = center_of_mass(body_mask)
        # Swap first two coordinates to match image orientation.
        com = (com[1], com[0], com[2])
    else:
        com = None
    #points, distances = sample_random_points(points, distances, 15)

    def view_slice_axial(slice_index):
        fig, ax = plt.subplots(figsize=(10, 10))
        # Display the grayscale image.
        ax.imshow(image[:, :, slice_index], cmap='gray', interpolation='none')
        
        if points is not None:
            mask = points[:, 2] == slice_index
            filtered_coords = points[mask]
            filtered_distances = distances[mask]
            for (x, y, z), dist in zip(filtered_coords, filtered_distances):
                plt.plot(x, y, 'ro', markersize=2)
                plt.text(x, y, f"{round(dist,1)}", color='green', fontsize=6)
                
        # Overlay the mandatory GTV mask.
        ax.contour(gtv_mask[:, :, slice_index], colors='yellow', linewidths=1)
        
        # Overlay optional masks if available.
        if body_mask is not None:
            ax.contour(body_mask[:, :, slice_index], colors='orange', linewidths=1)
        if mandible_mask is not None:
            ax.contour(mandible_mask[:, :, slice_index], colors='blue', linewidths=1)
        if spinal_cord_mask is not None:
            ax.contour(spinal_cord_mask[:, :, slice_index], colors='green', linewidths=1)
            
        # Overlay each plane contour.
        for idx, coeffs in enumerate(plane_coeffs_list):
            A, B, C = angles_to_vector(coeffs[0], coeffs[1], coeffs[2])
            D = - np.dot([A, B, C], [A, B, C])
            # Avoid division by zero.
            C_val = C if C != 0 else 1e-6
            x = np.linspace(0, image.shape[1], 100)
            y = np.linspace(0, image.shape[0], 100)
            X, Y = np.meshgrid(x, y)
            contour = (-A * X - B * Y - D) / C_val
            ax.contour(X, Y, contour, levels=[slice_index],
                       colors=plane_colors[idx % len(plane_colors)], linewidths=1)
        
        # Build legend based on available overlays.
        legend_handles = [Line2D([0], [0], color='yellow', lw=2, label='GTVp')]
        if body_mask is not None:
            legend_handles.append(Line2D([0], [0], color='orange', lw=2, label='Body'))
        if mandible_mask is not None:
            legend_handles.append(Line2D([0], [0], color='blue', lw=2, label='Mandible'))
        if spinal_cord_mask is not None:
            legend_handles.append(Line2D([0], [0], color='green', lw=2, label='Spinal Cord'))
        for idx, coeffs in enumerate(plane_coeffs_list):
            if idx < len(optimization_methods_list):
                legend_handles.append(Line2D([0], [0], color=plane_colors[idx % len(plane_colors)],
                                              lw=2, label=optimization_methods_list[idx]))
        
        ax.legend(handles=legend_handles, loc='upper right')
        
        # Custom coordinate formatter for axial view.
        def format_coord(x, y):
            col = int(round(x))
            row = int(round(y))
            info = f"x={col}, y={row}, slice={slice_index}"
            if 0 <= row < gtv_mask.shape[0] and 0 <= col < gtv_mask.shape[1]:
                info += f", GTVp={gtv_mask[row, col, slice_index]}"
            if body_mask is not None and 0 <= row < body_mask.shape[0] and 0 <= col < body_mask.shape[1]:
                info += f", Body={body_mask[row, col, slice_index]}"
            if mandible_mask is not None and 0 <= row < mandible_mask.shape[0] and 0 <= col < mandible_mask.shape[1]:
                info += f", Mandible={mandible_mask[row, col, slice_index]}"
            if spinal_cord_mask is not None and 0 <= row < spinal_cord_mask.shape[0] and 0 <= col < spinal_cord_mask.shape[1]:
                info += f", Spinal Cord={spinal_cord_mask[row, col, slice_index]}"
            return info
        ax.format_coord = format_coord
        
        ax.axis('off')
        plt.show()

    # For coronal view, determine slice range along x direction.
    if body_mask is not None:
        start_slice, end_slice = get_nonzero_slice_range(body_mask, slice_dir_indx=0)
    else:
        start_slice = 0
        end_slice = image.shape[0] - 1

    def view_slice_coronal(slice_index):
        fig, ax = plt.subplots(figsize=(10, 10))
        # Display coronal view (slice along x axis).
        ax.imshow(image[slice_index, :, :], cmap='gray', interpolation='none')
        ax.contour(gtv_mask[slice_index, :, :], colors='yellow')
        if body_mask is not None:
            ax.contour(body_mask[slice_index, :, :], colors='orange')
        if mandible_mask is not None:
            ax.contour(mandible_mask[slice_index, :, :], colors='blue')
        if spinal_cord_mask is not None:
            ax.contour(spinal_cord_mask[slice_index, :, :], colors='green')
        
        for idx, coeffs in enumerate(plane_coeffs_list):
            A, B, C = angles_to_vector(coeffs[0], coeffs[1], coeffs[2])
            D = - np.dot([A, B, C], [A, B, C])
            B_val = B if B != 0 else 1e-6
            x = np.linspace(0, image.shape[0], 100)
            z = np.linspace(0, image.shape[2], 100)
            X, Z = np.meshgrid(x, z)
            contour = (-A * X - C * Z - D) / B_val
            ax.contour(Z, X, contour, levels=[slice_index], colors=plane_colors[idx % len(plane_colors)])
        
        # Build legend for coronal view.
        # legend_handles = [Line2D([0], [0], color='yellow', lw=2, label='GTVp')]
        # if body_mask is not None:
        #     legend_handles.append(Line2D([0], [0], color='orange', lw=2, label='Body'))
        # if mandible_mask is not None:
        #     legend_handles.append(Line2D([0], [0], color='blue', lw=2, label='Mandible'))
        # if spinal_cord_mask is not None:
        #     legend_handles.append(Line2D([0], [0], color='green', lw=2, label='Spinal Cord'))
        # for idx, coeffs in enumerate(plane_coeffs_list):
        #     if idx < len(optimization_methods_list):
        #         legend_handles.append(Line2D([0], [0], color=plane_colors[idx % len(plane_colors)], lw=2,
        #                                       label=optimization_methods_list[idx]))
        # ax.legend(handles=legend_handles, loc='upper right')
        
        ax.set_ylim(start_slice, end_slice)
        ax.axis('off')
        plt.show()

    # Interactive slider for axial (z) view.
    slice_slider_axial = widgets.IntSlider(min=0, max=num_slices-1, step=1, value=num_slices//2, description='Axial Slice')
    display(widgets.interact(view_slice_axial, slice_index=slice_slider_axial))
    
    # Interactive slider for coronal (x) view.
    slice_slider_coronal = widgets.IntSlider(min=start_slice, max=end_slice, step=1, value=(start_slice+end_slice)//2, description='Coronal Slice')
    display(widgets.interact(view_slice_coronal, slice_index=slice_slider_coronal))

def display_patient_overlay_structures(struct_dict, title="Overlay Structures"):
    """
    Creates an interactive widget that overlays structure contours over the main image.
    
    The main image is expected under the key "Image" and the primary GTV mask under "GTVp" in struct_dict.
    Other masks (e.g. "Body", "Mandible", "Spinal Cord") are optional. For each slice, the main image is
    displayed in grayscale and the available masks are overlaid as contours using predetermined colors.
    A legend is added for the masks that are present and a custom coordinate formatter shows the pixel 
    coordinates and the mask values when hovering over the image.
    
    Parameters
    ----------
    struct_dict : dict
        Dictionary mapping structure names to (image_array, voxel_size).
        Must contain keys "Image" and "GTVp".
    title : str, optional
        Title prefix for the displayed plot.
    """
    # Check that mandatory keys exist.
    if "Image" not in struct_dict:
        print("Main image (key 'Image') not found in the structure dictionary.")
        return
    if "GTVp" not in struct_dict:
        print("Primary GTV mask (key 'GTVp') not found in the structure dictionary.")
        return
    
    main_img, _ = struct_dict["Image"]
    gtv_img, _  = struct_dict["GTVp"]
    num_slices = main_img.shape[2]
    
    # Define a dictionary of colors for each structure.
    mask_colors = {
        "GTVp": "yellow",   # mandatory
        "Body": "red",
        "Mandible": "blue",
        "Spinal Cord": "green"
    }
    
    # Build list of optional mask keys that are present.
    optional_keys = []
    for key in ["Body", "Mandible", "Spinal Cord"]:
        if key in struct_dict and struct_dict[key] is not None:
            optional_keys.append(key)
    
    def view_slice(slice_index):
        fig, ax = plt.subplots(figsize=(8,8))
        # Display the main image in grayscale.
        ax.imshow(main_img[:, :, slice_index], cmap='gray', interpolation='none')
        
        # Overlay the mandatory GTV mask.
        ax.contour(gtv_img[:, :, slice_index], levels=[0.5], colors=mask_colors["GTVp"], linewidths=2)
        
        # Overlay optional masks if available.
        for key in optional_keys:
            mask_img, _ = struct_dict[key]
            # Only add the contour if there is any nonzero element.
            if np.any(mask_img[:, :, slice_index]):
                ax.contour(mask_img[:, :, slice_index], levels=[0.5], colors=mask_colors[key], linewidths=2)
        
        # Build legend using dummy handles.
        legend_handles = [Line2D([0], [0], color=mask_colors["GTVp"], lw=2, label="GTVp")]
        for key in optional_keys:
            legend_handles.append(Line2D([0], [0], color=mask_colors[key], lw=2, label=key))
        ax.legend(handles=legend_handles, loc='upper right')
        
        # # Set a custom coordinate formatter.
        # def format_coord(x, y):
        #     col = int(round(x))
        #     row = int(round(y))
        #     info = f"x={col}, y={row}, slice={slice_index}"
        #     # Always show value from GTVp.
        #     if 0 <= row < gtv_img.shape[0] and 0 <= col < gtv_img.shape[1]:
        #         info += f", GTVp={gtv_img[row, col, slice_index]}"
        #     # Append info for each optional mask.
        #     for key in optional_keys:
        #         mask_img, _ = struct_dict[key]
        #         if 0 <= row < mask_img.shape[0] and 0 <= col < mask_img.shape[1]:
        #             info += f", {key}={mask_img[row, col, slice_index]}"
        #     return info
        # ax.format_coord = format_coord
        
        ax.set_title(f"{title} - Slice {slice_index}")
        ax.axis('off')
        plt.show()
    
    slider = widgets.IntSlider(min=0, max=num_slices-1, step=1, value=num_slices//2, description='Slice')
    display(widgets.interact(view_slice, slice_index=slider))

def process_all_patient_folders(root_folder, patterns, display_widgets=True):
    """
    Loop over all patient folders in the given root_folder and process each one.
    
    For each patient folder (assumed to be named with a patient ID), the function:
      - Loads the main image and structure mask files using the provided regex patterns,
        transposing each array with np.transpose(array, (1,0,2)).
      - Creates an interactive widget that overlays the structure contours over the main image,
        allowing scrolling through the slices along axis 2.
      - Stores the loaded data in a dictionary keyed by patient ID.
    
    Parameters
    ----------
    root_folder : str
        Directory containing patient folders (each named with a patient ID).
    patterns : dict
        Dictionary mapping structure names (e.g., "image", "gtvp", etc.) to regex patterns.
    display_widgets : bool, optional
        If True, display the interactive widget for each patient (default True).
    
    Returns
    -------
    dict
        A dictionary mapping each patient ID (folder name) to a dictionary with keys:
          - "image": the transposed main image (3D numpy array),
          - "voxel_size": the voxel dimensions,
          - "masks": a dictionary mapping mask names to the transposed 3D mask arrays.
    """
    all_patient_data = {}
    
    # Loop over each item in the root_folder.
    for patient_id in sorted(os.listdir(root_folder)):
        patient_folder = os.path.join(root_folder, patient_id)
        
        if os.path.isdir(patient_folder):
            print(f"Processing patient folder: {patient_id}")
            # Load structures from the patient folder using your patterns.
            struct_data = load_patient_structures(patient_folder, patterns)
            if struct_data:
                all_patient_data[patient_id] = struct_data
                # Display an interactive overlay widget if requested.
                if display_widgets:
                    display_patient_overlay_structures(struct_data, title=f"Patient {patient_id} Overlay")
            else:
                print(f"No matching files found for patient {patient_id}.")
    
    return all_patient_data


## Optimization Pipeline

In [12]:
def midline_optimized(csv_filepath, base_path, output_path, 
                      patient=None,
                      theta_deg=25, phi_deg=10,
                      optimization_method='BFGS', 
                      results_path_list=None, 
                      optimization_methods_list=None,
                      HU_range=None,
                      slice_range=None,
                      patient_range=None):
                      
    """
    Process patient data from a CSV file, compute and optimize midline plane parameters.
    
    For each patient, this function:
      - Processes patient data and applies appropriate masks.
      - Extracts nonzero slices and creates intensity and gradient interpolators.
      - Initializes candidate plane parameters using a grid search (via param_initialization_2d).
      - Optimizes the plane parameters using optimize_plane (BFGS) if not already saved.
      - Optionally displays additional results.
    
    
    """
    start_pipeline = time.time()
    list_best_plane_params = []
    list_obj_fun = []
    
    df = pd.read_csv(csv_filepath)

    if patient is not None:
        df = df[df['Patient ID'] == patient].reset_index(drop=True)

    for idx, row in df.iterrows():
        if patient_range is not None:
            if idx < patient_range[0]:
                continue
            if idx > patient_range[1]:
                break
        pat_id = row['Patient ID']
        pat_id = str(pat_id).zfill(8)
        if pd.isna(pat_id) or pd.isna(row['GTVp']):
            continue
        print(f"Processing patient {pat_id} at CSV row {idx}...")
        
        output_path_patient = os.path.join(output_path, f'{pat_id}')
        # if os.path.exists(output_path_patient):
        #     print(f"Patient {pat_id} already processed. Skipping...")
        #     continue
        
        os.makedirs(output_path_patient, exist_ok=True)
        
        # Load patient data from the selected row.
        pdata, pid = load_patient_data_from_row(row, base_path)
        if pdata is None or pid is None:
            print(f"Failed to load data for patient at CSV row {idx}.")
            continue
        else:
            # Extract arrays (each tuple is (array, voxel_size)).
            image = pdata.get("Image", (None,))[0]
            body = pdata.get("Body", (None,))[0]
            gtvp = pdata.get("GTVp", (None,))[0]
            if image is None or gtvp is None:
                print(f"Patient {pid} is missing required data. Skipping...")
                continue
            mandibula = pdata.get("Mandible", (None,))[0]
            spinalcord = pdata.get("Spinal Cord", (None,))[0]
            
            print(f"Patient {pid} data loaded successfully.")
        
        # display_scrollable_slices_with_plane(image, gtvp, body, mandibula, spinalcord)
        # break
        # Convert image to int16 to avoid overflow errors.
        image = image.astype(np.int16)
        
        if body is not None:
        # Apply binary erosion to the body mask.
            body = binary_erosion(body, iterations=2).astype(np.uint8)
            
            # Mask the image using the body mask (background set to -1000).
            image = np.where(body == 1, image, -1000)
        
        # Create bone mask and compute bone CT.
        if HU_range is None:
            HU_range = (900, 2500)
        bone_mask = mask_via_threshold(image, HU_range=(HU_range)).astype(np.uint16)
        bone_ct = image * bone_mask
        
        # # Process dental fillings: extract mask and assign HU values.
        dental_fillings_mask = mask_via_threshold(image, HU_range=(HU_range[1], 5000)).astype(np.uint16)
        dental_bone_ct = 1500 * dental_fillings_mask
        bone_ct = bone_ct + dental_bone_ct
        image = image * (1 - dental_fillings_mask) + dental_bone_ct
        
        # Extract the nonzero slice range from gtvp and apply to all volumes.
        start_slice, end_slice = get_nonzero_slice_range(gtvp)
        if slice_range is not None:
            start_slice = slice_range[0]
            end_slice = slice_range[1] + 1
        
        image = image[:, :, start_slice:end_slice + 1]
        bone_ct = bone_ct[:, :, start_slice:end_slice + 1]
        gtvp = gtvp[:, :, start_slice:end_slice + 1]
        # For optional structures, only apply slicing if they exist.
        if body is not None:
            body = body[:, :, start_slice:end_slice + 1]
        if mandibula is not None:
            mandibula = mandibula[:, :, start_slice:end_slice + 1]
        if spinalcord is not None:
            spinalcord = spinalcord[:, :, start_slice:end_slice + 1]
        
        
        
        # Attempt to load a saved intensity interpolator; if not, it will be created below.
        interpolator_path = os.path.join(output_path_patient, 'interpolator.joblib')
        interpolators_gradient_path = os.path.join(output_path_patient, 'interpolators_gradient.joblib')
        shape = image.shape
        grid_x = np.arange(shape[0])
        grid_y = np.arange(shape[1])
        grid_z = np.arange(shape[2])

        if os.path.exists(interpolator_path):
            interpolator = joblib.load(interpolator_path)
        else:
            start_interpolator = time.time()
            interpolator = RegularGridInterpolator((grid_x, grid_y, grid_z), image, 
                                                   method='cubic', bounds_error=False, fill_value=None)
            end_interpolator = time.time()
            print(f"Cubic interpolator took {end_interpolator - start_interpolator:.2f} seconds.")
            joblib.dump(interpolator, interpolator_path)

        interpolators_gradient = None
        
        # if os.path.exists(interpolators_gradient_path):
        #     interpolators_gradient = joblib.load(interpolators_gradient_path)
        # else:
        #     # Create intensity and gradient interpolators.
        #     shape = image.shape
        #     gradient_x = sobel(image, axis=0)
        #     gradient_y = sobel(image, axis=1)
        #     gradient_z = sobel(image, axis=2)
            
        #     interpolators_gradient = {
        #         'x': RegularGridInterpolator((grid_x, grid_y, grid_z), gradient_x, 
        #                                     method='cubic', bounds_error=False, fill_value=None),
        #         'y': RegularGridInterpolator((grid_x, grid_y, grid_z), gradient_y, 
        #                                     method='cubic', bounds_error=False, fill_value=None),
        #         'z': RegularGridInterpolator((grid_x, grid_y, grid_z), gradient_z, 
        #                                     method='cubic', bounds_error=False, fill_value=None)
        #     }
        #     joblib.dump(interpolators_gradient, interpolators_gradient_path)
        
        
        # Initialize candidate plane parameters using a grid search.
        initial_plane = param_initialization_2d(bone_ct, image, theta_deg, phi_deg, output_path_patient, idx, 
                                                interpolator, interpolators_gradient, plot=False)
        
        if os.path.exists(os.path.join(output_path_patient, "params_array.npy")): 
            plane_params = np.load(os.path.join(output_path_patient, "params_array.npy"))
            best_plane_params = plane_params[-1]
            #distances, normal_vector, indices_coord, indices_im = compute_signed_distances(best_plane_params, gtvp)
            display_scrollable_slices_with_plane(image, gtvp, body, mandibula, spinalcord, [best_plane_params], [optimization_method])
            
            # second_path = os.path.join(r"/home/loriskeller/Documents/Master Project/VS/Data_extract_and_midline/Results/results_07_04_25/Midsagplanes_HU300to2500", f'{pat_id}')
            # second_plane_params = np.load(os.path.join(second_path, "params_array.npy"))
            # second_best_plane_params = second_plane_params[-1]
            # display_scrollable_slices_with_plane(image, gtvp, body, mandibula, spinalcord, [best_plane_params,second_best_plane_params], [f"900-2000", f"300-2000"])
        else:
            # Optimize the plane parameters using the BFGS method.
            start_optimization = time.time()
            opt_result = optimize_plane(initial_plane, bone_ct, interpolator, interpolators_gradient)
            end_optimization = time.time()
            best_plane_params = opt_result.x  # Optimized [theta, phi, L].
            obj_fun = opt_result.fun
            print(f"Optimization ({optimization_method}) took {end_optimization - start_optimization:.2f} seconds.")
            print(f"Optimized parameters: {np.rad2deg(best_plane_params[0])}, {np.rad2deg(best_plane_params[1])}, {best_plane_params[2]} with MSE {round(obj_fun, 2)}")
            best_plane_result = {'params': best_plane_params, 'obj_fun': obj_fun}
            # Assuming res is the optimization result from optimize_plane.
            params_array = np.array(opt_result.params_list)
            objective_array = np.array(opt_result.objective_value_list)
            # Save as .npy files.
            np.save(os.path.join(output_path_patient, "params_array.npy"), params_array)
            np.save(os.path.join(output_path_patient, "objective_array.npy"), objective_array)

            if patient is None:
                list_best_plane_params.append(best_plane_params)
                list_obj_fun.append(obj_fun)
        
            # Plot the middle slice with the optimized plane.
            #plot_middle_slice_with_planes(image, [best_plane_params], title='Middle Slice with Optimized Plane', output_path=output_path_patient, filename=f'Image_slice_with_plane.svg')
            #display_scrollable_slices_with_plane(image, gtvp, body, mandibula, spinalcord, [best_plane_params], [optimization_method])
        
        # Optionally, display additional results.
        # if results_path_list is not None:
        #     results_list = [np.load(path) for path in results_path_list]
        #     display_scrollable_slices_with_plane(image, body, gtvp, mandibula, spinalcord, results_list, optimization_methods_list)
        
        
    # Suppose list_best_plane_params and list_obj_fun are Python lists containing numeric values or arrays.
    # if not os.path.exists(output_path):
    #     # Save the results as a compressed .npz file.
    #     np.savez_compressed(os.path.join(output_path, "best_params_and_objectives.npz"),
    #                 best_plane_params=list_best_plane_params,
    #                 objective_values=list_obj_fun)
        

    end_pipeline = time.time()
    print(f"Total time for processing {idx + 1} patients: {end_pipeline - start_pipeline:.2f} seconds")

## Mapping to real space

In [13]:
def physical_plane_params(theta, phi, L, voxel_size):
    """
    Convert plane parameters from voxel space to physical space.

    Parameters:
        theta, phi, L : float
                Plane parameters in voxel space.
        voxel_size : np.array
                Voxel size.

    Returns:
        params_real : np.array
                Plane parameters in physical space.
    """
    
    vec_normalized = np.array([np.cos(phi)*np.cos(theta),
                    np.cos(phi)*np.sin(theta),
                    np.sin(phi)])
    vec = L * vec_normalized
    vec_real = np.array([vec[0] * voxel_size[0],
                         vec[1] * voxel_size[1],
                         vec[2] * voxel_size[2]])
    theta_real, phi_real, L_real = vector_to_angles(vec_real)
    params_real = np.array([theta_real, phi_real, L_real])

    return params_real

def real_distance_to_plane(x_voxel, theta, phi, L, voxel_size):
    """
    Compute the physical distance (mm) from a set of voxel points to the plane defined by theta, phi, and L.
    
    Parameters:
      x_voxel : np.array of shape (n,3)
          The array of voxel coordinates, where each row is [x, y, z].
      theta, phi, L : float
          Plane parameters in voxel space.
      voxel_size : array-like, shape (3,)
          Voxel dimensions as [s_x, s_y, s_z].
    
    Returns:
      d : np.array
          A 1D array of signed distances (in mm) for each voxel.
    """
    # Convert voxel coordinates to physical coordinates via broadcasting.
    x_phys = x_voxel * np.array(voxel_size)  # Each column multiplied by corresponding voxel size.
    
    # Convert plane parameters to physical space.
    theta_real, phi_real, L_real = physical_plane_params(theta, phi, L, voxel_size)
    
    # Compute signed distance for each voxel.
    d = (np.cos(phi_real)*np.cos(theta_real)*x_phys[:, 0] +
         np.cos(phi_real)*np.sin(theta_real)*x_phys[:, 1] +
         np.sin(phi_real)*x_phys[:, 2] - L_real)
    
    return d


## Verification Plots

In [14]:
def plot_mse_vs_parameters(image, interpolator_intensity, interpolators_gradient,
                           theta_best, phi_best, L_best, output_path, pat):
    """
    Generate plots of the mean squared error (MSE) as a function of θ, φ, and L.
    
    For each parameter, a range is defined around the best value and the objective
    function is evaluated while holding the other two parameters fixed.
    
    Parameters
    ----------
    image : numpy.ndarray
        The 3D image.
    interpolator_intensity : RegularGridInterpolator
        Interpolator for the image intensity.
    interpolators_gradient : dict
        Dictionary with keys 'x', 'y', and 'z' containing RegularGridInterpolator 
        objects for the image gradients.
    theta_best : float
        Best theta value (in radians).
    phi_best : float
        Best phi value (in radians).
    L_best : float
        Best L value.
    output_path : str
        Directory where the plots will be saved.
    pat : str or int
        Identifier (e.g. patient number) used in filenames.
    """
    # Define parameter ranges around the best values.
    # For theta and phi, use ±2 degrees (converted to radians)
    delta_rad = np.deg2rad(2)
    theta_range = np.linspace(theta_best - delta_rad, theta_best + delta_rad, 100)
    phi_range   = np.linspace(phi_best - delta_rad, phi_best + delta_rad, 100)
    # For L, use ±4 units around L_best.
    L_range     = np.linspace(L_best - 4, L_best + 4, 100)
    
    mse_theta = []
    mse_phi   = []
    mse_L     = []
    
    # Compute MSE vs. Theta (with φ and L fixed)
    for theta in theta_range:
        mse = compute_objective(np.array([theta, phi_best, L_best]),
                                image, interpolator_intensity, interpolators_gradient)
        mse_theta.append(mse)
    
    # Compute MSE vs. Phi (with θ and L fixed)
    for phi in phi_range:
        mse = compute_objective(np.array([theta_best, phi, L_best]),
                                image, interpolator_intensity, interpolators_gradient)
        mse_phi.append(mse)
    
    # Compute MSE vs. L (with θ and φ fixed)
    for L in L_range:
        mse = compute_objective(np.array([theta_best, phi_best, L]),
                                image, interpolator_intensity, interpolators_gradient)
        mse_L.append(mse)
    
    # Plot MSE vs. Theta
    plt.figure()
    plt.plot(np.degrees(theta_range), mse_theta, label='MSE vs. Theta')
    plt.xlabel('Theta (degrees)')
    plt.ylabel('MSE')
    plt.title('MSE vs. Theta\nBest Theta = {:.2f}°'.format(np.degrees(theta_best)))
    plt.legend()
    plt.grid(True)
    theta_filename = os.path.join(output_path, f'mse_vs_theta_{pat}.svg')
    plt.savefig(theta_filename)
    
    # Plot MSE vs. Phi
    plt.figure()
    plt.plot(np.degrees(phi_range), mse_phi, label='MSE vs. Phi')
    plt.xlabel('Phi (degrees)')
    plt.ylabel('MSE')
    plt.title('MSE vs. Phi\nBest Phi = {:.2f}°'.format(np.degrees(phi_best)))
    plt.legend()
    plt.grid(True)
    phi_filename = os.path.join(output_path, f'mse_vs_phi_{pat}.svg')
    plt.savefig(phi_filename)
    
    # Plot MSE vs. L
    plt.figure()
    plt.plot(L_range, mse_L, label='MSE vs. L')
    plt.xlabel('L')
    plt.ylabel('MSE')
    plt.title('MSE vs. L\nBest L = {:.2f}'.format(L_best))
    plt.legend()
    plt.grid(True)
    L_filename = os.path.join(output_path, f'mse_vs_L_{pat}.svg')
    plt.savefig(L_filename)
    
    plt.show()
    print("MSE plots saved successfully:")
    print(f"  Theta plot: {theta_filename}")
    print(f"  Phi plot:   {phi_filename}")
    print(f"  L plot:     {L_filename}")


## Output

In [None]:
%matplotlib widget
#%matplotlib inline

data_path = r"/home/loriskeller/Documents/Master Project/Patient data/patient_data_complete/Patient_structures_clean"
csv_path = r"/home/loriskeller/Documents/Master Project/filtered_patients_extention_position_structures_final.csv"
output_path = r"/home/loriskeller/Documents/Master Project/VS/Data_extract_and_midline/Results/14.04.25/Midsagittalplanes Huber delta 300, threshold 300-1500"

midline_optimized(csv_path, data_path, output_path, theta_deg=25, phi_deg=10, optimization_method='BFGS', patient = 10683066 , HU_range=(300, 1500), 
                  slice_range=None, patient_range=None)

#load_patient_data_from_csv(csv_path, data_path, pat_id=10161216)

# 10687063
# 10376522
# 10587029
# 10027159
# 10621910
# 10774767



Processing patient 10683066 at CSV row 0...
✅ Success: 'temp_nifti.nii' was written successfully (679383904 bytes).
✅ Success: 'temp_nifti.nii' was written successfully (679383904 bytes).
Loaded GTVp from /home/loriskeller/Documents/Master Project/Patient data/patient_data_complete/Patient_structures_clean/10683066/GTVp.nii.gz
✅ Success: 'temp_nifti.nii' was written successfully (679383904 bytes).
Loaded Body from /home/loriskeller/Documents/Master Project/Patient data/patient_data_complete/Patient_structures_clean/10683066/mask_BODY.nii.gz
✅ Success: 'temp_nifti.nii' was written successfully (679383904 bytes).
Loaded Mandible from /home/loriskeller/Documents/Master Project/Patient data/patient_data_complete/Patient_structures_clean/10683066/mask_Mandible.nii.gz
✅ Success: 'temp_nifti.nii' was written successfully (679383904 bytes).
Loaded Spinal Cord from /home/loriskeller/Documents/Master Project/Patient data/patient_data_complete/Patient_structures_clean/10683066/mask_SpinalCord.nii

interactive(children=(IntSlider(value=19, description='Axial Slice', max=38), Output()), _dom_classes=('widget…

<function __main__.display_scrollable_slices_with_plane.<locals>.view_slice_axial(slice_index)>

interactive(children=(IntSlider(value=270, description='Coronal Slice', max=386, min=155), Output()), _dom_cla…

<function __main__.display_scrollable_slices_with_plane.<locals>.view_slice_coronal(slice_index)>

Total time for processing 1 patients: 12.10 seconds
