From 4eb5066c7e862210b7fed7175f561d944e8a6c18 Mon Sep 17 00:00:00 2001 From: Dmitry Nikolaev Date: Tue, 1 Jul 2025 14:19:07 +0000 Subject: [PATCH] fix and skip test_causal_variants on Navi4x --- test/test_transformers.py | 4 ++++ torch/nn/attention/__init__.py | 35 +++++++++++++++++++++++++--------- 2 files changed, 30 insertions(+), 9 deletions(-) diff --git a/test/test_transformers.py b/test/test_transformers.py index a71aa6be428e5..42145f6e09ebf 100644 --- a/test/test_transformers.py +++ b/test/test_transformers.py @@ -3779,6 +3779,10 @@ def test_causal_variants(self, device, causal_variant: CausalVariant, shape: Lis if TEST_WITH_ROCM and causal_variant == CausalVariant.LOWER_RIGHT: self.skipTest("No support for LOWER_RIGHT variant for now") return + if (TEST_WITH_ROCM + and "gfx12" in torch.cuda.get_device_properties(0).gcnArchName.split(":")[0] + and self._testMethodName == "test_causal_variants_causal_variant_CausalVariant_UPPER_LEFT_shape3_cuda"): + self.skipTest(f"Failed on Navi4x in release/2.5 for shape {shape}") bsz, num_heads, seq_len_q, seq_len_kv, head_dim = shape make_q_tensor = partial(make_tensor, SdpaShape(bsz, num_heads, seq_len_q, head_dim)) diff --git a/torch/nn/attention/__init__.py b/torch/nn/attention/__init__.py index 756c31fa08f5c..74492f04cf62d 100644 --- a/torch/nn/attention/__init__.py +++ b/torch/nn/attention/__init__.py @@ -1,21 +1,14 @@ # mypy: allow-untyped-defs """ This module contains functions and classes that alter the behavior of torch.nn.functional.scaled_dot_product_attention """ import contextlib -from typing import List, Union +from typing import Iterable, List, Union from warnings import warn +import torch.backends.cuda from torch._C import _SDPBackend as SDPBackend from torch.backends.cuda import ( can_use_efficient_attention, can_use_flash_attention, - cudnn_sdp_enabled, - enable_cudnn_sdp, - enable_flash_sdp, - enable_math_sdp, - enable_mem_efficient_sdp, - flash_sdp_enabled, - math_sdp_enabled, - mem_efficient_sdp_enabled, SDPAParams, ) @@ -66,6 +59,30 @@ def _raise_kernel_warnings(params: SDPAParams) -> None: warn("Flash attention can't be used because:") can_use_flash_attention(params, True) +_backend_names = { + "cudnn": "CUDNN_ATTENTION", + "flash": "FLASH_ATTENTION", + "mem_efficient": "EFFICIENT_ATTENTION", + "math": "MATH", +} + + +def _backend_from_string(name: str): + return getattr(SDPBackend, name) + + +def _cur_sdpa_kernel_backends(): + backends: List[SDPBackend] = [] + for name, val in _backend_names.items(): + if getattr(torch.backends.cuda, f"{name}_sdp_enabled")(): + backends.append(getattr(SDPBackend, val)) + return backends + + +def _sdpa_kernel(backends: Iterable[SDPBackend]): + for name, val in _backend_names.items(): + enabled = getattr(SDPBackend, val) in backends + getattr(torch.backends.cuda, f"enable_{name}_sdp")(enabled) @contextlib.contextmanager def sdpa_kernel(