Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 19 additions & 19 deletions tests/jax/test_custom_call_compute.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,8 @@
FP8_COMPUTE_TYPE = [jnp.float8_e4m3fn, jnp.float8_e5m2]
LN_CASES = [(256, 128), (128, 256)]
DTYPES = [jnp.bfloat16, jnp.float32]
is_fp8_supported, reason = helper.is_fp8_available()
is_mxfp8_supported, reason = helper.is_fp8_available(ScalingMode.MXFP8_1D_SCALING)
is_fp8_supported, fp8_unsupported_reason = helper.is_fp8_available()
is_mxfp8_supported, mxfp8_unsupported_reason = helper.is_fp8_available(ScalingMode.MXFP8_1D_SCALING)

supported_scaling_modes = []
""" Find supported scaling modes"""
Expand Down Expand Up @@ -209,7 +209,7 @@ def test_act_grad(self, shape, activation_type):
assert_allclose(prim_out, ref_out, dtype=x.dtype)
assert_allclose(prim_grad, ref_grad, dtype=x.dtype)

@pytest.mark.skipif(not is_fp8_supported, reason=reason)
@pytest.mark.skipif(not is_fp8_supported, reason=fp8_unsupported_reason)
@pytest_parametrize_wrapper("shape", ALL_ACTIVATION_SHAPES)
@pytest_parametrize_wrapper("activation_type", ACTIVATION_TYPES)
@pytest_parametrize_wrapper("output_type", [jnp.float8_e4m3fn, jnp.float8_e5m2])
Expand Down Expand Up @@ -240,7 +240,7 @@ def test_act_grad_with_tensor_scaling_fp8(
assert_allclose(prim_out, ref_out, dtype=output_type)
assert_allclose(prim_grad, ref_grad, dtype=output_type)

@pytest.mark.skipif(not is_mxfp8_supported, reason=reason)
@pytest.mark.skipif(not is_mxfp8_supported, reason=mxfp8_unsupported_reason)
@pytest_parametrize_wrapper("shape", ALL_ACTIVATION_SHAPES)
@pytest_parametrize_wrapper("activation_type", ACTIVATION_TYPES)
@pytest_parametrize_wrapper("output_type", [jnp.float8_e4m3fn, jnp.float8_e5m2])
Expand Down Expand Up @@ -270,7 +270,7 @@ def test_act_forward_with_tensor_scaling_fp8(

assert_bitwise_scaled_tensors(te_output, jax_output)

@pytest.mark.skipif(not is_mxfp8_supported, reason=reason)
@pytest.mark.skipif(not is_mxfp8_supported, reason=mxfp8_unsupported_reason)
@pytest_parametrize_wrapper("shape", [(2, 64, 1, 256)])
@pytest_parametrize_wrapper("activation_type", ACTIVATION_TYPES)
@pytest_parametrize_wrapper("output_type", [jnp.float8_e4m3fn, jnp.float8_e5m2])
Expand Down Expand Up @@ -391,7 +391,7 @@ def test_norm_grad(self, n, hidden, norm_type, zero_centered_gamma, epsilon, inp
n, hidden, norm_type, zero_centered_gamma, epsilon, inp_dtype, quantizer=None
)

@pytest.mark.skipif(not is_fp8_supported, reason=reason)
@pytest.mark.skipif(not is_fp8_supported, reason=fp8_unsupported_reason)
# No Norm FWD E5M2 in TE backend
@pytest_parametrize_wrapper("out_dtype", [jnp.float8_e4m3fn])
@pytest_parametrize_wrapper(
Expand Down Expand Up @@ -506,7 +506,7 @@ def _test_norm_forward(
if norm_type == "layernorm":
assert_allclose(mu, ref_mu, dtype=inp_dtype)

@pytest.mark.skipif(not is_fp8_supported, reason=reason)
@pytest.mark.skipif(not is_fp8_supported, reason=fp8_unsupported_reason)
# No Norm FWD E5M2 in TE backend
@pytest_parametrize_wrapper("out_dtype", [jnp.float8_e4m3fn])
@pytest_parametrize_wrapper(
Expand Down Expand Up @@ -542,7 +542,7 @@ def test_norm_forward_with_tensor_scaling_fp8(
q_layout=q_layout,
)

@pytest.mark.skipif(not is_mxfp8_supported, reason=reason)
@pytest.mark.skipif(not is_mxfp8_supported, reason=mxfp8_unsupported_reason)
@pytest.mark.parametrize("out_dtype", [jnp.float8_e4m3fn, jnp.float8_e5m2])
def test_norm_forward_with_block_scaling_fp8(
self, n, hidden, norm_type, zero_centered_gamma, epsilon, inp_dtype, out_dtype
Expand Down Expand Up @@ -591,7 +591,7 @@ def test_norm_forward_with_block_scaling_fp8(
}


@pytest.mark.skipif(not is_fp8_supported, reason=reason)
@pytest.mark.skipif(not is_fp8_supported, reason=fp8_unsupported_reason)
@pytest_parametrize_wrapper("in_dtype", QUANTIZATION_INPUT_DTYPE)
@pytest_parametrize_wrapper("q_dtype", [jnp.float8_e4m3fn, jnp.float8_e5m2])
@pytest_parametrize_wrapper("input_shape,flatten_axis", ALL_QUANTIZE_TEST_SHAPES_AND_FLATTEN_AXES)
Expand Down Expand Up @@ -638,7 +638,7 @@ def test_quantize_bitwise(
assert_bitwise_scaled_tensors(te_output, jax_output)


@pytest.mark.skipif(not is_fp8_supported, reason=reason)
@pytest.mark.skipif(not is_fp8_supported, reason=fp8_unsupported_reason)
@pytest_parametrize_wrapper("in_dtype", QUANTIZATION_INPUT_DTYPE)
@pytest_parametrize_wrapper("input_shape", [(8, 16, 32)])
@pytest_parametrize_wrapper("q_dtype", [jnp.float8_e4m3fn])
Expand Down Expand Up @@ -692,7 +692,7 @@ def test_grouped_qdq(
@pytest_parametrize_wrapper("in_dtype", QUANTIZATION_INPUT_DTYPE)
class TestFusedQuantize:

@pytest.mark.skipif(not is_fp8_supported, reason=reason)
@pytest.mark.skipif(not is_fp8_supported, reason=fp8_unsupported_reason)
@pytest_parametrize_wrapper("scaling_mode", supported_scaling_modes)
@pytest_parametrize_wrapper("input_shape,flatten_axis", QUANTIZE_TEST_SHAPES_AND_FLATTEN_AXES)
@pytest_parametrize_wrapper("out_dtype", QUANTIZE_OUTPUT_DTYPES)
Expand Down Expand Up @@ -793,7 +793,7 @@ def test_quantize_dact_dbias_no_quantization(
q_layout=QuantizeLayout.ROWWISE,
)

@pytest.mark.skipif(not is_fp8_supported, reason=reason)
@pytest.mark.skipif(not is_fp8_supported, reason=fp8_unsupported_reason)
@pytest_parametrize_wrapper("activation_type", ACTIVATION_TYPES)
@pytest_parametrize_wrapper("input_shape", ALL_ACTIVATION_SHAPES)
@pytest_parametrize_wrapper("out_dtype", QUANTIZE_OUTPUT_DTYPES)
Expand All @@ -817,7 +817,7 @@ def test_quantize_dact_dbias_tensor_scaling(
q_layout=q_layout,
)

@pytest.mark.skipif(not is_mxfp8_supported, reason=reason)
@pytest.mark.skipif(not is_mxfp8_supported, reason=mxfp8_unsupported_reason)
@pytest_parametrize_wrapper("activation_type", ACTIVATION_TYPES)
@pytest_parametrize_wrapper(
"input_shape", [s for s in ALL_ACTIVATION_SHAPES if is_shape_supported_by_mxfp8(s)]
Expand Down Expand Up @@ -886,7 +886,7 @@ def test_gemm_bf16(self, m, n, k, data_layout):

assert_allclose(primitive_out, ref_out, dtype=jnp.bfloat16)

@pytest.mark.skipif(not is_fp8_supported, reason=reason)
@pytest.mark.skipif(not is_fp8_supported, reason=fp8_unsupported_reason)
@pytest_parametrize_wrapper("m,n,k", [(64, 32, 64)])
@pytest_parametrize_wrapper("q_dtype", [jnp.float8_e4m3fn, jnp.float8_e5m2])
@pytest_parametrize_wrapper("scaling_mode", supported_scaling_modes)
Expand Down Expand Up @@ -928,7 +928,7 @@ def ref_func(x, w, data_layout):
assert_allclose(primitive_x_grad, ref_x_grad, dtype=jnp.bfloat16)
assert_allclose(primitive_w_grad, ref_w_grad, dtype=jnp.bfloat16)

@pytest.mark.skipif(not is_fp8_supported, reason=reason)
@pytest.mark.skipif(not is_fp8_supported, reason=fp8_unsupported_reason)
@pytest_parametrize_wrapper("m,n,k", [(64, 32, 64)])
@pytest_parametrize_wrapper("q_dtype", [jnp.float8_e4m3fn, jnp.float8_e5m2])
@pytest_parametrize_wrapper("scaling_mode", supported_scaling_modes)
Expand Down Expand Up @@ -992,7 +992,7 @@ def _ref_jax_norm_impl(x, gamma, beta, norm_type, zero_centered_gamma, eps, quan


class TestFusedDense:
@pytest.mark.skipif(not is_fp8_supported, reason=reason)
@pytest.mark.skipif(not is_fp8_supported, reason=fp8_unsupported_reason)
@pytest.mark.parametrize("m,n,k", [(64, 32, 64)])
@pytest.mark.parametrize("q_dtype", [jnp.float8_e4m3fn, jnp.float8_e5m2])
@pytest.mark.parametrize("scaling_mode", supported_scaling_modes)
Expand Down Expand Up @@ -1077,7 +1077,7 @@ def ref_func(x, w, gamma, beta):
if beta is not None:
assert_allclose(prim_beta_grad, ref_beta_grad, dtype=q_dtype)

@pytest.mark.skipif(not is_fp8_supported, reason=reason)
@pytest.mark.skipif(not is_fp8_supported, reason=fp8_unsupported_reason)
@pytest.mark.parametrize("m,n,k", [(64, 32, 64)])
@pytest.mark.parametrize("activation_type", [("gelu",), ("gelu", "linear")])
@pytest.mark.parametrize("q_dtype", [jnp.float8_e4m3fn, jnp.float8_e5m2])
Expand Down Expand Up @@ -1284,7 +1284,7 @@ def test_grouped_gemm_fp16(self, dtype, input_shape, layout):
prim_out = tex.grouped_gemm(lhs, rhs, group_sizes, contracting_dims)
self._assert_grouped_gemm_output(prim_out, group_sizes, ref_out, dtype)

@pytest.mark.skipif(not is_fp8_supported, reason=reason)
@pytest.mark.skipif(not is_fp8_supported, reason=fp8_unsupported_reason)
@pytest.mark.parametrize("fwd_bwd_dtype", fwd_bwd_dtypes)
@pytest_parametrize_wrapper("scaling_mode", supported_scaling_modes)
@pytest_parametrize_wrapper("layout", ["NN"])
Expand Down Expand Up @@ -1360,7 +1360,7 @@ def test_grouped_dense_grad_fp16(self, dtype, input_shape):
assert_allclose(prim_wgrad, ref_wgrad, dtype=dtype)
assert_allclose(prim_dbias, ref_dbias, dtype=dtype)

@pytest.mark.skipif(not is_fp8_supported, reason=reason)
@pytest.mark.skipif(not is_fp8_supported, reason=fp8_unsupported_reason)
@pytest.mark.parametrize(
"fwd_bwd_dtype",
[(jnp.float8_e4m3fn, jnp.float8_e4m3fn), (jnp.float8_e4m3fn, jnp.float8_e5m2)],
Expand Down