In [1]:
import jax
import jax.numpy as jnp
import jax.scipy as jsp
import jax.random as jar
# jax.config.update('jax_platform_name', 'cuda')
import numpy as np
import scipy.linalg as sl
jax.config.update("jax_enable_x64", True)

from pandora import models, utils, GWBFunctions
from pandora import LikelihoodCalculator as LC

import pickle, json, os, corner, glob, random, copy, time, inspect

import matplotlib.pyplot as plt
from matplotlib.ticker import AutoMinorLocator
import matplotlib.lines as mlines
plt.style.use('dark_background')
hist_settings = dict(
    bins = 40,
    histtype = 'step',
    lw = 3,
    density = True
)
%matplotlib inline
%config InlineBackend.figure_format = 'retina'
%load_ext autoreload
%autoreload 2
# %load_ext line_profiler



Optional mpi4py package is not installed.  MPI support is not available.


In [None]:
path = ...
den = jnp.load(path + f'/kde/density.npy', mmap_mode = 'r')
grid = jnp.load(path + f'/kde/log10rhogrid.npy', mmap_mode = 'r')

In [3]:
chosen_psd_model, chosen_orf_model, gwb_helper_dictionary = utils.varied_gamma_hd_pl(renorm_const = 1, 
                                                                                        lower_amp=-18.0, upper_amp=-11.0, 
                                                                                        lower_gamma = 0., upper_gamma = 7.)
gwb_helper_dictionary

{'ordered_gwb_psd_model_params': array(['log10_A', 'gamma'], dtype='<U7'),
 'varied_gwb_psd_params': [np.str_('log10_A'), np.str_('gamma')],
 'gwb_psd_param_lower_lim': Array([-18.,   0.], dtype=float64),
 'gwb_psd_param_upper_lim': Array([-11.,   7.], dtype=float64)}

In [4]:
Tspan = 20 * 365.25 * 86400 # The time-span of the entire PTA
crn_bins = 5 # number of frequency-bins for the GWB
int_bins = 5 # number of frequency-bins for the non-GWB (IRN) red noise
assert int_bins >= crn_bins
f_intrin = jnp.arange(1/Tspan, (int_bins + 0.01)/Tspan, 1/Tspan) # an array of frequency-bins for the IRN process
f_common = f_intrin[:crn_bins] # an array of frequency-bins for the common process

In [5]:
o = models.UniformPrior(gwb_psd_func = chosen_psd_model,
                orf_func = chosen_orf_model,
                crn_bins = crn_bins,
                int_bins = int_bins,
                f_common = f_common, 
                f_intrin = f_intrin,
                df = 1/Tspan,
                Tspan = Tspan, 
                Npulsars = 45,
                psr_pos = jnp.zeros(45),
                gwb_helper_dictionary = gwb_helper_dictionary,
                gamma_min = 0,
                gamma_max = 7,
                log10A_min = -20 + 0.5,
                log10A_max = -11 + 0.5,
                renorm_const = 1)

In [6]:
m = LC.KDE(den = den,
            grid = grid,
            run_type_object = o,
            device_to_run_likelihood_on = 'cuda')

In [7]:
x0 = o.make_initial_guess(key = jar.key(100))

In [8]:
m.get_lnliklihood(x0)

Array(-4485.70895428, dtype=float64)

In [9]:
%timeit m.get_lnliklihood(x0)

41.9 μs ± 1.02 μs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
