# Power Spectrum Inference

This notebook demonstrates **power-spectrum-level** Bayesian inference of cosmological
parameters from convergence maps using the Knox-formula likelihood.

Compared to the full-field approach (notebook 11), this is computationally cheaper
because the data compression step (maps → C_ell) is done once up-front:

1. Load observed convergence maps from a Parquet catalog
2. Compress to all auto- (and optionally cross-) angular power spectra
3. Build the NumPyro power-spectrum model with the Knox-formula likelihood
4. Sample the posterior with MCMC (`batched_sampling`)

## Setup

In [None]:
import os

os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "0.97"

import jax
import jax.numpy as jnp
import jax_cosmo as jc
import jax_fli as jfli
import numpy as np
from numpyro.handlers import condition

print(f"Number of devices: {jax.device_count()}")
print(f"Devices: {jax.devices()}")
jax.config.update("jax_enable_x64", False)

## Load Catalog

Set `CATALOG_PATH` to a Parquet file produced by a previous simulation run.
The catalog contains convergence maps for each redshift bin as a batched
`SphericalDensity` (or `FlatDensity`) field.

In [None]:
CATALOG_PATH = "output/fields/born_catalog.parquet"  # <- update to your path

catalog = jfli.io.Catalog.from_parquet(CATALOG_PATH)
print(f"Catalog entries: {len(catalog)}")

# The catalog field is a batched SphericalDensity with shape (n_bins, npix)
kappa_field = catalog.field[0]
print(f"Kappa field type:  {type(kappa_field).__name__}")
print(f"Kappa field shape: {kappa_field.array.shape}")

## Compute Observed C_ell

`cross_angular_cl()` computes all K = B*(B+1)/2 auto- and cross-spectra for a
batched field with B bins in a single healpy call.  The returned `PowerSpectrum`
stores the ell-array in `.wavenumber` and all spectra in `.array` with shape
`(K, n_ell)`, ordered `(0,0), (0,1), …, (B-1,B-1)`.

In [None]:
# For spherical maps use lmax; for flat maps use ell_edges instead
LMAX = 512

ps = kappa_field.cross_angular_cl(lmax=LMAX)

print(f"wavenumber shape: {ps.wavenumber.shape}")
print(f"spectra shape:    {ps.array.shape}")
print(f"ell range: [{ps.wavenumber[0]:.0f}, {ps.wavenumber[-1]:.0f}]")

## Build Configuration

We pass `ells=ps.wavenumber` so the theory model evaluates at exactly the
same multipoles as the observed spectra. `f_sky` corrects the Knox formula
for partial sky coverage; set it to the observed sky fraction.

In [None]:
nz_sources = jfli.io.get_stage3_nz_shear()

priors = {
    "Omega_c": jfli.infer.dist.PreconditionnedUniform(0.1, 0.5),
    "sigma8": jfli.infer.dist.PreconditionnedUniform(0.6, 1.0),
}

config = jfli.ppl.Configurations(
    mesh_size=(64, 64, 64),  # not used by power-spec model but required by dataclass
    box_size=(500.0, 500.0, 500.0),
    nside=kappa_field.nside,
    nz_shear=nz_sources,
    fiducial_cosmology=jc.Planck18,
    sigma_e=0.26,
    priors=priors,
    geometry="spherical",
    ells=ps.wavenumber,
    f_sky=1.0,  # update to the actual observed sky fraction
)

print(f"nside:     {config.nside}")
print(f"f_sky:     {config.f_sky}")
print(f"n_ells:    {len(config.ells)}")

## Condition on Observed Data

`numpyro.handlers.condition` replaces the `C_ell_*` sample sites with the
observed spectra, turning the joint model into a likelihood for the
cosmological parameters.

In [None]:
model = jfli.ppl.powerspec_probmodel(config)
obs = {"c_ell": ps.array.flatten()}
conditioned = condition(model, data=obs)

# Starting point for HMC: fiducial cosmology values
fiducial = jc.Planck18()
init_params = {
    "Omega_c": jnp.array(fiducial.Omega_c),
    "sigma8": jnp.array(fiducial.sigma8),
}

print("Init params:")
for k, v in init_params.items():
    print(f"  {k}: {float(v):.4f}")

## Define Save Callback

For power-spectrum inference the samples are small (just cosmological parameters),
so we save them as a simple `.npz` file instead of a Parquet catalog.

## Run MCMC Sampling

`batched_sampling` runs NUTS in sequential batches, checkpointing state after
each batch so a long run can be interrupted and resumed.

For production use, increase `num_warmup`, `num_samples`, and `batch_count`.

In [None]:
OUTPUT_PATH = "output/ps_inference"
sample_key = jax.random.PRNGKey(42)

jfli.infer.batched_sampling(
    conditioned,
    init_params=init_params,
    path=OUTPUT_PATH,
    rng_key=sample_key,
    num_warmup=100,
    num_samples=500,
    batch_count=10,
    sampler="NUTS",
    backend="blackjax",
    progress_bar=True,
    save_callback=jfli.infer.sample2catalog(config),
)

## Load and Inspect Results

In [None]:
import glob

npz_files = sorted(glob.glob(f"{OUTPUT_PATH}/samples_*.npz"))
print(f"Found {len(npz_files)} batch file(s): {npz_files}")

all_samples = {}
for f in npz_files:
    data = np.load(f)
    for k in data.files:
        all_samples.setdefault(k, []).append(data[k])

# Concatenate batches along the sample axis
all_samples = {k: np.concatenate(v, axis=0) for k, v in all_samples.items()}

print("\nParameter shapes and posterior means:")
for k, v in all_samples.items():
    if v.ndim >= 1 and np.issubdtype(v.dtype, np.floating):
        print(f"  {k}: shape={v.shape}  mean={v.mean():.4f}  std={v.std():.4f}")