From c1cee940d227404917c020a351e7865853a14adb Mon Sep 17 00:00:00 2001 From: Sampsa Date: Wed, 18 Jun 2025 11:45:10 +0000 Subject: [PATCH 1/2] some additional skips --- test/inductor/test_aot_inductor.py | 10 +++++++--- test/inductor/test_torchinductor.py | 3 +++ 2 files changed, 10 insertions(+), 3 deletions(-) diff --git a/test/inductor/test_aot_inductor.py b/test/inductor/test_aot_inductor.py index a41738aaea774..d9b6f941d466d 100644 --- a/test/inductor/test_aot_inductor.py +++ b/test/inductor/test_aot_inductor.py @@ -30,6 +30,7 @@ from torch.testing._internal import common_utils from torch.testing._internal.common_cuda import ( PLATFORM_SUPPORTS_FLASH_ATTENTION, + PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, SM80OrLater, SM90OrLater, ) @@ -929,8 +930,11 @@ def forward(self, q, k, v): @unittest.skipIf(IS_FBCODE, "Not yet runnable in fbcode") @unittest.skipIf(not SM80OrLater, "bfloat16 only supported in sm80+") @unittest.skipIf( - not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Some archs don't support SDPA" + # for archs where this isn't lowered to flash attention, the math + # backend will be used and it doesn't work for bfloat16 + not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Some archs don't support SDPA with bfloat16" ) + def test_sdpa_2(self): class Model(torch.nn.Module): def __init__(self) -> None: @@ -1039,7 +1043,7 @@ def forward(self, x, y): @skipIfXpu(msg="_scaled_dot_product_flash_attention is not supported on XPU yet") @unittest.skipIf( - not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Some archs don't support SDPA" + not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Some archs don't support flash SDPA" ) def test_fallback_kernel_with_symexpr_output(self): if self.device != GPU_TYPE: @@ -3036,7 +3040,7 @@ def grid(meta): ) @unittest.skipIf( - not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Some archs don't support SDPA" + not PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, "Some archs don't support mem eff SDPA" ) def test_scaled_dot_product_efficient_attention(self): if self.device != GPU_TYPE: diff --git a/test/inductor/test_torchinductor.py b/test/inductor/test_torchinductor.py index 39861a364ccf2..a97a92b274b62 100644 --- a/test/inductor/test_torchinductor.py +++ b/test/inductor/test_torchinductor.py @@ -10611,6 +10611,9 @@ def fn(q, k, v): ) @expectedFailureXPU + @unittest.skipIf( + not PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, "Some archs don't support SDPA" + ) def test_scaled_dot_product_efficient_attention(self): if self.device == "cpu": raise unittest.SkipTest(f"requires {GPU_TYPE}") From d81911fc23301c7e82db40178bd519057c601a2c Mon Sep 17 00:00:00 2001 From: Sampsa Date: Wed, 18 Jun 2025 11:53:21 +0000 Subject: [PATCH 2/2] some additional skips --- test/inductor/test_aot_inductor.py | 4 ++-- test/inductor/test_torchinductor.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/test/inductor/test_aot_inductor.py b/test/inductor/test_aot_inductor.py index d9b6f941d466d..4cbf7ef38c399 100644 --- a/test/inductor/test_aot_inductor.py +++ b/test/inductor/test_aot_inductor.py @@ -932,9 +932,9 @@ def forward(self, q, k, v): @unittest.skipIf( # for archs where this isn't lowered to flash attention, the math # backend will be used and it doesn't work for bfloat16 - not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Some archs don't support SDPA with bfloat16" + not PLATFORM_SUPPORTS_FLASH_ATTENTION, + "Some archs don't support SDPA with bfloat16", ) - def test_sdpa_2(self): class Model(torch.nn.Module): def __init__(self) -> None: diff --git a/test/inductor/test_torchinductor.py b/test/inductor/test_torchinductor.py index a97a92b274b62..67e3466371e28 100644 --- a/test/inductor/test_torchinductor.py +++ b/test/inductor/test_torchinductor.py @@ -10612,7 +10612,7 @@ def fn(q, k, v): @expectedFailureXPU @unittest.skipIf( - not PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, "Some archs don't support SDPA" + not PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, "Some archs don't support mem eff SDPA" ) def test_scaled_dot_product_efficient_attention(self): if self.device == "cpu":