# Level 0: Batched MAP Shear Inference

This notebook demonstrates SHINE's **Level 0 sanity check** on a batch of 10 galaxies
sharing the same true shear. Level 0 is the noiseless self-consistency test: when data
is generated by the exact same forward model with effectively zero noise, the MAP
estimate should recover the truth exactly.

All 10 galaxies are inferred **simultaneously** in a single batched MAP call using
`build_batched_model()`, which vmaps the rendering over the batch dimension. This
compiles the entire batch into one XLA program -- far more efficient than looping.

Since there is no noise, MAP (point estimation) is the natural choice -- full MCMC
is unnecessary and much slower.

**Configuration:** Exponential galaxy (hlr=0.5"), Gaussian PSF ($\sigma$=0.382"),
pixel scale 0.263"/px, noise $\sigma = 10^{-6}$, true shear $g_1=0.01$, $g_2=0.00$.

**What we do:**
1. Generate 10 synthetic observations with the same shear
2. Run **batched** MAP inference on all realizations simultaneously
3. Check that MAP estimates match truth with negligible bias

In [None]:
import os
# Workaround: XLA autotuner fails on tiny batched f64 GEMMs (2x2 Jacobian dot)
# on some GPUs. Disabling autotuning selects a default kernel instead.
os.environ.setdefault("XLA_FLAGS", "--xla_gpu_autotune_level=0")

import jax
import jax.numpy as jnp
import numpy as np
import matplotlib.pyplot as plt
import arviz as az

from shine.config import (
    ShineConfig,
    ImageConfig,
    NoiseConfig,
    PSFConfig,
    GalaxyConfig,
    ShearConfig,
    EllipticityConfig,
    PositionConfig,
    InferenceConfig,
    MAPConfig,
    DistributionConfig,
)
from shine.scene import SceneBuilder
from shine.inference import Inference
from shine.validation.simulation import generate_batch_observations
from shine.validation.extraction import (
    extract_convergence_diagnostics,
    extract_shear_estimates,
    check_convergence,
)
from shine.validation.bias_config import ConvergenceThresholds
from shine.validation.statistics import compute_bias_single_point

# Use 64-bit precision for accurate shear recovery
jax.config.update("jax_enable_x64", True)

print(f"JAX devices: {jax.devices()}")

## 1. Configuration

Parameters matched to the [ngmix metacal example](https://github.com/esheldon/ngmix/blob/master/examples/metacal/metacal.py):
- **Galaxy**: Exponential, hlr=0.5" (metacal `gal_hlr=0.5`)
- **PSF**: Gaussian, $\sigma = 0.382"$ (equivalent to Moffat FWHM=0.9")
- **Pixel scale**: 0.263"/px (metacal `scale=0.263`)
- **Noise**: $\sigma = 10^{-6}$ (metacal default `noise=1e-6`)
- **Shear**: $g_1 = 0.01$, $g_2 = 0.00$ (metacal `shear_true=[0.01, 0.00]`)
- **Position**: Fixed at center (metacal uses random subpixel offsets)

In [None]:
# Ground truth shear (matches metacal shear_true=[0.01, 0.00])
G1_TRUE = 0.01
G2_TRUE = 0.00
N_BATCH = 10

config = ShineConfig(
    image=ImageConfig(
        pixel_scale=0.263,       # arcsec/pixel (metacal scale=0.263)
        size_x=48,
        size_y=48,
        n_objects=1,
        fft_size=128,
        noise=NoiseConfig(type="Gaussian", sigma=1e-6),  # metacal noise=1e-6
    ),
    psf=PSFConfig(
        type="Gaussian",
        sigma=0.382,             # arcsec (equivalent to Moffat FWHM=0.9: 0.9/2.355)
    ),
    gal=GalaxyConfig(
        type="Exponential",      # metacal galsim.Exponential
        flux=1.0,                # metacal default flux=1
        half_light_radius=0.5,   # arcsec (metacal gal_hlr=0.5)
        ellipticity=EllipticityConfig(type="E1E2", e1=0.0, e2=0.0),
        shear=ShearConfig(
            type="G1G2",
            g1=DistributionConfig(type="Normal", mean=0.0, sigma=0.05),
            g2=DistributionConfig(type="Normal", mean=0.0, sigma=0.05),
        ),
        position=PositionConfig(
            type="Uniform",
            x_min=23.5, x_max=24.5,
            y_min=23.5, y_max=24.5,
        ),
    ),
    inference=InferenceConfig(
        method="map",
        map_config=MAPConfig(num_steps=200, learning_rate=0.1),
        rng_seed=42,
    ),
)

print(f"Image: {config.image.size_x}x{config.image.size_y} px, "
      f"scale={config.image.pixel_scale}\"/px")
print(f"Galaxy: {config.gal.type}, flux={config.gal.flux}, "
      f"hlr={config.gal.half_light_radius}\"")
print(f"PSF: {config.psf.type}, sigma={config.psf.sigma}\"")
print(f"Noise sigma: {config.image.noise.sigma}")
print(f"True shear: g1={G1_TRUE}, g2={G2_TRUE}")
print(f"Batch size: {N_BATCH}")
print(f"Inference method: {config.inference.method}")

## 2. Generate Observations

All 10 galaxies share the same true shear but have independent (effectively zero) noise
realizations. We use `generate_batch_observations()` to generate and stack all images
into a single `(N_BATCH, nx, ny)` array with a shared PSF.

In [None]:
# Generate N_BATCH observations with the same true shear, stacked into a batch
seeds = list(range(100, 100 + N_BATCH))
shear_pairs = [(G1_TRUE, G2_TRUE)] * N_BATCH

batch_sim = generate_batch_observations(
    config, shear_pairs=shear_pairs, seeds=seeds, run_id_prefix="level0"
)

print(f"Generated {N_BATCH} observations")
print(f"Stacked image shape: {batch_sim.images.shape}")
print(f"PSF type: {type(batch_sim.psf_model).__name__}")
print(f"Run IDs: {batch_sim.run_ids}")

In [None]:
# Visualize a few of the generated images
fig, axes = plt.subplots(2, 5, figsize=(15, 6))
for i, ax in enumerate(axes.flat):
    im = ax.imshow(batch_sim.images[i], origin="lower", cmap="gray_r")
    ax.set_title(f"Galaxy {i}", fontsize=10)
    ax.set_xticks([])
    ax.set_yticks([])
fig.suptitle(
    f"Level 0: 10 Exponential galaxies, g1={G1_TRUE}, g2={G2_TRUE}, "
    f"noise={config.image.noise.sigma}",
    fontsize=12,
)
plt.tight_layout()
plt.show()

## 3. Run Batched MAP Inference

Instead of looping over realizations, we use `build_batched_model()` which creates
a single NumPyro model with a `plate("batch", N_BATCH)` over all parameters and
vmaps the rendering. This compiles to one XLA program and runs all 10 galaxies in
parallel on the GPU.

In [None]:
# Build batched model and run a single MAP inference over all realizations
import time

scene = SceneBuilder(config)
model_fn = scene.build_batched_model(N_BATCH)

map_cfg = config.inference.map_config
print(f"Inference method: {config.inference.method}")
print(f"MAP: {map_cfg.num_steps} steps, lr={map_cfg.learning_rate}")
print(f"Batch size: {N_BATCH}\n")

rng_key = jax.random.PRNGKey(config.inference.rng_seed)
engine = Inference(model=model_fn, config=config.inference)

t0 = time.perf_counter()
batched_estimates = engine.run_map(
    rng_key=rng_key,
    observed_data=batch_sim.images,
    extra_args={"psf": batch_sim.psf_model},
    map_config=map_cfg,
)
elapsed = time.perf_counter() - t0

print(f"\nBatched MAP completed in {elapsed:.2f} s")
for k, v in batched_estimates.items():
    print(f"  {k}: shape={np.array(v).shape}")

# Split batched MAP estimates into per-realization InferenceData objects.
# We do this manually because _map_estimates_to_idata + split_batched_idata
# doesn't work for MAP: ArviZ names the batch dim "g1_dim_0" not "batch".
run_ids = batch_sim.run_ids
idata_list = []
for i in range(N_BATCH):
    single = {}
    for k, v in batched_estimates.items():
        arr = np.atleast_1d(np.array(v))
        single[k] = arr[i] if arr.ndim >= 1 and arr.shape[0] == N_BATCH else arr
    idata = Inference._map_estimates_to_idata(single)
    idata_list.append(idata)

# Print per-realization MAP estimates
for i, run_id in enumerate(run_ids):
    g1_val = float(idata_list[i].posterior.g1.values.flatten()[0])
    g2_val = float(idata_list[i].posterior.g2.values.flatten()[0])
    print(f"  {run_id}: g1={g1_val:+.6f}, g2={g2_val:+.6f}")

## 4. Extract Estimates

For each realization, extract the MAP point estimates and check that they
match the truth.

In [None]:
print(f"Created {N_BATCH} per-realization InferenceData objects from batched MAP")
print(f"Run IDs: {run_ids}")
print(f"Example: {run_ids[0]}, posterior vars: {list(idata_list[0].posterior.data_vars)}")
print(f"  inference_method: {idata_list[0].posterior.attrs.get('inference_method')}")
print(f"  g1 shape: {idata_list[0].posterior.g1.values.shape}")
print(f"  g2 shape: {idata_list[0].posterior.g2.values.shape}")

## 5. Extract Diagnostics

For each realization, extract:
- **Shear estimates**: MAP point estimate (mean = median = value, std = 0)
- **Convergence diagnostics**: sentinel values for MAP (rhat=1, ess=1)

Level 0 acceptance criterion for MAP: the estimate should be very close to truth.

In [None]:
# Extract results for all realizations
results = []
for run_id, single_idata in zip(run_ids, idata_list):
    g1_est = extract_shear_estimates(single_idata, "g1")
    g2_est = extract_shear_estimates(single_idata, "g2")
    diag = extract_convergence_diagnostics(single_idata)
    method = single_idata.posterior.attrs.get("inference_method", "nuts")
    passed = check_convergence(diag, ConvergenceThresholds(), method=method)
    results.append({
        "run_id": run_id,
        "g1_est": g1_est,
        "g2_est": g2_est,
        "diagnostics": diag,
        "passed": passed,
    })

# Summary table
print(f"{'Run ID':<14} {'g1 estimate':>14} {'g2 estimate':>14} {'Pass':>5}")
print("-" * 55)
for r in results:
    print(
        f"{r['run_id']:<14} "
        f"{r['g1_est'].mean:>14.6f} "
        f"{r['g2_est'].mean:>14.6f} "
        f"{'OK' if r['passed'] else 'FAIL':>5}"
    )

## 6. Acceptance Criteria Check

For Level 0 MAP, each realization must satisfy:
1. MAP estimate close to truth (absolute offset $< 10^{-3}$)
2. Convergence always passes for MAP (no sampling diagnostics)

In [None]:
MAX_ABS_OFFSET = 1e-3

all_passed = True
print("Level 0 Acceptance Criteria (MAP)")
print("=" * 80)

for r in results:
    run_id = r["run_id"]
    g1, g2 = r["g1_est"], r["g2_est"]

    # For MAP: check absolute offset from truth
    g1_offset = abs(g1.mean - G1_TRUE)
    g2_offset = abs(g2.mean - G2_TRUE)
    offset_ok = g1_offset < MAX_ABS_OFFSET and g2_offset < MAX_ABS_OFFSET

    passed = offset_ok and r["passed"]
    all_passed = all_passed and passed

    status = "PASS" if passed else "FAIL"
    print(f"\n{run_id} [{status}]")
    print(f"  g1: truth={G1_TRUE:+.4f}  MAP={g1.mean:+.6f}  "
          f"|offset|={g1_offset:.2e}  {'ok' if g1_offset < MAX_ABS_OFFSET else 'FAIL'}")
    print(f"  g2: truth={G2_TRUE:+.4f}  MAP={g2.mean:+.6f}  "
          f"|offset|={g2_offset:.2e}  {'ok' if g2_offset < MAX_ABS_OFFSET else 'FAIL'}")

print("\n" + "=" * 80)
print(f"Overall Level 0 result: {'ALL PASSED' if all_passed else 'SOME FAILED'}")
print(f"  {sum(1 for r in results if r['passed'])}/{len(results)} "
      f"realizations passed")

## 7. Multiplicative Bias

For each realization, compute the multiplicative bias:
$$m = \frac{\bar{g}_{\rm est}}{g_{\rm true}} - 1$$

At Level 0 (noiseless, self-consistent model), we expect $m \approx 0$.

Since $g_2^{\rm true} = 0$ (matching metacal), we cannot compute $m$ for $g_2$
(division by zero). Instead we report the additive residual $c_2 = \bar{g}_2 - 0$.

In [None]:
bias_g1_list = []

print(f"{'Run ID':<14} {'m(g1)':>12} {'c(g2)':>12}")
print("-" * 42)

for r in results:
    b1 = compute_bias_single_point(G1_TRUE, r["g1_est"].mean, r["g1_est"].std, "g1")
    bias_g1_list.append(b1)
    # g2_true=0 so we report additive residual instead of multiplicative bias
    c2 = r["g2_est"].mean - G2_TRUE
    print(f"{r['run_id']:<14} {b1.m:>12.6f} {c2:>12.2e}")

# Ensemble average
m_g1_vals = np.array([b.m for b in bias_g1_list])
c_g2_vals = np.array([r["g2_est"].mean - G2_TRUE for r in results])
print("-" * 42)
print(f"{'Ensemble mean':<14} {m_g1_vals.mean():>12.6f} {c_g2_vals.mean():>12.2e}")

## 8. Diagnostic Plots

In [None]:
# Collect all g1, g2 MAP estimates
g1_means = np.array([r["g1_est"].mean for r in results])
g2_means = np.array([r["g2_est"].mean for r in results])
indices = np.arange(N_BATCH)

fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# g1 estimates
axes[0].scatter(indices, g1_means, marker="o", s=60, color="steelblue",
                zorder=5, label="MAP estimate")
axes[0].axhline(G1_TRUE, color="red", ls="--", lw=2, label=f"Truth = {G1_TRUE}")
axes[0].fill_between([-0.5, N_BATCH - 0.5],
                     G1_TRUE - MAX_ABS_OFFSET,
                     G1_TRUE + MAX_ABS_OFFSET,
                     color="red", alpha=0.1, label=f"$\\pm${MAX_ABS_OFFSET}")
axes[0].set_xlabel("Realization")
axes[0].set_ylabel("$g_1$")
axes[0].set_title("$g_1$ Recovery (MAP)")
axes[0].legend(fontsize=9)
axes[0].set_xticks(indices)

# g2 estimates
axes[1].scatter(indices, g2_means, marker="o", s=60, color="coral",
                zorder=5, label="MAP estimate")
axes[1].axhline(G2_TRUE, color="red", ls="--", lw=2, label=f"Truth = {G2_TRUE}")
axes[1].fill_between([-0.5, N_BATCH - 0.5],
                     G2_TRUE - MAX_ABS_OFFSET,
                     G2_TRUE + MAX_ABS_OFFSET,
                     color="red", alpha=0.1, label=f"$\\pm${MAX_ABS_OFFSET}")
axes[1].set_xlabel("Realization")
axes[1].set_ylabel("$g_2$")
axes[1].set_title("$g_2$ Recovery (MAP)")
axes[1].legend(fontsize=9)
axes[1].set_xticks(indices)

fig.suptitle("Level 0: MAP Shear Recovery Across 10 Realizations", fontsize=13)
plt.tight_layout()
plt.show()

In [None]:
# Bias plots: m(g1) and c(g2)
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# g1: multiplicative bias
axes[0].scatter(indices, m_g1_vals, marker="s", s=60, color="steelblue")
axes[0].axhline(0, color="red", ls="--", lw=2, label="$m = 0$ (no bias)")
axes[0].axhline(m_g1_vals.mean(), color="green", ls=":", lw=1.5,
                label=f"Mean $m$ = {m_g1_vals.mean():.2e}")
axes[0].set_xlabel("Realization")
axes[0].set_ylabel("$m_{g_1}$")
axes[0].set_title("Multiplicative Bias $g_1$")
axes[0].legend(fontsize=9)
axes[0].set_xticks(indices)

# g2: additive bias (g2_true = 0)
axes[1].scatter(indices, c_g2_vals, marker="s", s=60, color="coral")
axes[1].axhline(0, color="red", ls="--", lw=2, label="$c = 0$ (no bias)")
axes[1].axhline(c_g2_vals.mean(), color="green", ls=":", lw=1.5,
                label=f"Mean $c$ = {c_g2_vals.mean():.2e}")
axes[1].set_xlabel("Realization")
axes[1].set_ylabel("$c_{g_2}$")
axes[1].set_title("Additive Bias $g_2$ ($g_2^{\\rm true} = 0$)")
axes[1].legend(fontsize=9)
axes[1].set_xticks(indices)

fig.suptitle("Level 0: Bias per Realization (MAP)", fontsize=13)
plt.tight_layout()
plt.show()

In [None]:
# g1 vs g2 MAP estimates across all realizations
fig, ax = plt.subplots(1, 1, figsize=(6, 6))
ax.scatter(g1_means, g2_means, s=60, color="steelblue", zorder=5, label="MAP estimates")
ax.scatter([G1_TRUE], [G2_TRUE], c="red", s=150, marker="*",
           zorder=10, label="Truth")
ax.set_xlabel("$g_1$")
ax.set_ylabel("$g_2$")
ax.set_title("MAP Estimates: $g_1$ vs $g_2$")
ax.legend()
fig.tight_layout()
plt.show()

## 9. Summary

This Level 0 test validates that SHINE's forward model is **self-consistent**: when
the data is generated from the same model with no noise, the MAP estimate recovers
the true shear values with negligible bias.

Since Level 0 is noiseless, MAP is the natural and fastest inference method --
full MCMC is unnecessary. For higher validation levels (Level 1+) with realistic
noise, NUTS or VI should be used instead.

In [None]:
# Final summary
n_passed = sum(1 for r in results if r["passed"])

print("Level 0 MAP Inference Summary")
print("=" * 50)
print(f"  Batch size:           {N_BATCH}")
print(f"  True shear:           g1={G1_TRUE}, g2={G2_TRUE}")
print(f"  Inference method:     MAP")
print(f"  All passed:           {n_passed}/{N_BATCH}")
print(f"  Mean m(g1):           {m_g1_vals.mean():.2e}")
print(f"  Mean c(g2):           {c_g2_vals.mean():.2e}")
print(f"  Max |g1 offset|:      {max(abs(r['g1_est'].mean - G1_TRUE) for r in results):.2e}")
print(f"  Max |g2 offset|:      {max(abs(r['g2_est'].mean - G2_TRUE) for r in results):.2e}")
print("=" * 50)