diff --git a/docs/parameters.md b/docs/parameters.md index 7a41dac37d6..16ac313498a 100644 --- a/docs/parameters.md +++ b/docs/parameters.md @@ -49,6 +49,7 @@ When using FastDeploy to deploy models (including offline inference and service | ```enable_expert_parallel``` | `bool` | Whether to enable expert parallel | | ```enable_logprob``` | `bool` | Whether to enable return log probabilities of the output tokens or not. If true, returns the log probabilities of each output token returned in the content of message.If logrpob is not used, this parameter can be omitted when starting | | ```logprobs_mode``` | `str` | Indicates the content returned in the logprobs. Supported mode: `raw_logprobs`, `processed_logprobs`, `raw_logits`, `processed_logits`. Raw means the values before applying logit processors, like bad words. Processed means the values after applying such processors. | +| ```max_logprobs``` | `int` | Maximum number of log probabilities to return, default: 20. -1 means vocab_size. | | ```served_model_name```| `str`| The model name used in the API. If not specified, the model name will be the same as the --model argument | | ```revision``` | `str` | The specific model version to use. It can be a branch name, a tag name, or a commit id. If unspecified, will use the default version. | | ```chat_template``` | `str` | Specify the template used for model concatenation, It supports both string input and file path input. The default value is None. If not specified, the model's default template will be used. | diff --git a/docs/zh/parameters.md b/docs/zh/parameters.md index 859685ff6b1..5eec25ce28c 100644 --- a/docs/zh/parameters.md +++ b/docs/zh/parameters.md @@ -47,6 +47,7 @@ | ```enable_expert_parallel``` | `bool` | 是否启用专家并行 | | ```enable_logprob``` | `bool` | 是否启用输出token返回logprob。如果未使用 logrpob,则在启动时可以省略此参数。 | | ```logprobs_mode``` | `str` | 指定logprobs中返回的内容。支持的模式:`raw_logprobs`、`processed_logprobs'、`raw_logits`,`processed_logits'。processed表示logits应用温度、惩罚、禁止词处理后计算的logprobs。| +| ```max_logprobs``` | `int` | 服务支持返回的最大logprob数量,默认20。-1表示词表大小。 | | ```served_model_name``` | `str` | API 中使用的模型名称,如果未指定,模型名称将与--model参数相同 | | ```revision``` | `str` | 自动下载模型时,用于指定模型的Git版本,分支名或tag | | ```chat_template``` | `str` | 指定模型拼接使用的模板,支持字符串与文件路径,默认为None,如未指定,则使用模型默认模板 | diff --git a/fastdeploy/config.py b/fastdeploy/config.py index 28753292210..c9228d9f619 100644 --- a/fastdeploy/config.py +++ b/fastdeploy/config.py @@ -183,6 +183,7 @@ def __init__( self.max_model_len = 0 self.dtype = "bfloat16" self.enable_logprob = False + self.max_logprobs = 20 self.logprobs_mode = "raw_logprobs" self.enable_redundant_experts = False self.redundant_experts_num = 0 @@ -227,6 +228,8 @@ def __init__( self.think_end_id = args.get("think_end_id", -1) self.im_patch_id = args.get("image_patch_id", -1) self.line_break_id = args.get("line_break_id", -1) + if self.max_logprobs == -1 and hasattr(self, "vocab_size"): + self.max_logprobs = self.vocab_size self._post_init() diff --git a/fastdeploy/engine/args_utils.py b/fastdeploy/engine/args_utils.py index 2ce7482d05b..cd80e1d614f 100644 --- a/fastdeploy/engine/args_utils.py +++ b/fastdeploy/engine/args_utils.py @@ -44,6 +44,7 @@ from fastdeploy.utils import ( DeprecatedOptionWarning, FlexibleArgumentParser, + console_logger, is_port_available, parse_quantization, ) @@ -387,6 +388,12 @@ class EngineArgs: Must be explicitly enabled via the `--enable-logprob` startup parameter to output logprob values. """ + max_logprobs: int = 20 + """ + Maximum number of log probabilities to return when `enable_logprob` is True. The default value comes the default for the + OpenAI Chat Completions API. -1 means no cap, i.e. all (output_length * vocab_size) logprobs are allowed to be returned and it may cause OOM. + """ + logprobs_mode: str = "raw_logprobs" """ Indicates the content returned in the logprobs. @@ -453,6 +460,13 @@ def __post_init__(self): raise NotImplementedError("Only CUDA platform supports logprob.") if self.speculative_config is not None and self.logprobs_mode.startswith("processed"): raise NotImplementedError("processed_logprobs not support in speculative.") + if self.speculative_config is not None and self.max_logprobs == -1: + raise NotImplementedError("max_logprobs=-1 not support in speculative.") + if not envs.FD_USE_GET_SAVE_OUTPUT_V1: + self.max_logprobs = 20 + console_logger.warning("Set max_logprobs=20 when FD_USE_GET_SAVE_OUTPUT_V1=0") + if self.max_logprobs == -1 and not envs.ENABLE_V1_KVCACHE_SCHEDULER: + raise NotImplementedError("Only ENABLE_V1_KVCACHE_SCHEDULER=1 support max_logprobs=-1") if self.splitwise_role != "mixed" and self.cache_transfer_protocol != "rdma": envs.ENABLE_V1_KVCACHE_SCHEDULER = 0 @@ -667,6 +681,12 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: default=EngineArgs.enable_logprob, help="Enable output of token-level log probabilities.", ) + model_group.add_argument( + "--max-logprobs", + type=int, + default=EngineArgs.max_logprobs, + help="Maximum number of log probabilities.", + ) model_group.add_argument( "--logprobs-mode", type=str, diff --git a/fastdeploy/engine/async_llm.py b/fastdeploy/engine/async_llm.py index b600290191d..249fe2d94b8 100644 --- a/fastdeploy/engine/async_llm.py +++ b/fastdeploy/engine/async_llm.py @@ -830,6 +830,7 @@ def _start_worker_service(self): f" --convert {self.cfg.model_config.convert}" f" --override-pooler-config {self.cfg.model_config.override_pooler_config}" f" --logprobs_mode {self.cfg.model_config.logprobs_mode}" + f" --max_logprobs {self.cfg.model_config.max_logprobs}" ) worker_store_true_flag = { diff --git a/fastdeploy/engine/engine.py b/fastdeploy/engine/engine.py index e4c0b717ad8..8e500df4a96 100644 --- a/fastdeploy/engine/engine.py +++ b/fastdeploy/engine/engine.py @@ -551,6 +551,7 @@ def _start_worker_service(self): f" --convert {self.cfg.model_config.convert}" f" --override-pooler-config {self.cfg.model_config.override_pooler_config}" f" --logprobs_mode {self.cfg.model_config.logprobs_mode}" + f" --max_logprobs {self.cfg.model_config.max_logprobs}" ) if self.cfg.structured_outputs_config.logits_processors is not None: arguments += f" --logits-processors {' '.join(self.cfg.structured_outputs_config.logits_processors)}" diff --git a/fastdeploy/engine/sampling_params.py b/fastdeploy/engine/sampling_params.py index ef26cd380e0..23b77c7be95 100644 --- a/fastdeploy/engine/sampling_params.py +++ b/fastdeploy/engine/sampling_params.py @@ -97,6 +97,7 @@ class SamplingParams: reasoning_max_tokens: Optional[int] = None min_tokens: int = 1 logprobs: Optional[int] = None + prompt_logprobs: Optional[int] = None # For logits and logprobs post processing temp_scaled_logprobs: bool = False top_p_normalized_logprobs: bool = False diff --git a/fastdeploy/model_executor/layers/sample/sampler.py b/fastdeploy/model_executor/layers/sample/sampler.py index b51efd67d69..622c1994670 100644 --- a/fastdeploy/model_executor/layers/sample/sampler.py +++ b/fastdeploy/model_executor/layers/sample/sampler.py @@ -255,9 +255,11 @@ def post_process(self, next_tokens: paddle.Tensor, skip_idx_list: List[int] = [] def compute_logprobs( self, logits: paddle.Tensor, - sampling_metadata: SamplingMetadata, + sampling_metadata: Optional[SamplingMetadata] = None, ) -> paddle.Tensor: """ """ + if sampling_metadata is None: + return F.log_softmax(logits, axis=-1) last_logits = logits real_bsz = last_logits.shape[0] temp_scaled_logprobs = sampling_metadata.temp_scaled_logprobs @@ -317,6 +319,8 @@ def gather_logprobs( assert token_ids.dtype == paddle.int64 logprobs.clip_(min=paddle.finfo(logprobs.dtype).min) # Get with the logprob of the prompt or sampled token. + if len(token_ids.shape) < len(logprobs.shape): + token_ids = token_ids.unsqueeze(-1) token_logprobs = paddle.take_along_axis(logprobs, token_ids, axis=-1) # Compute the ranks of the actual token. diff --git a/fastdeploy/model_executor/pre_and_post_process.py b/fastdeploy/model_executor/pre_and_post_process.py index 3467e7eafe1..a5354da8d32 100644 --- a/fastdeploy/model_executor/pre_and_post_process.py +++ b/fastdeploy/model_executor/pre_and_post_process.py @@ -91,7 +91,12 @@ from fastdeploy.output.pooler import PoolerOutput, PoolingSequenceGroupOutput from fastdeploy.output.stream_transfer_data import DecoderState, StreamTransferData -from fastdeploy.worker.output import ModelOutputData, ModelRunnerOutput, SamplerOutput +from fastdeploy.worker.output import ( + LogprobsTensors, + ModelOutputData, + ModelRunnerOutput, + SamplerOutput, +) DISABLE_RECOVER = envs.FD_DISABLED_RECOVER == "1" @@ -253,7 +258,12 @@ def pre_process( ) -def _build_stream_transfer_data(output_tokens: np.ndarray, pooler_outputs: List[PoolingSequenceGroupOutput] = None): +def _build_stream_transfer_data( + output_tokens: paddle.Tensor, + pooler_outputs: List[PoolingSequenceGroupOutput] = None, + logprobs: Optional[LogprobsTensors] = None, + prompt_logprobs_list: Optional[LogprobsTensors] = None, +): """Split output_tokens and output""" stream_transfer_datas = [] @@ -266,6 +276,11 @@ def _build_stream_transfer_data(output_tokens: np.ndarray, pooler_outputs: List[ stream_transfer_data = StreamTransferData( decoder_state=DecoderState.TEXT, tokens=output_token_per_sample, batch_id=bid ) + if logprobs: + logprobs = logprobs.slice_rows(bid, bid + 1) + stream_transfer_data.logprobs = logprobs + if prompt_logprobs_list: + stream_transfer_data.prompt_logprobs = prompt_logprobs_list[bid] stream_transfer_datas.append(stream_transfer_data) elif pooler_outputs is not None: for bid, pooler_output in enumerate(pooler_outputs): @@ -390,27 +405,31 @@ def post_process_normal( # 3. Transmit the model's output and stop generation signal via message queue. # In the future, we will abandon this approach. if not skip_save_output: - if sampler_output.logprobs_tensors is None: - if envs.FD_USE_GET_SAVE_OUTPUT_V1: - if save_each_rank or model_output.mp_rank == 0: - output = _build_stream_transfer_data(sampler_output.sampled_token_ids) - async_output_queue.put(output) - else: + if envs.FD_USE_GET_SAVE_OUTPUT_V1: + if save_each_rank or model_output.mp_rank == 0: + output = _build_stream_transfer_data( + sampler_output.sampled_token_ids, + logprobs=sampler_output.logprobs_tensors, + prompt_logprobs_list=model_output.prompt_logprobs_list, + ) + async_output_queue.put(output) + else: + if sampler_output.logprobs_tensors is None: save_output( sampler_output.sampled_token_ids, model_output.not_need_stop, model_output.mp_rank, save_each_rank, ) - else: - save_output_topk( - sampler_output.sampled_token_ids, - sampler_output.logprobs_tensors.logprob_token_ids, - sampler_output.logprobs_tensors.logprobs, - sampler_output.logprobs_tensors.selected_token_ranks, - model_output.not_need_stop, - model_output.mp_rank, - ) + else: + save_output_topk( + sampler_output.sampled_token_ids, + sampler_output.logprobs_tensors.logprob_token_ids, + sampler_output.logprobs_tensors.logprobs, + sampler_output.logprobs_tensors.selected_token_ranks, + model_output.not_need_stop, + model_output.mp_rank, + ) def post_process_specualate( diff --git a/fastdeploy/output/stream_transfer_data.py b/fastdeploy/output/stream_transfer_data.py index f2d71f9fc2d..b32e01c954f 100644 --- a/fastdeploy/output/stream_transfer_data.py +++ b/fastdeploy/output/stream_transfer_data.py @@ -20,6 +20,8 @@ import numpy as np +from fastdeploy.worker.output import LogprobsTensors + class DecoderState(Enum): """DecoderState""" @@ -38,7 +40,8 @@ class StreamTransferData: batch_id: int tokens: Optional[np.array] = None speculaive_decoding: bool = False - logprobs: Optional[np.array] = None + logprobs: Optional[LogprobsTensors] = None + prompt_logprobs: Optional[LogprobsTensors] = None accept_tokens: Optional[np.array] = None accept_num: Optional[np.array] = None # [num_reqs, hidden_size] diff --git a/fastdeploy/worker/gpu_model_runner.py b/fastdeploy/worker/gpu_model_runner.py index 2a6aa553eba..1d031bda428 100644 --- a/fastdeploy/worker/gpu_model_runner.py +++ b/fastdeploy/worker/gpu_model_runner.py @@ -92,7 +92,7 @@ from fastdeploy.model_executor.models.interfaces_base import FdModelForPooling from fastdeploy.output.pooler import PoolerOutput from fastdeploy.worker.model_runner_base import ModelRunnerBase -from fastdeploy.worker.output import ModelOutputData, ModelRunnerOutput +from fastdeploy.worker.output import LogprobsTensors, ModelOutputData, ModelRunnerOutput class GPUModelRunner(ModelRunnerBase): @@ -112,8 +112,12 @@ def __init__( self.speculative_method = self.fd_config.speculative_config.method self.speculative_decoding = self.speculative_method is not None self.enable_logprob = fd_config.model_config.enable_logprob + self.max_logprobs = fd_config.model_config.max_logprobs self.enable_early_stop = self.fd_config.early_stop_config.enable_early_stop self.is_pooling_model = self.fd_config.model_config.runner_type == "pooling" + self.vocal_size = self.fd_config.model_config.vocab_size + self.prompt_logprobs_reqs: dict[str, Request] = {} + self.in_progress_prompt_logprobs: dict[str, LogprobsTensors] = {} # VL model config: if self.enable_mm: @@ -561,6 +565,9 @@ def insert_tasks_v1(self, req_dicts: List[Request], num_running_requests: int = len(request.output_token_ids) if prefill_end_index >= len(input_ids) else 0 ) self.share_inputs["pre_ids"][idx : idx + 1] = -1 + # pooling model request.sampling_params is None + if request.sampling_params is not None and request.sampling_params.prompt_logprobs is not None: + self.prompt_logprobs_reqs[request.request_id] = request has_prefill_task = True elif request.task_type.value == RequestType.DECODE.value: # decode task logger.debug(f"Handle decode request {request} at idx {idx}") @@ -581,6 +588,8 @@ def insert_tasks_v1(self, req_dicts: List[Request], num_running_requests: int = self.share_inputs["seq_lens_decoder"][idx : idx + 1] = 0 self.share_inputs["seq_lens_encoder"][idx : idx + 1] = 0 self.share_inputs["is_block_step"][idx : idx + 1] = False + self.prompt_logprobs_reqs.pop(request.request_id, None) + self.in_progress_prompt_logprobs.pop(request.request_id, None) continue assert len(request.eos_token_ids) == self.model_config.eos_tokens_lens @@ -1282,7 +1291,7 @@ def _prepare_inputs(self) -> None: min_dec_lens=self.share_inputs["min_dec_len"], bad_words_token_ids=self.share_inputs["bad_tokens"][:, :max_bad_tokens_len], eos_token_ids=self.share_inputs["eos_token_id"], - max_num_logprobs=20 if self.enable_logprob else None, + max_num_logprobs=self.max_logprobs if self.enable_logprob else None, enable_early_stop=self.enable_early_stop, stop_flags=self.share_inputs["stop_flags"], temp_scaled_logprobs=self.share_inputs["temp_scaled_logprobs"], @@ -2086,6 +2095,8 @@ class at the server level, which is too granular for ModelRunner. if self.use_cudagraph: model_output = model_output[: self.real_token_num] + prompt_logprobs_list = self._get_prompt_logprobs_list(model_output) + if self.is_pooling_model: hidden_states = model_output pooler_output = self._pool(hidden_states, num_running_requests) @@ -2228,6 +2239,7 @@ class at the server level, which is too granular for ModelRunner. stop_seqs_len=self.share_inputs["stop_seqs_len"], prompt_lens=self.share_inputs["prompt_lens"], mask_rollback=self.share_inputs["mask_rollback"], + prompt_logprobs_list=prompt_logprobs_list, ) if self.speculative_config.method in ["mtp"] and self.scheduler_config.splitwise_role == "prefill": @@ -2485,6 +2497,9 @@ def clear_parameters(self, pid): def clear_requests(self): """Dynamic model loader use to clear requests use for RL""" self.share_inputs["stop_flags"][:] = True + # prompt_logprobs + self.prompt_logprobs_reqs.clear() + self.in_progress_prompt_logprobs.clear() def update_parameters(self, pid): """Dynamic model loader use to update parameters use for RL""" @@ -2694,3 +2709,69 @@ def prepare_rope3d( cumsum_seqlens=cumsum_seqlens, ) return rope_emb_lst + + def _get_prompt_logprobs_list( + self, + hidden_states: paddle.Tensor, + ) -> list[Optional[LogprobsTensors]]: + if len(self.prompt_logprobs_reqs) > 0: + assert ( + not self.fd_config.cache_config.enable_prefix_caching + ), "prompt_logprobs must disable prefix caching, --no-enable-prefix-caching." + logprobs_mode = self.fd_config.model_config.logprobs_mode + prompt_logprobs_list: list[Optional[LogprobsTensors]] = self.scheduler_config.max_num_seqs * [None] + completed_prefill_reqs: list[Request] = [] + for req_id, request in self.prompt_logprobs_reqs.items(): + num_prompt_logprobs = request.sampling_params.prompt_logprobs + if request.prompt_token_ids is None or num_prompt_logprobs is None: + continue + if num_prompt_logprobs == -1: + num_prompt_logprobs = self.vocal_size + + num_tokens = request.prefill_end_index - request.prefill_start_index + num_prompt_tokens = len(request.prompt_token_ids) + + logprobs_tensors = self.in_progress_prompt_logprobs.get(req_id) + if not logprobs_tensors: + logprobs_tensors = LogprobsTensors.empty(num_prompt_tokens - 1, num_prompt_logprobs + 1) + self.in_progress_prompt_logprobs[req_id] = logprobs_tensors + start_idx = request.prefill_start_index + start_tok = start_idx + 1 + num_remaining_tokens = num_prompt_tokens - start_tok + if num_tokens <= num_remaining_tokens: + # This is a chunk, more tokens remain. + # In the == case, there are no more prompt logprobs to produce + # but we want to defer returning them to the next step where we + # have new generated tokens to return. + num_logits = num_tokens + else: + # This is the last chunk of prompt tokens to return. + num_logits = num_remaining_tokens + completed_prefill_reqs.append(request) + prompt_logprobs_list[request.idx] = logprobs_tensors + if num_logits <= 0: + # This can happen for the final chunk if we prefilled exactly + # (num_prompt_tokens - 1) tokens for this request in the prior + # step. There are no more prompt logprobs to produce. + continue + offset = self.share_inputs["cu_seqlens_q"][request.idx] + prompt_hidden_states = hidden_states[offset : offset + num_logits] + logits = self.model.compute_logits(prompt_hidden_states) + prompt_token_ids = request.prompt_token_ids[start_tok : start_tok + num_logits] + prompt_token_ids_tensor = paddle.to_tensor(prompt_token_ids, dtype="int64") + if logprobs_mode == "raw_logprobs": + raw_logprobs = self.sampler.compute_logprobs(logits) + elif logprobs_mode == "raw_logits": + raw_logprobs = logits + token_ids, logprobs, ranks = self.sampler.gather_logprobs( + raw_logprobs, num_prompt_logprobs, prompt_token_ids_tensor + ) + chunk_slice = slice(start_idx, start_idx + num_logits) + logprobs_tensors.logprob_token_ids[chunk_slice].copy_(token_ids, False) + logprobs_tensors.logprobs[chunk_slice].copy_(logprobs, False) + logprobs_tensors.selected_token_ranks[chunk_slice].copy_(ranks, False) + + for req in completed_prefill_reqs: + del self.prompt_logprobs_reqs[req.request_id] + del self.in_progress_prompt_logprobs[req.request_id] + return prompt_logprobs_list diff --git a/fastdeploy/worker/output.py b/fastdeploy/worker/output.py index 7d3a006cacc..9121f85261f 100644 --- a/fastdeploy/worker/output.py +++ b/fastdeploy/worker/output.py @@ -63,6 +63,17 @@ def slice_columns(self, start: int, end: int): self.sampled_token_ranks, # unchanged ) + def slice_rows(self, start: int, end: int): + """ + Slice rows. + Keeps the number of max_num_logprobs unchanged. + """ + return LogprobsLists( + self.logprob_token_ids[start:end], + self.logprobs[start:end], + self.sampled_token_ranks[start:end], + ) + class LogprobsTensors(NamedTuple): """ """ @@ -87,7 +98,7 @@ def empty_cpu(num_positions: int, num_tokens_per_position: int) -> "LogprobsTens """Create empty LogprobsTensors on CPU.""" logprob_token_ids = paddle.empty([num_positions, num_tokens_per_position], dtype=paddle.int64).cpu() - logprobs = paddle.empty_like(logprob_token_ids, dtype=paddle.float32) + logprobs = paddle.empty_like(logprob_token_ids, dtype=paddle.float32).cpu() selected_token_ranks = paddle.empty([num_positions], dtype=paddle.int64).cpu() return LogprobsTensors( logprob_token_ids=logprob_token_ids, @@ -95,6 +106,30 @@ def empty_cpu(num_positions: int, num_tokens_per_position: int) -> "LogprobsTens selected_token_ranks=selected_token_ranks, ) + @staticmethod + def empty(num_positions: int, num_tokens_per_position: int) -> "LogprobsTensors": + """Create empty LogprobsTensors on default device.""" + + logprob_token_ids = paddle.empty([num_positions, num_tokens_per_position], dtype=paddle.int64) + logprobs = paddle.empty_like(logprob_token_ids, dtype=paddle.float32) + selected_token_ranks = paddle.empty([num_positions], dtype=paddle.int64) + return LogprobsTensors( + logprob_token_ids=logprob_token_ids, + logprobs=logprobs, + selected_token_ranks=selected_token_ranks, + ) + + def slice_rows(self, start: int, end: int): + """ + Slice rows. + Keeps the number of max_num_logprobs unchanged. + """ + return LogprobsTensors( + self.logprob_token_ids[start:end], + self.logprobs[start:end], + self.selected_token_ranks[start:end], + ) + @dataclass class SamplerOutput: @@ -242,6 +277,11 @@ class ModelOutputData: """ mask_rollback: paddle.Tensor = None + """ + prompt_logprobs + """ + prompt_logprobs_list: Optional[LogprobsTensors] = None + @dataclass class ModelRunnerOutput: diff --git a/fastdeploy/worker/worker_process.py b/fastdeploy/worker/worker_process.py index 8cedfea3762..1466ba2eae0 100644 --- a/fastdeploy/worker/worker_process.py +++ b/fastdeploy/worker/worker_process.py @@ -740,6 +740,12 @@ def parse_args(): action="store_true", help="Enable output of token-level log probabilities.", ) + parser.add_argument( + "--max_logprobs", + type=int, + default=20, + help="Maximum number of log probabilities.", + ) parser.add_argument( "--logprobs_mode", type=str,