diff --git a/tests/jax/test_custom_call_compute.py b/tests/jax/test_custom_call_compute.py index 9ff0c11757..b81f3fb9bf 100644 --- a/tests/jax/test_custom_call_compute.py +++ b/tests/jax/test_custom_call_compute.py @@ -57,8 +57,8 @@ FP8_COMPUTE_TYPE = [jnp.float8_e4m3fn, jnp.float8_e5m2] LN_CASES = [(256, 128), (128, 256)] DTYPES = [jnp.bfloat16, jnp.float32] -is_fp8_supported, reason = helper.is_fp8_available() -is_mxfp8_supported, reason = helper.is_fp8_available(ScalingMode.MXFP8_1D_SCALING) +is_fp8_supported, fp8_unsupported_reason = helper.is_fp8_available() +is_mxfp8_supported, mxfp8_unsupported_reason = helper.is_fp8_available(ScalingMode.MXFP8_1D_SCALING) supported_scaling_modes = [] """ Find supported scaling modes""" @@ -209,7 +209,7 @@ def test_act_grad(self, shape, activation_type): assert_allclose(prim_out, ref_out, dtype=x.dtype) assert_allclose(prim_grad, ref_grad, dtype=x.dtype) - @pytest.mark.skipif(not is_fp8_supported, reason=reason) + @pytest.mark.skipif(not is_fp8_supported, reason=fp8_unsupported_reason) @pytest_parametrize_wrapper("shape", ALL_ACTIVATION_SHAPES) @pytest_parametrize_wrapper("activation_type", ACTIVATION_TYPES) @pytest_parametrize_wrapper("output_type", [jnp.float8_e4m3fn, jnp.float8_e5m2]) @@ -240,7 +240,7 @@ def test_act_grad_with_tensor_scaling_fp8( assert_allclose(prim_out, ref_out, dtype=output_type) assert_allclose(prim_grad, ref_grad, dtype=output_type) - @pytest.mark.skipif(not is_mxfp8_supported, reason=reason) + @pytest.mark.skipif(not is_mxfp8_supported, reason=mxfp8_unsupported_reason) @pytest_parametrize_wrapper("shape", ALL_ACTIVATION_SHAPES) @pytest_parametrize_wrapper("activation_type", ACTIVATION_TYPES) @pytest_parametrize_wrapper("output_type", [jnp.float8_e4m3fn, jnp.float8_e5m2]) @@ -270,7 +270,7 @@ def test_act_forward_with_tensor_scaling_fp8( assert_bitwise_scaled_tensors(te_output, jax_output) - @pytest.mark.skipif(not is_mxfp8_supported, reason=reason) + @pytest.mark.skipif(not is_mxfp8_supported, reason=mxfp8_unsupported_reason) @pytest_parametrize_wrapper("shape", [(2, 64, 1, 256)]) @pytest_parametrize_wrapper("activation_type", ACTIVATION_TYPES) @pytest_parametrize_wrapper("output_type", [jnp.float8_e4m3fn, jnp.float8_e5m2]) @@ -391,7 +391,7 @@ def test_norm_grad(self, n, hidden, norm_type, zero_centered_gamma, epsilon, inp n, hidden, norm_type, zero_centered_gamma, epsilon, inp_dtype, quantizer=None ) - @pytest.mark.skipif(not is_fp8_supported, reason=reason) + @pytest.mark.skipif(not is_fp8_supported, reason=fp8_unsupported_reason) # No Norm FWD E5M2 in TE backend @pytest_parametrize_wrapper("out_dtype", [jnp.float8_e4m3fn]) @pytest_parametrize_wrapper( @@ -506,7 +506,7 @@ def _test_norm_forward( if norm_type == "layernorm": assert_allclose(mu, ref_mu, dtype=inp_dtype) - @pytest.mark.skipif(not is_fp8_supported, reason=reason) + @pytest.mark.skipif(not is_fp8_supported, reason=fp8_unsupported_reason) # No Norm FWD E5M2 in TE backend @pytest_parametrize_wrapper("out_dtype", [jnp.float8_e4m3fn]) @pytest_parametrize_wrapper( @@ -542,7 +542,7 @@ def test_norm_forward_with_tensor_scaling_fp8( q_layout=q_layout, ) - @pytest.mark.skipif(not is_mxfp8_supported, reason=reason) + @pytest.mark.skipif(not is_mxfp8_supported, reason=mxfp8_unsupported_reason) @pytest.mark.parametrize("out_dtype", [jnp.float8_e4m3fn, jnp.float8_e5m2]) def test_norm_forward_with_block_scaling_fp8( self, n, hidden, norm_type, zero_centered_gamma, epsilon, inp_dtype, out_dtype @@ -591,7 +591,7 @@ def test_norm_forward_with_block_scaling_fp8( } -@pytest.mark.skipif(not is_fp8_supported, reason=reason) +@pytest.mark.skipif(not is_fp8_supported, reason=fp8_unsupported_reason) @pytest_parametrize_wrapper("in_dtype", QUANTIZATION_INPUT_DTYPE) @pytest_parametrize_wrapper("q_dtype", [jnp.float8_e4m3fn, jnp.float8_e5m2]) @pytest_parametrize_wrapper("input_shape,flatten_axis", ALL_QUANTIZE_TEST_SHAPES_AND_FLATTEN_AXES) @@ -638,7 +638,7 @@ def test_quantize_bitwise( assert_bitwise_scaled_tensors(te_output, jax_output) -@pytest.mark.skipif(not is_fp8_supported, reason=reason) +@pytest.mark.skipif(not is_fp8_supported, reason=fp8_unsupported_reason) @pytest_parametrize_wrapper("in_dtype", QUANTIZATION_INPUT_DTYPE) @pytest_parametrize_wrapper("input_shape", [(8, 16, 32)]) @pytest_parametrize_wrapper("q_dtype", [jnp.float8_e4m3fn]) @@ -692,7 +692,7 @@ def test_grouped_qdq( @pytest_parametrize_wrapper("in_dtype", QUANTIZATION_INPUT_DTYPE) class TestFusedQuantize: - @pytest.mark.skipif(not is_fp8_supported, reason=reason) + @pytest.mark.skipif(not is_fp8_supported, reason=fp8_unsupported_reason) @pytest_parametrize_wrapper("scaling_mode", supported_scaling_modes) @pytest_parametrize_wrapper("input_shape,flatten_axis", QUANTIZE_TEST_SHAPES_AND_FLATTEN_AXES) @pytest_parametrize_wrapper("out_dtype", QUANTIZE_OUTPUT_DTYPES) @@ -793,7 +793,7 @@ def test_quantize_dact_dbias_no_quantization( q_layout=QuantizeLayout.ROWWISE, ) - @pytest.mark.skipif(not is_fp8_supported, reason=reason) + @pytest.mark.skipif(not is_fp8_supported, reason=fp8_unsupported_reason) @pytest_parametrize_wrapper("activation_type", ACTIVATION_TYPES) @pytest_parametrize_wrapper("input_shape", ALL_ACTIVATION_SHAPES) @pytest_parametrize_wrapper("out_dtype", QUANTIZE_OUTPUT_DTYPES) @@ -817,7 +817,7 @@ def test_quantize_dact_dbias_tensor_scaling( q_layout=q_layout, ) - @pytest.mark.skipif(not is_mxfp8_supported, reason=reason) + @pytest.mark.skipif(not is_mxfp8_supported, reason=mxfp8_unsupported_reason) @pytest_parametrize_wrapper("activation_type", ACTIVATION_TYPES) @pytest_parametrize_wrapper( "input_shape", [s for s in ALL_ACTIVATION_SHAPES if is_shape_supported_by_mxfp8(s)] @@ -886,7 +886,7 @@ def test_gemm_bf16(self, m, n, k, data_layout): assert_allclose(primitive_out, ref_out, dtype=jnp.bfloat16) - @pytest.mark.skipif(not is_fp8_supported, reason=reason) + @pytest.mark.skipif(not is_fp8_supported, reason=fp8_unsupported_reason) @pytest_parametrize_wrapper("m,n,k", [(64, 32, 64)]) @pytest_parametrize_wrapper("q_dtype", [jnp.float8_e4m3fn, jnp.float8_e5m2]) @pytest_parametrize_wrapper("scaling_mode", supported_scaling_modes) @@ -928,7 +928,7 @@ def ref_func(x, w, data_layout): assert_allclose(primitive_x_grad, ref_x_grad, dtype=jnp.bfloat16) assert_allclose(primitive_w_grad, ref_w_grad, dtype=jnp.bfloat16) - @pytest.mark.skipif(not is_fp8_supported, reason=reason) + @pytest.mark.skipif(not is_fp8_supported, reason=fp8_unsupported_reason) @pytest_parametrize_wrapper("m,n,k", [(64, 32, 64)]) @pytest_parametrize_wrapper("q_dtype", [jnp.float8_e4m3fn, jnp.float8_e5m2]) @pytest_parametrize_wrapper("scaling_mode", supported_scaling_modes) @@ -992,7 +992,7 @@ def _ref_jax_norm_impl(x, gamma, beta, norm_type, zero_centered_gamma, eps, quan class TestFusedDense: - @pytest.mark.skipif(not is_fp8_supported, reason=reason) + @pytest.mark.skipif(not is_fp8_supported, reason=fp8_unsupported_reason) @pytest.mark.parametrize("m,n,k", [(64, 32, 64)]) @pytest.mark.parametrize("q_dtype", [jnp.float8_e4m3fn, jnp.float8_e5m2]) @pytest.mark.parametrize("scaling_mode", supported_scaling_modes) @@ -1077,7 +1077,7 @@ def ref_func(x, w, gamma, beta): if beta is not None: assert_allclose(prim_beta_grad, ref_beta_grad, dtype=q_dtype) - @pytest.mark.skipif(not is_fp8_supported, reason=reason) + @pytest.mark.skipif(not is_fp8_supported, reason=fp8_unsupported_reason) @pytest.mark.parametrize("m,n,k", [(64, 32, 64)]) @pytest.mark.parametrize("activation_type", [("gelu",), ("gelu", "linear")]) @pytest.mark.parametrize("q_dtype", [jnp.float8_e4m3fn, jnp.float8_e5m2]) @@ -1284,7 +1284,7 @@ def test_grouped_gemm_fp16(self, dtype, input_shape, layout): prim_out = tex.grouped_gemm(lhs, rhs, group_sizes, contracting_dims) self._assert_grouped_gemm_output(prim_out, group_sizes, ref_out, dtype) - @pytest.mark.skipif(not is_fp8_supported, reason=reason) + @pytest.mark.skipif(not is_fp8_supported, reason=fp8_unsupported_reason) @pytest.mark.parametrize("fwd_bwd_dtype", fwd_bwd_dtypes) @pytest_parametrize_wrapper("scaling_mode", supported_scaling_modes) @pytest_parametrize_wrapper("layout", ["NN"]) @@ -1360,7 +1360,7 @@ def test_grouped_dense_grad_fp16(self, dtype, input_shape): assert_allclose(prim_wgrad, ref_wgrad, dtype=dtype) assert_allclose(prim_dbias, ref_dbias, dtype=dtype) - @pytest.mark.skipif(not is_fp8_supported, reason=reason) + @pytest.mark.skipif(not is_fp8_supported, reason=fp8_unsupported_reason) @pytest.mark.parametrize( "fwd_bwd_dtype", [(jnp.float8_e4m3fn, jnp.float8_e4m3fn), (jnp.float8_e4m3fn, jnp.float8_e5m2)],