# Load Needed Packages

#### Note: if you want to change precision or device of JAX, the best way to do it is to follow (as an example)
#### `conda activate <your env name>`
#### `conda env config vars set JAX_ENABLE_X64=True`
#### `conda env config vars set jax_platform_name=cpu`
#### `conda activate <your env name>`
#### Make sure to restart VScode or jupyter notebook after this! `jax.config.update()` may or may not work because I define default jax arrays in different places.

In [1]:
import jax
import jax.numpy as jnp
import jax.scipy as jsp
import jax.random as jar
# jax.config.update('jax_platform_name', 'cuda')
# jax.config.update("jax_enable_x64", False)

from pandora import models, utils, GWBFunctions
from pandora import LikelihoodCalculator as LC
from enterprise_extensions.model_utils import get_tspan
import pickle, json, os, corner, glob, random, copy, time, inspect
import matplotlib.pyplot as plt
from matplotlib.ticker import AutoMinorLocator
import matplotlib.lines as mlines

plt.style.use('dark_background')
hist_settings = dict(
    bins = 40,
    histtype = 'step',
    lw = 3,
    density = True
)

%matplotlib inline
%config InlineBackend.figure_format = 'retina'
%load_ext autoreload
%autoreload 2
# %load_ext line_profiler



Optional mpi4py package is not installed.  MPI support is not available.


# Choose a data set

In [None]:
datadir = ...
with open(datadir + f'v1p1_de440_pint_bipm2019.pkl', 'rb') as fin:
    psrs = pickle.load(fin)
psrlist = [psr.name for psr in psrs]
with open(datadir + f'v1p1_all_dict.json', 'r') as fin:
    noise_dict = json.load(fin)
inc_ecorr = True
backend = 'backend'
tnequad = False

# Step 1: Model Construction

## Frequency-bins

In [6]:
Tspan = get_tspan(psrs) # The time-span of the entire PTA
crn_bins = 5 # number of frequency-bins for the GWB
int_bins = 5 # number of frequency-bins for the non-GWB (IRN) red noise
assert int_bins >= crn_bins
f_intrin = jnp.arange(1/Tspan, (int_bins + 0.01)/Tspan, 1/Tspan) # an array of frequency-bins for the IRN process
f_common = f_intrin[:crn_bins] # an array of frequency-bins for the common process
renorm_const = 1 # the factor by which the units are going to change (divided by). Set it to `1` for no unit change (seconds), or let it be `1e9` (nano seconds)

## GWB PSD Model 1

In [7]:
chosen_psd_model1, chosen_orf_model1, gwb_helper_dictionary1 = utils.fixed_gamma_hd_pl(renorm_const=renorm_const, lower_amp=-16.0, upper_amp=-13.0)
gwb_helper_dictionary1

{'ordered_gwb_psd_model_params': array(['log10_A', 'gamma'], dtype='<U7'),
 'fixed_gwb_psd_params': ['gamma'],
 'varied_gwb_psd_params': ['log10_A'],
 'gwb_psd_param_lower_lim': Array([-16.], dtype=float32),
 'gwb_psd_param_upper_lim': Array([-13.], dtype=float32),
 'fixed_gwb_psd_param_indices': Array([1], dtype=int32),
 'fixed_gwb_psd_param_values': Array([4.3333335], dtype=float32)}

In [8]:
o1 = models.UniformPrior(gwb_psd_func = chosen_psd_model1,
                orf_func = chosen_orf_model1,
                crn_bins = crn_bins,
                int_bins = int_bins,
                f_common = f_common, 
                f_intrin = f_intrin,
                df = 1/Tspan,
                Tspan = Tspan, 
                Npulsars = len(psrs),
                psr_pos = [psr.pos for psr in psrs],
                gwb_helper_dictionary = gwb_helper_dictionary1,
                gamma_min = 0,
                gamma_max = 7,
                log10A_min = -20 + 0.5 * jnp.log10(renorm_const), #`0.5 * jnp.log10(renorm_const)` is added to account for change in units,
                log10A_max = -11 + 0.5 * jnp.log10(renorm_const), #`0.5 * jnp.log10(renorm_const)` is added to account for change in units,
                renorm_const = renorm_const)

In [9]:
m1  = LC.MultiPulsarModel(psrs = psrs,
                        device_to_run_likelihood_on = 'cuda',
                        TNr=jnp.array([False]),
                        TNT=jnp.array([False]),
                        run_type_object = o1,
                        noise_dict = noise_dict, 
                        backend = 'backend', 
                        tnequad = False, 
                        inc_ecorr = True, 
                        del_pta_after_init = True,
                        matrix_stabilization = True)

The delta is 1e-06
Condition number of the TNT matrix before stabilizing is: 1.3969157e+19
Condition number of the TNT matrix after stabilizing is: 3.8208222e+16


## GWB PSD Model 2

In [10]:
chosen_psd_model2, chosen_orf_model2, gwb_helper_dictionary2 = utils.varied_gamma_hd_pl(renorm_const=renorm_const)
gwb_helper_dictionary2

{'ordered_gwb_psd_model_params': array(['log10_A', 'gamma'], dtype='<U7'),
 'varied_gwb_psd_params': ['log10_A', 'gamma'],
 'gwb_psd_param_lower_lim': Array([-18.,   0.], dtype=float32),
 'gwb_psd_param_upper_lim': Array([-11.,   7.], dtype=float32)}

In [11]:
o2 = models.UniformPrior(gwb_psd_func = chosen_psd_model2,
                orf_func = chosen_orf_model2,
                crn_bins = crn_bins,
                int_bins = int_bins,
                f_common = f_common, 
                f_intrin = f_intrin,
                df = 1/Tspan,
                Tspan = Tspan, 
                Npulsars = len(psrs),
                psr_pos = [psr.pos for psr in psrs],
                gwb_helper_dictionary = gwb_helper_dictionary2,
                gamma_min = 0,
                gamma_max = 7,
                log10A_min = -20 + 0.5 * jnp.log10(renorm_const), #`0.5 * jnp.log10(renorm_const)` is added to account for change in units,
                log10A_max = -11 + 0.5 * jnp.log10(renorm_const), #`0.5 * jnp.log10(renorm_const)` is added to account for change in units,
                renorm_const = renorm_const)

In [12]:
m2  = LC.MultiPulsarModel(psrs = psrs,
                        device_to_run_likelihood_on = 'cuda',
                        TNr=jnp.array([False]),
                        TNT=jnp.array([False]),
                        run_type_object = o2,
                        noise_dict = noise_dict, 
                        backend = 'backend', 
                        tnequad = False, 
                        inc_ecorr = True, 
                        del_pta_after_init = True,
                        matrix_stabilization = True)

The delta is 1e-06
Condition number of the TNT matrix before stabilizing is: 1.3969157e+19
Condition number of the TNT matrix after stabilizing is: 3.8208222e+16


# Construct the HM

In [52]:
hm_object = LC.TwoModelHyperModel(model1=m1, model2=m2)

In [53]:
x0 = hm_object.make_initial_guess(jar.key(100))
x0.shape

(138,)

In [54]:
x0 = x0.at[-1].set(.9)
hm_object.get_lnliklihood(x0), hm_object.get_lnprior(x0)

(Array(46429.242, dtype=float32), Array(-8.01, dtype=float32, weak_type=True))

In [47]:
x0 = x0.at[-1].set(.1)
hm_object.get_lnliklihood(x0), hm_object.get_lnprior(x0)

(Array(46690.426, dtype=float32), Array(-8.01, dtype=float32, weak_type=True))

In [49]:
%timeit hm_object.get_lnliklihood(x0)

2.79 ms ± 38.6 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [18]:
%timeit hm_object.get_lnprior(x0)

82 µs ± 7.69 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


## Sampling

In [None]:
hm_object.sample(x0 = np.array(x0), niter = int(1e6), savedir = '../testnew/HM/', 
         resume=False)