From 01e55aeb72e0ed1117a42cfbb2331545cd3165b3 Mon Sep 17 00:00:00 2001 From: Aristotle Martin Date: Thu, 21 May 2026 11:00:41 -0400 Subject: [PATCH] Skip precise MXFP8 JAX comparison on HIP --- tests/jax/test_custom_call_compute.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/jax/test_custom_call_compute.py b/tests/jax/test_custom_call_compute.py index 3d9362cb4..113efbaab 100644 --- a/tests/jax/test_custom_call_compute.py +++ b/tests/jax/test_custom_call_compute.py @@ -560,7 +560,7 @@ def _test_norm_forward( precise_comparison = True - if get_cudnn_version() < (9, 10, 0) and scaling_mode == ScalingMode.MXFP8_1D_SCALING: + if (get_cudnn_version() < (9, 10, 0) or is_hip_extension()) and scaling_mode == ScalingMode.MXFP8_1D_SCALING: # Reduce precision of test as we don't use fused norm below this version CuDNN for MXFP8 and instead # do an unfused norm and quantize with an intermediate cast into in_dtype which can reduce precision precise_comparison = False