Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions test/test_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -3779,6 +3779,10 @@ def test_causal_variants(self, device, causal_variant: CausalVariant, shape: Lis
if TEST_WITH_ROCM and causal_variant == CausalVariant.LOWER_RIGHT:
self.skipTest("No support for LOWER_RIGHT variant for now")
return
if (TEST_WITH_ROCM
and "gfx12" in torch.cuda.get_device_properties(0).gcnArchName.split(":")[0]
and self._testMethodName == "test_causal_variants_causal_variant_CausalVariant_UPPER_LEFT_shape3_cuda"):
self.skipTest(f"Failed on Navi4x in release/2.5 for shape {shape}")

bsz, num_heads, seq_len_q, seq_len_kv, head_dim = shape
make_q_tensor = partial(make_tensor, SdpaShape(bsz, num_heads, seq_len_q, head_dim))
Expand Down
35 changes: 26 additions & 9 deletions torch/nn/attention/__init__.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,14 @@
# mypy: allow-untyped-defs
""" This module contains functions and classes that alter the behavior of torch.nn.functional.scaled_dot_product_attention """
import contextlib
from typing import List, Union
from typing import Iterable, List, Union
from warnings import warn

import torch.backends.cuda
from torch._C import _SDPBackend as SDPBackend
from torch.backends.cuda import (
can_use_efficient_attention,
can_use_flash_attention,
cudnn_sdp_enabled,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

are these imports not used? safe to remove?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are no errors about unknown names in the tests
Also, I can't find any mention of these names:

  • cudnn_sdp_enabled
  • enable_cudnn_sdp
  • enable_flash_sdp
  • enable_math_sdp
  • enable_mem_efficient_sdp
  • flash_sdp_enabled
  • math_sdp_enabled
  • mem_efficient_sdp_enabled

enable_cudnn_sdp,
enable_flash_sdp,
enable_math_sdp,
enable_mem_efficient_sdp,
flash_sdp_enabled,
math_sdp_enabled,
mem_efficient_sdp_enabled,
SDPAParams,
)

Expand Down Expand Up @@ -66,6 +59,30 @@ def _raise_kernel_warnings(params: SDPAParams) -> None:
warn("Flash attention can't be used because:")
can_use_flash_attention(params, True)

_backend_names = {
"cudnn": "CUDNN_ATTENTION",
"flash": "FLASH_ATTENTION",
"mem_efficient": "EFFICIENT_ATTENTION",
"math": "MATH",
}


def _backend_from_string(name: str):
return getattr(SDPBackend, name)


def _cur_sdpa_kernel_backends():
backends: List[SDPBackend] = []
for name, val in _backend_names.items():
if getattr(torch.backends.cuda, f"{name}_sdp_enabled")():
backends.append(getattr(SDPBackend, val))
return backends


def _sdpa_kernel(backends: Iterable[SDPBackend]):
for name, val in _backend_names.items():
enabled = getattr(SDPBackend, val) in backends
getattr(torch.backends.cuda, f"enable_{name}_sdp")(enabled)

@contextlib.contextmanager
def sdpa_kernel(
Expand Down