In [11]:
from adam_core.utils import get_perturber_state
from adam_core.time import Timestamp
from adam_core.coordinates.origin import OriginCodes
import numpy as np


launch_time = Timestamp.from_mjd([60600.0], scale="tai")
tof = 30 * 6
arrival_time = launch_time.add_days(tof)
launch_body = get_perturber_state(OriginCodes.EARTH, launch_time, frame="ecliptic", origin=OriginCodes.SUN)
arrival_body = get_perturber_state(OriginCodes.MARS_BARYCENTER, arrival_time, frame="ecliptic", origin=OriginCodes.SUN)


def create_time_grid(earlist_launch_time: Timestamp, maximum_arrival_time: Timestamp, step_size: float) -> np.ndarray:

    time_grid = np.arange(earlist_launch_time.rescale("tdb").mjd()[0].as_py(), maximum_arrival_time.rescale("tdb").mjd()[0].as_py(), step_size)
    x, y = np.meshgrid(time_grid, time_grid)
    return x.flatten(), y.flatten()

time_grid = create_time_grid(launch_time, arrival_time, 5)
time_grid

(array([60600.00037248, 60605.00037248, 60610.00037248, ...,
        60770.00037248, 60775.00037248, 60780.00037248]),
 array([60600.00037248, 60600.00037248, 60600.00037248, ...,
        60780.00037248, 60780.00037248, 60780.00037248]))

In [None]:
import jax.numpy as jnp
import numpy as np

from typing import Tuple
from adam_core.constants import Constants

MU = Constants.MU

def solve_lambert(r1: np.ndarray, r2: np.ndarray, tof: float, mu: float, algorithm: str = "izzo") -> Tuple[np.ndarray, np.ndarray, int]:

    if algorithm == "izzo":
        return solve_lambert_izzo(r1, r2, tof)
    else:
        raise ValueError(f"Algorithm {algorithm} not supported")



def solve_lambert_izzo(r1: np.ndarray, r2: np.ndarray, tof: float, mu: float) -> Tuple[np.ndarray, np.ndarray, int]:
    """
    Solve Lambert's problem using the Izzo algorithm.
    """
    assert tof > 0, "TOF must be positive"
    assert mu > 0, "MU must be positive"

    c = r2 - r1
    c_mag = np.linalg.norm(c)
    r1_mag = np.linalg.norm(r1)
    r2_mag = np.linalg.norm(r2)

    s = 0.5 * (r1_mag + r2_mag + c_mag)
   
    r1_hat = r1 / r1_mag
    r2_hat = r2 / r2_mag

    r1xr2_hat = np.cross(r1_hat, r2_hat)
    lambd2 = 1 - c_mag / s
    lambd = np.sqrt(lambd2)

    # TODO: What about z?
    cond = (r1[0] * r2[1]) - (r1[1] * r2[0])
    cond2 = np.dot(r1[:2], r2[:2])
    print(cond, cond2)
    if cond < 0:
        lambd = -lambd
        i_t1 = np.cross(r1_hat, r1xr2_hat)
        i_t2 = np.cross(r2_hat, r2_hat)
    else:
        i_t1 = np.cross(r1xr2_hat, r1_hat)
        i_t2 = np.cross(r1xr2_hat, r1xr2_hat)
            
    T = np.sqrt(2 * mu / s**3) * tof
    
    # some call to magical find xy
    gamma = np.sqrt(mu * s / 2)
    rho = (r1_mag - r2_mag) / c_mag
    sigma = np.sqrt(1 - rho**2)
    
    # More todo here

def findxy(lambd, T):
    assert np.abs(lambd) < 1, "lambd must be between -1 and 1"
    assert T < 0, "T must be negative"

    Mmax = np.floor(T / np.pi)
    T00 = np.arccos(lambd) + lambd * np.sqrt(1 - lambd**2)
    if T < T00 + Mmax * np.pi and Mmax > 0:


In [None]:
import jax.numpy as jnp
from jax import grad, jit, lax

@jit
def halley_method(f, x0, max_iter=100, tol=1e-12):
    f_prime = grad(f)
    f_double_prime = grad(f_prime)
    
    def cond_fun(state):
        x, delta, i = state
        return (jnp.abs(delta) > tol) & (i < max_iter)
    
    def body_fun(state):
        x, _, i = state
        fx = f(x)
        fpx = f_prime(x)
        fppx = f_double_prime(x)
        
        denominator = 2 * fpx**2 - fx * fppx
        denominator = jnp.where(jnp.abs(denominator) < 1e-14, 1e-14 * jnp.sign(denominator), denominator)
        
        delta = (2 * fx * fpx) / denominator
        new_x = x - delta
        
        return new_x, delta, i + 1
    
    init_state = (x0, jnp.inf, 0)
    final_state = lax.while_loop(cond_fun, body_fun, init_state)
    
    return final_state[0]

@jit
def householder_method(f, x0, order=3, max_iter=100, tol=1e-12):
    """
    Generalized Householder method for root finding with automatic differentiation.
    
    Parameters
    ----------
    f : callable
        Function to find root of
    x0 : float or jnp.ndarray
        Initial guess
    order : int
        Order of the method (3 = Halley's method, 2 = Newton's method)
    max_iter : int
        Maximum number of iterations
    tol : float
        Convergence tolerance
        
    Returns
    -------
    x : float or jnp.ndarray
        Root of the function
    """
    # Compute derivatives up to order-1
    derivatives = [f]
    for i in range(1, order):
        derivatives.append(grad(derivatives[i-1]))
    
    def cond_fun(state):
        x, delta, i = state
        return (jnp.abs(delta) > tol) & (i < max_iter)
    
    def body_fun(state):
        x, _, i = state
        
        # Compute function and derivative values
        f_values = [d(x) for d in derivatives]
        
        if order == 2:  # Newton's method
            delta = f_values[0] / f_values[1]
        elif order == 3:  # Halley's method
            denominator = 2 * f_values[1]**2 - f_values[0] * f_values[2]
            denominator = jnp.where(jnp.abs(denominator) < 1e-14, 1e-14 * jnp.sign(denominator), denominator)
            delta = (2 * f_values[0] * f_values[1]) / denominator
        else:  # Higher-order Householder
            # Compute the terms for the Householder update
            # We use logarithmic derivatives approach for numerical stability
            L1 = f_values[1] / f_values[0]
            terms = [1.0]
            
            for k in range(1, order-1):
                term = 0
                for j in range(1, k+1):
                    # Using Bell polynomial structure for higher derivatives
                    prod = f_values[j+1] / f_values[1]
                    for m in range(1, j):
                        prod *= (k - m + 1) / (j - m + 1)
                    term += prod * terms[k-j]
                terms.append(term)
            
            # Compute the update
            sum_term = 0
            for k in range(1, order-1):
                sum_term += terms[k] / (k+1)
            
            delta = f_values[0] / f_values[1] * (1 / (1 - f_values[0] * sum_term))
        
        new_x = x - delta
        return new_x, delta, i + 1
    
    init_state = (x0, jnp.inf, 0)
    final_state = lax.while_loop(cond_fun, body_fun, init_state)
    
    return final_state[0]

In [12]:
import jax.numpy as jnp
from jax import grad, jit, lax

@jit
def householder_method(f, x0, order=3, max_iter=100, tol=1e-12):
    """
    Generalized Householder method using lax primitives for root finding.
    
    Parameters
    ----------
    f : callable
        Function to find root of
    x0 : float or jnp.ndarray
        Initial guess
    order : int
        Order of the method (3 = Halley's method, 2 = Newton's method)
    max_iter : int
        Maximum number of iterations
    tol : float
        Convergence tolerance
        
    Returns
    -------
    x : float or jnp.ndarray
        Root of the function
    """
    # Compute derivatives with lax.scan
    def deriv_step(prev_deriv, _):
        return grad(prev_deriv), None
    
    derivatives = [f]
    _, _ = lax.scan(
        lambda prev_d, i: (grad(prev_d), None),
        f,
        jnp.arange(order-1)
    )
    
    # Precompute derivatives to avoid retracing in the loop
    derivatives = [f]
    for i in range(1, order):
        derivatives.append(grad(derivatives[i-1]))
    
    def cond_fun(state):
        x, delta, i = state
        return (jnp.abs(delta) > tol) & (i < max_iter)
    
    def newton_update(x, f_vals):
        return f_vals[0] / f_vals[1]
    
    def halley_update(x, f_vals):
        denominator = 2 * f_vals[1]**2 - f_vals[0] * f_vals[2]
        denominator = jnp.where(jnp.abs(denominator) < 1e-14, 1e-14 * jnp.sign(denominator), denominator)
        return (2 * f_vals[0] * f_vals[1]) / denominator
    
    def householder_update(x, f_vals, order):
        # Initialize terms
        L1 = f_vals[1] / f_vals[0]
        terms = [1.0]
        
        # Compute terms recursively with lax.fori_loop
        def outer_body(k, terms):
            # Initialize term for this k
            term = 0.0
            
            # Inner loop with lax.fori_loop
            def inner_body(j, term):
                prod = f_vals[j+1] / f_vals[1]
                
                # Innermost product loop with lax.fori_loop
                def prod_body(m, prod):
                    return prod * (k - m + 1) / (j - m + 1)
                
                prod = lax.fori_loop(1, j, prod_body, prod)
                return term + prod * terms[k-j]
            
            term = lax.fori_loop(1, k+1, inner_body, term)
            return terms + [term]
        
        terms = lax.fori_loop(1, order-1, outer_body, terms)
        
        # Compute final sum term
        def sum_body(i, acc):
            return acc + terms[i] / (i+1)
        
        sum_term = lax.fori_loop(1, order-1, sum_body, 0.0)
        
        return f_vals[0] / f_vals[1] * (1 / (1 - f_vals[0] * sum_term))
    
    def body_fun(state):
        x, _, i = state
        
        # Compute function and derivative values
        f_values = [d(x) for d in derivatives]
        
        # Choose update method based on order
        delta = lax.cond(
            order == 2,
            lambda: newton_update(x, f_values),
            lambda: lax.cond(
                order == 3,
                lambda: halley_update(x, f_values),
                lambda: householder_update(x, f_values, order)
            )
        )
        
        new_x = x - delta
        return new_x, delta, i + 1
    
    init_state = (x0, jnp.inf, 0)
    final_state = lax.while_loop(cond_fun, body_fun, init_state)
    
    return final_state[0]

In [13]:
def complex_func(x):
    return x**3 - 2*x + 2

x0 = jnp.array(1.0)

householder_method(complex_func, x0)

TypeError: Error interpreting argument to <function householder_method at 0x16ced7060> as an abstract array. The problematic value is of type <class 'function'> and was passed to the function at path f.
This typically means that a jit-wrapped function was called with a non-array argument, and this argument was not marked as static using the static_argnums or static_argnames parameters of jax.jit.