# Functions file for midplane detection
This file is used to define all function necessary for miplane detection given a head CT and contours such as head mask, primary GTV, spinal cord and mandible. The midplane detection is based on optimization using the mean squared error between intenisties of voxels and mirror voxels given by the plane . The plane is defined through the vector (A,B,C) pointing in perpendicular direction to the plane from the coordinate origin to the plane. The optimization parameters are the plane parameters (A,B,C) or the rotational angles ($\theta$,$\phi$,L) where L is defined as the distance from the origin (Length of vector (A,B,C)). The optimization technique can be gradient descent based or Nelder-Mead type.

## Install required packages

In [None]:

import joblib
import nibabel as nib
import matplotlib.pyplot as plt
import os
import numpy as np
from scipy.ndimage import center_of_mass
from scipy.optimize import minimize
import imageio
import re
import pandas as pd
import gzip
from scipy.optimize import curve_fit
import csv
from scipy.interpolate import RegularGridInterpolator
from scipy.ndimage import map_coordinates
from joblib import Parallel, delayed, dump, load
import numpy as np
from skimage.morphology import binary_opening, binary_closing, disk, ball
from matplotlib.widgets import Slider
import random
import ipywidgets as widgets
from IPython.display import display
from matplotlib.lines import Line2D
import time
import timeit
from scipy.ndimage import label, sobel, rotate, zoom, binary_erosion, binary_dilation, binary_opening
from scipy import ndimage
import json
from tqdm import tqdm
import cProfile
import numpy as np
import plotly.graph_objects as go

## Read CT data
Reads nifti.gz files arrays and voxel size information

In [None]:
def load_nifti_file(file_path):
    """
    Load a NIfTI file and return the image data.
    
    Parameters:
    file_path (str): Path to the NIfTI file.
    
    Returns:
    numpy.ndarray: The image data.
    """
    import nibabel as nib

    nifti_img = nib.load(file_path)
    img_data = nifti_img.get_fdata()
    voxel_size = nifti_img.header.get_zooms()
    
    return img_data, voxel_size

def open_gzip_file(gzip_file_path):
    """
    Open a gzip file and return its content.

    Parameters:
    gzip_file_path (str): The path to the gzip file.

    Returns:
    bytes: The content of the gzip file.
    """
    try:
        with gzip.open(gzip_file_path, 'rb') as f_in:
            file_content = f_in.read()
        return file_content
    except Exception as e:
        print(f'Error opening {gzip_file_path}: {e}')
        return None
    
def get_image_and_voxel_size_from_gzip(gzip_file_path):
    """
    Get the image array and voxel size from a gzipped NIfTI file.

    Parameters:
    gzip_file_path (str): Path to the gzipped NIfTI file.

    Returns:
    tuple: The image array and voxel size.
    """

    file_content = open_gzip_file(gzip_file_path)
    if file_content is not None:
        with open('temp_nifti.nii', 'wb') as temp_file:
            temp_file.write(file_content)

        # 🔹 Debugging: Check if the file was actually written
        file_size = os.path.getsize('temp_nifti.nii')
        if file_size == 0:
            print(f"❌ Error: 'temp_nifti.nii' was written but is empty! ({gzip_file_path})")
            return None, None
        else:
            print(f"✅ Success: 'temp_nifti.nii' was written successfully ({file_size} bytes).")

        img_data, voxel_size = load_nifti_file('temp_nifti.nii')
        os.remove('temp_nifti.nii')  # Remove temp file after reading
        
        return img_data, voxel_size
    else:
        print(f"❌ Error: Failed to read file content from '{gzip_file_path}'")
        return None, None
    
def load_structure_files(patient_folder_path, structures, structures_exact):
    """
    Load the NIfTI files for the given structures from the patient folder.

    Parameters:
    patient_folder_path (str): The path to the patient folder.
    structures (list): A list of structure names to load.

    Returns:
    dict: A dictionary with structure names as keys and loaded image data as values.
    """
    structure_images = {}
    for folderpath, subfolders, files in os.walk(patient_folder_path):
        for subfolder in subfolders:
            subfolder_path = os.path.join(folderpath, subfolder)
            # Convert all file names in the subfolder to lowercase
            for file_name in os.listdir(subfolder_path):
                lower_file_name = file_name.lower()
                os.rename(os.path.join(subfolder_path, file_name), os.path.join(subfolder_path, lower_file_name))
            check = all(os.path.exists(os.path.join(subfolder_path, structure_exact + '.nii.gz')) for structure_exact in structures_exact)
            if check:
                for index, structure_exact in enumerate(structures_exact):
                    file_path = os.path.join(subfolder_path, structure_exact + '.nii.gz')
                    structure_images[structures[index]], voxel_size = get_image_and_voxel_size_from_gzip(file_path)
                return structure_images, voxel_size
        raise FileNotFoundError("None of the subfolders contain all the required structure files.")

def process_patient_by_id(patient_id, row, base_folder_path):
    """
    Process a patient by ID and load the NIfTI files for the specified structures.

    Parameters:
    patient_id (str): The patient ID.
    csv_file_path (str): The path to the CSV file containing patient IDs and structures.
    base_folder_path (str): The base folder path containing patient folders.
    """
    structures = ['GTVp', 'body', 'spinal cord', 'mandible', 'image']
    #row = df[df['Patient_ID'] == patient_id]
    if not row.empty:
        if isinstance(row, pd.Series):
            row = row.to_frame().T
        structures_exact = row.iloc[:, 1:-1].dropna().values.flatten().tolist()
        structures_exact = ['mask_' + structure_exact for structure_exact in structures_exact]
        structures_exact.append('image')
        patient_folder_path = os.path.join(base_folder_path, str(patient_id))
        if os.path.exists(patient_folder_path):
            structure_images, voxel_size = load_structure_files(patient_folder_path, structures, structures_exact)
            if len(structure_images) == len(structures):
                print(f"Loaded all structures for patient {patient_id}")
                return structure_images, voxel_size
            else:
                print(f"Missing structures for patient {patient_id}")
                return None, None
        else:
            print(f"Patient folder {patient_id} does not exist")
            return None, None
    else:
        print(f"Patient ID {patient_id} not found in the CSV file")
        return None, None

def process_patient_data(row, base_path, output_path, pat_num):
    """
    Processes patient data by loading and transforming medical images and structures.
    Args:
        row (pd.Series): A row from a DataFrame containing patient data, including 'Patient_ID' and 'Extention'.
        base_path (str): The base directory path where patient data is stored.
        output_path (str): The directory path where processed patient data will be saved.
        pat_num (int): The patient number used for naming the output directory.
    Returns:
        tuple: A tuple containing the following elements:
            - image (np.ndarray): The processed image data.
            - gtvp (np.ndarray): The processed GTVp (Gross Tumor Volume primary) data.
            - body (np.ndarray): The processed body structure data.
            - spinalcord (np.ndarray): The processed spinal cord structure data.
            - mandibula (np.ndarray): The processed mandible structure data.
            - structure_images (dict): A dictionary containing all loaded structure images.
            - voxel_size (tuple): The voxel size of the images.
            - patient_folder_path (str): The path to the patient's folder.
            - output_path_patient (str): The path to the output folder for the patient.
            - patient_id (str): The patient ID padded to eight characters.
            - extension (str): The extension type ('negative' or 'positive').
            - pat_num (int): The patient number.
    Notes:
        - If the extension is not 'negative' or 'positive', or if the patient folder does not exist,
          the function returns a tuple of None values.
        - The function checks the connectivity of the GTVp structure and returns a tuple of None values
          if the GTVp is not connected.
    """
    
    patient_id = pad_to_eight_characters([str(row['Patient_ID'])])[0]
    extension = row['Extention']
    # if extension not in ['negative', 'positive']:
    #     print(f"Invalid extension: {extension} for patient: {patient_id}")
    #     return None, None, None, None, None, None, None, None, None, None, None, pat_num
    patient_folder_name = f"{patient_id}"
    patient_folder_path = os.path.join(base_path, patient_folder_name)
    output_path_patient = os.path.join(output_path, f"pat_{pat_num}")
    os.makedirs(output_path_patient, exist_ok=True)
    print(f"Processing patient number: {pat_num}")
    print(f"Patient folder path: {patient_folder_path}")
    if not os.path.isdir(patient_folder_path):
        print(f"Patient folder does not exist for patient: {patient_id}")
        return None, None, None, None, None, None, None, None, None, None, None, pat_num
    structure_images, voxel_size = process_patient_by_id(patient_id, row, base_path)
    #structure_images = check_multiple_structures(structure_images)
    image = np.transpose(structure_images['image'], (1, 0, 2))
    gtvp = np.transpose(structure_images['GTVp'], (1, 0, 2))
    # if not check_connectivity(gtvp):
    #     print(f"GTVP not connected for patient: {pat_num}")
    #     return None, None, None, None, None, None, None, None, None, None, None, None
    body = np.transpose(structure_images['body'], (1, 0, 2))
    spinalcord = np.transpose(structure_images['spinal cord'], (1, 0, 2))
    mandibula = np.transpose(structure_images['mandible'], (1, 0, 2))
    print(f"Loaded structures: {structure_images.keys()}")
    return image, gtvp, body, spinalcord, mandibula, structure_images, voxel_size, patient_folder_path, output_path_patient, patient_id, extension, pat_num


## CSV data preparation

In [None]:
def pad_to_eight_characters(input_list):
    """
    Check if each string in the list has 8 characters, and if not, add zeros at the beginning to make it 8 characters long.

    Parameters:
    input_list (list): A list of strings.

    Returns:
    list: A list of strings with each string padded to 8 characters.
    """
    return [s.zfill(8) for s in input_list]


## Image processing

In [None]:
def get_nonzero_slice_range(image_data, slice_dir_indx=2):
    """
    Get the range of slices that contain non-zero values in a 3D image.

    Parameters:
    image_data (numpy.ndarray): The 3D image data.
    slice_dir_indx (int): The index of the slice direction (0, 1, or 2).

    Returns:
    tuple: The start and end indices of the non-zero slice range.
    """
    non_zero_slices = []
    for i in range(image_data.shape[slice_dir_indx]):
        if slice_dir_indx == 0:
            slice_array = image_data[i, :, :]
        elif slice_dir_indx == 1:
            slice_array = image_data[:, i, :]
        else:
            slice_array = image_data[:, :, i]
        
        if np.any(slice_array):  # Check if the slice has any non-zero elements
            non_zero_slices.append(i)
    
    if non_zero_slices:
        start, end = non_zero_slices[0], non_zero_slices[-1]
    else:
        raise ValueError("No non-zero slices found in the specified direction.")
    
    return start, end

def mask_via_threshold(ct_image, HU_range=(700, 2000)):
    """
    Generate a bone mask from a CT image by thresholding within a specified HU range.

    Parameters:
    ct_image (numpy.ndarray): The 3D CT image data.
    HU_range (tuple): The range of Hounsfield Units to identify bone. Default is (700, 2000).

    Returns:
    numpy.ndarray: The 3D bone mask with bone as 1 and all other as 0.
    """
    bone_mask = np.zeros_like(ct_image)
    lower_bound, upper_bound = HU_range
    #print(f"Applying HU range: {lower_bound} to {upper_bound}")
    #print(f"CT image min value: {np.min(ct_image)}, max value: {np.max(ct_image)}")
    bone_mask[(ct_image >= lower_bound) & (ct_image <= upper_bound)] = 1
    #print(f"Bone mask generated with shape: {bone_mask.shape}, number of bone voxels: {np.sum(bone_mask)}")
    return bone_mask

def calculate_center_of_mass(image_data):
    """
    Calculate the center of mass of the given 3D image data within a specified range of slices.
    
    Parameters:
    image_data (numpy.ndarray): The 3D image data.
    slice_range (tuple): A tuple specifying the range of slices in the z direction (start, end).
    
    Returns:
    tuple: The coordinates of the center of mass.

    """
   
    from scipy.ndimage import center_of_mass


    return center_of_mass(image_data)

def check_connectivity(image_mask):
    """
    Check the connectivity of the mask.

    Parameters:
    image_mask (numpy.ndarray): The 3D mask data.

    Returns:
    bool: True if the mask is connected, False otherwise.
    """
    # Check the connectivity of the mask
    labeled_mask, num_labels = label(image_mask)
    return num_labels == 1

def count_voxels_per_slice(image_data, plot = False):
    """
    Compute the number of non-zero voxels in each slice along the third dimension (z-axis).

    Parameters:
    image_data (numpy.ndarray): The 3D image data.

    Returns:
    list: A list containing the number of non-zero voxels for each slice along the z-axis.
    """
    voxel_counts = []
    for z in range(image_data.shape[2]):
        slice_array = image_data[:, :, z]
        voxel_count = np.count_nonzero(slice_array)
        voxel_counts.append(voxel_count)
    if plot:
        plt.figure(figsize=(10, 6))
        plt.plot(np.arange(len(voxel_counts)), np.array(voxel_counts))
        plt.title('Barplot of Voxel Counts per Slice')
        plt.xlabel('z-axis (Slice Index)')
        plt.ylabel('Number of nonzero voxels of body mask')
        plt.show()

    return voxel_counts

def gradient_descent_voxel_counts(voxel_counts, step_size=0.01, max_iter=1000, tolerance=1e-6, plot = False):
    """
    Perform gradient descent to find the slice index with the minimum voxel count.

    Parameters:
    voxel_counts (list): A list containing the number of non-zero voxels for each slice along the z-axis.
    learning_rate (float): The learning rate for gradient descent.
    max_iter (int): The maximum number of iterations.
    tolerance (float): The tolerance for convergence.

    Returns:
    int: The slice index with the minimum voxel count.
    """
    def gradient(slice_index, voxel_counts):
        if slice_index == 0:
            return (voxel_counts[slice_index + 1] - voxel_counts[slice_index]) / 1
        else:
            return (voxel_counts[slice_index + 1] - voxel_counts[slice_index - 1]) / 2

    slice_index = 20  # Start at the fifth slice index
    for _ in range(max_iter):
        grad = gradient(slice_index, voxel_counts)
        new_slice_index = slice_index - step_size * grad
        new_slice_index = int(np.clip(new_slice_index, 0, len(voxel_counts) - 1))

        if abs(new_slice_index - slice_index) < tolerance:
            break

        slice_index = new_slice_index

    if plot:
        bars = plt.bar(range(len(voxel_counts)), voxel_counts)
        bars[slice_index].set_color('red')
        plt.title('Barplot of Voxel Counts per Slice')
        plt.xlabel('Slice Index')
        plt.ylabel('Number of Nonzero Voxels in Body Mask')
        plt.show()
    return slice_index

def select_slices(start_head, image, body, spinal_cord=None, mandible=None, gtvp=None):
    """
    Select slices from the given 3D images based on the body mask.

    Parameters:
    image (numpy.ndarray): The 3D image data.
    body (numpy.ndarray): The 3D body mask.
    spinal_cord (numpy.ndarray): The 3D spinal cord mask.
    mandible (numpy.ndarray): The 3D mandible mask.
    gtvp (numpy.ndarray): The 3D GTVP mask.

    Returns:
    tuple: The selected slices of the images.
    """
    start_body, end_body = get_nonzero_slice_range(body, slice_dir_indx=2)
    slice_dim_indx = 2
    image_selected = select_slice_range(image, slice_dim_indx, (start_head, end_body))
    body_selected = select_slice_range(body, slice_dim_indx, (start_head, end_body))
    
    results = [image_selected, body_selected]
    
    if spinal_cord is not None:
        spinal_cord_selected = select_slice_range(spinal_cord, slice_dim_indx, (start_head, end_body))
        results.append(spinal_cord_selected)
    if mandible is not None:
        mandible_selected = select_slice_range(mandible, slice_dim_indx, (start_head, end_body))
        results.append(mandible_selected)
    if gtvp is not None:
        gtvp_selected = select_slice_range(gtvp, slice_dim_indx, (start_head, end_body))
        results.append(gtvp_selected)
    
    return tuple(results)

def select_slice_range(image_data, slice_dir_indx, slice_index_range):
    """
    Select a certain range of slices along a specified slice direction from a 3D array and return the new 3D array.
    
    Parameters:
    image_data (numpy.ndarray): The 3D image data.
    slice_dir_indx (int): The index of the slice direction (0, 1, or 2).
    start_slice (int): The starting slice index.
    end_slice (int): The ending slice index.
    
    Returns:
    numpy.ndarray: The new 3D array with the selected range of slices.
    """
    start_slice, end_slice = slice_index_range
    if slice_dir_indx == 0:
        return image_data[start_slice:end_slice, :, :]
    elif slice_dir_indx == 1:
        return image_data[:, start_slice:end_slice, :]
    else:
        return image_data[:, :, start_slice:end_slice]
    

def rotate_3d_array(image, angle, fill_value=0):
    """
    Rotate a 3D image by the given angle around the Z-axis while retaining the same dimensions.

    Parameters:
    image (numpy.ndarray): The 3D image data.
    angle (float): The angle by which to rotate the image (in degrees).
    fill_value (int, optional): The value to fill in the empty spaces after rotation. Default is 0.

    Returns:
    numpy.ndarray: The rotated 3D image with the same dimensions as the input.
    """

    # Rotate the image by the given angle around the Z-axis (axis=2 for 3D)
    rotated_image = rotate(image, angle, axes=(1, 0), reshape=False, mode='constant', cval=fill_value)
    rotated_image = np.round(rotated_image).astype(int)
    
    return rotated_image

def build_image_pyramid(image, num_levels=4):
    """
    Build an image pyramid with downsampled images at each level.
    
    Parameters:
    image (ndarray): The input 3D image.
    num_levels (int): The number of levels in the pyramid.
    
    Returns:
    pyramid (list of ndarrays): A list containing the images at each pyramid level.
    """
    pyramid = [image]
    
    for level in range(1, num_levels):
        # Downsample the image by a factor of 2 for each level
        downsampled_image = zoom(image, (0.5, 0.5, 0.5), order=3)  # Using cubic interpolation
        pyramid.append(downsampled_image)
        image = downsampled_image
    
    return pyramid


def create_ellipsoid(size=(100, 100, 100), a=50, b=30, c=20, center=(50, 50, 50)):
    """
    Create a 3D ellipsoid with highest intensity at the edges and values ranging from 200 to 2000 inside.
    The intensity outside the ellipsoid is set to -1000. Two spheres of intensity 2000 are placed symmetrically
    on either side of the mid-sagittal plane.
    
    Parameters:
    size (tuple): The size of the 3D grid (x, y, z).
    a, b, c (float): The semi-principal axes of the ellipsoid (controls the shape).
    center (tuple): The center of the ellipsoid in the grid.
    
    Returns:
    ndarray: A 3D array representing the ellipsoid with intensity variations.
    """
    x = np.linspace(0, size[0] - 1, size[0])
    y = np.linspace(0, size[1] - 1, size[1])
    z = np.linspace(0, size[2] - 1, size[2])
    X, Y, Z = np.meshgrid(x, y, z, indexing='ij')
    
    # Normalize the coordinates
    X_norm = (X - center[0]) / a
    Y_norm = (Y - center[1]) / b
    Z_norm = (Z - center[2]) / c

    # Compute the ellipsoid mask
    ellipsoid = X_norm**2 + Y_norm**2 + Z_norm**2
    mask = ellipsoid <= 1  # Inside ellipsoid
    
    # Compute intensity: higher at the edges
    intensity = 200 + (1800 * ellipsoid)  # Scale to range 200-2000
    
    # Set values outside the ellipsoid to -1000
    intensity[~mask] = -1000
    
    # # Add two high-intensity spheres symmetrically about the mid-sagittal plane
    # sphere_radius = 10
    # sphere_offset = 30  # Distance from the mid-sagittal plane
    
    # left_sphere = ((X - (center[0] - sphere_offset))**2 + (Y - center[1])**2 + (Z - center[2])**2) <= sphere_radius**2
    # right_sphere = ((X - (center[0] + sphere_offset))**2 + (Y - center[1])**2 + (Z - center[2])**2) <= sphere_radius**2
    
    # intensity[left_sphere] = 2000
    # intensity[right_sphere] = 2000
    intensity = intensity[:,:,30:70]
    
    return intensity

def create_asymmetric_head(size=(100, 100, 100), a=50, b=30, c=40, center=(50, 50, 50)):
    """
    Create a 3D object mimicking an asymmetric head with a unique mid-sagittal plane.
    
    Parameters:
    - size (tuple): Size of the 3D grid (x, y, z).
    - a, b, c (float): Semi-principal axes of the ellipsoid shape.
    - center (tuple): Center of the ellipsoid in the grid.
    
    Returns:
    - ndarray: A 3D array with intensity variations introducing asymmetry.
    """
    x = np.linspace(0, size[0] - 1, size[0])
    y = np.linspace(0, size[1] - 1, size[1])
    z = np.linspace(0, size[2] - 1, size[2])
    X, Y, Z = np.meshgrid(x, y, z, indexing='ij')

    # Normalize coordinates
    X_norm = (X - center[0]) / a
    Y_norm = (Y - center[1]) / b
    Z_norm = (Z - center[2]) / c

    # Create an ellipsoid mask
    ellipsoid = X_norm**2 + Y_norm**2 + Z_norm**2
    mask = ellipsoid <= 1  # Inside the ellipsoid

    # Generate base intensity values (range: 200 to 2000, higher near edges)
    intensity = 200 + 1800 * ellipsoid

    # Set background intensity outside the ellipsoid
    intensity[~mask] = -1000

    # Introduce asymmetry: Increase intensity slightly on the right half
    asymmetry_factor = np.exp(-((X - center[0]) / 10)**2)  # Right side denser
    intensity[mask] += asymmetry_factor[mask] * 300  # Introduce asymmetry

    # Add a high-intensity feature only on one side
    asym_sphere = ((X - (center[0] + 15))**2 + (Y - (center[1] + 5))**2 + (Z - center[2])**2) < 15**2
    intensity[asym_sphere] = 2200  # High-intensity asymmetry

    return intensity

def generate_symmetric_ct_like_volume(shape=(128, 128, 128)):
    """
    Generates a 3D volume with varying intensities that has only one unique symmetry plane.
    This volume simulates a CT-like structure with different densities.
    
    Parameters:
    shape (tuple): The shape of the 3D volume (default is 128x128x128).
    
    Returns:
    np.ndarray: A 3D volume with a unique mid-sagittal symmetry plane.
    """
    x_dim, y_dim, z_dim = shape
    volume = np.zeros(shape, dtype=np.float32)

    # Generate a rough head-like shape using an ellipsoid function
    X, Y, Z = np.meshgrid(np.linspace(-1, 1, x_dim), 
                           np.linspace(-1, 1, y_dim), 
                           np.linspace(-1, 1, z_dim), indexing='ij')

    ellipsoid = (X**2 / 0.7**2) + (Y**2 / 1.0**2) + (Z**2 / 0.8**2) <= 1
    volume[ellipsoid] = np.random.uniform(50, 150, size=ellipsoid.sum())  # Varying intensities

    # Add asymmetry: Place an artificial high-density region on one side
    asymmetry = (X > 0.2) & (Y > -0.2) & (Y < 0.2) & (Z > -0.2) & (Z < 0.2)
    volume[asymmetry] += 100  # Bright artificial bone-like structure

    return volume


## Plot projection of plane on image slice
Plot projection of midplane onto a slice of the 3d head CT

In [None]:

def show_ct_slice(image_data, slice_index, save_path=None, title=''):
    """
    Display a grayscale CT image slice and optionally save it.

    Parameters:
    image_data (numpy.ndarray): The 3D image data.
    slice_index (int): The index of the slice to display.
    save_path (str, optional): Directory path to save the image. Default is None.
    title (str, optional): Title of the plot.
    """
    slice_array = image_data[:, :, slice_index]
    plt.figure(figsize=(5, 5))
    plt.imshow(slice_array, cmap='gray')
    plt.title(f'CT Image Slice {slice_index} \n {title}')
    # plt.axis('off')

    # ✅ Save BEFORE showing the figure
    if save_path is not None:
        plt.savefig(os.path.join(save_path, f"ct_slice_{slice_index}_{title}.png"))
    
    plt.show()
    plt.close()



def plot_plane_on_middle_slice(image_data, plane_coeffs, title=None, save_path=None, pat=None, com=None):
    """
    Plot the plane projection on the middle slice of the image.

    Parameters:
    image_data (numpy.ndarray): The 3D image data.
    plane_coeffs (tuple): The coefficients (A, B, C, D) of the plane equation Ax + By + Cz + D = 0.
    title (str): The title of the plot.
    save_path (str): The path to save the plot. If None, the plot is not saved.
    pat (str): Patient identifier, if applicable.
    com (tuple): Center of mass (x, y, z), if applicable.
    """
    A, B, C, D = plane_coeffs
    middle_slice_index = image_data.shape[2] // 2
    slice_array = image_data[:, :, middle_slice_index]
    
    fig, ax = plt.subplots(figsize=(5, 5))
    ax.imshow(slice_array, cmap='gray')
    
    if com is not None:
        ax.scatter(com[0], com[1], color='blue', marker='x', label='Center of Mass')
    
    ax.set_title(title)

    
    # Calculate the line of intersection of the plane with the slice
    x = np.linspace(0, slice_array.shape[1], 100)
    y = np.linspace(0, slice_array.shape[0], 100)
    X, Y = np.meshgrid(x, y)
    Z = middle_slice_index
    
    if C == 0:
        Z = A * X + B * Y + D
        ax.contour(X, Y, Z, levels=[0], colors='red', label='Plane Contour')
    else:
        contour = (-A * X - B * Y - D) / C  # Solve for Z in the plane equation
        ax.contour(X, Y, contour, levels=[Z], colors='red', label='Plane Contour')
    
    ax.legend()
    
    # Save the plot if a save path is specified
    if save_path:
        plt.savefig(os.path.join(save_path, f"plane_projection_optimized_patient_{pat}.png"))
    
    plt.show()
    #plt.close(fig)  # Properly close the figure to avoid reuse

def plot_middle_slice_with_planes(image_data, plane_params_list, title='Middle Slice with Plane Projections', com = None, coronal = False, output_path = None, 
                                  filename = f"middle_slice_with_planes.png"):
    """
    Plot the middle slice of the image with projections of multiple planes.

    Parameters:
    image_data (numpy.ndarray): The 3D image data.
    plane_params_list (list): A list of plane parameters, where each element is a tuple (A, B, C, D).
    title (str): The title of the plot.
    """

    middle_slice_index = image_data.shape[2] // 2
    slice_array = image_data[:, :, middle_slice_index]
    
    fig, ax = plt.subplots(figsize=(10, 10))
    ax.imshow(slice_array, cmap='gray')
    ax.set_title(title)

    if com is not None:
        ax.scatter(com[0], com[1], color='blue', marker='x', label='Center of Mass')
    
    x = np.linspace(0, slice_array.shape[1], 100)
    y = np.linspace(0, slice_array.shape[0], 100)
    X, Y = np.meshgrid(x, y)
    Z = middle_slice_index
    
    for plane_coeffs in plane_params_list:
        plane_coeffs /= (np.linalg.norm(plane_coeffs[:3]))
        A, B, C, D = plane_coeffs

        if C == 0:
            Z = A * X + B * Y + D
            ax.contour(X, Y, Z, levels=[0], colors='red')
        else:
            contour = (-A * X - B * Y - D) / C  # Solve for Z in the plane equation
            ax.contour(X, Y, contour, levels=[Z], colors='red')
    plt.axis('off')
    plt.show()
    if output_path is not None:
        plt.savefig(os.path.join(output_path, filename))

    if coronal:
        plt.figure(figsize=(10, 10))
        plt.imshow(image_data, cmap='bwr')
        plt.colorbar(label='Difference Intensity')
        plt.title('Difference Image')
        plt.axis('off')
        plt.show()

def create_gif_with_plane_projections(image_data, plane_params_list, output_path, pat=None, com=None, objective_values=None):
    """
    Create a GIF showing each plane projection on the middle slice of the image.

    Parameters:
    image_data (numpy.ndarray): The 3D image data.
    plane_params_list (list): A list of plane parameters, where each element is a tuple (A, B, C, D).
    output_path (str): The path to save the GIF.
    pat (str): Patient identifier, if applicable.
    com (tuple): Center of mass (x, y, z), if applicable.
    """
    middle_slice_index = image_data.shape[2] // 2
    slice_array = image_data[:, :, middle_slice_index]
    images = []
    initial_plane = plane_params_list[0]
    A_initial, B_initial, C_initial = angles_to_vector(initial_plane[0], initial_plane[1], initial_plane[2])
    D_initial = -np.dot([A_initial, B_initial, C_initial], [A_initial, B_initial, C_initial])

    for idx, plane_coeffs in enumerate(plane_params_list):  # Use enumerate for index tracking
        fig, ax = plt.subplots(figsize=(5, 5))
        ax.imshow(slice_array, cmap='gray')
        
        if com is not None:
            ax.scatter(com[0], com[1], color='blue', marker='x', label='Center of Mass')
        
        x = np.linspace(0, slice_array.shape[1], 100)
        y = np.linspace(0, slice_array.shape[0], 100)
        X, Y = np.meshgrid(x, y)
        Z = middle_slice_index

        A, B, C = angles_to_vector(plane_coeffs[0], plane_coeffs[1], plane_coeffs[2])
        D = -np.dot([A, B, C], [A, B, C]) 
        
        if C == 0:
            # Solve for Y in terms of X: Y = (-A/B) * X - D/B
            if B != 0:
                Y_line = (-A / B) * X - D / B  # Equation of the line
                ax.plot(X, Y_line, color='red', linestyle='--', label='2D Plane Line')
            else:
                # Special case: If A != 0 and B == 0, it's a vertical line at X = -D/A
                ax.axvline(x=-D / A, color='red', linestyle='--', label='Vertical Line')
        else:
            contour = (-A * X - B * Y - D) / C  # Solve for Z in the plane equation
            ax.contour(X, Y, contour, levels=[Z], colors='red', label='Optimized Plane')
        

        contour_initial = (-A_initial * X - B_initial * Y - D_initial) / C_initial
        ax.contour(X, Y, contour_initial, levels=[Z], colors='green', label='Initial Plane')
        
        # Plot a contour line along the y coordinate at the middle of the x coordinate
        # middle_x = [slice_array.shape[0] // 2] * 100
        # X_middle, Y_middle = np.meshgrid(middle_x, y)
        # Z_middle = np.zeros_like(Y_middle)  # Ensure Z_middle is a 2D array
        # ax.contour(X_middle, Y_middle, Z_middle, levels=[0], colors='blue', linestyles='--', label='Center Plane')
        
        legend_elements = [
        Line2D([0], [0], color='red', lw=2, label='Optimized Plane'),
        Line2D([0], [0], color='green', lw=2, label='Initial Plane')]
        # Line2D([0], [0], color='blue', lw=2, linestyle='--', label='Center Plane')
        # ]

        # Add legend to the plot
        ax.legend(handles=legend_elements, loc='upper right')
        ax.set_title(f'Plane Projection {idx + 1} \n Objective Value: {objective_values[idx]:.2f}')
        
        # Save the current figure to a temporary file
        temp_path = os.path.join(output_path, f"temp_plane_{idx}.png")
        plt.savefig(temp_path)
        plt.close(fig)
        
        # Read the saved image and append to the images list
        images.append(imageio.imread(temp_path))
    
    # Create a GIF from the images
    gif_path = os.path.join(output_path, f"plane_projections_patient_{pat}.gif")
    imageio.mimsave(gif_path, images, duration=0.5)
    
    # Remove temporary files
    for idx in range(len(plane_params_list)):
        temp_path = os.path.join(output_path, f"temp_plane_{idx}.png")
        os.remove(temp_path)


def create_gif_of_objective_updates(objective_values, output_path, gif_name='objective_updates.gif'):
    """
    Create a GIF showing the updates of the objective values as a function of the iteration.

    Parameters:
    objective_values (list): A list of objective values.
    output_path (str): The path to save the GIF.
    gif_name (str): The name of the GIF file.
    """
    images = []

    for i in range(1, len(objective_values)):
        fig, ax = plt.subplots(figsize=(5, 5))
        ax.plot(range(i + 1), objective_values[:i + 1], marker='.', color='blue', label='Objective Value')
        ax.set_title(f'Iteration {i}')
        ax.set_xlabel('Iteration')
        ax.set_ylabel('Objective Value')
        ax.legend()
        
        # Save the current figure to a temporary file
        temp_path = os.path.join(output_path, f"temp_objective_{i}.png")
        plt.savefig(temp_path)
        plt.close(fig)
        
        # Read the saved image and append to the images list
        images.append(imageio.imread(temp_path))
    
    # Create a GIF from the images
    gif_path = os.path.join(output_path, gif_name)
    imageio.mimsave(gif_path, images, duration=0.5)
    
    # Remove temporary files
    for temp_path in [os.path.join(output_path, f"temp_objective_{i}.png") for i in range(1, len(objective_values))]:
        os.remove(temp_path)

def save_image_gifs(image, pat_num, output_path, structure = 'image'):
    """
    Save the 3D image as a series of 2D image slices in a GIF.

    Parameters:
    image (numpy.ndarray): The 3D image data.
    """
    images = []
    for i in range(image.shape[2]):
        slice_array = image[:, :, i]
        fig, ax = plt.subplots(figsize=(5, 5))
        ax.imshow(slice_array, cmap='gray')
        ax.set_title(f'Image Slice {i}')
        ax.axis('off')
        
        # Save the current figure to a temporary file
        temp_path = os.path.join(output_path, f"temp_image_{i}.png")
        plt.savefig(temp_path)
        plt.close(fig)
        
        # Read the saved image and append to the images list
        images.append(imageio.imread(temp_path))
    
    # Create a GIF from the images
    gif_path = os.path.join(output_path, f"image_slices_patient_{pat_num}_{structure}.gif")
    imageio.mimsave(gif_path, images, duration=2)
    
    # Remove temporary files
    for i in range(image.shape[2]):
        temp_path = os.path.join(output_path, f"temp_image_{i}.png")
        os.remove(temp_path)



def plot_slice_with_planes(image, body_mask, gtv_mask, mandible_mask, spinal_cord_mask,
                           plane_coeffs_list, optimization_methods_list, slice_index,
                           orientation='axial', save_path=None, filename=None):
    """
    Plot a single slice of the 3D image with overlaid masks and plane contours, and optionally save the image.
    
    Parameters:
        image (numpy.ndarray): The 3D image data.
        body_mask (numpy.ndarray): The 3D body mask.
        gtv_mask (numpy.ndarray): The 3D GTV mask.
        mandible_mask (numpy.ndarray): The 3D mandible mask.
        spinal_cord_mask (numpy.ndarray): The 3D spinal cord mask.
        plane_coeffs_list (list): List of plane coefficients [A, B, C, D] for each plane.
        optimization_methods_list (list): List of labels corresponding to the planes.
        slice_index (int): The slice index to plot.
        orientation (str): Orientation of the slice ('axial' or 'coronal'). Default is 'axial'.
        save_path (str, optional): If provided, the plot will be saved to this file.
    """
    
    # Define colors for the plane contours
    plane_colors = ['red', 'purple', 'cyan']
    
    # Create the figure
    plt.figure(figsize=(10, 10))
    shape_x = image.shape[0]
    shape_y = image.shape[1]
    
    
    if orientation.lower() == 'axial':
        # Display the axial slice
        plt.imshow(image[:, :, slice_index], cmap='gray')
        plt.contour(gtv_mask[:, :, slice_index], colors='yellow', linewidths=1)
        plt.contour(mandible_mask[:, :, slice_index], colors='blue', linewidths=1)
        plt.contour(spinal_cord_mask[:, :, slice_index], colors='green', linewidths=1)

        for plane_index, plane_coeffs in enumerate(plane_coeffs_list):
            x = np.linspace(0, image.shape[1], 100)
            y = np.linspace(0, image.shape[0], 100)
            X, Y = np.meshgrid(x, y)
            Z = slice_index
            A, B, C, D = plane_coeffs
            # Avoid division by zero:
            C = C if C != 0 else 1e-6
            contour = (-A * X - B * Y - D) / C
            plt.contour(X, Y, contour, levels=[Z], colors=plane_colors[plane_index], linewidths=1)

        # Create custom legend handles
        custom_lines = [
            Line2D([0], [0], color='yellow', lw=2, label='Primary GTV'),
            Line2D([0], [0], color='blue', lw=2, label='Mandible'),
            Line2D([0], [0], color='green', lw=2, label='Spinal Cord'),
            Line2D([0], [0], color='red', lw=2, label=optimization_methods_list[0])
        ]
        if len(plane_coeffs_list) >= 2:
            custom_lines.append(Line2D([0], [0], color='purple', lw=2, label=optimization_methods_list[1]))
        if len(plane_coeffs_list) == 3:
            custom_lines.append(Line2D([0], [0], color='cyan', lw=2, label=optimization_methods_list[2]))
            
        plt.legend(handles=custom_lines, loc='upper right', fontsize = 18)
                
    if orientation.lower() == 'coronal':
        # Display the coronal slice
        plt.imshow(image[slice_index, :, :], cmap='gray')
        plt.contour(gtv_mask[slice_index, :, :], colors='yellow', linewidths=1)
        plt.contour(mandible_mask[slice_index, :, :], colors='blue', linewidths=1)
        plt.contour(spinal_cord_mask[slice_index, :, :], colors='green', linewidths=1)

        for plane_index, plane_coeffs in enumerate(plane_coeffs_list):
            x = np.linspace(0, image.shape[0], 100)
            z = np.linspace(0, image.shape[2], 100)
            X, Z = np.meshgrid(x, z)
            Y = slice_index
            A, B, C, D = plane_coeffs
            # Avoid division by zero:
            B = B if B != 0 else 1e-6
            contour = (-A * X - C * Z - D) / B
            plt.contour(Z, X, contour, levels=[Y], colors=plane_colors[plane_index], linewidths=1)

        # Create custom legend handles
        custom_lines = [
            Line2D([0], [0], color='yellow', lw=2, label='Primary GTV'),
            Line2D([0], [0], color='blue', lw=2, label='Mandible'),
            Line2D([0], [0], color='green', lw=2, label='Spinal Cord'),
            Line2D([0], [0], color='red', lw=2, label=optimization_methods_list[0])
        ]
        if len(plane_coeffs_list) >= 2:
            custom_lines.append(Line2D([0], [0], color='purple', lw=2, label=optimization_methods_list[1]))
        if len(plane_coeffs_list) == 3:
            custom_lines.append(Line2D([0], [0], color='cyan', lw=2, label=optimization_methods_list[2]))
            
        plt.legend(handles=custom_lines, loc='upper right', fontsize = 10)
        
        # Optionally adjust the y-limits based on nonzero slice range in the body mask
        start_slice, end_slice = get_nonzero_slice_range(body_mask, slice_dir_indx=0)
        plt.ylim(start_slice, end_slice)
    
    # Remove axes and reduce whitespace
    plt.axis('off')
    # plt.subplots_adjust(left=0, right=1, top=1, bottom=0)
    # plt.margins(0, 0)
    
    if save_path:
        plt.savefig(os.path.join(save_path, filename), bbox_inches='tight', pad_inches=0)
    else:
        plt.show()
    
    plt.close()


## Scrollable widget
Display interactive scrollable widget in axial and coronal slice directions of the 3d head CT midplane projection onto slices, primary GTV, mandibles and spinal cord contours.

In [None]:


def display_scrollable_slices(image, gtv_mask, body_mask=None, mandible_mask=None, spinal_cord_mask=None,
                              plane_coeffs_list=[], optimization_methods_list=[]):
    """
    Display an interactive widget to scroll through slices of the 3D image (in two views) with
    overlays of masks and plane contours.
    
    Parameters
    ----------
    image : numpy.ndarray
        The 3D image data.
    gtv_mask : numpy.ndarray
        The 3D primary GTV mask (mandatory).
    body_mask : numpy.ndarray, optional
        The 3D body mask.
    mandible_mask : numpy.ndarray, optional
        The 3D mandible mask.
    spinal_cord_mask : numpy.ndarray, optional
        The 3D spinal cord mask.
    plane_coeffs_list : list
        List of plane coefficients, where each is a tuple/list (A, B, C, D).
    optimization_methods_list : list
        List of method names corresponding to each plane, used for the legend.
    """
    num_slices = image.shape[2]
    plane_colors = ['red', 'purple', 'cyan']
    
    # Calculate center of mass from body_mask if provided.
    if body_mask is not None:
        com = calculate_center_of_mass(body_mask)
        # Swap first two coordinates to match image orientation.
        com = (com[1], com[0], com[2])
    else:
        com = None

    def view_slice_axial(slice_index):
        fig, ax = plt.subplots(figsize=(10, 10))
        # Display the grayscale image.
        ax.imshow(image[:, :, slice_index], cmap='gray', interpolation='none')
        
        # Overlay the mandatory GTV mask.
        ax.contour(gtv_mask[:, :, slice_index], colors='yellow', linewidths=1)
        
        # Overlay optional masks if available.
        if body_mask is not None:
            ax.contour(body_mask[:, :, slice_index], colors='orange', linewidths=1)
        if mandible_mask is not None:
            ax.contour(mandible_mask[:, :, slice_index], colors='blue', linewidths=1)
        if spinal_cord_mask is not None:
            ax.contour(spinal_cord_mask[:, :, slice_index], colors='green', linewidths=1)
            
        # Overlay each plane contour.
        for idx, coeffs in enumerate(plane_coeffs_list):
            A, B, C, D = coeffs
            # Avoid division by zero.
            C_val = C if C != 0 else 1e-6
            x = np.linspace(0, image.shape[1], 100)
            y = np.linspace(0, image.shape[0], 100)
            X, Y = np.meshgrid(x, y)
            contour = (-A * X - B * Y - D) / C_val
            ax.contour(X, Y, contour, levels=[slice_index],
                       colors=plane_colors[idx % len(plane_colors)], linewidths=1)
        
        # Build legend based on available overlays.
        legend_handles = [Line2D([0], [0], color='yellow', lw=2, label='GTVp')]
        if body_mask is not None:
            legend_handles.append(Line2D([0], [0], color='orange', lw=2, label='Body'))
        if mandible_mask is not None:
            legend_handles.append(Line2D([0], [0], color='blue', lw=2, label='Mandible'))
        if spinal_cord_mask is not None:
            legend_handles.append(Line2D([0], [0], color='green', lw=2, label='Spinal Cord'))
        for idx, coeffs in enumerate(plane_coeffs_list):
            if idx < len(optimization_methods_list):
                legend_handles.append(Line2D([0], [0], color=plane_colors[idx % len(plane_colors)],
                                              lw=2, label=optimization_methods_list[idx]))
        
        ax.legend(handles=legend_handles, loc='upper right')
        
        # Custom coordinate formatter for axial view.
        def format_coord(x, y):
            col = int(round(x))
            row = int(round(y))
            info = f"x={col}, y={row}, slice={slice_index}"
            if 0 <= row < gtv_mask.shape[0] and 0 <= col < gtv_mask.shape[1]:
                info += f", GTVp={gtv_mask[row, col, slice_index]}"
            if body_mask is not None and 0 <= row < body_mask.shape[0] and 0 <= col < body_mask.shape[1]:
                info += f", Body={body_mask[row, col, slice_index]}"
            if mandible_mask is not None and 0 <= row < mandible_mask.shape[0] and 0 <= col < mandible_mask.shape[1]:
                info += f", Mandible={mandible_mask[row, col, slice_index]}"
            if spinal_cord_mask is not None and 0 <= row < spinal_cord_mask.shape[0] and 0 <= col < spinal_cord_mask.shape[1]:
                info += f", Spinal Cord={spinal_cord_mask[row, col, slice_index]}"
            return info
        ax.format_coord = format_coord
        
        ax.axis('off')
        plt.show()

    # For coronal view, determine slice range along x direction.
    if body_mask is not None:
        start_slice, end_slice = get_nonzero_slice_range(body_mask, slice_dir_indx=0)
    else:
        start_slice = 0
        end_slice = image.shape[0] - 1

    def view_slice_coronal(slice_index):
        fig, ax = plt.subplots(figsize=(10, 10))
        # Display coronal view (slice along x axis).
        ax.imshow(image[slice_index, :, :], cmap='gray', interpolation='none')
        ax.contour(gtv_mask[slice_index, :, :], colors='yellow')
        if body_mask is not None:
            ax.contour(body_mask[slice_index, :, :], colors='orange')
        if mandible_mask is not None:
            ax.contour(mandible_mask[slice_index, :, :], colors='blue')
        if spinal_cord_mask is not None:
            ax.contour(spinal_cord_mask[slice_index, :, :], colors='green')
        
        for idx, coeffs in enumerate(plane_coeffs_list):
            A, B, C, D = coeffs
            B_val = B if B != 0 else 1e-6
            x = np.linspace(0, image.shape[0], 100)
            z = np.linspace(0, image.shape[2], 100)
            X, Z = np.meshgrid(x, z)
            contour = (-A * X - C * Z - D) / B_val
            ax.contour(Z, X, contour, levels=[slice_index], colors=plane_colors[idx % len(plane_colors)])
        
        # Build legend for coronal view.
        legend_handles = [Line2D([0], [0], color='yellow', lw=2, label='GTVp')]
        if body_mask is not None:
            legend_handles.append(Line2D([0], [0], color='orange', lw=2, label='Body'))
        if mandible_mask is not None:
            legend_handles.append(Line2D([0], [0], color='blue', lw=2, label='Mandible'))
        if spinal_cord_mask is not None:
            legend_handles.append(Line2D([0], [0], color='green', lw=2, label='Spinal Cord'))
        for idx, coeffs in enumerate(plane_coeffs_list):
            if idx < len(optimization_methods_list):
                legend_handles.append(Line2D([0], [0], color=plane_colors[idx % len(plane_colors)], lw=2,
                                              label=optimization_methods_list[idx]))
        ax.legend(handles=legend_handles, loc='upper right')
        
        ax.set_ylim(start_slice, end_slice)
        ax.axis('off')
        plt.show()

    # Interactive slider for axial (z) view.
    slice_slider_axial = widgets.IntSlider(min=0, max=num_slices-1, step=1, value=num_slices//2, description='Axial Slice')
    display(widgets.interact(view_slice_axial, slice_index=slice_slider_axial))
    
    # Interactive slider for coronal (x) view.
    slice_slider_coronal = widgets.IntSlider(min=start_slice, max=end_slice, step=1, value=(start_slice+end_slice)//2, description='Coronal Slice')
    display(widgets.interact(view_slice_coronal, slice_index=slice_slider_coronal))



def display_scrollable_slices_params_slider(image, body_mask, gtv_mask, mandible_mask, spinal_cord_mask, 
                              bone, soft_tissue, image_plot, pat, interpolation, interpolation_method, interpolator,
                              plane_params_init = None):
    """
    Display a scrollable module to view slices of the 3D image with masks.
    Allows interactive tuning of the plane using theta, phi, and L.
    """
    num_slices = image.shape[2]

    if plane_params_init is None:
        plane_params_init = (1,0,0,-image.shape[0]//2)
        
    # Initial plane parameters
    A, B, C, D = plane_params_init
    vec = np.array([A, B, C])
    if np.linalg.norm(vec) < 2:
        vec = vec * np.abs(D)
    theta_init, phi_init, L_init = vector_to_angles(vec)

    theta_init_deg = np.rad2deg(theta_init)
    phi_init_deg = np.rad2deg(phi_init)
    

    # Compute the range for L
    L_min = L_init - image.shape[0] // 2
    L_max = L_init + image.shape[0]

    def view_slice_with_plane_z(slice_index, theta, phi, L, image, soft_tissue, image_plot, pat,
            interpolation=interpolation, interpolation_method = interpolation_method, interpolator = interpolator):
        plt.figure(figsize=(10, 10))
        plt.imshow(image[:, :, slice_index], cmap='gray')

        # Overlay masks as contours
        plt.contour(gtv_mask[:, :, slice_index], colors='yellow', linewidths=1)
        plt.contour(mandible_mask[:, :, slice_index], colors='blue', linewidths=1)
        plt.contour(spinal_cord_mask[:, :, slice_index], colors='green', linewidths=1)

        theta = np.deg2rad(theta)
        phi = np.deg2rad(phi)
        
        # Convert angles to normal vector
        A, B, C = angles_to_vector(theta, phi, L) 
        vector = np.array([A, B, C])    
        D = -np.dot(vector, vector)

        # Generate plane contour
        x = np.linspace(0, image.shape[1], 100)
        y = np.linspace(0, image.shape[0], 100)
        X, Y = np.meshgrid(x, y)
        Z = slice_index
        C = C if C != 0 else 1e-6  # Avoid division by zero
        contour = (-A * X - B * Y - D) / C
        plt.contour(X, Y, contour, levels=[Z], colors='red', linewidths=1)

        # Compute the objective function value
        # obj_value = objective_function(
        #     (theta, phi, L), image, soft_tissue, image_plot, pat,
        #     interpolation=interpolation, interpolation_method = interpolation_method, interpolator = interpolator
        # )

        # Add a legend
        custom_lines = [
            Line2D([0], [0], color='yellow', lw=2, label='GTV'),
            Line2D([0], [0], color='blue', lw=2, label='Mandible'),
            Line2D([0], [0], color='green', lw=2, label='Spinal Cord'),
            Line2D([0], [0], color='red', lw=2, label='Custom Plane'),
        ]
        plt.legend(handles=custom_lines, loc='upper right')

        plt.title(f'Slice {slice_index}')# | Obj. Value: {obj_value:.4f}')
        plt.axis('off')
        plt.show()

    # Sliders for plane parameters
    slice_slider = widgets.IntSlider(min=0, max=num_slices - 1, step=1, value=num_slices // 2, description='Axial Slice')
    theta_slider = widgets.FloatSlider(min=-90, max=90, step=1, value=theta_init_deg, description='θ (xy-plane)')
    phi_slider = widgets.FloatSlider(min=-90, max=90, step=1, value=phi_init_deg, description='φ (xz-plane)')
    L_slider = widgets.FloatSlider(min=L_min, max=L_max, step=1, value=L_init, description='L')

    # Create interactive widget and display it
    interactive_plot = widgets.interactive(view_slice_with_plane_z, 
                                           slice_index=slice_slider, 
                                           theta=theta_slider, 
                                           phi=phi_slider, 
                                           L=L_slider,
                                             image=widgets.fixed(image),
                                                soft_tissue=widgets.fixed(soft_tissue),
                                                image_plot=widgets.fixed(image_plot),
                                                pat=widgets.fixed(pat),
                                                interpolation=widgets.fixed(interpolation),
                                                interpolation_method=widgets.fixed(interpolation_method),
                                                interpolator=widgets.fixed(interpolator))
    
                                        
    
    display(interactive_plot)  # Ensure it renders in Jupyter Notebook


def display_scrollable_image_with_values(image, title='Image with Hover Values'):
    """
    Display a scrollable slider for a 3D image where one can hover over the pixels and see the values.

    Parameters:
    image (numpy.ndarray): The 3D image data.
    """
    num_slices = image.shape[2]

    def view_slice(slice_index):
        fig, ax = plt.subplots(figsize=(10, 10))
        im = ax.imshow(image[:, :, slice_index], cmap='gray')
        fig.colorbar(im, ax=ax, label='Intensity')

        # Function to update the title with pixel values on hover
        def format_coord(x, y):
            x = int(x + 0.5)
            y = int(y + 0.5)
            if 0 <= x < image.shape[1] and 0 <= y < image.shape[0]:
                z = slice_index
                return f"x={x}, y={y}, z={z}, value={image[y, x, z]:.2f}"
            else:
                return f"x={x}, y={y}, z={z}"

        ax.format_coord = format_coord
        ax.set_title(title)
        plt.show()

    slice_slider = widgets.IntSlider(min=0, max=num_slices - 1, step=1, description='Slice')
    widgets.interact(view_slice, slice_index=slice_slider)



def verify_plane_params_rotated(image, soft_tissue, image_plot, pat, interpolation, interpolation_method, interpolator, output_path, verification_list, 
                                obj_val_update, body_mask, optimization_methods_list):

    verification_path = os.path.join(output_path, 'verification_plots')
    os.makedirs(verification_path, exist_ok=True)
    com = calculate_center_of_mass(body_mask)
    com = np.array([com[1], com[0], com[2]])
    plane_params_list = []
    param_vec_norm = np.array([1, 0, 0])
    
    for angles in verification_list:
        theta_deg, phi_deg = angles
        theta = np.deg2rad(theta_deg)
        phi = np.deg2rad(phi_deg)

        rotation_matrix = np.array([
                [np.cos(theta) * np.cos(phi), -np.sin(theta), np.cos(theta) * np.sin(phi)],
                [np.sin(theta) * np.cos(phi), np.cos(theta), np.sin(theta) * np.sin(phi)],
                [-np.sin(phi), 0, np.cos(phi)]
            ])

        # Rotate the normal vector around the center of mass
        rotated_normal_vector_normalized = np.dot(rotation_matrix, param_vec_norm)
    
        # Calculate D for the plane equation
        D = -np.dot(rotated_normal_vector_normalized, com)

        # Store the plane parameters
        plane_params = (
            rotated_normal_vector_normalized[0],
            rotated_normal_vector_normalized[1],
            rotated_normal_vector_normalized[2],
            D
        )
        plane_params_list.append(plane_params)

        # Convert rotated normal vector to spherical coordinates (theta, phi, l)
        rotated_normal_vector = rotated_normal_vector_normalized * np.abs(D)
        theta_sph, phi_sph, l = vector_to_angles(rotated_normal_vector)
        rot_params = np.array([theta_sph, phi_sph, l])

        # Compute MSE for this plane
        mse, diff_image = objective_function(
            rot_params, image, soft_tissue, image_plot, pat,
            interpolation=interpolation, obj_val_update=obj_val_update, interpolation_method=interpolation_method, interpolator=interpolator, 
            plot_differences=True, output_path=os.path.join(verification_path, f"verification_plane_{theta_deg}_{phi_deg}.png"))
    
        num_slices = image.shape[2]
        plane_colors = ['red', 'purple', 'cyan']

        def view_slice_with_plane_z(slice_index):
            plt.figure(figsize=(10, 10))
            plt.imshow(diff_image[:, :, slice_index], cmap='hot')
            plt.colorbar(label='Difference Intensity')
            
            for plane_index, plane_coeffs in enumerate(plane_params_list):
                x = np.linspace(0, image.shape[1], 100)
                y = np.linspace(0, image.shape[0], 100)
                X, Y = np.meshgrid(x, y)
                Z = slice_index
                A, B, C, D = plane_coeffs
                C = C if C != 0 else 1e-6  # Avoid divide by zero
                contour = (-A * X - B * Y - D) / C
                plt.contour(X, Y, contour, levels=[Z], colors=plane_colors[plane_index % len(plane_colors)], linewidths=1)

            custom_lines = [Line2D([0], [0], color=color, lw=2, label=method) for color, method in zip(plane_colors, optimization_methods_list[:len(plane_params_list)])]
            plt.legend(handles=custom_lines, loc='upper left')
            plt.scatter(com[0], com[1], color='blue', marker='x', label='Center of Mass')
            plt.title(f'Slice {slice_index} \n MSE: {mse:.4f} \n Theta: {theta_deg}°, Phi: {phi_deg}°')
            plt.axis('off')
            plt.show()

        slice_slider = widgets.IntSlider(min=0, max=num_slices - 1, step=1, description='Axial Slice')
        widgets.interact(view_slice_with_plane_z, slice_index=slice_slider)

        start_slice, end_slice = get_nonzero_slice_range(body_mask, slice_dir_indx=0)

        def view_slice_with_plane_x(slice_index):
            plt.figure(figsize=(10, 10))
            plt.imshow(diff_image[slice_index, :, :], cmap='hot')
            plt.colorbar(label='Difference Intensity')
            
            for plane_index, plane_coeffs in enumerate(plane_params_list):
                x = np.linspace(0, image.shape[0], 100)
                z = np.linspace(0, image.shape[2], 100)
                X, Z = np.meshgrid(x, z)
                Y = slice_index
                A, B, C, D = plane_coeffs
                B = B if B != 0 else 1e-6
                contour = (-A * X - C * Z - D) / B
                plt.contour(Z, X, contour, levels=[Y], colors=plane_colors[plane_index % len(plane_colors)])

            custom_lines = [Line2D([0], [0], color=color, lw=2, label=method) for color, method in zip(plane_colors, optimization_methods_list[:len(plane_params_list)])]
            plt.legend(handles=custom_lines, loc='upper right')
            plt.title(f'Slice {slice_index} \n MSE: {mse:.4f} \n Theta: {theta_deg}°, Phi: {phi_deg}°')
            plt.ylim(start_slice, end_slice)
            plt.show()

        slice_slider_x = widgets.IntSlider(min=start_slice, max=end_slice, step=1, description='Coronal Slice')
        widgets.interact(view_slice_with_plane_x, slice_index=slice_slider_x)


            # plt.imshow(mse_array, extent=[np.rad2deg(thetas[0]), np.rad2deg(thetas[-1]), 
            #                         np.rad2deg(phis[0]), np.rad2deg(phis[-1])], 
            #     aspect='auto', origin='lower', cmap='viridis', vmin = vmin, vmax = vmax_initial)
            # plt.colorbar(label='Mean Squared Error')
            # plt.scatter(theta_deg, phi_deg, color='red', marker='x', s=100, label='Verification Plane')
            # plt.title('MSE vs. Polar and Azimuthal Angles')
            # plt.xlabel('Polar Angle (θ)°')
            # plt.ylabel('Azimuthal Angle (φ)°')

            # # Plot a red cross at the bin center
            # #plt.scatter(min_theta, min_phi, color='red', marker='x', s=100, label='Min MSE')
            # plt.legend()
            # plt.savefig(os.path.join(verification_path, f"verification_mse_{theta_deg}_{phi_deg}.png"))
            # plt.close()


## Parametrization $(A,B,C) \rightarrow (\theta, \phi, L)$
Change the parameters defining a plane from the normal vector $\vec{n}=(A,B,C)$ from origin to plane to a parametrization of angles between x and y axis ($\theta$) and x and z axis ($\phi$) and distance (L) to origin along the normal direction ($\hat{\vec{n}}$).


In [None]:

def vector_to_angles(vector):
    """
    Calculate the distance from the origin and the angles in the xy and xz planes with respect to the x-axis.

    Parameters:
    vector (tuple): A tuple representing the vector (x, y, z).

    Returns:
    tuple: The distance from the origin, the angle in the xy plane, and the angle in the xz plane.
    """
    x, y, z = vector

    # Calculate the distance from the origin
    distance = np.sqrt(x**2 + y**2 + z**2)

    # Calculate the angle in the xy plane with respect to the x-axis
    angle_xy = np.arctan2(y, x)
    angle_xy_deg = np.degrees(angle_xy)

    # Calculate the angle in the xz plane with respect to the x-axis
    angle_xz = np.arctan2(z, x)
    angle_xz_deg = np.degrees(angle_xz)

    return angle_xy, angle_xz, distance

def angles_to_vector(angle_xy, angle_xz, distance):
    """
    Calculate the vector components given the distance from the origin and the angles in the xy and xz planes.

    Parameters:
    distance (float): The distance from the origin.
    angle_xy (float): The angle in the xy plane with respect to the x-axis.
    angle_xz (float): The angle in the xz plane with respect to the x-axis.

    Returns:
    tuple: The vector components (x, y, z).
    """
    x = distance * np.cos(angle_xy) * np.cos(angle_xz)
    y = distance * np.sin(angle_xy) * np.cos(angle_xz)
    z = distance * np.sin(angle_xz)
    
    return x, y, z

def plane_params_rescale(plane_params, scaling_factor):
    """
    Rescale the plane parameters by a given scaling factor.

    Parameters:
    plane_params (tuple): The original plane parameters (A, B, C, D).
    scaling_factor (float): The scaling factor to apply to the plane parameters.

    Returns:
    tuple: The rescaled plane parameters (A_rescaled, B_rescaled, C_rescaled, D_rescaled).
    """

    A, B, C, D = plane_params
    vec = np.array([A,B,C])
    if np.linalg.norm(vec) < 2:
        vec = vec * abs(D)
    vec_rescaled = vec * scaling_factor
    A_rescaled, B_rescaled, C_rescaled = vec_rescaled
    D = -np.dot(vec_rescaled, vec_rescaled)
    
    return A_rescaled, B_rescaled, C_rescaled, D
    

## Objective function
Given a 3D CT image computes the mean square error of the Intensities of the voxels in the image and the respective mirror voxels defined by the plane. The mirror voxels are computed by using the direction of the plane vector and the shortest distance from plane to the voxel. 

$$
f(A, B, C) = \sum_{i} \big(I_i(\vec{x_i}) - I_i(\vec{x_i}^m)\big)^2
$$

The mirrored voxel ( $\vec{x}_i^m $) is defined as:
$$

\vec{x}_i^m = \vec{x}_i - 2 \cdot d \cdot \hat{\vec{n}},
$$

where:

- $( d )$ is the signed distance from the point ( $\vec{x}_i = (x_i, y_i, z_i) $) to the plane, computed as:
  $$
  d = \frac{A x_i + B y_i + C z_i + D}{\sqrt{A^2 + B^2 + C^2}}.
  $$

- $( \hat{\vec{n}} )$ is the normalized normal vector of the plane, defined as:
  $$
  \hat{\vec{n}} = \frac{\vec{n}}{\|\vec{n}\|}, \quad \text{with } \vec{n} = (A, B, C), \text{ and } \|\vec{n}\| = \sqrt{A^2 + B^2 + C^2}.
  $$

- $( D )$ is computed from the plane equation $( A x + B y + C z + D = 0 )$ by substituting $( \vec{n} = (A, B, C) )$ as a point on the plane:
  $$
  D = -(A^2 + B^2 + C^2).
  $$

Substituting \( D \) into \( d \), the mirrored voxel can also be expressed as:
$$
\vec{x}_i^m = \vec{x}_i - 2 \cdot \left(\frac{A x_i + B y_i + C z_i}{\sqrt{A^2 + B^2 + C^2}} - \sqrt{A^2 + B^2 + C^2}\right) \cdot \hat{\vec{n}} \\
= \vec{x}_i - 2 \cdot \left(\frac{A x_i + B y_i + C z_i}{A^2 + B^2 + C^2} - 1 \right) \cdot \vec{n}
$$

$$
f(A, B, C) = \sum_{i} \big(I_i(\vec{x_i}) - I_i(\vec{x}_i - 2 \cdot \left(\frac{A x_i + B y_i + C z_i}{A^2 + B^2 + C^2} - 1 \right) \cdot \vec{n})\big)^2
$$







In [None]:
def real_params(plane_params, voxel_size):

    A, B, C, D = plane_params
    plane_vec = np.array([A, B, C])
    if np.linalg.norm(plane_vec) < 2:
        plane_vec = plane_vec * abs(D)
    plane_vec_real = plane_vec * np.array(voxel_size)
    plane_vec_real_normalized = plane_vec_real / np.linalg.norm(plane_vec_real)
    D_real_normalized = -np.dot(plane_vec_real_normalized, plane_vec_real)
    plane_params_real = (plane_vec_real_normalized[0], plane_vec_real_normalized[1], plane_vec_real_normalized[2], D_real_normalized)

    return plane_params_real

def compute_distances_and_indices(mask, plane_coeffs, plane_coeffs_real=None, voxel_size=None):
    """
    Compute the distances and indices of the distances for the given mask and plane coefficients.
    Optimized for speed using vectorized operations.

    Parameters:
    mask (numpy.ndarray): The 3D mask data.
    plane_coeffs (tuple): The coefficients (A, B, C, D) of the plane equation Ax + By + Cz + D = 0.
    plane_coeffs_real (tuple, optional): Real plane coefficients for distance calculation.
    voxel_size (tuple, optional): The voxel size in each dimension.

    Returns:
    tuple: A list of indices of the voxels and a list of the corresponding distances.
    """
    # Get the non-zero indices of the mask
    mask_nonzero = np.argwhere(mask != 0)

    # Extract plane coefficients
    A, B, C, D = plane_coeffs

    # Compute distances in a vectorized way
    x, y, z = mask_nonzero[:, 1], mask_nonzero[:, 0], mask_nonzero[:, 2]
    indices_list = list(zip(x, y, z))
    distances = (A * x + B * y + C * z + D) / np.sqrt(A**2 + B**2 + C**2)
    distances_list = distances.tolist()

    # Compute real distances if voxel size is provided
    if voxel_size is not None and plane_coeffs_real is not None:
        A_real, B_real, C_real, D_real = plane_coeffs_real
        x_real, y_real, z_real = x * voxel_size[0], y * voxel_size[1], z * voxel_size[2]
        distances_real = (A_real * x_real + B_real * y_real + C_real * z_real + D_real) / np.sqrt(A_real**2 + B_real**2 + C_real**2)
        return mask_nonzero.tolist(), distances.tolist(), distances_real.tolist()

    return indices_list, distances_list 

def compute_squared_intensity_difference(bone, soft_tissue, indices, distances, plane_params, interpolation='bone', mirror_plots = False, interpolation_method='linear'):
    """
    Compute the squared intensity difference between original and mirror voxels across a plane in a 3D image.
    Parameters:
    bone (ndarray): 3D numpy array representing the bone intensity values.
    soft_tissue (ndarray): 3D numpy array representing the soft tissue intensity values.
    indices (list of tuples): List of voxel indices (x, y, z) to be mirrored.
    distances (list of floats): List of distances from the plane for each voxel.
    plane_params (tuple): Parameters (A, B, C, D) of the plane equation Ax + By + Cz + D = 0.
    interpolation (str, optional): Type of interpolation ('bone' or 'full'). Defaults to 'bone'.
    mirror_plots (bool, optional): If True, generates plots of the mirror voxels. Defaults to False.
    interpolation_method (str, optional): Interpolation method for RegularGridInterpolator. Defaults to 'linear'.
    Returns:
    tuple: Mean squared difference sum and percentage difference of the squared intensity differences.
    """
    
    # Normalize plane parameters (normal vector)
    A, B, C, D = plane_params
    normal_vector = np.array([A, B, C]) / np.linalg.norm([A, B, C])

    if interpolation == 'bone':
        grid_x, grid_y, grid_z = np.arange(bone.shape[0]), np.arange(bone.shape[1]), np.arange(bone.shape[2])
        interpolator = RegularGridInterpolator((grid_x, grid_y, grid_z), bone, method=interpolation_method, bounds_error=False, fill_value=0)
    if interpolation == 'full':
        start = time.time()
        image = bone + soft_tissue
        grid_x, grid_y, grid_z = np.arange(image.shape[0]), np.arange(image.shape[1]), np.arange(image.shape[2])
        interpolator = RegularGridInterpolator((grid_x, grid_y, grid_z), image, method=interpolation_method, bounds_error=False, fill_value=0)
        stop = time.time()
        print(f"Interpolation time: {stop - start}")

    # Convert indices to array for vectorized operations
    indices_array = np.array(indices) # coordinate system indexing
    indices_array_image = np.array(indices)[:, [1, 0, 2]] # Swap x and y for correct indexing of image system
    distances_array = np.array(distances)

    # Compute original voxel intensities
    original_intensities = bone[indices_array_image[:, 0], indices_array_image[:, 1], indices_array_image[:, 2]]

    mirror_voxels = indices_array - 2 * distances_array[:, None] * normal_vector
    mirror_voxels_image = mirror_voxels[:, [1, 0, 2]] # Swap x and y for correct indexing of image system


    # Compute mirror voxel intensities
    mirror_intensities = interpolator(mirror_voxels_image) # coordinate indexing, shape = (num_voxels,)
    if mirror_plots:
        mirror_image = generate_mirror_voxel_array(bone, indices_array_image, mirror_voxels_image, mirror_intensities)
        mirror_and_original_image = bone + mirror_image
        mirror_points(bone, indices, mirror_voxels, plane_params)
        plot_plane_on_middle_slice(mirror_and_original_image, plane_params, title="Mirror and Original Image")
        
    # Compute squared differences
    squared_differences = (original_intensities - mirror_intensities) ** 2 # shape = (num_voxels,)
    squared_diff_sum = np.sum(squared_differences)

    # Compute percentage difference
    total_intensity_sum = np.sum(bone ** 2)
    num_nonzero_voxels = np.count_nonzero(bone)
    mean_squared_diff_sum = squared_diff_sum / num_nonzero_voxels
    percentage_diff = (mean_squared_diff_sum / total_intensity_sum) * 100

    return mean_squared_diff_sum, percentage_diff


def compute_intensity_metric(bone, soft_tissue, indices, distances, plane_params, metric_type='mse', interpolation='bone', interpolation_method='linear', 
                             mirror_plots=False, interpolator=None, plot_differences=False, output_path = None):
    """
    Compute either the Normalized Cross-Correlation (NCC) or Mean Squared Error (MSE) between original and mirror voxels across a plane in a 3D image.
    
    Parameters:
    bone (ndarray): 3D numpy array representing the bone intensity values.
    soft_tissue (ndarray): 3D numpy array representing the soft tissue intensity values.
    indices (list of tuples): List of voxel indices (x, y, z) to be mirrored.
    distances (list of floats): List of distances from the plane for each voxel.
    plane_params (tuple): Parameters (A, B, C, D) of the plane equation Ax + By + Cz + D = 0.
    metric_type (str): Type of metric to compute ('ncc' for Normalized Cross-Correlation, 'mse' for Mean Squared Error).
    interpolation (str, optional): Type of interpolation ('bone' or 'full'). Defaults to 'bone'.
    interpolation_method (str, optional): Interpolation method for RegularGridInterpolator. Defaults to 'linear'.
    mirror_plots (bool, optional): If True, generates plots of the mirror voxels. Defaults to False.
    
    Returns:
    tuple: The computed metric value (NCC or MSE).
    """
    # Normalize plane parameters (normal vector)
    A, B, C, D = plane_params
    normal_vector = np.array([A, B, C]) / np.linalg.norm([A, B, C])

    if interpolator is None:
        raise ValueError("Interpolator must be passed")

    # start_metric_compute = time.time()
    # Convert indices to array for vectorized operations
    indices_array = np.array(indices)
    indices_array_image = np.array(indices)[:, [1, 0, 2]]  # Swap x and y for correct indexing of image system
    distances_array = np.array(distances)

    # Compute original voxel intensities
    original_intensities = bone[indices_array_image[:, 0], indices_array_image[:, 1], indices_array_image[:, 2]]

    # Mirror voxels based on distances and normal vector
    mirror_voxels = indices_array - 2 * distances_array[:, None] * normal_vector
    mirror_voxels_image = mirror_voxels[:, [1, 0, 2]]  # Swap x and y for correct indexing of image system

    # Compute mirror voxel intensities using interpolation
    mirror_intensities = interpolator(mirror_voxels_image)  # coordinate indexing, shape = (num_voxels,)

    if mirror_plots:
        mirror_image = generate_mirror_voxel_array(bone, indices_array_image, mirror_voxels_image, mirror_intensities)
        mirror_and_original_image = bone + mirror_image
        mirror_points(bone, indices, mirror_voxels, plane_params)
        plot_plane_on_middle_slice(mirror_and_original_image, plane_params, title="Mirror and Original Image")

    if metric_type == 'mse':
        # Mean Squared Error calculation
        squared_differences = (original_intensities - mirror_intensities) ** 2   
        sum = np.sum(squared_differences)
        num_nonzero_voxels = np.count_nonzero(bone)
        mse = sum / num_nonzero_voxels
        percentage_diff = (sum / np.sum(bone ** 2)) * 100
        # end_metric_compute = time.time()
        # print(f"Metric computation time: {end_metric_compute - start_metric_compute}")
        if plot_differences:
            diff_image = np.zeros_like(bone)
            diff_image[indices_array_image[:, 0], indices_array_image[:, 1], indices_array_image[:, 2]] = squared_differences
            return mse, diff_image
        return mse, percentage_diff

    elif metric_type == 'ncc':
        # Normalized Cross-Correlation (NCC) calculation
        mean_original = np.mean(original_intensities)
        mean_mirror = np.mean(mirror_intensities)

        # Normalize the intensities by subtracting the mean
        normalized_original = original_intensities - mean_original
        normalized_mirror = mirror_intensities - mean_mirror

        # Compute the cross-correlation (dot product between normalized intensities)
        cross_correlation = np.sum(normalized_original * normalized_mirror)  # sum of product of normalized intensities

        # Normalize cross-correlation by dividing by the norm of the original and mirrored intensities
        norm_factor = np.sqrt(np.sum(normalized_original ** 2) * np.sum(normalized_mirror ** 2))  # normalization factor
        ncc = cross_correlation / norm_factor if norm_factor != 0 else 0
        # end_metric_compute = time.time()
        # print(f"Metric computation time: {end_metric_compute - start_metric_compute}")
        return -1 * ncc, 0
    
    else:
        raise ValueError("Invalid metric type. Use 'ncc' for Normalized Cross-Correlation or 'mse' for Mean Squared Error.")


def generate_mirror_voxel_array(image, indices, mirror_voxels_image, mirror_intensities):
    """
    Generate a 3D array of the same shape as the input image, filled with mirror voxel intensities.
    Optimized for speed using vectorized operations.
    """
    
    # Initialize an empty 3D array with the same shape as the input image
    mirror_image = np.zeros_like(image)

    # Convert mirror voxel positions to integer grid indices
    mirror_indices = np.round(mirror_voxels_image).astype(int)

    # Filter out indices that fall outside the valid image bounds
    valid_mask = (
        (mirror_indices[:, 0] >= 0) & (mirror_indices[:, 0] < image.shape[0]) &
        (mirror_indices[:, 1] >= 0) & (mirror_indices[:, 1] < image.shape[1]) &
        (mirror_indices[:, 2] >= 0) & (mirror_indices[:, 2] < image.shape[2])
    )
    # Apply the mask to filter valid indices and intensities
    valid_indices = mirror_indices[valid_mask]
    valid_orig_indices = indices[valid_mask]
    valid_original_intensities = image[valid_orig_indices[:, 0], valid_orig_indices[:, 1], valid_orig_indices[:, 2]]

    # Assign intensities to the corresponding positions in the 3D array
    mirror_image[valid_indices[:, 0], valid_indices[:, 1], valid_indices[:, 2]] = valid_original_intensities

    return mirror_image

def mirror_points(image, indices, mirror_voxels_image, plane_coeffs):
    # Filter out points with z coordinate of the slice index
    slice_index = image.shape[2] // 2  # Middle slice
    filtered_indices = [idx for idx in indices if idx[2] != slice_index]
    filtered_mirror_voxels = [mirror_voxels_image[i] for i, idx in enumerate(indices) if idx[2] != slice_index]

    # Take random points for demonstration
    random_indices = random.sample(range(len(filtered_indices)), 10)
    random_points = [filtered_indices[i] for i in random_indices]
    mirror_points = [filtered_mirror_voxels[i] for i in random_indices]

    # Plot the random points and their mirror points
    fig, ax = plt.subplots(figsize=(10, 10))
    ax.imshow(image[:, :, slice_index], cmap='gray')
    ax.set_title(f'Slice {slice_index} with Random Original and Mirror Points')

    # Plot original points
    for i, idx in enumerate(random_points):
        ax.scatter(idx[0], idx[1], color='blue', marker=f'${i}$', label=f'Original Point {i}')

    # Plot mirror points
    for i, mirror_idx in enumerate(mirror_points):
        ax.scatter(mirror_idx[0], mirror_idx[1], color='red', marker=f'${i}$', label=f'Mirror Point {i}')
    # Plot the plane contour
    A,B,C,D = plane_coeffs
    x = np.linspace(0, image.shape[1], 100)
    y = np.linspace(0, image.shape[0], 100)
    X, Y = np.meshgrid(x, y)
    Z = slice_index
    C = plane_coeffs[2] if plane_coeffs[2] != 0 else 1e-6
    contour = (-A * X - B * Y - D) / C
    ax.contour(X, Y, contour, levels=[Z], colors='red')

    plt.legend()
    plt.show()

def objective_function(plane_vec, bone, soft_tissue, image_plot, pat, interpolation = 'bone', obj_val_update = None, interpolation_method='linear', metric = 'mse', 
                       interpolator = None, plot_differences=False, output_path = None):
    """
    Computes the objective function value for a given plane vector and images.
    Parameters:
    plane_vec (tuple): A tuple containing the plane parameters (alpha, beta, D).
    bone (ndarray): The bone image data.
    soft_tissue (ndarray): The soft tissue image data.
    image_plot (ndarray): The image plot data.
    pat (object): Patient data or identifier.
    interpolation (str, optional): The interpolation type for computing squared intensity difference. Default is 'bone'.
    obj_val_update (list, optional): A list to append the computed objective function value. Default is None.
    input_image (str, optional): The type of input image to use ('bone' or 'all'). Default is 'bone'.
    interpolation_method (str, optional): The method of interpolation to use. Default is 'linear'.
    Returns:
    float: The computed objective function value.
    Prints:
    - Alpha, Beta, and D values in degrees.
    - Time taken to compute distances and indices.
    - Time taken to compute MSE.
    - Objective function value.
    - Percentage difference.
    """
    

    alpha, beta, D = plane_vec[0], plane_vec[1], plane_vec[2]
    #print(f"Alpha: {np.degrees(alpha)}°, Beta: {np.degrees(beta)}°, D: {D}")
    vec = angles_to_vector(alpha, beta, D)
    D = -np.dot(vec,vec)
    plane_params = (vec[0], vec[1], vec[2], D)
  
    start_distance = time.time()
    indices_list, distances_list = compute_distances_and_indices(bone, plane_params) # indices in image coordiantes
    end_distance = time.time()
    #print(f"Time taken to compute distances and indices: {end_distance - start_distance:.2f} seconds")
    #start_sum = time.time()
    sum, diff_image = compute_intensity_metric(bone, soft_tissue, indices_list, distances_list, plane_params, metric_type = metric, interpolation=interpolation, mirror_plots=False, 
                                               interpolation_method=interpolation_method, interpolator=interpolator, plot_differences=plot_differences, output_path=output_path)
    end_sum = time.time()   
    if plot_differences is True:
        return sum, diff_image
    #print(f"Time taken to compute MSE: {end_sum - start_sum:.2f} seconds")
    # print(f"Objective function value: {sum}")
    
    # print(f"Percentage difference: {percentage:.2e}%")
    #obj_val_update.append(sum)
    
    return sum 

def optimize_plane_parameters(initial_plane_params, image, soft_tissue, image_plot, pat, optimization_method, interpolation = 'full', output_path = None, 
                              interpolation_method='linear', metric = 'mse', interpolator = None):
    """
    Optimize the parameters of a plane to fit a given image and soft tissue data.
    Parameters:
    initial_plane_params (tuple): Initial parameters of the plane (A, B, C, D).
    image (ndarray): The image data to be used in the optimization.
    soft_tissue (ndarray): The soft tissue data to be used in the optimization.
    image_plot (ndarray): The image plot data to be used in the optimization.
    pat (str): Patient identifier.
    optimization_method (str): The optimization method to be used (e.g., 'L-BFGS-B', 'TNC').
    interpolation (str, optional): The interpolation type for the image. Default is 'bone'.
    input_image (str, optional): The input image type. Default is 'bone'.
    output_path (str, optional): The path to save the output plot. Default is None.
    interpolation_method (str, optional): The interpolation method to be used. Default is 'linear'.
    Returns:
    tuple: Optimized plane parameters (A, B, C, D) and the final objective function value.
    Notes:
    - The function uses the scipy.optimize.minimize method to perform the optimization.
    - The objective function values are plotted and saved to the specified output path.
    - The function prints the optimization results including the message, success status, number of iterations, and final objective function value.
    """
    print(f"Start optimization for patient {pat} ...")
    start_optimization = time.time()
    A,B,C,D = initial_plane_params
    vec = np.array([A,B,C])
    if np.linalg.norm(vec) < 2:
        vec = vec * abs(D)
    alpha, beta, L = vector_to_angles(vec)
    initial_params = np.array([alpha, beta, L])
    initial_objective_value = objective_function(initial_params, image, soft_tissue, image_plot, pat, interpolation, interpolation_method=interpolation_method, metric=metric, interpolator=interpolator)

    objective_values = []
    objective_values_update = []
    params_values = []
    params_values.append(initial_params)
    objective_values.append(initial_objective_value)

    plot_plane_on_middle_slice(image_plot, initial_plane_params, title="Initial Plane", pat=pat)
    
    def callback(xk):
        # Evaluate the objective function at the current parameters
        #start = time.time()
        current_value = objective_function(xk, image, soft_tissue, image_plot, pat, interpolation, objective_values_update, interpolation_method, metric, interpolator)
        objective_values.append(current_value)
        params_values.append(xk)
    #     end = time.time()
    #     print(f"Time taken to compute objective function in optimization: {end - start:.2f} seconds")
    #     cProfile.run('objective_function(xk, image, soft_tissue, image_plot, pat, interpolation, objective_values_update, input_image, interpolation_method, metric)')

    result = minimize(
        objective_function,
        x0=initial_params,
        args=(image, soft_tissue, image_plot, pat, interpolation, objective_values_update, interpolation_method, metric, interpolator),
        method=optimization_method,
        jac=None,
        options=None,
        callback=callback
    )
    end_optimization = time.time()
    print(f"Time taken for optimization of patient {pat}: {end_optimization - start_optimization:.2f} seconds")
    
    
    create_gif_with_plane_projections(image_plot, params_values, output_path, pat=pat, com=None, objective_values=objective_values)
    create_gif_of_objective_updates(objective_values, output_path, gif_name=f"objective_updates_patient_{pat}.gif")

    print(f"Optimization ended with message: {result['message']}")
    print(f"Why the optimization stopped: {result['success']}")
    print(f"Number of iterations: {result['nit']}")
    print(f"Final objective function value: {result['fun']}")
    print(f"All objective values: {objective_values}")

    plt.plot(objective_values) 
    plt.xlabel('Iteration')
    plt.ylabel('Objective Function Value')
    plt.title('Objective Function Value vs. Iteration')
    plt.savefig(os.path.join(output_path, f"objective_function_patient_{pat}.png"))
    plt.show()
    plt.close()

    A, B, C = angles_to_vector(result["x"][0], result["x"][1], result["x"][2])
    D = -np.dot([A, B, C], [A, B, C])
    plane_params_opt = (A, B, C, D)
    print(f"Optimized plane parameters: {plane_params_opt}")
    plot_plane_on_middle_slice(image_plot, plane_params_opt, title="Optimized Plane", pat=pat)
    return plane_params_opt, result["fun"]


def multi_resolution_optimization(plane, bone_ct, soft_tissue, image_plot, pat_num, optimization_method, 
                                   interpolation, opt_image, output_path, interpolation_method, metric, num_levels=4, interpolator=None):
    """
    Multi-resolution optimization of the mid-sagittal plane parameters using a pyramid approach.
    
    Parameters:
    bone_ct (ndarray): 3D bone CT image.
    soft_tissue (ndarray): 3D soft tissue image.
    image_plot (ndarray): 3D image plot (could be the same as soft_tissue or bone_ct).
    pat_num (int): Patient number.
    optimization_method (str): Optimization method ('Nelder-Mead', 'Gradient Descent', etc.).
    interpolation (str): Interpolation method ('bone' or 'full').
    opt_image (str): Which image to optimize (e.g., 'bone_ct', 'soft_tissue').
    output_path_patient (str): Output path for the optimized results.
    interpolation_method (str): Interpolation method used in optimization.
    metric (str): Metric to use for optimization ('mse', 'ncc', etc.).
    num_levels (int): Number of levels in the resolution pyramid.
    
    Returns:
    final_optimized_params (ndarray): Final optimized plane parameters at the highest resolution.
    """
    
    # Step 1: Build the image pyramid for bone_ct and soft_tissue
    bone_pyramid = build_image_pyramid(bone_ct, num_levels)
    soft_tissue_pyramid = build_image_pyramid(soft_tissue, num_levels)
    image_plot_pyramid = build_image_pyramid(image_plot, num_levels)
    initial_guess = plane
    # Step 2: Optimization loop across pyramid levels (starting from the lowest resolution)
    current_guess = initial_guess
    for level in range(num_levels - 1, -1, -1):  # Start from the highest resolution (level num_levels-1) and move to lower levels
        print(f"Optimizing at resolution level {level + 1}")
        folder_path = os.path.join(output_path, f"Resolution_level_{level + 1}")
        os.makedirs(folder_path, exist_ok=True)
        
        # Extract the current resolution images
        bone_current_res = bone_pyramid[level]
        soft_tissue_current_res = soft_tissue_pyramid[level]
        image_plot_res = image_plot_pyramid[level]
        
        plt.imshow(image_plot_res[:, :, image_plot_res.shape[2] // 2], cmap='gray')
        plt.title(f'Image Plot at Resolution Level {level + 1}')
        plt.axis('off')
        plt.savefig(os.path.join(folder_path, f"image_plot_level_{level + 1}.png"))
        plt.close()

        if level != num_levels - 1:
            current_guess = plane_params_rescale(current_guess, 2)
        
        # Step 3: Optimize plane parameters at the current resolution level
        best_plane_params, obj_fun = optimize_plane_parameters(
            current_guess, bone_current_res, soft_tissue_current_res, image_plot_res, pat_num, optimization_method, 
            interpolation, opt_image, folder_path, interpolation_method, metric, interpolator
        )
        
        # Step 4: Update the current guess for the next resolution (if any)
        current_guess = best_plane_params
        
    return current_guess, obj_fun

## Deformed Plane

In [None]:
import numpy as np
from scipy.interpolate import Rbf
from scipy.ndimage import center_of_mass
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from scipy.spatial.transform import Rotation as R

import numpy as np
from scipy import ndimage

def center_of_mass_structures(body, mandible, spinalcord):
    com_slices = []

    for i in range(mandible.shape[2]):
        slice_mandible = mandible[:, :, i]
        slice_spinalcord = spinalcord[:, :, i]
        slice_body = body[:, :, i]

        # Skip empty slices
        if np.sum(slice_mandible) == 0 and np.sum(slice_spinalcord) == 0 and np.sum(slice_body) == 0:
            continue  # Skip this slice if all masks are empty
        image = slice_mandible + slice_spinalcord + slice_body
        com_mandible = ndimage.center_of_mass(slice_mandible) if np.sum(slice_mandible) > 0 else None
        com_spinalcord = ndimage.center_of_mass(slice_spinalcord) if np.sum(slice_spinalcord) > 0 else None
        com_body = ndimage.center_of_mass(slice_body) if np.sum(slice_body) > 0 else None

        #plt.imshow(image, cmap='gray')

        # Convert to (x, y, z) format if valid
        if com_body is not None:
            com_body = (com_body[1], com_body[0], i)
            com_slices.append(com_body)
            #plt.scatter(com_body[0], com_body[1], c='r', marker='o', label="Center of Mass (Body)")

        if com_mandible is not None:
            com_mandible = (com_mandible[1], com_mandible[0], i)
            com_slices.append(com_mandible)
            #plt.scatter(com_mandible[0], com_mandible[1], c='g', marker='o', label="Center of Mass (Mandible)")

        if com_spinalcord is not None:
            com_spinalcord = (com_spinalcord[1], com_spinalcord[0], i)
            com_slices.append(com_spinalcord)
            #plt.scatter(com_spinalcord[0], com_spinalcord[1], c='b', marker='o', label="Center of Mass (Spinal Cord)")
            
        #plt.show()

    com_slices_array = np.array(com_slices, dtype=np.float64)

    # Final check for NaN/Inf before returning
    if np.any(np.isnan(com_slices_array)) or np.any(np.isinf(com_slices_array)):
        raise ValueError("NaN or Inf values detected in computed center of mass array!")

    return com_slices_array



import plotly.graph_objects as go


def compute_deformed_plane(image, body, mandible_mask, spinalcord_mask, rigid_plane_params):
    """
    Compute the deformed mid-sagittal plane using Thin-Plate Splines (TPS) and visualize it interactively.
    """
    A, B, C, D = rigid_plane_params
    
    # Compute center of mass per slice
    center_of_mass_points = center_of_mass_structures(body, mandible_mask, spinalcord_mask)

    # Extract coordinates
    x_vals, y_vals, z_vals = center_of_mass_points[:, 0], center_of_mass_points[:, 1], center_of_mass_points[:, 2]

    # Check for consistency
    if len(x_vals) != len(y_vals) or len(y_vals) != len(z_vals):
        raise ValueError(f"Array length mismatch: x({len(x_vals)}), y({len(y_vals)}), z({len(z_vals)})")

    # Compute normal vector
    normal = np.array([A, B, C])
    normal = normal / np.linalg.norm(normal)

    # Compute rotation to align normal with X-axis (1,0,0)
    rotation_vector = np.cross(normal, [1, 0, 0])
    rotation_angle = np.arccos(np.clip(np.dot(normal, [1, 0, 0]), -1.0, 1.0))

    if np.isnan(rotation_angle) or np.any(np.isnan(rotation_vector)):
        raise ValueError("Invalid rotation: NaN detected in angle or vector.")

    rotation_matrix = R.from_rotvec(rotation_angle * rotation_vector / np.linalg.norm(rotation_vector)).as_matrix()

    # Rotate center of mass points
    rotated_points = center_of_mass_points @ rotation_matrix.T
    x_rot, y_rot, z_rot = rotated_points[:, 0], rotated_points[:, 1], rotated_points[:, 2]

    # Check for mismatched lengths
    if len(x_rot) != len(y_rot) or len(y_rot) != len(z_rot):
        raise ValueError(f"Rotated array length mismatch: x({len(x_rot)}), y({len(y_rot)}), z({len(z_rot)})")

    # Interpolation
    tps_x = Rbf(z_rot, y_rot, x_rot, function='thin_plate')

    # Generate a smooth deformed plane in rotated coordinates
    z_range = np.linspace(min(z_rot), max(z_rot), 100)
    y_range = np.linspace(min(y_rot), max(y_rot), 100)
    Z_grid, Y_grid = np.meshgrid(z_range, y_range)

    # Compute X-displacements using TPS
    X_displacement = tps_x(Z_grid, Y_grid)
    X_grid_rotated = X_displacement

    # Stack deformed points
    deformed_points_rotated = np.column_stack((X_grid_rotated.ravel(), Y_grid.ravel(), Z_grid.ravel()))

    # Rotate back to original coordinate system
    deformed_points = deformed_points_rotated @ rotation_matrix.T
    X_grid, Y_grid, Z_grid = deformed_points[:, 0].reshape(Z_grid.shape), deformed_points[:, 1].reshape(Z_grid.shape), deformed_points[:, 2].reshape(Z_grid.shape)

    # Create interactive 3D plot with Plotly
    fig = go.Figure()

    # Add center of mass points
    fig.add_trace(go.Scatter3d(
        x=x_vals, y=y_vals, z=z_vals,
        mode='markers',
        marker=dict(size=5, color='red'),
        name="Center of Mass"
    ))

    # Add deformed plane as a surface
    fig.add_trace(go.Surface(
        x=X_grid, y=Y_grid, z=Z_grid,
        colorscale='Viridis',
        opacity=0.7
    ))

    # Set labels and title
    fig.update_layout(
        title="Interactive 3D Deformed Mid-Sagittal Plane",
        scene=dict(
            xaxis_title="X-axis (Deformed Plane Width)",
            yaxis_title="Y-axis (Height)",
            zaxis_title="Z-axis (Slices)"
        ),
        margin=dict(l=0, r=0, b=0, t=40)
    )

    fig.show()






## Plot Objective Function
Plot the objective function with respect to two parameters (heatmap) or with respect to one parameter (curve). Visualized the dependency of the objective function with respect to the parameters.

In [None]:
def param_initialization_1d(image, soft_tissue, image_plot, pat, output_path, interpolation=False, obj_val_update=None, plot = False, interpolation_method='linear', interpolator=None):
    """
    Initializes parameters for plane fitting on a given image.
    This function calculates the center of mass of the input image and uses it to
    initialize plane parameters. It evaluates multiple planes by calculating
    their mean squared error (MSE) with respect to the image and selects the plane
    with the lowest MSE.
    """
    start_initialization = time.time()
    com = calculate_center_of_mass(image)
    com = np.array([com[1], com[0], com[2]])  # Convert COM to match image axes
    phi = 0  # Fix azimuthal angle to start; can vary it if needed
    angle_rad = np.deg2rad(10)  # Polar angle in radians
    thetas = np.linspace(-angle_rad, angle_rad, 100)  # Polar angles from 0 to pi

    mse_list = []
    plane_params_list = []
    middle_x = image.shape[0] // 2
    A = middle_x
    B = 0
    C = 0
    param_vec = np.array([A, B, C])
    
    for theta in thetas:
        # Calculate the rotation matrix for the given angle theta
        rotation_matrix = np.array([
            [np.cos(theta), -np.sin(theta), 0],
            [np.sin(theta), np.cos(theta), 0],
            [0, 0, 1]
        ])

        # Rotate the normal vector around the center of mass
        rotated_normal_vector = np.dot(rotation_matrix, param_vec)
        rotated_normal_vector_normalized = rotated_normal_vector / np.linalg.norm(rotated_normal_vector)  # Normalize

        # Calculate D for the plane equation
        D = -np.dot(rotated_normal_vector_normalized, com)

        # Store the plane parameters
        plane_params = (rotated_normal_vector_normalized[0], rotated_normal_vector_normalized[1], rotated_normal_vector_normalized[2], D)
        plane_params_list.append(plane_params)

        print(f"Plane parameters for theta = {np.rad2deg(theta):.2f} degrees: {plane_params}")

        # Convert rotated normal vector to spherical coordinates (theta, phi, l)
        theta, phi, l = vector_to_angles(rotated_normal_vector)
        rot_params = np.array([theta, phi, l])

        # Compute MSE for this plane
        mse = objective_function(
            rot_params, image, soft_tissue, image_plot, pat,
            interpolation=interpolation, obj_val_update=obj_val_update, interpolation_method = interpolation_method, interpolator = interpolator
        )
        mse_list.append(mse)

        print(f"MSE for theta = {np.rad2deg(theta):.2f} degrees: {mse}")

    if plot:
        #plt.scatter(np.rad2deg(thetas), mse_list, marker='x', color='red')
        plt.plot(np.rad2deg(thetas), mse_list, color='red')
        plt.title('MSE vs. Polar Angle')
        plt.xlabel('Polar Angle (degrees)')
        plt.ylabel('Mean Squared Error')
        plt.grid(True)
        plt.savefig(os.path.join(output_path, f"MSE_vs_polar_angle_patient_{pat}.png"))
        #plt.show()
        plt.close()

        #plot_middle_slice_with_planes(image, plane_params_list, title='Middle Slice with Rotated Planes', com=com)

    # Select the plane with the lowest MSE
    best_plane_params = plane_params_list[np.argmin(mse_list)]
    print(f"Best plane parameters: {best_plane_params}")
    
    end_initialization = time.time()
    print(f"Time taken for initialization: {end_initialization - start_initialization:.2f} seconds")
    
    return best_plane_params

def rotated_shift_1d(image, soft_tissue, image_plot, pat, output_path, interpolation=False, obj_val_update=None, plot=False, interpolation_method='linear', interpolator=None):
    """
    Perform a rotated shift on a 1D image and find the best plane parameters that minimize the mean squared error (MSE).
    Parameters:
    -----------
    image : ndarray
        The input image data.
    soft_tissue : ndarray
        The soft tissue data.
    image_plot : ndarray
        The image plot data.
    pat : str
        Patient identifier.
    output_path : str
        Path to save the output plot.
    interpolation : bool, optional
        Whether to use interpolation (default is False).
    obj_val_update : callable, optional
        Function to update the objective value (default is None).
    plot : bool, optional
        Whether to plot the MSE vs. L graph (default is False).
    interpolation_method : str, optional
        Method of interpolation to use (default is 'linear').
    Returns:
    --------
    best_plane_params : tuple
        The parameters of the best plane (a, b, c, L) that minimize the MSE.
    Notes:
    ------
    - The function calculates the center of mass of the image and defines an original plane.
    - It then rotates the plane around the center of mass by 10 degrees.
    - The function computes the MSE for planes with varied L values around the rotated plane's L value.
    - If requested, it plots the MSE vs. L graph and saves it to the specified output path.
    - Finally, it identifies and returns the best plane parameters that minimize the MSE.
    """
    
    start_time = time.time()

    # Step 1: Calculate the center of mass
    com = calculate_center_of_mass(image)
    com = np.array([com[1], com[0], com[2]])  # Adjust order if necessary for image axes

    # Step 2: Define the original plane
    middle_x = image.shape[0] // 2
    original_plane = np.array([1, 0, 0, -middle_x])  # Plane: x = middle_x

    # Step 3: Rotate the plane around the center of mass by 10 degrees
    angle_rad = np.deg2rad(10)
    rotation_matrix = np.array([
        [1, 0, 0],
        [0, np.cos(angle_rad), -np.sin(angle_rad)],
        [0, np.sin(angle_rad), np.cos(angle_rad)]
    ])
    rotated_normal_vector = np.dot(rotation_matrix, original_plane[:3])
    rotated_normal_vector /= np.linalg.norm(rotated_normal_vector)  # Normalize
    rotated_D = -np.dot(rotated_normal_vector, com)
    rotated_plane = np.append(rotated_normal_vector, rotated_D)

    # Step 4: Define the L range around the rotated plane's L value
    _, _, rotated_L = vector_to_angles(rotated_normal_vector)
    L_range = np.linspace(rotated_L - 10, rotated_L + 10, 100)

    # Step 5: Compute MSE for planes with varied L
    mse_list = []
    for L in L_range:
        params = (rotated_normal_vector[0], rotated_normal_vector[1], rotated_normal_vector[2], L)
        mse = objective_function(
            params, image, soft_tissue, image_plot, pat,
            interpolation=interpolation, obj_val_update=obj_val_update, interpolation_method=interpolation_method, interpolator = interpolator
        )
        mse_list.append(mse)

    # Step 6: Plot results if requested
    if plot:
        plt.plot(L_range, mse_list, color='blue', label='MSE')
        plt.title('MSE vs. L (Shift Along Rotated Plane)')
        plt.xlabel('L')
        plt.ylabel('Mean Squared Error')
        plt.legend()
        plt.grid(True)
        plt.savefig(os.path.join(output_path, f"MSE_vs_L_patient_{pat}.png"))
        plt.close()

    # Step 7: Identify the best plane
    best_L = L_range[np.argmin(mse_list)]
    best_plane_params = (rotated_normal_vector[0], rotated_normal_vector[1], rotated_normal_vector[2], best_L)

    end_time = time.time()
    print(f"Time taken for initialization: {end_time - start_time:.2f} seconds")
    print(f"Best plane parameters: {best_plane_params}")

    return best_plane_params


def plot_mse_vs_parameters(image, soft_tissue, image_plot, pat, output_path, interpolation=False, obj_val_update=None,
                            interpolation_method='cubic', interpolator=None, best_plane_params=None):
    """
    Generates plots of MSE as a function of each of the parameters: theta, phi, and L.
    """
    print("Generating MSE plots...")
    if best_plane_params is None:
        A, B, C, D= (1,0,0,-image.shape[0])
    else:
        A, B, C, D = best_plane_params
    vec = np.array([A, B, C])
    if np.linalg.norm(vec) < 2:
        vec = vec * abs(D)
    theta_best, phi_best, L_best = vector_to_angles(vec)
    theta_best_deg = np.rad2deg(theta_best)
    phi_best_deg = np.rad2deg(phi_best)
    print(f"Best plane parameters: theta = {theta_best_deg:.2f}°, phi = {phi_best_deg:.2f}°, L = {L_best:.2f}")


    normal_vector = np.array([1, 0, 0])
    normal_vector = vec / np.linalg.norm(vec)

    #plot_plane_on_middle_slice(image_plot, best_plane_params, title="Best Plane", pat=pat)

    com = calculate_center_of_mass(image_plot)
    com = np.array([com[1], com[0], com[2]])
    middle_x = image.shape[0] // 2
    com = np.array([256.54,176.30,13.60])
    
    # Define parameter ranges
    theta_range = np.linspace(np.deg2rad(theta_best_deg - 2), np.deg2rad(theta_best_deg + 2), 100)  # Theta range around the best theta value
    phi_range = np.linspace(np.deg2rad(phi_best_deg - 2), np.deg2rad(phi_best_deg + 2), 100)  # Phi range around the best phi value
    L_range = np.linspace(L_best - 4, L_best + 4, 100)  # L range around the best L value
    
    mse_theta = []
    mse_phi = []
    mse_L = []
    plane_param_list_theta = []
    plane_param_list_phi = []
    plane_param_list_L = []

    
    # Compute MSE for theta (phi=0, L=middle_x)
    for theta in theta_range:
        rotation_matrix = np.array([
            [np.cos(theta), -np.sin(theta), 0],
            [np.sin(theta), np.cos(theta), 0],
            [0, 0, 1]
        ])

        normal_vector_theta = np.dot(rotation_matrix, normal_vector)
        D = -np.dot(normal_vector_theta, com)
        plane_params_theta = (normal_vector_theta[0], normal_vector_theta[1], normal_vector_theta[2], D)
        mse = objective_function([theta, phi_best, np.abs(D)], image, soft_tissue, image_plot, pat, 
                                 interpolation=interpolation, obj_val_update=obj_val_update, 
                                 interpolation_method=interpolation_method, interpolator=interpolator)
        mse_theta.append(mse)
        plane_param_list_theta.append(plane_params_theta)
        
    
    # Compute MSE for phi (theta=0, L=middle_x)
    for phi in phi_range:
        rotation_matrix = np.array([
            [np.cos(phi), 0, np.sin(phi)],
            [0, 1, 0],
            [-np.sin(phi), 0, np.cos(phi)]
        ])
        normal_vector_phi = np.dot(rotation_matrix, normal_vector)
        D = -np.dot(normal_vector_phi, com)
        plane_params_phi = (normal_vector_phi[0], normal_vector_phi[1], normal_vector_phi[2], D)
        mse = objective_function([theta_best, phi, np.abs(D)], image, soft_tissue, image_plot, pat, 
                                 interpolation=interpolation, obj_val_update=obj_val_update, 
                                 interpolation_method=interpolation_method, interpolator=interpolator)
        mse_phi.append(mse)
        plane_param_list_phi.append(plane_params_phi)
        
    # Compute MSE for L (using best_theta and best_phi)
    for L in L_range:
        mse = objective_function([theta_best, phi_best, L], image, soft_tissue, image_plot, pat, 
                                 interpolation=interpolation, obj_val_update=obj_val_update, 
                                 interpolation_method=interpolation_method, interpolator=interpolator)
        mse_L.append(mse)
        A, B, C = angles_to_vector(theta_best, phi_best, L)
        D = -np.dot([A, B, C], [A, B, C])
        plane_params_L = (A, B, C, D)
        plane_param_list_L.append(plane_params_L)
    
    # plot_middle_slice_with_planes(image_plot, plane_param_list_theta, title='Middle Slice with Rotated Planes by Theta', com=com, output_path=output_path, filename=f"middle_slice_rotated_theta_{pat}.png")
    # plot_middle_slice_with_planes(image_plot, plane_param_list_phi, title='Middle Slice with Rotated Planes by Phi', com=com, coronal = True, output_path=output_path, filename=f"middle_slice_rotated_phi_{pat}.png")
    # plot_middle_slice_with_planes(image_plot, plane_param_list_L, title='Middle Slice shifted Planes along L', com=com, output_path=output_path, filename=f"middle_slice_shifted_L_{pat}.png")
    
    # Plot MSE vs Theta
    plt.figure()
    plt.plot(np.degrees(theta_range), mse_theta, label='MSE vs Theta')
    plt.xlabel('Theta (degrees)')
    plt.ylabel('MSE')
    plt.title('MSE vs Theta \n best theta = {:.2f} degrees'.format(np.degrees(theta_best)))
    plt.legend()
    plt.grid()
    plt.savefig(output_path + f'/mse_vs_theta_{pat}.png')
    
    # Plot MSE vs Phi
    plt.figure()
    plt.plot(np.degrees(phi_range), mse_phi, label='MSE vs Phi')
    plt.xlabel('Phi (degrees)')
    plt.ylabel('MSE')
    plt.title('MSE vs Phi \n best phi = {:.2f} degrees'.format(np.degrees(phi_best)))
    plt.legend()
    plt.grid()
    plt.savefig(output_path + f'/mse_vs_phi_{pat}.png')
    
    # Plot MSE vs L
    plt.figure()
    plt.plot(L_range, mse_L, label='MSE vs L')
    plt.xlabel('L')
    plt.ylabel('MSE')
    plt.title('MSE vs L \n best L = {:.2f}'.format(L_best))
    plt.legend()
    plt.grid()
    plt.savefig(output_path + f'/mse_vs_L_{pat}.png')
    
    print("Plots saved successfully.")



def param_initialization_2d(image, soft_tissue, image_plot, pat, output_path, interpolation=False, obj_val_update=None, plot = False, gradient = False, 
                            interpolation_method='linear', interpolator = None, multi_resolution_optimization = False, num_levels = 4, verification_list = None, body = None):
    """
    Initializes parameters for plane fitting on a given image.
    This function calculates the center of mass of the input image and uses it to
    initialize plane parameters. It evaluates multiple planes by calculating
    their mean squared error (MSE) with respect to the image and selects the plane
    with the lowest MSE.
    """
    print("Starting initialization...")
    start_initialization = time.time()
    if multi_resolution_optimization:
        bone_pyramid = build_image_pyramid(image, num_levels)
        image = bone_pyramid[-1]
        soft_tissue_pyramid = build_image_pyramid(soft_tissue, num_levels)
        soft_tissue = soft_tissue_pyramid[-1]
        image_plot_pyramid = build_image_pyramid(image_plot, num_levels)
        image_plot = image_plot_pyramid[-1]
    
    com = calculate_center_of_mass(image_plot)
    com = np.array([com[1], com[0], com[2]])  # Convert COM to match image axes
    if body is not None:
        com = center_of_mass(body)
        com = np.array([com[1], com[0], com[2]])  # Convert COM to match image axes
    
    angle_rad_theta = np.deg2rad(25)
    angle_rad_phi =  np.deg2rad(10)
    thetas = np.linspace(-angle_rad_theta, angle_rad_theta, 10)  # Polar angles 
    phis = np.linspace(-angle_rad_phi, angle_rad_phi, 10)  # Azimuthal angles 

    mse_list = []
    mse_data = []
    plane_params_list = []
    grad_list = []
    middle_x = image.shape[0] // 2
    A = middle_x
    B = 0
    C = 0
    param_vec = np.array([A, B, C])

    # Check if the output path exists, if not create it
    if not os.path.exists(os.path.join(output_path, f"mse_array_heatmap_com_rot_patient_{pat}.npy")) or not os.path.exists(os.path.join(output_path, f"plane_params_list_com_rot_patient_{pat}.npy")):

        for phi_idx, phi in enumerate(phis):  # Loop over azimuthal angles
            for theta_idx, theta in enumerate(thetas):  # Loop over polar angles
                # Calculate the rotation matrix for the given theta and phi
                rotation_matrix = np.array([
                    [np.cos(theta) * np.cos(phi), -np.sin(theta), np.cos(theta) * np.sin(phi)],
                    [np.sin(theta) * np.cos(phi), np.cos(theta), np.sin(theta) * np.sin(phi)],
                    [-np.sin(phi), 0, np.cos(phi)]
                ])

                # Rotate the normal vector around the center of mass
                rotated_normal_vector = np.dot(rotation_matrix, param_vec)
                rotated_normal_vector_normalized = rotated_normal_vector / np.linalg.norm(rotated_normal_vector)  # Normalize

                # Calculate D for the plane equation
                D = -np.dot(rotated_normal_vector_normalized, com)

                # Store the plane parameters
                plane_params = (
                    rotated_normal_vector_normalized[0],
                    rotated_normal_vector_normalized[1],
                    rotated_normal_vector_normalized[2],
                    D
                )
                plane_params_list.append(plane_params)

                #print(f"Plane parameters for theta = {np.rad2deg(theta):.2f}°, phi = {np.rad2deg(phi):.2f}°: {plane_params}")

                # Convert rotated normal vector to spherical coordinates (theta, phi, l)
                theta_sph, phi_sph, l = vector_to_angles(rotated_normal_vector)
                rot_params = np.array([theta_sph, phi_sph, l])

                # Compute MSE for this plane
                mse = objective_function(
                    rot_params, image, soft_tissue, image_plot, pat,
                    interpolation=interpolation, obj_val_update=obj_val_update, interpolation_method = interpolation_method, interpolator = interpolator
                )
                mse_list.append(mse)
                mse_data.append([np.rad2deg(theta_sph), np.rad2deg(phi_sph), l, mse])

                if gradient:
                    # Compute gradient
                    grad = gradient_theta_phi_L(rot_params, image, soft_tissue, image_plot, pat, interpolation, obj_val_update, interpolation_method)
                    grad_list.append((theta_idx, phi_idx, grad[0], grad[1]))  # Store theta, phi indices and gradients
        mse_array = np.array(mse_list).reshape(len(phis), len(thetas))
        plane_params_array = np.array(plane_params_list)  
        mse_array_path = os.path.join(output_path, f"mse_array_heatmap_com_rot_patient_{pat}.npy")
        np.save(mse_array_path, mse_array)
        np.save(os.path.join(output_path, f"plane_params_list_com_rot_patient_{pat}.npy"), plane_params_array)

    else:
        mse_array = np.load(os.path.join(output_path, f"mse_array_heatmap_com_rot_patient_{pat}.npy"))
        mse_list = mse_array.flatten().tolist() 
        plane_params_array = np.load(os.path.join(output_path, f"plane_params_list_com_rot_patient_{pat}.npy"))
        plane_params_list = plane_params_array.tolist()


    #print(f"MSE list saved at: {mse_array_path}")
    plot_middle_slice_with_planes(image, plane_params_list, title=None, com=None, output_path=output_path, filename=f"middle_slice_rotated_planes_patient_{pat}.png")
            #print(f"MSE for theta = {np.rad2deg(theta):.2f}°, phi = {np.rad2deg(phi):.2f}°: {mse}")
    

    def generate_mse_gif(mse_array, phis, thetas, output_path="/mnt/data/mse_heatmap.gif"):
        # Convert MSE list into a 2D array for plotting
        #mse_array = np.array(mse_list).reshape(len(phis), len(thetas))

        # Find min and max MSE values
        vmin = np.min(mse_array)
        vmax_initial = vmin + 10**7  # Initial vmax value
        vmax_steps = np.linspace(vmax_initial, vmin, 100)  # Decrease vmax over 50 steps

        # Find the index of the minimum MSE value
        min_index = np.unravel_index(np.argmin(mse_array), mse_array.shape)

        # Compute the center of the selected bin
        min_phi = np.rad2deg(phis[min_index[0]])
        min_theta = np.rad2deg(thetas[min_index[1]])

        # # Compute bin widths
        # d_theta = np.rad2deg(thetas[1] - thetas[0]) / 2  # Half the step size in θ
        # d_phi = np.rad2deg(phis[1] - phis[0]) / 2  # Half the step size in φ

        # # Adjust to center of bin
        # min_theta += d_theta
        # min_phi += d_phi

        # Generate GIF frames
        from matplotlib.colors import LogNorm

        gif_frames = []
        for vmax in vmax_steps:
            fig, ax = plt.subplots()

            im = ax.imshow(mse_array, extent=[np.rad2deg(thetas[0]), np.rad2deg(thetas[-1]), 
                                            np.rad2deg(phis[0]), np.rad2deg(phis[-1])], 
                        aspect='auto', origin='lower', cmap='viridis', vmin=vmin, vmax=vmax)

            plt.colorbar(im, label='Mean Squared Error')
            plt.title('MSE vs. Polar and Azimuthal Angles')
            plt.xlabel('Polar Angle (θ)°')
            plt.ylabel('Azimuthal Angle (φ)°')

            # Plot a red cross at the bin center
            plt.scatter(min_theta, min_phi, color='red', marker='x', s=100, label='Min MSE')
            plt.legend()

            # Save frame to buffer
            fig.canvas.draw()
            frame = np.array(fig.canvas.renderer.buffer_rgba())
            gif_frames.append(frame)
            plt.close(fig)

        # Save GIF
        imageio.mimsave(output_path, gif_frames, fps=10)
        print(f"GIF saved at: {output_path}")

        return output_path
    
    def plot_mse_heatmap(mse_array, phis, thetas, output_path=None, filename="mse_heatmap.png"):
        """
        Plot a static heatmap of the MSE values as a function of polar and azimuthal angles.
        
        Parameters
        ----------
        mse_array : numpy.ndarray
            2D array of mean squared error (MSE) values.
        phis : array-like
            Array of azimuthal angles (in radians).
        thetas : array-like
            Array of polar angles (in radians).
        output_path : str, optional
            Directory where the plot should be saved. If None, the plot is not saved.
        filename : str, optional
            Name of the file to save the plot.
        
        Returns
        -------
        None
        """
        # Determine the minimum and maximum MSE values for the color scale.
        vmin = np.min(mse_array)
        vmax = vmin + 10**7  # Adjust vmax if needed
        
        # Find the index of the minimum MSE value.
        min_index = np.unravel_index(np.argmin(mse_array), mse_array.shape)
        min_phi = np.rad2deg(phis[min_index[0]])
        min_theta = np.rad2deg(thetas[min_index[1]])
        
        # Create the plot.
        fig, ax = plt.subplots(figsize=(10, 10))
        im = ax.imshow(
            mse_array,
            extent=[np.rad2deg(thetas[0]), np.rad2deg(thetas[-1]), np.rad2deg(phis[0]), np.rad2deg(phis[-1])],
            aspect='auto', origin='lower', cmap='viridis', vmin=vmin, vmax=vmax
        )
        cbar = plt.colorbar(im, ax=ax)
        cbar.set_label('Mean Squared Error')
        
        #ax.set_title('MSE vs. Polar and Azimuthal Angles')
        ax.set_xlabel('Polar Angle (θ)°')
        ax.set_ylabel('Azimuthal Angle (φ)°')
        
        # Mark the location of the minimum MSE.
        ax.scatter(min_theta, min_phi, color='red', marker='x', s=100, label='Min MSE')
        ax.legend()
        
        # Save the plot if an output path is provided.
        if output_path is not None:
            full_path = os.path.join(output_path, filename)
            plt.savefig(full_path)
            print(f"Plot saved at: {full_path}")
        
        plt.show()

    # Plot the heatmap
    plot_mse_heatmap(mse_array, phis, thetas, output_path=output_path, filename=f"mse_heatmap_patient_{pat}.png")
    path = generate_mse_gif(mse_array, phis, thetas, output_path=os.path.join(output_path, f"mse_heatmap_patient_{pat}.gif"))
    # Plot the planes on the middle slice
    # plot_middle_slice_with_planes(image, plane_params_list, title='Middle Slice with Rotated Planes', com=com)

    # Select the plane with the lowest MSE
    best_plane_index = np.argmin(mse_list)
    best_plane_params = plane_params_list[best_plane_index]
    print(f"Best plane parameters: {best_plane_params}")
    print(f"Best MSE: {mse_list[best_plane_index]}")

    end_initialization = time.time()
    print(f"Time taken for initialization: {end_initialization - start_initialization:.2f} seconds")

    # plot_plane_on_middle_slice(image_plot, best_plane_params, title="Best Plane", pat=pat)

    return best_plane_params


# def param_evaluation_3d(image, soft_tissue, image_plot, pat, output_path, interpolation=False, obj_val_update=None, 
#                         interpolation_method='linear', param_ranges=(10, 5, 10), num_steps=(10, 10, 10), 
#                         fixed_ref_params=(0, 0, 256)):
#     """
#     Evaluates the MSE over a 3D parameter space and generates heatmaps for visualization.
    
#     Parameters:
#         param_ranges (tuple): Range for (theta, phi, L) perturbations in degrees (theta, phi) and units (L).
#         num_steps (tuple): Number of steps for (theta, phi, L).
#         fixed_ref_params (tuple): Fixed (theta, phi, L) reference for visualization.
#     """
#     start_time = time.time()
    
#     A, B, C, D = fixed_ref_params
#     vec = np.array([A, B, C])
#     if np.linalg.norm(vec) < 2:
#         vec = vec * abs(D)
#     initial_theta, initial_phi, initial_l = vector_to_angles(vec)
#     initial_params = np.array([initial_theta, initial_phi, initial_l])
    
#     theta_range_deg, phi_range_deg, l_range = param_ranges
#     theta_steps, phi_steps, l_steps = num_steps
    
#     theta_range = np.deg2rad(theta_range_deg)
#     phi_range = np.deg2rad(phi_range_deg)
    
#     thetas = np.linspace(initial_theta - theta_range, initial_theta + theta_range, theta_steps)
#     phis = np.linspace(initial_phi - phi_range, initial_phi + phi_range, phi_steps)
#     ls = np.linspace(initial_l - l_range, initial_l + l_range, l_steps)
    
#     results = []
#     mse_cache = {}
    
#     def get_mse(theta, phi, l):
#         key = (theta, phi, l)
#         if key not in mse_cache:
#             mse_cache[key] = objective_function(
#                 [theta, phi, l], image, soft_tissue, image_plot, pat,
#                 interpolation=interpolation, obj_val_update=obj_val_update, interpolation_method=interpolation_method
#             )
#         return mse_cache[key]
        
#     # Compute MSE for all parameter combinations
#     for theta_index, theta in enumerate(thetas):
#         print(f"Progress: {(theta_index + 1) * 100 / len(thetas)} %")
#         for phi in phis:
#             for l in ls:
#                 mse = get_mse(theta, phi, l)
#                 results.append({"theta": theta, "phi": phi, "L": l, "mse": mse})


#     # Save results to JSON
#     results_path = os.path.join(output_path, f"mse_results_patient_{pat}.json")
#     with open(results_path, "w") as f:
#         json.dump(results, f, indent=4)

#     # Save results to CSV
#     csv_path = os.path.join(output_path, f"mse_results_patient_{pat}.csv")
#     with open(csv_path, mode='w', newline='') as file:
#         writer = csv.DictWriter(file, fieldnames=["theta", "phi", "L", "mse"])
#         writer.writeheader()
#         for result in results:
#             writer.writerow(result)

#     min_mse = min(mse_cache.values())
#     best_params = [key for key, value in mse_cache.items() if value == min_mse][0]
#     min_mse_point = {"theta": best_params[0], "phi": best_params[1], "L": best_params[2], "mse": min_mse}
    
#     return best_params, min_mse



def param_evaluation_3d(image, soft_tissue, image_plot, pat, output_path, interpolation=False, obj_val_update=None, 
                        interpolation_method='linear', param_ranges=(10, 5, 10), num_steps=(10, 10, 10), 
                        fixed_ref_params=(0, 0, 256), interpolator=None, use_parallel=False, n_jobs=15):
    """
    Evaluates the MSE over a 3D parameter space and generates heatmaps for visualization.

    Parameters:
        param_ranges (tuple): Range for (theta, phi, L) perturbations in degrees (theta, phi) and units (L).
        num_steps (tuple): Number of steps for (theta, phi, L).
        fixed_ref_params (tuple): Fixed (theta, phi, L) reference for visualization.
        use_parallel (bool): Whether to use parallel computing.
        n_jobs (int): Number of jobs for parallel execution.
    """
    print("Starting parameter sampling...")
    start_time = time.time()

    A, B, C, D = fixed_ref_params
    vec = np.array([A, B, C])
    if np.linalg.norm(vec) < 2:
        vec = vec * abs(D)
    initial_theta, initial_phi, initial_l = vector_to_angles(vec)
    
    theta_range_deg, phi_range_deg, l_range = param_ranges
    theta_steps, phi_steps, l_steps = num_steps

    theta_range = np.deg2rad(theta_range_deg)
    phi_range = np.deg2rad(phi_range_deg)

    thetas = np.linspace(initial_theta - theta_range, initial_theta + theta_range, theta_steps)
    phis = np.linspace(initial_phi - phi_range, initial_phi + phi_range, phi_steps)
    ls = np.linspace(initial_l - l_range, initial_l + l_range, l_steps)

    mse_cache = {}

    def get_mse(theta, phi, l):
        key = (theta, phi, l)
        if key not in mse_cache:
            mse_cache[key] = objective_function(
                [theta, phi, l], image, soft_tissue, image_plot, pat,
                interpolation=interpolation, obj_val_update=obj_val_update,
                interpolation_method=interpolation_method, interpolator=interpolator
            )
        return mse_cache[key]

    def compute_mse_for_combination(theta, phi, l):
        return {"theta": theta, "phi": phi, "L": l, "mse": get_mse(theta, phi, l)}

    # Generate parameter combinations dynamically
    parameter_combinations = ((theta, phi, l) for theta in thetas for phi in phis for l in ls)

    if use_parallel:
        batch_size = max(1, len(thetas) * len(phis) * len(ls) // (10 * n_jobs))  # Reduce job overhead
        with tqdm(total=len(thetas) * len(phis) * len(ls), desc="Processing") as pbar:
            results = Parallel(n_jobs=n_jobs, batch_size=batch_size)(
                delayed(compute_mse_for_combination)(theta, phi, l)
                for theta, phi, l in parameter_combinations
            )
            pbar.update(len(results))
    else:
        results = []
        for theta, phi, l in tqdm(parameter_combinations, total=len(thetas) * len(phis) * len(ls), desc="Processing"):
            results.append(compute_mse_for_combination(theta, phi, l))

    # Save results to JSON and CSV
    os.makedirs(output_path, exist_ok=True)
    results_path = os.path.join(output_path, f"mse_results_patient_{pat}.json")
    csv_path = os.path.join(output_path, f"mse_results_patient_{pat}.csv")

    with open(results_path, "w") as f_json:
        json.dump(results, f_json, indent=4)

    with open(csv_path, mode='w', newline='') as f_csv:
        writer = csv.DictWriter(f_csv, fieldnames=["theta", "phi", "L", "mse"])
        writer.writeheader()
        writer.writerows(results)

    # Find the minimum MSE and best parameters
    best_result = min(results, key=lambda x: x['mse'])
    
    end_time = time.time()
    print(f"Time taken: {end_time - start_time:.2f} seconds")
    print(f"Minimum MSE: {best_result['mse']:.2f}, Best Parameters: θ={best_result['theta']:.2f}, φ={best_result['phi']:.2f}, L={best_result['L']:.2f}")

    return best_result, best_result['mse']



def obj_fun_heatmap(sampling_results_csv_path, output_path, optimized_params, param_ranges, num_steps, method = 'Nelder-Mead', sampling_params = None, heatmap_range = None):
    # Start the timer
    print("Starting heatmap generation...")
    start_time = time.time()

    if heatmap_range is not None:
        v_min, v_max = heatmap_range
    else:
        v_min, v_max = None, None
    
    A, B, C, D = optimized_params
    vec = np.array([A, B, C])
    if np.linalg.norm(vec) < 2:
        vec = vec * abs(D)
    optimized_params = vector_to_angles(vec)

    A, B, C, D = sampling_params
    vec = np.array([A, B, C])
    if np.linalg.norm(vec) < 2:
        vec = vec * abs(D)
    sampling_params = vector_to_angles(vec)


    # Unpacking parameters
    theta_range_deg, phi_range_deg, l_range = param_ranges
    theta_steps, phi_steps, l_steps = num_steps
    
    # Convert ranges from degrees to radians
    theta_range = np.deg2rad(theta_range_deg)
    phi_range = np.deg2rad(phi_range_deg)
    
    # Generate the parameter grids for theta, phi, and L (still in radians)
    thetas = np.linspace(sampling_params[0] - theta_range, sampling_params[0] + theta_range, theta_steps)
    phis = np.linspace(sampling_params[1] - phi_range, sampling_params[1] + phi_range, phi_steps)
    ls = np.linspace(sampling_params[2] - l_range, sampling_params[2] + l_range, l_steps)

    # Load the sampling results from CSV
    df = pd.read_csv(sampling_results_csv_path)

    # # Find the min and max values for theta, phi, and L
    # theta_min = df['theta'].min()
    # theta_max = df['theta'].max()
    # phi_min = df['phi'].min()
    # phi_max = df['phi'].max()
    # L_min = df['L'].min()
    # L_max = df['L'].max()


    # Find the row with the minimum MSE
    min_mse_row = df.loc[df['mse'].idxmin()]
    max_mse_row = df.loc[df['mse'].idxmax()]
    min_mse = min_mse_row['mse']
    max_mse = max_mse_row['mse']
    theta_min = min_mse_row['theta']
    phi_min = min_mse_row['phi']
    L_min = min_mse_row['L']
    min_mse_plane_params = np.array([theta_min, phi_min, L_min])

    # Calculate the difference between the optimized parameters and the minimum MSE parameters
    diff_params = np.abs(min_mse_plane_params - np.array(optimized_params))
    # Convert diff_params to degrees for theta and phi
    diff_params_deg = np.array([np.rad2deg(diff_params[0]), np.rad2deg(diff_params[1]), diff_params[2]])
  
    
    # Convert optimized parameters (in radians) to degrees for comparison
    theta_closest_deg = np.rad2deg(optimized_params[0])
    phi_closest_deg = np.rad2deg(optimized_params[1])
    L_closest = optimized_params[2]  

    closest_theta = df.iloc[(df['theta'] - optimized_params[0]).abs().idxmin()]
    closest_phi = df.iloc[(df['phi'] - optimized_params[1]).abs().idxmin()]
    closest_L = df.iloc[(df['L'] - optimized_params[2]).abs().idxmin()]
    

    closest_thetas_extent = [np.rad2deg(min(thetas)), np.rad2deg(max(thetas))]
    closest_phis_extent = [np.rad2deg(min(phis)), np.rad2deg(max(phis))]
    closest_Ls_extent = [min(ls), max(ls)]

    gif = False
    if gif:
        print("Generating heatmaps for theta, phi, and L...")

        theta_closest_opt = closest_theta['theta']
        phi_closest_opt = closest_phi['phi']
        L_closest_opt = closest_L['L']
        # Find the row in the CSV that matches the closest optimized parameters
        closest_row = df[(df['theta'] == theta_closest_opt) & (df['phi'] == phi_closest_opt) & (df['L'] == L_closest_opt)]

        # Extract the MSE value from the closest row
        if not closest_row.empty:
            mse_closest_opt = closest_row['mse'].values[0]
            print(f"MSE for closest optimized parameters: {mse_closest_opt}")
        else:
            print("No matching row found for the closest optimized parameters.")

        v_min, v_max = min_mse, mse_closest_opt

        theta_min, theta_closest_opt = sorted([theta_min, theta_closest_opt])
        phi_min, phi_closest_opt = sorted([phi_min, phi_closest_opt])
        L_min, L_closest_opt = sorted([L_min, L_closest_opt])
        selected_thetas = thetas[(thetas >= theta_min) & (thetas <= theta_closest_opt)]
        selected_phis = phis[(phis >= phi_min) & (phis <= phi_closest_opt)]
        selected_Ls = ls[(ls >= L_min) & (ls <= L_closest_opt)]

        heatmaps_theta_path = os.path.join(output_path, "heatmaps_theta")
        os.makedirs(heatmaps_theta_path, exist_ok=True)
        heatmaps_phi_path = os.path.join(output_path, "heatmaps_phi")
        os.makedirs(heatmaps_phi_path, exist_ok=True)
        heatmaps_L_path = os.path.join(output_path, "heatmaps_L")
        os.makedirs(heatmaps_L_path, exist_ok=True)

        #generate gif heatmaps for theta
        gif_images = []
        for theta in selected_thetas:
            closest_theta_value = df['theta'].iloc[(df['theta'] - theta).abs().argmin()]
            heatmap_data_theta = df[df['theta'] == closest_theta_value][['phi', 'L', 'mse']]
            heatmap_matrix_theta = heatmap_data_theta.pivot(index='L', columns='phi', values='mse')
            plt.figure(figsize=(10, 8))
            plt.title(f"Heatmap of MSE for phi vs L \nθ = {np.rad2deg(theta):.2f}°")
            plt.xlabel("φ [°]")
            plt.ylabel("L")
            plt.imshow(heatmap_matrix_theta, cmap='hot', interpolation='nearest', aspect='auto', origin='lower', 
                       extent=[closest_phis_extent[0], closest_phis_extent[1], closest_Ls_extent[0], closest_Ls_extent[1]], vmin=v_min, vmax=v_max)
            plt.colorbar(label='MSE')
            plt.scatter(phi_closest_deg, L_closest, color='green', label=method)
            plt.scatter(np.rad2deg(min_mse_plane_params[1]), min_mse_plane_params[2], color='blue', label='Exhaustive-Search minimum')
            plt.legend()
            heatmap_path = os.path.join(heatmaps_theta_path, f"heatmap_{method}_theta_{np.rad2deg(theta):.2f}.png")
            plt.savefig(heatmap_path)
            plt.close()
            gif_images.append(imageio.imread(heatmap_path))
            
        if len(gif_images) > 0:
            imageio.mimsave(os.path.join(output_path, "heatmaps_theta.gif"), gif_images, duration=2.5)

        #generate gif heatmaps for phi
        gif_images = []
        for phi in selected_phis:
            closest_phi_value = df['phi'].iloc[(df['phi'] - phi).abs().argmin()]
            heatmap_data_phi = df[df['phi'] == closest_phi_value][['theta', 'L', 'mse']]
            heatmap_matrix_phi = heatmap_data_phi.pivot(index='L', columns='theta', values='mse')
            plt.figure(figsize=(10, 8))
            plt.title(f"Heatmap of MSE for theta vs L \nφ = {np.rad2deg(phi):.2f}°")
            plt.xlabel("θ [°]")
            plt.ylabel("L")
            plt.imshow(heatmap_matrix_phi, cmap='hot', interpolation='nearest', aspect='auto', origin='lower', 
                       extent=[closest_thetas_extent[0], closest_thetas_extent[1], closest_Ls_extent[0], closest_Ls_extent[1]], vmin=v_min, vmax=v_max)
            plt.colorbar(label='MSE')
            plt.scatter(theta_closest_deg, L_closest, color='green', label=method)
            plt.scatter(np.rad2deg(min_mse_plane_params[0]), min_mse_plane_params[2], color='blue', label= 'Exhaustive-Search minimum')
            plt.legend()
            heatmap_path = os.path.join(heatmaps_phi_path, f"heatmap_{method}_phi_{np.rad2deg(phi):.2f}.png")
            plt.savefig(heatmap_path)
            plt.close()
            gif_images.append(imageio.imread(heatmap_path))
            
        if len(gif_images) > 0:
            imageio.mimsave(os.path.join(output_path, "heatmaps_phi.gif"), gif_images, duration=2.5)

        # Generate gif heatmaps for L

        gif_images = []
        for L in selected_Ls:
            closest_L_value = df['L'].iloc[(df['L'] - L).abs().argmin()]
            heatmap_data_L = df[df['L'] == closest_L_value][['theta', 'phi', 'mse']]
            heatmap_matrix_L = heatmap_data_L.pivot(index='phi', columns='theta', values='mse')
            plt.figure(figsize=(10, 8))
            plt.title(f"Heatmap of MSE for theta vs phi \nL = {L:.2f}")
            plt.xlabel("θ [°]")
            plt.ylabel("φ [°]")
            plt.imshow(heatmap_matrix_L, cmap='hot', interpolation='nearest', aspect='auto', origin='lower', 
                       extent=[closest_thetas_extent[0], closest_thetas_extent[1], closest_phis_extent[0], closest_phis_extent[1]], vmin=v_min, vmax=v_max)
            plt.colorbar(label='MSE')
            plt.scatter(theta_closest_deg, phi_closest_deg, color='green', label=method)
            plt.scatter(np.rad2deg(min_mse_plane_params[0]), np.rad2deg(min_mse_plane_params[1]), color='blue', label='Exhaustive-Search minimum')
            plt.legend()
            heatmap_path = os.path.join(heatmaps_L_path, f"heatmap_{method}_L_{L:.2f}.png")
            plt.savefig(heatmap_path)
            plt.close()
            gif_images.append(imageio.imread(heatmap_path))
            
        if len(gif_images) > 0:
            imageio.mimsave(os.path.join(output_path, "heatmaps_L.gif"), gif_images, duration=2.5)

    else:
        
        v_min, v_max = min_mse, max_mse
        # Now filter the DataFrame to get all rows that match the closest values
        df_theta_closest = df[df['theta'] == closest_theta['theta']]
        df_phi_closest = df[df['phi'] == closest_phi['phi']]
        df_L_closest = df[df['L'] == closest_L['L']]
        

        # Structure for heatmap data (mse vs phi and L)
        heatmap_data_theta = df_theta_closest[['phi', 'L', 'mse']]
        heatmap_data_phi = df_phi_closest[['theta', 'L', 'mse']]
        heatmap_data_L = df_L_closest[['theta', 'phi', 'mse']]

        # closest_thetas_extent = [min(df_theta_closest['theta']), max(df_theta_closest['theta'])]
        # closest_phis_extent = [min(df_phi_closest['phi']), max(df_phi_closest['phi'])]
        # closest_Ls_extent = [min(df_L_closest['L']), max(df_L_closest['L'])]


        # Pivot the DataFrame for heatmap (phi vs L for mse)
        heatmap_matrix_theta = heatmap_data_theta.pivot(index='L', columns='phi', values='mse')
        
        # Pivot the DataFrame for phi vs L for mse (or vice versa depending on use)
        heatmap_matrix_phi = heatmap_data_phi.pivot(index='L', columns='theta', values='mse')
        heatmap_matrix_L = heatmap_data_L.pivot(index='phi', columns='theta', values='mse')
        
        # Generate the heatmaps
        plt.figure(figsize=(10, 8))
        plt.title("Heatmap of MSE for phi vs L \nθ = {:.2f}° \nΔ = {:.2f}°".format(theta_closest_deg, np.rad2deg(diff_params[0])))
        plt.xlabel("φ [°]")
        plt.ylabel("L")
        plt.imshow(heatmap_matrix_theta, cmap='hot', interpolation='nearest', aspect='auto', origin='lower', 
                   extent=[closest_phis_extent[0], closest_phis_extent[1], closest_Ls_extent[0], closest_Ls_extent[1]], vmin=v_min, vmax=v_max)
        plt.colorbar(label='MSE')
        plt.scatter(phi_closest_deg, L_closest, color='green', label='Optimized Parameters')
        plt.scatter(np.rad2deg(min_mse_plane_params[1]), min_mse_plane_params[2], color='blue', label='Minimum sampled space')
        plt.legend()
        plt.savefig(os.path.join(output_path, "heatmap_theta_closest.png"))
        plt.show()

        plt.figure(figsize=(10, 8))
        plt.title("Heatmap of MSE for theta vs L \nφ = {:.2f}° \nΔ = {:.2f}°".format(phi_closest_deg, np.rad2deg(diff_params[1])))
        plt.xlabel("θ [°]")
        plt.ylabel("L")
        plt.imshow(heatmap_matrix_phi, cmap='hot', interpolation='nearest', aspect='auto', origin='lower', 
                   extent=[closest_thetas_extent[0], closest_thetas_extent[1], closest_Ls_extent[0], closest_Ls_extent[1]], vmin=v_min, vmax=v_max)
        plt.colorbar(label='MSE')
        plt.scatter(theta_closest_deg, L_closest, color='green', label='Optimized Parameters')
        plt.scatter(np.rad2deg(min_mse_plane_params[0]), min_mse_plane_params[2], color='blue', label='Minimum sampled space')
        plt.legend()
        plt.savefig(os.path.join(output_path, "heatmap_phi_closest.png"))
        plt.show()

        plt.figure(figsize=(10, 8))
        plt.title("Heatmap of MSE for theta vs phi \nL = {:.2f} \nΔ = {:.2f}".format(L_closest, diff_params[2]))
        plt.xlabel("θ [°]")
        plt.ylabel("φ [°]")
        plt.imshow(heatmap_matrix_L, cmap='hot', interpolation='nearest', aspect='auto', origin='lower', 
                   extent=[closest_thetas_extent[0], closest_thetas_extent[1], closest_phis_extent[0], closest_phis_extent[1]], vmin=v_min, vmax=v_max)
        plt.colorbar(label='MSE')
        plt.scatter(theta_closest_deg, phi_closest_deg, color='green', label='Optimized Parameters')
        plt.scatter(np.rad2deg(min_mse_plane_params[0]), np.rad2deg(min_mse_plane_params[1]), color='blue', label='Minimum sampled space')
        plt.legend()
        plt.savefig(os.path.join(output_path, "heatmap_L_closest.png"))
        plt.show()

    heatmap_gif_dir = os.path.join(output_path, "heatmap_gifs_colorbar_adjust")
    os.makedirs(heatmap_gif_dir, exist_ok=True)

    
    # Define the number of frames for the GIF
    num_frames = 100

    # Compute dynamic vmax values
    vmax_values = np.linspace(v_max, v_min, num_frames)

    # Create lists to store frames for each heatmap
    gif_images_theta = []
    gif_images_phi = []
    gif_images_L = []

    # Iterate over different vmax settings
    for vmax in vmax_values:
        # Heatmap for theta (phi vs L)
        plt.figure(figsize=(10, 8))
        plt.title(f"Heatmap of MSE for phi vs L \nθ = {theta_closest_deg:.2f}° \nΔ = {np.rad2deg(diff_params[0]):.2f}°")
        plt.xlabel("φ [°]")
        plt.ylabel("L")
        plt.imshow(heatmap_matrix_theta, cmap='hot', interpolation='nearest', aspect='auto', origin='lower', 
                extent=[closest_phis_extent[0], closest_phis_extent[1], closest_Ls_extent[0], closest_Ls_extent[1]], vmin=v_min, vmax=vmax)
        plt.colorbar(label='MSE')
        plt.scatter(phi_closest_deg, L_closest, color='green', label='Optimized Parameters')
        plt.scatter(np.rad2deg(min_mse_plane_params[1]), min_mse_plane_params[2], color='blue', label='Minimum sampled space')
        plt.legend()
        heatmap_path = os.path.join(heatmap_gif_dir, f"heatmap_theta_{vmax:.2f}.png")
        plt.savefig(heatmap_path)
        plt.close()
        gif_images_theta.append(imageio.imread(heatmap_path))
        os.remove(heatmap_path)  # Remove the PNG file after reading

        # Heatmap for phi (theta vs L)
        plt.figure(figsize=(10, 8))
        plt.title(f"Heatmap of MSE for theta vs L \nφ = {phi_closest_deg:.2f}° \nΔ = {np.rad2deg(diff_params[1]):.2f}°")
        plt.xlabel("θ [°]")
        plt.ylabel("L")
        plt.imshow(heatmap_matrix_phi, cmap='hot', interpolation='nearest', aspect='auto', origin='lower', 
                extent=[closest_thetas_extent[0], closest_thetas_extent[1], closest_Ls_extent[0], closest_Ls_extent[1]], vmin=v_min, vmax=vmax)
        plt.colorbar(label='MSE')
        plt.scatter(theta_closest_deg, L_closest, color='green', label='Optimized Parameters')
        plt.scatter(np.rad2deg(min_mse_plane_params[0]), min_mse_plane_params[2], color='blue', label='Minimum sampled space')
        plt.legend()
        heatmap_path = os.path.join(heatmap_gif_dir, f"heatmap_phi_{vmax:.2f}.png")
        plt.savefig(heatmap_path)
        plt.close()
        gif_images_phi.append(imageio.imread(heatmap_path))
        os.remove(heatmap_path)  # Remove the PNG file after reading

        # Heatmap for L (theta vs phi)
        plt.figure(figsize=(10, 8))
        plt.title(f"Heatmap of MSE for theta vs phi \nL = {L_closest:.2f} \nΔ = {diff_params[2]:.2f}")
        plt.xlabel("θ [°]")
        plt.ylabel("φ [°]")
        plt.imshow(heatmap_matrix_L, cmap='hot', interpolation='nearest', aspect='auto', origin='lower', 
                extent=[closest_thetas_extent[0], closest_thetas_extent[1], closest_phis_extent[0], closest_phis_extent[1]], vmin=v_min, vmax=vmax)
        plt.colorbar(label='MSE')
        plt.scatter(theta_closest_deg, phi_closest_deg, color='green', label='Optimized Parameters')
        plt.scatter(np.rad2deg(min_mse_plane_params[0]), np.rad2deg(min_mse_plane_params[1]), color='blue', label='Minimum sampled space')
        plt.legend()
        heatmap_path = os.path.join(heatmap_gif_dir, f"heatmap_L_{vmax:.2f}.png")
        plt.savefig(heatmap_path)
        plt.close()
        gif_images_L.append(imageio.imread(heatmap_path))
        os.remove(heatmap_path)  # Remove the PNG file after reading

    # Save GIFs
    imageio.mimsave(os.path.join(heatmap_gif_dir, "heatmap_theta_colorbar_adjust.gif"), gif_images_theta, duration=0.5)
    imageio.mimsave(os.path.join(heatmap_gif_dir, "heatmap_phi_colorbar_adjust.gif"), gif_images_phi, duration=0.5)
    imageio.mimsave(os.path.join(heatmap_gif_dir, "heatmap_L_colorbar_adjust.gif"), gif_images_L, duration=0.5)

    print("GIFs generated successfully!")

    # Find best parameters
    A, B, C = angles_to_vector(min_mse_plane_params[0], min_mse_plane_params[1], min_mse_plane_params[2])
    D = -np.dot([A, B, C], [A, B, C])
    best_params = (A, B, C, D)

    end_time = time.time()
    print(f"Time taken for heatmap generation: {end_time - start_time:.2f} seconds")
    print(f"Minimum MSE: {min_mse:.2f}, Best Parameters: θ={min_mse_plane_params[0]:.2f}, φ={min_mse_plane_params[1]:.2f}, L={min_mse_plane_params[2]:.2f}")
    
    return np.array(best_params), min_mse, diff_params_deg

def generate_gif_with_varying_vmax(matrix, title, xlabel, ylabel, extent, output_path, filename_prefix, v_min, v_max, steps=20):
    """
    Generates a GIF where vmax is varied in multiple steps for visualization.
    """
    gif_images = []
    vmax_values = np.linspace(v_min, v_max, steps)
    
    for vmax in vmax_values:
        plt.figure(figsize=(10, 8))
        plt.title(title)
        plt.xlabel(xlabel)
        plt.ylabel(ylabel)
        plt.imshow(matrix, cmap='hot', interpolation='nearest', aspect='auto', origin='lower', extent=extent, vmin=v_min, vmax=vmax)
        plt.colorbar(label='MSE')
        plt.scatter(theta_closest_deg, phi_closest_deg, color='green', label='Optimized Parameters')
        plt.scatter(np.rad2deg(min_mse_plane_params[0]), np.rad2deg(min_mse_plane_params[1]), color='blue', label='Minimum sampled space')
        plt.legend()
        heatmap_path = os.path.join(output_path, f"{filename_prefix}_vmax_{vmax:.2f}.png")
        plt.savefig(heatmap_path)
        plt.close()
        gif_images.append(imageio.imread(heatmap_path))
    
    if len(gif_images) > 0:
        imageio.mimsave(os.path.join(output_path, f"{filename_prefix}.gif"), gif_images, duration=0.5)

def shift(bone, soft_tissue, range = None, steps = None, interpolation='bone', interpolation_method='cubic', metric = 'mse', interpolator = None):
    """
    Shifts a plane along the x-axis and computes the objective function value for each shift.
    Parameters:
    bone (ndarray): The bone image data.
    soft_tissue (ndarray): The soft tissue image data.
    range (tuple, optional): The range of shifts to apply. Defaults to (-150, 150).
    steps (int, optional): The number of steps within the range. Defaults to 20 if range is not provided.
    interpolation (str, optional): The type of interpolation to use. Defaults to 'bone'.
    interpolation_method (str, optional): The method of interpolation. Defaults to 'linear'.
    Returns:
    tuple: The plane equation with the shift that minimizes the objective function value.
    """

    com = calculate_center_of_mass(bone)
    middle_x = com[1]
    sum_values = []
    if range:
        start, stop = range
    else:
        start, stop = -150, 150
        steps = 20
    shift_values = np.linspace(start, stop, steps)
    x_shift = shift_values + middle_x
    planes = []
    for shift in shift_values:
        # Update the plane equation with the current shift
        shifted_plane = (1, 0, 0, -middle_x + shift)
        
        # Compute distances and indices for the shifted plane
        indices_shifted, distances_shifted = compute_distances_and_indices(bone, shifted_plane)
        
        # Compute the objective function value for the shifted plane
        sum_shifted, percentage_shifted = compute_intensity_metric(bone, soft_tissue, indices_shifted, distances_shifted, shifted_plane, metric_type=metric, interpolation = interpolation, 
                                                                   interpolation_method = interpolation_method, interpolator = interpolator)
        sum_values.append(sum_shifted)
        planes.append(shifted_plane)
        # Plot the plane on the middle slice
        #plot_plane_on_middle_slice(image_selected, shifted_plane)
        
        # Print the shift value and the objective function value
        #print(f"Shift: {shift}, Objective function value: {sum_shifted}, Percentage difference: {percentage_shifted:.2e}%")
    #plt.scatter(x_shift, sum_values, marker='x', color = 'red')
    plt.plot(x_shift, sum_values, color = 'red')
    plt.xlabel('X Shift')
    plt.ylabel('Objective Function Value')
    plt.title('Objective Function Value vs. X Shift')
    plt.grid(True)
    plt.show()

    plot_middle_slice_with_planes(bone, planes, title='Middle Slice with Shifted Plane')
    return planes[np.argmin(sum_values)]

def plot_objective_value_colormap(image, soft_tissue, image_plot, pat, initial_plane_params, shift_range_angle=(-5, 5), shift_range_L = (-10,10), shift_range_A = (-20,20), shift_range_B = (-20,20), 
                                  shift_steps=20, interpolation=False, interpolation_method='linear', interpolator = None):
    """
    Plot the objective value as a colormap as functions of the value in the first and second components after the shift.

    Parameters:
    image (ndarray): The image data.
    soft_tissue (ndarray): The soft tissue mask data.
    image_plot (ndarray): The image data for plotting.
    pat (int): Patient identifier.
    initial_plane_params (tuple): Initial plane parameters (A, B, C, D).
    shift_range (tuple): The range of shifts for the first and second components.
    shift_steps (int): The number of steps for the shifts.
    interpolation (bool): Whether to use interpolation in the objective function.
    
    """

    A, B, C, D = initial_plane_params

    vec_len = np.linalg.norm([A, B, C])
    vec_normalized = np.array([A, B, C]) / vec_len
    if vec_len == 1:
        vec = vec_normalized * abs(D)
    else:
        vec = np.array([A, B, C])

    # alpha, beta, L = vector_to_angles(vec)

    # alpha = alpha + 2*np.pi
   
    # start_angle = np.deg2rad(shift_range_angle[0])
    # stop_angle = np.deg2rad(shift_range_angle[1])
    # shifted_plane_params_list = []

    # shift_values_angle = np.linspace(start_angle, stop_angle, shift_steps)
    # shift_values_L = np.linspace(shift_range_L[0], shift_range_L[1], shift_steps)
    # objective_values = np.zeros((shift_steps, shift_steps))

    # for i, shift_angle in enumerate(shift_values_angle):
    #     for j, shift_L in enumerate(shift_values_L):
    #         alpha_new = alpha + shift_angle
    #         L_new = L + shift_L
    #         shifted_plane_params = (alpha_new, beta, L_new)
    #         A, B, C = angles_to_vector(alpha_new, beta, L_new)
    #         D = -np.dot([A, B, C], [A, B, C])
    #         plane = (A, B, C, D)
    #         shifted_plane_params_list.append(plane)
    #         objective_value = objective_function(shifted_plane_params, image, soft_tissue, image_plot, pat, interpolation, obj_val_update=[])
    #         objective_values[i, j] = objective_value

    # plot_objective_value_colormap_from_values(np.rad2deg(shift_values_angle+alpha, shift_values_L + L), shift_values_L, objective_values, f"Polar angle {chr(945)} (°)", "Distance from origin along normal L")
    # plot_middle_slice_with_planes(image, shifted_plane_params_list, title='Middle Slice with Shifted Plane')

    A, B, C = vec[0], vec[1], vec[2]
    start_A = A + shift_range_A[0]
    stop_A = A + shift_range_A[1]
    start_B = B + shift_range_B[0]
    stop_B = B + shift_range_B[1]
    shifted_plane_params_list = []

    shift_values_A = np.linspace(start_A, stop_A, shift_steps)
    shift_values_B = np.linspace(start_B, stop_B, shift_steps)
    objective_values = np.zeros((shift_steps, shift_steps))

    for i, shift_A in enumerate(shift_values_A):
        for j, shift_B in enumerate(shift_values_B):
            shifted_plane_params = (shift_A, shift_B, C)
            D = -np.dot(np.array([shift_A, shift_B, C]), np.array([shift_A, shift_B, C]))
            norm = np.linalg.norm([shift_A, shift_B, C])
            plane = vector_to_angles(shifted_plane_params)
            plane_params = (shift_A, shift_B, C, D)
            shifted_plane_params_list.append(plane_params)
            objective_value = objective_function(plane, image, soft_tissue, image_plot, pat, interpolation, obj_val_update=[], interpolation_method = interpolation_method, interpolator = interpolator)
            objective_values[i, j] = objective_value

    plot_objective_value_colormap_from_values(shift_values_A, shift_values_B, objective_values, "Parameter A", "Parameter B")
    
    plot_middle_slice_with_planes(image, shifted_plane_params_list, title='Middle Slice with Shifted Plane')

### Statistics
Compute different metrics from the planes optimized through different optimization methods and compare it to the plane found by the sampled minimum through Exhaustive-Search of the objective function space around the optimized point


In [None]:

def statistics(csv_filepath, base_path, pat_num, best_plane_params_path_list, best_plane_params_csv_reference_path,
                                          methods_list=['Exhaustive-Search', 'Nelder-Mead', 'Gradient Descent'], output_path='.',
                                          param_ranges=(10, 5, 10), num_steps=(10, 10, 10), sampling_params = None):    
    """
    Computes GTVP differences and volumes for multiple methods and saves the results in a CSV.
    """

    df = pd.read_csv(csv_filepath)
    patient = pat_num
    row = df.iloc[patient]
    image, gtvp, body, spinalcord, mandibula, structure_images, voxel_size, patient_folder_path, output_path_patient, patient_id, extension, pat_num = process_patient_data(row, base_path, output_path, pat_num)
    output_path = output_path_patient
    
    # Load the CSV file containing the reference values
    df = pd.read_csv(best_plane_params_csv_reference_path)
    min_mse_row = df.loc[df['mse'].idxmin()]
    best_plane_params_reference = np.array([min_mse_row['theta'], min_mse_row['phi'], min_mse_row['L']])

    # Path for the CSV file
    csv_file_path = os.path.join(output_path, 'statistics.csv')
    
    # Open the CSV file for writing the header
    with open(csv_file_path, mode='w', newline='') as file:
        writer = csv.writer(file)
        
        # Write the header row with methods as columns, and 'Exhaustive-Search' as the first column
        header = ['Metric'] + methods_list + ['Exhaustive-Search']
        writer.writerow(header)

    # Compute for reference plane parameters (best_plane_params_reference)
    A, B, C = angles_to_vector(best_plane_params_reference[0], best_plane_params_reference[1], best_plane_params_reference[2])
    D = -np.dot([A, B, C], [A, B, C])
    best_plane_params_reference = (A, B, C, D)
    best_plane_params_ref_real = real_params(best_plane_params_reference, voxel_size)
    indices_ref, distances_ref, distances_ref_real = compute_distances_and_indices(gtvp, best_plane_params_reference, best_plane_params_ref_real, voxel_size=voxel_size)
    assigned_indices_ref, assigned_distances_ref, assigned_distances_ref_real = assign_distances_and_indices(extension, indices_ref, distances_ref, distances_real=distances_ref_real)

    # Calculate the reference distances and volumes
    minmax_indices_ref_ipsi, minmax_distances_ref_ipsi = gtvp_max_min_distance(assigned_indices_ref['ipsilateral'], assigned_distances_ref_real['ipsilateral'])
    minmax_indices_ref_contra, minmax_distances_ref_contra = gtvp_max_min_distance(assigned_indices_ref['contralateral'], assigned_distances_ref_real['contralateral'])
    gtvp_volume_ipsi_ref = gtvp_volume(assigned_indices_ref['ipsilateral'], voxel_size)
    gtvp_volume_contra_ref = gtvp_volume(assigned_indices_ref['contralateral'], voxel_size)

    # Data structure to store results for each method
    results = {}
    heatmap_params = sampling_params
    # Iterate over methods and calculate the necessary values
    for best_plane_params_path, method in zip(best_plane_params_path_list, methods_list):
        # Load the best plane parameters for the method
        best_plane_params = np.load(best_plane_params_path)
        method_path = os.path.join(output_path, method)
        os.makedirs(method_path, exist_ok=True)
        params_sampled, obj_fun_sampled, difference = obj_fun_heatmap(best_plane_params_csv_reference_path, method_path, best_plane_params, param_ranges=param_ranges, 
                                                                      num_steps=num_steps, method=method, sampling_params=heatmap_params)

        best_plane_params_real = real_params(best_plane_params, voxel_size)
        indices, distances, distances_real = compute_distances_and_indices(gtvp, best_plane_params, best_plane_params_real, voxel_size=voxel_size)
        assigned_indices, assigned_distances, assigned_distances_real = assign_distances_and_indices(extension, indices, distances, distances_real=distances_real)

        # Calculate distances and volumes for this method
        minmax_indices_ipsi, minmax_distances_ipsi = gtvp_max_min_distance(assigned_indices['ipsilateral'], assigned_distances_real['ipsilateral'])
        minmax_indices_contra, minmax_distances_contra = gtvp_max_min_distance(assigned_indices['contralateral'], assigned_distances_real['contralateral'])
        gtvp_volume_ipsi = gtvp_volume(assigned_indices['ipsilateral'], voxel_size)
        gtvp_volume_contra = gtvp_volume(assigned_indices['contralateral'], voxel_size)

        # Calculate the differences
        difference_ipsi = {
            'shortest': minmax_distances_ipsi[0] - minmax_distances_ref_ipsi[0],
            'longest': minmax_distances_ipsi[1] - minmax_distances_ref_ipsi[1]
        }
        
        difference_contra = {
            'shortest': minmax_distances_contra[0] - minmax_distances_ref_contra[0],
            'longest': minmax_distances_contra[1] - minmax_distances_ref_contra[1]
        }
        
        # Volume differences
        volume_ipsi_diff = gtvp_volume_ipsi - gtvp_volume_ipsi_ref
        volume_contra_diff = gtvp_volume_contra - gtvp_volume_contra_ref
        
        # Store results for the method
        results[method] = {
            'difference_theta': difference[0],
            'difference_phi': difference[1],
            'difference_L': difference[2],
            'min_distance_ipsi': minmax_distances_ipsi[0],
            'max_distance_ipsi': minmax_distances_ipsi[1],
            'min_distance_contra': minmax_distances_contra[0],
            'max_distance_contra': minmax_distances_contra[1],
            'volume_ipsi': gtvp_volume_ipsi,
            'volume_contra': gtvp_volume_contra,
            'difference_ipsi_shortest': difference_ipsi['shortest'],
            'difference_ipsi_longest': difference_ipsi['longest'],
            'difference_contra_shortest': difference_contra['shortest'],
            'difference_contra_longest': difference_contra['longest'],
            'volume_ipsi_diff': volume_ipsi_diff,
            'volume_contra_diff': volume_contra_diff
        }

    # Add the Exhaustive-Search column and fill it with the reference values
    exhaustive_search_results = {
        'difference_theta': '',  # Empty because it is comparing against the method
        'difference_phi': '',
        'difference_L': '',
        'min_distance_ipsi': minmax_distances_ref_ipsi[0],
        'max_distance_ipsi': minmax_distances_ref_ipsi[1],
        'min_distance_contra': minmax_distances_ref_contra[0],
        'max_distance_contra': minmax_distances_ref_contra[1],
        'volume_ipsi': gtvp_volume_ipsi_ref,
        'volume_contra': gtvp_volume_contra_ref,
        'difference_ipsi_shortest': '',
        'difference_ipsi_longest': '',
        'difference_contra_shortest': '',
        'difference_contra_longest': '',
        'volume_ipsi_diff': '',
        'volume_contra_diff': ''
    }
    
    results['Exhaustive-Search'] = exhaustive_search_results

    # Write the results to the CSV file
    with open(csv_file_path, mode='a', newline='') as file:
        writer = csv.writer(file)

        # Write the data rows for each metric
        metrics = [
            'difference_theta', 'difference_phi', 'difference_L',
            'min_distance_ipsi', 'max_distance_ipsi', 
            'min_distance_contra', 'max_distance_contra',
            'volume_ipsi', 'volume_contra',
            'difference_ipsi_shortest', 'difference_ipsi_longest',
            'difference_contra_shortest', 'difference_contra_longest',
            'volume_ipsi_diff', 'volume_contra_diff'
        ]
        
        # For each metric, write the values for each method, and fill "Exhaustive-Search" with reference values
        for metric in metrics:
            row = [metric] + [results[method].get(metric, '') for method in methods_list] + [results['Exhaustive-Search'].get(metric, '')]
            writer.writerow(row)
    
    print(f"CSV file saved at {csv_file_path}")


## Optimization pipeline 
Optimizes the midplane between ipsilateral and contralateral sides of the neck in oropharynx patients given an excel sheet of patients and folders for each patient containing masks and images of patients. The number of patients can be defined in the function.

In [None]:
def midline_optimized(base_path, csv_filepath, output_path, interpolation = 'bone', patient = None, image_range = 'full', optimization_method = 'Nelder-Mead', 
                    params_filename = 'parameters_optimized.npy', objective_function_filename = 'objective_function_values.npy', opt_image = 'bone', 
                    interpolation_method = 'linear', max_pat_num = 1, verify_optimum = False, rotation_angle = None,
                    metric = 'mse', multiresolution_opt = False, verification_list = None, optimization_methods_list = ['BFGS'], results_path_list = None):
    """
    Optimizes the midline plane parameters for medical images.
    Parameters:
    base_path (str): The base directory path where patient data is stored.
    csv_filepath (str): The file path to the CSV file containing patient data.
    output_path (str): The directory path where output files will be saved.
    interpolation (str, optional): The interpolation method to use. Default is 'bone'.
    patient (int, optional): The specific patient number to process. If None, all patients in the CSV file will be processed. Default is None.
    image_range (str, optional): The range of the image to use ('full' or 'gtvp'). Default is 'full'.
    optimization_method (str, optional): The optimization method to use. Default is 'Nelder-Mead'.
    params_filename (str, optional): The filename to save the optimized parameters. Default is 'parameters_optimized.npy'.
    objective_function_filename (str, optional): The filename to save the objective function values. Default is 'objective_function_values.npy'.
    opt_image (str, optional): The type of image to use for optimization ('bone' or 'all'). Default is 'bone'.
    interpolation_method (str, optional): The interpolation method to use for image processing. Default is 'linear'.
    max_pat_num (int, optional): The maximum number of patients to process. Default is 1.
    Returns:
    list: A list of the best plane parameters for each patient.
    list: A list of the objective function values for each patient.
    """
    
    start_pipeline = time.time()

    if optimization_method == 'all':
        optimization_methods_list = ['Nelder-Mead', 'BFGS']
        list_best_plane_params_NM = []
        list_best_plane_params_BFGS = []
        list_obj_fun_NM = []
        list_obj_fun_BFGS = []
    
    
    pat_num = 0
    list_best_plane_params = []
    list_obj_fun = []  
    statistics_data = []
    plane_coeffs_data = []
    df = pd.read_csv(csv_filepath)
    for index, row in df.iterrows():
        if patient is not None:
            pat_num = patient
            index = patient
            row = df.iloc[patient]
        
        image, gtvp, body, spinalcord, mandibula, structure_images, voxel_size, patient_folder_path, output_path_patient, patient_id, extension, pat_num = process_patient_data(row, base_path, output_path, pat_num)
        if image is None or gtvp is None or body is None or spinalcord is None or mandibula is None:
            print(f"Patient number {pat_num} has missing data. Skipping...")
            pat_num += 1
            continue
        #image = image * body

        if rotation_angle is not None:
            image = rotate_3d_array(image, angle = rotation_angle, fill_value=0)
            #plt.imshow(image[:, :, image.shape[2] // 2], cmap='gray')
            body = rotate_3d_array(body, angle = rotation_angle, fill_value=0)
            gtvp = rotate_3d_array(gtvp, angle = rotation_angle, fill_value=0)
            spinalcord = rotate_3d_array(spinalcord, angle = rotation_angle, fill_value=0)
            mandibula = rotate_3d_array(mandibula, angle = rotation_angle, fill_value=0)
        

        if image_range == 'full':
            voxel_counts = count_voxels_per_slice(image, plot = False)
            valley_index = gradient_descent_voxel_counts(voxel_counts, plot = False)
            image, body = select_slices(valley_index, image, body)
            
        # Show original image
        # Ensure image is in a safe data type
        image = image.astype(np.int16)  # Avoid overflow errors

        #display_scrollable_image_with_values(body, title='Body Mask unaltered')
        # Apply binary erosion (ensuring boolean output)
        body = binary_erosion(body, iterations=2).astype(np.uint8)

        # Apply body mask to image safely
        image = np.where(body == 1, image, -1000)  # Set background to -1000 to get rod of fidiucial markers
        
        # Apply bone threshold mask
        #bone_mask = mask_via_threshold(image, HU_range=(900, 5000)).astype(np.uint8)
        bone_mask = mask_via_threshold(image, HU_range=(900, 2500)).astype(np.uint16)
        bone_ct = image * bone_mask # now everything other than bone is 0

        
        # rest_mask = mask_via_threshold(image, HU_range=(700, 900)).astype(np.uint8)
        # if np.sum(rest_mask) == 0:
        #     print("rest_mask is empty")
        # else:
        #     print("rest_mask is not empty")

        # Extract dental fillings mask
        dental_fillings_mask = mask_via_threshold(image, HU_range=(2500, 5000)).astype(np.uint16)

        # Assign teeth like HU values to dental fillings
        dental_bone_ct = 2000 * dental_fillings_mask

        # Combine bone and dental structures which substitutes dental fillings for bone
        bone_ct = bone_ct + dental_bone_ct
        image = image * (1 - dental_fillings_mask) + dental_bone_ct
        
        # Prepare image for visualization
        image_plot = np.copy(image)
        soft_tissue = mask_via_threshold(image, HU_range = (-2000, 900)).astype(np.uint16)


        if image_range == 'gtvp':
            start, end = get_nonzero_slice_range(gtvp)
            #start_body, end_body = get_nonzero_slice_range(body)
            slice_range = (start, end)
            body = body[:,:,slice_range[0]:slice_range[1]]
            image = image[:,:,slice_range[0]:slice_range[1]]
            image_plot = image_plot[:,:,slice_range[0]:slice_range[1]]
            bone_ct = bone_ct[:,:,slice_range[0]:slice_range[1]]
            soft_tissue = soft_tissue[:,:,slice_range[0]:slice_range[1]]
            gtvp = gtvp[:,:,slice_range[0]:slice_range[1]]
            mandibula = mandibula[:,:,slice_range[0]:slice_range[1]]
            spinalcord = spinalcord[:,:,slice_range[0]:slice_range[1]]
            # if np.sum(rest_mask) > 0:
            #     rest_mask = rest_mask[:,:,slice_range[0]:slice_range[1]]

        

        # if results_path_list is not None:
        #     results_list = [np.load(results_path) for results_path in results_path_list]
        #     #display_scrollable_slices(image, body, gtvp, mandibula, spinalcord, results_list, optimization_methods_list)
        #     plot_slice_with_planes(image, body, gtvp, mandibula, spinalcord, results_list, optimization_methods_list, 11, save_path=output_path_patient, filename=f"planes_lower", 
        #                            orientation='axial')
        #     plot_slice_with_planes(image, body, gtvp, mandibula, spinalcord, results_list, optimization_methods_list, 41, save_path=output_path_patient, filename=f"planes_upper", 
        #                            orientation='axial')
        #     plot_slice_with_planes(image, body, gtvp, mandibula, spinalcord, results_list, optimization_methods_list, 138, save_path=output_path_patient, filename=f"planes_coronal",
        #                               orientation='coronal')
        #     break
        # display_scrollable_image_with_values(image, title='Original Image')
        # display_scrollable_image_with_values(bone_ct, title='Bone Image')
        # display_scrollable_image_with_values(body, title='Body Mask')
        # break
        # if np.sum(rest_mask) > 0:
        #     display_scrollable_image_with_values(rest_mask, title='Rest Mask')
        # break
        
        # image = build_image_pyramid(image, num_levels=4)[-1]
        # bone_ct = build_image_pyramid(bone_ct, num_levels=4)[-1]
        # soft_tissue = build_image_pyramid(soft_tissue, num_levels=4)[-1]
        # image_plot = build_image_pyramid(image_plot, num_levels=4)[-1]
        

        # plt.imshow(image_plot[:, :, image_plot.shape[2] // 2], cmap='gray')
        # plt.title(f"Middle slice of patient number {pat_num}")
        # plt.savefig(os.path.join(output_path, f"Image_{pat_num}.png"))
        # plt.clf()

        # pat_num += 1
        # if pat_num == max_pat_num:
        #     break

        # plane_param = rotated_shift_1d(Im, soft_tissue, image_plot, pat_num, output_path_patient, interpolation = interpolation, obj_val_update=[], plot = True, interpolation_method = interpolation_method)
        # break

        # save_image_gifs(image, pat_num, output_path_patient)
        # save_image_gifs(bone_ct, pat_num, output_path_patient, structure='bone')
        
        #ellipsoid = create_ellipsoid()
        # head = create_asymmetric_head()
        # display_scrollable_image_with_values(ellipsoid, title='Ellipsoid')
        # break
        # grid_x, grid_y, grid_z = np.arange(ellipsoid.shape[0]), np.arange(ellipsoid.shape[1]), np.arange(ellipsoid.shape[2])
        # interpolator = RegularGridInterpolator((grid_x, grid_y, grid_z), ellipsoid, method=interpolation_method, bounds_error=False, fill_value=None)


        if os.path.exists(os.path.join(output_path_patient, 'interpolator.joblib')):
            interpolator = joblib.load(os.path.join(output_path_patient, 'interpolator.joblib'))
        else:
            # Select the interpolation method for either bone or full image (bone + soft tissue)
            start_interpolation = time.time()
            print(f"Starting cubic interpolation for patient {pat_num}...")
            if interpolation == 'bone':
                grid_x, grid_y, grid_z = np.arange(bone_ct.shape[0]), np.arange(bone_ct.shape[1]), np.arange(bone_ct.shape[2])
                interpolator = RegularGridInterpolator((grid_x, grid_y, grid_z), bone_ct, method=interpolation_method, bounds_error=False, fill_value=None)
            elif interpolation == 'full':
                grid_x, grid_y, grid_z = np.arange(image.shape[0]), np.arange(image.shape[1]), np.arange(image.shape[2])
                interpolator = RegularGridInterpolator((grid_x, grid_y, grid_z), image, method=interpolation_method, bounds_error=False, fill_value=None)
            end_interpolation = time.time()
            print(f"Interpolation time: {end_interpolation - start_interpolation} seconds")

            joblib.dump(interpolator, os.path.join(output_path_patient, 'interpolator.joblib'))
        
        # plane_best = shift(bone_ct, soft_tissue, range = (-150, 150), steps = 20, interpolation=interpolation, interpolation_method=interpolation_method, metric = metric, interpolator = interpolator)
        plane = param_initialization_2d(bone_ct, soft_tissue, image_plot, pat_num, output_path_patient, interpolation = interpolation, obj_val_update=[], plot = True, 
                                            gradient=False, interpolation_method = 'cubic', interpolator = interpolator, multi_resolution_optimization=multiresolution_opt, 
                                            num_levels=4, verification_list=None, body=body)
        
        if os.path.exists(os.path.join(output_path_patient, f'best_plane_params_patient_{pat_num}_{optimization_method}.npy')):
            best_plane_params = np.load(os.path.join(output_path_patient, f'best_plane_params_patient_{pat_num}_{optimization_method}.npy'))
            obj_fun = np.load(os.path.join(output_path_patient, f"objective_function_patient_{pat_num}_{optimization_method}.npy"))
        else:     
            
            #plot_plane_on_middle_slice(image, plane)
            
            #plane = (0,1,0,-center_of_mass(bone_ct)[0])
            #plane_1d = param_initialization_1d(Im, soft_tissue, image_plot, pat_num, output_path_patient, interpolation = interpolation, obj_val_update=[], plot = True, interpolation_method = interpolation_method)
            
            if optimization_method == 'exhaustive-search':
                start_exhaustive_search = time.time()
                #plane = (0,0,Im.shape[2]//2)
                # TODO: adjust it to changes
                best_plane_params, obj_fun = param_evaluation_3d(Im, soft_tissue, image_plot, pat_num, output_path_patient, interpolation=interpolation, 
                                    obj_val_update=[], plot=True, 
                                    gradient=False, interpolation_method='linear',
                                    initial_params=plane, param_ranges=(3, 3, 10), num_steps=(100, 100, 100), opt_method=optimization_method, interpolator = interpolator)
                end_exhaustive_search = time.time()
                print(f"Exhaustive search plane parameters: {best_plane_params} with objective function value: {obj_fun}")
                print(f"Time taken for exhaustive search: {end_exhaustive_search - start_exhaustive_search:.2f} seconds")
                
            if optimization_method == 'all':
                for method in optimization_methods_list: 
                    output_path_method = os.path.join(output_path_patient, method)
                    os.makedirs(output_path_method, exist_ok=True)
                    start_optimization = time.time()
                    print(f"Optimizing with {optimization_method} method")
                    best_plane_params, obj_fun = optimize_plane_parameters(plane, bone_ct, soft_tissue, image_plot, pat_num, method, interpolation = interpolation, 
                                                                output_path = output_path_method, interpolation_method = interpolation_method, metric=metric, interpolator = interpolator)
                    end_optimization = time.time()
                    print(f"Time taken for optimization using {method}: {end_optimization - start_optimization:.2f} seconds")
                    print(f"Optimized plane parameters using {method}: {best_plane_params} with objective function value: {obj_fun}")
                    np.save(os.path.join(output_path_method, f'best_plane_params_patient_{pat_num}_{method}.npy'), best_plane_params)
                    np.save(os.path.join(output_path_method, f'objective_function_patient_{pat_num}_{method}.npy'), obj_fun)
                    if method == 'Nelder-Mead':
                        list_best_plane_params_NM.append(best_plane_params)
                        list_obj_fun_NM.append(obj_fun)
                    if method == 'BFGS':
                        list_best_plane_params_BFGS.append(best_plane_params)
                        list_obj_fun_BFGS.append(obj_fun)
                
            else:
                
                if multiresolution_opt:
                    best_plane_params, obj_fun = multi_resolution_optimization(plane, bone_ct, soft_tissue, image_plot, pat_num, optimization_method, interpolation = interpolation, 
                                                            output_path = output_path_patient, interpolation_method = interpolation_method, metric=metric, num_levels=4, interpolator = interpolator)
                
                else:
                    start_optimization = time.time()
                    best_plane_params, obj_fun = optimize_plane_parameters(plane, bone_ct, soft_tissue, image, pat_num, optimization_method, interpolation = interpolation, 
                                                                output_path = output_path_patient, interpolation_method = interpolation_method, metric=metric, interpolator = interpolator)
                    end_optimization = time.time()
                    print(f"Time taken for optimization using {optimization_method}: {end_optimization - start_optimization:.2f} seconds")
                    print(f"Optimized plane parameters using {optimization_method}: {best_plane_params} with objective function value: {obj_fun}")
                    np.save(os.path.join(output_path_patient, f'best_plane_params_patient_{pat_num}_{optimization_method}.npy'), best_plane_params)
                    np.save(os.path.join(output_path_patient, f'objective_function_patient_{pat_num}_{optimization_method}.npy'), obj_fun)
            #best_plane_params = (1,0,0,-256)
            #obj_fun = 0
            if patient is None and optimization_method != 'all':
                list_best_plane_params.append(best_plane_params)
                list_obj_fun.append(obj_fun)
            
        if verify_optimum:
            verification_start = time.time()
            param_ranges = (4, 4, 10)
            num_steps = (100, 100, 100)
            if not os.path.exists(os.path.join(output_path_patient, f"mse_results_patient_{pat_num}.csv")):
                sampling_start = time.time()    
                best_plane_params_verified, obj_fun_verified = param_evaluation_3d(bone_ct, soft_tissue, image_plot, pat_num, output_path_patient, interpolation=interpolation, obj_val_update=[], 
                            interpolation_method='linear', param_ranges=param_ranges, num_steps=num_steps, 
                            fixed_ref_params=best_plane_params, interpolator = interpolator)
                sampling_end = time.time()
                print(f"Time taken for sampling: {sampling_end - sampling_start:.2f} seconds")
            
            sampling_results_csv_path = os.path.join(output_path_patient, f"mse_results_patient_{pat_num}.csv")
            best_plane_params_verified, obj_fun_verified, difference_parameters = obj_fun_heatmap(sampling_results_csv_path, output_path_patient, 
                                                                                    best_plane_params, param_ranges, num_steps, method = 'BFGS', sampling_params = best_plane_params,
                                                                                    heatmap_range = None)
            verification_end = time.time()
            
            print(f"Verified best plane parameters: {best_plane_params_verified} with objective function value: {obj_fun_verified}")
            print(f"Time taken for verification: {verification_end - verification_start:.2f} seconds")
            print(f"Difference between optimized and sampled parameters:\n"
                f"Δθ = {difference_parameters[0]} °\n"
                f"Δφ = {difference_parameters[1]} °\n"
                f"ΔL = {difference_parameters[2]} [voxel units]")
        
        #compute_deformed_plane(image, body, mandibula, spinalcord, best_plane_params)
        #plot_ct_with_plane(image, best_plane_params)
        
        display_scrollable_slices(image, body, gtvp, mandibula, spinalcord, [best_plane_params], ['BFGS'])
                      
        # plot_mse_vs_parameters(bone_ct, soft_tissue, image_plot, pat_num, output_path_patient, interpolation=interpolation, obj_val_update=None,
        #                     interpolation_method= interpolation_method, interpolator=interpolator, best_plane_params=best_plane_params)
        
        # verify_plane_params_rotated(bone_ct, soft_tissue, image_plot, pat_num, interpolation, interpolation_method, interpolator, output_path, verification_list, 
        #                         [], body, ['BFGS'])

        
            
        # display_scrollable_slices(image, body, gtvp, mandibula, spinalcord, [best_plane_params], optimization_methods_list)
                                       
        pat_num += 1
        if patient is not None:
            break
        if pat_num == max_pat_num:
            if not results_opt and optimization_method != 'all':
                np.save(os.path.join(output_path, params_filename), list_best_plane_params)
                np.save(os.path.join(output_path, objective_function_filename), list_obj_fun)
            if optimization_method == 'all':
                np.save(os.path.join(output_path, 'best_plane_params_NM.npy'), list_best_plane_params_NM)
                np.save(os.path.join(output_path, 'best_plane_params_BFGS.npy'), list_best_plane_params_BFGS)
                np.save(os.path.join(output_path, 'objective_function_NM.npy'), list_obj_fun_NM)
                np.save(os.path.join(output_path, 'objective_function_BFGS.npy'), list_obj_fun_BFGS)
            end_pipeline = time.time()
            print(f"Time taken for {pat_num} patients: {end_pipeline - start_pipeline:.2f} seconds")
            break
    
    return list_best_plane_params, list_obj_fun

## Visualize midplane of patients
Uses scrollable widged to vidualize optimized plane for a specific patient

In [None]:
def visualize_planes(base_path, csv_filepath, output_path, results_path_1, optimization_methods_list, patient = None, 
                     image_range = 'full', csv_sample_path = None, show_images = False):
    """
    Visualizes planes for medical imaging data.
    Parameters:
    base_path (str): The base directory path where patient data is stored.
    csv_filepath (str): The file path to the CSV file containing patient data.
    output_path (str): The directory path where output data will be saved.
    best_plane_params_path (str): The file path to the numpy file containing the best plane parameters.
    best_mse_path (str): The file path to the numpy file containing the best mean squared error values.
    patient (int, optional): The specific patient number to visualize. If None, all patients will be visualized. Default is None.
    image_range (str, optional): The range of images to visualize. Can be 'full' or 'gtvp'. Default is 'full'.
    Returns:
    None
    """

    pat_num = 0
    if csv_sample_path is not None:
        df = pd.read_csv(csv_sample_path)
        min_mse_row = df.loc[df['mse'].idxmin()]
        theta = min_mse_row['theta']
        phi = min_mse_row['phi']
        L = min_mse_row['L']
        A,B,C = angles_to_vector(theta, phi, L)
        D = -np.dot([A, B, C], [A, B, C])
        min_mse_plane_params = (A, B, C, D)

    # best_params_1 = np.load(os.path.join(results_path_1, "best_plane_params_patient_15.npy"), allow_pickle=True)
    # best_mse_1 = np.load(os.path.join(results_path_1, "objective_function_patient_15.npy"), allow_pickle=True)
    # best_params_2 = np.load(os.path.join(results_path_2, "parameters_optimized.npy"), allow_pickle=True)
    # best_mse_2 = np.load(os.path.join(results_path_2, "objective_function_values.npy"), allow_pickle=True)
    # list_best_mse = [best_mse_1, best_mse_2[0]]
    # list_best_plane_params = [best_params_1, best_params_2[0]]
    images = []
    df = pd.read_csv(csv_filepath)
    for index, row in df.iterrows():
        if patient is not None:
            pat_num = patient
            index = patient
            row = df.iloc[patient]

        image, gtvp, body, spinalcord, mandibula, structure_images, voxel_size, patient_folder_path, output_path_patient, patient_id, extension, pat_num = process_patient_data(row, base_path, output_path, pat_num)
        #image = image * body
        
        if image_range == 'full':
            voxel_counts = count_voxels_per_slice(image, plot = False)
            valley_index = gradient_descent_voxel_counts(voxel_counts, plot = False)
            image, body, spinalcord, mandibula, gtvp = select_slices(valley_index, image, body, spinal_cord = spinalcord, mandible = mandibula, gtvp = gtvp)
           
        if body is None:
            print(f"Body was not found for patient {pat_num}. Skipping...")
        if spinalcord is None:
            print(f"Spinal cord was not found for patient {pat_num}. Skipping...")
        if mandibula is None:
            print(f"Mandibula was not found for patient {pat_num}. Skipping...")
        if gtvp is None:
            print(f"GTVP was not found for patient {pat_num}. Skipping...")
        
        

        # Show original image
        # Ensure image is in a safe data type
        image = image.astype(np.int16)  # Avoid overflow errors

        # Apply binary erosion (ensuring boolean output)
        body = binary_erosion(body, iterations=2).astype(np.uint8)

        # Apply body mask to image safely
        image = np.where(body == 1, image, -1000)  # Set background to -1000 to get rod of fidiucial markers
        
        # Apply bone threshold mask
        bone_mask = mask_via_threshold(image, HU_range=(900, 2500)).astype(np.uint8)
        bone_ct = image * bone_mask # now everything other than bone is 0

        # Extract dental fillings mask
        dental_fillings_mask = mask_via_threshold(image, HU_range=(2500, 5000)).astype(np.uint16)

        # Assign teeth like HU values to dental fillings
        dental_bone_ct = 2000 * dental_fillings_mask

        # Combine bone and dental structures which substitutes dental fillings for bone
        bone_ct = bone_ct + dental_bone_ct
        
        # Prepare image for visualization
        image_plot = np.copy(image)
        soft_tissue = mask_via_threshold(image, HU_range = (-2000, 900)).astype(np.uint8)

        if image_range == 'gtvp':
            start, end = get_nonzero_slice_range(gtvp)
            slice_range = (start - 2, end + 2)
            body = body[:,:,slice_range[0]:slice_range[1]]
            image = image[:,:,slice_range[0]:slice_range[1]]
            image_plot = image_plot[:,:,slice_range[0]:slice_range[1]]
            bone_ct = bone_ct[:,:,slice_range[0]:slice_range[1]]
            soft_tissue = soft_tissue[:,:,slice_range[0]:slice_range[1]]
            gtvp = gtvp[:,:,slice_range[0]:slice_range[1]]
            mandibula = mandibula[:,:,slice_range[0]:slice_range[1]]
            spinalcord = spinalcord[:,:,slice_range[0]:slice_range[1]]


        if show_images:
            display_scrollable_image_with_values(image, title='Original Image')
            display_scrollable_image_with_values(bone_ct, title='Bone CT')
            display_scrollable_image_with_values(body, title='Body Mask')
            # pat_num += 1
            # if pat_num == 3:
            #     break
            break

        # # Save the bone_ct and image_plot as separate plots
        # plt.figure(figsize=(10, 10))
        # plt.imshow(bone_ct[:, :, bone_ct.shape[2] // 2], cmap='gray')
        # plt.title(f'Bone CT - Patient {pat_num}')
        # plt.axis('off')
        # plt.savefig(os.path.join(output_path_patient, f'bone_ct_patient_{pat_num}.png'))
        # plt.close()

        # plt.figure(figsize=(10, 10))
        # plt.imshow(image[:, :, image.shape[2] // 2], cmap='gray')
        # plt.title(f'Image Plot - Patient {pat_num}')
        # plt.axis('off')
        # plt.savefig(os.path.join(output_path_patient, f'image_patient_{pat_num}.png'))
        # plt.close()

        # list_best_plane_params = [(1, 0, 0, -image.shape[1]//2)]
        # optimization_methods_list = ['Middle of image plane']
        # display_scrollable_slices(image, body, gtvp, mandibula, spinalcord, list_best_plane_params, optimization_methods_list)
        # break

        
        if csv_sample_path is not None:
            list_best_plane_params.append(np.array(min_mse_plane_params))
        
        if patient is not None:
            results_path = os.path.join(results_path_1, f"pat_{patient}", f'best_plane_params_patient_{patient}_{optimization_methods_list[0]}.npy')
            list_best_plane_params = np.load(results_path)
            list_best_plane_params = [list_best_plane_params]
            display_scrollable_slices(image, body, gtvp, mandibula, spinalcord, list_best_plane_params, ['Gradient-Descent'])
            break
            
        display_scrollable_slices(image_plot, body, gtvp, mandibula, spinalcord, list_best_plane_params, optimization_methods_list)
        display_scrollable_slices(bone_ct, body, gtvp, mandibula, spinalcord, list_best_plane_params, optimization_methods_list)
        images.append(image)
        pat_num += 1
        

    return images, list_best_plane_params

def plot_ct_with_plane(ct_scan, plane_params):
    """
    Plot a 3D CT scan (512x512x20) and a plane defined by Ax + By + Cz + D = 0 interactively using Plotly.

    Parameters:
        ct_scan (numpy.ndarray): 3D NumPy array representing the CT scan.
        plane_params (tuple): (a, b, c, d) Plane parameters for Ax + By + Cz + D = 0.
    """
    a, b, c, d = plane_params
    vec = np.array([a, b, c])
    vec_normalized = vec / np.linalg.norm(vec)
    d = -np.dot(vec, vec_normalized)

    a, b, c, d = vec_normalized[0], vec_normalized[1], vec_normalized[2], d

    # Get CT scan dimensions
    nx, ny, nz = ct_scan.shape
    x_range = np.linspace(0, nx - 1, 100)
    y_range = np.linspace(0, ny - 1, 100)
    
    # Generate a meshgrid for the plane
    X, Y = np.meshgrid(x_range, y_range)
    
    # Solve for Z using the plane equation Ax + By + Cz + D = 0  ->  Z = (-A*X - B*Y - D) / C
    Z = (-a * X - b * Y - d) / c
    
    # Clip Z values to ensure they are within the CT scan bounds
    Z = np.clip(Z, 0, nz - 1)

    # Use an isosurface at 50% intensity for better CT scan visibility
    volume = go.Isosurface(
        x=np.repeat(np.arange(nx), ny * nz),
        y=np.tile(np.repeat(np.arange(ny), nz), nx),
        z=np.tile(np.arange(nz), nx * ny),
        value=ct_scan.flatten(),
        isomin=np.percentile(ct_scan, 50),  # Mid-range value to enhance contrast
        isomax=ct_scan.max(),
        opacity=0.5,  # Increase for visibility
        colorscale="gray",
        surface_count=5,  # More levels for better contrast
    )

    # Create the plane surface
    plane = go.Surface(
        x=X, y=Y, z=Z,
        colorscale="Viridis",
        opacity=0.4  # Reduce opacity so the CT scan is visible
    )

    # Create the interactive 3D figure
    fig = go.Figure(data=[volume, plane])

    # Set layout
    fig.update_layout(
        title="3D CT Scan with Overlayed Plane",
        scene=dict(
            xaxis_title="X-axis",
            yaxis_title="Y-axis",
            zaxis_title="Z-axis",
        ),
        margin=dict(l=0, r=0, b=0, t=40),
    )

    fig.show()


### Primary Gross Target Volume (GTV-p) characteristics

In [None]:
def gtvp_max_min_distance(indices, distances):

    if not indices or not distances:
        print(f"No gtvp extention on this side")
        return [], []

    min_distance = np.min(np.abs(distances))
    min_indices = indices[np.argmin(np.abs(distances))]
    max_distance = np.max(np.abs(distances))
    max_indices = indices[np.argmax(np.abs(distances))]

    return [min_indices, max_indices], [min_distance, max_distance]

def gtvp_volume(indices, voxel_size):
    
    #TODO: Check if len(indices) is correct
    if not indices:
        print(f"No volume on this side")
        return None
    voxel_volume = np.prod(voxel_size)
    number_of_gtv_voxels = len(indices)
    gtvp_volume = voxel_volume * number_of_gtv_voxels

    return gtvp_volume

def assign_distances_and_indices(position, indices, distances, distances_real = None):
    """
    Assign distances and indices to ipsilateral or contralateral based on the position argument.

    Parameters:
    position (str): The position argument. If 'positive', assign positive distances to ipsilateral. If 'negative', assign negative distances to ipsilateral.
    distances (list): The list of distances.
    indices (list): The list of indices corresponding to the distances.

    Returns:
    dict: A dictionary with keys 'ipsilateral' and 'contralateral' containing the assigned distances and indices.
    """
    assigned_distances = {'ipsilateral': [], 'contralateral': []}
    assigned_distances_real = {'ipsilateral': [], 'contralateral': []} if distances_real is not None else None
    assigned_indices = {'ipsilateral': [], 'contralateral': []}

    if position == 'positive':
        for idx, dist, dist_real in zip(indices, distances, distances_real):
            if dist >= 0:
                assigned_distances['ipsilateral'].append(dist)
                assigned_distances_real['ipsilateral'].append(dist_real) if distances_real is not None else None
                assigned_indices['ipsilateral'].append(idx)
            else:
                assigned_distances['contralateral'].append(dist)
                assigned_distances_real['contralateral'].append(dist_real) if distances_real is not None else None
                assigned_indices['contralateral'].append(idx)
    else:
        for idx, dist, dist_real in zip(indices, distances, distances_real):
            if dist < 0:
                assigned_distances['ipsilateral'].append(dist)
                assigned_distances_real['ipsilateral'].append(dist_real) if distances_real is not None else None
                assigned_indices['ipsilateral'].append(idx)
            else:
                assigned_distances['contralateral'].append(dist)
                assigned_distances_real['contralateral'].append(dist_real) if distances_real is not None else None
                assigned_indices['contralateral'].append(idx)
  
    if distances_real is not None:
        return assigned_indices, assigned_distances, assigned_distances_real
    else:
        return assigned_indices, assigned_distances

### Local playground

In [None]:
#%matplotlib widget
#%matplotlib inline

data_path = r"/home/loriskeller/Documents/Master Project/Patient data/patient_data_complete/06_midline_extraction"
csv_path = r"/home/loriskeller/Documents/Master Project/Patient data/patient_data_complete/updated_struc_and_extension.csv"

result_path_list = [r"/home/loriskeller/Documents/Master Project/VS/Data_extract_and_midline/results_17_03_25/pat_15 head range/pat_15/best_plane_params_patient_15_BFGS.npy",
                    r"/home/loriskeller/Documents/Master Project/VS/Data_extract_and_midline/results_17_03_25/pat_15 GTVp range/pat_15/best_plane_params_patient_15_BFGS.npy"]

verification_list = [(0,0)]
#     (60, 0)]
#     (-20, 60), (-20, -60),
#     (-40, 20), (-40, -20),
#     (0, 20), (0, -20),
#     (-40, 0),
#     (-10, 50), (-10, -50),
#     (30, 0)
# ]

# midline_optimized(data_path, csv_path, r"/home/loriskeller/Documents/Master Project/VS/Data_extract_and_midline/results_27_01_25/exhaustive search/new", 
#                   interpolation = 'full', patient = None, image_range = 'gtvp', optimization_method = 'exhaustive-search', opt_image = 'bone', interpolation_method = 'linear',
#                   max_pat_num = 1, verify_optimum=True, rotation_angle=None, metric = 'mse')

# midline_optimized(data_path, csv_path, r"/home/loriskeller/Documents/Master Project/VS/Data_extract_and_midline/results_24_02_25/All", 
#                   interpolation = 'full', patient = 15, image_range = 'gtvp', optimization_method = 'Nelder-Mead', opt_image = 'bone', interpolation_method = 'cubic',
#                   max_pat_num = 16, verify_optimum=False, rotation_angle=None, results_sampling=False,
#                   results_opt = False, metric = 'mse', multiresolution_opt=False)

midline_optimized(data_path, csv_path, r"/home/loriskeller/Documents/Master Project/VS/Data_extract_and_midline/results_17_03_25/pat_15 GTVp range", 
                  interpolation = 'full', patient = 0, image_range = 'gtvp', optimization_method = 'BFGS', opt_image = 'bone', interpolation_method = 'cubic',
                  max_pat_num = None, verify_optimum=False, rotation_angle=None,
                  metric = 'mse', verification_list=None, optimization_methods_list = ['Plane A', 'Plane B'], results_path_list = None)

In [None]:
path = r"/home/loriskeller/Documents/Master Project/VS/Data_extract_and_midline/results_03_03_25/check patient 15/900 to 2500 with dental fillings"

visualize_planes(data_path, csv_path, path,
                    path,
                    ['Thresh. (700-2500)', 'Thresh. (900-2500)'], show_images = False,
                    patient = 15, image_range = 'gtvp', csv_sample_path=None)

In [None]:
base_path = r"/home/loriskeller/Documents/Master Project/Patient data/patient_data_complete/06_midline_extraction"
csv_path = r"/home/loriskeller/Documents/Master Project/Patient data/patient_data_complete/updated_struc_and_extension.csv"

#parasm_list = np.load(r"/home/loriskeller/Documents/Master Project/VS/Data_extract_and_midline/results_01_02_25/NM_linear/parameters_optimized.npy")

# sampling_params = np.load(r"/home/loriskeller/Documents/Master Project/VS/Data_extract_and_midline/results_01_02_25/NM_linear/pat_15/best_plane_params_patient_15.npy")

# params_path_list = [r"/home/loriskeller/Documents/Master Project/VS/Data_extract_and_midline/results_01_02_25/GD_linear/pat_15/best_plane_params_patient_15.npy",
#                     r"/home/loriskeller/Documents/Master Project/VS/Data_extract_and_midline/results_01_02_25/GD_cubic/pat_15/best_plane_params_patient_15.npy"]
# methods_list = ['Gradient Descent linear', 'Gradient Descent cubic']

# statistics(csv_path, base_path, 15, params_path_list, r"/home/loriskeller/Documents/Master Project/VS/Data_extract_and_midline/results_01_02_25/NM_linear/pat_15/mse_results_patient_15.csv",
#             methods_list=methods_list, output_path=r"/home/loriskeller/Documents/Master Project/VS/Data_extract_and_midline/results_01_02_25/GD linear vs GD cubic",
#               param_ranges=(2,2,4), num_steps=(100,100,100), sampling_params=sampling_params)


# Gradient Computations for Objective Function

The objective function is given by:

$$
F = \sum_i \big(I(\vec{x}_i) - I(\vec{x}_i^m)\big)^2,
$$

where the mirrored point $\vec{x}_i^m$ is:

$$
\vec{x}_i^m = \vec{x}_i - 2 \cdot \left( \frac{A x_i + B y_i + C z_i}{A^2 + B^2 + C^2} - 1 \right) \cdot (A, B, C).
$$

## Gradient with Respect to $(A, B, C)$

The gradient of $F$ with respect to $p \in \{A, B, C\}$ is:

$$
\frac{\partial F}{\partial p} = 2 \sum_i \big(I(\vec{x}_i) - I(\vec{x}_i^m)\big) \cdot \left(-\nabla I(\vec{x}_i^m) \cdot \frac{\partial \vec{x}_i^m}{\partial p}\right).
$$

As the derivative of the scalar field $I(\vec{x}_i^m)$ is:

$$
\frac{\partial}{\partial p}(I(\vec{x}_i^m)) = \nabla I(\vec{x}_i^m) \cdot \frac{\partial \vec{x}_i^m}{\partial p},
$$

the mirrored point $\vec{x}_i^m$ is:

$$
\vec{x}_i^m = \vec{x}_i - 2 \cdot t \cdot (A, B, C),
$$

where:

$$
t = \frac{A x_i + B y_i + C z_i}{A^2 + B^2 + C^2} - 1.
$$

The derivative of $\vec{x}_i^m$ with respect to $p$ is:

$$
\frac{\partial \vec{x}_i^m}{\partial p} = -2 \cdot \frac{\partial t}{\partial p} \cdot (A, B, C) - 2 \cdot t \cdot \frac{\partial (A, B, C)}{\partial p}.
$$

The derivative of $t$ with respect to $p$ is:

$$
\frac{\partial t}{\partial p} = \frac{(x_i, y_i, z_i)_p \cdot (A^2 + B^2 + C^2) - (A x_i + B y_i + C z_i) \cdot 2p}{(A^2 + B^2 + C^2)^2}.
$$

Substituting these derivatives, the gradient becomes:

$$
\frac{\partial F}{\partial p} = 4 \sum_i \big(I(\vec{x}_i) - I(\vec{x}_i^m)\big) \cdot \nabla I(\vec{x}_i^m) \cdot 
\left[
\frac{(x_i, y_i, z_i)_p \cdot (A^2 + B^2 + C^2) - (A x_i + B y_i + C z_i) \cdot 2p}{(A^2 + B^2 + C^2)^2} \cdot (A, B, C)
+ \left(\frac{A x_i + B y_i + C z_i}{A^2 + B^2 + C^2} - 1\right) \cdot \frac{\partial (A, B, C)}{\partial p}
\right].
$$

where 

$$
p = A: \Rightarrow (x_i, y_i, z_i)_A = x_i, \frac{\partial (A, B, C)}{\partial A} = (1, 0, 0),
$$

$$
p = B: \Rightarrow (x_i, y_i, z_i)_B = y_i, \frac{\partial (A, B, C)}{\partial B} = (0, 1, 0),
$$

$$
p = C: \Rightarrow (x_i, y_i, z_i)_C = z_i, \frac{\partial (A, B, C)}{\partial C} = (0, 0, 1).
$$


## Derivative of $F$ with Respect to $A$

For $p = A$, we substitute:

- $(x_i, y_i, z_i)_A = x_i$,
- $\frac{\partial (A, B, C)}{\partial A} = (1, 0, 0)$.

The derivative is:

$$
\frac{\partial F}{\partial A} = 4 \sum_i \big(I(\vec{x}_i) - I(\vec{x}_i^m)\big) \cdot \nabla I(\vec{x}_i^m) \cdot 
\left[
\frac{x_i \cdot (A^2 + B^2 + C^2) - (A x_i + B y_i + C z_i) \cdot 2A}{(A^2 + B^2 + C^2)^2} \cdot (A, B, C)
+ \left(\frac{A x_i + B y_i + C z_i}{A^2 + B^2 + C^2} - 1\right) \cdot (1, 0, 0)
\right].
$$

---

## Derivative of $F$ with Respect to $B$

For $p = B$, we substitute:

- $(x_i, y_i, z_i)_B = y_i$,
- $\frac{\partial (A, B, C)}{\partial B} = (0, 1, 0)$.

The derivative is:

$$
\frac{\partial F}{\partial B} = 4 \sum_i \big(I(\vec{x}_i) - I(\vec{x}_i^m)\big) \cdot \nabla I(\vec{x}_i^m) \cdot 
\left[
\frac{y_i \cdot (A^2 + B^2 + C^2) - (A x_i + B y_i + C z_i) \cdot 2B}{(A^2 + B^2 + C^2)^2} \cdot (A, B, C)
+ \left(\frac{A x_i + B y_i + C z_i}{A^2 + B^2 + C^2} - 1\right) \cdot (0, 1, 0)
\right].
$$

---

## Derivative of $F$ with Respect to $C$

For $p = C$, we substitute:

- $(x_i, y_i, z_i)_C = z_i$,
- $\frac{\partial (A, B, C)}{\partial C} = (0, 0, 1)$.

The derivative is:

$$
\frac{\partial F}{\partial C} = 4 \sum_i \big(I(\vec{x}_i) - I(\vec{x}_i^m)\big) \cdot \nabla I(\vec{x}_i^m) \cdot 
\left[
\frac{z_i \cdot (A^2 + B^2 + C^2) - (A x_i + B y_i + C z_i) \cdot 2C}{(A^2 + B^2 + C^2)^2} \cdot (A, B, C)
+ \left(\frac{A x_i + B y_i + C z_i}{A^2 + B^2 + C^2} - 1\right) \cdot (0, 0, 1)
\right].
$$


In [None]:
def gradient_A_B_C(param_vec, *args):
    """
    Compute the gradient of a function with respect to parameters A, B, and C.
    Parameters:
    -----------
    param_vec : array-like
        A vector containing the parameters [A, B, C].
    *args : tuple
        Additional arguments required for the computation:
        - args[0] : ndarray
            The input image.
        - args[7] : str
            The interpolation method to be used by RegularGridInterpolator.
    Returns:
    --------
    gradient : ndarray
        A vector containing the gradients [dF/dA, dF/dB, dF/dC].
    Notes:
    ------
    This function computes the gradient of a function with respect to the parameters A, B, and C
    by using the Sobel operator to compute image gradients and RegularGridInterpolator for interpolation.
    The computation involves several steps including precomputing gradients, creating interpolators,
    computing distances, and mirror voxel intensities, and finally calculating the gradient terms.
    """
    
    start_gradient = time.time()

    # Extract parameters
    A, B, C = param_vec
    image = args[3]
    interpolation_method = args[7]
 

    # Precompute gradients
    gradient_x = sobel(image, axis=0)
    gradient_y = sobel(image, axis=1)
    gradient_z = sobel(image, axis=2)

    # Create interpolators
    grid_x, grid_y, grid_z = np.arange(image.shape[0]), np.arange(image.shape[1]), np.arange(image.shape[2])
    interpolators_gradient = {
        'x': RegularGridInterpolator((grid_x, grid_y, grid_z), gradient_x, method=interpolation_method, bounds_error=False, fill_value=None),
        'y': RegularGridInterpolator((grid_x, grid_y, grid_z), gradient_y, method=interpolation_method, bounds_error=False, fill_value=None),
        'z': RegularGridInterpolator((grid_x, grid_y, grid_z), gradient_z, method=interpolation_method, bounds_error=False, fill_value=None)
    }

    interpolator_intensity = RegularGridInterpolator((grid_x, grid_y, grid_z), image, method=interpolation_method, bounds_error=False, fill_value=None)

    # Get nonzero voxel indices
    mask_nonzero = np.argwhere(image != 0)
    x_i = mask_nonzero[:, [1, 0, 2]]  # Switch to (x, y, z) order
    I_x_i = image[mask_nonzero[:, 0], mask_nonzero[:, 1], mask_nonzero[:, 2]]

    # Compute distances and mirror voxel intensities
    param_vec_norm = np.array([A, B, C])
    D = - np.dot(param_vec_norm, param_vec_norm)
    param_vec_normalized = param_vec_norm / np.linalg.norm(param_vec_norm)
    distances = (np.dot(x_i, param_vec_norm) + D ) / np.linalg.norm(param_vec_norm)

    x_m = x_i - 2 * np.outer(distances, param_vec_normalized)
    x_m_image = x_m[:, [1, 0, 2]]
    # valid_mask = (
    #     (x_m_image[:, 0] >= 0) & (x_m_image[:, 0] < image.shape[1]) &
    #     (x_m_image[:, 1] >= 0) & (x_m_image[:, 1] < image.shape[0]) &
    #     (x_m_image[:, 2] >= 0) & (x_m_image[:, 2] < image.shape[2])
    # )

    # x_i = x_i[valid_mask]
    # x_m_image = x_m_image[valid_mask]
    # I_x_i = I_x_i[valid_mask]
    I_x_m = interpolator_intensity(x_m_image)

    # Compute gradients at mirror voxels
    gradient_I_x_m = np.vstack([
        interpolators_gradient['x'](x_m_image),
        interpolators_gradient['y'](x_m_image),
        interpolators_gradient['z'](x_m_image)
    ]).T

    # Compute terms for gradient calculation
    # Precompute scaling terms
    denominator = (A**2 + B**2 + C**2)
    denominator_squared = denominator**2

    # Precompute shared terms
    term = (I_x_i - I_x_m)[:, None] * gradient_I_x_m
    term2 = (np.dot(x_i, param_vec_norm) / denominator) - 1

    # Precompute scaled x_i for derivatives
    x_i_scaled = x_i / denominator

    # Compute explicit gradients in vectorized format
    dF_dA = np.sum(
        4 * np.dot(term, 
            (x_i_scaled[:, 0][:, None] - 2 * A * (np.dot(x_i, param_vec_norm)[:, None]) / denominator_squared) * param_vec_norm
        + term2 * np.array([1,0,0])), axis=1)
    

    dF_dB = np.sum(
        4 * np.dot(term,
            (x_i_scaled[:, 1][:, None] - 2 * B * (np.dot(x_i, param_vec_norm)[:, None]) / denominator_squared) * param_vec_norm
        + term2 * np.array([0,1,0])), axis=1)
    

    dF_dC = np.sum(
        4 * np.dot(term,
            (x_i_scaled[:, 2][:, None] - 2 * C * (np.dot(x_i, param_vec_norm)[:, None]) / denominator_squared) * param_vec_norm
            + term2 * np.array([0,0,1])), axis=1)
    

    end_gradient = time.time()
    print(f"Time taken to compute gradient: {end_gradient - start_gradient:.2f} seconds")

    return np.array([dF_dA, dF_dB, dF_dC])

## Gradient with Respect to $ \theta, \phi, L $

The parameters $ A, B, C $ are redefined as:

$$
A = L \cdot \cos(\phi) \cdot \cos(\theta),
$$

$$
B = L \cdot \cos(\phi) \cdot \sin(\theta),
$$

$$
C = L \cdot \sin(\phi).
$$

### Partial Derivative with Respect to $ \theta $:

The derivatives of $ A, B $ with respect to $ \theta $ are:

$$
\frac{\partial A}{\partial \theta} = -L \cdot \cos(\phi) \cdot \sin(\theta),
$$

$$
\frac{\partial B}{\partial \theta} = L \cdot \cos(\phi) \cdot \cos(\theta),
$$

$$
\frac{\partial C}{\partial \theta} = 0.
$$

The mirrored point $ \vec{x}_m $ depends on $ A, B, C $, so:

$$
\frac{\partial \vec{x}_m}{\partial \theta} = \frac{\partial \vec{x}_m}{\partial A} \cdot \frac{\partial A}{\partial \theta} + \frac{\partial \vec{x}_m}{\partial B} \cdot \frac{\partial B}{\partial \theta} + \frac{\partial \vec{x}_m}{\partial C} \cdot \frac{\partial C}{\partial \theta}.
$$

### Partial Derivative with Respect to $ \phi $:

The derivatives of $ A, B, C $ with respect to $ \phi $ are:

$$
\frac{\partial A}{\partial \phi} = -L \cdot \sin(\phi) \cdot \cos(\theta),
$$

$$
\frac{\partial B}{\partial \phi} = -L \cdot \sin(\phi) \cdot \sin(\theta),
$$

$$
\frac{\partial C}{\partial \phi} = L \cdot \cos(\phi).
$$

Thus:

$$
\frac{\partial \vec{x}_m}{\partial \phi} = \frac{\partial \vec{x}_m}{\partial A} \cdot \frac{\partial A}{\partial \phi} + \frac{\partial \vec{x}_m}{\partial B} \cdot \frac{\partial B}{\partial \phi} + \frac{\partial \vec{x}_m}{\partial C} \cdot \frac{\partial C}{\partial \phi}.
$$

### Partial Derivative with Respect to $ L $:

The derivatives of $ A, B, C $ with respect to $ L $ are:

$$
\frac{\partial A}{\partial L} = \cos(\phi) \cdot \cos(\theta),
$$

$$
\frac{\partial B}{\partial L} = \cos(\phi) \cdot \sin(\theta),
$$

$$
\frac{\partial C}{\partial L} = \sin(\phi).
$$

Thus:

$$
\frac{\partial \vec{x}_m}{\partial L} = \frac{\partial \vec{x}_m}{\partial A} \cdot \frac{\partial A}{\partial L} + \frac{\partial \vec{x}_m}{\partial B} \cdot \frac{\partial B}{\partial L} + \frac{\partial \vec{x}_m}{\partial C} \cdot \frac{\partial C}{\partial L}.
$$

### Gradient of $ F $:

The full gradient of $ F $ with respect to $ \theta, \phi, L $ is:

$$
\frac{\partial F}{\partial \theta} = 2 \sum_i \big(I(\vec{x}_i) - I(\vec{x}_i^m)\big) \cdot \left(-\nabla I(\vec{x}_i^m) \cdot \frac{\partial \vec{x}_i^m}{\partial \theta} \right),
$$

$$
\frac{\partial F}{\partial \phi} = 2 \sum_i \big(I(\vec{x}_i) - I(\vec{x}_i^m)\big) \cdot \left(-\nabla I(\vec{x}_i^m) \cdot \frac{\partial \vec{x}_i^m}{\partial \phi} \right),
$$

$$
\frac{\partial F}{\partial L} = 2 \sum_i \big(I(\vec{x}_i) - I(\vec{x}_i^m)\big) \cdot \left(-\nabla I(\vec{x}_i^m) \cdot \frac{\partial \vec{x}_i^m}{\partial L} \right).
$$



### Gradient of $ F $ with Respect to $ \theta, \phi, L $

Using the redefined parameters:

$$
A = L \cdot \cos(\phi) \cdot \cos(\theta),
$$

$$
B = L \cdot \cos(\phi) \cdot \sin(\theta),
$$

$$
C = L \cdot \sin(\phi),
$$

and substituting the results of the partial derivatives into the expression for the gradient of \( F \), we get:

#### Partial Derivative of $ F $ with Respect to $ \theta $:

Substitute:

$$
\frac{\partial A}{\partial \theta} = -L \cdot \cos(\phi) \cdot \sin(\theta),
$$

$$
\frac{\partial B}{\partial \theta} = L \cdot \cos(\phi) \cdot \cos(\theta),
$$

$$
\frac{\partial C}{\partial \theta} = 0,
$$

into:

$$
\frac{\partial \vec{x}_m}{\partial \theta} = \frac{\partial \vec{x}_m}{\partial A} \cdot \frac{\partial A}{\partial \theta} + \frac{\partial \vec{x}_m}{\partial B} \cdot \frac{\partial B}{\partial \theta} + \frac{\partial \vec{x}_m}{\partial C} \cdot \frac{\partial C}{\partial \theta}.
$$

The derivative of $ F $ is:

$$
\frac{\partial F}{\partial \theta} = 2 \sum_i \big(I(\vec{x}_i) - I(\vec{x}_i^m)\big) \cdot \left(-\nabla I(\vec{x}_i^m) \cdot \left[ \frac{\partial \vec{x}_m}{\partial A} \cdot (-L \cdot \cos(\phi) \cdot \sin(\theta)) + \frac{\partial \vec{x}_m}{\partial B} \cdot (L \cdot \cos(\phi) \cdot \cos(\theta)) \right] \right).
$$

#### Partial Derivative of $ F $ with Respect to $ \phi $:

Substitute:

$$
\frac{\partial A}{\partial \phi} = -L \cdot \sin(\phi) \cdot \cos(\theta),
$$

$$
\frac{\partial B}{\partial \phi} = -L \cdot \sin(\phi) \cdot \sin(\theta),
$$

$$
\frac{\partial C}{\partial \phi} = L \cdot \cos(\phi),
$$

into:

$$
\frac{\partial \vec{x}_m}{\partial \phi} = \frac{\partial \vec{x}_m}{\partial A} \cdot \frac{\partial A}{\partial \phi} + \frac{\partial \vec{x}_m}{\partial B} \cdot \frac{\partial B}{\partial \phi} + \frac{\partial \vec{x}_m}{\partial C} \cdot \frac{\partial C}{\partial \phi}.
$$

The derivative of $ F $ is:

$$
\frac{\partial F}{\partial \phi} = 2 \sum_i \big(I(\vec{x}_i) - I(\vec{x}_i^m)\big) \cdot \left(-\nabla I(\vec{x}_i^m) \cdot \left[ \frac{\partial \vec{x}_m}{\partial A} \cdot (-L \cdot \sin(\phi) \cdot \cos(\theta)) + \frac{\partial \vec{x}_m}{\partial B} \cdot (-L \cdot \sin(\phi) \cdot \sin(\theta)) + \frac{\partial \vec{x}_m}{\partial C} \cdot (L \cdot \cos(\phi)) \right] \right).
$$

#### Partial Derivative of $ F $ with Respect to $ L $:

Substitute:

$$
\frac{\partial A}{\partial L} = \cos(\phi) \cdot \cos(\theta),
$$

$$
\frac{\partial B}{\partial L} = \cos(\phi) \cdot \sin(\theta),
$$

$$
\frac{\partial C}{\partial L} = \sin(\phi),
$$

into:

$$
\frac{\partial \vec{x}_m}{\partial L} = \frac{\partial \vec{x}_m}{\partial A} \cdot \frac{\partial A}{\partial L} + \frac{\partial \vec{x}_m}{\partial B} \cdot \frac{\partial B}{\partial L} + \frac{\partial \vec{x}_m}{\partial C} \cdot \frac{\partial C}{\partial L}.
$$

The derivative of $ F $ is:

$$
\frac{\partial F}{\partial L} = 2 \sum_i \big(I(\vec{x}_i) - I(\vec{x}_i^m)\big) \cdot \left(-\nabla I(\vec{x}_i^m) \cdot \left[ \frac{\partial \vec{x}_m}{\partial A} \cdot (\cos(\phi) \cdot \cos(\theta)) + \frac{\partial \vec{x}_m}{\partial B} \cdot (\cos(\phi) \cdot \sin(\theta)) + \frac{\partial \vec{x}_m}{\partial C} \cdot (\sin(\phi)) \right] \right).
$$


In [None]:
def gradient_theta_phi_L(param_vec, *args):
    """
    Compute the gradient with respect to theta (alpha), phi (beta), and L for the given image.
    """
    start_gradient = time.time()

    # Extract parameters
    alpha, beta, L = param_vec
    image = args[0]
    interpolation_method = args[7]

    # Compute A, B, C based on alpha, beta, L
    A, B, C = angles_to_vector(alpha, beta, L)
    D = -(A**2 + B**2 + C**2)
    
    # Compute Sobel gradients
    gradient_x = sobel(image, axis=0)
    gradient_y = sobel(image, axis=1)
    gradient_z = sobel(image, axis=2)

    # Create grid for interpolation
    grid_x, grid_y, grid_z = np.arange(image.shape[0]), np.arange(image.shape[1]), np.arange(image.shape[2])
    interpolator_gradient = {
        'x': RegularGridInterpolator((grid_x, grid_y, grid_z), gradient_x, method=interpolation_method, bounds_error=False, fill_value=0),
        'y': RegularGridInterpolator((grid_x, grid_y, grid_z), gradient_y, method=interpolation_method, bounds_error=False, fill_value=0),
        'z': RegularGridInterpolator((grid_x, grid_y, grid_z), gradient_z, method=interpolation_method, bounds_error=False, fill_value=0),
    }

    interpolator_intensity = RegularGridInterpolator((grid_x, grid_y, grid_z), image, method=interpolation_method, bounds_error=False, fill_value=0)

    # Mask for nonzero image voxels
    mask_nonzero = np.argwhere(image != 0)
    x_i = mask_nonzero[:, [1, 0, 2]]  # Convert to (x, y, z)
    I_x_i = image[mask_nonzero[:, 0], mask_nonzero[:, 1], mask_nonzero[:, 2]]

    # Compute distances and mirrored voxels
    param_vec_norm = np.array([A, B, C])
    param_vec_normalized = param_vec_norm / np.linalg.norm(param_vec_norm)
    D_normalized = -np.dot(param_vec_norm, param_vec_normalized)
    distances = np.dot(x_i, param_vec_normalized) + D_normalized


    x_m = x_i - 2 * np.outer(distances, param_vec_normalized)
    x_m_image = x_m[:, [1, 0, 2]]
    
    I_x_m = interpolator_intensity(x_m_image)

    # Compute gradients at mirrored voxels
    gradient_I_x_m = np.vstack([
        interpolator_gradient['x'](x_m_image),
        interpolator_gradient['y'](x_m_image),
        interpolator_gradient['z'](x_m_image)
    ]).T

    denominator = A**2 + B**2 + C**2
    denominator_squared = denominator**2

    # Compute partial derivatives of A, B, C with respect to alpha, beta, L
    dA_dalpha = -L * np.cos(beta) * np.sin(alpha)
    dB_dalpha = L * np.cos(alpha) * np.cos(beta)
    dC_dalpha = 0

    dA_dbeta = -L * np.sin(beta) * np.cos(alpha)
    dB_dbeta = -L * np.sin(alpha) * np.sin(beta)
    dC_dbeta = L * np.cos(beta)

    dA_dL = np.cos(alpha) * np.cos(beta)
    dB_dL = np.sin(alpha) * np.cos(beta)
    dC_dL = np.sin(beta)

    # Define dot product of x_i and param_vec_norm for clarity
    term = (I_x_i - I_x_m)[:, None] * gradient_I_x_m #shape (n,3)
    term1 = np.dot(x_i, param_vec_norm)[:, None] #shape (n,1)
    term2 = (term1 / denominator) - 1 #shape (n,1)

    # shape(x_i[:,i][:,None]) = (n,1)
    # shape((x_i[:,i][:,None]-term1)*param_vec_norm) = (n,1) * (3,1) = (n,3)
    # shape(term2 * np.array([1, 0, 0])) = (n,3)

    # Precompute gradients with respect to A, B, and C for all points
    dx_m_dA = ((x_i[:, 0][:,None] * denominator - 2 * term1 * A) /
                denominator_squared) * param_vec_norm + term2 * np.array([1, 0, 0])

    dx_m_dB = ((x_i[:, 1][:,None]  * denominator - 2 * term1 * B) /
                denominator_squared) * param_vec_norm + term2 * np.array([0, 1, 0])

    dx_m_dC = ((x_i[:, 2][:,None]  * denominator - 2 * term1 * C) /
                denominator_squared) * param_vec_norm + term2 * np.array([0, 0, 1])

    dx_m_dalpha = dA_dalpha * dx_m_dA + dB_dalpha * dx_m_dB + dC_dalpha * dx_m_dC
    dx_m_dbeta = dA_dbeta * dx_m_dA + dB_dbeta * dx_m_dB + dC_dbeta * dx_m_dC
    dx_m_dL = dA_dL * dx_m_dA + dB_dL * dx_m_dB + dC_dL * dx_m_dC

    # Compute the final gradients
    dF_dalpha = np.sum(4 * np.sum(term * dx_m_dalpha, axis=1))
    dF_dbeta = np.sum(4 * np.sum(term * dx_m_dbeta, axis=1))
    dF_dL = np.sum(4 * np.sum(term * dx_m_dL, axis=1))

    end_gradient = time.time()
    print(f"Time taken to compute gradient: {end_gradient - start_gradient:.2f} seconds")

    return np.array([dF_dalpha, dF_dbeta, dF_dL])

## Direkt gradient according to parametrized objective function

In [None]:
def compute_gradient(theta, phi, L, image, interpolator_intensity, interpolators_gradient):
    """
    Compute the gradient of the objective function
        f(theta, phi, L) = (1/N)*sum_{i}[ I(x_i) - I(x_m,i) ]^2,
    with respect to the parameters theta, phi, and L.
    
    The mirror voxel for each voxel x is defined as:
        x_m = x - 2 * alpha * n,
    where
        n = [cos(phi)*cos(theta), cos(phi)*sin(theta), sin(phi)]
        alpha = cos(phi)*cos(theta)*x + cos(phi)*sin(theta)*y + sin(phi)*z - L.
    
    The function extracts the indices of the nonzero elements in the image as x.
    
    Parameters:
      theta : float
          Rotation angle around the z-axis.
      phi : float
          Rotation angle around the y-axis.
      L : float
          Offset along the normal.
      image : 3D numpy array
          The intensity image.
      interpolator_intensity : RegularGridInterpolator
          Interpolator to get the intensity I at any (x,y,z) location.
      interpolators_gradient : dict
          Dictionary with keys 'x', 'y', and 'z' containing RegularGridInterpolator
          objects for the gradient components of the image.
    
    Returns:
      grad : numpy array of shape (3,)
          The gradient [df/dtheta, df/dphi, df/dL].
    """
    # Get voxel indices (x) as an array of shape (N, 3) from nonzero elements.
    indices = np.array(np.nonzero(image)).T  # each row is [x, y, z]
    N = indices.shape[0]
    
    # Define the unit normal vector n.
    n = np.array([np.cos(phi)*np.cos(theta),
                  np.cos(phi)*np.sin(theta),
                  np.sin(phi)])
    
    # Compute alpha for each voxel.
    # indices[:,0] -> x, indices[:,1] -> y, indices[:,2] -> z.
    alpha = (np.cos(phi)*np.cos(theta)*indices[:,0] +
             np.cos(phi)*np.sin(theta)*indices[:,1] +
             np.sin(phi)*indices[:,2] - L)  # shape (N,)
    
    # Compute mirror voxels x_m vectorized.
    # x_m = x - 2*alpha*n, broadcasting n (shape (3,)) and alpha (shape (N,))
    x_m = indices - 2 * alpha[:, None] * n[None, :]
    
    # Evaluate the intensity at the mirror voxel positions.
    I_m = interpolator_intensity(x_m)
    # Get the intensity at the original voxel positions from the image.
    I_orig = image[indices[:,0], indices[:,1], indices[:,2]]
    
    # Compute the difference d = I(x) - I(x_m) for each voxel.
    d = I_orig - I_m  # shape (N,)
    
    # Evaluate the gradient of the intensity at mirror positions.
    grad_x = interpolators_gradient['x'](x_m)
    grad_y = interpolators_gradient['y'](x_m)
    grad_z = interpolators_gradient['z'](x_m)
    grad_I = np.stack([grad_x, grad_y, grad_z], axis=1)  # shape (N, 3)
    
    # -------------------------------
    # Compute the derivatives of x_m with respect to each parameter.
    # -------------------------------
    # For L:
    d_xm_dL = 2 * n  # shape (3,)
    
    # For theta:
    # term_theta = cos(phi)*sin(theta)*x - cos(phi)*cos(theta)*y
    term_theta = np.cos(phi)*np.sin(theta)*indices[:,0] - np.cos(phi)*np.cos(theta)*indices[:,1]
    d_xm_dtheta = 2 * ( term_theta[:, None] * n[None, :] +
                        alpha[:, None] * np.array([np.cos(phi)*np.sin(theta),
                                                   -np.cos(phi)*np.cos(theta),
                                                    0]) )
    
    # For phi:
    # term_phi = sin(phi)*cos(theta)*x + sin(phi)*sin(theta)*y - cos(phi)*z
    term_phi = (np.sin(phi)*np.cos(theta)*indices[:,0] +
                np.sin(phi)*np.sin(theta)*indices[:,1] -
                np.cos(phi)*indices[:,2])
    d_xm_dphi = 2 * ( term_phi[:, None] * n[None, :] +
                      alpha[:, None] * np.array([np.sin(phi)*np.cos(theta),
                                                 np.sin(phi)*np.sin(theta),
                                                -np.cos(phi)]) )
    
    # -------------------------------
    # Compute dot products between grad_I and the derivatives of x_m.
    dot_theta = np.sum(grad_I * d_xm_dtheta, axis=1)  # shape (N,)
    dot_phi   = np.sum(grad_I * d_xm_dphi, axis=1)
    dot_L     = np.sum(grad_I * d_xm_dL[None, :], axis=1)
    
    # Compute the gradient of the objective function with respect to each parameter.
    # df/dp = - (2/N) * sum_i [ d_i * (grad_I dot (dx_m/dp)) ]
    grad_theta = - (2.0 / N) * np.sum(d * dot_theta)
    grad_phi   = - (2.0 / N) * np.sum(d * dot_phi)
    grad_L     = - (2.0 / N) * np.sum(d * dot_L)
    
    return np.array([grad_theta, grad_phi, grad_L])
