In [1]:
import numpy as np
import matplotlib.pyplot as plt
import pycbc
from pycbc.catalog import Merger
from pycbc.filter import resample_to_delta_t, highpass
from pycbc.psd import interpolate, inverse_spectrum_truncation

import jax
import jax.numpy as jnp
from functools import partial
import gwjax
import gwjax.imrphenom
from gwjax import CplxNormal

from fastprogress import progress_bar

import blackjax

import numpyro
from numpyro import infer
from numpyro import distributions as dist
from numpyro.infer.util import initialize_model

  from .autonotebook import tqdm as notebook_tqdm
No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)


In [2]:
ifos = ['H1', 'L1']

In [3]:
def condition(strain):
    tmp = resample_to_delta_t(highpass(strain, 15.0), 1.0/2048).crop(2,2)
    return (1e23*tmp)

def estimate_psd(strain, delta_f):
    psd = strain.psd(4)
    psd = interpolate(psd, delta_f)
    psd = inverse_spectrum_truncation(psd, int(4 * strain.sample_rate),
                                  low_frequency_cutoff=15, trunc_method='hann')
    return psd

def whiten(strain):
    psd = estimate_psd(strain, strain.delta_f)
    return (strain.to_frequencyseries() / np.sqrt(psd)).to_timeseries()

In [4]:
merger = Merger("GW150914")
gps_time = merger.time

# Get the data and estimate the PSDs
strain = {ifo: condition(merger.strain(ifo)) for ifo in ifos}

wdata = {}
invasd = {}
for ifo in ifos:
    data = strain[ifo]
    
    srate = int(1/data.delta_t)
    idx = int(srate * (gps_time - data.start_time))
    
    wdata[ifo] = whiten(data)[idx-4*srate:idx+4*srate]
    invasd[ifo] = estimate_psd(data, 1/8)**-0.5

kmin, kmax = 15*8, 900*8

In [5]:
gmst = gwjax.gmst(gps_time)

In [6]:
freqs = jnp.arange(1+1024*8, dtype=jnp.float32)/8

In [7]:
params = {'phase': 0., 'geocent_time': 0.,
            'luminosity_distance': 1, 'theta_jn': 0., 
            'm1': 36, 'm2': 29, 'spin1': 0., 'spin2': 0.,
            'ra': 0., 'dec': 0., 'pol': 0.}

In [8]:
def project_to_detector(params, htup, freqs, invasd, detector, gmst):
    deltat = gwjax.time_delay_from_earth_center(detector, gmst,
                                                params['ra'], params['dec'])
    fp, fx = gwjax.antenna_pattern(detector, gmst,
                                    params['ra'], params['dec'], params['pol'])
    return invasd*(fp*htup[0] + fx*htup[1])*jnp.exp(-1.j*2*jnp.pi*freqs*deltat)

def waveform_projections(params, freqs, project_dict, gmst):
    htup = gwjax.imrphenom.IMRPhenomD(freqs, params)
    return {ifo: project(params, htup) for ifo, project in project_dict.items()}

In [9]:
fdata = {}
for ifo in ifos:
    fdata[ifo] = jnp.fft.rfft(jnp.array(np.roll(wdata[ifo], -4*srate), dtype=jnp.float32))

invasd = {ifo: jnp.array(iasd, dtype=jnp.float32) for ifo, iasd in invasd.items()}

project = {}
for ifo in ifos:
    project[ifo] = jax.jit(partial(project_to_detector,
                                        freqs=freqs[kmin:kmax],
                                        invasd=invasd[ifo][kmin:kmax],
                                        detector=gwjax.detectors[ifo],
                                        gmst=gmst))

waveform_projections = jax.jit(partial(waveform_projections,
                                        freqs=freqs[kmin:kmax],
                                        project_dict=project,
                                        gmst=gmst))

In [22]:
def cbc_model(fdata=None):
    amp = numpyro.sample("amp", dist.Normal(0, 10))
    t = numpyro.sample("time", dist.Normal(0, 0.01))
    phase = numpyro.sample('phase', dist.Uniform(-jnp.pi, jnp.pi))
    m1 = numpyro.sample('m1', dist.Uniform(20, 50))
    m2 = numpyro.sample('m2', dist.Uniform(20, 50))
    ra = numpyro.sample('ra', dist.Uniform(-jnp.pi, jnp.pi))
    dec = numpyro.sample('dec', dist.Uniform(0., jnp.pi))
    pol = numpyro.sample('psi', dist.Uniform(0, 2*jnp.pi))

    proj_wfs = waveform_projections({'phase': phase, 'geocent_time': t,
            'luminosity_distance': amp, 'theta_jn': 0., 
            'm1': 36, 'm2': 29, 'spin1': 0., 'spin2': 0.,
            'ra': ra, 'dec': dec, 'pol': pol})

    numpyro.sample("y1", CplxNormal(proj_wfs['H1']), obs=fdata['H1'])
    numpyro.sample("y2", CplxNormal(proj_wfs['L1']), obs=fdata['L1'])

In [23]:
rng_key = jax.random.PRNGKey(0)
init_params, potential_fn_gen, *_ = initialize_model(
    rng_key,
    cbc_model,
    model_args=(fdata),
    dynamic_args=True,
)

TypeError: cbc_model() takes from 0 to 1 positional arguments but 2 were given

In [None]:
logdensity_fn = lambda position: -potential_fn_gen(fdata)(position)
initial_position = init_params.z

In [None]:
import blackjax

num_warmup = 2000

adapt = blackjax.window_adaptation(
    blackjax.nuts, logdensity_fn, target_acceptance_rate=0.8
)
(last_state, parameters), _ = adapt.run(rng_key, initial_position, num_warmup)
kernel = blackjax.nuts(logdensity_fn, **parameters).step

In [None]:
def inference_loop(rng_key, kernel, initial_state, num_samples):
    @jax.jit
    def one_step(state, rng_key):
        state, info = kernel(rng_key, state)
        return state, (state, info)

    keys = jax.random.split(rng_key, num_samples)
    _, (states, infos) = jax.lax.scan(one_step, initial_state, keys)

    return states, (
        infos.acceptance_rate,
        infos.is_divergent,
        infos.num_integration_steps,
    )

In [None]:
num_sample = 1000

states, infos = inference_loop(rng_key, kernel, last_state, num_sample)
_ = states.position["mu"].block_until_ready()

In [None]:
acceptance_rate = np.mean(infos[0])
num_divergent = np.mean(infos[1])

print(f"\Average acceptance rate: {acceptance_rate:.2f}")
print(f"There were {100*num_divergent:.2f}% divergent transitions")

In [None]:
import seaborn as sns
from matplotlib import pyplot as plt

samples = states.position

fig, axes = plt.subplots(ncols=2)
fig.set_size_inches(12, 5)
sns.kdeplot(samples["mu"], ax=axes[0])
sns.kdeplot(samples["tau"], ax=axes[1])
axes[0].set_xlabel("mu")
axes[1].set_xlabel("tau")
fig.tight_layout()