# Combine_forecast steps

This notebook breaks the `Combine_forecast.py` pipeline into sequential steps for easier debugging.

In [None]:
import os
import sys
from pathlib import Path

repo_root = Path.cwd()
if (repo_root / "LensedUniverse").is_dir():
    workdir = repo_root / "LensedUniverse"
else:
    workdir = repo_root

os.chdir(workdir)
sys.path.insert(0, str(workdir))

# Set test mode and data paths as needed
os.environ.setdefault("COMBINE_FORECAST_TEST", "1")
os.environ.setdefault("SLCOSMO_DATA_DIR", "../slcosmo")
os.environ.setdefault("OTHER_FORECAST_DIR", "../SLCOSMO/other_forecast")

## 1. Imports and setup

In [None]:
import jax
import jax.numpy as jnp
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import corner
import arviz as az

import numpyro
import numpyro.distributions as dist
from numpyro.infer import NUTS, MCMC
from jax import random

from SLCOSMO import SLCOSMO, SLmodel, tool

TEST_MODE = os.environ.get("COMBINE_FORECAST_TEST") == "1"
DATA_DIR = os.environ.get("SLCOSMO_DATA_DIR", os.path.join("..", "slcosmo"))
OTHER_FORECAST_DIR = os.environ.get("OTHER_FORECAST_DIR", os.path.join("..", "SLCOSMO", "other_forecast"))

jax.config.update("jax_enable_x64", True)
numpyro.set_platform("gpu")
numpyro.enable_x64()

slcosmo = SLCOSMO()
model_instance = SLmodel(slcosmo)

SEED = 42
rng_np = np.random.default_rng(SEED)
np.random.seed(SEED)

## 2. Cosmology model & priors

In [None]:
def cosmology_model(kind, cosmo_prior, sample_h0=True):
    cosmo = {
        "Omegam": numpyro.sample("Omegam", dist.Uniform(cosmo_prior["omegam_low"], cosmo_prior["omegam_up"])),
        "Omegak": 0.0,
        "w0": -1.0,
        "wa": 0.0,
        "h0": 70.0,
    }
    if kind in ["wcdm", "owcdm", "waw0cdm", "owaw0cdm"]:
        cosmo["w0"] = numpyro.sample("w0", dist.Uniform(cosmo_prior["w0_low"], cosmo_prior["w0_up"]))
    if kind in ["waw0cdm", "owaw0cdm"]:
        cosmo["wa"] = numpyro.sample("wa", dist.Uniform(cosmo_prior["wa_low"], cosmo_prior["wa_up"]))
    if kind in ["owcdm", "owaw0cdm"]:
        cosmo["Omegak"] = numpyro.sample("Omegak", dist.Uniform(cosmo_prior["omegak_low"], cosmo_prior["omegak_up"]))
    if sample_h0:
        cosmo["h0"] = numpyro.sample("h0", dist.Uniform(cosmo_prior["h0_low"], cosmo_prior["h0_up"]))
    return cosmo

cosmo_prior = {
    "w0_up": 0.0,   "w0_low": -2.0,
    "wa_up": 2.0,   "wa_low": -2.0,
    "omegak_up": 1.0, "omegak_low": -1.0,
    "h0_up": 80.0,  "h0_low": 60.0,
    "omegam_up": 0.5, "omegam_low": 0.1,
}

cosmo_true = {"Omegam": 0.32, "Omegak": 0.0, "w0": -1.0, "wa": 0.0, "h0": 70.0}

## 3. DSPL mock data

In [None]:
data_dspl = np.loadtxt(os.path.join(DATA_DIR, "EuclidDSPLs_1.txt"))
data_dspl = data_dspl[(data_dspl[:, 5] < 0.95)]

zl_dspl  = data_dspl[:, 0]
zs1_dspl = data_dspl[:, 1]
zs2_true_cat = data_dspl[:, 2]

beta_err_dspl = data_dspl[:, 6]
model_vel_dspl = data_dspl[:, 11]

m_ok = (zs2_true_cat > zs1_dspl)
zl_dspl  = zl_dspl[m_ok]
zs1_dspl = zs1_dspl[m_ok]
zs2_true_cat = zs2_true_cat[m_ok]
beta_err_dspl = beta_err_dspl[m_ok]
model_vel_dspl = model_vel_dspl[m_ok]

N_dspl = len(zl_dspl)
is_photo = (rng_np.random(N_dspl) < 0.60)
zs2_err = np.where(is_photo, 0.1, 1e-4)
zs2_obs = zs2_true_cat + rng_np.normal(0.0, zs2_err)

eps = 1e-3
bad = zs2_obs <= (zs1_dspl + eps)
for _ in range(20):
    if not np.any(bad):
        break
    zs2_obs[bad] = zs2_true_cat[bad] + rng_np.normal(0.0, zs2_err[bad])
    bad = zs2_obs <= (zs1_dspl + eps)
zs2_obs = np.maximum(zs2_obs, zs1_dspl + eps)

Dl1, Ds1, Dls1 = tool.compute_distances(zl_dspl, zs1_dspl, cosmo_true)
Dl2, Ds2, Dls2 = tool.compute_distances(zl_dspl, zs2_true_cat, cosmo_true)
beta_geom_dspl = Dls1 * Ds2 / (Ds1 * Dls2)

lambda_true_dspl = tool.truncated_normal(1.0, 0.05, 0.85, 1.15, N_dspl, random_state=rng_np)
lambda_err_dspl = lambda_true_dspl * 0.06
lambda_obs_dspl = lambda_true_dspl + np.random.normal(0.0, lambda_err_dspl)

true_vel_dspl = model_vel_dspl * jnp.sqrt(lambda_true_dspl)
vel_err_dspl = 0.03 * true_vel_dspl
obs_vel_dspl = true_vel_dspl + np.random.normal(0.0, vel_err_dspl)

beta_true_dspl = tool.beta_antimst(beta_geom_dspl, mst=lambda_true_dspl)
beta_obs_dspl = tool.truncated_normal(beta_true_dspl, beta_err_dspl, 0.0, 1.0, random_state=rng_np)

dspl_data = {
    "zl": zl_dspl,
    "zs1": zs1_dspl,
    "zs2_cat": zs2_true_cat,
    "zs2_obs": zs2_true_cat,
    "zs2_err": zs2_err,
    "is_photo": is_photo.astype(np.int32),
    "beta_obs": beta_true_dspl,
    "beta_err": beta_err_dspl,
    "v_model": model_vel_dspl,
    "v_obs": obs_vel_dspl,
    "v_err": vel_err_dspl,
    "lambda_err": lambda_err_dspl,
    "lambda_obs": lambda_true_dspl,
}

photo_z = True

## 4. Lens + kinematics mock data

In [None]:
LUT = np.load(os.path.join(DATA_DIR, "velocity_disp_table.npy"))
N1, N2, N3, N4 = LUT.shape
thetaE_grid = np.linspace(0.5, 3.0, N1)
gamma_grid  = np.linspace(1.2, 2.8, N2)
Re_grid     = np.linspace(0.15, 3.0, N3)
beta_grid   = np.linspace(-0.5, 0.8, N4)
jampy_interp = tool.make_4d_interpolant(thetaE_grid, gamma_grid, Re_grid, beta_grid, LUT)

Euclid_GG_data = np.loadtxt(os.path.join(DATA_DIR, "Euclid_len.txt"))
zl_lens = Euclid_GG_data[:, 0]
zs_lens = Euclid_GG_data[:, 1]
Ein_lens = Euclid_GG_data[:, 2]
re_lens = Euclid_GG_data[:, 5]

mask_lens = (Ein_lens >= 0.6) & (re_lens >= 0.25) & (re_lens <= 2.8)
zl_lens = zl_lens[mask_lens]
zs_lens = zs_lens[mask_lens]
thetaE_lens = Ein_lens[mask_lens]
re_lens = re_lens[mask_lens]

dl_lens, ds_lens, dls_lens = tool.dldsdls(zl_lens, zs_lens, cosmo_true, n=20)
N_lens = len(zl_lens)

gamma_true_lens = tool.truncated_normal(2.0, 0.2, 1.5, 2.5, N_lens, random_state=rng_np)
beta_true_lens  = tool.truncated_normal(0.0, 0.2, -0.4, 0.4, N_lens, random_state=rng_np)
vel_model_lens = jampy_interp(thetaE_lens, gamma_true_lens, re_lens, beta_true_lens) * jnp.sqrt(ds_lens / dls_lens)
lambda_true_lens = tool.truncated_normal(1.0, 0.05, 0.8, 1.2, N_lens, random_state=rng_np)
vel_true_lens = vel_model_lens * jnp.sqrt(lambda_true_lens)

gamma_obs_lens = gamma_true_lens + tool.truncated_normal(0.0, 0.05, -0.2, 0.2, N_lens, random_state=rng_np)
theta_E_err = 0.01 * thetaE_lens
thetaE_lens_obs = thetaE_lens + np.random.normal(0.0, theta_E_err)
vel_err_lens = 0.10 * vel_true_lens
vel_obs_lens = np.random.normal(vel_true_lens, vel_err_lens)

lens_data = {
    "zl": zl_lens,
    "zs": zs_lens,
    "theta_E": thetaE_lens,
    "theta_E_err": theta_E_err,
    "re": re_lens,
    "gamma_obs": gamma_true_lens,
    "vel_obs": vel_true_lens,
    "vel_err": vel_err_lens,
}

## 5. Lensed SNe mock data

In [None]:
sn_data = pd.read_csv(os.path.join(DATA_DIR, "Euclid_150SNe.csv"))
sn_data = sn_data[(sn_data["tmax"] >= 5) & (sn_data["tmax"] <= 80)]
sn_data = sn_data.nlargest(70, "tmax")
zl_sne = np.array(sn_data["zl"])
zs_sne = np.array(sn_data["z_host"])
tmax_sne = np.array(sn_data["tmax"])

Dl_sne, Ds_sne, Dls_sne = tool.dldsdls(zl_sne, zs_sne, cosmo_true, n=20)
Ddt_geom_sne = (1.0 + zl_sne) * Dl_sne * Ds_sne / Dls_sne

N_sne = len(zl_sne)
lambda_true_sne = tool.truncated_normal(1.0, 0.05, 0.8, 1.2, N_sne, random_state=rng_np)
Ddt_true_sne = Ddt_geom_sne * lambda_true_sne

frac_err_Ddt = np.sqrt((1.0 / tmax_sne) ** 2 + 0.05 ** 2)
Ddt_obs_sne = Ddt_true_sne * np.random.normal(1.0, frac_err_Ddt)
Ddt_err_sne = Ddt_true_sne * frac_err_Ddt

lambda_obs_sne = lambda_true_sne * np.random.normal(1.0, 0.08, N_sne)
lambda_err_sne = 0.08 * lambda_true_sne

sne_data = {
    "zl": zl_sne,
    "zs": zs_sne,
    "Ddt_obs": Ddt_true_sne,
    "Ddt_err": Ddt_err_sne,
    "lambda_obs": lambda_true_sne,
    "lambda_err": lambda_err_sne,
}

## 6. Joint model

In [None]:
def joint_model(dspl_data=None, lens_data=None, sne_data=None):
    cosmo = cosmology_model("waw0cdm", cosmo_prior, sample_h0=True)
    lambda_mean = numpyro.sample("lambda_mean", dist.Uniform(0.9, 1.1))
    lambda_sigma = numpyro.sample("lambda_sig", dist.TruncatedNormal(0.05, 0.5, low=0.0, high=0.2))

    gamma_mean = numpyro.sample("gamma_mean", dist.Uniform(1.8, 2.2))
    gamma_sigma = numpyro.sample("gamma_sigma", dist.TruncatedNormal(0.2, 0.5, low=0.0, high=0.4))
    beta_mean  = numpyro.sample("beta_mean", dist.Uniform(-0.1, 0.1))
    beta_sigma = numpyro.sample("beta_sigma", dist.TruncatedNormal(0.2, 0.5, low=0.0, high=0.4))

    if dspl_data is not None:
        N_dspl = len(dspl_data["zl"])

        zl  = jnp.asarray(dspl_data["zl"])
        zs1 = jnp.asarray(dspl_data["zs1"])
        zs2_obs = jnp.asarray(dspl_data["zs2_obs"])
        zs2_err = jnp.asarray(dspl_data["zs2_err"])

        Dl1, Ds1, Dls1 = tool.compute_distances(zl, zs1, cosmo)

        if photo_z:
            eps = 1e-3
            zs2_true = numpyro.sample(
                "zs2_true",
                dist.TruncatedNormal(zs2_obs, zs2_err, low=zs1 + eps, high=10.0).to_event(1)
            )
        else:
            zs2_true = dspl_data["zs2_cat"]
        Dl2, Ds2, Dls2 = tool.compute_distances(zl, zs2_true, cosmo)
        beta_geom = Dls1 * Ds2 / (Ds1 * Dls2)

        with numpyro.plate("dspl", N_dspl):
            lambda_dspl = numpyro.sample(
                "lambda_dspl",
                dist.TruncatedNormal(lambda_mean, lambda_sigma, low=0.8, high=1.2),
            )
            numpyro.sample(
                "lambda_dspl_like",
                dist.Normal(lambda_dspl, jnp.asarray(dspl_data["lambda_err"])),
                obs=jnp.asarray(dspl_data["lambda_obs"]),
            )
            beta_mst = tool.beta_antimst(beta_geom, lambda_dspl)
            numpyro.sample(
                "beta_dspl_like",
                dist.TruncatedNormal(beta_mst, jnp.asarray(dspl_data["beta_err"]), low=0.0, high=1.0),
                obs=jnp.asarray(dspl_data["beta_obs"]),
            )

    if lens_data is not None:
        dl_lens, ds_lens, dls_lens = tool.dldsdls(lens_data["zl"], lens_data["zs"], cosmo, n=20)
        N_lens = len(lens_data["zl"])
        with numpyro.plate("lens", N_lens):
            gamma_i = numpyro.sample(
                "gamma_i",
                dist.TruncatedNormal(gamma_mean, gamma_sigma, low=1.6, high=2.4),
            )
            beta_i = numpyro.sample(
                "beta_i",
                dist.TruncatedNormal(beta_mean, beta_sigma, low=-0.4, high=0.4),
            )
            lambda_lens = numpyro.sample(
                "lambda_lens",
                dist.TruncatedNormal(lambda_mean, lambda_sigma, low=0.8, high=1.2),
            )
            theta_E_i = numpyro.sample(
                "theta_E_i",
                dist.Normal(lens_data["theta_E"], lens_data["theta_E_err"]),
            )
            v_interp = jampy_interp(theta_E_i, gamma_i, lens_data["re"], beta_i)
            vel_pred = v_interp * jnp.sqrt(ds_lens / dls_lens) * jnp.sqrt(lambda_lens)

            numpyro.sample(
                "gamma_obs_lens",
                dist.Normal(gamma_i, 0.05),
                obs=lens_data["gamma_obs"],
            )
            numpyro.sample(
                "vel_lens_like",
                dist.Normal(vel_pred, lens_data["vel_err"]),
                obs=lens_data["vel_obs"],
            )

    if sne_data is not None:
        Dl_sne, Ds_sne, Dls_sne = tool.dldsdls(sne_data["zl"], sne_data["zs"], cosmo, n=20)
        Ddt_geom = (1.0 + sne_data["zl"]) * Dl_sne * Ds_sne / Dls_sne
        N_sne = len(sne_data["zl"])
        with numpyro.plate("sne", N_sne):
            lambda_sne = numpyro.sample("lambda_sne", dist.Normal(lambda_mean, lambda_sigma))
            numpyro.sample(
                "lambda_sne_like",
                dist.Normal(lambda_sne, sne_data["lambda_err"]),
                obs=sne_data["lambda_obs"],
            )
            Ddt_true = Ddt_geom * lambda_sne
            numpyro.sample(
                "Ddt_sne_like",
                dist.Normal(Ddt_true, sne_data["Ddt_err"]),
                obs=sne_data["Ddt_obs"],
            )

