In [1]:
import gw_ml_priors

In [2]:
import multiprocessing
import os
import shutil
import h5py
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from bilby.core.prior import Interped
from joblib import Parallel, delayed
from numpy.random import uniform as unif
from tqdm.auto import tqdm

from gw_ml_priors.conversions import calc_a2
from gw_ml_priors.regressors.scikit_regressor import ScikitRegressor

NUM_CORES = multiprocessing.cpu_count()


def get_a1_prior(xeff, q, mcmc_n=int(1e4)):
    a1s = np.linspace(0, 1, 500)
    da1 = a1s[1] - a1s[0]
    p_a1 = Parallel(n_jobs=NUM_CORES, verbose=1)(
        delayed(get_p_a1_given_xeff_q)(a1, xeff, q, mcmc_n)
        for a1 in tqdm(a1s, desc="Building a1 cache")
    )
    p_a1 = p_a1 / np.sum(p_a1) / da1
    data = pd.DataFrame(dict(a1=a1s, p_a1=p_a1))
    a1 = data.a1.values
    p_a1 = norm_values(data.p_a1.values, a1)
    min_b, max_b = find_boundary(a1, p_a1)
    return Interped(
        xx=a1, yy=p_a1, minimum=min_b, maximum=max_b, name="a_1", latex_label=r"$a_1$"
    )


def get_p_a1_given_xeff_q(a1, xeff, q, n=int(1e4)):
    cos1, cos2 = unif(-1, 1, n), unif(-1, 1, n)
    a2 = calc_a2(xeff=xeff, q=q, cos1=cos1, cos2=cos2, a1=a1)
    integrand = a2_interpreter_function(a2)
    return np.mean(integrand)


def find_nearest(array, value):
    array = np.asarray(array)
    idx = (np.abs(array - value)).argmin()
    return array[idx]


def find_boundary_idx(x):
    """finds idx where data is non zero (assumes that there wont be gaps)"""
    non_z = np.nonzero(x)[0]
    return non_z[0], non_z[-1]


def norm_values(y, x):
    return y / np.trapz(y, x)


def find_boundary(x, y):
    b1, b2 = find_boundary_idx(y)
    vals = [x[b1], x[b2]]
    start, end = min(vals), max(vals)
    return start, end


def a2_interpreter_function(a2):
    return np.where(((0 < a2) & (a2 < 1)), 1, 0)




In [4]:
q_range = np.linspace(0, 1, 100)    # q range 0 to 1
xeff_range = np.linspace(-1, 1, 100)  # xeff range -1 to 1
# remove zero
q_range = np.delete(q_range, np.where(q_range == 0))
xeff_range = np.delete(xeff_range, np.where(xeff_range == 0))
print(q_range)
print(xeff_range)

[0.01010101 0.02020202 0.03030303 0.04040404 0.05050505 0.06060606
 0.07070707 0.08080808 0.09090909 0.1010101  0.11111111 0.12121212
 0.13131313 0.14141414 0.15151515 0.16161616 0.17171717 0.18181818
 0.19191919 0.2020202  0.21212121 0.22222222 0.23232323 0.24242424
 0.25252525 0.26262626 0.27272727 0.28282828 0.29292929 0.3030303
 0.31313131 0.32323232 0.33333333 0.34343434 0.35353535 0.36363636
 0.37373737 0.38383838 0.39393939 0.4040404  0.41414141 0.42424242
 0.43434343 0.44444444 0.45454545 0.46464646 0.47474747 0.48484848
 0.49494949 0.50505051 0.51515152 0.52525253 0.53535354 0.54545455
 0.55555556 0.56565657 0.57575758 0.58585859 0.5959596  0.60606061
 0.61616162 0.62626263 0.63636364 0.64646465 0.65656566 0.66666667
 0.67676768 0.68686869 0.6969697  0.70707071 0.71717172 0.72727273
 0.73737374 0.74747475 0.75757576 0.76767677 0.77777778 0.78787879
 0.7979798  0.80808081 0.81818182 0.82828283 0.83838384 0.84848485
 0.85858586 0.86868687 0.87878788 0.88888889 0.8989899  0.90909

In [None]:
outdir = 'out'
os.makedirs(outdir, exist_ok=True)
data = []
for i in q_range:
    for j in xeff_range:
        print(f"q={i}, xeff={j}")
        a1_prior = get_a1_prior(q=i, xeff=j)
        plt.plot(a1_prior.xx, a1_prior.yy)
        plt.xlabel("a1")
        plt.ylabel(f"p(a1|q={i},xeff={j})")
        plt.savefig(f"{outdir}/p_a1_given_q_xeff_{i}_{j}.png")
        plt.close()
        # save to data
        data.append([i,j,a1_prior.xx,a1_prior.yy])
        

In [None]:
df = pd.DataFrame(data)
df.columns = ["q","xeff","a1","p_a1"]
import pickle
# save the dataframe to pickle
with open(f"{outdir}/p_a1_given_q_xeff.pkl", "wb") as f:
    pickle.dump(df, f)
