# PSF Modelling with SVI (STPSF + Pixel Correction)

这个 notebook 使用 `Herculens_Tian_JWST.ipynb` 同款的 SVI 思路，
对 29 个 `psf_data/*SCIERR.fits` 联合建模。

建模结构：
- `PSF_model_ss = STPSF_base_ss + correction_ss`
- `correction_ss` 来自 `matern_power_spectrum + white-noise Fourier modes`
- 在 supersampled 网格建模后，`resize` 到 detector 101x101
- 每个观测星点自由度：`x_pos, y_pos, log10_flux, background`


In [None]:
%load_ext autoreload
%autoreload 2

import os
import glob
import warnings
warnings.simplefilter("ignore")

import numpy as np
import jax
import jax.numpy as jnp
from jax.scipy.ndimage import map_coordinates

import numpyro
import numpyro.distributions as dist
import numpyro.infer as infer
import numpyro.infer.autoguide as autoguide
import optax

import matplotlib.pyplot as plt
import arviz as az
from astropy.io import fits

jax.config.update("jax_enable_x64", True)
numpyro.enable_x64()

from herculens_import_main import (
    matern_power_spectrum,
    split_scheduler,
    SVI_vec,
    K_grid,
    get_pixel_grid,
)
from herculens.PointSourceModel.point_source_model import PointSourceModel

print("JAX devices:", jax.devices())

In [None]:
# -----------------------------
# Config
# -----------------------------
PIX_SCALE = 0.031
DATA_DIR = "../Data/WFI2033"
PSF_DATA_DIR = "./psf_data"
BASE_PSF_PATH = os.path.join(DATA_DIR, "F115W_PSF_stpsf_ss2.fits")
BASE_PSF_EXT = "OVERSAMP"   # supersampled extension in STPSF file
RESULT_DIR = "./result"
OUTPUT_DIR = RESULT_DIR

SEED = 123
MAX_ITERATIONS = 8000
NUM_CHAINS = 4
NUM_POST_SAMPLES = 300

SS_FACTOR = 2

XPOS_PRIOR_SIGMA = 0.03
YPOS_PRIOR_SIGMA = 0.03
XPOS_BOUNDS = (-0.30, 0.30)
YPOS_BOUNDS = (-0.30, 0.30)

os.makedirs(OUTPUT_DIR, exist_ok=True)
print("OUTPUT_DIR:", OUTPUT_DIR)

In [None]:
# -----------------------------
# Load 29 SCIERR cutouts
# -----------------------------
psf_files = sorted(glob.glob(os.path.join(PSF_DATA_DIR, "*SCIERR.fits")))
if len(psf_files) == 0:
    raise FileNotFoundError(f"No SCIERR FITS found in {PSF_DATA_DIR}")

print(f"Found {len(psf_files)} cutouts")

sci_list = []
err_list = []
flux_loc_list = []
bkg_loc_list = []
bkg_scale_list = []

for fp in psf_files:
    with fits.open(fp, memmap=True) as hdul:
        sci = np.array(hdul["SCI"].data, dtype=np.float64)
        err = np.array(hdul["ERR"].data, dtype=np.float64)

    if sci.shape != err.shape:
        raise ValueError(f"Shape mismatch in {fp}: SCI{ sci.shape } vs ERR{ err.shape }")

    valid = np.isfinite(sci) & np.isfinite(err) & (err > 0)
    if not np.any(valid):
        raise ValueError(f"No valid pixels in {fp}")

    med_err = float(np.nanmedian(err[valid]))
    sci = np.where(valid, sci, 0.0)
    err = np.where(valid, err, med_err)
    err = np.clip(err, med_err * 0.25, np.inf)

    border = np.concatenate([
        sci[:5, :].ravel(),
        sci[-5:, :].ravel(),
        sci[:, :5].ravel(),
        sci[:, -5:].ravel(),
    ])
    bkg = float(np.nanmedian(border))
    mad = float(np.nanmedian(np.abs(border - bkg)))
    bkg_sigma = max(1.4826 * mad, med_err * 0.1, 1e-6)

    flux_est = float(np.sum(np.clip(sci - bkg, 0.0, None)))
    flux_est = max(flux_est, 1e-6)

    sci_list.append(sci)
    err_list.append(err)
    flux_loc_list.append(np.log10(flux_est))
    bkg_loc_list.append(bkg)
    bkg_scale_list.append(bkg_sigma)

sci_stack = jnp.array(np.stack(sci_list), dtype=jnp.float64)
err_stack = jnp.array(np.stack(err_list), dtype=jnp.float64)
flux_loc = jnp.array(np.array(flux_loc_list), dtype=jnp.float64)
bkg_loc = jnp.array(np.array(bkg_loc_list), dtype=jnp.float64)
bkg_scale = jnp.array(np.array(bkg_scale_list), dtype=jnp.float64)

n_star, ny, nx = sci_stack.shape
pixel_grid, xgrid, ygrid, x_axis, y_axis, extent, nx_grid, ny_grid = get_pixel_grid(np.zeros((ny, nx)), PIX_SCALE)

print("data shape:", sci_stack.shape)
print("first 5 log10(flux) loc:", np.array(flux_loc[:5]))
print("first 5 bkg loc:", np.array(bkg_loc[:5]))

In [None]:
# -----------------------------
# Load STPSF base and define render helpers
# -----------------------------
with fits.open(BASE_PSF_PATH, memmap=True) as hdul:
    if BASE_PSF_EXT in hdul:
        base_psf_ss_np = np.array(hdul[BASE_PSF_EXT].data, dtype=np.float64)
    else:
        base_psf_ss_np = np.array(hdul[0].data, dtype=np.float64)

if base_psf_ss_np.ndim != 2 or base_psf_ss_np.shape[0] != base_psf_ss_np.shape[1]:
    raise ValueError(f"Base PSF must be square 2D, got {base_psf_ss_np.shape}")

if base_psf_ss_np.shape[0] % SS_FACTOR != 0:
    raise ValueError(f"Base PSF size {base_psf_ss_np.shape[0]} not divisible by SS_FACTOR={SS_FACTOR}")


def normalize_kernel(kernel):
    kernel = jnp.nan_to_num(kernel, nan=0.0, posinf=0.0, neginf=0.0)
    kernel = jnp.clip(kernel, 0.0, jnp.inf)
    total = jnp.sum(kernel)
    return jnp.where(total > 0.0, kernel / total, kernel)


def downsample_mean(kernel_ss, factor=2):
    ny_ss, nx_ss = kernel_ss.shape
    return kernel_ss.reshape(ny_ss // factor, factor, nx_ss // factor, factor).mean(axis=(1, 3))


point_source_model = PointSourceModel(["IMAGE_POSITIONS"])


def render_point_sources_from_kernel(kernel_det, theta_x, theta_y, amplitude):
    theta_x = jnp.atleast_1d(theta_x)
    theta_y = jnp.atleast_1d(theta_y)
    amplitude = jnp.atleast_1d(amplitude)

    x_pix, y_pix = pixel_grid.map_coord2pix(theta_x, theta_y)
    kernel_t = kernel_det.T

    nx_det, ny_det = pixel_grid.num_pixel_axes
    xrange = jnp.arange(nx_det) + kernel_t.shape[0] // 2
    yrange = jnp.arange(ny_det) + kernel_t.shape[1] // 2

    result = jnp.zeros((nx_det, ny_det), dtype=kernel_det.dtype)
    for x0, y0, amp in zip(x_pix, y_pix, amplitude):
        xy_grid = jnp.meshgrid(xrange - x0, yrange - y0)
        result = result + amp * map_coordinates(kernel_t, xy_grid, order=1, mode="nearest")
    return result


def render_single_star(kernel_det, x_pos, y_pos, flux, background):
    kwargs_point_source = [{"ra": x_pos, "dec": y_pos, "amp": flux}]
    theta_x_list, theta_y_list, amp_list = point_source_model.get_multiple_images(
        kwargs_point_source,
        kwargs_lens=None,
        kwargs_solver=None,
        k=0,
        with_amplitude=True,
        zero_amp_duplicates=False,
    )
    image = render_point_sources_from_kernel(kernel_det, theta_x_list[0], theta_y_list[0], amp_list[0])
    return image + background


base_psf_ss = normalize_kernel(jnp.array(base_psf_ss_np, dtype=jnp.float64))
base_psf_det = normalize_kernel(downsample_mean(base_psf_ss, factor=SS_FACTOR))
k_grid_ss = K_grid(base_psf_ss.shape)
k_values_ss = jnp.array(k_grid_ss.k, dtype=jnp.float64)

print("base ss shape:", base_psf_ss.shape)
print("base det shape:", base_psf_det.shape)

fig, axes = plt.subplots(1, 2, figsize=(10, 4))
axes[0].imshow(np.array(base_psf_ss), origin="lower", norm="log")
axes[0].set_title("Base STPSF (ss grid)")
axes[1].imshow(np.array(base_psf_det), origin="lower", norm="log")
axes[1].set_title("Base STPSF (detector grid)")
plt.tight_layout()

In [None]:
# -----------------------------
# Model: STPSF + (Matern + WN Fourier correction) + resize + per-star nuisance
# Flux is solved analytically per star (weighted least squares), not sampled.
# Correction field is projected to have zero 0th/1st moments (no total-flux or centroid drift).
# -----------------------------
def weighted_ls_flux(unit_model_stack, data_minus_bkg, err_stack, eps=1e-12):
    inv_var = 1.0 / (err_stack**2 + eps)
    numer = jnp.sum(unit_model_stack * data_minus_bkg * inv_var, axis=(1, 2))
    denom = jnp.sum((unit_model_stack**2) * inv_var, axis=(1, 2)) + eps
    flux = numer / denom
    # enforce physically positive point-source flux
    return jnp.clip(flux, eps, jnp.inf)


def project_zero_moments(corr, eps=1e-12):
    ny, nx = corr.shape
    yy, xx = jnp.indices((ny, nx), dtype=corr.dtype)
    xx = xx - (nx - 1) / 2.0
    yy = yy - (ny - 1) / 2.0

    b0 = jnp.ones_like(corr)
    bx = xx
    by = yy

    d = corr.reshape(-1)
    B = jnp.stack([b0.reshape(-1), bx.reshape(-1), by.reshape(-1)], axis=1)  # (N, 3)

    BtB = B.T @ B
    Btd = B.T @ d
    coeff = jnp.linalg.solve(BtB + eps * jnp.eye(3, dtype=corr.dtype), Btd)
    d_proj = d - B @ coeff
    return d_proj.reshape(ny, nx)


def model_psf_svi(sci_data, err_data, flux_loc, bkg_loc, bkg_scale, base_psf_ss, k_values):
    _ = flux_loc  # kept for function interface compatibility

    corr_dict = matern_power_spectrum(
        "PSF correction",
        "psf_corr",
        k_values,
        n_value=None,
        positive=False,
    )
    corr_ss_raw = corr_dict["pixels"]
    corr_ss = project_zero_moments(corr_ss_raw)
    numpyro.deterministic("pixels_psf_corr_proj", corr_ss)

    psf_ss_raw = base_psf_ss + corr_ss
    psf_ss_pos = jax.nn.softplus(100.0 * psf_ss_raw) / 100.0
    psf_ss_model = psf_ss_pos / jnp.sum(psf_ss_pos)
    numpyro.deterministic("psf_ss_model", psf_ss_model)

    psf_det_model = downsample_mean(psf_ss_model, factor=SS_FACTOR)
    psf_det_model = psf_det_model / jnp.sum(psf_det_model)
    numpyro.deterministic("psf_det_model", psf_det_model)

    base_det = downsample_mean(base_psf_ss, factor=SS_FACTOR)
    base_det = base_det / jnp.sum(base_det)
    numpyro.deterministic("psf_corr_ss_eff", psf_ss_model - base_psf_ss)
    numpyro.deterministic("psf_corr_det_eff", psf_det_model - base_det)

    n_star = sci_data.shape[0]
    with numpyro.plate("stars", n_star):
        x_pos = numpyro.sample(
            "x_pos",
            dist.TruncatedNormal(
                loc=jnp.zeros(n_star),
                scale=XPOS_PRIOR_SIGMA,
                low=XPOS_BOUNDS[0],
                high=XPOS_BOUNDS[1],
            ),
        )
        y_pos = numpyro.sample(
            "y_pos",
            dist.TruncatedNormal(
                loc=jnp.zeros(n_star),
                scale=YPOS_PRIOR_SIGMA,
                low=YPOS_BOUNDS[0],
                high=YPOS_BOUNDS[1],
            ),
        )
        background = numpyro.sample(
            "background",
            dist.Normal(loc=bkg_loc, scale=bkg_scale),
        )

    # Render unit-flux stars on detector grid; solve amplitudes analytically
    unit_model_stack = jax.vmap(render_single_star, in_axes=(None, 0, 0, 0, 0))(
        psf_det_model,
        x_pos,
        y_pos,
        jnp.ones_like(x_pos),
        jnp.zeros_like(background),
    )
    data_minus_bkg = sci_data - background[:, None, None]
    flux_opt = weighted_ls_flux(unit_model_stack, data_minus_bkg, err_data)
    numpyro.deterministic("flux_opt", flux_opt)
    numpyro.deterministic("log10_flux_opt", jnp.log10(flux_opt))

    model_stack = unit_model_stack * flux_opt[:, None, None] + background[:, None, None]
    numpyro.sample("obs", dist.Normal(model_stack, err_data), obs=sci_data)


rng_key = jax.random.PRNGKey(SEED)

init_fun = infer.init_to_median(num_samples=15)
guide = autoguide.AutoDiagonalNormal(model_psf_svi, init_loc_fn=init_fun, init_scale=0.02)

scheduler = split_scheduler(MAX_ITERATIONS, init_value=0.01, transition_steps=[200, 20])
optim = optax.adabelief(learning_rate=scheduler)
loss = infer.TraceMeanField_ELBO()

svi = SVI_vec(model_psf_svi, guide, optim, loss)

svi_results = svi.run(
    rng_key,
    NUM_CHAINS,
    MAX_ITERATIONS,
    sci_stack,
    err_stack,
    flux_loc,
    bkg_loc,
    bkg_scale,
    base_psf_ss,
    k_values_ss,
    stable_update=True,
)

losses_np = np.array(jax.device_get(svi_results.losses))
final_losses = losses_np[:, -1]
best_chain = int(np.argmin(final_losses))
print("Final losses:", final_losses)
print("Best chain:", best_chain)

params_best = jax.tree.map(lambda v: v[best_chain], svi_results.params)
median_best = guide.median(params_best)

In [None]:
# -----------------------------
# Posterior samples and required posteriors
# - pixel grid posterior: pixels_psf_corr (raw) and pixels_psf_corr_proj (projected)
# - power-spectrum posterior: n_psf_corr, rho_psf_corr (and sigma_psf_corr)
# - deterministic model PSF images for direct posterior median (SVI)
# -----------------------------
rng_key, sample_key, pred_key = jax.random.split(rng_key, 3)

posterior_latent = guide.sample_posterior(
    sample_key,
    params_best,
    sample_shape=(NUM_POST_SAMPLES,),
)

return_sites = [
    "psf_ss_model",
    "psf_det_model",
    "pixels_psf_corr",
    "pixels_psf_corr_proj",
    "psf_corr_ss_eff",
    "psf_corr_det_eff",
    "n_psf_corr",
    "rho_psf_corr",
    "sigma_psf_corr",
    "x_pos",
    "y_pos",
    "background",
    "flux_opt",
    "log10_flux_opt",
]

predictive = infer.Predictive(
    model_psf_svi,
    posterior_samples=posterior_latent,
    return_sites=return_sites,
)

posterior = predictive(
    pred_key,
    sci_stack,
    err_stack,
    flux_loc,
    bkg_loc,
    bkg_scale,
    base_psf_ss,
    k_values_ss,
)

n_post = np.array(jax.device_get(posterior["n_psf_corr"]))[:, 0]
rho_post = np.array(jax.device_get(posterior["rho_psf_corr"]))[:, 0]
sigma_post = np.array(jax.device_get(posterior["sigma_psf_corr"]))[:, 0]

psf_ss_median = np.median(np.array(jax.device_get(posterior["psf_ss_model"])), axis=0)
psf_det_median = np.median(np.array(jax.device_get(posterior["psf_det_model"])), axis=0)
pixcorr_median = np.median(np.array(jax.device_get(posterior["pixels_psf_corr_proj"])), axis=0)

summary = {
    "n_median": float(np.median(n_post)),
    "n_p16": float(np.percentile(n_post, 16)),
    "n_p84": float(np.percentile(n_post, 84)),
    "rho_median": float(np.median(rho_post)),
    "rho_p16": float(np.percentile(rho_post, 16)),
    "rho_p84": float(np.percentile(rho_post, 84)),
    "sigma_median": float(np.median(sigma_post)),
    "sigma_p16": float(np.percentile(sigma_post, 16)),
    "sigma_p84": float(np.percentile(sigma_post, 84)),
}

summary

In [None]:
# -----------------------------
# Visual diagnostics
# -----------------------------
fig, axes = plt.subplots(2, 3, figsize=(14, 8))

for i in range(losses_np.shape[0]):
    axes[0, 0].plot(losses_np[i], alpha=0.7, label=f"chain {i}")
axes[0, 0].set_yscale("asinh")
axes[0, 0].set_title("SVI loss")
axes[0, 0].legend(loc="best", fontsize=8)

im1 = axes[0, 1].imshow(np.array(base_psf_det), origin="lower", norm="log")
axes[0, 1].set_title("Base STPSF (det)")
plt.colorbar(im1, ax=axes[0, 1], fraction=0.046, pad=0.04)

im2 = axes[0, 2].imshow(psf_det_median, origin="lower", norm="log")
axes[0, 2].set_title("Model PSF median (det)")
plt.colorbar(im2, ax=axes[0, 2], fraction=0.046, pad=0.04)

im3 = axes[1, 0].imshow(pixcorr_median, origin="lower")
axes[1, 0].set_title("pixels_psf_corr median (ss)")
plt.colorbar(im3, ax=axes[1, 0], fraction=0.046, pad=0.04)

axes[1, 1].hist(n_post, bins=30, alpha=0.8)
axes[1, 1].set_title("Posterior of n_psf_corr")

axes[1, 2].hist(rho_post, bins=30, alpha=0.8)
axes[1, 2].set_xscale("log")
axes[1, 2].set_title("Posterior of rho_psf_corr")

plt.tight_layout()

In [None]:
# -----------------------------
# Save outputs
# -----------------------------
output_fits = os.path.join(OUTPUT_DIR, "F115W_PSF_svi_stpsf_plus_corr.fits")
output_npz = os.path.join(OUTPUT_DIR, "psf_svi_posterior_summary.npz")

hdr = fits.Header()
hdr["NSTAR"] = int(n_star)
hdr["PIXSCALE"] = float(PIX_SCALE)
hdr["SSFACT"] = int(SS_FACTOR)
hdr["NPOST"] = int(NUM_POST_SAMPLES)
hdr["NMED"] = float(summary["n_median"])
hdr["RHOMED"] = float(summary["rho_median"])
hdr["SIGMED"] = float(summary["sigma_median"])

hdul = fits.HDUList([
    fits.PrimaryHDU(header=hdr),
    fits.ImageHDU(data=np.array(psf_ss_median, dtype=np.float32), name="SS_PSF_MED"),
    fits.ImageHDU(data=np.array(psf_det_median, dtype=np.float32), name="DET_PSF_MED"),
    fits.ImageHDU(data=np.array(pixcorr_median, dtype=np.float32), name="SS_CORR_MED"),
    fits.ImageHDU(data=np.array(k_values_ss, dtype=np.float32), name="K_GRID"),
])
hdul.writeto(output_fits, overwrite=True)

np.savez(
    output_npz,
    n_post=n_post,
    rho_post=rho_post,
    sigma_post=sigma_post,
    x_pos_post=np.array(jax.device_get(posterior["x_pos"])),
    y_pos_post=np.array(jax.device_get(posterior["y_pos"])),
    flux_opt_post=np.array(jax.device_get(posterior["flux_opt"])),
log10_flux_opt_post=np.array(jax.device_get(posterior["log10_flux_opt"])),
    background_post=np.array(jax.device_get(posterior["background"])),
)

print("saved:", output_fits)
print("saved:", output_npz)

In [None]:
# -----------------------------
# Result display requested:
# - STPSF psf and modeled psf
# - global correction field
# - all stars: star, model, residual
# -----------------------------
x_med = np.median(np.array(jax.device_get(posterior["x_pos"])), axis=0)
y_med = np.median(np.array(jax.device_get(posterior["y_pos"])), axis=0)
flux_med = np.median(np.array(jax.device_get(posterior["flux_opt"])), axis=0)
bkg_med = np.median(np.array(jax.device_get(posterior["background"])), axis=0)


# Model PSF-only image for each star (background fixed to 0)
psf_only_med = jax.vmap(render_single_star, in_axes=(None, 0, 0, 0, 0))(
    jnp.array(psf_det_median),
    jnp.array(x_med),
    jnp.array(y_med),
    jnp.array(flux_med),
    jnp.zeros_like(jnp.array(bkg_med)),
)
psf_only_med = np.array(jax.device_get(psf_only_med))

data_np = np.array(jax.device_get(sci_stack))
err_np = np.array(jax.device_get(err_stack))
bkg_cube = bkg_med[:, None, None]

star_img = data_np - bkg_cube
residual = (data_np - psf_only_med - bkg_cube) / err_np

# (A) STPSF detector PSF and model detector PSF
fig, axes = plt.subplots(1, 2, figsize=(9, 4))
im0 = axes[0].imshow(np.array(base_psf_det), origin="lower", norm="log")
axes[0].set_title("STPSF (detector)")
axes[0].set_xticks([])
axes[0].set_yticks([])
plt.colorbar(im0, ax=axes[0], fraction=0.046, pad=0.04)

im1 = axes[1].imshow(np.array(psf_det_median), origin="lower", norm="log")
axes[1].set_title("Model PSF (detector)")
axes[1].set_xticks([])
axes[1].set_yticks([])
plt.colorbar(im1, ax=axes[1], fraction=0.046, pad=0.04)
plt.tight_layout()

# (B) Global correction field (single field, not per-star)
global_corr_ss = np.array(psf_ss_median) - np.array(base_psf_ss)
global_corr_det = np.array(psf_det_median) - np.array(base_psf_det)

abs_corr_ss = np.max(np.abs(global_corr_ss))
if abs_corr_ss <= 0:
    abs_corr_ss = 1.0
abs_corr_det = np.max(np.abs(global_corr_det))
if abs_corr_det <= 0:
    abs_corr_det = 1.0

fig, axes = plt.subplots(1, 2, figsize=(10, 4))
im2 = axes[0].imshow(global_corr_ss, origin="lower", cmap="coolwarm", vmin=-abs_corr_ss, vmax=abs_corr_ss)
axes[0].set_title("Global correction field (SS)")
axes[0].set_xticks([])
axes[0].set_yticks([])
plt.colorbar(im2, ax=axes[0], fraction=0.046, pad=0.04)

im3 = axes[1].imshow(global_corr_det, origin="lower", cmap="coolwarm", vmin=-abs_corr_det, vmax=abs_corr_det)
axes[1].set_title("Global correction field (detector)")
axes[1].set_xticks([])
axes[1].set_yticks([])
plt.colorbar(im3, ax=axes[1], fraction=0.046, pad=0.04)
plt.tight_layout()

# (C) Show all stars with 3 panels each: star / model / residual
n_show = star_img.shape[0]
chunk_size = 7  # rows per figure

star_vmin, star_vmax = np.percentile(star_img, [1.0, 99.5])
model_vmin, model_vmax = np.percentile(psf_only_med, [1.0, 99.5])

for start in range(0, n_show, chunk_size):
    stop = min(start + chunk_size, n_show)
    nrows = stop - start

    fig, axes = plt.subplots(nrows, 3, figsize=(9.5, 2.6 * nrows), constrained_layout=True)
    axes = np.array(axes)
    if nrows == 1:
        axes = axes[None, :]

    for r, i in enumerate(range(start, stop)):
        im_star = axes[r, 0].imshow(star_img[i], origin="lower", cmap="viridis", vmin=star_vmin, vmax=star_vmax)
        axes[r, 0].set_title(f"star {i}: data-bkg", fontsize=9)

        im_model = axes[r, 1].imshow(psf_only_med[i], origin="lower", cmap="viridis", vmin=model_vmin, vmax=model_vmax)
        axes[r, 1].set_title("model", fontsize=9)

        im_res = axes[r, 2].imshow(residual[i], origin="lower", cmap="bwr", vmin=-3, vmax=3)
        axes[r, 2].set_title("residual", fontsize=9)

        for c in range(3):
            axes[r, c].set_xticks([])
            axes[r, c].set_yticks([])

    fig.colorbar(im_star, ax=axes[:, 0].tolist(), fraction=0.02, pad=0.01)
    fig.colorbar(im_model, ax=axes[:, 1].tolist(), fraction=0.02, pad=0.01)
    fig.colorbar(im_res, ax=axes[:, 2].tolist(), fraction=0.02, pad=0.01)

    fig.suptitle(f"Stars {start} - {stop-1}: star / model / residual", fontsize=12)

In [None]:
# -----------------------------
# HMC continuation from SVI (sample full PSF + star nuisance parameter set)
# Debug version: step-by-step logging to locate kernel crash.
# -----------------------------
import time

def _hlog(msg):
    print(f"[HMC-DEBUG {time.strftime('%H:%M:%S')}] {msg}", flush=True)

_hlog("Step 0: configure HMC hyperparameters")
HMC_WARMUP = 500
HMC_SAMPLES = 1500
HMC_CHAINS = 4
HMC_TARGET_ACCEPT = 0.9
HMC_CHAIN_METHOD = "vectorized"  # use vectorized chains for GPU execution

_hlog(f"Step 0.1: devices={jax.devices()}")
_hlog(f"Step 0.2: sci_stack={tuple(sci_stack.shape)}, err_stack={tuple(err_stack.shape)}, base_psf_ss={tuple(base_psf_ss.shape)}")
_hlog(f"Step 0.3: finite checks: sci={bool(np.array(jnp.all(jnp.isfinite(sci_stack))))}, err={bool(np.array(jnp.all(jnp.isfinite(err_stack))))}, k={bool(np.array(jnp.all(jnp.isfinite(k_values_ss))))}")

_hlog("Step 1: build per-SVI-chain medians for initialization")
svi_chain_medians = []
for ic in range(NUM_CHAINS):
    _hlog(f"Step 1.{ic+1}: median from SVI chain {ic}")
    params_ic = jax.tree.map(lambda v: v[ic], svi_results.params)
    median_ic = guide.median(params_ic)
    svi_chain_medians.append(median_ic)
_hlog("Step 1 done")

if HMC_CHAINS > len(svi_chain_medians):
    raise ValueError(f"HMC_CHAINS={HMC_CHAINS} > available SVI chains={len(svi_chain_medians)}")

# Full set requested: pixels_wn_psf_corr + all 29 star nuisances + n/rho/sigma
HMC_SAMPLE_SITES = (
    "n_psf_corr",
    "rho_psf_corr",
    "sigma_psf_corr",
    "pixels_wn_psf_corr",
    "x_pos",
    "y_pos",
    "background",
)

_hlog("Step 2: validate required sites")
for k in HMC_SAMPLE_SITES:
    if k not in svi_chain_medians[0]:
        raise KeyError(f"Missing site in SVI median for HMC init: {k}")
_hlog("Step 2 done")

_hlog("Step 3: build init_params_hmc")
init_params_hmc = {
    k: jnp.stack([jnp.array(svi_chain_medians[i][k]) for i in range(HMC_CHAINS)], axis=0)
    for k in HMC_SAMPLE_SITES
}
for k in HMC_SAMPLE_SITES:
    arr = init_params_hmc[k]
    _hlog(f"  init[{k}] shape={tuple(arr.shape)} dtype={arr.dtype}")

_hlog("Step 3.1: block_until_ready on init params")
_ = jax.tree.map(lambda x: x.block_until_ready() if hasattr(x, "block_until_ready") else x, init_params_hmc)
_hlog("Step 3.1 done")

_hlog("Step 4: create NUTS kernel")
model_hmc = model_psf_svi
nuts_kernel = infer.NUTS(model_hmc, target_accept_prob=HMC_TARGET_ACCEPT)
_hlog("Step 4 done")

_hlog("Step 5: create MCMC object")
mcmc_hmc = infer.MCMC(
    nuts_kernel,
    num_warmup=HMC_WARMUP,
    num_samples=HMC_SAMPLES,
    num_chains=HMC_CHAINS,
    chain_method=HMC_CHAIN_METHOD,
    progress_bar=True,
)
_hlog("Step 5 done")

_hlog("Step 6: split rng and start mcmc_hmc.run")
rng_key, hmc_key = jax.random.split(rng_key)
_hlog(f"Step 6.1: hmc_key shape={tuple(hmc_key.shape)}")

try:
    mcmc_hmc.run(
        hmc_key,
        sci_stack,
        err_stack,
        flux_loc,
        bkg_loc,
        bkg_scale,
        base_psf_ss,
        k_values_ss,
        init_params=init_params_hmc,
    )
    _hlog("Step 6 done: mcmc_hmc.run finished")
except Exception as e:
    _hlog(f"Step 6 FAILED: {type(e).__name__}: {e}")
    raise

# For large high-dimensional posteriors, avoid az.from_numpyro() heavy conversion.
_hlog("Step 7: get raw posterior samples")
hmc_samples = mcmc_hmc.get_samples(group_by_chain=True)
for k in ["n_psf_corr", "rho_psf_corr", "sigma_psf_corr", "pixels_wn_psf_corr", "x_pos", "y_pos", "background", "psf_det_model", "psf_ss_model", "flux_opt"]:
    if k in hmc_samples:
        _hlog(f"  hmc_samples[{k}] shape={tuple(np.array(hmc_samples[k]).shape)}")
_hlog("Step 7 done")

_hlog("Step 8: build lightweight ArviZ idata for trace only")
def _trace_2d(arr):
    arr = np.array(jax.device_get(arr))
    while arr.ndim > 2 and arr.shape[-1] == 1:
        arr = arr[..., 0]
    if arr.ndim != 2:
        raise ValueError(f"Trace array expected 2D after squeeze, got {arr.shape}")
    return arr

n_trace = _trace_2d(hmc_samples["n_psf_corr"])
rho_trace = _trace_2d(hmc_samples["rho_psf_corr"])
sigma_trace = _trace_2d(hmc_samples["sigma_psf_corr"])

inf_data_hmc = az.from_dict(
    posterior={
        "n_psf_corr": n_trace,
        "rho_psf_corr": rho_trace,
        "sigma_psf_corr": sigma_trace,
    }
)
_hlog("Step 8 done")

_hlog("Step 9: save HMC outputs to result/")
hmc_trace_npz = os.path.join(OUTPUT_DIR, "hmc_trace_n_rho_sigma.npz")
np.savez(hmc_trace_npz, n_trace=n_trace, rho_trace=rho_trace, sigma_trace=sigma_trace)

psf_ss_hmc_median = np.median(np.array(jax.device_get(hmc_samples["psf_ss_model"])), axis=(0, 1))
psf_det_hmc_median = np.median(np.array(jax.device_get(hmc_samples["psf_det_model"])), axis=(0, 1))
psf_ss_hmc_p16 = np.percentile(np.array(jax.device_get(hmc_samples["psf_ss_model"])), 16, axis=(0, 1))
psf_ss_hmc_p84 = np.percentile(np.array(jax.device_get(hmc_samples["psf_ss_model"])), 84, axis=(0, 1))

hmc_fits_path = os.path.join(OUTPUT_DIR, "F115W_PSF_hmc_median.fits")
fits.HDUList([
    fits.PrimaryHDU(),
    fits.ImageHDU(data=np.array(psf_ss_hmc_median, dtype=np.float32), name="SS_PSF_HMC_MED"),
    fits.ImageHDU(data=np.array(psf_det_hmc_median, dtype=np.float32), name="DET_PSF_HMC_MED"),
    fits.ImageHDU(data=np.array(psf_ss_hmc_p16, dtype=np.float32), name="SS_PSF_HMC_P16"),
    fits.ImageHDU(data=np.array(psf_ss_hmc_p84, dtype=np.float32), name="SS_PSF_HMC_P84"),
]).writeto(hmc_fits_path, overwrite=True)

_hlog(f"Step 9 done: saved {hmc_trace_npz}")
_hlog(f"Step 9 done: saved {hmc_fits_path}")

_hlog("Step 10: print summary")
print(mcmc_hmc.print_summary())
_hlog("Step 10 done")

In [None]:
# -----------------------------
# HMC traces for n / rho / sigma (4 chains) via ArviZ (lightweight idata)
# -----------------------------
plt.rcParams['figure.constrained_layout.use'] = True
chains_to_plot = np.arange(HMC_CHAINS)
_ = az.plot_trace(
    inf_data_hmc.sel(chain=chains_to_plot),
    var_names=["n_psf_corr", "rho_psf_corr", "sigma_psf_corr"],
    figsize=(10, 8),
)
plt.tight_layout()

In [None]:
# -----------------------------
# Compare star0: 4 chain-median models/residuals + median over 4 chains
# Use deterministic outputs directly from HMC samples.
# -----------------------------
# chain-wise medians for sampled parameters
hmc_chain_median = jax.tree.map(lambda v: jnp.median(v, axis=1), hmc_samples)

# chain-wise medians of deterministic PSF image and solved flux
psf_det_chain_median = np.median(np.array(jax.device_get(hmc_samples["psf_det_model"])), axis=1)
flux_opt_chain_median = np.median(np.array(jax.device_get(hmc_samples["flux_opt"])), axis=1)

star0_data = np.array(jax.device_get(sci_stack))[0]
star0_err = np.array(jax.device_get(err_stack))[0]

chain_models = []
chain_residuals = []
for c in range(HMC_CHAINS):
    psf_det_c = psf_det_chain_median[c]

    x0 = float(np.array(jax.device_get(hmc_chain_median["x_pos"][c]))[0])
    y0 = float(np.array(jax.device_get(hmc_chain_median["y_pos"][c]))[0])
    bkg0 = float(np.array(jax.device_get(hmc_chain_median["background"][c]))[0])
    flux0 = float(flux_opt_chain_median[c][0])

    model0_c = np.array(jax.device_get(render_single_star(jnp.array(psf_det_c), x0, y0, flux0, 0.0)))
    resid0_c = (star0_data - model0_c - bkg0) / star0_err

    chain_models.append(model0_c)
    chain_residuals.append(resid0_c)

chain_models = np.array(chain_models)
chain_residuals = np.array(chain_residuals)

# plot 4 chain models and residuals
fig, axes = plt.subplots(2, HMC_CHAINS, figsize=(3.0 * HMC_CHAINS, 6), constrained_layout=True)
model_vmin, model_vmax = np.percentile(chain_models, [1.0, 99.5])
for c in range(HMC_CHAINS):
    im_m = axes[0, c].imshow(chain_models[c], origin="lower", cmap="viridis", vmin=model_vmin, vmax=model_vmax)
    axes[0, c].set_title(f"chain {c} model")
    im_r = axes[1, c].imshow(chain_residuals[c], origin="lower", cmap="bwr", vmin=-3, vmax=3)
    axes[1, c].set_title(f"chain {c} residual")
    axes[0, c].set_xticks([]); axes[0, c].set_yticks([])
    axes[1, c].set_xticks([]); axes[1, c].set_yticks([])
fig.colorbar(im_m, ax=axes[0, :].tolist(), fraction=0.02, pad=0.01)
fig.colorbar(im_r, ax=axes[1, :].tolist(), fraction=0.02, pad=0.01)

# median model/residual over the 4 chain medians
model0_median4 = np.median(chain_models, axis=0)
bkg0_median4 = np.median([
    float(np.array(jax.device_get(hmc_chain_median["background"][c]))[0]) for c in range(HMC_CHAINS)
])
resid0_median4 = (star0_data - model0_median4 - bkg0_median4) / star0_err

fig, axes = plt.subplots(1, 2, figsize=(8, 4), constrained_layout=True)
im_a = axes[0].imshow(model0_median4, origin="lower", cmap="viridis", vmin=model_vmin, vmax=model_vmax)
axes[0].set_title("star0 model (median of 4 chains)")
im_b = axes[1].imshow(resid0_median4, origin="lower", cmap="bwr", vmin=-3, vmax=3)
axes[1].set_title("star0 residual (median of 4 chains)")
for ax in axes:
    ax.set_xticks([])
    ax.set_yticks([])
fig.colorbar(im_a, ax=axes[0], fraction=0.046, pad=0.04)
fig.colorbar(im_b, ax=axes[1], fraction=0.046, pad=0.04)