Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions .github/CODEOWNERS
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,9 @@ src/MaxText/elastic_train.py @lukebaumann @shauryagup @richjames0 @shralex
src/MaxText/layers/quantizations.py @khatwanimohit @jshin1394 @liudangyi @richjames0 @shralex

# Inference
src/MaxText/tests/inference @vipannalla @mitalisi @gpolovets1 @mailvijayasingh @jrplatin @patemotter @lumosis @richjames0
src/MaxText/inference @vipannalla @mitalisi @gpolovets1 @mailvijayasingh @jrplatin @patemotter @lumosis @richjames0
src/MaxText/inference_mlperf @vipannalla @mitalisi @gpolovets1 @mailvijayasingh @jrplatin @patemotter @lumosis @richjames0
src/maxtext/tests/inference @vipannalla @mitalisi @gpolovets1 @mailvijayasingh @jrplatin @patemotter @lumosis @richjames0
src/maxtext/inference @vipannalla @mitalisi @gpolovets1 @mailvijayasingh @jrplatin @patemotter @lumosis @richjames0
src/maxtext/inference_mlperf @vipannalla @mitalisi @gpolovets1 @mailvijayasingh @jrplatin @patemotter @lumosis @richjames0

# Dockerfiles and dependencies
*.Dockerfile @bvandermoon @parambole @richjames0 @shralex
Expand Down
12 changes: 6 additions & 6 deletions .vscode/launch.json
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
"console": "integratedTerminal",
"justMyCode": false,
"python": "python3",
"module": "MaxText.decode",
"module": "maxtext.decode",
"args": ["src/MaxText/configs/base.yml",
"run_name=runner_$(date +%Y-%m-%d-%H-%M)",
"base_output_directory=gs://test-maxtext-output",
Expand All @@ -35,9 +35,9 @@
"console": "integratedTerminal",
"justMyCode": false,
"python": "python3",
"module": "MaxText.decode",
"module": "maxtext.decode",
"args": ["src/MaxText/configs/base.yml",
"run_name=runner_$(date +%Y-%m-%d-%H-%M)",
"run_name=runner_$(date +%Y-%m-%d-%H-%M)",
"base_output_directory=gs://test-maxtext-output",
"dataset_path=gs://test-maxtext-dataset",
"steps=2",
Expand All @@ -53,7 +53,7 @@
"python": "python3",
"module": "MaxText.train",
"args": ["src/MaxText/configs/base.yml",
"run_name=runner_$(date +%Y-%m-%d-%H-%M)",
"run_name=runner_$(date +%Y-%m-%d-%H-%M)",
"base_output_directory=gs://test-maxtext-output",
"dataset_path=gs://test-maxtext-dataset",
"steps=2",
Expand All @@ -66,7 +66,7 @@
"console": "integratedTerminal",
"justMyCode": false,
"python": "python3",
"module": "MaxText.inference_microbenchmark",
"module": "maxtext.inference_microbenchmark",
"args": [
"src/MaxText/configs/base.yml",
"model_name=llama2-7b",
Expand All @@ -82,7 +82,7 @@
"inference_microbenchmark_prefill_lengths=32,64,128,256,512,1024",
"inference_microbenchmark_stages=generate",
"inference_microbenchmark_loop_iters=1",
"run_name=runner_$(date +%Y-%m-%d-%H-%M)",
"run_name=runner_$(date +%Y-%m-%d-%H-%M)",
"base_output_directory=gs://test-maxtext-output",
"prefill_cache_axis_order=0,2,1,3",
"ar_cache_axis_order=0,2,1,3",
Expand Down
6 changes: 3 additions & 3 deletions codecov.yml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
# During scheduled runs, the 'regular' flag is carried forward from the last PR.

# Exclude non-source code, deprecated and experimental folders from coverage tracking
codecov:
codecov:
token: 35742a22-fb1f-4839-97ff-b54da5588689
# By default file names in the coverage report will have their path in the file system, which in our
# runners would be /__w/maxtext/maxtext/src/MaxText/* but Codecov expects src/MaxText/* so we need to fix the path
Expand All @@ -36,8 +36,8 @@ ignore:
- "src/MaxText/configs"
- "src/MaxText/examples"
- "src/MaxText/experimental"
- "src/MaxText/inference"
- "src/MaxText/inference_mlperf"
- "src/maxtext/inference"
- "src/maxtext/inference_mlperf"
- "src/MaxText/scratch_code"
- "src/MaxText/distillation" # code moved to src/maxtext/trainers/post_train/distillation
- "src/MaxText/sft" # code moved to src/maxtext/trainers/post_train/sft
Expand Down
86 changes: 47 additions & 39 deletions docs/guides/optimization/pallas_kernels_performance.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,8 @@ This guide explains **when** to consider Pallas, a **workflow** for developing a

Think in **roofline** terms ([All About Rooflines](https://jax-ml.github.io/scaling-book/roofline/)) and in terms of **structure the compiler can’t see**:

* **Roofline framing.** Is your op **compute-limited** (MXU at or near peak) or **bandwidth-limited** (HBM↔on-chip transfers dominate)? Pallas tends to shine when you can reduce bandwidth pressure or avoid wasted work via better tiling and scheduling.
* **Compiler invisibles.** Irregular sparsity, ragged batch shapes, non-contiguous memory access, and domain-specific invariants are all signals that a custom kernel could help.
- **Roofline framing.** Is your op **compute-limited** (MXU at or near peak) or **bandwidth-limited** (HBM↔on-chip transfers dominate)? Pallas tends to shine when you can reduce bandwidth pressure or avoid wasted work via better tiling and scheduling.
- **Compiler invisibles.** Irregular sparsity, ragged batch shapes, non-contiguous memory access, and domain-specific invariants are all signals that a custom kernel could help.

**Know when XLA is enough.** Before writing a custom kernel, always [profile your baseline](#1-high-level-profiling). If a standard operation (like a dense [`jnp.matmul`](https://docs.jax.dev/en/latest/_autosummary/jax.numpy.matmul.html)) is already performing well, the XLA compiler is doing its job. In these cases, a Pallas kernel will increase code complexity and maintenance burden with minimal performance improvement.

Expand All @@ -42,29 +42,34 @@ it is very difficult to automatically infer the dual of the memory pipeline.

For dense, regular GEMMs, XLA’s libraries are hard to beat. The exception is **Mixture-of-Experts (MoE)** MLPs with **ragged token→expert layouts** (some tokens routed to different experts; shapes are irregular). Zero-padding to make dense tensors wastes FLOPs; a custom kernel can operate only on the actually-selected tokens.

* In MaxText, we use Grouped Matrix Multiplication (GMM) via **Megablox** to compute per-expert matmuls on ragged batches. Precomputed metadata (e.g., token→expert indices and ranges) guides the grouped computation and avoids work on padded regions.
- In MaxText, we use Grouped Matrix Multiplication (GMM) via **Megablox** to compute per-expert matmuls on ragged batches. Precomputed metadata (e.g., token→expert indices and ranges) guides the grouped computation and avoids work on padded regions.

**Note:** *Megablox* is an efficient, non-capped MoE implementation in JAX. *Megablocks* refers to the equivalent PyTorch implementation. See [arXiv:2211.15841](https://arxiv.org/abs/2211.15841) for more details.

### 2. Memory-Access-Bound work (attention)

Attention kernels are classically **bandwidth-limited** if you materialize the full \[L,L\] score matrix. A Pallas kernel can block **Q/K/V** into tiles that fit on-chip and perform **online softmax accumulation**, never storing the massive intermediate.
Attention kernels are classically **bandwidth-limited** if you materialize the full [L,L] score matrix. A Pallas kernel can block **Q/K/V** into tiles that fit on-chip and perform **online softmax accumulation**, never storing the massive intermediate.

* MaxText uses a Pallas attention kernel for training (Flash/Splash-style) and **paged/ragged** attention for inference to efficiently fetch KV cache pages and handle non-contiguous layouts.
- MaxText uses a Pallas attention kernel for training (Flash/Splash-style) and **paged/ragged** attention for inference to efficiently fetch KV cache pages and handle non-contiguous layouts.

## 🛠️ Pallas kernels in MaxText

To maximize performance, MaxText uses custom Pallas kernels for memory-bandwidth-bound or structurally irregular operations that a general-purpose compiler cannot optimize as effectively. Below are the key kernels we use. **Note**: Examples evolve; treat this list as guidance.

* **Training Attention (Flash/Splash-style):** This kernel is the default for training Transformer models in MaxText, such as DeepSeek, Gemma and Llama. It avoids creating the large \[L,L\] attention matrix to save memory, processing data in smaller, tiled chunks with online softmax accumulation.
* [`src/MaxText/kernels/splash_attention_kernel.py`](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/MaxText/kernels/splash_attention_kernel.py)
* **Serving Attention (Paged & Ragged):** For high-throughput inference, this kernel efficiently fetches non-contiguous "pages" of the KV cache from memory. It is a key optimization for our serving stack and is used for models running on MaxText's inference engine.
* [`src/MaxText/inference/paged_attention.py`](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/MaxText/inference/paged_attention.py)
* [`src/MaxText/inference/paged_attention_kernel_v2.py`](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/MaxText/inference/paged_attention_kernel_v2.py)
* **MoE Grouped Matmul (Megablox GMM):** Sparse/irregular grouped GEMMs driven by host-built metadata.
- **Training Attention (Flash/Splash-style):** This kernel is the default for training Transformer models in MaxText, such as DeepSeek, Gemma and Llama. It avoids creating the large [L,L] attention matrix to save memory, processing data in smaller, tiled chunks with online softmax accumulation.

> This is an efficient computation method for Mixture-of-Experts (MoE) models like DeepSeek, Llama 4, Mixtral and Qwen-MoE. In MoE, each token is processed by only a few "experts," which is inefficient for standard matrix multiplication. Megablox solves this by having the CPU (**host**) first create a routing plan (**metadata**) that assigns tokens to experts. The accelerator (**device**) then uses this plan to perform many small, dense matrix multiplications in parallel (**Grouped Matrix Multiplication**), avoiding wasted work on unused experts.
* [`src/MaxText/kernels/megablox/gmm.py`](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/MaxText/kernels/megablox/gmm.py)
- [`src/MaxText/kernels/splash_attention_kernel.py`](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/MaxText/kernels/splash_attention_kernel.py)

- **Serving Attention (Paged & Ragged):** For high-throughput inference, this kernel efficiently fetches non-contiguous "pages" of the KV cache from memory. It is a key optimization for our serving stack and is used for models running on MaxText's inference engine.

- [`src/maxtext/inference/paged_attention.py`](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/maxtext/inference/paged_attention.py)
- [`src/maxtext/inference/paged_attention_kernel_v2.py`](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/maxtext/inference/paged_attention_kernel_v2.py)

- **MoE Grouped Matmul (Megablox GMM):** Sparse/irregular grouped GEMMs driven by host-built metadata.

> This is an efficient computation method for Mixture-of-Experts (MoE) models like DeepSeek, Llama 4, Mixtral and Qwen-MoE. In MoE, each token is processed by only a few "experts," which is inefficient for standard matrix multiplication. Megablox solves this by having the CPU (**host**) first create a routing plan (**metadata**) that assigns tokens to experts. The accelerator (**device**) then uses this plan to perform many small, dense matrix multiplications in parallel (**Grouped Matrix Multiplication**), avoiding wasted work on unused experts.

- [`src/MaxText/kernels/megablox/gmm.py`](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/MaxText/kernels/megablox/gmm.py)

**Note:** Megablox accelerates the grouped **matmul**; **routing/gating** is separate code ([`src/MaxText/layers/moe.py`](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/MaxText/layers/moe.py)).

Expand All @@ -74,7 +79,7 @@ To maximize performance, MaxText uses custom Pallas kernels for memory-bandwidth

Give the kernel a clear name in traces and capture a profile. Always use [`jax.block_until_ready()`](https://docs.jax.dev/en/latest/_autosummary/jax.block_until_ready.html) when timing your operations.

``` python
```python
import jax
from jax import profiler

Expand Down Expand Up @@ -104,24 +109,24 @@ For a more automated approach, consider using libraries like [tune-jax](https://

Pallas exposes the underlying hardware primitives for you to control.

* **HBM:** High-Bandwidth Memory (standard device memory).
* **VMEM:** On-chip vector SRAM for array tiles; your kernel primarily reads/writes VMEM refs.
* **SMEM:** On-chip scalar SRAM for control/metadata (e.g., counters, small tables).
* **Semaphores:** Available for advanced async/barrier patterns in manual pipelines.
* **MXU:** The Matrix Unit, optimized for large block GEMMs/convolutions.
* **VPU:** The Vector Processing Unit, used for elementwise/vector work.
- **HBM:** High-Bandwidth Memory (standard device memory).
- **VMEM:** On-chip vector SRAM for array tiles; your kernel primarily reads/writes VMEM refs.
- **SMEM:** On-chip scalar SRAM for control/metadata (e.g., counters, small tables).
- **Semaphores:** Available for advanced async/barrier patterns in manual pipelines.
- **MXU:** The Matrix Unit, optimized for large block GEMMs/convolutions.
- **VPU:** The Vector Processing Unit, used for elementwise/vector work.

**Alignment & Constraints:** Respect TPU BlockSpec constraints (divisibility/shape rules for trailing dimensions and supported block shapes). Start with tile shapes that fit in VMEM and meet these requirements, then sweep different sizes to find the optimum. Let profiling guide you; don't assume powers of two are always best.

## 🧱 Core Pallas design patterns

These are the common techniques used in MaxText's Pallas kernels.

* **Tiling & Blocking:** Move just a tile that fits on-chip, compute on it, and write it back.
* **Explicit Pipelining:** Overlap HBM↔VMEM loads with compute to hide latency (e.g., double-buffering).
* **Online Accumulation:** Combine partial results as you go; don’t materialize huge intermediate arrays.
* **Auxiliary Metadata:** Precompute control tables (e.g., token-to-expert ranges) and keep them in fast scalar memory.
* **Compute↔Communication Overlap:** In distributed runs, overlap local work with cross-device traffic when possible.
- **Tiling & Blocking:** Move just a tile that fits on-chip, compute on it, and write it back.
- **Explicit Pipelining:** Overlap HBM↔VMEM loads with compute to hide latency (e.g., double-buffering).
- **Online Accumulation:** Combine partial results as you go; don’t materialize huge intermediate arrays.
- **Auxiliary Metadata:** Precompute control tables (e.g., token-to-expert ranges) and keep them in fast scalar memory.
- **Compute↔Communication Overlap:** In distributed runs, overlap local work with cross-device traffic when possible.

## ✍️ Writing & integrating a kernel

Expand All @@ -136,9 +141,11 @@ import jax
import jax.numpy as jnp
from jax.experimental import pallas as pl


def add_vectors_kernel(x_ref, y_ref, o_ref):
o_ref[:] = x_ref[:] + y_ref[:]


def add_vectors(x: jax.Array, y: jax.Array) -> jax.Array:
assert x.shape == y.shape
return pl.pallas_call(
Expand All @@ -156,14 +163,16 @@ import jax
import jax.numpy as jnp
from jax.experimental import pallas as pl


def tile_add_kernel(x_ref, y_ref, o_ref):
# Operate on the tile slices handed in by BlockSpecs (already in VMEM on TPU).
o_ref[:, :] = x_ref[:, :] + y_ref[:, :]


def tile_add(x: jax.Array, y: jax.Array) -> jax.Array:
assert x.shape == y.shape and x.ndim == 2
B0 = min(128, x.shape[0]) # Example choice; tune this with a sweep
B1 = x.shape[1] # Full width tile (for illustration)
B1 = x.shape[1] # Full width tile (for illustration)

# Map program id (tile index) -> tile origin in the full (HBM) array.
# NOTE: The runtime advances origins by `block_shape`, so `i` is already a tile
Expand Down Expand Up @@ -192,29 +201,28 @@ def tile_add(x: jax.Array, y: jax.Array) -> jax.Array:

Prefer `pl.pallas_call` with scratch buffers allocated in the appropriate memory space (VMEM/SMEM) and use multi-buffering to overlap HBM loads with compute. Advanced pipelining to consider: custom prefetch block order via a scalar prefetch grid (for details see [here](https://docs.jax.dev/en/latest/pallas/tpu/sparse.html)), which lets you control block execution order based on runtime values.


## 🌐 Distributed execution

Dispatch a kernel on multiple devices with `jax.shard_map`. It’s usually simpler and more maintainable than in-kernel cross-device communication. While Pallas supports low-level comms, `shard_map` is the right first choice for multi-device parallelism, and you can **communicate with `shard_map` collectives** when needed.

## 🐞 Debugging tips

* Use `interpret=True` in `pallas_call` to run the kernel body in a Python interpreter backend, simulating device execution on CPU without lowering through XLA.
* Start with a tiny problem size and assert on invariants inside the kernel.
* Add `jax.named_scope` liberally so kernels are easy to spot in performance traces.
- Use `interpret=True` in `pallas_call` to run the kernel body in a Python interpreter backend, simulating device execution on CPU without lowering through XLA.
- Start with a tiny problem size and assert on invariants inside the kernel.
- Add `jax.named_scope` liberally so kernels are easy to spot in performance traces.

## ✅ Putting it all together (checklist)

1. **Profile** the baseline using `named_scope` and `block_until_ready`.
2. **Tile arrays into smaller chunks using BlockSpecs** (virtually always necessary, even for simple kernels).
3. Build a **sweep harness** for block shapes (and optionally scalar prefetch grid choices).
4. **Validate** end-to-end performance in the model, not just microbenchmarks.
5. Consider **maintainability** and guard the new kernel with tests.
6. Consider applying **`jax.vmap`** to a Pallas kernel to simplify implementation; think of it as prepending grid dimensions automatically.
1. **Tile arrays into smaller chunks using BlockSpecs** (virtually always necessary, even for simple kernels).
1. Build a **sweep harness** for block shapes (and optionally scalar prefetch grid choices).
1. **Validate** end-to-end performance in the model, not just microbenchmarks.
1. Consider **maintainability** and guard the new kernel with tests.
1. Consider applying **`jax.vmap`** to a Pallas kernel to simplify implementation; think of it as prepending grid dimensions automatically.

## 📚 References

* **Pallas Docs & Quickstart:** [docs.jax.dev/en/latest/pallas/index.html](https://docs.jax.dev/en/latest/pallas/index.html)
* **JAX Profiling Guides:** [jax.readthedocs.io/en/latest/profiling.html](https://jax.readthedocs.io/en/latest/profiling.html)
* **Manual Parallelism (shard_map):** [docs.jax.dev/en/latest/notebooks/shard_map.html](https://docs.jax.dev/en/latest/notebooks/shard_map.html)
* **Distributed Pallas on TPU:** [docs.jax.dev/en/latest/pallas/tpu/distributed.html](https://docs.jax.dev/en/latest/pallas/tpu/distributed.html)
- **Pallas Docs & Quickstart:** [docs.jax.dev/en/latest/pallas/index.html](https://docs.jax.dev/en/latest/pallas/index.html)
- **JAX Profiling Guides:** [jax.readthedocs.io/en/latest/profiling.html](https://jax.readthedocs.io/en/latest/profiling.html)
- **Manual Parallelism (shard_map):** [docs.jax.dev/en/latest/notebooks/shard_map.html](https://docs.jax.dev/en/latest/notebooks/shard_map.html)
- **Distributed Pallas on TPU:** [docs.jax.dev/en/latest/pallas/tpu/distributed.html](https://docs.jax.dev/en/latest/pallas/tpu/distributed.html)
Loading
Loading