In [1]:
import jax
import jax.numpy as jnp
import jax.scipy.stats as stats
import matplotlib.pyplot as plt
import numpy as np

import blackjax

from ripple.waveforms import IMRPhenomXAS
from ripple import ms_to_Mc_eta

from fastprogress import progress_bar

import blackjax

from pycbc.catalog import Merger
from pycbc.filter import resample_to_delta_t, highpass
from pycbc.psd import interpolate, inverse_spectrum_truncation
from pycbc.waveform import get_fd_waveform
from pycbc.filter import matched_filter, sigmasq, get_cutoff_indices

No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)


In [2]:
data = {"L1": jnp.load("/users/sgreen/gwtuna/LVK/Paper/BBHSearch/P.E./L1_data.npy")}#, "H1": jnp.load("H1_data.npy")}
freqs = jnp.load("/users/sgreen/gwtuna/LVK/Paper/BBHSearch/P.E./L1_freqs.npy")
psd = {"L1": jnp.load("/users/sgreen/gwtuna/LVK/Paper/BBHSearch/P.E./L1_psd.npy")}#, "H1": jnp.load("H1_psd.npy")}

In [3]:
# Define the conditiion data function
dynfac = 1.0e23
def condition(strain, sampling_rate):
    tmp = resample_to_delta_t(highpass(strain, 15.0), 1.0/sampling_rate).crop(2,2)
    return (dynfac*tmp)

# Define the PSD function 
def estimate_psd(strain, delta_f):
    psd = strain.psd(4)
    psd = interpolate(psd, delta_f)
    psd = inverse_spectrum_truncation(psd, int(4 * strain.sample_rate),
                                  low_frequency_cutoff=15, trunc_method='hann')
    return psd

@jax.jit
def log_likelihood(mchirp, eta, chi1, chi2, tc, phic, dist_mpc, inclination):
    hp, hc = waveform_generator(
        freqs, mchirp, eta, chi1, chi2, tc, phic, dist_mpc, inclination,
    )
    d_inner_h = 0.
    optimal_snr_squared = 0.
    for ifo in data.keys():
        if ifo == "L1":
            fp = -0.456852978678261
            fc = 0.36204310587763466
        else:
            fp = 0.45529254427236665
            fc = -0.6283981252126967
        h = hp*fp + hc*fc
        h_star = jnp.conj(h)
        integrand = jnp.nan_to_num(h_star * data[ifo] / psd[ifo], 0.)
        d_inner_h += 4.0 / 32.0 * jnp.sum(integrand)
        integrand = jnp.nan_to_num(h_star * h / psd[ifo], 0.)
        optimal_snr_squared += (4.0 / 32.0 * jnp.sum(integrand)).real
    log_l = d_inner_h.real - optimal_snr_squared / 2
    return log_l.real

def log_probability(mchirp, eta, chi1, chi2, tc, phic, dist_mpc, inclination):
    return log_likelihood(mchirp, eta, chi1, chi2, tc, phic, dist_mpc, inclination)

@jax.jit
def waveform(fs, theta, fref):
    return IMRPhenomXAS.gen_IMRPhenomXAS_polar(fs, theta, fref)

@jax.jit
def waveform_generator(
    fs, mchirp, eta, chi1, chi2, tc, phic, dist_mpc, inclination,
    **kwargs
):
    fref = kwargs.get("fref", 20)
    theta_ripple = jnp.array(
        [mchirp, eta, chi1, chi2, dist_mpc, tc, phic, inclination]
    )
    hp, hc = waveform(fs, theta_ripple, fref)
    return hp, hc

In [4]:
logdensity = lambda a: log_probability(**a)

In [6]:
inv_mass_matrix = np.array([0.01, 0.01, 0.01, 0.01, 0.01, 0.01, 0.01, 0.01])
num_integration_steps = 60
step_size = 1e-3

nuts = blackjax.nuts(logdensity, step_size, inv_mass_matrix)
#hmc = blackjax.hmc(logdensity, step_size, inv_mass_matrix, num_integration_steps)

In [7]:
initial_position = {"mchirp": 30.0, "eta": 50.15, "chi1": 0.0, "chi2": 0.0, "tc": 112625946.0, "phic": 2.0, "dist_mpc": 100., "inclination": 2.0}
initial_state = nuts.init(initial_position)
initial_state

Exception ignored in: <function _xla_gc_callback at 0x7fcfec421f80>
Traceback (most recent call last):
  File "/mnt/lustre/shared_conda/envs/sgreen/PyCBCandOptuna/lib/python3.11/site-packages/jax/_src/lib/__init__.py", line 97, in _xla_gc_callback
    def _xla_gc_callback(*args):
    
KeyboardInterrupt: 
2023-07-04 16:33:39.937588: E external/xla/xla/service/slow_operation_alarm.cc:65] 
********************************
[Compiling module jit_log_likelihood] Very slow compile? If you want to file a bug, run with envvar XLA_FLAGS=--xla_dump_to=/tmp/foo and attach the results.
********************************
2023-07-04 16:34:14.792893: E external/xla/xla/service/slow_operation_alarm.cc:133] The operation took 2m34.855382209s

********************************
[Compiling module jit_log_likelihood] Very slow compile? If you want to file a bug, run with envvar XLA_FLAGS=--xla_dump_to=/tmp/foo and attach the results.
********************************


HMCState(position={'mchirp': 30.0, 'eta': 50.15, 'chi1': 0.0, 'chi2': 0.0, 'tc': 112625946.0, 'phic': 2.0, 'dist_mpc': 100.0, 'inclination': 2.0}, logdensity=Array(0., dtype=float64), logdensity_grad={'chi1': Array(nan, dtype=float64, weak_type=True), 'chi2': Array(nan, dtype=float64, weak_type=True), 'dist_mpc': Array(nan, dtype=float64, weak_type=True), 'eta': Array(nan, dtype=float64, weak_type=True), 'inclination': Array(nan, dtype=float64, weak_type=True), 'mchirp': Array(nan, dtype=float64, weak_type=True), 'phic': Array(nan, dtype=float64, weak_type=True), 'tc': Array(nan, dtype=float64, weak_type=True)})

In [None]:
hmc_kernel = jax.jit(nuts.step)

In [None]:
def inference_loop(rng_key, kernel, initial_state, num_samples):
    @jax.jit
    def one_step(state, rng_key):
        state, _ = kernel(rng_key, state)
        return state, state

    keys = jax.random.split(rng_key, num_samples)
    _, states = jax.lax.scan(one_step, initial_state, keys)

    return states

In [None]:
%%time
rng_key = jax.random.PRNGKey(0)
states = inference_loop(rng_key, hmc_kernel, initial_state, 10_000)

m_samples = states.position["m"].block_until_ready()
b_samples = states.position["b"].block_until_ready()
#scale_samples = jnp.exp(states.position["log_scale"])

In [None]:
fig, (ax, ax1) = plt.subplots(ncols=2, figsize=(15, 6))
ax.plot(m_samples)
ax.set_xlabel("Samples")
ax.set_ylabel("m")

ax1.plot(b_samples)
ax1.set_xlabel("Samples")
ax1.set_ylabel("b")

In [None]:
inv_mass_matrix = np.array([0.01, 0.01, 0.01])
step_size = 1e-3

nuts = blackjax.nuts(logdensity, step_size, inv_mass_matrix)

In [None]:
initial_position = {"m": -1.0, "b": 5.0, "log_f": jnp.log(0.5)}
initial_state = nuts.init(initial_position)
initial_state

In [None]:
%%time
rng_key = jax.random.PRNGKey(0)
states = inference_loop(rng_key, nuts.step, initial_state, 4_000)

m_samples = states.position["m"].block_until_ready()
b_samples = states.position["b"].block_until_ready()
#loc_samples = states.position["loc"].block_until_ready()
#scale_samples = jnp.exp(states.position["log_scale"])

In [None]:
fig, (ax, ax1) = plt.subplots(ncols=2, figsize=(15, 6))
ax.plot(m_samples)
ax.set_xlabel("Samples")
ax.set_ylabel("m")

ax1.plot(b_samples)
ax1.set_xlabel("Samples")
ax1.set_ylabel("b")