# Tutorial: A General-Purpose Photon Geodesic Integrator

## Author: Dalton Moone

## This notebook constructs a complete C-language project that numerically integrates the path of a photon. It features a modular design that can handle either a purely analytic spacetime metric or a numerical one represented by data on a grid. For numerical metrics, it uses `nrpy`'s finite difference capabilities to compute the Christoffel symbols. The system integrates 9 variables: the 8 for the geodesic equation and a 9th for the "proper distance" L.

**Notebook Status:** <font color='orange'><b>In Development</b></font>

**Validation Notes:** This notebook implements a highly flexible geodesic integrator. The user can select the `MetricType` parameter to switch between an "analytic" solution (where derivatives are computed symbolically) and a "numerical" solution (where derivatives are computed using finite differencing).

# Table of Contents

This notebook is organized as follows:

1.  [Step 1: Initialize `nrpy` and Project Parameters](#initialize)
2.  [Step 2: The "Smart" Symbolic Engine](#symbolic_engine)
3.  [Step 3: Define Spacetime Metrics](#define_metrics)
4.  [Step 4: Generate All C Functions](#generate_c_funcs)
5.  [Step 5: Assemble and Build the Final C Project](#assemble_project)

<a id='initialize'></a>
# Step 1: Initialize `nrpy` and Project Parameters

First, we import the necessary modules and set up our project directory. We also define a key parameter, `MetricType`, which will control the behavior of the entire notebook.

*   `MetricType = "analytic"`: The Christoffel symbols will be computed using symbolic derivatives (`sp.diff()`). This is useful for known, exact solutions.
*   `MetricType = "numerical"`: The Christoffel symbols will be computed using finite difference derivatives. This is the standard approach for numerical relativity simulations where the metric is given as data on a grid.

In [None]:
# =================================================================
# STEP 1: INITIALIZE NRPY AND PROJECT PARAMETERS
# =================================================================
import os
import shutil
import sympy as sp

# NRPy imports
import nrpy.c_function as cfc
import nrpy.c_codegen as ccg
import nrpy.params as par
import nrpy.indexedexp as ixp
import nrpy.finite_difference as fin
import nrpy.infrastructures.BHaH.BHaH_defines_h as Bdefines_h
import nrpy.infrastructures.BHaH.Makefile_helpers as Makefile
import nrpy.helpers.generic as gh

# Set project name and clean the output directory
project_name = "photon_geodesic_integrator"
project_dir = os.path.join("project", project_name)
shutil.rmtree(project_dir, ignore_errors=True)

# Set NRPy parameters
par.set_parval_from_str("Infrastructure", "BHaH")

# The crucial parameter that controls the notebook's behavior.
# Change to "numerical" to generate finite-difference code.
MetricType = "analytic" 

# THE FIX IS HERE: Use the correct, modern parameter name 'fd_order'.
# This parameter is only used when MetricType = "numerical".
# By default, nrpy uses fd_order=2 for first derivatives. We can set it here.
par.set_parval_from_str("fd_order", 4)

# The symbolic finite difference operators like fin.dD are automatically 
# available after importing nrpy.finite_difference.

<a id='symbolic_engine'></a>
# Step 2: The "Smart" Symbolic Engine

This is the core of our flexible design. The function `construct_geodesic_odes` now checks the `MetricType`.
*   If `analytic`, it uses `sp.diff()` for exact derivatives.
*   If `numerical`, it assumes the metric components are grid functions and uses `nrpy`'s symbolic finite difference operators to compute the derivatives.

In [None]:
# =================================================================
# STEP 2: THE "SMART" SYMBOLIC ENGINE
# =================================================================

def construct_geodesic_odes(g4DD, xx, metric_type="analytic"):
    """
    Constructs the symbolic RHS expressions for the 9 geodesic ODEs.
    This function can handle both analytic metrics (using sp.diff) and
    numerical metrics (using nrpy.finite_difference).
    """
    DIM = 4
    pU = ixp.declarerank1('pU', dimension=DIM)
    
    g4UU, _ = ixp.symm_matrix_inverter4x4(g4DD)

    g4DD_dD = ixp.zerorank3(dimension=DIM)
    if metric_type == "analytic":
        print("Using ANALYTIC derivatives for the Christoffel symbols.")
        for i in range(DIM):
            for j in range(DIM):
                for k in range(DIM):
                    g4DD_dD[i][j][k] = sp.diff(g4DD[i][j], xx[k])
    elif metric_type == "numerical":
        print("Using NUMERICAL finite difference derivatives for the Christoffel symbols.")
        for i in range(DIM):
            for j in range(DIM):
                for k in range(DIM):
                    g4DD_dD[i][j][k] = fin.dD(g4DD[i][j], k)
    else:
        raise ValueError(f"Unknown MetricType: {metric_type}")

    Gamma4UDD = ixp.zerorank3(dimension=DIM)
    for i in range(DIM):
        for j in range(DIM):
            for k in range(DIM):
                for l in range(DIM):
                    Gamma4UDD[k][i][j] += sp.Rational(1, 2) * g4UU[k][l] * (g4DD_dD[j][l][i] + g4DD_dD[i][l][j] - g4DD_dD[i][j][l])

    pos_rhs = [pU[0], pU[1], pU[2], pU[3]]
    mom_rhs = ixp.zerorank1(dimension=DIM)
    for a in range(DIM):
        for mu in range(DIM):
            for nu in range(DIM):
                mom_rhs[a] += -Gamma4UDD[a][mu][nu] * pU[mu] * pU[nu]

    dL_dlambda_squared = sum(g4DD[i][j] * pU[i] * pU[j] for i in range(1, DIM) for j in range(1, DIM))
    dL_dlambda = sp.sqrt(dL_dlambda_squared)

    return [
        pos_rhs[0], pos_rhs[1], pos_rhs[2], pos_rhs[3],
        mom_rhs[0], mom_rhs[1], mom_rhs[2], mom_rhs[3],
        dL_dlambda
    ]

<a id='define_metrics'></a>
# Step 3: Define Spacetime Metrics

This section contains functions that define the analytic form of various spacetime metrics. Each function returns the metric tensor `g4DD` and the coordinate system `xx`. In the next step, we will decide whether to use this analytic form directly or to convert it into a numerical grid function representation.


In [None]:
# =================================================================
# STEP 3: DEFINE SPACETIME METRICS
# =================================================================

def define_schwarzschild_metric_analytic():
    """
    Defines and returns the ANALYTIC Schwarzschild metric tensor and its 
    corresponding coordinate system.
    """
    M, r, th = sp.symbols("M r th", real=True)
    
    # THE FIX IS HERE: For the analytic case, xx should be a simple list of symbols.
    xx = [sp.Symbol("t"), r, th, sp.Symbol("ph")]

    g4DD = ixp.zerorank2(dimension=4)
    g4DD[0][0] = -(1 - 2*M/r)
    g4DD[1][1] = 1 / (1 - 2*M/r)
    g4DD[2][2] = r**2
    g4DD[3][3] = r**2 * sp.sin(th)**2
    
    return g4DD, xx


<a id='generate_c_funcs'></a>
# Step 4: Generate All C Functions

This is the main "factory" section. It orchestrates the entire Python-to-C process. It will:
1. Read the `MetricType` parameter.
2. Call the appropriate metric function to get the symbolic metric tensor.
3. If the type is `numerical`, it will convert the analytic metric into a set of `IndexedBase` grid functions.
4. It passes the metric and coordinates to our "smart" symbolic engine.
5. It converts the resulting symbolic expressions into highly optimized C code using `nrpy`'s `c_codegen` function, which automatically handles finite differencing if needed.
6. It registers all other necessary C functions for the project.

In [None]:
# =================================================================
# STEP 4: GENERATE ALL C FUNCTIONS
# =================================================================

# Helper function to register the quadratic root-finder for r-crossings.
def register_CFunction_r_root_finder_quadratic():
    print("Registering C function for quadratic r-coordinate root-finding...")
    c_params = "double r0, double r1, double r2, double target_r"
    c_func_body = r"""
    const double a = (r0 - 2.0 * r1 + r2) / 2.0;
    const double b = (r2 - r0) / 2.0;
    const double c = r1 - target_r;
    if (fabs(a) < 1e-14) {
        if (fabs(b) < 1e-14) { return -1.0; }
        return -c / b;
    }
    const double discriminant = b*b - 4.0*a*c;
    if (discriminant < 0.0) { return -1.0; }
    const double sqrt_discriminant = sqrt(discriminant);
    const double s_plus = (-b + sqrt_discriminant) / (2.0 * a);
    const double s_minus = (-b - sqrt_discriminant) / (2.0 * a);
    if (s_plus >= 0.0 && s_plus <= 1.0) { return s_plus; }
    if (s_minus >= 0.0 && s_minus <= 1.0) { return s_minus; }
    return -1.0;
    """
    cfc.register_CFunction(
        name="find_r_crossing_s_quadratic",
        desc="Finds 's' where r crosses a target value using the quadratic formula.",
        cfunc_type="double", params=c_params, body=c_func_body, includes=["math.h"]
    )

# Helper function to symbolically generate and register the Lagrange interpolation function.
def register_CFunction_Lagrange_Interpolation_symbolic():
    print("Registering C function for Lagrange interpolation (from SymPy)...")
    s = sp.symbols("s")
    y0 = ixp.declarerank1("y0", dimension=9)
    y1 = ixp.declarerank1("y1", dimension=9)
    y2 = ixp.declarerank1("y2", dimension=9)
    
    L0 = (s * s - s) / 2
    L1 = 1 - s * s
    L2 = (s * s + s) / 2
    
    y_out_expressions = []
    for i in range(9):
        y_out_expressions.append(y0[i] * L0 + y1[i] * L1 + y2[i] * L2)
        
    c_code_body = ccg.c_codegen(
        y_out_expressions,
        [f"y_out[{i}]" for i in range(9)],
        enable_cse=True, cse_varprefix="interp_expr"
    )
    
    for i in range(9):
        c_code_body = c_code_body.replace(f"y0{i}", f"y0[{i}]")
        c_code_body = c_code_body.replace(f"y1{i}", f"y1[{i}]")
        c_code_body = c_code_body.replace(f"y2{i}", f"y2[{i}]")
    
    cfc.register_CFunction(
        name="interpolation_state_at_s_lagrange",
        desc="Uses 2nd-order Lagrange interpolation (from SymPy) for a 9-variable state.",
        params="const double y0[9], const double y1[9], const double y2[9], const double s, double y_out[9]",
        body=c_code_body,
        includes=["BHaH_defines.h"]
    )

# Main orchestrator function for this step.
def generate_c_functions(metric_type="analytic"):
    print(f"Generating C functions for MetricType = '{metric_type}'...")
    
    # --- 1. Get the metric and coordinate system ---
    g4DD_analytic, xx = define_schwarzschild_metric_analytic()
    
    # --- 2. Get the symbolic ODE RHS expressions ---
    if metric_type == "analytic":
        rhs_expressions = construct_geodesic_odes(g4DD_analytic, xx, metric_type)
        c_declarations = "const REAL r = y[1];\nconst REAL th = y[2];\n"
    elif metric_type == "numerical":
        ixp.prefix_for_symbolic_expressions = "g4DD" # Set a prefix for the gridfunctions
        g4DD_gridfuncs = ixp.register_gridfunctions_for_rank2("g4DD", "sym", DIM=4)
        rhs_expressions = construct_geodesic_odes(g4DD_gridfuncs, xx, metric_type)
        c_declarations = ""
    else:
        raise ValueError(f"Unknown MetricType: {metric_type}")
    
    # --- 3. Generate and register the ODE RHS C function ---
    pU_sym = ixp.declarerank1('pU', dimension=4)
    sub_dict = {pU_sym[0]: sp.Symbol("y[4]"), pU_sym[1]: sp.Symbol("y[5]"), 
                pU_sym[2]: sp.Symbol("y[6]"), pU_sym[3]: sp.Symbol("y[7]")}
    
    rhs_expressions_subbed = [expr.subs(sub_dict) for expr in rhs_expressions]
    
    results = ccg.c_codegen(
        rhs_expressions_subbed,
        [f"rhs_out[{i}]" for i in range(9)],
        enable_cse=True, cse_varprefix="expr"
    )
    
    c_code_body_rhs = c_declarations + results
    # THE FIX IS HERE: Added the required 'desc' argument.
    cfc.register_CFunction(
        name="calculate_ode_rhs",
        desc="Calculates the RHS of the 9 ODEs.",
        params="const double M, const double y[9], double rhs_out[9]",
        body=c_code_body_rhs, includes=["BHaH_defines.h", "math.h"]
    )
    
    # --- 4. Generate and register the GSL wrapper ---
    c_func_body_gsl = r"""
    (void)lambda;
    const photon_params *p = (const photon_params *)params;
    calculate_ode_rhs(p->M, y, f);
    return GSL_SUCCESS;
    """
    # THE FIX IS HERE: Added the required 'desc' argument.
    cfc.register_CFunction(
        name="ode_gsl_wrapper",
        desc="A GSL-compatible wrapper for the ODE RHS function.",
        cfunc_type="int",
        params="double lambda, const double y[], double f[], void *params",
        body=c_func_body_gsl, includes=["BHaH_defines.h", "BHaH_function_prototypes.h", "gsl/gsl_errno.h"]
    )
    
    # --- 5. Generate and register other C utility functions ---
    register_CFunction_Lagrange_Interpolation_symbolic()
    register_CFunction_r_root_finder_quadratic()

    # --- 6. Generate and register the main integration loop ---
    c_params_main_loop = """
    const double M, const double initial_lambda,
    const double d_lambda_initial, const int num_steps,
    const double start_y[9], const double r_max,
    const double t_max, path_struct *path_out
    """
    c_func_body_main_loop = r"""
    const gsl_odeiv2_step_type * T = gsl_odeiv2_step_rkf45;
    gsl_odeiv2_step * s = gsl_odeiv2_step_alloc(T, 9);
    gsl_odeiv2_control * c = gsl_odeiv2_control_y_new(1e-9, 1e-9);
    gsl_odeiv2_evolve * e = gsl_odeiv2_evolve_alloc(9);
    
    photon_params p = {.M = M};
    gsl_odeiv2_system sys = {ode_gsl_wrapper, NULL, 9, &p};

    double lambda = initial_lambda;
    double d_lambda = d_lambda_initial;
    double y[9], y_prev[9], y_prev_prev[9];

    for (int j = 0; j < 9; j++) { y[j] = start_y[j]; y_prev[j] = start_y[j]; y_prev_prev[j] = start_y[j]; }

    path_out->lambda[0] = lambda; path_out->t[0] = y[0]; path_out->r[0] = y[1];
    path_out->theta[0] = y[2]; path_out->phi[0] = y[3]; path_out->L[0] = y[8];
    path_out->actual_steps = 0;
    path_out->reason = IN_FLIGHT;

    int i = 0;
    while (i < num_steps) {
        for(int j=0; j<9; j++) { y_prev_prev[j] = y_prev[j]; y_prev[j] = y[j]; }
        
        double lambda_prev_step = lambda;
        int status = gsl_odeiv2_evolve_apply(e, c, s, &sys, &lambda, 1e10, &d_lambda, y);
        
        if (status != GSL_SUCCESS) {
            path_out->reason = MAX_STEPS_REACHED;
            break;
        }
        i++;
        
        path_out->actual_steps = i;
        path_out->lambda[i] = lambda;
        path_out->t[i] = y[0]; path_out->r[i] = y[1];
        path_out->theta[i] = y[2]; path_out->phi[i] = y[3];
        path_out->L[i] = y[8];
        
        if (i > 1 && y[1] <= 2.0 * M && y_prev[1] > 2.0 * M) {
            path_out->reason = HIT_HORIZON;
            const double s_r = find_r_crossing_s_quadratic(y_prev_prev[1], y_prev[1], y[1], 2.0 * M);
            if (s_r >= 0.0 && s_r <= 1.0) {
                double y_horizon[9];
                interpolation_state_at_s_lagrange(y_prev_prev, y_prev, y, s_r, y_horizon);
                path_out->lambda[i] = lambda_prev_step + s_r * (lambda - lambda_prev_step);
                path_out->t[i] = y_horizon[0];
                path_out->r[i] = y_horizon[1];
                path_out->theta[i] = y_horizon[2];
                path_out->phi[i] = y_horizon[3];
                path_out->L[i] = y_horizon[8];
            }
            break;
        } else if (y[1] > r_max || y[0] > t_max) {
            path_out->reason = (y[1] > r_max) ? ESCAPED_R_MAX : ESCAPED_T_MAX;
            break;
        }
    }
    gsl_odeiv2_evolve_free(e); gsl_odeiv2_control_free(c); gsl_odeiv2_step_free(s);
    """
    cfc.register_CFunction(
        name="integrate_single_photon",
        desc="Integrates 9 ODEs (geodesic + proper distance) using an adaptive loop.",
        params=c_params_main_loop, body=c_func_body_main_loop,
        includes=["BHaH_defines.h", "BHaH_function_prototypes.h", "gsl/gsl_errno.h", "gsl/gsl_odeiv2.h"]
    )

<a id='assemble_project'></a>
# Step 5: Assemble and Build the Final C Project

This final step brings everything together. We define the C data structures needed by our functions, generate the `main()` C function that sets up the initial conditions and calls the integrator, and finally, construct the `Makefile` to compile the entire project.

In [None]:
# =================================================================
# STEP 5: ASSEMBLE AND BUILD THE FINAL C PROJECT
# =================================================================

def register_CFunction_main():
    print("Registering main() C function...")
    c_func_body_main = r"""
#define M_PI 3.14159265358979323846
(void)argc; (void)argv;

const double M = 1.0;
const double r_max = 50.0;
const double t_max = 200.0;
double start_y[9];

start_y[0] = 0.0;
start_y[1] = 10.0 * M;
start_y[2] = M_PI / 2.0;
start_y[3] = 0.0;
start_y[6] = 0.0;
start_y[5] = -1.0;
start_y[7] = 3.85 / start_y[1];
const double r_start = start_y[1];
const double g00 = -(1.0 - 2.0*M/r_start);
const double g11 = 1.0 / (1.0 - 2.0*M/r_start);
const double g33 = r_start*r_start * sin(start_y[2])*sin(start_y[2]);
start_y[4] = sqrt((-g11*start_y[5]*start_y[5] - g33*start_y[7]*start_y[7]) / g00);
start_y[8] = 0.0;

const double initial_lambda = 0.0;
const double d_lambda_initial = 0.01;
const int num_steps = 80000;

path_struct path_results;
path_results.num_steps = num_steps;
path_results.lambda = (double*)malloc(sizeof(double) * (num_steps + 1));
path_results.t      = (double*)malloc(sizeof(double) * (num_steps + 1));
path_results.r      = (double*)malloc(sizeof(double) * (num_steps + 1));
path_results.theta  = (double*)malloc(sizeof(double) * (num_steps + 1));
path_results.phi    = (double*)malloc(sizeof(double) * (num_steps + 1));
path_results.L      = (double*)malloc(sizeof(double) * (num_steps + 1));

printf("Starting photon integration...\n");
integrate_single_photon(M, initial_lambda, d_lambda_initial, num_steps, start_y, 
                        r_max, t_max, &path_results);
printf("Integration finished after %d steps. Reason code: %d\n", path_results.actual_steps, path_results.reason);

FILE *fp = fopen("photon_path.txt", "w");
if (fp == NULL) { return 1; }
fprintf(fp, "# lambda\tt\tr\tphi\tL\n");
for (int i = 0; i <= path_results.actual_steps; i++) {
    fprintf(fp, "%.4f\t%.4f\t%.4f\t%.4f\t%.4f\n",
           path_results.lambda[i], path_results.t[i], path_results.r[i],
           path_results.phi[i], path_results.L[i]);
}
fclose(fp);
printf("Path data written to photon_path.txt\n");

free(path_results.lambda); free(path_results.t); free(path_results.r);
free(path_results.theta); free(path_results.phi); free(path_results.L);
return 0;
"""
    cfc.register_CFunction(
        name="main", desc="Sets up and runs a single photon integration.", cfunc_type="int",
        params="int argc, char *argv[]", body=c_func_body_main,
        includes=["BHaH_defines.h", "BHaH_function_prototypes.h", "math.h", "stdio.h", "stdlib.h"]
    )

# --- Main execution block ---
print("\nAssembling and building C project...")

os.makedirs(project_dir, exist_ok=True)

# Call the main orchestrator function from Step 4, passing the global MetricType.
generate_c_functions(MetricType)
register_CFunction_main()

photon_params_struct_str = "typedef struct { double M; } photon_params;"
termination_reason_enum_str = """
typedef enum {
    IN_FLIGHT, ESCAPED_R_MAX, ESCAPED_T_MAX,
    HIT_HORIZON, MAX_STEPS_REACHED
} termination_reason_t;
"""
photon_path_struct_str = """
typedef struct {
    int num_steps;
    int actual_steps;
    termination_reason_t reason;
    double *lambda, *t, *r, *theta, *phi, *L;
} path_struct;
"""

custom_structs_and_enums = f"{photon_params_struct_str}\n\n{termination_reason_enum_str}\n\n{photon_path_struct_str}"
with open(os.path.join(project_dir, "custom_structs.h"), "w") as file:
    file.write(custom_structs_and_enums)

Bdefines_h.output_BHaH_defines_h(
    project_dir=project_dir,
    additional_includes=["custom_structs.h"]
)

print("Copying SIMD intrinsics header...")
gh.copy_files(
    package="nrpy.helpers",
    filenames_list=["simd_intrinsics.h"],
    project_dir=project_dir,
    subdirectory="simd",
)

Makefile.output_CFunctions_function_prototypes_and_construct_Makefile(
    project_dir=project_dir,
    project_name=project_name,
    exec_or_library_name=project_name,
    addl_CFLAGS=["-Wall -Wextra -g $(shell gsl-config --cflags)"],
    addl_libraries=["$(shell gsl-config --libs)"],
)

print(f"\nFinished! A C project has been generated in {project_dir}/")
print(f"To build, navigate to this directory in your terminal and type 'make'.")
print(f"To run, type './{project_name}'.")