Skip to content

Add muon optimizer#2546

Merged
copybara-service[bot] merged 1 commit intomainfrom
shuningjin-opt6
Dec 22, 2025
Merged

Add muon optimizer#2546
copybara-service[bot] merged 1 commit intomainfrom
shuningjin-opt6

Conversation

@shuningjin
Copy link
Copy Markdown
Collaborator

@shuningjin shuningjin commented Oct 25, 2025

Description

This PR integrates Muon optimizer to MaxText training. We use implementation from Optax with reshaping interface.

Special note for Muon optimizer: (1) Muon is used together with AdamW. (2) Muon only applies for 2D parameter, and we need reshape 3D/4D to 2D when applicable.

Fix: b/437908829

0 Prerequisite

The changes will need optax >=0.2.7. As of 2025-12-18, we only have optax release 0.2.6 from 2025-09. For now, we need manual install from head. Optax 0.2.7 is released on 2025-02-05.

Need latest changes in optax.contrib.muon. Importantly, we recently modified it to have

# Install the specific commit
pip install git+https://github.com/google-deepmind/optax@9858013795e22958fc2b318fb59f254bf700b10e

or

# uninstall the old version first
pip uninstall optax
# Install the latest 'main' branch 
pip install git+https://github.com/google-deepmind/optax

1 Code change

Basic integration:

  • src/MaxText/configs/base.yml: add muon config, while reusing adamw config
  • src/MaxText/optimizers.py: use optax.contrib.muon

Reshape interface (see Sec 3)

  • src/MaxText/muon_utils.py:
    • we use maxtext.get_abstract_param to get abstract structure without materializing the weight. This does not have memory or FLOP overhead.
    • from the structure, we use get_transform_tree to get the muon dimension number.
    • To review the muon dimension number of a model: e.g., python3 -m MaxText.muon_utils qwen3-4b True. This is helpful for integrating Muon with more models (see Sec 3.3)
  • pass in model for optimizers.get_optimizer (train_utils.py, train_compile.py, sft_trainer.py)
  • muon_test.py: unit test to ensure generated Muon dimension numbers match the hardcoded reference

2 User guide

Model Specific Support: Reshape

  • deepseek2, deepseek3, kimi-k2, qwen3, gemma3, llama2 / llama3 (i.e., DecoderBlockTypeDEEPSEEK, DecoderBlockType.QWEN3, DecoderBlockType.GEMMA3, DecoderBlockType.LLAMA2).
  • The reshaping for these is tested in unit test against hardcoded reference.
  • For other models, raise an error.
  • For integrating Muon with more models, see Sec 3.3

Sharding: Works with different sharding (e.g., tested FSDP, DP, TP)

Configs: opt_type=muon, optionally muon_beta=0.95, muon_weight_decay=0.1, muon_consistent_rms=0.2

Pretrain command

BASE_OUTPUT_PATH=gs://runner-maxtext-logs
RUN_NAME=muon-$(date +%Y-%m-%d-%H-%M-%S)
python3 -m MaxText.train MaxText/configs/base.yml \
base_output_directory=$BASE_OUTPUT_PATH run_name=$RUN_NAME \
model_name=gemma3-4b \
tokenizer_type=sentencepiece tokenizer_path=src/MaxText/assets/tokenizer.gemma3 \
dataset_type=tfds dataset_path='gs://mlperf-llm-public2' dataset_name='c4/en:3.0.4' train_split='train2' \
enable_checkpointing=false dtype=bfloat16 weight_dtype=bfloat16 \
opt_type=muon learning_rate=5e-4 adam_weight_decay=0.1 muon_weight_decay=0.1 muon_consistent_rms=0.2 \
per_device_batch_size=16 max_target_length=2048 steps=20 \
ici_fsdp_parallelism=4 ici_data_parallelism=1 ici_tensor_parallelism=1 \
cosine_learning_rate_final_fraction=0.1 warmup_steps_fraction=0.1 learning_rate_schedule_steps=-1 \
override_model_config=true enable_dropout=false \
profiler=xplane skip_first_n_steps_for_profiler=5 profiler_steps=3

Train compile command

BASE_OUTPUT_PATH=gs://runner-maxtext-logs
RUN_NAME=muon-$(date +%Y-%m-%d-%H-%M-%S)
python3 -m MaxText.train_compile MaxText/configs/base.yml \
base_output_directory=$BASE_OUTPUT_PATH run_name=$RUN_NAME \
model_name=gemma3-4b \
tokenizer_type=sentencepiece tokenizer_path=src/MaxText/assets/tokenizer.gemma3 \
dataset_type=tfds dataset_path='gs://mlperf-llm-public2' dataset_name='c4/en:3.0.4' train_split='train2' \
enable_checkpointing=false dtype=bfloat16 weight_dtype=bfloat16 \
opt_type=muon learning_rate=5e-4 adam_weight_decay=0.1 muon_weight_decay=0.1 muon_consistent_rms=0.2 \
per_device_batch_size=16 max_target_length=2048 steps=20 \
ici_fsdp_parallelism=2 ici_data_parallelism=1 ici_tensor_parallelism=2 \
cosine_learning_rate_final_fraction=0.1 warmup_steps_fraction=0.1 learning_rate_schedule_steps=-1 \
override_model_config=true enable_dropout=false \
compile_topology=v5p-8 compile_topology_num_slices=1

3 Reshape Interface

3.1 Why we need this?

3D and 4D parameters are logically 2D. Use the MuonDimensionNumber (mdn) for reshaping specification.

  • Note: reduction_dim - in feature, output_dim - out feature, the rest dims are batch over. dims can be negative number, e.g., 0 is 0th dim, -1 is the last dim, -2 is second to last. dims grouped together are flatten.
  • e.g., decoder.moe_layers.DeepSeekMoeBlock_0.MoeBlock_0.wo, (num_experts, num_layer, base_moe_mlp_dim, base_emb_dim), reduction_dim = (-2,), output_dim = (-1,)
  • e.g, decoder.moe_layers.self_attention.out.kernel, (base_num_query_heads, num_layer, v_head_dim, base_emb_dim), reduction = (0, -2), outputdim = (-1,)
  • As muon is designed for 2D we do not apply muon to scalar (norm, bias). Additionally, we do not apply it for embedding and unembedding, as previous work suggests this is empirically better

On the Optax side, this is achieved via google-deepmind/optax#1407.

3.2 How we integrate it into MaxText?

Determine mdn for parameter

To determine a mdn for a parameter: we can investigate the weight shape. For example

  • check code: self.query = DenseGeneral(in_features_shape=self.config.emb_dim, out_features_shape=(self.num_query_heads, self.qk_head_dim)
  • check parameter: decoder.moe_layers.self_attention.query.kernel. The shape is (base_emb_dim, L, base_num_query_heads, qk_head_dim), where qk_head_dim=(qk_nope_head_dim + qk_rope_head_dim).
  • determine the muon dimension. in_features is axis 0, out_features is axis -2 and -1. Thus, mdn(reduction_axes=(0,), output_axes=(-2, -1)). (Remark: why not set output_axes=(2, 3), because axis 1 is optional only for scan)

Note many parameters are shared across model, so we only need to go through this process when there is new component.

Assemble mdn for model

One way is to hardcode the dimension number for each model. Example: tests.muon_test.

To generalize the reshaping, we choose another way

  • Given the abstract model param, we extract a static tree of mdn, using name-based rules.
  • (Alternatively, we can pass a callable mdn to muon. However, this can introduce more overhead for update.)

3.3 How to accommodate more models

Step 1: check hard-coded reference in tests/muon_test.py to get a sense

Step 2: print model structure and automatically dimension number, feed output to Gemini. Example:

# example: model_name=qwen3-4b, scan_layers=True
python3 -m MaxText.muon_utils qwen3-4b True
=== Model Structure ===
{'params': {'decoder': {'decoder_norm': {'scale': {'shape': (2560,), 'names': ('norm',)}}, 'layers': {'mlp': {'wi_0': {'kernel': {'shape': (2560, 36, 9728), 'names': ('embed', 'layers', 'mlp')}}, 'wi_1': {'kernel': {'shape': (2560, 36, 9728), 'names': ('embed', 'layers', 'mlp')}}, 'wo': {'kernel': {'shape': (9728, 36, 2560), 'names': ('mlp', 'layers', 'embed')}}}, 'post_self_attention_layer_norm': {'scale': {'shape': (2560, 36), 'names': ('norm', 'layers')}}, 'pre_self_attention_layer_norm': {'scale': {'shape': (2560, 36), 'names': ('norm', 'layers')}}, 'self_attention': {'key': {'kernel': {'shape': (2560, 36, 8, 128), 'names': ('embed', 'layers', 'kv_heads', 'kv_head_dim')}}, 'key_norm': {'scale': {'shape': (128, 36), 'names': ('norm', 'layers')}}, 'out': {'kernel': {'shape': (32, 36, 128, 2560), 'names': ('heads', 'layers', 'kv', 'embed')}}, 'query': {'kernel': {'shape': (2560, 36, 32, 128), 'names': ('embed', 'layers', 'q_heads', 'kv')}}, 'query_norm': {'scale': {'shape': (128, 36), 'names': ('norm', 'layers')}}, 'value': {'kernel': {'shape': (2560, 36, 8, 128), 'names': ('embed', 'layers', 'kv_heads', 'kv_head_dim')}}}}}, 'token_embedder': {'embedding': {'shape': (151936, 2560), 'names': ('vocab', 'embed')}}}}

=== Muon Dimension Numbers ===
{'params': {'decoder': {'decoder_norm': {'scale': None}, 'layers': {'mlp': {'wi_0': {'kernel': MuonDimensionNumbers(reduction_axis=(0,), output_axis=(-1,))}, 'wi_1': {'kernel': MuonDimensionNumbers(reduction_axis=(0,), output_axis=(-1,))}, 'wo': {'kernel': MuonDimensionNumbers(reduction_axis=(0,), output_axis=(-1,))}}, 'post_self_attention_layer_norm': {'scale': None}, 'pre_self_attention_layer_norm': {'scale': None}, 'self_attention': {'key': {'kernel': MuonDimensionNumbers(reduction_axis=(0,), output_axis=(-2, -1))}, 'key_norm': {'scale': None}, 'out': {'kernel': MuonDimensionNumbers(reduction_axis=(0, -2), output_axis=(-1,))}, 'query': {'kernel': MuonDimensionNumbers(reduction_axis=(0,), output_axis=(-2, -1))}, 'query_norm': {'scale': None}, 'value': {'kernel': MuonDimensionNumbers(reduction_axis=(0,), output_axis=(-2, -1))}}}}, 'token_embedder': {'embedding': None}}}

Is this reasonable?

Step 3: If Gemini says not reasonable, copy paste MaxText.muon_utils.transform_logic function and ask it to revise.

# gemini answer for this example
Yes, this configuration looks highly reasonable and correct for a scanned (layer-stacked) Transformer implementation using the Muon optimizer.

The configuration correctly identifies the fan-in (reduction axis) and fan-out (output axis) for the weight matrices, while correctly excluding vector parameters (Norms) and the Embedding table.

Here is the detailed verification of why this works: ...

Step 4: Double check of the final muon number, add it to the test. (e.g., compare it with similar model from Step 1, or double check component as in 3.2)

Tests

unit test for reshape: python3 -m pytest -v --pyargs tests.muon_test -rP -s

  • This test can be run on CPU
  • deepseek2, deepseek3, kimi-k2, gemma3, llama2 / llama3, qwen3

end-to-end test

  • gemma3-4b: b/437908829#comment21

Checklist

Before submitting this PR, please make sure (put X in square brackets):

  • I have performed a self-review of my code. For an optional AI review, add the gemini-review label.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have run end-to-end tests tests and provided workload links above if applicable.
  • I have made or will make corresponding changes to the doc if needed, including adding new documentation pages to the relevant Table of Contents (toctree directive) as explained in our documentation.

Copy link
Copy Markdown
Collaborator

@RissyRan RissyRan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the work and test!

Comment thread src/MaxText/muon_dimension_number.py Outdated
Comment thread src/MaxText/muon_dimension_number.py Outdated
Comment thread src/MaxText/optimizers.py Outdated
Comment thread src/MaxText/configs/base.yml
Comment thread tests/muon_test.py
@github-actions
Copy link
Copy Markdown

🤖 Hi @shuningjin, I've received your request, and I'm working on it now! You can track my progress in the logs for more details.

Copy link
Copy Markdown

@github-actions github-actions Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

📋 Review Summary

This PR introduces the muon optimizer from optax.contrib. The integration is well-structured, with the core logic for determining Muon dimension numbers encapsulated in the new src/MaxText/muon_utils.py file. The changes are supported by a new unit test that verifies the dimension number calculations for a range of models.

🔍 General Feedback

  • The use of a rule-based approach in muon_utils.py to determine dimension numbers is a reasonable approach for the initial integration.
  • The inclusion of a command-line utility within muon_utils.py to inspect model structures is a great addition for developers.
  • The updates to the configuration and various training scripts to support the new optimizer are thorough.
  • The new unit tests in tests/muon_test.py provide good coverage for the supported models.

Overall, this is a solid contribution that adds a valuable new optimizer to MaxText. The code is clean, well-documented, and tested.

Comment thread src/MaxText/muon_utils.py Outdated
Copy link
Copy Markdown
Collaborator

@gagika gagika left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks

Comment thread src/MaxText/configs/base.yml
Comment thread src/MaxText/configs/types.py Outdated
@shuningjin shuningjin force-pushed the shuningjin-opt6 branch 2 times, most recently from fc38dca to 475d3da Compare December 18, 2025 20:56
Copy link
Copy Markdown
Collaborator

@RissyRan RissyRan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the great work!

@shuningjin shuningjin requested a review from jacoguzo as a code owner December 19, 2025 01:59
@shuningjin shuningjin force-pushed the shuningjin-opt6 branch 3 times, most recently from 52f39dc to 043bc21 Compare December 22, 2025 19:16
@copybara-service copybara-service Bot merged commit 948d302 into main Dec 22, 2025
18 of 19 checks passed
@copybara-service copybara-service Bot deleted the shuningjin-opt6 branch December 22, 2025 23:32
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants