Onboard explicit sharding to pipeline module#2744
Merged
copybara-service[bot] merged 1 commit intomainfrom Dec 30, 2025
Merged
Conversation
0d8601c to
83a0f00
Compare
|
🤖 Hi @NuojCheng, I've received your request, and I'm working on it now! You can track my progress in the logs for more details. |
There was a problem hiding this comment.
📋 Review Summary
This Pull Request successfully integrates explicit sharding into the pipeline module, significantly enhancing control over tensor distribution and optimizing rotation operations. The refactoring efforts have led to a cleaner and more modular codebase, particularly with the removal of deprecated features and the centralization of sharding logic.
🔍 General Feedback
- The consistent use of
logical_partition_specand the new sharding utilities across various files greatly improves clarity and adherence to the explicit sharding paradigm. - The adoption of
jax.lax.ppermutefor rotation operations is a notable performance improvement. - The removal of deprecated flags, classes, and utility functions simplifies the codebase and reduces technical debt.
83a0f00 to
737bf10
Compare
gobbleturk
reviewed
Dec 18, 2025
gobbleturk
reviewed
Dec 18, 2025
gobbleturk
reviewed
Dec 18, 2025
gobbleturk
reviewed
Dec 19, 2025
gobbleturk
approved these changes
Dec 19, 2025
857a184 to
5b1ad95
Compare
Codecov Report❌ Patch coverage is 📢 Thoughts on this report? Let us know! |
5b1ad95 to
e0009eb
Compare
e0009eb to
112f8c3
Compare
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Description
This PR onboards explicit sharding to the pipeline module, providing direct control over tensor distribution and layouts. Additionally, it introduces performance optimizations for collective operations and simplifies the codebase by deprecating old unused modules and functions.
Key Changes
Explicit Sharding for Pipeline Module
shard_mode=explicitwith pipeline parallelism. Theinit_statesmethod now utilizesout_shardingforstate_ioandcirc_storagewhen operating in explicit mode, ensuring consistent data distribution across the mesh.shard_dim_by_stagesto remove hardcodedP.UNCONSTRAINEDchoices. It now directly imports parameter physical shardings and aNoneplaceholder for explicit sharding, allowing for more granular sharding specifications._rotate_right/_rotate_leftand_shift_right/_shift_left. These functions now leveragejax.lax.ppermutefor circular shift operations. This replaces the previous slice-and-concatenate approach, improving code readability.Deprecation & Code Simplification
ZeroOneTransformerhas been fully removed. Its functionality is now natively supported byshard_optimizer_over_dataflag.model_fsdp_ag_onceflag.Removed deprecated functions inmaxtext_utils.py.Tests
Environment:
Test Case 1: FSDP + PP
smoke_train model_name=llama2-7b ici_pipeline_parallelism=2 num_pipeline_microbatches=8 pipeline_fsdp_ag_once=trueshard_mode=autoshard_mode=explicitTest Case 2: ZeRO-1 + PP
smoke_train model_name=llama2-7b ici_pipeline_parallelism=2 ici_data_parallelism=2 shard_optimizer_over_data=true num_pipeline_microbatches=8 pipeline_fsdp_ag_once=trueshard_mode=autoshard_mode=explicitChecklist
Before submitting this PR, please make sure (put X in square brackets):
gemini-reviewlabel.