In [5]:
import os
os.environ["XLA_FLAGS"] = '--xla_force_host_platform_device_count=10 --xla_cpu_enable_fast_math=true'

import jax
print(jax.devices())

import warnings
warnings.filterwarnings("ignore", "Wswiglal-redir-stdio")

import numpy as nnp
import jax.numpy as jnp
import jax 
from jax import grad, vmap
jax.config.update("jax_enable_x64", True)


import matplotlib.pyplot as plt

import scipy.interpolate as interp
import scipy.integrate as integ
import scipy.linalg as sla

import fisher_jim_tgr_2par as lib
import pycbc.conversions

import astropy.units as u
from astropy import constants as const

Ms = (u.Msun * const.G / const.c**3 ).si.value

[CpuDevice(id=0), CpuDevice(id=1), CpuDevice(id=2), CpuDevice(id=3), CpuDevice(id=4), CpuDevice(id=5), CpuDevice(id=6), CpuDevice(id=7), CpuDevice(id=8), CpuDevice(id=9)]


In [6]:
# Ground stuff
n_freq = 2000
fmin = 10.
fmax = 1000.
freqs = jnp.logspace(jnp.log10(fmin), jnp.log10(fmax), num = int(n_freq))


full_str = ["M_c", "eta", "d_L", "ra", "dec", "iota", "psi", "t_c", "phase_c", "zco", "rng", "Mc_source", "snr_H1", "snr_L1", "snr_V1", "snr_t"]
names = full_str[0:9]
# data = np.loadtxt("data/events.txt")
data = jnp.array(nnp.loadtxt("data/5-year-data-O3/5-year-data-O3-events.txt"))
# data = jnp.array(nnp.loadtxt("data/5-year-data-CE/5-year-data-CE-events.txt"))

psd = lib.read_mag(freqs, "curves/o3_l1.txt")**2
# psd = lib.read_mag(freqs, "../curves/ce1.txt")**2

data_150914 = nnp.array(data[0])
data_150914[0] = 30
q  = 0.79
data_150914[1] = q/(1+q)**2
data_150914[2] = 390


In [7]:
red_param = dict(zip(names, jnp.array(data[0]).T))
get_dh_H1  = jax.jit(lib.get_dh_H1)
get_dh_L1  = jax.jit(lib.get_dh_L1)
get_dh_V1  = jax.jit(lib.get_dh_V1)

get_h_H1   = jax.jit(lib.get_h_H1)
get_h_L1   = jax.jit(lib.get_h_L1)
get_h_V1   = jax.jit(lib.get_h_V1)

a = get_dh_H1(red_param, freqs)
a = get_dh_L1(red_param, freqs)
a = get_dh_V1(red_param, freqs)
a = get_h_H1(red_param, freqs)
a = get_h_L1(red_param, freqs)
a = get_h_V1(red_param, freqs)
# dh_L1  = get_dh_L1(red_param, freqs)
# dh_V1  = get_dh_V1(red_param, freqs)

In [8]:
kk = [-2,3]
names_ppe = names + [f"phi_{k}" for k in kk]


idx_par =  {names_ppe[i] : i for i in range(len(names_ppe))} # indexes for fisher matrix
log_flag =  {names_ppe[i] : 0 for i in range(len(names_ppe))} # I want the derivate to be wrt log M_c and log Dl
log_flag["M_c"] = 1; log_flag["d_L"] = 1
def calc_FI_main(idx):
    if idx=='150914':
        dat = jnp.array(data_150914)
    
    red_param = dict(zip(names, jnp.array(dat).T))

    dh_H1  = get_dh_H1(red_param, freqs)
    dh_L1  = get_dh_L1(red_param, freqs)
    dh_V1  = get_dh_V1(red_param, freqs)
    
    h_H1   = get_h_H1(red_param, freqs)
    h_L1   = get_h_L1(red_param, freqs)
    h_V1   = get_h_V1(red_param, freqs)

    for k in kk:
        dpsi_ppe = lib.get_dpsi_ppe(freqs, red_param, k)
        dh_H1[f"phi_{k}"] = 1j*dpsi_ppe*h_H1
        dh_L1[f"phi_{k}"] = 1j*dpsi_ppe*h_L1
        dh_V1[f"phi_{k}"] = 1j*dpsi_ppe*h_V1
    
    fi_H1 = lib.fish(freqs, dh_H1, red_param, idx_par, psd, log_flag)
    fi_L1 = lib.fish(freqs, dh_L1, red_param, idx_par, psd, log_flag)
    fi_V1 = lib.fish(freqs, dh_V1, red_param, idx_par, psd, log_flag)
    fi = fi_H1 + fi_L1 + fi_V1
    return fi

fi = calc_FI_main('150914')
jnp.diag(sla.inv(fi))

power error defn
power error defn


Array([5.04986065e-03, 2.68149579e-04, 3.25700841e-01,
       2.79666069e-04, 2.63782366e-04, 7.25581574e-01,
       1.05245587e+00, 1.13504149e-07, 4.67058320e+00,
       3.85677333e-05, 1.88171057e-04], dtype=float64)

In [None]:
# def calc_FI(idx, k):
#     red_param = dict(zip(names, jnp.array(data[idx]).T))
    
#     fi_H1, fi_L1, fi_V1 = lib.get_FI_ppe(freqs, red_param, idx_par, psd, log_flag, k)

#     fi = fi_H1 + fi_L1 + fi_V1
#     return fi

In [76]:
# kk = [0,3]
# fun = lambda idx : calc_FI(idx, k)
# fun_jit = jax.jit(fun)

In [86]:
# red_param = dict(zip(names, jnp.array(data[1]).T))


Array([2.34987845e-04, 3.24705141e-05, 3.18730617e-01,
       2.79666069e-04, 2.63782366e-04, 7.25581574e-01,
       1.05245587e+00, 1.08872664e-07, 1.24659970e+00,
       1.96662835e-04], dtype=float64)

Array([2.44653615e-01, 1.20713558e-02, 3.32785966e+00,
       2.29410149e-03, 2.31567156e-02, 8.74057987e-01,
       8.74938930e-01, 5.15006984e-06, 1.28418445e+01,
       5.17096541e-01, 1.15039510e-01], dtype=float64)