From 1bf5789192d281756a6a5332d2dd134dbdee1e98 Mon Sep 17 00:00:00 2001 From: kevin Date: Mon, 27 Oct 2025 21:36:48 +0800 Subject: [PATCH 1/3] update v1 prefill batch --- .../engine/sched/resource_manager_v1.py | 36 ++++++++++++++++--- fastdeploy/worker/worker_process.py | 5 +-- 2 files changed, 32 insertions(+), 9 deletions(-) diff --git a/fastdeploy/engine/sched/resource_manager_v1.py b/fastdeploy/engine/sched/resource_manager_v1.py index 98943436763..e8a28816563 100644 --- a/fastdeploy/engine/sched/resource_manager_v1.py +++ b/fastdeploy/engine/sched/resource_manager_v1.py @@ -331,13 +331,33 @@ def _update_mm_hashes(self, request): inputs["mm_positions"] = [] inputs["mm_hashes"] = [] + def _is_mm_request(self, request): + inputs = request.multimodal_inputs + if inputs is None or len(inputs) == 0: + return False + + if ( + (inputs.get("video_feature_urls") is not None and len(inputs["video_feature_urls"]) > 0) + or (inputs.get("image_feature_urls") is not None and len(inputs["image_feature_urls"]) > 0) + or (inputs.get("audio_feature_urls") is not None and len(inputs["audio_feature_urls"]) > 0) + ): + return True + elif ( + inputs.get("images", None) is not None + and inputs.get("image_patch_id", None) is not None + and inputs.get("grid_thw", None) is not None + ): + return True + + return False + def _get_num_new_tokens(self, request, token_budget): # TODO: set condition to new _get_num_new_tokens num_new_tokens = request.need_prefill_tokens - request.num_computed_tokens num_new_tokens = min(num_new_tokens, token_budget) request.with_image = False - if not self.config.model_config.enable_mm: + if not self._is_mm_request(request): return num_new_tokens request.with_image = False @@ -465,6 +485,12 @@ def _get_num_new_tokens(self, request, token_budget): # Compatible with scenarios without images and videos. return num_new_tokens + def exist_mm_prefill(self, scheduled_reqs): + for request in scheduled_reqs: + if request.task_type == RequestType.PREFILL and self._is_mm_request(request): + return True + return False + def exist_prefill(self, scheduled_reqs): for request in scheduled_reqs: if request.task_type == RequestType.PREFILL: @@ -610,12 +636,12 @@ def _allocate_decode_and_extend(): while self.waiting and token_budget > 0: if len(self.running) == self.max_num_seqs: break - if not self.enable_max_prefill and ( - (self.config.model_config.enable_mm or paddle.is_compiled_with_xpu()) - and self.exist_prefill(scheduled_reqs) + + request = self.waiting[0] + if (self._is_mm_request(request) and self.exist_mm_prefill(scheduled_reqs)) or ( + paddle.is_compiled_with_xpu() and self.exist_prefill(scheduled_reqs) ): break - request = self.waiting[0] if request.status == RequestStatus.WAITING: self._update_mm_hashes(request) # Enable prefix caching diff --git a/fastdeploy/worker/worker_process.py b/fastdeploy/worker/worker_process.py index 00457966bed..a51ba729040 100644 --- a/fastdeploy/worker/worker_process.py +++ b/fastdeploy/worker/worker_process.py @@ -283,10 +283,7 @@ def event_loop_normal(self) -> None: # The first worker detects whether there are tasks in the task queue if local_rank == 0: if self.task_queue.num_tasks() > 0: - # VL only support 1 batch to prefill - if envs.ENABLE_V1_KVCACHE_SCHEDULER or not ( - self.fd_config.model_config.enable_mm and self.worker.exist_prefill() - ): + if envs.ENABLE_V1_KVCACHE_SCHEDULER: if self.nnode > 1 and self.parallel_config.tensor_parallel_size > self.max_chips_per_node: self.task_queue.read_finish_flag.set(1) else: From 12575c6a94b263ed253f15964fa0033b14b53ded Mon Sep 17 00:00:00 2001 From: kevin Date: Wed, 29 Oct 2025 22:16:20 +0800 Subject: [PATCH 2/3] update code --- fastdeploy/engine/sched/resource_manager_v1.py | 2 +- fastdeploy/worker/gpu_model_runner.py | 3 +-- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/fastdeploy/engine/sched/resource_manager_v1.py b/fastdeploy/engine/sched/resource_manager_v1.py index e8a28816563..64a8068e570 100644 --- a/fastdeploy/engine/sched/resource_manager_v1.py +++ b/fastdeploy/engine/sched/resource_manager_v1.py @@ -357,7 +357,7 @@ def _get_num_new_tokens(self, request, token_budget): num_new_tokens = min(num_new_tokens, token_budget) request.with_image = False - if not self._is_mm_request(request): + if not self.config.model_config.enable_mm: return num_new_tokens request.with_image = False diff --git a/fastdeploy/worker/gpu_model_runner.py b/fastdeploy/worker/gpu_model_runner.py index db56abf57ea..46c836342fc 100644 --- a/fastdeploy/worker/gpu_model_runner.py +++ b/fastdeploy/worker/gpu_model_runner.py @@ -470,8 +470,6 @@ def _apply_mm_inputs(self, request: Request, multi_vision_inputs: dict, rope_3d_ else: raise ValueError(f"multiple modalities model {self.model_config.model_type} is not supported") self.share_inputs["image_features"] = image_features[-actual_image_token_num:] - else: - self.share_inputs["image_features"] = None position_ids = request.multimodal_inputs["position_ids"] rope_3d_position_ids["position_ids_idx"].append(request.idx) @@ -494,6 +492,7 @@ def insert_tasks_v1(self, req_dicts: List[Request], num_running_requests: int = req_len = len(req_dicts) has_prefill_task = False has_decode_task = False + self.share_inputs["image_features"] = None multi_vision_inputs = {"images_lst": [], "grid_thw_lst": [], "vit_position_ids_lst": [], "cu_seqlens": [0]} rope_3d_position_ids = { "position_ids_idx": [], From 19e2cc7812088e4d7bae3757c103dc1cb5ae03f2 Mon Sep 17 00:00:00 2001 From: kevin Date: Thu, 30 Oct 2025 19:58:29 +0800 Subject: [PATCH 3/3] update code --- fastdeploy/worker/worker_process.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/fastdeploy/worker/worker_process.py b/fastdeploy/worker/worker_process.py index a51ba729040..3e9152304f0 100644 --- a/fastdeploy/worker/worker_process.py +++ b/fastdeploy/worker/worker_process.py @@ -283,7 +283,9 @@ def event_loop_normal(self) -> None: # The first worker detects whether there are tasks in the task queue if local_rank == 0: if self.task_queue.num_tasks() > 0: - if envs.ENABLE_V1_KVCACHE_SCHEDULER: + if envs.ENABLE_V1_KVCACHE_SCHEDULER or not ( + self.fd_config.model_config.enable_mm and self.worker.exist_prefill() + ): if self.nnode > 1 and self.parallel_config.tensor_parallel_size > self.max_chips_per_node: self.task_queue.read_finish_flag.set(1) else: