In [1]:
import jax
import jax.numpy as jnp
import jax.scipy.stats as stats
import matplotlib.pyplot as plt
import numpy as np
from jax import random, jit

from pycbc.catalog import Merger
from pycbc.filter import resample_to_delta_t, highpass
from pycbc.psd import interpolate, inverse_spectrum_truncation

import gwjax
import gwjax.imrphenom

import optax

from jax.config import config
config.update("jax_debug_nans", True)

  from .autonotebook import tqdm as notebook_tqdm
No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)


In [2]:
# Define the conditiion data function
dynfac = 1.0e23
def condition(strain, sampling_rate):
    tmp = resample_to_delta_t(highpass(strain, 15.0), 1.0/sampling_rate).crop(2,2)
    return (dynfac*tmp)

# Define the PSD function 
def estimate_psd(strain, delta_f):
    psd = strain.psd(4)
    psd = interpolate(psd, delta_f)
    psd = inverse_spectrum_truncation(psd, int(4 * strain.sample_rate),
                                  low_frequency_cutoff=15, trunc_method='hann')
    return psd

In [3]:
# Get the data and estimate the PSDs
merger = Merger("GW150914")
sampling_rate = 2048 # Sampling rate in Hz
data = condition(merger.strain('H1'), sampling_rate)
signal_duration = float(data.duration)  # Duration of the signal in seconds
delta_freq = data.delta_f

invpsd = estimate_psd(data, data.delta_f)**(-1)

fcore = data.to_frequencyseries()*invpsd

#Get the frequency range
nyquist = sampling_rate//2
freqs = jnp.arange(1+(nyquist*signal_duration))/signal_duration


#Determine the low freq and high frequency cut off
kmin, kmax = int(15*signal_duration), int(900*signal_duration)

#Make everything the same length
data = jnp.array([data.to_frequencyseries()[kmin:kmax]])
fcore = jnp.asarray(fcore[kmin:kmax])
freqs = freqs[kmin:kmax]
invpsd = jnp.asarray(invpsd[kmin:kmax]) # can't take fft because it needs to start at 0 to nyquist 

In [11]:
def loglikelihood_fn(params, *_):
    m1, m2, theta1, theta2 = params
    #return jax.lax.cond( 0 < m1 < 100 and 0 < m2 < 100, loglikelihood_yes, loglikelihood_no, m1, m2, theta1, theta2) 
    #phase = jnp.tanh(conc_phase)*jnp.pi
    #m1, m2 = 100/(1+jnp.exp(-sigmoidm1)), 100/(1+jnp.exp(-sigmoidm2))
    #m1, m2 = -jnp.log(sigmoidm1)+3.0, -jnp.log(sigmoidm2)+3.0
    #s1, s2 = jnp.arccosh(theta1), jnp.arccosh(theta2)
    s1, s2 = 0.99*jnp.arctanh(theta1), 0.99*jnp.arctanh(theta2)
    
    parameters = {'phase': 0., 
              'geocent_time': 0., 
              'luminosity_distance': 1, 
              'theta_jn': 0., 
              'm1': m1, 'm2': m2, 
              'spin1': s1, 'spin2': s2, 
              'ra': 0., 'dec': 0., 
              'pol': 0.}
    
    hp, hc = gwjax.imrphenom.IMRPhenomD(freqs, parameters)
    d_inner_h = 0.
    optimal_snr_squared = 0.
    fp = -0.456852978678261
    fc = 0.36204310587763466
    h = hp*fp + hc*fc
    h_star = jnp.conj(h)
    integrand = jnp.nan_to_num(h_star * data * invpsd, 0.)
    d_inner_h += 4.0 / 32.0 * jnp.sum(integrand)
    integrand = jnp.nan_to_num(h_star * h * invpsd, 0.)
    optimal_snr_squared += (4.0 / 32.0 * jnp.sum(integrand)).real
    log_l = d_inner_h.real - optimal_snr_squared / 2
    return log_l.real
    

In [12]:
#loglikelihood_fn([jnp.log((30.0/100.0)-1), jnp.log((30.0/100.0)-1), jnp.arctanh(0.1/0.99),  jnp.arctanh(0.1/0.99), 0.01, jnp.arctanh(0.1/jnp.pi)])
#loglikelihood_fn([jnp.log(30.0/(100.0-30.0)), jnp.log(30.0/(100.0-30.0)), jnp.arctanh(0.1/0.99), jnp.arctanh(0.1/0.99)])
loglikelihood_fn([30.0, 30.0, jnp.cosh(0.0), jnp.cosh(0.0)])

Unexpected exception formatting exception. Falling back to standard exception


Traceback (most recent call last):
  File "/mnt/lustre/shared_conda/envs/sgreen/PyCBCandOptuna/lib/python3.11/site-packages/jax/_src/pjit.py", line 1252, in _pjit_call_impl
    return compiled.unsafe_call(*args)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/mnt/lustre/shared_conda/envs/sgreen/PyCBCandOptuna/lib/python3.11/site-packages/jax/_src/profiler.py", line 314, in wrapper
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/mnt/lustre/shared_conda/envs/sgreen/PyCBCandOptuna/lib/python3.11/site-packages/jax/_src/interpreters/pxla.py", line 1920, in __call__
    dispatch.check_special(self.name, arrays)
  File "/mnt/lustre/shared_conda/envs/sgreen/PyCBCandOptuna/lib/python3.11/site-packages/jax/_src/dispatch.py", line 436, in check_special
    _check_special(name, buf.dtype, buf)
  File "/mnt/lustre/shared_conda/envs/sgreen/PyCBCandOptuna/lib/python3.11/site-packages/jax/_src/dispatch.py", line 441, in _check_special
    raise FloatingPointError(f"invalid v

In [None]:
from typing import NamedTuple

class ScheduleState(NamedTuple):
    step_size: float
    do_sample: bool


def build_schedule(
    num_training_steps,
    num_cycles=4,
    initial_step_size=1e-4,
    exploration_ratio=0.25,
):
    cycle_length = num_training_steps // num_cycles

    def schedule_fn(step_id):
        do_sample = False
        if ((step_id % cycle_length)/cycle_length) >= exploration_ratio:
            do_sample = True

        cos_out = jnp.cos(jnp.pi * (step_id % cycle_length) / cycle_length) + 1
        step_size = 0.5 * cos_out * initial_step_size

        return ScheduleState(step_size, do_sample)

    return schedule_fn

In [None]:
import jax
import jax.numpy as jnp

import matplotlib.pyplot as plt
import numpy as np

schedule_fn = build_schedule(20000, 4, 1e-3)
schedule = [schedule_fn(i) for i in range(20000)]

step_sizes = np.array([step.step_size for step in schedule])
do_sample = np.array([step.do_sample for step in schedule])

sampling_points = np.ma.masked_where(~do_sample, step_sizes)

fig, ax = plt.subplots(figsize=(12,8))
ax.plot(step_sizes, lw=2, ls="--", color="r", label="Exploration stage")
ax.plot(sampling_points, lw=2, ls="-", color="k", label="Sampling stage")

ax.spines.right.set_visible(False)
ax.spines.top.set_visible(False)

ax.set_xlabel("Training steps", fontsize=20)
ax.set_ylabel("Step size", fontsize=20)
plt.legend()
plt.title("Training schedule for Cyclical SGLD")

In [None]:
from typing import NamedTuple

import blackjax
import optax

from jax.tree_util import tree_structure
from optax._src.base import OptState


class CyclicalSGMCMCState(NamedTuple):
    """State of the Cyclical SGMCMC sampler.
    """
    position: tree_structure
    opt_state: OptState


def cyclical_sgld(grad_estimator_fn, loglikelihood_fn):

    # Initialize the SgLD step function
    sgld = blackjax.sgld(grad_estimator_fn)
    sgd = optax.sgd(1e-3)

    def init_fn(position):
        opt_state = sgd.init(position)
        print(CyclicalSGMCMCState(position, opt_state))
        return CyclicalSGMCMCState(position, opt_state)

    def step_fn(rng_key, state, minibatch, schedule_state):
        """Cyclical SGLD kernel."""

        def step_with_sgld(current_state):
            rng_key, state, minibatch, step_size = current_state
            new_position = sgld.step(rng_key, state.position, minibatch, step_size)
            print(new_position)
            return CyclicalSGMCMCState(new_position, state.opt_state)

        def step_with_sgd(current_state):
            _, state, minibatch, step_size = current_state
            grads = grad_estimator_fn(state.position, 0)
            rescaled_grads = - 1. * step_size * grads
            updates, new_opt_state = sgd.update(rescaled_grads, state.opt_state, state.position)
            new_position = optax.apply_updates(state.position, updates)
            return CyclicalSGMCMCState(new_position, new_opt_state)

        new_state = jax.lax.cond(
            schedule_state.do_sample,
            step_with_sgld,
            step_with_sgd,
            (rng_key, state, minibatch, schedule_state.step_size)
        )

        return new_state

    return init_fn, step_fn

In [None]:
import jax
from fastprogress import progress_bar


# 50k iterations
# M = 30
# initial step size = 0.09
# ratio exploration = 1/4
num_training_steps = 5000
schedule_fn = build_schedule(num_training_steps, 30.0, 0.09, 0.25)
schedule = [schedule_fn(i) for i in range(num_training_steps)]


grad_fn = lambda x, _: jax.grad(loglikelihood_fn)(x)
init, step = cyclical_sgld(grad_fn, loglikelihood_fn)

rng_key = jax.random.PRNGKey(3)
init_position = jnp.array([30.0, 30.0, 0.0, 0.0])
init_state = init(init_position)
print(init_state)

In [13]:
state = init_state
cyclical_samples = []
for i in progress_bar(range(num_training_steps)):
    _, rng_key = jax.random.split(rng_key)
    state = jax.jit(step)(rng_key, state, 0, schedule[i])
    if schedule[i].do_sample:
        cyclical_samples.append(state.position)

Traced<ShapedArray(float64[4])>with<DynamicJaxprTrace(level=2/0)>
Invalid nan value encountered in the output of a C++-jit/pmap function. Calling the de-optimized version.
Unexpected exception formatting exception. Falling back to standard exception


Traceback (most recent call last):
  File "/mnt/lustre/shared_conda/envs/sgreen/PyCBCandOptuna/lib/python3.11/site-packages/jax/_src/api.py", line 131, in _nan_check_posthook
    dispatch.check_special(pjit.pjit_p, buffers)
  File "/mnt/lustre/shared_conda/envs/sgreen/PyCBCandOptuna/lib/python3.11/site-packages/jax/_src/dispatch.py", line 436, in check_special
    _check_special(name, buf.dtype, buf)
  File "/mnt/lustre/shared_conda/envs/sgreen/PyCBCandOptuna/lib/python3.11/site-packages/jax/_src/dispatch.py", line 441, in _check_special
    raise FloatingPointError(f"invalid value (nan) encountered in {name}")
FloatingPointError: invalid value (nan) encountered in pjit

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/mnt/lustre/shared_conda/envs/sgreen/PyCBCandOptuna/lib/python3.11/site-packages/jax/_src/pjit.py", line 1252, in _pjit_call_impl
    return compiled.unsafe_call(*args)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^


In [None]:
state

In [None]:
cyclical_samples

In [None]:
fig = plt.figure()
ax = fig.add_subplot(111)
x = jnp.array([sample[0] for sample in cyclical_samples])
y = jnp.array([sample[1] for sample in cyclical_samples])
z = jnp.array([sample[2] for sample in cyclical_samples])
a = jnp.array([sample[3] for sample in cyclical_samples])


ax.plot(x, y, 'k-', lw=0.1, alpha=0.5)

plt.axis('off')
plt.title("Trajectory with Cyclical SGLD")

In [None]:
x

In [None]:
print(jnp.shape(cyclical_samples))

In [None]:
import corner

figure = corner.corner(np.array([cyclical_samples]), labels=[
        r"sigmoidm1",
        r"sigmoidm2",
        r"theta1",
        r"theta2",
    
    ], show_titles=True,
    quantiles=[0.16, 0.5, 0.84],
    title_kwargs={"fontsize": 12})

In [None]:
print(100 / (1+jnp.exp(-159)))
print(100 / (1+jnp.exp(-118)))
print(jnp.tanh(92)*jnp.pi)
print(jnp.tanh(115)*jnp.pi)

In [None]:
mass1 = 100 / (1+jnp.exp(-x))
mass2 = 100 / (1+jnp.exp(-y))
spin1 = jnp.tanh(z)*jnp.pi
spin2 = jnp.tanh(a)*jnp.pi



revamped_cyclicalsamples = jnp.reshape(jnp.stack([mass1, mass2, spin1, spin2]), newshape=(3720, 4))

In [None]:
mass1

In [None]:
jnp.shape(revamped_cyclicalsamples)

In [None]:
figure = corner.corner(np.array([revamped_cyclicalsamples]), labels=[
        r"mass1",
        r"mass1",
        r"spin1",
        r"spin1",
    
    ], show_titles=True,
    quantiles=[0.16, 0.5, 0.84],
    title_kwargs={"fontsize": 12})