Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 0 additions & 5 deletions docs/reference/core_concepts/moe_configuration.md
Original file line number Diff line number Diff line change
Expand Up @@ -96,11 +96,6 @@ Dropping:

## 2. Sharding

`expert_shard_attention_option`: Determines how the "expert" axis is interpreted when sharding attention layers. Options include:

- `fsdp`: Treats the expert axis as a FSDP axis.
- `context`: Treats the expert axis as a context parallelism axis, useful for long context.

`use_ring_of_experts` (experimental): This feature requires expert parallelism. If enabled, it replaces the standard two All-to-All communications with All-Gather in dispatch and Reduce-Scatter in collect. By gathering inputs across all shards, it allows for local routing and Top-K calculations, followed by result aggregation via Reduce-Scatter. This approach is particularly effective for models with a large Top-K, as it gathers activations before they are replicated k times to reduce communication.

`moe_fsdp_use_two_stage_all_gather`: If enabled, split the All-Gather operation for MoE weights into two separate stages when using FSDP/FSDP-transpose sharding. This is preferred when 3D All-Gather support is unavailable.
Expand Down
4 changes: 0 additions & 4 deletions src/maxtext/common/common_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,10 +66,6 @@
MODEL_MODE_PREFILL = "prefill"
MODEL_MODE_TRAIN = "train"

# expert_shard_attention_option
EP_AS_CONTEXT = "context"
EP_AS_FSDP = "fsdp"

DECODING_ACTIVE_SEQUENCE_INDICATOR = 1

# A large negative mask value is used for masking to ensure that the
Expand Down
9 changes: 4 additions & 5 deletions src/maxtext/configs/base.yml
Original file line number Diff line number Diff line change
Expand Up @@ -237,11 +237,6 @@ merge_gating_gmm: False

norm_topk_prob: false # boolean to enable the top-k probability normalization. qwen3-specific normalization of router weights.

# how the expert axis is used to shard attention weights and activations
# "fsdp" (ep acts as fsdp parallelism)
# "context" (ep acts as context parallelism, training only)
expert_shard_attention_option: "fsdp"

# when moe weight matrices are sharded on both fsdp and fsdp-transpose axes, use two separate all-gather calls
moe_fsdp_use_two_stage_all_gather: false
# Shard the expert dimension of the MLP weights on the FSDP axis.
Expand Down Expand Up @@ -521,6 +516,7 @@ logical_axis_rules: [
# ==========================================
# Dense Activations
['activation_mlp', ['tensor', 'tensor_transpose', 'tensor_sequence']],
# Note activation batch and length also get used in attention and vocab
['activation_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert']],
['activation_length', ['sequence', 'context']],
['activation_length', ['context']],
Expand Down Expand Up @@ -569,6 +565,9 @@ logical_axis_rules: [
# Axes used for DCN must be earlier in this list than ICI, see (b/339009148) for details
data_sharding: [['data', 'stage', 'fsdp', 'fsdp_transpose', 'sequence', 'context', 'context_autoregressive', 'tensor', 'tensor_transpose', 'tensor_sequence', 'expert', 'autoregressive']]
input_data_sharding_logical_axes: ['activation_embed_and_logits_batch', 'activation_norm_length']
# Determines which physical axis plays the role of context parallelism for input data processing and load balancing
# only supports "context" or "expert" (when custom_mesh_and_rule=ep-as-cp)
context_sharding: "context"
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What values can this take? Can we remove this as a field as instead its implied by the logical axis rules? E.g. we need a fuction that takes as input the rules and outputs the value of context_sharding?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we list other options? if any

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is hard to infer which physical axis is used for CP from reading logical rule. For example, both sequence and context are used to shard activation_length but only context is used for data processing.

I will add comments indicating possible values of context sharding and add checks.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we list other options? if any

Done!


# sharding tolerance: float between 0.0 and 1.0 representing the allowed percentage of non-sharded parameters.
sharding_tolerance: 0.02
Expand Down
77 changes: 77 additions & 0 deletions src/maxtext/configs/custom_mesh_and_rule/ep-as-cp.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
# Copyright 2026 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# This rule uses data, stage, FSDP, and expert. Expert axis acts as context parallelism in
# components except core dMoE part (between EP all2all).
mesh_axes: ['data', 'stage', 'fsdp', 'expert']
data_sharding: [['data', 'stage', 'fsdp', 'expert']]
context_sharding: 'expert'
logical_axis_rules: [
# ==========================================
# Vocabulary Embedding
# ==========================================
# Vocab Activations
['activation_embed_and_logits_batch', ['data', 'stage', 'fsdp']],
['activation_embed_and_logits_batch_sequence', ['data', 'stage', 'fsdp', 'expert']],
# Vocab Weights
['vocab', []],
['embed_vocab', ['fsdp', 'expert']],
# ==========================================
# Attention
# ==========================================
# Attention Activations
['activation_heads', []],
['activation_kv_heads', []],
['activation_attn_length', ['expert']],
['activation_q_length', ['expert']],
['activation_kv_length', []],
['activation_attn_embed', []],
['activation_kv', []],
['activation_kv_batch', ['data', 'fsdp']],
['activation_kv_head_dim', []],
# Attention Weights
['heads', []],
['q_heads', []],
['kv_heads', []],
['qkv', []],
['kv', []],
['kv_head_dim', []],
['q_lora', ['fsdp']],
["q_lora_up_proj", []],
['kv_lora', ['fsdp']],
["kv_lora_up_proj", []],
# ==========================================
# Mixture of Experts (MoE)
# ==========================================
# MoE Activations
['activation_batch_moe', ['data', 'fsdp']],
['activation_exp', ['expert']],
# MoE Weights
['exp', 'expert'],
['embed_moe', ['fsdp']],
# ==========================================
# Standard MLP / Dense Layers / Model Structure
# ==========================================
# Dense Activations
['activation_mlp', []],
['activation_batch', ['data', 'fsdp']],
Copy link
Copy Markdown
Collaborator

@gobbleturk gobbleturk Apr 15, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

activation_batch is also used for attention and is a key dimension. Maybe note this as a comment in the attention section above

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point! Added comments in base.yml

['activation_length', ['expert']],
['activation_norm_length', ['expert']],
['activation_embed', []],
['activation_stage', 'stage'],
# General Weights
['mlp', []],
['layers', 'stage'],
['embed', ['fsdp', 'expert']],
]
16 changes: 8 additions & 8 deletions src/maxtext/configs/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -661,10 +661,6 @@ class MoEGeneral(BaseModel):
)
use_random_routing: bool = Field(False, description="Whether to use random routing for debugging.")
interleave_moe_layer_step: int = Field(1, description="Frequency of MoE layers, e.g., 2 means every 2nd layer is MoE.")
expert_shard_attention_option: Literal["fsdp", "context"] = Field(
"fsdp",
description="How the expert axis is used to shard attention weights and activations.",
)
moe_fsdp_use_two_stage_all_gather: bool = Field(
False,
description="Use two separate All-Gather calls for MoE weights sharded on both FSDP and FSDP-transpose.",
Expand Down Expand Up @@ -842,6 +838,7 @@ class LayoutAndSharding(BaseModel):

logical_axis_rules: Any = Field([], description="Rules for mapping logical axes to physical mesh axes.")
data_sharding: Any = Field([], description="Sharding for input data.")
context_sharding: str = Field("context", description="Physical axis name for context parallelism.")
input_data_sharding_logical_axes: list[str] = Field(
["activation_embed_and_logits_batch", "activation_norm_length"],
description="Logical axes for sharding input data.",
Expand Down Expand Up @@ -2116,6 +2113,8 @@ def set_derived_and_validate_values(self) -> "MaxTextConfig":
self.logical_axis_rules = custom_mesh_config["logical_axis_rules"]
if "data_sharding" in custom_mesh_config:
self.data_sharding = custom_mesh_config["data_sharding"]
if "context_sharding" in custom_mesh_config:
self.context_sharding = custom_mesh_config["context_sharding"]
else:
raise NotImplementedError(f"Custom mesh config file not found at {custom_mesh_path}")

Expand Down Expand Up @@ -2398,10 +2397,9 @@ def calculate_global_batch_sizes(per_device_batch_size, expansion_factor, num_de
self.tensors_on_device = [t for t in tensors if getattr(self, t) == "device"]
self.tensors_to_offload = [t for t in tensors if getattr(self, t) == "offload"]

cp_size = self.ici_context_parallelism * self.dcn_context_parallelism
if self.expert_shard_attention_option == "context":
cp_size *= self.ici_expert_parallelism * self.dcn_expert_parallelism
self.context_parallel_size = cp_size
self.context_parallel_size = getattr(self, f"ici_{self.context_sharding}_parallelism", 1) * getattr(
self, f"dcn_{self.context_sharding}_parallelism", 1
)
if self.pipeline_parallel_layers == -1:
if self.decoder_block == DecoderBlockType.DEEPSEEK:
moe_layers = self.num_decoder_layers - self.first_num_dense_layers
Expand Down Expand Up @@ -2603,6 +2601,8 @@ def calculate_global_batch_sizes(per_device_batch_size, expansion_factor, num_de
)
if self.quantization:
raise ValueError("Quantization is not supported with 'explicit' sharding.")
if self.context_sharding not in ("context", "expert"):
raise ValueError(f"Assigned context_sharding f{self.context_sharding} is not supported.")
if (
self.per_device_batch_size > 0
and (self.per_device_batch_size * self.max_target_length) % self.num_vocab_tiling != 0
Expand Down
9 changes: 1 addition & 8 deletions src/maxtext/layers/attention_mla.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,6 @@
D_KV,
DType,
EMBED,
EP_AS_CONTEXT,
HEAD,
Q_LORA_UP_PROJ,
KV_BATCH,
Expand Down Expand Up @@ -905,9 +904,6 @@ def mla_get_key_value(self, low_rank_main, key_rope, model_mode):
if model_mode == MODEL_MODE_PREFILL:
key_logical_name = self.prefill_key_axis_names
value_logical_name = self.prefill_value_axis_names
elif model_mode == MODEL_MODE_TRAIN and self.config.expert_shard_attention_option == EP_AS_CONTEXT:
key_logical_name = self.ep_key_axis_names
value_logical_name = self.ep_value_axis_names
else:
key_logical_name = self.key_axis_names
value_logical_name = self.value_axis_names
Expand Down Expand Up @@ -1227,11 +1223,8 @@ def __call__(
record_max_logits=use_qk_clip,
)

out = self._maybe_shard_with_logical(out, self.out_axis_names)
out = jax.ad_checkpoint.checkpoint_name(out, "attention_out")
if model_mode == MODEL_MODE_TRAIN and self.config.expert_shard_attention_option == EP_AS_CONTEXT:
out = self._maybe_shard_with_logical(out, self.ep_out_axis_names)
else:
out = self._maybe_shard_with_logical(out, self.out_axis_names)

out_sharding = create_sharding(self.mesh, out_logical_name)
out = self.out_projection(out, out_sharding=out_sharding)
Expand Down
7 changes: 3 additions & 4 deletions src/maxtext/layers/attention_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,6 @@
DEFAULT_MASK_VALUE,
DType,
D_KV,
EP_AS_FSDP,
HEAD,
KV_LENGTH,
LENGTH,
Expand Down Expand Up @@ -1270,7 +1269,7 @@ def wrap_splash_kernel(single_head_mask):

splash_kernel = wrap_splash_kernel(single_head_mask)
segment_axis_names_splash_kernel = self._logical_to_mesh_axes((Q_LENGTH,))
elif self.config.use_jax_splash and self.config.expert_shard_attention_option == EP_AS_FSDP:
elif self.config.use_jax_splash:
if self.config.use_max_logit_estimate > 0:
sa_config = dataclasses.replace(sa_config, max_logit_const=self.config.use_max_logit_estimate)
segment_axis_names_splash_kernel = nn.logical_to_mesh_axes((Q_LENGTH,))
Expand Down Expand Up @@ -1517,7 +1516,7 @@ def cudnn_flash_attention(

_, _, _, head_dim = query.shape # pylint: disable=unused-variable

using_context_parallelism = self.mesh.shape["context"] > 1
using_context_parallelism = self.mesh.shape[self.config.context_sharding] > 1

# Initialize default attention configuration
sliding_window_size = None
Expand Down Expand Up @@ -1575,7 +1574,7 @@ def cudnn_flash_attention(
transpose_batch_sequence=False,
window_size=sliding_window_size,
context_parallel_causal_load_balanced=self.config.context_parallel_load_balance,
context_parallel_axis="context",
context_parallel_axis=self.config.context_sharding,
context_parallel_strategy=self.config.context_parallel_strategy,
max_segments_per_seq=max_segments_per_seq,
)
Expand Down
11 changes: 6 additions & 5 deletions src/maxtext/utils/train_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,9 +230,8 @@ def setup_train_loop(config, recorder, devices=None):
data_iterator, eval_data_iterator = create_data_iterator(config, mesh)
rampup_manager = create_rampup_manager(config, checkpoint_manager)
data_loader = create_dataloader(config, mesh, data_iterator, recorder, rampup_manager)
context_parallel_size = mesh.shape.get("context", 1)
# Check if context parallelism is being used with sequence packing
if context_parallel_size > 1 and config.packing and config.dataset_type != "synthetic":
if config.context_parallel_size > 1 and config.packing and config.dataset_type != "synthetic":
raise ValueError(
"Context parallelism cannot be used with sequence packing. "
"Disable sequence packing (set packing=False). "
Expand All @@ -241,11 +240,13 @@ def setup_train_loop(config, recorder, devices=None):

# Apply reordering wrapper to data iterators if context parallelism is enabled
with jax.set_mesh(mesh):
if context_parallel_size > 1 and config.context_parallel_load_balance:
data_iterator = map(maxtext_utils.get_reorder_callable(context_parallel_size, config.shard_mode), data_iterator)
if config.context_parallel_size > 1 and config.context_parallel_load_balance:
data_iterator = map(
maxtext_utils.get_reorder_callable(config.context_parallel_size, config.shard_mode), data_iterator
)
if eval_data_iterator:
eval_data_iterator = map(
maxtext_utils.get_reorder_callable(context_parallel_size, config.shard_mode),
maxtext_utils.get_reorder_callable(config.context_parallel_size, config.shard_mode),
eval_data_iterator,
)

Expand Down
7 changes: 7 additions & 0 deletions tests/utils/sharding_dump.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,13 @@
"pipeline-large-moe",
("ici_fsdp_parallelism=-1", "ici_expert_parallelism=4", "use_ring_of_experts=true"),
),
(
"deepseek2-16b",
"tpu7x-8",
1,
"ep-as-cp",
("ici_fsdp_parallelism=-1", "ici_expert_parallelism=2"),
),
("qwen3-0.6b", "tpu7x-16", 1, "", ()),
("gpt-oss-20b", "tpu7x-16", 1, "", ()),
("gpt-oss-20b", "tpu7x-16", 1, "", ("ici_fsdp_parallelism=-1", "ici_expert_parallelism=2")),
Expand Down
Loading
Loading