In [1]:
import jax
import jax.numpy as jnp
import jax.scipy.stats as stats
import matplotlib.pyplot as plt
import numpy as np

from functools import partial
from sgmcmcjax.samplers import build_sgld_sampler

from jax import random, jit

import gwjax
import gwjax.imrphenom

from fastprogress import progress_bar

from jax.config import config
config.update("jax_debug_nans", True)

  from .autonotebook import tqdm as notebook_tqdm
No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)


In [2]:
data = {"L1": jnp.load("/users/sgreen/gwtuna/LVK/Paper/BBHSearch/P.E./L1_data.npy")}#, "H1": jnp.load("H1_data.npy")}
freqs = jnp.load("/users/sgreen/gwtuna/LVK/Paper/BBHSearch/P.E./L1_freqs.npy")
psd = {"L1": jnp.load("/users/sgreen/gwtuna/LVK/Paper/BBHSearch/P.E./L1_psd.npy")}#, "H1": jnp.load("H1_psd.npy")}

In [3]:
data = data['L1']
psd = psd['L1']

Xdata = jnp.array([data, psd, freqs])

In [4]:
def waveform_template(freqs, params):
    # Create the waveform template using the mass parameters
    # Adjust this function based on your specific waveform model
    #may want to do partial on this
    return gwjax.imrphenom.IMRPhenomD(freqs, params)

def log_likelihood(theta, data):
    m1, m2 = theta
    params = {'phase': 0., 
              'geocent_time': 0.,
              'luminosity_distance': 1., 
              'theta_jn': 0., 
              'm1': m1, 'm2': m2, 
              'spin1': 0., 'spin2': 0., 
              'ra': 0., 'dec': 0., 
              'pol': 0.}
    hp, hc = waveform_template(freqs, params)
    d_inner_h = 0.
    optimal_snr_squared = 0.
    fp = -0.456852978678261
    fc = 0.36204310587763466
    h = hp*fp + hc*fc
    h_star = jnp.conj(h)
    integrand = jnp.nan_to_num(h_star * data / psd, 0.)
    d_inner_h += 4.0 / 32.0 * jnp.sum(integrand)
    integrand = jnp.nan_to_num(h_star * h / psd, 0.)
    optimal_snr_squared += (4.0 / 32.0 * jnp.sum(integrand)).real
    log_l = d_inner_h.real - optimal_snr_squared / 2
    return log_l.real

def log_prior(theta):
    m1, m2 = theta
    return jnp.log(m1*m2)

In [5]:
# build sampler
batch_size = len(data)
dt = 1e-5
sampler = build_sgld_sampler(dt, log_likelihood, log_prior, (Xdata,), batch_size)

In [6]:
# jit the sampler
sampler = partial(jit, static_argnums=(1,))(sampler)

In [7]:
# run sampler
Nsamples = 2000
key = random.PRNGKey(0)
initial_theta = jnp.array([30.0, 30.0])
samples = sampler(key, Nsamples, initial_theta)

Running for 2,000 iterations: 100%|██████████| 2000/2000 [02:18<00:00, 14.40it/s]
Running for 2,000 iterations: 100%|██████████| 2000/2000 [02:21<00:00, 14.12it/s]


Unexpected exception formatting exception. Falling back to standard exception


Traceback (most recent call last):
  File "/mnt/lustre/shared_conda/envs/sgreen/PyCBCandOptuna/lib/python3.11/site-packages/jax/_src/pjit.py", line 1252, in _pjit_call_impl
    return compiled.unsafe_call(*args)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/mnt/lustre/shared_conda/envs/sgreen/PyCBCandOptuna/lib/python3.11/site-packages/jax/_src/profiler.py", line 314, in wrapper
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/mnt/lustre/shared_conda/envs/sgreen/PyCBCandOptuna/lib/python3.11/site-packages/jax/_src/interpreters/pxla.py", line 1920, in __call__
    dispatch.check_special(self.name, arrays)
  File "/mnt/lustre/shared_conda/envs/sgreen/PyCBCandOptuna/lib/python3.11/site-packages/jax/_src/dispatch.py", line 436, in check_special
    _check_special(name, buf.dtype, buf)
  File "/mnt/lustre/shared_conda/envs/sgreen/PyCBCandOptuna/lib/python3.11/site-packages/jax/_src/dispatch.py", line 441, in _check_special
    raise FloatingPointError(f"invalid v

In [None]:
jnp.shape(samples)

In [None]:
plt.plot(samples[:, 0])