# Quasar time-delay cosmology inference (from static datavectors)

This notebook mirrors `test_sne_forecast_time_delay.ipynb` but uses quasar time-delay blocks with
block-dependent sizes. It expects a preprocessed NPZ produced by `prepare_quasar_datavectors.py`.

Assumptions:
- `z_lens`, `z_src` are constants per lens.
- `fpd_true` is fixed from chain means; `fpd_err` is fractional (std/|mean|).
- `td_err` is fractional (std/|mean|) from the time-delay chain.
- `lambda_true` is generated from Normal(1, 0.05).
- `td_true` is computed from (Ddt * lambda_true) and `fpd_true`.
- Observables are set to true values (no added noise); likelihood uses measurement errors.
- MST measurement error comes from sigma_v_likelihood_prec / sigma_v_measured (per-lens fractional).
- Missing MST uses the prior only (no likelihood).


In [None]:
import numpy as np
import jax
import jax.numpy as jnp
import numpyro

from slcosmo.tools import tool

jax.config.update("jax_enable_x64", True)
numpyro.enable_x64()

if any(d.platform == "gpu" for d in jax.devices()):
    numpyro.set_platform("gpu")
else:
    numpyro.set_platform("cpu")

print("JAX devices:", jax.devices())
print("JAX default backend:", jax.default_backend())

SEED = 42
rng = np.random.default_rng(SEED)


In [None]:
# ---------------------------
# Load processed NPZ
# ---------------------------
DATA_NPZ = '/users/tianli/Temp_data/quasar_datavectors_seed6_processed.npz'
d = np.load(DATA_NPZ)

z_lens = d['z_lens']
z_src = d['z_src']
fpd_true = d['fpd_true']
fpd_err = d['fpd_err']
td_err = d['td_err']
sigma_v_obs = d['sigma_v_obs']
sigma_v_frac_err = d['sigma_v_frac_err']
mst_err = sigma_v_frac_err
block_id = d['block_id']
lens_id = d['lens_id']
pair_id = d['pair_id']

print('Total time-delay measurements:', z_lens.size)
print('Finite MST entries:', np.isfinite(mst_err).sum())


In [None]:
# ---------------------------
# Lens-level MST error and generated MST true
# ---------------------------
lens_uid = block_id.astype(np.int64) * 1_000_000 + lens_id.astype(np.int64)
unique_uid, first_idx, inv = np.unique(lens_uid, return_index=True, return_inverse=True)

mst_mask = np.isfinite(sigma_v_frac_err[first_idx])
lambda_true = rng.normal(1.0, 0.05, size=mst_mask.shape)
lambda_obs = lambda_true
lambda_err = sigma_v_frac_err[first_idx] * np.abs(lambda_obs)

lambda_obs_safe = np.where(mst_mask, lambda_obs, 0.0)
lambda_err_safe = np.where(mst_mask, lambda_err, 1.0)

print('Unique lenses:', unique_uid.size)
print('MST observed lenses:', mst_mask.sum())


In [None]:
# ---------------------------
# Compute time-delay true from fpd_true and lambda_true
# ---------------------------
cosmo_true = {'Omegam': 0.32, 'Omegak': 0.0, 'w0': -1.0, 'wa': 0.0, 'h0': 70.0}

zl_j = jnp.asarray(z_lens)
zs_j = jnp.asarray(z_src)
Dl, Ds, Dls = tool.dldsdls(zl_j, zs_j, cosmo_true, n=20)
Ddt_geom = (1.0 + zl_j) * Dl * Ds / Dls
Ddt_geom = np.asarray(Ddt_geom)

c_km_day = tool.c_km_s * 86400.0
Mpc_km = tool.Mpc / 1000.0

lambda_true_obs = lambda_true[inv]
td_true = (Ddt_geom * lambda_true_obs * Mpc_km / c_km_day) * fpd_true


In [None]:
# ---------------------------
# Scale Fermat potentials for stability
# ---------------------------
def scale_phi(phi_obs):
    finite = np.isfinite(phi_obs) & (phi_obs != 0)
    if not np.any(finite):
        return phi_obs, 1.0
    median = np.median(np.abs(phi_obs[finite]))
    if (not np.isfinite(median)) or median == 0:
        return phi_obs, 1.0
    exp = int(np.round(-np.log10(median)))
    scale = 10.0 ** exp
    return phi_obs * scale, scale

phi_obs_scaled, phi_scale = scale_phi(fpd_true)
phi_err_scaled = fpd_err * np.abs(phi_obs_scaled)

print('phi_scale:', phi_scale)
print('phi_obs_scaled median:', np.median(np.abs(phi_obs_scaled)))


In [None]:
import numpyro.distributions as dist
from numpyro.infer import NUTS, MCMC
from jax import random

cosmo_prior = {
    'w0_up': 0.0,   'w0_low': -2.0,
    'wa_up': 2.0,   'wa_low': -2.0,
    'omegak_up': 1.0, 'omegak_low': -1.0,
    'h0_up': 80.0,  'h0_low': 60.0,
    'omegam_up': 0.5, 'omegam_low': 0.1,
}

def cosmology_model(kind, cosmo_prior, sample_h0=True):
    cosmo = {
        'Omegam': numpyro.sample('Omegam', dist.Uniform(cosmo_prior['omegam_low'], cosmo_prior['omegam_up'])),
        'Omegak': 0.0,
        'w0': -1.0,
        'wa': 0.0,
        'h0': 70.0,
    }
    if kind in ['wcdm', 'owcdm', 'waw0cdm', 'owaw0cdm']:
        cosmo['w0'] = numpyro.sample('w0', dist.Uniform(cosmo_prior['w0_low'], cosmo_prior['w0_up']))
    if kind in ['waw0cdm', 'owaw0cdm']:
        cosmo['wa'] = numpyro.sample('wa', dist.Uniform(cosmo_prior['wa_low'], cosmo_prior['wa_up']))
    if kind in ['owcdm', 'owaw0cdm']:
        cosmo['Omegak'] = numpyro.sample('Omegak', dist.Uniform(cosmo_prior['omegak_low'], cosmo_prior['omegak_up']))
    if sample_h0:
        cosmo['h0'] = numpyro.sample('h0', dist.Uniform(cosmo_prior['h0_low'], cosmo_prior['h0_up']))
    return cosmo

def quasar_td_model(zl, zs, t_obs, t_err, phi_obs, phi_err, lens_index, lambda_obs, lambda_err, mst_mask, phi_scale):
    cosmo = cosmology_model('waw0cdm', cosmo_prior, sample_h0=True)
    lambda_mean = numpyro.sample('lambda_mean', dist.Uniform(0.9, 1.1))
    lambda_sig = numpyro.sample('lambda_sig', dist.TruncatedNormal(0.05, 0.5, low=0.0, high=0.2))

    zl = jnp.asarray(zl)
    zs = jnp.asarray(zs)
    t_obs = jnp.asarray(t_obs)
    t_err = jnp.asarray(t_err)
    phi_obs = jnp.asarray(phi_obs)
    phi_err = jnp.asarray(phi_err)
    lens_index = jnp.asarray(lens_index)
    lambda_obs = jnp.asarray(lambda_obs)
    lambda_err = jnp.asarray(lambda_err)
    mst_mask = jnp.asarray(mst_mask)
    phi_scale = jnp.asarray(phi_scale)

    n_lens = lambda_obs.shape[0]
    with numpyro.plate('lens', n_lens):
        lambda_true = numpyro.sample('lambda_true', dist.Normal(lambda_mean, lambda_sig))
        numpyro.sample('lambda_like', dist.Normal(lambda_true, lambda_err).mask(mst_mask), obs=lambda_obs)

    Dl, Ds, Dls = tool.dldsdls(zl, zs, cosmo, n=20)
    Ddt_geom = (1.0 + zl) * Dl * Ds / Dls

    with numpyro.plate('td_obs', zl.shape[0]):
        phi_true_scaled = numpyro.sample('phi_true_scaled', dist.Normal(phi_obs, phi_err))
        lambda_m = lambda_true[lens_index]
        Ddt_true = Ddt_geom * lambda_m
        phi_true = phi_true_scaled / phi_scale
        t_model_days = (Ddt_true * Mpc_km / c_km_day) * phi_true
        numpyro.sample('t_delay_like', dist.Normal(t_model_days, t_err), obs=t_obs)


In [None]:
# ---------------------------
# Run MCMC (optional)
# ---------------------------
RUN_MCMC = False

MAX_OBS = None  # e.g., 2000
if MAX_OBS is not None and z_lens.size > MAX_OBS:
    idx = rng.choice(z_lens.size, size=MAX_OBS, replace=False)
    z_lens_s = z_lens[idx]
    z_src_s = z_src[idx]
    t_obs_s = td_true[idx]
    phi_obs_s = phi_obs_scaled[idx]
    t_err_s = td_err[idx] * np.abs(t_obs_s)
    phi_err_s = fpd_err[idx] * np.abs(phi_obs_s)
    lens_index_s = inv[idx]
else:
    z_lens_s = z_lens
    z_src_s = z_src
    t_obs_s = td_true
    phi_obs_s = phi_obs_scaled
    t_err_s = td_err * np.abs(t_obs_s)
    phi_err_s = fpd_err * np.abs(phi_obs_s)
    lens_index_s = inv

if RUN_MCMC:
    nuts = NUTS(quasar_td_model, target_accept_prob=0.8)
    mcmc = MCMC(nuts, num_warmup=500, num_samples=1000, num_chains=4, chain_method='vectorized')
    key = random.PRNGKey(0)
    mcmc.run(
        key,
        zl=z_lens_s,
        zs=z_src_s,
        t_obs=t_obs_s,
        t_err=t_err_s,
        phi_obs=phi_obs_s,
        phi_err=phi_err_s,
        lens_index=lens_index_s,
        lambda_obs=lambda_obs_safe,
        lambda_err=lambda_err_safe,
        mst_mask=mst_mask,
        phi_scale=phi_scale,
    )

    import arviz as az
    import matplotlib.pyplot as plt
    import corner

    var_names = ['h0', 'Omegam', 'w0', 'wa', 'lambda_mean', 'lambda_sig']
    posterior = mcmc.get_samples(group_by_chain=True)
    inf_data = az.from_dict(posterior=posterior)
    print(az.summary(inf_data, var_names=var_names))

    az.plot_trace(inf_data, var_names=var_names, compact=True)
    plt.tight_layout()
    plt.show()

    corner_df = az.extract(inf_data, var_names=var_names).to_dataframe()
    corner.corner(corner_df, labels=var_names)
    plt.show()
