In [1]:
import os
os.environ["XLA_FLAGS"] = '--xla_force_host_platform_device_count=10 --xla_cpu_enable_fast_math=true'

import jax
print(jax.devices())

import warnings
warnings.filterwarnings("ignore", "Wswiglal-redir-stdio")

import numpy as np
import jax.numpy as jnp
import jax 
from jax import grad, vmap
jax.config.update("jax_enable_x64", True)


import matplotlib.pyplot as plt

import scipy.interpolate as interp
import scipy.integrate as integ
import scipy.linalg as sla

import fisher_jim_tgr_2par as lib
import pycbc.conversions

import astropy.units as u
from astropy import constants as const

Ms = (u.Msun * const.G / const.c**3 ).si.value
from datetime import datetime
datestr = datetime.now().strftime('%m-%d-%y')

import matplotlib as mpl
from matplotlib.legend_handler import HandlerLine2D, HandlerPatch

def reset_matplotlib():
    # Reset all matplotlib settings to defaults
    mpl.rcdefaults()

    # Create a default handler map and update it globally
    default_handler_map = {
        mpl.lines.Line2D: HandlerLine2D(numpoints=1),
        mpl.patches.Patch: HandlerPatch()
    }

    # Update the default handler map globally
    mpl.legend.Legend.update_default_handler_map(default_handler_map)

# Call this function at the start of your script to globally reset settings
reset_matplotlib()


An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.


[CpuDevice(id=0), CpuDevice(id=1), CpuDevice(id=2), CpuDevice(id=3), CpuDevice(id=4), CpuDevice(id=5), CpuDevice(id=6), CpuDevice(id=7), CpuDevice(id=8), CpuDevice(id=9)]


In [2]:
# Ground stuff
n_freq = 2000
fmin = 10.
fmax = 1000.
freqs = jnp.logspace(jnp.log10(fmin), jnp.log10(fmax), num = int(n_freq))


full_str = ["M_c", "eta", "d_L", "ra", "dec", "iota", "psi", "t_c", "phase_c", "zco", "rng", "Mc_source", "snr_H1", "snr_L1", "snr_V1", "snr_t"]
names = full_str[0:9]
# data = np.loadtxt("data/events.txt")
data = jnp.array(np.loadtxt("data/5-year-data-O3/5-year-data-O3-events.txt"))
# data = jnp.array(nnp.loadtxt("data/5-year-data-CE/5-year-data-CE-events.txt"))

psd = lib.read_mag(freqs, "curves/o3_l1.txt")**2
# psd = lib.read_mag(freqs, "../curves/ce1.txt")**2

data_150914 = np.array(data[0])
data_150914[0] = 30
q  = 0.79
data_150914[1] = q/(1+q)**2
data_150914[2] = 390

idx = '150914'
if idx=='150914':
    dat = jnp.array(data_150914)

red_param = dict(zip(names, jnp.array(dat).T))

In [3]:
from jimgw.detector import H1, L1, V1
from jimgw.waveform import RippleIMRPhenomPv2, RippleIMRPhenomD
import jax
import jax.numpy as jnp
from jax import grad, vmap

jax.config.update("jax_enable_x64", True)
jax.config.update("jax_debug_nans", True)
# Assuming 'waveform' is defined elsewhere, or you can import it
waveform = RippleIMRPhenomPv2(f_ref=20)
# waveform = RippleIMRPhenomD(f_ref=20)

def get_h_slow(x, f, detector):
    # Set default parameters
    

    ff = jnp.array([f])
    h_sky = waveform(ff, x)
    align_time = jnp.exp(-1j * 2 * jnp.pi * ff * (x['epoch'] + x['t_c']))
    signal = detector.fd_response(ff, h_sky, x) * align_time
    return signal[0]

# Create generic functions for each detector
def get_h_nojit(x, f, detector):
    return vmap(lambda f_single: get_h_slow(x, f_single, detector))(f)



red_param = dict(zip(names, jnp.array(dat).T))
epsilon = jnp.array(1e-6)
for key in ['s1_x', 's1_y', 's1_z', 's2_x', 's2_y', 's2_z']:
    red_param[key] = epsilon
zero_vector = jnp.array(0.)
for key in ['gmst', 'epoch']:
    red_param[key] = zero_vector

x = red_param.copy()
# x['s1_z'] = 0.2
gr_param_diff = names + ['s1_z', 's1_x']

get_h1 = jax.jit(lambda x : get_h_nojit(x, freqs, H1))
get_h2 = jax.jit(lambda x : get_h_nojit(x, freqs, L1))
get_h3 = jax.jit(lambda x : get_h_nojit(x, freqs, V1))


get_h = lambda x : {'H1': get_h1(x), 'L1': get_h2(x), 'V1': get_h3(x)}


h = get_h(x)

keys = list(x.keys())
idx_diff = tuple(i for i, key in enumerate(keys) if key in gr_param_diff)

def f_wrapped(*args, freqs = None, det= None):
    x = dict(zip(keys, args)) 
    return get_h_nojit(x, freqs, H1)

get_dh1 = jax.jit(jax.jacfwd(lambda *x: f_wrapped(*x, freqs = freqs, det = H1), argnums = idx_diff))
get_dh2 = jax.jit(jax.jacfwd(lambda *x: f_wrapped(*x, freqs = freqs, det = H1), argnums = idx_diff))
get_dh3 = jax.jit(jax.jacfwd(lambda *x: f_wrapped(*x, freqs = freqs, det = H1), argnums = idx_diff))

def get_dh_gr(x):
    xvalues = list(x.values())

    dh = {'H1': dict(zip(gr_param_diff, get_dh1(*xvalues))) , 'L1': dict(zip(gr_param_diff, get_dh2(*xvalues))), 'V1': dict(zip(gr_param_diff, get_dh3(*xvalues)))}
    return dh

dh = get_dh_gr(x)

##  Inner Product Minimum Working Ex

In [5]:
paramx = ['M_c', 'eta', 'd_L', 'ra', 'dec', 'iota', 'psi', 't_c', 'phase_c']
idx_x = {paramx[i] : i for i in range(len(paramx))}
# %timeit compute_bias(dh["H1"], h["H1"], psd, freqs, idx_x)
log_flag =  {paramx[i] : 0 for i in range(len(paramx))}; log_flag["M_c"] = 1; log_flag["d_L"] = 1
logmult = {paramx[i]: (x[paramx[i]] if log_flag[paramx[i]] else 1) for i in range(len(paramx))}



def innprodslow(hf1, hf2, psd, freqs):
    integrand = (jnp.conj(hf1) * hf2 + hf1 * jnp.conj(hf2)) / psd
    prod = 2. * jnp.trapezoid(integrand, freqs)
    return prod

freqs = jnp.array(freqs, dtype=jnp.float64)
psd = jnp.array(psd, dtype=jnp.float64)
innprodfast = jax.jit(innprodslow)

dh_dim  = {key: logmult[key] * dh["H1"][key] for key in logmult}
i1 = "d_L"; i2 = 't_c'
hf1 = dh_dim[i1].copy(); hf2 = dh_dim[i2].copy()
innprodslow(hf1,hf2, psd, freqs), innprodfast(hf1,hf2, psd, freqs)

(Array(-1.81415985e-12+0.j, dtype=complex128),
 Array(-2.12628272e-12+0.j, dtype=complex128))

## Fisher Matrix Demo

fish slow uses the slow inner product with no jit

fish fast uses the fast inner product with jit

In [10]:
def fishslow(freqs, dh, par, idx_par, psd, log_flag):
    n_pt = len(freqs)
    n_dof = len(idx_par)

    dh_arr = jnp.zeros([n_dof, n_pt], dtype=jnp.complex128)

    # Convert idx_par to a list for static looping
    idx_list = list(idx_par.keys())
    for idx in idx_list:
        idx_position = idx_par[idx]
        dh_arr = dh_arr.at[idx_position, :].set(dh[idx])

        # Use jax.lax.cond for conditional multiplication
        dh_arr = dh_arr.at[idx_position, :].set(
            jax.lax.cond(
                log_flag[idx],
                lambda x: x * par[idx],
                lambda x: x,
                dh_arr[idx_position, :]
            )
        )

    gamma = jnp.zeros([n_dof, n_dof], dtype=jnp.float64)

    # Use static loops
    for i in range(n_dof):
        for j in range(i, n_dof):
            gamma = gamma.at[i, j].set(
                jnp.real(innprodslow(dh_arr[i, :], dh_arr[j, :], psd, freqs))
            )
        for j in range(i):
            gamma = gamma.at[i, j].set(jnp.conj(gamma[j, i]))

    return gamma

def fishfast(freqs, dh, par, idx_par, psd, log_flag):
    n_pt = len(freqs)
    n_dof = len(idx_par)

    dh_arr = jnp.zeros([n_dof, n_pt], dtype=jnp.complex128)

    # Convert idx_par to a list for static looping
    idx_list = list(idx_par.keys())
    for idx in idx_list:
        idx_position = idx_par[idx]
        dh_arr = dh_arr.at[idx_position, :].set(dh[idx])

        # Use jax.lax.cond for conditional multiplication
        dh_arr = dh_arr.at[idx_position, :].set(
            jax.lax.cond(
                log_flag[idx],
                lambda x: x * par[idx],
                lambda x: x,
                dh_arr[idx_position, :]
            )
        )

    gamma = jnp.zeros([n_dof, n_dof], dtype=jnp.float64)

    # Use static loops
    for i in range(n_dof):
        for j in range(i, n_dof):
            gamma = gamma.at[i, j].set(
                jnp.real(innprodfast(dh_arr[i, :], dh_arr[j, :], psd, freqs))
            )
        for j in range(i):
            gamma = gamma.at[i, j].set(jnp.conj(gamma[j, i]))

    return gamma

In [11]:
import pandas as pd 

display(pd.DataFrame(fishslow(freqs, dh["H1"], x, idx_x, psd, log_flag) ,paramx,paramx))
pd.DataFrame(np.abs(fishslow(freqs, dh["H1"], x, idx_x, psd, log_flag)- fishfast(freqs, dh["H1"], x, idx_x, psd, log_flag))/fishslow(freqs, dh["H1"], x, idx_x, psd, log_flag),paramx,paramx)

Unnamed: 0,M_c,eta,d_L,ra,dec,iota,psi,t_c,phase_c
M_c,194111.7,-1449215.0,-552.6458,34807.38,62480.83,28.508156,-16661.577969,-3821259.0,17634.043381
eta,-1449215.0,15020910.0,-1260.883,-478321.8,-1023653.0,4749.994527,189497.146149,60402390.0,-200777.338549
d_L,-552.6458,-1260.883,691.232,-445.637,-181.9569,-464.401857,19.028877,-1.81416e-12,0.000149
ra,34807.38,-478321.8,-445.637,21121.28,50590.44,162.432783,-6663.681278,-2901309.0,7046.060381
dec,62480.83,-1023653.0,-181.9569,50590.44,127089.0,-183.354449,-14845.701722,-7248934.0,15721.231998
iota,28.50816,4749.995,-464.4019,162.4328,-183.3544,313.051593,37.952096,17862.13,-53.747032
psi,-16661.58,189497.1,19.02888,-6663.681,-14845.7,37.952096,2464.400518,867424.8,-2610.065618
t_c,-3821259.0,60402390.0,-1.81416e-12,-2901309.0,-7248934.0,17862.134437,867424.827852,414204500.0,-918891.554292
phase_c,17634.04,-200777.3,0.0001488438,7046.06,15721.23,-53.747032,-2610.065618,-918891.6,2764.928385


Unnamed: 0,M_c,eta,d_L,ra,dec,iota,psi,t_c,phase_c
M_c,0.0,-0.0,-0.0,0.0,0.0,0.0,-0.0,-0.0,0.0
eta,-0.0,0.0,-1.80329e-16,-0.0,-0.0,0.0,0.0,0.0,-0.0
d_L,-0.0,-1.80329e-16,0.0,-0.0,-0.0,-0.0,0.0,-0.172048,4.39508e-12
ra,0.0,-0.0,-0.0,0.0,0.0,1.749752e-16,-1.364853e-16,-0.0,0.0
dec,0.0,-0.0,-0.0,0.0,0.0,-0.0,-0.0,-0.0,0.0
iota,0.0,0.0,-0.0,1.749752e-16,-0.0,0.0,0.0,0.0,-0.0
psi,-0.0,0.0,0.0,-1.364853e-16,-0.0,0.0,0.0,0.0,-0.0
t_c,-0.0,0.0,-0.1720482,-0.0,-0.0,0.0,0.0,0.0,-0.0
phase_c,0.0,-0.0,4.39508e-12,0.0,0.0,-0.0,-0.0,-0.0,0.0


In [7]:
dh

NameError: name 'dh' is not defined

* First entry is 