Skip to content

Add support z-loss in pre-training#3211

Merged
copybara-service[bot] merged 1 commit intomainfrom
agagik-z-loss
Feb 27, 2026
Merged

Add support z-loss in pre-training#3211
copybara-service[bot] merged 1 commit intomainfrom
agagik-z-loss

Conversation

@gagika
Copy link
Copy Markdown
Collaborator

@gagika gagika commented Feb 21, 2026

Description

This PR implements Z-loss to improve numerical stability and prevent runaway logits. Alongside techniques like QK-normalization and logit soft-capping, it is a key mechanism for stabilizing low-precision (BF16/FP8) training.

Key Changes:

  • Configuration: Added a z_loss_multiplier parameter to types.py (defaults to 0.0).
  • Integration: Wired the existing Z-loss utility (max_utils.cross_entropy_with_logits) into the standard training loop (loss_fn) and the vocabulary tiling path (vocab_tiling_linen_loss).
  • Logging: Normalized Z-loss is now exported to the aux dictionary and logged to TensorBoard as learning/z_loss.

Tests

  • Math Verification: Added test_cross_entropy_with_z_loss in max_utils_test.py to verify the penalty calculation is mathematically correct.
  • Integration/Tiling: Added test_vocab_tiling_gradient_with_z_loss in tiling_test.py to ensure loss and gradients match exactly between standard and vocabulary-tiled computations when Z-loss is enabled.

Checklist

Before submitting this PR, please make sure (put X in square brackets):

  • I have performed a self-review of my code. For an optional AI review, add the gemini-review label.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have run end-to-end tests tests and provided workload links above if applicable.
  • I have made or will make corresponding changes to the doc if needed, including adding new documentation pages to the relevant Table of Contents (toctree directive) as explained in our documentation.

@codecov
Copy link
Copy Markdown

codecov Bot commented Feb 21, 2026

Codecov Report

❌ Patch coverage is 71.42857% with 8 lines in your changes missing coverage. Please review.

Files with missing lines Patch % Lines
src/maxtext/trainers/pre_train/train.py 50.00% 6 Missing ⚠️
src/maxtext/utils/vocabulary_tiling.py 87.50% 2 Missing ⚠️

📢 Thoughts on this report? Let us know!

@github-actions
Copy link
Copy Markdown

🤖 Hi @gagika, I've received your request, and I'm working on it now! You can track my progress in the logs for more details.

@github-actions
Copy link
Copy Markdown

🤖 I'm sorry @gagika, but I was unable to process your request. Please see the logs for more details.

@github-actions
Copy link
Copy Markdown

🤖 Hi @aireenmei, I've received your request, and I'm working on it now! You can track my progress in the logs for more details.

Copy link
Copy Markdown

@github-actions github-actions Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

## 📋 Review Summary

This PR successfully implements and integrates z-loss into both the standard and vocabulary-tiling training loops to improve numerical stability in low-precision setups. The code is well-tested, correctly computes gradients (including the exact mathematical verifications for the z-loss penalty), and properly scales to vocabulary-tiling workflows.

🔍 General Feedback

  • Great job adding the comprehensive mathematical assertions in max_utils_test.py as well as end-to-end integration tests in tiling_test.py.
  • The implementation efficiently leverages max_utils.cross_entropy_with_logits to gracefully handle the auxiliary loss gradient computation in a mathematically correct way.
  • I left a couple of inline comments regarding the z_loss metric reporting; the auxiliary loss currently logged to TensorBoard tracks the unnormalized sum across the batch, which will scale incorrectly with your batch size and sequence length without normalization.

Comment thread src/maxtext/trainers/pre_train/train.py
Comment thread src/maxtext/trainers/pre_train/train.py
Comment thread src/maxtext/configs/types.py
Comment thread src/maxtext/trainers/pre_train/train.py
Copy link
Copy Markdown
Collaborator

@RissyRan RissyRan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! Just minor comments.

Comment thread src/maxtext/configs/types.py
Comment thread src/maxtext/trainers/pre_train/train.py
@gagika gagika force-pushed the agagik-z-loss branch 2 times, most recently from 12ea159 to a3475bf Compare February 25, 2026 22:41
Copy link
Copy Markdown
Collaborator

@RissyRan RissyRan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! Please consider the gradient accumulation case.

total_z_loss = jnp.sum(z_loss)

total_weights = jnp.sum(data["targets_segmentation"] != 0)
# If gradient accumulation is enabled, we don't need to divide total_loss
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry I missed this part earlier. Shall we also handle gradient accumulation?

If not, we could have an assertion in type.py that mentioning that gradient_accumulation_steps and z_loss_multiplier cannot be set together.

if config.gradient_accumulation_steps > 1 and not config.use_tunix_gradient_accumulation:
loss = total_loss
else:
# When using Tunix gradient accumulation, we revert to standard normalization.
# Unlike the manual accumulation path above, Tunix (via optax.MultiSteps) expects
# a normalized loss for each step. It handles the accumulation state
# updates and scaling internally.
loss = total_loss / (total_weights + EPS)

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

loss already has z_loss portion it it (loss += total_z_loss inside cross_entropy_with_logits, and will be used for gradient / backwards pass.

The standalone total_z_loss is only used for metric logging to TensorBoard. Because it doesn't affect the gradients, we don't need to worry about the gradient accumulation division optimization e.g. we just normalize it by total_weights so the logged metric is always a consistent per-token average.

@copybara-service copybara-service Bot merged commit 17d805e into main Feb 27, 2026
30 checks passed
@copybara-service copybara-service Bot deleted the agagik-z-loss branch February 27, 2026 05:10
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants