Skip to content

Commit 84024f9

Browse files
Add torch.backends.cuda.math_sdp.fp32_precision (#2844)
**Overview** This PR adds a new float32 precision API torch.backends.cuda.math_sdp.fp32_precision to configure fp32 precision behavior of SDPBackend.MATH **Rationale** The test/test_transformers.py testing suite calculates the numerical tolerance by comparing output tensors from the same precision ("reference") and higher precision ("golden"), both calculated by SDPBackend.MATH. However, the golden output is calculated with TF32 rather than FP32, which in fact is less accurate than the FA/ME backend if they used IEEE rather than TF32 for their accumulation. The loss of precison causes false negatives in SDPA tests like TestSDPACudaOnlyCUDA.test_flash_attention_vs_math_ref_grads_batch_size_8_seq_len_q_143_seq_len_k_4_head_dim_203_is_causal_False_dropout_p_0_22_float16_scale_l1_enable_gqa_True_n_heads1_cuda_float16 , at least on ROCM platform. The false negative disappears after forcing higher_precision_dtype = torch.float64 **Major Changes** To restore the precision of golden output, a new API torch.backends.cuda.math_sdp.fp32_precision is introduced, which allows configuration of "matmul" precision during SDPBackend.MATH, and a new decorator @math_sdp_precision("ieee") is added to all tests that use check_out_and_grad. At last, an assert is added to the inner most function _check_equal as a sanity check to ensure math_sdp has the right precison configured for torch.float32 golden tensors. **Known Issues** The backward phase honors the configuration when calling backward(), regardless the configuration when creating the graph. --------- Co-authored-by: Xinya Zhang <Xinya.Zhang@amd.com>
1 parent a47ec2b commit 84024f9

File tree

7 files changed

+75
-1
lines changed

7 files changed

+75
-1
lines changed

aten/src/ATen/Context.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,8 @@ Float32Op str2op(const std::string& name) {
7070
return Float32Op::RNN;
7171
else if (name == "matmul")
7272
return Float32Op::MATMUL;
73+
else if (name == "math_sdp")
74+
return Float32Op::MATH_SDP;
7375
TORCH_CHECK(false, "Unknown op: ", name);
7476
}
7577

aten/src/ATen/Context.h

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ enum class CuBLASReductionOption : uint8_t {
4646
DisallowReducedPrecisionDisallowSplitK = 2,
4747
};
4848
enum class TORCH_API Float32Backend { GENERIC, CUDA, MKLDNN };
49-
enum class TORCH_API Float32Op { ALL, CONV, RNN, MATMUL };
49+
enum class TORCH_API Float32Op { ALL, CONV, RNN, MATMUL, MATH_SDP };
5050
enum class TORCH_API Float32Precision { NONE, IEEE, TF32, BF16 };
5151

5252
TORCH_API Float32Backend str2backend(const std::string& name);
@@ -512,6 +512,7 @@ class TORCH_API Context {
512512
float32_matmul_precision == at::Float32MatmulPrecision::HIGHEST
513513
? Float32Precision::NONE
514514
: Float32Precision::TF32},
515+
{{Float32Backend::CUDA, Float32Op::MATH_SDP}, Float32Precision::NONE},
515516
};
516517

517518
Allocator* prev_allocator_ptr_{nullptr};
@@ -684,6 +685,36 @@ struct TORCH_API NoTF32Guard {
684685
bool changed = false;
685686
};
686687

688+
template <Float32Backend target_backend, Float32Op target_op>
689+
struct Fp32PrecisonGuard {
690+
Fp32PrecisonGuard(const Float32Precision new_precision) {
691+
if (new_precision == Float32Precision::NONE) {
692+
return;
693+
}
694+
saved_precision =
695+
globalContext().float32Precision(target_backend, target_op);
696+
changed = (new_precision != saved_precision);
697+
if (changed) {
698+
globalContext().setFloat32Precision(
699+
target_backend, target_op, new_precision);
700+
}
701+
}
702+
Fp32PrecisonGuard(Fp32PrecisonGuard&& other) = delete;
703+
Fp32PrecisonGuard(const Fp32PrecisonGuard&) = delete;
704+
Fp32PrecisonGuard& operator=(const Fp32PrecisonGuard&) = delete;
705+
Fp32PrecisonGuard& operator=(Fp32PrecisonGuard&&) = delete;
706+
~Fp32PrecisonGuard() {
707+
if (changed) {
708+
globalContext().setFloat32Precision(
709+
target_backend, target_op, saved_precision);
710+
}
711+
}
712+
713+
private:
714+
Float32Precision saved_precision;
715+
bool changed = false;
716+
};
717+
687718
struct TORCH_API ROCmBackwardPassGuard {
688719
ROCmBackwardPassGuard();
689720
ROCmBackwardPassGuard(ROCmBackwardPassGuard&& other) = delete;

aten/src/ATen/native/transformers/attention.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -868,6 +868,11 @@ std::tuple<Tensor, Tensor> _scaled_dot_product_attention_math(
868868
? value.to(at::kFloat)
869869
: value;
870870
auto attn_mask = attn_mask_;
871+
const auto math_sdp_precision = at::globalContext().float32Precision(at::Float32Backend::CUDA, at::Float32Op::MATH_SDP);
872+
// Temporarily override matmul precision with value from cuda.math_sdp
873+
// IEEE should be used when use fp32+math backend as golden reference.
874+
at::Fp32PrecisonGuard<at::Float32Backend::CUDA, at::Float32Op::MATMUL> fp32guard(math_sdp_precision);
875+
871876
// Naive, composite implementation defined here.
872877

873878
// Scale q, k before matmul for stability see https://tinyurl.com/sudb9s96 for

docs/source/conf.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2140,6 +2140,7 @@
21402140
"PropModule",
21412141
# torch.backends.cuda
21422142
"cuBLASModule",
2143+
"MathSDPModule",
21432144
"cuFFTPlanCache",
21442145
"cuFFTPlanCacheAttrContextProp",
21452146
"cuFFTPlanCacheManager",

test/test_transformers.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@
5252
PLATFORM_SUPPORTS_CUDNN_ATTENTION,
5353
tf32_on_and_off,
5454
tf32_enabled,
55+
math_sdp_precision,
5556
)
5657

5758
if TEST_FAIRSEQ:
@@ -126,6 +127,12 @@ def _check_equal(
126127
_check_equal(gold, ref, tst, fudge_factor, tensor_name)
127128
return
128129

130+
if golden.is_cuda and golden.dtype == torch.float32:
131+
assert torch.backends.cuda.math_sdp.fp32_precision == "ieee", (
132+
"Testing script error: FP32 golden tensor must be calculated with IEEE"
133+
" precision. Add @math_sdp_precision('ieee') to related tests to fix it."
134+
)
135+
129136
# Compute error between golden
130137
test_error = (golden - test).abs().max()
131138
ref_error = (golden - reference).abs().max()
@@ -3383,6 +3390,7 @@ def test_mem_eff_backwards_determinism(self, device):
33833390
)
33843391
@parametrize("scale", [None, "l1"])
33853392
@tf32_enabled()
3393+
@math_sdp_precision("ieee")
33863394
def test_mem_efficient_attention_vs_math_ref_grads(self, device, batch_size: int, seq_len_q: int, seq_len_k: int,
33873395
head_dim: int, is_causal: bool, dropout_p: float, dtype: torch.dtype,
33883396
scale: str):
@@ -3498,6 +3506,7 @@ def _get_mem_eff_drop_mask(batch_size, n_heads, q_len, kv_len, p, seed, offset,
34983506
)
34993507
@parametrize("scale", [None, "l1"])
35003508
@tf32_enabled()
3509+
@math_sdp_precision("ieee")
35013510
def test_mem_efficient_attention_attn_mask_vs_math_ref_grads(self, device, batch_size: int, seq_len_q: int,
35023511
seq_len_k: int, head_dim: int, is_causal: bool,
35033512
dropout_p: float, dtype: torch.dtype,
@@ -3611,6 +3620,7 @@ def _get_mem_eff_drop_mask(batch_size, n_heads, q_len, kv_len, p, seed, offset,
36113620
@parametrize("enable_gqa", [True, False])
36123621
@parametrize("n_heads", [[16, 8], [10, 2]])
36133622
@tf32_enabled()
3623+
@math_sdp_precision("ieee")
36143624
def test_flash_attention_vs_math_ref_grads(self, device, batch_size: int, seq_len_q: int, seq_len_k: int,
36153625
head_dim: int, is_causal: bool, dropout_p: float, dtype: torch.dtype,
36163626
scale: str, enable_gqa: bool, n_heads: list[int]):
@@ -3756,6 +3766,7 @@ def test_flash_attention_vs_math_ref_grads(self, device, batch_size: int, seq_le
37563766
@parametrize("scale", [None, "l1"])
37573767
@parametrize("fused_kernel", PLATFORM_SPECIFIC_SDPA)
37583768
@tf32_enabled()
3769+
@math_sdp_precision("ieee")
37593770
def test_fused_attention_vs_math_ref_grads_cudagraph(self, device, batch_size: int,
37603771
seq_len_q: int, seq_len_k: int,
37613772
head_dim: int,
@@ -4070,6 +4081,7 @@ def test_fused_kernels_nested_broadcasting_query_dense(self, device):
40704081
@parametrize("dtype", [torch.float16])
40714082
@parametrize("scale", [None, "l1"])
40724083
@parametrize("is_causal", [True, False])
4084+
@math_sdp_precision("ieee")
40734085
def test_flash_attention_vs_math_ref_grads_nestedtensor(self, device, batch_size: int, max_seq_len_q: int, max_seq_len_kv: int,
40744086
head_dim: int, dropout_p: float, dtype: torch.dtype,
40754087
scale: str, is_causal: bool):

torch/backends/cuda/__init__.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
"cuFFTPlanCache",
1313
"cuFFTPlanCacheManager",
1414
"cuBLASModule",
15+
"MathSDPModule",
1516
"preferred_linalg_library",
1617
"preferred_blas_library",
1718
"preferred_rocm_fa_library",
@@ -206,6 +207,18 @@ def __setattr__(self, name, value):
206207
raise AttributeError("Unknown attribute " + name)
207208

208209

210+
class MathSDPModule:
211+
def __getattr__(self, name):
212+
if name == "fp32_precision":
213+
return torch._C._get_fp32_precision_getter("cuda", "math_sdp")
214+
raise AttributeError("Unknown attribute " + name)
215+
216+
def __setattr__(self, name, value):
217+
if name == "fp32_precision":
218+
return torch._C._set_fp32_precision_setter("cuda", "math_sdp", value)
219+
raise AttributeError("Unknown attribute " + name)
220+
221+
209222
_LinalgBackends = {
210223
"default": torch._C._LinalgBackend.Default,
211224
"cusolver": torch._C._LinalgBackend.Cusolver,
@@ -591,3 +604,4 @@ def sdp_kernel(
591604

592605
cufft_plan_cache = cuFFTPlanCacheManager()
593606
matmul = cuBLASModule()
607+
math_sdp = MathSDPModule()

torch/testing/_internal/common_cuda.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -221,6 +221,15 @@ def tf32_enabled():
221221
torch.backends.cuda.matmul.allow_tf32 = old_allow_tf32_matmul
222222

223223

224+
@contextlib.contextmanager
225+
def math_sdp_precision(target_precision: str):
226+
saved_precision = torch.backends.cuda.math_sdp.fp32_precision
227+
try:
228+
torch.backends.cuda.math_sdp.fp32_precision = target_precision
229+
yield
230+
finally:
231+
torch.backends.cuda.math_sdp.fp32_precision = saved_precision
232+
224233
# This is a wrapper that wraps a test to run this test twice, one with
225234
# allow_tf32=True, another with allow_tf32=False. When running with
226235
# allow_tf32=True, it will use reduced precision as specified by the

0 commit comments

Comments
 (0)