Skip to content

[PyTorch] Fix CP A2A F16 when NVTE_FP8_DPA_BWD=1#2917

Merged
cyanguwa merged 1 commit into
NVIDIA:mainfrom
cyanguwa:fix_2719
Apr 23, 2026
Merged

[PyTorch] Fix CP A2A F16 when NVTE_FP8_DPA_BWD=1#2917
cyanguwa merged 1 commit into
NVIDIA:mainfrom
cyanguwa:fix_2719

Conversation

@cyanguwa
Copy link
Copy Markdown
Collaborator

@cyanguwa cyanguwa commented Apr 22, 2026

Description

This PR fixes an issue with CP A2A when NVTE_FP8_DPA_BWD=1 (default) but fp8_dpa=False. The bug comes from some refactoring work in #2719 for CP. The unit tests didn't catch this because in every test NVTE_FP8_DPA_BWD is explicitly set to appropriate values (0 or 1) whereas in real life users may not do this. We will not change that test logic but will consider adding more tests regarding the recipe control in the future.

The necessary change for the bug fix is only one line in A2A, but this PR also cleaned up a few other places in context_parallel.py.

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

See Description.

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>
@cyanguwa cyanguwa marked this pull request as ready for review April 22, 2026 19:13
@cyanguwa cyanguwa requested a review from ksivaman April 22, 2026 19:13
@cyanguwa
Copy link
Copy Markdown
Collaborator Author

/te-ci torch L1

@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps Bot commented Apr 22, 2026

Greptile Summary

This PR fixes a bug in CP A2A (AttnFuncWithCPAndQKVOA2A) where is_bwd_fp8 was set to the raw NVTE_FP8_DPA_BWD env-var integer (default 1) without gating on the fp8 flag. When fp8=False but the env var defaulted to 1, bwd_requires_o_fp8 would be True and bwd_requires_o_f16 would be False — causing out_part to be quantized to FP8 and saved as an F16 tensor in the backward context, corrupting gradient computation. The fix (is_bwd_fp8 = fp8 and _use_fp8_dpa_bwd) correctly gates the env var on the fp8 flag in all three CP variants (P2P, AllGather, A2A), and the follow-on cleanups to ctx.fp8 and out_fp8 conditions are consistent and correct.

Confidence Score: 5/5

This PR is safe to merge; the fix is targeted, correct, and the cleanup changes are logically consistent.

The root cause is well-diagnosed: raw env-var int (1) used as a boolean without the fp8 gate caused FP8 quantization of the output tensor in a non-FP8 path, leading to corrupt backward tensors. The one-line fix and consistent cleanup across all three CP classes are correct. No new edge cases are introduced.

No files require special attention.

Important Files Changed

Filename Overview
transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py Bug fix: is_bwd_fp8 now correctly includes the fp8 flag across all three CP classes (P2P, AllGather, A2A); also cleans up ctx.fp8 and out_fp8 condition redundancies introduced in #2719.

Flowchart

%%{init: {'theme': 'neutral'}}%%
flowchart TD
    A["CP Forward (A2A/P2P/AllGather)"] --> B{"fp8?"}
    B -- No --> C["is_bwd_fp8 = False\n(BEFORE: 1 from env var)"]
    B -- Yes --> D{"NVTE_FP8_DPA_BWD?"}
    D -- "0" --> E["is_bwd_fp8 = False"]
    D -- "1 (default)" --> F["is_bwd_fp8 = True"]
    C --> G["bwd_requires_o_fp8 = False\nbwd_requires_o_f16 = True\nout_part = F16 tensor ✅"]
    E --> G
    F --> H["bwd_requires_o_fp8 = True\nout_part = FP8 tensor ✅"]
    G --> I["ctx.fp8 = False\nf16_tensors saved ✅"]
    H --> J["ctx.fp8 = True\nfp8_tensors saved ✅"]
Loading

Reviews (1): Last reviewed commit: "fix fp8 and is_bwd_fp8 relationship" | Re-trigger Greptile

@cyanguwa cyanguwa requested a review from ptrendx April 22, 2026 22:43
@cyanguwa cyanguwa merged commit 424b031 into NVIDIA:main Apr 23, 2026
46 of 53 checks passed
YigongQin pushed a commit to YigongQin/TransformerEngine that referenced this pull request Apr 23, 2026
fix fp8 and is_bwd_fp8 relationship

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>
@ptrendx ptrendx added this to the 2.15 milestone Apr 23, 2026
KshitijLakhani pushed a commit that referenced this pull request Apr 27, 2026
fix fp8 and is_bwd_fp8 relationship

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>
faradawn pushed a commit to faradawn/TransformerEngine that referenced this pull request May 14, 2026
fix fp8 and is_bwd_fp8 relationship

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants