|
52 | 52 | PLATFORM_SUPPORTS_CUDNN_ATTENTION, |
53 | 53 | tf32_on_and_off, |
54 | 54 | tf32_enabled, |
| 55 | + math_sdp_precision, |
55 | 56 | ) |
56 | 57 |
|
57 | 58 | if TEST_FAIRSEQ: |
@@ -126,6 +127,12 @@ def _check_equal( |
126 | 127 | _check_equal(gold, ref, tst, fudge_factor, tensor_name) |
127 | 128 | return |
128 | 129 |
|
| 130 | + if golden.is_cuda and golden.dtype == torch.float32: |
| 131 | + assert torch.backends.cuda.math_sdp.fp32_precision == "ieee", ( |
| 132 | + "Testing script error: FP32 golden tensor must be calculated with IEEE" |
| 133 | + " precision. Add @math_sdp_precision('ieee') to related tests to fix it." |
| 134 | + ) |
| 135 | + |
129 | 136 | # Compute error between golden |
130 | 137 | test_error = (golden - test).abs().max() |
131 | 138 | ref_error = (golden - reference).abs().max() |
@@ -3383,6 +3390,7 @@ def test_mem_eff_backwards_determinism(self, device): |
3383 | 3390 | ) |
3384 | 3391 | @parametrize("scale", [None, "l1"]) |
3385 | 3392 | @tf32_enabled() |
| 3393 | + @math_sdp_precision("ieee") |
3386 | 3394 | def test_mem_efficient_attention_vs_math_ref_grads(self, device, batch_size: int, seq_len_q: int, seq_len_k: int, |
3387 | 3395 | head_dim: int, is_causal: bool, dropout_p: float, dtype: torch.dtype, |
3388 | 3396 | scale: str): |
@@ -3498,6 +3506,7 @@ def _get_mem_eff_drop_mask(batch_size, n_heads, q_len, kv_len, p, seed, offset, |
3498 | 3506 | ) |
3499 | 3507 | @parametrize("scale", [None, "l1"]) |
3500 | 3508 | @tf32_enabled() |
| 3509 | + @math_sdp_precision("ieee") |
3501 | 3510 | def test_mem_efficient_attention_attn_mask_vs_math_ref_grads(self, device, batch_size: int, seq_len_q: int, |
3502 | 3511 | seq_len_k: int, head_dim: int, is_causal: bool, |
3503 | 3512 | dropout_p: float, dtype: torch.dtype, |
@@ -3611,6 +3620,7 @@ def _get_mem_eff_drop_mask(batch_size, n_heads, q_len, kv_len, p, seed, offset, |
3611 | 3620 | @parametrize("enable_gqa", [True, False]) |
3612 | 3621 | @parametrize("n_heads", [[16, 8], [10, 2]]) |
3613 | 3622 | @tf32_enabled() |
| 3623 | + @math_sdp_precision("ieee") |
3614 | 3624 | def test_flash_attention_vs_math_ref_grads(self, device, batch_size: int, seq_len_q: int, seq_len_k: int, |
3615 | 3625 | head_dim: int, is_causal: bool, dropout_p: float, dtype: torch.dtype, |
3616 | 3626 | scale: str, enable_gqa: bool, n_heads: list[int]): |
@@ -3756,6 +3766,7 @@ def test_flash_attention_vs_math_ref_grads(self, device, batch_size: int, seq_le |
3756 | 3766 | @parametrize("scale", [None, "l1"]) |
3757 | 3767 | @parametrize("fused_kernel", PLATFORM_SPECIFIC_SDPA) |
3758 | 3768 | @tf32_enabled() |
| 3769 | + @math_sdp_precision("ieee") |
3759 | 3770 | def test_fused_attention_vs_math_ref_grads_cudagraph(self, device, batch_size: int, |
3760 | 3771 | seq_len_q: int, seq_len_k: int, |
3761 | 3772 | head_dim: int, |
@@ -4070,6 +4081,7 @@ def test_fused_kernels_nested_broadcasting_query_dense(self, device): |
4070 | 4081 | @parametrize("dtype", [torch.float16]) |
4071 | 4082 | @parametrize("scale", [None, "l1"]) |
4072 | 4083 | @parametrize("is_causal", [True, False]) |
| 4084 | + @math_sdp_precision("ieee") |
4073 | 4085 | def test_flash_attention_vs_math_ref_grads_nestedtensor(self, device, batch_size: int, max_seq_len_q: int, max_seq_len_kv: int, |
4074 | 4086 | head_dim: int, dropout_p: float, dtype: torch.dtype, |
4075 | 4087 | scale: str, is_causal: bool): |
|
0 commit comments