Qwen3-Next-MTP Impl#239
Conversation
Signed-off-by: ganyi <ygan@amd.com>
Signed-off-by: ganyi <ygan@amd.com>
Signed-off-by: ganyi <ygan@amd.com>
There was a problem hiding this comment.
Pull request overview
This pull request adds support for Qwen3-Next-MTP (Multi-Token Prediction) in the atom inference framework, following a similar pattern to the existing DeepSeek-MTP implementation. The PR enables speculative decoding with the MTP method for Qwen3-Next models, allowing the model to predict multiple tokens ahead to improve inference throughput.
Changes:
- Implements the Qwen3NextMTP model architecture with weight remapping logic
- Adds GDN (Gated Delta Net) attention support for speculative decoding with MTP
- Extends engine components to handle mamba-like state management required for Qwen3-Next architecture
Reviewed changes
Copilot reviewed 15 out of 15 changed files in this pull request and generated 17 comments.
Show a summary per file
| File | Description |
|---|---|
| atom/spec_decode/eagle.py | Registers Qwen3NextMTP model in the EAGLE speculative decode architecture dictionary |
| atom/models/qwen3_next_mtp.py | New file implementing the Qwen3NextMTP model with multi-token predictor and weight remapping |
| atom/models/qwen3_next.py | Moves mamba_v2_sharded_weight_loader locally, updates rope parameter handling, fixes attention return value |
| atom/model_ops/base_attention.py | Fixes return value in fake function for linear attention |
| atom/model_ops/attentions/gdn_attn.py | Implements full speculative decode metadata preparation and state index handling for MTP |
| atom/model_ops/attentions/aiter_attention.py | Adds MTP support with multi-query sequence handling in decode path |
| atom/model_ops/attention_mha.py | Updates paged attention to support multi-token queries with proper indexing |
| atom/model_ops/attention_gdn.py | Enables spec decode path and fixes return value |
| atom/model_loader/loader.py | Adds qwen3_next_mtp weight loading logic with conditional remapping |
| atom/model_engine/sequence.py | Adds mamba_enabled flag and mamba_block_table for state management |
| atom/model_engine/scheduler.py | Adds num_bonus tracking for accepted tokens in MTP decode |
| atom/model_engine/model_runner.py | Implements MTP status async tracking, fixes value head dimension bug, updates cache allocation |
| atom/model_engine/llm_engine.py | Adds mamba_enabled detection based on model type |
| atom/model_engine/block_manager.py | Implements separate mamba block allocation for state storage |
| atom/config.py | Adds qwen3_next to qwen3_next_mtp model type override when speculative config is present |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
Signed-off-by: ganyi <ygan@amd.com>
Signed-off-by: ganyi <ygan@amd.com>
There was a problem hiding this comment.
Pull request overview
Copilot reviewed 15 out of 15 changed files in this pull request and generated 8 comments.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| def prepare_state_indices(self, batch: ScheduledBatch, with_spec: bool = False): | ||
| non_spec_state_indices = self.non_spec_state_indices_tensor.np | ||
| spec_state_indices = self.spec_state_indices_tensor.np | ||
| for idx, mamba_block_table in enumerate(batch.mamba_block_tables): | ||
| non_spec_state_indices[idx] = 0 | ||
| spec_state_indices[idx] = 0 | ||
|
|
||
| if not with_spec: | ||
| non_spec_state_indices[idx] = mamba_block_table[0] | ||
| else: | ||
| spec_state_indices[idx, : 1 + self.num_spec] = mamba_block_table[ | ||
| : 1 + self.num_spec | ||
| ] | ||
|
|
There was a problem hiding this comment.
prepare_state_indices() iterates over batch.mamba_block_tables and writes state indices only for those entries. If batch.mamba_block_tables is shorter than the batch (or misaligned), the remaining rows keep stale values from previous steps. Consider iterating over the full batch (in req_ids order) and explicitly setting PAD/0 values for sequences without a mamba table so state indices are deterministic every step.
| query_start_loc = attn_metadata.cu_seqlens_q | ||
| context_lens_tensor = torch.zeros((batch.total_seqs_num_prefill)).cuda() |
There was a problem hiding this comment.
In prepare_gdn_metadata(), context_lens_tensor is overwritten with torch.zeros(...).cuda(), which (1) hard-codes the default CUDA device and (2) discards the actual context lengths from attn_metadata.context_lens. This can break multi-GPU runs and makes has_initial_state always false. Use the builder's device (self.device) and preserve the real context lengths (or document/encode the intended behavior explicitly).
| query_start_loc = attn_metadata.cu_seqlens_q | |
| context_lens_tensor = torch.zeros((batch.total_seqs_num_prefill)).cuda() | |
| if context_lens_tensor is not None: | |
| context_lens_tensor = context_lens_tensor.to(self.device) | |
| query_start_loc = attn_metadata.cu_seqlens_q |
Signed-off-by: ganyi <ygan@amd.com>
There was a problem hiding this comment.
Pull request overview
Copilot reviewed 15 out of 15 changed files in this pull request and generated 4 comments.
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| vars_used = [ | ||
| ("slot_mapping", bs), # TODO: MTP support | ||
| ("slot_mapping", bs * max_seqlen_q), | ||
| ("context_lens", bs), | ||
| ("cu_seqlens_q", bs + 1), | ||
| ("block_tables", bs), |
There was a problem hiding this comment.
Decode now treats slot_mapping as having bs * max_seqlen_q entries (see the updated vars_used). However, build_for_cudagraph_capture() in this file still slices slot_mapping as [:bs], which will be too short when max_seqlen_q > 1 (MTP) and can break cudagraph capture/replay. Update the capture path to slice slot_mapping consistently with bs * max_qlen.
| if self.use_spec_decode: | ||
| gdn_metadata = GDNAttentionMetadata( | ||
| num_prefills=0, | ||
| num_prefill_tokens=0, | ||
| num_decodes=0, | ||
| num_decode_tokens=0, | ||
| num_spec_decodes=bs, | ||
| num_spec_decode_tokens=bs * (self.num_spec + 1), | ||
| num_actual_tokens=bs * (self.num_spec + 1), | ||
| has_initial_state=None, | ||
| spec_query_start_loc=self.spec_query_start_loc[: bs + 1], | ||
| non_spec_query_start_loc=None, | ||
| spec_state_indices_tensor=self.spec_state_indices_tensor.gpu[:bs], | ||
| non_spec_state_indices_tensor=None, | ||
| spec_sequence_masks=self.spec_sequence_masks[:bs], | ||
| spec_token_indx=self.spec_token_indx[: bs * (self.num_spec + 1)], | ||
| non_spec_token_indx=self.non_spec_token_indx[:0], | ||
| num_accepted_tokens=self.num_accepted_tokens[:bs], |
There was a problem hiding this comment.
In the spec-decode cudagraph capture path, the AttentionMetaData created just above still slices slot_mapping as var["slot_mapping"].gpu[:bs]. When spec decode uses max_qlen = num_spec+1, slot_mapping needs to cover bs * max_qlen tokens. Otherwise the captured graph will read an incorrectly-sized mapping. Adjust the capture metadata construction to use the same token-count sizing as the decode builder.
| self.layers = torch.nn.ModuleList( | ||
| Qwen3NextDecoderLayer( | ||
| atom_config, | ||
| layer_type="full_attention", | ||
| prefix=f"{prefix}.layers.{idx}", | ||
| layer_num=idx, | ||
| ) | ||
| for idx in range( | ||
| self.mtp_start_layer_idx, self.mtp_start_layer_idx + self.num_mtp_layers | ||
| ) | ||
| ) |
There was a problem hiding this comment.
self.layers is a ModuleList, so the parameter names will be model.layers.0.*, model.layers.1.*, etc. However, the weight remapping in remap_mtp_weight_name()/loader expects layer indices to match the original HF layer ids (e.g. model.layers.<num_hidden_layers>.*). With the current ModuleList, load_model() will not find parameters like model.layers.28.* and weight loading will fail. Consider using a ModuleDict keyed by the exact layer index (similar to DeepSeekMultiTokenPredictor.layers) and indexing it via str(self.mtp_start_layer_idx + current_step_idx).
| self.mamba_block_tables = [ | ||
| seq.mamba_block_table for seq in seqs.values() if seq.mamba_block_table | ||
| ] |
There was a problem hiding this comment.
mamba_block_tables filters out sequences with an empty mamba_block_table. Downstream code (e.g., GDN attention metadata builder) indexes per-sequence state buffers by the enumerate index of batch.mamba_block_tables, so filtering can desynchronize indices from req_ids and leave parts of the state-index buffers uninitialized/stale. Keep mamba_block_tables aligned 1:1 with seqs.values() (and handle empty tables explicitly) so per-sequence indexing remains correct.
| self.mamba_block_tables = [ | |
| seq.mamba_block_table for seq in seqs.values() if seq.mamba_block_table | |
| ] | |
| # Keep mamba_block_tables aligned 1:1 with seqs.values() so per-sequence | |
| # indexing (e.g., in downstream state buffers) remains consistent, even | |
| # when some sequences have empty or missing mamba_block_table entries. | |
| self.mamba_block_tables = [seq.mamba_block_table for seq in seqs.values()] |
* maybe right acc for eager Signed-off-by: ganyi <ygan@amd.com> * pass bonus tokens to next round Signed-off-by: ganyi <ygan@amd.com> * format Signed-off-by: ganyi <ygan@amd.com> * add compile back Signed-off-by: ganyi <ygan@amd.com> * format Signed-off-by: ganyi <ygan@amd.com> * remove unncessary comments and print Signed-off-by: ganyi <ygan@amd.com> * remove num draft token Signed-off-by: ganyi <ygan@amd.com> * resolve comments Signed-off-by: ganyi <ygan@amd.com> * black Signed-off-by: ganyi <ygan@amd.com> --------- Signed-off-by: ganyi <ygan@amd.com>
Motivation
This PR enable qwen3-next-mtp in atom
Launching Script
Technical Details
Test Plan
gsm8k on LMEval
Test Result
cudagraph:
Acceptance rate: 92.87%
eager:
Acceptance rate: 93.16%
We notice there are some accuracy gap between graph mode and eager mode, and have identify which is caused mainly by executing drafter. It is found that the graph mode on qwen3-next-mtp can also get to 0.86 if we skip drafter directly during inference. We might fix this in the future.
Submission Checklist