diff --git a/test/test_transformers.py b/test/test_transformers.py index 8dd54ae00dd2..e7ed97eefdc4 100644 --- a/test/test_transformers.py +++ b/test/test_transformers.py @@ -66,7 +66,7 @@ def use_deterministic_algorithims(mode: bool, warn_only: bool): default_atol = {torch.float16: 1e-3, torch.bfloat16: 1e-3, torch.float32: 1e-5} default_rtol = {torch.float16: 1e-3, torch.bfloat16: 1.6e-2, torch.float32: 1.3e-6} -isSM86or89Device = torch.cuda.is_available() and torch.cuda.get_device_capability() in [(8, 6), (8, 9)] +isSM8XDevice = torch.cuda.is_available() and torch.cuda.get_device_capability() in [(8, 6), (8, 7), (8, 9)] isSM90Device = torch.cuda.is_available() and torch.cuda.get_device_capability() == (9, 0) isSM5xDevice = torch.cuda.is_available() and torch.cuda.get_device_capability()[0] == 5 isLessThanSM80Device = torch.cuda.is_available() and torch.cuda.get_device_capability()[0] < 8 @@ -1322,7 +1322,7 @@ class TestSDPAFailureModes(NNTestCase): _do_cuda_non_default_stream = True @onlyCUDA - @unittest.skipIf(not PLATFORM_SUPPORTS_FLASH_ATTENTION or not isSM86or89Device, + @unittest.skipIf(not PLATFORM_SUPPORTS_FLASH_ATTENTION or not isSM8XDevice, "Does not support fused SDPA or not SM86+ hardware") @parametrize("head_dim", [193, 204, 256]) def test_flash_backward_failure_sm86plus(self, device, head_dim: int): @@ -2676,8 +2676,8 @@ def is_power_of_2(n): if head_dim > 128: self.skipTest("Flash attention on ROCM only supports power of two headdim <= 128, for now.") - if isSM86or89Device and head_dim in range(193, 256 + 1): - self.skipTest("Flash attention on sm86 and sm89 for headdim > 192 currently disabled") + if isSM8XDevice and head_dim in range(193, 256 + 1): + self.skipTest("Flash attention on sm86, sm87, and sm89 for headdim > 192 currently disabled") if is_causal and seq_len_q != seq_len_k: self.skipTest("Flash V2 does not accept is_casual when seq_len_q != seq_len_k") @@ -2742,8 +2742,8 @@ def is_power_of_2(n): upstream_grad = torch.rand_like(out, requires_grad=False) - # backward for flash attention on sm86 and sm89 for headdim > 64 currently disabled - if isSM86or89Device and head_dim in range(193, 256): + # backward for flash attention on sm86, sm87, and sm89 for headdim >= 193 currently disabled + if isSM8XDevice and head_dim in range(193, 256): self.assertRaises(RuntimeError, lambda: out.backward(upstream_grad)) return out.backward(upstream_grad)