diff --git a/test/inductor/test_aot_inductor.py b/test/inductor/test_aot_inductor.py index a41738aaea77..4cbf7ef38c39 100644 --- a/test/inductor/test_aot_inductor.py +++ b/test/inductor/test_aot_inductor.py @@ -30,6 +30,7 @@ from torch.testing._internal import common_utils from torch.testing._internal.common_cuda import ( PLATFORM_SUPPORTS_FLASH_ATTENTION, + PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, SM80OrLater, SM90OrLater, ) @@ -929,7 +930,10 @@ def forward(self, q, k, v): @unittest.skipIf(IS_FBCODE, "Not yet runnable in fbcode") @unittest.skipIf(not SM80OrLater, "bfloat16 only supported in sm80+") @unittest.skipIf( - not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Some archs don't support SDPA" + # for archs where this isn't lowered to flash attention, the math + # backend will be used and it doesn't work for bfloat16 + not PLATFORM_SUPPORTS_FLASH_ATTENTION, + "Some archs don't support SDPA with bfloat16", ) def test_sdpa_2(self): class Model(torch.nn.Module): @@ -1039,7 +1043,7 @@ def forward(self, x, y): @skipIfXpu(msg="_scaled_dot_product_flash_attention is not supported on XPU yet") @unittest.skipIf( - not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Some archs don't support SDPA" + not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Some archs don't support flash SDPA" ) def test_fallback_kernel_with_symexpr_output(self): if self.device != GPU_TYPE: @@ -3036,7 +3040,7 @@ def grid(meta): ) @unittest.skipIf( - not PLATFORM_SUPPORTS_FLASH_ATTENTION, "Some archs don't support SDPA" + not PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, "Some archs don't support mem eff SDPA" ) def test_scaled_dot_product_efficient_attention(self): if self.device != GPU_TYPE: diff --git a/test/inductor/test_torchinductor.py b/test/inductor/test_torchinductor.py index 39861a364ccf..67e3466371e2 100644 --- a/test/inductor/test_torchinductor.py +++ b/test/inductor/test_torchinductor.py @@ -10611,6 +10611,9 @@ def fn(q, k, v): ) @expectedFailureXPU + @unittest.skipIf( + not PLATFORM_SUPPORTS_MEM_EFF_ATTENTION, "Some archs don't support mem eff SDPA" + ) def test_scaled_dot_product_efficient_attention(self): if self.device == "cpu": raise unittest.SkipTest(f"requires {GPU_TYPE}")