Differentiable cosmological forward modeling on JAX
jax-fli is a JAX toolkit for end-to-end differentiable cosmological simulations. It chains initial conditions, Lagrangian Perturbation Theory, Particle-Mesh N-body integration, lightcone painting (3D, flat-sky, HEALPix), gravitational lensing (Born and ray-tracing), and angular power spectrum estimation into a single differentiable pipeline. The library supports multi-GPU distribution via JAX sharding, reversible solvers for memory-efficient backpropagation, and probabilistic inference with BlackJAX/NumPyro.
ICs ──> LPT ──> PM N-body ──> Lightcone Painting ──> Lensing ──> Power Spectra
│ │
symplectic KDK/KKD 3D / flat-sky / HEALPix
+ PGD correction + interpolation kernels
- N-body solvers -- Symplectic DKD (
EfficientDriftDoubleKick) and reversible KKD (ReversibleDoubleKickDrift) integrators - Painting targets -- 3D density, flat-sky 2D projection, and HEALPix spherical maps with CIC / bilinear / NGP / RBF schemes
- Interpolation kernels --
DriftInterp,OnionTiler, andTelephotoInterpfor on-the-fly lightcone construction beyond the box boundary - Correction kernels -- PGD (position-based) and Sharpening (velocity-based, reversible) for sub-grid halo correction
- Gravitational lensing -- Born approximation (fully JIT-able) and ray-tracing (via Dorian) convergence maps
- Multi-GPU -- Distributed simulations via JAX sharding with automatic halo exchange
- Immutable Field PyTrees --
DensityField,ParticleField,FlatDensity,SphericalDensitycarrying arrays + metadata through the pipeline - Power spectra -- 3D P(k), angular C_ell, transfer functions, and theory predictions with Halofit
- Probabilistic inference -- Deterministic forward model builder + NumPyro wrappers + BlackJAX batched sampling
- I/O -- Orbax checkpointing, Parquet serialization, HuggingFace Dataset integration, CosmoGrid/GowerStreet loaders
pip install -e ".[all]"This installs all optional dependencies (lensing, ray-tracing, catalogs, sampling). For specific extras:
pip install -e ".[dev]" # Development tools (pytest, ruff, pre-commit)
pip install -e ".[raytrace]" # Ray-tracing via Dorian
pip install -e ".[catalog]" # Parquet / HuggingFace catalog supportNote: This package depends on custom forks of
jaxpmandjax_cosmothat are git-pinned inpyproject.toml.
import jax
import jax.numpy as jnp
import jax_cosmo as jc
import jax_fli as jfli
key = jax.random.PRNGKey(42)
cosmo = jc.Planck18()
# 1. Gaussian initial conditions
initial_field = jfli.gaussian_initial_conditions(
key, mesh_size=(256, 256, 256), box_size=(1000.0, 1000.0, 1000.0),
cosmo=cosmo, nside=256,
)
# 2. LPT displacement + momentum
dx, p = jfli.lpt(cosmo, initial_field, ts=0.1, order=1)
# 3. PM N-body with spherical lightcone output
solver = jfli.ReversibleDoubleKickDrift(
interp_kernel=jfli.NoInterp(painting=jfli.PaintingOptions(target="spherical")),
)
lightcone = jfli.nbody(cosmo, dx, p, t1=1.0, dt0=0.05, nb_shells=4, solver=solver)
# 4. Born lensing convergence
nz = [jfli.tophat_z(0.0, 0.5, gals_per_arcmin2=1.0)]
kappa = jfli.born(cosmo, lightcone, nz_shear=nz)
# 5. Angular power spectrum
cl = kappa.angular_cl(method="healpy")| # | Notebook | Description |
|---|---|---|
| 01 | Basics | Core objects, field types, painting targets, power spectra, and I/O |
| 02 | LPT Simulation | Full lightcone using LPT with arrays of scale factors and spherical shells |
| 03 | PM Simulation | PM N-body solvers, correction kernels, and painting targets |
| 04 | Distributed PM | Multi-GPU distributed PM using JAX sharding and device meshes |
| 05 | PM Interpolation | Lightcone painting with TelephotoInterp and OnionTiler kernels |
| 06 | Advanced PM | Production pipeline combining PGD correction, OnionTiler, and theory validation |
| 07 | Lensing | Born approximation and ray-tracing convergence maps |
| 08 | External Catalogs | Loading CosmoGrid and GowerStreet simulation data |
| 09 | Comparison | Validation against CosmoGrid through power spectra and lensing maps |
| 10 | Probabilistic Modeling | Custom MCMC distributions and reparameterization transforms |
| 11 | Full-Field Inference | Bayesian inference of initial conditions and cosmological parameters |
| Module | Purpose |
|---|---|
fields/ |
Immutable PyTree containers (DensityField, ParticleField, FlatDensity, SphericalDensity) |
initial.py |
Gaussian initial conditions and interpolation to mesh |
pm/ |
Particle-mesh engine: lpt(), nbody(), symplectic solvers, PGD correction, integration loop |
lensing/ |
Born approximation and ray-tracing convergence maps |
power/ |
Power spectrum estimation (P(k), C_ell, transfer, coherence), theory predictions |
probabilistic_models/ |
Deterministic forward model builder, NumPyro wrappers, Configurations dataclass |
sampling/ |
BlackJAX batched sampling, DistributedNormal, chain plotting |
io/ |
Checkpoint persistence, HuggingFace catalog, CosmoGrid/GowerStreet loaders |
utils.py |
Lightcone geometry helpers, comoving distances, scale factors |
parameters.py |
Predefined cosmologies (e.g. Planck18) |
# Install dev dependencies
pip install -e ".[dev]"
# Run tests
pytest
# Lint
ruff check .
# Format + import sort (pre-commit uses yapf + isort)
pre-commit run --all-filesMIT

