In [1]:
import jax.numpy as jnp
from ripple.waveforms import IMRPhenomXAS
from ripple import ms_to_Mc_eta
import copy
import jax

from fastprogress import progress_bar

import blackjax
import blackjax.sgmcmc.gradients as gradients

from pycbc.catalog import Merger
from pycbc.filter import resample_to_delta_t, highpass
from pycbc.psd import interpolate, inverse_spectrum_truncation
from pycbc.waveform import get_fd_waveform
from pycbc.filter import matched_filter, sigmasq, get_cutoff_indices

from jax.scipy.stats import multivariate_normal

import blackjax
import blackjax.smc.resampling as resampling

import numpyro
from numpyro import infer
from numpyro import distributions as dist

No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
  from .autonotebook import tqdm as notebook_tqdm


In [2]:
data = {"L1": jnp.load("/users/sgreen/gwtuna/LVK/Paper/BBHSearch/P.E./L1_data.npy")}#, "H1": jnp.load("H1_data.npy")}
freqs = jnp.load("/users/sgreen/gwtuna/LVK/Paper/BBHSearch/P.E./L1_freqs.npy")
psd = {"L1": jnp.load("/users/sgreen/gwtuna/LVK/Paper/BBHSearch/P.E./L1_psd.npy")}#, "H1": jnp.load("H1_psd.npy")}

In [3]:
# Define the conditiion data function
dynfac = 1.0e23
def condition(strain, sampling_rate):
    tmp = resample_to_delta_t(highpass(strain, 15.0), 1.0/sampling_rate).crop(2,2)
    return (dynfac*tmp)

# Define the PSD function 
def estimate_psd(strain, delta_f):
    psd = strain.psd(4)
    psd = interpolate(psd, delta_f)
    psd = inverse_spectrum_truncation(psd, int(4 * strain.sample_rate),
                                  low_frequency_cutoff=15, trunc_method='hann')
    return psd

@jax.jit
def log_likelihood(mchirp, eta, chi1, chi2, tc, phic, dist_mpc, inclination):
    hp, hc = waveform_generator(
        freqs, mchirp, eta, chi1, chi2, tc, phic, dist_mpc, inclination,
    )
    d_inner_h = 0.
    optimal_snr_squared = 0.
    for ifo in data.keys():
        if ifo == "L1":
            fp = -0.456852978678261
            fc = 0.36204310587763466
        else:
            fp = 0.45529254427236665
            fc = -0.6283981252126967
        h = hp*fp + hc*fc
        h_star = jnp.conj(h)
        integrand = jnp.nan_to_num(h_star * data[ifo] / psd[ifo], 0.)
        d_inner_h += 4.0 / 32.0 * jnp.sum(integrand)
        integrand = jnp.nan_to_num(h_star * h / psd[ifo], 0.)
        optimal_snr_squared += (4.0 / 32.0 * jnp.sum(integrand)).real
    log_l = d_inner_h.real - optimal_snr_squared / 2
    return log_l.real

def prior_model():
    mchirp = numpyro.sample('mchirp', dist.Uniform(10., 50.))
    eta = numpyro.sample('eta', dist.Uniform(0.1, 0.25))
    chi1 = numpyro.sample('chi1', dist.Uniform(-0.9, 0.9))
    chi2 = numpyro.sample('chi2', dist.Uniform(-0.9, 0.9))
    tc = numpyro.sample('tc', dist.Uniform(1126259460, 1126259464))
    phic = numpyro.sample('phic', dist.Uniform(0, 3.14))
    dist_mpc = numpyro.sample('dist_mpc', dist.Uniform(100., 2000.))
    inclination = numpyro.sample("inclination", dist.Uniform(0.0, 3.14))
    return mchirp, eta, chi1, chi2, tc, phic, dist_mpc, inclination

@jax.jit
def waveform(fs, theta, fref):
    return IMRPhenomXAS.gen_IMRPhenomXAS_polar(fs, theta, fref)

@jax.jit
def waveform_generator(
    fs, mchirp, eta, chi1, chi2, tc, phic, dist_mpc, inclination,
    **kwargs
):
    fref = kwargs.get("fref", 20)
    theta_ripple = jnp.array(
        [mchirp, eta, chi1, chi2, dist_mpc, tc, phic, inclination]
    )
    hp, hc = waveform(fs, theta_ripple, fref)
    return hp, hc

In [4]:
def inference_loop(rng_key, mcmc_kernel, initial_state, num_samples):
    @jax.jit
    def one_step(state, k):
        state, _ = mcmc_kernel(k, state)
        return state, state

    keys = jax.random.split(rng_key, num_samples)
    _, states = jax.lax.scan(one_step, initial_state, keys)

    return states


def full_logdensity(params):
    return prior_model() + prior_log_prob(params)


inv_mass_matrix = jnp.eye(1)
n_samples = 10_000

In [5]:
%%time

key = jax.random.PRNGKey(42)

hmc_parameters = dict(
    step_size=1e-4, inverse_mass_matrix=inv_mass_matrix, num_integration_steps=50
)

hmc = blackjax.hmc(full_logdensity, **hmc_parameters)
hmc_state = hmc.init(jnp.ones((1,)))
hmc_samples = inference_loop(key, hmc.step, hmc_state, n_samples)

TypeError: unexpected PRNG key type <class 'NoneType'>