Add lm_loss/perplexity metrics and fix MoE load balance loss extraction#3664
Add lm_loss/perplexity metrics and fix MoE load balance loss extraction#3664copybara-service[bot] merged 1 commit intomainfrom
Conversation
|
🤖 Hi @gagika, I've received your request, and I'm working on it now! You can track my progress in the logs for more details. |
There was a problem hiding this comment.
The pull request introduces valuable metrics for tracking language model performance (lm_loss and perplexity) and improves the robustness of loss extraction for MoE and Indexer models. It also fixes an issue with synthetic data evaluation. However, the renaming of total_loss to xent_sum was not applied to the DPO loss function, which will cause crashes when DPO is enabled.
🔍 General Feedback
- Consistency: The renaming of
total_losstoxent_sumin the mainloss_fnis a good improvement for clarity, but it must be applied consistently across all loss functions, includingdpo_loss_fn. - Modularity: The addition of
collect_intermediates_by_suffixinmaxtext_utils.pyis a great architectural improvement that makes the trainer more resilient to changes in model structure. - Testing: Ensure that the new metrics are verified with a full training run (including DPO if possible) to confirm that no
KeyErrororAttributeErroroccurs during logging.
Codecov Report❌ Patch coverage is 📢 Thoughts on this report? Let us know! |
|
🤖 Hi @gagika, I've received your request, and I'm working on it now! You can track my progress in the logs for more details. |
There was a problem hiding this comment.
This Pull Request introduces valuable metrics (lm_loss, perplexity) and improves the robustness of loss extraction for MoE and indexer models across various architectures. It also addresses a bug in synthetic data evaluation. However, the review identified significant issues regarding loss normalization and gradient scaling when using gradient accumulation, particularly for DPO and auxiliary losses, which should be addressed to ensure training stability and correctness.
🔍 General Feedback
- Metric Clarity: The renaming of
total_losstoxent_sumis a positive change that reduces ambiguity between the unnormalized cross-entropy sum and the full training objective. - Robustness: The new
collect_intermediates_by_suffixutility is a clean solution for handling dynamic layer indices in scanned model architectures. - Gradient Accumulation Consistency: There is a recurring pattern where averaged losses (DPO, MoE LB, Indexer) are added to unnormalized sums when gradient accumulation is enabled, leading to incorrect gradient scaling. Aligning these components to use unnormalized sums during GA would improve consistency and correctness.
|
🤖 Hi @gagika, I've received your request, and I'm working on it now! You can track my progress in the logs for more details. |
There was a problem hiding this comment.
This PR introduces valuable metric improvements and a robust suffix-based intermediate collection helper. The changes significantly improve the maintainability of loss extraction across different model architectures.
🔍 General Feedback
- The renaming of
total_losstoxent_sumis a great clarification. - Please ensure
SyntheticDataIteratoris updated to support the new DPO fields to avoidKeyErrorduring benchmarking. - Be aware of potential gradient scaling issues for auxiliary losses (MoE LB, Indexer) when gradient accumulation is used, as these losses are currently added to the objective in their averaged form.
There was a problem hiding this comment.
Thanks Gagik! For the balance loss, you may need to set it for dropping, which needs capacity_factor=1 sparse_matmul=False. After 50-100 training steps, you should be able to see something like this. Could you help have a sanity check? I see the loss is stable in your logs.
a35514b to
0e26968
Compare
done: |
JamesDeng42
left a comment
There was a problem hiding this comment.
Generally LGTM , left one comment.
Description
Adds
lm_lossandperplexitymetrics to pre-training, fixes MoE loadbalance loss extraction for scanned architectures, and fixes eval with
synthetic data.
Metric improvements:
aux["total_loss"]→aux["xent_sum"]to clarify it is theunnormalized sum of per-token cross-entropies, distinct from the full
training objective (
lossincludes MoE lb, indexer, MTP auxiliary losses)learning/lm_loss(= xent_sum / total_weights) andlearning/perplexity(= exp(lm_loss)) to TensorBoard and console outputeval/avg_perplexityandeval/avg_z_lossto cumulative eval metricsmoe_lb_lossin console output whennum_experts > 1MoE load balance loss fix:
The previous implementation looked up MoE lb loss via a fixed set of hard-coded
intermediate paths. This broke for Gemma4 and other models (like Llama4) that use scanned
blocks, where the path includes dynamic layer indices. Replaced with
collect_intermediates_by_suffix(), a newmaxtext_utilshelper that matchesleaves by key-path suffix, working correctly across scanned, scannable-block,
and standard layer layouts. Same fix applied to indexer loss extraction.
Synthetic eval fix:
fixed synthetic evals work with
eval_interval > 0.Tests
Verified a run on Gemma4-26B:
output: https://paste.googleplex.com/5891139803676672
DPO test:
Checklist
Before submitting this PR, please make sure (put X in square brackets):
gemini-reviewlabel.