In [42]:
import numpy as np
import pickle
import json
from enterprise.pulsar import Tempo2Pulsar, Pulsar
from enterprise.signals import selections, parameter, white_signals, gp_priors, gp_signals, signal_base, utils
from enterprise_extensions import deterministic
import libstempo as T2
import scipy.linalg as sl
import scipy.constants as sc
import scipy.sparse as sps
from sksparse.cholmod import cholesky, CholmodError

## Relevant functions

In [30]:
# function to calculate the pulsar phase, taken from QuickCW code
# takes in an enterprise pulsar object, GW source's theta and phi coords, chirp mass, GW frequency

def get_pphase(psr, gwtheta, gwphi, mc, fgw):
    
    KPC2S = sc.parsec / sc.c * 1e3
    SOLAR2S = sc.G / sc.c ** 3 * 1.98855e30
    
    # GW direction of propagation
    singwtheta = np.sin(gwtheta)
    cosgwtheta = np.cos(gwtheta)
    singwphi = np.sin(gwphi)
    cosgwphi = np.cos(gwphi)
    omhat = np.array([-singwtheta * cosgwphi, -singwtheta * singwphi, -cosgwtheta])
    
    # line of sight to pulsar
    ptheta = psr.theta
    pphi = psr.phi
    phat = np.array([np.sin(ptheta) * np.cos(pphi), np.sin(ptheta) * np.sin(pphi), np.cos(ptheta)])

    # angular separation between GW source and pulsar LOS
    cosMu = -np.dot(omhat, phat)
    
    # pulsar phase calculation
    pphase = (1 + 256/5 * (mc*SOLAR2S)**(5/3) * (np.pi * fgw)**(8/3) * psr.pdist[0]*KPC2S*(1-cosMu)) ** (5/8) - 1
    pphase /= 32 * (mc*SOLAR2S)**(5/3) * (np.pi * fgw)**(5/3)
    pp = -pphase%(2*np.pi)
    
    return pp

In [31]:
# functions modified from signal_base.py

def get_TNr_PTA(pta, params):
    return [get_TNr_signal(signalcollection, params) for signalcollection in pta._signalcollections]

def get_TNr_signal(signal, params):
    T = signal.get_basis(params)
    if T is None:
        return None
    Nvec = signal.get_ndiag(params)
    # below was previously res=signal.get_detres which is residual-delay
    # we just want the CW delay
    cwsig = signal.get_delay(params) 
    return Nvec.solve(cwsig, left_array=T)

def get_rNr_logdet_PTA(pta, params):
    return [get_rNr_logdet_signal(signalcollection, params) for signalcollection in pta._signalcollections]

def get_rNr_logdet_signal(signal, params):
    Nvec = signal.get_ndiag(params)
    # below was previously res=signal.get_detres which is residual-delay
    # we just want the CW delay
    cwsig = signal.get_delay(params)
    return Nvec.solve(cwsig, left_array=cwsig, logdet=True)

def _block_TNT(TNTs, cholesky_sparse=True):
    if cholesky_sparse:
        return sps.block_diag(TNTs, "csc")
    else:
        return sl.block_diag(*TNTs)
    
def _block_TNr(TNrs):
    return np.concatenate(TNrs)

In [37]:
# main S/N function modified from chi_squared.py in enterprise_extensions
# takes in a PTA object and dictionary of injected parameters

def get_SNR(pta, xs, cholesky_sparse=True):
    """Compute generalize chisq for pta:
        chisq = y^T (N + F phi F^T)^-1 y
              = y^T N^-1 y - y^T N^-1 F (F^T N^-1 F + phi^-1)^-1 F^T N^-1 y
        S/N = [s^T (N + T B T^T)^-1 s]^(1/2)
        Essentially same as chisq, just take square root at the end.
    """

    params = xs if isinstance(xs, dict) else pta.map_params(xs)

    TNrs = get_TNr_PTA(pta, params) #uses modified function above
    TNTs = pta.get_TNT(params)
    phiinvs = pta.get_phiinv(params, logdet=True, method='cliques')

    rho2 = np.sum(ell[0] for ell in get_rNr_logdet_PTA(pta, params)) #uses modified function above

    if pta._commonsignals:
        phiinv, logdet_phi = phiinvs
        TNT = _block_TNT(TNTs)
        TNr = _block_TNr(TNrs)

        if cholesky_sparse:
            try:
                cf = cholesky(TNT + sps.csc_matrix(phiinv))  # cf(Sigma)
                expval = cf(TNr)
                logdet_sigma = cf.logdet()
            except CholmodError:  # pragma: no cover
                return -np.inf
            rho2 = rho2 - np.dot(TNr, expval)
        else:
            try:
                cf = sl.cho_factor(TNT + phiinv)  # cf(Sigma)
                expval = sl.cho_solve(cf, TNr)
                logdet_sigma = 2 * np.sum(np.log(np.diag(cf[0])))
            except sl.LinAlgError:
                return -np.inf
            rho2 = rho2 - np.dot(TNr, expval)

    else:
        for TNr, TNT, pl in zip(TNrs, TNTs, phiinvs):
            if TNr is None:
                continue

            phiinv, _ = pl
            Sigma = TNT + (np.diag(phiinv) if phiinv.ndim == 1 else phiinv)

            try:
                cf = sl.cho_factor(Sigma)
                expval = sl.cho_solve(cf, TNr)
            except sl.LinAlgError:  # pragma: no cover
                return -np.inf

            rho2 = rho2 - np.dot(TNr, expval)

    return rho2**(1/2)

## PTA configuration

In [None]:
# read in pickled pulsars
pklfile = '/path/to/pickle'
with open(pklfile, 'rb') as psr_pkl:
    psrs = pickle.load(psr_pkl)

# read in noisefile
noisefile = '/path/to/noisefile'
with open(noisefile, 'rb') as nf:
    noisedict = json.load(nf)

In [39]:
# timing model
tm = gp_signals.TimingModel()

# white noise parameters -- modify as needed
efac = parameter.Constant(1)
selection = selections.Selection(selections.no_selection)
ef = white_signals.MeasurementNoise(efac=efac, selection=selection)

# red noise parameters -- modify as needed
irn_log10_A = parameter.Constant()
irn_gamma = parameter.Constant()
irn_pl = utils.powerlaw(log10_A=irn_log10_A, gamma=irn_gamma)
rn = gp_signals.FourierBasisGP(irn_pl, components=30)

# GWB parameters -- modify as needed
gwb_logA = parameter.Constant(np.log10(2.4e-15)) #currently set to NG15 measurement
gwb_gamma = parameter.Constant(13/3) #currently set to NG15 measurement
gwb_pl = gp_priors.powerlaw(log10_A=gwb_logA, gamma=gwb_gamma)
curn = gp_signals.FourierBasisGP(spectrum=gwb_pl, components=30, name='gwb')
hd = gp_signals.FourierBasisCommonGP(spectrum=gwb_pl, orf=utils.hd_orf(), components=30, name='gwb')
    
# continuous wave signal
# parameters in cw_wf don't really matter since they'll be set to a constant value later
tmin = np.min([p.toas.min() for p in psrs])
cw_wf = deterministic.cw_delay(cos_gwtheta=parameter.Uniform(-1, 1)('cos_gwtheta'),
                               gwphi=parameter.Uniform(0, 2*np.pi)('gwphi'),
                               cos_inc=parameter.Uniform(-1, 1)('cos_inc'),
                               log10_mc=parameter.Uniform(7, 10)('log10_mc'),
                               log10_fgw=parameter.Uniform(-9, -7)('log10_fgw'),
                               log10_dist=parameter.Uniform(1, 4)('log10_dL'),
                               log10_h=None,
                               phase0=parameter.Uniform(0, 2*np.pi)('phase0'),
                               psi=parameter.Uniform(0, np.pi)('psi'),
                               p_dist=parameter.Uniform(0, 1),
                               p_phase=parameter.Uniform(0, 2*np.pi),
                               evolve=False,
                               phase_approx=True,
                               psrTerm=True,
                               tref=tmin)
cw = deterministic.CWSignal(cw_wf, psrTerm=True, name='cw0')

# complete signal
#s = tm + ef + rn + curn + cw
s = tm + ef + rn + hd + cw
    
# initialize PTA
pta = signal_base.PTA([s(p) for p in psrs])
pta.set_default_params(noisedict)

In [40]:
# create injected parameter dictionary
# example parameters -- modify as needed

ra = (15*(13 + 0/60 + 8.09/3600))
dec = 27 + 58/60 + 37.2/3600 
gwphi = ra*np.pi/180
gwtheta = np.pi/2 - dec*np.pi/180

mc = 1e9
dist = 100
fgw = 1e-8
inc = 0.0 #face-on
phase0 = 0.0
psi = np.pi/4.0

inj_vals = []
for p in psrs:
    inj_vals.append(0) #pulsar distance parameter p_dist is an offset from pulsar distance -- should be 0
    inj_vals.append(get_pphase(p, gwtheta, gwphi, mc, fgw)) #pulsar phase (calculated with function above)
inj_vals.append(np.cos(gwtheta))
inj_vals.append(np.cos(inc))
inj_vals.append(gwphi)
inj_vals.append(np.log10(dist))
inj_vals.append(np.log10(fgw))
inj_vals.append(np.log10(mc))
inj_vals.append(phase0)
inj_vals.append(psi)

injDict = {}
for param, val in zip(pta.param_names,inj_vals):
    injDict[param] = val

In [43]:
# calculate S/N
snr = get_SNR(pta, injDict)
print(snr)

  rho2 = np.sum(ell[0] for ell in get_rNr_logdet_PTA(pta, params)) #uses modified function above


6.864392820668058
