Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions test/inductor/test_aot_inductor.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from torch.testing._internal import common_utils
from torch.testing._internal.common_cuda import (
PLATFORM_SUPPORTS_FLASH_ATTENTION,
PLATFORM_SUPPORTS_MEM_EFF_ATTENTION,
SM80OrLater,
SM90OrLater,
)
Expand Down Expand Up @@ -929,7 +930,10 @@ def forward(self, q, k, v):
@unittest.skipIf(IS_FBCODE, "Not yet runnable in fbcode")
@unittest.skipIf(not SM80OrLater, "bfloat16 only supported in sm80+")
@unittest.skipIf(
not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Some archs don't support SDPA"
# for archs where this isn't lowered to flash attention, the math
# backend will be used and it doesn't work for bfloat16
not PLATFORM_SUPPORTS_FLASH_ATTENTION,
"Some archs don't support SDPA with bfloat16",
)
def test_sdpa_2(self):
class Model(torch.nn.Module):
Expand Down Expand Up @@ -1039,7 +1043,7 @@ def forward(self, x, y):

@skipIfXpu(msg="_scaled_dot_product_flash_attention is not supported on XPU yet")
@unittest.skipIf(
not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Some archs don't support SDPA"
not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Some archs don't support flash SDPA"
)
def test_fallback_kernel_with_symexpr_output(self):
if self.device != GPU_TYPE:
Expand Down Expand Up @@ -3036,7 +3040,7 @@ def grid(meta):
)

@unittest.skipIf(
not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Some archs don't support SDPA"
not PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, "Some archs don't support mem eff SDPA"
)
def test_scaled_dot_product_efficient_attention(self):
if self.device != GPU_TYPE:
Expand Down
3 changes: 3 additions & 0 deletions test/inductor/test_torchinductor.py
Original file line number Diff line number Diff line change
Expand Up @@ -10611,6 +10611,9 @@ def fn(q, k, v):
)

@expectedFailureXPU
@unittest.skipIf(
not PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, "Some archs don't support mem eff SDPA"
)
def test_scaled_dot_product_efficient_attention(self):
if self.device == "cpu":
raise unittest.SkipTest(f"requires {GPU_TYPE}")
Expand Down