mtp: combined-forward speculative decode beats plain on GB10 (+2.4 t/s) (stacked on #13)#14
Open
TrevorS wants to merge 1 commit into
Open
mtp: combined-forward speculative decode beats plain on GB10 (+2.4 t/s) (stacked on #13)#14TrevorS wants to merge 1 commit into
TrevorS wants to merge 1 commit into
Conversation
This was referenced May 24, 2026
36c1735 to
4e47e95
Compare
00644e0 to
781a665
Compare
4e47e95 to
a6284f0
Compare
Makes `--mtp` faster than plain decode on DGX Spark / GB10 by replacing
the canonical MTP draft+verify sequence (eval first_token, then run the
verifier separately) with a single batched-N=2 forward over
[first_token, drafts[0]]. The verifier reads both tokens' logits in one
graph eval, so the cost amortizes.
Strict mode (--quality or DS4_MTP_STRICT=1) falls back to canonical
decode2_exact for byte-equality with plain decode.
Speed (ds4-bench standard sweep, promessi_sposi.txt, gen=128, GB10):
ctx plain --mtp delta
2048 14.24 16.13 +1.89
10240 14.04 15.56 +1.52
18432 13.88 14.82 +0.94
26624 13.59 14.77 +1.18
34816 12.95 14.91 +1.96
43008 12.79 13.79 +1.00
51200 12.57 12.95 +0.38
59392 12.32 12.91 +0.59
MTP is faster at every context; the margin tracks the speculative
accept rate, which varies with how predictable the continuation is
(prose like promessi_sposi sits on the lower end; chat-style prompts
accept more drafts and see +2 to +6 t/s). Full sweeps for both paths
are in speed-bench/gb10.csv (plain) and speed-bench/gb10_mtp.csv (--mtp).
What's in this commit:
1. Small-N matmul kernel polish (cuda):
- Pair-fuse Q_A + KV_A in qkv_rms_fused decode (one weight load, two
outputs).
- Fuse head_rms_norm + rope_tail on Q (decode + batched paths).
2. Q8 share-weight batched matmul (cuda):
- matmul_q8_0_preq_batch_share_warp_kernel<N_TOK>: one Q8 weight row
per warp, N_TOK F32 dot-products against N_TOK token activations.
Bit-equal to N=1 matmul_q8_0_preq_warp8_kernel at any block count
(same per-lane stride, same warp_sum_f32, explicit __fmaf_rn locks
the FMA contraction to match the N=1 SASS). Dispatched at
n_tok=2..4 under !DS4_MTP_STRICT; cuBLAS Q8 path under strict.
3. Combined-forward verifier (mtp):
- metal_graph_eval_mtp_draft_n_from_hc: batched MTP-draft primitive.
- ds4_session_eval_speculative_argmax_combined: single batched
verifier forward over [first_token, drafts[0]], accept drafts that
match target argmax, strict-mode fallback to decode2_exact.
- combined_prev_hc bootstrap from canonical eval for the first iter.
4. Dispatch + default (cli):
- K4 share-warp dispatch gate: replace the cuBLAS-cache-availability
check (always-false on Spark, where every Q8 weight has a cached
F16 copy) with a !DS4_MTP_STRICT gate, so the share-warp kernel
actually fires at decode time.
- mtp_draft_tokens default 1 -> 2. Combined-forward needs
mtp_draft_tokens==2 to fire; the previous default of 1 routed
--mtp through canonical decode2 with no measurable win.
5. Bench tooling (ds4-bench):
- Add --mtp FILE / --mtp-draft N so ds4-bench can drive the
speculative decode path (mirrors the CLI decode loop) and report
real --mtp throughput. Logs the chosen decode path. Default
--mtp-draft 2. speed-bench/gb10_mtp.csv generated with it.
Plain decode is unchanged by this commit (byte-identical to upstream
on the chat-formatted test prompts; the ds4-bench plain sweep matches
the no-MTP baseline at every context).
Tested:
- make clean && make cuda-spark clean
- make cpu clean
- ./ds4_test --long-context, --tool-call-quality, --server,
--metal-kernels OK
- ds4-bench plain + --mtp full 2048..65536 sweeps (CSVs included)
- Two ds4_test checks fail identically on stock upstream/main
(f91c12b), not introduced here -- root-caused, both upstream:
--logprob-vectors short_code_completion: the only divergence
across the 4 steps is the code-fence tag case -- the
full-precision official API emits ```c, the IQ2XXS (2-bit)
local model emits ```C; the code (return snprintf) is
byte-identical. A quant near-tie, not a bug.
--metal-tensor-equivalence: long-context only. Upstream's MoE
routed-expert down-projection accumulates via float atomicAdd
when n_tokens >= 128 (use_atomic_down, moe_down_expert_tile*);
atomicAdd order is scheduling-dependent, so two runs of the
same config drift at ulp scale and occasionally flip a greedy
argmax at long ctx. Setting DS4_CUDA_MOE_NO_ATOMIC_DOWN=1
(an upstream flag) makes it bit-exact, confirming the cause.
Stock upstream and the no-MTP PR1 branch flake identically.
Hardware: NVIDIA DGX Spark (GB10 / sm_121), driver 580.142, CUDA 13.0
Model: DeepSeek-V4-Flash-IQ2XXS-w2Q2K-AProjQ8-SExpQ8-OutQ8-chat-v2-imatrix.gguf
MTP: DeepSeek-V4-Flash-MTP-Q4K-Q8_0-F32.gguf
781a665 to
e873b5e
Compare
This was referenced May 25, 2026
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
mtp: combined-forward speculative decode beats plain on GB10
Stacked on #13 (
gb10-hbm-resident-model).Summary
Makes
--mtpfaster than plain decode on DGX Spark / GB10 by replacing the canonical MTP draft+verify sequence (evalfirst_token, then run the verifier separately) with a single batched-N=2 forward over[first_token, drafts[0]]. The verifier reads both tokens' logits in one graph eval, so the cost amortizes.Strict mode (
--qualityorDS4_MTP_STRICT=1) falls back to canonicaldecode2_exactfor byte-equality with plain decode.Speed —
ds4-benchstandard sweep (promessi_sposi.txt, gen=128, GB10)ds4-benchnow takes--mtp(see below), so this is the same harness CONTRIBUTING.md specifies, just with speculative decode enabled. Full CSVs:speed-bench/gb10.csv(plain),speed-bench/gb10_mtp.csv(--mtp).--mtpMTP is faster at every context (+0.4 to +2.0 t/s on this prompt). The margin tracks the speculative accept rate, which depends on how predictable the continuation is — prose like
promessi_sposisits on the lower end.Prompt-dependence (chat-style prompts accept more drafts)
Measured with
ds4 -p ... -n 256 --temp 0(plain vs default--mtp), not in the CSV:--mtp"knight"(short)These are higher-accept-rate cases; the CSV table above is the conservative standard-prompt number.
What's in the PR
1. Small-N matmul kernel polish
qkv_rms_fuseddecode (one weight load, two outputs).head_rms_norm+rope_tailon Q.2. Q8 share-weight batched matmul
matmul_q8_0_preq_batch_share_warp_kernel<N_TOK>: one Q8 weight row per warp, N_TOK F32 dot-products against N_TOK token activations.matmul_q8_0_preq_warp8_kernelat any block count: same per-lane stride, samewarp_sum_f32, explicit__fmaf_rnlocks the FMA contraction to the N=1 SASS. Dispatched atn_tok=2..4under!DS4_MTP_STRICT.3. Combined-forward verifier
metal_graph_eval_mtp_draft_n_from_hc: batched MTP-draft primitive.ds4_session_eval_speculative_argmax_combined: single batched verifier forward over[first_token, drafts[0]], accept drafts matching target argmax, strict-mode fallback todecode2_exact.combined_prev_hcbootstrap from canonical eval for the first iter.4. Dispatch + default
!DS4_MTP_STRICTgate, so the share-warp kernel fires at decode time.mtp_draft_tokensdefault 1 → 2 (combined-forward needs==2to fire; the old default of 1 routed--mtpthrough canonicaldecode2with no win).5. Bench tooling
ds4-benchgains--mtp FILE/--mtp-draft N, mirroring the CLI decode loop so the standard harness can measure real--mtpthroughput. Logs the chosen decode path.speed-bench/gb10_mtp.csvgenerated with it.Correctness
DS4_MTP_STRICT=1/--quality: combined-forward declines, path is canonicaldecode2_exact. Plain decode byte-identical toupstream/main(chat-formatted test prompts). Theds4-benchplain sweep matches the no-MTP baseline at every context.Tested
make clean && make cuda-spark— cleanmake cpu— clean./ds4_test --long-context,--tool-call-quality,--server,--metal-kernels— OKds4-benchplain +--mtpfull 2048→65536 sweeps (both CSVs included)ds4_testchecks fail identically on stockupstream/main(f91c12b) — not introduced here:--logprob-vectors short_code_completion— the only divergence across all 4 steps is the case of the markdown code-fence language tag: the full-precision official API emits```c, the IQ2XXS (2-bit) local model emits```C; the generated code (return snprintf) is byte-identical. A near-tie the aggressive quant resolves to uppercase; reproduces on stockupstream/main.--metal-tensor-equivalence— long-context only, and upstream's, not ours. Upstream's MoE routed-expert down-projection accumulates via floatatomicAddwhenn_tokens >= 128(use_atomic_down/moe_down_expert_tile*);atomicAddorder is scheduling-dependent, so two runs of the same config drift at ulp scale and occasionally flip a greedy argmax at long ctx.DS4_CUDA_MOE_NO_ATOMIC_DOWN=1(an upstream flag) makes the long cases bit-exact (rms=0), confirming the cause. Stock upstream and the no-MTP PR1 branch flake identically (~2/5 runs); combined-forward's decode path never uses the atomic kernel, so it doesn't worsen it.Hardware: NVIDIA DGX Spark (GB10 / sm_121), driver 580.142, CUDA 13.0
Model:
DeepSeek-V4-Flash-IQ2XXS-w2Q2K-AProjQ8-SExpQ8-OutQ8-chat-v2-imatrix.ggufMTP:
DeepSeek-V4-Flash-MTP-Q4K-Q8_0-F32.ggufAGENT.md compliance
--mtp-draftis pre-existing (only its default changes),DS4_MTP_STRICTis pre-existing.DS4_CUDA_NO_Q8_SHARE_BATCH=1opt-out preserved as the kill switch.