# We need

- `HubbleInvMpcNow = cosmo.Hubble(0) # Hubble constant in Mpc-1`
- ```
  z_points = cosmo.get_background()['z']
  chi_points = cosmo.get_background()['comov. dist.']
  ```
  in order to get a function chi(z)
- ```
  matterPk = cosmo.get_pk_array(k_vec, z_vec, 
                                  len(k_vec), len(z_vec), 
                                  nonlinear=isnonlinear)
  ```
- ```
  cosmo = Class()
  cosmo.set(params_cosmo)
  cosmo.compute()
  ```


In [None]:
import jax
# enable 64-bit precision
from jax.config import config
config.update("jax_enable_x64", True)
import jax.numpy as jnp
import jax_cosmo as jc
from jax_cosmo.power import linear_matter_power, nonlinear_matter_power

import matplotlib.pyplot as plt

In [None]:
cosmo = jc.Planck15()

In [None]:
cosmo.__dict__

In [None]:
cosmo

In [None]:
constants = {
    'JoulestoErg'      : 1.0E7,
    'metersTocm'       : 1.0E2,
    'metersToMpc'      : 3.2408E-23,
    'gravityConstantG' : 6.673E-11, # G in m^3/ s^2 / kg
    'speedOfLightC'    : 299792458.0, # c in m / s
    'cMpcInvSec'       : 9.72E-15 # Speed of light in Mpc/sec
    }

In [None]:
def rho_crit(cosmo):
    """Calculates the critical energy density at present time
    density in erg/cm^3 for a given cosmology
    rho_crit = 3*H0**2 * c^2 /(8 pi G)

    Args:
        cosmo (Class): An instance of the computed cosmology

    Returns:
        float: Critical density at present time in erg/cm^3
    """
    HubbleInvMpcNow = cosmo.h * 100 * constants['metersToMpc'] / constants['cMpcInvSec']
    conversions = (constants['metersToMpc']**2 / constants['metersTocm']**3) * constants['JoulestoErg']
    rhoCritical = conversions * 3.0 * HubbleInvMpcNow**2 * constants['speedOfLightC']**4 / (8.0 * jnp.pi * constants['gravityConstantG']) 
    return rhoCritical

In [None]:
rho_crit(cosmo)

In [None]:
def chi_of_z(cosmo, z):
    """Calculates the comoving distance in Mpc for a given redshift

    Args:
        cosmo (Class): An instance of the computed cosmology

    Returns:
        float: Comoving distance in Mpc
    """
    chi = jc.background.radial_comoving_distance(cosmo, 1.0 / (1.0 + z))
    return chi

In [None]:
z = jnp.linspace(0.0, 10.0, 1000)
chi = chi_of_z(cosmo, z)

In [None]:
plt.plot(z, chi)
plt.show()

In [None]:
def z_from_chi(cosmo, chi_vec):
    """Calculates the comoving distance in Mpc for a given redshift

    Args:
        cosmo (Class): An instance of the computed cosmology

    Returns:
        float: Comoving distance in Mpc
    """
    from jax_cosmo.scipy.interpolate import interp
    # create an array of z values to interpolate
    z = jnp.linspace(0, 10, 1000)
    chi = chi_of_z(cosmo, z)
    z_vec = interp(chi_vec, chi, z)
    return z_vec

In [None]:
chi = jnp.linspace(0.0, 6700.0, 1000)
z = z_from_chi(cosmo, chi)
plt.plot(chi, z)
plt.show()

In [None]:
# #Let's have a look at the linear power
# k = jnp.logspace(-3,-0.5, 50).reshape(-1,1)
# a = jnp.linspace(0.1, 1, 10).T
# print(k.shape)
# print(a.shape)

# pk = linear_matter_power(cosmo, k/cosmo.h, a)
# pk_nonlin = nonlinear_matter_power(cosmo, k/cosmo.h, a)

# print(pk.shape)
# print(pk_nonlin.shape)
# plt.figure(figsize=(7,5))
# for i in range(a.shape[0]):
#     plt.loglog(k,jnp.sqrt(pk[:,i]/cosmo.h**3), label=f'a={a[i]:.2f}')
# plt.legend()
# plt.xlabel('k [Mpc]')
# plt.ylabel('delta_k');
# plt.title('Non-linear power spectrum')
# plt.show()

# plt.figure(figsize=(7,5))
# for i in range(a.shape[0]):
#     plt.loglog(k,jnp.sqrt(pk_nonlin[:,i]/cosmo.h**3), label=f'a={a[i]:.2f}')
# plt.legend()
# plt.xlabel('k [Mpc]')
# plt.ylabel('delta_k');
# plt.title('Non-linear power spectrum')
# plt.show()


In [None]:
from jax_gw.signal.agwb import compute_cl, parser_with_arguments

In [None]:
parser = parser_with_arguments()
# LIGO BAND:
# A_max = 6E-38
# z_peak = 0.6
# z_sigma = 0.7
# LISA BAND:
# A_max = 
# z_peak = 
# z_sigma = 
args_data = parser.parse_args("./jax_gw/data/stochastic_GW/ --preBessel --overwriteKernel".split())
cls = compute_cl(jnp.array([6E-38, 0.6, 0.7]), args_data, f_value=63.1)

In [None]:
cls[2], cls[3], cls[4]

In [None]:
ell_arr = jnp.arange(0, len(cls[1][0]))
plt.loglog(ell_arr[1:], ell_arr[1:]*(ell_arr[1:]+1)*cls[1][0][1:]/(2*jnp.pi), label='LIGO band')
plt.xlabel(r'$\ell$')
plt.ylabel(r'$\ell(\ell+1)C_\ell/(2\pi)$')
plt.legend()
plt.show()

In [None]:
from jax import random
import jax.numpy as jnp
import numpyro
# import numpyro handlers
from numpyro import handlers
import numpyro.distributions as dist
from numpyro.infer import MCMC, NUTS, Predictive, HMC

In [None]:
def generate_cl_data(A_max, z_peak, z_sigma):
    samples = jnp.array([A_max, z_peak, z_sigma])
    args_data = parser.parse_args(f"./jax_gw/data/stochastic_GW/ --preBessel --overwriteKernel".split())
    cls = compute_cl(samples, args_data, f_value=63.1)
    # add noise
    cls_01 = cls[1][0] + cls[1][1]
    return cls_01

cl_data = generate_cl_data(6E-38, 0.5, 0.6)

In [None]:
def compute_loglkl_from_cls(data_cl, theory_cl, l_vec):
    # Combining theory_cl and data_cl to calculate the likelihood
    # Using equation (3) of arXiv 1811.11584
    # Note that most of the expression below can be precomputed if needed
    # Note that data_cl already have noise inside

    chi2_l = (2.0 * l_vec + 1.0) * \
                ( (data_cl / theory_cl) + jnp.log(theory_cl) ) \
            - (2.0 * l_vec - 1.0) * jnp.log(data_cl) 
    # Exclude l = 0 
    if l_vec[0]==0:
        chi2_l = chi2_l[1:]
    chi2 = jnp.sum(chi2_l)
    loglklhood = - 0.5 * chi2
    return loglklhood


def likelihood_fn(A_max=None, z_peak=None, z_sigma=None):
    """Likelihood function for the astrpphysical GW stochastic background (AGWB)

    The likelihood is a Wishart distribution with a covariance given by the AGWB power spectrum

    Args:
        A_max (float): Maximum amplitude of the AGWB
        z_peak (float): Redshift of the peak of the AGWB
        z_sigma (float): Width of the AGWB
    """
    # Sample the parameters
    with handlers.seed(rng_seed=0):
        A_max_sample = numpyro.sample("A_max", dist.Uniform(1E-40, 1E-35))
        z_peak_sample = numpyro.sample("z_peak", dist.Uniform(0.0, 2.0))
        z_sigma_sample = numpyro.sample("z_sigma", dist.Uniform(0.0, 3.0))
    str_formatted = f"./jax_gw/data/stochastic_GW/ --preBessel --overwriteKernel"
    str_formatted_splitted = str_formatted.split()
    args_data = parser.parse_args(str_formatted_splitted)
    cls = compute_cl(
        jnp.array([A_max_sample, z_peak_sample, z_sigma_sample]),
        args_data, 
        f_value=63.1)
    ell_arr = jnp.arange(0, len(cls[1][0]))

    # Compute the log likelihood
    loglklhood = compute_loglkl_from_cls(cl_data, cls[1][0], ell_arr)
    numpyro.factor("loglklhood", loglklhood)

In [None]:
rng_key = random.PRNGKey(0)
rng_key, rng_key_ = random.split(rng_key)
mcmc = MCMC(NUTS(likelihood_fn), num_warmup=1000, num_samples=1000, num_chains=1)
mcmc.run(rng_key=rng_key_)

In [None]:
def make_sparse(l_max, l_min=0, intervals=[30,40,240,1000], sample_distances=[1,10,20,40]):
    """If intervals is not sorted it will be sorted.
    TODO: The default values probably need refinement.

    Args:
        l_vec (ndarray): The array of all ell at which the likelihood is evaluated
        intervals (list, optional): The upper limit of each interval.
                                    Each interval has different sampling rate.
                                    Defaults to [30,40,240,1000].
        sample_distances (list, optional): The distance between samples in each interval. 
                                           Defaults to [1,10,20,40].

    Raises:
        ValueError: If the specifications of the interval are inconsistent.

    Returns:
        ndarray: array of sparse ell.
    """
    if (not isinstance(intervals, list))  or (not isinstance(sample_distances, list)):
        raise ValueError("intervals and samples_distances should both be lists.")
    if len(intervals) != len(sample_distances):
        raise ValueError("Arrays `intervals` and `sample_distances` should have same length")
    if any(x < 1 for x in sample_distances):
        raise ValueError("All sample distances have to be greater than 1.")
    
    # if np.max(np.array(intervals)).astype(int) < l_max.astype(int):
    #     raise ValueError("Not sure how to deal with the interval from {} to {}".format(np.max(intervals),l_max))
    
    intervals = sorted(intervals)    
    
    ell_value = l_min
    ell_list = []
    interv_idx = 0
    while interv_idx < len(intervals):
        upper_bound = intervals[interv_idx]
        while ell_value < upper_bound and ell_value <= l_max:
            ell_list.append(ell_value)
            ell_value += sample_distances[interv_idx]
        interv_idx += 1
    if l_max not in ell_list:
        ell_list.append(l_max)

    l_sparse_vec = jnp.array(ell_list)
    
    return l_sparse_vec

In [None]:
make_sparse(1000)

In [None]:
jitted_make_sparse = jax.jit(make_sparse, static_argnums=(0,))

jitted_make_sparse(1000)

# Decoupling precompute from likelihood

In [None]:
# samples contains the values of the parameters of the kernel
# args contains the arguments of the code
# transfer the samples to args
from jax_gw.signal.agwb import *
parser = parser_with_arguments()
args = parser.parse_args(f"./src/jax_gw/data/stochastic_GW/ --preBessel --overwriteKernel".split())
samples = np.array([1E-38, 0.7, 0.6])
f_value = None
f_ref = 63.1
f_min = 20
f_max = 500
verbose = True


args.A_max = samples[0]
args.mean_z = samples[1]
args.sigma_z = samples[2]
nonlinear = 'Halofit'
# Write here cosmological parameters used to calculate the data
params_cosmo = {
        'output': 'mPk',
        'z_pk': '0., 3.0, 7.0, 10.0',
        'P_k_max_1/Mpc': '70',
        'non linear': nonlinear
#         'gauge' : 'Newtonian' #TODO: commented this as it should be the same. Check!
    }

if not args.output_path: # Assign default name
    if args.overwriteKernel:
        args.output_path = os.path.join(args.input_dir,'data_cl_f_l_GAUSS.fits')
    else:
        args.output_path = os.path.join(args.input_dir,'data_cl_f_l_TABLES.fits')
    
if not args.bessel_path: # Assign default name
    args.bessel_path = os.path.join(args.input_dir,'sph_bessel_k_z_l_TEST_gwtools.fits')

# Assign absolute path for all files
input_dir  = os.path.abspath(args.input_dir)
path = {
    'f'      : os.path.join(input_dir, args.f_fname),
    'z'      : os.path.join(input_dir, args.z_fname),
    'A'      : os.path.join(input_dir, args.A_fname),
    'output' : os.path.abspath(args.output_path),
    'bessel' : os.path.abspath(args.bessel_path)
}

### NOTE: temporary comment out. put back in when interpolating from files

# Check existence of input files and warning before overwriting output
# assert os.path.isfile(path['f']), 'File {} not found!'.format(path['f'])
# assert os.path.isfile(path['z']), 'File {} not found!'.format(path['z'])
# assert os.path.isfile(path['A']), 'File {} not found!'.format(path['A'])
# if os.path.isfile(path['output']) and args.storeCl:
#     print('WARNING! I am going to overwrite a pre-existing data file!')
    


# # Import files
# # initially 71 redshifts and 141 frequencies. Kernel A in erg/cm^3.
# f_vec  = np.genfromtxt(path['f'], delimiter='\t')
# z_vec = np.genfromtxt(path['z'], delimiter='\t')
# A_z_f  = np.genfromtxt(path['A'])
# # Check that the imported arrays have the right dimensions and consistent with input parameters
# assert z_vec.shape+f_vec.shape==A_z_f.shape, 'The dimensions of the imported arrays are wrong!'
# assert args.z_min>=z_vec.min() and args.z_max<=z_vec.max(), 'Check redshift boundaries!'
# # Interpolate the kernel and check that the arguments are z and f in this order
# A_z_f_interp = interp2d(z_vec, f_vec, A_z_f.T, kind='cubic')
# assert z_vec.min()==A_z_f_interp.x_min and z_vec.max()==A_z_f_interp.x_max
# assert f_vec.min()==A_z_f_interp.y_min and f_vec.max()==A_z_f_interp.y_max
l_vec = np.arange(args.l_max+1)
if args.full_ell:
    l_compute = np.arange(args.l_max+1)
else:
    l_compute = jax.jit(make_sparse, static_argnums=(0,))(args.l_max)

### DENSE

x_vec = get_x_full(ell_max=args.l_max, x_min=args.x_min, 
                    after=args.num_after_max, points_pp=args.points_pp)

# Create vectors
k_num = int(args.k_density * (np.log10(args.k_max) - np.log10(args.k_min)))
k_vec = create_array(args.k_min, args.k_max, k_num, "log")


### SPARSE
k_sparse = create_array(args.k_min, args.k_max+1, args.k_sparse_num, 'log')
z_sparse = create_array(args.z_min, args.z_max, args.z_sparse_num, 'log')    

intermediate_grids = get_intermediate_grids(k_vec, x_vec, k_sparse)

#     print("NOTE: choosing very narrow frequency interval")
if f_value is None:
    f_vec = create_array(f_min, f_max, args.f_num, args.f_spacing)
else:
    f_vec = [f_value,]

# Calculate the matter power spectrum. This is frequency independent.
# This is the only place were we need Class
# cosmo = Class()
# cosmo.set(params_cosmo)
# cosmo.compute()

cosmo = jc.Planck15()


b_eff, deltaM_eff, assorted_grids = \
                get_cosmo_eff(cosmo, z_sparse, intermediate_grids, 
                                args, nonlinear)

# used in the evaluation of noise
chi_mid = chi_from_z(cosmo, assorted_grids["z_mid"])
    
    
# Precompute Spherical Bessel Function
try: 
    if verbose:
        print("Checking for pre-computed Spherical Bessel")
    assert os.path.isfile(path['bessel'])
    _l = read_data_from_fits(path['bessel'], 'l')
    _x = read_data_from_fits(path['bessel'], 'x')
    assert _x.size == x_vec.size
    assert _x.min() == x_vec.min()
    assert _x.max() == x_vec.max()
    # TODO: re-implement this check in jax
    # assert set(_l) >= set(l_compute)
except AssertionError:
    print("Could not find consistent precomputed Bessel")
    if args.preBessel:
        print('WARNING! I am going to overwrite the precomputed bessel file!')
        print("Writing Bessel. This might take a while time and it might require a lot of memory")

        after_func = lambda l: max(l, args.min_after_nu)
        write_sph_bessel(path['bessel'], l_compute, x_vec=x_vec,
                            before=args.num_before_nu, after=after_func)
        print('Finished writing Bessel')
    else:
        path['bessel'] = None
        print('Not going to use precomputed Bessel Function')
        print("Use --preBessel True to store and use spherical Bessel Functions for these k and z vectors")
else:
    if verbose:
        print("Found Precomputed Bessels")
if verbose:
    print("Recovering Bessel")
bessel_x_l = get_bessel_x_l(path['bessel'], l_compute)[None,...]

not_chi_mask_nonzero = (~assorted_grids["chi_mask"]).nonzero()


In [None]:
f_vec.min(), f_vec.max(), f_vec.size

In [None]:
def cl_narrowband_from_grids(params, f, args,):
    f_len = 1
    clustering = np.zeros((f_len, len(l_vec)))
    noise = np.zeros(f_len)
    data = np.zeros((f_len, len(l_vec)))
    nf = 0
    print(f"\r{nf} {f:.4f}\tHz", end=" ")
    sys.stdout.flush()
        
    if args.overwriteKernel:
        A_eff, A_sparse = compute_kernel_on_grid(
                                            params,
                                            freq=f, 
                                            assorted_grids=assorted_grids, 
                                            args=args, 
                                            not_chi_mask_nonzero=not_chi_mask_nonzero,
                                            A_kernel_interp2d=None)
        A_eff = A_eff * f
        A_sparse = A_sparse * f
    else:
        A_eff, A_sparse = compute_kernel_on_grid(freq=f, 
                                            assorted_grids=assorted_grids, 
                                            args=args, 
                                            A_kernel_interp2d=A_z_f_interp)
        
    noise = noise.at[nf].set(compute_spatial_shot_noise(cosmo,
                                            A_z=A_sparse, 
                                            chi_vec=chi_mid, 
                                            n_G=args.n_G))
    clustering_l = compute_clustering_cl(cosmo, A_eff, b_eff, deltaM_eff, 
                                            bessel_x_l, intermediate_grids["chi_grid"], 
                                            k_vec)
    clustering = clustering.at[nf].set(
        interpolate_cl(
            clustering_l, 
            l_compute, 
            l_vec, 
    ))
        
    # data[nf,:] = clustering[nf,:] + noise[nf]
    data = data.at[nf].set(clustering[nf,:] + noise[nf])
    return data

In [None]:
def cl_broadband_from_grids(params, f_vec, f_ref, args,):
    f_len = 1
    clustering = np.zeros((f_len, len(l_vec)))
    noise = np.zeros(f_len)
    data = np.zeros((f_len, len(l_vec)))
    # print(f"\r{0} {f_vec[0]:.4f}-{f_vec[-1]:.4f}\tHz", end=" ")
    sys.stdout.flush()
    nf=0
    if args.overwriteKernel:
        A_eff, A_sparse = compute_kernel_on_grid(
                                            params,
                                            freq=f_ref, 
                                            assorted_grids=assorted_grids, 
                                            args=args, 
                                            not_chi_mask_nonzero=not_chi_mask_nonzero,
                                            A_kernel_interp2d=None)
        f_slope = 2/3
        broadband = 1 / (f_slope + 2) / f_ref**f_slope * \
            (f_vec[-1]**(f_slope+2) - f_vec[0]**(f_slope+2))
        A_eff = A_eff * broadband
        A_sparse = A_sparse * broadband
    else:
        A_eff, A_sparse = compute_kernel_on_grid(freq=f, 
                                            assorted_grids=assorted_grids, 
                                            args=args, 
                                            A_kernel_interp2d=A_z_f_interp)
        
    noise = noise.at[0].set(compute_spatial_shot_noise(cosmo,
                                            A_z=A_sparse, 
                                            chi_vec=chi_mid, 
                                            n_G=args.n_G))
    clustering_l = compute_clustering_cl(cosmo, A_eff, b_eff, deltaM_eff, 
                                            bessel_x_l, intermediate_grids["chi_grid"], 
                                            k_vec)
    clustering = clustering.at[nf].set(
        interpolate_cl(
            clustering_l, 
            l_compute, 
            l_vec, 
    ))
        
    # data[nf,:] = clustering[nf,:] + noise[nf]
    data = data.at[nf].set(clustering[nf,:] + noise[nf])
    return data

In [None]:
from jax import random
import jax.numpy as jnp
import numpyro
# import numpyro handlers
from numpyro import handlers
import numpyro.distributions as dist
from numpyro.infer import MCMC, NUTS, Predictive, HMC

In [None]:
def generate_cl_data(A_max, z_peak, z_sigma, f_vec, f_ref):
    samples = jnp.array([A_max, z_peak, z_sigma])
    str_formatted = f"./src/jax_gw/data/stochastic_GW/ --preBessel --overwriteKernel"
    str_formatted_splitted = str_formatted.split()
    args_data = parser.parse_args(str_formatted_splitted)
    A_max_sample, z_peak_sample, z_sigma_sample = samples
    cls = cl_broadband_from_grids(
        jnp.array([
            A_max_sample*1E-37,
            z_peak_sample, 
            z_sigma_sample
        ]),
        f_vec, 
        f_ref,
        args=args_data,
    )
    # add noise
    cls_01 = cls[1][0] + cls[1][1]
    return cls_01

f_min, f_max = 20, 500
f_ref = 63.1
f_vec = np.linspace(f_min, f_max, 10)
cl_data = generate_cl_data(0.6, 0.5, 0.6, f_vec, f_ref)

In [None]:
print(cl_data.min(), cl_data.max())

In [None]:
def compute_loglkl_from_cls(data_cl, theory_cl, l_vec):
    # Combining theory_cl and data_cl to calculate the likelihood
    # Using equation (3) of arXiv 1811.11584
    # Note that most of the expression below can be precomputed if needed
    # Note that data_cl already have noise inside

    chi2_l = (2.0 * l_vec + 1.0) * \
                ( (data_cl / theory_cl) + jnp.log(theory_cl) ) \
            - (2.0 * l_vec - 1.0) * jnp.log(data_cl) \
                - 2.0 * jnp.log(data_cl)

    
    # Exclude l = 0 from the sum
    chi2_l = chi2_l[1:]
    chi2 = jnp.sum(chi2_l)
    loglklhood = - 0.5 * chi2 + 1225505.9
    return loglklhood

def likelihood_fn(A_max=None, z_peak=None, z_sigma=None):
    """Likelihood function for the astrpphysical GW stochastic background (AGWB)

    The likelihood is a Wishart distribution with a covariance given by the AGWB power spectrum

    Args:
        A_max (float): Maximum amplitude of the AGWB
        z_peak (float): Redshift of the peak of the AGWB
        z_sigma (float): Width of the AGWB
    """
    # Sample the parameters
    # with handlers.seed(rng_seed=0):
    A_max_sample = numpyro.sample("A_max", dist.Uniform(0.4, 0.8))
    z_peak_sample = numpyro.sample("z_peak", dist.Uniform(0.2, 0.8))
    z_sigma_sample = numpyro.sample("z_sigma", dist.Uniform(0.2, 0.8))
        

    str_formatted = f"./src/jax_gw/data/stochastic_GW/ --preBessel --overwriteKernel"
    str_formatted_splitted = str_formatted.split()
    args_data = parser.parse_args(str_formatted_splitted)
    f_min, f_max = 20, 500
    f_ref = 63.1
    f_vec = np.linspace(f_min, f_max, 10)
    cls = cl_broadband_from_grids(
        jnp.array([
            A_max_sample*1E-37,
            z_peak_sample, 
            z_sigma_sample
        ]),
        f_vec, 
        f_ref,
        args=args_data,
    )
    ell_arr = jnp.arange(0, len(cls[0]))

    # Compute the log likelihood
    loglklhood = compute_loglkl_from_cls(cl_data, cls[0], ell_arr)
    print(loglklhood)
    numpyro.factor("loglklhood", loglklhood)


In [None]:
from numpyro.infer import init_to_feasible, init_to_median, init_to_sample
rng_key = random.PRNGKey(0)
rng_key, rng_key_ = random.split(rng_key)
kernel = NUTS(
    likelihood_fn,
    init_strategy=init_to_median(),
    )
mcmc = MCMC(kernel, num_warmup=100, num_samples=100, num_chains=1)
mcmc.run(rng_key=rng_key_)
mcmc.print_summary()

In [None]:
from numpyro.infer import init_to_feasible, init_to_median, init_to_sample
rng_key = random.PRNGKey(0)
rng_key, rng_key_ = random.split(rng_key)
kernel = NUTS(
    likelihood_fn,
    init_strategy=init_to_median(),
    )
mcmc = MCMC(kernel, num_warmup=100, num_samples=100, num_chains=1)
mcmc.run(rng_key=rng_key_)
mcmc.print_summary()
posterior = mcmc.get_samples()

# Instrumental Noise

# Runs

- One run with f=0.001