# test_joint_hmc

Notebook for joint inference using 4 probes (`dspl`, `lens_kinematic`, `sne`, `quasar`).
Default behavior follows `hmc_scripts/run_joint_hmc.py` (clean-only inference, `target_accept_prob=0.99`).


In [None]:
import os
os.environ.setdefault('HDF5_USE_FILE_LOCKING', 'FALSE')

from pathlib import Path
import sys
import runpy

workdir = Path.cwd()
if (workdir / 'hmc_scripts').exists() is False:
    workdir = Path('/users/tianli/LensedUniverse')
os.chdir(workdir)
if str(workdir) not in sys.path:
    sys.path.insert(0, str(workdir))

RESULT_DIR = Path('/mnt/lustre/tianli/LensedUniverse_result')
FIG_DIR = Path('result')
FIG_DIR.mkdir(parents=True, exist_ok=True)

print('workdir =', workdir)
print('result dir =', RESULT_DIR)


In [None]:
# Run joint inference script
# Note: current script default is clean-only (RUN_NOISY_INFERENCE=False).
RUN_INFERENCE = True

if RUN_INFERENCE:
    script_path = Path('hmc_scripts/run_joint_hmc.py')
    if not script_path.exists():
        raise FileNotFoundError(f'Not found: {script_path}')
    print('[RUN] executing', script_path)
    runpy.run_path(str(script_path), run_name='__main__')
    print('[DONE] joint inference script finished')
else:
    print('Skip inference run (RUN_INFERENCE=False)')


In [None]:
import arviz as az
import numpy as np

clean_path = RESULT_DIR / 'joint_clean.nc'
if not clean_path.exists():
    raise FileNotFoundError(f'joint clean file not found: {clean_path}')

idata_clean = az.from_netcdf(clean_path)
summary_vars = [v for v in ['h0', 'Omegam', 'w0', 'wa', 'lambda_mean', 'lambda_sigma', 'gamma_mean', 'gamma_sigma', 'beta_mean', 'beta_sigma'] if v in idata_clean.posterior]
print(az.summary(idata_clean, var_names=summary_vars))


In [None]:
import corner
import matplotlib.pyplot as plt

plot_vars = [v for v in ['h0', 'Omegam', 'w0', 'wa', 'lambda_mean', 'lambda_sigma', 'gamma_mean', 'gamma_sigma', 'beta_mean', 'beta_sigma'] if v in idata_clean.posterior]
if not plot_vars:
    raise RuntimeError('No expected vars found in joint_clean posterior.')

samples = np.column_stack([
    np.asarray(idata_clean.posterior[v]).reshape(-1)
    for v in plot_vars
])

finite_mask = np.isfinite(samples).all(axis=1)
samples = samples[finite_mask]
if samples.shape[0] < 10:
    raise RuntimeError('Too few finite samples for corner plot.')

fig = corner.corner(
    samples,
    labels=plot_vars,
    show_titles=True,
    title_fmt='.3f',
    quantiles=[0.16, 0.5, 0.84],
)

pdf_out = FIG_DIR / 'joint_clean_corner_from_test_notebook.pdf'
png_out = FIG_DIR / 'joint_clean_corner_from_test_notebook.png'
fig.savefig(pdf_out, dpi=200, bbox_inches='tight')
fig.savefig(png_out, dpi=200, bbox_inches='tight')
plt.show()

print('Saved:', pdf_out)
print('Saved:', png_out)
