Skip to content

[Dev][FSDP] TE Partial CUDA Graph + Megatron-FSDP compatibility#4231

Draft
buptzyb wants to merge 1 commit intoNVIDIA:devfrom
buptzyb:fsdp_cudagraph
Draft

[Dev][FSDP] TE Partial CUDA Graph + Megatron-FSDP compatibility#4231
buptzyb wants to merge 1 commit intoNVIDIA:devfrom
buptzyb:fsdp_cudagraph

Conversation

@buptzyb
Copy link
Copy Markdown
Contributor

@buptzyb buptzyb commented Apr 9, 2026

What does this PR do ?

Main PR #4232 . Depend on TE PR NVIDIA/TransformerEngine#2831 .

Summary

Enables TransformerEngine's partial (layerwise) CUDA Graph to work correctly with
Megatron-FSDP. Prior to this PR, FSDP's forward hooks that wrap backward handlers
were passed directly to TE's make_graphed_callables, which would capture them at
the wrong phase, causing numerical divergence and illegal memory access during
graph replay.

Changes

Core (megatron/core/transformer/cuda_graphs.py)

  • Two-phase hook extraction before CUDA graph capture:
    • Phase 1 (_extract_module_hooks): copies all 4 PyTorch hook dicts uniformly and
      clears them from the module.
    • Phase 2 (_apply_fsdp_hook_transforms): detects FSDP-specific hook wrappers via
      sentinel attributes (_cuda_graph_backward_handler,
      _cuda_graph_backward_pre_handler) and reroutes their inner backward handlers into
      the correct TE-facing key, withholding the wrappers from TE entirely.
  • *_restore keys are populated independently in Phase 1 and never touched in Phase 2,
    ensuring clean teardown after graph deletion.

FSDP (megatron/core/distributed/fsdp/src/megatron_fsdp/megatron_fsdp.py)

  • Tags FSDP forward hooks that wrap backward handlers with the sentinel attributes read
    by cuda_graphs.py.

Argument validation (megatron/training/arguments.py)

  • When using CUDA Graph with Megatron-FSDP (any non-full_iteration scope), require:
    • --fsdp-double-buffer (prevents dynamic buffer addresses across iterations)
    • --fsdp-db-use-persist-buf-on-alloc-fail (prevents failed allocations falling back
      to dynamic buffers during graph replay, which would cause illegal memory access)

Tests

Unit test (tests/unit_tests/distributed/megatron_fsdp/)

  • test_cudagraph_alignment_with_fsdp: verifies Δloss == 0 (bit-for-bit) between
    eager FSDP and FSDP+CUDA graph across:
    • Parallelism configs: default (DP=8), TP=2, EP=2+ETP=2
    • CUDA graph scopes: attn, attn+moe_router+moe_preprocess, moe_router
  • utils.py: consolidated training defaults into base_args so callers only need to
    pass overrides.

⚠️ For major changes (either in lines of code or in its impact), please make sure to first share a design doc with the team. If you're unsure what's the best way to do so, contact the @mcore-oncall.

Contribution process

Pre-checks

  • I have added relevant unit tests
  • I have added relevant functional tests
  • I have added proper typing to my code Typing guidelines
  • I have added relevant documentation
  • I have run the autoformatter.sh on my PR

Code review

Feel free to message or comment the @mcore-oncall to help accelerate your merge into main. The less complex your PR is, the faster it will be approved and merged!

All PRs start as draft. If you open a non-draft PR, it will be automatically converted to draft.

Step 1: Mark PR as "Ready for Review"

  1. When your PR is ready, click Ready for Review.
  2. An oncall reviewer is auto-assigned and expert reviewers are notified based on your changes.
    • Some PRs may jump straight to step 2. This is determined by .github/CODEOWNERS.

⚠️ Only mark as ready once merge-conflicts are resolved and the CI is passing.
Final Review might get declined if these requirements are not fulfilled.

Step 2: Final Review

For PRs that change megatron/core, once all expert reviewers have approved, the Final Review label is applied automatically and final reviewers are assigned.

For PRs outside megatron/core, this step is skipped.

Step 3: Approved

Once all required reviewers have approved, the Approved label is applied automatically.

Merge

Any member of mcore-engineers will be able to merge your PR.

For MRs into `dev` branch The proposed review process for `dev` branch is under active discussion.

MRs are mergable after one approval by either eharper@nvidia.com or zijiey@nvidia.com.

Signed-off-by: Robin Zhang <robinz@nvidia.com>
@copy-pr-bot
Copy link
Copy Markdown

copy-pr-bot bot commented Apr 9, 2026

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant