Add muon optimizer#2546
Conversation
e92b401 to
f308526
Compare
f308526 to
72979a3
Compare
f742c6f to
245759b
Compare
RissyRan
left a comment
There was a problem hiding this comment.
Thanks for the work and test!
220c0e9 to
37851a8
Compare
|
🤖 Hi @shuningjin, I've received your request, and I'm working on it now! You can track my progress in the logs for more details. |
There was a problem hiding this comment.
📋 Review Summary
This PR introduces the muon optimizer from optax.contrib. The integration is well-structured, with the core logic for determining Muon dimension numbers encapsulated in the new src/MaxText/muon_utils.py file. The changes are supported by a new unit test that verifies the dimension number calculations for a range of models.
🔍 General Feedback
- The use of a rule-based approach in
muon_utils.pyto determine dimension numbers is a reasonable approach for the initial integration. - The inclusion of a command-line utility within
muon_utils.pyto inspect model structures is a great addition for developers. - The updates to the configuration and various training scripts to support the new optimizer are thorough.
- The new unit tests in
tests/muon_test.pyprovide good coverage for the supported models.
Overall, this is a solid contribution that adds a valuable new optimizer to MaxText. The code is clean, well-documented, and tested.
37851a8 to
063193f
Compare
063193f to
ad77a46
Compare
fc38dca to
475d3da
Compare
RissyRan
left a comment
There was a problem hiding this comment.
Thanks for the great work!
17ca170 to
1865c27
Compare
52f39dc to
043bc21
Compare
043bc21 to
a4650cf
Compare
Description
This PR integrates Muon optimizer to MaxText training. We use implementation from Optax with reshaping interface.
Special note for Muon optimizer: (1) Muon is used together with AdamW. (2) Muon only applies for 2D parameter, and we need reshape 3D/4D to 2D when applicable.
Fix: b/437908829
0 Prerequisite
The changes will need optax >=0.2.7.
As of 2025-12-18, we only have optax release 0.2.6 from 2025-09. For now, we need manual install from head.Optax 0.2.7 is released on 2025-02-05.Need latest changes in optax.contrib.muon. Importantly, we recently modified it to have
or
1 Code change
Basic integration:
src/MaxText/configs/base.yml: add muon config, while reusing adamw configsrc/MaxText/optimizers.py: use optax.contrib.muonReshape interface (see Sec 3)
src/MaxText/muon_utils.py:maxtext.get_abstract_paramto get abstract structure without materializing the weight. This does not have memory or FLOP overhead.get_transform_treeto get the muon dimension number.python3 -m MaxText.muon_utils qwen3-4b True. This is helpful for integrating Muon with more models (see Sec 3.3)optimizers.get_optimizer(train_utils.py,train_compile.py,sft_trainer.py)muon_test.py: unit test to ensure generated Muon dimension numbers match the hardcoded reference2 User guide
Model Specific Support: Reshape
DecoderBlockTypeDEEPSEEK, DecoderBlockType.QWEN3, DecoderBlockType.GEMMA3, DecoderBlockType.LLAMA2).Sharding: Works with different sharding (e.g., tested FSDP, DP, TP)
Configs:
opt_type=muon, optionallymuon_beta=0.95,muon_weight_decay=0.1,muon_consistent_rms=0.2Pretrain command
Train compile command
3 Reshape Interface
3.1 Why we need this?
3D and 4D parameters are logically 2D. Use the MuonDimensionNumber (mdn) for reshaping specification.
reduction_dim- in feature,output_dim- out feature, the rest dims are batch over. dims can be negative number, e.g., 0 is 0th dim, -1 is the last dim, -2 is second to last. dims grouped together are flatten.decoder.moe_layers.DeepSeekMoeBlock_0.MoeBlock_0.wo, (num_experts, num_layer, base_moe_mlp_dim, base_emb_dim), reduction_dim = (-2,), output_dim = (-1,)decoder.moe_layers.self_attention.out.kernel, (base_num_query_heads, num_layer, v_head_dim, base_emb_dim), reduction = (0, -2), outputdim = (-1,)On the Optax side, this is achieved via google-deepmind/optax#1407.
3.2 How we integrate it into MaxText?
Determine mdn for parameter
To determine a mdn for a parameter: we can investigate the weight shape. For example
self.query = DenseGeneral(in_features_shape=self.config.emb_dim, out_features_shape=(self.num_query_heads, self.qk_head_dim)decoder.moe_layers.self_attention.query.kernel. The shape is(base_emb_dim, L, base_num_query_heads, qk_head_dim), where qk_head_dim=(qk_nope_head_dim + qk_rope_head_dim).Note many parameters are shared across model, so we only need to go through this process when there is new component.
Assemble mdn for model
One way is to hardcode the dimension number for each model. Example:
tests.muon_test.To generalize the reshaping, we choose another way
3.3 How to accommodate more models
Step 1: check hard-coded reference in
tests/muon_test.pyto get a senseStep 2: print model structure and automatically dimension number, feed output to Gemini. Example:
Step 3: If Gemini says not reasonable, copy paste
MaxText.muon_utils.transform_logicfunction and ask it to revise.Step 4: Double check of the final muon number, add it to the test. (e.g., compare it with similar model from Step 1, or double check component as in 3.2)
Tests
unit test for reshape:
python3 -m pytest -v --pyargs tests.muon_test -rP -send-to-end test
Checklist
Before submitting this PR, please make sure (put X in square brackets):
gemini-reviewlabel.