[PyT] [Common] Enable sm120 support for fused attn if cuDNN is 9.18.1+#2693
Conversation
674394b to
998b3b8
Compare
Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
for more information, see https://pre-commit.ci
…pe instead of TH1 for sm120 Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
dc282ea to
b2f5864
Compare
for more information, see https://pre-commit.ci
Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
…incorrect max logit calculation (includes padded tokens in max calculation) Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
…pa arbitrary kernel call Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
…clude a check for sm120 Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
for more information, see https://pre-commit.ci
Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
Greptile SummaryThis PR enables SM120 (Blackwell) support for fused attention with THD (packed/ragged) sequence formats when cuDNN ≥ 9.18.1 is present. The core insight is that cuDNN on SM120 rejects the "token-count" dimensional layout used on earlier architectures and instead requires BHSD-like dimensions with Key changes:
Confidence Score: 4/5
Important Files Changed
Flowchart%%{init: {'theme': 'neutral'}}%%
flowchart TD
A[nvte_get_fused_attn_backend] --> B{sm_arch == 120?}
B -- No --> C[Existing backend selection logic]
B -- Yes --> D{cudnn_runtime_version >= 91801?}
D -- No --> E[Return NVTE_No_Backend\n warn: upgrade cuDNN]
D -- Yes --> F{qkv_layout is T3HD or TH3D?}
F -- Yes --> G[Return NVTE_No_Backend\n warn: unsupported layout]
F -- No --> H[Return NVTE_F16_arbitrary_seqlen]
H --> I[fused_attn_arbitrary_seqlen_fwd_impl]
I --> J{sm_arch == 120?}
J -- No --> K[b = max_b\n s_q = max_t_q\n s_kv = max_t_kv\n use_ragged_stats = true]
J -- Yes --> L[b/s_q/s_kv unchanged\n use_ragged_stats = false]
K --> M[Stats shape: TH1 packed\n ragged offset on stats]
L --> N[Stats shape: BHS1\n no ragged offset on stats\n ragged offset on Q K V O only]
I --> O[fused_attn_fwd Python wrapper]
O --> P{qkv_format=thd AND max_tensor.ndim==4?}
P -- Yes --> Q[Mask padded positions to -inf\n before computing max_logit]
P -- No --> R[Compute max_logit directly]
Last reviewed commit: "[pre-commit.ci] auto..." |
| NVTE_ERROR( | ||
| "T3HD and TH3D QKV layouts are not supported by cuDNN on SM120 " | ||
| "Use thd_thd_thd or other THD layouts instead."); |
There was a problem hiding this comment.
Missing period in forward error message
The forward error message is missing a period that is present in the corresponding backward error message (line 748). Minor inconsistency but worth fixing for uniformity.
| NVTE_ERROR( | |
| "T3HD and TH3D QKV layouts are not supported by cuDNN on SM120 " | |
| "Use thd_thd_thd or other THD layouts instead."); | |
| "T3HD and TH3D QKV layouts are not supported by cuDNN on SM120. " | |
| "Use thd_thd_thd or other THD layouts instead."); |
Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!
|
/te-ci L0 L1 |
Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
|
/te-ci L0 L1 |
| NVTE_ERROR( | ||
| "T3HD and TH3D QKV layouts are not supported by cuDNN on SM120. " | ||
| "Use thd_thd_thd or other THD layouts instead."); | ||
| } |
There was a problem hiding this comment.
Should we add this logic to nvte_get_fused_backend()? So we don't error here but rather as a "not supported" case.
There was a problem hiding this comment.
My initial thought process was this use case was:
-
To disabled fused attn in the Python layer (PyT specific, here) as we already have a check for sm120 for cuDNN version < 9.18.1 and hence I added it in here: https://github.com/KshitijLakhani/TransformerEngine/blob/bcfef909681c24a95163ecf987fbf952a4f4eb4a/transformer_engine/pytorch/attention/dot_product_attention/utils.py#L705
(already in this PR) -
To not allow a call to cuDNN kernels from the C++ (common) layer so that TE can produce an error for the user with an easier to understand message rather than a difficult to understand TE message
(already in this PR and the lines of interest for this comment)
If I understand right you are suggesting we let #1 be there and replace #2 with a call to set backend to NVTE_No_Backend in nvte_get_fused_attention_backend() for cudnn >= 9.18.1 + t3hd/th3d + sm120 + NVTE_F16_arbitrary_seqlen
Sounds okay ?
There was a problem hiding this comment.
I feel the sm120 check, the T3HD/TH3D check, cuDNN version check can all be in nvte_get_fused_attention_backend() so it produces NVTE_No_Backend. Sometimes we put checks in utils.py because it's not easy to do on the C side, like with the KV cache feature. But for the logic you have in utils.py and here, they can probably go into nvte_get_fused_attention_backend?
There was a problem hiding this comment.
Also, greptile seems to say there are duplicate device queries (device_id/sm_arch_) in your code (3 of them). Can you check if it's true?
There was a problem hiding this comment.
Also, greptile seems to say there are duplicate device queries (device_id/sm_arch_) in your code (3 of them). Can you check if it's true?
I investigated this. Seems like I've added two new calls to cuda::current_device() in fused_attn_f16_arbitrary_seqlen.cu:
fused_attn_arbitrary_seqlen_fwd()fused_attn_arbitrary_seqlen_fwd_impl()
There was one from before:
fused_attn_arbitrary_seqlen_bwd_impl()
I could pass the sm_arch from fused_attn_arbitrary_seqlen_fwd() to fused_attn_arbitrary_seqlen_fwd_impl() but that would just increase an additional arg. Do you think we should consolidate it in this way ?
There was a problem hiding this comment.
Moved the T3HD/ TH3D check to nvte_get_fused_attention_backend()
Re-ran all tests on sm120 successfully @cyanguwa
| output_S->data.dptr = nullptr; | ||
| if (q_format == NVTE_QKV_Format::NVTE_THD && cudnn_runtime_version >= 90600) { | ||
| if ((q_format == NVTE_QKV_Format::NVTE_THD && cudnn_runtime_version >= 90600) && | ||
| !(sm_arch_ >= 120)) { |
There was a problem hiding this comment.
I wonder if we should do sm_arch+ != 120 instead. I feel our sm numbers are not monotonically increasing. I made the mistake of doing >sm100 sometimes, but then sm103 and sm120 had different support matrix.
There was a problem hiding this comment.
I agree. Will push a commit for it.
There was a problem hiding this comment.
Made this change everywhere in the PR
…ed to sm120+ Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
for more information, see https://pre-commit.ci
Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
…nstead of higher layers in TE stack Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
for more information, see https://pre-commit.ci
|
/te-ci L0 L1 |
1 similar comment
|
/te-ci L0 L1 |
|
PyT and common jobs had passed in previous CI runs |
#2693) * Enable sm120 support for fused attn if cuDNN is 9.18.1+ Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Force intermediate tensors such as S, Sum_Exp, and Max to be BHS1 shape instead of TH1 for sm120 Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Add support for sm120 correct batch, seq dims Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com> * Add support for sm120 BHS1 style max logit even QKV are THD to avoid incorrect max logit calculation (includes padded tokens in max calculation) Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com> * Disable fused and flash attn for sm120 filter:kv cache Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com> * For CP P2P attn, set softmax_lse_in_packed_format to False if sm120+ Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com> * Assert in TE if T3HD/TH3D layout is used on sm120 before cuDNN F16 sdpa arbitrary kernel call Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com> * Modify is_ragged_q && cudnn_runtime_version >= 90600 check to also include a check for sm120 Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * nit: Code clean up Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com> * Disable fused attn for T3HD and TH3D Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com> * nit: Add missed sm120 guard Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com> * Modify sm120 condition to be very specific to sm120 and not generalized to sm120+ Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * nit: Fix missing sm120 check in fwd Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com> * Move the check for sm120 T3HD/TH3D to nvte_get_fused_attn_backend() instead of higher layers in TE stack Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com> * nit: Check for matching sm120 and not sm120+ Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
NVIDIA#2693) * Enable sm120 support for fused attn if cuDNN is 9.18.1+ Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Force intermediate tensors such as S, Sum_Exp, and Max to be BHS1 shape instead of TH1 for sm120 Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Add support for sm120 correct batch, seq dims Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com> * Add support for sm120 BHS1 style max logit even QKV are THD to avoid incorrect max logit calculation (includes padded tokens in max calculation) Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com> * Disable fused and flash attn for sm120 filter:kv cache Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com> * For CP P2P attn, set softmax_lse_in_packed_format to False if sm120+ Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com> * Assert in TE if T3HD/TH3D layout is used on sm120 before cuDNN F16 sdpa arbitrary kernel call Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com> * Modify is_ragged_q && cudnn_runtime_version >= 90600 check to also include a check for sm120 Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * nit: Code clean up Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com> * Disable fused attn for T3HD and TH3D Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com> * nit: Add missed sm120 guard Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com> * Modify sm120 condition to be very specific to sm120 and not generalized to sm120+ Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * nit: Fix missing sm120 check in fwd Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com> * Move the check for sm120 T3HD/TH3D to nvte_get_fused_attn_backend() instead of higher layers in TE stack Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com> * nit: Check for matching sm120 and not sm120+ Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
Description
Enable sm120 support for THD for fused attn for cuDNN 9.18.1+
Type of change
Changes
nvte_get_fused_attention_backend()if T3HD or TH3D shapes are used as cuDNN does not support then. Also, warn the user is they are using sm120 with cuDNN < 9.18.1get_attention_backends()(until fully supported)Test results:
Ran PyT attention tests on sm120 and no failures:
Checklist: