Skip to content

[PyT] [Common] Enable sm120 support for fused attn if cuDNN is 9.18.1+#2693

Merged
KshitijLakhani merged 21 commits intoNVIDIA:mainfrom
KshitijLakhani:klakhani/maint/sm120-thd-flash-support
Mar 22, 2026
Merged

[PyT] [Common] Enable sm120 support for fused attn if cuDNN is 9.18.1+#2693
KshitijLakhani merged 21 commits intoNVIDIA:mainfrom
KshitijLakhani:klakhani/maint/sm120-thd-flash-support

Conversation

@KshitijLakhani
Copy link
Copy Markdown
Collaborator

@KshitijLakhani KshitijLakhani commented Feb 20, 2026

Description

Enable sm120 support for THD for fused attn for cuDNN 9.18.1+

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

  • For sm120, change the shape for stats tensors to be BHS1 instead of TH1. Propogate changes for the same.
  • For sm120, disable fused attention in nvte_get_fused_attention_backend() if T3HD or TH3D shapes are used as cuDNN does not support then. Also, warn the user is they are using sm120 with cuDNN < 9.18.1
  • For sm120, disable fused and flash attention for kv cache in get_attention_backends()(until fully supported)
  • NOTE: No changes made to test code (skip for sm120 etc.) - any skips to be achieved via disabling of backend attn type rather than hard hammer way of disabling tests

Test results:

Ran PyT attention tests on sm120 and no failures:

klakhani@alon-ts1-iec-15:~/TE$ pytest tests/pytorch/attention/test_attention_with_cp.py 
===================================================================================================== test session starts ======================================================================================================
platform linux -- Python 3.12.3, pytest-8.1.1, pluggy-1.6.0
rootdir: /home/klakhani/TE
configfile: pyproject.toml
plugins: typeguard-4.5.1, anyio-4.12.1, xdist-3.8.0, shard-0.1.2, flakefinder-1.1.0, hypothesis-6.130.8, rerunfailures-16.1
collected 8488 items                                                                                                                                                                                                           
Running 8488 items in this shard

tests/pytorch/attention/test_attention_with_cp.py .s.ss.s.ss...sssssss...ss...ss.s.sssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss [  1%]
ssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss [  4%]
ssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss [  7%]
ssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss [  9%]
ssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss [ 12%]
ssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss [ 14%]
ssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss [ 17%]
ssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss [ 19%]
ssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss [ 22%]
ssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss [ 24%]
ssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss [ 27%]
ssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss [ 29%]
ssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss [ 32%]
ssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss [ 35%]
ssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss [ 37%]
ssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss [ 40%]
ssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss [ 42%]
ssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss [ 45%]
ssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss [ 47%]
ssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss [ 50%]
ssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss [ 52%]
ssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss [ 55%]
ssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss [ 57%]
ssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss [ 60%]
ssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss [ 63%]
ssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss.s.s.sss.s.s.s.s.s.sss.s.sssss.sssssssssssss.s.sss.s.sssssssssssssssss [ 65%]
ssssssssssssssssss.s.sss.s.sssssssssss.s.s.sss.s.sssssssssss.s.s.sssss.sssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss [ 68%]
ssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss [ 70%]
ssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss [ 73%]
ssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss [ 75%]
ssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss [ 78%]
ssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss [ 80%]
ssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss [ 83%]
ssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss [ 85%]
ssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss [ 88%]
ssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss [ 91%]
ssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss [ 93%]
ssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss [ 96%]
ssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss [ 98%]
ssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss                                                                                                       [100%]

======================================================================================================= warnings summary =======================================================================================================
../../../usr/local/lib/python3.12/dist-packages/torch/jit/_script.py:1487
../../../usr/local/lib/python3.12/dist-packages/torch/jit/_script.py:1487
  /usr/local/lib/python3.12/dist-packages/torch/jit/_script.py:1487: DeprecationWarning: `torch.jit.script` is deprecated. Please switch to `torch.compile` or `torch.export`.
    warnings.warn(

transformer_engine/pytorch/attention/dot_product_attention/utils.py:2087
  /home/klakhani/TE/transformer_engine/pytorch/attention/dot_product_attention/utils.py:2087: UserWarning: window_size should be (-1, 0) or (>=0, 0) for attn_mask_type=causal
    warnings.warn(

-- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html
=================================================================================== 44 passed, 8444 skipped, 3 warnings in 576.35s (0:09:36) ===================================================================================
klakhani@alon-ts1-iec-15:~/TE$ pytest tests/pytorch/attention/test_attention.py 
===================================================================================================== test session starts ======================================================================================================
platform linux -- Python 3.12.3, pytest-8.1.1, pluggy-1.6.0
rootdir: /home/klakhani/TE
configfile: pyproject.toml
plugins: typeguard-4.5.1, anyio-4.12.1, xdist-3.8.0, shard-0.1.2, flakefinder-1.1.0, hypothesis-6.130.8, rerunfailures-16.1
collected 2607 items                                                                                                                                                                                                           
Running 2607 items in this shard

tests/pytorch/attention/test_attention.py ......................................................................ss............................................................................s.....s........ssss....... [  6%]
.ssss.s.................................................ss..sss...ss..sss....s...s.....s...s.....s...s....ss..sss...ss..sss....s...s.....s...s.....s...s...ss.ss.ss.ss.ss.ss.ss.ss.ss.ss.ss.ss.......................... [ 14%]
................................ss..................ss..............ssssss..ssss....sssss.s....sssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss [ 23%]
ssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss [ 31%]
ssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss [ 39%]
ssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss [ 48%]
ssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss [ 56%]
ssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss [ 64%]
ssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss [ 72%]
ssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss [ 81%]
ssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss [ 89%]
ssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss [ 97%]
sssssssssssssssssssssssssssssssssssssssssssssssssssssssss                                                                                                                                                                [100%]

======================================================================================================= warnings summary =======================================================================================================
../../../usr/local/lib/python3.12/dist-packages/torch/jit/_script.py:1487
../../../usr/local/lib/python3.12/dist-packages/torch/jit/_script.py:1487
  /usr/local/lib/python3.12/dist-packages/torch/jit/_script.py:1487: DeprecationWarning: `torch.jit.script` is deprecated. Please switch to `torch.compile` or `torch.export`.
    warnings.warn(

transformer_engine/pytorch/attention/dot_product_attention/utils.py:2087
  /home/klakhani/TE/transformer_engine/pytorch/attention/dot_product_attention/utils.py:2087: UserWarning: window_size should be (-1, 0) or (>=0, 0) for attn_mask_type=causal
    warnings.warn(

transformer_engine/pytorch/attention/dot_product_attention/utils.py:2087
  /home/klakhani/TE/transformer_engine/pytorch/attention/dot_product_attention/utils.py:2087: UserWarning: window_size should be (-1, 0) or (>=0, 0) for attn_mask_type=padding_causal
    warnings.warn(

transformer_engine/pytorch/attention/dot_product_attention/utils.py:2087
  /home/klakhani/TE/transformer_engine/pytorch/attention/dot_product_attention/utils.py:2087: UserWarning: window_size should be (-1, 0) or (>=0, 0) for attn_mask_type=causal_bottom_right
    warnings.warn(

transformer_engine/pytorch/attention/dot_product_attention/utils.py:2087
  /home/klakhani/TE/transformer_engine/pytorch/attention/dot_product_attention/utils.py:2087: UserWarning: window_size should be (-1, 0) or (>=0, 0) for attn_mask_type=padding_causal_bottom_right
    warnings.warn(

tests/pytorch/attention/test_attention.py::test_dot_product_attention[False-False-None-True-False-base_1_0-model_configs0-dtype0]
  /usr/local/lib/python3.12/dist-packages/torch/autograd/graph.py:869: UserWarning: Attempting to run cuBLAS, but there was no current CUDA context! Attempting to set the primary context... (Triggered internally at /opt/pytorch/pytorch/aten/src/ATen/cuda/CublasHandlePool.cpp:335.)
    return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass

tests/pytorch/attention/test_attention.py: 14 warnings
  /usr/local/lib/python3.12/dist-packages/torch/jit/_script.py:365: DeprecationWarning: `torch.jit.script_method` is deprecated. Please switch to `torch.compile` or `torch.export`.
    warnings.warn(

-- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html
================================================================================== 396 passed, 2211 skipped, 21 warnings in 134.70s (0:02:14) ==================================================================================
klakhani@alon-ts1-iec-15:~/TE$ pytest tests/pytorch/attention/test_cp_utils.py 
===================================================================================================== test session starts ======================================================================================================
platform linux -- Python 3.12.3, pytest-8.1.1, pluggy-1.6.0
rootdir: /home/klakhani/TE
configfile: pyproject.toml
plugins: typeguard-4.5.1, anyio-4.12.1, xdist-3.8.0, shard-0.1.2, flakefinder-1.1.0, hypothesis-6.130.8, rerunfailures-16.1
collected 9 items                                                                                                                                                                                                                                 
Running 9 items in this shard

tests/pytorch/attention/test_cp_utils.py .........                                                                                                                                                                                          [100%]

================================================================================================================ warnings summary =================================================================================================================
../../../usr/local/lib/python3.12/dist-packages/torch/jit/_script.py:1487
../../../usr/local/lib/python3.12/dist-packages/torch/jit/_script.py:1487
  /usr/local/lib/python3.12/dist-packages/torch/jit/_script.py:1487: DeprecationWarning: `torch.jit.script` is deprecated. Please switch to `torch.compile` or `torch.export`.
    warnings.warn(

-- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html
========================================================================================================== 9 passed, 2 warnings in 3.93s ==========================================================================================================
klakhani@alon-ts1-iec-15:~/TE$ pytest tests/pytorch/attention/test_kv_cache.py 
========================================================================================================================== test session starts ===========================================================================================================================
platform linux -- Python 3.12.3, pytest-8.1.1, pluggy-1.6.0
rootdir: /home/klakhani/TE
configfile: pyproject.toml
plugins: typeguard-4.5.1, anyio-4.12.1, xdist-3.8.0, shard-0.1.2, flakefinder-1.1.0, hypothesis-6.130.8, rerunfailures-16.1
collected 576 items                                                                                                                                                                                                                                                      
Running 576 items in this shard

tests/pytorch/attention/test_kv_cache.py ssssssssssssssssssssssssssssssssssssssssssssssss........................ssssssssssssssssssssssssssssssssssssssssssssssss........................sssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss [ 37%]
ssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss [ 82%]
sssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssssss                                                                                                                                                              [100%]

============================================================================================================================ warnings summary ============================================================================================================================
../../../usr/local/lib/python3.12/dist-packages/torch/jit/_script.py:1487
../../../usr/local/lib/python3.12/dist-packages/torch/jit/_script.py:1487
  /usr/local/lib/python3.12/dist-packages/torch/jit/_script.py:1487: DeprecationWarning: `torch.jit.script` is deprecated. Please switch to `torch.compile` or `torch.export`.
    warnings.warn(

tests/pytorch/attention/test_kv_cache.py: 576 warnings
  /home/klakhani/TE/tests/pytorch/attention/test_kv_cache.py:86: UserWarning: torch.range is deprecated and will be removed in a future release because its behavior is inconsistent with Python's range builtin. Instead, use torch.arange, which produces values in [start, end).
    self.seq_ids = torch.range(0, total_requests - 1, dtype=torch.int32, device="cpu")

tests/pytorch/attention/test_kv_cache.py: 288 warnings
  /home/klakhani/TE/transformer_engine/pytorch/attention/inference.py:435: UserWarning: torch.range is deprecated and will be removed in a future release because its behavior is inconsistent with Python's range builtin. Instead, use torch.arange, which produces values in [start, end).
    self.batch_indices_post_step = torch.range(

tests/pytorch/attention/test_kv_cache.py: 14 warnings
  /usr/local/lib/python3.12/dist-packages/torch/jit/_script.py:365: DeprecationWarning: `torch.jit.script_method` is deprecated. Please switch to `torch.compile` or `torch.export`.
    warnings.warn(

-- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html
======================================================================================================== 48 passed, 528 skipped, 880 warnings in 62.17s (0:01:02) ========================================================================================================

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@KshitijLakhani KshitijLakhani force-pushed the klakhani/maint/sm120-thd-flash-support branch from 674394b to 998b3b8 Compare February 20, 2026 18:40
KshitijLakhani and others added 3 commits March 2, 2026 15:31
Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
…pe instead of TH1 for sm120

Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
@KshitijLakhani KshitijLakhani force-pushed the klakhani/maint/sm120-thd-flash-support branch from dc282ea to b2f5864 Compare March 2, 2026 23:31
pre-commit-ci bot and others added 9 commits March 2, 2026 23:32
Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
…incorrect max logit calculation (includes padded tokens in max calculation)

Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
…pa arbitrary kernel call

Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
…clude a check for sm120

Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
@KshitijLakhani KshitijLakhani changed the title Enable sm120 support for fused attn if cuDNN is 9.18.1+ [PyT] [Common] Enable sm120 support for fused attn if cuDNN is 9.18.1+ Mar 11, 2026
@KshitijLakhani KshitijLakhani self-assigned this Mar 11, 2026
@KshitijLakhani KshitijLakhani marked this pull request as ready for review March 12, 2026 21:51
@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps bot commented Mar 12, 2026

Greptile Summary

This PR enables SM120 (Blackwell) support for fused attention with THD (packed/ragged) sequence formats when cuDNN ≥ 9.18.1 is present. The core insight is that cuDNN on SM120 rejects the "token-count" dimensional layout used on earlier architectures and instead requires BHSD-like dimensions with max_seqlen at graph build time, relying on ragged offsets for variable-length boundaries.

Key changes:

  • fused_attn.cpp: nvte_get_fused_attn_backend now gates SM120 + NVTE_F16_arbitrary_seqlen behind cuDNN ≥ 9.18.1 and additionally blocks T3HD/TH3D layouts which remain unsupported on SM120.
  • fused_attn_f16_arbitrary_seqlen.cu: Both forward and backward impls introduce a use_ragged_stats boolean (always false on SM120) that controls whether the softmax stats tensor uses the packed TH1 ragged layout or the dense BHS1 layout. On SM120 the batch/seqlen dimensions are not replaced with token counts, keeping the graph description in BHSD form while ragged offsets still govern Q/K/V/O variable boundaries.
  • utils.py: KV-cache path disables both fused and flash attention on SM120; the THD layout filter is refined to allow SM120 + cuDNN ≥ 9.18.1 while still blocking T3HD/TH3D.
  • context_parallel.py: softmax_lse_in_packed_format is set to False for SM120 (BHS1 shape) even when cuDNN ≥ 9.6, keeping CP bookkeeping consistent with the new stats layout.
  • fused_attn.py (fused_attn_fwd): The return_max_logit path now detects 4-D (BHS1) Max/Sum_Exp tensors from SM120 and applies a sequence-length validity mask before reducing to a per-head max logit, preventing padded positions from corrupting the result.

Confidence Score: 4/5

  • The PR is safe to merge with one minor documentation fix needed; the functional logic is correct and well-tested on SM120 hardware.
  • The architectural approach (BHSD dimensions + ragged offsets for SM120) is sound and consistently applied across both the forward and backward C++ impls and the Python wrappers. The Python-level and C++-level filtering for unsupported layouts (T3HD/TH3D) on SM120 are symmetric. The return_max_logit masking fix correctly handles padded positions. The only finding is a stale inline comment in context_parallel.py that does not document the SM120 BHS1 exception. No functional bugs were identified.
  • transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu — the dual path (token-count vs. max-seqlen dimensions) is subtle; reviewers should verify the cache-key (FADescriptor_v1) behaviour on SM120 does not cause unexpected cache thrashing when batch size or seqlen change across calls.

Important Files Changed

Filename Overview
transformer_engine/common/fused_attn/fused_attn.cpp Adds SM120 guard in nvte_get_fused_attn_backend: requires cuDNN ≥ 9.18.1, then additionally blocks T3HD/TH3D layouts that cuDNN still does not support on SM120. Logic is consistent with Python-side filtering and correctly falls back to NVTE_No_Backend.
transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu Introduces use_ragged_stats flag (false on SM120) in both fwd and bwd impls. On SM120 the batch/seqlen dimensions are kept in BHSD form rather than being replaced by token counts, and the stats tensor uses a dense BHS1 layout instead of the ragged TH1 layout. Shape allocation in fused_attn_arbitrary_seqlen_fwd is updated consistently. The backward asymmetry for set_max_total_seq_len_kv (explicit inline check vs. relying on use_ragged_stats) is intentional and correct but slightly inconsistent in style.
transformer_engine/pytorch/attention/dot_product_attention/utils.py KV-cache path now disables both fused and flash attention for SM120 (not just sm89). THD layout path is refined: cuDNN < 9.18.1 disables fused attn entirely on SM120; cuDNN ≥ 9.18.1 only disables it for T3HD/TH3D. Logging messages are appropriately updated.
transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py softmax_lse_in_packed_format now also returns False for SM120 devices (even with cuDNN ≥ 9.6), correctly signalling BHS1 format to the CP forward/backward kernels. The existing inline comment above this flag is stale and does not document the SM120 exception.
transformer_engine/pytorch/cpp_extensions/fused_attn.py When return_max_logit=True and qkv_format is THD, the code now detects the 4-D (BHS1) Max/Sum_Exp tensor that SM120 (or older cuDNN) produces and applies a validity mask before computing the per-head max logit, preventing inf/garbage at padded positions from contaminating the result.

Flowchart

%%{init: {'theme': 'neutral'}}%%
flowchart TD
    A[nvte_get_fused_attn_backend] --> B{sm_arch == 120?}
    B -- No --> C[Existing backend selection logic]
    B -- Yes --> D{cudnn_runtime_version >= 91801?}
    D -- No --> E[Return NVTE_No_Backend\n warn: upgrade cuDNN]
    D -- Yes --> F{qkv_layout is T3HD or TH3D?}
    F -- Yes --> G[Return NVTE_No_Backend\n warn: unsupported layout]
    F -- No --> H[Return NVTE_F16_arbitrary_seqlen]

    H --> I[fused_attn_arbitrary_seqlen_fwd_impl]
    I --> J{sm_arch == 120?}
    J -- No --> K[b = max_b\n s_q = max_t_q\n s_kv = max_t_kv\n use_ragged_stats = true]
    J -- Yes --> L[b/s_q/s_kv unchanged\n use_ragged_stats = false]
    K --> M[Stats shape: TH1 packed\n ragged offset on stats]
    L --> N[Stats shape: BHS1\n no ragged offset on stats\n ragged offset on Q K V O only]

    I --> O[fused_attn_fwd Python wrapper]
    O --> P{qkv_format=thd AND max_tensor.ndim==4?}
    P -- Yes --> Q[Mask padded positions to -inf\n before computing max_logit]
    P -- No --> R[Compute max_logit directly]
Loading

Last reviewed commit: "[pre-commit.ci] auto..."

Comment on lines +639 to +641
NVTE_ERROR(
"T3HD and TH3D QKV layouts are not supported by cuDNN on SM120 "
"Use thd_thd_thd or other THD layouts instead.");
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing period in forward error message

The forward error message is missing a period that is present in the corresponding backward error message (line 748). Minor inconsistency but worth fixing for uniformity.

Suggested change
NVTE_ERROR(
"T3HD and TH3D QKV layouts are not supported by cuDNN on SM120 "
"Use thd_thd_thd or other THD layouts instead.");
"T3HD and TH3D QKV layouts are not supported by cuDNN on SM120. "
"Use thd_thd_thd or other THD layouts instead.");

Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!

@KshitijLakhani
Copy link
Copy Markdown
Collaborator Author

/te-ci L0 L1

Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
@KshitijLakhani
Copy link
Copy Markdown
Collaborator Author

/te-ci L0 L1

@KshitijLakhani KshitijLakhani requested a review from cyanguwa March 13, 2026 21:21
NVTE_ERROR(
"T3HD and TH3D QKV layouts are not supported by cuDNN on SM120. "
"Use thd_thd_thd or other THD layouts instead.");
}
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we add this logic to nvte_get_fused_backend()? So we don't error here but rather as a "not supported" case.

Copy link
Copy Markdown
Collaborator Author

@KshitijLakhani KshitijLakhani Mar 18, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My initial thought process was this use case was:

  1. To disabled fused attn in the Python layer (PyT specific, here) as we already have a check for sm120 for cuDNN version < 9.18.1 and hence I added it in here: https://github.com/KshitijLakhani/TransformerEngine/blob/bcfef909681c24a95163ecf987fbf952a4f4eb4a/transformer_engine/pytorch/attention/dot_product_attention/utils.py#L705
    (already in this PR)

  2. To not allow a call to cuDNN kernels from the C++ (common) layer so that TE can produce an error for the user with an easier to understand message rather than a difficult to understand TE message
    (already in this PR and the lines of interest for this comment)

If I understand right you are suggesting we let #1 be there and replace #2 with a call to set backend to NVTE_No_Backend in nvte_get_fused_attention_backend() for cudnn >= 9.18.1 + t3hd/th3d + sm120 + NVTE_F16_arbitrary_seqlen
Sounds okay ?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel the sm120 check, the T3HD/TH3D check, cuDNN version check can all be in nvte_get_fused_attention_backend() so it produces NVTE_No_Backend. Sometimes we put checks in utils.py because it's not easy to do on the C side, like with the KV cache feature. But for the logic you have in utils.py and here, they can probably go into nvte_get_fused_attention_backend?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, greptile seems to say there are duplicate device queries (device_id/sm_arch_) in your code (3 of them). Can you check if it's true?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, greptile seems to say there are duplicate device queries (device_id/sm_arch_) in your code (3 of them). Can you check if it's true?

I investigated this. Seems like I've added two new calls to cuda::current_device() in fused_attn_f16_arbitrary_seqlen.cu:

  • fused_attn_arbitrary_seqlen_fwd()
  • fused_attn_arbitrary_seqlen_fwd_impl()

There was one from before:

  • fused_attn_arbitrary_seqlen_bwd_impl()

I could pass the sm_arch from fused_attn_arbitrary_seqlen_fwd() to fused_attn_arbitrary_seqlen_fwd_impl() but that would just increase an additional arg. Do you think we should consolidate it in this way ?

Copy link
Copy Markdown
Collaborator Author

@KshitijLakhani KshitijLakhani Mar 20, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Moved the T3HD/ TH3D check to nvte_get_fused_attention_backend()
Re-ran all tests on sm120 successfully @cyanguwa

output_S->data.dptr = nullptr;
if (q_format == NVTE_QKV_Format::NVTE_THD && cudnn_runtime_version >= 90600) {
if ((q_format == NVTE_QKV_Format::NVTE_THD && cudnn_runtime_version >= 90600) &&
!(sm_arch_ >= 120)) {
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if we should do sm_arch+ != 120 instead. I feel our sm numbers are not monotonically increasing. I made the mistake of doing >sm100 sometimes, but then sm103 and sm120 had different support matrix.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree. Will push a commit for it.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Made this change everywhere in the PR

KshitijLakhani and others added 2 commits March 19, 2026 18:00
KshitijLakhani and others added 4 commits March 19, 2026 20:54
Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
…nstead of higher layers in TE stack

Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
@KshitijLakhani KshitijLakhani requested a review from cyanguwa March 20, 2026 00:15
@KshitijLakhani
Copy link
Copy Markdown
Collaborator Author

/te-ci L0 L1

1 similar comment
@KshitijLakhani
Copy link
Copy Markdown
Collaborator Author

/te-ci L0 L1

@KshitijLakhani
Copy link
Copy Markdown
Collaborator Author

PyT and common jobs had passed in previous CI runs
JAX jobs failed due to a CI based issue unrelated to this PR
Re ran the JAX jobs manually (46638097) once the CI issue was fixed and they pass. Going ahead with merge

@KshitijLakhani KshitijLakhani merged commit 487d68c into NVIDIA:main Mar 22, 2026
33 of 44 checks passed
KshitijLakhani added a commit that referenced this pull request Mar 25, 2026
#2693)

* Enable sm120 support for fused attn if cuDNN is 9.18.1+

Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Force intermediate tensors such as S, Sum_Exp, and Max to be BHS1 shape instead of TH1 for sm120

Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Add support for sm120 correct batch, seq dims

Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>

* Add support for sm120 BHS1 style max logit even QKV are THD to avoid incorrect max logit calculation (includes padded tokens in max calculation)

Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>

* Disable fused and flash attn for sm120 filter:kv cache

Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>

* For CP P2P attn, set softmax_lse_in_packed_format to False if sm120+

Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>

* Assert in TE if T3HD/TH3D layout is used on sm120 before cuDNN F16 sdpa arbitrary kernel call

Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>

* Modify is_ragged_q && cudnn_runtime_version >= 90600 check to also include a check for sm120

Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* nit: Code clean up

Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>

* Disable fused attn for T3HD and TH3D

Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>

* nit: Add missed sm120 guard

Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>

* Modify sm120 condition to be very specific to sm120 and not generalized to sm120+

Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* nit: Fix missing sm120 check in fwd

Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>

* Move the check for sm120 T3HD/TH3D to nvte_get_fused_attn_backend() instead of higher layers in TE stack

Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>

* nit: Check for matching sm120 and not sm120+

Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
vthumbe1503 pushed a commit to ksivaman/TransformerEngine-1 that referenced this pull request Apr 1, 2026
NVIDIA#2693)

* Enable sm120 support for fused attn if cuDNN is 9.18.1+

Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Force intermediate tensors such as S, Sum_Exp, and Max to be BHS1 shape instead of TH1 for sm120

Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Add support for sm120 correct batch, seq dims

Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>

* Add support for sm120 BHS1 style max logit even QKV are THD to avoid incorrect max logit calculation (includes padded tokens in max calculation)

Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>

* Disable fused and flash attn for sm120 filter:kv cache

Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>

* For CP P2P attn, set softmax_lse_in_packed_format to False if sm120+

Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>

* Assert in TE if T3HD/TH3D layout is used on sm120 before cuDNN F16 sdpa arbitrary kernel call

Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>

* Modify is_ragged_q && cudnn_runtime_version >= 90600 check to also include a check for sm120

Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* nit: Code clean up

Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>

* Disable fused attn for T3HD and TH3D

Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>

* nit: Add missed sm120 guard

Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>

* Modify sm120 condition to be very specific to sm120 and not generalized to sm120+

Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* nit: Fix missing sm120 check in fwd

Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>

* Move the check for sm120 T3HD/TH3D to nvte_get_fused_attn_backend() instead of higher layers in TE stack

Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>

* nit: Check for matching sm120 and not sm120+

Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants