JAXlaxy is a curated, opinionated, and constantly updated map of the JAX ecosystem.
In 2026, JAX has evolved from a research curiosity into the backbone of global AI. It powers the world's largest Foundation Models, real-time physics simulators, and structural biology breakthroughs. While the original DeepMind "JAX Ecosystem" post (2020) is a classic archive, the universe has expanded. JAXlaxy filters the noise to show you the brightest stars.
Every entry ships with a health indicator β π’ Active Β· π‘ Stable Β· π΄ Legacy β so you can tell at a glance whether a library is the right launchpad in 2026 or a relic better admired than flown.
- βοΈ The Sun β Core & Kernels
- πͺ The Giants β Neural Network Frameworks
- π°οΈ The Satellites β Training Infrastructure
- π Constellations β LLM & Foundation-Model Training
- π§ͺ Scientific Computing & Simulation
- π Probabilistic Programming
- π€ Reinforcement Learning & Evolution
- π Domain Libraries β Graphs, Vision, Brain Dynamics
- π§ The Pathfinder β What should you choose?
- πΈ Onramps β Reference Implementations
β οΈ Legacy Radar β The warning zone
The foundational technologies that make high-performance JAX computing possible. If your stack has JAX, it has these.
- π’ JAX: Autograd and XLA-powered numerical computing β the sun of our universe.
- π’ Pallas: JAX-native kernel authoring language for writing custom TPU/GPU kernels without leaving Python.
- π’ OpenXLA: The open-source compiler that turns your JAX program into blazing accelerator code.
- π’ torchax: Run PyTorch model code directly on JAX β the bridge for teams migrating in from the PyTorch universe.
The primary solar systems for building and training neural networks in JAX.
- π’ Flax (NNX): The 2026 standard for DeepMind and Google-scale research. NNX brings ergonomic object-oriented state on top of JAX's functional core.
- π’ Equinox: Everything is a PyTree. Minimal magic, maximum transparency β the default pick for scientific ML and PyTorch converts.
- π’ Penzai: DeepMind's toolkit for legible, introspectable, surgically-editable neural networks β the interpretability researcher's first choice.
π‘ Pragmatic multi-framework options. Teams already invested in the Keras or HuggingFace ecosystem can use Keras 3 with its JAX backend or HuggingFace Transformers Flax models β not JAX-native in design, but battle-tested bridges for real production stacks.
The utilities that keep your training loop stable, your weights safe, and your data moving at TPU speed.
- π’ Optax: The undisputed king of gradient processing β composable optimizer transforms used in every serious JAX training loop.
- π’ Orbax: The 2026 standard for multi-host checkpointing, model exporting, and resumable training at scale.
- π’ Grain: Deterministic, JAX-native high-throughput data loading β the TPU-era replacement for
tf.data. - π’ Chex: DeepMind's assertions and testing toolkit for writing JAX code you can actually trust.
- π’ jaxtyping: Shape-and-dtype-aware type hints that catch bugs before the first JIT compile.
- π‘ safejax: Serialize Flax / Haiku / Equinox parameters via
safetensorsβ safer than pickling, portable across frameworks. - π’ jax-tqdm: Add a real progress bar to JIT-compiled
jax.lax.scanand training loops β one decorator, zero friction. - π’ JAX Toolbox: NVIDIA's nightly CI and optimized container images for running JAX workloads on H100/B200 GPUs.
- π’ mpi4jax: Zero-copy MPI collectives inside JIT-compiled JAX code β the bridge for classical HPC clusters.
Large-scale training stacks. Pick by workload scale and how much of the stack you want to own.
- π’ MaxText: Google's flagship pure-JAX LLM reference β scales from single-TPU to multi-pod without leaving Python.
- π’ Tunix: Google's post-training toolkit on JAX β SFT, RLHF (PPO/GRPO/DAPO), and agentic RL with tool use, built on Flax NNX.
- π’ Levanter: Stanford CRFM's scalable, reproducible foundation-model trainer with named tensors and bit-level determinism.
- π’ EasyDeL: Opinionated training and serving for Llama/Mixtral/Falcon/Qwen families in JAX β ergonomics-first.
- π’ kvax: A production-grade FlashAttention implementation for JAX with document-mask and context-parallel support.
- π‘ Lorax: Automatic LoRA injection for any JAX model (Flax, Haiku, Equinox) β one line to fine-tune at a fraction of the memory.
- π’ FlaxDiff: Multi-node, multi-device diffusion model training on TPUs β the LLM world has a sibling.
- π‘ EasyLM: Pre-train, fine-tune, evaluate and serve LLMs in JAX/Flax β the original reference used in several early open models; stable but slower-moving.
Where JAX truly outshines PyTorch: differentiable physics, biology, cosmology, and inverse problems.
- π’ AlphaFold 3: DeepMind's state-of-the-art predictor for protein, nucleic-acid and ligand structure. Weights gated; non-commercial research license.
- π‘ jax-unirep: Fast UniRep protein-embedding models β the pragmatic starting point for sequence-level protein ML.
- π’ Diffrax: Numerical ODE / SDE / CDE solvers in JAX β the canonical answer for Neural ODEs and continuous-time models.
- π’ Optimistix: Root-finding, minimization, fixed-points and nonlinear least-squares β the "SciPy optimize" of JAX.
- π’ JAXopt: Hardware-accelerated, batchable, differentiable optimizers β great for bi-level and implicit-differentiation problems.
- π’ Brax: Massively parallel differentiable rigid-body physics β training humanoid policies on a single GPU in minutes.
- π’ JAX-MD: Differentiable molecular dynamics at accelerator speed β end-to-end backprop through an MD trajectory.
- π’ dynamiqs: High-performance, differentiable simulation of open and closed quantum systems in JAX.
- π’ XLB: Autodesk's differentiable, massively parallel Lattice-Boltzmann fluid solver for ML-in-the-loop CFD.
- π’ FDTDX: Finite-Difference Time-Domain electromagnetic simulation in JAX β design photonic devices with autograd.
- π’ JaxDF: Write differentiable PDE simulators with arbitrary discretizations β the building block for inverse-problem science.
- π’ JAX-in-Cell: Self-consistent particle-in-cell plasma simulations β classical HPC physics meets JAX autodiff.
- π’ foragax: Agent-based modelling framework in JAX β fast, vectorized, auto-differentiable social and ecological sims.
- π’ jax-cosmo: Differentiable cosmology β end-to-end gradients through large-scale-structure likelihoods.
- π’ astronomix: Differentiable (magneto)hydrodynamics for astrophysics β simulate galaxy-scale flows with autograd.
- π’ exojax: Automatically differentiable spectrum modelling of exoplanets and brown dwarfs.
- π’ jwave: Differentiable acoustic wave simulation β for medical-imaging and photoacoustic inverse problems.
- π’ SCICO: Los Alamos' scientific computational imaging β plug-and-play priors and inverse problems in JAX.
- π’ MBIRJAX: High-performance tomographic reconstruction β CT and 3D imaging with modern regularizers.
- π’ DiffeRT: Differentiable ray tracing for radio propagation β wireless-channel modelling with gradients.
- π’ tmmax: Vectorized transfer-matrix method for thin-film optics β the Swiss Army knife for photonic stacks.
- π’ vivsim: Fluid-structure interaction via the Immersed-Boundary Lattice-Boltzmann method β engineering-grade FSI with autograd.
Bayesian inference, sampling, and uncertainty β JAX's vectorized scans make MCMC fly.
- π’ NumPyro: The mainstream full-DSL probabilistic programming language β Pyro semantics on a JAX engine.
- π’ BlackJAX: Composable samplers β NUTS, HMC, SMC, VI β with no DSL lock-in. Bring your own log-density.
- π’ Distrax: DeepMind's lightweight distributions and bijectors library β a pragmatic alternative to TFP when you want minimum dependencies.
- π’ Dynamax: Probabilistic state-space models β HMMs, LGSSMs, nonlinear filters β with Kevin Murphy's seal of approval.
- π’ GPJax: Gaussian Processes in JAX β a didactic, extensible framework for kernel machines.
- π’ tinygp: The tiniest GP library β fast, elegant, and built for astronomers by Dan Foreman-Mackey.
- π’ flowjax: Normalizing flows built as Equinox modules β density estimation with a clean PyTree interface.
- π’ bayex: Bayesian optimization powered by JAX β hyperparameter tuning that runs inside your training job.
- π‘ Oryx: Probabilistic programming via program transformations β inside TensorFlow Probability, niche but powerful for researchers.
π‘ Also worth knowing:
tfp.substrates.jaxβ TensorFlow Probability's distributions, MCMC, and VI running on a pure JAX substrate.
End-to-end JIT'd training loops and hardware-accelerated environments β the JAX RL stack trains in minutes what TensorFlow took hours to simulate.
- π‘ PureJaxRL: Fully vectorized, end-to-end JIT'd RL pipelines β PPO on 2048 envs without leaving JAX. Low recent commit activity, but remains the canonical reference for the JAX-native RL-loop pattern.
- π’ Jumanji: InstaDeep's suite of industry-driven, hardware-accelerated RL environments β from bin-packing to routing.
- π’ gymnax: Classic Gym environments re-implemented in JAX β CartPole, Atari-lite, bsuite, and more, all JIT-compatible.
- π’ Pgx: Vectorized board-game environments with an AlphaZero reference β Chess, Shogi, Go at scale.
- π’ NAVIX: MiniGrid reimplemented in pure JAX β RL gridworlds that train in seconds, not hours.
- π’ QDax: Quality-Diversity optimization β MAP-Elites and neuro-evolution at accelerator speed.
- π’ evosax: JAX-based evolutionary strategies β CMA-ES, OpenAI-ES, NSLC, ready to vectorize.
- π’ RLax: DeepMind's RL building blocks β value functions, distributional losses, exploration β the LEGO set, not the agent.
- π’ Mctx: DeepMind's Monte-Carlo Tree Search primitives in native JAX β MuZero-style planning, vectorized.
π‘ For continuous-control physics environments, see Brax in the Scientific Computing section β it doubles as a world-class RL env suite. π‘ EvoJAX β the original "put ES on TPU" toolkit β has been archived; see the Legacy Radar. Use evosax or QDax above.
Specialized stellar systems: graphs, vision, neuroscience, and privacy-preserving compute.
- π‘ PGMax: Discrete probabilistic graphical models with loopy-BP and smooth-minimum-sum inference in JAX.
β οΈ Jraph β the de facto GNN library in JAX β was archived by DeepMind. See the Legacy Radar; no drop-in JAX-native successor exists yet.
- π’ Scenic: Google Research's JAX/Flax library for vision transformers, video, and multi-modal research β the living vision codebase in JAX.
- π’ dm_pix: DeepMind's image-processing primitives for JAX β JIT-friendly augmentations and color ops.
β οΈ Note on vision model zoos: most Flax/Equinox pre-trained-weight repos (FlaxVision, jax-models, Eqxvision) have gone dormant. See the Legacy Radar and prefer Scenic or HuggingFace Transformers' Flax models for new work.
- π’ BrainPy: Computational neuroscience and brain-inspired computing β differentiable spiking networks and neural dynamics.
- π’ brainunit: Physical units and unit-aware arithmetic inside JAX β make your neuroscience code dimensionally safe.
- π’ brainstate: State-based program compilation for brain-dynamics models β augmenting JAX's functional core with stateful ergonomics.
- π’ dendritex: Compartmental dendritic neuron modelling in JAX β cable-equation dynamics at GPU speed.
- π’ Spyx: Spiking Neural Networks in JAX β neuromorphic-style learning with modern accelerators.
- π’ OTT-JAX: Optimal transport β Sinkhorn, low-rank Gromov-Wasserstein, and neural OT β the reference toolkit in JAX.
- π’ Coreax: GCHQ's coreset algorithms for compressing large datasets while preserving statistical structure.
- π’ SPU: A compiler + runtime for running JAX programs under Secure Multi-Party Computation β privacy-preserving ML, the compiler way.
Real 2026 user journeys β recommended stars. Pick the row that matches your mission.
| Your Mission | Recommended | Status | Why |
|---|---|---|---|
| Large-scale LLM training on TPU | MaxText | π’ | Pure-JAX, scales multi-pod, battle-tested on Gemini-class workloads |
| Post-training LLMs (SFT + RLHF) | Tunix | π’ | PPO/GRPO/DAPO with tool-using agents, on Flax NNX |
| Foundation-model research (custom arch) | Flax NNX | π’ | DeepMind's new OO-ergonomic API β Haiku's successor |
| Scientific ML / PyTorch-style transparency | Equinox | π’ | Callable PyTrees, minimal magic, strong sci-ML adoption |
| Neural ODEs / continuous-time models | Diffrax | π’ | The canonical differential-equation solver in JAX |
| Differentiable physics simulation | Brax or JAX-MD | π’ | Brax = rigid body, JAX-MD = molecular dynamics |
| Probabilistic modelling (full DSL) | NumPyro | π’ | Pyro-lineage, fast, mainstream |
| Sampling only (MCMC / SMC / VI) | BlackJAX | π’ | Composable samplers, no DSL lock-in |
| RL research | PureJaxRL + Jumanji / gymnax | π‘ | End-to-end JIT'd loops (PureJaxRL is in maintenance but still canonical) + π’ accelerator-native envs |
| Protein / biomolecular structure | AlphaFold 3 | π’ | DeepMind canonical (research license) |
| Interpretability / model surgery | Penzai | π’ | DeepMind's introspective modelling library |
| TPU/GPU kernel authoring | Pallas | π’ | JAX-native β write kernels without leaving Python |
| Concern | Use | Why |
|---|---|---|
| Optimizers | Optax | Composable gradient transforms β universal adoption |
| Checkpointing | Orbax | Multi-host, resumable, the 2026 standard |
| Data loading | Grain | Deterministic, JAX-native, replaces tf.data |
| Testing & invariants | Chex | DeepMind's assertions library for JAX code |
| Shape-safe types | jaxtyping | Catches shape bugs before JIT |
Canonical codebases to read when you're learning how idiomatic JAX is written. Pick one near your mission and study it.
- π MaxText Examples β see how a production-grade, multi-pod LLM training loop is structured in pure JAX.
- π§ PureJaxRL Tutorials β the clearest demonstration of end-to-end JIT'd training-loop design in the JAX world.
- π§ͺ JAX-MD Notebooks β differentiable molecular dynamics from first principles, with narrative tutorials.
- π¬ Penzai Tutorials β model introspection and editing β a great example of composable PyTree APIs.
- π NumPyro Examples β Bayesian inference recipes, from linear regression to deep GPs.
- πΈ Flax NNX Guides β the official "how to think in NNX" walkthrough; the best starting point if you're coming from Haiku or PyTorch.
These pioneers lit the way, but the galaxy has moved on. Each entry explains what to use instead β if you're reading old tutorials, read this first.
- π΄ Haiku (dm-haiku) β In maintenance mode. DeepMind's new research has shifted to Flax NNX. β Use Flax NNX.
- π΄ Trax β Effectively abandoned (Google). β Use MaxText for scale, Flax NNX for research.
- π΄ Objax β No longer actively developed. β Use Equinox (similar OO feel) or Flax NNX.
- π΄ Elegy β Unmaintained. β Use Flax NNX or Equinox directly.
- π΄ SymJAX β Superseded by native JAX tracing. β Use JAX itself.
- π΄ Parallax β Archived experimental project. β Use Equinox for "immutable modules" ergonomics.
- π΄ mcx β Sampling DSL superseded by the community. β Use BlackJAX.
- π΄ Coax β Slowed to a crawl. β Use PureJaxRL.
- π΄ EvoJAX β Archived by Google. β Use evosax (general ES) or QDax (Quality-Diversity).
- π΄ Jraph β Archived by DeepMind in 2024. β No drop-in JAX-native successor. For new GNN work, consider PyTorch Geometric or compose message-passing primitives with Equinox.
- π΄ FlaxVision / jax-models / Eqxvision β Dormant vision model zoos. β Use Scenic or HuggingFace Transformers' Flax models.
- π‘ Pax / Praxis β Still ships, but Google's external narrative has moved to MaxText + Flax NNX. Google-internal lineage; new users should not start here.
- π‘ Neural Tangents β Low activity. Stays here because it's the canonical library for its (niche) infinite-width-network research area.
- π‘ FedJAX β Federated-learning niche with minimal recent activity. No clear successor in JAX; still worth knowing if you work on FL.
The JAX multiverse is expanding. If you see a new star β or see a dying one β open an Issue or a PR. See CONTRIBUTING.md for cosmic guidelines.
To audit the health of the stars yourself, run:
python scripts/health_check.pyThis fetches last-commit dates and star counts for every linked repo and flags classifications that look stale.
The galaxy is also audited automatically every Monday via GitHub Actions β dying stars get flagged in an Issue before you ever need to look. See CONTRIBUTING.md for details.
Prefer to see the galaxy? The Observatory is a cinematic 3D map of every entry β size-by-stars, color-by-health, click-to-explore. cd observatory && npm run dev from a fresh clone.
Maintained with β€οΈ by the JAX community. 2026 Edition. Let's map the stars. π