In [1]:
import jax
import jax.numpy as jnp
import numpyro
from numpyro.infer import NUTS, MCMC, Predictive
from numpyro import handlers
from numpyro.diagnostics import summary, autocorrelation
import numpyro.distributions as dist
from src.models.vsh_model import*
from src.models.configuration import*
from src.data.data_utils import*
from src.plot.plots import*
from src.save_load_pkl.save_load import*

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
def chi2_jit(angles, obs, error, theta, lmax):
    return least_square(angles, obs, error, theta, lmax=lmax, grid=False)
chi2_jit = jit(chi2_jit, static_argnames=['lmax'])


def model_for_HMC(angles, obs, error, lmax):
    total_params = count_vsh_coeffs(lmax)
    
    # Prior on all VSH coefficients (both toroidal and spheroidal)
    theta = numpyro.sample("theta", dist.Normal(0.0, 1.0).expand([total_params]))
    # Least-squares residuals: we assume Gaussian-distributed residuals
    chi2_val = chi2_jit(angles, obs, error, theta, lmax=lmax)

    # The log-likelihood is proportional to -0.5*chi^2
    numpyro.factor("likelihood", -0.5*chi2_val)

n_s = 1000 # number of samples
n_warmup = 200 #  number of warmups 
n_chains = 3 # numbe of chains

In [3]:
df = load_filtered_qso_df() # load filtered data
angles, obs, error = config_data(df)

In [4]:
rng_key = jax.random.key(0)

kernel_norm = NUTS(model_for_HMC, target_accept_prob=0.75) # this is to make sure acceptance does not exceed 90%
mcmc_norm = MCMC(kernel_norm, num_warmup=n_warmup, num_samples=n_s, num_chains=n_chains, chain_method='sequential', progress_bar=True)
mcmc_norm.run(rng_key, angles = angles, obs = obs, error = error, lmax=3)
posterior_sample_norm = mcmc_norm.get_samples()

diagnostic_norm = summary(mcmc_norm.get_samples(group_by_chain=True))

r_hat_norm = diagnostic_norm['theta']['r_hat']
n_eff_norm = diagnostic_norm['theta']['n_eff']
iat = estimate_iat(n_s, n_chains, n_eff_norm)

print('Coefficients of interest, s_10, s_11r and s_11i')
print(f'Their respsecive r_hat values are: {r_hat_norm[1]}, {r_hat_norm[4]} and {r_hat_norm[5]}')
print(f'Their respective effective sample size are : {n_eff_norm[1]}, {n_eff_norm[4]} and {n_eff_norm[5]}')

mcmc_norm.print_summary()

sample: 100%|██████████| 1200/1200 [01:25<00:00, 14.00it/s, 7 steps of size 3.58e-02. acc. prob=0.85]  
sample: 100%|██████████| 1200/1200 [00:37<00:00, 31.63it/s, 7 steps of size 3.90e-02. acc. prob=0.82] 
sample: 100%|██████████| 1200/1200 [00:43<00:00, 27.81it/s, 15 steps of size 3.40e-02. acc. prob=0.87] 


Coefficients of interest, s_10, s_11r and s_11i
Their respsecive r_hat values are: 1.0000629425048828, 0.9995251297950745 and 0.999615490436554
Their respective effective sample size are : 1987.2052544820829, 2642.5333554772883 and 3845.410651009035

                mean       std    median      5.0%     95.0%     n_eff     r_hat
  theta[0]      0.00      0.00      0.00     -0.00      0.00   1380.22      1.00
  theta[1]     -0.01      0.00     -0.01     -0.01     -0.01   1987.21      1.00
  theta[2]     -0.01      0.00     -0.01     -0.01     -0.00   3124.36      1.00
  theta[3]      0.00      0.00      0.00      0.00      0.00   3976.41      1.00
  theta[4]      0.00      0.00      0.00     -0.00      0.00   2642.53      1.00
  theta[5]     -0.01      0.00     -0.01     -0.01     -0.01   3845.41      1.00
  theta[6]      0.00      0.00      0.00      0.00      0.01   1953.43      1.00
  theta[7]     -0.00      0.00     -0.00     -0.01     -0.00   2103.70      1.00
  theta[8]      0.00

In [5]:
print(iat)

3


In [7]:
mcmc_norm.print_summary()


                mean       std    median      5.0%     95.0%     n_eff     r_hat
  theta[0]      0.00      0.00      0.00     -0.00      0.00   1380.22      1.00
  theta[1]     -0.01      0.00     -0.01     -0.01     -0.01   1987.21      1.00
  theta[2]     -0.01      0.00     -0.01     -0.01     -0.00   3124.36      1.00
  theta[3]      0.00      0.00      0.00      0.00      0.00   3976.41      1.00
  theta[4]      0.00      0.00      0.00     -0.00      0.00   2642.53      1.00
  theta[5]     -0.01      0.00     -0.01     -0.01     -0.01   3845.41      1.00
  theta[6]      0.00      0.00      0.00      0.00      0.01   1953.43      1.00
  theta[7]     -0.00      0.00     -0.00     -0.01     -0.00   2103.70      1.00
  theta[8]      0.00      0.00      0.00      0.00      0.00   4341.47      1.00
  theta[9]      0.00      0.00      0.00     -0.00      0.00   3570.90      1.00
 theta[10]      0.00      0.00      0.00      0.00      0.01   4609.03      1.00
 theta[11]      0.01      0

In [None]:
save_pickle('test',mcmc_norm)


In [9]:
mcmc = load_pickle('test')

In [17]:
import os
os.path.getsize('hmc_samples/posterior_samples/test2.pkl')

360232

In [18]:
save_pickle('test2', posterior_sample_norm, dir = 'hmc_samples/posterior_samples')
save_pickle('diag', diagnostic_norm, dir = 'hmc_samples/diagnostic_hmc')

In [19]:
os.path.getsize('hmc_samples/diagnostic_hmc/diag.pkl')

1421