In [1]:
import os
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"
os.environ["XLA_PYTHON_CLIENT_ALLOCATOR"] = "platform"

import discovery as ds
import json
import numpy as np
import jax.numpy as jnp
import glob
import matplotlib.pyplot as plt
import jax
jax.config.update('jax_enable_x64', True)
import argparse
from functools import partial
import json
import discovery.samplers.numpyro as ds_numpyro
from pathlib import Path
import pickle
import corner
import inspect
import typing

In [2]:
feather_dir = "./from_polina/"

In [3]:
psrs = [ds.Pulsar.read_feather(psrfile) for psrfile in sorted(glob.glob(feather_dir+'*.feather'))]

## Psr signal injection as follows:
- gwtheta=1.0471975511965979
- gwphi=3.141592653589793
- dist=69.75
- log10_mc=9
- fgw=2e-08
- inc=0
- rzn=0
- curn

### Truth values of the pulsar distances are contained in psr.pdist

The contention is that the hessian of the lnL should be greatly maximised where the psr distance search values are at the true values

### Set up signal model below:

In [None]:
# Custom function to use the p_dist rather than p_phase

def makefourier_binary_pdist(pulsarterm=True):
    def fourier_binary_pdist(f, df, mintoa, pos, log10_h0, log10_f0, ra, sindec, cosinc, psi, phi_earth, psr_dist):
        """BBH residuals from Ellis et. al 2012, 2013"""

        h0 = 10**log10_h0
        f0 = 10**log10_f0

        dec, inc = jnp.arcsin(sindec), jnp.arccos(cosinc)

        # calculate antenna pattern (note: pos is pulsar sky position unit vector)
        fplus, fcross = fpc_fast(pos, 0.5 * jnp.pi - dec, ra)  # careful with dec -> gwtheta conversion


        
        cosgwtheta, cosgwphi = np.cos(gwtheta), np.cos(gwphi)
        singwtheta, singwphi = np.sin(gwtheta), np.sin(gwphi)
        
        omhat = jnp.array([-singwtheta * cosgwphi, -singwtheta * singwphi, -cosgwtheta])
        
        cosMu = -jnp.dot(omhat, phat)
        
        if pulsarterm:
            phi_psr = psr_dist * (2 * np.pi * f0 * (1 - cosMu)) * eu.KPC2S
            phi_avg  = 0.5 * (phi_earth + phi_psr)
        else:
            phi_avg = phi_earth

        tref = 86400.0 * 51544.5  # MJD J2000 in seconds

        cphase = cos2comp(f, df, 1.0, f0, phi_avg - 2.0 * jnp.pi * f0 * tref, mintoa)
        sphase = cos2comp(f, df, 1.0, f0, phi_avg - 2.0 * jnp.pi * f0 * tref - 0.5*jnp.pi, mintoa)

        # fix this for no pulsarterm

        if pulsarterm:
            phi_diff = 0.5 * (phi_earth - phi_psr)
            sin_diff = jnp.sin(phi_diff)

            delta_sin =  2.0 * cphase * sin_diff
            delta_cos = -2.0 * sphase * sin_diff
        else:
            delta_sin = sphase
            delta_cos = cphase

        At = -1.0 * (1.0 + jnp.cos(inc)**2) * delta_sin
        Bt =  2.0 * jnp.cos(inc) * delta_cos

        alpha = h0 / (2 * jnp.pi * f0)

        # calculate rplus and rcross
        rplus  = alpha * (-At * jnp.cos(2 * psi) + Bt * jnp.sin(2 * psi))
        rcross = alpha * ( At * jnp.sin(2 * psi) + Bt * jnp.cos(2 * psi))

        # calculate residuals
        res = -fplus * rplus - fcross * rcross

        return res

    if not pulsarterm:
        fourier_binary = functools.partial(fourier_binary, phi_psr=jnp.nan)

    return fourier_binary

def cos2comp(f, df, A, f0, phi, t0):
    """Project signal A * cos(2pi f t + phi) onto Fourier basis
    cos(2pi k t/T), sin(2pi k t/T) for t in [t0, t0+T]."""

    T = 1.0 / df[0]

    Delta_omega = 2.0 * jnp.pi * (f0 - f[::2])
    Sigma_omega = 2.0 * jnp.pi * (f0 + f[::2])

    phase_Delta_start = phi + Delta_omega * t0
    phase_Delta_end   = phi + Delta_omega * (t0 + T)

    phase_Sigma_start = phi + Sigma_omega * t0
    phase_Sigma_end   = phi + Sigma_omega * (t0 + T)

    ck = (A / T) * (
        (jnp.sin(phase_Delta_end) - jnp.sin(phase_Delta_start)) / Delta_omega +
        (jnp.sin(phase_Sigma_end) - jnp.sin(phase_Sigma_start)) / Sigma_omega
    )

    sk = (A / T) * (
        (jnp.cos(phase_Delta_end) - jnp.cos(phase_Delta_start)) / Delta_omega -
        (jnp.cos(phase_Sigma_end) - jnp.cos(phase_Sigma_start)) / Sigma_omega
    )

    return jnp.stack((sk, ck), axis=1).reshape(-1)



def makefourier_binary(pulsarterm=True):
    def fourier_binary(f, df, mintoa, pos, log10_h0, log10_f0, ra, sindec, cosinc, psi, phi_earth, L_psr):
        h0 = 10**log10_h0
        f0 = 10**log10_f0

        dec, inc = jnp.arcsin(sindec), jnp.arccos(cosinc)

        # calculate antenna pattern
        fplus, fcross = fpc_fast(pos, 0.5 * jnp.pi - dec, ra)

        c = 2.99792458e8 
        omega_hat = jnp.array([ -jnp.cos(dec) * jnp.cos(ra), 
                                -jnp.cos(dec) * jnp.sin(ra),
                                -jnp.sin(dec)
                              ])

        phi_psr = phi_earth - 2.0 * jnp.pi * f0 * L_psr / c * (1.0 + jnp.dot(omega_hat, pos))

        if pulsarterm:
            phi_avg = 0.5 * (phi_earth + phi_psr)
        else:
            phi_avg = phi_earth

        tref = 86400.0 * 51544.5  # MJD J2000 in seconds

        cphase = cos2comp(f, df, 1.0, f0, phi_avg - 2.0 * jnp.pi * f0 * tref, mintoa)
        sphase = cos2comp(f, df, 1.0, f0, phi_avg - 2.0 * jnp.pi * f0 * tref - 0.5 * jnp.pi, mintoa)

        if pulsarterm:
            phi_diff = 0.5 * (phi_earth - phi_psr)
            sin_diff = jnp.sin(phi_diff)

            delta_sin =  2.0 * cphase * sin_diff
            delta_cos = -2.0 * sphase * sin_diff
        else:
            delta_sin = sphase
            delta_cos = cphase

        At = -1.0 * (1.0 + jnp.cos(inc)**2) * delta_sin
        Bt =  2.0 * jnp.cos(inc) * delta_cos

        alpha = h0 / (2 * jnp.pi * f0)

        rplus  = alpha * (-At * jnp.cos(2 * psi) + Bt * jnp.sin(2 * psi))
        rcross = alpha * ( At * jnp.sin(2 * psi) + Bt * jnp.cos(2 * psi))

        res = -fplus * rplus - fcross * rcross

        return res

    if not pulsarterm:
        fourier_binary = functools.partial(fourier_binary, L_psr=jnp.nan)

    return fourier_binary

In [31]:
# Injected with effectively no WN

noisedict = {}
for psr in psrs:
    noisedict[psr.name+"_KAT_MKBF_efac"] = 1 # EFAC=1 means no EFAC
    noisedict[psr.name+"_KAT_MKBF_log10_ecorr"] = -19 # Bringing these down to effectively 0
    noisedict[psr.name+"_KAT_MKBF_log10_t2equad"] = -19

In [56]:

# Fourier based CW delay
fourdelay = ds.makefourier_binary(pulsarterm=True)
cwcommon = ['cw_sindec', 'cw_cosinc', 'cw_log10_f0', 'cw_log10_h0', 'cw_phi_earth', 'cw_psi', 'cw_ra']

T = ds.getspan(psrs)
fml = ds.ArrayLikelihood([ds.PulsarLikelihood([psr.residuals,
                                               ds.makenoise_measurement(psr, noisedict=noisedict),
                                               ds.makegp_timing(psr, svd=True),
                                              ]) for psr in psrs],
                         commongp = ds.makecommongp_fourier(psrs, ds.makepowerlaw_crn(14), 30, T, means=fourdelay,
                                                            common=['crn_gamma', 'crn_log10_A'] + cwcommon, name='rednoise', meansname='cw'))

In [59]:
logl = fml.logL

In [60]:
logl

<function discovery.matrix.VectorWoodburyKernel_varP.make_kernelproduct.<locals>.kernelproduct(params)>

In [61]:
## Params the likelihood expects
logl.params

['B1855+09_cw_phi_psr',
 'B1855+09_rednoise_gamma',
 'B1855+09_rednoise_log10_A',
 'B1937+21_cw_phi_psr',
 'B1937+21_rednoise_gamma',
 'B1937+21_rednoise_log10_A',
 'B1953+29_cw_phi_psr',
 'B1953+29_rednoise_gamma',
 'B1953+29_rednoise_log10_A',
 'J0023+0923_cw_phi_psr',
 'J0023+0923_rednoise_gamma',
 'J0023+0923_rednoise_log10_A',
 'J0030+0451_cw_phi_psr',
 'J0030+0451_rednoise_gamma',
 'J0030+0451_rednoise_log10_A',
 'J0125-2327_cw_phi_psr',
 'J0125-2327_rednoise_gamma',
 'J0125-2327_rednoise_log10_A',
 'J0340+4130_cw_phi_psr',
 'J0340+4130_rednoise_gamma',
 'J0340+4130_rednoise_log10_A',
 'J0406+3039_cw_phi_psr',
 'J0406+3039_rednoise_gamma',
 'J0406+3039_rednoise_log10_A',
 'J0437-4715_cw_phi_psr',
 'J0437-4715_rednoise_gamma',
 'J0437-4715_rednoise_log10_A',
 'J0509+0856_cw_phi_psr',
 'J0509+0856_rednoise_gamma',
 'J0509+0856_rednoise_log10_A',
 'J0557+1551_cw_phi_psr',
 'J0557+1551_rednoise_gamma',
 'J0557+1551_rednoise_log10_A',
 'J0605+3757_cw_phi_psr',
 'J0605+3757_rednoise_ga

In [None]:
phi_psr = psr_dist * (2 * np.pi * f0 * (1 - cosMu)) * eu.KPC2S

In [64]:
np.log(2e-08)

np.float64(-17.72753356339242)

In [67]:
10**-17.72753356339242

1.8726923509527988e-18

In [69]:
10**-14

1e-14

np.float64(-17.72753356339242)