# Load Needed Packages

#### Note: if you want to change precision or device of JAX, the best way to do it is to follow (as an example)
#### `conda activate <your env name>`
#### `conda env config vars set JAX_ENABLE_X64=True`
#### `conda env config vars set jax_platform_name=cpu`
#### `conda activate <your env name>`
#### Make sure to restart VScode or jupyter notebook after this! `jax.config.update()` may or may not work because I define default jax arrays in different places.

In [1]:
import jax
import jax.numpy as jnp
import jax.scipy as jsp
import jax.random as jar
# jax.config.update('jax_platform_name', 'cuda')
# jax.config.update("jax_enable_x64", False)

import numpy as np
import models, utils, GWBFunctions
import LikelihoodCalculator as LC
from enterprise_extensions.model_utils import get_tspan
import pickle, json, os, corner, glob, random, copy, time, inspect
import matplotlib.pyplot as plt
from matplotlib.ticker import AutoMinorLocator
import matplotlib.lines as mlines

plt.style.use('dark_background')
hist_settings = dict(
    bins = 40,
    histtype = 'step',
    lw = 3,
    density = True
)

%matplotlib inline
%config InlineBackend.figure_format = 'retina'
%load_ext autoreload
%autoreload 2
# %load_ext line_profiler



The fastshermanmorrison package is not found. The single pulsar analyses are much faster with the fast algorithm!
Optional mpi4py package is not installed.  MPI support is not available.


# Choose a data set

In [2]:
dsets = ['SIMG45', 'SIMG', 'GSIM', 'GSIMNULL', 'GSIMRED', 'SIMRed', 'SIMBase', 'NIMCast', '15SIM', '12p5', 'A4Cast', 'SIMNULL', 'NIMCast', '15yr', 'AJAXPaper', 'SIMCurn']

In [3]:
id = -3
cname = dsets[id]
cname

'15yr'

In [4]:
crn_bins = 5
Npulsars = 67
noise_dict = {}
# cname = names[0][0]
cname

'15yr'

In [5]:
if cname == '12p5':
    rr = 0
    with open(f'../Data/Pickle/channelized_12yr_v3_partim_DE438.pkl', 'rb') as fin:
        psrs = pickle.load(fin)
    psrlist = [psr.name for psr in psrs]
    with open(f'../Data/Pickle/channelized_12p5yr_v3_full_noisedict.json', 'r') as fin:
        nd = json.load(fin)
    for k, v in nd.items():
        if 'equad' in k:
            noise_dict.update({k.replace('equad','tnequad' ):v})
        else:
            noise_dict.update({k:v})
elif cname == '15yr':
    rr = 0
    with open(f'../Data/Pickle/v1p1_de440_pint_bipm2019.pkl', 'rb') as fin:
        psrs = pickle.load(fin)[:Npulsars]
    psrlist = [psr.name for psr in psrs]
    with open(f'../Data/Pickle/v1p1_all_dict.json', 'r') as fin:
        noise_dict = json.load(fin)
    inc_ecorr = True
    backend = 'backend'
    tnequad = False
    
else:
    rr = 2
    with open(f'../DEMO/Data/Pickle/{cname}/{rr}.pkl', 'rb') as fin:
        psrs = pickle.load(fin)
    len(psrs)
    psrlist = [psr.name for psr in psrs]
    for pname in psrlist:
        noise_dict.update({pname + '_efac': 1.0})
        noise_dict.update({pname + '_log10_t2equad': -np.inf})



# Step 1: Model Construction

## Frequency-bins

In [6]:
Tspan = get_tspan(psrs) # The time-span of the entire PTA
crn_bins = 5 # number of frequency-bins for the GWB
int_bins = 5 # number of frequency-bins for the non-GWB (IRN) red noise
assert int_bins >= crn_bins
f_intrin = jnp.arange(1/Tspan, (int_bins + 0.01)/Tspan, 1/Tspan) # an array of frequency-bins for the IRN process
f_common = f_intrin[:crn_bins] # an array of frequency-bins for the common process
renorm_const = 1. # the factor by which the units are going to change (divided by). Set it to `1` for no unit change (seconds), or let it be `1e9` (nano seconds) for better performance.

## GWB PSD model

### A dictionary is used to store the necessary information about GWB PSD. You can either use the `utils.param_order_help` to make your own dictionary or use one of the pre-made ones. Take a look at `GWBFunctions.py` for a list of supported PSD as well as GWB ORF functions.

### For example, choose a GWB with HD correlations and fixed spectral index (at 13/3) powerlaw PSD.

In [7]:
# chosen_psd_model, chosen_orf_model, gwb_helper_dictionary = utils.fixed_gamma_hd_pl(renorm_const=renorm_const)
chosen_psd_model, chosen_orf_model, gwb_helper_dictionary = utils.hd_spectrum(renorm_const=renorm_const,crn_bins=crn_bins)
# chosen_psd_model, chosen_orf_model, gwb_helper_dictionary = utils.varied_gamma_hd_pl(renorm_const=renorm_const)
# chosen_psd_model, chosen_orf_model, gwb_helper_dictionary = utils.varied_gamma_gt_pl(renorm_const=renorm_const)
# chosen_psd_model, chosen_orf_model, gwb_helper_dictionary = utils.varied_gamma_bin_orf_pl(renorm_const=renorm_const)
gwb_helper_dictionary

{'ordered_gwb_psd_model_params': array(['halflog10_rho'], dtype='<U13'),
 'varied_gwb_psd_params': ['halflog10_rho'],
 'gwb_psd_param_lower_lim': Array([-9., -9., -9., -9., -9.], dtype=float64),
 'gwb_psd_param_upper_lim': Array([-1., -1., -1., -1., -1.], dtype=float64)}

### Now, construct the model using `models.UniformPrior`

In [8]:
help(models.UniformPrior)

Help on class UniformPrior in module models:

class UniformPrior(builtins.object)
 |  UniformPrior(gwb_psd_func, orf_func, crn_bins, int_bins, f_common, f_intrin, df, psr_pos, Tspan, Npulsars, gwb_helper_dictionary, gamma_min=0, gamma_max=7, log10A_min=Array(-13.5, dtype=float64), log10A_max=Array(-6.5, dtype=float64), renorm_const=1000000000.0)
 |  
 |  A class to take care of prior and the phi-matrix construction based on uniform/log-uniform priors.
 |  
 |  :param gwb_psd_func: a PSD function from the `GWBFunctions` class
 |  :param orf_func: an orf function from the `GWBFunctions` class
 |  :param crn_bins: number of frequency-bins for the GWB
 |  :param int_bins: number of frequency-bins for the non-GWB (IRN) red noise
 |  :param `f_common`: an array of frequency-bins for the common process
 |  :param `f_intrin`: an array of frequency-bins for the IRN process
 |  :param df: the diffence between consecutive frequency-bins. It is usually 1/Tspan
 |  :param psr_pos: an array of pulsa

In [9]:
o = models.UniformPrior(gwb_psd_func = chosen_psd_model,
                orf_func = chosen_orf_model,
                crn_bins = crn_bins,
                int_bins = int_bins,
                f_common = f_common, 
                f_intrin = f_intrin,
                df = 1/Tspan,
                Tspan = Tspan, 
                Npulsars = len(psrs),
                psr_pos = [psr.pos for psr in psrs],
                gwb_helper_dictionary = gwb_helper_dictionary,
                gamma_min = 0,
                gamma_max = 7,
                log10A_min = -20 + 0.5 * jnp.log10(renorm_const), #`0.5 * jnp.log10(renorm_const)` is added to account for change in units,
                log10A_max = -11 + 0.5 * jnp.log10(renorm_const), #`0.5 * jnp.log10(renorm_const)` is added to account for change in units,
                renorm_const = renorm_const)

In [10]:
x0 = o.make_initial_guess(key = jar.key(100)) # Some random draw from the prior given an RNG key

In [11]:
phimat = o.get_phi_mat(x0)
phimat.shape

(5, 67, 67)

### Here is the phi-matrix. Note that this matrix is batched by `max(int_bins, crn_bins)` and is lower-triangular as fully populating phi (a positive definite matrix) is vain for computational purposes!

In [12]:
phimat[0]

Array([[ 1.13721304e-07,  0.00000000e+00,  0.00000000e+00, ...,
         0.00000000e+00,  0.00000000e+00,  0.00000000e+00],
       [ 4.38168166e-08,  1.13721316e-07,  0.00000000e+00, ...,
         0.00000000e+00,  0.00000000e+00,  0.00000000e+00],
       [ 3.29142536e-08,  5.19604098e-08,  1.13721304e-07, ...,
         0.00000000e+00,  0.00000000e+00,  0.00000000e+00],
       ...,
       [-1.15913598e-08,  1.51887022e-09,  1.03131491e-08, ...,
         1.13752718e-07,  0.00000000e+00,  0.00000000e+00],
       [-1.19664332e-08, -3.05603017e-09,  3.25103762e-10, ...,
         2.37859771e-08,  1.13721304e-07,  0.00000000e+00],
       [-1.25210477e-08, -2.78271737e-09,  1.75695204e-09, ...,
         3.23319078e-08,  5.36935502e-08,  1.22934541e-06]],      dtype=float64)

# Step 2: Likelihood Calculation

In [15]:
help(LC.MultiPulsarModel)

Help on class MultiPulsarModel in module LikelihoodCalculator:

class MultiPulsarModel(builtins.object)
 |  MultiPulsarModel(run_type_object, psrs, TNr=Array([False], dtype=bool), TNT=Array([False], dtype=bool), noise_dict=None, backend='none', tnequad=False, inc_ecorr=False, del_pta_after_init=True, matrix_stabilization=True)
 |  
 |  A class to calculate likelihood based on given IRN + GWB models (no deterministic signal)
 |  
 |  :param run_type_object: a class from `run_types.py`
 |  :param psrs: an enterprise `psrs` object. Ignored if `TNr` and `TNT` is supplied
 |  :param TNr: the so-called TNr matrix. It is the product of the basis matrix `F`
 |  with the inverse of the timing marginalized white noise covaraince matrix D and the timing residulas `r`.
 |  The naming convension should read FD^-1r but TNr is a more well-known name for this quantity!
 |  :param TNT: the so-called TNT matrix. It is the product of the basis matrix `F`
 |  with the inverse of the timing marginalized wh

### Making the likelihood calculator object is easy. We never change the white noise for multi-pulsar analyses, so the sufficient data is really `TNr` and `TNT` matrices. Thus, you can either supply a `psrs` object to make a `TNr` and `TNT` from, or you can supply both `TNr` and `TNT` directly. `noise_dict`, `backend`, `tnequad`, and `inc_ecorr` are only used when `psrs` is supplied.

In [13]:
m  = LC.MultiPulsarModel(psrs = psrs,
                        TNr=jnp.array([False]),
                        TNT=jnp.array([False]),
                        run_type_object = o,
                        noise_dict = noise_dict, 
                        backend = 'backend', 
                        tnequad = False, 
                        inc_ecorr = True, 
                        del_pta_after_init = False,
                        matrix_stabilization = False)
m.pta.params

[B1855+09_red_noise_gamma:Uniform(pmin=0, pmax=7),
 B1855+09_red_noise_log10_A:Uniform(pmin=-20, pmax=-11),
 B1937+21_red_noise_gamma:Uniform(pmin=0, pmax=7),
 B1937+21_red_noise_log10_A:Uniform(pmin=-20, pmax=-11),
 B1953+29_red_noise_gamma:Uniform(pmin=0, pmax=7),
 B1953+29_red_noise_log10_A:Uniform(pmin=-20, pmax=-11),
 J0023+0923_red_noise_gamma:Uniform(pmin=0, pmax=7),
 J0023+0923_red_noise_log10_A:Uniform(pmin=-20, pmax=-11),
 J0030+0451_red_noise_gamma:Uniform(pmin=0, pmax=7),
 J0030+0451_red_noise_log10_A:Uniform(pmin=-20, pmax=-11),
 J0340+4130_red_noise_gamma:Uniform(pmin=0, pmax=7),
 J0340+4130_red_noise_log10_A:Uniform(pmin=-20, pmax=-11),
 J0406+3039_red_noise_gamma:Uniform(pmin=0, pmax=7),
 J0406+3039_red_noise_log10_A:Uniform(pmin=-20, pmax=-11),
 J0437-4715_red_noise_gamma:Uniform(pmin=0, pmax=7),
 J0437-4715_red_noise_log10_A:Uniform(pmin=-20, pmax=-11),
 J0509+0856_red_noise_gamma:Uniform(pmin=0, pmax=7),
 J0509+0856_red_noise_log10_A:Uniform(pmin=-20, pmax=-11),
 J05

In [14]:
m.get_lnliklihood(x0)

Array(45864.21266335, dtype=float64)

# Comparisons to `enterprise`

In [15]:
x0 = o.make_initial_guess(key = jar.key(100))
x0.shape

(139,)

## phi-mat construction comparison

In [18]:
%timeit phi = o.get_phi_mat(x0)

37.7 µs ± 1.02 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [19]:
%timeit ent_phi = m.pta.get_phi(m.pta.map_params(x0))

7.88 ms ± 792 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


## Likelihood comparison

In [16]:
for _ in np.arange(1, 5, 1, dtype = int):
    x0 = o.make_initial_guess(key = jar.key(_))
    params = m.pta.map_params(x0)
    print(_, m.pta.get_lnlikelihood(params))

1 7969178.804954649
2 7970463.753592478
3 7955323.971957422
4 7970182.545376894


### White noise is fixed, so `rNr` is just a constant and does not need to be involved in likelihood calculations.

In [17]:
def lnlike_offset():
    likelihood_adjust = 0
    likelihood_adjust += -0.5 * np.sum([ell for ell in m.pta.get_rNr_logdet(params)])
    likelihood_adjust += sum(m.pta.get_logsignalprior(params))
    return likelihood_adjust

In [18]:
%timeit lnlike_offset

7.12 ns ± 0.0214 ns per loop (mean ± std. dev. of 7 runs, 100,000,000 loops each)


### The likelihood values MUST match given same precision!

In [19]:
m.get_lnliklihood(x0)

Array(44979.52526641, dtype=float64)

In [20]:
m.pta.get_lnlikelihood(params) - lnlike_offset()

44979.525266406126

### The timing varies given precision, device (CPU or GPU), number of frequency-bins, and number of pulsars! To keep the comparison fair, you should use CPU and double precision. To get the best performance, you should use GPU and single-precision. Issues related to single-precision will go away if you renormalize the `TNT` matrix and stabilize it numerically (both of which can be done using the classes shown above!).
### Also, `%timeit` is not a good way to time a GPU accelerated function!

In [21]:
%timeit o.get_phi_mat(x0); m.get_lnliklihood(x0)

4.04 ms ± 10.7 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [22]:
%timeit m.pta.get_lnlikelihood(m.pta.map_params(x0))

188 ms ± 14.9 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
