Add Packing Support for Context Parallelism (Ring Attention)#2906
Conversation
Codecov Report❌ Patch coverage is 📢 Thoughts on this report? Let us know! |
richjames0
left a comment
There was a problem hiding this comment.
a couple of nits but lgtm
gobbleturk
left a comment
There was a problem hiding this comment.
Thanks for the tests and great comments illustrating the two reorder strategies!
RissyRan
left a comment
There was a problem hiding this comment.
LGTM, just minor comments.
|
This PR has been automatically marked as stale because it has not had recent activity. It will be closed soon if no further activity occurs. Thank you for your contributions. |
|
This PR was closed because it has been inactive for a while. Please reopen it if you are still working on it. |
9d04e5a to
e9d76dc
Compare
|
This PR has been automatically marked as stale because it has not had recent activity. It will be closed soon if no further activity occurs. Thank you for your contributions. |
|
This PR was closed because it has been inactive for a while. Please reopen it if you are still working on it. |
2935ff5 to
132bf4c
Compare
Enable CP + packing for context_parallel_strategy="ring" with load
balancing. On GPU, uses Transformer Engine's striped reorder for
THD-packed sequences. On TPU/CPU, falls back to pure-JAX reorder_sequence
and never imports TE.
Changes:
- common_types: Add ReorderStrategy enum (AUTO, DUAL_CHUNK_SWAP, STRIPED).
- configs: Add context_parallel_reorder_strategy (default "auto"). Reject
explicit STRIPED on non-GPU at config validation time.
- attention_op: Thread segment_positions through apply_attention,
cudnn_flash_attention, and __call__. Use segment_positions in TE's
SequenceDescriptor for packing. Restrict packing+CP to load-balanced
ring only. Note TE version constraint.
- attentions.py, attention_mla.py, gpt3.py: Pass inputs_positions into
attention_op calls (None for gpt3).
- max_utils: Hardware-dispatched reorder_causal_load_balanced. GPU uses
TE's reorder_causal_load_balancing; TPU/CPU uses reorder_sequence.
TE import is lazy and GPU-only.
- maxtext_utils: Thread reorder_strategy and hardware through
shard_reorder_causal_load_balanced and get_reorder_callable. Default
hardware="tpu" never triggers TE import.
- train_utils: Allow ring+packing; forbid all_gather+packing and
synthetic+packing. Resolve AUTO->STRIPED for packing else
DUAL_CHUNK_SWAP. Pass config.hardware to reorder callable. Build
data_loader after reorder wrapper is applied.
- attention_test_util: Pass cfg_cp.hardware so TPU tests use pure-JAX
reorder. Helper is TPU-oriented and does not model GPU packed behavior.
- tests: Add test_gpu_ring_attention_with_packing (sm90+).
Requires TE with reorder_causal_load_balancing; works with TE <=2.11 or
>=2.14 (incompatible with 2.12 and 2.13 due to a known bug).
132bf4c to
eb52c0b
Compare
9608068
into
AI-Hypercomputer:main
Description
Enables sequence packing for context parallelism with
ringstrategy using TransformerEngine's DotProductAttention. Includes comprehensive GPU tests for ring attention with packing for sm90+.reorder_causal_load_balancingapiCurrent support matrix for context parallelism on GPU:
all_gatherdual_chunk_swapringdual_chunk_swap(non-packed)striped(packed)Tests
Added a GPU integration test that works for sm90+.
Checklist
Before submitting this PR, please make sure (put X in square brackets):
gemini-reviewlabel.