fix: broadcast cu_seqlens/max_seqlen to intermediate PP stages for SF…#4150
fix: broadcast cu_seqlens/max_seqlen to intermediate PP stages for SF…#4150meinie0826 wants to merge 1 commit intoNVIDIA:mainfrom
Conversation
|
This PR has been automatically converted to draft because all PRs must start as drafts. When you are ready for review, click Ready for Review to begin the review process. This will:
See the contribution guide for more details. |
There was a problem hiding this comment.
Pull request overview
Fixes SFT packed-sequence training with PP>2 by ensuring packed-sequence metadata is broadcast to TP ranks on intermediate pipeline stages, and adds a unit-test regression suite for the broadcast behavior (issue #4092).
Changes:
- Add intermediate-stage broadcast/receive logic in
get_batch_on_this_tp_rankforattention_mask,cu_seqlens, andmax_seqlen. - Add new CUDA/distributed unit tests validating
cu_seqlens/max_seqlenpropagation across TP/PP configurations.
Reviewed changes
Copilot reviewed 2 out of 2 changed files in this pull request and generated 5 comments.
| File | Description |
|---|---|
megatron/training/utils.py |
Adds intermediate PP-stage broadcast/receive of packed-sequence metadata in get_batch_on_this_tp_rank. |
tests/unit_tests/test_get_batch_on_this_tp_rank.py |
Introduces regression tests for packed-sequence metadata broadcast across TP ranks and PP stages. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| else: | ||
| # Intermediate pipeline stages: broadcast attention_mask, cu_seqlens and max_seqlen | ||
| # so that RoPE and FlashAttention on TP ranks > 0 receive the packed-sequence metadata. | ||
| _broadcast(batch['attention_mask']) | ||
| _broadcast_cu_seqlens(batch['cu_seqlens']) | ||
| _broadcast(batch['max_seqlen']) |
There was a problem hiding this comment.
In the PP-last-stage path, TP ranks > 0 still never receive cu_seqlens/max_seqlen (the code only broadcasts labels, loss_mask, attention_mask). For SFT/packed sequences, the last stage still runs transformer layers and needs the packed-sequence metadata on every TP rank (see pretrain_gpt.py:get_batch / pretrain_mamba.py:get_batch, which build packed-seq params on first/last stages). Consider broadcasting cu_seqlens (via _broadcast_cu_seqlens) and max_seqlen here as well, and mirroring the receive-side logic so TP>0 doesn’t fall back to non-packed execution.
| else: | ||
| # Intermediate pipeline stages: receive attention_mask, cu_seqlens and max_seqlen | ||
| # from TP rank 0 so that RoPE and FlashAttention work correctly with packed sequences. | ||
| tokens = None | ||
| labels = None | ||
| loss_mask = None | ||
| position_ids = None | ||
| local_cp_size = None | ||
|
|
||
| _broadcast(attention_mask) | ||
| cu_seqlens = _broadcast_cu_seqlens() | ||
| _broadcast(max_seqlen) | ||
|
|
There was a problem hiding this comment.
Receive-side intermediate-stage logic now pulls cu_seqlens/max_seqlen, but the PP-last-stage branch above explicitly sets cu_seqlens = None and max_seqlen = None and never receives them. That leaves TP ranks > 0 on the last stage without packed-sequence metadata, causing pretrain_gpt.py:get_batch / pretrain_mamba.py:get_batch to treat the batch as non-packed on those ranks. Align the PP-last-stage receive logic with the sender side (and with first/intermediate stages) so cu_seqlens/max_seqlen are available on all TP ranks when args.sft is enabled.
| world_size = tp * pp | ||
| if torch.cuda.device_count() < world_size: | ||
| pytest.skip(f"Need {world_size} GPUs (tp={tp}, pp={pp})") | ||
|
|
||
| Utils.initialize_model_parallel( | ||
| tensor_model_parallel_size=tp, | ||
| pipeline_model_parallel_size=pp, | ||
| ) |
There was a problem hiding this comment.
These tests assume that having enough visible GPUs (torch.cuda.device_count()) implies the test is running with WORLD_SIZE == tp*pp, but Utils.initialize_model_parallel(tp, pp) requires the process group world size to match (or exceed) the requested topology. When running pytest without torchrun, WORLD_SIZE defaults to 1 and this will fail instead of skipping. Consider skipping unless int(os.environ.get('WORLD_SIZE','1')) >= tp*pp (or torch.distributed.get_world_size() >= tp*pp once initialized).
| tp_rank = mpu.get_tensor_model_parallel_rank() | ||
| pp_rank = mpu.get_pipeline_model_parallel_rank() | ||
| is_first = mpu.is_pipeline_first_stage() | ||
| is_last = mpu.is_pipeline_last_stage() | ||
|
|
||
| data_iterator = _make_data_iterator() if tp_rank == 0 else None | ||
| batch = get_batch_on_this_tp_rank(data_iterator) | ||
|
|
||
| if is_last and not is_first: | ||
| Utils.destroy_model_parallel() | ||
| return # last stage: cu_seqlens is intentionally None | ||
|
|
||
| assert batch["cu_seqlens"] is not None | ||
| assert batch["max_seqlen"] is not None | ||
|
|
||
| # Broadcast the TP-0 copy to all TP ranks and verify every rank agrees. | ||
| cu_ref = batch["cu_seqlens"].clone() | ||
| mx_ref = batch["max_seqlen"].clone() | ||
| torch.distributed.broadcast( | ||
| cu_ref, | ||
| src=mpu.get_tensor_model_parallel_src_rank(), | ||
| group=mpu.get_tensor_model_parallel_group(), | ||
| ) | ||
| torch.distributed.broadcast( | ||
| mx_ref, | ||
| src=mpu.get_tensor_model_parallel_src_rank(), | ||
| group=mpu.get_tensor_model_parallel_group(), | ||
| ) | ||
|
|
||
| assert torch.equal(batch["cu_seqlens"], cu_ref), ( | ||
| f"cu_seqlens mismatch on pp_rank={pp_rank} tp_rank={tp_rank}" | ||
| ) | ||
| assert torch.equal(batch["max_seqlen"], mx_ref), ( | ||
| f"max_seqlen mismatch on pp_rank={pp_rank} tp_rank={tp_rank}" | ||
| ) | ||
|
|
||
| Utils.destroy_model_parallel() |
There was a problem hiding this comment.
This test returns early for the last PP stage under the assumption that cu_seqlens is intentionally None there. For packed sequences, last-stage TP ranks also need the metadata to build packed-seq params for attention. Once the last-stage broadcast is fixed, remove this early-return and validate consistency across TP ranks on the last stage as well.
| tp_rank = mpu.get_tensor_model_parallel_rank() | |
| pp_rank = mpu.get_pipeline_model_parallel_rank() | |
| is_first = mpu.is_pipeline_first_stage() | |
| is_last = mpu.is_pipeline_last_stage() | |
| data_iterator = _make_data_iterator() if tp_rank == 0 else None | |
| batch = get_batch_on_this_tp_rank(data_iterator) | |
| if is_last and not is_first: | |
| Utils.destroy_model_parallel() | |
| return # last stage: cu_seqlens is intentionally None | |
| assert batch["cu_seqlens"] is not None | |
| assert batch["max_seqlen"] is not None | |
| # Broadcast the TP-0 copy to all TP ranks and verify every rank agrees. | |
| cu_ref = batch["cu_seqlens"].clone() | |
| mx_ref = batch["max_seqlen"].clone() | |
| torch.distributed.broadcast( | |
| cu_ref, | |
| src=mpu.get_tensor_model_parallel_src_rank(), | |
| group=mpu.get_tensor_model_parallel_group(), | |
| ) | |
| torch.distributed.broadcast( | |
| mx_ref, | |
| src=mpu.get_tensor_model_parallel_src_rank(), | |
| group=mpu.get_tensor_model_parallel_group(), | |
| ) | |
| assert torch.equal(batch["cu_seqlens"], cu_ref), ( | |
| f"cu_seqlens mismatch on pp_rank={pp_rank} tp_rank={tp_rank}" | |
| ) | |
| assert torch.equal(batch["max_seqlen"], mx_ref), ( | |
| f"max_seqlen mismatch on pp_rank={pp_rank} tp_rank={tp_rank}" | |
| ) | |
| Utils.destroy_model_parallel() | |
| try: | |
| tp_rank = mpu.get_tensor_model_parallel_rank() | |
| pp_rank = mpu.get_pipeline_model_parallel_rank() | |
| data_iterator = _make_data_iterator() if tp_rank == 0 else None | |
| batch = get_batch_on_this_tp_rank(data_iterator) | |
| assert batch["cu_seqlens"] is not None | |
| assert batch["max_seqlen"] is not None | |
| # Broadcast the TP-0 copy to all TP ranks and verify every rank agrees. | |
| cu_ref = batch["cu_seqlens"].clone() | |
| mx_ref = batch["max_seqlen"].clone() | |
| torch.distributed.broadcast( | |
| cu_ref, | |
| src=mpu.get_tensor_model_parallel_src_rank(), | |
| group=mpu.get_tensor_model_parallel_group(), | |
| ) | |
| torch.distributed.broadcast( | |
| mx_ref, | |
| src=mpu.get_tensor_model_parallel_src_rank(), | |
| group=mpu.get_tensor_model_parallel_group(), | |
| ) | |
| assert torch.equal(batch["cu_seqlens"], cu_ref), ( | |
| f"cu_seqlens mismatch on pp_rank={pp_rank} tp_rank={tp_rank}" | |
| ) | |
| assert torch.equal(batch["max_seqlen"], mx_ref), ( | |
| f"max_seqlen mismatch on pp_rank={pp_rank} tp_rank={tp_rank}" | |
| ) | |
| finally: | |
| Utils.destroy_model_parallel() |
2b04b29 to
1b5fe0f
Compare
1b5fe0f to
cea9d00
Compare
|
@chtruong814 Hi, could you take a look at this pr and give some advice? |
What does this PR do ?
Fix: Add an
elsebranch for intermediate stages on both the sender and receiver sides ofget_batch_on_this_tp_rankto broadcastattention_mask,cu_seqlens, andmax_seqlen.Fixes #4092
Contribution process
Pre-checks
Code review
Feel free to message or comment the @mcore-oncall to help accelerate your merge into main. The less complex your PR is, the faster it will be approved and merged!
All PRs start as draft. If you open a non-draft PR, it will be automatically converted to draft.
Step 1: Mark PR as "Ready for Review"
.github/CODEOWNERS.Final Review might get declined if these requirements are not fulfilled.
Step 2: Final Review
For PRs that change
megatron/core, once all expert reviewers have approved, theFinal Reviewlabel is applied automatically and final reviewers are assigned.For PRs outside
megatron/core, this step is skipped.Step 3: Approved
Once all required reviewers have approved, the
Approvedlabel is applied automatically.Merge
Any member of mcore-engineers will be able to merge your PR.
For MRs into `dev` branch
The proposed review process for `dev` branch is under active discussion.MRs are mergable after one approval by either
eharper@nvidia.comorzijiey@nvidia.com.