From 569972e40832b69605f3b45a6fdbdee9b5754fe4 Mon Sep 17 00:00:00 2001 From: rainyfly <1435317881@qq.com> Date: Wed, 5 Nov 2025 16:12:53 +0800 Subject: [PATCH 1/7] fix bug for PD EP --- fastdeploy/engine/engine.py | 4 +- fastdeploy/engine/request.py | 1 + fastdeploy/envs.py | 1 + .../inter_communicator/engine_worker_queue.py | 38 +++++++++++++------ fastdeploy/worker/gpu_model_runner.py | 4 ++ 5 files changed, 35 insertions(+), 13 deletions(-) diff --git a/fastdeploy/engine/engine.py b/fastdeploy/engine/engine.py index fb57f2abe9b..dd14ecdafdd 100644 --- a/fastdeploy/engine/engine.py +++ b/fastdeploy/engine/engine.py @@ -709,7 +709,9 @@ def launch_components(self): for i in range(self.cfg.parallel_config.data_parallel_size): request_queues_for_dp_ipc.append(multiprocessing.Queue()) self.engine.scheduler.start( - self.cfg.node_rank * self.cfg.worker_num_per_node, request_queues_for_dp_ipc, result_queue_for_dp_ipc + self.cfg.node_rank * self.cfg.worker_num_per_node % self.cfg.worker_num_per_node, + request_queues_for_dp_ipc, + result_queue_for_dp_ipc, ) if not envs.FD_ENABLE_MULTI_API_SERVER: diff --git a/fastdeploy/engine/request.py b/fastdeploy/engine/request.py index d0906685502..8f5a82ff53b 100644 --- a/fastdeploy/engine/request.py +++ b/fastdeploy/engine/request.py @@ -161,6 +161,7 @@ def __init__( self.extend_block_tables = [] # dp self.dp_rank = dp_rank + self.llm_engine_recv_req_timestamp = time.time() @classmethod def from_dict(cls, d: dict): diff --git a/fastdeploy/envs.py b/fastdeploy/envs.py index 1eb7af39490..754daf37dd7 100644 --- a/fastdeploy/envs.py +++ b/fastdeploy/envs.py @@ -153,6 +153,7 @@ "ENCODE_FEATURE_BOS_SK": lambda: os.getenv("ENCODE_FEATURE_BOS_SK"), # Enable offline perf test mode for PD disaggregation "FD_OFFLINE_PERF_TEST_FOR_PD": lambda: int(os.getenv("FD_OFFLINE_PERF_TEST_FOR_PD", "0")), + "FD_ENABLE_E2W_TENSOR_CONVERT": lambda: int(os.getenv("FD_ENABLE_E2W_TENSOR_CONVERT", "1")), } diff --git a/fastdeploy/inter_communicator/engine_worker_queue.py b/fastdeploy/inter_communicator/engine_worker_queue.py index be4880e17a5..2fd7b7162e3 100644 --- a/fastdeploy/inter_communicator/engine_worker_queue.py +++ b/fastdeploy/inter_communicator/engine_worker_queue.py @@ -485,24 +485,38 @@ def _connect_with_retry(self, max_retries: int = 5, interval: int = 3) -> None: @staticmethod def to_tensor(tasks): """ - Convert NumPy arrays in multimodal inputs to PaddlePaddle tensors. + Convert NumPy arrays in multimodal inputs to Paddle tensors. Args: - tasks: List of tasks containing multimodal inputs. + tasks (tuple): ([request], bsz) """ + if (not envs.FD_ENABLE_MAX_PREFILL) or (not envs.FD_ENABLE_E2W_TENSOR_CONVERT): + return try: - if envs.FD_ENABLE_MAX_PREFILL: - llm_logger.debug(f"Convert image to tensor, type: {type(tasks)}") - batch_tasks, _ = tasks - for task in batch_tasks: - if not hasattr(task, "multimodal_inputs"): + batch_tasks, _ = tasks + for task in batch_tasks: + multimodal_inputs = getattr(task, "multimodal_inputs", None) + if not multimodal_inputs: + continue + # tensor keys + tensor_keys = [ + "images", + "patch_idx", + "token_type_ids", + "position_ids", + "attention_mask_offset", + ] + + llm_logger.debug(f"Converting multimodal inputs to tensor...{tensor_keys}") + + for key in tensor_keys: + value = multimodal_inputs.get(key) + if value is None: continue - images = task.multimodal_inputs["images"] - if isinstance(images, np.ndarray): - llm_logger.debug(f"Convert image to tensor, shape: {images.shape}") - task.multimodal_inputs["images"] = paddle.to_tensor(images) + if not isinstance(value, paddle.Tensor): + multimodal_inputs[key] = paddle.to_tensor(value) except Exception as e: - llm_logger.warning(f"Failed to convert to tensor: {e}") + llm_logger.warning(f"Tensor conversion failed: {type(e).__name__}: {e}") @staticmethod def to_numpy(tasks): diff --git a/fastdeploy/worker/gpu_model_runner.py b/fastdeploy/worker/gpu_model_runner.py index 67a09f5c5d4..e4dee856677 100644 --- a/fastdeploy/worker/gpu_model_runner.py +++ b/fastdeploy/worker/gpu_model_runner.py @@ -569,6 +569,10 @@ def insert_tasks_v1(self, req_dicts: List[Request], num_running_requests: int = if request.sampling_params is not None and request.sampling_params.prompt_logprobs is not None: self.prompt_logprobs_reqs[request.request_id] = request has_prefill_task = True + if ( + self.fd_config.scheduler_config.splitwise_role == "decode" + ): # In PD, we continue to decode after P generate first token + self.share_inputs["seq_lens_encoder"][idx : idx + 1] = 0 elif request.task_type.value == RequestType.DECODE.value: # decode task logger.debug(f"Handle decode request {request} at idx {idx}") encoder_block_num = len(request.block_tables) From 66016ea66f372f2ac2d1781fefebd08fdfc2e8ae Mon Sep 17 00:00:00 2001 From: rainyfly <1435317881@qq.com> Date: Wed, 5 Nov 2025 16:18:07 +0800 Subject: [PATCH 2/7] fix --- fastdeploy/inter_communicator/engine_worker_queue.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fastdeploy/inter_communicator/engine_worker_queue.py b/fastdeploy/inter_communicator/engine_worker_queue.py index 2fd7b7162e3..42423b6c5f3 100644 --- a/fastdeploy/inter_communicator/engine_worker_queue.py +++ b/fastdeploy/inter_communicator/engine_worker_queue.py @@ -490,7 +490,7 @@ def to_tensor(tasks): Args: tasks (tuple): ([request], bsz) """ - if (not envs.FD_ENABLE_MAX_PREFILL) or (not envs.FD_ENABLE_E2W_TENSOR_CONVERT): + if (not envs.FD_ENABLE_MAX_PREFILL) and (not envs.FD_ENABLE_E2W_TENSOR_CONVERT): return try: batch_tasks, _ = tasks From 9f51845aa4aac763a8ba7bea2d064d091d152964 Mon Sep 17 00:00:00 2001 From: rainyfly <1435317881@qq.com> Date: Wed, 5 Nov 2025 17:20:15 +0800 Subject: [PATCH 3/7] optimize perf for engine worker queue --- fastdeploy/cache_manager/cache_messager.py | 10 +++++-- fastdeploy/engine/async_llm.py | 11 +++++--- fastdeploy/engine/common_engine.py | 31 ++++++++++++++------- fastdeploy/engine/engine.py | 12 +++++--- fastdeploy/envs.py | 1 + fastdeploy/splitwise/splitwise_connector.py | 7 ++++- fastdeploy/worker/worker_process.py | 11 +++++--- 7 files changed, 58 insertions(+), 25 deletions(-) diff --git a/fastdeploy/cache_manager/cache_messager.py b/fastdeploy/cache_manager/cache_messager.py index e6e6aa15218..78bed925c69 100644 --- a/fastdeploy/cache_manager/cache_messager.py +++ b/fastdeploy/cache_manager/cache_messager.py @@ -132,7 +132,10 @@ def __init__( self.gpu_cache_kvs = gpu_cache_kvs self.rank = rank self.nranks = nranks - address = (pod_ip, engine_worker_queue_port) + if not envs.FD_ENGINE_TASK_QUEUE_WITH_SHM: + address = (pod_ip, engine_worker_queue_port) + else: + address = f"/dev/shm/fd_task_queue_{engine_worker_queue_port}.sock" self.engine_worker_queue = EngineWorkerQueue( address=address, is_server=False, @@ -420,7 +423,10 @@ def __init__( self.gpu_cache_kvs = gpu_cache_kvs self.rank = rank self.nranks = nranks - address = (pod_ip, engine_worker_queue_port) + if not envs.FD_ENGINE_TASK_QUEUE_WITH_SHM: + address = (pod_ip, engine_worker_queue_port) + else: + address = f"/dev/shm/fd_task_queue_{engine_worker_queue_port}.sock" self.engine_worker_queue = EngineWorkerQueue( address=address, is_server=False, diff --git a/fastdeploy/engine/async_llm.py b/fastdeploy/engine/async_llm.py index 11f809b995d..936f834777f 100644 --- a/fastdeploy/engine/async_llm.py +++ b/fastdeploy/engine/async_llm.py @@ -923,10 +923,13 @@ def launch_components(self): 1, self.cfg.parallel_config.data_parallel_size // self.cfg.nnode, ): - address = ( - self.cfg.master_ip, - int(self.cfg.parallel_config.engine_worker_queue_port[i]), - ) + if not envs.FD_ENGINE_TASK_QUEUE_WITH_SHM: + address = ( + self.cfg.master_ip, + int(self.cfg.parallel_config.engine_worker_queue_port[i]), + ) + else: + address = f"/dev/shm/fd_task_queue_{self.cfg.parallel_config.engine_worker_queue_port[i]}.sock" llm_logger.info(f"dp start queue service {address}") self.dp_engine_worker_queue_server.append( EngineWorkerQueue( diff --git a/fastdeploy/engine/common_engine.py b/fastdeploy/engine/common_engine.py index bc5f86853d6..b872455f069 100644 --- a/fastdeploy/engine/common_engine.py +++ b/fastdeploy/engine/common_engine.py @@ -263,10 +263,16 @@ def start_worker_queue_service(self, start_queue): """ start queue service for engine worker communication """ - address = ( - self.cfg.master_ip, - int(self.cfg.parallel_config.engine_worker_queue_port[self.cfg.parallel_config.local_data_parallel_id]), - ) + + if not envs.FD_ENGINE_TASK_QUEUE_WITH_SHM: + address = ( + self.cfg.master_ip, + int( + self.cfg.parallel_config.engine_worker_queue_port[self.cfg.parallel_config.local_data_parallel_id] + ), + ) + else: + address = f"/dev/shm/fd_task_queue_{self.cfg.parallel_config.engine_worker_queue_port[self.cfg.parallel_config.local_data_parallel_id]}.sock" if start_queue and (self.cfg.host_ip == self.cfg.master_ip or self.cfg.master_ip == "0.0.0.0"): self.llm_logger.info(f"Starting engine worker queue server service at {address}") @@ -280,12 +286,17 @@ def start_worker_queue_service(self, start_queue): self.cfg.parallel_config.engine_worker_queue_port[self.cfg.parallel_config.local_data_parallel_id] = str( self.engine_worker_queue_server.get_server_port() ) - address = ( - self.cfg.master_ip, - int( - self.cfg.parallel_config.engine_worker_queue_port[self.cfg.parallel_config.local_data_parallel_id] - ), - ) + if not envs.FD_ENGINE_TASK_QUEUE_WITH_SHM: + address = ( + self.cfg.master_ip, + int( + self.cfg.parallel_config.engine_worker_queue_port[ + self.cfg.parallel_config.local_data_parallel_id + ] + ), + ) + else: + address = f"/dev/shm/fd_task_queue_{self.cfg.parallel_config.engine_worker_queue_port[self.cfg.parallel_config.local_data_parallel_id]}.sock" if self.cfg.cache_config.enable_prefix_caching or self.cfg.scheduler_config.splitwise_role != "mixed": self.cache_task_queue = EngineCacheQueue( diff --git a/fastdeploy/engine/engine.py b/fastdeploy/engine/engine.py index dd14ecdafdd..5aa1041a41a 100644 --- a/fastdeploy/engine/engine.py +++ b/fastdeploy/engine/engine.py @@ -723,10 +723,14 @@ def launch_components(self): 1, self.cfg.parallel_config.data_parallel_size // self.cfg.nnode, ): - address = ( - self.cfg.master_ip, - int(self.cfg.parallel_config.engine_worker_queue_port[i]), - ) + if not envs.FD_ENGINE_TASK_QUEUE_WITH_SHM: + address = ( + self.cfg.master_ip, + int(self.cfg.parallel_config.engine_worker_queue_port[i]), + ) + else: + address = f"/dev/shm/fd_task_queue_{self.cfg.parallel_config.engine_worker_queue_port[i]}.sock" + llm_logger.info(f"dp start queue service {address}") self.dp_engine_worker_queue_server.append( EngineWorkerQueue( diff --git a/fastdeploy/envs.py b/fastdeploy/envs.py index 754daf37dd7..26942afec72 100644 --- a/fastdeploy/envs.py +++ b/fastdeploy/envs.py @@ -154,6 +154,7 @@ # Enable offline perf test mode for PD disaggregation "FD_OFFLINE_PERF_TEST_FOR_PD": lambda: int(os.getenv("FD_OFFLINE_PERF_TEST_FOR_PD", "0")), "FD_ENABLE_E2W_TENSOR_CONVERT": lambda: int(os.getenv("FD_ENABLE_E2W_TENSOR_CONVERT", "1")), + "FD_ENGINE_TASK_QUEUE_WITH_SHM": lambda: int(os.getenv("FD_ENGINE_TASK_QUEUE_WITH_SHM", "0")), } diff --git a/fastdeploy/splitwise/splitwise_connector.py b/fastdeploy/splitwise/splitwise_connector.py index ea402239072..a5bbdb6be7e 100644 --- a/fastdeploy/splitwise/splitwise_connector.py +++ b/fastdeploy/splitwise/splitwise_connector.py @@ -329,8 +329,13 @@ def create_connection(self, port): Parameters: port (int): Port number. """ + if not envs.FD_ENGINE_TASK_QUEUE_WITH_SHM: + address = ("0.0.0.0", int(port)) + else: + address = f"/dev/shm/fd_task_queue_{port}.sock" + self.connect_innode_instances[port] = EngineWorkerQueue( - address=("0.0.0.0", int(port)), + address=address, num_client=self.cfg.parallel_config.tensor_parallel_size, client_id=0, ) diff --git a/fastdeploy/worker/worker_process.py b/fastdeploy/worker/worker_process.py index f84ad66d239..160972caec0 100644 --- a/fastdeploy/worker/worker_process.py +++ b/fastdeploy/worker/worker_process.py @@ -554,10 +554,13 @@ def init_device(self) -> None: def start_task_queue_service(self): # Initialize task queue - task_address = ( - self.parallel_config.pod_ip, - self.parallel_config.engine_worker_queue_port, - ) + if not envs.FD_ENGINE_TASK_QUEUE_WITH_SHM: + task_address = ( + self.parallel_config.pod_ip, + self.parallel_config.engine_worker_queue_port, + ) + else: + task_address = f"/dev/shm/fd_task_queue_{self.parallel_config.engine_worker_queue_port}.sock" logger.info(f"connect task queue address {task_address}") self.task_queue = TaskQueue( address=task_address, From 46a0b98a743b0812e585d9dc5a3a89f88520aad5 Mon Sep 17 00:00:00 2001 From: rainyfly <1435317881@qq.com> Date: Wed, 5 Nov 2025 19:16:27 +0800 Subject: [PATCH 4/7] fix bug --- fastdeploy/engine/common_engine.py | 15 ++------------- 1 file changed, 2 insertions(+), 13 deletions(-) diff --git a/fastdeploy/engine/common_engine.py b/fastdeploy/engine/common_engine.py index b872455f069..72efede205c 100644 --- a/fastdeploy/engine/common_engine.py +++ b/fastdeploy/engine/common_engine.py @@ -283,21 +283,10 @@ def start_worker_queue_service(self, start_queue): local_data_parallel_size=self.cfg.parallel_config.data_parallel_size, ) # Dynamically updates the port value if an anonymous port is used - self.cfg.parallel_config.engine_worker_queue_port[self.cfg.parallel_config.local_data_parallel_id] = str( - self.engine_worker_queue_server.get_server_port() - ) if not envs.FD_ENGINE_TASK_QUEUE_WITH_SHM: - address = ( - self.cfg.master_ip, - int( - self.cfg.parallel_config.engine_worker_queue_port[ - self.cfg.parallel_config.local_data_parallel_id - ] - ), + self.cfg.parallel_config.engine_worker_queue_port[self.cfg.parallel_config.local_data_parallel_id] = ( + str(self.engine_worker_queue_server.get_server_port()) ) - else: - address = f"/dev/shm/fd_task_queue_{self.cfg.parallel_config.engine_worker_queue_port[self.cfg.parallel_config.local_data_parallel_id]}.sock" - if self.cfg.cache_config.enable_prefix_caching or self.cfg.scheduler_config.splitwise_role != "mixed": self.cache_task_queue = EngineCacheQueue( address=( From ad096dfdb57aa5d0a174f82d076dc8f001dddc93 Mon Sep 17 00:00:00 2001 From: rainyfly <1435317881@qq.com> Date: Wed, 5 Nov 2025 19:48:13 +0800 Subject: [PATCH 5/7] fix internode ll two stage --- .../layers/moe/fused_moe_backend_base.py | 18 +++++++++++++++--- 1 file changed, 15 insertions(+), 3 deletions(-) diff --git a/fastdeploy/model_executor/layers/moe/fused_moe_backend_base.py b/fastdeploy/model_executor/layers/moe/fused_moe_backend_base.py index 41b06962da0..dd9dc65f0e2 100644 --- a/fastdeploy/model_executor/layers/moe/fused_moe_backend_base.py +++ b/fastdeploy/model_executor/layers/moe/fused_moe_backend_base.py @@ -108,10 +108,22 @@ def init_ep(self, layer: nn.Layer) -> None: # For non-mixed ep phase = config.model_config.moe_phase.phase - if phase == "prefill": - self.ep_prefill_runner = self.EPPrefillRunner(**common_args) + if current_platform.is_cuda(): + if phase == "prefill": + self.ep_prefill_runner = self.EPPrefillRunner( + **common_args, + use_internode_ll_two_stage=layer.fd_config.parallel_config.use_internode_ll_two_stage, + ) + else: + self.ep_decoder_runner = self.EPDecoderRunner( + **common_args, + use_internode_ll_two_stage=layer.fd_config.parallel_config.use_internode_ll_two_stage, + ) else: - self.ep_decoder_runner = self.EPDecoderRunner(**common_args) + if phase == "prefill": + self.ep_prefill_runner = self.EPPrefillRunner(**common_args) + else: + self.ep_decoder_runner = self.EPDecoderRunner(**common_args) def process_loaded_weights(self, layer, weights) -> None: """ From d5b806f21d7163aa45fd4721dae1299a9ec18dd6 Mon Sep 17 00:00:00 2001 From: rainyfly <1435317881@qq.com> Date: Wed, 5 Nov 2025 21:38:17 +0800 Subject: [PATCH 6/7] fix for ci --- fastdeploy/envs.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fastdeploy/envs.py b/fastdeploy/envs.py index 26942afec72..b0d3ee78676 100644 --- a/fastdeploy/envs.py +++ b/fastdeploy/envs.py @@ -153,7 +153,7 @@ "ENCODE_FEATURE_BOS_SK": lambda: os.getenv("ENCODE_FEATURE_BOS_SK"), # Enable offline perf test mode for PD disaggregation "FD_OFFLINE_PERF_TEST_FOR_PD": lambda: int(os.getenv("FD_OFFLINE_PERF_TEST_FOR_PD", "0")), - "FD_ENABLE_E2W_TENSOR_CONVERT": lambda: int(os.getenv("FD_ENABLE_E2W_TENSOR_CONVERT", "1")), + "FD_ENABLE_E2W_TENSOR_CONVERT": lambda: int(os.getenv("FD_ENABLE_E2W_TENSOR_CONVERT", "0")), "FD_ENGINE_TASK_QUEUE_WITH_SHM": lambda: int(os.getenv("FD_ENGINE_TASK_QUEUE_WITH_SHM", "0")), } From bc5a50f3bd22011c7ab5cafe95405796da30a6e1 Mon Sep 17 00:00:00 2001 From: rainyfly <1435317881@qq.com> Date: Thu, 6 Nov 2025 20:38:26 +0800 Subject: [PATCH 7/7] fix bug --- fastdeploy/engine/common_engine.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/fastdeploy/engine/common_engine.py b/fastdeploy/engine/common_engine.py index 72efede205c..9caf8074b57 100644 --- a/fastdeploy/engine/common_engine.py +++ b/fastdeploy/engine/common_engine.py @@ -287,6 +287,15 @@ def start_worker_queue_service(self, start_queue): self.cfg.parallel_config.engine_worker_queue_port[self.cfg.parallel_config.local_data_parallel_id] = ( str(self.engine_worker_queue_server.get_server_port()) ) + address = ( + self.cfg.master_ip, + int( + self.cfg.parallel_config.engine_worker_queue_port[ + self.cfg.parallel_config.local_data_parallel_id + ] + ), + ) + if self.cfg.cache_config.enable_prefix_caching or self.cfg.scheduler_config.splitwise_role != "mixed": self.cache_task_queue = EngineCacheQueue( address=(