SAJAX is a JAX-accelerated reimplementation of SAGE (Chakraborty et al. 2024), a code that models stellar contamination of exoplanet transmission spectra from active regions (ars, faculae) on the stellar surface.
The key innovation over plain SAGE is that SAJAX vectorises the spectral
loop with jax.vmap, making it fast on both CPU and GPU without any
change to the calling code, and fully differentiable — enabling
gradient-based inference with tools like NumPyro or Optax.
Documentation can be found at sajax.readthedocs.io
pip install sajaxOr in development mode from a local clone:
git clone https://github.com/SamMerc/sajax.git
cd sajax
pip install -e ".[dev]"sajax/
├── sajax/
│ ├── __init__.py # public API
│ ├── core.py # JAX light-curve engine
│ ├── planet.py # planet orbital dynamics
│ ├── geometry.py # rotation matrices, coordinate transforms
├── docs/
│ ├── quickstart.ipynb
│ ├── comparison.ipynb
│ ├── inference.ipynb
├── tests/
│ ├── test_core.py
│ ├── test_planet.py
├── pyproject.toml
├── .gitignore
└── README.md
