diff --git a/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_trtllm_moe.py b/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_trtllm_moe.py index 0717fc292a8..5eb6bcfaa38 100644 --- a/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_trtllm_moe.py +++ b/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_trtllm_moe.py @@ -119,7 +119,7 @@ def relu2(x: torch.Tensor) -> torch.Tensor: def _get_test_data( - otype, wtype, batch_size, hidden_size, num_experts, intermediate_size, X_GEN_SCALE + otype, wtype, batch_size, hidden_size, num_experts, intermediate_size, X_GEN_SCALE, W_GEN_SCALE ): input_shape = (batch_size, hidden_size) w31_shape = (num_experts, 2 * intermediate_size, hidden_size) @@ -127,8 +127,8 @@ def _get_test_data( x = cast_to_representable(gen_tensor(input_shape, otype, scale=X_GEN_SCALE)) router_logits = gen_tensor((batch_size, num_experts), otype) - w31_weight = gen_tensor(w31_shape, otype, wtype) - w2_weight = gen_tensor(w2_shape, otype, wtype) + w31_weight = gen_tensor(w31_shape, otype, wtype, W_GEN_SCALE) + w2_weight = gen_tensor(w2_shape, otype, wtype, W_GEN_SCALE) w31_empty_scales = torch.empty(num_experts, 2, dtype=otype).cuda() w2_empty_scales = torch.empty(num_experts, 1, dtype=otype).cuda() return x, router_logits, w31_weight, w2_weight, w31_empty_scales, w2_empty_scales @@ -203,9 +203,17 @@ def test_trtllm_fused_moe( X_GEN_SCALE = 1.0 else: X_GEN_SCALE = 0.5 + W_GEN_SCALE = 0.1 x, router_logits, w31_weight, w2_weight, w31_scales, w2_scales = _get_test_data( - otype, wtype, batch_size, hidden_size, num_experts, intermediate_size, X_GEN_SCALE + otype, + wtype, + batch_size, + hidden_size, + num_experts, + intermediate_size, + X_GEN_SCALE, + W_GEN_SCALE, ) routing_weights, selected_experts = compute_routing(router_logits, top_k) @@ -278,14 +286,14 @@ def get_fc1_expert_weights( w1_weight.contiguous(), w2_weight.contiguous(), )[0].view(x.shape) - torch.testing.assert_close(output_triton_moe, ad_test_output, rtol=1e-1, atol=1e-1) + torch.testing.assert_close(output_triton_moe, ad_test_output, rtol=1e-2, atol=1e-2) diff = (ref_output - ad_test_output).abs() print(f"max diff: {diff.max()}") torch.testing.assert_close(ad_test_output, trtllm_test_output, rtol=1e-6, atol=1e-6) _print_diff_if(lambda diff: diff.max() > 1e-1, diff, ad_test_output, ref_output) - torch.testing.assert_close(ref_output, ad_test_output, rtol=1e-1, atol=1e-1) + torch.testing.assert_close(ref_output, ad_test_output, rtol=1e-2, atol=1e-2) FP8_TEST_DTYPES = [ @@ -305,7 +313,7 @@ def get_fc1_expert_weights( not fp8_compatible() or not trtllm_ops_available(), reason="Requires fp8 and trtllm support", ) -def test_trtllm_fused_fp8moe( +def test_trtllm_fused_moe_fp8( batch_size, hidden_size, num_experts, @@ -333,7 +341,9 @@ def test_trtllm_fused_fp8moe( else: X_GEN_SCALE = 0.5 - def dequantize_weights(w31_weight, w2_weight, w31_scales, w2_scales): + W_GEN_SCALE = 0.1 + + def dequantize_weights(w31_weight, w2_weight, w31_scales, w2_scales, W_GEN_SCALE): # input_shape = (batch_size, hidden_size) w31_shape = (num_experts, 2 * intermediate_size, hidden_size) w2_shape = (num_experts, hidden_size, intermediate_size) @@ -341,8 +351,8 @@ def dequantize_weights(w31_weight, w2_weight, w31_scales, w2_scales): w31_dequantized = gen_tensor(w31_weight.shape, otype) w2_dequantized = gen_tensor(w2_weight.shape, otype) for expert_id in range(num_experts): - w31 = cast_to_representable(gen_tensor(w31_shape[1:], otype, scale=0.1)) - w2 = cast_to_representable(gen_tensor(w2_shape[1:], otype, scale=0.09)) + w31 = cast_to_representable(gen_tensor(w31_shape[1:], otype, scale=W_GEN_SCALE)) + w2 = cast_to_representable(gen_tensor(w2_shape[1:], otype, scale=W_GEN_SCALE)) w31_quant, s31 = dynamic_per_tensor_fp8_quant(w31) w2_quant, s2 = dynamic_per_tensor_fp8_quant(w2) w31_weight.data[expert_id].copy_(w31_quant) @@ -354,11 +364,18 @@ def dequantize_weights(w31_weight, w2_weight, w31_scales, w2_scales): return w31_dequantized, w2_dequantized x, router_logits, w31_weight, w2_weight, w31_scales, w2_scales = _get_test_data( - otype, wtype, batch_size, hidden_size, num_experts, intermediate_size, X_GEN_SCALE + otype, + wtype, + batch_size, + hidden_size, + num_experts, + intermediate_size, + X_GEN_SCALE, + W_GEN_SCALE, ) w31_dequantized, w2_dequantized = dequantize_weights( - w31_weight, w2_weight, w31_scales, w2_scales + w31_weight, w2_weight, w31_scales, w2_scales, W_GEN_SCALE ) routing_weights, selected_experts = compute_routing(router_logits, top_k)