Skip to content
8 changes: 8 additions & 0 deletions cpp/tensorrt_llm/kernels/fusedQKNormRopeKernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -312,6 +312,14 @@ void launchFusedQKNormRope(void* qkv, int const num_tokens, int const num_heads_
base, position_ids, num_tokens, factor, low, high, attention_factor);
});
break;
case 256:
DISPATCH_INTERLEAVE(interleave, INTERLEAVE, {
fusedQKNormRopeKernel<256, INTERLEAVE><<<gridDim, blockDim, 0, stream>>>(
reinterpret_cast<__nv_bfloat16*>(qkv), num_heads_q, num_heads_k, num_heads_v, eps,
reinterpret_cast<__nv_bfloat16 const*>(q_weight), reinterpret_cast<__nv_bfloat16 const*>(k_weight),
base, position_ids, num_tokens, factor, low, high, attention_factor);
});
break;
default: TLLM_THROW("Unsupported head dimension for fusedQKNormRope: %d", head_dim);
}
}
Expand Down
14 changes: 12 additions & 2 deletions examples/models/core/qwen/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,14 @@ This document shows how to build and run a [Qwen](https://huggingface.co/Qwen) m
- [Quick start](#quick-start)
- [Run a single inference](#run-a-single-inference)
- [Evaluation](#evaluation)
- [Model Quantization to FP4](#model-quantization-to-fp4)
- [Model Quantization](#model-quantization)
- [Benchmark](#benchmark)
- [Serving](#serving)
- [trtllm-serve](#trtllm-serve)
- [Disaggregated Serving](#disaggregated-serving)
- [Eagle3](#eagle3)
- [Dynamo](#dynamo)
- [Dynamo](#dynamo)
- [Qwen3-Next](#qwen3-next)
- [Notes and Troubleshooting](#notes-and-troubleshooting)
- [Credits](#credits)

Expand Down Expand Up @@ -926,6 +927,15 @@ For further details, please refer to [speculative-decoding.md](../../../../docs/
NVIDIA Dynamo is a high-throughput low-latency inference framework designed for serving generative AI and reasoning models in multi-node distributed environments.
Dynamo supports TensorRT LLM as one of its inference engine. For details on how to use TensorRT LLM with Dynamo please refer to [LLM Deployment Examples using TensorRT-LLM](https://github.com/ai-dynamo/dynamo/blob/main/examples/tensorrt_llm/README.md)

## Qwen3-Next

Below is the command to run the Qwen3-Next model.

```bash
mpirun -n 1 --allow-run-as-root --oversubscribe python3 examples/llm-api/quickstart_advanced.py --model_dir /Qwen3-Next-80B-A3B-Thinking --kv_cache_fraction 0.6 --disable_kv_cache_reuse --max_batch_size 1 --tp_size 4

```

## Notes and Troubleshooting

- **Model Directory:** Update `<YOUR_MODEL_DIR>` with the actual path where the model weights reside.
Expand Down
6 changes: 4 additions & 2 deletions tensorrt_llm/_torch/custom_ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,15 @@
if IS_FLASHINFER_AVAILABLE:
from .flashinfer_custom_ops import (
flashinfer_apply_rope_with_cos_sin_cache_inplace,
flashinfer_fused_add_rmsnorm, flashinfer_rmsnorm,
flashinfer_silu_and_mul)
flashinfer_fused_add_rmsnorm, flashinfer_gemma_fused_add_rmsnorm,
flashinfer_gemma_rmsnorm, flashinfer_rmsnorm, flashinfer_silu_and_mul)
__all__ += [
'flashinfer_silu_and_mul',
'flashinfer_rmsnorm',
'flashinfer_fused_add_rmsnorm',
'flashinfer_apply_rope_with_cos_sin_cache_inplace',
'flashinfer_gemma_fused_add_rmsnorm',
'flashinfer_gemma_rmsnorm',
]

if IS_CUTLASS_DSL_AVAILABLE:
Expand Down
26 changes: 25 additions & 1 deletion tensorrt_llm/_torch/custom_ops/flashinfer_custom_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@

if IS_FLASHINFER_AVAILABLE:
from flashinfer.activation import silu_and_mul
from flashinfer.norm import fused_add_rmsnorm, rmsnorm
from flashinfer.norm import (fused_add_rmsnorm, gemma_fused_add_rmsnorm,
gemma_rmsnorm, rmsnorm)
from flashinfer.rope import apply_rope_with_cos_sin_cache_inplace

# Warp this into custom op since flashinfer didn't warp it properly and we want to avoid graph break between mlp layer for user buffer optimization
Expand All @@ -27,13 +28,36 @@ def _(input: torch.Tensor, weight: torch.Tensor,
eps: float) -> torch.Tensor:
return torch.empty_like(input)

@torch.library.custom_op("trtllm::flashinfer_gemma_rmsnorm",
mutates_args=())
def flashinfer_gemma_rmsnorm(input: torch.Tensor, weight: torch.Tensor,
eps: float) -> torch.Tensor:
return gemma_rmsnorm(input, weight, eps, enable_pdl=ENABLE_PDL)

@flashinfer_gemma_rmsnorm.register_fake
def _(input: torch.Tensor, weight: torch.Tensor,
eps: float) -> torch.Tensor:
return torch.empty_like(input)

@torch.library.custom_op("trtllm::flashinfer_fused_add_rmsnorm",
mutates_args=("input", "residual"))
def flashinfer_fused_add_rmsnorm(input: torch.Tensor,
residual: torch.Tensor,
weight: torch.Tensor, eps: float) -> None:
fused_add_rmsnorm(input, residual, weight, eps, enable_pdl=ENABLE_PDL)

@torch.library.custom_op("trtllm::flashinfer_gemma_fused_add_rmsnorm",
mutates_args=("input", "residual"))
def flashinfer_gemma_fused_add_rmsnorm(input: torch.Tensor,
residual: torch.Tensor,
weight: torch.Tensor,
eps: float) -> None:
gemma_fused_add_rmsnorm(input,
residual,
weight,
eps,
enable_pdl=ENABLE_PDL)

@torch.library.custom_op(
"trtllm::flashinfer_apply_rope_with_cos_sin_cache_inplace",
mutates_args=("query", "key"))
Expand Down
2 changes: 2 additions & 0 deletions tensorrt_llm/_torch/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from .modeling_qwen2vl import Qwen2_5_VLModel, Qwen2VLModel
from .modeling_qwen3 import Qwen3ForCausalLM
from .modeling_qwen3_moe import Qwen3MoeForCausalLM
from .modeling_qwen3_next import Qwen3NextForCausalLM
from .modeling_qwen_moe import Qwen2MoeForCausalLM
from .modeling_seedoss import SeedOssForCausalLM
from .modeling_siglip import SiglipVisionModel
Expand Down Expand Up @@ -66,6 +67,7 @@
"Qwen2_5_VLModel",
"Qwen3ForCausalLM",
"Qwen3MoeForCausalLM",
"Qwen3NextForCausalLM",
"GptOssForCausalLM",
"SeedOssForCausalLM",
]
Expand Down
4 changes: 3 additions & 1 deletion tensorrt_llm/_torch/models/checkpoints/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,14 @@
from .hf.qwen2_moe_weight_mapper import Qwen2MoeHfWeightMapper
from .hf.qwen2vl_weight_mapper import Qwen2VLHfWeightMapper
from .hf.qwen3_moe_weight_mapper import Qwen3MoeHfWeightMapper
from .hf.qwen3_next_weight_mapper import Qwen3NextHfWeightMapper
from .hf.weight_loader import HfWeightLoader
from .hf.weight_mapper import HfWeightMapper

__all__ = [
"HfConfigLoader", "HfWeightLoader", "HfWeightMapper",
"BaseCheckpointLoader", "HfCheckpointLoader", "NemotronHHfWeightMapper",
"Gemma3HfWeightMapper", "MixtralHfWeightMapper", "Llama4HfWeightMapper",
"Qwen2MoeHfWeightMapper", "Qwen3MoeHfWeightMapper", "Qwen2VLHfWeightMapper"
"Qwen2MoeHfWeightMapper", "Qwen3MoeHfWeightMapper", "Qwen2VLHfWeightMapper",
"Qwen3NextHfWeightMapper"
]
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
from typing import Union

import torch
from torch import nn

from tensorrt_llm._torch.model_config import ModelConfig
from tensorrt_llm._torch.models.checkpoints.hf.qwen2_moe_weight_mapper import \
Qwen2MoeHfWeightMapper
from tensorrt_llm._torch.models.modeling_nemotron_h import split
from tensorrt_llm._torch.models.modeling_utils import register_mapper
from tensorrt_llm.models.modeling_utils import DecoderModelForCausalLM


@register_mapper("HF", "Qwen3NextForCausalLM")
class Qwen3NextHfWeightMapper(Qwen2MoeHfWeightMapper):

def init_model_and_config(self, model: Union[nn.Module,
DecoderModelForCausalLM],
config: ModelConfig):
super().init_model_and_config(model, config)
self._num_kv_heads = model.config.num_key_value_heads if hasattr(
model.config, 'num_key_value_heads'
) and model.config.num_key_value_heads is not None else model.config.num_attention_heads

def should_skip_module(self, module_name: str) -> bool:
if module_name.startswith("draft_model"):
return True
return super().should_skip_module(module_name)

def _duplicate_kv_weights(self, module: nn.Module, new_name: str,
weights: dict):
tensors_to_duplicate = ["weight", "bias"]
if module.quant_config.quant_mode.has_nvfp4():
tensors_to_duplicate.append("weight_scale")
if module.quant_config.quant_mode.has_fp8_block_scales():
tensors_to_duplicate.append("weight_scale_inv")

if new_name in ['k_proj', 'v_proj']:
num_kv_heads_list = [self._num_kv_heads
] * len(weights) if isinstance(
self._num_kv_heads,
int) else self._num_kv_heads
processed_weights = {
k:
self._duplicate_kv(weight=v[:],
num_kv_heads=num_kv_heads_list[i],
tensor_parallel_size=self._tp_size)
if k in tensors_to_duplicate else v
for i, (k, v) in enumerate(weights.items())
}
return processed_weights

return weights

def preprocess_weights(self, weights: dict) -> dict:
config = self.config.pretrained_config
tp_size = self.config.mapping.tp_size
tp_rank = self.config.mapping.tp_rank

# linear_num_value_heads = config.linear_num_value_heads
# linear_num_key_heads = config.linear_num_key_heads
# linear_key_head_dim = config.linear_key_head_dim
# linear_value_head_dim = config.linear_value_head_dim
linear_key_dim = config.linear_key_head_dim * config.linear_num_key_heads # 16 * 128
linear_value_dim = config.linear_value_head_dim * config.linear_num_value_heads # 32 * 128

new_weights = {}
for name, _ in weights.items():
key = name

if "A_log" in key:
w = split(weights[name], tp_size, tp_rank)
w = w.to(torch.float32)
new_weights[key] = w
elif "dt_bias" in key:
w = split(weights[name], tp_size, tp_rank)
w = w.to(torch.float32)
new_weights[key] = w
elif "in_proj" in key:
# Don't need to split in_proj weight based on the implementation of reference.
# Need to know the reason.
new_weights[key] = weights[name]
elif "conv1d" in key:
w = weights[name]
# removing dim(1) because we are using Linear to store conv1d weights
if "weight" in key:
w = w.squeeze(1)

conv_q, conv_k, conv_v = torch.split(
w, [linear_key_dim, linear_key_dim, linear_value_dim],
dim=0)

w = []
for rank in range(tp_size):
conv_q_rank = split(conv_q, tp_size, rank)
conv_k_rank = split(conv_k, tp_size, rank)
conv_v_rank = split(conv_v, tp_size, rank)
y = torch.concat([conv_q_rank, conv_k_rank, conv_v_rank])
w.append(y)
w = torch.concat(w).contiguous()
new_weights[key] = w
else:
new_weights[key] = weights[name]

return new_weights
10 changes: 8 additions & 2 deletions tensorrt_llm/_torch/models/modeling_qwen3.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,12 @@ def __init__(
model_config: ModelConfig[Qwen3Config],
layer_idx: Optional[int] = None,
fuse_qk_norm_rope: bool = True,
attn_output_gate: bool = False,
use_gemma_rms_norm: bool = False,
):
config = model_config.pretrained_config
self.pretrained_config = config
self.attn_output_gate = attn_output_gate

if getattr(config, "rope_scaling", None) is not None:
if "type" in config.rope_scaling:
Expand All @@ -58,13 +62,15 @@ def __init__(
num_attention_heads=config.num_attention_heads,
num_key_value_heads=config.num_key_value_heads,
max_position_embeddings=config.max_position_embeddings,
bias=config.attention_bias,
bias=getattr(config, "attention_bias", None),
pos_embd_params=pos_embd_params,
fuse_qk_norm_rope=fuse_qk_norm_rope,
layer_idx=layer_idx,
dtype=config.torch_dtype,
dense_bias=config.attention_bias,
dense_bias=getattr(config, "attention_bias", None),
config=model_config,
attn_output_gate=self.attn_output_gate,
use_gemma_rms_norm=use_gemma_rms_norm,
)


Expand Down
Loading