Skip to content

Onboard explicit sharding to pipeline module#2744

Merged
copybara-service[bot] merged 1 commit intomainfrom
chengnuojin-explicit-pipeline
Dec 30, 2025
Merged

Onboard explicit sharding to pipeline module#2744
copybara-service[bot] merged 1 commit intomainfrom
chengnuojin-explicit-pipeline

Conversation

@NuojCheng
Copy link
Copy Markdown
Collaborator

@NuojCheng NuojCheng commented Nov 24, 2025

Description

This PR onboards explicit sharding to the pipeline module, providing direct control over tensor distribution and layouts. Additionally, it introduces performance optimizations for collective operations and simplifies the codebase by deprecating old unused modules and functions.

Key Changes

Explicit Sharding for Pipeline Module

  • Introduced shard_mode=explicit with pipeline parallelism. The init_states method now utilizes out_sharding for state_io and circ_storage when operating in explicit mode, ensuring consistent data distribution across the mesh.
  • Refactored shard_dim_by_stages to remove hardcoded P.UNCONSTRAINED choices. It now directly imports parameter physical shardings and a None placeholder for explicit sharding, allowing for more granular sharding specifications.
  • Refactored the rotation and shift primitives: _rotate_right / _rotate_left and _shift_right / _shift_left. These functions now leverage jax.lax.ppermute for circular shift operations. This replaces the previous slice-and-concatenate approach, improving code readability.

Deprecation & Code Simplification

  • The class ZeroOneTransformer has been fully removed. Its functionality is now natively supported by shard_optimizer_over_data flag.
  • Removed the model_fsdp_ag_once flag.
  • Removed deprecated functions in maxtext_utils.py.
  • Made some corresponding changes in tests.

Tests

Environment:

  • Device: v5p-8
  • JAX Version: 0.8.0
  • Model: Llama2-7b

Test Case 1: FSDP + PP

smoke_train model_name=llama2-7b ici_pipeline_parallelism=2 num_pipeline_microbatches=8 pipeline_fsdp_ag_once=true

Implementation / Mode Commit/Flag Step Time Memory Profile (xprof)
Baseline 7991534 4756ms 50062Mb Link
Auto Sharding shard_mode=auto 4567ms 48219Mb Link
Explicit Sharding shard_mode=explicit 6822ms 50913Mb Link

Test Case 2: ZeRO-1 + PP

smoke_train model_name=llama2-7b ici_pipeline_parallelism=2 ici_data_parallelism=2 shard_optimizer_over_data=true num_pipeline_microbatches=8 pipeline_fsdp_ag_once=true

Implementation / Mode Commit/Flag Step Time Memory Profile (xprof)
Baseline 7991534 4759ms 50062Mb Link
Auto Sharding shard_mode=auto 4636ms 48315Mb Link
Explicit Sharding shard_mode=explicit 5272ms 47997Mb Link

Checklist

Before submitting this PR, please make sure (put X in square brackets):

  • I have performed a self-review of my code. For an optional AI review, add the gemini-review label.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have run end-to-end tests tests and provided workload links above if applicable.
  • I have made or will make corresponding changes to the doc if needed, including adding new documentation pages to the relevant Table of Contents (toctree directive) as explained in our documentation.

@NuojCheng NuojCheng added the draft Draft PR label Nov 24, 2025
@NuojCheng NuojCheng force-pushed the chengnuojin-explicit-pipeline branch 14 times, most recently from 0d8601c to 83a0f00 Compare December 18, 2025 05:53
@NuojCheng NuojCheng added gemini-review and removed draft Draft PR labels Dec 18, 2025
@NuojCheng NuojCheng marked this pull request as ready for review December 18, 2025 05:54
@github-actions
Copy link
Copy Markdown

🤖 Hi @NuojCheng, I've received your request, and I'm working on it now! You can track my progress in the logs for more details.

Copy link
Copy Markdown

@github-actions github-actions Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

📋 Review Summary

This Pull Request successfully integrates explicit sharding into the pipeline module, significantly enhancing control over tensor distribution and optimizing rotation operations. The refactoring efforts have led to a cleaner and more modular codebase, particularly with the removal of deprecated features and the centralization of sharding logic.

🔍 General Feedback

  • The consistent use of logical_partition_spec and the new sharding utilities across various files greatly improves clarity and adherence to the explicit sharding paradigm.
  • The adoption of jax.lax.ppermute for rotation operations is a notable performance improvement.
  • The removal of deprecated flags, classes, and utility functions simplifies the codebase and reduces technical debt.

Comment thread src/MaxText/layers/decoders.py
Comment thread src/MaxText/layers/decoders.py
Comment thread src/MaxText/layers/decoders.py
Comment thread src/MaxText/layers/pipeline.py
Comment thread src/MaxText/layers/pipeline.py
Comment thread src/MaxText/layers/pipeline.py
Comment thread src/MaxText/layers/pipeline.py
Comment thread src/MaxText/sharding.py
Comment thread src/MaxText/train.py
Comment thread src/MaxText/vocabulary_tiling.py
@NuojCheng NuojCheng force-pushed the chengnuojin-explicit-pipeline branch from 83a0f00 to 737bf10 Compare December 18, 2025 07:10
Comment thread src/MaxText/layers/pipeline.py
Comment thread src/MaxText/layers/pipeline.py Outdated
Comment thread src/MaxText/layers/pipeline.py Outdated
Comment thread src/MaxText/maxtext_utils.py
Comment thread src/MaxText/layers/pipeline.py Outdated
@NuojCheng NuojCheng force-pushed the chengnuojin-explicit-pipeline branch 2 times, most recently from 857a184 to 5b1ad95 Compare December 23, 2025 22:19
@codecov
Copy link
Copy Markdown

codecov Bot commented Dec 23, 2025

@NuojCheng NuojCheng force-pushed the chengnuojin-explicit-pipeline branch from 5b1ad95 to e0009eb Compare December 24, 2025 01:30
@NuojCheng NuojCheng force-pushed the chengnuojin-explicit-pipeline branch from e0009eb to 112f8c3 Compare December 29, 2025 22:48
Copy link
Copy Markdown
Collaborator

@khatwanimohit khatwanimohit left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm

@copybara-service copybara-service Bot merged commit fbcee8f into main Dec 30, 2025
23 of 24 checks passed
@copybara-service copybara-service Bot deleted the chengnuojin-explicit-pipeline branch December 30, 2025 06:16
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants