Skip to content

mtp: combined-forward speculative decode beats plain on GB10 (+2.4 t/s) (stacked on #13)#14

Open
TrevorS wants to merge 1 commit into
gb10-hbm-resident-modelfrom
gb10-mtp-combined-forward
Open

mtp: combined-forward speculative decode beats plain on GB10 (+2.4 t/s) (stacked on #13)#14
TrevorS wants to merge 1 commit into
gb10-hbm-resident-modelfrom
gb10-mtp-combined-forward

Conversation

@TrevorS
Copy link
Copy Markdown
Owner

@TrevorS TrevorS commented May 24, 2026

mtp: combined-forward speculative decode beats plain on GB10

Stacked on #13 (gb10-hbm-resident-model).

Summary

Makes --mtp faster than plain decode on DGX Spark / GB10 by replacing the canonical MTP draft+verify sequence (eval first_token, then run the verifier separately) with a single batched-N=2 forward over [first_token, drafts[0]]. The verifier reads both tokens' logits in one graph eval, so the cost amortizes.

Strict mode (--quality or DS4_MTP_STRICT=1) falls back to canonical decode2_exact for byte-equality with plain decode.

Speed — ds4-bench standard sweep (promessi_sposi.txt, gen=128, GB10)

ds4-bench now takes --mtp (see below), so this is the same harness CONTRIBUTING.md specifies, just with speculative decode enabled. Full CSVs: speed-bench/gb10.csv (plain), speed-bench/gb10_mtp.csv (--mtp).

ctx plain --mtp Δ
2048 14.24 16.13 +1.89
10240 14.04 15.56 +1.52
18432 13.88 14.82 +0.94
26624 13.59 14.77 +1.18
34816 12.95 14.91 +1.96
43008 12.79 13.79 +1.00
51200 12.57 12.95 +0.38
59392 12.32 12.91 +0.59

MTP is faster at every context (+0.4 to +2.0 t/s on this prompt). The margin tracks the speculative accept rate, which depends on how predictable the continuation is — prose like promessi_sposi sits on the lower end.

Prompt-dependence (chat-style prompts accept more drafts)

Measured with ds4 -p ... -n 256 --temp 0 (plain vs default --mtp), not in the CSV:

Prompt plain --mtp Δ
"knight" (short) 16.12 18.54 +2.42
long story prompt (~30k ctx) 13.46 17.54-19.91 +3.9 to +6.4

These are higher-accept-rate cases; the CSV table above is the conservative standard-prompt number.

What's in the PR

1. Small-N matmul kernel polish

  • Pair-fuse Q_A + KV_A in qkv_rms_fused decode (one weight load, two outputs).
  • Fuse head_rms_norm + rope_tail on Q.

2. Q8 share-weight batched matmul

  • matmul_q8_0_preq_batch_share_warp_kernel<N_TOK>: one Q8 weight row per warp, N_TOK F32 dot-products against N_TOK token activations.
  • Bit-equal to N=1 matmul_q8_0_preq_warp8_kernel at any block count: same per-lane stride, same warp_sum_f32, explicit __fmaf_rn locks the FMA contraction to the N=1 SASS. Dispatched at n_tok=2..4 under !DS4_MTP_STRICT.

3. Combined-forward verifier

  • metal_graph_eval_mtp_draft_n_from_hc: batched MTP-draft primitive.
  • ds4_session_eval_speculative_argmax_combined: single batched verifier forward over [first_token, drafts[0]], accept drafts matching target argmax, strict-mode fallback to decode2_exact.
  • combined_prev_hc bootstrap from canonical eval for the first iter.

4. Dispatch + default

  • K4 share-warp dispatch gate: replace the cuBLAS-cache-availability check (always-false on Spark, where every Q8 weight has a cached F16 copy) with a !DS4_MTP_STRICT gate, so the share-warp kernel fires at decode time.
  • mtp_draft_tokens default 1 → 2 (combined-forward needs ==2 to fire; the old default of 1 routed --mtp through canonical decode2 with no win).

5. Bench tooling

  • ds4-bench gains --mtp FILE / --mtp-draft N, mirroring the CLI decode loop so the standard harness can measure real --mtp throughput. Logs the chosen decode path. speed-bench/gb10_mtp.csv generated with it.

Correctness

DS4_MTP_STRICT=1 / --quality: combined-forward declines, path is canonical decode2_exact. Plain decode byte-identical to upstream/main (chat-formatted test prompts). The ds4-bench plain sweep matches the no-MTP baseline at every context.

Tested

  • make clean && make cuda-spark — clean
  • make cpu — clean
  • ./ds4_test --long-context, --tool-call-quality, --server, --metal-kernels — OK
  • ds4-bench plain + --mtp full 2048→65536 sweeps (both CSVs included)
  • Two ds4_test checks fail identically on stock upstream/main (f91c12b) — not introduced here:
    • --logprob-vectors short_code_completion — the only divergence across all 4 steps is the case of the markdown code-fence language tag: the full-precision official API emits ```c, the IQ2XXS (2-bit) local model emits ```C; the generated code (return snprintf) is byte-identical. A near-tie the aggressive quant resolves to uppercase; reproduces on stock upstream/main.
    • --metal-tensor-equivalence — long-context only, and upstream's, not ours. Upstream's MoE routed-expert down-projection accumulates via float atomicAdd when n_tokens >= 128 (use_atomic_down / moe_down_expert_tile*); atomicAdd order is scheduling-dependent, so two runs of the same config drift at ulp scale and occasionally flip a greedy argmax at long ctx. DS4_CUDA_MOE_NO_ATOMIC_DOWN=1 (an upstream flag) makes the long cases bit-exact (rms=0), confirming the cause. Stock upstream and the no-MTP PR1 branch flake identically (~2/5 runs); combined-forward's decode path never uses the atomic kernel, so it doesn't worsen it.

Hardware: NVIDIA DGX Spark (GB10 / sm_121), driver 580.142, CUDA 13.0
Model: DeepSeek-V4-Flash-IQ2XXS-w2Q2K-AProjQ8-SExpQ8-OutQ8-chat-v2-imatrix.gguf
MTP: DeepSeek-V4-Flash-MTP-Q4K-Q8_0-F32.gguf

AGENT.md compliance

  • "Preserve correctness before speed" — plain decode byte-identical; strict mode preserves byte-equal canonical path.
  • "Do not add permanent semantic variants behind flags" — no new flag; --mtp-draft is pre-existing (only its default changes), DS4_MTP_STRICT is pre-existing.
  • "Diagnostic switches are fine when they validate the one release path"DS4_CUDA_NO_Q8_SHARE_BATCH=1 opt-out preserved as the kill switch.

Makes `--mtp` faster than plain decode on DGX Spark / GB10 by replacing
the canonical MTP draft+verify sequence (eval first_token, then run the
verifier separately) with a single batched-N=2 forward over
[first_token, drafts[0]].  The verifier reads both tokens' logits in one
graph eval, so the cost amortizes.

Strict mode (--quality or DS4_MTP_STRICT=1) falls back to canonical
decode2_exact for byte-equality with plain decode.

Speed (ds4-bench standard sweep, promessi_sposi.txt, gen=128, GB10):

  ctx     plain    --mtp    delta
  2048    14.24    16.13    +1.89
  10240   14.04    15.56    +1.52
  18432   13.88    14.82    +0.94
  26624   13.59    14.77    +1.18
  34816   12.95    14.91    +1.96
  43008   12.79    13.79    +1.00
  51200   12.57    12.95    +0.38
  59392   12.32    12.91    +0.59

MTP is faster at every context; the margin tracks the speculative
accept rate, which varies with how predictable the continuation is
(prose like promessi_sposi sits on the lower end; chat-style prompts
accept more drafts and see +2 to +6 t/s).  Full sweeps for both paths
are in speed-bench/gb10.csv (plain) and speed-bench/gb10_mtp.csv (--mtp).

What's in this commit:

1. Small-N matmul kernel polish (cuda):
   - Pair-fuse Q_A + KV_A in qkv_rms_fused decode (one weight load, two
     outputs).
   - Fuse head_rms_norm + rope_tail on Q (decode + batched paths).

2. Q8 share-weight batched matmul (cuda):
   - matmul_q8_0_preq_batch_share_warp_kernel<N_TOK>: one Q8 weight row
     per warp, N_TOK F32 dot-products against N_TOK token activations.
     Bit-equal to N=1 matmul_q8_0_preq_warp8_kernel at any block count
     (same per-lane stride, same warp_sum_f32, explicit __fmaf_rn locks
     the FMA contraction to match the N=1 SASS).  Dispatched at
     n_tok=2..4 under !DS4_MTP_STRICT; cuBLAS Q8 path under strict.

3. Combined-forward verifier (mtp):
   - metal_graph_eval_mtp_draft_n_from_hc: batched MTP-draft primitive.
   - ds4_session_eval_speculative_argmax_combined: single batched
     verifier forward over [first_token, drafts[0]], accept drafts that
     match target argmax, strict-mode fallback to decode2_exact.
   - combined_prev_hc bootstrap from canonical eval for the first iter.

4. Dispatch + default (cli):
   - K4 share-warp dispatch gate: replace the cuBLAS-cache-availability
     check (always-false on Spark, where every Q8 weight has a cached
     F16 copy) with a !DS4_MTP_STRICT gate, so the share-warp kernel
     actually fires at decode time.
   - mtp_draft_tokens default 1 -> 2.  Combined-forward needs
     mtp_draft_tokens==2 to fire; the previous default of 1 routed
     --mtp through canonical decode2 with no measurable win.

5. Bench tooling (ds4-bench):
   - Add --mtp FILE / --mtp-draft N so ds4-bench can drive the
     speculative decode path (mirrors the CLI decode loop) and report
     real --mtp throughput.  Logs the chosen decode path.  Default
     --mtp-draft 2.  speed-bench/gb10_mtp.csv generated with it.

Plain decode is unchanged by this commit (byte-identical to upstream
on the chat-formatted test prompts; the ds4-bench plain sweep matches
the no-MTP baseline at every context).

Tested:
  - make clean && make cuda-spark      clean
  - make cpu                           clean
  - ./ds4_test --long-context, --tool-call-quality, --server,
    --metal-kernels                    OK
  - ds4-bench plain + --mtp full 2048..65536 sweeps (CSVs included)
  - Two ds4_test checks fail identically on stock upstream/main
    (f91c12b), not introduced here -- root-caused, both upstream:
      --logprob-vectors short_code_completion: the only divergence
        across the 4 steps is the code-fence tag case -- the
        full-precision official API emits ```c, the IQ2XXS (2-bit)
        local model emits ```C; the code (return snprintf) is
        byte-identical.  A quant near-tie, not a bug.
      --metal-tensor-equivalence: long-context only.  Upstream's MoE
        routed-expert down-projection accumulates via float atomicAdd
        when n_tokens >= 128 (use_atomic_down, moe_down_expert_tile*);
        atomicAdd order is scheduling-dependent, so two runs of the
        same config drift at ulp scale and occasionally flip a greedy
        argmax at long ctx.  Setting DS4_CUDA_MOE_NO_ATOMIC_DOWN=1
        (an upstream flag) makes it bit-exact, confirming the cause.
        Stock upstream and the no-MTP PR1 branch flake identically.

Hardware: NVIDIA DGX Spark (GB10 / sm_121), driver 580.142, CUDA 13.0
Model: DeepSeek-V4-Flash-IQ2XXS-w2Q2K-AProjQ8-SExpQ8-OutQ8-chat-v2-imatrix.gguf
MTP:   DeepSeek-V4-Flash-MTP-Q4K-Q8_0-F32.gguf
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant