In [None]:
import os
import glob

import time

import numpy as np
import jax.numpy as jnp
from jax import grad, jit
path_to_data = os.path.join(os.getcwd(), 'src', 'jax_gw', 'data', 'stochastic_GW')

In [None]:
f_vec_file = glob.glob(os.path.join(path_to_data, 'fvec_wide.dat'))[0]
z_vec_file = glob.glob(os.path.join(path_to_data, 'zvec.dat'))[0]
LumA_file = glob.glob(os.path.join(path_to_data, 'LumA_wide.dat'))[0]
print(f_vec_file)
print(z_vec_file)
print(LumA_file)

In [None]:

f_vec = np.loadtxt(f_vec_file)
z_vec = np.loadtxt(z_vec_file)
LumA = np.loadtxt(LumA_file)
assert LumA.shape == (len(z_vec), len(f_vec))
print(f_vec.shape)
print(z_vec.shape)
print(LumA.shape)

In [None]:
import matplotlib.pyplot as plt

plt.figure(figsize=(6, 4))
f_indices = np.arange(0, len(f_vec), 10)
for f_ind in f_indices:
    plt.plot(z_vec, LumA[:, f_ind], label=f'f={f_vec[f_ind]} Hz')
plt.xlabel('z')
plt.ylabel(r'$\mathcal{A}(z)$')
plt.yscale('log')
plt.legend( loc='upper right', bbox_to_anchor=(1.4, 1.0))
plt.show()

In [None]:
f_ind = 60
plt.figure(figsize=(6, 4))
plt.plot(z_vec, LumA[:, f_ind], label=f'f={f_vec[f_ind]} Hz')
plt.xlabel('z')
plt.ylabel(r'$\mathcal{A}(z)$')
plt.yscale('log')
plt.legend()
plt.show()


In [None]:
rho_crit = 7.68E-9 # erg/cm^3
H0_Hz = 2.27e-18 # Hz
# TODO: replace Hubble constant with redshift rependent Hubble parameter
Omega_gw = f_vec / rho_crit * np.trapz(LumA/H0_Hz, z_vec, axis=0)
plt.figure(figsize=(6, 4))
plt.plot(f_vec, Omega_gw)
plt.xlabel('f [Hz]')
plt.ylabel(r'$\Omega_{\rm GW}(f)$')
plt.xscale('log')
plt.yscale('log')
plt.show()

In [None]:
f_ind = 60
print(f'f={f_vec[f_ind]} Hz: Omega_gw={Omega_gw[f_ind]}')

In [None]:
# use gradient descent to find the best fit parameters
f_ind = 60
y_data = LumA[:, f_ind] * 1e36
x_data = z_vec

def gaussian_model(params, x):
    A, mu, sigma = params
    return A * jnp.exp(-(x - mu)**2 / (2 * sigma**2))

def loss(params, x, y):
    return jnp.mean((y - gaussian_model(params, x))**2)

grad_loss = jit(grad(loss))
non_jitted_grad_loss = grad(loss)

# initialize the parameters
A = 1
mu = 0.5
sigma = 0.4

params = jnp.array([A, mu, sigma])
params_0 = params
n_iter = 1000
lr = 0.1

start = time.time()
for i in range(n_iter):
    params = params - lr * grad_loss(params, x_data, y_data)
    if i % (n_iter // 10) == 0:
        print(f'loss = {loss(params, x_data, y_data)}')
end = time.time()

print(f'jitted elapsed time = {end - start} seconds')
print(f'best fit parameters: {params}')

params = params_0
start = time.time()
for i in range(n_iter):
    params = params - lr * non_jitted_grad_loss(params, x_data, y_data)
    if i % (n_iter // 10) == 0:
        print(f'loss = {loss(params, x_data, y_data)}')
end = time.time()
print(f'elapsed time = {end - start} seconds')

print(f'best fit parameters: {params}')


In [None]:
print(f'best fit parameters: {params}')
plt.figure(figsize=(6, 4))
plt.plot(x_data, y_data, label='data')
plt.plot(x_data, gaussian_model(params_0, x_data), label='initial guess')
plt.plot(x_data, gaussian_model(params, x_data), label='best fit')
plt.xlabel('z')
plt.ylabel(r'$\mathcal{A}(z)$')
plt.legend(title=f'f = {f_vec[f_ind]} Hz')
plt.show()