Skip to content

Conversation

@ckl117
Copy link
Collaborator

@ckl117 ckl117 commented Nov 3, 2025

Motivation

GPUModelRunner supports max_logprobs=-1 and prompt_logprobs.

  • computes vocal_size logprobs(including prompt_logprobs).
  • computes prompt_logprobs(disable prefix cache).

Modifications

Usage or Command

export FD_USE_GET_SAVE_OUTPUT_V1=1

python -m fastdeploy.entrypoints.openai.api_server \
    --model ./ERNIE-4.5-0.3B-PT \
    --max-model-len 32768 \
    --max-num-seqs 128 \
    --tensor-parallel-size 1 \
    --enable-logprob \
    --max-logprobs -1 \
    --no-enable-prefix-caching \

Accuracy Tests

TODO: Server layer should support top_logprobs=-1 and prompt_logprobs.

Checklist

  • Add at least a tag in the PR title.
    • Tag list: [[FDConfig],[APIServer],[Engine], [Scheduler], [PD Disaggregation], [Executor], [Graph Optimization], [Speculative Decoding], [RL], [Models], [Quantization], [Loader], [OP], [KVCache], [DataProcessor], [BugFix], [Docs], [CI], [Optimization], [Feature], [Benchmark], [Others], [XPU], [HPU], [GCU], [DCU], [Iluvatar], [Metax]]
    • You can add new tags based on the PR content, but the semantics must be clear.
  • Format your code, run pre-commit before commit.
  • Add unit tests. Please write the reason in this PR if no unit tests.
  • Provide accuracy results.
  • If the current PR is submitting to the release branch, make sure the PR has been submitted to the develop branch, then cherry-pick it to the release branch with the [Cherry-Pick] PR tag.

@paddle-bot
Copy link

paddle-bot bot commented Nov 3, 2025

Thanks for your contribution!

Comment on lines 568 to 569
if request.sampling_params.prompt_logprobs is not None:
self.prompt_logprobs_reqs[request.request_id] = request
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

超长时间压测后有没有考虑内存增长的情况?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

另外RL场景中使用的话,需要在clear_requests函数中清空一下model_runner的一些对象,包括这个,也顺便梳理下有没有其他的对象需要清除

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

prompt变长和增大并发都会导致显存增长,但不存在显存泄露,这是符合预期的。

Comment on lines +590 to +591
self.prompt_logprobs_reqs.pop(request.request_id, None)
self.in_progress_prompt_logprobs.pop(request.request_id, None)
Copy link
Collaborator

@gongshaotian gongshaotian Nov 4, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

抢占时不需要 del self.prompt_logprobs_reqs[req.request_id] 吗

Copy link
Collaborator Author

@ckl117 ckl117 Nov 4, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

上面也有抢占时prompt_logprobs_reqs清除的逻辑啊。

@ckl117 ckl117 changed the title [Feature] GPU Model Runner Supports prompt_logprobs and max_logprobs for Text model [Feature][Executor] GPU Model Runner Supports prompt_logprobs and max_logprobs for Text model Nov 4, 2025
@ckl117 ckl117 changed the title [Feature][Executor] GPU Model Runner Supports prompt_logprobs and max_logprobs for Text model [Feature][Executor] GPU Model Runner Supports prompt_logprobs and max_logprobs Nov 4, 2025
Comment on lines 2756 to 2758
if isinstance(prompt_token_ids, np.ndarray):
prompt_token_ids = prompt_token_ids.tolist()
prompt_token_ids_tensor = paddle.to_tensor(prompt_token_ids, dtype="int64")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

paddle 不支持 ndarray 直接转tensor吗

Copy link
Collaborator

@gongshaotian gongshaotian left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@ckl117 ckl117 merged commit 1c3ca48 into PaddlePaddle:develop Nov 5, 2025
29 of 33 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants