# Load Needed Packages

#### Note: if you want to change precision or device of JAX, the best way to do it is to follow (as an example)
#### `conda activate <your env name>`
#### `conda env config vars set JAX_ENABLE_X64=True`
#### `conda env config vars set jax_platform_name=cpu`
#### `conda activate <your env name>`
#### Make sure to restart VScode or jupyter notebook after this! `jax.config.update("jax_enable_x64", False)` may or may not work because I define default jax arrays in my python codes.

In [1]:
import jax
import jax.numpy as jnp
import jax.scipy as jsp
import jax.random as jar
import nf_dist
# jax.config.update('jax_platform_name', 'cuda')
# jax.config.update("jax_enable_x64", False)
import torch
import numpy as np
from scipy.stats import norm
import models, utils, GWBFunctions
import LikelihoodCalculator as LC
from enterprise_extensions.model_utils import get_tspan
import pickle, json, os, corner, glob, random, copy, time, inspect
import matplotlib.pyplot as plt
from matplotlib.ticker import AutoMinorLocator
import matplotlib.lines as mlines

import torch.utils.dlpack as td
import jax.dlpack as jd

plt.style.use('dark_background')
hist_settings = dict(
    bins = 40,
    histtype = 'step',
    lw = 3,
    density = True
)

%matplotlib inline
%config InlineBackend.figure_format = 'retina'
%load_ext autoreload
%autoreload 2
# %load_ext line_profiler



The fastshermanmorrison package is not found. The single pulsar analyses are much faster with the fast algorithm!
Optional mpi4py package is not installed.  MPI support is not available.


# Choose a data set

In [2]:
dsets = ['SIMG45', 'SIMG', 'GSIM', 'GSIMNULL', 'GSIMRED', 'SIMRed', 'SIMBase', 'NIMCast', '15SIM', '12p5', 'A4Cast', 'SIMNULL', 'NIMCast', '15yr', 'AJAXPaper', 'SIMCurn']

In [3]:
id = -3
cname = dsets[id]
cname

'15yr'

In [4]:
crn_bins = 5
Npulsars = 67
noise_dict = {}
# cname = names[0][0]
cname

'15yr'

In [5]:
if cname == '12p5':
    rr = 0
    with open(f'../Data/Pickle/channelized_12yr_v3_partim_DE438.pkl', 'rb') as fin:
        psrs = pickle.load(fin)
    psrlist = [psr.name for psr in psrs]
    with open(f'../Data/Pickle/channelized_12p5yr_v3_full_noisedict.json', 'r') as fin:
        nd = json.load(fin)
    for k, v in nd.items():
        if 'equad' in k:
            noise_dict.update({k.replace('equad','tnequad' ):v})
        else:
            noise_dict.update({k:v})
elif cname == '15yr':
    rr = 0
    with open(f'../Data/Pickle/v1p1_de440_pint_bipm2019.pkl', 'rb') as fin:
        psrs = pickle.load(fin)[:Npulsars]
    psrlist = [psr.name for psr in psrs]
    with open(f'../Data/Pickle/v1p1_all_dict.json', 'r') as fin:
        noise_dict = json.load(fin)
    inc_ecorr = True
    backend = 'backend'
    tnequad = False
    
else:
    rr = 2
    with open(f'../DEMO/Data/Pickle/{cname}/{rr}.pkl', 'rb') as fin:
        psrs = pickle.load(fin)
    len(psrs)
    psrlist = [psr.name for psr in psrs]
    for pname in psrlist:
        noise_dict.update({pname + '_efac': 1.0})
        noise_dict.update({pname + '_log10_t2equad': -np.inf})



# Step 1: Model Construction for GWB

## Frequency-bins

In [6]:
Tspan = get_tspan(psrs) # The time-span of the entire PTA
crn_bins = 5 # number of frequency-bins for the GWB
int_bins = 30 # number of frequency-bins for the non-GWB (IRN) red noise
assert int_bins > crn_bins
f_intrin = jnp.arange(1/Tspan, (int_bins + 0.01)/Tspan, 1/Tspan) # an array of frequency-bins for the IRN process
f_common = f_intrin[:crn_bins] # an array of frequency-bins for the common process
renorm_const = 1e9 # the factor by which the units are going to change (divided by). Set it to `1` for no unit change (seconds), or let it be `1e9` (nano seconds) for better performance.

## GWB PSD model

### A dictionary is used to store the necessary information about GWB PSD. You can either use the `utils.param_order_help` to make your own dictionary or use one of the pre-made ones. Take a look at `GWBFunctions.py` for a list of supported PSD as well as GWB ORF functions.

### For example, choose a GWB with HD correlations and fixed spectral index (at 13/3) powerlaw PSD.

In [7]:
# chosen_psd_model, chosen_orf_model, gwb_helper_dictionary = utils.fixed_gamma_hd_pl(renorm_const=renorm_const)
chosen_psd_model, chosen_orf_model, gwb_helper_dictionary = utils.hd_spectrum(renorm_const=renorm_const,crn_bins=crn_bins)
# chosen_psd_model, chosen_orf_model, gwb_helper_dictionary = utils.varied_gamma_hd_pl(renorm_const=renorm_const)
# chosen_psd_model, chosen_orf_model, gwb_helper_dictionary = utils.varied_gamma_gt_pl(renorm_const=renorm_const)
# chosen_psd_model, chosen_orf_model, gwb_helper_dictionary = utils.varied_gamma_bin_orf_pl(renorm_const=renorm_const)
gwb_helper_dictionary

{'ordered_gwb_psd_model_params': array(['halflog10_rho'], dtype='<U13'),
 'varied_gwb_psd_params': ['halflog10_rho'],
 'gwb_psd_param_lower_lim': Array([-4.5, -4.5, -4.5, -4.5, -4.5], dtype=float64),
 'gwb_psd_param_upper_lim': Array([3.5, 3.5, 3.5, 3.5, 3.5], dtype=float64)}

### Now, construct the model using `models.UniformPrior`

In [8]:
help(models.UniformPrior)

Help on class UniformPrior in module models:

class UniformPrior(builtins.object)
 |  UniformPrior(gwb_psd_func, orf_func, crn_bins, int_bins, f_common, f_intrin, df, psr_pos, Tspan, Npulsars, gwb_helper_dictionary, gamma_min=0, gamma_max=7, log10A_min=Array(-13.5, dtype=float64), log10A_max=Array(-6.5, dtype=float64), renorm_const=1000000000.0)
 |  
 |  A class to take care of prior and the phi-matrix construction based on uniform/log-uniform priors.
 |  
 |  :param gwb_psd_func: a PSD function from the `GWBFunctions` class
 |  :param orf_func: an orf function from the `GWBFunctions` class
 |  :param crn_bins: number of frequency-bins for the GWB
 |  :param int_bins: number of frequency-bins for the non-GWB (IRN) red noise
 |  :param `f_common`: an array of frequency-bins for the common process
 |  :param `f_intrin`: an array of frequency-bins for the IRN process
 |  :param df: the diffence between consecutive frequency-bins. It is usually 1/Tspan
 |  :param psr_pos: an array of pulsa

In [9]:
o = models.UniformPrior(gwb_psd_func = chosen_psd_model,
                orf_func = chosen_orf_model,
                crn_bins = crn_bins,
                int_bins = int_bins,
                f_common = f_common, 
                f_intrin = f_intrin,
                df = 1/Tspan,
                Tspan = Tspan, 
                Npulsars = len(psrs),
                psr_pos = [psr.pos for psr in psrs],
                gwb_helper_dictionary = gwb_helper_dictionary,
                gamma_min = 0,
                gamma_max = 7,
                log10A_min = -20 + 0.5 * jnp.log10(renorm_const), #`0.5 * jnp.log10(renorm_const)` is added to account for change in units,
                log10A_max = -11 + 0.5 * jnp.log10(renorm_const), #`0.5 * jnp.log10(renorm_const)` is added to account for change in units,
                renorm_const = renorm_const)



In [10]:
x0 = o.make_initial_guess(key = jar.key(100)) # Some random draw from the prior given an RNG key

In [11]:
phimat = o.get_phi_mat(x0)
phimat.shape

(30, 67, 67)

### Here is the phi-matrix. Note that this matrix is batched by `max(int_bins, crn_bins)` and is lower-triangular as fully populating phi (a positive definite matrix) is vain for computational purposes!

In [12]:
phimat[0]

Array([[ 1.13757125e+02,  0.00000000e+00,  0.00000000e+00, ...,
         0.00000000e+00,  0.00000000e+00,  0.00000000e+00],
       [ 4.38168166e+01,  1.27119917e+04,  0.00000000e+00, ...,
         0.00000000e+00,  0.00000000e+00,  0.00000000e+00],
       [ 3.29142536e+01,  5.19604098e+01,  1.13721931e+02, ...,
         0.00000000e+00,  0.00000000e+00,  0.00000000e+00],
       ...,
       [-1.15913598e+01,  1.51887022e+00,  1.03131491e+01, ...,
         3.14141127e+07,  0.00000000e+00,  0.00000000e+00],
       [-1.19664332e+01, -3.05603017e+00,  3.25103762e-01, ...,
         2.37859771e+01,  1.13737831e+02,  0.00000000e+00],
       [-1.25210477e+01, -2.78271737e+00,  1.75695204e+00, ...,
         3.23319078e+01,  5.36935502e+01,  1.11562410e+12]],      dtype=float64)

# Step 2: Likelihood Calculation and Addition of Astro Parameters

In [13]:
help(LC.AstroInferenceModel)

Help on class AstroInferenceModel in module LikelihoodCalculator:

class AstroInferenceModel(builtins.object)
 |  AstroInferenceModel(nf_dist, num_astro_params, astr_prior_lower_lim, astr_prior_upper_lim, astro_additional_prior_func, run_type_object, psrs, TNr=Array([False], dtype=bool), TNT=Array([False], dtype=bool), noise_dict=None, backend='none', tnequad=False, inc_ecorr=False, del_pta_after_init=True, matrix_stabilization=True)
 |  
 |  A class to calculate likelihood based on a given IRN + GWB model (no deterministic signal) as well as as a normalizing flow astroemulator
 |  
 |  :param nf_dist: the normalizing flow object
 |  :param num_astro_params: the number of astro parameters 
 |  :param astr_prior_lower_lim: the lower bound for the astro parameters
 |  :param astr_prior_upper_lim: the upper bound for the astro parameters
 |  :param astro_additional_prior_func: if you want non-uniform prior on the astro parameters, supply a numpy-compatible function to calculate
 |        

### Making the AstroInference object is easy. In addition to your normalzing flow object, you need to define your prior function for the astro-parameters. For example:

In [14]:
def astro_prior_func_uniform(xs):
    '''
    this function does not do anything. Uniform prior is already included in the Bayesian inference object. 
    So, if you wish to use uniform prior on the astro parameters, this function is what you need.

    :param xs: an array of length equal to the number of astro params
    '''
    return 0

#mean = np.array([...])
#std = np.array([...])
def astro_additional_prior_func_normal(xs):
    '''
    this function is for the inclusion of the normal priors. You can change it to a multi-variate normal if needed.
    just choose your mean and std for the astro parameters. BE CAREFUL! The order of `xs` is the same order as your 
    trained normalizing flow object.

    :param xs: an array of length equal to the number of astro params
    '''
    return norm.logpdf(xs, loc=mean, scale=std)

In [15]:
astro_prior_bounds = np.array([[0.1, 11.0], 
          [-3.5, -1.5], 
          [10.5, 12.5], 
          [7.6, 9.0], 
          [0.0, 0.9], 
          [-1.5, 0.0]])

astro_x0 = np.random.uniform( 
            astro_prior_bounds[:, 0], 
            astro_prior_bounds[:, 1])[None]
astro_x0, astro_x0.shape

(array([[ 7.55645124, -1.55699849, 12.13215913,  8.49829871,  0.09407345,
         -1.30706308]]),
 (1, 6))

### Load your normalizing flow object

In [16]:
nf, half_range, B, mean, _, _ = torch.load('../Data/AstroData/PaperFinal/chosenone/lr0.0001_bs1000_decay0/condflow_at_40000.pkl',
                                            map_location = 'cpu')

  nf, half_range, B, mean, _, _ = torch.load('../Data/AstroData/PaperFinal/chosenone/lr0.0001_bs1000_decay0/condflow_at_40000.pkl',


### Specify the number of astro parameters

In [20]:
n_astro_params = 6

In [21]:
nf_dist_object = nf_dist.NFastroinference(pyro_nf_object = nf,
        mean = np.array(mean),
        half_range = np.array(half_range),
        scale = B,
        gwb_freq_idxs = np.array(range(n_astro_params, n_astro_params + crn_bins), dtype = int),
        ast_param_idxs = np.array(range(n_astro_params), dtype = int))

### Make the AstroInference object

In [22]:
m = LC.AstroInferenceModel(psrs = psrs,
                        nf_dist = nf_dist_object ,
                        num_astro_params = n_astro_params,
                        astr_prior_lower_lim = astro_prior_bounds[:, 0],
                        astr_prior_upper_lim = astro_prior_bounds[:, 1],
                        astro_additional_prior_func = astro_prior_func_uniform,
                        TNr=jnp.array([False]),
                        TNT=jnp.array([False]),
                        run_type_object = o,
                        noise_dict = noise_dict, 
                        backend = 'backend', 
                        tnequad = False, 
                        inc_ecorr = True, 
                        del_pta_after_init = True,
                        matrix_stabilization = True)

Condition number of the TNT matrix before stabilizing is: 5.5744446445107334e+20
Condition number of the TNT matrix after stabilizing is: 6.926691275258731e+14


### It is VERY important that the initial guess of the model parameters `x0` leads to a finite likelihood! Keep generating new `x0` until the likelihood is finite


In [28]:
x0 = m.make_initial_guess()
ans = m.get_lnliklihood(x0)
assert np.isfinite(ans)
ans

array([35317.62], dtype=float32)

In [29]:
%timeit m.get_lnliklihood(x0)

58.2 ms ± 5.21 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


### Now perform sampling... 

In [None]:
m.sample(x0 = x0, niter = int(1e6), savedir = '../testnew/', resume=True)