In [1]:
%matplotlib inline
%config InlineBackend.figure_format = 'retina'

%load_ext autoreload
%autoreload 2

import numpy as np
from matplotlib import pyplot as plt
from glob import glob

import enterprise
from enterprise.pulsar import Pulsar
from enterprise.signals import parameter,white_signals,gp_signals,signal_base

from enterprise_extensions import blocks
from enterprise_extensions.model_utils import get_tspan

from la_forge.core import Core

import sys
sys.path.append('../')
sys.path.append('../../PerFreqOS')

from OSplusplus import OSplusplus as ospp

from PFOS.optimal_statistic import OptimalStatistic



## Load the Pulsars (MDC1 pulsars in enterprise)

In [2]:
datadir = enterprise.__path__[0] + '/datafiles/mdc_open1/'
parfiles = sorted(glob(datadir + '/*.par'))
timfiles = sorted(glob(datadir + '/*.tim'))

psrs = [Pulsar(par,tim) for par,tim in zip(parfiles,timfiles)]
    
inj_params = {'gw_log10_A':np.log10(5e-14),'gw_gamma':(13./3.)}








## Create a PTA object to sample from

In [3]:
Tspan = get_tspan(psrs)

efac = parameter.Constant(1.0)
ef = white_signals.MeasurementNoise(efac=efac)

curn = blocks.common_red_noise_block(psd='powerlaw',Tspan=Tspan,components=10,
                                     gamma_val=13./3.,logmin=-18,logmax=-12,name='gw')

tm = gp_signals.TimingModel(use_svd=True)

model = tm + ef + curn

pta = signal_base.PTA([model(psr) for psr in psrs])

## Sample from the PTA

In [4]:
run_sampler = False

savepath = 'MDC1_FG_CURN_samples'
if run_sampler:
    from enterprise_extensions.hypermodel import HyperModel

    hyper = HyperModel({0:pta})
    sampler = hyper.setup_sampler(resume=True,outdir=savepath)

    N = int(1e6)
    x0 = hyper.initial_sample()
    sampler.sample(x0, N, AMweight=25, SCAMweight=40, DEweight=55, writeHotChains=False)
else:
    # Load the chain in with la_forge
    lfcore = Core(chaindir=savepath)
    lfcore.save(savepath+'.core')


## Create an OSplusplus object

In [7]:
t1 = ospp(psrs,pta,core=lfcore)
A2,A2s,total_snr = t1.compute_os(N=1,pair_covariant=False)
print(A2, A2s, total_snr)


og_os = OptimalStatistic(psrs,pta=pta)
_,_,_,old_A2,old_A2s = og_os.compute_os(t1.max_likelihood_params)
print(old_A2, old_A2s)



2.508467506388377e-27 1.9360098416715863e-28 12.956894393793577
2.508467506388959e-27 1.9360098416715867e-28


In [20]:
sf, sfs, total_snr = t1.compute_pfos(pair_covariant=False)
plt.errorbar(t1.frequencies,sf,sfs,fmt='.')
plt.loglog()
plt.show()

 30%|███       | 3/10 [00:00<00:01,  4.56it/s]


LinAlgError: SVD did not converge

In [38]:
t1.set_orf(['hd','dipole'])

A2,A2s,total_snr = t1.compute_os(pair_covariant=False)
print(A2,A2s,total_snr)

A2,A2s,total_snr = t1.compute_os(pair_covariant=True)
print(A2,A2s,total_snr)



(2, 630) (630, 630) (630,)
[2.50508927e-27 4.71505657e-30] [[ 3.87938594e-56 -1.83190212e-57]
 [-1.83190212e-57  2.55681413e-57]] 12.957229929018965


Pairs of pairs: 100%|███████████████| 198765/198765 [00:00<00:00, 585433.17it/s]


(2, 630) (630, 630) (630,)
[ 2.53234577e-27 -2.48233981e-29] [[ 2.77737486e-55 -5.89849858e-58]
 [-5.89849858e-58  2.29968317e-57]] 4.822224732076541


## Free spectral models

In [12]:
Tspan = get_tspan(psrs)

efac = parameter.Constant(1.0)
ef = white_signals.MeasurementNoise(efac=efac)

curn = blocks.common_red_noise_block(psd='spectrum',Tspan=Tspan,components=10,
                                     gamma_val=13./3.,logmin=-14,logmax=-9,name='gw')

tm = gp_signals.TimingModel(use_svd=True)

model = tm + ef + curn

fs_pta = signal_base.PTA([model(psr) for psr in psrs])

In [13]:
run_sampler = False

savepath = 'MDC1_FS_CURN_samples'
if run_sampler:
    from enterprise_extensions.hypermodel import HyperModel

    hyper = HyperModel({0:fs_pta})
    sampler = hyper.setup_sampler(resume=True,outdir=savepath)

    N = int(1e6)
    x0 = hyper.initial_sample()
    sampler.sample(x0, N, AMweight=25, SCAMweight=40, DEweight=55, writeHotChains=False)
else:
    # Load the chain in with la_forge
    fs_lfcore = Core(chaindir=savepath)
    fs_lfcore.save(savepath+'.core')


In [22]:
fs_pta.param_names

['gw_log10_rho_0',
 'gw_log10_rho_1',
 'gw_log10_rho_2',
 'gw_log10_rho_3',
 'gw_log10_rho_4',
 'gw_log10_rho_5',
 'gw_log10_rho_6',
 'gw_log10_rho_7',
 'gw_log10_rho_8',
 'gw_log10_rho_9']

In [47]:
t2 = ospp(psrs,fs_pta,corepath='MDC1_FS_CURN_samples.core')
A2,A2s,total_snr = t2.compute_os(pair_covariant=True,gamma=13./3.)
print(A2, A2s, total_snr)




Loading data from HDF5 file....


Pairs of pairs: 100%|███████████████| 198765/198765 [00:00<00:00, 590899.66it/s]


7.984925469881207e-27 1.3771303835428785e-15 5.79823491319592e-12


In [51]:
from OSplusplus.utils import get_gwb_a2
get_gwb_a2(fs_pta,t2.max_likelihood_params,'gw')

-5.508135585274247e-15

In [49]:
inj_params

{'gw_log10_A': -13.301029995663981, 'gw_gamma': 4.333333333333333}