In [None]:
from jax import config

config.update("jax_enable_x64", True)

import jax.numpy as jnp
import matplotlib.pyplot as plt

from jax_gw.detector.orbits import (
    EARTH_TILT,
    axial_tilt,
    create_cartwheel_arm_lengths,
    create_cartwheel_orbit,
    create_circular_orbit_xy,
    ecliptic_timeshift,
    equatorial_timeshift,
    flatten_pairs,
    get_arm_lengths,
    get_receiver_positions,
    get_separations,
    lat_lon_to_cartesian,
    earthbound_ifo_pipeline,
)

In [None]:
AU_per_billion_meters = 149.597871
L_target = 2.5
R_target = 1.0
ecc = L_target / (AU_per_billion_meters * 2 * jnp.sqrt(3) * R_target)
N_LISA = 3

In [None]:
times = jnp.linspace(0, 1, 1000)
orbits = create_cartwheel_orbit(ecc, R_target, N_LISA, times)

In [None]:
orbits.shape

In [None]:
fig = plt.figure()
ax = fig.add_subplot(111, projection="3d")
for i in range(N_LISA):
    ax.plot(
        orbits[i, 0, :],
        orbits[i, 1, :],
        orbits[i, 2, :],
        label=f"Spacecraft {i + 1}",
    )
# equal aspect ratio for 3d plot
# plt.gca().set_aspect("equal", adjustable="box")
plt.xlim(-1.2,1.2)
plt.ylim(-1.2,1.2)
ax.set_zlim([-1.2, 1.2])
ax.legend()
plt.show()


In [None]:
r = get_separations(orbits)
print(r.shape)
L = get_arm_lengths(r) 
print(L.shape)

In [None]:
L_12 = L[:, 0, 1]
L_13 = L[:, 0, 2]
L_23 = L[:, 1, 2]

In [None]:
jnp.linalg.norm(orbits[0,:,0]- orbits[1,:,0]) * AU_per_billion_meters

In [None]:
d_analytic = create_cartwheel_arm_lengths(ecc, R_target, N_LISA, times)
d_analytic.shape


In [None]:
L12_analytic = d_analytic[:, 0, 1]
L13_analytic = d_analytic[:, 0, 2]
L23_analytic = d_analytic[:, 1, 2]

In [None]:
# plot separations and compare to analytic
fig = plt.figure()
ax = fig.add_subplot(111)
for dist in [L_12, L_13, L_23]:
    ax.plot(times, dist*AU_per_billion_meters, linestyle='-', label='numerical', linewidth=3)
for dist in [L12_analytic, L13_analytic, L23_analytic]:
    ax.plot(times, dist*AU_per_billion_meters, linestyle='--', label='analytic')
ax.set_xlabel("Time")
ax.set_ylabel("Separation")
ax.legend()
plt.tight_layout()
plt.show()


In [None]:
receiver_orbits = get_receiver_positions(orbits)
print(receiver_orbits.shape)
receiver_positions = flatten_pairs(receiver_orbits)
receiver_positions.shape

In [None]:
separations = get_separations(orbits)
flatten_pairs(separations).shape

In [None]:
flatten_pairs(L).shape

In [None]:
%timeit flatten_pairs(L)

In [None]:
FREQ_CENTER_ORBIT = 1  # in 1/year
FREQ_ROTATION = 365.25  # in 1/year
times = jnp.linspace(0, 1 / FREQ_ROTATION, 1000)
r = 1.0 # in AU
r_orbital = create_circular_orbit_xy(r, FREQ_CENTER_ORBIT, times)

# calculate x, y, z coordinates of detector with respect to the guiding center
# at time t=0
# detector_lat = EARTH_Z_LAT
# detector_lon = EARTH_Z_LON
detector_lat, detector_lon = (
    46.455140209119214 * jnp.pi / 180,
    -119.40746331631823 * jnp.pi / 180,
)
r_detector_initial_equatorial = lat_lon_to_cartesian(detector_lat, detector_lon)
print(r_detector_initial_equatorial)
hour_angle = 2.0 * jnp.pi * FREQ_ROTATION * times
r_detector = equatorial_timeshift(r_detector_initial_equatorial, hour_angle)
r_detector.shape

In [None]:
r_detector = axial_tilt(r_detector, EARTH_TILT)
print(r_detector[:,0])
r_detector.shape

In [None]:
plt.plot(times, r_detector[0, :], label='x')
plt.plot(times, r_detector[1, :], label='y')
plt.plot(times, r_detector[2, :], label='z')
plt.legend()
plt.show()

In [None]:
L_arm = 4 # in km
r_earth_in_km = 6371.0

psi_H = (90+36) * jnp.pi / 180
psi_L = None
psi = psi_H
beta_arm = jnp.pi / 2.0

# x_arm_ecliptic_initial = jnp.array([L_arm / r_earth_in_km, 0.0, 0.0])
# y_arm_ecliptic_initial = jnp.array([0.0, L_arm / r_earth_in_km, 0.0])
north_pole_equatorial = jnp.array([0.0, 0.0, 1.0])
local_east = jnp.cross(north_pole_equatorial, r_detector_initial_equatorial)
local_east = local_east / jnp.linalg.norm(local_east)
# rotate the arms by psi with respect to r_detector_initial_equatorial
# by applying the matrix form of Rodrigues' rotation formula
K_matrix = jnp.array(
    [
        [0.0, -r_detector_initial_equatorial[2], r_detector_initial_equatorial[1]],
        [r_detector_initial_equatorial[2], 0.0, -r_detector_initial_equatorial[0]],
        [-r_detector_initial_equatorial[1], r_detector_initial_equatorial[0], 0.0],
    ]
)
rotation_matrix_psi = jnp.eye(3) + jnp.sin(psi) * K_matrix + (1-jnp.cos(psi)) * K_matrix @ K_matrix
rotation_matrix_beta = jnp.eye(3) + jnp.sin(beta_arm) * K_matrix + (1-jnp.cos(beta_arm)) * K_matrix @ K_matrix
x_arm_direction = rotation_matrix_psi @ local_east
y_arm_direction = rotation_matrix_beta @ x_arm_direction

print(x_arm_direction)
print(y_arm_direction)

arm_length = L_arm / r_earth_in_km
x_arm_local_equatorial_initial = arm_length * x_arm_direction
y_arm_local_equatorial_initial = arm_length * y_arm_direction

# convert to ecliptic coordinates
x_arm_ecliptic_initial = axial_tilt(x_arm_local_equatorial_initial, +EARTH_TILT)
print(x_arm_ecliptic_initial)
y_arm_ecliptic_initial = axial_tilt(y_arm_local_equatorial_initial, +EARTH_TILT)
print(y_arm_ecliptic_initial)


In [None]:
# show the equatorial location of the detector in a lat-lon 2d plot
plt.figure()
ax = plt.gca()

plt.scatter(
    detector_lon * 180 / jnp.pi,
    detector_lat * 180 / jnp.pi,
    marker="o",
    color="red",
    label="detector",
    s=1,
)
# plot the arms
# x_arm_lat_lon_dir_x = jnp.cos(psi)
# x_arm_lat_lon_dir_y = jnp.sin(psi)

# y_arm_lat_lon_dir_x = jnp.cos(psi+beta_arm)
# y_arm_lat_lon_dir_y = jnp.sin(psi+beta_arm)

# verify the direction of the arms in lat-lon coordinates
# angle with respect to East
abs_phi_x_arm = jnp.arccos(jnp.dot(x_arm_direction, local_east))
abs_phi_y_arm = jnp.arccos(jnp.dot(y_arm_direction, local_east))
phi_x_arm = jnp.sign(x_arm_direction[2]) * abs_phi_x_arm
phi_y_arm = jnp.sign(y_arm_direction[2]) * abs_phi_y_arm
print(phi_x_arm * 180 / jnp.pi)
print(phi_y_arm * 180 / jnp.pi)
x_arm_lat_lon_dir_x = jnp.cos(phi_x_arm)
x_arm_lat_lon_dir_y = jnp.sin(phi_x_arm)

y_arm_lat_lon_dir_x = jnp.cos(phi_y_arm)
y_arm_lat_lon_dir_y = jnp.sin(phi_y_arm)

plt.quiver(
    detector_lon * 180 / jnp.pi,
    detector_lat * 180 / jnp.pi,
    x_arm_lat_lon_dir_x,
    x_arm_lat_lon_dir_y,
    color="red",
    width=0.005,
)
plt.quiver(
    detector_lon * 180 / jnp.pi,
    detector_lat * 180 / jnp.pi,
    -x_arm_lat_lon_dir_x,
    -x_arm_lat_lon_dir_y,
    color="red",
    width=0.005,
)

plt.quiver(
    detector_lon * 180 / jnp.pi,
    detector_lat * 180 / jnp.pi,
    y_arm_lat_lon_dir_x,
    y_arm_lat_lon_dir_y,
    pivot="tip",
    color="red",
    width=0.005,
)

plt.quiver(
    detector_lon * 180 / jnp.pi,
    detector_lat * 180 / jnp.pi,
    -y_arm_lat_lon_dir_x,
    -y_arm_lat_lon_dir_y,
    pivot="tip",
    color="red",
    width=0.005,
)

plt.xlim(-180, 180)
plt.ylim(-90, 90)

plt.xlabel("Longitude (degrees)")
plt.ylabel("Latitude (degrees)")
plt.legend()
plt.grid(
    linestyle="--",
)
plt.show()



In [None]:
x_arm = ecliptic_timeshift(x_arm_ecliptic_initial, hour_angle, EARTH_TILT)
y_arm = ecliptic_timeshift(y_arm_ecliptic_initial, hour_angle, EARTH_TILT)

In [None]:
plt.plot(times, x_arm[0, :], label='x_arm (x)')
plt.plot(times, x_arm[1, :], label='x_arm (y)')
plt.plot(times, x_arm[2, :], label='x_arm (z)')
plt.plot(times, y_arm[0, :], label='y_arm (x)')
plt.plot(times, y_arm[1, :], label='y_arm (y)')
plt.plot(times, y_arm[2, :], label='y_arm (z)')
plt.legend()
plt.show()

In [None]:
# add a rotation around this guiding center, assuming a solid body like the Earth
earth_radius_per_km = 6371.0
AU_per_earth_radius = (AU_per_billion_meters * 1e9) / (earth_radius_per_km * 1e3)
print(AU_per_earth_radius)

r_beam_splitter = r_orbital + r_detector / AU_per_earth_radius
# convert x_arm, y_arm to double precision, divide by AU_per_earth_radius, and add to r_beam_splitter
r_beam_splitter = jnp.array(r_beam_splitter, dtype=jnp.float64)
x_arm = jnp.array(x_arm, dtype=jnp.float64) / AU_per_earth_radius
y_arm = jnp.array(y_arm, dtype=jnp.float64) / AU_per_earth_radius

x_arm = r_beam_splitter + x_arm
y_arm = r_beam_splitter + y_arm

# stack the beam splitter, the x Arm, and the y Arm into a 3x3xlen(times) array
orbits = jnp.stack([r_beam_splitter, x_arm, y_arm], axis=0)
orbits.dtype

In [None]:
x_arm

In [None]:
rot_orbits = orbits - r_beam_splitter
# plt.plot(rot_orbits[0, 0, :], rot_orbits[0, 1, :], label='beam splitter')
plt.plot(rot_orbits[1, 0, :], rot_orbits[1, 1, :], label='x arm')
plt.plot(rot_orbits[2, 0, :], rot_orbits[2, 1, :], label='y arm')
# plt.axis('equal')
plt.legend()
plt.show()

In [None]:
separations = get_separations(orbits)
flatten_pairs(separations).shape

In [None]:
arm_lengths = flatten_pairs(get_arm_lengths(separations))

In [None]:
plt.plot(times, arm_lengths[0, :], label='x arm length', linewidth=3)
plt.plot(times, arm_lengths[2, :], label='y arm length')
plt.plot(times, arm_lengths[4, :], label='unphysical arm', linewidth=5, alpha=0.5)
plt.plot(times, jnp.sqrt(2) * arm_lengths[0, :], label='x arm * sqrt2', linestyle='--')
plt.legend()
plt.show()

In [None]:
orbits = earthbound_ifo_pipeline(
    lat=detector_lat,
    lon=detector_lon,
    times=times,
    r=r,
    L_arm=L_arm,
    psi=psi,
    beta_arm=beta_arm,
)
orbits.shape
    

In [None]:
x_arm = (orbits[1]-orbits[0])[:,0]
y_arm = (orbits[2]-orbits[0])[:,0]

angle_beta = jnp.arccos(jnp.dot(x_arm, y_arm) / (jnp.linalg.norm(x_arm) * jnp.linalg.norm(y_arm)))
print(angle_beta * 180 / jnp.pi)