# Load Needed Packages

In [1]:
import jax
import jax.numpy as jnp
import jax.scipy as jsp
import jax.random as jar
# jax.config.update('jax_platform_name', 'cuda')
import numpy as np
jax.config.update("jax_enable_x64", True)

from pandora import models, utils, GWBFunctions
from pandora import LikelihoodCalculator as LC

from enterprise_extensions import blocks
from enterprise.signals import signal_base, gp_signals
from enterprise.signals import gp_priors as gpp
from enterprise.signals import parameter
from enterprise_extensions.model_utils import get_tspan

import numpyro
from numpyro import distributions as dist
from numpyro import infer

import pickle, json, os, corner, glob, random, copy, time, inspect

import matplotlib.pyplot as plt
from matplotlib.ticker import AutoMinorLocator
import matplotlib.lines as mlines
plt.style.use('dark_background')
hist_settings = dict(
    bins = 40,
    histtype = 'step',
    lw = 3,
    density = True
)
%matplotlib inline
%config InlineBackend.figure_format = 'retina'
%load_ext autoreload
%autoreload 2
# %load_ext line_profiler



Optional mpi4py package is not installed.  MPI support is not available.


  from .autonotebook import tqdm as notebook_tqdm


# Choose a data set

In [2]:
datadir = '/home/koonima/FAST/Data/Pickle/'
with open(datadir + f'v1p1_de440_pint_bipm2019.pkl', 'rb') as fin:
    psrs = pickle.load(fin)
psrlist = [psr.name for psr in psrs]
with open(datadir + f'v1p1_all_dict.json', 'r') as fin:
    noise_dict = json.load(fin)
inc_ecorr = True
backend = 'backend'
tnequad = False

libstempo not installed. PINT or libstempo are required to use par and tim files.


## Frequency-bins

In [3]:
Tspan = get_tspan(psrs) # The time-span of the entire PTA
crn_bins = 30 # number of frequency-bins for the GWB
int_bins = 30 # number of frequency-bins for the non-GWB (IRN) red noise
assert int_bins >= crn_bins
f_intrin = jnp.arange(1/Tspan, (int_bins + 0.01)/Tspan, 1/Tspan) # an array of frequency-bins for the IRN process
f_common = f_intrin[:crn_bins] # an array of frequency-bins for the common process

# Building the Run in `enteprise`

In [8]:
tm = gp_signals.MarginalizingTimingModel(use_svd=True)
wn = blocks.white_noise_block(
    vary=False,
    inc_ecorr=True,
    gp_ecorr=False,
    select='backend',
    tnequad=tnequad,
)
rn = blocks.red_noise_block(
    psd="powerlaw",
    prior="log-uniform",
    Tspan=Tspan,
    components=int_bins,
    gamma_val=None,
)
gwb = blocks.common_red_noise_block(
    psd="spectrum",
    prior="log-uniform",
    Tspan=Tspan,
    components=int_bins,
    gamma_val=None,
)
s = tm + wn + rn + gwb

pta = signal_base.PTA(
    [s(p) for p in psrs], signal_base.LogLikelihoodDenseCholesky
)
pta.set_default_params(noise_dict)

In [9]:
pta.params

[B1855+09_red_noise_gamma:Uniform(pmin=0, pmax=7),
 B1855+09_red_noise_log10_A:Uniform(pmin=-20, pmax=-11),
 B1937+21_red_noise_gamma:Uniform(pmin=0, pmax=7),
 B1937+21_red_noise_log10_A:Uniform(pmin=-20, pmax=-11),
 B1953+29_red_noise_gamma:Uniform(pmin=0, pmax=7),
 B1953+29_red_noise_log10_A:Uniform(pmin=-20, pmax=-11),
 J0023+0923_red_noise_gamma:Uniform(pmin=0, pmax=7),
 J0023+0923_red_noise_log10_A:Uniform(pmin=-20, pmax=-11),
 J0030+0451_red_noise_gamma:Uniform(pmin=0, pmax=7),
 J0030+0451_red_noise_log10_A:Uniform(pmin=-20, pmax=-11),
 J0340+4130_red_noise_gamma:Uniform(pmin=0, pmax=7),
 J0340+4130_red_noise_log10_A:Uniform(pmin=-20, pmax=-11),
 J0406+3039_red_noise_gamma:Uniform(pmin=0, pmax=7),
 J0406+3039_red_noise_log10_A:Uniform(pmin=-20, pmax=-11),
 J0437-4715_red_noise_gamma:Uniform(pmin=0, pmax=7),
 J0437-4715_red_noise_log10_A:Uniform(pmin=-20, pmax=-11),
 J0509+0856_red_noise_gamma:Uniform(pmin=0, pmax=7),
 J0509+0856_red_noise_log10_A:Uniform(pmin=-20, pmax=-11),
 J05

## To save on GPU memory, save `TNT` and `TNr`

In [26]:
np.save('./TNT.npy', np.array(pta.get_TNT(params={})))
np.save('./TNr.npy', np.array(pta.get_TNr(params={}))[..., None])

# Building the Run in `pandora`

In [15]:
chosen_psd_model, chosen_orf_model, gwb_helper_dictionary = utils.hd_spectrum(renorm_const = 1, 
                                                                            crn_bins = crn_bins, lower_halflog10_rho=-9, upper_halflog10_rho=-4)
gwb_helper_dictionary

{'ordered_gwb_psd_model_params': array(['halflog10_rho'], dtype='<U13'),
 'varied_gwb_psd_params': [np.str_('halflog10_rho')],
 'gwb_psd_param_lower_lim': Array([-9., -9., -9., -9., -9., -9., -9., -9., -9., -9., -9., -9., -9.,
        -9., -9., -9., -9., -9., -9., -9., -9., -9., -9., -9., -9., -9.,
        -9., -9., -9., -9.], dtype=float64),
 'gwb_psd_param_upper_lim': Array([-4., -4., -4., -4., -4., -4., -4., -4., -4., -4., -4., -4., -4.,
        -4., -4., -4., -4., -4., -4., -4., -4., -4., -4., -4., -4., -4.,
        -4., -4., -4., -4.], dtype=float64)}

In [17]:
pnames = [pname for pname in pta.param_names]
pnames

['B1855+09_red_noise_gamma',
 'B1855+09_red_noise_log10_A',
 'B1937+21_red_noise_gamma',
 'B1937+21_red_noise_log10_A',
 'B1953+29_red_noise_gamma',
 'B1953+29_red_noise_log10_A',
 'J0023+0923_red_noise_gamma',
 'J0023+0923_red_noise_log10_A',
 'J0030+0451_red_noise_gamma',
 'J0030+0451_red_noise_log10_A',
 'J0340+4130_red_noise_gamma',
 'J0340+4130_red_noise_log10_A',
 'J0406+3039_red_noise_gamma',
 'J0406+3039_red_noise_log10_A',
 'J0437-4715_red_noise_gamma',
 'J0437-4715_red_noise_log10_A',
 'J0509+0856_red_noise_gamma',
 'J0509+0856_red_noise_log10_A',
 'J0557+1551_red_noise_gamma',
 'J0557+1551_red_noise_log10_A',
 'J0605+3757_red_noise_gamma',
 'J0605+3757_red_noise_log10_A',
 'J0610-2100_red_noise_gamma',
 'J0610-2100_red_noise_log10_A',
 'J0613-0200_red_noise_gamma',
 'J0613-0200_red_noise_log10_A',
 'J0636+5128_red_noise_gamma',
 'J0636+5128_red_noise_log10_A',
 'J0645+5158_red_noise_gamma',
 'J0645+5158_red_noise_log10_A',
 'J0709+0458_red_noise_gamma',
 'J0709+0458_red_nois

In [18]:
# save paramter list
with open(os.path.join('./', "pars.txt"), "w") as fout:
    for pname in pnames:
        fout.write(pname + "\n")

### Now, construct the model using `models.UniformPrior`

In [19]:
o = models.UniformPrior(gwb_psd_func = chosen_psd_model,
                orf_func = chosen_orf_model,
                crn_bins = crn_bins,
                int_bins = int_bins,
                f_common = f_common, 
                f_intrin = f_intrin,
                df = 1/Tspan,
                Tspan = Tspan, 
                Npulsars = len(psrs),
                psr_pos = [psr.pos for psr in psrs],
                gwb_helper_dictionary = gwb_helper_dictionary,
                gamma_min = 0,
                gamma_max = 7,
                log10A_min = -20 + 0.5,
                log10A_max = -11 + 0.5,
                renorm_const = 1)

In [20]:
m  = LC.CURN(psrs = psrs,
            device_to_run_likelihood_on = 'cuda',
            TNr = jnp.load('./TNr.npy', mmap_mode = 'r'),
            TNT = jnp.load('./TNT.npy', mmap_mode = 'r'),
            run_type_object = o,
            noise_dict = None, 
            backend = None, 
            tnequad = False, 
            inc_ecorr = True, 
            del_pta_after_init = True,
            matrix_stabilization = False)

In [21]:
x0 = o.make_initial_guess(key = jar.key(100)) # Some random draw from the prior given an RNG key

# Likelihood Comparison

In [22]:
def lnlike_offset(params):
    likelihood_adjust = 0
    likelihood_adjust += -0.5 * np.sum([ell for ell in pta.get_rNr_logdet(params)])
    likelihood_adjust += sum(pta.get_logsignalprior(params))
    return likelihood_adjust
y0 = np.array(x0)
y0[-1] = x0[134] 
y0[134] = x0[-1] 
params = pta.map_params(y0)
ln_offset = lnlike_offset(params)

In [23]:
for _ in np.arange(1, 10, 1, dtype = int):
    x0 = o.make_initial_guess(key = jar.key(_))
    y0 = np.array(x0)
    y0[-1] = x0[134] 
    y0[134] = x0[-1] 
    print(_, m.get_lnliklihood(x0), pta.get_lnlikelihood(y0) - ln_offset)

1 43492.495810113716 43493.491041834466
2 42892.10318886241 42240.14647618402
3 43693.465875158734 43699.55953341909
4 43303.74356677311 41589.608523029834
5 43763.25767917525 43812.545691027306
6 42054.93184478585 42716.040726311505
7 36637.330118466394 36953.639802444726
8 40656.320667804335 40912.45881380234
9 42901.70968149616 42068.74924931396


In [30]:
%timeit m.get_lnliklihood(x0)

496 μs ± 27.3 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [31]:
rhos, _ = o.get_phi_mat_CURN(x0)

In [40]:
ent_rhos = np.array(pta.get_phi(pta.map_params(x0)))

In [44]:
np.allclose(rhos, ent_rhos.T[0::2])

True