From 4d5c847d2d58ffdaa3d55dfadcf9393d7f7064d1 Mon Sep 17 00:00:00 2001 From: zhupengyang Date: Tue, 14 Oct 2025 05:48:53 +0000 Subject: [PATCH] [XPU] fix ep --- fastdeploy/engine/expert_service.py | 4 +- .../layers/backends/xpu/moe/ep.py | 86 ++++++++++--------- .../layers/backends/xpu/moe/fused_moe.py | 2 +- .../layers/moe/fused_moe_backend_base.py | 19 ++-- 4 files changed, 64 insertions(+), 47 deletions(-) diff --git a/fastdeploy/engine/expert_service.py b/fastdeploy/engine/expert_service.py index 4a08fe07514..f3573d22371 100644 --- a/fastdeploy/engine/expert_service.py +++ b/fastdeploy/engine/expert_service.py @@ -103,7 +103,7 @@ def start( llm_logger.info(f"start expert service {local_data_parallel_id}") - if self.cfg.splitwise_role != "mixed" or self.cfg.cache_config.enable_prefix_caching: + if self.cfg.scheduler_config.splitwise_role != "mixed" or self.cfg.cache_config.enable_prefix_caching: if self.do_profile: get_profile_block_num = np.zeros([1], dtype=np.int32) while True: @@ -123,7 +123,7 @@ def start( self.cache_manager_processes = self.engine.start_cache_service( self.cfg.local_device_ids, ipc_signal_suffix_cache ) - if self.cfg.splitwise_role != "mixed": + if self.cfg.scheduler_config.splitwise_role != "mixed": self.engine.split_mode_get_tasks() if self.cfg.scheduler_config.name == "splitwise": diff --git a/fastdeploy/model_executor/layers/backends/xpu/moe/ep.py b/fastdeploy/model_executor/layers/backends/xpu/moe/ep.py index c2ec2b9ae98..71c2dd600ff 100644 --- a/fastdeploy/model_executor/layers/backends/xpu/moe/ep.py +++ b/fastdeploy/model_executor/layers/backends/xpu/moe/ep.py @@ -23,12 +23,11 @@ import fastdeploy from fastdeploy.config import MoEPhase -from fastdeploy.model_executor.layers.moe.ep import DeepEPEngineBase, EPRunner from fastdeploy.utils import singleton @singleton -class DeepEPEngine(DeepEPEngineBase): +class DeepEPEngine: """ A wrapper class for DeepEP engine. """ @@ -36,7 +35,7 @@ class DeepEPEngine(DeepEPEngineBase): def __init__( self, num_max_dispatch_tokens_per_rank: int, - hidden: int, + hidden_size: int, num_experts: int, ep_size: int, ep_rank: int, @@ -52,20 +51,24 @@ def __init__( ep_size: The number of ranks. rank_id: The rank id. num_max_dispatch_tokens_per_rank: The maximum number of tokens per rank to dispatch. - hidden: The hidden dimension of the model. + hidden_size: The hidden_size dimension of the model. num_experts: The number of experts. """ - super().__init__( - num_max_dispatch_tokens_per_rank, - hidden, - num_experts, - ep_size, - ep_rank, - splitwise_role, - moe_phase, - async_finish, - group, - ) + self.num_max_dispatch_tokens_per_rank = num_max_dispatch_tokens_per_rank + self.hidden_size = hidden_size + self.num_experts = num_experts + self.ep_size = ep_size + self.rank_id = ep_rank + self.splitwise_role = splitwise_role + self.moe_phase = moe_phase + self.async_finish = async_finish + # TODO(@wufeisheng): Support configurable EP size​ + if group is None: + group = paddle.distributed.new_group(range(ep_size)) + self.group = group + self.num_local_experts = num_experts // ep_size + self.deepep_engine = None + self.init_deepep_engine() def init_deepep_engine(self): if self.splitwise_role == "mixed" or self.moe_phase.phase == "prefill": @@ -89,14 +92,14 @@ def get_low_latency_buffer(self): Args: group: The MPI group object. num_max_dispatch_tokens_per_rank: The maximum number of tokens per rank to dispatch. - hidden: The hidden dimension of the model. + hidden_size: The hidden_size dimension of the model. """ # NOTES: the low-latency mode will consume much more space than the normal mode # So we recommend that `num_max_dispatch_tokens_per_rank` # (the actual batch size in the decoding engine) should be less than 256 num_rdma_bytes = deep_ep.Buffer.get_low_latency_rdma_size_hint( self.num_max_dispatch_tokens_per_rank, - self.hidden, + self.hidden_size, self.ep_size, self.num_experts, ) @@ -127,12 +130,12 @@ def low_latency_dispatch( ): """ Args: - hidden_states: [token_num, hidden] 'bfloat16/int8' + hidden_states: [token_num, hidden_size] 'bfloat16/int8' topk_idx: [token_num, num_topk] 'int64' Returns: recv_hidden_states: [num_local_experts, - num_max_dispatch_tokens_per_rank * ep_size, hidden] + num_max_dispatch_tokens_per_rank * ep_size, hidden_size] ep_size * num_local_experts = num_experts recv_count: [num_local_experts] recv_count: a tensor shaped `[num_local_experts]` with type `torch.int`, indicating how many tokens each @@ -171,7 +174,7 @@ def low_latency_combine( """ Return: - combined_hidden_states: [num_tokens, hidden] + combined_hidden_states: [num_tokens, hidden_size] """ combined_hidden_states, combine_hook = self.deepep_engine.low_latency_combine( hidden_states, @@ -196,7 +199,7 @@ def barrier_all(self): self.deepep_engine.barrier_all() -class XPUEPRunner(EPRunner): +class XPUEPRunner: """ EPRunnerBase """ @@ -204,7 +207,7 @@ class XPUEPRunner(EPRunner): def __init__( self, top_k: int, - hidden: int, + hidden_size: int, num_experts: int, splitwise_role: str, moe_phase: MoEPhase, @@ -214,23 +217,22 @@ def __init__( redundant_experts_num: int = 0, ep_group=None, ): - super().__init__( - top_k, - hidden, - num_experts, - splitwise_role, - moe_phase, - num_max_dispatch_tokens_per_rank, - ep_size, - ep_rank, - redundant_experts_num, - ep_group, - ) + self.top_k = top_k + self.hidden_size = hidden_size + self.num_experts = num_experts + self.splitwise_role = splitwise_role + self.moe_phase = moe_phase + self.num_max_dispatch_tokens_per_rank = num_max_dispatch_tokens_per_rank + self.ep_size = ep_size + self.ep_rank = ep_rank + self.redundant_experts_num = redundant_experts_num + self.ep_group = ep_group + self.init_ep_engine() def init_ep_engine(self): self.ep_engine = DeepEPEngine( num_max_dispatch_tokens_per_rank=self.num_max_dispatch_tokens_per_rank, - hidden=self.hidden, + hidden_size=self.hidden_size, num_experts=self.num_experts + self.redundant_experts_num, ep_size=self.ep_size, ep_rank=self.ep_rank, @@ -285,6 +287,12 @@ def combine(self, *args, **kwargs): """ raise NotImplementedError + def clean_low_latency_buffer(self): + self.ep_engine.clean_low_latency_buffer() + + def barrier_all(self): + self.ep_engine.barrier_all() + class XPUEPPrefillRunner(XPUEPRunner): """ @@ -294,7 +302,7 @@ class XPUEPPrefillRunner(XPUEPRunner): def __init__( self, top_k: int, - hidden: int, + hidden_size: int, num_experts: int, splitwise_role: str, num_max_dispatch_tokens_per_rank: int, @@ -306,7 +314,7 @@ def __init__( ): super().__init__( top_k, - hidden, + hidden_size, num_experts, splitwise_role, moe_phase, @@ -358,7 +366,7 @@ class XPUEPDecoderRunner(XPUEPRunner): def __init__( self, top_k: int, - hidden: int, + hidden_size: int, num_experts: int, splitwise_role: str, num_max_dispatch_tokens_per_rank: int, @@ -370,7 +378,7 @@ def __init__( ): super().__init__( top_k, - hidden, + hidden_size, num_experts, splitwise_role, moe_phase, diff --git a/fastdeploy/model_executor/layers/backends/xpu/moe/fused_moe.py b/fastdeploy/model_executor/layers/backends/xpu/moe/fused_moe.py index 4276b89a1a5..a7ac1845a21 100644 --- a/fastdeploy/model_executor/layers/backends/xpu/moe/fused_moe.py +++ b/fastdeploy/model_executor/layers/backends/xpu/moe/fused_moe.py @@ -366,7 +366,7 @@ def create_weights(self, layer: nn.Layer, **extra_weight_attrs): self.down_proj_weight_shape = [ layer.num_local_experts, layer.hidden_size, - layer.moe_intermediate // 2, + layer.moe_intermediate_size // 2, ] else: raise ValueError(f"Unsupported moe quant type: {self.moe_quant_type}") diff --git a/fastdeploy/model_executor/layers/moe/fused_moe_backend_base.py b/fastdeploy/model_executor/layers/moe/fused_moe_backend_base.py index d1be7af8036..ef8d2d836c9 100644 --- a/fastdeploy/model_executor/layers/moe/fused_moe_backend_base.py +++ b/fastdeploy/model_executor/layers/moe/fused_moe_backend_base.py @@ -43,6 +43,15 @@ def __init__(self, quant_config): self.ep_prefill_runner = None self.ep_decoder_runner = None + def import_backend_ep_runner(self) -> None: + """ + Different platform has different ep runner. Override this method to import the corresponding EP runner. + """ + from .ep import EPDecoderRunner, EPPrefillRunner + + self.EPPrefillRunner = EPPrefillRunner + self.EPDecoderRunner = EPDecoderRunner + def init_ep(self, layer: nn.Layer) -> None: """ Initialize EP (Expert Parallel) related modules. @@ -51,7 +60,7 @@ def init_ep(self, layer: nn.Layer) -> None: return # Lazy import to avoid circular dependency or unnecessary loading - from .ep import EPDecoderRunner, EPPrefillRunner + self.import_backend_ep_runner() # Common arguments for both runners common_args = { @@ -76,16 +85,16 @@ def init_ep(self, layer: nn.Layer) -> None: # for RL init model without deepep buff return else: - self.ep_prefill_runner = EPPrefillRunner(**common_args) - self.ep_decoder_runner = EPDecoderRunner(**common_args) + self.ep_prefill_runner = self.EPPrefillRunner(**common_args) + self.ep_decoder_runner = self.EPDecoderRunner(**common_args) return # For non-mixed ep phase = config.model_config.moe_phase.phase if phase == "prefill": - self.ep_prefill_runner = EPPrefillRunner(**common_args) + self.ep_prefill_runner = self.EPPrefillRunner(**common_args) else: - self.ep_decoder_runner = EPDecoderRunner(**common_args) + self.ep_decoder_runner = self.EPDecoderRunner(**common_args) def process_loaded_weights(self, layer, weights) -> None: """