Skip to content

Qwen3-Next-MTP Impl#239

Merged
valarLip merged 9 commits into
mainfrom
ganyi/qwen3_next_mtp_impl
Feb 26, 2026
Merged

Qwen3-Next-MTP Impl#239
valarLip merged 9 commits into
mainfrom
ganyi/qwen3_next_mtp_impl

Conversation

@ganyi1996ppo
Copy link
Copy Markdown
Contributor

@ganyi1996ppo ganyi1996ppo commented Feb 26, 2026

Motivation

This PR enable qwen3-next-mtp in atom

Launching Script

python3 -m atom.entrypoints.openai_server \
  --model $MODEL \
  --gpu-memory-utilization 0.8 \
  --level 1 \
  -tp 4 \
  --server-port 8000 \
  --method "mtp" \
  --num-speculative-tokens 1 \

Technical Details

Test Plan

gsm8k on LMEval

Test Result

cudagraph:

|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.8415|±  |0.0101|
|     |       |strict-match    |     5|exact_match|↑  |0.8287|±  |0.0104|

Acceptance rate: 92.87%

eager:

|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|-----|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.8650|±  |0.0094|
|     |       |strict-match    |     5|exact_match|↑  |0.8415|±  |0.0101|

Acceptance rate: 93.16%

We notice there are some accuracy gap between graph mode and eager mode, and have identify which is caused mainly by executing drafter. It is found that the graph mode on qwen3-next-mtp can also get to 0.86 if we skip drafter directly during inference. We might fix this in the future.

Submission Checklist

Signed-off-by: ganyi <ygan@amd.com>
Signed-off-by: ganyi <ygan@amd.com>
Signed-off-by: ganyi <ygan@amd.com>
Signed-off-by: ganyi <ygan@amd.com>
Signed-off-by: ganyi <ygan@amd.com>
Copilot AI review requested due to automatic review settings February 26, 2026 07:11
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This pull request adds support for Qwen3-Next-MTP (Multi-Token Prediction) in the atom inference framework, following a similar pattern to the existing DeepSeek-MTP implementation. The PR enables speculative decoding with the MTP method for Qwen3-Next models, allowing the model to predict multiple tokens ahead to improve inference throughput.

Changes:

  • Implements the Qwen3NextMTP model architecture with weight remapping logic
  • Adds GDN (Gated Delta Net) attention support for speculative decoding with MTP
  • Extends engine components to handle mamba-like state management required for Qwen3-Next architecture

Reviewed changes

Copilot reviewed 15 out of 15 changed files in this pull request and generated 17 comments.

Show a summary per file
File Description
atom/spec_decode/eagle.py Registers Qwen3NextMTP model in the EAGLE speculative decode architecture dictionary
atom/models/qwen3_next_mtp.py New file implementing the Qwen3NextMTP model with multi-token predictor and weight remapping
atom/models/qwen3_next.py Moves mamba_v2_sharded_weight_loader locally, updates rope parameter handling, fixes attention return value
atom/model_ops/base_attention.py Fixes return value in fake function for linear attention
atom/model_ops/attentions/gdn_attn.py Implements full speculative decode metadata preparation and state index handling for MTP
atom/model_ops/attentions/aiter_attention.py Adds MTP support with multi-query sequence handling in decode path
atom/model_ops/attention_mha.py Updates paged attention to support multi-token queries with proper indexing
atom/model_ops/attention_gdn.py Enables spec decode path and fixes return value
atom/model_loader/loader.py Adds qwen3_next_mtp weight loading logic with conditional remapping
atom/model_engine/sequence.py Adds mamba_enabled flag and mamba_block_table for state management
atom/model_engine/scheduler.py Adds num_bonus tracking for accepted tokens in MTP decode
atom/model_engine/model_runner.py Implements MTP status async tracking, fixes value head dimension bug, updates cache allocation
atom/model_engine/llm_engine.py Adds mamba_enabled detection based on model type
atom/model_engine/block_manager.py Implements separate mamba block allocation for state storage
atom/config.py Adds qwen3_next to qwen3_next_mtp model type override when speculative config is present

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread atom/model_engine/llm_engine.py
Comment thread atom/models/qwen3_next.py Outdated
Comment thread atom/models/qwen3_next_mtp.py Outdated
Comment thread atom/model_ops/attentions/gdn_attn.py Outdated
Comment thread atom/models/qwen3_next.py
Comment thread atom/model_engine/scheduler.py Outdated
Comment thread atom/model_engine/scheduler.py Outdated
Comment thread atom/models/qwen3_next.py Outdated
Comment thread atom/models/qwen3_next_mtp.py Outdated
Comment thread atom/model_ops/attentions/gdn_attn.py Outdated
Signed-off-by: ganyi <ygan@amd.com>
Signed-off-by: ganyi <ygan@amd.com>
Copilot AI review requested due to automatic review settings February 26, 2026 08:07
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 15 out of 15 changed files in this pull request and generated 8 comments.


💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread atom/model_engine/model_runner.py Outdated
Comment thread atom/model_ops/attentions/gdn_attn.py
Comment thread atom/model_engine/scheduler.py
Comment on lines +133 to +146
def prepare_state_indices(self, batch: ScheduledBatch, with_spec: bool = False):
non_spec_state_indices = self.non_spec_state_indices_tensor.np
spec_state_indices = self.spec_state_indices_tensor.np
for idx, mamba_block_table in enumerate(batch.mamba_block_tables):
non_spec_state_indices[idx] = 0
spec_state_indices[idx] = 0

if not with_spec:
non_spec_state_indices[idx] = mamba_block_table[0]
else:
spec_state_indices[idx, : 1 + self.num_spec] = mamba_block_table[
: 1 + self.num_spec
]

Copy link

Copilot AI Feb 26, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

prepare_state_indices() iterates over batch.mamba_block_tables and writes state indices only for those entries. If batch.mamba_block_tables is shorter than the batch (or misaligned), the remaining rows keep stale values from previous steps. Consider iterating over the full batch (in req_ids order) and explicitly setting PAD/0 values for sequences without a mamba table so state indices are deterministic every step.

Copilot uses AI. Check for mistakes.
Comment on lines +170 to 171
query_start_loc = attn_metadata.cu_seqlens_q
context_lens_tensor = torch.zeros((batch.total_seqs_num_prefill)).cuda()
Copy link

Copilot AI Feb 26, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In prepare_gdn_metadata(), context_lens_tensor is overwritten with torch.zeros(...).cuda(), which (1) hard-codes the default CUDA device and (2) discards the actual context lengths from attn_metadata.context_lens. This can break multi-GPU runs and makes has_initial_state always false. Use the builder's device (self.device) and preserve the real context lengths (or document/encode the intended behavior explicitly).

Suggested change
query_start_loc = attn_metadata.cu_seqlens_q
context_lens_tensor = torch.zeros((batch.total_seqs_num_prefill)).cuda()
if context_lens_tensor is not None:
context_lens_tensor = context_lens_tensor.to(self.device)
query_start_loc = attn_metadata.cu_seqlens_q

Copilot uses AI. Check for mistakes.
Comment thread atom/model_engine/model_runner.py Outdated
Comment thread atom/model_ops/attention_mha.py Outdated
Comment thread atom/config.py Outdated
Comment thread atom/model_engine/model_runner.py Outdated
Comment thread atom/config.py Outdated
Comment thread atom/model_engine/model_runner.py Outdated
Signed-off-by: ganyi <ygan@amd.com>
Signed-off-by: ganyi <ygan@amd.com>
Copilot AI review requested due to automatic review settings February 26, 2026 08:56
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 15 out of 15 changed files in this pull request and generated 4 comments.


💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines 221 to 225
vars_used = [
("slot_mapping", bs), # TODO: MTP support
("slot_mapping", bs * max_seqlen_q),
("context_lens", bs),
("cu_seqlens_q", bs + 1),
("block_tables", bs),
Copy link

Copilot AI Feb 26, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Decode now treats slot_mapping as having bs * max_seqlen_q entries (see the updated vars_used). However, build_for_cudagraph_capture() in this file still slices slot_mapping as [:bs], which will be too short when max_seqlen_q > 1 (MTP) and can break cudagraph capture/replay. Update the capture path to slice slot_mapping consistently with bs * max_qlen.

Copilot uses AI. Check for mistakes.
Comment on lines +348 to +365
if self.use_spec_decode:
gdn_metadata = GDNAttentionMetadata(
num_prefills=0,
num_prefill_tokens=0,
num_decodes=0,
num_decode_tokens=0,
num_spec_decodes=bs,
num_spec_decode_tokens=bs * (self.num_spec + 1),
num_actual_tokens=bs * (self.num_spec + 1),
has_initial_state=None,
spec_query_start_loc=self.spec_query_start_loc[: bs + 1],
non_spec_query_start_loc=None,
spec_state_indices_tensor=self.spec_state_indices_tensor.gpu[:bs],
non_spec_state_indices_tensor=None,
spec_sequence_masks=self.spec_sequence_masks[:bs],
spec_token_indx=self.spec_token_indx[: bs * (self.num_spec + 1)],
non_spec_token_indx=self.non_spec_token_indx[:0],
num_accepted_tokens=self.num_accepted_tokens[:bs],
Copy link

Copilot AI Feb 26, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the spec-decode cudagraph capture path, the AttentionMetaData created just above still slices slot_mapping as var["slot_mapping"].gpu[:bs]. When spec decode uses max_qlen = num_spec+1, slot_mapping needs to cover bs * max_qlen tokens. Otherwise the captured graph will read an incorrectly-sized mapping. Adjust the capture metadata construction to use the same token-count sizing as the decode builder.

Copilot uses AI. Check for mistakes.
Comment on lines +48 to +58
self.layers = torch.nn.ModuleList(
Qwen3NextDecoderLayer(
atom_config,
layer_type="full_attention",
prefix=f"{prefix}.layers.{idx}",
layer_num=idx,
)
for idx in range(
self.mtp_start_layer_idx, self.mtp_start_layer_idx + self.num_mtp_layers
)
)
Copy link

Copilot AI Feb 26, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

self.layers is a ModuleList, so the parameter names will be model.layers.0.*, model.layers.1.*, etc. However, the weight remapping in remap_mtp_weight_name()/loader expects layer indices to match the original HF layer ids (e.g. model.layers.<num_hidden_layers>.*). With the current ModuleList, load_model() will not find parameters like model.layers.28.* and weight loading will fail. Consider using a ModuleDict keyed by the exact layer index (similar to DeepSeekMultiTokenPredictor.layers) and indexing it via str(self.mtp_start_layer_idx + current_step_idx).

Copilot uses AI. Check for mistakes.
Comment on lines +147 to +149
self.mamba_block_tables = [
seq.mamba_block_table for seq in seqs.values() if seq.mamba_block_table
]
Copy link

Copilot AI Feb 26, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

mamba_block_tables filters out sequences with an empty mamba_block_table. Downstream code (e.g., GDN attention metadata builder) indexes per-sequence state buffers by the enumerate index of batch.mamba_block_tables, so filtering can desynchronize indices from req_ids and leave parts of the state-index buffers uninitialized/stale. Keep mamba_block_tables aligned 1:1 with seqs.values() (and handle empty tables explicitly) so per-sequence indexing remains correct.

Suggested change
self.mamba_block_tables = [
seq.mamba_block_table for seq in seqs.values() if seq.mamba_block_table
]
# Keep mamba_block_tables aligned 1:1 with seqs.values() so per-sequence
# indexing (e.g., in downstream state buffers) remains consistent, even
# when some sequences have empty or missing mamba_block_table entries.
self.mamba_block_tables = [seq.mamba_block_table for seq in seqs.values()]

Copilot uses AI. Check for mistakes.
@valarLip valarLip merged commit 634d706 into main Feb 26, 2026
16 checks passed
@valarLip valarLip deleted the ganyi/qwen3_next_mtp_impl branch February 26, 2026 14:04
Jasen2201 pushed a commit to Jasen2201/ATOM that referenced this pull request Apr 10, 2026
* maybe right acc for eager

Signed-off-by: ganyi <ygan@amd.com>

* pass bonus tokens to next round

Signed-off-by: ganyi <ygan@amd.com>

* format

Signed-off-by: ganyi <ygan@amd.com>

* add compile back

Signed-off-by: ganyi <ygan@amd.com>

* format

Signed-off-by: ganyi <ygan@amd.com>

* remove unncessary comments and print

Signed-off-by: ganyi <ygan@amd.com>

* remove num draft token

Signed-off-by: ganyi <ygan@amd.com>

* resolve comments

Signed-off-by: ganyi <ygan@amd.com>

* black

Signed-off-by: ganyi <ygan@amd.com>

---------

Signed-off-by: ganyi <ygan@amd.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants