From 96410691e6bd98002564a8f570efb1fe3ac47f46 Mon Sep 17 00:00:00 2001 From: ST-XX <15625257+ST-XX@users.noreply.github.com> Date: Thu, 20 Nov 2025 14:26:07 +0800 Subject: [PATCH 1/2] enable guided decoding ENABLE_V1_KVCACHE_SCHEDULER = 1 --- fastdeploy/engine/args_utils.py | 2 -- fastdeploy/worker/gpu_model_runner.py | 38 +++++++++++++++++++++++++++ fastdeploy/worker/worker_process.py | 3 --- 3 files changed, 38 insertions(+), 5 deletions(-) diff --git a/fastdeploy/engine/args_utils.py b/fastdeploy/engine/args_utils.py index 5af1ddfd224..9c1d8bba7b1 100644 --- a/fastdeploy/engine/args_utils.py +++ b/fastdeploy/engine/args_utils.py @@ -524,8 +524,6 @@ def __post_init__(self): if not current_platform.is_cuda() and not current_platform.is_xpu(): envs.ENABLE_V1_KVCACHE_SCHEDULER = 0 - if self.guided_decoding_backend != "off": - envs.ENABLE_V1_KVCACHE_SCHEDULER = 0 if "PaddleOCR" in get_model_architecture(self.model, self.model_config_name): envs.FD_ENABLE_MAX_PREFILL = 1 diff --git a/fastdeploy/worker/gpu_model_runner.py b/fastdeploy/worker/gpu_model_runner.py index e547da97df7..1e1e1794a5b 100644 --- a/fastdeploy/worker/gpu_model_runner.py +++ b/fastdeploy/worker/gpu_model_runner.py @@ -521,7 +521,28 @@ def insert_tasks_v1(self, req_dicts: List[Request], num_running_requests: int = if hasattr(request, "pooling_params") and request.pooling_params is not None: batch_pooling_params.append(request.pooling_params) + logits_info = None + prefill_tokens = [] if request.task_type.value == RequestType.PREFILL.value: # prefill task + # guided decoding + if ( + request.guided_json is not None + or request.guided_regex is not None + or request.structural_tag is not None + or request.guided_grammar is not None + ): + logits_info, schemata_key = self._init_logits_processor(request) + request.schemata_key = schemata_key + + if self.scheduler_config.splitwise_role == "decode": + if ( + hasattr(request, "prefill_end_index") + and hasattr(request, "prompt_token_ids") + and request.prefill_end_index > len(request.prompt_token_ids) + ): + if hasattr(request, "output_token_ids"): + prefill_tokens.extend(request.output_token_ids) + prefill_start_index = request.prefill_start_index prefill_end_index = request.prefill_end_index length = prefill_end_index - prefill_start_index @@ -657,6 +678,8 @@ def insert_tasks_v1(self, req_dicts: List[Request], num_running_requests: int = # For logits processors self.share_inputs["logits_processors_args"][idx] = request.get("logits_processors_args") or {} + self.sampler.apply_logits_processor(idx, logits_info, prefill_tokens) + if len(multi_vision_inputs["images_lst"]) > 0: self.share_inputs["image_features"] = self.extract_vision_features(multi_vision_inputs) @@ -2041,6 +2064,21 @@ def _get_p_done_idxs_gd(self, model_forward_batch: Optional[List[Request]], num_ if self.share_inputs["step_idx"][idx] == 0: prefill_done_idxs.append(idx) + if envs.ENABLE_V1_KVCACHE_SCHEDULER: + if model_forward_batch is None: + return prefill_done_idxs + + for task in model_forward_batch: + if task.task_type.value != RequestType.PREFILL.value: + continue + # in chunk prefill + if self.cache_config.enable_chunked_prefill: + if hasattr(task, "prefill_end_index") and hasattr(task, "prompt_token_ids"): + if len(task.prompt_token_ids) > task.prefill_end_index and idx in prefill_done_idxs: + prefill_done_idxs.remove(idx) + + return prefill_done_idxs + if self.cache_config.enable_chunked_prefill: if model_forward_batch is not None: for task in model_forward_batch: diff --git a/fastdeploy/worker/worker_process.py b/fastdeploy/worker/worker_process.py index b30db63a7ec..27480e7e1e7 100644 --- a/fastdeploy/worker/worker_process.py +++ b/fastdeploy/worker/worker_process.py @@ -932,9 +932,6 @@ def initialize_fd_config(args, ranks: int = 1, local_rank: int = 0) -> FDConfig: if not current_platform.is_cuda() and not current_platform.is_xpu(): logger.info("Set ENABLE_V1_KVCACHE_SCHEDULER to 0 due to not supported.") envs.ENABLE_V1_KVCACHE_SCHEDULER = 0 - if structured_outputs_config.guided_decoding_backend != "off": - logger.info("Set ENABLE_V1_KVCACHE_SCHEDULER to 0 due to not supported guided_decoding.") - envs.ENABLE_V1_KVCACHE_SCHEDULER = 0 if envs.ENABLE_V1_KVCACHE_SCHEDULER and args.splitwise_role == "prefill": os.environ["PREFILL_NODE_ONE_STEP_STOP_V1"] = "1" From 295916b82b56272e96f4d5d589154040ba95a17c Mon Sep 17 00:00:00 2001 From: Daci <15625257+ST-XX@users.noreply.github.com> Date: Mon, 24 Nov 2025 17:15:25 +0800 Subject: [PATCH 2/2] Apply suggestions from code review Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- fastdeploy/worker/gpu_model_runner.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/fastdeploy/worker/gpu_model_runner.py b/fastdeploy/worker/gpu_model_runner.py index 1e1e1794a5b..54c220c6d06 100644 --- a/fastdeploy/worker/gpu_model_runner.py +++ b/fastdeploy/worker/gpu_model_runner.py @@ -2074,8 +2074,8 @@ def _get_p_done_idxs_gd(self, model_forward_batch: Optional[List[Request]], num_ # in chunk prefill if self.cache_config.enable_chunked_prefill: if hasattr(task, "prefill_end_index") and hasattr(task, "prompt_token_ids"): - if len(task.prompt_token_ids) > task.prefill_end_index and idx in prefill_done_idxs: - prefill_done_idxs.remove(idx) + if len(task.prompt_token_ids) > task.prefill_end_index and task.idx in prefill_done_idxs: + prefill_done_idxs.remove(task.idx) return prefill_done_idxs