# Extensions
Notebook designed to run extension experiments, these experiments include 
-	Power spectrum plots
-	Perform inference for different clipping values

First, we want to run HMC for different values of $l_{max}$ on the full dataset, just like we did to produce Figure 7 in the original paper. This time our goal is to save the samples and their respective diagnostics, so that we can work on any other potential extensions.


In [2]:
import jax
import jax.numpy as jnp
import numpyro
from numpyro.infer import NUTS, MCMC, Predictive
from numpyro import handlers
from numpyro.diagnostics import summary, autocorrelation
import numpyro.distributions as dist
from src.models.vsh_model import*
from src.models.configuration import*
from src.data.data_utils import*
from src.save_load_pkl.save_load import*
import gc

In [3]:
df = load_filtered_qso_df() # load filtered data
angles, obs, error = config_data(df)

In [4]:
def chi2_jit(angles, obs, error, theta, lmax):
    return least_square(angles, obs, error, theta, lmax=lmax, grid=False)
chi2_jit = jit(chi2_jit, static_argnames=['lmax'])


def model_for_HMC(angles, obs, error, lmax):
    total_params = count_vsh_coeffs(lmax)
    
    # Prior on all VSH coefficients (both toroidal and spheroidal)
    theta = numpyro.sample("theta", dist.Normal(0.0, 1.0).expand([total_params]))
    # Least-squares residuals: we assume Gaussian-distributed residuals
    chi2_val = chi2_jit(angles, obs, error, theta, lmax=lmax)

    # The log-likelihood is proportional to -0.5*chi^2
    numpyro.factor("likelihood", -0.5*chi2_val)

n_s = 5000 # number of samples
n_warmup = 2000 #  number of warmups 
n_chains = 6 # numbe of chains

In [5]:
rng_key = jax.random.key(0)

kernel = NUTS(model_for_HMC, target_accept_prob=0.75) # this is to make sure acceptance does not exceed 90%

posterior_samples = [] # collect posterior samples based on l
iat_values = []
for l in range(1, 8):
    print(f'l = {l}')
    # Run sampling algoeithm (HMC)
    mcmc = MCMC(kernel, num_warmup=n_warmup, num_samples=n_s, num_chains=n_chains, chain_method='sequential', progress_bar=True)
    mcmc.run(rng_key, angles = angles, obs = obs, error = error, lmax=l)
    ps = mcmc.get_samples()
    posterior_samples.append(ps)
    
    diagnostics = summary(mcmc.get_samples(group_by_chain=True))
    n_eff = diagnostics['theta']['n_eff']
    iat = estimate_iat(n_s, n_chains, n_eff, index=[1,4,5])
    iat_values.append(iat)
    print(f'Autocorrelation length estimate: {iat}')

    r_hats = diagnostics['theta']['r_hat']
    avg_r_hat = np.sum(r_hats) / len(r_hats)
    print("Average r_hat:", avg_r_hat)

    divergences = mcmc.get_extra_fields()["diverging"]  # shape: (num_samples * num_chains,)
    num_divergences = divergences.sum()
    print("Number of divergences:", num_divergences)

    # == Save results ==
    save_pickle(f'lmax_{l}', posterior_samples, dir = 'hmc_samples/posterior_samples')
    save_pickle(f'lmax_{l}', diagnostics, dir = 'hmc_samples/diagnostic_hmc')
    
    # Free memory after each iteration
    del mcmc
    gc.collect()
    jax.clear_caches()

l = 1


  0%|          | 0/7000 [00:00<?, ?it/s]2025-06-21 15:20:49.040540: E external/xla/xla/service/slow_operation_alarm.cc:73] Constant folding an instruction is taking > 1s:

  %multiply.1986 = f32[1212154]{0} multiply(%constant.2838, %broadcast.1601), metadata={op_name="jit(_body_fn)/jit(main)/while/body/while/body/jvp(jit(chi2_jit))/jit(least_square)/vmap(jit(model_vsh))/jit(T_lm)/jit(T_lm_scalar)/jvp(jit(Y_lm))/mul" source_file="/home/riccardo_mancini/Gaia_EDR3/src/models/vsh_model.py" source_line=198}

This isn't necessarily a bug; constant-folding is inherently a trade-off between compilation time and speed at runtime. XLA has some guards that attempt to keep constant folding from taking too long, but fundamentally you'll always be able to come up with an input program that takes a long time.

If you'd like to file a bug, run with envvar XLA_FLAGS=--xla_dump_to=/tmp/foo and attach the results.
2025-06-21 15:20:49.042940: E external/xla/xla/service/slow_operation_alarm.cc:140] The ope

Autocorrelation length estimate: 2
Average r_hat: 1.0000461
Number of divergences: 0
l = 2


  0%|          | 0/7000 [00:00<?, ?it/s]2025-06-21 15:29:34.916909: E external/xla/xla/service/slow_operation_alarm.cc:73] Constant folding an instruction is taking > 2s:

  %multiply.6533 = c64[1212154,3]{1,0} multiply(%broadcast.4275, %constant.6969), metadata={op_name="jit(_body_fn)/jit(main)/while/body/while/body/jvp(jit(chi2_jit))/jit(least_square)/vmap(jit(model_vsh))/jit(S_lm)/jit(S_lm_scalar)/mul" source_file="/home/riccardo_mancini/Gaia_EDR3/src/models/vsh_model.py" source_line=559}

This isn't necessarily a bug; constant-folding is inherently a trade-off between compilation time and speed at runtime. XLA has some guards that attempt to keep constant folding from taking too long, but fundamentally you'll always be able to come up with an input program that takes a long time.

If you'd like to file a bug, run with envvar XLA_FLAGS=--xla_dump_to=/tmp/foo and attach the results.
2025-06-21 15:29:28.499969: E external/xla/xla/service/slow_operation_alarm.cc:140] The operation took

Autocorrelation length estimate: 1
Average r_hat: 0.99993956
Number of divergences: 0
l = 3


  0%|          | 0/7000 [00:00<?, ?it/s]2025-06-21 15:42:46.123251: E external/xla/xla/service/slow_operation_alarm.cc:73] Constant folding an instruction is taking > 4s:

  %concatenate.192 = f32[1212154,3]{1,0} concatenate(%constant.8972, %constant.8976, %broadcast.7631), dimensions={1}, metadata={op_name="jit(_body_fn)/jit(main)/while/body/while/body/jvp(jit(chi2_jit))/jit(least_square)/vmap(jit(model_vsh))/jit(T_lm)/jit(T_lm_scalar)/jit(basis_vectors)/concatenate" source_file="/home/riccardo_mancini/Gaia_EDR3/src/models/vsh_model.py" source_line=288}

This isn't necessarily a bug; constant-folding is inherently a trade-off between compilation time and speed at runtime. XLA has some guards that attempt to keep constant folding from taking too long, but fundamentally you'll always be able to come up with an input program that takes a long time.

If you'd like to file a bug, run with envvar XLA_FLAGS=--xla_dump_to=/tmp/foo and attach the results.
2025-06-21 15:42:46.191041: E external

Autocorrelation length estimate: 2
Average r_hat: 0.99999595
Number of divergences: 0
l = 4


sample: 100%|██████████| 7000/7000 [05:40<00:00, 20.53it/s, 15 steps of size 1.33e-01. acc. prob=0.85]
sample: 100%|██████████| 7000/7000 [04:30<00:00, 25.84it/s, 7 steps of size 1.36e-01. acc. prob=0.85] 
sample: 100%|██████████| 7000/7000 [04:22<00:00, 26.67it/s, 15 steps of size 1.41e-01. acc. prob=0.83]
sample: 100%|██████████| 7000/7000 [04:03<00:00, 28.69it/s, 7 steps of size 1.44e-01. acc. prob=0.83] 
sample: 100%|██████████| 7000/7000 [04:23<00:00, 26.52it/s, 15 steps of size 1.39e-01. acc. prob=0.84] 
sample: 100%|██████████| 7000/7000 [04:47<00:00, 24.31it/s, 7 steps of size 1.33e-01. acc. prob=0.85] 


Autocorrelation length estimate: 2
Average r_hat: 1.0000026
Number of divergences: 0
l = 5


sample: 100%|██████████| 7000/7000 [08:23<00:00, 13.89it/s, 15 steps of size 1.24e-01. acc. prob=0.84]   
sample: 100%|██████████| 7000/7000 [07:14<00:00, 16.12it/s, 15 steps of size 1.27e-01. acc. prob=0.83]  
sample: 100%|██████████| 7000/7000 [07:17<00:00, 16.01it/s, 15 steps of size 1.27e-01. acc. prob=0.84]  
sample: 100%|██████████| 7000/7000 [07:31<00:00, 15.50it/s, 15 steps of size 1.24e-01. acc. prob=0.84]  
sample: 100%|██████████| 7000/7000 [07:57<00:00, 14.66it/s, 15 steps of size 1.08e-01. acc. prob=0.88]  
sample: 100%|██████████| 7000/7000 [07:23<00:00, 15.79it/s, 7 steps of size 1.26e-01. acc. prob=0.84]   


Autocorrelation length estimate: 2
Average r_hat: 0.999972
Number of divergences: 0
l = 6


sample: 100%|██████████| 7000/7000 [12:18<00:00,  9.48it/s, 15 steps of size 1.05e-01. acc. prob=0.87]   
sample: 100%|██████████| 7000/7000 [11:09<00:00, 10.45it/s, 15 steps of size 1.08e-01. acc. prob=0.85]  
sample: 100%|██████████| 7000/7000 [11:07<00:00, 10.49it/s, 31 steps of size 1.06e-01. acc. prob=0.87]  
sample: 100%|██████████| 7000/7000 [10:57<00:00, 10.64it/s, 15 steps of size 1.18e-01. acc. prob=0.83]  
sample: 100%|██████████| 7000/7000 [10:55<00:00, 10.67it/s, 15 steps of size 1.10e-01. acc. prob=0.85]  
sample: 100%|██████████| 7000/7000 [10:46<00:00, 10.83it/s, 7 steps of size 1.19e-01. acc. prob=0.83]   


Autocorrelation length estimate: 2
Average r_hat: 0.9999835
Number of divergences: 0
l = 7


sample: 100%|██████████| 7000/7000 [16:42<00:00,  6.98it/s, 15 steps of size 1.04e-01. acc. prob=0.85]   
sample: 100%|██████████| 7000/7000 [15:31<00:00,  7.52it/s, 15 steps of size 1.18e-01. acc. prob=0.80]  
sample: 100%|██████████| 7000/7000 [15:12<00:00,  7.67it/s, 31 steps of size 1.05e-01. acc. prob=0.84]  
sample: 100%|██████████| 7000/7000 [15:08<00:00,  7.70it/s, 15 steps of size 1.05e-01. acc. prob=0.84]  
sample: 100%|██████████| 7000/7000 [15:00<00:00,  7.78it/s, 15 steps of size 1.04e-01. acc. prob=0.84]  
sample: 100%|██████████| 7000/7000 [15:51<00:00,  7.36it/s, 31 steps of size 1.19e-01. acc. prob=0.80]   


Autocorrelation length estimate: 3
Average r_hat: 1.000038
Number of divergences: 0


In [21]:
posterior_sample_norm = load_pickle('lmax_1', dir = 'hmc_samples/posterior_samples')[0]
diagnostics = load_pickle('lmax_1', dir = 'hmc_samples/diagnostic_hmc')
n_eff = diagnostics['theta']['n_eff']
iat = estimate_iat(n_s, n_chains, n_eff, index=[1,4,5])
print(iat)

2


In [22]:
cov_m = cov_matrix_hmc(posterior_sample_norm['theta'][::iat], indices=[1,4,5])
result_uni = jnp.mean(posterior_sample_norm['theta'], axis = 0)
params = [result_uni[1], result_uni[4], result_uni[5]]

summary_norm, v_vec, v_Sigma,_ = vsh_vector_summary(params, cov_m)
summary_norm_gal, v_vec_gal, v_Sigma_gal, _ = vsh_vector_summary_galactic(v_vec, v_Sigma)

lb_summary_ = lb_summary(v_vec_gal, v_Sigma_gal)
ra_dec_summary_ = ra_dec_summary(v_vec, v_Sigma)

In [23]:
print_summary(summary_norm, title='Result, Equatorial Coordinates')
print_summary(ra_dec_summary_)
print('')
print_summary(summary_norm_gal, title='Results, Galactic Coordinates')
print_summary(lb_summary_)

Result, Equatorial Coordinates
------------------------------
  |g| (μas/yr)        : 5.4654
  g (μas/yr)          : [ 0.02362784 -5.229347   -1.5886097 ]
  |sigma_g| (μas/yr)  : 0.3019
  sigma_g (μas/yr)    : [0.39327432 0.31327942 0.1983548 ]
  Corr_gx_gy          : -0.0731
  Corr_gx_gz          : -0.0149
  Corr_gy_gz          : -0.0609
  RA (deg)            : 270.2589
  Sigma_RA (deg)      : 4.3077
  Dec (deg)           : -16.8980
  Sigma_Dec (deg)     : 2.2590
  Corr_RA_dec         : 0.0078

Results, Galactic Coordinates
-----------------------------
  |g|_gal (μas/yr)    : 5.4654
  g_gal (μas/yr)      : [5.33483435 1.1511802  0.29092886]
  |sigma_g_gal| (μas/yr): 0.3019
  sigma_g_gal (μas/yr): [0.28359563 0.29112217 0.35633511]
  Corr_g_galx_g_galy  : 0.3072
  Corr_g_galx_g_galz  : 0.0933
  Corr_g_galy_g_galz  : -0.4341
  l (deg)             : 12.1769
  Sigma_l (deg)       : 2.8578
  b (deg)             : 3.0514
  Sigma_b (deg)       : 3.7346
  Corr_l_b            : -0.0085
