Skip to content
1 change: 1 addition & 0 deletions docs/parameters.md
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ When using FastDeploy to deploy models (including offline inference and service
| ```enable_expert_parallel``` | `bool` | Whether to enable expert parallel |
| ```enable_logprob``` | `bool` | Whether to enable return log probabilities of the output tokens or not. If true, returns the log probabilities of each output token returned in the content of message.If logrpob is not used, this parameter can be omitted when starting |
| ```logprobs_mode``` | `str` | Indicates the content returned in the logprobs. Supported mode: `raw_logprobs`, `processed_logprobs`, `raw_logits`, `processed_logits`. Raw means the values before applying logit processors, like bad words. Processed means the values after applying such processors. |
| ```max_logprobs``` | `int` | Maximum number of log probabilities to return, default: 20. -1 means vocab_size. |
| ```served_model_name```| `str`| The model name used in the API. If not specified, the model name will be the same as the --model argument |
| ```revision``` | `str` | The specific model version to use. It can be a branch name, a tag name, or a commit id. If unspecified, will use the default version. |
| ```chat_template``` | `str` | Specify the template used for model concatenation, It supports both string input and file path input. The default value is None. If not specified, the model's default template will be used. |
Expand Down
1 change: 1 addition & 0 deletions docs/zh/parameters.md
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
| ```enable_expert_parallel``` | `bool` | 是否启用专家并行 |
| ```enable_logprob``` | `bool` | 是否启用输出token返回logprob。如果未使用 logrpob,则在启动时可以省略此参数。 |
| ```logprobs_mode``` | `str` | 指定logprobs中返回的内容。支持的模式:`raw_logprobs`、`processed_logprobs'、`raw_logits`,`processed_logits'。processed表示logits应用温度、惩罚、禁止词处理后计算的logprobs。|
| ```max_logprobs``` | `int` | 服务支持返回的最大logprob数量,默认20。-1表示词表大小。 |
| ```served_model_name``` | `str` | API 中使用的模型名称,如果未指定,模型名称将与--model参数相同 |
| ```revision``` | `str` | 自动下载模型时,用于指定模型的Git版本,分支名或tag |
| ```chat_template``` | `str` | 指定模型拼接使用的模板,支持字符串与文件路径,默认为None,如未指定,则使用模型默认模板 |
Expand Down
3 changes: 3 additions & 0 deletions fastdeploy/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,7 @@ def __init__(
self.max_model_len = 0
self.dtype = "bfloat16"
self.enable_logprob = False
self.max_logprobs = 20
self.logprobs_mode = "raw_logprobs"
self.enable_redundant_experts = False
self.redundant_experts_num = 0
Expand Down Expand Up @@ -227,6 +228,8 @@ def __init__(
self.think_end_id = args.get("think_end_id", -1)
self.im_patch_id = args.get("image_patch_id", -1)
self.line_break_id = args.get("line_break_id", -1)
if self.max_logprobs == -1 and hasattr(self, "vocab_size"):
self.max_logprobs = self.vocab_size

self._post_init()

Expand Down
20 changes: 20 additions & 0 deletions fastdeploy/engine/args_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
from fastdeploy.utils import (
DeprecatedOptionWarning,
FlexibleArgumentParser,
console_logger,
is_port_available,
parse_quantization,
)
Expand Down Expand Up @@ -387,6 +388,12 @@ class EngineArgs:
Must be explicitly enabled via the `--enable-logprob` startup parameter to output logprob values.
"""

max_logprobs: int = 20
"""
Maximum number of log probabilities to return when `enable_logprob` is True. The default value comes the default for the
OpenAI Chat Completions API. -1 means no cap, i.e. all (output_length * vocab_size) logprobs are allowed to be returned and it may cause OOM.
"""

logprobs_mode: str = "raw_logprobs"
"""
Indicates the content returned in the logprobs.
Expand Down Expand Up @@ -453,6 +460,13 @@ def __post_init__(self):
raise NotImplementedError("Only CUDA platform supports logprob.")
if self.speculative_config is not None and self.logprobs_mode.startswith("processed"):
raise NotImplementedError("processed_logprobs not support in speculative.")
if self.speculative_config is not None and self.max_logprobs == -1:
raise NotImplementedError("max_logprobs=-1 not support in speculative.")
if not envs.FD_USE_GET_SAVE_OUTPUT_V1:
self.max_logprobs = 20
console_logger.warning("Set max_logprobs=20 when FD_USE_GET_SAVE_OUTPUT_V1=0")
if self.max_logprobs == -1 and not envs.ENABLE_V1_KVCACHE_SCHEDULER:
raise NotImplementedError("Only ENABLE_V1_KVCACHE_SCHEDULER=1 support max_logprobs=-1")

if self.splitwise_role != "mixed" and self.cache_transfer_protocol != "rdma":
envs.ENABLE_V1_KVCACHE_SCHEDULER = 0
Expand Down Expand Up @@ -667,6 +681,12 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
default=EngineArgs.enable_logprob,
help="Enable output of token-level log probabilities.",
)
model_group.add_argument(
"--max-logprobs",
type=int,
default=EngineArgs.max_logprobs,
help="Maximum number of log probabilities.",
)
model_group.add_argument(
"--logprobs-mode",
type=str,
Expand Down
1 change: 1 addition & 0 deletions fastdeploy/engine/async_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -830,6 +830,7 @@ def _start_worker_service(self):
f" --convert {self.cfg.model_config.convert}"
f" --override-pooler-config {self.cfg.model_config.override_pooler_config}"
f" --logprobs_mode {self.cfg.model_config.logprobs_mode}"
f" --max_logprobs {self.cfg.model_config.max_logprobs}"
)

worker_store_true_flag = {
Expand Down
1 change: 1 addition & 0 deletions fastdeploy/engine/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -551,6 +551,7 @@ def _start_worker_service(self):
f" --convert {self.cfg.model_config.convert}"
f" --override-pooler-config {self.cfg.model_config.override_pooler_config}"
f" --logprobs_mode {self.cfg.model_config.logprobs_mode}"
f" --max_logprobs {self.cfg.model_config.max_logprobs}"
)
if self.cfg.structured_outputs_config.logits_processors is not None:
arguments += f" --logits-processors {' '.join(self.cfg.structured_outputs_config.logits_processors)}"
Expand Down
1 change: 1 addition & 0 deletions fastdeploy/engine/sampling_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ class SamplingParams:
reasoning_max_tokens: Optional[int] = None
min_tokens: int = 1
logprobs: Optional[int] = None
prompt_logprobs: Optional[int] = None
# For logits and logprobs post processing
temp_scaled_logprobs: bool = False
top_p_normalized_logprobs: bool = False
Expand Down
6 changes: 5 additions & 1 deletion fastdeploy/model_executor/layers/sample/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,9 +255,11 @@ def post_process(self, next_tokens: paddle.Tensor, skip_idx_list: List[int] = []
def compute_logprobs(
self,
logits: paddle.Tensor,
sampling_metadata: SamplingMetadata,
sampling_metadata: Optional[SamplingMetadata] = None,
) -> paddle.Tensor:
""" """
if sampling_metadata is None:
return F.log_softmax(logits, axis=-1)
last_logits = logits
real_bsz = last_logits.shape[0]
temp_scaled_logprobs = sampling_metadata.temp_scaled_logprobs
Expand Down Expand Up @@ -317,6 +319,8 @@ def gather_logprobs(
assert token_ids.dtype == paddle.int64
logprobs.clip_(min=paddle.finfo(logprobs.dtype).min)
# Get with the logprob of the prompt or sampled token.
if len(token_ids.shape) < len(logprobs.shape):
token_ids = token_ids.unsqueeze(-1)
token_logprobs = paddle.take_along_axis(logprobs, token_ids, axis=-1)

# Compute the ranks of the actual token.
Expand Down
53 changes: 36 additions & 17 deletions fastdeploy/model_executor/pre_and_post_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,12 @@

from fastdeploy.output.pooler import PoolerOutput, PoolingSequenceGroupOutput
from fastdeploy.output.stream_transfer_data import DecoderState, StreamTransferData
from fastdeploy.worker.output import ModelOutputData, ModelRunnerOutput, SamplerOutput
from fastdeploy.worker.output import (
LogprobsTensors,
ModelOutputData,
ModelRunnerOutput,
SamplerOutput,
)

DISABLE_RECOVER = envs.FD_DISABLED_RECOVER == "1"

Expand Down Expand Up @@ -253,7 +258,12 @@ def pre_process(
)


def _build_stream_transfer_data(output_tokens: np.ndarray, pooler_outputs: List[PoolingSequenceGroupOutput] = None):
def _build_stream_transfer_data(
output_tokens: paddle.Tensor,
pooler_outputs: List[PoolingSequenceGroupOutput] = None,
logprobs: Optional[LogprobsTensors] = None,
prompt_logprobs_list: Optional[LogprobsTensors] = None,
):
"""Split output_tokens and output"""

stream_transfer_datas = []
Expand All @@ -266,6 +276,11 @@ def _build_stream_transfer_data(output_tokens: np.ndarray, pooler_outputs: List[
stream_transfer_data = StreamTransferData(
decoder_state=DecoderState.TEXT, tokens=output_token_per_sample, batch_id=bid
)
if logprobs:
logprobs = logprobs.slice_rows(bid, bid + 1)
stream_transfer_data.logprobs = logprobs
if prompt_logprobs_list:
stream_transfer_data.prompt_logprobs = prompt_logprobs_list[bid]
stream_transfer_datas.append(stream_transfer_data)
elif pooler_outputs is not None:
for bid, pooler_output in enumerate(pooler_outputs):
Expand Down Expand Up @@ -390,27 +405,31 @@ def post_process_normal(
# 3. Transmit the model's output and stop generation signal via message queue.
# In the future, we will abandon this approach.
if not skip_save_output:
if sampler_output.logprobs_tensors is None:
if envs.FD_USE_GET_SAVE_OUTPUT_V1:
if save_each_rank or model_output.mp_rank == 0:
output = _build_stream_transfer_data(sampler_output.sampled_token_ids)
async_output_queue.put(output)
else:
if envs.FD_USE_GET_SAVE_OUTPUT_V1:
if save_each_rank or model_output.mp_rank == 0:
output = _build_stream_transfer_data(
sampler_output.sampled_token_ids,
logprobs=sampler_output.logprobs_tensors,
prompt_logprobs_list=model_output.prompt_logprobs_list,
)
async_output_queue.put(output)
else:
if sampler_output.logprobs_tensors is None:
save_output(
sampler_output.sampled_token_ids,
model_output.not_need_stop,
model_output.mp_rank,
save_each_rank,
)
else:
save_output_topk(
sampler_output.sampled_token_ids,
sampler_output.logprobs_tensors.logprob_token_ids,
sampler_output.logprobs_tensors.logprobs,
sampler_output.logprobs_tensors.selected_token_ranks,
model_output.not_need_stop,
model_output.mp_rank,
)
else:
save_output_topk(
sampler_output.sampled_token_ids,
sampler_output.logprobs_tensors.logprob_token_ids,
sampler_output.logprobs_tensors.logprobs,
sampler_output.logprobs_tensors.selected_token_ranks,
model_output.not_need_stop,
model_output.mp_rank,
)


def post_process_specualate(
Expand Down
5 changes: 4 additions & 1 deletion fastdeploy/output/stream_transfer_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@

import numpy as np

from fastdeploy.worker.output import LogprobsTensors


class DecoderState(Enum):
"""DecoderState"""
Expand All @@ -38,7 +40,8 @@ class StreamTransferData:
batch_id: int
tokens: Optional[np.array] = None
speculaive_decoding: bool = False
logprobs: Optional[np.array] = None
logprobs: Optional[LogprobsTensors] = None
prompt_logprobs: Optional[LogprobsTensors] = None
accept_tokens: Optional[np.array] = None
accept_num: Optional[np.array] = None
# [num_reqs, hidden_size]
Expand Down
85 changes: 83 additions & 2 deletions fastdeploy/worker/gpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@
from fastdeploy.model_executor.models.interfaces_base import FdModelForPooling
from fastdeploy.output.pooler import PoolerOutput
from fastdeploy.worker.model_runner_base import ModelRunnerBase
from fastdeploy.worker.output import ModelOutputData, ModelRunnerOutput
from fastdeploy.worker.output import LogprobsTensors, ModelOutputData, ModelRunnerOutput


class GPUModelRunner(ModelRunnerBase):
Expand All @@ -112,8 +112,12 @@ def __init__(
self.speculative_method = self.fd_config.speculative_config.method
self.speculative_decoding = self.speculative_method is not None
self.enable_logprob = fd_config.model_config.enable_logprob
self.max_logprobs = fd_config.model_config.max_logprobs
self.enable_early_stop = self.fd_config.early_stop_config.enable_early_stop
self.is_pooling_model = self.fd_config.model_config.runner_type == "pooling"
self.vocal_size = self.fd_config.model_config.vocab_size
self.prompt_logprobs_reqs: dict[str, Request] = {}
self.in_progress_prompt_logprobs: dict[str, LogprobsTensors] = {}

# VL model config:
if self.enable_mm:
Expand Down Expand Up @@ -561,6 +565,9 @@ def insert_tasks_v1(self, req_dicts: List[Request], num_running_requests: int =
len(request.output_token_ids) if prefill_end_index >= len(input_ids) else 0
)
self.share_inputs["pre_ids"][idx : idx + 1] = -1
# pooling model request.sampling_params is None
if request.sampling_params is not None and request.sampling_params.prompt_logprobs is not None:
self.prompt_logprobs_reqs[request.request_id] = request
has_prefill_task = True
elif request.task_type.value == RequestType.DECODE.value: # decode task
logger.debug(f"Handle decode request {request} at idx {idx}")
Expand All @@ -581,6 +588,8 @@ def insert_tasks_v1(self, req_dicts: List[Request], num_running_requests: int =
self.share_inputs["seq_lens_decoder"][idx : idx + 1] = 0
self.share_inputs["seq_lens_encoder"][idx : idx + 1] = 0
self.share_inputs["is_block_step"][idx : idx + 1] = False
self.prompt_logprobs_reqs.pop(request.request_id, None)
self.in_progress_prompt_logprobs.pop(request.request_id, None)
Comment on lines +591 to +592
Copy link
Collaborator

@gongshaotian gongshaotian Nov 4, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

抢占时不需要 del self.prompt_logprobs_reqs[req.request_id] 吗

Copy link
Collaborator Author

@ckl117 ckl117 Nov 4, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

上面也有抢占时prompt_logprobs_reqs清除的逻辑啊。

continue

assert len(request.eos_token_ids) == self.model_config.eos_tokens_lens
Expand Down Expand Up @@ -1282,7 +1291,7 @@ def _prepare_inputs(self) -> None:
min_dec_lens=self.share_inputs["min_dec_len"],
bad_words_token_ids=self.share_inputs["bad_tokens"][:, :max_bad_tokens_len],
eos_token_ids=self.share_inputs["eos_token_id"],
max_num_logprobs=20 if self.enable_logprob else None,
max_num_logprobs=self.max_logprobs if self.enable_logprob else None,
enable_early_stop=self.enable_early_stop,
stop_flags=self.share_inputs["stop_flags"],
temp_scaled_logprobs=self.share_inputs["temp_scaled_logprobs"],
Expand Down Expand Up @@ -2086,6 +2095,8 @@ class at the server level, which is too granular for ModelRunner.
if self.use_cudagraph:
model_output = model_output[: self.real_token_num]

prompt_logprobs_list = self._get_prompt_logprobs_list(model_output)

if self.is_pooling_model:
hidden_states = model_output
pooler_output = self._pool(hidden_states, num_running_requests)
Expand Down Expand Up @@ -2228,6 +2239,7 @@ class at the server level, which is too granular for ModelRunner.
stop_seqs_len=self.share_inputs["stop_seqs_len"],
prompt_lens=self.share_inputs["prompt_lens"],
mask_rollback=self.share_inputs["mask_rollback"],
prompt_logprobs_list=prompt_logprobs_list,
)

if self.speculative_config.method in ["mtp"] and self.scheduler_config.splitwise_role == "prefill":
Expand Down Expand Up @@ -2485,6 +2497,9 @@ def clear_parameters(self, pid):
def clear_requests(self):
"""Dynamic model loader use to clear requests use for RL"""
self.share_inputs["stop_flags"][:] = True
# prompt_logprobs
self.prompt_logprobs_reqs.clear()
self.in_progress_prompt_logprobs.clear()

def update_parameters(self, pid):
"""Dynamic model loader use to update parameters use for RL"""
Expand Down Expand Up @@ -2694,3 +2709,69 @@ def prepare_rope3d(
cumsum_seqlens=cumsum_seqlens,
)
return rope_emb_lst

def _get_prompt_logprobs_list(
self,
hidden_states: paddle.Tensor,
) -> list[Optional[LogprobsTensors]]:
if len(self.prompt_logprobs_reqs) > 0:
assert (
not self.fd_config.cache_config.enable_prefix_caching
), "prompt_logprobs must disable prefix caching, --no-enable-prefix-caching."
logprobs_mode = self.fd_config.model_config.logprobs_mode
prompt_logprobs_list: list[Optional[LogprobsTensors]] = self.scheduler_config.max_num_seqs * [None]
completed_prefill_reqs: list[Request] = []
for req_id, request in self.prompt_logprobs_reqs.items():
num_prompt_logprobs = request.sampling_params.prompt_logprobs
if request.prompt_token_ids is None or num_prompt_logprobs is None:
continue
if num_prompt_logprobs == -1:
num_prompt_logprobs = self.vocal_size

num_tokens = request.prefill_end_index - request.prefill_start_index
num_prompt_tokens = len(request.prompt_token_ids)

logprobs_tensors = self.in_progress_prompt_logprobs.get(req_id)
if not logprobs_tensors:
logprobs_tensors = LogprobsTensors.empty(num_prompt_tokens - 1, num_prompt_logprobs + 1)
self.in_progress_prompt_logprobs[req_id] = logprobs_tensors
start_idx = request.prefill_start_index
start_tok = start_idx + 1
num_remaining_tokens = num_prompt_tokens - start_tok
if num_tokens <= num_remaining_tokens:
# This is a chunk, more tokens remain.
# In the == case, there are no more prompt logprobs to produce
# but we want to defer returning them to the next step where we
# have new generated tokens to return.
num_logits = num_tokens
else:
# This is the last chunk of prompt tokens to return.
num_logits = num_remaining_tokens
completed_prefill_reqs.append(request)
prompt_logprobs_list[request.idx] = logprobs_tensors
if num_logits <= 0:
# This can happen for the final chunk if we prefilled exactly
# (num_prompt_tokens - 1) tokens for this request in the prior
# step. There are no more prompt logprobs to produce.
continue
offset = self.share_inputs["cu_seqlens_q"][request.idx]
prompt_hidden_states = hidden_states[offset : offset + num_logits]
logits = self.model.compute_logits(prompt_hidden_states)
prompt_token_ids = request.prompt_token_ids[start_tok : start_tok + num_logits]
prompt_token_ids_tensor = paddle.to_tensor(prompt_token_ids, dtype="int64")
if logprobs_mode == "raw_logprobs":
raw_logprobs = self.sampler.compute_logprobs(logits)
elif logprobs_mode == "raw_logits":
raw_logprobs = logits
token_ids, logprobs, ranks = self.sampler.gather_logprobs(
raw_logprobs, num_prompt_logprobs, prompt_token_ids_tensor
)
chunk_slice = slice(start_idx, start_idx + num_logits)
logprobs_tensors.logprob_token_ids[chunk_slice].copy_(token_ids, False)
logprobs_tensors.logprobs[chunk_slice].copy_(logprobs, False)
logprobs_tensors.selected_token_ranks[chunk_slice].copy_(ranks, False)

for req in completed_prefill_reqs:
del self.prompt_logprobs_reqs[req.request_id]
del self.in_progress_prompt_logprobs[req.request_id]
return prompt_logprobs_list
Loading
Loading