From bdfc4a4a08046aa9b3cee52c5663ee56af75c0ac Mon Sep 17 00:00:00 2001 From: zhoutianzi666 <17801055074@163.com> Date: Mon, 27 Oct 2025 14:24:43 +0800 Subject: [PATCH 1/2] remove dev sync in prefill --- fastdeploy/model_executor/layers/moe/ep.py | 1 + .../model_executor/layers/moe/fused_moe_deepgemm_backend.py | 3 +-- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/fastdeploy/model_executor/layers/moe/ep.py b/fastdeploy/model_executor/layers/moe/ep.py index 910c5dd87f6..20ae045a8b5 100644 --- a/fastdeploy/model_executor/layers/moe/ep.py +++ b/fastdeploy/model_executor/layers/moe/ep.py @@ -461,6 +461,7 @@ def dispatch( "async_finish": self.ep_engine.async_finish, "topk_idx": topk_idx, "topk_weights": topk_weights, + "expert_alignment": 128, } return buffer.dispatch(**dispatch_args) diff --git a/fastdeploy/model_executor/layers/moe/fused_moe_deepgemm_backend.py b/fastdeploy/model_executor/layers/moe/fused_moe_deepgemm_backend.py index 06cc3294915..75eca3396b7 100644 --- a/fastdeploy/model_executor/layers/moe/fused_moe_deepgemm_backend.py +++ b/fastdeploy/model_executor/layers/moe/fused_moe_deepgemm_backend.py @@ -345,7 +345,6 @@ def apply_ep_prefill( (recv_x, recv_x_scale) = recv_x token_nums_this_rank = count_tokens_per_expert_func(recv_topk_idx, layer.num_local_experts) - token_nums_this_rank_padded = sum(token_nums_this_rank[1].numpy().tolist()) ( permute_input, @@ -365,7 +364,7 @@ def apply_ep_prefill( token_nums_this_rank[0], token_nums_this_rank[1], True, # use_in_ep - token_nums_this_rank_padded, + token_all_num, ) permute_scale = permute_scale.transpose([1, 0]).contiguous() From 335c511cac6becf72b2dab009b75bfc6f55b1470 Mon Sep 17 00:00:00 2001 From: zhoutianzi666 <17801055074@163.com> Date: Mon, 27 Oct 2025 14:48:39 +0800 Subject: [PATCH 2/2] remove dev sync in prefill --- fastdeploy/model_executor/layers/moe/ep.py | 3 ++- .../model_executor/layers/moe/fused_moe_deepgemm_backend.py | 4 +++- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/fastdeploy/model_executor/layers/moe/ep.py b/fastdeploy/model_executor/layers/moe/ep.py index 20ae045a8b5..eb5742d98f7 100644 --- a/fastdeploy/model_executor/layers/moe/ep.py +++ b/fastdeploy/model_executor/layers/moe/ep.py @@ -435,6 +435,7 @@ def dispatch( x: paddle.Tensor, topk_idx: paddle.Tensor, topk_weights: paddle.Tensor, + expert_alignment: int = 1, *args, **kwargs, ): @@ -461,7 +462,7 @@ def dispatch( "async_finish": self.ep_engine.async_finish, "topk_idx": topk_idx, "topk_weights": topk_weights, - "expert_alignment": 128, + "expert_alignment": expert_alignment, } return buffer.dispatch(**dispatch_args) diff --git a/fastdeploy/model_executor/layers/moe/fused_moe_deepgemm_backend.py b/fastdeploy/model_executor/layers/moe/fused_moe_deepgemm_backend.py index 75eca3396b7..cd20d7eaf07 100644 --- a/fastdeploy/model_executor/layers/moe/fused_moe_deepgemm_backend.py +++ b/fastdeploy/model_executor/layers/moe/fused_moe_deepgemm_backend.py @@ -335,7 +335,9 @@ def apply_ep_prefill( recv_num_tokens_per_expert_list, handle, _, - ) = self.ep_prefill_runner.dispatch(x, topk_idx, topk_weights, x_scale_tensor=x_scale_tensor) + ) = self.ep_prefill_runner.dispatch( + x, topk_idx, topk_weights, x_scale_tensor=x_scale_tensor, expert_alignment=128 + ) token_all_num = sum(recv_num_tokens_per_expert_list)