Skip to content

[XPU] Unify Spec and non-spec branch.(#6947)#7180

Merged
Jiang-Jia-Jun merged 11 commits intoPaddlePaddle:developfrom
Jiajun-Ji:mtp-unify-v4
Apr 16, 2026
Merged

[XPU] Unify Spec and non-spec branch.(#6947)#7180
Jiang-Jia-Jun merged 11 commits intoPaddlePaddle:developfrom
Jiajun-Ji:mtp-unify-v4

Conversation

@Jiajun-Ji
Copy link
Copy Markdown
Contributor

@Jiajun-Ji Jiajun-Ji commented Apr 3, 2026

Motivation

💡 If this PR is a Cherry Pick, the PR title needs to follow the format by adding the [Cherry-Pick] label at the very beginning and appending the original PR ID at the end. For example, [Cherry-Pick][CI] Add check trigger and logic(#5191)

💡 如若此PR是Cherry Pick,PR标题需遵循格式,在最开始加上[Cherry-Pick]标签,以及最后面加上原PR ID,例如[Cherry-Pick][CI] Add check trigger and logic(#5191)

Modifications

Usage or Command

Accuracy Tests

  • 输出长度未见明显异常
image

Checklist

  • Add at least a tag in the PR title.
    • Tag list: [[FDConfig],[APIServer],[Engine], [Scheduler], [PD Disaggregation], [Executor], [Graph Optimization], [Speculative Decoding], [RL], [Models], [Quantization], [Loader], [OP], [KVCache], [DataProcessor], [BugFix], [Docs], [CI], [Optimization], [Feature], [Benchmark], [Others], [XPU], [HPU], [GCU], [DCU], [Iluvatar], [Metax]]
    • You can add new tags based on the PR content, but the semantics must be clear.
  • Format your code, run pre-commit before commit.
  • Add unit tests. Please write the reason in this PR if no unit tests.
  • Provide accuracy results.
  • If the current PR is submitting to the release branch, make sure the PR has been submitted to the develop branch, then cherry-pick it to the release branch with the [Cherry-Pick] PR tag.

Copilot AI review requested due to automatic review settings April 3, 2026 04:45
@paddle-bot
Copy link
Copy Markdown

paddle-bot bot commented Apr 3, 2026

Thanks for your contribution!

Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

该 PR 作为 #6685 的 Cherry-Pick,目标是在 XPU 后端统一 Spec / 非 Spec 分支的执行与后处理路径,并补齐 XPU 的 draft token 验证能力,从而与 GPU 侧的统一架构对齐。

Changes:

  • XPU ModelRunner 侧统一 speculative method 字段与 proposer 初始化/调用路径,并在后处理时接入 unified_update_model_status
  • XPU SpeculativeSampler 拆分 “naive 采样” 与 “verify + 采样” 路由,新增 verify_draft_tokens 调用链。
  • 新增 XPU 自定义算子 verify_draft_tokens(C++ wrapper + XPU3 kernel)及对应单测。

Reviewed changes

Copilot reviewed 9 out of 9 changed files in this pull request and generated 7 comments.

Show a summary per file
File Description
fastdeploy/worker/xpu_model_runner.py 统一 spec_method 命名与 proposer 初始化/运行逻辑;调整 share_inputs(新增 reasoning_status 等)并把后处理切到 unified_update_model_status
fastdeploy/model_executor/xpu_pre_and_post_process.py speculative 后处理由 speculate_update/speculate_set_value_by_flags_and_idx 迁移到 unified_update_model_status,并新增 is_naive_mode/prefill_one_step_stop 参数
fastdeploy/model_executor/layers/sample/sampler.py XPU speculative 采样路径重构:naive 采样 vs verify_draft_tokens 验证采样分流
custom_ops/xpu_ops/test/test_verify_draft_tokens.py 新增 verify_draft_tokens kernel 的参考实现对比测试
custom_ops/xpu_ops/src/plugin/src/wrapper/mtp_wrapper/verify_draft_tokens.cpp 新增 verify_draft_tokens 的 XPU plugin wrapper(含 CPU wrapper 与 XPU3 launch)
custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/mtp_kernel/verify_draft_tokens.xpu 新增 XPU3 verify_draft_tokens kernel 实现
custom_ops/xpu_ops/src/plugin/include/xpu/plugin.h 导出 verify_draft_tokens plugin API 声明
custom_ops/xpu_ops/src/ops/pybind/pybind.cc 暴露 verify_draft_tokens 到 Python 侧
custom_ops/xpu_ops/src/ops/mtp/verify_draft_token.cc 新增 verify_draft_tokens 的 Paddle 扩展 OP 封装与参数校验

Comment thread fastdeploy/worker/xpu_model_runner.py
Comment on lines +528 to +539
WRAPPER_CHECK_PTR(ctx, float, real_bsz, curand_states);
WRAPPER_CHECK_PTR(ctx, float, real_bsz, topp);
WRAPPER_CHECK_PTR(ctx, bool, real_bsz, stop_flags);
WRAPPER_CHECK_PTR(ctx, int, real_bsz, seq_lens_encoder);
WRAPPER_CHECK_PTR(ctx, float, real_bsz, seq_lens_this_time);
WRAPPER_CHECK_PTR(ctx, int64_t, end_length, end_tokens);

WRAPPER_CHECK_PTR(ctx, bool, real_bsz, is_block_step);
WRAPPER_CHECK_PTR(ctx, bool, real_bsz, cu_seqlens_q_output);
WRAPPER_CHECK_PTR(ctx, bool, real_bsz, reasoning_status);
WRAPPER_CHECK_PTR(ctx, bool, real_bsz, max_dec_len);
WRAPPER_CHECK_PTR(ctx, bool, real_bsz, step_idx);
Copy link

Copilot AI Apr 3, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

新增的 verify_draft_tokens wrapper 里 WRAPPER_CHECK_PTR 的类型参数与真实指针类型不一致:seq_lens_this_time 是 int* 却按 float 检查;cu_seqlens_q_output/reasoning_status/max_dec_len/step_idx 也都被按 bool 检查。该问题会导致 wrapper 参数校验错误,严重时可能引发编译/运行期问题。建议按实际类型修正(seq_lens_this_time/cand_lens/seq_lens_encoder/… 用 int;cu_seqlens_q_output/reasoning_status 用 int;max_dec_len/step_idx 用 int64_t)。

Copilot uses AI. Check for mistakes.
output_token =
topp_sampling_kernel(candidate_ids_now + i * max_candidate_len,
candidate_scores_now + i * max_candidate_len,
curand_states + i,
Copy link

Copilot AI Apr 3, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cpu_wrapper 的 TOPP Phase2 调用 topp_sampling_kernel 时使用 curand_states + i(i 为 token 位置),但 host 侧 VerifyDraftTokens 只分配了 bsz 个随机数(按 batch),在 i>0 时会越界读取,CPU 回退路径会出现未定义行为/崩溃。建议统一 curand_states 的语义(按 batch 还是按 step 位置),并确保分配长度与访问方式一致(例如改为 curand_states + bid 或分配 max_step_tokens 长度并在 XPU3 kernel 侧同样按 i 偏移)。

Suggested change
curand_states + i,
curand_states + bid,

Copilot uses AI. Check for mistakes.
Comment thread custom_ops/xpu_ops/src/ops/mtp/verify_draft_token.cc
Comment on lines +693 to +695
pass
# if not paddle.is_compiled_with_cuda():
# self.skipTest("Requires CUDA")
Copy link

Copilot AI Apr 3, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

该单测在 setUp 里没有根据 XPU 编译情况进行 skip(相关逻辑被注释掉),但文件顶部会在未编译 XPU 时把 CUDA_PLACE 退化为 CPUPlace,随后仍会调用 verify_draft_tokens 自定义算子,容易在无 XPU 环境下直接失败/崩溃。建议在 setUp 中恢复并改为检查 paddle.is_compiled_with_xpu()(或至少在非 XPU 环境 skip)。

Suggested change
pass
# if not paddle.is_compiled_with_cuda():
# self.skipTest("Requires CUDA")
if not paddle.is_compiled_with_xpu():
self.skipTest("Requires XPU")

Copilot uses AI. Check for mistakes.
Comment thread fastdeploy/model_executor/layers/sample/sampler.py
Comment thread fastdeploy/worker/xpu_model_runner.py
@codecov-commenter
Copy link
Copy Markdown

codecov-commenter commented Apr 3, 2026

Codecov Report

❌ Patch coverage is 16.21622% with 31 lines in your changes missing coverage. Please review.
⚠️ Please upload report for BASE (develop@e53f518). Learn more about missing BASE report.

Files with missing lines Patch % Lines
fastdeploy/model_executor/layers/sample/sampler.py 10.34% 26 Missing ⚠️
fastdeploy/worker/input_batch.py 28.57% 5 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             develop    #7180   +/-   ##
==========================================
  Coverage           ?   73.83%           
==========================================
  Files              ?      394           
  Lines              ?    54779           
  Branches           ?     8581           
==========================================
  Hits               ?    40446           
  Misses             ?    11609           
  Partials           ?     2724           
Flag Coverage Δ
GPU 73.83% <16.21%> (?)

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

PaddlePaddle-bot

This comment was marked as outdated.

Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 10 out of 10 changed files in this pull request and generated 8 comments.

Comment thread fastdeploy/worker/xpu_model_runner.py
Comment thread fastdeploy/worker/xpu_model_runner.py
Comment thread custom_ops/xpu_ops/test/test_verify_draft_tokens.py
Comment thread custom_ops/xpu_ops/test/test_verify_draft_tokens.py
Comment thread custom_ops/xpu_ops/test/test_verify_draft_tokens.py
Comment thread fastdeploy/worker/xpu_model_runner.py
PaddlePaddle-bot

This comment was marked as outdated.

Comment thread fastdeploy/model_executor/xpu_pre_and_post_process.py Outdated
PaddlePaddle-bot

This comment was marked as outdated.

Copilot AI review requested due to automatic review settings April 14, 2026 08:50
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 13 out of 13 changed files in this pull request and generated 7 comments.

Comments suppressed due to low confidence (1)

custom_ops/xpu_ops/src/ops/gather_next_token.cc:134

  • GatherNextTokenInferShape 里对非 speculative 分支把 bsz 固定成 0(且未使用 op attr 里的 max_bsz),会导致静态 shape 推断错误(out 维度变成 [0, dim])。建议把 max_bsz 作为 infer_shape 参数透传并用于返回形状(或直接返回 [-1, dim] 作为动态维度)。
  int64_t bsz = 0;
  int64_t dim_embed = x_shape[1];
  if (is_speculative) {
    return {{-1, dim_embed}};
  } else {
    return {{bsz, dim_embed}};
  }

Comment thread custom_ops/xpu_ops/src/ops/mtp/verify_draft_token.cc
Comment thread custom_ops/xpu_ops/src/plugin/src/wrapper/mtp_wrapper/verify_draft_tokens.cpp Outdated
Comment thread custom_ops/xpu_ops/test/test_verify_draft_tokens.py
Comment thread custom_ops/xpu_ops/test/test_verify_draft_tokens.py
Comment thread fastdeploy/model_executor/layers/sample/sampler.py
PaddlePaddle-bot

This comment was marked as outdated.

@cmcamdy cmcamdy changed the title [XPU] [Cherry-Pick] Unify Spec and non-spec branch.(#6947) [XPU] Unify Spec and non-spec branch.(#6947) Apr 15, 2026
PaddlePaddle-bot

This comment was marked as outdated.

cmcamdy
cmcamdy previously approved these changes Apr 15, 2026
Copy link
Copy Markdown
Collaborator

@cmcamdy cmcamdy left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 9 out of 9 changed files in this pull request and generated 2 comments.

Comment thread custom_ops/xpu_ops/test/test_adjust_batch_and_gather_next_token.py
Comment thread custom_ops/xpu_ops/src/ops/pybind/pybind.cc
PaddlePaddle-bot

This comment was marked as outdated.

Copy link
Copy Markdown

@PaddlePaddle-bot PaddlePaddle-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

📋 Review 摘要

PR 概述:统一 XPU 平台的 speculative 和 non-speculative 分支,与 GPU 实现对齐。

变更范围:custom_ops/xpu_ops/、fastdeploy/model_executor/、fastdeploy/worker/

影响面 Tag[XPU] [Speculative Decoding]

📝 PR 规范检查

项目 状态
标题 Tag ✅ 包含 [XPU]
Motivation ✅ 已填写
Modifications ✅ 已填写
Checklist ❌ 未填写

建议补充 Checklist 中的测试和精度测试相关内容。

问题

级别 文件 概述
🟡 建议 fastdeploy/worker/xpu_model_runner.py:1467 XPU 平台不支持 NGRAM/SUFFIX 方法,缺少兼容性检查
❓ 疑问 fastdeploy/worker/xpu_model_runner.py:1084 reasoning_status 参数新增但未被更新

总体评价

整体代码重构清晰,统一了 XPU 的 speculative 和 non-speculative 分支,与 GPU 实现对齐。但存在平台兼容性问题:NgramProposerSuffixProposer 使用 CUDA 专用代码,在 XPU 上配置 NGRAM 或 SUFFIX 方法会导致运行时失败。建议添加平台和方法的兼容性检查。

Comment thread fastdeploy/worker/xpu_model_runner.py
Comment thread fastdeploy/worker/xpu_model_runner.py
@Jiang-Jia-Jun Jiang-Jia-Jun merged commit 29495b2 into PaddlePaddle:develop Apr 16, 2026
34 of 37 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants