Skip to content

Introduce ep-as-cp customized logical rule#3656

Merged
copybara-service[bot] merged 3 commits intomainfrom
chengnuojin-ep-cp
Apr 16, 2026
Merged

Introduce ep-as-cp customized logical rule#3656
copybara-service[bot] merged 3 commits intomainfrom
chengnuojin-ep-cp

Conversation

@NuojCheng
Copy link
Copy Markdown
Collaborator

@NuojCheng NuojCheng commented Apr 14, 2026

Description

Note: This PR should be merged after #3607.

This PR introduces a new custom mesh and rule, ep-as-cp, and deprecates the expert_shard_attention_option flag.

Key Changes

  • Enabled via custom_mesh_and_rule=ep-as-cp, this rule supports DP, PP, FSDP, and EP. Under this setup, EP functions as CP everywhere except within the core MoE components (specifically, the layers between EP all-to-all communications).
  • Introduces a new string flag (defaulting to "context") to explicitly designate which physical axis serves as context parallelism. This is required because CP is utilized in the data pipeline and attention load balancing, and this axis mapping cannot be easily inferred from the custom logical rules alone.
  • The new ep-as-cp mesh and rule are validated by the dump sharding tests.

Tests

Functionality Check

The most straightforward way to verify the ep-as-cp implementation is by running an experiment using a fractional batch size. To test the implementation, we ran the following experiments:

Experiment 1: Real Training

Experiment 2: Training Compilation

Experiment 3: Default vs. ep-as-cp (debug_sharding diff)

Experiment 4: Training compilation large


Correctness Check

To verify loss correctness, we evaluated the following configuration:

  • Topology: v5p-8
  • Sharding Comparison: FSDP=2 + CP=2 vs. FSDP=2 + EP (as CP)=2
  • Batch Size: per_device_batch_size=1
  • Model: model_name=deepseek3-test

Note: We compare FSDP=2 + CP=2 against FSDP=2 + EP (as CP)=2 to ensure an apples-to-apples comparison, as CP directly impacts input data pipelining (CP load balancing).

Result: The losses match within a reasonable tolerance.


Performance Check

To evaluate the performance implications, we compared FSDP + EP (as FSDP) against FSDP + EP (as CP) using the following configuration:

Result: Using EP as CP incurs a general performance penalty. However, this configuration is specifically designed for, and highly beneficial in, two scenarios:

  1. Long-context training: To prevent OOM errors.
  2. Strong scaling: When a fractional batch size is strictly required.

Diff with Previous Implementation

Finally, we compared the sharding and performance of this PR against the previous implementation (using the expert_shard_attention_option flag). For an accurate baseline, we compared against commit 9777a4cf9574f3d10c591e25450cea1b1dde7e01 (April 3rd), which is isolated from recent changes.

Configuration:

Checklist

Before submitting this PR, please make sure (put X in square brackets):

  • I have performed a self-review of my code. For an optional AI review, add the gemini-review label.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have run end-to-end tests tests and provided workload links above if applicable.
  • I have made or will make corresponding changes to the doc if needed, including adding new documentation pages to the relevant Table of Contents (toctree directive) as explained in our documentation.

@NuojCheng NuojCheng added the draft Draft PR label Apr 14, 2026
@codecov
Copy link
Copy Markdown

codecov Bot commented Apr 14, 2026

Codecov Report

❌ Patch coverage is 16.66667% with 5 lines in your changes missing coverage. Please review.

Files with missing lines Patch % Lines
src/maxtext/utils/train_utils.py 0.00% 1 Missing and 2 partials ⚠️
src/maxtext/layers/attention_op.py 0.00% 1 Missing and 1 partial ⚠️

📢 Thoughts on this report? Let us know!

# ==========================================
# Dense Activations
['activation_mlp', []],
['activation_batch', ['data', 'fsdp']],
Copy link
Copy Markdown
Collaborator

@gobbleturk gobbleturk Apr 15, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

activation_batch is also used for attention and is a key dimension. Maybe note this as a comment in the attention section above

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point! Added comments in base.yml

data_sharding: [['data', 'stage', 'fsdp', 'fsdp_transpose', 'sequence', 'context', 'context_autoregressive', 'tensor', 'tensor_transpose', 'tensor_sequence', 'expert', 'autoregressive']]
input_data_sharding_logical_axes: ['activation_embed_and_logits_batch', 'activation_norm_length']
# Determines which physical axis plays the role of context parallelism for input data processing and load balancing
context_sharding: "context"
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What values can this take? Can we remove this as a field as instead its implied by the logical axis rules? E.g. we need a fuction that takes as input the rules and outputs the value of context_sharding?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we list other options? if any

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is hard to infer which physical axis is used for CP from reading logical rule. For example, both sequence and context are used to shard activation_length but only context is used for data processing.

I will add comments indicating possible values of context sharding and add checks.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we list other options? if any

Done!

Copy link
Copy Markdown
Collaborator

@RissyRan RissyRan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you! Wondering if we should add this test into train_compile to prevent the breakage. But I am also fine to keep it as scheduled tests since this is not frequent used. Both work for me.


# This rule uses data, FSDP, and expert. Expert axis acts as context parallelism in
# components except core dMoE part (between EP all2all).
mesh_axes: ['data', 'fsdp', 'expert']
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wondering if you will have a README about those custom mesh and rule supported?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I plan to add a doc explaining custom mesh and rule. I probably won't include the doc in this PR since more planned changes of custom rule are on the way.

# General Weights
['mlp', []],
['embed', ['fsdp', 'expert']],
] No newline at end of file
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: extra line

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

data_sharding: [['data', 'stage', 'fsdp', 'fsdp_transpose', 'sequence', 'context', 'context_autoregressive', 'tensor', 'tensor_transpose', 'tensor_sequence', 'expert', 'autoregressive']]
input_data_sharding_logical_axes: ['activation_embed_and_logits_batch', 'activation_norm_length']
# Determines which physical axis plays the role of context parallelism for input data processing and load balancing
context_sharding: "context"
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we list other options? if any

@copybara-service copybara-service Bot merged commit 9c68b1a into main Apr 16, 2026
29 of 31 checks passed
@copybara-service copybara-service Bot deleted the chengnuojin-ep-cp branch April 16, 2026 19:40
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants