[PyTorch] Fix CP A2A F16 when NVTE_FP8_DPA_BWD=1#2917
Conversation
Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>
|
/te-ci torch L1 |
Greptile SummaryThis PR fixes a bug in CP A2A ( Confidence Score: 5/5This PR is safe to merge; the fix is targeted, correct, and the cleanup changes are logically consistent. The root cause is well-diagnosed: raw env-var int (1) used as a boolean without the fp8 gate caused FP8 quantization of the output tensor in a non-FP8 path, leading to corrupt backward tensors. The one-line fix and consistent cleanup across all three CP classes are correct. No new edge cases are introduced. No files require special attention. Important Files Changed
Flowchart%%{init: {'theme': 'neutral'}}%%
flowchart TD
A["CP Forward (A2A/P2P/AllGather)"] --> B{"fp8?"}
B -- No --> C["is_bwd_fp8 = False\n(BEFORE: 1 from env var)"]
B -- Yes --> D{"NVTE_FP8_DPA_BWD?"}
D -- "0" --> E["is_bwd_fp8 = False"]
D -- "1 (default)" --> F["is_bwd_fp8 = True"]
C --> G["bwd_requires_o_fp8 = False\nbwd_requires_o_f16 = True\nout_part = F16 tensor ✅"]
E --> G
F --> H["bwd_requires_o_fp8 = True\nout_part = FP8 tensor ✅"]
G --> I["ctx.fp8 = False\nf16_tensors saved ✅"]
H --> J["ctx.fp8 = True\nfp8_tensors saved ✅"]
Reviews (1): Last reviewed commit: "fix fp8 and is_bwd_fp8 relationship" | Re-trigger Greptile |
fix fp8 and is_bwd_fp8 relationship Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>
fix fp8 and is_bwd_fp8 relationship Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>
fix fp8 and is_bwd_fp8 relationship Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>
Description
This PR fixes an issue with CP A2A when
NVTE_FP8_DPA_BWD=1(default) butfp8_dpa=False. The bug comes from some refactoring work in #2719 for CP. The unit tests didn't catch this because in every testNVTE_FP8_DPA_BWDis explicitly set to appropriate values (0 or 1) whereas in real life users may not do this. We will not change that test logic but will consider adding more tests regarding the recipe control in the future.The necessary change for the bug fix is only one line in A2A, but this PR also cleaned up a few other places in context_parallel.py.
Type of change
Changes
See Description.
Checklist: