Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ See our guide on running MaxText in decoupled mode, without any GCP dependencies

## 🔥 Latest news 🔥

* \[December 22, 2025\] [Muon optimizer](https://kellerjordan.github.io/posts/muon) is now supported.
* \[December 10, 2025\] DeepSeek V3.1 is now supported. Use existing configs for [DeepSeek V3 671B](https://github.com/AI-Hypercomputer/maxtext/blob/main/src/MaxText/configs/models/deepseek3-671b.yml) and load in V3.1 checkpoint to use model.
* \[December 9, 2025\] [New RL and SFT Notebook tutorials](https://github.com/AI-Hypercomputer/maxtext/tree/main/src/MaxText/examples) are available.
* \[December 4, 2025\] The [ReadTheDocs documentation site](https://maxtext.readthedocs.io/en/latest/index.html) has been reorganized.
Expand Down
10 changes: 9 additions & 1 deletion src/MaxText/configs/base.yml
Original file line number Diff line number Diff line change
Expand Up @@ -704,7 +704,7 @@ gradient_clipping_threshold: 1.0
# batch by accumulating the gradient over a set of steps.
gradient_accumulation_steps: 1

opt_type: "adamw" # one of "adamw", "adam_pax" or "sgd"
opt_type: "adamw" # one of "adamw", "adam_pax", "sgd", or "muon"

# AdamW optimizer parameters
# We use AdamW following Llama2's training details, see https://arxiv.org/pdf/2307.09288.pdf section 2.2
Expand All @@ -717,6 +717,14 @@ mu_dtype: "" # data type to store "mu" of AdamW tracking the first moment. Inher
# Setting nu_dtype is not yet supported by optax, instead nu_dtype is always inherited from weights.
# See b/399961932 for more.

# Muon optimizer parameters
# https://github.com/google-deepmind/optax/blob/main/optax/contrib/_muon.py
# "mu_dtype", "adam_eps" are shared by AdamW
# "nesterov", "ns_coeffs", "ns_steps", "weight_decay_mask", "adaptive" use default
muon_beta: 0.95 # Decay rate for the exponentially weighted average of grads.
muon_weight_decay: 0 # Strength of the weight decay regularization. This is multiplied with the learning rate.
muon_consistent_rms: None # If None, apply width scaling to updates. If float, apply consistent rms scaling (recommend 0.2).
Comment thread
shuningjin marked this conversation as resolved.
Comment thread
shuningjin marked this conversation as resolved.

# Stack trace parameters
collect_stack_trace: False
stack_trace_to_cloud: False # Uploads to cloud logging if True, else to the console if False.
Expand Down
24 changes: 24 additions & 0 deletions src/MaxText/configs/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@ class OptimizerType(str, Enum):
ADAMW = "adamw"
ADAM_PAX = "adam_pax"
SGD = "sgd"
MUON = "muon"


class RopeType(str, Enum):
Expand Down Expand Up @@ -1040,6 +1041,18 @@ class AdamW(BaseModel):
)


class Muon(BaseModel):
"""Configuration specific to the Muon optimizer."""

muon_beta: float = Field(0.95, description="Decay rate for the exponentially weighted average of grads.")
muon_weight_decay: float = Field(
0, description="Strength of the weight decay regularization. This is multiplied with the learning rate."
)
muon_consistent_rms: None | float = Field(
None, description="If None, apply width scaling to updates. If float, apply consistent rms scaling (recommend 0.2)."
)


class PositionalEmbedding(BaseModel):
"""General configuration for positional embeddings."""

Expand Down Expand Up @@ -1617,6 +1630,7 @@ class MaxTextConfig(
TrainingLoop,
Optimizer,
AdamW,
Muon,
FineTuning,
# Reinforcement Learning
RLHardware,
Expand Down Expand Up @@ -2118,6 +2132,16 @@ def calculate_global_batch_sizes(per_device_batch_size, expansion_factor, num_de
self.use_grpo = True
else:
self.use_grpo = False
if self.opt_type == "muon" and self.decoder_block not in [
DecoderBlockType.DEEPSEEK,
DecoderBlockType.QWEN3,
DecoderBlockType.GEMMA3,
DecoderBlockType.LLAMA2,
]:
raise ValueError(
"Muon dimension numbers haven't been tested for this model. Run this command first: "
f"`python3 -m MaxText.muon_utils {self.model_name} True`"
)

# I. FINAL TYPE CONVERSIONS AND DERIVED LISTS
# Create the ici_parallelism and dcn_parallelism lists for legacy compatibility.
Expand Down
25 changes: 13 additions & 12 deletions src/MaxText/maxtext_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -752,18 +752,19 @@ def init_initial_state(model, tx, config, is_training, key):

def get_abstract_param(model, config):
"""Get abstract model structure (name, shape) without materializing the weights to save memory"""
key = jax.random.PRNGKey(0)
input_shape = (config.micro_batch_size_to_train_on, config.max_target_length)
image_shape = multimodal_utils.get_dummy_image_shape_for_init(
config.model_name, batch_size=config.micro_batch_size_to_train_on
)
abstract_vars = jax.eval_shape(
model.init,
{"params": key, "dropout": key, "aqt": key},
jnp.ones(input_shape, dtype=jnp.int32),
jnp.ones(input_shape, dtype=jnp.int32),
encoder_images=np.ones(image_shape, dtype=jnp.int32) if config.use_multimodal else None,
)
with model.mesh, nn_partitioning.axis_rules(config.logical_axis_rules):
key = jax.random.PRNGKey(0)
input_shape = (config.micro_batch_size_to_train_on, config.max_target_length)
image_shape = multimodal_utils.get_dummy_image_shape_for_init(
config.model_name, batch_size=config.micro_batch_size_to_train_on
)
abstract_vars = jax.eval_shape(
model.init,
{"params": key, "dropout": key, "aqt": key},
jnp.ones(input_shape, dtype=jnp.int32),
jnp.ones(input_shape, dtype=jnp.int32),
encoder_images=np.ones(image_shape, dtype=jnp.int32) if config.use_multimodal else None,
)
return abstract_vars


Expand Down
174 changes: 174 additions & 0 deletions src/MaxText/muon_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,174 @@
# Copyright 2023–2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


"""Utilities for Muon optimizer integration and dimension number generation.

This module provides functions to automatically generate MuonDimensionNumbers
for various MaxText models. These dimension numbers are crucial for the Muon
optimizer to correctly apply its update rules.

This module can also be run as a script to inspect the generated dimension
numbers for a specific model. Example:
python3 -m MaxText.muon_utils qwen3-4b True
"""


import os
import sys
from typing import Optional, Tuple

import flax.linen as nn
import jax
from optax.contrib._muon import MuonDimensionNumbers as mdn

from MaxText import maxtext_utils, pyconfig
from MaxText.globals import MAXTEXT_PKG_DIR
from MaxText.layers import models, quantizations


Transformer = models.transformer_as_linen


def _is_path_contain_any(tuples, path):
return any(x in path for x in tuples)


def transform_logic(path: Tuple[str, ...]) -> Optional[mdn]:
"""
Determines Muon dimension numbers based on the parameter's hierarchical path.

This function defines the mapping from a parameter's logical path within the model
to its corresponding MuonDimensionNumbers (mdn). The strategy is applied in
a specific order to handle general cases and then more specific ones, allowing
for fall-through logic in nested structures.

Strategy:
1. Exclusions: Parameters not suitable for Muon (e.g., scalars, embeddings,
unembedding) are explicitly returned as `None`.
2. Special Weights:
2.1 MoE Block Specific Weights
2.2 Self-Attention Specific Weights
3. Standard Weights: Default mapping for most other 3D weight shapes.

Args:
path: A tuple of strings representing the hierarchical path of the parameter.

Returns:
An instance of `MuonDimensionNumbers` if a specific mapping is found,
`None` for excluded parameters, or a default `mdn` for standard weights.
"""

# 1 Exclude parameters not suitable for Muon (scalar, embeddings, unembedding)
if _is_path_contain_any(("scale", "bias", "embedding", "logits_dense"), path):
return None

# 2 Special weights
# 2.1 Special weights: MoE, [0, L, -2, -1]
# L (optional) stands for layer when scan_layers=True
if "MoeBlock_0" in path:
# exclude gate
if _is_path_contain_any(("wi_0", "wi_1", "wo"), path):
return mdn((-2,), (-1,))

# 2.2 Special weights: Self attention
elif "self_attention" in path:
# Attention output projection: [0, L, -2, -1]
if "out" in path:
return mdn((0, -2), (-1,))
# Attention qkv projection: [0, L, -2, -1]
# MLA, exclude wq_a / wkv_a
elif _is_path_contain_any(("query", "key", "value", "wq_b", "wkv_b"), path):
return mdn((0,), (-2, -1))

# 3 Standard weights, [0, L, -1]
return mdn((0,), (-1,))


def get_transform_tree(tree, path=()):
"""Extraction utility via recursion."""
if isinstance(tree, dict):
return {k: get_transform_tree(v, path + (k,)) for k, v in tree.items()}
else:
return transform_logic(path)


def get_muon_weight_dimension_numbers(model, config, verbose=False):
"""Extract muon dimension number from model structure."""
# quickly get param structure without materialization
abstract_param = maxtext_utils.get_abstract_param(model, config)
# get muon dimension number from param
muon_weight_dimension_numbers = get_transform_tree(abstract_param)
if verbose:
_print_structure_debug(abstract_param, muon_weight_dimension_numbers)
return muon_weight_dimension_numbers


def _print_structure_debug(abstract_param, muon_weight_dimension_numbers):
"""Prints the model structure and the resulting Muon config."""
# Access the shape from the inner ShapeDtypeStruct and names from the wrapper
# Return a new tree with the same structure containing only shapes/names
info_tree = jax.tree_util.tree_map(
lambda leaf: {"shape": leaf.value.shape, "names": leaf.names},
abstract_param,
is_leaf=lambda x: isinstance(x, nn.LogicallyPartitioned),
)
print(f"\n=== Model Structure ===\n{info_tree}")
print(f"\n=== Muon Dimension Numbers ===\n{muon_weight_dimension_numbers}")
print("\nIs this reasonable?")


def get_model_mdn(model_name, scan_layers=True, verbose=False):
"""Initializes a model and retrieves its Muon dimension numbers.

This function sets up the configuration for a given model, initializes the
transformer model, and then extracts the Muon dimension numbers for the model's
weights. It can optionally print verbose debug information.

Args:
model_name: The name of the model to be initialized.
scan_layers: Whether to use layer scanning in the model configuration.
verbose: If True, prints detailed debugging information about the model
structure and Muon dimension numbers.

Returns:
A tree structure containing the Muon dimension numbers for the model's
parameters.
"""
# Setup config
argv = [
None,
os.path.join(MAXTEXT_PKG_DIR, "configs", "base.yml"),
f"model_name={model_name}",
f"scan_layers={scan_layers}",
"attention=dot_product",
]
config = pyconfig.initialize(argv)
# Setup model
devices_array = maxtext_utils.create_device_mesh(config)
mesh = jax.sharding.Mesh(devices_array, config.mesh_axes)
quant = quantizations.configure_quantization(config)
model = Transformer(config, mesh=mesh, quant=quant)
# Get dimension number
muon_weight_dimension_numbers = get_muon_weight_dimension_numbers(model, config, verbose=verbose)
return muon_weight_dimension_numbers


if __name__ == "__main__":
if len(sys.argv) != 3:
print("Usage: python3 -m MaxText.muon_utils <model_name> <scan_layers:True/False>")
sys.exit(1)
model_name_arg = sys.argv[1]
scan_layers_arg = sys.argv[2].lower() == "true"
get_model_mdn(model_name_arg, scan_layers_arg, verbose=True)
27 changes: 26 additions & 1 deletion src/MaxText/optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,11 @@
import jax.numpy as jnp

import optax
from optax.contrib._muon import muon
from MaxText.muon_utils import get_muon_weight_dimension_numbers


def get_optimizer(config, learning_rate_schedule):
def get_optimizer(config, learning_rate_schedule, model=None):
"""Create optimizer."""
if config.opt_type == "adamw":
# Create AdamW Optimizer following Llama2's training details, see https://arxiv.org/pdf/2307.09288.pdf section 2.2
Expand All @@ -45,6 +47,29 @@ def get_optimizer(config, learning_rate_schedule):
)
elif config.opt_type == "sgd":
return optax.sgd(learning_rate_schedule)
elif config.opt_type == "muon":
# extract muon dimension number from model structure
if model is not None:
muon_weight_dimension_numbers = get_muon_weight_dimension_numbers(model, config)
else:
raise ValueError("Please specify model to extract muon dimension number.")
muon_kwargs = {
# Shared parameters: "nesterov" uses default
"learning_rate": learning_rate_schedule,
"eps": config.adam_eps,
"mu_dtype": config.mu_dtype,
# Muon-specific parameters: "ns_coeffs", "ns_steps", "weight_decay_mask", "adaptive" uses default
"beta": config.muon_beta,
"weight_decay": config.muon_weight_decay,
"muon_weight_dimension_numbers": muon_weight_dimension_numbers,
"consistent_rms": config.muon_consistent_rms,
# AdamW-specific parameters
"adam_b1": config.adam_b1,
"adam_b2": config.adam_b2,
"adam_eps_root": config.adam_eps_root,
"adam_weight_decay": config.adam_weight_decay,
}
return muon(**muon_kwargs)
else:
raise ValueError(f"{config.opt_type=} is not a supported.")

Expand Down
19 changes: 12 additions & 7 deletions src/MaxText/pyconfig.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,16 @@ def _prepare_for_pydantic(raw_keys: dict[str, Any]) -> dict[str, Any]:
if key == "run_name" and new_value is None:
new_value = ""

# Preprocess muon_consistent_rms to be None or float
if key == "muon_consistent_rms":
if value in ["None", "none"]:
new_value = None
else:
try:
new_value = float(value)
except ValueError as e:
raise ValueError("muon_consistent_rms should be None or float") from e

pydantic_kwargs[key] = new_value

return pydantic_kwargs
Expand Down Expand Up @@ -293,13 +303,8 @@ def initialize_pydantic(argv: list[str], **kwargs) -> MaxTextConfig:

pydantic_kwargs = _prepare_for_pydantic(raw_keys_dict)

if pydantic_kwargs.get("use_tokamax_splash") and pydantic_kwargs.get(
"use_jax_splash"
):
raise ValueError(
"At most one of `use_tokamax_splash` and `use_jax_splash` can be set to"
" True."
)
if pydantic_kwargs.get("use_tokamax_splash") and pydantic_kwargs.get("use_jax_splash"):
raise ValueError("At most one of `use_tokamax_splash` and `use_jax_splash` can be set to True.")

# Initialize JAX distributed system before device backend is initialized.
if pydantic_kwargs.get("jax_debug_log_modules"):
Expand Down
3 changes: 2 additions & 1 deletion src/MaxText/sft/sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,8 @@ def setup_trainer_state(mt_config, goodput_recorder=None):
with maybe_record_goodput(goodput_recorder, GoodputEvent.TPU_INIT):
model, mesh = model_creation_utils.create_nnx_model(mt_config)
learning_rate_schedule = maxtext_utils.create_learning_rate_schedule(mt_config)
optimizer = optimizers.get_optimizer(mt_config, learning_rate_schedule)
# pass in model for muon
optimizer = optimizers.get_optimizer(mt_config, learning_rate_schedule, model)

with maybe_record_goodput(goodput_recorder, GoodputEvent.TRAINING_PREPARATION):
training_hooks = hooks.SFTTrainingHooks(mt_config, mesh, learning_rate_schedule, goodput_recorder)
Expand Down
3 changes: 2 additions & 1 deletion src/MaxText/train_compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,8 @@ def get_shaped_inputs(topology_mesh, config):
model = Transformer(config, topology_mesh, quant=quant, model_mode=MODEL_MODE_TRAIN)
# The learning_rate_schedule is baked into the compiled object.
learning_rate_schedule = maxtext_utils.create_learning_rate_schedule(config)
tx = optimizers.get_optimizer(config, learning_rate_schedule)
# pass in model for muon
tx = optimizers.get_optimizer(config, learning_rate_schedule, model)

# Shaped RNG keys
_, example_rng = jax.random.split(jax.random.PRNGKey(0), 2)
Expand Down
3 changes: 2 additions & 1 deletion src/MaxText/train_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,8 @@ def create_training_tools(config, model, mesh):
"""Creates the init_rng, optimizer, learning rate schedule, and checkpoint manager."""
init_rng = jax.random.PRNGKey(config.init_weights_seed)
learning_rate_schedule = maxtext_utils.create_learning_rate_schedule(config)
tx = optimizers.get_optimizer(config, learning_rate_schedule)
# pass in model for muon
tx = optimizers.get_optimizer(config, learning_rate_schedule, model)
logger = checkpointing.setup_checkpoint_logger(config)
if config.enable_multi_tier_checkpointing:
checkpoint_manager = checkpointing.create_orbax_emergency_replicator_checkpoint_manager(
Expand Down
Loading