# Extensions
Notebook designed to run extension experiments, these experiments include 
-	Power spectrum plots
-	Perform inference for different clipping values

First, we want to run HMC for different values of $l_{max}$ on the full dataset, just like we did to produce Figure 7 in the original paper. This time our goal is to save the samples and their respective diagnostics, so that we can work on any other potential extensions.


In [6]:
import jax
import jax.numpy as jnp
import numpyro
from numpyro.infer import NUTS, MCMC, Predictive
from numpyro import handlers
from numpyro.diagnostics import summary, autocorrelation
import numpyro.distributions as dist
from src.models.vsh_model import*
from src.models.configuration import*
from src.data.data_utils import*
from src.plot.plots import*
from src.save_load_pkl.save_load import*
import gc

In [3]:
df = load_filtered_qso_df() # load filtered data
angles, obs, error = config_data(df)

In [None]:
def chi2_jit(angles, obs, error, theta, lmax):
    return least_square(angles, obs, error, theta, lmax=lmax, grid=False)
chi2_jit = jit(chi2_jit, static_argnames=['lmax'])


def model_for_HMC(angles, obs, error, lmax):
    total_params = count_vsh_coeffs(lmax)
    
    # Prior on all VSH coefficients (both toroidal and spheroidal)
    theta = numpyro.sample("theta", dist.Normal(0.0, 1.0).expand([total_params]))
    # Least-squares residuals: we assume Gaussian-distributed residuals
    chi2_val = chi2_jit(angles, obs, error, theta, lmax=lmax)

    # The log-likelihood is proportional to -0.5*chi^2
    numpyro.factor("likelihood", -0.5*chi2_val)

n_s = 5000 # number of samples
n_warmup = 2000 #  number of warmups 
n_chains = 6 # numbe of chains

In [None]:
rng_key = jax.random.key(0)

kernel = NUTS(model_for_HMC, target_accept_prob=0.75) # this is to make sure acceptance does not exceed 90%

posterior_samples = [] # collect posterior samples based on l
iat_values = []
for l in range(1, 8):
    print(f'l = {l}')
    # Run sampling algoeithm (HMC)
    mcmc = MCMC(kernel, num_warmup=n_warmup, num_samples=n_s, num_chains=n_chains, chain_method='sequential', progress_bar=True)
    mcmc.run(rng_key, angles = angles, obs = obs, error = error, lmax=l)
    ps = mcmc.get_samples()
    posterior_samples.append(ps)
    
    diagnostics = summary(mcmc.get_samples(group_by_chain=True))
    n_eff = diagnostics['theta']['n_eff']
    iat = estimate_iat(n_s, n_chains, n_eff, index=[1,4,5])
    iat_values.append(iat)
    print(f'Autocorrelation length estimate: {iat}')

    r_hats = diagnostics['theta']['r_hat']
    avg_r_hat = np.sum(r_hats) / len(r_hats)
    print("Average r_hat:", avg_r_hat)

    divergences = mcmc.get_extra_fields()["diverging"]  # shape: (num_samples * num_chains,)
    num_divergences = divergences.sum()
    print("Number of divergences:", num_divergences)

    # == Save results ==
    save_pickle(f'lmax_{l}', posterior_samples, dir = 'hmc_samples/posterior_samples')
    save_pickle(f'lmax_{l}', diagnostics, dir = 'hmc_samples/diagnostic_hmc')
    
    # Free memory after each iteration
    del mcmc
    gc.collect()
    jax.clear_caches()

l = 1


sample: 100%|██████████| 1500/1500 [00:33<00:00, 44.50it/s, 3 steps of size 1.08e-01. acc. prob=0.85] 
sample: 100%|██████████| 1500/1500 [00:25<00:00, 58.48it/s, 3 steps of size 8.72e-02. acc. prob=0.90]


Autocorrelation length estimate: 2
Average r_hat: 1.0005512
Number of divergences: 0


In [21]:
posterior_sample_norm = load_pickle('lmax_1', dir = 'hmc_samples/posterior_samples')[0]
diagnostics = load_pickle('lmax_1', dir = 'hmc_samples/diagnostic_hmc')
n_eff = diagnostics['theta']['n_eff']
iat = estimate_iat(n_s, n_chains, n_eff, index=[1,4,5])
print(iat)

2


In [22]:
cov_m = cov_matrix_hmc(posterior_sample_norm['theta'][::iat], indices=[1,4,5])
result_uni = jnp.mean(posterior_sample_norm['theta'], axis = 0)
params = [result_uni[1], result_uni[4], result_uni[5]]

summary_norm, v_vec, v_Sigma,_ = vsh_vector_summary(params, cov_m)
summary_norm_gal, v_vec_gal, v_Sigma_gal, _ = vsh_vector_summary_galactic(v_vec, v_Sigma)

lb_summary_ = lb_summary(v_vec_gal, v_Sigma_gal)
ra_dec_summary_ = ra_dec_summary(v_vec, v_Sigma)

In [23]:
print_summary(summary_norm, title='Result, Equatorial Coordinates')
print_summary(ra_dec_summary_)
print('')
print_summary(summary_norm_gal, title='Results, Galactic Coordinates')
print_summary(lb_summary_)

Result, Equatorial Coordinates
------------------------------
  |g| (μas/yr)        : 5.4654
  g (μas/yr)          : [ 0.02362784 -5.229347   -1.5886097 ]
  |sigma_g| (μas/yr)  : 0.3019
  sigma_g (μas/yr)    : [0.39327432 0.31327942 0.1983548 ]
  Corr_gx_gy          : -0.0731
  Corr_gx_gz          : -0.0149
  Corr_gy_gz          : -0.0609
  RA (deg)            : 270.2589
  Sigma_RA (deg)      : 4.3077
  Dec (deg)           : -16.8980
  Sigma_Dec (deg)     : 2.2590
  Corr_RA_dec         : 0.0078

Results, Galactic Coordinates
-----------------------------
  |g|_gal (μas/yr)    : 5.4654
  g_gal (μas/yr)      : [5.33483435 1.1511802  0.29092886]
  |sigma_g_gal| (μas/yr): 0.3019
  sigma_g_gal (μas/yr): [0.28359563 0.29112217 0.35633511]
  Corr_g_galx_g_galy  : 0.3072
  Corr_g_galx_g_galz  : 0.0933
  Corr_g_galy_g_galz  : -0.4341
  l (deg)             : 12.1769
  Sigma_l (deg)       : 2.8578
  b (deg)             : 3.0514
  Sigma_b (deg)       : 3.7346
  Corr_l_b            : -0.0085
