Skip to content

BryanBradfo/JAXlaxy

Folders and files

NameName
Last commit message
Last commit date

Latest commit

Β 

History

22 Commits
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 

🌌 JAXlaxy

Your compass for the JAX multiverse

JAX Logo

Maintenance JAX Version Awesome

JAXlaxy is a curated, opinionated, and constantly updated map of the JAX ecosystem.

In 2026, JAX has evolved from a research curiosity into the backbone of global AI. It powers the world's largest Foundation Models, real-time physics simulators, and structural biology breakthroughs. While the original DeepMind "JAX Ecosystem" post (2020) is a classic archive, the universe has expanded. JAXlaxy filters the noise to show you the brightest stars.

Every entry ships with a health indicator β€” 🟒 Active Β· 🟑 Stable Β· πŸ”΄ Legacy β€” so you can tell at a glance whether a library is the right launchpad in 2026 or a relic better admired than flown.


πŸ›°οΈ Navigation


The foundational technologies that make high-performance JAX computing possible. If your stack has JAX, it has these.

  • 🟒 JAX: Autograd and XLA-powered numerical computing β€” the sun of our universe.
  • 🟒 Pallas: JAX-native kernel authoring language for writing custom TPU/GPU kernels without leaving Python.
  • 🟒 OpenXLA: The open-source compiler that turns your JAX program into blazing accelerator code.
  • 🟒 torchax: Run PyTorch model code directly on JAX β€” the bridge for teams migrating in from the PyTorch universe.

The primary solar systems for building and training neural networks in JAX.

  • 🟒 Flax (NNX): The 2026 standard for DeepMind and Google-scale research. NNX brings ergonomic object-oriented state on top of JAX's functional core.
  • 🟒 Equinox: Everything is a PyTree. Minimal magic, maximum transparency β€” the default pick for scientific ML and PyTorch converts.
  • 🟒 Penzai: DeepMind's toolkit for legible, introspectable, surgically-editable neural networks β€” the interpretability researcher's first choice.

πŸ’‘ Pragmatic multi-framework options. Teams already invested in the Keras or HuggingFace ecosystem can use Keras 3 with its JAX backend or HuggingFace Transformers Flax models β€” not JAX-native in design, but battle-tested bridges for real production stacks.


The utilities that keep your training loop stable, your weights safe, and your data moving at TPU speed.

  • 🟒 Optax: The undisputed king of gradient processing β€” composable optimizer transforms used in every serious JAX training loop.
  • 🟒 Orbax: The 2026 standard for multi-host checkpointing, model exporting, and resumable training at scale.
  • 🟒 Grain: Deterministic, JAX-native high-throughput data loading β€” the TPU-era replacement for tf.data.
  • 🟒 Chex: DeepMind's assertions and testing toolkit for writing JAX code you can actually trust.
  • 🟒 jaxtyping: Shape-and-dtype-aware type hints that catch bugs before the first JIT compile.
  • 🟑 safejax: Serialize Flax / Haiku / Equinox parameters via safetensors β€” safer than pickling, portable across frameworks.
  • 🟒 jax-tqdm: Add a real progress bar to JIT-compiled jax.lax.scan and training loops β€” one decorator, zero friction.
  • 🟒 JAX Toolbox: NVIDIA's nightly CI and optimized container images for running JAX workloads on H100/B200 GPUs.
  • 🟒 mpi4jax: Zero-copy MPI collectives inside JIT-compiled JAX code β€” the bridge for classical HPC clusters.

Large-scale training stacks. Pick by workload scale and how much of the stack you want to own.

  • 🟒 MaxText: Google's flagship pure-JAX LLM reference β€” scales from single-TPU to multi-pod without leaving Python.
  • 🟒 Tunix: Google's post-training toolkit on JAX β€” SFT, RLHF (PPO/GRPO/DAPO), and agentic RL with tool use, built on Flax NNX.
  • 🟒 Levanter: Stanford CRFM's scalable, reproducible foundation-model trainer with named tensors and bit-level determinism.
  • 🟒 EasyDeL: Opinionated training and serving for Llama/Mixtral/Falcon/Qwen families in JAX β€” ergonomics-first.
  • 🟒 kvax: A production-grade FlashAttention implementation for JAX with document-mask and context-parallel support.
  • 🟑 Lorax: Automatic LoRA injection for any JAX model (Flax, Haiku, Equinox) β€” one line to fine-tune at a fraction of the memory.
  • 🟒 FlaxDiff: Multi-node, multi-device diffusion model training on TPUs β€” the LLM world has a sibling.
  • 🟑 EasyLM: Pre-train, fine-tune, evaluate and serve LLMs in JAX/Flax β€” the original reference used in several early open models; stable but slower-moving.

Where JAX truly outshines PyTorch: differentiable physics, biology, cosmology, and inverse problems.

🧬 Life Sciences

  • 🟒 AlphaFold 3: DeepMind's state-of-the-art predictor for protein, nucleic-acid and ligand structure. Weights gated; non-commercial research license.
  • 🟑 jax-unirep: Fast UniRep protein-embedding models β€” the pragmatic starting point for sequence-level protein ML.

βž— Differentiable Solvers & Optimization

  • 🟒 Diffrax: Numerical ODE / SDE / CDE solvers in JAX β€” the canonical answer for Neural ODEs and continuous-time models.
  • 🟒 Optimistix: Root-finding, minimization, fixed-points and nonlinear least-squares β€” the "SciPy optimize" of JAX.
  • 🟒 JAXopt: Hardware-accelerated, batchable, differentiable optimizers β€” great for bi-level and implicit-differentiation problems.

βš›οΈ Physics & Molecular Dynamics

  • 🟒 Brax: Massively parallel differentiable rigid-body physics β€” training humanoid policies on a single GPU in minutes.
  • 🟒 JAX-MD: Differentiable molecular dynamics at accelerator speed β€” end-to-end backprop through an MD trajectory.
  • 🟒 dynamiqs: High-performance, differentiable simulation of open and closed quantum systems in JAX.
  • 🟒 XLB: Autodesk's differentiable, massively parallel Lattice-Boltzmann fluid solver for ML-in-the-loop CFD.
  • 🟒 FDTDX: Finite-Difference Time-Domain electromagnetic simulation in JAX β€” design photonic devices with autograd.
  • 🟒 JaxDF: Write differentiable PDE simulators with arbitrary discretizations β€” the building block for inverse-problem science.
  • 🟒 JAX-in-Cell: Self-consistent particle-in-cell plasma simulations β€” classical HPC physics meets JAX autodiff.
  • 🟒 foragax: Agent-based modelling framework in JAX β€” fast, vectorized, auto-differentiable social and ecological sims.

πŸ”­ Cosmology & Astrophysics

  • 🟒 jax-cosmo: Differentiable cosmology β€” end-to-end gradients through large-scale-structure likelihoods.
  • 🟒 astronomix: Differentiable (magneto)hydrodynamics for astrophysics β€” simulate galaxy-scale flows with autograd.
  • 🟒 exojax: Automatically differentiable spectrum modelling of exoplanets and brown dwarfs.

πŸ“‘ Imaging, Signals & Tomography

  • 🟒 jwave: Differentiable acoustic wave simulation β€” for medical-imaging and photoacoustic inverse problems.
  • 🟒 SCICO: Los Alamos' scientific computational imaging β€” plug-and-play priors and inverse problems in JAX.
  • 🟒 MBIRJAX: High-performance tomographic reconstruction β€” CT and 3D imaging with modern regularizers.
  • 🟒 DiffeRT: Differentiable ray tracing for radio propagation β€” wireless-channel modelling with gradients.
  • 🟒 tmmax: Vectorized transfer-matrix method for thin-film optics β€” the Swiss Army knife for photonic stacks.
  • 🟒 vivsim: Fluid-structure interaction via the Immersed-Boundary Lattice-Boltzmann method β€” engineering-grade FSI with autograd.

Bayesian inference, sampling, and uncertainty β€” JAX's vectorized scans make MCMC fly.

  • 🟒 NumPyro: The mainstream full-DSL probabilistic programming language β€” Pyro semantics on a JAX engine.
  • 🟒 BlackJAX: Composable samplers β€” NUTS, HMC, SMC, VI β€” with no DSL lock-in. Bring your own log-density.
  • 🟒 Distrax: DeepMind's lightweight distributions and bijectors library β€” a pragmatic alternative to TFP when you want minimum dependencies.
  • 🟒 Dynamax: Probabilistic state-space models β€” HMMs, LGSSMs, nonlinear filters β€” with Kevin Murphy's seal of approval.
  • 🟒 GPJax: Gaussian Processes in JAX β€” a didactic, extensible framework for kernel machines.
  • 🟒 tinygp: The tiniest GP library β€” fast, elegant, and built for astronomers by Dan Foreman-Mackey.
  • 🟒 flowjax: Normalizing flows built as Equinox modules β€” density estimation with a clean PyTree interface.
  • 🟒 bayex: Bayesian optimization powered by JAX β€” hyperparameter tuning that runs inside your training job.
  • 🟑 Oryx: Probabilistic programming via program transformations β€” inside TensorFlow Probability, niche but powerful for researchers.

πŸ’‘ Also worth knowing: tfp.substrates.jax β€” TensorFlow Probability's distributions, MCMC, and VI running on a pure JAX substrate.


End-to-end JIT'd training loops and hardware-accelerated environments β€” the JAX RL stack trains in minutes what TensorFlow took hours to simulate.

  • 🟑 PureJaxRL: Fully vectorized, end-to-end JIT'd RL pipelines β€” PPO on 2048 envs without leaving JAX. Low recent commit activity, but remains the canonical reference for the JAX-native RL-loop pattern.
  • 🟒 Jumanji: InstaDeep's suite of industry-driven, hardware-accelerated RL environments β€” from bin-packing to routing.
  • 🟒 gymnax: Classic Gym environments re-implemented in JAX β€” CartPole, Atari-lite, bsuite, and more, all JIT-compatible.
  • 🟒 Pgx: Vectorized board-game environments with an AlphaZero reference β€” Chess, Shogi, Go at scale.
  • 🟒 NAVIX: MiniGrid reimplemented in pure JAX β€” RL gridworlds that train in seconds, not hours.
  • 🟒 QDax: Quality-Diversity optimization β€” MAP-Elites and neuro-evolution at accelerator speed.
  • 🟒 evosax: JAX-based evolutionary strategies β€” CMA-ES, OpenAI-ES, NSLC, ready to vectorize.
  • 🟒 RLax: DeepMind's RL building blocks β€” value functions, distributional losses, exploration β€” the LEGO set, not the agent.
  • 🟒 Mctx: DeepMind's Monte-Carlo Tree Search primitives in native JAX β€” MuZero-style planning, vectorized.

πŸ’‘ For continuous-control physics environments, see Brax in the Scientific Computing section β€” it doubles as a world-class RL env suite. πŸ’‘ EvoJAX β€” the original "put ES on TPU" toolkit β€” has been archived; see the Legacy Radar. Use evosax or QDax above.


Specialized stellar systems: graphs, vision, neuroscience, and privacy-preserving compute.

πŸ•ΈοΈ Graphs & Structured Models

  • 🟑 PGMax: Discrete probabilistic graphical models with loopy-BP and smooth-minimum-sum inference in JAX.

⚠️ Jraph β€” the de facto GNN library in JAX β€” was archived by DeepMind. See the Legacy Radar; no drop-in JAX-native successor exists yet.

πŸ–ΌοΈ Vision

  • 🟒 Scenic: Google Research's JAX/Flax library for vision transformers, video, and multi-modal research β€” the living vision codebase in JAX.
  • 🟒 dm_pix: DeepMind's image-processing primitives for JAX β€” JIT-friendly augmentations and color ops.

⚠️ Note on vision model zoos: most Flax/Equinox pre-trained-weight repos (FlaxVision, jax-models, Eqxvision) have gone dormant. See the Legacy Radar and prefer Scenic or HuggingFace Transformers' Flax models for new work.

🧠 Brain Dynamics Programming

  • 🟒 BrainPy: Computational neuroscience and brain-inspired computing β€” differentiable spiking networks and neural dynamics.
  • 🟒 brainunit: Physical units and unit-aware arithmetic inside JAX β€” make your neuroscience code dimensionally safe.
  • 🟒 brainstate: State-based program compilation for brain-dynamics models β€” augmenting JAX's functional core with stateful ergonomics.
  • 🟒 dendritex: Compartmental dendritic neuron modelling in JAX β€” cable-equation dynamics at GPU speed.
  • 🟒 Spyx: Spiking Neural Networks in JAX β€” neuromorphic-style learning with modern accelerators.

πŸ›‘οΈ Specialty

  • 🟒 OTT-JAX: Optimal transport β€” Sinkhorn, low-rank Gromov-Wasserstein, and neural OT β€” the reference toolkit in JAX.
  • 🟒 Coreax: GCHQ's coreset algorithms for compressing large datasets while preserving statistical structure.
  • 🟒 SPU: A compiler + runtime for running JAX programs under Secure Multi-Party Computation β€” privacy-preserving ML, the compiler way.

Real 2026 user journeys β†’ recommended stars. Pick the row that matches your mission.

Your Mission Recommended Status Why
Large-scale LLM training on TPU MaxText 🟒 Pure-JAX, scales multi-pod, battle-tested on Gemini-class workloads
Post-training LLMs (SFT + RLHF) Tunix 🟒 PPO/GRPO/DAPO with tool-using agents, on Flax NNX
Foundation-model research (custom arch) Flax NNX 🟒 DeepMind's new OO-ergonomic API β€” Haiku's successor
Scientific ML / PyTorch-style transparency Equinox 🟒 Callable PyTrees, minimal magic, strong sci-ML adoption
Neural ODEs / continuous-time models Diffrax 🟒 The canonical differential-equation solver in JAX
Differentiable physics simulation Brax or JAX-MD 🟒 Brax = rigid body, JAX-MD = molecular dynamics
Probabilistic modelling (full DSL) NumPyro 🟒 Pyro-lineage, fast, mainstream
Sampling only (MCMC / SMC / VI) BlackJAX 🟒 Composable samplers, no DSL lock-in
RL research PureJaxRL + Jumanji / gymnax 🟑 End-to-end JIT'd loops (PureJaxRL is in maintenance but still canonical) + 🟒 accelerator-native envs
Protein / biomolecular structure AlphaFold 3 🟒 DeepMind canonical (research license)
Interpretability / model surgery Penzai 🟒 DeepMind's introspective modelling library
TPU/GPU kernel authoring Pallas 🟒 JAX-native β€” write kernels without leaving Python

πŸ”§ Plumbing (pick these regardless of mission)

Concern Use Why
Optimizers Optax Composable gradient transforms β€” universal adoption
Checkpointing Orbax Multi-host, resumable, the 2026 standard
Data loading Grain Deterministic, JAX-native, replaces tf.data
Testing & invariants Chex DeepMind's assertions library for JAX code
Shape-safe types jaxtyping Catches shape bugs before JIT

Canonical codebases to read when you're learning how idiomatic JAX is written. Pick one near your mission and study it.

  • πŸš€ MaxText Examples β€” see how a production-grade, multi-pod LLM training loop is structured in pure JAX.
  • 🧭 PureJaxRL Tutorials β€” the clearest demonstration of end-to-end JIT'd training-loop design in the JAX world.
  • πŸ§ͺ JAX-MD Notebooks β€” differentiable molecular dynamics from first principles, with narrative tutorials.
  • πŸ”¬ Penzai Tutorials β€” model introspection and editing β€” a great example of composable PyTree APIs.
  • πŸ“Š NumPyro Examples β€” Bayesian inference recipes, from linear regression to deep GPs.
  • 🌸 Flax NNX Guides β€” the official "how to think in NNX" walkthrough; the best starting point if you're coming from Haiku or PyTorch.

These pioneers lit the way, but the galaxy has moved on. Each entry explains what to use instead β€” if you're reading old tutorials, read this first.

  • πŸ”΄ Haiku (dm-haiku) β€” In maintenance mode. DeepMind's new research has shifted to Flax NNX. β†’ Use Flax NNX.
  • πŸ”΄ Trax β€” Effectively abandoned (Google). β†’ Use MaxText for scale, Flax NNX for research.
  • πŸ”΄ Objax β€” No longer actively developed. β†’ Use Equinox (similar OO feel) or Flax NNX.
  • πŸ”΄ Elegy β€” Unmaintained. β†’ Use Flax NNX or Equinox directly.
  • πŸ”΄ SymJAX β€” Superseded by native JAX tracing. β†’ Use JAX itself.
  • πŸ”΄ Parallax β€” Archived experimental project. β†’ Use Equinox for "immutable modules" ergonomics.
  • πŸ”΄ mcx β€” Sampling DSL superseded by the community. β†’ Use BlackJAX.
  • πŸ”΄ Coax β€” Slowed to a crawl. β†’ Use PureJaxRL.
  • πŸ”΄ EvoJAX β€” Archived by Google. β†’ Use evosax (general ES) or QDax (Quality-Diversity).
  • πŸ”΄ Jraph β€” Archived by DeepMind in 2024. β†’ No drop-in JAX-native successor. For new GNN work, consider PyTorch Geometric or compose message-passing primitives with Equinox.
  • πŸ”΄ FlaxVision / jax-models / Eqxvision β€” Dormant vision model zoos. β†’ Use Scenic or HuggingFace Transformers' Flax models.
  • 🟑 Pax / Praxis β€” Still ships, but Google's external narrative has moved to MaxText + Flax NNX. Google-internal lineage; new users should not start here.
  • 🟑 Neural Tangents β€” Low activity. Stays here because it's the canonical library for its (niche) infinite-width-network research area.
  • 🟑 FedJAX β€” Federated-learning niche with minimal recent activity. No clear successor in JAX; still worth knowing if you work on FL.

πŸ› οΈ Contributing

The JAX multiverse is expanding. If you see a new star β€” or see a dying one β€” open an Issue or a PR. See CONTRIBUTING.md for cosmic guidelines.

To audit the health of the stars yourself, run:

python scripts/health_check.py

This fetches last-commit dates and star counts for every linked repo and flags classifications that look stale.

The galaxy is also audited automatically every Monday via GitHub Actions β€” dying stars get flagged in an Issue before you ever need to look. See CONTRIBUTING.md for details.

Prefer to see the galaxy? The Observatory is a cinematic 3D map of every entry β€” size-by-stars, color-by-health, click-to-explore. cd observatory && npm run dev from a fresh clone.


Maintained with ❀️ by the JAX community. 2026 Edition. Let's map the stars. 🌌

About

Your compass for the JAX multiverse. 🌌 A curated, up-to-date map of the best libraries, state-of-the-art research, and high-performance tools

Topics

Resources

License

Code of conduct

Contributing

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors