# Convergence

> Scripts to perform the refinement on the generated orbits

In [1]:
#| default_exp convergence

In [2]:
#| export
#| hide
from julia.api import Julia
jl = Julia(compiled_modules=False)
from julia import Main
import logging
import numpy as np
import pandas as pd

In [3]:
#| export
#| hide
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

## Julia Wrapper

In [4]:
#| export
Main.include("../julia/convergence_algorithm.jl")

def differential_correction(
    orbit: np.ndarray,
    μ: float,
    variable_time: bool = True,
    time_flight: float = None,
    jacobi_constant: float = None,
    X_end: np.ndarray = None,
    tol: float = 1e-9,
    max_iter: int = 20,
    printout: bool = False,
    DX_0: np.ndarray = None,
    X_big_0: np.ndarray = None,
    δ: float = None
):
    """
    Wrapper for the Julia differential_correction function.
    
    Parameters:
        orbit (np.ndarray): Orbit data with shape [num_timesteps, 7].
        μ (float): Gravitational parameter.
        variable_time (bool): Whether to use variable time nodes.
        time_flight (float, optional): Total time of flight.
        jacobi_constant (float, optional): Jacobi constant.
        X_end (np.ndarray, optional): Terminal state vector (shape: [6]).
        tol (float): Tolerance for convergence.
        max_iter (int): Maximum number of iterations.
        printout (bool): Whether to print iteration logs.
        DX_0 (np.ndarray, optional): Initial guess for state vector correction.
        X_big_0 (np.ndarray, optional): Auxiliary initial guess.
        δ (float, optional): Step size or perturbation parameter.

    Returns:
        tuple: (X_corrected, t_corrected, norm_F_or_G, iterations, success)
    """

    # Convert Python None to Julia 'nothing' using an empty list
    julia_time_flight = [] if time_flight is None else float(time_flight)
    julia_jacobi_constant = [] if jacobi_constant is None else float(jacobi_constant)
    julia_X_end = [] if X_end is None else Main.Vector(X_end.tolist())
    julia_DX_0 = [] if DX_0 is None else Main.Vector(DX_0.tolist())
    julia_X_big_0 = [] if X_big_0 is None else Main.Vector(X_big_0.tolist())
    julia_δ = [] if δ is None else float(δ)

    # Convert NumPy array to Julia array
    julia_orbit = Main.Matrix(orbit.tolist())

    # Call the Julia function
    result = Main.differential_correction(
        julia_orbit,
        μ,
        variable_time=variable_time,
        time_flight=julia_time_flight,
        jacobi_constant=julia_jacobi_constant,
        X_end=julia_X_end,
        tol=tol,
        max_iter=max_iter,
        printout=printout,
        DX_0=julia_DX_0,
        X_big_0=julia_X_big_0,
        δ=julia_δ
    )

    # Extract results from Julia tuples
    # Assuming the result is a tuple: (X_corrected, t_corrected, norm, iterations, success)
    X_corrected = np.array(result[0]).reshape(-1, 6)
    t_vec_corrected = np.array(result[1]).flatten()
    norm_F_or_G = float(result[2])
    iterations = int(result[3])
    success = int(result[4])

    return X_corrected, t_vec_corrected, norm_F_or_G, iterations, success

In [5]:
#| export
def create_converged_orbits_df(
    converged_indices, 
    orbit_array, 
    converged_orbits, 
    errors, 
    iterations
):
    """
    Creates a DataFrame containing detailed information about converged orbits.
    
    Parameters:
        converged_indices (list): List of orbit indices that have converged.
        orbit_array (np.ndarray): Original array containing all orbit data.
        converged_orbits (np.ndarray): Array containing corrected converged orbits.
        errors (np.ndarray): Array of norm values for each converged orbit.
        iterations (np.ndarray): Array of iteration counts for each converged orbit.
    
    Returns:
        pd.DataFrame: DataFrame with detailed information about each converged orbit.
    """
    data = []
    for idx, orbit_index in enumerate(converged_indices):
        initial_orbit = orbit_array[orbit_index]
        corrected_orbit = converged_orbits[idx]
        
        # Extract old and new periods
        old_period = initial_orbit[-1, 0]
        new_period = corrected_orbit[-1, 0]
        
        # Extract initial state from the original orbit
        initial_state = initial_orbit[0, 1:7]  # [pos_x0, pos_y0, pos_z0, vel_x0, vel_y0, vel_z0]
        
        # Extract initial state from the corrected orbit
        final_state = corrected_orbit[0, 1:7]
        
        # Compile the data row
        row = {
            'id': orbit_index,
            'old_period': old_period,
            'new_period': new_period,
            'iterations_convergence': iterations[idx],
            'error': errors[idx],
            'initial_pos_x0': initial_state[0],
            'initial_pos_y0': initial_state[1],
            'initial_pos_z0': initial_state[2],
            'initial_vel_x0': initial_state[3],
            'initial_vel_y0': initial_state[4],
            'initial_vel_z0': initial_state[5],
            'final_pos_x0': final_state[0],
            'final_pos_y0': final_state[1],
            'final_pos_z0': final_state[2],
            'final_vel_x0': final_state[3],
            'final_vel_y0': final_state[4],
            'final_vel_z0': final_state[5]
        }
        data.append(row)
    
    df = pd.DataFrame(data)
    return df

In [6]:
#| export
def process_diferential_correction_orbits(
    orbit_array: np.ndarray, 
    μ: float, 
    variable_time: bool = True, 
    tol: float = 1e-9, 
    max_iter: int = 20, 
    printout: bool = False
):
    """
    Processes a set of orbits by providing the orbit array directly to differential correction.
    
    Parameters:
        orbit_array (np.ndarray): Array containing orbit data with shape [num_orbits, num_timesteps, 7].
                                   The first element in the last dimension is assumed to be time.
        μ (float): Gravitational parameter.
        variable_time (bool, optional): Whether to use variable time nodes for correction. Default is True.
        tol (float, optional): Tolerance for convergence in differential correction. Default is 1e-9.
        max_iter (int, optional): Maximum number of iterations for differential correction. Default is 20.
        printout (bool, optional): Whether to print iteration logs. Default is False.
    
    Returns:
        tuple: (converged_orbits_array, converged_orbits_df)
               - converged_orbits_array: NumPy array with shape [num_converged_orbits, num_timesteps, 7] containing corrected orbits.
               - converged_orbits_df: pandas DataFrame with detailed information about each converged orbit.
    """
    # Lists to store results
    converged_orbits_list = []
    converged_indices = []
    errors_list = []
    iterations_list = []
    
    num_orbits, num_timesteps, data_length = orbit_array.shape
    
    # Validate data_length
    if data_length != 7:
        logger.error("Data length is incorrect. Expected 7 elements (1 time + 6 state components).")
        raise ValueError("Data length is incorrect. Expected 7 elements (1 time + 6 state components).")
    
    # Loop over each orbit
    for i in range(num_orbits):
        logger.info(f"Processing orbit {i+1}/{num_orbits}")
        
        # Extract each orbit
        orbit = orbit_array[i]  # Shape: [num_timesteps, 7]

        # Apply differential correction
        try:
            X_corrected, t_corrected, norm_val, iterations, success_flag = differential_correction(
                orbit=orbit, 
                μ=μ, 
                variable_time=variable_time, 
                tol=tol, 
                max_iter=float(max_iter), 
                printout=printout
            )
        except Exception as e:
            logger.error(f"Orbit {i}: Differential correction failed with error: {e}")
            continue

        logger.info(f"Orbit {i}: Success={success_flag}, Norm={norm_val}, Iterations={iterations}")
                
        # Check the success flag
        if success_flag == 1:  # Assuming success_flag == 1 indicates convergence
            # If successful, store the corrected orbit and additional info
            corrected_orbit = np.hstack((t_corrected[:, np.newaxis], X_corrected))
            converged_orbits_list.append(corrected_orbit)
            errors_list.append(norm_val)
            iterations_list.append(iterations)
            converged_indices.append(i)
        else:
            # Orbit failed to converge; no action needed since we're not tracking failures
            continue

    # Convert lists to NumPy arrays
    if converged_orbits_list:
        converged_orbits_array = np.array(converged_orbits_list)
        errors_array = np.array(errors_list)
        iterations_array = np.array(iterations_list)
    else:
        converged_orbits_array = np.empty((0, num_timesteps, 7))
        errors_array = np.array([])
        iterations_array = np.array([])
    
    # Create DataFrame for converged orbits
    if converged_orbits_list:
        converged_orbits_df = create_converged_orbits_df(
            converged_indices=converged_indices,
            orbit_array=orbit_array,
            converged_orbits=converged_orbits_array,
            errors=errors_array,
            iterations=iterations_array
        )
    else:
        converged_orbits_df = pd.DataFrame(columns=[
            'id', 'old_period', 'new_period', 'iterations_convergence', 'error',
            'initial_pos_x0', 'initial_pos_y0', 'initial_pos_z0',
            'initial_vel_x0', 'initial_vel_y0', 'initial_vel_z0',
            'final_pos_x0', 'final_pos_y0', 'final_pos_z0',
            'final_vel_x0', 'final_vel_y0', 'final_vel_z0'
        ])
    
    return converged_orbits_array, converged_orbits_df

## Python (not working)

In [7]:
import numpy as np
from scipy.integrate import solve_ivp
from typing import List, Tuple, Dict, Optional
from orbit_generation.propagation import jacobi_constant, prop_node
from orbit_generation.constants import EM_MU

In [8]:
class MultipleShooting:
    def __init__(self, mu: float, max_iterations: int, period: Optional[float] = None, energy: Optional[float] = None):
        self.mu = mu
        self.max_iterations = max_iterations
        self.period = period
        self.energy = energy

    def adjust_orbit(self, X: np.ndarray) -> Tuple[np.ndarray, List[int]]:
        if X.ndim == 2:
            X = X.reshape(1, *X.shape)
        elif X.ndim != 3:
            raise ValueError("Input X must be 2D or 3D array")

        N, F, T = X.shape
        modified_orbits = np.copy(X)
        
        if self.period is None:
            self.period = T - 1  # Assume unit time steps if period is not provided

        converged_indices = []

        for orbit_index in range(N):
            for _ in range(self.max_iterations):
                modified_orbits[orbit_index] = self.propagate_orbit(modified_orbits[orbit_index])
                
                if self.energy is not None:
                    modified_orbits[orbit_index] = self.adjust_energy(modified_orbits[orbit_index])
                
                errors = self.calculate_errors(modified_orbits[orbit_index])
                
                if self.check_convergence(errors):
                    converged_indices.append(orbit_index)
                    break

        print(f"{len(converged_indices)} out of {N} orbits converged.")
        return modified_orbits, converged_indices

    def propagate_orbit(self, orbit: np.ndarray) -> np.ndarray:
        dt = self.period / (orbit.shape[1] - 1)
        propagated_orbit = np.zeros_like(orbit)
        propagated_orbit[:, 0] = orbit[:, 0]  # Keep the first time step
        for i in range(1, orbit.shape[1]):
            state = orbit[1:, i-1]
            propagated_state = prop_node(state, dt, self.mu)
            propagated_orbit[:, i] = np.concatenate(([orbit[0, i-1] + dt], propagated_state))
        propagated_orbit[:, -1] = propagated_orbit[:, 0]  # Enforce periodicity
        return propagated_orbit

    def adjust_energy(self, orbit: np.ndarray) -> np.ndarray:
        current_energy = jacobi_constant(orbit[1:, 0], self.mu)[1]
        scaling_factor = np.sqrt(self.energy / current_energy)
        orbit[4:, :] *= scaling_factor
        return orbit

    def calculate_errors(self, orbit: np.ndarray) -> Dict[str, float]:
        errors = {}
        errors['position_error'] = np.linalg.norm(orbit[1:4, -1] - orbit[1:4, 0])
        errors['velocity_error'] = np.linalg.norm(orbit[4:, -1] - orbit[4:, 0])
        if self.energy is not None:
            errors['energy_error'] = np.abs(jacobi_constant(orbit[1:, 0], self.mu)[1] - self.energy)
        return errors

    def check_convergence(self, errors: Dict[str, float]) -> bool:
        position_threshold = 1e-6
        velocity_threshold = 1e-6
        energy_threshold = 1e-6
        
        converged = (errors['position_error'] < position_threshold and
                     errors['velocity_error'] < velocity_threshold)
        
        if self.energy is not None:
            converged = converged and (errors['energy_error'] < energy_threshold)
        
        return converged

In [9]:
class MultipleShooting:
    def __init__(self, mu: float, max_iterations: int, period: Optional[float] = None, energy: Optional[float] = None):
        self.mu = mu
        self.max_iterations = max_iterations
        self.period = period
        self.energy = energy
        self.tolerance = 1e-9

    def adjust_orbit(self, X: np.ndarray) -> Tuple[np.ndarray, List[int]]:
        N, F, T = X.shape
        t_vec = np.linspace(0, self.period or T-1, T)
        
        X_corrected = np.zeros_like(X)
        converged_indices = []
        
        for i in range(N):
            X_corrected[i], t_vec_corrected, error, iterations, success = self.differential_correction(
                X[i], t_vec, variable_time=self.period is not None,
                time_flight=self.period, jacobi_constant=self.energy
            )
            if success == 1:
                converged_indices.append(i)
        
        return X_corrected, converged_indices

    def differential_correction(self, X_old, t_vec_old, variable_time=True, time_flight=None,
                                jacobi_constant=None, X_end=None, printout=False):
        k = 0
        n = len(t_vec_old)

        if not variable_time and time_flight is not None:
            raise ValueError("Set variable_time to True to specify the time of flight.")

        while k < self.max_iterations:
            k += 1
            X_big, F, DF = self.constraints(X_old, t_vec_old, X_end, time_flight, jacobi_constant, variable_time)

            if np.linalg.norm(F) <= self.tolerance and t_vec_old[-1] > 1e-6:
                if printout:
                    print(f"Converged in {k} iterations")
                return X_old, t_vec_old, np.linalg.norm(F), k, 1

            if np.linalg.norm(F) >= 10 and k > 1:
                if printout:
                    print(f"Solution diverged after {k} iterations")
                return X_old, t_vec_old, np.linalg.norm(F), k, -1

            X_big_new = X_big - np.linalg.pinv(DF) @ F

            X_new = X_big_new[:7*n].reshape(n, 7)

            if variable_time:
                t_vec_new = np.zeros(n)
                t_vec_new[1:] = np.cumsum(X_big_new[7*n:])
                t_vec_old = t_vec_new.copy()

            X_old = X_new.copy()

            if printout:
                print(f"{k} | {np.linalg.norm(F)}")

        return X_old, t_vec_old, np.linalg.norm(F), self.max_iterations, -1

    def constraints(self, X, t_vec, X_end=None, time_flight=None, jacobi_constant=None, variable_time=True):
        n = len(t_vec)
        T_vec = np.diff(t_vec)

        if time_flight is None and jacobi_constant is None:
            dim_F = n * 7
        elif time_flight is not None:
            dim_F = n * 7 + 1
        elif jacobi_constant is not None:
            dim_F = n * 7 + 1

        X_big = np.concatenate([X.flatten(), T_vec] if variable_time else [X.flatten()])
        F = np.zeros(dim_F)
        DF = np.zeros((dim_F, len(X_big)))

        for i in range(n - 1):
            ind_x = slice(i * 7, (i + 1) * 7)
            ind_y = slice(i * 7, (i + 2) * 7)
            Xf_i, Phi = self.get_state(X[i], T_vec[i])
            F[ind_x] = Xf_i - X[i + 1]
            DF[ind_x, ind_y] = np.hstack((Phi, -np.eye(7)))
            if variable_time:
                DF[ind_x, 7 * n + i] = np.concatenate(([1], self.dynamics_cr3bp(X[i, 1:])))

        if X_end is None:
            F[-7:] = X[-1] - X[0]
            DF[-7:, -7:] = np.eye(7)
            DF[-7:, :7] = -np.eye(7)
        else:
            F[-7:] = X[-1] - X_end
            DF[-7:, -7:] = np.eye(7)

        if time_flight is not None:
            F[-1] = np.sum(T_vec) - time_flight
            DF[-1, -n+1:] = np.ones(n - 1)
        elif jacobi_constant is not None:
            J = 0
            for i in range(n):
                _, J_i, DJ_i = self.jacobi(X[i, 1:])
                J += J_i
                DF[-1, i*7+1:(i+1)*7] = DJ_i
            F[-1] = J - n * jacobi_constant

        return X_big, F, DF

    def get_state(self, X0, dt):
        t = X0[0]
        state = X0[1:]
        sol = solve_ivp(
            lambda t, y: np.concatenate(([1], self.dynamics_cr3bp(y))),
            [t, t + dt],
            state,
            dense_output=True,
            rtol=1e-12,
            atol=1e-12,
            method='Radau'
        )
        Xf = np.concatenate(([t + dt], sol.y[:, -1]))
        Phi = self.compute_stm(state, dt)
        return Xf, Phi

    def dynamics_cr3bp(self, X):
        x, y, z, v_x, v_y, v_z = X
        r1 = np.sqrt((x + self.mu)**2 + y**2 + z**2)
        r2 = np.sqrt((x - (1 - self.mu))**2 + y**2 + z**2)
        x_dot  = v_x
        y_dot  = v_y
        z_dot  = v_z
        x_ddot = x + 2 * v_y - (1 - self.mu) * (x + self.mu) / r1**3 - self.mu * (x - (1 - self.mu)) / r2**3
        y_ddot = y - 2 * v_x - y * ((1 - self.mu) / r1**3 + self.mu / r2**3)
        z_ddot = -z * ((1 - self.mu) / r1**3 + self.mu / r2**3)
        return np.array([x_dot, y_dot, z_dot, x_ddot, y_ddot, z_ddot])

    def jacobi(self, X):
        x, y, z, xp, yp, zp = X
        mu1 = 1 - self.mu
        mu2 = self.mu
        r1 = np.sqrt((x + mu2)**2 + y**2 + z**2)
        r2 = np.sqrt((x - mu1)**2 + y**2 + z**2)
        K = 0.5 * (xp**2 + yp**2 + zp**2)
        Ubar = -0.5 * (x**2 + y**2) - mu1 / r1 - mu2 / r2 - 0.5 * mu1 * mu2
        E = K + Ubar
        J = -2 * E
        DJ = np.zeros(6)  # Compute the gradient of J here
        return 0, J, DJ

    def compute_stm(self, X0, dt):
        def variational_equations(t, y):
            state = y[:6]
            phi = y[6:].reshape(6, 6)
            dxdt = self.dynamics_cr3bp(state)
            A = self.compute_jacobian(state)
            dphi_dt = A @ phi
            return np.concatenate([dxdt, dphi_dt.flatten()])

        y0 = np.concatenate([X0, np.eye(6).flatten()])
        sol = solve_ivp(variational_equations, [0, dt], y0, 
                        rtol=1e-12, atol=1e-12, method='Radau')
        return np.vstack([np.hstack([np.eye(1), np.zeros((1, 6))]),
                          np.hstack([np.zeros((6, 1)), sol.y[6:, -1].reshape(6, 6)])])

    def compute_jacobian(self, X):
        x, y, z, _, _, _ = X
        mu1 = 1 - self.mu
        mu2 = self.mu
        r1 = np.sqrt((x + mu2)**2 + y**2 + z**2)
        r2 = np.sqrt((x - mu1)**2 + y**2 + z**2)

        Uxx = 1 - mu1/r1**3 - mu2/r2**3 + 3*mu1*(x+mu2)**2/r1**5 + 3*mu2*(x-mu1)**2/r2**5
        Uyy = 1 - mu1/r1**3 - mu2/r2**3 + 3*mu1*y**2/r1**5 + 3*mu2*y**2/r2**5
        Uzz = -mu1/r1**3 - mu2/r2**3 + 3*mu1*z**2/r1**5 + 3*mu2*z**2/r2**5
        Uxy = 3*mu1*(x+mu2)*y/r1**5 + 3*mu2*(x-mu1)*y/r2**5
        Uxz = 3*mu1*(x+mu2)*z/r1**5 + 3*mu2*(x-mu1)*z/r2**5
        Uyz = 3*mu1*y*z/r1**5 + 3*mu2*y*z/r2**5

        return np.array([
            [0, 0, 0, 1, 0, 0],
            [0, 0, 0, 0, 1, 0],
            [0, 0, 0, 0, 0, 1],
            [Uxx, Uxy, Uxz, 0, 2, 0],
            [Uxy, Uyy, Uyz, -2, 0, 0],
            [Uxz, Uyz, Uzz, 0, 0, 0]
        ])

In [10]:
# Example usage with all constraints
mu = EM_MU
period = 2 * np.pi
energy = -1.5
max_iterations = 20

ms_full = MultipleShooting(mu, max_iterations, period, energy)
# adjusted_orbits_full = ms_full.adjust_orbit(generation)

In [11]:
#| hide
import nbdev; nbdev.nbdev_export()