def head_dict(data_dict, N_use=None):
    return {k: np.asarray(v)[:N_use] for k, v in data_dict.items()}

## 7. Run MCMC (test-mode defaults)

In [None]:
if TEST_MODE:
    N_DSPL_USE = 50
    N_LENS_USE = 200
    N_SNE_USE = 10
    num_warmup = 200
    num_samples = 200
    num_chains = 2
    chain_method = "sequential"
else:
    N_DSPL_USE = 1200
    N_LENS_USE = 5000
    N_SNE_USE = 50
    num_warmup = 2000
    num_samples = 5000
    num_chains = 8
    chain_method = "vectorized"

dspl_data = head_dict(dspl_data, N_DSPL_USE)
lens_data = head_dict(lens_data, N_LENS_USE)
sne_data  = head_dict(sne_data,  N_SNE_USE)

init_values = {
    "h0": 70.0,
    "Omegam": 0.32,
    "w0": -1.0,
    "wa": 0.0,
    "lambda_mean": 1.0,
    "lambda_sig": 0.05,
    "gamma_mean": 2.0,
    "gamma_sigma": 0.2,
    "beta_mean": 0.0,
    "beta_sigma": 0.2,
}

from numpyro.infer import init_to_value
init_strategy = init_to_value(values=init_values)

nuts_kernel = NUTS(
    joint_model,
    target_accept_prob=0.8,
    dense_mass=[("wa", "w0", "h0", "Omegam", "lambda_mean")],
    init_strategy=init_strategy,
)
mcmc = MCMC(
    nuts_kernel,
    num_warmup=num_warmup,
    num_samples=num_samples,
    num_chains=num_chains,
    chain_method=chain_method,
    progress_bar=True,
)

rng_key = random.PRNGKey(0)
mcmc.run(rng_key, dspl_data=dspl_data, sne_data=sne_data, lens_data=lens_data)

posterior = jax.device_get(mcmc.get_samples(group_by_chain=True))
sample_stats = jax.device_get(mcmc.get_extra_fields(group_by_chain=True))
inf_data = az.from_dict(posterior=posterior, sample_stats=sample_stats)

nc_filename = "combine_forecast_test_output" if TEST_MODE else "/mnt/lustre/tianli/slcosmo_result/Lens_revolution"
summary_df = az.summary(inf_data)
summary_df.to_csv(nc_filename + "_summary.csv")
az.to_netcdf(inf_data, nc_filename + ".nc")

print("Saved:", nc_filename + "_summary.csv")
print("Saved:", nc_filename + ".nc")