From e87f5baeb3944791ad11bfc601ae8ebc5bba6707 Mon Sep 17 00:00:00 2001 From: ckl117 Date: Fri, 31 Oct 2025 19:21:37 +0800 Subject: [PATCH 1/8] init temp --- fastdeploy/config.py | 3 + fastdeploy/engine/args_utils.py | 20 +++++- fastdeploy/engine/async_llm.py | 1 + fastdeploy/engine/engine.py | 1 + fastdeploy/engine/sampling_params.py | 1 + .../model_executor/pre_and_post_process.py | 49 ++++++++----- fastdeploy/output/stream_transfer_data.py | 5 +- fastdeploy/output/token_processor.py | 24 ++++++- fastdeploy/worker/gpu_model_runner.py | 68 ++++++++++++++++++- fastdeploy/worker/output.py | 11 +++ fastdeploy/worker/worker_process.py | 6 ++ 11 files changed, 167 insertions(+), 22 deletions(-) diff --git a/fastdeploy/config.py b/fastdeploy/config.py index 8cd82fbc655..6f470716cd8 100644 --- a/fastdeploy/config.py +++ b/fastdeploy/config.py @@ -183,6 +183,7 @@ def __init__( self.max_model_len = 0 self.dtype = "bfloat16" self.enable_logprob = False + self.max_logprobs = 20 self.logprobs_mode = "raw_logprobs" self.enable_redundant_experts = False self.redundant_experts_num = 0 @@ -227,6 +228,8 @@ def __init__( self.think_end_id = args.get("think_end_id", -1) self.im_patch_id = args.get("image_patch_id", -1) self.line_break_id = args.get("line_break_id", -1) + if self.max_logprobs == -1 and hasattr(self, "vocab_size"): + self.max_logprobs = self.vocab_size self._post_init() diff --git a/fastdeploy/engine/args_utils.py b/fastdeploy/engine/args_utils.py index 9112efa7d30..feda67b6b29 100644 --- a/fastdeploy/engine/args_utils.py +++ b/fastdeploy/engine/args_utils.py @@ -43,6 +43,7 @@ from fastdeploy.utils import ( DeprecatedOptionWarning, FlexibleArgumentParser, + console_logger, is_port_available, parse_quantization, ) @@ -386,6 +387,12 @@ class EngineArgs: Must be explicitly enabled via the `--enable-logprob` startup parameter to output logprob values. """ + max_logprobs: int = 20 + """ + Maximum number of log probabilities to return when `enable_logprob` is True. The default value comes the default for the + OpenAI Chat Completions API. -1 means no cap, i.e. all (output_length * vocab_size) logprobs are allowed to be returned and it may cause OOM. + """ + logprobs_mode: str = "raw_logprobs" """ Indicates the content returned in the logprobs. @@ -452,6 +459,11 @@ def __post_init__(self): raise NotImplementedError("Only CUDA platform supports logprob.") if self.speculative_config is not None and self.logprobs_mode.startswith("processed"): raise NotImplementedError("processed_logprobs not support in speculative.") + if self.speculative_config is not None and self.max_logprobs == -1: + raise NotImplementedError("max_logprobs=-1 not support in speculative.") + if envs.FD_USE_GET_SAVE_OUTPUT_V1 == 0: + self.max_logprobs = 20 + console_logger.warning("Set max_logprobs=20 when FD_USE_GET_SAVE_OUTPUT_V1=0") if self.splitwise_role != "mixed" and self.cache_transfer_protocol != "rdma": envs.ENABLE_V1_KVCACHE_SCHEDULER = 0 @@ -666,10 +678,16 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: default=EngineArgs.enable_logprob, help="Enable output of token-level log probabilities.", ) + model_group.add_argument( + "--max-logprobs", + type=int, + default=EngineArgs.max_logprobs, + help="Maximum number of log probabilities.", + ) model_group.add_argument( "--logprobs-mode", type=str, - choices=["raw_logprobs", "processed_logprobs", "processed_logits"], + choices=["raw_logprobs", "raw_logits", "processed_logprobs", "processed_logits"], default=EngineArgs.logprobs_mode, help="Indicates the content returned in the logprobs.", ) diff --git a/fastdeploy/engine/async_llm.py b/fastdeploy/engine/async_llm.py index 46a701f14ec..7411514b1b0 100644 --- a/fastdeploy/engine/async_llm.py +++ b/fastdeploy/engine/async_llm.py @@ -831,6 +831,7 @@ def _start_worker_service(self): f" --convert {self.cfg.model_config.convert}" f" --override-pooler-config {self.cfg.model_config.override_pooler_config}" f" --logprobs_mode {self.cfg.model_config.logprobs_mode}" + f" --max_logprobs {self.cfg.model_config.max_logprobs}" ) worker_append_flag = { diff --git a/fastdeploy/engine/engine.py b/fastdeploy/engine/engine.py index f1a734f426e..7de02fa989a 100644 --- a/fastdeploy/engine/engine.py +++ b/fastdeploy/engine/engine.py @@ -533,6 +533,7 @@ def _start_worker_service(self): f" --convert {self.cfg.model_config.convert}" f" --override-pooler-config {self.cfg.model_config.override_pooler_config}" f" --logprobs_mode {self.cfg.model_config.logprobs_mode}" + f" --max_logprobs {self.cfg.model_config.max_logprobs}" ) if self.cfg.structured_outputs_config.logits_processors is not None: arguments += f" --logits-processors {' '.join(self.cfg.structured_outputs_config.logits_processors)}" diff --git a/fastdeploy/engine/sampling_params.py b/fastdeploy/engine/sampling_params.py index ef26cd380e0..67c1980ddd1 100644 --- a/fastdeploy/engine/sampling_params.py +++ b/fastdeploy/engine/sampling_params.py @@ -97,6 +97,7 @@ class SamplingParams: reasoning_max_tokens: Optional[int] = None min_tokens: int = 1 logprobs: Optional[int] = None + prompt_logprobs: Optional[int] = -1 # For logits and logprobs post processing temp_scaled_logprobs: bool = False top_p_normalized_logprobs: bool = False diff --git a/fastdeploy/model_executor/pre_and_post_process.py b/fastdeploy/model_executor/pre_and_post_process.py index bddb12b496b..43d38a1b894 100644 --- a/fastdeploy/model_executor/pre_and_post_process.py +++ b/fastdeploy/model_executor/pre_and_post_process.py @@ -238,7 +238,11 @@ def pre_process( ) -def _build_stream_transfer_data(output_tokens: np.ndarray): +def _build_stream_transfer_data( + output_tokens: paddle.Tensor, + logprobs: paddle.Tensor = None, + prompt_logprobs: paddle.Tensor = None, +): """Split output_tokens and output""" output_tokens = output_tokens.reshape([-1]).numpy() output_tokens_lists = np.split(output_tokens, output_tokens.shape[0]) @@ -248,6 +252,11 @@ def _build_stream_transfer_data(output_tokens: np.ndarray): stream_transfer_data = StreamTransferData( decoder_state=DecoderState.TEXT, tokens=output_token_per_sample, batch_id=bid ) + if logprobs: + logprobs = logprobs.tolists().slice_rows(bid, bid + 1) + stream_transfer_data.logprobs = logprobs + if prompt_logprobs: + raise NotImplementedError("current dont spport prompt_logprobs") stream_transfer_datas.append(stream_transfer_data) return stream_transfer_datas @@ -262,6 +271,7 @@ def post_process_normal( async_output_queue: queue.Queue = None, think_end_id: int = -1, line_break_id: int = -1, + max_logprobs: int = 20, ) -> ModelRunnerOutput: """Post-processing steps after completing a single token generation.""" if think_end_id > 0: @@ -356,30 +366,35 @@ def post_process_normal( sampler_output.sampled_token_ids, model_output.is_block_step, ) + prompt_logprobs_tensors = None # 3. Transmit the model's output and stop generation signal via message queue. # In the future, we will abandon this approach. if not skip_save_output: - if sampler_output.logprobs_tensors is None: - if envs.FD_USE_GET_SAVE_OUTPUT_V1: - if save_each_rank or model_output.mp_rank == 0: - output = _build_stream_transfer_data(sampler_output.sampled_token_ids) - async_output_queue.put(output) - else: + if envs.FD_USE_GET_SAVE_OUTPUT_V1: + if save_each_rank or model_output.mp_rank == 0: + output = _build_stream_transfer_data( + sampler_output.sampled_token_ids, + logprobs=sampler_output.logprobs_tensors, + prompt_logprobs=prompt_logprobs_tensors, + ) + async_output_queue.put(output) + else: + if sampler_output.logprobs_tensors is None: save_output( sampler_output.sampled_token_ids, model_output.not_need_stop, model_output.mp_rank, save_each_rank, ) - else: - save_output_topk( - sampler_output.sampled_token_ids, - sampler_output.logprobs_tensors.logprob_token_ids, - sampler_output.logprobs_tensors.logprobs, - sampler_output.logprobs_tensors.selected_token_ranks, - model_output.not_need_stop, - model_output.mp_rank, - ) + else: + save_output_topk( + sampler_output.sampled_token_ids, + sampler_output.logprobs_tensors.logprob_token_ids, + sampler_output.logprobs_tensors.logprobs, + sampler_output.logprobs_tensors.selected_token_ranks, + model_output.not_need_stop, + model_output.mp_rank, + ) def post_process_specualate( @@ -468,6 +483,7 @@ def post_process( async_output_queue: queue.Queue = None, think_end_id: int = -1, line_break_id: int = -1, + max_logprobs: int = 20, ) -> None: """Post-processing steps after completing a single token generation.""" if speculative_decoding: @@ -491,6 +507,7 @@ def post_process( async_output_queue, think_end_id, line_break_id, + max_logprobs, ) diff --git a/fastdeploy/output/stream_transfer_data.py b/fastdeploy/output/stream_transfer_data.py index 6241a28d990..dacd3cc7b29 100644 --- a/fastdeploy/output/stream_transfer_data.py +++ b/fastdeploy/output/stream_transfer_data.py @@ -20,6 +20,8 @@ import numpy as np +from fastdeploy.worker.output import LogprobsLists + class DecoderState(Enum): """DecoderState""" @@ -38,7 +40,8 @@ class StreamTransferData: tokens: np.array batch_id: int speculaive_decoding: bool = False - logprobs: Optional[np.array] = None + logprobs: Optional[LogprobsLists] = None + prompt_logprobs: Optional[LogprobsLists] = None accept_tokens: Optional[np.array] = None accept_num: Optional[np.array] = None # [num_reqs, hidden_size] diff --git a/fastdeploy/output/token_processor.py b/fastdeploy/output/token_processor.py index 5d94a2c2b90..497fab1f231 100644 --- a/fastdeploy/output/token_processor.py +++ b/fastdeploy/output/token_processor.py @@ -241,6 +241,11 @@ def _process_batch_output_use_zmq(self, receive_datas): task_id = task.request_id token_ids = stream_data.tokens # numpy.array + logprobs = stream_data.logprobs + prompt_logprobs = stream_data.prompt_logprobs + llm_logger.info( + f"batch_id = {i}, task_id={task_id},token_ids={token_ids},logprobs={logprobs}, prompt_logprobs={prompt_logprobs}" + ) current_time = time.time() if self.tokens_counter[task_id] == 0: @@ -285,6 +290,21 @@ def _process_batch_output_use_zmq(self, receive_datas): finished=False, metrics=metrics, ) + if self.use_logprobs: + result.outputs.logprob = logprobs.logprobs[0] + topk_token_ids = logprobs.logprob_token_ids[1:] + topk_logprobs = logprobs.logprobs[1:] + sampled_rank = logprobs.sampled_token_ranks[0] + if result.outputs.top_logprobs is None: + result.outputs.top_logprobs = LogprobsLists( + logprob_token_ids=[topk_token_ids], + logprobs=[topk_logprobs], + sampled_token_ranks=[sampled_rank], + ) + else: + result.outputs.top_logprobs.logprob_token_ids.extend([topk_token_ids]) + result.outputs.top_logprobs.logprobs.extend([topk_logprobs]) + result.outputs.top_logprobs.sampled_token_ranks.extend([sampled_rank]) if self.tokens_counter[task_id] == 0: if task.messages is not None: result.prompt = task.messages @@ -306,8 +326,8 @@ def process_sampling_results_use_zmq(self): """ if self.speculative_decoding: raise NotImplementedError("GET_SAVE_OUTPUT_V1 does not support speculative decoding") - if self.use_logprobs: - raise NotImplementedError("GET_SAVE_OUTPUT_V1 does not support use_logprobs") + # if self.use_logprobs: + # raise NotImplementedError("GET_SAVE_OUTPUT_V1 does not support use_logprobs") rank_id = self.cfg.parallel_config.local_data_parallel_id while True: try: diff --git a/fastdeploy/worker/gpu_model_runner.py b/fastdeploy/worker/gpu_model_runner.py index 598388bf9dc..a8670ff4a98 100644 --- a/fastdeploy/worker/gpu_model_runner.py +++ b/fastdeploy/worker/gpu_model_runner.py @@ -92,7 +92,7 @@ from fastdeploy.model_executor.models.interfaces_base import FdModelForPooling from fastdeploy.output.pooler import PoolerOutput from fastdeploy.worker.model_runner_base import ModelRunnerBase -from fastdeploy.worker.output import ModelOutputData, ModelRunnerOutput +from fastdeploy.worker.output import LogprobsTensors, ModelOutputData, ModelRunnerOutput class GPUModelRunner(ModelRunnerBase): @@ -112,8 +112,13 @@ def __init__( self.speculative_method = self.fd_config.speculative_config.method self.speculative_decoding = self.speculative_method is not None self.enable_logprob = fd_config.model_config.enable_logprob + self.max_logprobs = fd_config.model_config.max_logprobs self.enable_early_stop = self.fd_config.early_stop_config.enable_early_stop self.is_pooling_model = self.fd_config.model_config.runner_type == "pooling" + self.vocal_size = self.fd_config.model_config.vocab_size + self.running_reqs: list[Request] = self.scheduler_config.max_num_seqs * [None] + self.prompt_logprobs_reqs: list[Request] = [] + self.in_progress_prompt_logprobs: dict[str, LogprobsTensors] = {} # VL model config: if self.enable_mm: @@ -552,6 +557,12 @@ def insert_tasks_v1(self, req_dicts: List[Request], num_running_requests: int = len(request.output_token_ids) if prefill_end_index >= len(input_ids) else 0 ) self.share_inputs["pre_ids"][idx : idx + 1] = -1 + # self.running_reqs[idx] = request + prompt_logprobs = request.sampling_params.prompt_logprobs + if prompt_logprobs is not None: + self.num_prompt_logprobs[request.request_id] = ( + self.vocal_size if prompt_logprobs == -1 else prompt_logprobs + ) has_prefill_task = True elif request.task_type.value == RequestType.DECODE.value: # decode task logger.debug(f"Handle decode request {request} at idx {idx}") @@ -572,6 +583,8 @@ def insert_tasks_v1(self, req_dicts: List[Request], num_running_requests: int = self.share_inputs["seq_lens_decoder"][idx : idx + 1] = 0 self.share_inputs["seq_lens_encoder"][idx : idx + 1] = 0 self.share_inputs["is_block_step"][idx : idx + 1] = False + self.num_prompt_logprobs.pop(request.request_id, None) + self.in_progress_prompt_logprobs.pop(request.request_id, None) continue assert len(request.eos_token_ids) == self.model_config.eos_tokens_lens @@ -1269,7 +1282,7 @@ def _prepare_inputs(self) -> None: min_dec_lens=self.share_inputs["min_dec_len"], bad_words_token_ids=self.share_inputs["bad_tokens"][:, :max_bad_tokens_len], eos_token_ids=self.share_inputs["eos_token_id"], - max_num_logprobs=20 if self.enable_logprob else None, + max_num_logprobs=self.max_logprobs if self.enable_logprob else None, enable_early_stop=self.enable_early_stop, stop_flags=self.share_inputs["stop_flags"], temp_scaled_logprobs=self.share_inputs["temp_scaled_logprobs"], @@ -2001,6 +2014,24 @@ class at the server level, which is too granular for ModelRunner. # 1. Prepare inputs of model and sampler. skip_idx_list = self._get_skip_idx(model_forward_batch) self._prepare_inputs() + # print(f'model_forward_batch = {model_forward_batch}') + # print(f'self.running_reqs = {self.running_reqs}') + # for bid, req in enumerate(self.running_reqs): + # if req is None or self.share_inputs["stop_flags"][bid,0]: + # self.running_reqs[bid] = None # stop_flags = true + # self.num_prompt_logprobs[bid] = None + # continue + # print(f'req: {req.to_dict()}') + # if req.sampling_params.prompt_logprobs is not None: + # self.num_prompt_logprobs[bid] = self.vocal_size if req.sampling_params.prompt_logprobs == -1 else req.sampling_params.prompt_logprobs + print(f"self.in_progress_prompt_logprobs = {self.in_progress_prompt_logprobs}") + # print(f'input_ids = {self.share_inputs["input_ids"]}') + print(f'ids_remove_padding = {self.share_inputs["ids_remove_padding"]}') + print(f"batch_id_per_token = {self.forward_meta.batch_id_per_token}") + print(f'cu_seqlens_q = {self.share_inputs["cu_seqlens_q"]}') + print(f'seq_lens_encoder = {self.share_inputs["seq_lens_encoder"]}') + print(f'seq_lens_decoder = {self.share_inputs["seq_lens_decoder"]}') + print(f'seq_lens_this_time = {self.share_inputs["seq_lens_this_time"]}') self.sampler.pre_process(skip_idx_list) # 1.1 Update state of logits processor @@ -2040,6 +2071,29 @@ class at the server level, which is too granular for ModelRunner. (self.share_inputs["output_padding_offset"] if self.speculative_decoding else None), self.model_config.max_model_len, ) + # 遍历需要计算prompt_logprobs的请求 + completed_prefill_reqs = [] + for idx, request in enumerate(self.prompt_logprobs_reqs): + # 1.判断当前请求是否已经计算完prompt_logprobs + num_prompt_logprobs = request.sampling_params.prompt_logprobs + if request.prompt_token_ids is None or num_prompt_logprobs is None: + continue + if num_prompt_logprobs == -1: + num_prompt_logprobs = self.vocal_size + num_prompt_tokens = len(request.prompt_token_ids) + logprobs_tensors = self.in_progress_prompt_logprobs.get(request.idx) + if not logprobs_tensors: + logprobs_tensors = LogprobsTensors.empty_cpu(num_prompt_tokens, num_prompt_logprobs + 1) + self.in_progress_prompt_logprobs[request.idx] = logprobs_tensors + # 2.如果已经计算完prompt_logprobs,记录到completed_prefill_reqs,跳过 + # 3.判断需要chunked_prefill部分的prompt_logprobs + + # if self.in_progress_prompt_logprobs[bid] is None: + # self.in_progress_prompt_logprobs[bid] = self.share_inputs["input_ids"][:, :prompt_logprobs] + + # 清除已经计算完prompt_logprobs的请求 + + # 将prompt_logprob组装成batch,通过zmq返回 # 4. Compute logits, Sample logits = None @@ -2144,6 +2198,7 @@ class at the server level, which is too granular for ModelRunner. async_output_queue=self.async_output_queue, think_end_id=self.model_config.think_end_id, line_break_id=self.model_config.line_break_id, + max_logprobs=self.max_logprobs, ) if self.guided_backend is not None and sampler_output is not None: self.sampler.post_process(sampler_output.sampled_token_ids, skip_idx_list) @@ -2559,3 +2614,12 @@ def prepare_rope3d( cumsum_seqlens=cumsum_seqlens, ) return rope_emb_lst + + def _get_prompt_logprobs_dict( + self, + hidden_states: paddle.Tensor, + num_scheduled_tokens: dict[str, int], + ) -> dict[str, Optional[LogprobsTensors]]: + # in_progress_dict = self.in_progress_prompt_logprobs_cpu + # prompt_logprobs_dict: dict[str, Optional[LogprobsTensors]] = {} + pass diff --git a/fastdeploy/worker/output.py b/fastdeploy/worker/output.py index b4192e88269..60d298f7faf 100644 --- a/fastdeploy/worker/output.py +++ b/fastdeploy/worker/output.py @@ -63,6 +63,17 @@ def slice_columns(self, start: int, end: int): self.sampled_token_ranks, # unchanged ) + def slice_rows(self, start: int, end: int): + """ + Slice rows. + Keeps the number of max_num_logprobs unchanged. + """ + return LogprobsLists( + self.logprob_token_ids[start:end], + self.logprobs[start:end], + self.sampled_token_ranks[start:end], + ) + class LogprobsTensors(NamedTuple): """ """ diff --git a/fastdeploy/worker/worker_process.py b/fastdeploy/worker/worker_process.py index 575913cfbe3..0468ee95634 100644 --- a/fastdeploy/worker/worker_process.py +++ b/fastdeploy/worker/worker_process.py @@ -614,6 +614,12 @@ def parse_args(): action="store_true", help="Enable output of token-level log probabilities.", ) + parser.add_argument( + "--max_logprobs", + type=int, + default=20, + help="Maximum number of log probabilities.", + ) parser.add_argument( "--logprobs_mode", type=str, From 6df86f1e0dc6072f6f5be1034902aca760e5d904 Mon Sep 17 00:00:00 2001 From: ckl117 Date: Mon, 3 Nov 2025 16:10:38 +0800 Subject: [PATCH 2/8] gpu worker support prompt_logprob --- fastdeploy/engine/args_utils.py | 4 +- fastdeploy/engine/sampling_params.py | 2 +- .../model_executor/layers/sample/sampler.py | 6 +- .../model_executor/pre_and_post_process.py | 19 ++- fastdeploy/worker/gpu_model_runner.py | 121 ++++++++++-------- fastdeploy/worker/output.py | 17 ++- 6 files changed, 105 insertions(+), 64 deletions(-) diff --git a/fastdeploy/engine/args_utils.py b/fastdeploy/engine/args_utils.py index feda67b6b29..e01e79ae6d0 100644 --- a/fastdeploy/engine/args_utils.py +++ b/fastdeploy/engine/args_utils.py @@ -461,9 +461,11 @@ def __post_init__(self): raise NotImplementedError("processed_logprobs not support in speculative.") if self.speculative_config is not None and self.max_logprobs == -1: raise NotImplementedError("max_logprobs=-1 not support in speculative.") - if envs.FD_USE_GET_SAVE_OUTPUT_V1 == 0: + if not envs.FD_USE_GET_SAVE_OUTPUT_V1: self.max_logprobs = 20 console_logger.warning("Set max_logprobs=20 when FD_USE_GET_SAVE_OUTPUT_V1=0") + if self.max_logprobs == -1 and not envs.ENABLE_V1_KVCACHE_SCHEDULER: + raise NotImplementedError("Only ENABLE_V1_KVCACHE_SCHEDULER=1 support max_logprobs=-1") if self.splitwise_role != "mixed" and self.cache_transfer_protocol != "rdma": envs.ENABLE_V1_KVCACHE_SCHEDULER = 0 diff --git a/fastdeploy/engine/sampling_params.py b/fastdeploy/engine/sampling_params.py index 67c1980ddd1..23b77c7be95 100644 --- a/fastdeploy/engine/sampling_params.py +++ b/fastdeploy/engine/sampling_params.py @@ -97,7 +97,7 @@ class SamplingParams: reasoning_max_tokens: Optional[int] = None min_tokens: int = 1 logprobs: Optional[int] = None - prompt_logprobs: Optional[int] = -1 + prompt_logprobs: Optional[int] = None # For logits and logprobs post processing temp_scaled_logprobs: bool = False top_p_normalized_logprobs: bool = False diff --git a/fastdeploy/model_executor/layers/sample/sampler.py b/fastdeploy/model_executor/layers/sample/sampler.py index c60ec83604a..bfbbe14c7e0 100644 --- a/fastdeploy/model_executor/layers/sample/sampler.py +++ b/fastdeploy/model_executor/layers/sample/sampler.py @@ -247,9 +247,11 @@ def post_process(self, next_tokens: paddle.Tensor, skip_idx_list: List[int] = [] def compute_logprobs( self, logits: paddle.Tensor, - sampling_metadata: SamplingMetadata, + sampling_metadata: Optional[SamplingMetadata] = None, ) -> paddle.Tensor: """ """ + if sampling_metadata is None: + return F.log_softmax(logits, axis=-1) last_logits = logits real_bsz = last_logits.shape[0] temp_scaled_logprobs = sampling_metadata.temp_scaled_logprobs @@ -309,6 +311,8 @@ def gather_logprobs( assert token_ids.dtype == paddle.int64 logprobs.clip_(min=paddle.finfo(logprobs.dtype).min) # Get with the logprob of the prompt or sampled token. + if len(token_ids.shape) < len(logprobs.shape): + token_ids = token_ids.unsqueeze(-1) token_logprobs = paddle.take_along_axis(logprobs, token_ids, axis=-1) # Compute the ranks of the actual token. diff --git a/fastdeploy/model_executor/pre_and_post_process.py b/fastdeploy/model_executor/pre_and_post_process.py index 43d38a1b894..5f1cd69f1f4 100644 --- a/fastdeploy/model_executor/pre_and_post_process.py +++ b/fastdeploy/model_executor/pre_and_post_process.py @@ -86,7 +86,12 @@ ) from fastdeploy.output.stream_transfer_data import DecoderState, StreamTransferData -from fastdeploy.worker.output import ModelOutputData, ModelRunnerOutput, SamplerOutput +from fastdeploy.worker.output import ( + LogprobsTensors, + ModelOutputData, + ModelRunnerOutput, + SamplerOutput, +) DISABLE_RECOVER = envs.FD_DISABLED_RECOVER == "1" @@ -240,8 +245,8 @@ def pre_process( def _build_stream_transfer_data( output_tokens: paddle.Tensor, - logprobs: paddle.Tensor = None, - prompt_logprobs: paddle.Tensor = None, + logprobs: Optional[LogprobsTensors] = None, + prompt_logprobs_list: Optional[LogprobsTensors] = None, ): """Split output_tokens and output""" output_tokens = output_tokens.reshape([-1]).numpy() @@ -255,8 +260,8 @@ def _build_stream_transfer_data( if logprobs: logprobs = logprobs.tolists().slice_rows(bid, bid + 1) stream_transfer_data.logprobs = logprobs - if prompt_logprobs: - raise NotImplementedError("current dont spport prompt_logprobs") + if prompt_logprobs_list: + stream_transfer_data.prompt_logprobs = prompt_logprobs_list[bid] stream_transfer_datas.append(stream_transfer_data) return stream_transfer_datas @@ -366,7 +371,7 @@ def post_process_normal( sampler_output.sampled_token_ids, model_output.is_block_step, ) - prompt_logprobs_tensors = None + prompt_logprobs_list = model_output.prompt_logprobs_list # 3. Transmit the model's output and stop generation signal via message queue. # In the future, we will abandon this approach. if not skip_save_output: @@ -375,7 +380,7 @@ def post_process_normal( output = _build_stream_transfer_data( sampler_output.sampled_token_ids, logprobs=sampler_output.logprobs_tensors, - prompt_logprobs=prompt_logprobs_tensors, + prompt_logprobs=prompt_logprobs_list, ) async_output_queue.put(output) else: diff --git a/fastdeploy/worker/gpu_model_runner.py b/fastdeploy/worker/gpu_model_runner.py index a8670ff4a98..9b7122d02e2 100644 --- a/fastdeploy/worker/gpu_model_runner.py +++ b/fastdeploy/worker/gpu_model_runner.py @@ -116,8 +116,7 @@ def __init__( self.enable_early_stop = self.fd_config.early_stop_config.enable_early_stop self.is_pooling_model = self.fd_config.model_config.runner_type == "pooling" self.vocal_size = self.fd_config.model_config.vocab_size - self.running_reqs: list[Request] = self.scheduler_config.max_num_seqs * [None] - self.prompt_logprobs_reqs: list[Request] = [] + self.prompt_logprobs_reqs: dict[str, Request] = {} self.in_progress_prompt_logprobs: dict[str, LogprobsTensors] = {} # VL model config: @@ -557,12 +556,8 @@ def insert_tasks_v1(self, req_dicts: List[Request], num_running_requests: int = len(request.output_token_ids) if prefill_end_index >= len(input_ids) else 0 ) self.share_inputs["pre_ids"][idx : idx + 1] = -1 - # self.running_reqs[idx] = request - prompt_logprobs = request.sampling_params.prompt_logprobs - if prompt_logprobs is not None: - self.num_prompt_logprobs[request.request_id] = ( - self.vocal_size if prompt_logprobs == -1 else prompt_logprobs - ) + if request.sampling_params.prompt_logprobs is not None: + self.prompt_logprobs_reqs[request.request_id] = request has_prefill_task = True elif request.task_type.value == RequestType.DECODE.value: # decode task logger.debug(f"Handle decode request {request} at idx {idx}") @@ -2014,24 +2009,8 @@ class at the server level, which is too granular for ModelRunner. # 1. Prepare inputs of model and sampler. skip_idx_list = self._get_skip_idx(model_forward_batch) self._prepare_inputs() - # print(f'model_forward_batch = {model_forward_batch}') - # print(f'self.running_reqs = {self.running_reqs}') - # for bid, req in enumerate(self.running_reqs): - # if req is None or self.share_inputs["stop_flags"][bid,0]: - # self.running_reqs[bid] = None # stop_flags = true - # self.num_prompt_logprobs[bid] = None - # continue - # print(f'req: {req.to_dict()}') - # if req.sampling_params.prompt_logprobs is not None: - # self.num_prompt_logprobs[bid] = self.vocal_size if req.sampling_params.prompt_logprobs == -1 else req.sampling_params.prompt_logprobs - print(f"self.in_progress_prompt_logprobs = {self.in_progress_prompt_logprobs}") - # print(f'input_ids = {self.share_inputs["input_ids"]}') + # print(f"self.in_progress_prompt_logprobs = {self.in_progress_prompt_logprobs}") print(f'ids_remove_padding = {self.share_inputs["ids_remove_padding"]}') - print(f"batch_id_per_token = {self.forward_meta.batch_id_per_token}") - print(f'cu_seqlens_q = {self.share_inputs["cu_seqlens_q"]}') - print(f'seq_lens_encoder = {self.share_inputs["seq_lens_encoder"]}') - print(f'seq_lens_decoder = {self.share_inputs["seq_lens_decoder"]}') - print(f'seq_lens_this_time = {self.share_inputs["seq_lens_this_time"]}') self.sampler.pre_process(skip_idx_list) # 1.1 Update state of logits processor @@ -2062,6 +2041,9 @@ class at the server level, which is too granular for ModelRunner. ) if self.use_cudagraph: model_output = model_output[: self.real_token_num] + + prompt_logprobs_list = self._get_prompt_logprobs_dict(model_output) + hidden_states = rebuild_padding( model_output, self.share_inputs["cu_seqlens_q"], @@ -2071,29 +2053,6 @@ class at the server level, which is too granular for ModelRunner. (self.share_inputs["output_padding_offset"] if self.speculative_decoding else None), self.model_config.max_model_len, ) - # 遍历需要计算prompt_logprobs的请求 - completed_prefill_reqs = [] - for idx, request in enumerate(self.prompt_logprobs_reqs): - # 1.判断当前请求是否已经计算完prompt_logprobs - num_prompt_logprobs = request.sampling_params.prompt_logprobs - if request.prompt_token_ids is None or num_prompt_logprobs is None: - continue - if num_prompt_logprobs == -1: - num_prompt_logprobs = self.vocal_size - num_prompt_tokens = len(request.prompt_token_ids) - logprobs_tensors = self.in_progress_prompt_logprobs.get(request.idx) - if not logprobs_tensors: - logprobs_tensors = LogprobsTensors.empty_cpu(num_prompt_tokens, num_prompt_logprobs + 1) - self.in_progress_prompt_logprobs[request.idx] = logprobs_tensors - # 2.如果已经计算完prompt_logprobs,记录到completed_prefill_reqs,跳过 - # 3.判断需要chunked_prefill部分的prompt_logprobs - - # if self.in_progress_prompt_logprobs[bid] is None: - # self.in_progress_prompt_logprobs[bid] = self.share_inputs["input_ids"][:, :prompt_logprobs] - - # 清除已经计算完prompt_logprobs的请求 - - # 将prompt_logprob组装成batch,通过zmq返回 # 4. Compute logits, Sample logits = None @@ -2181,6 +2140,7 @@ class at the server level, which is too granular for ModelRunner. stop_token_ids=self.share_inputs["stop_seqs"], stop_seqs_len=self.share_inputs["stop_seqs_len"], prompt_lens=self.share_inputs["prompt_lens"], + prompt_logprobs_list=prompt_logprobs_list, ) if self.speculative_config.method in ["mtp"] and self.scheduler_config.splitwise_role == "prefill": @@ -2618,8 +2578,63 @@ def prepare_rope3d( def _get_prompt_logprobs_dict( self, hidden_states: paddle.Tensor, - num_scheduled_tokens: dict[str, int], - ) -> dict[str, Optional[LogprobsTensors]]: - # in_progress_dict = self.in_progress_prompt_logprobs_cpu - # prompt_logprobs_dict: dict[str, Optional[LogprobsTensors]] = {} - pass + ) -> list[Optional[LogprobsTensors]]: + logprobs_mode = self.fd_config.model_config.logprobs_mode + prompt_logprobs_list: list[Optional[LogprobsTensors]] = self.scheduler_config.max_num_seqs * [None] + completed_prefill_reqs: list[Request] = [] + for req_id, request in self.prompt_logprobs_reqs.items(): + num_prompt_logprobs = request.sampling_params.prompt_logprobs + if request.prompt_token_ids is None or num_prompt_logprobs is None: + continue + if num_prompt_logprobs == -1: + num_prompt_logprobs = self.vocal_size + + num_tokens = request.prefill_end_index - request.prefill_start_index + num_prompt_tokens = len(request.prompt_token_ids) + + logprobs_tensors = self.in_progress_prompt_logprobs.get(req_id) + if not logprobs_tensors: + logprobs_tensors = LogprobsTensors.empty(num_prompt_tokens - 1, num_prompt_logprobs + 1) + self.in_progress_prompt_logprobs[req_id] = logprobs_tensors + start_idx = request.prefill_start_index + start_tok = start_idx + 1 + num_remaining_tokens = num_prompt_tokens - start_tok + if num_tokens <= num_remaining_tokens: + # This is a chunk, more tokens remain. + # In the == case, there are no more prompt logprobs to produce + # but we want to defer returning them to the next step where we + # have new generated tokens to return. + num_logits = num_tokens + else: + # This is the last chunk of prompt tokens to return. + num_logits = num_remaining_tokens + completed_prefill_reqs.append(request) + prompt_logprobs_list[request.idx] = logprobs_tensors + if num_logits <= 0: + # This can happen for the final chunk if we prefilled exactly + # (num_prompt_tokens - 1) tokens for this request in the prior + # step. There are no more prompt logprobs to produce. + continue + offset = self.share_inputs["cu_seqlens_q"][request.idx] + prompt_hidden_states = hidden_states[offset : offset + num_logits] + logits = self.model.compute_logits(prompt_hidden_states) + prompt_token_ids = request.prompt_token_ids[start_tok : start_tok + num_logits] + if isinstance(prompt_token_ids, np.ndarray): + prompt_token_ids = prompt_token_ids.tolist() + prompt_token_ids_tensor = paddle.to_tensor(prompt_token_ids, dtype="int64") + if logprobs_mode == "raw_logprobs": + raw_logprobs = self.sampler.compute_logprobs(logits) + elif logprobs_mode == "raw_logits": + raw_logprobs = logits + token_ids, logprobs, ranks = self.sampler.gather_logprobs( + raw_logprobs, num_prompt_logprobs, prompt_token_ids_tensor + ) + chunk_slice = slice(start_idx, start_idx + num_logits) + logprobs_tensors.logprob_token_ids[chunk_slice].copy_(token_ids, False) + logprobs_tensors.logprobs[chunk_slice].copy_(logprobs, False) + logprobs_tensors.selected_token_ranks[chunk_slice].copy_(ranks, False) + + for req in completed_prefill_reqs: + del self.prompt_logprobs_reqs[req.request_id] + del self.in_progress_prompt_logprobs[req.request_id] + return prompt_logprobs_list diff --git a/fastdeploy/worker/output.py b/fastdeploy/worker/output.py index 60d298f7faf..af16edbfa70 100644 --- a/fastdeploy/worker/output.py +++ b/fastdeploy/worker/output.py @@ -98,7 +98,7 @@ def empty_cpu(num_positions: int, num_tokens_per_position: int) -> "LogprobsTens """Create empty LogprobsTensors on CPU.""" logprob_token_ids = paddle.empty([num_positions, num_tokens_per_position], dtype=paddle.int64).cpu() - logprobs = paddle.empty_like(logprob_token_ids, dtype=paddle.float32) + logprobs = paddle.empty_like(logprob_token_ids, dtype=paddle.float32).cpu() selected_token_ranks = paddle.empty([num_positions], dtype=paddle.int64).cpu() return LogprobsTensors( logprob_token_ids=logprob_token_ids, @@ -106,6 +106,19 @@ def empty_cpu(num_positions: int, num_tokens_per_position: int) -> "LogprobsTens selected_token_ranks=selected_token_ranks, ) + @staticmethod + def empty(num_positions: int, num_tokens_per_position: int) -> "LogprobsTensors": + """Create empty LogprobsTensors on default device.""" + + logprob_token_ids = paddle.empty([num_positions, num_tokens_per_position], dtype=paddle.int64) + logprobs = paddle.empty_like(logprob_token_ids, dtype=paddle.float32) + selected_token_ranks = paddle.empty([num_positions], dtype=paddle.int64) + return LogprobsTensors( + logprob_token_ids=logprob_token_ids, + logprobs=logprobs, + selected_token_ranks=selected_token_ranks, + ) + @dataclass class SamplerOutput: @@ -248,6 +261,8 @@ class ModelOutputData: """ prompt_lens: paddle.Tensor = None + prompt_logprobs_list: Optional[LogprobsTensors] = None + @dataclass class ModelRunnerOutput: From c42dd80a1ebf620b409718de1e5a81f47631e896 Mon Sep 17 00:00:00 2001 From: ckl117 Date: Mon, 3 Nov 2025 16:28:12 +0800 Subject: [PATCH 3/8] add doc for max_logprobs and code check --- docs/parameters.md | 1 + docs/zh/parameters.md | 1 + .../model_executor/pre_and_post_process.py | 2 +- fastdeploy/output/token_processor.py | 24 ++----------------- fastdeploy/worker/gpu_model_runner.py | 2 -- fastdeploy/worker/output.py | 11 +++++++++ 6 files changed, 16 insertions(+), 25 deletions(-) diff --git a/docs/parameters.md b/docs/parameters.md index 7a41dac37d6..16ac313498a 100644 --- a/docs/parameters.md +++ b/docs/parameters.md @@ -49,6 +49,7 @@ When using FastDeploy to deploy models (including offline inference and service | ```enable_expert_parallel``` | `bool` | Whether to enable expert parallel | | ```enable_logprob``` | `bool` | Whether to enable return log probabilities of the output tokens or not. If true, returns the log probabilities of each output token returned in the content of message.If logrpob is not used, this parameter can be omitted when starting | | ```logprobs_mode``` | `str` | Indicates the content returned in the logprobs. Supported mode: `raw_logprobs`, `processed_logprobs`, `raw_logits`, `processed_logits`. Raw means the values before applying logit processors, like bad words. Processed means the values after applying such processors. | +| ```max_logprobs``` | `int` | Maximum number of log probabilities to return, default: 20. -1 means vocab_size. | | ```served_model_name```| `str`| The model name used in the API. If not specified, the model name will be the same as the --model argument | | ```revision``` | `str` | The specific model version to use. It can be a branch name, a tag name, or a commit id. If unspecified, will use the default version. | | ```chat_template``` | `str` | Specify the template used for model concatenation, It supports both string input and file path input. The default value is None. If not specified, the model's default template will be used. | diff --git a/docs/zh/parameters.md b/docs/zh/parameters.md index 859685ff6b1..5eec25ce28c 100644 --- a/docs/zh/parameters.md +++ b/docs/zh/parameters.md @@ -47,6 +47,7 @@ | ```enable_expert_parallel``` | `bool` | 是否启用专家并行 | | ```enable_logprob``` | `bool` | 是否启用输出token返回logprob。如果未使用 logrpob,则在启动时可以省略此参数。 | | ```logprobs_mode``` | `str` | 指定logprobs中返回的内容。支持的模式:`raw_logprobs`、`processed_logprobs'、`raw_logits`,`processed_logits'。processed表示logits应用温度、惩罚、禁止词处理后计算的logprobs。| +| ```max_logprobs``` | `int` | 服务支持返回的最大logprob数量,默认20。-1表示词表大小。 | | ```served_model_name``` | `str` | API 中使用的模型名称,如果未指定,模型名称将与--model参数相同 | | ```revision``` | `str` | 自动下载模型时,用于指定模型的Git版本,分支名或tag | | ```chat_template``` | `str` | 指定模型拼接使用的模板,支持字符串与文件路径,默认为None,如未指定,则使用模型默认模板 | diff --git a/fastdeploy/model_executor/pre_and_post_process.py b/fastdeploy/model_executor/pre_and_post_process.py index 5f1cd69f1f4..a6fc1c02b2e 100644 --- a/fastdeploy/model_executor/pre_and_post_process.py +++ b/fastdeploy/model_executor/pre_and_post_process.py @@ -258,7 +258,7 @@ def _build_stream_transfer_data( decoder_state=DecoderState.TEXT, tokens=output_token_per_sample, batch_id=bid ) if logprobs: - logprobs = logprobs.tolists().slice_rows(bid, bid + 1) + logprobs = logprobs.slice_rows(bid, bid + 1) stream_transfer_data.logprobs = logprobs if prompt_logprobs_list: stream_transfer_data.prompt_logprobs = prompt_logprobs_list[bid] diff --git a/fastdeploy/output/token_processor.py b/fastdeploy/output/token_processor.py index 497fab1f231..5d94a2c2b90 100644 --- a/fastdeploy/output/token_processor.py +++ b/fastdeploy/output/token_processor.py @@ -241,11 +241,6 @@ def _process_batch_output_use_zmq(self, receive_datas): task_id = task.request_id token_ids = stream_data.tokens # numpy.array - logprobs = stream_data.logprobs - prompt_logprobs = stream_data.prompt_logprobs - llm_logger.info( - f"batch_id = {i}, task_id={task_id},token_ids={token_ids},logprobs={logprobs}, prompt_logprobs={prompt_logprobs}" - ) current_time = time.time() if self.tokens_counter[task_id] == 0: @@ -290,21 +285,6 @@ def _process_batch_output_use_zmq(self, receive_datas): finished=False, metrics=metrics, ) - if self.use_logprobs: - result.outputs.logprob = logprobs.logprobs[0] - topk_token_ids = logprobs.logprob_token_ids[1:] - topk_logprobs = logprobs.logprobs[1:] - sampled_rank = logprobs.sampled_token_ranks[0] - if result.outputs.top_logprobs is None: - result.outputs.top_logprobs = LogprobsLists( - logprob_token_ids=[topk_token_ids], - logprobs=[topk_logprobs], - sampled_token_ranks=[sampled_rank], - ) - else: - result.outputs.top_logprobs.logprob_token_ids.extend([topk_token_ids]) - result.outputs.top_logprobs.logprobs.extend([topk_logprobs]) - result.outputs.top_logprobs.sampled_token_ranks.extend([sampled_rank]) if self.tokens_counter[task_id] == 0: if task.messages is not None: result.prompt = task.messages @@ -326,8 +306,8 @@ def process_sampling_results_use_zmq(self): """ if self.speculative_decoding: raise NotImplementedError("GET_SAVE_OUTPUT_V1 does not support speculative decoding") - # if self.use_logprobs: - # raise NotImplementedError("GET_SAVE_OUTPUT_V1 does not support use_logprobs") + if self.use_logprobs: + raise NotImplementedError("GET_SAVE_OUTPUT_V1 does not support use_logprobs") rank_id = self.cfg.parallel_config.local_data_parallel_id while True: try: diff --git a/fastdeploy/worker/gpu_model_runner.py b/fastdeploy/worker/gpu_model_runner.py index 9b7122d02e2..c9a8bca4eae 100644 --- a/fastdeploy/worker/gpu_model_runner.py +++ b/fastdeploy/worker/gpu_model_runner.py @@ -2009,8 +2009,6 @@ class at the server level, which is too granular for ModelRunner. # 1. Prepare inputs of model and sampler. skip_idx_list = self._get_skip_idx(model_forward_batch) self._prepare_inputs() - # print(f"self.in_progress_prompt_logprobs = {self.in_progress_prompt_logprobs}") - print(f'ids_remove_padding = {self.share_inputs["ids_remove_padding"]}') self.sampler.pre_process(skip_idx_list) # 1.1 Update state of logits processor diff --git a/fastdeploy/worker/output.py b/fastdeploy/worker/output.py index af16edbfa70..37fad10ee1f 100644 --- a/fastdeploy/worker/output.py +++ b/fastdeploy/worker/output.py @@ -119,6 +119,17 @@ def empty(num_positions: int, num_tokens_per_position: int) -> "LogprobsTensors" selected_token_ranks=selected_token_ranks, ) + def slice_rows(self, start: int, end: int): + """ + Slice rows. + Keeps the number of max_num_logprobs unchanged. + """ + return LogprobsTensors( + self.logprob_token_ids[start:end], + self.logprobs[start:end], + self.selected_token_ranks[start:end], + ) + @dataclass class SamplerOutput: From fc70fd5cd0619e7fb5655c4d5d7d16f0ed64404b Mon Sep 17 00:00:00 2001 From: ckl117 Date: Mon, 3 Nov 2025 16:45:34 +0800 Subject: [PATCH 4/8] code check --- fastdeploy/model_executor/pre_and_post_process.py | 3 +-- fastdeploy/output/stream_transfer_data.py | 6 +++--- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/fastdeploy/model_executor/pre_and_post_process.py b/fastdeploy/model_executor/pre_and_post_process.py index fa4f6abce18..61e59625ea3 100644 --- a/fastdeploy/model_executor/pre_and_post_process.py +++ b/fastdeploy/model_executor/pre_and_post_process.py @@ -398,7 +398,6 @@ def post_process_normal( sampler_output.sampled_token_ids, model_output.is_block_step, ) - prompt_logprobs_list = model_output.prompt_logprobs_list # 3. Transmit the model's output and stop generation signal via message queue. # In the future, we will abandon this approach. if not skip_save_output: @@ -407,7 +406,7 @@ def post_process_normal( output = _build_stream_transfer_data( sampler_output.sampled_token_ids, logprobs=sampler_output.logprobs_tensors, - prompt_logprobs=prompt_logprobs_list, + prompt_logprobs=model_output.prompt_logprobs_list, ) async_output_queue.put(output) else: diff --git a/fastdeploy/output/stream_transfer_data.py b/fastdeploy/output/stream_transfer_data.py index 290ec793412..b32e01c954f 100644 --- a/fastdeploy/output/stream_transfer_data.py +++ b/fastdeploy/output/stream_transfer_data.py @@ -20,7 +20,7 @@ import numpy as np -from fastdeploy.worker.output import LogprobsLists +from fastdeploy.worker.output import LogprobsTensors class DecoderState(Enum): @@ -40,8 +40,8 @@ class StreamTransferData: batch_id: int tokens: Optional[np.array] = None speculaive_decoding: bool = False - logprobs: Optional[LogprobsLists] = None - prompt_logprobs: Optional[LogprobsLists] = None + logprobs: Optional[LogprobsTensors] = None + prompt_logprobs: Optional[LogprobsTensors] = None accept_tokens: Optional[np.array] = None accept_num: Optional[np.array] = None # [num_reqs, hidden_size] From 58d6466408f9874d044eb95fb2396118569a2f66 Mon Sep 17 00:00:00 2001 From: ckl117 Date: Mon, 3 Nov 2025 20:01:29 +0800 Subject: [PATCH 5/8] check --- fastdeploy/model_executor/pre_and_post_process.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fastdeploy/model_executor/pre_and_post_process.py b/fastdeploy/model_executor/pre_and_post_process.py index 61e59625ea3..8c9a01aad2f 100644 --- a/fastdeploy/model_executor/pre_and_post_process.py +++ b/fastdeploy/model_executor/pre_and_post_process.py @@ -406,7 +406,7 @@ def post_process_normal( output = _build_stream_transfer_data( sampler_output.sampled_token_ids, logprobs=sampler_output.logprobs_tensors, - prompt_logprobs=model_output.prompt_logprobs_list, + prompt_logprobs_list=model_output.prompt_logprobs_list, ) async_output_queue.put(output) else: From ba284852e8a9e4d71a01462dfbc4f3256e813fda Mon Sep 17 00:00:00 2001 From: ckl117 Date: Tue, 4 Nov 2025 11:23:15 +0800 Subject: [PATCH 6/8] code check --- fastdeploy/worker/gpu_model_runner.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/fastdeploy/worker/gpu_model_runner.py b/fastdeploy/worker/gpu_model_runner.py index 40354453660..4149632a826 100644 --- a/fastdeploy/worker/gpu_model_runner.py +++ b/fastdeploy/worker/gpu_model_runner.py @@ -587,7 +587,7 @@ def insert_tasks_v1(self, req_dicts: List[Request], num_running_requests: int = self.share_inputs["seq_lens_decoder"][idx : idx + 1] = 0 self.share_inputs["seq_lens_encoder"][idx : idx + 1] = 0 self.share_inputs["is_block_step"][idx : idx + 1] = False - self.num_prompt_logprobs.pop(request.request_id, None) + self.prompt_logprobs_reqs.pop(request.request_id, None) self.in_progress_prompt_logprobs.pop(request.request_id, None) continue @@ -2094,7 +2094,7 @@ class at the server level, which is too granular for ModelRunner. if self.use_cudagraph: model_output = model_output[: self.real_token_num] - prompt_logprobs_list = self._get_prompt_logprobs_dict(model_output) + prompt_logprobs_list = self._get_prompt_logprobs_list(model_output) if self.is_pooling_model: hidden_states = model_output @@ -2706,10 +2706,13 @@ def prepare_rope3d( ) return rope_emb_lst - def _get_prompt_logprobs_dict( + def _get_prompt_logprobs_list( self, hidden_states: paddle.Tensor, ) -> list[Optional[LogprobsTensors]]: + assert ( + not self.fd_config.cache_config.enable_prefix_caching + ), "prompt_logprobs must disable prefix caching, --no-enable-prefix-caching." logprobs_mode = self.fd_config.model_config.logprobs_mode prompt_logprobs_list: list[Optional[LogprobsTensors]] = self.scheduler_config.max_num_seqs * [None] completed_prefill_reqs: list[Request] = [] From 4c2f9ecd404ae2239d26d6d39d65c575c22672cc Mon Sep 17 00:00:00 2001 From: ckl117 Date: Tue, 4 Nov 2025 15:21:01 +0800 Subject: [PATCH 7/8] clear_requests clear prompt_logprob and _get_prompt_logprobs_list() check --- fastdeploy/worker/gpu_model_runner.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/fastdeploy/worker/gpu_model_runner.py b/fastdeploy/worker/gpu_model_runner.py index 4149632a826..6c2b68ed880 100644 --- a/fastdeploy/worker/gpu_model_runner.py +++ b/fastdeploy/worker/gpu_model_runner.py @@ -2496,6 +2496,9 @@ def clear_parameters(self, pid): def clear_requests(self): """Dynamic model loader use to clear requests use for RL""" self.share_inputs["stop_flags"][:] = True + # prompt_logprobs + self.prompt_logprobs_reqs.clear() + self.in_progress_prompt_logprobs.clear() def update_parameters(self, pid): """Dynamic model loader use to update parameters use for RL""" @@ -2710,9 +2713,10 @@ def _get_prompt_logprobs_list( self, hidden_states: paddle.Tensor, ) -> list[Optional[LogprobsTensors]]: - assert ( - not self.fd_config.cache_config.enable_prefix_caching - ), "prompt_logprobs must disable prefix caching, --no-enable-prefix-caching." + if len(self.prompt_logprobs_reqs) > 0: + assert ( + not self.fd_config.cache_config.enable_prefix_caching + ), "prompt_logprobs must disable prefix caching, --no-enable-prefix-caching." logprobs_mode = self.fd_config.model_config.logprobs_mode prompt_logprobs_list: list[Optional[LogprobsTensors]] = self.scheduler_config.max_num_seqs * [None] completed_prefill_reqs: list[Request] = [] @@ -2753,8 +2757,6 @@ def _get_prompt_logprobs_list( prompt_hidden_states = hidden_states[offset : offset + num_logits] logits = self.model.compute_logits(prompt_hidden_states) prompt_token_ids = request.prompt_token_ids[start_tok : start_tok + num_logits] - if isinstance(prompt_token_ids, np.ndarray): - prompt_token_ids = prompt_token_ids.tolist() prompt_token_ids_tensor = paddle.to_tensor(prompt_token_ids, dtype="int64") if logprobs_mode == "raw_logprobs": raw_logprobs = self.sampler.compute_logprobs(logits) From 1a4edad18d2284d107adda144be7ed2f95fb6c83 Mon Sep 17 00:00:00 2001 From: ckl117 Date: Tue, 4 Nov 2025 19:39:01 +0800 Subject: [PATCH 8/8] check pooling model req sampling_params is None --- fastdeploy/worker/gpu_model_runner.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/fastdeploy/worker/gpu_model_runner.py b/fastdeploy/worker/gpu_model_runner.py index 6c2b68ed880..1d031bda428 100644 --- a/fastdeploy/worker/gpu_model_runner.py +++ b/fastdeploy/worker/gpu_model_runner.py @@ -565,7 +565,8 @@ def insert_tasks_v1(self, req_dicts: List[Request], num_running_requests: int = len(request.output_token_ids) if prefill_end_index >= len(input_ids) else 0 ) self.share_inputs["pre_ids"][idx : idx + 1] = -1 - if request.sampling_params.prompt_logprobs is not None: + # pooling model request.sampling_params is None + if request.sampling_params is not None and request.sampling_params.prompt_logprobs is not None: self.prompt_logprobs_reqs[request.request_id] = request has_prefill_task = True elif request.task_type.value == RequestType.DECODE.value: # decode task