-
Notifications
You must be signed in to change notification settings - Fork 613
Add and verify support for deterministic fp8 dpa/mha on SM100
#2621
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Add and verify support for deterministic fp8 dpa/mha on SM100
#2621
Conversation
Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
for more information, see https://pre-commit.ci
|
/te-ci pytorch L1 |
Greptile SummaryExtends deterministic FP8 attention support from FP16/BF16 (PR #2584) to FP8 data types on SM100 (Blackwell) architecture. The implementation threads the Key changes:
The version check at Confidence Score: 4/5
Important Files Changed
Sequence DiagramsequenceDiagram
participant Test as test_mha_fp8_vs_f16/<br/>test_dpa_fp8_vs_f16
participant Utils as utils.py:<br/>get_attention_backend
participant CPP as fused_attn.cpp:<br/>nvte_fused_attn_bwd*
participant FP8 as fused_attn_fp8.cu:<br/>fused_attn_fp8_bwd
participant Impl as fused_attn_fp8.cu:<br/>fused_attn_fp8_bwd_impl_v1
participant cuDNN as cuDNN Backend
Note over Test: deterministic param<br/>added to tests
Test->>Utils: check backend with<br/>deterministic flag
Note over Utils: Filter: Allow FP8+deterministic<br/>only on SM100+ (arch >= 10.0)
Utils-->>Test: backend available
Test->>CPP: nvte_fused_attn_bwd_*<br/>(deterministic)
Note over CPP: QKV packed/<br/>KV packed/<br/>separate paths
CPP->>FP8: fused_attn_fp8_bwd<br/>(deterministic)
FP8->>Impl: fused_attn_fp8_bwd_impl_v1<br/>(deterministic)
Note over Impl: Check cuDNN version
alt cudnn_runtime_version >= 91900
Impl->>cuDNN: set_deterministic_algorithm<br/>(deterministic)
else version < 91900
Note over Impl: deterministic flag ignored
end
Impl->>cuDNN: execute backward pass
cuDNN-->>Impl: gradients
Impl-->>FP8: return
FP8-->>CPP: return
CPP-->>Test: return
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
5 files reviewed, 1 comment
| if (cudnn_runtime_version >= 91900) { | ||
| sdpa_backward_options.set_deterministic_algorithm(deterministic); | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
logic: Version check uses 91900 (cuDNN 9.19.0), but related PR #2584 and description mention 9.18.1+ requirement. Should this be 91810 instead?
| if (cudnn_runtime_version >= 91900) { | |
| sdpa_backward_options.set_deterministic_algorithm(deterministic); | |
| } | |
| if (cudnn_runtime_version >= 91810) { | |
| sdpa_backward_options.set_deterministic_algorithm(deterministic); | |
| } |
Is there a specific reason FP8 requires cuDNN 9.19.0+ while FP16/BF16 only needs 9.18.1+?
Description
Follow up for #2584 to add and verify support for "deterministic" fp8 dpa/mha cudnn attention kernels
Type of change
Changes
Please list the changes introduced in this PR:
deterministicargument throughfused_attn_fp8.cu.pytorch/attention/dot_product_attention/utils.pyto allow fp8 + deterministic kernels on SM100test_attention.pyto check fp8 withdeterministic=TrueChecklist: