Skip to content

Commit

Permalink
[CUDA] [CI] Disable flash attention for sm87 architecture when the he…
Browse files Browse the repository at this point in the history
  • Loading branch information
nWEIdia authored and jeffdaily committed Feb 8, 2024
1 parent b8c7da0 commit 8758ee4
Showing 1 changed file with 6 additions and 6 deletions.
12 changes: 6 additions & 6 deletions test/test_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def use_deterministic_algorithims(mode: bool, warn_only: bool):
default_atol = {torch.float16: 1e-3, torch.bfloat16: 1e-3, torch.float32: 1e-5}
default_rtol = {torch.float16: 1e-3, torch.bfloat16: 1.6e-2, torch.float32: 1.3e-6}

isSM86or89Device = torch.cuda.is_available() and torch.cuda.get_device_capability() in [(8, 6), (8, 9)]
isSM8XDevice = torch.cuda.is_available() and torch.cuda.get_device_capability() in [(8, 6), (8, 7), (8, 9)]
isSM90Device = torch.cuda.is_available() and torch.cuda.get_device_capability() == (9, 0)
isSM5xDevice = torch.cuda.is_available() and torch.cuda.get_device_capability()[0] == 5
isLessThanSM80Device = torch.cuda.is_available() and torch.cuda.get_device_capability()[0] < 8
Expand Down Expand Up @@ -1322,7 +1322,7 @@ class TestSDPAFailureModes(NNTestCase):
_do_cuda_non_default_stream = True

@onlyCUDA
@unittest.skipIf(not PLATFORM_SUPPORTS_FLASH_ATTENTION or not isSM86or89Device,
@unittest.skipIf(not PLATFORM_SUPPORTS_FLASH_ATTENTION or not isSM8XDevice,
"Does not support fused SDPA or not SM86+ hardware")
@parametrize("head_dim", [193, 204, 256])
def test_flash_backward_failure_sm86plus(self, device, head_dim: int):
Expand Down Expand Up @@ -2676,8 +2676,8 @@ def is_power_of_2(n):
if head_dim > 128:
self.skipTest("Flash attention on ROCM only supports power of two headdim <= 128, for now.")

if isSM86or89Device and head_dim in range(193, 256 + 1):
self.skipTest("Flash attention on sm86 and sm89 for headdim > 192 currently disabled")
if isSM8XDevice and head_dim in range(193, 256 + 1):
self.skipTest("Flash attention on sm86, sm87, and sm89 for headdim > 192 currently disabled")
if is_causal and seq_len_q != seq_len_k:
self.skipTest("Flash V2 does not accept is_casual when seq_len_q != seq_len_k")

Expand Down Expand Up @@ -2742,8 +2742,8 @@ def is_power_of_2(n):

upstream_grad = torch.rand_like(out, requires_grad=False)

# backward for flash attention on sm86 and sm89 for headdim > 64 currently disabled
if isSM86or89Device and head_dim in range(193, 256):
# backward for flash attention on sm86, sm87, and sm89 for headdim >= 193 currently disabled
if isSM8XDevice and head_dim in range(193, 256):
self.assertRaises(RuntimeError, lambda: out.backward(upstream_grad))
return
out.backward(upstream_grad)
Expand Down

0 comments on commit 8758ee4

Please sign in to comment.