diff --git a/fastdeploy/engine/args_utils.py b/fastdeploy/engine/args_utils.py index 77def0feb4d..3f7952fc5aa 100644 --- a/fastdeploy/engine/args_utils.py +++ b/fastdeploy/engine/args_utils.py @@ -535,8 +535,6 @@ def __post_init__(self): if not (current_platform.is_cuda() or current_platform.is_xpu() or current_platform.is_maca()): envs.ENABLE_V1_KVCACHE_SCHEDULER = 0 - if self.guided_decoding_backend != "off": - envs.ENABLE_V1_KVCACHE_SCHEDULER = 0 if "PaddleOCR" in get_model_architecture(self.model, self.model_config_name): envs.FD_ENABLE_MAX_PREFILL = 1 diff --git a/fastdeploy/worker/gpu_model_runner.py b/fastdeploy/worker/gpu_model_runner.py index a63876d4afb..81f44ff813f 100644 --- a/fastdeploy/worker/gpu_model_runner.py +++ b/fastdeploy/worker/gpu_model_runner.py @@ -521,7 +521,28 @@ def insert_tasks_v1(self, req_dicts: List[Request], num_running_requests: int = if hasattr(request, "pooling_params") and request.pooling_params is not None: batch_pooling_params.append(request.pooling_params) + logits_info = None + prefill_tokens = [] if request.task_type.value == RequestType.PREFILL.value: # prefill task + # guided decoding + if ( + request.guided_json is not None + or request.guided_regex is not None + or request.structural_tag is not None + or request.guided_grammar is not None + ): + logits_info, schemata_key = self._init_logits_processor(request) + request.schemata_key = schemata_key + + if self.scheduler_config.splitwise_role == "decode": + if ( + hasattr(request, "prefill_end_index") + and hasattr(request, "prompt_token_ids") + and request.prefill_end_index > len(request.prompt_token_ids) + ): + if hasattr(request, "output_token_ids"): + prefill_tokens.extend(request.output_token_ids) + prefill_start_index = request.prefill_start_index prefill_end_index = request.prefill_end_index length = prefill_end_index - prefill_start_index @@ -657,6 +678,8 @@ def insert_tasks_v1(self, req_dicts: List[Request], num_running_requests: int = # For logits processors self.share_inputs["logits_processors_args"][idx] = request.get("logits_processors_args") or {} + self.sampler.apply_logits_processor(idx, logits_info, prefill_tokens) + if len(multi_vision_inputs["images_lst"]) > 0: self.share_inputs["image_features"] = self.extract_vision_features(multi_vision_inputs) @@ -2059,6 +2082,21 @@ def _get_p_done_idxs_gd(self, model_forward_batch: Optional[List[Request]], num_ if self.share_inputs["step_idx"][idx] == 0: prefill_done_idxs.append(idx) + if envs.ENABLE_V1_KVCACHE_SCHEDULER: + if model_forward_batch is None: + return prefill_done_idxs + + for task in model_forward_batch: + if task.task_type.value != RequestType.PREFILL.value: + continue + # in chunk prefill + if self.cache_config.enable_chunked_prefill: + if hasattr(task, "prefill_end_index") and hasattr(task, "prompt_token_ids"): + if len(task.prompt_token_ids) > task.prefill_end_index and task.idx in prefill_done_idxs: + prefill_done_idxs.remove(task.idx) + + return prefill_done_idxs + if self.cache_config.enable_chunked_prefill: if model_forward_batch is not None: for task in model_forward_batch: diff --git a/fastdeploy/worker/worker_process.py b/fastdeploy/worker/worker_process.py index c48908718b5..f836e07c60d 100644 --- a/fastdeploy/worker/worker_process.py +++ b/fastdeploy/worker/worker_process.py @@ -968,9 +968,6 @@ def initialize_fd_config(args, ranks: int = 1, local_rank: int = 0) -> FDConfig: if not (current_platform.is_cuda() or current_platform.is_xpu() or current_platform.is_maca()): logger.info("Set ENABLE_V1_KVCACHE_SCHEDULER to 0 due to not supported.") envs.ENABLE_V1_KVCACHE_SCHEDULER = 0 - if structured_outputs_config.guided_decoding_backend != "off": - logger.info("Set ENABLE_V1_KVCACHE_SCHEDULER to 0 due to not supported guided_decoding.") - envs.ENABLE_V1_KVCACHE_SCHEDULER = 0 if envs.ENABLE_V1_KVCACHE_SCHEDULER and args.splitwise_role == "prefill": os.environ["PREFILL_NODE_ONE_STEP_STOP_V1"] = "1"