In [1]:
import jax
import jax.numpy as jnp
import jax.scipy.stats as stats
import matplotlib.pyplot as plt
import numpy as np
from functools import partial

import blackjax

from ripple.waveforms import IMRPhenomXAS
from ripple import ms_to_Mc_eta

from fastprogress import progress_bar

import blackjax

from pycbc.catalog import Merger
from pycbc.filter import resample_to_delta_t, highpass
from pycbc.psd import interpolate, inverse_spectrum_truncation
from pycbc.waveform import get_fd_waveform
from pycbc.filter import matched_filter, sigmasq, get_cutoff_indices 

from jax.config import config
config.update("jax_debug_nans", True)

jax.config.update("jax_enable_x64", True)

No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)


In [2]:
data = {"L1": jnp.load("/users/sgreen/gwtuna/LVK/Paper/BBHSearch/P.E./L1_data.npy")}#, "H1": jnp.load("H1_data.npy")}
freqs = jnp.load("/users/sgreen/gwtuna/LVK/Paper/BBHSearch/P.E./L1_freqs.npy")
psd = {"L1": jnp.load("/users/sgreen/gwtuna/LVK/Paper/BBHSearch/P.E./L1_psd.npy")}#, "H1": jnp.load("H1_psd.npy")}

In [3]:
chi1 = 0.0
chi2 = 0.0
tc = 1126259462
phic = jnp.pi/2
dist_mpc = 100.
inclination = jnp.pi/2

In [4]:
def log_likelihood(mchirp, eta, fref=20.0, freqs=freqs, psd=psd, data=data, chi1=chi1, chi2=chi2, tc=tc, phic=phic, dist_mpc=dist_mpc, inclination=inclination):
#    hp, hc = waveform_generator(
#        freqs, mchirp, eta, chi1, chi2, tc, phic, dist_mpc, inclination,
#    )
    
    theta = jnp.array([mchirp, eta, chi1, chi2, tc, phic, dist_mpc, inclination])
    hp, hc = IMRPhenomXAS.gen_IMRPhenomXAS_polar(freqs, theta, fref)
    d_inner_h = 0.
    optimal_snr_squared = 0.
    for ifo in data.keys():
        if ifo == "L1":
            fp = -0.456852978678261
            fc = 0.36204310587763466
        else:
            fp = 0.45529254427236665
            fc = -0.6283981252126967
        h = hp*fp + hc*fc
        h_star = jnp.conj(h)
        integrand = jnp.nan_to_num(h_star * data[ifo] / psd[ifo], 0.)
        d_inner_h += 4.0 / 32.0 * jnp.sum(integrand)
        integrand = jnp.nan_to_num(h_star * h / psd[ifo], 0.)
        optimal_snr_squared += (4.0 / 32.0 * jnp.sum(integrand)).real
    log_l = d_inner_h.real - optimal_snr_squared / 2
    print(type(log_l.real))
    return log_l.real

In [5]:
#my_log_likelihood = partial(log_likelihood, freqs, psd, data)

In [7]:
def log_probability(mchirp, eta, fref=20.0, freqs=freqs, psd=psd, data=data, chi1=chi1, chi2=chi2, tc=tc, phic=phic, dist_mpc=dist_mpc, inclination=inclination):
    return log_likelihood(mchirp, eta, fref, freqs, psd, data, chi1, chi2, tc, phic, dist_mpc, inclination)

In [8]:
logdensity = lambda a: log_probability(**a)

In [22]:
inv_mass_matrix = np.array([0.0000001, 0.00000001])
num_integration_steps = 1e-5
step_size = 1e-5

nuts = blackjax.nuts(logdensity, step_size, inv_mass_matrix)
#hmc = blackjax.hmc(logdensity, step_size, inv_mass_matrix, num_integration_steps)

In [23]:
initial_position = {"mchirp": 30.0, "eta": 0.15}
initial_state = nuts.init(initial_position)
initial_state

Unexpected exception formatting exception. Falling back to standard exception


Traceback (most recent call last):
  File "/mnt/lustre/shared_conda/envs/sgreen/PyCBCandOptuna/lib/python3.11/site-packages/jax/_src/pjit.py", line 1252, in _pjit_call_impl
    return compiled.unsafe_call(*args)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/mnt/lustre/shared_conda/envs/sgreen/PyCBCandOptuna/lib/python3.11/site-packages/jax/_src/profiler.py", line 314, in wrapper
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/mnt/lustre/shared_conda/envs/sgreen/PyCBCandOptuna/lib/python3.11/site-packages/jax/_src/interpreters/pxla.py", line 1920, in __call__
    dispatch.check_special(self.name, arrays)
  File "/mnt/lustre/shared_conda/envs/sgreen/PyCBCandOptuna/lib/python3.11/site-packages/jax/_src/dispatch.py", line 436, in check_special
    _check_special(name, buf.dtype, buf)
  File "/mnt/lustre/shared_conda/envs/sgreen/PyCBCandOptuna/lib/python3.11/site-packages/jax/_src/dispatch.py", line 441, in _check_special
    raise FloatingPointError(f"invalid v

In [None]:
nuts_kernel = jax.jit(nuts.step)

In [None]:
def inference_loop(rng_key, kernel, initial_state, num_samples):
    @jax.jit
    def one_step(state, rng_key):
        state, _ = kernel(rng_key, state)
        return state, state

    keys = jax.random.split(rng_key, num_samples)
    _, states = jax.lax.scan(one_step, initial_state, keys)

    return states

In [None]:
%%time
rng_key = jax.random.PRNGKey(0)
states = inference_loop(rng_key, nuts_kernel, initial_state, 10_000)

m_samples = states.position["mchirp"].block_until_ready()
b_samples = states.position["eta"].block_until_ready()
#scale_samples = jnp.exp(states.position["log_scale"])

In [None]:
fig, (ax, ax1) = plt.subplots(ncols=2, figsize=(15, 6))
ax.plot(m_samples)
ax.set_xlabel("Samples")
ax.set_ylabel("m")

ax1.plot(b_samples)
ax1.set_xlabel("Samples")
ax1.set_ylabel("b")

In [None]:
inv_mass_matrix = np.array([0.01, 0.01, 0.01])
step_size = 1e-3

nuts = blackjax.nuts(logdensity, step_size, inv_mass_matrix)

In [None]:
initial_position = {"m": -1.0, "b": 5.0, "log_f": jnp.log(0.5)}
initial_state = nuts.init(initial_position)
initial_state

In [None]:
%%time
rng_key = jax.random.PRNGKey(0)
states = inference_loop(rng_key, nuts.step, initial_state, 4_000)

m_samples = states.position["m"].block_until_ready()
b_samples = states.position["b"].block_until_ready()
#loc_samples = states.position["loc"].block_until_ready()
#scale_samples = jnp.exp(states.position["log_scale"])

In [None]:
fig, (ax, ax1) = plt.subplots(ncols=2, figsize=(15, 6))
ax.plot(m_samples)
ax.set_xlabel("Samples")
ax.set_ylabel("m")

ax1.plot(b_samples)
ax1.set_xlabel("Samples")
ax1.set_ylabel("b")