[Speculative Decoding] fix mtp stop_seqs and limit thinking bugs#7166
Conversation
|
Thanks for your contribution! |
ba88df0 to
0f4325c
Compare
0f4325c to
41a8185
Compare
41a8185 to
8dea198
Compare
Codecov Report✅ All modified and coverable lines are covered by tests. Additional details and impacted files@@ Coverage Diff @@
## develop #7166 +/- ##
==========================================
Coverage ? 74.46%
==========================================
Files ? 383
Lines ? 53588
Branches ? 8405
==========================================
Hits ? 39905
Misses ? 10966
Partials ? 2717
Flags with carried forward coverage won't be shown. Click here to find out more. ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
8dea198 to
ae2f9f4
Compare
52711f9 to
b37c463
Compare
b37c463 to
dd2326a
Compare
dd2326a to
4ab41f1
Compare
4ab41f1 to
a0be6ee
Compare
a0be6ee to
99b5c45
Compare
a0a3014 to
17066c5
Compare
17066c5 to
d81c52c
Compare
There was a problem hiding this comment.
Pull request overview
该 PR 旨在修复投机解码(speculative decoding)中因 step_idx 语义变更导致的 stop sequences 截断与 thinking 长度限制相关 CUDA kernel 索引错误,并同步更新对应单测以覆盖新行为。
Changes:
- 修复
speculate_set_stop_value_multi_seqs的 stop 判定与匹配/截断逻辑(含延迟检测与 eos 追加策略)。 - 调整
speculate_limit_thinking_content_length里step_idx的只读语义与current_base_step计算方式,移除 kernel 内回退step_idx行为。 - 更新
test_speculate_set_stop_value_multi_seqs.py以适配新step_idx语义并补充边界用例。
Reviewed changes
Copilot reviewed 4 out of 4 changed files in this pull request and generated 6 comments.
| File | Description |
|---|---|
| tests/operators/test_speculate_set_stop_value_multi_seqs.py | 更新 stop_seq 匹配的 Python 参考实现与测试用例,覆盖新索引/延迟检测语义 |
| custom_ops/gpu_ops/speculate_decoding/unified_update_model_status.cu | 调整 token 写回历史的 base 计算以适配新的 step_idx 语义(当前实现存在关键 off-by-one 风险) |
| custom_ops/gpu_ops/speculate_decoding/speculate_set_stop_value_multi_seqs.cu | 更新 stop_seq 匹配/截断逻辑(当前实现存在 pre_ids 索引偏移与并发写竞争问题) |
| custom_ops/gpu_ops/speculate_decoding/speculate_limit_thinking_content_length.cu | step_idx 改为只读并修正 base step 计算,删除 kernel 内 step_idx 回退逻辑 |
| accept_idx 表示 stop_seq 最后 token 在 accept_tokens 中的位置 (0-based) | ||
| accept_idx = -1 表示 stop_seq 最后 token 在 pre_ids 的末尾 | ||
| (pre_ids[step_idx_now - 1]),即上一轮延迟匹配的最后一个 token。 | ||
| 为防止在 stop_seqs 后面追加 eos 越界,跳过 accept_tokens[accept_num-1] | ||
| (当前轮最后一个 token),该 token 延迟到下一轮匹配。 | ||
| 循环范围:accept_num > 0 时为 [-1, accept_num-2]; | ||
| accept_num = 0 时为 [-1](仅检查 pre_ids 末尾)。 |
There was a problem hiding this comment.
这里的注释与仓库中 step_idx/pre_ids 的既有语义不一致:在 decoder 阶段其它 kernel 都把最后一个输出 token 放在 pre_ids[step_idx](pre_ids[0] 预留),不是 pre_ids[step_idx-1]。建议修正注释并确保后续索引计算与该语义一致,否则会产生 off-by-one 的匹配/截断错误。
| } else { | ||
| int pre_ids_idx = step_idx_now + accept_tokens_idx; | ||
| #ifdef DEBUG_SPEC_STOP_SEQS | ||
| printf( | ||
| "PreIds bid:%d. tid:%d, step_idx_now:%ld. " | ||
| "accept_idx:%d. " | ||
| "pre_id_idx: %ld\n", | ||
| "accept_idx:%d. pre_id_idx: %d\n", | ||
| bid, | ||
| tid, | ||
| step_idx_now, | ||
| accept_idx, | ||
| step_idx_now - accept_num + accept_idx - | ||
| (stop_seq_len - 1 - i)); | ||
| pre_ids_idx); | ||
| #endif | ||
| int pre_ids_idx = | ||
| step_idx_now + accept_idx - (stop_seq_len - 1 - i); | ||
| // EC3 | ||
| // 特殊拼接会导致input_ids最后一位无特殊token,即pre_ids[0]可能为23, | ||
| // 导致异常结束 | ||
| if (pre_ids_idx <= 0) { | ||
| break; | ||
| } | ||
| if (pre_ids_idx < 0) break; | ||
| cur_token_idx = pre_ids_now[pre_ids_idx]; |
There was a problem hiding this comment.
pre_ids 的索引计算这里少了 +1 偏移且边界条件也不对:如果 pre_ids[0] 是预留位置、最后一个输出 token 在 pre_ids[step_idx_now],那么 accept_tokens_idx=-1 应该读取 pre_ids[step_idx_now];当前实现会读到 pre_ids[step_idx_now-1]。同时 pre_ids_idx==0 也应视为越界以避免把预留位参与 stop_seq 匹配。建议按 pre_ids_idx = step_idx_now + accept_tokens_idx + 1 计算,并将越界判断调整为 <=0。
| if (is_end) { | ||
| #ifdef DEBUG_SPEC_STOP_SEQS | ||
| printf("bid:%d end with accept_idx %d", bid, accept_idx); | ||
| printf("bid:%d end with accept_idx %d\n", bid, accept_idx); | ||
| #endif | ||
|
|
||
| accept_nums[bid] = accept_idx; | ||
| accept_tokens_now[accept_idx - 1] = end_ids[0]; | ||
| // stop_flags[bid] = true; | ||
| // accept_idx 在循环退出时已递增,指向 stop_seq 最后 token 的下一个位置 | ||
| accept_nums[bid] = accept_idx + 1; | ||
| accept_tokens_now[accept_idx] = end_ids[0]; | ||
| } |
There was a problem hiding this comment.
这里是多线程(tid=stop_seq 维度)并发写 accept_nums/accept_tokens_now[accept_idx],但没有任何同步/原子操作;当多个 stop_seq 同时匹配或 stop_seqs 中存在重复/前缀重叠时,会产生写竞争,导致结果不确定。建议改为单线程遍历 stop_seqs,或用原子/规约选择“最先触发”的匹配(例如最小截断位置),并确保只有一个线程负责最终写回。
| #ifdef DEBUG_SPEC_STOP_SEQS | ||
| printf("num %d < stop_seq_len %d\n", | ||
| step_idx_now - accept_num + accept_idx + 1, | ||
| step_idx_now + accept_idx + 1, | ||
| stop_seq_len); |
There was a problem hiding this comment.
DEBUG_SPEC_STOP_SEQS 分支里的 printf 格式符与参数类型不匹配:step_idx_now 是 int64_t,但这里用的是 %d。即使只在 debug 宏打开时编译,也可能引入未定义行为/错误日志。建议改用 %ld/%lld 并做显式类型转换。
| // Bounds check: highest write index is prompt_len + cur_step_idx | ||
| if (prompt_len + cur_step_idx < max_model_len) { | ||
| int64_t *token_ids_all_now = | ||
| &token_ids_all[batch_id * max_model_len + prompt_len]; | ||
| int64_t *output_ids = &step_output_ids[batch_id * max_step_tokens]; | ||
| int64_t base = cur_step_idx - output_len + 1; | ||
| int64_t base = cur_step_idx - output_len; | ||
| for (int i = 0; i < output_len; i++) { | ||
| token_ids_all_now[base + i] = output_ids[i]; | ||
| } |
There was a problem hiding this comment.
base 的 off-by-one 看起来会把本轮输出写到错误的位置:cur_step_idx 在循环内是先 ++ 再写入/计数,且其它 speculate 写历史的逻辑普遍假设最后一个输出 token 在 pre_ids/token_ids_all 的索引 step_idx 处(0 号位预留)。这里从 “- output_len + 1” 改成 “- output_len” 会把写入整体左移 1,可能覆盖预留位或丢失最后一个 token。建议重新对齐 step_idx 的 1-based 语义(大概率需要恢复 +1),并同步更新对应的 Python reference/单测以避免静默错写历史。
| token_ids_all 布局: | ||
| pre_ids_now[0] = prompt 最后一个 token(预留位置) | ||
| pre_ids_now[k] = 第 k 个 output token (k >= 1) | ||
| 最后一个 output token 在 pre_ids_now[step_idx] | ||
|
|
||
| 核心设计: | ||
| 1. 主循环只检查 accept_idx <= accept_num-2 | ||
| 2. 如果 stop_seq 最后 token 在 accept_num-1,延迟到下一轮 | ||
| 3. 下一轮通过 pre_ids_end 检测,输出 eos | ||
| 4. 匹配成功时: 保留 stop_seq 所有 token,在其后追加 eos |
There was a problem hiding this comment.
该测试文件新增的大段说明目前是中文(docstring/注释),而同文件其余部分以英文为主。为保持测试用例的一致性与可维护性,建议将这段说明改为英文(代码注释保持英文)。
| token_ids_all 布局: | |
| pre_ids_now[0] = prompt 最后一个 token(预留位置) | |
| pre_ids_now[k] = 第 k 个 output token (k >= 1) | |
| 最后一个 output token 在 pre_ids_now[step_idx] | |
| 核心设计: | |
| 1. 主循环只检查 accept_idx <= accept_num-2 | |
| 2. 如果 stop_seq 最后 token 在 accept_num-1,延迟到下一轮 | |
| 3. 下一轮通过 pre_ids_end 检测,输出 eos | |
| 4. 匹配成功时: 保留 stop_seq 所有 token,在其后追加 eos | |
| Layout of token_ids_all: | |
| pre_ids_now[0] = the last prompt token (reserved slot) | |
| pre_ids_now[k] = the k-th output token (k >= 1) | |
| The last output token is stored at pre_ids_now[step_idx] | |
| Core design: | |
| 1. The main loop only checks accept_idx <= accept_num - 2 | |
| 2. If the last token of stop_seq is at accept_num - 1, defer handling to the next round | |
| 3. In the next round, detect it through pre_ids_end and emit eos | |
| 4. On a successful match, keep all stop_seq tokens and append eos after them |
d81c52c to
2e35a0b
Compare
…_stop_value kernels - speculate_limit_thinking_content_length: update current_base_step to step_idx+1 (step_idx now records history count before current round); remove incorrect step_idx decrement on accept_num truncation; mark step_idx param as const. - speculate_set_stop_value_multi_seqs: fix can_stop gate to use step_idx_now+accept_num>=min_token_limit; fix skip check and pre_ids_idx formula (remove stale -accept_num offset); use <= condition so accept_idx maps directly to the accepted token that ends the stop sequence; fix accept_tokens index (remove -1). - Update unit tests for speculate_set_stop_value_multi_seqs kernel.
dba48e8 to
064fe85
Compare
PaddlePaddle-bot
left a comment
There was a problem hiding this comment.
🤖 AI Code Review |
2026-04-13 13:09 CST
📋 Review 摘要
PR 概述:修复投机解码中 speculate_set_stop_value_multi_seqs 和 speculate_limit_thinking_content_length 两个 kernel 因 step_idx 语义变更引起的索引错误
变更范围:custom_ops/gpu_ops/speculate_decoding/
影响面 Tag:[Speculative Decoding] [OP]
问题
| 级别 | 文件 | 概述 |
|---|---|---|
| 无 | 无 | 未发现阻塞性问题 |
总体评价
PR 正确修复了因 step_idx 语义变更(从"含当前 round tokens"改为"仅含历史 tokens")引起的索引错误:
-
speculate_set_stop_value_multi_seqs.cu:
- can_stop 判断修复:
step_idx_now + accept_num >= min_token_limit✓ - accept_idx 从 -1 开始,支持检测 pre_ids 末尾的延迟匹配 ✓
- loop_end = accept_num - 2,跳过最后一个位置防止越界 ✓
- pre_ids_idx 计算修复为
step_idx_now + accept_tokens_idx✓
- can_stop 判断修复:
-
speculate_limit_thinking_content_length.cu:
- step_idx 参数改为 const,声明为只读 ✓
- current_base_step 修复为
step_idx[bid] + 1,适配新语义 ✓ - 移除 step_idx 回退逻辑,由 unified_update_model_status 统一管理 ✓
-
unified_update_model_status.cu:
- base 修复为
cur_step_idx - output_len,适配新 step_idx 语义 ✓
- base 修复为
-
测试覆盖:
- test_speculate_set_stop_value_multi_seqs.py 已全面更新,覆盖多种边界情况 ✓
注:XPU 平台使用 speculate_verify 而非 verify_draft_tokens,调用链和 step_idx 更新方式不同,可能需要独立评估(但不在本次 PR 范围内)。
…_stop_value kernels (PaddlePaddle#7166) - speculate_limit_thinking_content_length: update current_base_step to step_idx+1 (step_idx now records history count before current round); remove incorrect step_idx decrement on accept_num truncation; mark step_idx param as const. - speculate_set_stop_value_multi_seqs: fix can_stop gate to use step_idx_now+accept_num>=min_token_limit; fix skip check and pre_ids_idx formula (remove stale -accept_num offset); use <= condition so accept_idx maps directly to the accepted token that ends the stop sequence; fix accept_tokens index (remove -1). - Update unit tests for speculate_set_stop_value_multi_seqs kernel.
…7402, #7445 to release/online/20260415 (#7447) * [Speculate Decoding] Fix step_idx semantics in limit_thinking and set_stop_value kernels (#7166) - speculate_limit_thinking_content_length: update current_base_step to step_idx+1 (step_idx now records history count before current round); remove incorrect step_idx decrement on accept_num truncation; mark step_idx param as const. - speculate_set_stop_value_multi_seqs: fix can_stop gate to use step_idx_now+accept_num>=min_token_limit; fix skip check and pre_ids_idx formula (remove stale -accept_num offset); use <= condition so accept_idx maps directly to the accepted token that ends the stop sequence; fix accept_tokens index (remove -1). - Update unit tests for speculate_set_stop_value_multi_seqs kernel. * [Speculate Decoding] Fix bug of reasoning_phase_token_constraint kernel (#7349) Co-authored-by: guanshihui] <guanshihui@baidu.com> * [Speculate Decoding] Fix reasoning_phase_token_constraint call args in SpeculativeSampler (#7402) * [Interrupt reasoning] Add interrupt_requests control command support --------- Co-authored-by: guanshihui] <guanshihui@baidu.com>
Motivation
本 PR 修复投机解码中 speculate_set_stop_value_multi_seqs 和 speculate_limit_thinking_content_length 两个 kernel 因 step_idx 语义变更引起的索引错误。
Modifications
speculate_set_stop_value_multi_seqs
1、修复 can_stop 判断:step_idx_now >= min_token_limit → step_idx_now + accept_num >= min_token_limit,因为 step_idx 不再包含本轮 tokens。
2、添加 pre_ids_end 检测:新增检测上一轮延迟的 stop_seq 是否在本轮 pre_ids 末尾完整匹配,适配 pre_ids[1] 布局(+1 偏移)。
3、修改主循环条件:accept_num - 1 → accept_num - 2,不检查最后一个位置,防止写入 eos 时越界,延迟到下一轮 pre_ids_end 检测。
4、修复跳过条件和 pre_ids_idx 计算:去除旧语义遗留的 -accept_num 偏移,适配 pre_ids[1] 布局(+1 偏移)。
5、修复 accept_tokens 索引:重新计算 offset 和 accept_tokens_idx,使逻辑更清晰。
6、修复匹配成功后的输出逻辑:accept_idx → accept_idx + 1,保留 stop_seq 所有 token,在其后追加 eos
speculate_limit_thinking_content_length
1、修复 current_base_step 计算:step_idx[bid] - original_accept_num + 1 → step_idx[bid] + 1,适配新 step_idx 语义。
2、移除 step_idx 回退逻辑:截断 accept_num 时不再修改 step_idx,该操作由 unified_update_model_status 负责。
3、step_idx 参数改为 const:该 kernel 只读取 step_idx,调用处移除 const_cast。
测试
更新 test_speculate_set_stop_value_multi_seqs.py,同步适配新 step_idx 语义下的索引和匹配逻辑。
Usage or Command
无新增接口,修复已有逻辑。可通过投机解码推理验证 stop sequences 截断行为及 thinking 长度限制是否正确。
Accuracy Tests
单元测试通过。
Checklist
pre-commitbefore commit.test_speculate_set_stop_value_multi_seqs.py。