Skip to content

[Megatron-FSDP] MaxPoolAllocator for double-buffering hybrid architectures.#5462

Merged
cspades merged 11 commits into
NVIDIA:mainfrom
cspades:cye/maxpool-dbuf
Jul 2, 2026
Merged

[Megatron-FSDP] MaxPoolAllocator for double-buffering hybrid architectures.#5462
cspades merged 11 commits into
NVIDIA:mainfrom
cspades:cye/maxpool-dbuf

Conversation

@cspades

@cspades cspades commented Jun 23, 2026

Copy link
Copy Markdown
Member
  • I, the PR author, have personally reviewed every line of this PR.

What does this PR do ?

image

Iterating through all FSDP units, data buckets are categorized by data-type, sorted from small to large, and compared to the current MaxPool. If there are not enough buckets in the pool to support the unit, buckets are added to the pool (with size 0). If the largest buckets of the pool are not large enough to support the buckets in the unit (assigned to the pool from smallest to largest), the buckets in the pool are enlarged. After this process, we arrive at a minimal set of buckets that can symmetrically double-buffer every FSDP unit in the model.

  • Adds hybrid architecture double buffering via FSDP unit max-pooling for Megatron-FSDP. (V1)
    • Opens up CG or NCCL UBR support for hybrid architectures, which will help support users for a while.
  • Adds the strict_assignment state to attempt to assign the same bucket previously assigned to an FSDP unit before warning the user and assigning a different bucket to the unit.
    • If this warning appears during warmup or CUDA graph capture, likely some memory is being orphaned and you will hit numerical errors.
  • Fixes an issue where parameters / buckets that are not members of an FSDP unit will pre-fetch subsequent buckets that aren't subsequently used, exhausting buffers in the double buffer allocator and causing an allocation error.
    • Only necessary for double buffer allocators, which require careful management of the 2 buffers in the pool.
  • Deprecates --grad-reduce-in-bf16 / reduce_grad_in_fp32 for Megatron-FSDP, which has been incredibly confusing to use. Default arguments (auto) assume BF16 for both, so will not OOM any existing user's configs.
  • Deprecate fsdp_unit_id == -1. It is never set to -1.
  • Adds a call to torch.autograd.graph.set_override_stale_capture_stream(True) (only supported on new PyTorch versions since Detect and fix stale stream references in autograd during CUDA graph capture pytorch/pytorch#180090) to prevent full-iteration CG errors like this:
[rank0]: RuntimeError: During CUDA graph capture, autograd node 'torch::autograd::AccumulateGrad' has a stale reference to the default stream (stream 0) from warmup. This will invalidate the capture because cudaStreamWaitEvent on the default stream pulls a non-capturing stream into the graph.

[rank0]: To fix, either:
[rank0]:   (a) Run warmup on the same stream that capture will use, or
[rank0]:   (b) Delete references to the loss / autograd graph (e.g. `del loss`) before capture, or
[rank0]:   (c) Call torch.autograd.graph.set_override_stale_capture_stream(True) to automatically redirect stale nodes to the capturing stream.

^ (a) is annoying to implement, (b) is dirty, and (c) is EZ-PZ and recommended.

⚠️ For major changes (either in lines of code or in its impact), please make sure to first share a design doc with the team. If you're unsure what's the best way to do so, contact @NVIDIA/mcore-oncall.

Issue tracking

For PRs from open-source community contributors:

  • New features: a linked issue is required. Please open a feature request and reference it here before submitting the PR.
  • Small updates (bug fixes, minor improvements): a linked issue is recommended and will accelerate the PR review process.

Linked issue:

Contribution process

Pre-checks

  • I have added relevant unit tests
  • I have added relevant functional tests
  • I have added proper typing to my code Typing guidelines
  • I have added relevant documentation
  • I have run the autoformatter.sh on my PR

Code review

Feel free to message or comment @NVIDIA/mcore-oncall to help accelerate your merge into main. The less complex your PR is, the faster it will be approved and merged!

All PRs start as draft. If you open a non-draft PR, it will be automatically converted to draft.

Step 1: Mark PR as "Ready for Review"

  1. When your PR is ready, click Ready for Review.
  2. An oncall reviewer is auto-assigned and expert reviewers are notified based on your changes.
    • Some PRs may jump straight to step 2. This is determined by .github/CODEOWNERS.

⚠️ Only mark as ready once merge-conflicts are resolved and the CI is passing.
Final Review might get declined if these requirements are not fulfilled.

Step 2: Final Review

For PRs that change megatron/core, once all expert reviewers have approved, the Final Review label is applied automatically and final reviewers are assigned.

For PRs outside megatron/core, this step is skipped.

Step 3: Approved

Once all required reviewers have approved, the Approved label is applied automatically.

Merge

Any member of mcore-engineers will be able to merge your PR.

@cspades cspades self-assigned this Jun 23, 2026
@copy-pr-bot

copy-pr-bot Bot commented Jun 23, 2026

Copy link
Copy Markdown

Auto-sync is disabled for draft pull requests in this repository. Workflows must be run manually.

Contributors can view more details about this message here.

@cspades cspades force-pushed the cye/maxpool-dbuf branch from af4ad72 to b81af2c Compare June 23, 2026 22:21
@cspades cspades marked this pull request as ready for review June 25, 2026 01:24
@cspades cspades requested review from a team as code owners June 25, 2026 01:24
Comment thread megatron/core/distributed/fsdp/src/megatron_fsdp/param_and_grad_buffer.py Outdated
Comment on lines +967 to +970
def _build_fixed_max_pool(self):
"""
Compute the maximum double-buffer pool required to support all FSDP units.
"""

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The max pooling algorithm is here. The rest of the code is similar to FixedPoolAllocator.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do max pooling decisions depend on prefetching/overlapping? Conceptually, more aggressive prefetching needs more memory and therefore affects the max pooling algorithm?

@cspades cspades Jun 29, 2026

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Max-pooling is a hard-requirement on memory efficiency unless we have a more sophisticated FSDP scheduling algorithm to pre-compute all possible paths for this set of asymmetric model layers.

Once you consider the max pool the representative size of the unit in this model, pre-fetch sizing can be controlled using suggested_communication_unit_size which controls how many max pool buckets we pre-fetch. It defaults to 500M or 1B numel.

Finally, if you have performance requirements where the size of a max pool bucket lands you in an awkward position with the required communication size, and you don't want to increase the communication size, then the last thing you can do is to use fine-grained AG or fine-grained RS to allow for more checkpoints where we permit an AG or RS to be launched. This will allow multiple episodes of AG or RS to be called within a single FSDP unit.

Comment thread megatron/core/distributed/fsdp/mcore_fsdp_adapter.py
Comment on lines +210 to +218
if hasattr(torch.autograd.graph, 'set_override_stale_capture_stream'):
torch.autograd.graph.set_override_stale_capture_stream(True)
else:
logger.warning(
'torch.autograd.graph.set_override_stale_capture_stream is not '
'available in this PyTorch version; CUDA graph capture may fail '
'if autograd nodes hold stale references to non-capturing streams. '
'Upgrade to a PyTorch build that includes pytorch/pytorch#180090.'
)

@cspades cspades Jun 25, 2026

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should just be something that we should call if we have a new enough PyTorch version: pytorch/pytorch#180090 (The PyTorch version has not been published yet.)

It harmlessly makes things a lot easier w.r.t. stragglers on the Autograd / accumulate stream. cc @nanz-nv

@wujingyue wujingyue left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Deprecates --grad-reduce-in-bf16 / reduce_grad_in_fp32 for Megatron-FSDP, which has been incredibly confusing to use. Default arguments (auto) assume BF16 for both, so will not OOM any existing user's configs.
Adds a call to torch.autograd.graph.set_override_stale_capture_stream(True) (only supported on new PyTorch versions since pytorch/pytorch#180090) to prevent full-iteration CG errors like this:

Thanks for the PR and the figures!

While I'm still reviewing the rest, can these two changes go to a separate PR(s)? https://google.github.io/eng-practices/review/developer/small-cls.html

@cspades

cspades commented Jun 26, 2026

Copy link
Copy Markdown
Member Author

While I'm still reviewing the rest, can these two changes go to a separate PR(s)? https://google.github.io/eng-practices/review/developer/small-cls.html

@wujingyue Considering this exact commit needs to be merged for the NeMo release code freeze in a few days, could we make an exception in this case? These three features are all needed for Nemotron benchmarks. I'm concerned that waiting on 3 PR's to be merged in a few work days is not feasible.

Comment thread examples/megatron_fsdp/train_llama3_8b_fsdp_h100_fp8.sh Outdated
Comment thread examples/megatron_fsdp/train_llama3_8b_fsdp_h100_fp8.sh
Comment thread megatron/core/distributed/fsdp/src/megatron_fsdp/param_and_grad_buffer.py Outdated
@wujingyue

Copy link
Copy Markdown
Contributor

I'm concerned that waiting on 3 PR's to be merged in a few work days is not feasible.

In my experience, reviewing three stacked PRs is usually faster than reviewing a single large PR. Stacked PRs can also be reviewed in parallel, though I may be missing something about how the review process works in Megatron-LM.

As a less ideal alternative, you could keep everything in a single PR but split it into three well-structured commits. GitHub's UI supports reviewing commits individually, which provides a similar incremental review experience.

Comment thread megatron/core/distributed/fsdp/src/megatron_fsdp/param_and_grad_buffer.py Outdated
Comment thread megatron/core/distributed/fsdp/src/megatron_fsdp/param_and_grad_buffer.py Outdated
Comment thread megatron/core/distributed/fsdp/src/megatron_fsdp/param_and_grad_buffer.py Outdated
Comment on lines +967 to +970
def _build_fixed_max_pool(self):
"""
Compute the maximum double-buffer pool required to support all FSDP units.
"""

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do max pooling decisions depend on prefetching/overlapping? Conceptually, more aggressive prefetching needs more memory and therefore affects the max pooling algorithm?

@wujingyue wujingyue left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM otherwise

Comment thread megatron/core/distributed/fsdp/src/megatron_fsdp/param_and_grad_buffer.py Outdated
# If more buckets are needed for this unit, extend the pool with 0's.
if len(bucket_sizes) > len(max_bucket_sizes):
extend_len = len(bucket_sizes) - len(max_bucket_sizes)
max_bucket_sizes.extend([0] * extend_len)

@wujingyue wujingyue Jun 28, 2026

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't max_bucket_sizes already sorted so we can prepend 0s without having to sort max_bucket_sizes again?

@cspades cspades Jun 29, 2026

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well, we have already assigned the previous bucket ID's and I'm using the enumerated index of this list as a bucket offset. If I prepend, it will shift all the buckets to the right by one relative to their bucket offset, and break this algorithm.

sorted(enumerate(max_bucket_sizes), key=lambda x: x[1])

We can avoid this by reversing the zip, adding the new buckets to the end of the pool but getting the largest N buckets from the top of the pool and assigning them to the largest N buckets of the unit (so also bucket_sizes.sort() -> bucket_sizes.sort(reverse=True). I think that should preserve a reversed sorting order.

if ddp_config.grad_reduce_in_fp32
else ddp_config.megatron_fsdp_grad_comm_dtype
),
main_grads_dtype=ddp_config.megatron_fsdp_main_grads_dtype,

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@gautham-kollu , can you update bridge?

FYI @yaoyu-33 and @cuichenx

@cspades cspades Jul 1, 2026

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The migration action items are:

--megatron-fsdp-main-grads-dtype fp32 / ddp.megatron_fsdp_main_grads_dtype=torch.float32
--megatron-fsdp-grad-comm-dtype fp32 / ddp.megatron_fsdp_grad_comm_dtype=torch.float32

for any recipe that uses grad_reduce_in_fp32=True (i.e. does not use --grad-reduce-in-bf16).

For completeness, if --grad-reduce-in-bf16 / grad_reduce_in_fp32=False, then the default megatron_fsdp_grad_comm_dtype and megatron_fsdp_main_grads_dtype are both BF16 so that's also aligned with turning that argument on and does not need any changes. (This is the logical spaghetti I was talking about, two levels of arguments.)

cc @gautham-kollu if you can hit this in your next benchmark update. 🙏🏻 IMO low-ish priority because this will not OOM anyone's script.

@ericharper ericharper left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Approve but bridge needs to be updated.

if ddp_config.grad_reduce_in_fp32
else ddp_config.megatron_fsdp_grad_comm_dtype
),
main_grads_dtype=ddp_config.megatron_fsdp_main_grads_dtype,

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@gautham-kollu , can you update bridge?

FYI @yaoyu-33 and @cuichenx

@svcnvidia-nemo-ci

Copy link
Copy Markdown

🔄 Merge queue validation started!

You can track the progress here: https://github.com/NVIDIA/Megatron-LM/actions/runs/28484555475

@svcnvidia-nemo-ci

Copy link
Copy Markdown

🔄 Merge queue validation started!

You can track the progress here: https://github.com/NVIDIA/Megatron-LM/actions/runs/28485237545

@svcnvidia-nemo-ci

Copy link
Copy Markdown

🔄 Merge queue validation started!

You can track the progress here: https://github.com/NVIDIA/Megatron-LM/actions/runs/28532337643

cspades added 11 commits July 1, 2026 16:14
Signed-off-by: Cory Ye <cye@nvidia.com>
Signed-off-by: Cory Ye <cye@nvidia.com>
… later, and grad_comm_dtype not respected during FixedPool/MaxPool bucket planning.

Signed-off-by: Cory Ye <cye@nvidia.com>
Signed-off-by: Cory Ye <cye@nvidia.com>
…ction.

Signed-off-by: Cory Ye <cye@nvidia.com>
Signed-off-by: Cory Ye <cye@nvidia.com>
Signed-off-by: Cory Ye <cye@nvidia.com>
…nits.

Signed-off-by: Cory Ye <cye@nvidia.com>
Signed-off-by: Cory Ye <cye@nvidia.com>
@svcnvidia-nemo-ci

Copy link
Copy Markdown

🔄 Merge queue validation started!

You can track the progress here: https://github.com/NVIDIA/Megatron-LM/actions/runs/28556300115

@svcnvidia-nemo-ci

Copy link
Copy Markdown

🔄 Merge queue validation started!

You can track the progress here: https://github.com/NVIDIA/Megatron-LM/actions/runs/28557199610

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

8 participants