-
Notifications
You must be signed in to change notification settings - Fork 1.8k
[None][feat] Support Qwen3 next #7892
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
12 commits
Select commit
Hold shift + click to select a range
bee8ef9
feat: Support Qwen3-next
byshiue f3df5d8
remove debug codes
byshiue 3875be3
split the prefill and decode attention for qwen3_next
byshiue d786bfa
add sglang attention implementation
byshiue 04e9810
remove debug message
byshiue adb269f
fix bug of ifb case on qwen3-next
byshiue b2978ad
refine codes of qwen3-next
byshiue ff5ca33
run pre-run commit -a to organize codes
byshiue af2f974
resolve the dependency of latest transformers
byshiue dd4e80a
fix ci
nv-guomingz 3a9b875
update code per coderabbitai suggestions.
nv-guomingz 5f9e254
address comments
nv-guomingz File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
105 changes: 105 additions & 0 deletions
105
tensorrt_llm/_torch/models/checkpoints/hf/qwen3_next_weight_mapper.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,105 @@ | ||
from typing import Union | ||
|
||
import torch | ||
from torch import nn | ||
|
||
from tensorrt_llm._torch.model_config import ModelConfig | ||
from tensorrt_llm._torch.models.checkpoints.hf.qwen2_moe_weight_mapper import \ | ||
Qwen2MoeHfWeightMapper | ||
from tensorrt_llm._torch.models.modeling_nemotron_h import split | ||
from tensorrt_llm._torch.models.modeling_utils import register_mapper | ||
from tensorrt_llm.models.modeling_utils import DecoderModelForCausalLM | ||
|
||
|
||
@register_mapper("HF", "Qwen3NextForCausalLM") | ||
class Qwen3NextHfWeightMapper(Qwen2MoeHfWeightMapper): | ||
|
||
def init_model_and_config(self, model: Union[nn.Module, | ||
DecoderModelForCausalLM], | ||
config: ModelConfig): | ||
super().init_model_and_config(model, config) | ||
self._num_kv_heads = model.config.num_key_value_heads if hasattr( | ||
model.config, 'num_key_value_heads' | ||
) and model.config.num_key_value_heads is not None else model.config.num_attention_heads | ||
|
||
def should_skip_module(self, module_name: str) -> bool: | ||
if module_name.startswith("draft_model"): | ||
return True | ||
return super().should_skip_module(module_name) | ||
|
||
def _duplicate_kv_weights(self, module: nn.Module, new_name: str, | ||
weights: dict): | ||
tensors_to_duplicate = ["weight", "bias"] | ||
if module.quant_config.quant_mode.has_nvfp4(): | ||
tensors_to_duplicate.append("weight_scale") | ||
if module.quant_config.quant_mode.has_fp8_block_scales(): | ||
tensors_to_duplicate.append("weight_scale_inv") | ||
|
||
coderabbitai[bot] marked this conversation as resolved.
Show resolved
Hide resolved
|
||
if new_name in ['k_proj', 'v_proj']: | ||
num_kv_heads_list = [self._num_kv_heads | ||
] * len(weights) if isinstance( | ||
self._num_kv_heads, | ||
int) else self._num_kv_heads | ||
processed_weights = { | ||
k: | ||
self._duplicate_kv(weight=v[:], | ||
num_kv_heads=num_kv_heads_list[i], | ||
tensor_parallel_size=self._tp_size) | ||
if k in tensors_to_duplicate else v | ||
for i, (k, v) in enumerate(weights.items()) | ||
} | ||
return processed_weights | ||
|
||
return weights | ||
|
||
def preprocess_weights(self, weights: dict) -> dict: | ||
config = self.config.pretrained_config | ||
tp_size = self.config.mapping.tp_size | ||
tp_rank = self.config.mapping.tp_rank | ||
|
||
# linear_num_value_heads = config.linear_num_value_heads | ||
# linear_num_key_heads = config.linear_num_key_heads | ||
# linear_key_head_dim = config.linear_key_head_dim | ||
# linear_value_head_dim = config.linear_value_head_dim | ||
linear_key_dim = config.linear_key_head_dim * config.linear_num_key_heads # 16 * 128 | ||
linear_value_dim = config.linear_value_head_dim * config.linear_num_value_heads # 32 * 128 | ||
|
||
new_weights = {} | ||
for name, _ in weights.items(): | ||
key = name | ||
|
||
if "A_log" in key: | ||
w = split(weights[name], tp_size, tp_rank) | ||
w = w.to(torch.float32) | ||
new_weights[key] = w | ||
elif "dt_bias" in key: | ||
w = split(weights[name], tp_size, tp_rank) | ||
w = w.to(torch.float32) | ||
new_weights[key] = w | ||
elif "in_proj" in key: | ||
# Don't need to split in_proj weight based on the implementation of reference. | ||
# Need to know the reason. | ||
new_weights[key] = weights[name] | ||
elif "conv1d" in key: | ||
w = weights[name] | ||
# removing dim(1) because we are using Linear to store conv1d weights | ||
if "weight" in key: | ||
w = w.squeeze(1) | ||
|
||
conv_q, conv_k, conv_v = torch.split( | ||
w, [linear_key_dim, linear_key_dim, linear_value_dim], | ||
dim=0) | ||
|
||
w = [] | ||
for rank in range(tp_size): | ||
conv_q_rank = split(conv_q, tp_size, rank) | ||
conv_k_rank = split(conv_k, tp_size, rank) | ||
conv_v_rank = split(conv_v, tp_size, rank) | ||
y = torch.concat([conv_q_rank, conv_k_rank, conv_v_rank]) | ||
w.append(y) | ||
w = torch.concat(w).contiguous() | ||
new_weights[key] = w | ||
coderabbitai[bot] marked this conversation as resolved.
Show resolved
Hide resolved
|
||
else: | ||
new_weights[key] = weights[name] | ||
|
||
return new_weights |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.