From cb46e7692eceb3180843203df13561358ff0ee46 Mon Sep 17 00:00:00 2001 From: zhoutianzi666 <17801055074@163.com> Date: Thu, 30 Oct 2025 11:27:04 +0800 Subject: [PATCH] add real gate_correction_bias weight to mock un-balanced dispatch --- tests/layers/test_fusedmoe.py | 399 +++++++++++++++++++++++++++++++++- 1 file changed, 396 insertions(+), 3 deletions(-) diff --git a/tests/layers/test_fusedmoe.py b/tests/layers/test_fusedmoe.py index 8037d9c31a3..ae2e1a4b631 100644 --- a/tests/layers/test_fusedmoe.py +++ b/tests/layers/test_fusedmoe.py @@ -25,6 +25,396 @@ paddle.set_default_dtype("bfloat16") +gate_correction_bias_real_data = paddle.to_tensor( + [ + 32.8339, + 32.8231, + 32.8151, + 32.8131, + 32.8317, + 32.8343, + 32.8356, + 32.8270, + 32.8344, + 32.8342, + 32.8126, + 32.8299, + 32.8282, + 32.8254, + 32.8320, + 32.8280, + 32.8303, + 32.8351, + 32.8364, + 32.8347, + 32.8179, + 32.8349, + 32.8322, + 32.8323, + 32.8360, + 32.8351, + 32.8059, + 32.8352, + 32.8303, + 32.8334, + 32.8283, + 32.8265, + 32.8344, + 32.8307, + 32.8271, + 32.8343, + 32.8326, + 32.8327, + 32.8349, + 32.8356, + 32.8303, + 32.8327, + 32.8310, + 32.8363, + 32.8274, + 32.8335, + 32.8350, + 32.8255, + 32.8298, + 32.8141, + 32.8218, + 32.8362, + 32.8126, + 32.7902, + 32.8314, + 32.8356, + 32.8177, + 32.8333, + 32.8352, + 32.8354, + 32.8334, + 32.8325, + 32.7971, + 32.8319, + 32.8222, + 32.8284, + 32.8288, + 32.8355, + 32.8351, + 32.8356, + 32.8338, + 32.8346, + 32.7737, + 32.8317, + 32.8357, + 32.8345, + 32.8347, + 32.8360, + 32.8289, + 32.8268, + 32.8164, + 32.8324, + 32.8363, + 32.8308, + 32.8352, + 32.8302, + 32.8345, + 32.8298, + 32.8057, + 32.8229, + 32.8355, + 32.8325, + 32.8350, + 32.8357, + 32.8315, + 32.8327, + 32.8263, + 32.8342, + 32.8165, + 32.8349, + 32.8310, + 32.8101, + 32.8101, + 32.8081, + 32.8341, + 32.8313, + 32.8331, + 32.8299, + 32.8320, + 32.7941, + 32.8277, + 32.8287, + 32.8326, + 32.8331, + 32.8360, + 32.8295, + 32.8255, + 32.8330, + 32.8279, + 32.8210, + 32.7921, + 32.8348, + 32.8271, + 32.8297, + 32.8211, + 32.8353, + 32.8339, + 32.8335, + 32.8275, + 32.8245, + 32.8287, + 32.8352, + 32.8318, + 32.8354, + 32.8110, + 32.8347, + 32.8340, + 32.8322, + 32.8341, + 32.8316, + 32.8328, + 32.8341, + 32.8354, + 32.8264, + 32.8362, + 32.8352, + 32.8293, + 32.8292, + 32.8328, + 32.8316, + 32.8329, + 32.8308, + 32.8307, + 32.8170, + 32.8345, + 32.8356, + 32.8176, + 32.8326, + 32.8288, + 32.8355, + 32.8346, + 32.8337, + 32.8049, + 32.8315, + 32.8337, + 32.8352, + 32.7991, + 32.8304, + 32.8348, + 32.8316, + 32.8358, + 32.8279, + 32.8348, + 32.8326, + 32.8215, + 32.8281, + 32.8344, + 32.8309, + 32.8355, + 32.8337, + 32.8276, + 32.8250, + 32.8340, + 32.8322, + 32.8317, + 32.8274, + 32.8363, + 32.8277, + 32.8345, + 32.8342, + 32.8343, + 32.8355, + 32.8326, + 32.8299, + 32.8322, + 32.8351, + 32.8356, + 32.7925, + 32.8362, + 32.8170, + 32.8323, + 32.8335, + 32.8339, + 32.8193, + 32.8340, + 32.8362, + 32.8323, + 32.8328, + 32.8328, + 32.8296, + 32.8297, + 32.8344, + 32.8254, + 32.8341, + 32.8345, + 32.7967, + 32.8228, + 32.8363, + 32.8356, + 32.8317, + 32.8362, + 32.8302, + 32.8356, + 32.8239, + 32.8304, + 32.8323, + 32.8335, + 32.8196, + 32.8354, + 32.6991, + 32.8350, + 32.8337, + 32.8314, + 32.8274, + 32.8232, + 32.8305, + 32.8349, + 32.8246, + 32.8343, + 32.8339, + 32.7849, + 32.8359, + 32.8353, + 32.8352, + 32.8348, + 32.8095, + 32.8301, + 32.8350, + 32.8340, + 32.8353, + 32.8343, + 32.8344, + 32.8312, + 32.8350, + 32.8327, + 32.8231, + 32.8325, + 32.8352, + 32.8352, + 32.8293, + 32.8357, + 32.8337, + 32.8335, + 32.8348, + 32.8321, + 32.8153, + 32.8352, + 32.8265, + 32.8326, + 32.8361, + 32.8357, + 32.8312, + 32.8347, + 32.8152, + 32.8340, + 32.8272, + 32.8352, + 32.8331, + 32.8324, + 32.7952, + 32.8170, + 32.8356, + 32.8360, + 32.8298, + 32.8356, + 32.8331, + 32.8317, + 32.8349, + 32.8269, + 32.8323, + 32.8354, + 32.8350, + 32.8226, + 32.8002, + 32.8205, + 32.8329, + 32.8319, + 32.8297, + 32.8282, + 32.8356, + 32.8303, + 32.8349, + 32.8337, + 32.8247, + 32.8279, + 32.8309, + 32.8225, + 32.8337, + 32.8356, + 32.8105, + 32.8353, + 32.8361, + 32.8297, + 32.8313, + 32.8313, + 32.8363, + 32.8357, + 32.8357, + 32.8363, + 32.7806, + 32.8306, + 32.8347, + 32.8248, + 32.8334, + 32.8356, + 32.8324, + 32.8327, + 32.8284, + 32.8351, + 32.8349, + 32.8351, + 32.8171, + 32.8317, + 32.8363, + 32.8346, + 32.8335, + 32.8307, + 32.7907, + 32.8229, + 32.8346, + 32.8298, + 32.8336, + 32.8313, + 32.8349, + 32.8219, + 32.8354, + 32.8337, + 32.8294, + 32.8306, + 32.8322, + 32.8290, + 32.8333, + 32.8327, + 32.8279, + 32.8283, + 32.8338, + 32.8310, + 32.8351, + 32.8171, + 32.8310, + 32.8323, + 32.8324, + 32.8215, + 32.8314, + 32.8333, + 32.8353, + 32.8184, + 32.8344, + 32.8280, + 32.8352, + 32.8361, + 32.8308, + 32.8271, + 32.8335, + 32.8236, + 32.8350, + 32.8325, + 32.8330, + 32.8228, + 32.8352, + 32.8258, + 32.8343, + 32.8338, + 32.8292, + ], + dtype="float32", +) + class FuseMoEWrapper(paddle.nn.Layer): def __init__( @@ -90,6 +480,7 @@ def __init__( topk_group=4, n_group=8, gate_correction_bias=paddle.zeros([self.fd_config.model_config.moe_num_experts], paddle.float32), + # gate_correction_bias = gate_correction_bias_real_data ) moe_layer = self.fused_moe @@ -179,14 +570,16 @@ def test_fused_moe(self): nnodes = (ep_size + 7) // 8 - fused_moe = FuseMoEWrapper(self.model_config, tp_size, tp_rank, ep_size, ep_rank, nnodes=nnodes) - # 这行代码必须保留,否则影响均匀性! paddle.seed(ep_rank + 100) + fused_moe = FuseMoEWrapper(self.model_config, tp_size, tp_rank, ep_size, ep_rank, nnodes=nnodes) + moe_cuda_graphs = [None] * 100 cache_hidden_states = [None] * 100 - for idx, num_tokens in enumerate([10, 20, 40, 60, 80, 100, 128, 160, 192, 256]): + test_token_nums = [10, 20, 40, 60, 80, 100, 128, 160, 192, 256] + # test_token_nums = [1024 * i for i in [1,2,4,8,16,32]] + for idx, num_tokens in enumerate(test_token_nums): cache_hidden_states[idx] = paddle.rand((num_tokens, self.model_config.hidden_size), dtype=paddle.bfloat16)