Skip to content

[Speculative Decoding] Add MTP logprob support for PD disaggregation#7442

Merged
Jiang-Jia-Jun merged 5 commits intoPaddlePaddle:developfrom
Deleter-D:dev_mtp_pd_logprob
Apr 17, 2026
Merged

[Speculative Decoding] Add MTP logprob support for PD disaggregation#7442
Jiang-Jia-Jun merged 5 commits intoPaddlePaddle:developfrom
Deleter-D:dev_mtp_pd_logprob

Conversation

@Deleter-D
Copy link
Copy Markdown
Collaborator

@Deleter-D Deleter-D commented Apr 16, 2026

Motivation

Enable logprob return for MTP speculative decoding under PD disaggregation architecture, particularly for handling the first token at prefill nodes.

Modifications

  1. Add mtp_save_first_token_with_topk.cc
  • support logprob saving for the first token at prefill nodes
  1. Add speculate_logprob_msg.h
  • extract common message structure definitions
  1. Refactor save_output_specualate function to differentiate prefill and decode node processing logic
  2. Move mtp_save_first_token call from mtp.py to pre_and_post_process.py

Usage or Command

Accuracy Tests

Checklist

  • Add at least a tag in the PR title.
    • Tag list: [[FDConfig],[APIServer],[Engine], [Scheduler], [PD Disaggregation], [Executor], [Graph Optimization], [Speculative Decoding], [RL], [Models], [Quantization], [Loader], [OP], [KVCache], [DataProcessor], [BugFix], [Docs], [CI], [Optimization], [Feature], [Benchmark], [Others], [XPU], [HPU], [GCU], [DCU], [Iluvatar], [Metax]]
    • You can add new tags based on the PR content, but the semantics must be clear.
  • Format your code, run pre-commit before commit.
  • Add unit tests. Please write the reason in this PR if no unit tests.
  • Provide accuracy results.
  • If the current PR is submitting to the release branch, make sure the PR has been submitted to the develop branch, then cherry-pick it to the release branch with the [Cherry-Pick] PR tag.

@paddle-bot
Copy link
Copy Markdown

paddle-bot bot commented Apr 16, 2026

Thanks for your contribution!

@Deleter-D Deleter-D changed the title [Speculative Decoding] Support mtp logprob in pd [Speculative Decoding] Add MTP logprob support for PD disaggregation Apr 16, 2026
PaddlePaddle-bot

This comment was marked as outdated.

PaddlePaddle-bot

This comment was marked as outdated.

PaddlePaddle-bot

This comment was marked as outdated.

@codecov-commenter
Copy link
Copy Markdown

codecov-commenter commented Apr 16, 2026

Codecov Report

❌ Patch coverage is 26.31579% with 14 lines in your changes missing coverage. Please review.
⚠️ Please upload report for BASE (develop@e952720). Learn more about missing BASE report.

Files with missing lines Patch % Lines
fastdeploy/model_executor/pre_and_post_process.py 9.09% 9 Missing and 1 partial ⚠️
fastdeploy/spec_decode/mtp.py 0.00% 3 Missing and 1 partial ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             develop    #7442   +/-   ##
==========================================
  Coverage           ?   73.31%           
==========================================
  Files              ?      398           
  Lines              ?    54985           
  Branches           ?     8616           
==========================================
  Hits               ?    40312           
  Misses             ?    11978           
  Partials           ?     2695           
Flag Coverage Δ
GPU 73.31% <26.31%> (?)

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

Copy link
Copy Markdown

@PaddlePaddle-bot PaddlePaddle-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🤖 AI Code Review | 2026-04-17 16:37 CST

📋 Review 摘要

PR 概述:为 MTP 投机解码在 PD 分离架构下新增 logprob 支持,主要处理 prefill 节点的首 token logprob 保存。
变更范围custom_ops/gpu_ops/speculate_decoding/(C++ 算子)、fastdeploy/model_executor/pre_and_post_process.py(Python 调度逻辑)、fastdeploy/spec_decode/mtp.py(MTP 后处理)、fastdeploy/worker/gpu_model_runner.py(调用入口)
影响面 TagOP Speculative Decoding PD Disaggregation

问题

级别 文件 概述
🟡 建议 pre_and_post_process.py:25-28 顶层无条件导入 GPU ops 可能影响非 GPU 平台
🟡 建议 speculate_logprob_msg.h:26 宏命名前缀不一致
🟡 建议 PR Checklist 缺少单元测试和精度测试结果

🟡 建议 1:pre_and_post_process.py 顶层无条件导入 GPU ops

pre_and_post_process.py 新增了顶层无条件导入:

from fastdeploy.model_executor.ops.gpu import (
    mtp_save_first_token,
    mtp_save_first_token_with_topk,
)

该模块不仅被 gpu_model_runner.py 导入,也被 gcu_model_runner.pymetax_model_runner.py 导入。虽然 ops.gpu.__init__.pytolerant_import_error 机制会将缺失的符号设为 None,但如果未来非 GPU 平台误触 MTP prefill 路径,会在运行时产生 TypeError: 'NoneType' object is not callable 而非有意义的错误信息。

建议将此导入改为延迟导入(lazy import),放到 save_output_specualate 函数的 is_mtp_prefill 分支内部,或者加平台守卫(与文件中已有的 if current_platform.is_iluvatar() 模式一致)。


🟡 建议 2:speculate_logprob_msg.h 宏命名前缀不一致

新增头文件中 SPEC_LOGPROB_MAX_BSZSPEC_LOGPROB_K 使用了 SPEC_LOGPROB_ 前缀,但 MAX_DRAFT_TOKEN_NUM 没有前缀。建议统一为 SPEC_LOGPROB_MAX_DRAFT_TOKEN_NUM 或类似命名,避免全局宏名冲突风险。


🟡 建议 3:缺少单元测试和精度测试结果

PR Checklist 中 "Add unit tests" 和 "Provide accuracy results" 均未勾选,且未说明原因。此 PR 新增了 C++ 算子 mtp_save_first_token_with_topk,并重构了 save_output_specualate 的控制流逻辑,建议至少补充:

  • 说明为何未添加单元测试(如依赖特定硬件环境等)
  • MTP + PD 分离场景下的 logprob 正确性验证结果

总体评价

PR 整体设计合理:提取公共消息结构到 speculate_logprob_msg.h 消除了重复定义,将 GPU 平台的 mtp_save_first_tokenmtp.py 移至 model_runner 调用链中使职责更清晰,XPU 平台的兼容性也通过 current_platform.is_xpu() 守卫得到保留。C++ 算子实现与已有的 SpeculateSaveOutMmsgTopK 模式一致。未发现阻塞性 P0 问题,建议关注上述非 GPU 平台导入兼容性问题。

Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

该 PR 在 FastDeploy 的 speculative decoding(MTP)与 PD disaggregation 架构下,补齐 prefill 节点“首 token”场景的 logprob 返回能力,并将部分输出保存逻辑从 proposer 侧迁移到 model runner / 统一后处理路径中。

Changes:

  • save_output_specualate 中区分 MTP prefill 与常规 decode 的输出保存路径,并在需要 logprobs 时走 topk 消息格式。
  • 新增/抽取 SysV 消息队列的 logprob 消息结构头文件,供 get/save topk 输出算子复用。
  • 新增 mtp_save_first_token_with_topk 自定义算子,用于 prefill 首 token 的 topk logprob 写入。

Reviewed changes

Copilot reviewed 7 out of 7 changed files in this pull request and generated 4 comments.

Show a summary per file
File Description
fastdeploy/worker/gpu_model_runner.py 调整 speculative 输出保存调用参数,新增传入 proposer 输入与 rank 信息以支持 MTP prefill 首 token 保存
fastdeploy/spec_decode/mtp.py GPU 平台的 mtp_save_first_token 从 proposer 后处理迁出,仅保留 XPU 路径
fastdeploy/model_executor/pre_and_post_process.py save_output_specualate 新增 MTP prefill 分支,并在 logprobs 场景调用 mtp_save_first_token_with_topk
custom_ops/gpu_ops/speculate_decoding/speculate_save_output_with_topk.cc 复用抽取后的 speculate_logprob_msg.h 常量与结构体定义
custom_ops/gpu_ops/speculate_decoding/speculate_get_output_with_topk.cc 同上,复用公共消息结构定义
custom_ops/gpu_ops/speculate_decoding/speculate_logprob_msg.h 新增公共 topk logprob 消息结构/常量定义头文件
custom_ops/gpu_ops/speculate_decoding/draft_model/mtp_save_first_token_with_topk.cc 新增 prefill 首 token topk logprob 写消息队列算子

Comment on lines 2517 to +2521
save_output_specualate(
sampler_output=sampler_output,
model_output=model_output_data,
share_inputs=self.share_inputs,
proposer_share_inputs=self.proposer.model_inputs,
Comment on lines +124 to +172
int max_num_logprobs = logprob_token_ids.shape()[1];
for (int i = 0; i < bsz; i++) {
int cur_token_num;
if (seq_lens_decoder_data[i] < prompt_lens_data[i] ||
token_num_per_batch_data[i] == 0) {
// chunk prefill or stop slots
cur_token_num = 0;
} else {
cur_token_num = token_num_per_batch_data[i] + 1;
}
msg_sed.meta[3 + i] = cur_token_num;
if (preempted_idx_data[i] == 1) {
msg_sed.meta[3 + i] = -9;
}

auto* cur_batch_msg_sed = &msg_sed.mtext[i];
int token_offset = cu_batch_token_offset_data[i];
for (int j = 0; j < cur_token_num; j++) {
auto* cur_tokens = &cur_batch_msg_sed->tokens[j * (SPEC_LOGPROB_K + 1)];
auto* cur_scores = &cur_batch_msg_sed->scores[j * (SPEC_LOGPROB_K + 1)];
if (j == 0) {
// first token has full logprobs
for (int k = 0; k < SPEC_LOGPROB_K + 1; k++) {
if (k == 0) {
cur_tokens[k] =
(int)sampled_token_ids_data[i * max_draft_tokens + j];
cur_scores[k] =
logprob_scores_data[(token_offset + j) * (SPEC_LOGPROB_K + 1) +
k];
} else if (k < max_num_logprobs) {
// only for first token
cur_tokens[k] =
(int)logprob_token_ids_data[(token_offset + j) *
(SPEC_LOGPROB_K + 1) +
k];
cur_scores[k] =
logprob_scores_data[(token_offset + j) * (SPEC_LOGPROB_K + 1) +
k];
} else {
cur_tokens[k] = -1;
cur_scores[k] = 0.0;
}
}
cur_batch_msg_sed->ranks[j] = (int)logprob_ranks_data[token_offset + j];
} else {
// draft token only has token_id
cur_tokens[0] = (int)sampled_token_ids_data[i * max_draft_tokens + j];
}
}
Comment on lines +584 to +597
mtp_save_first_token_with_topk(
recover_proposer_share_inputs_map["base_model_draft_tokens"],
sampler_output.logprobs_tensors.logprob_token_ids,
sampler_output.logprobs_tensors.logprobs,
sampler_output.logprobs_tensors.selected_token_ranks,
recover_share_inputs_map["accept_num_cpu"],
sampler_output.cu_batch_token_offset,
model_output.not_need_stop,
recover_share_inputs_map["seq_lens_decoder_cpu"],
recover_share_inputs_map["prompt_lens_cpu"],
recover_share_inputs_map["last_preempted_idx"],
3, # mtype
model_output.mp_rank,
save_each_rank,
Comment on lines +584 to +588
mtp_save_first_token_with_topk(
recover_proposer_share_inputs_map["base_model_draft_tokens"],
sampler_output.logprobs_tensors.logprob_token_ids,
sampler_output.logprobs_tensors.logprobs,
sampler_output.logprobs_tensors.selected_token_ranks,
@Jiang-Jia-Jun Jiang-Jia-Jun merged commit df3b4e1 into PaddlePaddle:develop Apr 17, 2026
36 of 42 checks passed
Jiang-Jia-Jun pushed a commit that referenced this pull request Apr 17, 2026
…saggregation (#7442) (#7464)

* support mtp logprob in pd

* fix

* fix

* fix

* fix xpu bugs
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants