# test_lensed_quasar_time_delay_8sample_clean_parent500

Clean-only 8-lens quasar time-delay forecast from `data/test_data.py`.

Key logic:
- Use only the longest time-delay pair (`max |dt|`) for each target.
- LambdaCDM inference samples only `Omegam` and `H0` for cosmology.
- MST truth is generated from population:
  - `lambda_mean_true=1.0`
  - `lambda_sigma_true=0.04`
- Per-lens external convergence is included:
  - `kext_true ~ N(0, 0.01)`
  - `kext_err = 0.01`
  - `lambda_eff = (1-kext)*lambda`
- FPD is inferred from time-delay relation and uses 3% measurement error.
- No noisy data mode: all runs use clean observations.

Three inference scenarios:
1) `fiducial_tdcosmo`: use individual MST measurements from kinematic-driven error budget.
2) `individual_mst_8pct`: use individual MST measurements with additional +8% MST error added in quadrature.
3) `parent500_no_individual`: no individual MST measurements, but add 500 parent-population MST observations, each with 1% constraint.


In [None]:
import os
import sys
from pathlib import Path
import importlib.util

import numpy as np
import matplotlib.pyplot as plt
import jax
import jax.numpy as jnp
import numpyro
import numpyro.distributions as dist
from numpyro.infer import NUTS, MCMC, init_to_value
from jax import random
import arviz as az
from corner import corner

os.environ.setdefault('HDF5_USE_FILE_LOCKING', 'FALSE')

cwd = Path.cwd().resolve()
repo_root = None

if (cwd / 'slcosmo').is_dir() and (cwd / 'data').is_dir():
    repo_root = cwd
elif (cwd / 'LensedUniverse' / 'slcosmo').is_dir():
    repo_root = cwd / 'LensedUniverse'
else:
    for candidate in [cwd, *cwd.parents]:
        if (candidate / 'slcosmo').is_dir() and (candidate / 'data').is_dir():
            repo_root = candidate
            break

if repo_root is None:
    raise RuntimeError(f'Cannot locate LensedUniverse repo root from cwd={cwd}')

workdir = repo_root
os.chdir(workdir)
if str(workdir) not in sys.path:
    sys.path.insert(0, str(workdir))

from slcosmo.tools import tool

USE_X64 = os.environ.get('SLCOSMO_USE_X64', '0').strip().lower() in {'1', 'true', 'yes', 'y', 'on'}
jax.config.update('jax_enable_x64', USE_X64)
if USE_X64:
    numpyro.enable_x64()
if any(d.platform == 'gpu' for d in jax.devices()):
    numpyro.set_platform('gpu')
else:
    numpyro.set_platform('cpu')

print('Precision mode:', 'FP64' if USE_X64 else 'FP32')
print('Repo root:', workdir)
print('JAX devices:', jax.devices())

SEED = 42
rng_np = np.random.default_rng(SEED)
np.random.seed(SEED)

RESULT_DIR = workdir / 'test'
RESULT_DIR.mkdir(parents=True, exist_ok=True)


In [None]:
# ---------------------------
# Load 8-lens dataset and keep only the longest delay pair per target
# ---------------------------
data_path = workdir / 'data' / 'test_data.py'
spec = importlib.util.spec_from_file_location('test_data', data_path)
test_data = importlib.util.module_from_spec(spec)
spec.loader.exec_module(test_data)
tdcosmo_8lens = test_data.tdcosmo_8lens

lens_names = sorted(tdcosmo_8lens.keys())

z_lens = []
z_src = []
sigma_v = []
sigma_v_err = []

pair_name = []
t_base = []
t_err = []

for name in lens_names:
    d = tdcosmo_8lens[name]
    z_lens.append(float(d['zl']))
    z_src.append(float(d['zs']))
    sigma_v.append(float(d['sigma_ap_los_kms']))
    sigma_v_err.append(float(d['sigma_ap_los_err_kms']))

    td_dict = d['time_delays_days']
    p_long = max(td_dict, key=lambda p: abs(float(td_dict[p]['dt'])))
    pair_name.append(f'{name}:{p_long}')
    dt_val = float(td_dict[p_long]['dt'])
    de = 0.5 * (float(td_dict[p_long].get('err_minus', 0.0)) + float(td_dict[p_long].get('err_plus', 0.0)))
    t_base.append(abs(dt_val))
    t_err.append(de)

z_lens = np.asarray(z_lens, dtype=float)
z_src = np.asarray(z_src, dtype=float)
sigma_v = np.asarray(sigma_v, dtype=float)
sigma_v_err = np.asarray(sigma_v_err, dtype=float)

t_base = np.asarray(t_base, dtype=float)
t_err = np.asarray(t_err, dtype=float)

n_lens = z_lens.size
n_obs = t_base.size

print('N lens:', n_lens)
print('N time-delay observations (longest pair only):', n_obs)
for i, name in enumerate(lens_names):
    print(f'  {i:02d} {name:15s} pair={pair_name[i]} sigma_v={sigma_v[i]:.1f}+-{sigma_v_err[i]:.1f}')



In [None]:
# ---------------------------
# Build mock truth and clean observations
# ---------------------------
cosmo_true = {'Omegam': 0.3, 'Omegak': 0.0, 'w0': -1.0, 'wa': 0.0, 'h0': 70.0}

Dl, Ds, Dls = tool.dldsdls(jnp.asarray(z_lens), jnp.asarray(z_src), cosmo_true, n=20)
Ddt_lens = np.asarray((1.0 + z_lens) * Dl * Ds / Dls)
Ddt_obs = Ddt_lens

c_km_day = tool.c_km_s * 86400.0
Mpc_km = tool.Mpc / 1000.0

# Fixed MST population truth used to generate mock data
lambda_mean_true = 1.0
lambda_sigma_true = 0.04
lambda_low_model = 0.6
lambda_high_model = 1.6
lambda_true_raw = tool.truncated_normal(
    lambda_mean_true,
    lambda_sigma_true,
    lambda_low_model,
    lambda_high_model,
    n_lens,
    random_state=rng_np,
)

# Recenter finite-sample draw so the mock ensemble mean stays close to 1.0
lambda_true = lambda_true_raw - np.mean(lambda_true_raw) + 1.0
lambda_true = np.clip(lambda_true, lambda_low_model + 1e-3, lambda_high_model - 1e-3)
lambda_true = lambda_true - np.mean(lambda_true) + 1.0
lambda_true = np.clip(lambda_true, lambda_low_model + 1e-3, lambda_high_model - 1e-3)
assert np.abs(np.mean(lambda_true) - 1.0) < 0.01

# Per-lens external convergence truth and measurement error
kext_sigma = 0.01
kext_true = rng_np.normal(0.0, kext_sigma, n_lens)
kext_err = np.full(n_lens, 0.01)
kext_obs = kext_true.copy()

lambda_eff_true = (1.0 - kext_true) * lambda_true

# Individual MST measurement error from kinematics (fiducial)
ADDITIONAL_MST_BUDGET = 0.04
lambda_err_frac_fiducial = 2.0 * (sigma_v_err / sigma_v) + ADDITIONAL_MST_BUDGET
lambda_err_fiducial = lambda_err_frac_fiducial * np.abs(lambda_true)

# Additional +8% individual MST error scenario (added in quadrature)
lambda_err_8pct = np.sqrt(lambda_err_fiducial**2 + (0.08 * np.abs(lambda_true))**2)

# Parent-population observations for no-individual-MST scenario
N_PARENT = 500
lambda_parent_true = tool.truncated_normal(
    lambda_mean_true,
    lambda_sigma_true,
    lambda_low_model,
    lambda_high_model,
    N_PARENT,
    random_state=rng_np,
)
lambda_parent_err = 0.01 * np.abs(lambda_parent_true)
lambda_parent_obs = rng_np.normal(lambda_parent_true, lambda_parent_err)

# Time-delay measured true value: unbiased base delays transformed by lambda_eff truth
t_measured_true = t_base * lambda_eff_true

# Infer FPD true from t = (Ddt/c) * phi * lambda_eff
phi_true = (c_km_day * t_measured_true) / (Ddt_obs * Mpc_km * lambda_eff_true)
phi_err = 0.03 * np.abs(phi_true)

# Clean observation set only
t_obs = t_measured_true.copy()
phi_obs = phi_true.copy()

# Scale phi for numerical stability
phi_scale = 10.0 ** int(np.round(-np.log10(np.median(np.abs(phi_true)))))
phi_obs_scaled = phi_obs * phi_scale
phi_err_scaled = 0.03 * np.abs(phi_true * phi_scale)

scenario_data = {
    'fiducial_tdcosmo': {
        'use_mst_measurement': True,
        'use_parent_population': False,
        'lambda_obs': lambda_true.copy(),
        'lambda_err': lambda_err_fiducial.copy(),
    },
    'individual_mst_8pct': {
        'use_mst_measurement': True,
        'use_parent_population': False,
        'lambda_obs': lambda_true.copy(),
        'lambda_err': lambda_err_8pct.copy(),
    },
    'parent500_no_individual': {
        'use_mst_measurement': False,
        'use_parent_population': True,
        'lambda_obs': np.zeros_like(lambda_true),
        'lambda_err': np.ones_like(lambda_true),
    },
}

print('phi_scale:', phi_scale)
print('lambda_mean_true:', lambda_mean_true)
print('lambda_sigma_true:', lambda_sigma_true)
print('lambda_true_raw_mean:', float(np.mean(lambda_true_raw)))
print('lambda_true_centered_mean:', float(np.mean(lambda_true)))
print('kext_sigma (true distribution):', kext_sigma)
print('kext_err (measurement):', np.unique(kext_err))
print('kext_true (%):', np.round(100.0 * kext_true, 3))
print('N_PARENT:', N_PARENT)
print('lambda_err_frac_fiducial (%):', np.round(100.0 * lambda_err_frac_fiducial, 2))
print('lambda_err_8pct_quad (%):', np.round(100.0 * (lambda_err_8pct / np.abs(lambda_true)), 2))
print('scenario settings:', {k: (v['use_mst_measurement'], v['use_parent_population']) for k, v in scenario_data.items()})


In [None]:
# ---------------------------
# LambdaCDM model: sample Omegam/H0 and infer MST population
# ---------------------------
def quasar_td_lcdm_model(
    z_lens,
    z_src,
    t_obs,
    t_err,
    phi_obs_scaled,
    phi_err_scaled,
    phi_scale,
    lambda_obs,
    lambda_err,
    kext_obs,
    kext_err,
    lambda_parent_obs,
    lambda_parent_err,
    use_mst_measurement,
    use_parent_population,
):
    Omegam = numpyro.sample('Omegam', dist.Uniform(0.1, 0.5))
    h0 = numpyro.sample('h0', dist.Uniform(0.0, 150.0))

    cosmo = {'Omegam': Omegam, 'Omegak': 0.0, 'w0': -1.0, 'wa': 0.0, 'h0': h0}

    z_lens = jnp.asarray(z_lens)
    z_src = jnp.asarray(z_src)
    t_obs = jnp.asarray(t_obs)
    t_err = jnp.asarray(t_err)
    phi_obs_scaled = jnp.asarray(phi_obs_scaled)
    phi_err_scaled = jnp.asarray(phi_err_scaled)
    phi_scale = jnp.asarray(phi_scale)
    lambda_obs = jnp.asarray(lambda_obs)
    lambda_err = jnp.asarray(lambda_err)
    kext_obs = jnp.asarray(kext_obs)
    kext_err = jnp.asarray(kext_err)
    lambda_parent_obs = jnp.asarray(lambda_parent_obs)
    lambda_parent_err = jnp.asarray(lambda_parent_err)

    lambda_mean = numpyro.sample('lambda_mean', dist.Uniform(0.5, 1.5))
    lambda_sigma = numpyro.sample('lambda_sigma', dist.LogUniform(0.001, 0.5))

    with numpyro.plate('lens', z_lens.shape[0]):
        lambda_lens = numpyro.sample(
            'lambda_true',
            dist.TruncatedNormal(lambda_mean, lambda_sigma, low=lambda_low_model, high=lambda_high_model),
        )
        kext_lens = numpyro.sample('kext', dist.Normal(0.0, kext_sigma))
        if use_mst_measurement:
            numpyro.sample('lambda_like', dist.Normal(lambda_lens, lambda_err), obs=lambda_obs)
        numpyro.sample('kext_like', dist.Normal(kext_lens, kext_err), obs=kext_obs)

    if use_parent_population:
        with numpyro.plate('parent_obs', lambda_parent_obs.shape[0]):
            lambda_parent_latent = numpyro.sample(
                'lambda_parent_true',
                dist.TruncatedNormal(lambda_mean, lambda_sigma, low=lambda_low_model, high=lambda_high_model),
            )
            numpyro.sample('lambda_parent_like', dist.Normal(lambda_parent_latent, lambda_parent_err), obs=lambda_parent_obs)

    lambda_eff = (1.0 - kext_lens) * lambda_lens

    Dl, Ds, Dls = tool.dldsdls(z_lens, z_src, cosmo, n=20)
    Ddt_lens = (1.0 + z_lens) * Dl * Ds / Dls
    Ddt_obs = Ddt_lens

    c_km_day = tool.c_km_s * 86400.0
    Mpc_km = tool.Mpc / 1000.0

    with numpyro.plate('td_obs', t_obs.shape[0]):
        phi_latent_scaled = numpyro.sample('phi_true_scaled', dist.Normal(phi_obs_scaled, phi_err_scaled))
        phi_latent = phi_latent_scaled / phi_scale
        t_model_days = (Ddt_obs * Mpc_km / c_km_day) * phi_latent * lambda_eff
        numpyro.sample('t_delay_like', dist.Normal(t_model_days, t_err), obs=t_obs)


def build_init_values(phi_obs_scaled, use_parent_population):
    values = {
        'Omegam': jnp.asarray(cosmo_true['Omegam']),
        'h0': jnp.asarray(cosmo_true['h0']),
        'lambda_mean': jnp.asarray(lambda_mean_true),
        'lambda_sigma': jnp.asarray(lambda_sigma_true),
        'lambda_true': jnp.asarray(lambda_true),
        'kext': jnp.asarray(kext_true),
        'phi_true_scaled': jnp.asarray(phi_obs_scaled),
    }
    if use_parent_population:
        values['lambda_parent_true'] = jnp.asarray(lambda_parent_true)
    return values


In [None]:
# ---------------------------
# Run clean-only scenarios
# ---------------------------
RUN_MCMC = True
TARGET_ACCEPT = 0.99
NUM_WARMUP = 500
NUM_SAMPLES = 1000
NUM_CHAINS = 4


def run_mcmc(scenario_name, scenario, key):
    init_values = build_init_values(phi_obs_scaled, bool(scenario['use_parent_population']))
    nuts = NUTS(
        quasar_td_lcdm_model,
        target_accept_prob=TARGET_ACCEPT,
        init_strategy=init_to_value(values=init_values),
    )
    mcmc = MCMC(
        nuts,
        num_warmup=NUM_WARMUP,
        num_samples=NUM_SAMPLES,
        num_chains=NUM_CHAINS,
        chain_method='vectorized',
        progress_bar=True,
    )

    mcmc.run(
        key,
        z_lens=z_lens,
        z_src=z_src,
        t_obs=t_obs,
        t_err=t_err,
        phi_obs_scaled=phi_obs_scaled,
        phi_err_scaled=phi_err_scaled,
        phi_scale=phi_scale,
        lambda_obs=scenario['lambda_obs'],
        lambda_err=scenario['lambda_err'],
        kext_obs=kext_obs,
        kext_err=kext_err,
        lambda_parent_obs=lambda_parent_obs,
        lambda_parent_err=lambda_parent_err,
        use_mst_measurement=bool(scenario['use_mst_measurement']),
        use_parent_population=bool(scenario['use_parent_population']),
    )

    extra = mcmc.get_extra_fields(group_by_chain=True)
    n_div = int(np.asarray(extra['diverging']).sum())
    print(f'[{scenario_name}] divergences:', n_div)

    posterior = mcmc.get_samples(group_by_chain=True)
    inf_data = az.from_dict(posterior=posterior)
    return inf_data


def make_single_corner(idata, var_names, outfile, color):
    fig = corner(
        idata,
        var_names=var_names,
        labels=var_names,
        color=color,
        show_titles=False,
        levels=[0.68, 0.95],
        fill_contours=True,
        plot_datapoints=False,
        smooth=0.2,
        use_math_text=True,
        contour_kwargs={'linewidths': 2.5},
        hist_kwargs={'density': True, 'linewidth': 2.5},
    )
    fig.savefig(outfile, bbox_inches='tight')
    plt.show()
    plt.close(fig)


if RUN_MCMC:
    order = [
        'fiducial_tdcosmo',
        'individual_mst_8pct',
        'parent500_no_individual',
    ]
    color_map = {
        'fiducial_tdcosmo': '#2f8aed',
        'individual_mst_8pct': '#f48c06',
        'parent500_no_individual': '#2ca25f',
    }

    keys = random.split(random.PRNGKey(SEED), len(order))
    inference_results = {}

    for i, name in enumerate(order):
        scenario = scenario_data[name]
        inf_data = run_mcmc(name, scenario, keys[i])
        inference_results[name] = inf_data

        nc_path = RESULT_DIR / f'test_quasar8_clean3_{name}.nc'
        az.to_netcdf(inf_data, nc_path)

        trace_vars = ['h0', 'Omegam', 'lambda_mean', 'lambda_sigma']
        corner_vars = ['h0', 'Omegam', 'lambda_mean', 'lambda_sigma']

        axes = az.plot_trace(inf_data, var_names=trace_vars, compact=True)
        fig_trace = axes.ravel()[0].figure
        trace_path = RESULT_DIR / f'test_quasar8_clean3_{name}_trace.pdf'
        fig_trace.savefig(trace_path, bbox_inches='tight')
        plt.close(fig_trace)

        print(f"\n[{name}] arviz summary")
        print(az.summary(inf_data, var_names=trace_vars, round_to=4))

        corner_path = RESULT_DIR / f'test_quasar8_clean3_{name}_corner.pdf'
        make_single_corner(
            inf_data,
            corner_vars,
            corner_path,
            color_map[name],
        )

        print('Saved:', nc_path)
        print('Saved:', trace_path)
        print('Saved:', corner_path)

    overlay_vars = ['h0', 'Omegam', 'lambda_mean', 'lambda_sigma']
    overlay_ranges = [(60.0, 80.0), (0.1, 0.5), (0.8, 1.2), (0.0, 0.3)]

    first = order[0]
    fig = corner(
        inference_results[first],
        var_names=overlay_vars,
        labels=overlay_vars,
        range=overlay_ranges,
        color=color_map[first],
        show_titles=False,
        levels=[0.68, 0.95],
        fill_contours=True,
        plot_datapoints=False,
        smooth=0.2,
        use_math_text=True,
        contour_kwargs={'linewidths': 2.5},
        hist_kwargs={'density': True, 'linewidth': 2.5},
    )

    for name in order[1:]:
        corner(
            inference_results[name],
            fig=fig,
            var_names=overlay_vars,
            labels=overlay_vars,
            range=overlay_ranges,
            color=color_map[name],
            show_titles=False,
            levels=[0.68, 0.95],
            fill_contours=False,
            no_fill_contours=True,
            plot_datapoints=False,
            smooth=0.2,
            use_math_text=True,
            contour_kwargs={'linewidths': 2.5},
            hist_kwargs={'density': True, 'linewidth': 2.5, 'histtype': 'step'},
        )

    overlay_path = RESULT_DIR / 'test_quasar8_clean3_overlay_full.pdf'
    fig.savefig(overlay_path, bbox_inches='tight')
    plt.show()
    plt.close(fig)
    print('Saved:', overlay_path)
