In [None]:
import jax
from jax import config

import jax.numpy as jnp

from jax_gw.detector.orbits import (
    EARTH_TILT,
    EARTH_Z_LAT,
    EARTH_Z_LON,
    create_cartwheel_orbit,
    flat_index,
    flatten_pairs,
    get_arm_lengths,
    get_receiver_positions,
    get_separations,
    lat_lon_to_cartesian,
    path_from_indices,
    earthbound_ifo_pipeline,
)
from jax_gw.detector.response import (
    antenna_pattern, 
    transfer_function, 
    response_function, 
    get_path_response,
    get_differential_strain_response,
    jitted_vmapped_transfer_function,
    sky_vmapped_antenna_pattern,
    response_pipe,
)

from jax_gw.signal.plotting import plot_response
from jax_gw.sources.waveforms import WD_binary_source

from jax_gw.detector.pixel import (
    get_directional_basis,
    flat_to_matrix_sky_indices,
    unflatten_sky_axis,
    get_sph_harm_values,
    get_solid_angle_theta_phi,
    pixel_to_lm,
)

config.update("jax_enable_x64", True)

N_times = 2
N_freqs = 5
t_obs = 3.16e-5
times = jnp.linspace(0, t_obs, N_times)
freqs = jnp.linspace(1e-4, 2e-3, N_freqs)
A = 1.0
f_0 = 1e-3
f_0_dot = 1e-10
phi_0 = 0.0
waveforms = WD_binary_source(A, f_0, f_0_dot, phi_0, times)

In [None]:
waveforms.shape

In [None]:
import matplotlib.pyplot as plt

plt.plot(times, waveforms[:, 0])
plt.plot(times, waveforms[:, 1])
plt.show()

In [None]:
# create random 3D unit vector
arm_direction = jax.random.normal(jax.random.PRNGKey(0), (3,))
arm_direction = arm_direction / jnp.linalg.norm(arm_direction)

# create random source direction
enc_beta = jax.random.uniform(jax.random.PRNGKey(0), (1,), minval=-jnp.pi / 2.0, maxval=jnp.pi / 2.0).item()
enc_lambda = jax.random.uniform(jax.random.PRNGKey(0), (1,), minval=-jnp.pi, maxval=jnp.pi).item()

enc_beta, enc_lambda

In [None]:
k, u, v = get_directional_basis(enc_beta, enc_lambda)
antenna = antenna_pattern(u, v, arm_direction)
antenna.shape

In [None]:
time_vmapped_antenna_pattern = jax.vmap(antenna_pattern, in_axes=(None, None, 0), out_axes=0)

In [None]:
arm_directions = jax.random.normal(jax.random.PRNGKey(0), (N_times, 3))
arm_directions = arm_directions / jnp.linalg.norm(arm_directions, axis=-1, keepdims=True)

antennae = time_vmapped_antenna_pattern(u, v, arm_directions)
antennae.shape

In [None]:
antennas = antenna_pattern(u, v, arm_directions)
antennas.shape

In [None]:
# enc_betas = jax.random.uniform(jax.random.PRNGKey(0), (50,), minval=-jnp.pi / 2.0, maxval=jnp.pi / 2.0)
# enc_lambdas = jax.random.uniform(jax.random.PRNGKey(0), (50,), minval=-jnp.pi, maxval=jnp.pi)
# or linspace

N_theta = 100
N_phi = 120
N_sky = N_theta * N_phi
# enc_cos_betas = jnp.linspace(1.0, -1.0, N_beta)
# enc_betas_reduced = jnp.arccos(enc_cos_betas)
delta_phi = 2 * jnp.pi / N_phi
ecl_thetas_reduced = jnp.linspace(1/N_theta, jnp.pi-1/N_theta, N_theta)
ecl_phis_reduced = jnp.arange(0, 2 * jnp.pi, delta_phi)[:N_phi]
print(max(ecl_phis_reduced), min(ecl_phis_reduced))

flat_to_m_sky = flat_to_matrix_sky_indices(N_theta, N_phi)
ecl_thetas = ecl_thetas_reduced[flat_to_m_sky[:,0]]
ecl_phis = ecl_phis_reduced[flat_to_m_sky[:,1]]
sky_basis = get_directional_basis(ecl_thetas, ecl_phis)
k_hat, u_hat, v_hat = sky_basis
print(k_hat.shape, u_hat.shape, v_hat.shape)

In [None]:
# now vmap over both the direction in the sky (beta, lambda) and the arm direction (arm_direction)

antennae = sky_vmapped_antenna_pattern(u_hat, v_hat, arm_directions)
antennae.shape

In [None]:
waveforms.shape

In [None]:
jnp.einsum("...i,...i->...", waveforms, antennae).shape

In [None]:
transfer_function(k, freq=1E-3, arms=1E2*jnp.array([1.0, 1.0, 1.0]))

# LISA

In [None]:
AU_per_billion_meters = 149.597871
L_target = 2.5
R_target = 1.0
ecc = L_target / (AU_per_billion_meters * 2 * jnp.sqrt(3) * R_target)
N = 3
orbits = create_cartwheel_orbit(ecc, R_target, N, times)
separations = get_separations(orbits)
arms = flatten_pairs(separations)
print(arms.shape)
arm_lengths = get_arm_lengths(arms) 
print(arm_lengths.shape)


# N_arms = N * (N-1) 
# arms = jax.random.uniform(jax.random.PRNGKey(0), (N_arms, N_times, 3), minval=-100.0, maxval=100.0)
# print(arms.shape)
transfer_function(k, freq=jnp.array([1E-3, 2E-3, 3E-3, 4E-3]), arms=arms).shape

In [None]:
# apparently not putting the argument names in the vmapped fuction is necessary, otherwise it doesn't work
%timeit vmapped_transfer_function(k_hat, freqs, arms).shape

In [None]:
%timeit jitted_vmapped_transfer_function(k_hat, freqs, arms)

In [None]:
full_transfer = jitted_vmapped_transfer_function(k_hat, freqs, arms)
full_transfer.shape

In [None]:
arm_directions = arms / arm_lengths[..., None]
antennae = sky_vmapped_antenna_pattern(u_hat, v_hat, arm_directions)
antennae.shape

In [None]:
receiver_orbits = get_receiver_positions(orbits)
receiver_positions = flatten_pairs(receiver_orbits)

In [None]:
response = response_function(
    k_hat.T,
    freqs,
    receiver_positions,
    full_transfer,
    antennae,
)
response.shape, response.nbytes / 1E9

In [None]:
response, antennae = response_pipe(
    orbits,
    freqs,
    sky_basis,
)
print(response.shape, response.nbytes / 1E9)

In [None]:
auto_arm_overlap = 0.5 * jnp.einsum("...i,...i->...", response, response)
auto_arm_overlap.shape

In [None]:
path_1 = jnp.array([0, 1, 0, 2, 0])
path_2 = jnp.array([0, 2, 0, 1, 0])
path_3 = jnp.array([1, 2, 1, 0, 1])
path_4 = jnp.array([1, 0, 1, 2, 1])
path_5 = jnp.array([2, 0, 2, 1, 2])
path_6 = jnp.array([2, 1, 2, 0, 2])

In [None]:
def get_michelson_responses(path_A, path_B, response):
    indices = path_from_indices(jnp.stack([path_A, path_B], axis=0))

    flat_indices = jnp.apply_along_axis(
        lambda indices: flat_index(*indices, N),
        axis=-1,
        arr=indices,
    )

    print(flat_indices)
    

    path_separations = arm_lengths[flat_indices]
    path_separations = jnp.insert(path_separations, 0, 0.0, axis=1)
    # exclude the last element of each path
    path_separations = path_separations[:,:-1,...]
    path_separations = jnp.moveaxis(path_separations, -1, 0)

    cumul_path_separations = jnp.cumsum(path_separations, axis=-1)
    print(cumul_path_separations.shape)

    cumul_path_phases = -2 * jnp.pi * jnp.outer(freqs, cumul_path_separations)
    cumul_path_phases = cumul_path_phases.reshape(
        freqs.shape + cumul_path_separations.shape
    )

    cumul_path_exp = jnp.exp(1j * cumul_path_phases)

    path_responses = jnp.einsum("ijkl,kljmin->kjmin", cumul_path_exp, response[flat_indices])

    michelson_response = path_responses[0] - path_responses[1]

    return michelson_response

In [None]:
michelson_response_1 = get_michelson_responses(path_1, path_2, response)
michelson_response_2 = get_michelson_responses(path_3, path_4, response)
michelson_response_3 = get_michelson_responses(path_5, path_6, response)

In [None]:
isotropic_avg_response = jnp.sum(michelson_response_1* jnp.conj(michelson_response_1), axis=-1)
# and integrate over the sky
delta_omega = 4 * jnp.pi / N_sky
isotropic_avg_response = jnp.sum(isotropic_avg_response, axis=1) * delta_omega / (4 * jnp.pi)
isotropic_avg_response.shape

In [None]:
response_plus_abs = jnp.abs(michelson_response_1[..., 0])
response_plus_abs.shape

In [None]:
import matplotlib.pyplot as plt

fig = plt.figure()
ax = fig.add_subplot(projection='3d')
# Make data for sphere of radius 10
plot_response(
    10*jnp.ones((N_theta*N_phi)),
    ecl_thetas_reduced,
    ecl_phis_reduced,
    ax,
)

# rotate by 45 degrees

plt.show()


In [None]:
# now plot_surface in 3D the antenna pattern of response_plus_abs

fig = plt.figure()
ax = fig.add_subplot(projection="3d")

# Make data
plotted_response = response_plus_abs[0, :, 0]
plotted_response = unflatten_sky_axis(
    plotted_response, axis=0, N_theta=N_theta, N_phi=N_phi
)
u = ecl_phis_reduced
v = ecl_thetas_reduced
x = plotted_response * jnp.outer(jnp.sin(v), jnp.cos(u))
y = plotted_response * jnp.outer(jnp.sin(v), jnp.sin(u))
z = plotted_response * jnp.outer(jnp.cos(v), jnp.ones(jnp.size(u)))

# Plot the surface
ax.plot_surface(x, y, z)

# Set an equal aspect ratio
ax.set_aspect("equal")

plt.show()

# Noise PSDs for LISA

In [None]:
S_shot = 4.84e-42 # Hz^-1
S_acc = 2.31e-52 # Hz^-1
L_over_c =  8.33910238 # s
S_n = 4 * S_shot + 8 * (1 + jnp.cos(freqs *2 * jnp.pi * L_over_c)**2) * S_acc * jnp.power(freqs, -4)

plt.loglog(freqs, jnp.sqrt(S_n))
plt.title("Noise PSD for LISA")
plt.xlabel("Frequency [Hz]")
plt.ylabel(r"$\sqrt{S_n}\;\;\;[Hz^{-1/2}]$")

In [None]:
S_X = 4 * jnp.power(jnp.sin(freqs * 2 * jnp.pi * L_over_c), 2) * S_n

plt.loglog(freqs, jnp.sqrt(S_X))
plt.xlabel("Frequency [Hz]")
plt.ylabel(r"$\sqrt{S_X}\;\;\;[Hz^{-1/2}]$")

In [None]:
h_eff = jnp.sqrt(S_n / jnp.abs(isotropic_avg_response[0]))
plt.loglog(freqs, h_eff)
plt.xlabel("Frequency [Hz]")
plt.ylabel(r"$h_{eff}\;\;\;[Hz^{-1/2}]$")
plt.show()

In [None]:
plt.loglog(freqs, S_n)
plt.xlabel("Frequency [Hz]")
plt.ylabel(r"$h_{eff}\;\;\;[Hz^{-1/2}]$")
plt.show()

# Angular Noise Power Spectrum for LISA

In [None]:
# l_max = 10
# lpmn_l_max_jitted = jax.jit(lambda x: lpmn_values(l_max, l_max, x, is_normalized=True))

In [None]:
# alp_normed = lpmn_l_max_jitted(jnp.cos(ecl_thetas_reduced))
# # swap the first two axes to have the l axis first
# alp_normed = jnp.swapaxes(alp_normed, 0, 1)
# alp_normed.shape

In [None]:
# exp_1j_m_phi = jnp.exp(1j * jnp.outer(jnp.arange(0,l_max + 1), ecl_phis_reduced))
# exp_1j_m_phi.shape

In [None]:
# sph_harm_values = alp_normed[...,None] * exp_1j_m_phi[None,:,None,:]
# sph_harm_values = sph_harm_values.reshape(*sph_harm_values.shape[:-2], -1)
# sph_harm_values.shape

In [None]:
l_max = 10
sph_harm_values = get_sph_harm_values(l_max, ecl_thetas_reduced, ecl_phis_reduced)
sph_harm_values.shape

In [None]:
# plot the shperical harmonic values
fig = plt.figure()

ax = fig.add_subplot(projection="3d")


# Make data
# l=0, m=0
plotted_response = jnp.abs(sph_harm_values[1, 0])
plotted_response = unflatten_sky_axis(
    plotted_response, axis=0, N_theta=N_theta, N_phi=N_phi
)
u = ecl_phis_reduced
v = ecl_thetas_reduced
x = plotted_response * jnp.outer(jnp.sin(v), jnp.cos(u))
y = plotted_response * jnp.outer(jnp.sin(v), jnp.sin(u))
z = plotted_response * jnp.outer(jnp.cos(v), jnp.ones(jnp.size(u)))

# Plot the surface
ax.plot_surface(x, y, z)

# Set an equal aspect ratio
ax.set_aspect("equal")

plt.show()

In [None]:
response_lm_1 = pixel_to_lm(michelson_response_1, 1, N_theta, N_phi, ecl_thetas, ecl_phis, sph_harm_values)
response_lm_2 = pixel_to_lm(michelson_response_2, 1, N_theta, N_phi, ecl_thetas, ecl_phis, sph_harm_values)
response_lm_3 = pixel_to_lm(michelson_response_3, 1, N_theta, N_phi, ecl_thetas, ecl_phis, sph_harm_values)

In [None]:
overlap_iso_unnormed = response_lm_1[...,0,0] * jnp.conj(response_lm_2[...,0,0])
overlap_iso_unnormed = 0.5 * jnp.sum(overlap_iso_unnormed, axis=-1)
plt.plot(freqs, jnp.real(overlap_iso_unnormed[0]), label="real")
plt.plot(freqs, jnp.imag(overlap_iso_unnormed[0]), label="imag")
plt.xlabel("Frequency [Hz]")
plt.ylabel(r"$\gamma_{12}$")
plt.show()

In [None]:
# sum over axis=2
overlap_lm_11 = 0.5 * jnp.sum(response_lm_1 * jnp.conj(response_lm_1), axis=2)
print(overlap_lm_11.shape)
overlap_lm_22 = 0.5 * jnp.sum(response_lm_2 * jnp.conj(response_lm_2), axis=2)
overlap_lm_33 = 0.5 * jnp.sum(response_lm_3 * jnp.conj(response_lm_3), axis=2)
overlap_lm_12 = 0.5 * jnp.sum(response_lm_1 * jnp.conj(response_lm_2), axis=2)
print(overlap_lm_12.shape)
overlap_lm_13 = 0.5 * jnp.sum(response_lm_1 * jnp.conj(response_lm_3), axis=2)
overlap_lm_23 = 0.5 * jnp.sum(response_lm_2 * jnp.conj(response_lm_3), axis=2)


In [None]:
plt.plot(freqs, jnp.real(overlap_lm_12[0,...,0,0]), label="real")
plt.plot(freqs, jnp.imag(overlap_lm_12[0,...,0,0]), label="imag")
plt.xlabel("Frequency [Hz]")
plt.ylabel(r"$\gamma_{12}$ [unnormalized]")
plt.show()

In [None]:
def get_noise_ell(overlap_lm, N_f, l_max, f_ref, spectral_index, t_obs):
    
    spectral_shape = jnp.power(freqs / f_ref, spectral_index)
    l_array = jnp.arange(0, l_max + 1)
    overlap_init = overlap_lm[0]
    # sum the square complex norm over m (axis=-1)
    overlap_sq_init = jnp.sum(overlap_init * jnp.conj(overlap_init), axis=-1)
    print(overlap_sq_init[:,0].min(), overlap_sq_init[:,0].max(), overlap_sq_init[:,0].mean())
    integrand = overlap_sq_init * ((2.0/5.0*spectral_shape/N_f)**2)[:, None] / (2 * l_array + 1)
    # integrate over frequency
    noise_ell_inv = t_obs / 2 * jnp.trapz(integrand, freqs, axis=0)

    return noise_ell_inv**(-1)


spectral_index = 0
f_ref = 0.01
N_f = S_X

print("Frequency resolution: ", freqs[1] - freqs[0])
noise_ell_12 = get_noise_ell(overlap_lm_12, N_f, l_max, f_ref, spectral_index, 1)
print(noise_ell_12)
plt.plot(jnp.arange(0, l_max + 1), jnp.real(noise_ell_12))
plt.yscale("log")
plt.xlabel(r"$\ell$")
plt.ylabel(r"$N_\ell^{12}$")
plt.show()

In [None]:
noise_ell_11 = get_noise_ell(overlap_lm_11, N_f, l_max, f_ref, spectral_index, t_obs)
print(noise_ell_11)
plt.plot(jnp.arange(0, l_max + 1), jnp.real(noise_ell_11))
plt.yscale("log")
plt.xlabel(r"$\ell$")
plt.ylabel(r"$N_\ell^{12}$")
plt.show()

In [None]:
noise_ell_22 = get_noise_ell(overlap_lm_22, N_f, l_max, f_ref, spectral_index, t_obs)
print(noise_ell_22)
plt.plot(jnp.arange(0, l_max + 1), jnp.real(noise_ell_22))
plt.yscale("log")
plt.xlabel(r"$\ell$")
plt.ylabel(r"$N_\ell^{12}$")
plt.show()

# Ground Based IFOs

In [None]:
FREQ_CENTER_ORBIT = 1  # in 1/year
FREQ_ROTATION = 365.25  # in 1/year
N_times = 2
times = jnp.linspace(0, 1 / FREQ_ROTATION, N_times)
r = 1 # in AU
L_arm = 4 # in km
orbits = earthbound_ifo_pipeline(
    EARTH_Z_LAT,
    EARTH_Z_LON,
    times,
    r,
    L_arm,
)
orbits.shape

In [None]:
separations = get_separations(orbits)

receiver_orbits = get_receiver_positions(orbits)
receiver_positions = flatten_pairs(receiver_orbits)

arms = flatten_pairs(separations)
print(arms.shape)
arm_lengths = get_arm_lengths(arms)
print(arm_lengths.shape)
c_in_AU_per_s = 0.0020039888
f_star = c_in_AU_per_s / arm_lengths[0, 0]
N_freqs = 5
freqs = jnp.linspace(0, 0.5 * f_star, N_freqs)
full_transfer = jitted_vmapped_transfer_function(k_hat, freqs, arms)
arm_directions = arms / arm_lengths[..., None]
antennae = sky_vmapped_antenna_pattern(u_hat, v_hat, arm_directions)
print(receiver_positions.shape)
print(antennae.shape)
print(full_transfer.shape)

In [None]:
response, antennae = response_pipe(
    orbits,
    freqs,
    sky_basis=sky_basis,
)

response.shape

In [None]:
path_1 = jnp.array([0, 1, 0,])
path_2 = jnp.array([0, 2, 0,])
paths = jnp.stack([path_1, path_2], axis=0)

path_responses, cumul_path_separations = get_path_response(
    paths,
    freqs,
    arm_lengths,
    response,
)

michelson_response = get_differential_strain_response(
    path_responses,
    path_idx_1=0,
    path_idx_2=1,
    cumul_path_separations=cumul_path_separations,
)
michelson_response.shape

In [None]:
# now plot_surface in 3D the antenna pattern of response_plus_abs

fig = plt.figure(figsize=(10,10))
ax1 = fig.add_subplot(221, projection='3d')
ax2 = fig.add_subplot(222, projection='3d')
ax3 = fig.add_subplot(223, projection='3d')
ax4 = fig.add_subplot(224, projection='3d')
# Make data
response_plus_abs = jnp.abs(michelson_response[..., 0])
print(response_plus_abs.min(), response_plus_abs.max())
response_cross_abs = jnp.abs(michelson_response[..., 1])
plotted_response_1 = response_plus_abs[0,:,0]
plotted_response_2 = response_plus_abs[0,:,-1]
plotted_response_3 = response_cross_abs[0,:,0]
plotted_response_4 = response_cross_abs[0,:,-1]
axes = [ax1, ax2, ax3, ax4]
responses = [plotted_response_1, plotted_response_2, plotted_response_3, plotted_response_4]
for idx in range(len(axes)):
    plotted_response = responses[idx]
    plot_response(
        plotted_response,
        ecl_thetas_reduced,
        ecl_phis_reduced,
        axes[idx],
    )
plt.show()

In [None]:
ecl_phis.shape

# LIGO Baseline

In [None]:
FREQ_CENTER_ORBIT = 1  # in 1/year
FREQ_ROTATION = 365.25  # in 1/year
N_times = 2
times = jnp.linspace(0, 1 / FREQ_ROTATION, N_times)
r = 1 # in AU
L_arm = 4 # in km

# https://www.ligo.org/scientists/GW100916/detectors.txt
H_lat, H_lon = (
    46.455140209119214 * jnp.pi / 180,
    -119.40746331631823 * jnp.pi / 180,
)
L_lat, L_lon = (
    30.562932349951804 * jnp.pi / 180,
    -90.77416707625777 * jnp.pi / 180,
)

psi_H = (90+36) * jnp.pi / 180
psi_L = (180+18) * jnp.pi / 180
beta_H = 90 * jnp.pi / 180
beta_L = 90 * jnp.pi / 180

orbits_H = earthbound_ifo_pipeline(
    H_lat,
    H_lon,
    times,
    r,
    L_arm,
    psi=psi_H,
    beta_arm=beta_H,
)
print(orbits_H.shape)

orbits_L = earthbound_ifo_pipeline(
    L_lat,
    L_lon,
    times,
    r,
    L_arm,
    psi=psi_L,
    beta_arm=beta_L,
)
print(orbits_L.shape)


In [None]:

receiver_orbits = get_receiver_positions(orbits_H)
receiver_positions_H = flatten_pairs(receiver_orbits)
separations = get_separations(orbits_H)
arms = flatten_pairs(separations)
arm_lengths_H = get_arm_lengths(arms)


receiver_orbits = get_receiver_positions(orbits_L)
receiver_positions_L = flatten_pairs(receiver_orbits)
separations = get_separations(orbits_L)
arms = flatten_pairs(separations)
arm_lengths_L = get_arm_lengths(arms)

# use the mean arm length
# init_ifo_arm_lengths = [
#     arm_lengths_H[0, 0],
#     arm_lengths_L[0, 0],
# ]
# mean_init_arm_length = jnp.mean(jnp.array(init_ifo_arm_lengths))
# c_in_AU_per_s = 0.0020039888
# f_star = c_in_AU_per_s / mean_init_arm_length
# print(f"f_star = {f_star:.2e} Hz")
# freqs = jnp.linspace(0, 0.5 * f_star, 200)
freqs = jnp.linspace(10, 500, 50)
print(f"min freq = {freqs[0]:.2e} Hz")
print(f"max freq = {freqs[-1]:.2e} Hz")

In [None]:
response_H, antennae_H = response_pipe(
    orbits_H,
    freqs,
    sky_basis=sky_basis,
)
print(response_H.shape)
response_L, antennae_L = response_pipe(
    orbits_L,
    freqs,
    sky_basis=sky_basis,
)
print(response_L.shape)

In [None]:
path_1 = jnp.array([0, 1, 0,])
path_2 = jnp.array([0, 2, 0,])
paths = jnp.stack([path_1, path_2,], axis=0)

path_responses_H, cumul_path_separations_H = get_path_response(
    paths,
    freqs,
    arm_lengths_H,
    response_H,
)

print("Path response shape: ", path_responses_H.shape)

michelson_response_H = get_differential_strain_response(
    path_responses_H,
    path_idx_1=0,
    path_idx_2=1,
    cumul_path_separations=cumul_path_separations_H,
)
print("H", michelson_response_H.shape)

path_responses_L, cumul_path_separations_L = get_path_response(
    paths,
    freqs,
    arm_lengths_L,
    response_L,
)

michelson_response_L = get_differential_strain_response(
    path_responses_L,
    path_idx_1=0,
    path_idx_2=1,
    cumul_path_separations=cumul_path_separations_L,
)
print(michelson_response_L.shape)

In [None]:
michelson_response_H_lm = pixel_to_lm(michelson_response_H, 1, N_theta, N_phi, ecl_thetas, ecl_phis, sph_harm_values)
michelson_response_L_lm = pixel_to_lm(michelson_response_L, 1, N_theta, N_phi, ecl_thetas, ecl_phis, sph_harm_values)

In [None]:
overlap_from_lm = michelson_response_H_lm * jnp.conj(michelson_response_L_lm)
iso_overlap_from_lm_unnormed = overlap_from_lm[...,0,0]
iso_overlap_from_lm_unnormed = 0.5 * jnp.sum(iso_overlap_from_lm_unnormed, axis=-1)
plt.plot(freqs, jnp.real(iso_overlap_from_lm_unnormed[0]),label='real')
plt.plot(freqs, jnp.imag(iso_overlap_from_lm_unnormed[0]),label='imag')
plt.xlabel('Frequency [Hz]')
plt.ylabel('Overlap')
plt.legend()
plt.show()

In [None]:
# now plot_surface in 3D the antenna pattern of response_plus_abs

fig = plt.figure(figsize=(10,6))
ax1 = fig.add_subplot(231, projection='3d')
ax2 = fig.add_subplot(232, projection='3d')
ax3 = fig.add_subplot(233, projection='3d')
ax4 = fig.add_subplot(234, projection='3d')
ax5 = fig.add_subplot(235, projection='3d')
ax6 = fig.add_subplot(236, projection='3d')

# Make data
response_plus_abs_H = jnp.abs(michelson_response_H[..., 0])
response_cross_abs_H = jnp.abs(michelson_response_H[..., 1])
unpolarized_response_abs_H = jnp.sqrt(response_plus_abs_H**2 + response_cross_abs_H**2)

response_plus_abs_L = jnp.abs(michelson_response_L[..., 0])
response_cross_abs_L = jnp.abs(michelson_response_L[..., 1])
unpolarized_response_abs_L = jnp.sqrt(response_plus_abs_L**2 + response_cross_abs_L**2)

plotted_response_1 = response_plus_abs_H[0,:,0]
plotted_response_2 = response_cross_abs_H[0,:,0]
plotted_response_3 = unpolarized_response_abs_H[0,:,0]
plotted_response_4 = response_plus_abs_L[0,:,1]
plotted_response_5 = response_cross_abs_L[0,:,1]
plotted_response_6 = unpolarized_response_abs_L[0,:,1]

axes = [ax1, ax2, ax3, ax4, ax5, ax6]
responses = [
    plotted_response_1, plotted_response_2, plotted_response_3,
    plotted_response_4, plotted_response_5, plotted_response_6,
]

assert len(axes) == len(responses), "axes and responses must be the same length"
for idx in range(len(axes)):
    plotted_response = responses[idx]
    plot_response(
        plotted_response,
        ecl_thetas_reduced,
        ecl_phis_reduced,
        axes[idx],
    )
    # axes[idx].view_init(0,180)
plt.show()

In [None]:
anisotropic_ORF_HL = 0.5 * jnp.sum(
    michelson_response_H * jnp.conj(michelson_response_L),
    axis=-1,
)
anisotropic_ORF_HL.shape

In [None]:
from scipy.interpolate import RectSphereBivariateSpline

plot_freq_idx = -1

re_interpolator = RectSphereBivariateSpline(
    ecl_thetas_reduced,
    ecl_phis_reduced - jnp.pi,
    unflatten_sky_axis(jnp.real(anisotropic_ORF_HL[0,:,plot_freq_idx]), 0, N_theta, N_phi),
)

imag_interpolator = RectSphereBivariateSpline(
    ecl_thetas_reduced,
    ecl_phis_reduced - jnp.pi,
    unflatten_sky_axis(jnp.imag(anisotropic_ORF_HL[0,:,plot_freq_idx]), 0, N_theta, N_phi),
)

ecl_latitudes = jnp.pi/2 - ecl_thetas
# declination plays the role of latitude in the equatorial coordinate system
# sin_declinations = jnp.sin(ecl_latitudes)*jnp.cos(EARTH_TILT)-jnp.cos(ecl_latitudes)*jnp.sin(EARTH_TILT)*jnp.sin(ecl_phis)
# tan_right_ascensions = (jnp.sin(ecl_phis)*jnp.cos(EARTH_TILT)-jnp.tan(ecl_latitudes)*jnp.sin(EARTH_TILT))/jnp.cos(ecl_phis)
# print(sin_declinations.shape, tan_right_ascensions.shape)
x_equatorial = jnp.cos(ecl_latitudes)*jnp.cos(ecl_phis)
y_equatorial = jnp.cos(ecl_latitudes)*jnp.sin(ecl_phis)*jnp.cos(EARTH_TILT) - jnp.sin(ecl_latitudes)*jnp.sin(EARTH_TILT)
z_equatorial = jnp.cos(ecl_latitudes)*jnp.sin(ecl_phis)*jnp.sin(EARTH_TILT) + jnp.sin(ecl_latitudes)*jnp.cos(EARTH_TILT)
equat_thetas = jnp.arccos(z_equatorial)

equat_phis = jnp.sign(y_equatorial)*jnp.arccos(x_equatorial/(x_equatorial**2 + y_equatorial**2)**0.5)
print(equat_thetas.shape, equat_phis.shape)

equat_ani_ORF_re = re_interpolator(equat_thetas, equat_phis, grid=False)
equat_ani_ORF_im = imag_interpolator(equat_thetas, equat_phis, grid=False)
equat_ani_ORF = equat_ani_ORF_re + 1j*equat_ani_ORF_im
print(equat_ani_ORF_re.shape, equat_ani_ORF_im.shape)


In [None]:
# WIKIPEDIA: https://en.wikipedia.org/wiki/Ecliptic_coordinate_system


ecl_lat_reduced = (jnp.pi/2 - ecl_thetas_reduced) * 180 / jnp.pi
# invert latitudes to start at the south pole, so that latitudes are monotonically increasing
ecl_lat_reduced = ecl_lat_reduced[::-1]
ecl_lon_reduced = (ecl_phis_reduced - jnp.pi) * 180 / jnp.pi
lon2d, lat2d = jnp.meshgrid(ecl_lon_reduced, ecl_lat_reduced)
realPart = True
if realPart:
    plot_ani_ORF_HL = equat_ani_ORF_re
else:
    plot_ani_ORF_HL = equat_ani_ORF_im
plot_ani_ORF_HL = unflatten_sky_axis(plot_ani_ORF_HL, axis=0, N_theta=N_theta, N_phi=N_phi)
plot_ani_ORF_HL = plot_ani_ORF_HL[::-1, :]
plot_equat_lats = ecl_lat_reduced
plot_equat_lons = ecl_lon_reduced
print(plot_equat_lats.shape, plot_equat_lons.shape,)
print(plot_equat_lats.min(), plot_equat_lats.max(), plot_equat_lats[0], plot_equat_lats[-1], plot_equat_lats[1] - plot_equat_lats[0])
print(plot_equat_lons.min(), plot_equat_lons.max(), plot_equat_lons[0], plot_equat_lons[-1])
# unflatten_sky_axis(k_hat, axis=0, N_theta=N_theta, N_phi=N_phi)

In [None]:
import cartopy.crs as ccrs
# plot the Molleweide projection of anisotropic overlap reduction function
# for the two detectors

key = jax.random.PRNGKey(0)
plt.figure(figsize=(10,6))
data_crs = ccrs.PlateCarree()
ax = plt.axes(projection=ccrs.Mollweide())
# ax.set_extent(
#     [
#         -120,-80,
#         20, 60
#     ], 
#     crs=ccrs.PlateCarree(),
# )
ax.set_global()
ax.coastlines()
# data = jax.random.normal(key, shape=(N_theta, N_phi))
# print(data.shape)
contourORF = plt.contourf(
    lon2d, 
    lat2d, 
    jnp.real(plot_ani_ORF_HL),
    transform=data_crs,
    levels=100,
    cmap='jet',
)
ax.gridlines()
colorbarORF = plt.colorbar(
    contourORF, 
    orientation='vertical', 
    shrink=0.5, 
    pad=0.05,
)
print(freqs[plot_freq_idx])
str_re_imag = r"\Re" if realPart else r"\Im"
colorbarORF.set_label(r'$%s\left[\gamma(t=0,f=%.1f \mathrm{Hz},\hat{\Omega})\right]$' % (str_re_imag, freqs[plot_freq_idx]))
plt.show()

In [None]:
# verify that the sum of d_omega is 4 pi
d_Omega = get_solid_angle_theta_phi(ecl_thetas, ecl_phis, N_theta, N_phi)
print(jnp.sum(d_Omega) - 4 * jnp.pi)

In [None]:
gamma_lm_HL = pixel_to_lm(anisotropic_ORF_HL, 1, N_theta, N_phi, ecl_thetas, ecl_phis, sph_harm_values)
gamma_HL_00 = gamma_lm_HL[0,:,0,0]
normalization = 5 / (jnp.sin(jnp.pi/2.0))**2
Y_00 = sph_harm_values[0,0,0]
gamma_iso_from_aniso = normalization * (1 / (4 * jnp.pi * Y_00)) * gamma_HL_00
print(gamma_iso_from_aniso.shape)

In [None]:
labels = {
    "real_part": r'$\Re\left[\gamma_{00}(t=0,f)\right]$',
    "imag_part": r'$\Im\left[\gamma_{00}(t=0,f)\right]$',
}
print(gamma_iso_from_aniso[0])
plt.axhline(y=0, color='k')
plt.plot(freqs, jnp.real(gamma_iso_from_aniso), label=labels["real_part"], linestyle="", marker='o', markersize=10)
plt.plot(freqs, jnp.imag(gamma_iso_from_aniso), label=labels["imag_part"], linestyle="", marker='o', markersize=10, alpha=0.5)
# add horizontal line at 0

plt.title("SHD l=0,m=0 of anisotropic ORF in ecliptic coordinates")
plt.xlabel(r'$f\;\mathrm{[Hz]}$')
plt.ylabel(r'$\mathcal{N}\;\gamma_{00}(t=0, f)$')
plt.legend()
plt.show()

In [None]:
# average over directions in the sky
# with the trapezoidal rule:
# gamma_IJ = sum (delta_omega * gamma_IJ)

print(d_Omega.shape)
arm_angle = jnp.pi/2
normalization = 5 / (jnp.sin(arm_angle)**2)
isotropic_ORF_HL = normalization / (4.0 * jnp.pi) * jnp.sum(
    anisotropic_ORF_HL * d_Omega[..., None],
    axis=1
)
print(isotropic_ORF_HL.shape)


In [None]:
fig = plt.figure(figsize=(8, 6))
print(isotropic_ORF_HL[0, 0])
plt.plot(freqs, jnp.real(isotropic_ORF_HL[0, :]), label=r"$\Re\gamma_{HL}$",  linewidth=5, color="black")
plt.plot(freqs, jnp.imag(isotropic_ORF_HL[0, :]), label=r"$\Im\gamma_{HL}$", linewidth=5, color="red")
plt.plot(freqs, jnp.real(isotropic_ORF_HL[1, :]), label=r"$\Re\gamma_{HL}$ later", linestyle="--", color="pink", linewidth=1)
plt.plot(freqs, jnp.imag(isotropic_ORF_HL[1, :]), label=r"$\Im\gamma_{HL}$ later", linestyle="--", color="blue", linewidth=1)
plt.xlabel("Frequency [Hz]")
plt.ylabel("Isotropic ORF")
plt.legend()
plt.show()

In [None]:
# use the Wigner D matrix to rotate the anisotropic ORF to the equatorial frame
# get the Wigner D matrix
# D^l_{m,m'}(\alpha, \beta, \gamma) = e^{-i m \alpha} d^l_{m,m'}(\beta) e^{-i m' \gamma}
# euler angles to rotate from ecliptic to equatorial frame:
# alpha = ?
# beta = ?
# gamma = ?
alpha, beta, gamma = -jnp.pi/2, EARTH_TILT, jnp.pi/2
m_array = jnp.arange(0, l_max + 1)[None, :, None]
m_prime_array = jnp.arange(-l_max, l_max + 1)[None, None, :]
exp_alpha = jnp.exp(-1j * m_prime_array * alpha)
exp_gamma = jnp.exp(-1j * m_array * gamma)

from jax.scipy.special import gammaln

def gamma_func(x):
    """Gamma function:
    gamma(x) = int_0^inf t^{x-1} e^{-t} dt
    """
    return jnp.exp(gammaln(x))

def binomial_coefficient(n, k):
    """Binomial coefficient:
    (n k) = n! / (k! (n-k)!)
    """
    return gamma_func(n + 1) / (gamma_func(k + 1) * gamma_func(n - k + 1))

def get_jacobi_polynomial(n, a, b, cos_theta):
    # promote n, a, b to arrays, if they are not already
    n = jnp.asarray(n)
    a = jnp.asarray(a)
    b = jnp.asarray(b)
    n_max = jnp.max(n)
    s_array_shape = (n_max + 1,) + (-1,) * len(n.shape)
    s_array = jnp.arange(0, n_max + 1).reshape(s_array_shape)
    print(f"{s_array[0].min(), s_array[0].max()=}")
    print(f"{(n)[1, 0, :]=}")
    print(f"{(n)[1, 1, :]=}")
    print(f"{(a)[1, 0, :]=}")
    print(f"{(a)[1, 1, :]=}")
    print(f"{(b)[1, 0, :]=}")
    print(f"{(b)[1, 1, :]=}")
    physical_sum = jnp.where(n >= s_array, 1, 0)
    s_array_phys = jnp.where(physical_sum, s_array, 0)
    print(f"{s_array_phys[:, 1, 0, :]=}")
    print(f"{s_array_phys[:, 1, 1, :]=}")
    print(f"{(n+a)[1, 0, :]=}")
    print(f"{(n+a)[1, 1, :]=}")
    print(f"{(n-s_array_phys)[:, 1, 0, :]=}")
    print(f"{(n-s_array_phys)[:, 1, 1, :]=}")
    
    comb_1 = binomial_coefficient(n+a, n-s_array_phys)
    print(f"{comb_1[:, 1, 0, :]=}")
    print(f"{comb_1[:, 1, 1, :]=}")
    comb_2 = binomial_coefficient(n+b, s_array_phys)
    print(f"{comb_2[:, 1, 0, :]=}")
    print(f"{comb_2[:, 1, 1, :]=}")
    trig_1 = jnp.power((cos_theta - 1) / 2, s_array_phys)
    trig_2 = jnp.power((cos_theta + 1) / 2, n - s_array_phys)
    jacobi_polynomial = comb_1 * comb_2 * trig_1 * trig_2
    
    jacobi_polynomial = jnp.where(physical_sum, jacobi_polynomial, 0)
    print(f"{jacobi_polynomial[:, 1, 0, :]=}")
    print(f"{jacobi_polynomial[:, 1, 1, :]=}")
    jacobi_polynomial = jnp.sum(jacobi_polynomial, axis=0)

    return jacobi_polynomial


def get_d_Wigner_small(l_max, beta):
    """
    get the small Wigner d matrix from the Jacobi polynomials
    """
    l_array = jnp.arange(0, l_max + 1)[:, None, None]
    m_array = jnp.arange(0, l_max + 1)[None, :, None]
    m_prime_array = jnp.arange(-l_max, l_max + 1)[None, None, :]
    physical_m = jnp.where(l_array >= jnp.abs(m_array), 1, 0)
    physical_m_prime = jnp.where(l_array >= jnp.abs(m_prime_array), 1, 0)
    m_array = jnp.where(physical_m, m_array, 0)
    m_prime_array = jnp.where(physical_m_prime, m_prime_array, 0)
    k = jnp.minimum(l_array - jnp.abs(m_array), l_array - jnp.abs(m_prime_array))
    print(k.shape, k.dtype, k.min(), k.max())
    print(f"{k[1, 0, :]=}")
    print(f"{k[1, 1, :]=}")
    
    case_plus_m = jnp.where(k == l_array + m_array, 1, 0)
    case_minus_m_prime = jnp.where(k == l_array - m_prime_array, 1, 0)
    a_param = jnp.where(case_plus_m, m_prime_array - m_array, m_array - m_prime_array)
    a_param = jnp.where(case_minus_m_prime, m_prime_array - m_array, a_param)

    lambda_param = jnp.where(case_plus_m, m_prime_array - m_array, 0)
    lambda_param = jnp.where(case_minus_m_prime, m_prime_array - m_array, lambda_param)
    print(f"{lambda_param[1, 0, :]=}")
    print(f"{lambda_param[1, 1, :]=}")
    
    b_param = 2 * l_array - 2 * k - a_param
    print(a_param.shape)
    print(b_param.shape)
    
    sin_beta_over_2 = jnp.sin(beta / 2.0)
    cos_beta_over_2 = jnp.cos(beta / 2.0)
    cos_beta = jnp.cos(beta)
    
    sin_pow = jnp.power(sin_beta_over_2, a_param)
    cos_pow = jnp.power(cos_beta_over_2, b_param)
    alternate = jnp.power(-1, lambda_param)
    comb_1 = binomial_coefficient(2 * l_array - k, k + a_param)
    comb_2 = binomial_coefficient(k + b_param, b_param)
    coeff_sqrt = jnp.sqrt(comb_1 / comb_2)
    prefactor = alternate * coeff_sqrt * sin_pow * cos_pow
    # print(f"{prefactor=}", prefactor.shape)
    
    # get the Jacobi polynomial P^(a,b)_k(cos(beta))
    jacobi_poly = get_jacobi_polynomial(k, a_param, b_param, cos_beta)
    print(f"{jacobi_poly[0]=}, {jacobi_poly[1]=}", jacobi_poly.shape)
    # get the Wigner d matrix
    d_Wigner_small = prefactor * jacobi_poly
    d_Wigner_small = jnp.where(physical_m, d_Wigner_small, 0)
    d_Wigner_small = jnp.where(physical_m_prime, d_Wigner_small, 0)

    

    return d_Wigner_small


In [None]:
# binomial_coefficient(-1, -1)

In [None]:
# print("Test jacobi polynomial")
# # P(n=0, a, b, cos_theta) = 1
# print("Test P(0, a, b, cos_theta)==1")
# print(get_jacobi_polynomial(0, 2, 2, 2), get_jacobi_polynomial(0, 3, 4, 0), get_jacobi_polynomial(0, 4, 10, 0))
# print("Test P(1, a, b, cos_theta)==(a+1) + (a+b+2) (cos_theta-1)/2")
# a, b, cos_theta = 0, 0, -0.5
# print(get_jacobi_polynomial(1, a, b, cos_theta), (a+1) + (a+b+2) * (cos_theta-1)/2)
# a, b, cos_theta = 1, 1, 3/4
# print(get_jacobi_polynomial(1, a, b, cos_theta), (a+1) + (a+b+2) * (cos_theta-1)/2)
# a, b, cos_theta = 1, 0, 0
# print(get_jacobi_polynomial(1, a, b, cos_theta), (a+1) + (a+b+2) * (cos_theta-1)/2)

print("test wigner d matrix")
beta_test = jnp.pi / 3.0
l_max_test = 0
d_Wigner_small = get_d_Wigner_small(l_max_test, beta_test)
print(f"{l_max_test=}, {beta_test=}, {d_Wigner_small[0]=}")
l_max_test = 6
d_Wigner_small = get_d_Wigner_small(l_max_test, beta_test)
# print(d_Wigner_small, d_Wigner_small.shape)
p25_sqrt_6 = jnp.sqrt(6) / 4.0
print(f"{d_Wigner_small[1,0,l_max_test-1]=}, {p25_sqrt_6=}")

In [None]:
d_Wigner_small = get_d_Wigner_small(l_max, beta)

wigner_D = exp_alpha * d_Wigner_small * exp_gamma

print(f"{wigner_D.shape=}")
print(f"{gamma_lm_HL.shape=}")
# generate the negative m values of gamma_lm_HL, and sum them over the last axis
# of the Wigner D matrix, using
# Y_{l,-m}(\theta, \phi) = (-1)^m Y_{l,m}^*(\theta, \phi)
gamma_lm_HL_neg = jnp.conj(gamma_lm_HL[..., ::-1])
# keep only the non-zero m values
gamma_lm_HL_neg = gamma_lm_HL_neg[..., :l_max]
# stack the positive and negative m values of gamma_lm_HL
gamma_lm_HL_full = jnp.concatenate([gamma_lm_HL_neg, gamma_lm_HL], axis=-1)
print(f"{gamma_lm_HL_full.shape=}")


# rotate the anisotropic ORF to the equatorial frame
gamma_lm_HL_equatorial = jnp.einsum(
    "lmp,tklp->tklm",
    jnp.conj(wigner_D),
    gamma_lm_HL_full)
print(f"{gamma_lm_HL_equatorial.shape=}")

In [None]:
normalization = 5 / (jnp.sin(jnp.pi / 2) ** 2)
gamma_iso_from_Wigner = (
    normalization 
    * (1 / (4 * jnp.pi * Y_00)) 
    * gamma_lm_HL_equatorial[0, :, 0, 0]
    )
print(f"{gamma_iso_from_Wigner[0]=}")
plt.plot(freqs, gamma_iso_from_Wigner.real, label=r"$\Re(\gamma_{iso})$")
plt.plot(freqs, gamma_iso_from_Wigner.imag, label=r"$\Im(\gamma_{iso})$")
plt.xlabel("Frequency [Hz]")
plt.ylabel(r"$\gamma_{iso}$")
plt.legend()
plt.show()


In [None]:
from scipy.special import spherical_jn
gamma_10_from_Wigner = (
    gamma_lm_HL_equatorial[0, :, 1, 0]
    )
HL_separation = receiver_positions_H - receiver_positions_L
HL_time_delay = jnp.linalg.norm(HL_separation, axis=-1) / c_in_AU_per_s

print(f"{HL_time_delay=}, {HL_time_delay.shape=}")
x = 2 * jnp.pi * freqs * HL_time_delay[0,0]
gamma_10_analytic = (
    - 0.0608j* spherical_jn(1, x) 
    - 2.6982j *spherical_jn(2, x) / x 
    + 7.7217j *spherical_jn(3, x) / x**2
)

print(f"{gamma_10_from_Wigner[1]=}")
print(f"{gamma_10_analytic[1]=}")
plt.plot(freqs, gamma_10_from_Wigner.real, label=r"$\Re(\gamma_{10})$")
plt.plot(freqs, gamma_10_from_Wigner.imag, label=r"$\Im(\gamma_{10})$")
plt.plot(freqs, gamma_10_analytic.real, label=r"$\Re(\gamma_{10})$ analytic", linestyle="", marker="x")
plt.plot(freqs, gamma_10_analytic.imag, label=r"$\Im(\gamma_{10})$ analytic", linestyle="", marker="x")
plt.xlabel("Frequency [Hz]")
plt.ylabel(r"$\gamma_{10}$")
plt.legend()
plt.show()

In [None]:
from scipy.special import spherical_jn
gamma_11_from_Wigner = (
    gamma_lm_HL_equatorial[0, :, 1, 1]
    )
HL_separation = receiver_positions_H - receiver_positions_L
HL_time_delay = jnp.linalg.norm(HL_separation, axis=-1) / c_in_AU_per_s

print(f"{HL_time_delay=}, {HL_time_delay.shape=}")
x = 2 * jnp.pi * freqs * HL_time_delay[0,0]
gamma_11_analytic = (
    - (0.0519 + 0.0652j) * spherical_jn(1,x) 
    - (1.8621 + 1.0517j) * spherical_jn(2,x)/x
    + (4.0108 - 2.4933j) * spherical_jn(3,x)/x**2
)

# equat_ani_ORF_flipped = (
#     unflatten_sky_axis(equat_ani_ORF, axis=0, N_theta=N_theta, N_phi=N_phi)
#     [::-1,::-1].flatten()
# )

gamma_lm_interpolated = pixel_to_lm(
    equat_ani_ORF, 
    0, 
    N_theta, 
    N_phi, 
    ecl_thetas, 
    ecl_phis, 
    sph_harm_values
)

print(f"{gamma_lm_interpolated.shape=}")

print(f"{gamma_11_from_Wigner[1]=}")
print(f"{gamma_11_analytic[1]=}")
print(f"{gamma_11_from_Wigner[-1]=}")
print(f"{gamma_11_analytic[-1]=}")
print(f"{-gamma_lm_interpolated[1,1]=}")
plt.plot(freqs, gamma_11_from_Wigner.real, label=r"$\Re(\gamma_{11})$")
plt.plot(freqs, gamma_11_from_Wigner.imag, label=r"$\Im(\gamma_{11})$")
plt.plot(freqs, gamma_11_analytic.real, label=r"$\Re(\gamma_{11})$ analytic", linestyle="", marker="x")
plt.plot(freqs, gamma_11_analytic.imag, label=r"$\Im(\gamma_{11})$ analytic", linestyle="", marker="x")
plt.scatter(freqs[plot_freq_idx], -gamma_lm_interpolated[1,1].real, label=r"$\Re(\gamma_{11})$ interpolated", marker="o")
plt.scatter(freqs[plot_freq_idx], -gamma_lm_interpolated[1,1].imag, label=r"$\Im(\gamma_{11})$ interpolated", marker="o")
plt.xlabel("Frequency [Hz]")
plt.ylabel(r"$\gamma_{11}$")
plt.legend()
plt.show()

# Angular Noise power spectrum for LIGO

In [None]:
import numpy as np
noise_Aplus = np.loadtxt("jax_gw/detector/detectors/aplus.txt")
freqs_noise, noise_Aplus_asd = noise_Aplus[:,0], noise_Aplus[:,1]
noise_Aplus_psd = noise_Aplus_asd**2
noise_interp = jnp.interp(freqs, freqs_noise, noise_Aplus_psd)

plt.plot(freqs_noise, noise_Aplus_psd, label="A+")
print(freqs[1:])
plt.scatter(
    freqs[1:],
    noise_interp[1:],
    label="A+ interpolated",
    marker="o",
    color="red"
)
plt.xlabel("Frequency [Hz]")
plt.ylabel("PSD [Hz$^{-1}$]")
plt.xscale("log")
plt.yscale("log")
plt.legend()
plt.show()

In [None]:
def angular_noise_ell(overlap_lm, N_f, ell_array, freqs, f_ref, spectral_index, t_obs):
    
    spectral_shape = jnp.power(freqs / f_ref, spectral_index)
    print(f"{spectral_shape.min()=}, {spectral_shape.max()=}")
    overlap_init = overlap_lm[0]
    # sum the square complex norm over m (axis=-1)
    overlap_sq_init = jnp.sum(overlap_init * jnp.conj(overlap_init), axis=-1)
    print(overlap_sq_init[:,0].min(), overlap_sq_init[:,0].max(), overlap_sq_init[:,0].mean())
    integrand = overlap_sq_init * ((2.0/5.0*spectral_shape/N_f)**2)[:, None] / (2 * ell_array + 1)
    # integrate over frequency
    noise_ell_inv = t_obs / 2 * jnp.trapz(integrand, freqs, axis=0)

    return noise_ell_inv**(-1)
ell_array = jnp.arange(0, l_max + 1)
f_ref = 63.0
t_obs_noise = 4.0
# print(f"{noise_interp=}")
print(f"{ell_array=}")
print(f"{freqs=}")
print(f"{f_ref=}")
print(f"{spectral_index=}")
print(f"{t_obs_noise=}")
HL_N_ell_alpha_m2p3 = angular_noise_ell(
    gamma_lm_HL_equatorial, 
    noise_interp, 
    ell_array, 
    freqs,
    f_ref, 
    spectral_index=-2.3, 
    t_obs=t_obs_noise)
HL_N_ell_alpha_0 = angular_noise_ell(
    gamma_lm_HL_equatorial, 
    noise_interp, 
    ell_array, 
    freqs,
    f_ref, 
    spectral_index=0, 
    t_obs=t_obs_noise)
HL_N_ell_alpha_m3 = angular_noise_ell(
    gamma_lm_HL_equatorial, 
    noise_interp, 
    ell_array, 
    freqs,
    f_ref, 
    spectral_index=-3, 
    t_obs=t_obs_noise)

In [None]:
plt.plot(ell_array[1:], (jnp.abs(HL_N_ell_alpha_m2p3)*(ell_array+1/2))[1:], label=r"$\alpha_I=-2.3$")
plt.plot(ell_array[1:], (jnp.abs(HL_N_ell_alpha_0)*(ell_array+1/2))[1:], label=r"$\alpha_I=0$")
plt.plot(ell_array[1:], (jnp.abs(HL_N_ell_alpha_m3)*(ell_array+1/2))[1:], label=r"$\alpha_I=-3$")
plt.xlabel(r"$\ell$")
plt.ylabel(r"$N_\ell\;(\ell + 1/2)$")
plt.xscale("log")
plt.yscale("log")
plt.legend()
plt.show()

In [None]:
import matplotlib.pyplot as plt

fig = plt.figure(figsize=(8, 10))
ax = fig.add_subplot(projection='3d')
# Make data for sphere of radius 10
plot_response(
    10*jnp.ones((N_theta*N_phi)),
    ecl_thetas_reduced,
    ecl_phis_reduced,
    ax,
)

# add the initial position of the detector
plot_x, plot_y, plot_z = 10 * lat_lon_to_cartesian(H_lat, H_lon)
print(plot_x, plot_y, plot_z)

ax.scatter(
    plot_x,
    plot_y,
    plot_z,
    color='red',
    marker='o',
    s=10,
    label='Initial position of LIGO Hanford',
)

plot_x, plot_y, plot_z = 10 * lat_lon_to_cartesian(L_lat, L_lon)
print(plot_x, plot_y, plot_z)
ax.scatter(
    plot_x, plot_y, plot_z,
    color='green',
    marker='o',
    s=10,
    label='Initial position of LIGO Livingston',
)

# add also an arrow connecting the south pole to the north pole
ax.quiver(
    0, 0, 0, 0, 0, 10,
    color='black',
    label='Polar axis',
)

# add also a circle in the equatorial plane
ax.plot(
    10 * jnp.cos(ecl_phis_reduced),
    10 * jnp.sin(ecl_phis_reduced),
    0,
    color='black',
    linestyle='--',
    label='Equatorial plane',
)

# create an arc for the prime meridian
# which has points of zero longitude
points_meridian = 10 * jnp.array(
    [
        lat_lon_to_cartesian(theta, 0)
        for theta in jnp.linspace(-jnp.pi/2, +jnp.pi/2, 100)
    ]
)
ax.plot(
    points_meridian[:, 0],
    points_meridian[:, 1],
    points_meridian[:, 2],
    color='black',
    linestyle='--',
    label='Greenwich meridian',
)

# rotate the point of view
# ax.view_init(23, 190)

# Add a legend
ax.legend(loc='upper left')

# add axis labels
ax.set_xlabel('x')
ax.set_ylabel('y')
ax.set_zlabel('z')

plt.tight_layout()

plt.show()

# Small-Antenna Limit

In [None]:
response_sa = 0.5*(antennae[:,0]+antennae[:,1]-antennae[:,2]-antennae[:,3])
print(response_sa.shape)
response_plus_abs = jnp.abs(response_sa[..., 0])
response_cross_abs = jnp.abs(response_sa[..., 1])
unpolarized_response_abs = jnp.sqrt(response_plus_abs**2 + response_cross_abs**2)

In [None]:
# now plot_surface in 3D the antenna pattern of response_plus_abs

fig = plt.figure(figsize=(15,5))
ax1 = fig.add_subplot(131,projection='3d')
ax2 = fig.add_subplot(132,projection='3d')
ax3 = fig.add_subplot(133,projection='3d')
axes = [ax1, ax2, ax3]

responses = [
    response_plus_abs[:,0], 
    response_cross_abs[:,0],
    unpolarized_response_abs[:,0],
]

for idx in range(len(axes)):
    plot_response(
        responses[idx],
        ecl_thetas_reduced,
        ecl_phis_reduced,
        axes[idx],
    )

plt.show()

In [None]:
# u corresponds to phi of sperical coordinates from 0 to 2pi
u = ecl_phis_reduced
# v corresponds to theta of sperical coordinates from 0 to pi
v = ecl_thetas_reduced
print(u.min(), u.max(), v.min(), v.max())
plt.plot(u, label='u')
plt.plot(v, label='v')
plt.legend()
plt.show()

In [None]:
fig = plt.figure()
ax = fig.add_subplot(projection='3d')

# analytic expression of the response at the small antenna limit
plotted_response = (jnp.cos(u[None,:])**2 - jnp.sin(u[None,:])**2) * (jnp.cos(v[:,None])**2 + 1)
plot_response(
    plotted_response,
    ecl_thetas_reduced,
    ecl_phis_reduced,
    ax,
    unflatten=False,
)

plt.show()

In [None]:
response_H_sa = 0.5*(antennae_H[:,0]+antennae_H[:,1]-antennae_H[:,2]-antennae_H[:,3])
print(response_H_sa.shape)
response_plus_abs = jnp.abs(response_H_sa[..., 0])
response_cross_abs = jnp.abs(response_H_sa[..., 1])
unpolarized_response_abs = jnp.sqrt(response_plus_abs**2 + response_cross_abs**2)

In [None]:
# now plot_surface in 3D the antenna pattern of response_plus_abs

fig = plt.figure(figsize=(15,5))
ax1 = fig.add_subplot(131,projection='3d')
ax2 = fig.add_subplot(132,projection='3d')
ax3 = fig.add_subplot(133,projection='3d')
axes = [ax1, ax2, ax3]

responses = [
    response_plus_abs[:,0], 
    response_cross_abs[:,0],
    unpolarized_response_abs[:,0],
]

for idx in range(len(axes)):
    plot_response(
        responses[idx],
        ecl_thetas_reduced,
        ecl_phis_reduced,
        axes[idx],
    )

plt.show()

In [None]:
fig = plt.figure()
ax = fig.add_subplot(projection='3d')

# analytic expression of the response of H at the small antenna limit
plotted_response = (jnp.cos(u[None,:])**2 - jnp.sin(u[None,:])**2) * (jnp.cos(v[:,None])**2 + 1)
plot_response(
    plotted_response,
    ecl_thetas_reduced,
    ecl_phis_reduced,
    ax,
    unflatten=False,
)

plt.show()

# Ground Based One-Way Arm

In [None]:
path_1 = jnp.array([0, 1,])
paths = jnp.stack([path_1], axis=0)

path_responses = get_path_response(
    paths,
    freqs,
    arm_lengths,
    response,
)

oneway_response = path_responses[0]

response_plus_abs = jnp.abs(oneway_response[..., 0])
response_plus_abs.shape

In [None]:
print(jnp.pi/2)
oneway_response.shape

In [None]:
import numpy as np

sky_idx_mid = N_sky//2
sky_idx_min = sky_idx_mid-N_phi//2
sky_idx_max = sky_idx_mid+N_phi//2


for idx in sky_idx_min, sky_idx_max:
    print(ecl_thetas[idx], ecl_phis[idx])
freq_idx = 40
r_fidx_from_xy = oneway_response[0,sky_idx_min:sky_idx_max,freq_idx,0]
r_f_from_y = oneway_response[0,sky_idx_mid,:,0]
plt.plot(
    ecl_phis_reduced,
    jnp.abs(r_fidx_from_xy), 
    linewidth=3,
    label='response_plus_abs_numerical')
enc_thetas_idx_minmax = ecl_thetas[sky_idx_min:sky_idx_max]
enc_phis_idx_minmax = ecl_phis[sky_idx_min:sky_idx_max]
_, u_hat_idx_minmax, v_hat_idx_minmax = get_directional_basis(enc_thetas_idx_minmax, enc_phis_idx_minmax)
antennae_idx_minmax = sky_vmapped_antenna_pattern(u_hat_idx_minmax, v_hat_idx_minmax, arm_directions[0,0])
print(arm_directions.shape)
print(antennae_idx_minmax.shape)
antennae_idx_minmax_plus = antennae_idx_minmax[:,0]
antennae_idx_minmax_cross = antennae_idx_minmax[:,1]
n_x_xy = jnp.sin(enc_thetas_idx_minmax) * jnp.cos(enc_phis_idx_minmax)

u_dot_n_from_xy = arm_directions[0,0,0]*n_x_xy

sinc_oneway_from_xy = jnp.sinc(freqs[freq_idx]*arm_lengths[0][0].item()/c_in_AU_per_s*(1+u_dot_n_from_xy))

print(sinc_oneway_from_xy.shape)
plt.plot(
    ecl_phis_reduced, 
    0.5*jnp.abs(antennae_idx_minmax_plus*sinc_oneway_from_xy), 
    linestyle='--',
    label='response_plus_abs_analytic')

print(full_transfer.shape)
transfer_idx_from_xy = full_transfer[sky_idx_min:sky_idx_max,freq_idx,0,0]



plt.plot(
    ecl_phis_reduced,
    0.5*jnp.abs(antennae_idx_minmax_plus*transfer_idx_from_xy),
    linestyle=':',
    label='antenna_pattern*transfer_function')
plt.legend()
plt.show()

In [None]:
sky_idx_from_x = sky_idx_min + 8 * N_phi
transfer_idx_from_y = full_transfer[sky_idx_from_x,:,0,0]
n_x_y = jnp.sin(ecl_thetas[sky_idx_from_x]) * jnp.cos(ecl_phis[sky_idx_from_x])
u_dot_n_from_y = arm_directions[0,0,0]*n_x_y
sinc_oneway_from_y = jnp.sinc(freqs*arm_lengths[0][0].item()/c_in_AU_per_s*(1+u_dot_n_from_y))

plt.plot(freqs,jnp.abs(sinc_oneway_from_y), linewidth=3)
plt.plot(freqs,jnp.abs(transfer_idx_from_y), linestyle='--')
plt.show()

In [None]:
f_star = c_in_AU_per_s/ jnp.abs(separations[0,0,1,0])
print(f_star)
plt.plot(
    freqs/f_star,
    response_plus_abs[0,0,:], 
    linewidth=4,
    label='response_plus_abs_numerical')
plt.plot(
    freqs/f_star,
    0.5*jnp.abs(jnp.sinc(freqs/f_star)), 
    label='response_plus_abs_analytic',
    linestyle='dashed')
plt.xscale('log')
plt.grid(True)
plt.legend()
plt.show()

# Ground Based Two-Way Arm

In [None]:
path_1 = jnp.array([0, 1, 0])
paths = jnp.stack([path_1], axis=0)

path_responses = get_path_response(
    paths,
    freqs,
    arm_lengths,
    response,
)

twoway_response = path_responses[0]

response_plus_abs = jnp.abs(twoway_response[..., 0])
response_plus_abs.shape

In [None]:
f_star = c_in_AU_per_s / jnp.abs(separations[0,0,1,0])
print(f_star)
plt.title('Response of a round-trip arm to normal incidence')
plt.plot(freqs/(f_star/2),response_plus_abs[0,0,:], label='2-way numerical', linestyle='--', linewidth=3)
plt.plot(freqs/(f_star/2), jnp.abs(jnp.sinc(2*freqs/f_star)), label='2-way analytic')
plt.xlabel(r'$2f/f_*$')
plt.ylabel(r'$|R_+|$')
plt.xscale('log')
plt.grid(True)
plt.legend()
plt.show()

In [None]:
fig = plt.figure()
ax = fig.add_subplot(projection='3d')

# Make data
plotted_response = response_plus_abs[0,:,0]
plot_response(
    plotted_response,
    ecl_thetas_reduced,
    ecl_phis_reduced,
    ax,
)

plt.show()

# BBO

In [None]:
import jax
import jax.numpy as jnp

from jax_gw.detector.orbits import (
    create_cartwheel_orbit, 
    get_vertex_angle,
    get_receiver_positions, 
    get_separations, 
    get_arm_lengths, 
    flatten_pairs,
)

from jax_gw.detector.pixel import (
    get_directional_basis,
    flat_to_matrix_sky_indices,
    get_sph_harm_values,
    pixel_to_lm,
)

from jax_gw.detector.response import (
    create_cyclic_permutation_paths,
    response_pipe,
    get_path_response,
    get_pairwise_differential_strain_response,
)

from jax_gw.detector.overlap import (
    unpolarized_cross_overlap,
    overlap_angular_noise_ell,
)

import matplotlib.pyplot as plt

# Inputs

Orbits

- N_times
- N_freqs
- t_obs
- L_target
- R_target
- N_vertices
- N_triplets

Response

- N_theta
- N_phi
- path_base
- path_synthesis_type
- l_max

Noise

- delta_x_sq
- delta_a_sq
- f_ref

In [None]:
N_times = 4
N_freqs = 30
f_min = 1E-2
f_max = 1E1
t_obs = 3.16e-5
L_target = 0.05 
R_target = 1.0
N = 3

N_theta = 100
N_phi = 120
l_max = 10

path_1 = jnp.array([0, 1, 0,])

f_ref = 1E-1
delta_x_sq = 2E-34 # m^2/Hz
delta_a_sq = 9E-34
t_obs_noise = 4.0

times = jnp.linspace(0, t_obs, N_times)
freqs = jnp.logspace(jnp.log10(f_min), jnp.log10(f_max), N_freqs)

In [None]:
N_sky = N_theta * N_phi
# enc_cos_betas = jnp.linspace(1.0, -1.0, N_beta)
# enc_betas_reduced = jnp.arccos(enc_cos_betas)
delta_phi = 2 * jnp.pi / N_phi
ecl_thetas_reduced = jnp.linspace(1/N_theta, jnp.pi-1/N_theta, N_theta)
ecl_phis_reduced = jnp.arange(0, 2 * jnp.pi, delta_phi)[:N_phi]
print(max(ecl_phis_reduced), min(ecl_phis_reduced))

flat_to_m_sky = flat_to_matrix_sky_indices(N_theta, N_phi)
ecl_thetas = ecl_thetas_reduced[flat_to_m_sky[:,0]]
ecl_phis = ecl_phis_reduced[flat_to_m_sky[:,1]]
sky_basis = get_directional_basis(ecl_thetas, ecl_phis)

In [None]:
AU_per_billion_meters = 149.597871

ecc = L_target / (AU_per_billion_meters * 2 * jnp.sqrt(3) * R_target)
ecc = ecc.item()

orbits_1 = create_cartwheel_orbit(ecc, R_target, N, times, timeshift=0)
receiver_orbits = get_receiver_positions(orbits_1)
receiver_positions_1 = flatten_pairs(receiver_orbits)
separations = get_separations(orbits_1)
arms = flatten_pairs(separations)
arm_lengths_1 = get_arm_lengths(arms)

orbits_2 = create_cartwheel_orbit(ecc, R_target, N, times, timeshift=jnp.pi)
receiver_positions_2 = flatten_pairs(receiver_orbits)
separations = get_separations(orbits_2)
arms = flatten_pairs(separations)
arm_lengths_2 = get_arm_lengths(arms)

In [None]:
response_1, antennae_1 = response_pipe(
    orbits_1,
    freqs,
    sky_basis=sky_basis,
)
print(response_1.shape)
response_2, antennae_2 = response_pipe(
    orbits_2,
    freqs,
    sky_basis=sky_basis,
)
print(response_2.shape)

In [None]:

paths_triplet = create_cyclic_permutation_paths(path_1, N)
print(paths_triplet)


In [None]:
path_responses_1, cumul_path_separations_1 = get_path_response(
    paths_triplet,
    freqs,
    arm_lengths_1,
    response_1,
)
print(path_responses_1.shape, cumul_path_separations_1.shape)

michelson_1 = get_pairwise_differential_strain_response(
    path_response=path_responses_1, 
    cumul_path_separations=cumul_path_separations_1)
print(michelson_1.shape)

# (N_path, N_times, N_sky, N_freq, N_pol,), (N_times, N_path, N_depth)
# michelson_1_X, michelson_1_Y, michelson_1_Z = michelson_1[0], michelson_1[1], michelson_1[2]

In [None]:
path_responses_2, cumul_path_separations_2 = get_path_response(
    paths_triplet,
    freqs,
    arm_lengths_2,
    response_2,
)

michelson_2 = get_pairwise_differential_strain_response(
    path_response=path_responses_2,
    cumul_path_separations=cumul_path_separations_2)

In [None]:
# A_channel_direction = 1/3 * jnp.array([2, -1, -1])
# A_channel_1 = jnp.einsum('i...,i...->...', A_channel_direction, michelson_1)
# A_channel_2 = jnp.einsum('i...,i...->...', A_channel_direction, michelson_2)
# A_channel_1.shape, A_channel_2.shape

In [None]:
michelson_1 = michelson_1[0]
michelson_2 = michelson_2[0]
anisotropic_ORF_BBO_12 = unpolarized_cross_overlap(michelson_1, michelson_2)

In [None]:
sph_harm_values = get_sph_harm_values(l_max, ecl_thetas_reduced, ecl_phis_reduced)

In [None]:
vertex_angle = get_vertex_angle(orbits_1)

In [None]:
vertex_angle = get_vertex_angle(orbits_1)
print(f"{vertex_angle=} rad, {vertex_angle*180/jnp.pi=} deg")
gamma_BBO_12_lm = pixel_to_lm(
    anisotropic_ORF_BBO_12,
    1,
    N_theta,
    N_phi,
    ecl_thetas=ecl_thetas,
    ecl_phis=ecl_phis,
    sph_harm_values=sph_harm_values,
)
gamma_BBO_12_00 = gamma_BBO_12_lm[0, :, 0, 0]
normalization = 5 / (jnp.sin(vertex_angle)) ** 2
Y_00 = sph_harm_values[0, 0, 0]
gamma_iso_from_aniso = normalization * (1 / (4 * jnp.pi * Y_00)) * gamma_BBO_12_00
print(gamma_iso_from_aniso.shape)

In [None]:
print(gamma_iso_from_aniso[0])
plt.plot(freqs, gamma_iso_from_aniso.real, label="real")
plt.plot(freqs, gamma_iso_from_aniso.imag, label="imag")
plt.xlabel("Frequency [Hz]")
plt.ylabel("Isotropic ORF")
plt.xscale("log")
plt.legend()
plt.show()

In [None]:
L_SI = L_target * 1E9
noise_psd_bbo = 4 / (L_SI)**2 * (delta_x_sq + (delta_a_sq / (2 * jnp.pi * freqs)**4))

In [None]:
plt.plot(freqs, noise_psd_bbo)
plt.xscale('log')
plt.yscale('log')
plt.xlabel('Frequency [Hz]')
plt.ylabel('Noise PSD BBO [1/Hz]')
plt.show()

In [None]:
ell_array = jnp.arange(0, l_max + 1)
print(f"{ell_array=}")
print(f"{freqs=}")
print(f"{f_ref=}")
print(f"{t_obs_noise=}")

In [None]:
index_start = 0

spectral_indices = [-2.3, 0, -3]
BBO_N_ell_alphas = jnp.zeros((len(spectral_indices), l_max+1))
print(f"{BBO_N_ell_alphas.shape=}")
for idx, spectral_index in enumerate(spectral_indices):
    BBO_N_ell_alpha = overlap_angular_noise_ell(
        gamma_BBO_12_lm, 
        noise_psd_bbo, 
        ell_array, 
        freqs,
        f_ref, 
        spectral_index=spectral_index, 
        t_obs=t_obs_noise)
    BBO_N_ell_alphas = BBO_N_ell_alphas.at[idx].set(BBO_N_ell_alpha)

plt.figure(figsize=(8, 6))
plot_x = ell_array[index_start:]
for idx, spectral_index in enumerate(spectral_indices):
    BBO_N_ell_alpha = BBO_N_ell_alphas[idx]
    plot_y = (jnp.abs(BBO_N_ell_alpha)*(ell_array+1/2))[index_start:]
    plt.plot(plot_x, plot_y, label=f"$\\alpha_I={spectral_index}$")
plt.xlabel(r"$\ell$")
plt.ylabel(r"$\mathrm{BB0}\;\;N_\ell\;(\ell + 1/2)$")
# plt.xscale("log")
plt.yscale("log")
plt.legend()
plt.show()

# BBO pipeline

In [None]:
from jax_gw.pipes.N_ell import get_N_ell_BBO

l_max = 10
N_ell_BBO = get_N_ell_BBO(
    N_times = 4,
    N_freqs = 64,
    N_theta = 300,
    N_phi = 40,
    l_max = l_max,
    t_obs = 3.16e-5,
    spectral_indices = [-2.3, 0, -3],
)

N_ell_BBO.shape

In [None]:
import matplotlib.pyplot as plt
import numpy as np

l_array = np.arange(0, l_max+1)
plt.plot(l_array, N_ell_BBO[0]*(l_array+0.5), label=r"$\alpha_I=-2.3$")
plt.yscale("log")
plt.xlabel(r"$\ell$")
plt.ylabel(r"$N_\ell\;(\ell + 1/2)\;\;\mathrm{BBO}$")
plt.legend()
plt.show()

In [None]:
def test_N_freq_robust():
    l_max = 10
    N_freq_choices = [16, 32, 64]
    N_ell_curves = []
    for N_freqs in N_freq_choices:
        print(f"{N_freqs=}")
        N_ell_BBO = get_N_ell_BBO(
            N_times = 4,
            N_freqs = N_freqs,
            N_theta = 300,
            N_phi = 40,
            l_max = l_max,
            t_obs = 3.16e-5,
            spectral_indices = [-2.3, 0, -3],
        )
        N_ell_curves.append(N_ell_BBO[0]*(l_array+0.5))

    plt.plot(l_array, N_ell_curves[0], label=f"N_freqs={N_freq_choices[0]}")
    plt.plot(l_array, N_ell_curves[1], label=f"N_freqs={N_freq_choices[1]}")
    plt.plot(l_array, N_ell_curves[2], label=f"N_freqs={N_freq_choices[2]}")
    plt.yscale("log")
    plt.xlabel(r"$\ell$")
    plt.ylabel(r"$N_\ell\;(\ell + 1/2)\;\;\mathrm{BBO}$")
    plt.legend()
    plt.show()

test_N_freq_robust()

In [None]:
def test_N_theta_robust():
    l_max = 10
    N_theta_choices = [100, 200, 300]
    N_ell_curves = []
    for N_theta in N_theta_choices:
        print(f"{N_theta=}")
        N_ell_BBO = get_N_ell_BBO(
            N_times = 4,
            N_freqs = 64,
            N_theta = N_theta,
            N_phi = 40,
            l_max = l_max,
            t_obs = 3.16e-5,
            spectral_indices = [-2.3, 0, -3],
        )
        N_ell_curves.append(N_ell_BBO[0]*(l_array+0.5))

    plt.plot(l_array, N_ell_curves[0], label=f"N_theta={N_theta_choices[0]}")
    plt.plot(l_array, N_ell_curves[1], label=f"N_theta={N_theta_choices[1]}")
    plt.plot(l_array, N_ell_curves[2], label=f"N_theta={N_theta_choices[2]}")
    plt.yscale("log")
    plt.xlabel(r"$\ell$")
    plt.ylabel(r"$N_\ell\;(\ell + 1/2)\;\;\mathrm{BBO}$")
    plt.legend()
    plt.show()
    
test_N_theta_robust()

In [None]:
def test_N_phi_robust():
    l_max = 10
    N_phi_choices = [10, 15, 20, 40]
    N_ell_curves = []
    for N_phi in N_phi_choices:
        print(f"{N_phi=}")
        N_ell_BBO = get_N_ell_BBO(
            N_times = 4,
            N_freqs = 64,
            N_theta = 300,
            N_phi = N_phi,
            l_max = l_max,
            t_obs = 3.16e-5,
            spectral_indices = [-2.3, 0, -3],
        )
        N_ell_curves.append(N_ell_BBO[0]*(l_array+0.5))

    for idx, N_phi in enumerate(N_phi_choices):
        plt.plot(l_array, N_ell_curves[idx], label=f"N_phi={N_phi_choices[idx]}")
    plt.yscale("log")
    plt.xlabel(r"$\ell$")
    plt.ylabel(r"$N_\ell\;(\ell + 1/2)\;\;\mathrm{BBO}$")
    plt.legend()
    plt.show()
    
test_N_phi_robust()