# HMC

In [1]:
import jax 
import jax.numpy as jnp
from jax import jit, vmap
from jax import random
from src.models.vsh_model import*
from src.models.configuration import*
from src.data.data_utils import*
from numpyro.infer import MCMC, NUTS
import numpyro
import numpyro.distributions as dist

  """
  """
  """
  from .autonotebook import tqdm as notebook_tqdm


In [2]:
df = load_qso_dataframe()
angles, obs, error = config_data(df)

In [3]:
def model_for_HMC(angles, obs, error, lmax):
    total_params = count_vsh_coeffs(lmax)//2
    
    # Prior on all VSH coefficients (both toroidal and spheroidal)
    theta_t = numpyro.sample("theta_t", dist.Uniform(-0.05, 0.05).expand([total_params]))
    theta_s = numpyro.sample("theta_s", dist.Uniform(-0.008, 0.008).expand([total_params]))
    # Least-squares residuals: we assume Gaussian-distributed residuals
    chi2_val = least_square_hmc(angles, obs, error, theta_t, theta_s, lmax=lmax, grid=False)

    # The log-likelihood is proportional to -0.5*chi^2
    numpyro.factor("likelihood", -0.5*chi2_val)


In [4]:
kernel = NUTS(model_for_HMC, target_accept_prob=0.8)
rng_key = jax.random.key(0)
mcmc = MCMC(kernel, num_warmup=100, num_samples=2000, num_chains=4, progress_bar=True)
mcmc.run(rng_key, angles=angles, obs=obs, error=error, lmax=2)

posterior_sample = mcmc.get_samples()

  mcmc = MCMC(kernel, num_warmup=100, num_samples=2000, num_chains=4, progress_bar=True)
  0%|          | 0/2100 [00:00<?, ?it/s]2025-05-18 23:53:49.869989: E external/xla/xla/service/slow_operation_alarm.cc:73] Constant folding an instruction is taking > 1s:

  %multiply.4844 = f32[1215942]{0} multiply(%constant.5562, %broadcast.3368), metadata={op_name="jit(_body_fn)/jit(main)/while/body/while/body/jvp(jit(least_square_hmc))/vmap(jit(model_vsh_hmc))/jit(T_lm)/jit(T_lm_scalar)/jvp(jit(Y_lm))/exp" source_file="/root/Document/Gaia_Project/mem97/src/models/vsh_model.py" source_line=196}

This isn't necessarily a bug; constant-folding is inherently a trade-off between compilation time and speed at runtime. XLA has some guards that attempt to keep constant folding from taking too long, but fundamentally you'll always be able to come up with an input program that takes a long time.

If you'd like to file a bug, run with envvar XLA_FLAGS=--xla_dump_to=/tmp/foo and attach the results.
2025-05

In [5]:
mcmc.print_summary()


                mean       std    median      5.0%     95.0%     n_eff     r_hat
theta_s[0]     -0.01      0.00     -0.01     -0.01     -0.01   5961.28      1.00
theta_s[1]      0.00      0.00      0.00      0.00      0.01   8988.67      1.00
theta_s[2]     -0.01      0.00     -0.01     -0.01     -0.01  10613.53      1.00
theta_s[3]     -0.00      0.00     -0.00     -0.00     -0.00   8346.74      1.00
theta_s[4]      0.01      0.00      0.01      0.01      0.01   6399.99      1.00
theta_s[5]      0.01      0.00      0.01      0.01      0.01   5728.62      1.00
theta_s[6]     -0.00      0.00     -0.00     -0.01     -0.00   5907.73      1.00
theta_s[7]      0.01      0.00      0.01      0.00      0.01   5828.49      1.00
theta_t[0]     -0.00      0.00     -0.00     -0.00     -0.00   7829.94      1.00
theta_t[1]     -0.01      0.00     -0.01     -0.01     -0.01   9833.46      1.00
theta_t[2]      0.01      0.00      0.01      0.00      0.01   9484.70      1.00
theta_t[3]      0.01      0

In [6]:
s_lm = jnp.mean(posterior_sample['theta_s'], axis=0)
t_lm = jnp.mean(posterior_sample['theta_t'], axis=0)
std_s = jnp.std(posterior_sample['theta_s'], axis=0)
std_t = jnp.std(posterior_sample['theta_t'], axis=0)

In [7]:
spheroidal_vector_summary(s_lm, std_s**2, index = np.array([0,1,2]))

Equatorial components:
G_vec = [-1.4547749 -3.8274574 -2.1835685] +/- [0.6650398  0.08050711 0.27697843](μas/yr)
Magnitude = 4.640449523925781 +/- 0.585536539554596 (μas/yr)
RA = 249.18875122070312 +/- 17.822141647338867 (deg)
Dec = -28.070072174072266 +/- 2.8098304271698 (deg)

Galactic components:
G_vec = [ 4.47936163 -0.64733841  1.024716  ] +/- [0.15567797 0.38996001 0.59090767](μas/yr)
l = 351.77678753104533 +/- 4.89408390807209 (deg)
d = 12.757351775152259 +/- 0.42132514388443626 (deg)


In [8]:
toroidal_vector_summary(t_lm, std_t, index = np.array([0,1,2]))

R_vec = [ 5.5001535 -2.4638968 -0.7120509] +/- [17.297695 16.898705 10.669681](μas/yr)
Magnitude = 6.06873083114624 +/- 35.2105598449707 (μas/yr)
RA = 335.8691101074219 +/- 330.11083984375 (deg)
Dec = -6.738097667694092 +/- 35.832176208496094 (deg)


Try to extract covariance of VSH coefficients after determining posterior distribution

In [14]:
def cov_matrix_hmc(posterior_sample, indices=None):
    theta_samples = np.array(posterior_sample)
    cov_matrix = np.cov(theta_samples, rowvar=False)
    
    if indices is not None:
        cov_matrix = cov_matrix[np.ix_(indices, indices)]

    return cov_matrix


In [15]:
cov_slm = cov_matrix_hmc(posterior_sample["theta_s"], indices=[0,1,2])

In [16]:
correlation_slm = rho_matrix(cov_slm)
print(correlation_slm.shape)
print("Correlation matrix of spheroidal coefficients of VSH")
print(correlation_slm)
print('')
print("Example:")
print(f"rho(s11r,s11i) = {correlation_slm[0][1]}")

(3, 3)
Correlation matrix of spheroidal coefficients of VSH
[[ 1.         -0.02017899 -0.00758782]
 [-0.02017899  1.         -0.00990544]
 [-0.00758782 -0.00990544  1.        ]]

Example:
rho(s11r,s11i) = -0.020178985647500453


In [17]:
summary_equatorial, v_vec, Sigma_v = vsh_vector_summary(s_lm, cov_slm)

In [18]:
summary_equatorial

{'|g| (μas/yr)': np.float32(4.6404495),
 'g (μas/yr)': array([-1.4547749, -3.8274574, -2.1835685], dtype=float32),
 '|sigma_g| (μas/yr)': np.float64(0.5649745527185548),
 'sigma_g (μas/yr)': array([0.39173117, 0.66508133, 0.05693069]),
 'Corr_gx_gy': np.float64(0.020178985647500453),
 'Corr_gx_gz': np.float64(0.007587822208358225),
 'Corr_gy_gz': np.float64(-0.009905442567999235)}

In [19]:
alpha_delta_result = alpha_delta_summary(v_vec, Sigma_v)

In [20]:
alpha_delta_result

{'RA (deg)': np.float32(249.18875),
 'Sigma_RA (deg)': np.float64(3.9594852777931036),
 'Dec (deg)': np.float32(-28.070072),
 'Sigma_Dec (deg)': np.float64(3.9594852777931036),
 'Corr_RA_dec': np.float64(0.9721501459177418)}