In [None]:
# Export alpha‑stable SV raw data for NPE‑RS (θ, x, x_obs) as .npy

from pathlib import Path
import sys

import jax
import jax.numpy as jnp
import numpy as np

try:
    from precond_npe_misspec.examples.alpha_stable_sv import assumed_dgp, prior_sample
except ModuleNotFoundError:
    ROOT = Path.cwd()
    REPO = ROOT.parent if ROOT.name == "notebooks" else ROOT
    sys.path += [str(REPO / "src"), str(REPO)]
    from precond_npe_misspec.examples.alpha_stable_sv import assumed_dgp, prior_sample
else:
    ROOT = Path.cwd()
    REPO = ROOT.parent if ROOT.name == "notebooks" else ROOT

try:
    from data.markets import load_sp500_returns_yahoo
except ModuleNotFoundError:
    sys.path += [str(REPO)]
    from data.markets import load_sp500_returns_yahoo

# ---------------- Config ----------------
N_SIMS = 20_000  # adjust as needed
SEED = 0
THETA1 = 0.0
Y_START = "2013-01-02"
Y_END = "2017-02-07"
OUTDIR = REPO / "data" / "alpha_sv_raw_npe_rs"
BATCH = 2048
OUTDIR.mkdir(parents=True, exist_ok=True)

# --------------- Observed x ---------------
r_obs = load_sp500_returns_yahoo(
    start=Y_START, end=Y_END, log_returns=True, standardise=False
)
T = int(r_obs.shape[0])
x_obs = np.asarray(r_obs[:T], dtype=np.float32)  # (T,)

# --------------- Simulate (θ, x) ---------------
key = jax.random.key(SEED)
k_th_base, k_x_base = jax.random.split(key)


@jax.jit
def _simulate_batch(
    th_keys: jax.Array, x_keys: jax.Array
) -> tuple[jax.Array, jax.Array]:
    thetas_b = jax.vmap(prior_sample)(th_keys)  # (B,3)
    xs_b = jax.vmap(lambda kk, th: assumed_dgp(kk, th, T=T, theta1=THETA1))(
        x_keys, thetas_b
    )  # (B,T)
    return thetas_b, xs_b


theta_parts, x_parts = [], []
for start in range(0, N_SIMS, BATCH):
    end = min(start + BATCH, N_SIMS)
    idx = jnp.arange(start, end, dtype=jnp.uint32)
    th_keys = jax.vmap(lambda i: jax.random.fold_in(k_th_base, i))(idx)
    x_keys = jax.vmap(lambda i: jax.random.fold_in(k_x_base, i))(idx)
    th_b, x_b = _simulate_batch(th_keys, x_keys)
    theta_parts.append(np.asarray(th_b, dtype=np.float32))
    x_parts.append(np.asarray(x_b, dtype=np.float32))

theta = np.concatenate(theta_parts, axis=0)  # (N_SIMS,3)
x = np.concatenate(x_parts, axis=0)  # (N_SIMS,T)

# --------------- Save .npy ---------------
np.save(OUTDIR / "theta.npy", theta)
np.save(OUTDIR / "x.npy", x)
np.save(OUTDIR / "x_obs.npy", x_obs)

print("Saved:", OUTDIR)
print("theta.npy", theta.shape, theta.dtype)
print("x.npy", x.shape, x.dtype)
print("x_obs.npy", x_obs.shape, x_obs.dtype)

In [None]:
(x - np.nanmean(x)) / np.nanstd(x)