Add support z-loss in pre-training#3211
Conversation
Codecov Report❌ Patch coverage is
📢 Thoughts on this report? Let us know! |
c8b18fa to
cf917cc
Compare
|
🤖 Hi @gagika, I've received your request, and I'm working on it now! You can track my progress in the logs for more details. |
|
🤖 I'm sorry @gagika, but I was unable to process your request. Please see the logs for more details. |
|
🤖 Hi @aireenmei, I've received your request, and I'm working on it now! You can track my progress in the logs for more details. |
There was a problem hiding this comment.
This PR successfully implements and integrates z-loss into both the standard and vocabulary-tiling training loops to improve numerical stability in low-precision setups. The code is well-tested, correctly computes gradients (including the exact mathematical verifications for the z-loss penalty), and properly scales to vocabulary-tiling workflows.
🔍 General Feedback
- Great job adding the comprehensive mathematical assertions in
max_utils_test.pyas well as end-to-end integration tests intiling_test.py. - The implementation efficiently leverages
max_utils.cross_entropy_with_logitsto gracefully handle the auxiliary loss gradient computation in a mathematically correct way. - I left a couple of inline comments regarding the
z_lossmetric reporting; the auxiliary loss currently logged to TensorBoard tracks the unnormalized sum across the batch, which will scale incorrectly with your batch size and sequence length without normalization.
RissyRan
left a comment
There was a problem hiding this comment.
LGTM! Just minor comments.
12ea159 to
a3475bf
Compare
RissyRan
left a comment
There was a problem hiding this comment.
LGTM! Please consider the gradient accumulation case.
| total_z_loss = jnp.sum(z_loss) | ||
|
|
||
| total_weights = jnp.sum(data["targets_segmentation"] != 0) | ||
| # If gradient accumulation is enabled, we don't need to divide total_loss |
There was a problem hiding this comment.
Sorry I missed this part earlier. Shall we also handle gradient accumulation?
If not, we could have an assertion in type.py that mentioning that gradient_accumulation_steps and z_loss_multiplier cannot be set together.
maxtext/src/maxtext/trainers/pre_train/train.py
Lines 200 to 207 in a3475bf
There was a problem hiding this comment.
loss already has z_loss portion it it (loss += total_z_loss inside cross_entropy_with_logits, and will be used for gradient / backwards pass.
The standalone total_z_loss is only used for metric logging to TensorBoard. Because it doesn't affect the gradients, we don't need to worry about the gradient accumulation division optimization e.g. we just normalize it by total_weights so the logged metric is always a consistent per-token average.
a3475bf to
25b5de7
Compare
Description
This PR implements Z-loss to improve numerical stability and prevent runaway logits. Alongside techniques like QK-normalization and logit soft-capping, it is a key mechanism for stabilizing low-precision (BF16/FP8) training.
Key Changes:
z_loss_multiplierparameter totypes.py(defaults to0.0).max_utils.cross_entropy_with_logits) into the standard training loop (loss_fn) and the vocabulary tiling path (vocab_tiling_linen_loss).auxdictionary and logged to TensorBoard aslearning/z_loss.Tests
test_cross_entropy_with_z_lossinmax_utils_test.pyto verify the penalty calculation is mathematically correct.test_vocab_tiling_gradient_with_z_lossintiling_test.pyto ensure loss and gradients match exactly between standard and vocabulary-tiled computations when Z-loss is enabled.Checklist
Before submitting this PR, please make sure (put X in square brackets):
gemini-reviewlabel.