In [1]:
import os
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"
os.environ["XLA_PYTHON_CLIENT_ALLOCATOR"] = "platform"

import discovery as ds
import json
import numpy as np
import jax.numpy as jnp
import glob
import matplotlib.pyplot as plt
import jax
jax.config.update('jax_enable_x64', True)
import argparse
from functools import partial
import json
import discovery.samplers.numpyro as ds_numpyro
from pathlib import Path
import pickle
import corner
import inspect
import typing

In [2]:
feather_dir = "./from_polina/"

In [83]:
psrs = [ds.Pulsar.read_feather(psrfile) for psrfile in sorted(glob.glob(feather_dir+'*.feather'))][:10] #For now just grabbing 10 psrs

## Psr signal injection as follows:
- gwtheta=1.0471975511965979
- gwphi=3.141592653589793
- dist=69.75
- log10_mc=9
- fgw=2e-08
- inc=0
- rzn=0
- curn

### Truth values of the pulsar distances are contained in psr.pdist

The contention is that the hessian of the lnL should be greatly maximised where the psr distance search values are at the true values

### Set up signal model below:

In [157]:
# Custom function to use the p_dist rather than p_phase


def fpc_fast(pos, gwtheta, gwphi):
    x, y, z = pos

    sin_phi = jnp.sin(gwphi)
    cos_phi = jnp.cos(gwphi)
    sin_theta = jnp.sin(gwtheta)
    cos_theta = jnp.cos(gwtheta)

    m_dot_pos = sin_phi * x - cos_phi * y
    n_dot_pos = -cos_theta * cos_phi * x - cos_theta * sin_phi * y + sin_theta * z
    omhat_dot_pos = -sin_theta * cos_phi * x - sin_theta * sin_phi * y - cos_theta * z

    denom = 1.0 + omhat_dot_pos

    fplus = 0.5 * (m_dot_pos**2 - n_dot_pos**2) / denom
    fcross = (m_dot_pos * n_dot_pos) / denom

    return fplus, fcross
    
def cos2comp(f, df, A, f0, phi, t0):
    """Project signal A * cos(2pi f t + phi) onto Fourier basis
    cos(2pi k t/T), sin(2pi k t/T) for t in [t0, t0+T]."""

    T = 1.0 / df[0]

    Delta_omega = 2.0 * jnp.pi * (f0 - f[::2])
    Sigma_omega = 2.0 * jnp.pi * (f0 + f[::2])

    phase_Delta_start = phi + Delta_omega * t0
    phase_Delta_end   = phi + Delta_omega * (t0 + T)

    phase_Sigma_start = phi + Sigma_omega * t0
    phase_Sigma_end   = phi + Sigma_omega * (t0 + T)

    ck = (A / T) * (
        (jnp.sin(phase_Delta_end) - jnp.sin(phase_Delta_start)) / Delta_omega +
        (jnp.sin(phase_Sigma_end) - jnp.sin(phase_Sigma_start)) / Sigma_omega
    )

    sk = (A / T) * (
        (jnp.cos(phase_Delta_end) - jnp.cos(phase_Delta_start)) / Delta_omega -
        (jnp.cos(phase_Sigma_end) - jnp.cos(phase_Sigma_start)) / Sigma_omega
    )

    return jnp.stack((sk, ck), axis=1).reshape(-1)



def makefourier_binary_pdist(pulsarterm=True):
    def fourier_binary_pdist(f, df, mintoa, pos, log10_h0, log10_f0, ra, sindec, cosinc, psi, phi_earth, p_dist):
        h0 = 10**log10_h0
        f0 = 10**log10_f0

        pos = jnp.array(pos)
        
        dec, inc = jnp.arcsin(sindec), jnp.arccos(cosinc)

        # calculate antenna pattern
        fplus, fcross = fpc_fast(pos, 0.5 * jnp.pi - dec, ra)

        c = 2.99792458e8 
        omega_hat = jnp.array([ -jnp.cos(dec) * jnp.cos(ra), 
                                -jnp.cos(dec) * jnp.sin(ra),
                                -jnp.sin(dec)
                              ])

        phi_psr = phi_earth - 2.0 * jnp.pi * f0 * p_dist / c * (1.0 + jnp.dot(omega_hat, pos))

        if pulsarterm:
            phi_avg = 0.5 * (phi_earth + phi_psr)
        else:
            phi_avg = phi_earth

        tref = 86400.0 * 51544.5  # MJD J2000 in seconds

        cphase = cos2comp(f, df, 1.0, f0, phi_avg - 2.0 * jnp.pi * f0 * tref, mintoa)
        sphase = cos2comp(f, df, 1.0, f0, phi_avg - 2.0 * jnp.pi * f0 * tref - 0.5 * jnp.pi, mintoa)

        if pulsarterm:
            phi_diff = 0.5 * (phi_earth - phi_psr)
            sin_diff = jnp.sin(phi_diff)

            delta_sin =  2.0 * cphase * sin_diff
            delta_cos = -2.0 * sphase * sin_diff
        else:
            delta_sin = sphase
            delta_cos = cphase

        At = -1.0 * (1.0 + jnp.cos(inc)**2) * delta_sin
        Bt =  2.0 * jnp.cos(inc) * delta_cos

        alpha = h0 / (2 * jnp.pi * f0)

        rplus  = alpha * (-At * jnp.cos(2 * psi) + Bt * jnp.sin(2 * psi))
        rcross = alpha * ( At * jnp.sin(2 * psi) + Bt * jnp.cos(2 * psi))

        res = -fplus * rplus - fcross * rcross

        return res

    if not pulsarterm:
        fourier_binary_pdist = functools.partial(fourier_binary, p_dist=jnp.nan)

    return fourier_binary_pdist

In [158]:
# Injected with effectively no WN

noisedict = {}
for psr in psrs:
    noisedict[psr.name+"_KAT_MKBF_efac"] = 1 # EFAC=1 means no EFAC
    noisedict[psr.name+"_KAT_MKBF_log10_ecorr"] = -19 # Bringing these down to effectively 0
    noisedict[psr.name+"_KAT_MKBF_log10_t2equad"] = -19

In [159]:

# Fourier based CW delay
fourdelay = makefourier_binary_pdist(pulsarterm=True)
cwcommon = ['cw_sindec', 'cw_cosinc', 'cw_log10_f0', 'cw_log10_h0', 'cw_phi_earth', 'cw_psi', 'cw_ra']

T = ds.getspan(psrs)
fml = ds.ArrayLikelihood([ds.PulsarLikelihood([psr.residuals,
                                               ds.makenoise_measurement(psr, noisedict=noisedict),
                                               ds.makegp_timing(psr, svd=True),
                                              ]) for psr in psrs],
                         commongp = ds.makecommongp_fourier(psrs, ds.makepowerlaw_crn(14), 30, T, means=fourdelay,
                                                            common=['crn_gamma', 'crn_log10_A'] + cwcommon, name='rednoise', meansname='cw'))

In [262]:
logl = fml.logL

In [220]:
logl

<function discovery.matrix.VectorWoodburyKernel_varP.make_kernelproduct.<locals>.kernelproduct(params)>

In [162]:
psr.pdist

[1.45, 0.29]

In [163]:
## Params the likelihood expects
logl.params

['B1855+09_cw_p_dist',
 'B1855+09_rednoise_gamma',
 'B1855+09_rednoise_log10_A',
 'B1937+21_cw_p_dist',
 'B1937+21_rednoise_gamma',
 'B1937+21_rednoise_log10_A',
 'B1953+29_cw_p_dist',
 'B1953+29_rednoise_gamma',
 'B1953+29_rednoise_log10_A',
 'J0023+0923_cw_p_dist',
 'J0023+0923_rednoise_gamma',
 'J0023+0923_rednoise_log10_A',
 'J0030+0451_cw_p_dist',
 'J0030+0451_rednoise_gamma',
 'J0030+0451_rednoise_log10_A',
 'J0125-2327_cw_p_dist',
 'J0125-2327_rednoise_gamma',
 'J0125-2327_rednoise_log10_A',
 'J0340+4130_cw_p_dist',
 'J0340+4130_rednoise_gamma',
 'J0340+4130_rednoise_log10_A',
 'J0406+3039_cw_p_dist',
 'J0406+3039_rednoise_gamma',
 'J0406+3039_rednoise_log10_A',
 'J0437-4715_cw_p_dist',
 'J0437-4715_rednoise_gamma',
 'J0437-4715_rednoise_log10_A',
 'J0509+0856_cw_p_dist',
 'J0509+0856_rednoise_gamma',
 'J0509+0856_rednoise_log10_A',
 'crn_gamma',
 'crn_log10_A',
 'cw_cosinc',
 'cw_log10_f0',
 'cw_log10_h0',
 'cw_phi_earth',
 'cw_psi',
 'cw_ra',
 'cw_sindec']

In [164]:
truths = {}
for psr in psrs:
    truths[psr.name+"_cw_p_dist"] = psr.pdist[0]

In [165]:
# Ignores uncertainty but the CW would be injected given the pdist value
truths

{'B1855+09_cw_p_dist': 1.18,
 'B1937+21_cw_p_dist': 3.1,
 'B1953+29_cw_p_dist': 4.64,
 'J0023+0923_cw_p_dist': 1.02,
 'J0030+0451_cw_p_dist': 0.3296,
 'J0125-2327_cw_p_dist': 0.873,
 'J0340+4130_cw_p_dist': 1.71,
 'J0406+3039_cw_p_dist': 1.72,
 'J0437-4715_cw_p_dist': 0.1549,
 'J0509+0856_cw_p_dist': 1.45}

In [166]:
for param in logl.params:
    if "rednoise_log10_A" in param:
        truths[param] = -20
    if "crn_log10_A" in param:
        truths[param] = -14.619788758288394
    if "rednoise_gamma" in param:
        truths[param] = 3
    if "crn_gamma" in param:
        truths[param] = 13/3
    if "cw_cosinc" in param:
        truths[param] = 1
    if "cw_log10_f0" in param:
        truths[param] = -7.698970004336019



In [167]:
touse = truths.copy()



gwtheta=1.0471975511965979
gwphi=3.141592653589793
dist=69.75
log10_mc=9
fgw=2e-08
inc=0
rzn=0
curn


In [168]:
ds.priordict_standard

{'(.*_)?efac': [0.9, 1.1],
 '(.*_)?t2equad': [-8.5, -5],
 '(.*_)?tnequad': [-8.5, -5],
 '(.*_)?log10_ecorr': [-8.5, -5],
 '(.*_)?rednoise_log10_A.*': [-20, -11],
 '(.*_)?rednoise_gamma.*': [0, 7],
 '(.*_)?rednoise_log10_fb': [-9, -6],
 '(.*_)?red_noise_log10_A.*': [-20, -11],
 '(.*_)?red_noise_gamma.*': [0, 7],
 '(.*_)?red_noise_log10_fb': [-9, -6],
 'crn_log10_A.*': [-18, -11],
 'crn_gamma.*': [0, 7],
 'crn_log10_fb': [-9, -6],
 'gw_(.*_)?log10_A': [-18, -11],
 'gw_(.*_)?gamma': [0, 7],
 'gw_log10_fb': [-9, -6],
 '(.*_)?dmgp_log10_A': [-20, -11],
 '(.*_)?dmgp_gamma': [0, 7],
 '(.*_)?dmgp_alpha': [1, 3],
 '(.*_)?chromgp_log10_A': [-20, -11],
 '(.*_)?chromgp_gamma': [0, 7],
 '(.*_)?chromgp_alpha': [1, 7],
 '(.*_)?dm_gp_log10_A': [-20, -11],
 '(.*_)?dm_gp_gamma': [0, 7],
 '(.*_)?dm_gp_alpha': [1, 3],
 '(.*_)?chrom_gp_log10_A': [-20, -11],
 '(.*_)?chrom_gp_gamma': [0, 7],
 '(.*_)?chrom_gp_alpha': [1, 3],
 'crn_log10_rho': [-9, -4],
 'gw_(.*_)?log10_rho': [-9, -4],
 '(.*_)?red_noise_log10_

In [278]:
# I think it's better not to have all the parameters correct, so these are incorrect. 
# I want to make sure that Hessian changes because of the pdists and not anything else

touse.update({'cw_log10_h0': np.random.uniform(-18,-11), 'cw_phi_earth': np.pi, 'cw_psi': np.pi/2, 'cw_ra': np.pi/2, 'cw_sindec': 0, 'cw_cosinc': 0}) 

In [279]:
order_map = {key: i for i, key in enumerate(logl.params)}

In [280]:
sorted_items = sorted(touse.items(), key=lambda item: order_map.get(item[0], float('inf')))

In [281]:
sorted_dict = dict(sorted_items)

In [282]:
sorted_dict

{'B1855+09_cw_p_dist': 1.18,
 'B1855+09_rednoise_gamma': 3,
 'B1855+09_rednoise_log10_A': -20,
 'B1937+21_cw_p_dist': 3.1,
 'B1937+21_rednoise_gamma': 3,
 'B1937+21_rednoise_log10_A': -20,
 'B1953+29_cw_p_dist': 4.64,
 'B1953+29_rednoise_gamma': 3,
 'B1953+29_rednoise_log10_A': -20,
 'J0023+0923_cw_p_dist': 1.02,
 'J0023+0923_rednoise_gamma': 3,
 'J0023+0923_rednoise_log10_A': -20,
 'J0030+0451_cw_p_dist': 0.3296,
 'J0030+0451_rednoise_gamma': 3,
 'J0030+0451_rednoise_log10_A': -20,
 'J0125-2327_cw_p_dist': 0.873,
 'J0125-2327_rednoise_gamma': 3,
 'J0125-2327_rednoise_log10_A': -20,
 'J0340+4130_cw_p_dist': 1.71,
 'J0340+4130_rednoise_gamma': 3,
 'J0340+4130_rednoise_log10_A': -20,
 'J0406+3039_cw_p_dist': 1.72,
 'J0406+3039_rednoise_gamma': 3,
 'J0406+3039_rednoise_log10_A': -20,
 'J0437-4715_cw_p_dist': 0.1549,
 'J0437-4715_rednoise_gamma': 3,
 'J0437-4715_rednoise_log10_A': -20,
 'J0509+0856_cw_p_dist': 1.45,
 'J0509+0856_rednoise_gamma': 3,
 'J0509+0856_rednoise_log10_A': -20,
 'cr

In [283]:
sorted_dict.values()

dict_values([1.18, 3, -20, 3.1, 3, -20, 4.64, 3, -20, 1.02, 3, -20, 0.3296, 3, -20, 0.873, 3, -20, 1.71, 3, -20, 1.72, 3, -20, 0.1549, 3, -20, 1.45, 3, -20, 4.333333333333333, -14.619788758288394, 0, -7.698970004336019, -16.754489168663948, 3.141592653589793, 1.5707963267948966, 1.5707963267948966, 0])

In [284]:
# Get the keys in the same order as values
param_keys = list(sorted_dict.keys())

# Define the wrapper
def logl_wrapped(x_array):
    # Convert array -> dict
    params = {k: v for k, v in zip(param_keys, x_array)}
    return logl(params)

In [285]:
#x0 = np.array(list(sorted_dict.values()), dtype=np.float64)
x0 = jnp.array(list(sorted_dict.values()), dtype=jnp.float64)
grad_fn = jax.jit(jax.grad(logl_wrapped))
hess_fn = jax.jit(jax.hessian(logl_wrapped))

In [286]:
print("logl(x0) =", logl_wrapped(x0))

logl(x0) = 33490.17741467983


In [287]:
G = grad_fn(x0)

In [292]:
G

Array([ 3.09410931e-20, -1.14745171e-12, -4.08098997e-12,  4.48142884e-19,
        1.37630359e-09,  6.51560921e-09,  9.27131882e-21,  3.91286036e-10,
        1.36636387e-09,  3.52386755e-18,  3.09707167e-11,  2.95969666e-10,
        1.93079380e-18,  3.76848153e-13, -1.26766725e-11,  4.10447078e-19,
       -1.81050993e-12, -7.57020620e-12,  1.71978744e-19, -1.77868402e-13,
        1.15671611e-12,  1.09085054e-19,  1.05452060e-12,  1.52160700e-11,
       -3.89507306e-18,  9.36271529e-12,  6.19457573e-11,  3.75850032e-21,
       -7.91457264e-13, -3.36172795e-12,  7.64959095e+02,  2.38412649e+03,
        1.56061219e-17,  3.10904923e-16,  1.52770685e-17,  1.73472348e-18,
       -5.22041306e-17, -6.18724367e-17, -6.19361461e-17], dtype=float64)

In [288]:
H = hess_fn(x0)

In [293]:
H.shape

(39, 39)

In [290]:
H

Array([[-8.40036550e-37, -4.04215397e-33, -2.89344666e-32, ...,
        -1.79292460e-19, -5.11793510e-19, -3.64776212e-19],
       [-4.04215397e-33, -1.38243262e-12, -5.28358029e-12, ...,
         2.53560210e-32,  7.23318566e-32,  5.16213834e-32],
       [-2.89344666e-32, -5.28351769e-12, -1.87594415e-11, ...,
         1.81502968e-31,  5.17764465e-31,  3.69515165e-31],
       ...,
       [-1.79292460e-19,  2.53560210e-32,  1.81502968e-31, ...,
        -2.65389862e-17,  1.21683851e-16, -1.22698645e-16],
       [-5.11793510e-19,  7.23318566e-32,  5.17764465e-31, ...,
         1.21683851e-16,  9.42961725e-17, -2.32086263e-17],
       [-3.64776212e-19,  5.16213834e-32,  3.69515165e-31, ...,
        -1.22698645e-16, -2.32086263e-17, -1.07711991e-16]],      dtype=float64)

In [242]:
val = logl_wrapped(x0)
print(val)

33490.17741467983


In [243]:
jax.debug.print("params: {}", x0)

params: [  1.18         3.         -20.           3.1          3.
 -20.           4.64         3.         -20.           1.02
   3.         -20.           0.3296       3.         -20.
   0.873        3.         -20.           1.71         3.
 -20.           1.72         3.         -20.           0.1549
   3.         -20.           1.45         3.         -20.
   4.33333333 -14.61978876   1.          -7.69897    -16.2529779
   3.14159265   1.57079633   1.57079633   0.        ]


In [147]:
touse

{'B1855+09_cw_p_dist': 1.18,
 'B1937+21_cw_p_dist': 3.1,
 'B1953+29_cw_p_dist': 4.64,
 'J0023+0923_cw_p_dist': 1.02,
 'J0030+0451_cw_p_dist': 0.3296,
 'J0125-2327_cw_p_dist': 0.873,
 'J0340+4130_cw_p_dist': 1.71,
 'J0406+3039_cw_p_dist': 1.72,
 'J0437-4715_cw_p_dist': 0.1549,
 'J0509+0856_cw_p_dist': 1.45,
 'B1855+09_rednoise_gamma': 3,
 'B1855+09_rednoise_log10_A': -20,
 'B1937+21_rednoise_gamma': 3,
 'B1937+21_rednoise_log10_A': -20,
 'B1953+29_rednoise_gamma': 3,
 'B1953+29_rednoise_log10_A': -20,
 'J0023+0923_rednoise_gamma': 3,
 'J0023+0923_rednoise_log10_A': -20,
 'J0030+0451_rednoise_gamma': 3,
 'J0030+0451_rednoise_log10_A': -20,
 'J0125-2327_rednoise_gamma': 3,
 'J0125-2327_rednoise_log10_A': -20,
 'J0340+4130_rednoise_gamma': 3,
 'J0340+4130_rednoise_log10_A': -20,
 'J0406+3039_rednoise_gamma': 3,
 'J0406+3039_rednoise_log10_A': -20,
 'J0437-4715_rednoise_gamma': 3,
 'J0437-4715_rednoise_log10_A': -20,
 'J0509+0856_rednoise_gamma': 3,
 'J0509+0856_rednoise_log10_A': -20,
 'cr

In [148]:
sindec = 0
dec = jnp.arcsin(sindec)
ra = 1.5707963267948966
omega_hat = jnp.array([ -jnp.cos(dec) * jnp.cos(ra), 
                        -jnp.cos(dec) * jnp.sin(ra),
                        -jnp.sin(dec)
                      ])

In [156]:
type(psrs[0].pos)

list

In [245]:
## Minimal test


def logl(params):
    # deliberately simple example
    return -0.5 * jnp.sum(params**2)

def logl_wrapped(x):
    keys = ["a", "b"]
    params = {k: v for k, v in zip(keys, x)}
    return logl(params["a"] + params["b"])  # or whatever you have

x0 = jnp.array([0.1, 0.2], dtype=jnp.float64)

print("logl:", logl_wrapped(x0))
print("grad:", jax.grad(logl_wrapped)(x0))
print("hess:", jax.hessian(logl_wrapped)(x0))

logl: -0.04500000000000001
grad: [-0.3 -0.3]
hess: [[-1. -1.]
 [-1. -1.]]
