Skip to content

Add lm_loss/perplexity metrics and fix MoE load balance loss extraction#3664

Merged
copybara-service[bot] merged 1 commit intomainfrom
agagik-loss-fix
Apr 16, 2026
Merged

Add lm_loss/perplexity metrics and fix MoE load balance loss extraction#3664
copybara-service[bot] merged 1 commit intomainfrom
agagik-loss-fix

Conversation

@gagika
Copy link
Copy Markdown
Collaborator

@gagika gagika commented Apr 14, 2026

Description

Adds lm_loss and perplexity metrics to pre-training, fixes MoE load
balance loss extraction for scanned architectures, and fixes eval with
synthetic data.

Metric improvements:

  • Rename aux["total_loss"]aux["xent_sum"] to clarify it is the
    unnormalized sum of per-token cross-entropies, distinct from the full
    training objective (loss includes MoE lb, indexer, MTP auxiliary losses)
  • Add learning/lm_loss (= xent_sum / total_weights) and
    learning/perplexity (= exp(lm_loss)) to TensorBoard and console output
  • Add eval/avg_perplexity and eval/avg_z_loss to cumulative eval metrics
  • Log moe_lb_loss in console output when num_experts > 1

MoE load balance loss fix:
The previous implementation looked up MoE lb loss via a fixed set of hard-coded
intermediate paths. This broke for Gemma4 and other models (like Llama4) that use scanned
blocks, where the path includes dynamic layer indices. Replaced with
collect_intermediates_by_suffix(), a new maxtext_utils helper that matches
leaves by key-path suffix, working correctly across scanned, scannable-block,
and standard layer layouts. Same fix applied to indexer loss extraction.

Synthetic eval fix:
fixed synthetic evals work with eval_interval > 0.

Tests

Verified a run on Gemma4-26B:

python3 -m maxtext.trainers.pre_train.train src/maxtext/configs/base.yml   model_name=gemma4-26b   base_output_directory=${BASE_OUTPUT_DIRECTORY?}   dataset_type=hf   hf_path=allenai/c4   hf_name=en   hf_eval_split=validation   tokenizer_type=huggingface   tokenizer_path=google/gemma-4-26b-a4b-it   per_device_batch_size=1   run_name=runner_finetune_gemma4_26b_2   steps=100   enable_checkpointing=false   sharding_tolerance=0.03   eval_interval=20   eval_steps=5   load_balance_loss_weight=0.1

output: https://paste.googleplex.com/5891139803676672

DPO test:

 python3 -m maxtext.trainers.pre_train.train src/maxtext/configs/post_train/dpo.yml   run_name=dpo_smoke   base_emb_dim=256   base_num_query_heads=4   base_num_kv_heads=4   base_mlp_dim=512   base_num_decoder_layers=4   head_dim=64   per_device_batch_size=1   max_target_length=128   steps=10   eval_interval=5   eval_steps=2   enable_checkpointing=false

Checklist

Before submitting this PR, please make sure (put X in square brackets):

  • I have performed a self-review of my code. For an optional AI review, add the gemini-review label.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have run end-to-end tests tests and provided workload links above if applicable.
  • I have made or will make corresponding changes to the doc if needed, including adding new documentation pages to the relevant Table of Contents (toctree directive) as explained in our documentation.

@github-actions
Copy link
Copy Markdown

🤖 Hi @gagika, I've received your request, and I'm working on it now! You can track my progress in the logs for more details.

Copy link
Copy Markdown

@github-actions github-actions Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

## 📋 Review Summary

The pull request introduces valuable metrics for tracking language model performance (lm_loss and perplexity) and improves the robustness of loss extraction for MoE and Indexer models. It also fixes an issue with synthetic data evaluation. However, the renaming of total_loss to xent_sum was not applied to the DPO loss function, which will cause crashes when DPO is enabled.

🔍 General Feedback

  • Consistency: The renaming of total_loss to xent_sum in the main loss_fn is a good improvement for clarity, but it must be applied consistently across all loss functions, including dpo_loss_fn.
  • Modularity: The addition of collect_intermediates_by_suffix in maxtext_utils.py is a great architectural improvement that makes the trainer more resilient to changes in model structure.
  • Testing: Ensure that the new metrics are verified with a full training run (including DPO if possible) to confirm that no KeyError or AttributeError occurs during logging.

Comment thread src/maxtext/common/metric_logger.py
Comment thread src/maxtext/common/metric_logger.py Outdated
Comment thread src/maxtext/trainers/pre_train/train.py
Comment thread src/maxtext/trainers/pre_train/train.py
@codecov
Copy link
Copy Markdown

codecov Bot commented Apr 14, 2026

@github-actions
Copy link
Copy Markdown

🤖 Hi @gagika, I've received your request, and I'm working on it now! You can track my progress in the logs for more details.

Copy link
Copy Markdown

@github-actions github-actions Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

## 📋 Review Summary

This Pull Request introduces valuable metrics (lm_loss, perplexity) and improves the robustness of loss extraction for MoE and indexer models across various architectures. It also addresses a bug in synthetic data evaluation. However, the review identified significant issues regarding loss normalization and gradient scaling when using gradient accumulation, particularly for DPO and auxiliary losses, which should be addressed to ensure training stability and correctness.

🔍 General Feedback

  • Metric Clarity: The renaming of total_loss to xent_sum is a positive change that reduces ambiguity between the unnormalized cross-entropy sum and the full training objective.
  • Robustness: The new collect_intermediates_by_suffix utility is a clean solution for handling dynamic layer indices in scanned model architectures.
  • Gradient Accumulation Consistency: There is a recurring pattern where averaged losses (DPO, MoE LB, Indexer) are added to unnormalized sums when gradient accumulation is enabled, leading to incorrect gradient scaling. Aligning these components to use unnormalized sums during GA would improve consistency and correctness.

Comment thread src/maxtext/trainers/post_train/dpo/dpo_utils.py
Comment thread src/maxtext/common/metric_logger.py
Comment thread src/maxtext/trainers/pre_train/train.py
Comment thread src/maxtext/utils/maxtext_utils.py
Comment thread src/maxtext/utils/gradient_accumulation.py Outdated
@github-actions
Copy link
Copy Markdown

🤖 Hi @gagika, I've received your request, and I'm working on it now! You can track my progress in the logs for more details.

Copy link
Copy Markdown

@github-actions github-actions Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

## 📋 Review Summary

This PR introduces valuable metric improvements and a robust suffix-based intermediate collection helper. The changes significantly improve the maintainability of loss extraction across different model architectures.

🔍 General Feedback

  • The renaming of total_loss to xent_sum is a great clarification.
  • Please ensure SyntheticDataIterator is updated to support the new DPO fields to avoid KeyError during benchmarking.
  • Be aware of potential gradient scaling issues for auxiliary losses (MoE LB, Indexer) when gradient accumulation is used, as these losses are currently added to the objective in their averaged form.

Comment thread src/maxtext/trainers/pre_train/train.py
Comment thread src/maxtext/common/metric_logger.py
Comment thread src/maxtext/input_pipeline/synthetic_data_processing.py
Comment thread src/maxtext/trainers/pre_train/train.py
Copy link
Copy Markdown
Collaborator

@RissyRan RissyRan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks Gagik! For the balance loss, you may need to set it for dropping, which needs capacity_factor=1 sparse_matmul=False. After 50-100 training steps, you should be able to see something like this. Could you help have a sanity check? I see the loss is stable in your logs.

Comment thread src/maxtext/trainers/pre_train/train.py
@gagika gagika force-pushed the agagik-loss-fix branch 2 times, most recently from a35514b to 0e26968 Compare April 15, 2026 20:39
@gagika
Copy link
Copy Markdown
Collaborator Author

gagika commented Apr 15, 2026

Thanks Gagik! For the balance loss, you may need to set it for dropping, which needs capacity_factor=1 sparse_matmul=False. After 50-100 training steps, you should be able to see something like this. Could you help have a sanity check? I see the loss is stable in your logs.

done:

https://screenshot.googleplex.com/3FcUcvjd4YooesH

Comment thread src/maxtext/configs/types.py Outdated
Copy link
Copy Markdown
Collaborator

@JamesDeng42 JamesDeng42 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Generally LGTM , left one comment.

Comment thread src/maxtext/common/metric_logger.py
@copybara-service copybara-service Bot merged commit 1ffb7d9 into main Apr 16, 2026
88 of 90 checks passed
@copybara-service copybara-service Bot deleted the agagik-loss-fix branch April 16, 2026 20:40
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants