In [None]:
import jax
import jax.numpy as jnp
import scipy.constants as sc
from typing import NamedTuple




In [None]:
class CircularSatelliteState(NamedTuple):
    """A class to represent the state of a satellite.

    Attributes:
        position: A 3D vector representing the position of the satellite.
        velocity: A 3D vector representing the velocity of the satellite.
        inclination: The inclination of the orbit in radians.
        right_ascension: The right ascension of the orbit in radians.
        eccentricity: The eccentricity of the orbit.
        true_anomaly: The true anomaly of the orbit in radians.
        semi_major_axis: The semi-major axis of the orbit in meters.
    """
    position: jax.numpy.ndarray
    velocity: jax.numpy.ndarray
    
    semi_major_axis: float
    inclination: float = 0.0
    right_ascension: float = 0.0
    eccentricity: float = 0
    true_anomaly: float = 0.0
    period : float = 0.0
    Re : float = 6.371e6  # meters
    Me : float =  5.972e24  # kg

    Specific_gravity : float = sc.G

    
    

def initialize_circular_orbit(semi_major_axis: float, true_anomaly = 0.0, right_ascension : float = 0, inclination: float = 0.0, ) -> CircularSatelliteState:
    """Initialize the state of a circular orbit.

    Args:
        semi_major_axis: The semi-major axis of the orbit in meters.
        true_anomaly: The true anomaly of the orbit in radians.
        right_ascension: The right ascension of the orbit in radians.
        inclination: The inclination of the orbit in radians.
    Returns:
        A CircularSatelliteState object representing the state of the satellite.
    """
    Re = 6.371e6  # meters
    Me =  5.972e24  # kg

    # Ensure the semi-major axis is positive
    assert semi_major_axis > 0, "Semi-major axis must be positive."

    period = 2 * sc.pi * jnp.sqrt((semi_major_axis+Re)**3/(sc.G*Me))	# Orbital period 

    assert period > 0, "Period must be positive."

    im_state = CircularSatelliteState(position=jnp.array([0,0,0]), true_anomaly = true_anomaly,period=period, velocity=jnp.array([0,0,0]), 
                                      right_ascension=right_ascension, inclination=inclination, semi_major_axis=semi_major_axis)


    state = propegate(im_state, 0.0)

    assert jnp.all(jnp.isfinite(state.position)), "Position is not finite."
    assert jnp.all(jnp.isfinite(state.velocity)), "Velocity is not finite."

    assert jnp.isclose(jnp.linalg.norm(state.position), state.semi_major_axis, atol=1e-5), f"Position is not on the orbit."
    assert jnp.isclose(jnp.linalg.norm(state.velocity), jnp.sqrt(sc.G*state.Me/(state.semi_major_axis)), atol=1), "Velocity is not on the orbit."


    return state


def propegate(state: CircularSatelliteState, delta_t: float) -> CircularSatelliteState:
    """Propegate the orbit by a given time.

    Args:
        state: A CircularSatelliteState object representing the current state of the satellite.
        delta_t: The timeo over

    Returns:
        A new CircularSatelliteState object with the rotated orbit.
    """
    pi = 3.141592653589793


    # Propegate the position in the circle, given delta t
    true_anomaly = (state.true_anomaly+2 *pi * delta_t / state.period)%(2*pi)

    # Position vector
    circle_position = state.semi_major_axis * jnp.array([jnp.cos(state.true_anomaly), jnp.sin(state.true_anomaly), 0])
   
    # Velocity vector
    circle_velocity = jnp.array([-state.semi_major_axis * jnp.sin(state.true_anomaly), state.semi_major_axis * jnp.cos(state.true_anomaly), 0])

    # Rotate the circle postion vector by the inclination and right ascension
    inclination_matrix = jnp.array([[1, 0, 0],
                                     [0, jnp.cos(state.inclination), -jnp.sin(state.inclination)],
                                     [0, jnp.sin(state.inclination), jnp.cos(state.inclination)]])

    # Rotation matrix of the satellite orbit in around the z-axis 
    right_ascension_matrix = jnp.array([[jnp.cos(state.right_ascension), -jnp.sin(state.right_ascension), 0],
                                  [jnp.sin(state.right_ascension), jnp.cos(state.right_ascension), 0],
                                  [0, 0, 1]])
    
    
    # Rotate the position vector by the inclination and right ascension
    new_position = circle_position @ inclination_matrix @ right_ascension_matrix
    # Rotate the velocity vector by the inclination and right ascension
    new_velocity = circle_velocity @ inclination_matrix @ right_ascension_matrix




    return state._replace(true_anomaly = true_anomaly,position=new_position, velocity=new_velocity)

# Test code

In [44]:
simple_500_km_orbit = initialize_circular_orbit(500e3, 0,0.15,0.15)

print("Position:", simple_500_km_orbit.position)

print("Velocity:", simple_500_km_orbit.velocity)

simple_500_km_orbit = propegate(simple_500_km_orbit, 1000)

print("Position after 1000 seconds:", simple_500_km_orbit.position)
print("Velocity after 1000 seconds:", simple_500_km_orbit.velocity)

AssertionError: Velocity is not on the orbit.