Skip to content

Add speculative decoding support with MTP layers#3594

Merged
santhnm2 merged 114 commits intoNVIDIA:mainfrom
santhnm2:spec_mamba
Mar 11, 2026
Merged

Add speculative decoding support with MTP layers#3594
santhnm2 merged 114 commits intoNVIDIA:mainfrom
santhnm2:spec_mamba

Conversation

@santhnm2
Copy link
Contributor

@santhnm2 santhnm2 commented Feb 25, 2026

What does this PR do ?

This PR adds speculative decoding support for inference.

How it works

Each generation step proceeds in phases:

  1. Input construction — For each active decode request, 1 + K tokens are fed into the model: the previously sampled token plus K speculative tokens from the MTP heads of the previous step. These are interleaved across requests with matching position IDs.
    (dynamic_context.py: update_requests)

  2. Forward pass — The model processes all tokens in one pass. The base decoder produces logits at every position (note: materialize_only_last_token_logits must be off). The MTP heads produce K additional sets of logits from their lightweight transformer/Mamba layers, cached in model._mtp_logits_cache. These are concatenated to produce a [1+K, seq_len, vocab_size] logit tensor.
    (text_generation_controller.py: _dynamic_step_forward_logits)

  3. Sampling & verification — Both base and MTP logits are sampled (grouped into temperature/top_k/top_p buckets for efficiency). Then a greedy token-matching verification determines how many speculative tokens to accept: a speculative token at position t+k is accepted iff the base model's output at position t+k-1 equals it. Acceptance is consecutive — once a mismatch occurs, all subsequent speculative tokens for that request are rejected (enforced via cummin).
    (text_generation_controller.py: _dynamic_step_sample_logits_and_verify_tokens)

  4. KV cache rewind — For rejected tokens, the KV cache is rolled back: block offsets are decremented, and if the rewind crosses a block boundary, the block is released back to the allocator. For Mamba/hybrid models, SSM recurrent state is restored from intermediate snapshots captured during the Triton kernel execution.
    (text_generation_controller.py: _rewind_kv_cache)

  5. Bookkeeping — Sequence lengths advance by accepted_count + 1 (not just 1). Finish conditions (EOS, max length) are checked, and the accepted + sampled tokens are appended to each request's output. New MTP-sampled tokens are staged for the next step.

MTP head architecture

Each MultiTokenPredictionLayer (multi_token_prediction.py) takes the hidden states from the previous depth, concatenates them with shifted input embeddings via a learned projection (eh_proj), and runs the result through a transformer block (or Mamba stack). This is repeated K times to produce predictions at positions t+1 through t+K. A mtp_use_repeated_layer option shares weights across all K layers.

Mamba/hybrid SSM support

SSM models require special handling because they carry recurrent state that must be rollback-able:

  • Conv state: The circular buffer is enlarged by K slots, so no explicit save/restore is needed.
  • SSM state: The Triton SSM kernel (mamba_ssm.py) dumps intermediate state snapshots at every speculative step into a pre-allocated [num_layers, max_requests, K+1, ...] buffer — zero extra GPU-CPU sync. On rewind, the correct snapshot is indexed by accepted token count.

Contribution process

flowchart LR
    A[Pre-checks] --> B[PR Tests]
    subgraph Code Review/Approval
        C1[Expert Review] --> C2[Final Review]
    end
    B --> C1
    C2 --> D[Merge]
Loading

Pre-checks

  • I want this PR in a versioned release and have added the appropriate Milestone (e.g., Core 0.8)
  • I have added relevant unit tests
  • I have added relevant functional tests
  • I have added proper typing to my code Typing guidelines
  • I have added relevant documentation
  • I have run the autoformatter.sh on my PR

Code review

The following process is enforced via the CODEOWNERS file for changes into megatron/core. For changes outside of megatron/core, it is up to the PR author whether or not to tag the Final Reviewer team.

For MRs into `main` branch

Feel free to message or comment the @mcore-oncall to help accelerate your merge into main. The less complex your PR is, the faster it will be approved and merged!

(Step 1): Add PR label Expert Review

(Step 2): Collect the expert reviewers reviews

  1. Attach the Expert Review label when your PR is ready for review.
  2. GitHub auto-assigns expert reviewers based on your changes. They will get notified and pick up your PR soon.

⚠️ Only proceed to the next step once all reviewers have approved, merge-conflict are resolved and the CI is passing.
Final Review might get declined if these requirements are not fulfilled.

(Step 3): Final Review

  1. Add Final Review label
  2. GitHub auto-assigns final reviewers based on your changes. They will get notified and pick up your PR soon.

(Optional Step 4): Cherry-pick into release branch

If this PR also needs to be merged into core_r* release branches, after this PR has been merged, select Cherry-pick to open a new PR into the release branch.

For MRs into `dev` branch The proposed review process for `dev` branch is under active discussion.

MRs are mergable after one approval by either eharper@nvidia.com or zijiey@nvidia.com.

Merging your PR

Any member of core-adlr and core-nemo will be able to merge your PR.

Shanmugam Ramasamy and others added 16 commits February 2, 2026 12:55
Signed-off-by: Keshav Santhanam <ksanthanam@nvidia.com>
Signed-off-by: Keshav Santhanam <ksanthanam@nvidia.com>
Signed-off-by: Keshav Santhanam <ksanthanam@nvidia.com>
Signed-off-by: Keshav Santhanam <ksanthanam@nvidia.com>
Signed-off-by: Keshav Santhanam <ksanthanam@nvidia.com>
Signed-off-by: Keshav Santhanam <ksanthanam@nvidia.com>
@copy-pr-bot
Copy link

copy-pr-bot bot commented Feb 25, 2026

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

Signed-off-by: Keshav Santhanam <ksanthanam@nvidia.com>
Signed-off-by: Keshav Santhanam <ksanthanam@nvidia.com>
Signed-off-by: Keshav Santhanam <ksanthanam@nvidia.com>
Signed-off-by: Keshav Santhanam <ksanthanam@nvidia.com>
Signed-off-by: Keshav Santhanam <ksanthanam@nvidia.com>
Signed-off-by: Keshav Santhanam <ksanthanam@nvidia.com>
Signed-off-by: Keshav Santhanam <ksanthanam@nvidia.com>
@santhnm2
Copy link
Contributor Author

santhnm2 commented Mar 2, 2026

/ok to test 8e3710f

@svcnvidia-nemo-ci svcnvidia-nemo-ci added this to the Core 0.16 milestone Mar 2, 2026
santhnm2 added 2 commits March 1, 2026 22:33
Signed-off-by: Keshav Santhanam <ksanthanam@nvidia.com>
Signed-off-by: Keshav Santhanam <ksanthanam@nvidia.com>
@santhnm2
Copy link
Contributor Author

santhnm2 commented Mar 2, 2026

/ok to test 9727533

@santhnm2 santhnm2 changed the title Speculative decoding MTP Add speculative decoding support with MTP layers Mar 2, 2026
Signed-off-by: Keshav Santhanam <ksanthanam@nvidia.com>
Signed-off-by: Keshav Santhanam <ksanthanam@nvidia.com>
@santhnm2
Copy link
Contributor Author

/ok to test 6a08b28

Signed-off-by: Keshav Santhanam <ksanthanam@nvidia.com>
@santhnm2
Copy link
Contributor Author

/ok to test 3296f19

Signed-off-by: Keshav Santhanam <ksanthanam@nvidia.com>
@santhnm2
Copy link
Contributor Author

/ok to test fade26c

Signed-off-by: Keshav Santhanam <ksanthanam@nvidia.com>
@santhnm2
Copy link
Contributor Author

/ok to test b578a6a

Signed-off-by: Keshav Santhanam <ksanthanam@nvidia.com>
@santhnm2
Copy link
Contributor Author

/ok to test 277dfba

@copy-pr-bot
Copy link

copy-pr-bot bot commented Mar 11, 2026

/ok to test 277dfba

@santhnm2, there was an error processing your request: E2

See the following link for more information: https://docs.gha-runners.nvidia.com/cpr/e/2/

@santhnm2
Copy link
Contributor Author

/ok to test 2169c01

Signed-off-by: Keshav Santhanam <ksanthanam@nvidia.com>
Signed-off-by: Keshav Santhanam <ksanthanam@nvidia.com>
@santhnm2
Copy link
Contributor Author

/ok to test 5ee472a

@santhnm2 santhnm2 added this pull request to the merge queue Mar 11, 2026
@svcnvidia-nemo-ci
Copy link

🔄 Merge queue validation started!

You can track the progress here: https://github.com/NVIDIA/Megatron-LM/actions/runs/22976150513

Merged via the queue into NVIDIA:main with commit 8f539df Mar 11, 2026
120 checks passed
@santhnm2 santhnm2 deleted the spec_mamba branch March 11, 2026 22:31
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Approved All necessary approvals have been made Run functional tests

Projects

None yet

Development

Successfully merging this pull request may close these issues.

8 participants