# Mass Integrator: Visualization and Analysis Toolkit

**Notebook Status:** <font color='green'><b>Validated</b></font>

## Introduction and Goals

This notebook is the primary tool for visualizing, analyzing, and validating the data produced by the `mass_integrator` C code project. While the `mass_integrator` notebook is responsible for *generating* the simulation code, this notebook is responsible for *processing its output*.

The tools provided here serve three main purposes:

1.  **Validation (Debug Mode):** When the C code is run in "debug mode," it produces a detailed text file (`massive_particle_path.txt`) tracing the trajectory of a single particle. The functions in **Part 1** of this notebook read this file and generate plots to validate the numerical accuracy of the integrator against known physics, such as the conservation of orbital radius and the rates of relativistic precession.

2.  **Visualization (Production Mode):** When the C code is run in "production mode," it generates a series of binary snapshot files (`mass_blueprint_t_xxxx.bin`) representing the state of the entire accretion disk at different times. The functions in **Part 2** read these binary files to produce 3D visualizations of the disk and stitch them together into animations of the disk's evolution.

3.  **Pre-processing for the Photon Integrator:** The `photon_geodesic_integrator` requires the disk snapshot data to be in a highly optimized format for fast spatial queries. The functions in **Part 3** perform this crucial pre-processing step, converting the raw `.bin` snapshots into query-ready `.kdtree.bin` files.

### How to Use This Notebook

This notebook is intended to be used *after* you have successfully compiled and run the `mass_integrator` C code.

*   To use the **Part 1** validation tools, run the C code with the `run_in_debug_mode = True` parameter.
*   To use the **Part 2 & 3** visualization and pre-processing tools, run the C code with `run_in_debug_mode = False`.

# Table of Contents

This notebook provides tools for visualizing and validating the output of the `mass_integrator` C code.

*   [Imports and Setup](#imports)

**Part 1: Debug Mode Analysis (Single Particle Trajectory)**
*   [1.a: 3D Trajectory Plot](#part_1a)
*   [1.b: Trajectory Component Plots](#part_1b)
*   [1.c: Orbital Validation Plots](#part_1c)

**Part 2: Production Mode Analysis (Disk Snapshots)**
*   [2.a: Disk Snapshot Visualizer](#part_2a)
*   [2.b: Create Animation of Disk Evolution](#part_2b)

**Part 3: K-d Tree Pre-processor**
*   [3.a: Build K-d Trees from Snapshots](#part_3a)
*   [3.b: Inspect a K-d Tree File](#part_3b)

**Part 4: Verification Suite for Keplerian Orbits**
*   [4.a: Analytical Solution and Verification Driver](#part_4a)

<a id='imports'></a>
### **Imports and Setup**

This cell imports all Python libraries required for the notebook to function. By centralizing all imports here, we can easily see the dependencies of the project and avoid redundant `import` statements in later cells.

In [None]:
# General utilities
import os
import glob
import time
import shutil
from collections import deque
from functools import partial
from typing import List, Tuple

# Core data science and plotting libraries
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from scipy.stats import linregress

# Performance optimization
import numba
from multiprocessing import Pool, cpu_count

# Progress bar
from tqdm import tqdm

# The nrpy.params module is needed to read the snapshot_every_t parameter
# when generating animation frames.
try:
    import nrpy.params as par
except ImportError:
    print("Warning: nrpy.params not found. Animation generation might use a default time step.")
    # Define a dummy class if nrpy is not available, so the notebook doesn't crash.
    class Par:
        def parval_from_str(self, param_name):
            if param_name == "snapshot_every_t":
                return 10.0 # A reasonable default
            raise ValueError(f"Parameter {param_name} not found.")
    par = Par()

<a id='part_1a'></a>
### 1.a: 3D Trajectory Plot

This section provides the function `plot_particle_trajectory`, which reads the `massive_particle_path.txt` file generated in debug mode. It creates a 3D visualization of the particle's orbit, showing its path relative to the black hole's event horizon. This is useful for getting an intuitive, qualitative sense of the trajectory.

In [None]:
def plot_particle_trajectory(
    project_dir: str = "project/mass_integrator",
    input_filename: str = "massive_particle_path.txt",
    M_scale: float = 1.0,
    a_spin: float = 0.9
) -> None:
    """
    Reads the trajectory data from the C code's output file and generates
    a 3D plot of the particle's orbit around the black hole.

    Args:
        project_dir: The root directory of the C project where the output file is located.
        input_filename: The name of the trajectory data file.
        M_scale: The mass of the black hole, used to plot the event horizon.
        a_spin: The spin of the black hole, used to plot the event horizon.
    """
    print("--- Generating Particle Trajectory Plot ---")
    
    # --- 1. Construct the full path and load the data ---
    full_path = os.path.join(project_dir, input_filename)
    
    if not os.path.exists(full_path):
        print(f"ERROR: Trajectory file not found at '{full_path}'")
        print("Please ensure you have compiled and run the C code successfully.")
        return

    try:
        # Load the data, skipping the header row
        data = np.loadtxt(full_path, skiprows=1)
        # Columns: 0:τ, 1:t, 2:x, 3:y, 4:z, 5:u^t, 6:u^x, 7:u^y, 8:u^z
        x_coords = data[:, 2]
        y_coords = data[:, 3]
        z_coords = data[:, 4]
        print(f"Successfully loaded {len(x_coords)} data points from trajectory file.")
    except Exception as e:
        print(f"ERROR: Failed to load or parse the data file '{full_path}'.")
        print(f"Exception: {e}")
        return

    # --- 2. Set up the 3D plot ---
    plt.style.use('dark_background')
    fig = plt.figure(figsize=(12, 10))
    ax = fig.add_subplot(111, projection='3d')

    # --- 3. Plot the particle's trajectory ---
    ax.plot(x_coords, y_coords, z_coords, label='Particle Orbit', color='cyan', lw=2)
    
    # Mark the start and end points
    ax.scatter(x_coords[0], y_coords[0], z_coords[0], color='lime', s=100, label='Start', marker='o')
    ax.scatter(x_coords[-1], y_coords[-1], z_coords[-1], color='red', s=100, label='End', marker='X')

    # --- 4. Plot the black hole's event horizon ---
    # The radius of the event horizon for a Kerr black hole
    r_horizon = M_scale * (1 + np.sqrt(1 - a_spin**2))
    
    # Create a sphere for the event horizon
    u = np.linspace(0, 2 * np.pi, 100)
    v = np.linspace(0, np.pi, 100)
    x_bh = r_horizon * np.outer(np.cos(u), np.sin(v))
    y_bh = r_horizon * np.outer(np.sin(u), np.sin(v))
    z_bh = r_horizon * np.outer(np.ones(np.size(u)), np.cos(v))
    
    ax.plot_surface(x_bh, y_bh, z_bh, color='black', alpha=0.9, rstride=4, cstride=4)
    # Add a grey wireframe for better visibility
    ax.plot_wireframe(x_bh, y_bh, z_bh, color='dimgrey', alpha=0.2, rstride=10, cstride=10)

    # --- 5. Customize the plot ---
    ax.set_xlabel('X (M)', fontsize=12, labelpad=10)
    ax.set_ylabel('Y (M)', fontsize=12, labelpad=10)
    ax.set_zlabel('Z (M)', fontsize=12, labelpad=10)
    
    # Set equal aspect ratio
    max_range = np.array([x_coords.max()-x_coords.min(), y_coords.max()-y_coords.min(), z_coords.max()-z_coords.min()]).max() / 2.0
    mid_x = (x_coords.max()+x_coords.min()) * 0.5
    mid_y = (y_coords.max()+y_coords.min()) * 0.5
    mid_z = (z_coords.max()+z_coords.min()) * 0.5
    ax.set_xlim(mid_x - max_range, mid_x + max_range)
    ax.set_ylim(mid_y - max_range, mid_y + max_range)
    ax.set_zlim(mid_z - max_range, mid_z + max_range)

    ax.set_title(f"Massive Particle Trajectory (M={M_scale}, a={a_spin})", fontsize=16)
    ax.legend()
    ax.view_init(elev=30., azim=45) # Set a nice viewing angle
    
    plt.show()


<a id='part_1b'></a>
### 1.b: Trajectory Component Plots

The `plot_trajectory_components` function also reads the debug output file but generates a set of 2D plots. It shows the evolution of each Cartesian coordinate (`x`, `y`, `z`) as a function of coordinate time `t`, and also plots the relationship between the particle's proper time `τ` and coordinate time `t`. These plots are useful for analyzing oscillations and time dilation effects.

In [None]:
def plot_trajectory_components(
    project_dir: str = "project/mass_integrator",
    input_filename: str = "massive_particle_path.txt"
) -> None:
    """
    Reads trajectory data and creates four plots:
    x vs t, y vs t, z vs t, and proper time (τ) vs t.
    """
    print("--- Generating Trajectory Component Plots ---")
    
    # --- 1. Load the data ---
    full_path = os.path.join(project_dir, input_filename)
    
    if not os.path.exists(full_path):
        print(f"ERROR: Trajectory file not found at '{full_path}'")
        return

    try:
        data = np.loadtxt(full_path, skiprows=1)
        # Columns: 0:τ, 1:t, 2:x, 3:y, 4:z, ...
        proper_time = data[:, 0]
        coord_time = data[:, 1]
        x_coords = data[:, 2]
        y_coords = data[:, 3]
        z_coords = data[:, 4]
        print(f"Successfully loaded {len(coord_time)} data points.")
    except Exception as e:
        print(f"ERROR: Failed to load or parse data file '{full_path}'. Exception: {e}")
        return

    # --- 2. Create the plots ---
    plt.style.use('seaborn-v0_8-whitegrid')
    fig, axes = plt.subplots(2, 2, figsize=(16, 12))
    fig.suptitle('Particle Trajectory Components vs. Coordinate Time', fontsize=20)

    # Plot 1: x(t)
    axes[0, 0].plot(coord_time, x_coords, color='cyan')
    axes[0, 0].set_title('X Coordinate vs. Time', fontsize=14)
    axes[0, 0].set_xlabel('Coordinate Time (t) [M]', fontsize=12)
    axes[0, 0].set_ylabel('x [M]', fontsize=12)
    axes[0, 0].grid(True)

    # Plot 2: y(t)
    axes[0, 1].plot(coord_time, y_coords, color='magenta')
    axes[0, 1].set_title('Y Coordinate vs. Time', fontsize=14)
    axes[0, 1].set_xlabel('Coordinate Time (t) [M]', fontsize=12)
    axes[0, 1].set_ylabel('y [M]', fontsize=12)
    axes[0, 1].grid(True)

    # Plot 3: z(t)
    axes[1, 0].plot(coord_time, z_coords, color='lime')
    axes[1, 0].set_title('Z Coordinate vs. Time', fontsize=14)
    axes[1, 0].set_xlabel('Coordinate Time (t) [M]', fontsize=12)
    axes[1, 0].set_ylabel('z [M]', fontsize=12)
    axes[1, 0].grid(True)

    # Plot 4: τ(t)
    axes[1, 1].plot(coord_time, proper_time, color='gold')
    axes[1, 1].set_title('Proper Time vs. Coordinate Time', fontsize=14)
    axes[1, 1].set_xlabel('Coordinate Time (t) [M]', fontsize=12)
    axes[1, 1].set_ylabel('Proper Time (τ) [M]', fontsize=12)
    axes[1, 1].grid(True)

    plt.tight_layout(rect=[0, 0.03, 1, 0.95])
    plt.show()



<a id='part_1c'></a>
### 1.c: Orbital Validation Plots

The following functions perform quantitative validation of the integrator's accuracy by comparing the numerical results against known theoretical predictions from General Relativity for stable orbits.

In [None]:
def plot_radius_vs_time(
    project_dir: str = "project/mass_integrator",
    input_filename: str = "massive_particle_path.txt"
) -> None:
    """
    Reads trajectory data and plots the particle's radial distance (r)
    as a function of coordinate time (t) to validate circularity.
    """
    print("--- Generating Radius vs. Time Validation Plot ---")
    
    full_path = os.path.join(project_dir, input_filename)
    if not os.path.exists(full_path):
        print(f"ERROR: Trajectory file not found at '{full_path}'")
        return

    try:
        data = np.loadtxt(full_path, skiprows=1)
        coord_time = data[:, 1]
        x_coords = data[:, 2]
        y_coords = data[:, 3]
        z_coords = data[:, 4]
    except Exception as e:
        print(f"ERROR: Failed to load data. Exception: {e}")
        return

    # Calculate the radius at each time step
    radius = np.sqrt(x_coords**2 + y_coords**2 + z_coords**2)
    
    # Calculate statistics on the radius
    mean_radius = np.mean(radius)
    min_radius = np.min(radius)
    max_radius = np.max(radius)
    percent_variation = 100 * (max_radius - min_radius) / mean_radius

    print(f"Radius Statistics:")
    print(f"  Mean Radius: {mean_radius:.6f} M")
    print(f"  Min Radius:  {min_radius:.6f} M")
    print(f"  Max Radius:  {max_radius:.6f} M")
    print(f"  Total Variation: {percent_variation:.4f}%")

    # Create the plot
    plt.style.use('seaborn-v0_8-whitegrid')
    plt.figure(figsize=(12, 6))
    
    plt.plot(coord_time, radius, label='Particle Radius r(t)', color='cyan')
    
    # Add lines for mean, min, and max to visualize the variation
    plt.axhline(mean_radius, color='lime', linestyle='--', label=f'Mean r = {mean_radius:.4f}')
    
    plt.title('Validation: Particle Radius vs. Coordinate Time', fontsize=16)
    plt.xlabel('Coordinate Time (t) [M]', fontsize=12)
    plt.ylabel('Radius (r) [M]', fontsize=12)
    plt.legend()
    plt.grid(True)
    
    # Use a "tight" y-axis to emphasize any small variations
    plt.ylim(min_radius * 0.999, max_radius * 1.001)
    
    plt.show()



### Plotting Precession

In [None]:
def plot_precession_validation(
    project_dir: str = "project/mass_integrator",
    input_filename: str = "massive_particle_path.txt",
    M_scale: float = 1.0,
    a_spin: float = 0.9,
    r_initial: float = 10.0
) -> None:
    """
    Reads trajectory data and validates the orbital precession rate against
    the theoretical Lense-Thirring formula.
    """
    print("--- Generating Precession Validation Plot ---")
    
    full_path = os.path.join(project_dir, input_filename)
    if not os.path.exists(full_path):
        print(f"ERROR: Trajectory file not found at '{full_path}'")
        return

    try:
        data = np.loadtxt(full_path, skiprows=1)
        coord_time = data[:, 1]
        x_coords = data[:, 2]
        y_coords = data[:, 3]
    except Exception as e:
        print(f"ERROR: Failed to load data. Exception: {e}")
        return

    # Calculate the azimuthal angle phi at each time step
    phi = np.arctan2(y_coords, x_coords)
    
    # The angle will wrap around from +pi to -pi. We need to unwrap it.
    phi_unwrapped = np.unwrap(phi)

    # --- Theoretical Calculation ---
    r = r_initial
    Omega_K = (M_scale**0.5) / (r**1.5 + a_spin * M_scale**0.5)
    Omega_LT = (2 * M_scale * a_spin) / (r**3)
    Omega_phi_theory = Omega_K + Omega_LT
    
    # --- Measurement from Simulation Data ---
    # Perform a linear regression to find the slope of phi(t)
    regression = linregress(coord_time, phi_unwrapped)
    Omega_phi_measured = regression.slope
    
    percent_error = 100 * abs(Omega_phi_measured - Omega_phi_theory) / Omega_phi_theory

    print("Precession Rate (dφ/dt) Validation:")
    print(f"  Theoretical Ω_φ: {Omega_phi_theory:.6f} rad/M")
    print(f"  Measured Ω_φ (from data): {Omega_phi_measured:.6f} rad/M")
    print(f"  Relative Error: {percent_error:.4f}%")

    # --- Create the Plot ---
    plt.style.use('seaborn-v0_8-whitegrid')
    plt.figure(figsize=(12, 6))
    
    plt.plot(coord_time, phi_unwrapped, label='Measured φ(t) from Simulation', color='cyan', lw=2)
    plt.plot(coord_time, Omega_phi_theory * coord_time, label=f'Theoretical φ(t) (slope={Omega_phi_theory:.4f})', color='lime', linestyle='--', lw=2)
    
    plt.title('Validation: Orbital Precession (Frame-Dragging)', fontsize=16)
    plt.xlabel('Coordinate Time (t) [M]', fontsize=12)
    plt.ylabel('Azimuthal Angle (φ) [radians]', fontsize=12)
    plt.legend()
    plt.grid(True)
    
    plt.show()



### Plot apsidal precession

In [None]:
def plot_apsidal_precession_validation(
    project_dir: str = "project/mass_integrator",
    input_filename: str = "massive_particle_path.txt",
    M_scale: float = 1.0,
    a_spin: float = 0.9
) -> None:
    """
    Reads trajectory data and validates the apsidal precession rate against
    the theoretical GR formula for nearly circular orbits.
    """
    print("--- Generating Apsidal Precession Validation Plot ---")
    
    full_path = os.path.join(project_dir, input_filename)
    if not os.path.exists(full_path):
        print(f"ERROR: Trajectory file not found at '{full_path}'")
        return

    try:
        data = np.loadtxt(full_path, skiprows=1)
        x_coords = data[:, 2]
        y_coords = data[:, 3]
    except Exception as e:
        print(f"ERROR: Failed to load data. Exception: {e}")
        return

    radius = np.sqrt(x_coords**2 + y_coords**2)
    phi = np.unwrap(np.arctan2(y_coords, x_coords))
    
    # Find the angles where the particle is at periapsis (minimum radius)
    # We find indices where the radius is a local minimum
    periapsis_indices = (np.r_[True, radius[1:] < radius[:-1]] & np.r_[radius[:-1] < radius[1:], True]).nonzero()[0]
    
    if len(periapsis_indices) < 2:
        print("Could not find at least two periapsis points. Cannot calculate precession.")
        return
        
    # Calculate the measured precession angle per orbit
    delta_phi_measured = phi[periapsis_indices[1]] - phi[periapsis_indices[0]]
    precession_per_orbit_measured = delta_phi_measured - 2 * np.pi

    # --- Theoretical Calculation at the average radius of the orbit ---
    r_avg = np.mean(radius)
    M = M_scale
    a = a_spin
    
    Omega_phi_theory = (M**0.5) / (r_avg**1.5 + a * M**0.5)
    Omega_r_theory_sq = Omega_phi_theory**2 * (1 - (6*M)/r_avg + (8*a*M**0.5)/r_avg**1.5 - (3*a**2)/r_avg**2)
    
    if Omega_r_theory_sq < 0:
        print("Theoretical orbit is unstable (Ω_r^2 < 0). Cannot calculate precession.")
        return
        
    Omega_r_theory = np.sqrt(Omega_r_theory_sq)
    
    # Precession per unit time
    Omega_precession_theory = Omega_phi_theory - Omega_r_theory
    
    # Period of one radial oscillation
    T_r = 2 * np.pi / Omega_r_theory
    
    # Total precession angle over one radial period
    precession_per_orbit_theory = Omega_precession_theory * T_r
    
    percent_error = 100 * abs(precession_per_orbit_measured - precession_per_orbit_theory) / precession_per_orbit_theory

    print(f"Apsidal Precession Validation (at average radius r={r_avg:.3f} M):")
    print(f"  Measured precession per orbit:   {precession_per_orbit_measured:.6f} radians")
    print(f"  Theoretical precession per orbit: {precession_per_orbit_theory:.6f} radians")
    print(f"  Relative Error: {percent_error:.4f}%")

    # --- Create the Plot (Polar Plot) ---
    plt.style.use('dark_background')
    fig = plt.figure(figsize=(10, 10))
    ax = fig.add_subplot(111, polar=True)
    
    ax.plot(phi, radius, color='cyan', label='Particle Orbit')
    ax.scatter(phi[periapsis_indices], radius[periapsis_indices], color='lime', s=100, label='Periapsis Points', zorder=5)
    
    ax.set_title('Validation: Apsidal Precession of a Nearly Circular Orbit', fontsize=16)
    ax.set_xlabel('Azimuthal Angle (φ)', fontsize=12)
    ax.set_ylabel('Radius (r) [M]', fontsize=12, labelpad=-50)
    ax.legend()
    
    plt.show()



<a id='part_2a'></a>
## Part 2: Production Mode Analysis (Disk Snapshots)

### 2.a: Disk Snapshot Visualizer

When the `mass_integrator` is run in production mode, it outputs binary `.bin` files containing the state of all particles at specific time intervals. The `visualize_disk_snapshot` function reads one of these files, performs several sanity checks on the data, and generates a 3D scatter plot of the particle positions. This allows for a quick visual inspection of the disk's structure at any given time.

In [None]:
def visualize_disk_snapshot(
    project_dir: str = "project/mass_integrator",
    output_folder: str = "output",
    snapshot_index: int = -1, # -1 means the last available snapshot
    M_scale: float = 1.0,
    a_spin: float = 0.9
) -> None:
    """
    Reads a specific mass blueprint snapshot file, performs sanity checks,
    and generates a 3D plot of the particle disk.
    
    UPDATED to read the new binary format with the full 4-velocity.
    """
    print("--- Visualizing Mass Blueprint Snapshot ---")
    
    # --- 1. Find and Load the Snapshot File ---
    snapshot_dir = os.path.join(project_dir, output_folder)
    if not os.path.isdir(snapshot_dir):
        print(f"ERROR: Snapshot directory not found at '{snapshot_dir}'")
        return

    snapshot_files = sorted(glob.glob(os.path.join(snapshot_dir, "mass_blueprint_t_*.bin")))
    
    if not snapshot_files:
        print(f"ERROR: No snapshot .bin files found in '{snapshot_dir}'")
        return

    if snapshot_index == -1:
        snapshot_to_load = snapshot_files[-1]
    elif snapshot_index < len(snapshot_files):
        snapshot_to_load = snapshot_files[snapshot_index]
    else:
        print(f"ERROR: Snapshot index {snapshot_index} is out of bounds. Only {len(snapshot_files)} snapshots exist.")
        return

    print(f"Loading snapshot file: '{snapshot_to_load}'")

    # MODIFICATION: Define the dtype to match the new C struct with u[4].
    snapshot_dtype = np.dtype([
        ('id', np.int32),
        ('pos', 'f8', (3,)), 
        ('u', 'f8', (4,)),   # Changed from ('u_spatial', 'f8', (3,))
        ('lambda_rest', 'f8'),
        ('j_intrinsic', 'f4')
    ])

    try:
        # The header is a 4-byte int, not part of the dtype
        with open(snapshot_to_load, 'rb') as f:
            num_particles = np.fromfile(f, dtype=np.int32, count=1)[0]
            data = np.fromfile(f, dtype=snapshot_dtype, count=num_particles)
        print(f"Successfully loaded data for {num_particles} particles.")
    except Exception as e:
        print(f"ERROR: Failed to load or parse the data file '{snapshot_to_load}'.")
        print(f"Exception: {e}")
        return

    # --- 2. Perform Sanity Checks ---
    positions = data['pos']
    # MODIFICATION: Use the new 'u' field for velocities.
    velocities = data['u'] 
    
    radii = np.sqrt(positions[:, 0]**2 + positions[:, 1]**2)
    # Calculate speed from the SPATIAL components of the 4-velocity
    speeds = np.sqrt(velocities[:, 1]**2 + velocities[:, 2]**2 + velocities[:, 3]**2)

    print("\n--- Data Sanity Checks ---")
    print(f"  Particle count: {num_particles}")
    print(f"  Mean radius: {np.mean(radii):.3f} M (should be between disk_r_min and disk_r_max)")
    print(f"  Mean spatial 4-velocity magnitude: {np.mean(speeds):.3f}")
    print(f"  Max z-coordinate: {np.max(np.abs(positions[:, 2])):.2e} (should be close to zero)")
    
    if np.any(np.isnan(positions)):
        nan_count = np.count_nonzero(np.isnan(data['pos'][:,0]))
        print(f"  WARNING: Found {nan_count} terminated (NaN) particles in this snapshot.")
    
    # --- 3. Create the 3D Plot (No changes needed here) ---
    plt.style.use('dark_background')
    fig = plt.figure(figsize=(12, 10))
    ax = fig.add_subplot(111, projection='3d')

    num_to_plot = min(num_particles, 2000)
    plot_indices = np.random.choice(num_particles, num_to_plot, replace=False)
    
    colors = radii[plot_indices]
    sc = ax.scatter(positions[plot_indices, 0], positions[plot_indices, 1], positions[plot_indices, 2], 
                    c=colors, cmap='plasma', s=5, label='Disk Particles')
    
    r_horizon = M_scale * (1 + np.sqrt(1 - a_spin**2))
    u = np.linspace(0, 2 * np.pi, 100)
    v = np.linspace(0, np.pi, 100)
    x_bh = r_horizon * np.outer(np.cos(u), np.sin(v))
    y_bh = r_horizon * np.outer(np.sin(u), np.sin(v))
    z_bh = r_horizon * np.outer(np.ones(np.size(u)), np.cos(v))
    ax.plot_surface(x_bh, y_bh, z_bh, color='black', alpha=0.9, rstride=4, cstride=4)
    ax.plot_wireframe(x_bh, y_bh, z_bh, color='dimgrey', alpha=0.2, rstride=10, cstride=10)

    ax.set_xlabel('X (M)', fontsize=12, labelpad=10)
    ax.set_ylabel('Y (M)', fontsize=12, labelpad=10)
    ax.set_zlabel('Z (M)', fontsize=12, labelpad=10)
    
    max_radius_plot = np.max(radii) * 1.1
    ax.set_xlim(-max_radius_plot, max_radius_plot)
    ax.set_ylim(-max_radius_plot, max_radius_plot)
    ax.set_zlim(-max_radius_plot/2, max_radius_plot/2)

    ax.set_title(f"Accretion Disk Snapshot from {os.path.basename(snapshot_to_load)}", fontsize=16)
    fig.colorbar(sc, ax=ax, shrink=0.6, aspect=10, label='Particle Radius (M)')
    ax.view_init(elev=45., azim=45)
    
    plt.show()

<a id='part_2b'></a>
### 2.b: Create Animation of Disk Evolution

This section provides the functions necessary to convert the entire sequence of snapshot files into a video. The process involves two steps:
1.  **`generate_animation_frames`**: This orchestrator function finds all snapshot files, then calls a helper (`_plot_single_frame`) to generate a `.png` image for each one.
2.  **`encode_video_from_frames`**: After all frames are generated, this function uses the external command-line tool `ffmpeg` to stitch the individual images into a single `.mp4` video file.

In [None]:
def _plot_single_frame(
    positions: np.ndarray,
    radii: np.ndarray,
    M_scale: float,
    a_spin: float,
    current_time: float,
    output_filename: str,
    fig_width_inches: float,
    fig_dpi: int
) -> None:
    """
    Plots a single frame of the disk animation and saves it to a file,
    guaranteeing the output image has even dimensions for video encoding.
    """
    # Calculate the desired pixel dimensions
    desired_width_px = fig_width_inches * fig_dpi
    
    # Get the aspect ratio from the data to calculate height
    x_range = np.max(positions[:, 0]) - np.min(positions[:, 0])
    y_range = np.max(positions[:, 1]) - np.min(positions[:, 1])
    aspect_ratio = y_range / x_range if x_range > 0 else 1.0
    desired_height_px = desired_width_px * aspect_ratio

    # Round down to the nearest even number
    final_width_px = int(desired_width_px // 2 * 2)
    final_height_px = int(desired_height_px // 2 * 2)
    
    # Recalculate figure size in inches to match the final pixel dimensions
    final_fig_width_inches = final_width_px / fig_dpi
    final_fig_height_inches = final_height_px / fig_dpi
    
    fig = plt.figure(figsize=(final_fig_width_inches, final_fig_height_inches))
    ax = fig.add_subplot(111, projection='3d')
    
    num_particles = len(positions)
    num_to_plot = min(num_particles, 5000)
    plot_indices = np.random.choice(num_particles, num_to_plot, replace=False)
    
    colors = radii[plot_indices]
    sc = ax.scatter(positions[plot_indices, 0], positions[plot_indices, 1], positions[plot_indices, 2], 
                    c=colors, cmap='plasma', s=5)
    
    r_horizon = M_scale * (1 + np.sqrt(1 - a_spin**2))
    u = np.linspace(0, 2 * np.pi, 100)
    v = np.linspace(0, np.pi, 100)
    x_bh = r_horizon * np.outer(np.cos(u), np.sin(v))
    y_bh = r_horizon * np.outer(np.sin(u), np.sin(v))
    z_bh = r_horizon * np.outer(np.ones(np.size(u)), np.cos(v))
    ax.plot_surface(x_bh, y_bh, z_bh, color='black', alpha=0.9, rstride=4, cstride=4)
    ax.plot_wireframe(x_bh, y_bh, z_bh, color='dimgrey', alpha=0.2, rstride=10, cstride=10)

    ax.set_xlabel('X (M)', fontsize=12, labelpad=10)
    ax.set_ylabel('Y (M)', fontsize=12, labelpad=10)
    ax.set_zlabel('Z (M)', fontsize=12, labelpad=10)
    
    max_radius_plot = np.max(radii) * 1.1
    ax.set_xlim(-max_radius_plot, max_radius_plot)
    ax.set_ylim(-max_radius_plot, max_radius_plot)
    ax.set_zlim(-max_radius_plot/2, max_radius_plot/2)

    ax.set_title(f"Accretion Disk Snapshot at t = {current_time:.2f} M", fontsize=16)
    fig.colorbar(sc, ax=ax, shrink=0.6, aspect=10, label='Particle Radius (M)')
    ax.view_init(elev=45., azim=45)
    
    # --- THE CRITICAL FIX IS HERE ---
    # Adjust layout first, then save without the problematic argument.
    plt.tight_layout()
    plt.savefig(output_filename, dpi=fig_dpi) # REMOVED: bbox_inches='tight'
    # --- END OF FIX ---

    plt.close(fig)

# This is the main function that calls the helper above. It does not need changes,
# but is included for completeness.
def generate_animation_frames(
    project_dir: str = "project/mass_integrator",
    output_folder: str = "output_GrandDesign_ISCO",
    frames_output_dir: str = "animation_frames",
    M_scale: float = 1.0,
    a_spin: float = 0.95,
    fig_width_inches: float = 12.0,
    fig_dpi: int = 150,
    overwrite_existing_frames: bool = False
) -> None:
    # ... (This function's body is unchanged from the previous answer) ...
    print("--- Starting Animation Frame Generation ---")
    plt.style.use('dark_background')
    
    snapshot_dir = os.path.join(project_dir, output_folder)
    if not os.path.isdir(snapshot_dir):
        print(f"ERROR: Snapshot directory not found at '{snapshot_dir}'")
        return

    snapshot_files = glob.glob(os.path.join(snapshot_dir, "mass_blueprint_t_*.bin"))
    if not snapshot_files:
        print(f"ERROR: No snapshot .bin files found in '{snapshot_dir}'")
        return

    snapshot_files.sort(key=lambda f: int(os.path.basename(f).split('_t_')[1].split('.bin')[0]))
    print(f"Found {len(snapshot_files)} snapshots to process.")

    full_frames_dir = os.path.join(project_dir, frames_output_dir)
    os.makedirs(full_frames_dir, exist_ok=True)

    if not overwrite_existing_frames:
        print("Overwrite is OFF. Will skip any frames that already exist.")

    snapshot_dtype = np.dtype([
        ('id', np.int32), ('pos', 'f8', (3,)), ('u', 'f8', (4,)),
        ('lambda_rest', 'f8'), ('j_intrinsic', 'f4')
    ])
    
    try:
        snapshot_every_t = par.parval_from_str("snapshot_every_t")
    except ValueError:
        snapshot_every_t = 1.0

    for i, snapshot_file in enumerate(snapshot_files):
        frame_filename = os.path.join(full_frames_dir, f"frame_{i:04d}.png")
        
        if not overwrite_existing_frames and os.path.exists(frame_filename):
            print(f"  Skipping frame {i+1}/{len(snapshot_files)}: {os.path.basename(frame_filename)} already exists.")
            continue

        print(f"  Processing frame {i+1}/{len(snapshot_files)}: {os.path.basename(snapshot_file)}")
        
        with open(snapshot_file, 'rb') as f:
            num_particles = np.fromfile(f, dtype=np.int32, count=1)[0]
            if num_particles == 0: continue
            data = np.fromfile(f, dtype=snapshot_dtype, count=num_particles)
        
        positions = data['pos']
        radii = np.sqrt(positions[:, 0]**2 + positions[:, 1]**2)
        
        snapshot_number = int(os.path.basename(snapshot_file).split('_t_')[1].split('.bin')[0])
        current_time = snapshot_number * snapshot_every_t
        
        _plot_single_frame(
            positions, radii, M_scale, a_spin, current_time, frame_filename,
            fig_width_inches=fig_width_inches, fig_dpi=fig_dpi
        )
    
    print("\n--- Frame generation complete. ---")

In [None]:
# --- HOW TO RUN THE DEFINITIVE FIX ---

# 1. Re-generate the frames. This is necessary because the files on disk are still wrong.
#    This will replace the 1371px-wide images with new, even-dimensioned ones.
print("STEP 1: Re-generating all frames with corrected dimensions...")
generate_animation_frames(
    project_dir="/home/daltonm/Documents/project/mass_integrator",
    output_folder="output_GrandDesign_ISCO",
    frames_output_dir="animation_frames_grand_design",
    M_scale=1.0,
    a_spin=0.95,
    overwrite_existing_frames=True
)
print("STEP 1 COMPLETE.\n")

# 2. Encode the video from the newly created, valid frames.
print("STEP 2: Encoding the new frames into a video...")
encode_video_from_frames(
    image_folder="/home/daltonm/Documents/project/mass_integrator/animation_frames_grand_design",
    output_video_path="/home/daltonm/Documents/project/mass_integrator/grand_design_disk_evolution.mp4",
    frame_rate=15
)
print("STEP 2 COMPLETE.")

# --- How to Run This Test ---
# 1. Run your C code in PRODUCTION mode (set run_in_debug_mode = false in the .par file).
# 2. This will create an 'output' folder with several .bin files.
# 3. Call this function. It will automatically find and plot the LAST snapshot.
output_folder="output_GrandDesign_ISCO"
visualize_disk_snapshot(output_folder=output_folder)

#plot_particle_trajectory(
#    project_dir="project/mass_integrator",
#    input_filename="massive_particle_path.txt",
#    M_scale=1.0,
#    a_spin=0.9
#)
# After running the C code, call this function.
plot_trajectory_components()
# Call this function after running your C code.
plot_radius_vs_time()
# --- How to run ---
plot_precession_validation()
# --- How to run ---
plot_apsidal_precession_validation()



### Video of disk movement

In [None]:
def encode_video_from_frames(
    image_folder: str,
    output_video_path: str,
    frame_rate: int = 30,
    crf: int = 18
) -> None:
    # ... (This function is unchanged) ...
    print(f"--- Starting Video Encoding ---")
    output_dir = os.path.dirname(output_video_path)
    if output_dir:
        os.makedirs(output_dir, exist_ok=True)
    command = [
        'ffmpeg', '-y', '-framerate', str(frame_rate),
        '-i', os.path.join(image_folder, 'frame_%04d.png'),
        '-c:v', 'libx264', '-pix_fmt', 'yuv420p',
        '-r', str(frame_rate), '-crf', str(crf),
        output_video_path
    ]
    print(f"Running FFmpeg command:\n{' '.join(command)}")
    try:
        result = subprocess.run(command, capture_output=True, text=True, check=True)
        print(f"\n[✓] Video encoding successful. File saved to '{output_video_path}'")
    except FileNotFoundError:
        print("\n[!] ERROR: FFmpeg not found.")
    except subprocess.CalledProcessError as e:
        print(f"\n[!] ERROR: FFmpeg failed with exit code {e.returncode}.")
        print("\n--- FFmpeg stderr ---")
        print(e.stderr)



<a id='part_3a'></a>
## Part 3: K-d Tree Pre-processor

### 3.a: Build K-d Trees from Snapshots

The `photon_geodesic_integrator` needs to perform millions of fast spatial queries to check for intersections with the accretion disk. To enable this, the raw snapshot data must be pre-processed into a more efficient data structure. This section provides the tools to build a **k-d tree** from each snapshot.

The process is handled by two main functions:
*   **`build_implicit_kdtree_kernel`**: A high-performance kernel, accelerated with `numba`, that takes particle positions and recursively partitions them to build the tree structure.
*   **`run_kdtree_preprocessor_parallel`**: The main driver that finds all raw `.bin` snapshots and uses a multiprocessing pool to run the kernel on all of them in parallel, creating the final `.kdtree.bin` files.

In [None]:
# ==============================================================================
#  STEP 1: Create the JIT-compiled "kernel" for the heavy lifting.
# ==============================================================================
# Numba works best with simple data types, so we pass NumPy arrays.
# The @numba.jit decorator is the key. 'nopython=True' ensures it compiles to fast machine code.
@numba.jit(nopython=True)
def build_implicit_kdtree_kernel(particle_pos_array, num_particles):
    """
    Numba-compiled kernel to perform the heavy lifting of building the implicit k-d tree.
    This function contains only the numerical logic that Numba can optimize.
    
    Args:
        particle_pos_array (np.ndarray): A NumPy array of shape (N, 3) with just particle positions.
        num_particles (int): The number of particles.
        
    Returns:
        tuple: (reordering_indices, node_metadata)
               - reordering_indices: An array that maps the old index to the new, reordered index.
               - node_metadata: The array of split axes for each node.
    """
    # These will store the final results
    reordering_indices = np.zeros(num_particles, dtype=np.int32)
    node_metadata = np.full(num_particles, -1, dtype=np.int32)

    # Numba doesn't have deque, but we can simulate a queue with a NumPy array and pointers.
    # This is a common pattern for JIT-compiling queue-based algorithms.
    # We need to pre-allocate a large enough queue. num_particles is a safe upper bound.
    q_target_idx = np.zeros(num_particles, dtype=np.int32)
    # This queue will store slices (start, end) into a temporary index array
    q_slice_start = np.zeros(num_particles, dtype=np.int32)
    q_slice_end = np.zeros(num_particles, dtype=np.int32)
    
    q_head = 0
    q_tail = 0

    # The array of original indices that we will sort and partition
    indices_array = np.arange(num_particles, dtype=np.int32)
    
    if num_particles > 0:
        # Enqueue the first item: target index 0, and the slice representing all particles
        q_target_idx[q_tail] = 0
        q_slice_start[q_tail] = 0
        q_slice_end[q_tail] = num_particles
        q_tail += 1

    while q_head < q_tail:
        # Dequeue an item
        target_idx = q_target_idx[q_head]
        start = q_slice_start[q_head]
        end = q_slice_end[q_head]
        q_head += 1

        # --- Adaptive Splitting Logic ---
        # Note: Slicing is faster inside Numba than fancy indexing
        current_indices = indices_array[start:end]
        current_positions = particle_pos_array[current_indices]
        
        min_x, min_y, min_z = current_positions[0]
        max_x, max_y, max_z = current_positions[0]
        for i in range(1, len(current_positions)):
            pos = current_positions[i]
            min_x, max_x = min(min_x, pos[0]), max(max_x, pos[0])
            min_y, max_y = min(min_y, pos[1]), max(max_y, pos[1])
            min_z, max_z = min(min_z, pos[2]), max(max_z, pos[2])

        spread_x = max_x - min_x
        spread_y = max_y - min_y
        spread_z = max_z - min_z
        
        # Find split axis (0=x, 1=y, 2=z)
        split_axis = 0
        if spread_y > spread_x: split_axis = 1
        if spread_z > spread_y and spread_z > spread_x: split_axis = 2

        # --- Partitioning ---
        # Sort the current slice of the main index array based on the split axis
        # This is the most performance-critical part
        sorted_indices_on_axis = current_indices[np.argsort(particle_pos_array[current_indices, split_axis])]
        indices_array[start:end] = sorted_indices_on_axis
        
        median_offset = len(sorted_indices_on_axis) // 2
        median_original_index = sorted_indices_on_axis[median_offset]

        # --- Place the Pivot Particle's Metadata ---
        reordering_indices[target_idx] = median_original_index
        node_metadata[target_idx] = split_axis
        
        # --- Prepare for Next Level (Enqueue children) ---
        left_child_idx = 2 * target_idx + 1
        right_child_idx = 2 * target_idx + 2
        
        # Left child points to the slice before the median
        if median_offset > 0:
            if left_child_idx < num_particles:
                q_target_idx[q_tail] = left_child_idx
                q_slice_start[q_tail] = start
                q_slice_end[q_tail] = start + median_offset
                q_tail += 1
        
        # Right child points to the slice after the median
        if median_offset + 1 < len(sorted_indices_on_axis):
            if right_child_idx < num_particles:
                q_target_idx[q_tail] = right_child_idx
                q_slice_start[q_tail] = start + median_offset + 1
                q_slice_end[q_tail] = end
                q_tail += 1
                
    return reordering_indices, node_metadata

# ==============================================================================
#  STEP 2: Create the main driver function that calls the kernel.
# ==============================================================================
def build_and_save_kdtree_snapshot_fast(raw_snapshot_file, output_dir):
    """
    Main driver function that handles file I/O and calls the fast Numba kernel.
    """
    raw_dtype = np.dtype([
        ('id', np.int32), ('pos', 'f8', (3,)), ('u', 'f8', (4,)),
        ('lambda_rest', 'f8'), ('j_intrinsic', 'f4')
    ])

    # --- Step 1: Load the Raw Data (Fast I/O) ---
    with open(raw_snapshot_file, 'rb') as f:
        num_particles = np.fromfile(f, dtype=np.int32, count=1)[0]
        if num_particles == 0:
            print("  Snapshot is empty, skipping.")
            return
        particle_data = np.fromfile(f, dtype=raw_dtype, count=num_particles)

    # --- Step 2: Call the Fast JIT-Compiled Kernel ---
    # We only pass the position data to the kernel for efficiency.
    particle_positions = np.ascontiguousarray(particle_data['pos'])
    reordering_map, node_metadata = build_implicit_kdtree_kernel(particle_positions, num_particles)
    
    # Use the returned map to reorder the full, original particle data array
    reordered_particles = particle_data[reordering_map]

    # --- Step 3: Save the Two Parallel Arrays (Fast I/O) ---
    output_filename = os.path.join(output_dir, os.path.basename(raw_snapshot_file).replace('.bin', '.kdtree.bin'))
    with open(output_filename, 'wb') as f:
        f.write(np.uint64(num_particles))
        f.write(np.uint64(3)) # dimensions
        f.write(node_metadata.tobytes())
        f.write(reordered_particles.tobytes())

In [None]:
def run_kdtree_preprocessor_parallel():
    """
    Finds all raw snapshot .bin files and processes them in parallel
    using a pool of worker processes.
    """
    base_project_dir = "project"
    mass_project_name = "mass_integrator"
    input_folder = os.path.join(base_project_dir, mass_project_name, "output")
    processed_folder = os.path.join(base_project_dir, "processed_snapshots")

    print(f"--- Starting K-d Tree Pre-processing (Parallel) ---")
    print(f"Input directory:  '{input_folder}'")
    print(f"Shared Output directory: '{processed_folder}'")

    os.makedirs(processed_folder, exist_ok=True)

    snapshot_files = sorted(glob.glob(os.path.join(input_folder, "mass_blueprint_t_*.bin")))
    
    if not snapshot_files:
        print("\nWARNING: No raw snapshot files found. Did you run the mass_integrator C code first?")
        return

    num_processes = cpu_count()
    print(f"\nFound {len(snapshot_files)} files. Processing in parallel using {num_processes} CPU cores...")

    task_function = partial(build_and_save_kdtree_snapshot_fast, output_dir=processed_folder)

    start_time = time.time()
    with Pool(processes=num_processes) as pool:
        list(tqdm(pool.imap_unordered(task_function, snapshot_files), total=len(snapshot_files)))
    
    end_time = time.time()
    
    print(f"\n--- All snapshots processed successfully in {end_time - start_time:.2f} seconds. ---")
    print(f"The query-ready .kdtree.bin files are now in the shared directory: '{processed_folder}'.")


# --- How to Run ---
# It's best practice to put the execution call inside this block
if __name__ == '__main__':
    # Make sure to call the parallel version!
    run_kdtree_preprocessor_parallel()

<a id='part_3b'></a>
### 3.b: Inspect a K-d Tree File

After building the k-d trees, it is useful to have a tool to verify that the output files are structured correctly. The `view_kdtree_snapshot` function reads a single `.kdtree.bin` file and prints a human-readable summary of its header and the first few particle records, confirming that the pre-processing was successful.

In [None]:
def view_kdtree_snapshot(
    project_dir: str = "project/mass_integrator",
    processed_folder: str = "processed_snapshots",
    snapshot_index: int = 0, # 0 for the first, -1 for the last
    max_nodes_to_print: int = 15
) -> None:
    """
    Reads a single, processed .kdtree.bin file and prints a detailed,
    human-readable summary to verify its contents and structure.
    
    UPDATED to read and display the full 4-velocity u^mu.
    """
    print("--- K-d Tree Blueprint Inspector ---")
    
    # --- Step 1: Find and select the snapshot file ---
    snapshot_dir = os.path.join(project_dir, processed_folder)
    if not os.path.isdir(snapshot_dir):
        print(f"ERROR: Processed snapshot directory not found at '{snapshot_dir}'")
        return

    snapshot_files = sorted(glob.glob(os.path.join(snapshot_dir, "*.kdtree.bin")))
    if not snapshot_files:
        print(f"ERROR: No .kdtree.bin files found in '{snapshot_dir}'")
        print("Please run the k-d tree pre-processor cell first.")
        return

    try:
        file_to_load = snapshot_files[snapshot_index]
    except IndexError:
        print(f"ERROR: Snapshot index {snapshot_index} is out of bounds. Only {len(snapshot_files)} files exist.")
        return
        
    print(f"Loading and inspecting file: '{os.path.basename(file_to_load)}'")

    # --- Step 2: Define the dtypes to read the file ---
    # MODIFICATION: This dtype must exactly match the new particle data struct in C
    particle_dtype = np.dtype([
        ('id', np.int32), 
        ('pos', 'f8', (3,)), 
        ('u', 'f8', (4,)), # Changed from 'u_spatial' to 'u'
        ('lambda_rest', 'f8'),
        ('j_intrinsic', 'f4')
    ])
    metadata_dtype = np.int32

    # --- Step 3: Read the binary file according to the custom format ---
    try:
        with open(file_to_load, 'rb') as f:
            num_particles = np.fromfile(f, dtype=np.uint64, count=1)[0]
            dimensions = np.fromfile(f, dtype=np.uint64, count=1)[0]
            node_metadata = np.fromfile(f, dtype=metadata_dtype, count=num_particles)
            particle_data = np.fromfile(f, dtype=particle_dtype, count=num_particles)
    except Exception as e:
        print(f"\nERROR: Failed to read or parse the binary file. Exception: {e}")
        return

    # --- Step 4: Print a summary and verify the contents ---
    print("\n--- File Header ---")
    print(f"  Number of Particles: {num_particles}")
    print(f"  Dimensions:          {dimensions}")
    
    if len(node_metadata) != num_particles or len(particle_data) != num_particles:
        print("\nERROR: Mismatch between header particle count and data array lengths!")
        return

    print("\n--- Data Verification ---")
    
    axis_map = {0: 'X', 1: 'Y', 2: 'Z', -1: 'LEAF'}
    
    # MODIFICATION: Updated header to show the full 4-velocity
    header = (f"{'Index':<6} | {'Split Axis':<10} | {'Particle ID':<12} | "
              f"{'Position (x, y, z)':<25} | {'4-Velocity (ut, ux, uy, uz)':<40}")
    print(header)
    print("-" * len(header))

    for i in range(min(num_particles, max_nodes_to_print)):
        split_axis_val = node_metadata[i]
        particle = particle_data[i]
        
        split_axis_str = axis_map.get(split_axis_val, 'INVALID')
        pos_str = (f"({particle['pos'][0]:>6.2f}, {particle['pos'][1]:>6.2f}, "
                   f"{particle['pos'][2]:>6.2f})")
        
        # MODIFICATION: Format the full 4-velocity for display
        vel_str = (f"({particle['u'][0]:>6.2f}, {particle['u'][1]:>6.2f}, "
                   f"{particle['u'][2]:>6.2f}, {particle['u'][3]:>6.2f})")
        
        print(f"{i:<6} | {split_axis_str:<10} | {particle['id']:<12} | "
              f"{pos_str:<25} | {vel_str:<40}")

    if num_particles > max_nodes_to_print:
        print("...")
        print(f"(... and {num_particles - max_nodes_to_print} more nodes)")

    print("\n--- Verification Complete ---")
    print("Check that the 4-Velocity column contains four components and looks reasonable.")
view_kdtree_snapshot()

<a id='part_4a'></a>
## Part 4: Verification Suite for Keplerian Orbits

### 4.a: Analytical Solution and Verification Driver

This section provides a comprehensive verification suite to test the numerical accuracy of the `mass_integrator` for stable, circular, equatorial (Keplerian) orbits.

*   **`get_analytical_solution`**: This function computes the exact, theoretical trajectory of a massive particle in a Keplerian orbit in the Kerr spacetime. It uses a numerically stable method to avoid precision issues near the ISCO.
*   **`verify_mass_integrator`**: This is the main driver function. For a given set of test cases (black hole spin `a` and orbital radius `r`), it:
    1.  Writes the analytical initial conditions to the debug file.
    2.  Runs the compiled C code in debug mode.
    3.  Reads the resulting numerical trajectory.
    4.  Compares the numerical result against the analytical solution point-by-point to calculate the error in position, radius, and orbital phase.
    5.  Generates plots visualizing the trajectory and the error growth over time.

In [None]:
def get_analytical_solution(tau_values, r_initial, M_scale, a_spin):
    """
    Calculates the analytical trajectory (t,x,y,z) for a circular, equatorial
    orbit in Kerr spacetime using a NUMERICALLY STABLE method.

    This version implements the three-step process that avoids catastrophic
    cancellation for orbits near the ISCO.

    Args:
        tau_values (np.ndarray): Array of proper time values.
        r_initial (float): The constant radius of the orbit.
        M_scale (float): The mass of the black hole.
        a_spin (float): The spin parameter of the black hole.

    Returns:
        tuple: (U_t, Omega_tau, analytical_positions)
    """
    # Use shorter variable names for clarity
    r = r_initial
    M = M_scale
    a = a_spin
    
    # --- Step 1: Calculate the stable angular velocity Omega = d(phi)/dt ---
    sqrt_M = np.sqrt(M)
    Omega = sqrt_M / (r**1.5 + a * sqrt_M)
    
    # --- Step 2: Calculate the required metric components in Boyer-Lindquist ---
    # These are specialized for the equatorial plane (theta=pi/2)
    g_tt = -(1 - 2*M/r)
    g_tphi = -2*a*M/r
    g_phiphi = r**2 + a**2 + (2*M*a**2)/r
    
    # --- Step 3: Solve for u^t and u^phi using the stable formulas ---
    ut_squared_inv_denom = g_tt + 2*g_tphi*Omega + g_phiphi*Omega**2
    
    # Check for instability. If the denominator is non-negative, the orbit is not possible.
    if ut_squared_inv_denom >= 0:
        # Return NaNs to signal that this orbit is unstable/invalid.
        nan_array = np.full((len(tau_values), 4), np.nan)
        return np.nan, np.nan, nan_array

    ut_squared = -1.0 / ut_squared_inv_denom
    U_t = np.sqrt(ut_squared)
    
    u_phi = Omega * U_t
    Omega_tau = u_phi # d(phi)/d(tau) is u^phi

    # --- Step 4: Calculate the trajectory over the given proper time values ---
    phi_of_tau = Omega_tau * tau_values
    t_of_tau = U_t * tau_values
    x_of_tau = r_initial * np.cos(phi_of_tau)
    y_of_tau = r_initial * np.sin(phi_of_tau)
    z_of_tau = np.zeros_like(tau_values)
    
    analytical_positions = np.vstack([t_of_tau, x_of_tau, y_of_tau, z_of_tau]).T
    
    return U_t, Omega_tau, analytical_positions

In [None]:
import numpy as np
import subprocess
import os
import matplotlib.pyplot as plt
from typing import List, Tuple

def verify_mass_integrator(
    test_cases: List[Tuple[float, float]],
    M_scale: float = 1.0,
    tau_max: float = 2000.0,
    project_dir: str = "project/mass_integrator",
    executable_name: str = "mass_integrator",
    output_folder: str = "verification_plots"
) -> None:
    """
    Runs a verification suite for the massive particle integrator for specific test cases.
    
    VERSION 3: Includes robust NaN checking to correctly handle and report unstable orbits.
    """
    print("--- Starting Mass Integrator Verification Suite ---")
    
    full_output_folder = os.path.join(project_dir, output_folder)
    os.makedirs(full_output_folder, exist_ok=True)
    print(f"Plots will be saved to: {full_output_folder}")
    
    for a_spin, r_initial in test_cases:
        print(f"\n--- Verifying r_initial={r_initial:.2f}, a_spin={a_spin:.2f} ---")

        # --- 1. Generate Initial Conditions ---
        U_t_initial, Omega_tau_initial, _ = get_analytical_solution(np.array([0.0]), r_initial, M_scale, a_spin)
        
        # *** ROBUSTNESS CHECK 1: Check if analytical solution itself is unstable ***
        if np.isnan(U_t_initial):
            print("  [✗] ANALYTICAL FAILURE: The chosen r_initial is inside the ISCO or numerically unstable.")
            print("      Skipping this test case.")
            continue
        # *** END OF CHECK ***

        u_x_initial = 0.0
        u_y_initial = r_initial * Omega_tau_initial
        u_z_initial = 0.0
        
        ic_filename = os.path.join(project_dir, "particle_debug_initial_conditions.txt")
        with open(ic_filename, "w") as f:
            f.write("# Format: t_initial pos_x pos_y pos_z u_x u_y u_z\n")
            f.write(f"0.0 {r_initial:.10f} 0.0 0.0   {u_x_initial:.10f} {u_y_initial:.10f} {u_z_initial:.10f}\n")
        
        # --- 2. Run the C Integrator ---
        par_filename = os.path.join(project_dir, "mass_integrator.par")
        with open(par_filename, "w") as f:
            f.write(f"run_in_debug_mode = True\n")
            f.write(f"a_spin = {a_spin:.10f}\n")
            f.write(f"M_scale = {M_scale:.10f}\n")
            f.write(f"t_max_integration = {tau_max * 2 * U_t_initial}\n")
            f.write(f"metric_choice = 0\n")

        output_path = os.path.join(project_dir, "massive_particle_path.txt")
        if os.path.exists(output_path):
            os.remove(output_path)

        try:
            subprocess.run(
                f"./{executable_name}", shell=True, capture_output=True, text=True, check=True, cwd=project_dir
            )
            print("  [✓] C Integrator ran successfully.")
        except subprocess.CalledProcessError as e:
            print(f"  [✗] ERROR: C Integrator failed for r={r_initial}, a={a_spin}.")
            print(e.stderr)
            continue

        # --- 3. Load Numerical and Generate Analytical Results ---
        try:
            numerical_data = np.loadtxt(output_path, skiprows=1)
            if numerical_data.size == 0:
                print("  [✗] NUMERICAL FAILURE: Output file is empty.")
                continue
            numerical_data = numerical_data[numerical_data[:, 0] <= tau_max]
            print(f"  [✓] Loaded {len(numerical_data)} numerical data points.")
        except Exception as e:
            print(f"  [✗] ERROR: Could not load or parse output file '{output_path}'. Error: {e}")
            continue

        # *** ROBUSTNESS CHECK 2: Check if numerical integrator produced NaNs ***
        if np.any(np.isnan(numerical_data)):
            print("  [✗] NUMERICAL FAILURE: The C integrator produced NaN values.")
            print("      This indicates the orbit was unstable as predicted. Skipping statistics.")
            continue
        # *** END OF CHECK ***

        tau_values = numerical_data[:, 0]
        _, _, analytical_data = get_analytical_solution(tau_values, r_initial, M_scale, a_spin)
        print("  [✓] Generated analytical ground truth.")

        # --- 4. Calculate Statistics ---
        pos_error = np.sqrt(np.sum((numerical_data[:, 2:5] - analytical_data[:, 1:4])**2, axis=1))
        radius_error = np.abs(np.sqrt(numerical_data[:, 2]**2 + numerical_data[:, 3]**2) - r_initial)
        
        phi_num = np.unwrap(np.arctan2(numerical_data[:, 3], numerical_data[:, 2]))
        phi_ana = np.unwrap(np.arctan2(analytical_data[:, 2], analytical_data[:, 1]))
        phase_error = np.abs(phi_num - phi_ana)

        print("\n  --- STATISTICAL REPORT ---")
        print(f"  Final Position Error:      {pos_error[-1]:.3e} M")
        print(f"  Max Radius Error (Drift):  {np.max(radius_error):.3e} M")
        print(f"  Mean Radius Error:         {np.mean(radius_error):.3e} M")
        print(f"  Max Phase Error (Timing):  {np.max(phase_error):.3e} radians")
        print(f"  Mean Phase Error:          {np.mean(phase_error):.3e} radians")
        
        # *** ROBUSTNESS CHECK 3: Corrected verdict logic ***
        if np.max(radius_error) > 1e-5 or np.max(phase_error) > 1e-5:
             print("\n  --- VERDICT ---\n  FAIL: Error metrics exceed tolerance.")
        else:
             print("\n  --- VERDICT ---\n  PASS: All error metrics are within tolerance.")
        # *** END OF CHECK ***

        # --- 5. Generate and Save Plots ---
        fig, axes = plt.subplots(1, 2, figsize=(18, 8))
        fig.suptitle(f"Verification for r_initial={r_initial:.2f}, a_spin={a_spin:.2f}", fontsize=16)

        axes[0].plot(analytical_data[:, 1], analytical_data[:, 2], 'r--', label='Analytical', lw=2)
        axes[0].plot(numerical_data[:, 2], numerical_data[:, 3], 'c-', label='Numerical', lw=1, alpha=0.8)
        axes[0].set_title("Orbit Trajectory (Top-Down View)")
        axes[0].set_xlabel("x (M)"); axes[0].set_ylabel("y (M)")
        axes[0].set_aspect('equal', 'box')
        axes[0].legend(); axes[0].grid(True)

        axes[1].plot(tau_values, radius_error, label='Radius Error |r_num - r_initial|')
        axes[1].plot(tau_values, phase_error, label='Phase Error |φ_num - φ_ana|')
        axes[1].set_title("Error Growth over Proper Time")
        axes[1].set_xlabel("Proper Time (τ) [M]"); axes[1].set_ylabel("Error")
        axes[1].set_yscale('log')
        axes[1].legend(); axes[1].grid(True)
        
        plot_filename = f"verification_r{r_initial:.2f}_a{a_spin:.2f}.png"
        full_plot_path = os.path.join(full_output_folder, plot_filename)
        plt.savefig(full_plot_path, dpi=150, bbox_inches='tight')
        print(f"  [✓] Plot saved to '{full_plot_path}'")
        
        plt.show()

## Example

In [None]:
# Zero-spin (Schwarzschild) orbits (a=0.0)
# ISCO is at r = 6.0 M
schwarzschild_test_cases = [
    # (a_spin, r_initial)
    (0.0, 6.0),    # At the ISCO
    (0.0, 6.5),    # Near-ISCO
    (0.0, 7.0),
    (0.0, 8.0),
    (0.0, 10.0),   # Intermediate orbit
    (0.0, 15.0),
    (0.0, 20.0),
    (0.0, 30.0),
    (0.0, 50.0),
    (0.0, 100.0)   # Weak-field orbit
]

# Example function call:
verify_mass_integrator(
     test_cases=schwarzschild_test_cases,
     output_folder="verification_schwarzschild_a0.00"
)