Skip to content

Fix MaxText 22b.sh AOT compilation OOM on v4-128#4081

Merged
copybara-service[bot] merged 1 commit into
mainfrom
fix-22b-sh-oom
Jun 5, 2026
Merged

Fix MaxText 22b.sh AOT compilation OOM on v4-128#4081
copybara-service[bot] merged 1 commit into
mainfrom
fix-22b-sh-oom

Conversation

@darisoy
Copy link
Copy Markdown
Collaborator

@darisoy darisoy commented Jun 5, 2026

Description

This PR resolves the TPU HBM Out-of-Memory (OOM) error during Ahead-of-Time (AOT) compilation for the v4-128 topology with per_device_batch_size=13 by enabling Vocabulary Tiling and forcing Flash Attention in the 22b.sh configuration.

Details

  • The default configuration of 22b.sh fails with an HBM OOM during compilation.
  • Enabling vocabulary tiling (num_vocab_tiling=8) and forcing flash attention (attention=flash) reduces peak memory usage and compiler fragmentation, allowing the compilation to succeed.

FIXED: b/517329766
BUGS: b/517329766

Tests

The fix was verified by running the AOT compilation script directly on a TPU VM targeting the 2-slice v4-128 topology.

Command to Reproduce

bash src/maxtext/configs/tpu/v4/22b.sh \
  EXECUTABLE=train_compile \
  M_COMPILE_TOPOLOGY=v4-128 \
  M_COMPILE_TOPOLOGY_NUM_SLICES=2 \
  DATASET_PATH=dummy-dataset \
  OUTPUT_PATH=dummy-output-dir \
  RUN_PREFLIGHT=false

Results

Compilation completed successfully:

Jitting and compilation complete!
Finished train_compile.py successfully!

Checklist

Before submitting this PR, please make sure (put X in square brackets):

  • I have performed a self-review of my code. For an optional AI review, add the gemini-review label.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have run end-to-end tests tests and provided workload links above if applicable.
  • I have made or will make corresponding changes to the doc if needed, including adding new documentation pages to the relevant Table of Contents (toctree directive) as explained in our documentation.

This change enables vocabulary tiling (num_vocab_tiling=8) and forces flash attention (attention=flash) in 22b.sh.
This resolves the HBM OOM when running AOT compilation for v4-128 with per_device_batch_size=13.
By using vocabulary tiling, we save ~1.4 GB HBM, allowing the model to fit without requiring TP=2 or activation offloading, maintaining the original performance intent of the config.
@codecov
Copy link
Copy Markdown

codecov Bot commented Jun 5, 2026

Codecov Report

✅ All modified and coverable lines are covered by tests.

📢 Thoughts on this report? Let us know!

@copybara-service copybara-service Bot merged commit b2153a3 into main Jun 5, 2026
48 checks passed
@copybara-service copybara-service Bot deleted the fix-22b-sh-oom branch June 5, 2026 22:45
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants