In [None]:
from jax_gw.pipes.N_ell import get_N_ell_BBO

l_max = 10
N_ell_BBO = get_N_ell_BBO(
    N_times = 4,
    N_freqs = 64,
    N_theta = 300,
    N_phi = 40,
    l_max = l_max,
    t_obs = 3.16e-5,
    spectral_indices = [-2.3, 0, -3],
)

import matplotlib.pyplot as plt
import numpy as np

l_array = np.arange(0, l_max+1)
plt.plot(l_array, N_ell_BBO[0]*(l_array+0.5), label=r"$\alpha_I=-2.3$")
plt.yscale("log")
plt.xlabel(r"$\ell$")
plt.ylabel(r"$N_\ell\;(\ell + 1/2)\;\;\mathrm{BBO}$")
plt.legend()
plt.show()

In [None]:
from jax_gw.signal.agwb import *
parser = parser_with_arguments()
args = parser.parse_args(f"./src/jax_gw/data/stochastic_GW/ --preBessel --overwriteKernel".split())

# f=1E-1 Hz fit. power-law in frequency up to f_max = 1E1 Hz.
A_max = 0.510579E-36
mean_z = 0.5784058
sigma_z = 0.6766768
   
samples = np.array([A_max, mean_z, sigma_z])
f_value = None
f_ref = 1E-1
f_min = 1E-2
f_max = 1E1
verbose = True


args.A_max = samples[0]
args.mean_z = samples[1]
args.sigma_z = samples[2]
nonlinear = 'Halofit'
# Write here cosmological parameters used to calculate the data
params_cosmo = {
        'output': 'mPk',
        'z_pk': '0., 3.0, 7.0, 10.0',
        'P_k_max_1/Mpc': '70',
        'non linear': nonlinear
#         'gauge' : 'Newtonian' #TODO: commented this as it should be the same. Check!
    }

if not args.output_path: # Assign default name
    if args.overwriteKernel:
        args.output_path = os.path.join(args.input_dir,'data_cl_f_l_GAUSS.fits')
    else:
        args.output_path = os.path.join(args.input_dir,'data_cl_f_l_TABLES.fits')
    
if not args.bessel_path: # Assign default name
    args.bessel_path = os.path.join(args.input_dir,'sph_bessel_k_z_l_TEST_gwtools.fits')

# Assign absolute path for all files
input_dir  = os.path.abspath(args.input_dir)
path = {
    'f'      : os.path.join(input_dir, args.f_fname),
    'z'      : os.path.join(input_dir, args.z_fname),
    'A'      : os.path.join(input_dir, args.A_fname),
    'output' : os.path.abspath(args.output_path),
    'bessel' : os.path.abspath(args.bessel_path)
}

### NOTE: temporary comment out. put back in when interpolating from files

# Check existence of input files and warning before overwriting output
# assert os.path.isfile(path['f']), 'File {} not found!'.format(path['f'])
# assert os.path.isfile(path['z']), 'File {} not found!'.format(path['z'])
# assert os.path.isfile(path['A']), 'File {} not found!'.format(path['A'])
# if os.path.isfile(path['output']) and args.storeCl:
#     print('WARNING! I am going to overwrite a pre-existing data file!')
    


# # Import files
# # initially 71 redshifts and 141 frequencies. Kernel A in erg/cm^3.
# f_vec  = np.genfromtxt(path['f'], delimiter='\t')
# z_vec = np.genfromtxt(path['z'], delimiter='\t')
# A_z_f  = np.genfromtxt(path['A'])
# # Check that the imported arrays have the right dimensions and consistent with input parameters
# assert z_vec.shape+f_vec.shape==A_z_f.shape, 'The dimensions of the imported arrays are wrong!'
# assert args.z_min>=z_vec.min() and args.z_max<=z_vec.max(), 'Check redshift boundaries!'
# # Interpolate the kernel and check that the arguments are z and f in this order
# A_z_f_interp = interp2d(z_vec, f_vec, A_z_f.T, kind='cubic')
# assert z_vec.min()==A_z_f_interp.x_min and z_vec.max()==A_z_f_interp.x_max
# assert f_vec.min()==A_z_f_interp.y_min and f_vec.max()==A_z_f_interp.y_max
l_vec = np.arange(args.l_max+1)
if args.full_ell:
    l_compute = np.arange(args.l_max+1)
else:
    l_compute = jax.jit(make_sparse, static_argnums=(0,))(args.l_max)

### DENSE

x_vec = get_x_full(ell_max=args.l_max, x_min=args.x_min, 
                    after=args.num_after_max, points_pp=args.points_pp)

# Create vectors
k_num = int(args.k_density * (np.log10(args.k_max) - np.log10(args.k_min)))
k_vec = create_array(args.k_min, args.k_max, k_num, "log")


### SPARSE
k_sparse = create_array(args.k_min, args.k_max+1, args.k_sparse_num, 'log')
z_sparse = create_array(args.z_min, args.z_max, args.z_sparse_num, 'log')    

intermediate_grids = get_intermediate_grids(k_vec, x_vec, k_sparse)

#     print("NOTE: choosing very narrow frequency interval")
if f_value is None:
    f_vec = create_array(f_min, f_max, args.f_num, args.f_spacing)
else:
    f_vec = [f_value,]

# Calculate the matter power spectrum. This is frequency independent.
# This is the only place were we need Class
# cosmo = Class()
# cosmo.set(params_cosmo)
# cosmo.compute()

cosmo = jc.Planck15()


b_eff, deltaM_eff, assorted_grids = \
                get_cosmo_eff(cosmo, z_sparse, intermediate_grids, 
                                args, nonlinear)

# used in the evaluation of noise
chi_mid = chi_from_z(cosmo, assorted_grids["z_mid"])
    
    
# Precompute Spherical Bessel Function
try: 
    if verbose:
        print("Checking for pre-computed Spherical Bessel")
    assert os.path.isfile(path['bessel'])
    _l = read_data_from_fits(path['bessel'], 'l')
    _x = read_data_from_fits(path['bessel'], 'x')
    assert _x.size == x_vec.size
    assert _x.min() == x_vec.min()
    assert _x.max() == x_vec.max()
    # TODO: re-implement this check in jax
    # assert set(_l) >= set(l_compute)
except AssertionError:
    print("Could not find consistent precomputed Bessel")
    if args.preBessel:
        print('WARNING! I am going to overwrite the precomputed bessel file!')
        print("Writing Bessel. This might take a while time and it might require a lot of memory")

        after_func = lambda l: max(l, args.min_after_nu)
        write_sph_bessel(path['bessel'], l_compute, x_vec=x_vec,
                            before=args.num_before_nu, after=after_func)
        print('Finished writing Bessel')
    else:
        path['bessel'] = None
        print('Not going to use precomputed Bessel Function')
        print("Use --preBessel True to store and use spherical Bessel Functions for these k and z vectors")
else:
    if verbose:
        print("Found Precomputed Bessels")
if verbose:
    print("Recovering Bessel")
bessel_x_l = get_bessel_x_l(path['bessel'], l_compute)[None,...]

not_chi_mask_nonzero = (~assorted_grids["chi_mask"]).nonzero()

In [None]:
def cl_broadband_from_grids(params, f_vec, f_ref, args,):
    f_len = 1
    clustering = np.zeros((f_len, len(l_vec)))
    noise = np.zeros(f_len)
    data = np.zeros((f_len, len(l_vec)))
    # print(f"\r{0} {f_vec[0]:.4f}-{f_vec[-1]:.4f}\tHz", end=" ", flush=True)
    nf=0
    A_eff, A_sparse = compute_kernel_on_grid(
                                        params,
                                        freq=f_ref, 
                                        assorted_grids=assorted_grids, 
                                        args=args, 
                                        not_chi_mask_nonzero=not_chi_mask_nonzero,
                                        A_kernel_interp2d=None)
    f_slope = 2/3
    broadband = 1 / (f_slope + 2) / f_ref**f_slope * \
        (f_vec[-1]**(f_slope+2) - f_vec[0]**(f_slope+2))
    A_eff = A_eff * broadband
    A_sparse = A_sparse * broadband
        
    noise = noise.at[0].set(compute_spatial_shot_noise(cosmo,
                                            A_z=A_sparse, 
                                            chi_vec=chi_mid, 
                                            n_G=args.n_G))
    clustering_l = compute_clustering_cl(cosmo, A_eff, b_eff, deltaM_eff, 
                                            bessel_x_l, intermediate_grids["chi_grid"], 
                                            k_vec)
    clustering = clustering.at[nf].set(
        interpolate_cl(
            clustering_l, 
            l_compute, 
            l_vec, 
    ))
        
    # data[nf,:] = clustering[nf,:] + noise[nf]
    data = data.at[nf].set(clustering[nf,:] + noise[nf])
    return data

In [None]:
f_vec.min(), f_vec.max()

In [None]:
params = np.array([args.A_max, args.mean_z, args.sigma_z])
cl_broad = cl_broadband_from_grids(params, f_vec, f_ref, args,)
print(cl_broad.shape)

In [None]:
import matplotlib.pyplot as plt
l_max_plot = 10
l_plot = np.arange(l_max_plot+1)
plt.plot(l_plot, cl_broad[0,:l_max_plot+1]*(l_plot+0.5), label=f"f={f_vec[0]:.3f}-{f_vec[-1]:.3f}Hz")
plt.yscale("log")
plt.xlabel(r"$\ell$")
plt.ylabel(r"$\mathrm{Broadband}\;\;C_\ell^{cluster}\;(\ell + 1/2)\;\;\mathrm{BBO}$")
plt.legend()
plt.show()


In [None]:
def cl_narrowband_from_grids(params, f, args,):
    f_len = 1
    clustering = np.zeros((f_len, len(l_vec)))
    noise = np.zeros(f_len)
    data = np.zeros((f_len, len(l_vec)))
    nf = 0
    # print(f"\r{nf} {f:.4f}\tHz", end=" ", flush=True)

    A_eff, A_sparse = compute_kernel_on_grid(
                                        params,
                                        freq=f, 
                                        assorted_grids=assorted_grids, 
                                        args=args, 
                                        not_chi_mask_nonzero=not_chi_mask_nonzero,
                                        A_kernel_interp2d=None)
    A_eff = A_eff * f
    A_sparse = A_sparse * f
        
    noise = noise.at[nf].set(compute_spatial_shot_noise(cosmo,
                                            A_z=A_sparse, 
                                            chi_vec=chi_mid, 
                                            n_G=args.n_G))
    clustering_l = compute_clustering_cl(cosmo, A_eff, b_eff, deltaM_eff, 
                                            bessel_x_l, intermediate_grids["chi_grid"], 
                                            k_vec)
    clustering = clustering.at[nf].set(
        interpolate_cl(
            clustering_l, 
            l_compute, 
            l_vec, 
    ))
        
    # data[nf,:] = clustering[nf,:] + noise[nf]
    data = data.at[nf].set(clustering[nf,:] + noise[nf])
    return data

In [None]:
params = np.array([args.A_max, args.mean_z, args.sigma_z])
cl_narrow = cl_narrowband_from_grids(params, f_ref, args,)
print(cl_narrow.shape)

In [None]:
import matplotlib.pyplot as plt
l_max_plot = 10
l_plot = np.arange(l_max_plot+1)
plt.plot(l_plot, cl_narrow[0,:l_max_plot+1]*(l_plot+0.5), label=f"f_ref={f_ref:.3f}Hz")
plt.yscale("log")
plt.xlabel(r"$\ell$")
plt.ylabel(r"$\mathrm{Narrowband}\;\;C_\ell^{cluster}\;(\ell + 1/2)\;\;\mathrm{BBO}$")
plt.legend()
plt.show()

In [None]:
import matplotlib.pyplot as plt
l_min_plot = 1
l_max_plot = 10
l_plot = np.arange(l_min_plot,l_max_plot+1)
plt.plot(l_plot, cl_narrow[0,l_min_plot:l_max_plot+1]*(l_plot)*(l_plot+1), label=f"f_ref={f_ref:.3f}Hz")
plt.yscale("log")
plt.xlabel(r"$\ell$")
plt.ylabel(r"$\mathrm{Narrowband}\;\;C_\ell^{cluster}\;\ell(\ell + 1)\;\;\mathrm{BBO}$")
plt.legend()
plt.show()

In [None]:
H0_Hz = 2.27e-18
omega_factor = 2 * np.pi**2 * f_ref**3 / (3*H0_Hz**2) # Hz
cl_narrow_omega = 1/omega_factor**2 * cl_narrow # Omega^2 / Hz^2

In [None]:
import matplotlib.pyplot as plt
l_min_plot = 1
l_max_plot = 10
l_plot = np.arange(l_min_plot,l_max_plot+1)
plt.plot(
    l_plot, 
    cl_narrow_omega[0,l_min_plot:l_max_plot+1]*(l_plot)*(l_plot+1), 
    label=f"f_ref={f_ref:.3f}Hz")
plt.yscale("log")
plt.xlabel(r"$\ell$")
plt.ylabel(r"$\mathrm{Narrowband}\;\;C_\ell^{cluster}\;\ell(\ell + 1)\;\;\mathrm{BBO}"
           r"\;\; [\Omega_{\mathrm{GW}}^2\;/\;\mathrm{Hz}^2\;{??}]$")
plt.legend()
plt.show()

In [None]:
cl_broad_omega = 1/omega_factor**2 * cl_broad

In [None]:
import matplotlib.pyplot as plt
l_min_plot = 0
l_max_plot = 10
l_plot = np.arange(l_min_plot, l_max_plot+1)
plt.plot(
    l_plot, 
    cl_broad_omega[0,l_min_plot:l_max_plot+1]*(l_plot+0.5), 
    label=f"f = {f_vec[0]:.3f}-{f_vec[-1]:.3f}Hz")
plt.yscale("log")
plt.xlabel(r"$\ell$")
plt.ylabel(r"$\mathrm{Broadband}\;\;C_\ell^{cluster}\;(\ell + 0.5)\;\;\mathrm{BBO}$"
           r"$\;\; [\Omega_{\mathrm{GW}}^2\mathrm{/sr^2}\;{??}]$")
plt.legend()
plt.show()

In [None]:
omega_gw = 4E-12
cl_broad_omega_new = omega_gw**2/omega_factor**2 * cl_broad
import matplotlib.pyplot as plt
l_min_plot = 0
l_max_plot = 10
l_plot = np.arange(l_min_plot, l_max_plot+1)
plt.plot(
    l_plot, 
    cl_broad_omega_new[0,l_min_plot:l_max_plot+1]*(l_plot+0.5), 
    label=f"f = {f_vec[0]:.3f}-{f_vec[-1]:.3f}Hz")
plt.yscale("log")
plt.xlabel(r"$\ell$")
plt.ylabel(r"$\mathrm{Broadband}\;\;C_\ell^{cluster}\;(\ell + 0.5)\;\;\mathrm{BBO}$"
           r"$\;\; [1\mathrm{/sr^2}\;{??}]$")
plt.legend(title=r"$\Omega_{\mathrm{GW}}=10^{-12}$")
plt.show()

In [None]:
l_min_plot = 0
l_max_plot = 10
l_plot = np.arange(l_min_plot, l_max_plot+1)
plt.plot(l_plot, N_ell_BBO[0]*(l_plot+0.5), label=r"$N_\ell (\alpha_I=-2.3)$")
plt.plot(
    l_plot, 
    cl_broad_omega[0,l_min_plot:l_max_plot+1]*(l_plot+0.5), 
    label=r"$\mathrm{Broadband}\;\;C_\ell^{cluster}$",
)
plt.plot(
    l_plot, 
    cl_broad_omega_new[0,l_min_plot:l_max_plot+1]*(l_plot+0.5),
    label=r"$\mathrm{Broadband}\;\;C_\ell^{cluster}$",
)

plt.yscale("log")
plt.xlabel(r"$\ell$")
plt.ylabel(r"$\mathrm{Broadband}\;\;C_\ell\;(\ell + 0.5)\;\;\mathrm{BBO}$"
           r"$\;\; [1\mathrm{/sr^2}\;{??}]$")
plt.legend(title=r"$\Omega_{\mathrm{GW}}=10^{-12}\;\;f = [0.01, 10]$ Hz")
plt.show()