Skip to content

Conversation

byshiue
Copy link
Collaborator

@byshiue byshiue commented Sep 22, 2025

Summary by CodeRabbit

  • New Features

    • Added Qwen3-Next causal LM support, including weight loading and runtime integration.
    • Introduced attention output gating and optional Gemma-style RMSNorm.
    • Enabled Gemma RMSNorm variants when supported.
  • Performance

    • Added high-performance Flash Linear Attention and fused recurrent gating paths for faster inference.
    • Expanded fused QK-Norm+RoPE to support head dimension 256.
    • Improved weight-loading handling for specific layers.
    • Added KV cache management tailored for Qwen3-Next.
  • Documentation

    • Updated Qwen docs navigation and added Qwen3-Next quick start and usage guidance.

PR Checklist

Please review the following before submitting your PR:

  • PR description clearly explains what and why. If using CodeRabbit's summary, please make sure it makes sense.

  • PR Follows TRT-LLM CODING GUIDELINES to the best of your knowledge.

  • Test cases are provided for new code paths (see test instructions)

  • Any new dependencies have been scanned for license and vulnerabilities

  • CODEOWNERS updated if ownership changes

  • Documentation updated as needed

  • The reviewers assigned automatically/manually are appropriate for the PR.

  • Please check this after reviewing the above items as appropriate for this PR.

@byshiue
Copy link
Collaborator Author

byshiue commented Sep 22, 2025

/bot run

Copy link
Contributor

coderabbitai bot commented Sep 22, 2025

📝 Walkthrough

Walkthrough

Adds Qwen3-Next model integration with linear/full attention, Triton-based FLA kernels, Gemma RMSNorm variants, attention output gating, new HF weight mapper, and executor support. Extends custom ops/public exports, adjusts Q/KV kernel to support head_dim=256, updates docs, and tweaks weight-loading and KV cache creation paths.

Changes

Cohort / File(s) Summary
CUDA kernel update
cpp/tensorrt_llm/kernels/fusedQKNormRopeKernel.cu
Adds head_dim=256 dispatch in launchFusedQKNormRope; other dims unchanged.
Docs: Qwen
examples/models/core/qwen/README.md
Renames nav entry; adds Qwen3-Next quick start section (duplicated block present).
Torch custom ops (FlashInfer Gemma)
tensorrt_llm/_torch/custom_ops/__init__.py, tensorrt_llm/_torch/custom_ops/flashinfer_custom_ops.py
Exposes flashinfer Gemma RMSNorm and fused add+RMSNorm ops; registers Torch custom ops and fakes.
Model exports and weight mapper
tensorrt_llm/_torch/models/__init__.py, tensorrt_llm/_torch/models/checkpoints/__init__.py, tensorrt_llm/_torch/models/checkpoints/hf/qwen3_next_weight_mapper.py
Exports Qwen3NextForCausalLM and Qwen3NextHfWeightMapper; adds HF mapper with KV duplication, conv1d handling, TP preprocessing.
Qwen3 model updates
tensorrt_llm/_torch/models/modeling_qwen3.py
Adds attn_output_gate and use_gemma_rms_norm flags; safer attention_bias access; forwards flags to base.
Qwen3-Next model integration
tensorrt_llm/_torch/models/modeling_qwen3_next.py
Introduces Qwen3-Next stack: gated delta-net linear attention, MoE blocks, Triton fused QKVZBA and gating, decoder layers, model and CausalLM wrapper with weight loading.
Weight-loading tweak
tensorrt_llm/_torch/models/modeling_utils.py
Special-cases modules named "linear_attn.conv1d": squeeze dim 1 on weight before load_weights.
Attention and norms
tensorrt_llm/_torch/modules/attention.py, tensorrt_llm/_torch/modules/qk_norm_attention.py, tensorrt_llm/_torch/modules/rms_norm.py
Adds attention output gating; QK-Norm-RoPE supports Gemma RMSNorm and gating with fusion checks; RMSNorm gains Gemma mode (weight init and FlashInfer call paths).
FLA Triton ops (new)
tensorrt_llm/_torch/modules/fla/*
Adds utilities and kernels: cumsum, scaled dot-KKT, solve_tril, WY-fast recompute, chunked ops (o, delta h), fused recurrent (and sigmoid-gating) variants, L2 norm, gated LayerNorm/RMSNorm, op helpers, index utilities, and high-level wrappers.
Executor config and cache
tensorrt_llm/_torch/pyexecutor/_util.py, tensorrt_llm/_torch/pyexecutor/config_utils.py
Adds is_qwen3_next(config). Creates MambaHybridCacheManager branch for Qwen3-Next with layer masks and cache dtype/params validation.

Sequence Diagram(s)

sequenceDiagram
  autonumber
  participant U as User
  participant LM as Qwen3NextForCausalLM
  participant M as Qwen3NextModel
  participant L as DecoderLayer (linear/full)
  participant G as Gated Delta Net / Attention
  participant MOE as Sparse MoE Block
  participant N as RMSNorm/O-Proj

  U->>LM: forward(input_ids/embeds, attn_metadata, spec_metadata)
  LM->>M: forward(...)
  M->>M: embeddings + layer selection
  loop num_hidden_layers
    M->>L: forward(hidden_states, attn_metadata, spec_metadata)
    alt linear_attention
      L->>G: forward(hidden_states, attn_metadata, mamba_metadata)
      G-->>L: attn_out
    else full_attention
      L->>G: self-attention (QKV, gating/norm)
      G-->>L: attn_out
    end
    L->>MOE: route/process (if enabled)
    MOE-->>L: moe_out
    L->>N: post-attn norm/proj
    N-->>M: layer_out
  end
  M-->>LM: logits or hidden_states
  LM-->>U: output
Loading
sequenceDiagram
  autonumber
  participant HF as HF Weights
  participant WM as Qwen3NextHfWeightMapper
  participant MU as modeling_utils
  participant MD as Model/Modules

  HF-->>WM: provide state_dict + config
  WM->>WM: init_model_and_config (num_kv_heads)
  WM->>WM: preprocess_weights (A_log, dt_bias, conv1d split/concat)
  WM->>WM: duplicate_kv (per TP, quantization-aware)
  WM-->>MU: mapped weights
  MU->>MD: traverse modules
  alt module has load_weights and name includes "linear_attn.conv1d"
    MU->>MD: squeeze(dim=1) then load_weights
  else
    MU->>MD: load_weights
  end
Loading
sequenceDiagram
  autonumber
  participant C as Config
  participant CU as config_utils.is_qwen3_next
  participant EX as _pyexecutor._util
  participant CM as CacheManager

  C-->>CU: query linear_key_head_dim
  CU-->>EX: True/False
  alt True
    EX->>EX: derive mamba/layer masks
    EX->>CM: create MambaHybridCacheManager (validate beam/connector)
  else False
    EX->>CM: other cache manager paths
  end
  CM-->>EX: manager instance
Loading
sequenceDiagram
  autonumber
  participant C as Config
  participant A as Attention
  participant O as Output Gate

  C-->>A: attn_output_gate flag
  A->>A: compute attn_out
  opt attn_output_gate
    A->>O: sigmoid(gate) * attn_out
    O-->>A: gated_out
  end
  A-->>A: o_proj(gated_or_plain)
Loading

Estimated code review effort

🎯 5 (Critical) | ⏱️ ~120–180 minutes

Pre-merge checks and finishing touches

❌ Failed checks (2 warnings)
Check name Status Explanation Resolution
Description Check ⚠️ Warning The PR description only includes the checklist and omits the required “Description” and “Test Coverage” sections, so it does not meet the repository’s template requirements. Please add a “Description” section that explains the issue and the implemented solution, and include a “Test Coverage” section listing relevant tests to satisfy the repository’s PR template.
Docstring Coverage ⚠️ Warning Docstring coverage is 0.00% which is insufficient. The required threshold is 80.00%. You can run @coderabbitai generate docstrings to improve docstring coverage.
✅ Passed checks (1 passed)
Check name Status Explanation
Title Check ✅ Passed The title “[None][feat] Support Qwen3 next” correctly follows the repository’s naming convention and succinctly conveys the primary change of adding support for the Qwen3Next model.
✨ Finishing touches
  • 📝 Generate Docstrings
🧪 Generate unit tests
  • Create PR with unit tests
  • Post copyable unit tests in a comment

📜 Recent review details

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 7939a3e and 473b474.

📒 Files selected for processing (1)
  • cpp/tensorrt_llm/kernels/fusedQKNormRopeKernel.cu (1 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
  • cpp/tensorrt_llm/kernels/fusedQKNormRopeKernel.cu

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 43

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
tensorrt_llm/_torch/models/checkpoints/__init__.py (1)

1-1: Missing NVIDIA Apache-2.0 header (2025) at file top

This repo requires the NVIDIA Apache-2.0 header on all source files. Please add it.

Apply:

+# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#     http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
 from .base_checkpoint_loader import BaseCheckpointLoader
 from .hf.checkpoint_loader import HfCheckpointLoader
 from .hf.config_loader import HfConfigLoader
 from .hf.gemma3_weight_mapper import Gemma3HfWeightMapper
 from .hf.llama4_weight_mapper import Llama4HfWeightMapper
 from .hf.mixtral_weight_mapper import MixtralHfWeightMapper
 from .hf.nemotron_h_weight_mapper import NemotronHHfWeightMapper
 from .hf.qwen2_moe_weight_mapper import Qwen2MoeHfWeightMapper
 from .hf.qwen3_moe_weight_mapper import Qwen3MoeHfWeightMapper
 from .hf.qwen3_next_weight_mapper import Qwen3NextHfWeightMapper
🧹 Nitpick comments (57)
tensorrt_llm/_torch/modules/fla/l2norm.py (5)

24-44: Fix misleading comment and align store-cast with the other kernel.

  • The comment says “mean and variance” but the code computes squared L2 norm.
  • Optional: mirror the explicit cast used in l2norm_fwd_kernel to make the intended destination dtype clear (Triton will auto-cast, but being explicit improves readability).
-    # Compute mean and variance
+    # Compute squared L2 norm (sum of squares) and normalize
@@
-    b_y = b_x * b_rstd
-    tl.store(y + cols, b_y, mask=mask)
+    b_y = b_x * b_rstd
+    # Match the explicit cast pattern used in the 2D kernel for clarity
+    tl.store(y + cols, b_y.to(tl.float32), mask=mask)

If you prefer exact destination dtype, pass it as a constexpr/meta or derive from a block pointer like in the 2D kernel. Otherwise, keep auto-cast and only change the comment.


55-71: Remove unused NB meta-parameter and its call-site.

NB is computed/passed but never used in the kernel, tripping linters and confusing readers.

-def l2norm_fwd_kernel(
+def l2norm_fwd_kernel(
     x,
     y,
     eps,
-    NB: tl.constexpr,
     T: tl.constexpr,
     D: tl.constexpr,
     BT: tl.constexpr,
     BD: tl.constexpr,
 ):
@@
-    if D <= 512:
-        NB = triton.cdiv(T, 2048)
+    if D <= 512:
@@
         l2norm_fwd_kernel[grid](
             x,
             y,
             eps,
-            NB=NB,
             T=T,
             D=D,
             BD=BD,
             BT=16,
             num_warps=8,
             num_stages=3,
         )

Also applies to: 93-111


74-93: Guard max feature size check with a clearer error and type info.

The TRY003 lint likely flags the long message; also, including the resolved byte budget and dtype helps debugging.

-    if D > BD:
-        raise RuntimeError("This layer doesn't support feature dim >= 64KB.")
+    if D > BD:
+        raise RuntimeError(
+            f"L2Norm fused kernel supports at most ~64KB per feature; "
+            f"D={D}, dtype={x.dtype}, element_size={x.element_size()}B."
+        )

127-131: Silence linter for unused ctx or use it.

Ruff flags ctx unused in forward. Either delete it or bind to underscore.

-    def forward(ctx, x, eps=1e-6, output_dtype=None):
+    def forward(ctx, x, eps=1e-6, output_dtype=None):
+        # ctx is unused because we don't implement backward for the Triton path
+        del ctx
         return l2norm_fwd(x, eps, output_dtype)

142-152: Add minimal docstrings for public API.

Document behavior, shapes, and dtype handling for l2norm() and L2Norm to align with repo guidelines.

-def l2norm(x: torch.Tensor,
-           eps: float = 1e-6,
-           output_dtype: Optional[torch.dtype] = None) -> torch.Tensor:
+def l2norm(x: torch.Tensor,
+           eps: float = 1e-6,
+           output_dtype: Optional[torch.dtype] = None) -> torch.Tensor:
+    """
+    L2-normalize the last dimension of x.
+
+    Args:
+        x: [..., D] tensor. If requires_grad is True, uses a Torch eager path
+           for autograd; otherwise uses a fused Triton kernel.
+        eps: Small epsilon added inside the sqrt for numerical stability.
+        output_dtype: Optional output dtype; defaults to x.dtype.
+    Returns:
+        Tensor of same shape as x with last-dim L2 normalization applied.
+    """
     return L2NormFunction.apply(x, eps, output_dtype)
@@
 class L2Norm(nn.Module):
 
     def __init__(self,
                  eps: float = 1e-6,
                  output_dtype: Optional[torch.dtype] = None):
         super().__init__()
+        """
+        Module wrapper for l2norm().
+
+        Args:
+            eps: Epsilon for numerical stability.
+            output_dtype: Optional output dtype override.
+        """
         self.eps = eps
         self.output_dtype = output_dtype
tensorrt_llm/_torch/modules/fla/layernorm_gated.py (5)

28-31: Remove stray no-op statement.

x.shape[-1] does nothing and can confuse readers.

-    x.shape[-1]

118-140: Confirm device context handling; ensure compatibility across PyTorch backends.

torch.get_device_module(x.device).device(x.device.index) may not exist or vary across PT versions/backends. Since wrappers now route CPU to refs, Triton path should only run on CUDA. Please either:

  • switch to with torch.cuda.device(x.device.index):, or
  • keep current call but verify on supported PT versions and backends (CUDA/HIP) in CI.

Also applies to: 161-180


251-292: Public API docstrings (Google-style) for LayerNorm.

Add a brief Google-style docstring with args/returns to meet repo guidelines.


294-332: Class name collision risk with existing RMSNorm.

There is already tensorrt_llm._torch.modules.rms_norm.RMSNorm. This new RMSNorm in a different module can cause import ambiguity. Prefer explicit import sites (fully-qualified) or rename this class (e.g., FLARMSNorm) to avoid collisions.


10-15: Import style consistency for rearrange.

Repo has a local tensorrt_llm.functional.rearrange. Using einops is fine, but mixing both can be confusing. Consider importing with a qualified alias (from einops import rearrange as einops_rearrange) or standardizing on one.

tensorrt_llm/_torch/modules/fla/fused_sigmoid_gating_recurrent.py (3)

95-114: Minor: numerics for softplus and sigmoid

Softplus branch uses a threshold on beta_x which is fine. Consider precomputing inv_beta = 1.0/softplus_beta and using it to avoid repeated divisions; also clamp softplus_beta > 0. Sigmoid can overflow for very negative b_b—optional: use 1.0 / (1.0 + tl.exp(-tl.clip(b_b, -20, 20))).

Apply this micro‑diff:

-        beta_x = softplus_beta * x
+        beta_x = softplus_beta * x
+        inv_beta = 1.0 / softplus_beta
@@
-            (1.0 / softplus_beta) * tl.log(1.0 + tl.exp(beta_x)),
+            inv_beta * tl.log(1.0 + tl.exp(beta_x)),
             x,
         )
-        b_g = -tl.exp(b_A_log) * softplus_x
+        b_g = -tl.exp(b_A_log) * softplus_x
@@
-        b_beta = 1.0 / (1.0 + tl.exp(-b_b))
+        b_beta = 1.0 / (1.0 + tl.exp(-tl.maximum(tl.minimum(b_b, 20.0), -20.0)))

116-121: Minor: avoid redundant sqrt for L2 norm across masked lanes

When K is not a power of two, masked lanes are zeroed which is fine. If you want stricter numerics, sum only valid lanes using mask_k in the reduction.

Example:

-            b_q = b_q / (tl.sqrt(tl.sum(b_q * b_q)) + 1e-6)
-            b_k = b_k / (tl.sqrt(tl.sum(b_k * b_k)) + 1e-6)
+            q_norm = tl.sqrt(tl.sum((b_q * b_q))) + 1e-6
+            k_norm = tl.sqrt(tl.sum((b_k * b_k))) + 1e-6
+            b_q = b_q / q_norm
+            b_k = b_k / k_norm

180-183: Kernel tile config may underutilize the GPU

num_warps=1, num_stages=3 and BV<=8 are conservative. Consider autotuning (or exposing meta-params) for K=256, V>=16 which is common in Qwen3Next.

Happy to sketch a simple Triton autotune decorator here if you want.

examples/models/core/qwen/README.md (2)

930-939: Clarify acronyms and tighten language.

Spell out IFB and BS to avoid confusion; e.g., “In‑Flight Batching (IFB) is not supported yet; use batch_size=1.”


860-875: YAML filename mismatch breaks disaggregated serve command.

You write disagg-config.yml but run with -c disagg-config.yaml. Align names:

-trtllm-serve disaggregated -c disagg-config.yaml
+trtllm-serve disaggregated -c disagg-config.yml
tensorrt_llm/_torch/models/checkpoints/hf/qwen3_next_weight_mapper.py (2)

25-29: Make draft model skip robust to qualified names.

startswith("draft_model") misses nested names (e.g., model.draft_model.*). Match by suffix or substring.

-        if module_name.startswith("draft_model"):
+        last = module_name.split(".")[-1]
+        if last == "draft_model" or ".draft_model" in module_name:
             return True

64-66: Assume config fields exist; add fallback or assert.

Accessing linear_* fields will fail if absent. Add asserts or defaults.

tensorrt_llm/_torch/models/checkpoints/hf/qwen2_moe_weight_mapper.py (1)

27-33: Remove or ticket the commented TODO.

Prefer an issue or a clear TODO with owner/context; avoid dead commented code.

tensorrt_llm/_torch/models/modeling_utils.py (1)

958-961: Harden conv1d squeeze special‑case.

Guard presence/shape before squeezing to avoid key/shape errors on already‑squeezed tensors.

-                    if "linear_attn.conv1d" in name:
-                        module_weights['weight'] = module_weights[
-                            'weight'].squeeze(dim=1)
+                    if "linear_attn.conv1d" in name and "weight" in module_weights:
+                        w = module_weights["weight"]
+                        if w.dim() >= 3 and w.shape[1] == 1:
+                            module_weights["weight"] = w.squeeze(dim=1)
tensorrt_llm/_torch/modules/attention.py (1)

515-525: Avoid extra split/concat; compute gate without rebuilding qkv.

Minor perf nit: we can split once and avoid rebuilding qkv. Optional:

-        if self.attn_output_gate:
-            q_gate, k, v = qkv.split(
-                [self.q_size * 2, self.kv_size, self.kv_size], dim=-1)
+        if self.attn_output_gate:
+            q_gate, k, v = qkv.split([self.q_size * 2, self.kv_size, self.kv_size], dim=-1)
             orig_shape = q_gate.shape[:-1]
             q_gate = q_gate.view(*orig_shape, self.num_heads, -1)
             q, gate = torch.chunk(q_gate, 2, dim=-1)
             q = q.reshape(*orig_shape, -1)
             gate = gate.reshape(*orig_shape, -1)
-            ### TODO: avoid the redundant split and concat
-            qkv = torch.concat([q, k, v], dim=-1)
+            qkv = torch.concat([q, k, v], dim=-1)  # keep fused format
tensorrt_llm/_torch/modules/rms_norm.py (2)

35-44: Weight init differs for Gemma only when has_weights=True; consider “no-weights” parity

If has_weights=False, the buffer is always initialized to ones. In Gemma mode you then scale by (weight + 1), yielding a factor of 2 instead of 1. If this path is ever used, results will be biased. Align the no‑weights path with Gemma by initializing the buffer to zeros when use_gemma_rms_norm=True.

You can adjust the else: branch accordingly:

# inside __init__, has_weights == False
init = torch.zeros if use_gemma_rms_norm else torch.ones
self.register_buffer(
    'weight',
    init(hidden_size, dtype=dtype, device=device),
    persistent=False,
)

91-96: Python 3.8+ compatibility and typing hygiene in this module

  • Elsewhere in this file (group_rms_norm), built‑in generics and match are used; those require Python 3.9+/3.10+. Our target is Python 3.8+. Please replace list[...]/tuple[...] with List[...]/Tuple[...] and match with if/elif.

I can provide a concrete diff for that block if desired.
Also, consider replacing mutable default args like weights: Optional[list[Tensor]] = [] with None.

tensorrt_llm/_torch/modules/qk_norm_attention.py (1)

29-33: Typing: use typing.Tuple for Python 3.8

Return annotation uses tuple[...] which isn’t valid on 3.8. Switch to Tuple[...] and import it.

Apply this diff:

-from typing import Optional
+from typing import Optional, Tuple
@@
-def compute_yarn_parameters(
-    config: PretrainedConfig, ) -> tuple[float, float, float, float]:
+def compute_yarn_parameters(
+    config: PretrainedConfig, ) -> Tuple[float, float, float, float]:
tensorrt_llm/_torch/custom_ops/flashinfer_custom_ops.py (2)

37-40: Silence Ruff ARG001 in fake registration by prefixing unused args

Rename unused args in fakes to avoid noise in CI.

Apply this diff:

-    def _(input: torch.Tensor, weight: torch.Tensor,
-          eps: float) -> torch.Tensor:
+    def _(input: torch.Tensor, _weight: torch.Tensor,
+          _eps: float) -> torch.Tensor:
         return torch.empty_like(input)

51-60: Same Ruff fix for fused Gemma fake op

Keep consistency with the above change.

Apply this diff:

-    def flashinfer_gemma_fused_add_rmsnorm(input: torch.Tensor,
+    def flashinfer_gemma_fused_add_rmsnorm(input: torch.Tensor,
                                            residual: torch.Tensor,
                                            weight: torch.Tensor,
                                            eps: float) -> None:
         gemma_fused_add_rmsnorm(input,
                                 residual,
                                 weight,
                                 eps,
                                 enable_pdl=ENABLE_PDL)

And for the fake (if present elsewhere), prefix unused args similarly.

tensorrt_llm/_torch/models/modeling_qwen3_next.py (3)

197-199: Prefer torch.sigmoid over deprecated F.sigmoid

Modern PyTorch deprecates F.sigmoid.

Apply this diff:

-        shared_expert_output = F.sigmoid(
-            self.shared_expert_gate(hidden_states)) * shared_expert_output
+        shared_expert_output = torch.sigmoid(
+            self.shared_expert_gate(hidden_states)) * shared_expert_output

760-778: Redundant compute: fix_query_key_value_ordering called twice on the prefill path

You call fix_query_key_value_ordering before the if and again in the else. Drop the first call or guard it.

Apply this diff to compute lazily:

-        query, key, value, z, b, a = self.fix_query_key_value_ordering(
-            projected_states_qkvz, projected_states_ba)
-
         if self.num_v_heads // self.num_k_heads in [1, 2,
                                                     4]:  # and is_cuda_graph:
             mixed_qkv, z, b, a = fused_qkvzba_split_reshape_cat(
                 projected_states_qkvz,
                 projected_states_ba,
                 triton.cdiv(self.num_k_heads, self.attn_tp_size),
                 triton.cdiv(self.num_v_heads, self.attn_tp_size),
                 self.head_k_dim,
                 self.head_v_dim,
             )
         else:
             query, key, value, z, b, a = self.fix_query_key_value_ordering(
                 projected_states_qkvz, projected_states_ba)

1192-1204: XOR check is fine; message could be clearer

The XOR condition is correct. Optionally refine the error message to “Specify exactly one of input_ids or inputs_embeds.”

tensorrt_llm/_torch/modules/fla/solve_tril.py (9)

15-24: Consider re‑enabling Triton autotune behind a flag.

Autotune blocks are commented out. Add an env‑gated autotune to retain performance without affecting reproducibility in CI.

Example:

-# @triton.autotune(
+@triton.autotune(
     # configs=[...],
     # key=[...],
-# )
+)

Gate with: if int(os.getenv("TRITON_AUTOTUNE", "0")): ...


49-69: 16×16 kernel math/readability LGTM; add a brief comment on offset wrap.

offset = (i_t * 16) % BT guarantees offset+o_i < BT for BT∈{16,32,64}; add a one‑liner comment to avoid future regressions if BT set changes.


72-92: BT is unused in this kernel signature.

Ruff ARG001; remove BT or use it in pointer shapes. It’s currently constant (32) in all block ptrs.

Apply:

-def merge_16x16_to_32x32_inverse_kernel(
+def merge_16x16_to_32x32_inverse_kernel(
     A,
     Ad,
     Ai,
     cu_seqlens,
     chunk_indices,
     T,
     H: tl.constexpr,
-    BT: tl.constexpr,
     IS_VARLEN: tl.constexpr,
 ):

And drop BT= at the call site.


144-164: BT is unused in the 64×64 kernel too.

Same ARG001; remove it like the 32×32 kernel, and drop the argument from the launcher.

-    H: tl.constexpr,
-    BT: tl.constexpr,
+    H: tl.constexpr,
     IS_VARLEN: tl.constexpr,

314-356: Zeroing strict upper blocks is fine; consider early exit when already zero‑initialized.

If Ai is created with zeros_like you can skip stores. Minor micro‑opt.


365-380: Docstring mismatch and clarify API.

You compute (I + A)^-1 but the summary says “inverse of the lower triangular matrix.” Align text, and specify expected shapes for varlen.

Apply:

-    Compute the inverse of the lower triangular matrix
-    A should be strictly lower triangular, i.e., A.triu() == 0.
+    Compute (I + A)^-1 for a strictly lower-triangular A (A.triu() == 0).
+    A has shape [B, T, H, BT], BT ∈ {16, 32, 64}. With cu_seqlens, sequences
+    are processed in 16-wide chunks (and merged to BT), optionally varlen.

383-407: Intermediate dtype choice is good; add an explicit torch.float32.

Use torch.float32 over torch.float for clarity.

-    Ad = torch.empty(B, T, H, 16, device=A.device,
-                     dtype=torch.float if BT != 16 else output_dtype)
+    Ad = torch.empty(B, T, H, 16, device=A.device,
+                     dtype=torch.float32 if BT != 16 else output_dtype)

408-426: Re-enable autotune or make warps/stages configurable.

Hardcoding num_warps/num_stages is okay for now, but expose env overrides to simplify perf tuning without code changes.

-        num_warps=4,
-        num_stages=3,
+        num_warps=int(os.getenv("FLSOLVE_MERGE_WARPS", "4")),
+        num_stages=int(os.getenv("FLSOLVE_MERGE_STAGES", "3")),

(Remember to import os.)


72-80: Autotune configs: consider smaller num_warps on low‑SM GPUs.

If you restore autotune, include configs with num_warps=1,2 for 32×32/64×64 merges; helps older/lower‑end parts.

Also applies to: 146-152

tensorrt_llm/_torch/modules/fla/utils.py (3)

107-108: Fix incorrect type annotation for cache_entries.

The type annotation suggests a tuple but the actual type is a list. This mismatch can confuse type checkers and developers.

-    cache_entries: Tuple[Optional[Tuple], Optional[Dict], Any] = []
+    cache_entries: list = []

85-85: Add explicit stacklevel to warning call.

The static analysis correctly identifies that the warning should specify a stacklevel for better debugging context.

-            warnings.warn(msg)
+            warnings.warn(msg, stacklevel=2)

305-305: Avoid catching bare Exception.

Catching bare Exception is too broad and can mask programming errors.

-    except Exception:
+    except (KeyError, IndexError, RuntimeError):
tensorrt_llm/_torch/modules/fla/wy_fast.py (1)

117-117: Potential issue with unpacking syntax.

The unpacking syntax *k.shape looks unusual and might cause issues. Consider using explicit tuple unpacking.

-    B, T, Hg, K, V = *k.shape, v.shape[-1]
+    B, T, Hg, K = k.shape
+    V = v.shape[-1]
tensorrt_llm/_torch/modules/fla/index.py (2)

11-13: Add boundary check for empty cu_seqlens.

The function should validate that cu_seqlens has at least 2 elements to avoid index errors.

 @tensor_cache
 def prepare_lens(cu_seqlens: torch.LongTensor) -> torch.LongTensor:
+    if cu_seqlens.numel() < 2:
+        raise ValueError("cu_seqlens must have at least 2 elements")
     return cu_seqlens[1:] - cu_seqlens[:-1]

19-22: Potential performance issue with list comprehension and tolist().

Converting to Python list and back to tensor can be inefficient for large sequences. Consider keeping operations in tensor space.

-    indices = torch.cat([
-        torch.arange(n)
-        for n in triton.cdiv(prepare_lens(cu_seqlens), chunk_size).tolist()
-    ])
+    # More efficient: avoid Python list conversion
+    chunk_counts = triton.cdiv(prepare_lens(cu_seqlens), chunk_size)
+    indices = torch.cat([
+        torch.arange(n.item())
+        for n in chunk_counts
+    ])
tensorrt_llm/_torch/modules/fla/chunk_o.py (1)

134-134: Unpacking syntax issue.

Similar to the previous file, the unpacking syntax could be clearer.

-    B, T, Hg, K, V = *q.shape, v.shape[-1]
+    B, T, Hg, K = q.shape
+    V = v.shape[-1]
tensorrt_llm/_torch/modules/fla/chunk.py (4)

139-141: Docstring type: scale should be Optional[float], not Optional[int].

-        scale (Optional[int]):
+        scale (Optional[float]):

120-126: PEP 484 Optional types for public API.

Annotate optionals explicitly.

-def chunk_gated_delta_rule(
+def chunk_gated_delta_rule(
     q: torch.Tensor,
     k: torch.Tensor,
     v: torch.Tensor,
     g: torch.Tensor,
     beta: torch.Tensor,
-    scale: float = None,
-    initial_state: torch.Tensor = None,
+    scale: Optional[float] = None,
+    initial_state: Optional[torch.Tensor] = None,
     output_final_state: bool = False,
-    cu_seqlens: Optional[torch.LongTensor] = None,
+    cu_seqlens: Optional[torch.LongTensor] = None,
     head_first: bool = False,
     use_qk_l2norm_in_kernel: bool = False,
 ):

81-93: Remove stray pass and silence unused ctx per Ruff.

-    def forward(
-        ctx,
+    def forward(
+        _ctx,
         q: torch.Tensor,
@@
-    ):
-        pass
+    ):

99-110: Silence unused unpacked values to satisfy linters.

-        g, o, A, final_state, w, h, v_new = chunk_gated_delta_rule_fwd(
+        _g, o, _A, final_state, _w, _h, _v_new = chunk_gated_delta_rule_fwd(
             q=q,
             k=k,
             v=v,
             g=g,
             beta=beta,
             scale=scale,
             initial_state=initial_state,
             output_final_state=output_final_state,
             cu_seqlens=cu_seqlens,
         )
tensorrt_llm/_torch/modules/fla/chunk_delta_h.py (2)

236-246: Return type annotation incorrect; function returns three values.

-def chunk_gated_delta_rule_fwd_h(
+def chunk_gated_delta_rule_fwd_h(
     k: torch.Tensor,
@@
-    cu_seqlens: Optional[torch.LongTensor] = None,
-) -> Tuple[torch.Tensor, torch.Tensor]:
+    cu_seqlens: Optional[torch.LongTensor] = None,
+) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:

243-244: Remove author-internal note or convert to TODO.

-    chunk_size: int = 64,  # SY: remove this argument and force chunk size 64?
+    chunk_size: int = 64,  # TODO: consider enforcing chunk size = 64
tensorrt_llm/_torch/modules/fla/cumsum.py (3)

161-169: PEP 484 Optional types for public API.

-def chunk_local_cumsum_scalar(
+def chunk_local_cumsum_scalar(
     g: torch.Tensor,
     chunk_size: int,
     reverse: bool = False,
-    scale: float = None,
-    cu_seqlens: Optional[torch.Tensor] = None,
+    scale: Optional[float] = None,
+    cu_seqlens: Optional[torch.Tensor] = None,
     head_first: bool = False,
-    output_dtype: Optional[torch.dtype] = torch.float,
+    output_dtype: Optional[torch.dtype] = torch.float,
 ) -> torch.Tensor:

200-208: PEP 484 Optional types for public API.

-def chunk_local_cumsum_vector(
+def chunk_local_cumsum_vector(
     g: torch.Tensor,
     chunk_size: int,
     reverse: bool = False,
-    scale: float = None,
-    cu_seqlens: Optional[torch.Tensor] = None,
+    scale: Optional[float] = None,
+    cu_seqlens: Optional[torch.Tensor] = None,
     head_first: bool = False,
-    output_dtype: Optional[torch.dtype] = torch.float,
+    output_dtype: Optional[torch.dtype] = torch.float,
 ) -> torch.Tensor:

245-255: Drop unused kwargs; make Optional explicit.

-@input_guard
-def chunk_local_cumsum(
+@input_guard
+def chunk_local_cumsum(
     g: torch.Tensor,
     chunk_size: int,
     reverse: bool = False,
-    scale: float = None,
-    cu_seqlens: Optional[torch.Tensor] = None,
+    scale: Optional[float] = None,
+    cu_seqlens: Optional[torch.Tensor] = None,
     head_first: bool = False,
-    output_dtype: Optional[torch.dtype] = torch.float,
-    **kwargs,
+    output_dtype: Optional[torch.dtype] = torch.float,
 ) -> torch.Tensor:
tensorrt_llm/_torch/modules/fla/fused_recurrent.py (4)

214-225: PEP 484 Optional types for public API.

-def fused_recurrent_gated_delta_rule(
+def fused_recurrent_gated_delta_rule(
     q: torch.Tensor,
     k: torch.Tensor,
     v: torch.Tensor,
     g: torch.Tensor,
-    beta: torch.Tensor = None,
-    scale: float = None,
-    initial_state: torch.Tensor = None,
+    beta: Optional[torch.Tensor] = None,
+    scale: Optional[float] = None,
+    initial_state: Optional[torch.Tensor] = None,
     output_final_state: bool = False,
-    cu_seqlens: Optional[torch.LongTensor] = None,
+    cu_seqlens: Optional[torch.LongTensor] = None,
     use_qk_l2norm_in_kernel: bool = False,
 ) -> Tuple[torch.Tensor, torch.Tensor]:

573-588: PEP 484 Optional types for update API.

-def fused_recurrent_gated_delta_rule_update(
+def fused_recurrent_gated_delta_rule_update(
     q: torch.Tensor,
     k: torch.Tensor,
     v: torch.Tensor,
     g: torch.Tensor,
-    beta: torch.Tensor = None,
-    scale: float = None,
-    initial_state_source: torch.Tensor = None,
-    initial_state_indices: torch.Tensor = None,
-    cu_seqlens: Optional[torch.LongTensor] = None,
+    beta: Optional[torch.Tensor] = None,
+    scale: Optional[float] = None,
+    initial_state_source: Optional[torch.Tensor] = None,
+    initial_state_indices: Optional[torch.Tensor] = None,
+    cu_seqlens: Optional[torch.LongTensor] = None,
     use_qk_l2norm_in_kernel: bool = False,
     disable_state_update: bool = False,
     disable_output_calculation: bool = False,
-    intermediate_states_buffer: Optional[torch.Tensor] = None,
-    cache_steps: Optional[int] = None,
+    intermediate_states_buffer: Optional[torch.Tensor] = None,
+    cache_steps: Optional[int] = None,
 ) -> torch.Tensor:

175-190: Silence unused ctx per Ruff.

-    def forward(
-        ctx,
+    def forward(
+        _ctx,

564-570: Silence unused ctx in backward; message is long (TRY003).

Optional: shorten message or reference docs; rename ctx to _ctx.

-    def backward(ctx, do, dht):
-        raise NotImplementedError(
-            "Backward pass is not implemented yet and we do not have plans to implement it "
-            "because we haven't figured out how to compute dg without materializing the full "
-            "hidden states for all time steps.")
+    def backward(_ctx, _do, _dht):
+        raise NotImplementedError("Backward is not implemented for this op.")
📜 Review details

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 9dc7316 and 7939a3e.

📒 Files selected for processing (30)
  • cpp/tensorrt_llm/kernels/fusedQKNormRopeKernel.cu (1 hunks)
  • examples/models/core/qwen/README.md (2 hunks)
  • tensorrt_llm/_torch/custom_ops/__init__.py (1 hunks)
  • tensorrt_llm/_torch/custom_ops/flashinfer_custom_ops.py (2 hunks)
  • tensorrt_llm/_torch/models/__init__.py (2 hunks)
  • tensorrt_llm/_torch/models/checkpoints/__init__.py (1 hunks)
  • tensorrt_llm/_torch/models/checkpoints/hf/qwen2_moe_weight_mapper.py (1 hunks)
  • tensorrt_llm/_torch/models/checkpoints/hf/qwen3_next_weight_mapper.py (1 hunks)
  • tensorrt_llm/_torch/models/modeling_qwen3.py (2 hunks)
  • tensorrt_llm/_torch/models/modeling_qwen3_next.py (1 hunks)
  • tensorrt_llm/_torch/models/modeling_utils.py (1 hunks)
  • tensorrt_llm/_torch/modules/attention.py (5 hunks)
  • tensorrt_llm/_torch/modules/fla/chunk.py (1 hunks)
  • tensorrt_llm/_torch/modules/fla/chunk_delta_h.py (1 hunks)
  • tensorrt_llm/_torch/modules/fla/chunk_o.py (1 hunks)
  • tensorrt_llm/_torch/modules/fla/chunk_scaled_dot_kkt.py (1 hunks)
  • tensorrt_llm/_torch/modules/fla/cumsum.py (1 hunks)
  • tensorrt_llm/_torch/modules/fla/fused_recurrent.py (1 hunks)
  • tensorrt_llm/_torch/modules/fla/fused_sigmoid_gating_recurrent.py (1 hunks)
  • tensorrt_llm/_torch/modules/fla/index.py (1 hunks)
  • tensorrt_llm/_torch/modules/fla/l2norm.py (1 hunks)
  • tensorrt_llm/_torch/modules/fla/layernorm_gated.py (1 hunks)
  • tensorrt_llm/_torch/modules/fla/op.py (1 hunks)
  • tensorrt_llm/_torch/modules/fla/solve_tril.py (1 hunks)
  • tensorrt_llm/_torch/modules/fla/utils.py (1 hunks)
  • tensorrt_llm/_torch/modules/fla/wy_fast.py (1 hunks)
  • tensorrt_llm/_torch/modules/qk_norm_attention.py (2 hunks)
  • tensorrt_llm/_torch/modules/rms_norm.py (3 hunks)
  • tensorrt_llm/_torch/pyexecutor/_util.py (3 hunks)
  • tensorrt_llm/_torch/pyexecutor/config_utils.py (1 hunks)
🧰 Additional context used
📓 Path-based instructions (5)
**/*.{h,hpp,hh,hxx,cpp,cxx,cc,cu,cuh,py}

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

Use only spaces, no tabs; indent with 4 spaces.

Files:

  • tensorrt_llm/_torch/models/checkpoints/hf/qwen2_moe_weight_mapper.py
  • cpp/tensorrt_llm/kernels/fusedQKNormRopeKernel.cu
  • tensorrt_llm/_torch/pyexecutor/config_utils.py
  • tensorrt_llm/_torch/custom_ops/__init__.py
  • tensorrt_llm/_torch/models/__init__.py
  • tensorrt_llm/_torch/modules/attention.py
  • tensorrt_llm/_torch/pyexecutor/_util.py
  • tensorrt_llm/_torch/models/checkpoints/__init__.py
  • tensorrt_llm/_torch/modules/fla/cumsum.py
  • tensorrt_llm/_torch/modules/fla/chunk_o.py
  • tensorrt_llm/_torch/models/checkpoints/hf/qwen3_next_weight_mapper.py
  • tensorrt_llm/_torch/modules/rms_norm.py
  • tensorrt_llm/_torch/modules/fla/wy_fast.py
  • tensorrt_llm/_torch/modules/qk_norm_attention.py
  • tensorrt_llm/_torch/custom_ops/flashinfer_custom_ops.py
  • tensorrt_llm/_torch/modules/fla/chunk.py
  • tensorrt_llm/_torch/modules/fla/op.py
  • tensorrt_llm/_torch/models/modeling_utils.py
  • tensorrt_llm/_torch/modules/fla/chunk_delta_h.py
  • tensorrt_llm/_torch/modules/fla/index.py
  • tensorrt_llm/_torch/modules/fla/layernorm_gated.py
  • tensorrt_llm/_torch/modules/fla/chunk_scaled_dot_kkt.py
  • tensorrt_llm/_torch/models/modeling_qwen3.py
  • tensorrt_llm/_torch/modules/fla/solve_tril.py
  • tensorrt_llm/_torch/modules/fla/fused_recurrent.py
  • tensorrt_llm/_torch/modules/fla/fused_sigmoid_gating_recurrent.py
  • tensorrt_llm/_torch/modules/fla/utils.py
  • tensorrt_llm/_torch/models/modeling_qwen3_next.py
  • tensorrt_llm/_torch/modules/fla/l2norm.py
**/*.py

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

**/*.py: Python code must target Python 3.8+.
Indent Python code with 4 spaces; do not use tabs.
Maintain module namespace when importing; prefer 'from package.subpackage import foo' then 'foo.SomeClass()' instead of importing the class directly.
Python filenames should be snake_case (e.g., some_file.py).
Python classes use PascalCase names.
Functions and methods use snake_case names.
Local variables use snake_case; prefix 'k' for variables that start with a number (e.g., k_99th_percentile).
Global variables use upper SNAKE_CASE prefixed with 'G' (e.g., G_MY_GLOBAL).
Constants use upper SNAKE_CASE (e.g., MY_CONSTANT).
Avoid shadowing variables from an outer scope.
Initialize all externally visible members of a class in the constructor.
Prefer docstrings for interfaces that may be used outside a file; comments for in-function or file-local interfaces.
Use Google-style docstrings for classes and functions (Sphinx-parsable).
Document attributes and variables inline so they render under the class/function docstring.
Avoid reflection when a simpler, explicit approach suffices (e.g., avoid dict(**locals()) patterns).
In try/except, catch the most specific exceptions possible.
For duck-typing try/except, keep the try body minimal and use else for the main logic.

Files:

  • tensorrt_llm/_torch/models/checkpoints/hf/qwen2_moe_weight_mapper.py
  • tensorrt_llm/_torch/pyexecutor/config_utils.py
  • tensorrt_llm/_torch/custom_ops/__init__.py
  • tensorrt_llm/_torch/models/__init__.py
  • tensorrt_llm/_torch/modules/attention.py
  • tensorrt_llm/_torch/pyexecutor/_util.py
  • tensorrt_llm/_torch/models/checkpoints/__init__.py
  • tensorrt_llm/_torch/modules/fla/cumsum.py
  • tensorrt_llm/_torch/modules/fla/chunk_o.py
  • tensorrt_llm/_torch/models/checkpoints/hf/qwen3_next_weight_mapper.py
  • tensorrt_llm/_torch/modules/rms_norm.py
  • tensorrt_llm/_torch/modules/fla/wy_fast.py
  • tensorrt_llm/_torch/modules/qk_norm_attention.py
  • tensorrt_llm/_torch/custom_ops/flashinfer_custom_ops.py
  • tensorrt_llm/_torch/modules/fla/chunk.py
  • tensorrt_llm/_torch/modules/fla/op.py
  • tensorrt_llm/_torch/models/modeling_utils.py
  • tensorrt_llm/_torch/modules/fla/chunk_delta_h.py
  • tensorrt_llm/_torch/modules/fla/index.py
  • tensorrt_llm/_torch/modules/fla/layernorm_gated.py
  • tensorrt_llm/_torch/modules/fla/chunk_scaled_dot_kkt.py
  • tensorrt_llm/_torch/models/modeling_qwen3.py
  • tensorrt_llm/_torch/modules/fla/solve_tril.py
  • tensorrt_llm/_torch/modules/fla/fused_recurrent.py
  • tensorrt_llm/_torch/modules/fla/fused_sigmoid_gating_recurrent.py
  • tensorrt_llm/_torch/modules/fla/utils.py
  • tensorrt_llm/_torch/models/modeling_qwen3_next.py
  • tensorrt_llm/_torch/modules/fla/l2norm.py
**/*.{cpp,cxx,cc,h,hpp,hh,hxx,cu,cuh,py}

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

Prepend the NVIDIA Apache-2.0 copyright header with current year to the top of all source files (e.g., .cpp, .h, .cu, .py).

Files:

  • tensorrt_llm/_torch/models/checkpoints/hf/qwen2_moe_weight_mapper.py
  • cpp/tensorrt_llm/kernels/fusedQKNormRopeKernel.cu
  • tensorrt_llm/_torch/pyexecutor/config_utils.py
  • tensorrt_llm/_torch/custom_ops/__init__.py
  • tensorrt_llm/_torch/models/__init__.py
  • tensorrt_llm/_torch/modules/attention.py
  • tensorrt_llm/_torch/pyexecutor/_util.py
  • tensorrt_llm/_torch/models/checkpoints/__init__.py
  • tensorrt_llm/_torch/modules/fla/cumsum.py
  • tensorrt_llm/_torch/modules/fla/chunk_o.py
  • tensorrt_llm/_torch/models/checkpoints/hf/qwen3_next_weight_mapper.py
  • tensorrt_llm/_torch/modules/rms_norm.py
  • tensorrt_llm/_torch/modules/fla/wy_fast.py
  • tensorrt_llm/_torch/modules/qk_norm_attention.py
  • tensorrt_llm/_torch/custom_ops/flashinfer_custom_ops.py
  • tensorrt_llm/_torch/modules/fla/chunk.py
  • tensorrt_llm/_torch/modules/fla/op.py
  • tensorrt_llm/_torch/models/modeling_utils.py
  • tensorrt_llm/_torch/modules/fla/chunk_delta_h.py
  • tensorrt_llm/_torch/modules/fla/index.py
  • tensorrt_llm/_torch/modules/fla/layernorm_gated.py
  • tensorrt_llm/_torch/modules/fla/chunk_scaled_dot_kkt.py
  • tensorrt_llm/_torch/models/modeling_qwen3.py
  • tensorrt_llm/_torch/modules/fla/solve_tril.py
  • tensorrt_llm/_torch/modules/fla/fused_recurrent.py
  • tensorrt_llm/_torch/modules/fla/fused_sigmoid_gating_recurrent.py
  • tensorrt_llm/_torch/modules/fla/utils.py
  • tensorrt_llm/_torch/models/modeling_qwen3_next.py
  • tensorrt_llm/_torch/modules/fla/l2norm.py
**/*.{h,hpp,hh,hxx,cpp,cxx,cc,cu,cuh}

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

**/*.{h,hpp,hh,hxx,cpp,cxx,cc,cu,cuh}: Namespace closing braces must include a trailing comment with the namespace name (e.g., '} // namespace foo').
Prefer const or constexpr variables over #define for constants.
Declare variables that are not modified after initialization as const.
Avoid magic literals in code; except for 0, nullptr, true, false. Use named constants for comparisons and logic.
Use Allman brace style for formatting.
Place the semicolon of an empty for/while loop on a new line.
Bodies of switch/while/do-while/for must be compound statements (brace-delimited), and if/else must always be followed by brace-delimited statements.
Type names (e.g., classes) must be CamelCase starting with an uppercase letter (e.g., FooBar).
Local variables, methods, and namespaces use lowerCamelCase (e.g., localFooBar).
Non-magic-number global variables that are non-static and not in an anonymous namespace must be lowerCamelCase prefixed with 'g' (e.g., gDontUseGlobalFoos).
Non-magic-number globals that are static or in an anonymous namespace use lowerCamelCase prefixed with 's' (e.g., sMutableStaticGlobal).
Locally visible static variables use lowerCamelCase with 's' prefix (e.g., static std::once_flag sFlag).
Private/protected member variables use 'm' prefix with CamelCase (e.g., mNbFooValues). Public members may omit, but 'm' is encouraged for clarity.
Constants (enums, global constants, static constants, and function-scope magic/literal constants) use uppercase SNAKE_CASE with 'k' prefix (e.g., kDIGIT_NUM).
Function-scope constants that are not magic numbers or literals are named like non-constant variables (e.g., bool const pass = a && b).
If macros are necessary, name them in UPPER_SNAKE_CASE (e.g., FOO_VERSION) and prefer constants over #define.
Use LLVM clang-format; wrap lines at a maximum of 120 columns; use '// clang-format off/on' sparingly with justification.
Use smart pointers for heap allocations; prefer unique_ptr for sole ownership, shared_ptr for shared...

Files:

  • cpp/tensorrt_llm/kernels/fusedQKNormRopeKernel.cu
**/*.{cpp,cxx,cc,cu,h,hpp,hh,hxx,cuh}

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

C++ filenames should be lowerCamelCase (first letter lowercase) and must be case-insensitive unique within a compilation target.

Files:

  • cpp/tensorrt_llm/kernels/fusedQKNormRopeKernel.cu
🧠 Learnings (1)
📚 Learning: 2025-08-27T14:41:56.665Z
Learnt from: ixlmar
PR: NVIDIA/TensorRT-LLM#7294
File: tensorrt_llm/_torch/modules/rms_norm.py:96-99
Timestamp: 2025-08-27T14:41:56.665Z
Learning: In tensorrt_llm/_torch/modules/rms_norm.py, the RMSNorm class uses a custom sentinel (_ARGUMENT_NOT_SPECIFIED_SENTINEL) instead of Ellipsis (...) for detecting unspecified optional arguments. Other modules in the codebase may use Ellipsis as a sentinel but do not forward it to RMSNorm methods, so there's no need for backward compatibility with Ellipsis in RMSNorm.

Applied to files:

  • tensorrt_llm/_torch/modules/rms_norm.py
🧬 Code graph analysis (24)
tensorrt_llm/_torch/pyexecutor/config_utils.py (1)
tensorrt_llm/_torch/models/modeling_utils.py (1)
  • config (519-520)
tensorrt_llm/_torch/custom_ops/__init__.py (1)
tensorrt_llm/_torch/custom_ops/flashinfer_custom_ops.py (5)
  • flashinfer_fused_add_rmsnorm (44-47)
  • flashinfer_gemma_fused_add_rmsnorm (51-59)
  • flashinfer_gemma_rmsnorm (33-35)
  • flashinfer_rmsnorm (22-24)
  • flashinfer_silu_and_mul (13-14)
tensorrt_llm/_torch/models/__init__.py (1)
tensorrt_llm/_torch/models/modeling_qwen3_next.py (1)
  • Qwen3NextForCausalLM (1231-1254)
tensorrt_llm/_torch/modules/attention.py (2)
tensorrt_llm/logger.py (1)
  • warning_once (135-136)
tensorrt_llm/_torch/modules/qk_norm_attention.py (1)
  • apply_rope (240-257)
tensorrt_llm/_torch/pyexecutor/_util.py (3)
tensorrt_llm/_torch/pyexecutor/config_utils.py (1)
  • is_qwen3_next (16-17)
tensorrt_llm/_torch/models/modeling_utils.py (1)
  • config (519-520)
tensorrt_llm/_torch/pyexecutor/mamba_cache_manager.py (1)
  • MambaHybridCacheManager (167-246)
tensorrt_llm/_torch/models/checkpoints/__init__.py (2)
tensorrt_llm/_torch/models/checkpoints/hf/qwen3_next_weight_mapper.py (1)
  • Qwen3NextHfWeightMapper (15-105)
tensorrt_llm/_torch/models/checkpoints/hf/weight_mapper.py (1)
  • HfWeightMapper (10-99)
tensorrt_llm/_torch/modules/fla/cumsum.py (2)
tensorrt_llm/_torch/modules/fla/index.py (1)
  • prepare_chunk_indices (17-23)
tensorrt_llm/_torch/modules/fla/utils.py (2)
  • check_shared_mem (300-306)
  • input_guard (133-166)
tensorrt_llm/_torch/modules/fla/chunk_o.py (3)
tensorrt_llm/_torch/modules/fla/index.py (1)
  • prepare_chunk_indices (17-23)
tensorrt_llm/_torch/modules/fla/op.py (1)
  • safe_exp (26-27)
tensorrt_llm/_torch/modules/fla/utils.py (1)
  • check_shared_mem (300-306)
tensorrt_llm/_torch/models/checkpoints/hf/qwen3_next_weight_mapper.py (2)
tensorrt_llm/_torch/models/modeling_utils.py (3)
  • register_mapper (655-666)
  • DecoderModelForCausalLM (342-603)
  • config (519-520)
tensorrt_llm/_torch/models/checkpoints/hf/weight_mapper.py (1)
  • _duplicate_kv (76-99)
tensorrt_llm/_torch/modules/rms_norm.py (1)
tensorrt_llm/_torch/custom_ops/flashinfer_custom_ops.py (4)
  • flashinfer_gemma_fused_add_rmsnorm (51-59)
  • flashinfer_gemma_rmsnorm (33-35)
  • flashinfer_rmsnorm (22-24)
  • flashinfer_fused_add_rmsnorm (44-47)
tensorrt_llm/_torch/modules/fla/wy_fast.py (1)
tensorrt_llm/_torch/modules/fla/index.py (1)
  • prepare_chunk_indices (17-23)
tensorrt_llm/_torch/modules/qk_norm_attention.py (1)
tensorrt_llm/_torch/modules/rms_norm.py (1)
  • RMSNorm (25-110)
tensorrt_llm/_torch/custom_ops/flashinfer_custom_ops.py (1)
tests/unittest/_torch/misc/test_autotuner.py (2)
  • _ (129-130)
  • _ (203-204)
tensorrt_llm/_torch/modules/fla/chunk.py (8)
tensorrt_llm/_torch/modules/fla/chunk_delta_h.py (1)
  • chunk_gated_delta_rule_fwd_h (236-294)
tensorrt_llm/_torch/modules/fla/chunk_o.py (1)
  • chunk_fwd_o (124-169)
tensorrt_llm/_torch/modules/fla/chunk_scaled_dot_kkt.py (1)
  • chunk_scaled_dot_kkt_fwd (88-143)
tensorrt_llm/_torch/modules/fla/cumsum.py (1)
  • chunk_local_cumsum (246-282)
tensorrt_llm/_torch/modules/fla/l2norm.py (2)
  • l2norm (133-136)
  • l2norm_fwd (74-122)
tensorrt_llm/_torch/modules/fla/solve_tril.py (1)
  • solve_tril (360-426)
tensorrt_llm/_torch/modules/fla/utils.py (1)
  • input_guard (133-166)
tensorrt_llm/_torch/modules/fla/wy_fast.py (1)
  • recompute_w_u_fwd (109-149)
tensorrt_llm/_torch/modules/fla/chunk_delta_h.py (2)
tensorrt_llm/_torch/modules/fla/index.py (2)
  • prepare_chunk_indices (17-23)
  • prepare_chunk_offsets (27-32)
tensorrt_llm/_torch/modules/fla/op.py (1)
  • safe_exp (26-27)
tensorrt_llm/_torch/modules/fla/index.py (1)
tensorrt_llm/_torch/modules/fla/utils.py (1)
  • tensor_cache (93-130)
tensorrt_llm/_torch/modules/fla/layernorm_gated.py (2)
tensorrt_llm/functional.py (3)
  • rearrange (6144-6287)
  • sqrt (450-454)
  • sum (3253-3275)
tensorrt_llm/_torch/modules/rms_norm.py (1)
  • RMSNorm (25-110)
tensorrt_llm/_torch/modules/fla/chunk_scaled_dot_kkt.py (2)
tensorrt_llm/_torch/modules/fla/index.py (1)
  • prepare_chunk_indices (17-23)
tensorrt_llm/_torch/modules/fla/op.py (1)
  • safe_exp (26-27)
tensorrt_llm/_torch/models/modeling_qwen3.py (2)
tensorrt_llm/_torch/models/modeling_utils.py (1)
  • config (519-520)
tensorrt_llm/_torch/model_config.py (1)
  • torch_dtype (199-204)
tensorrt_llm/_torch/modules/fla/solve_tril.py (2)
tensorrt_llm/_torch/modules/fla/index.py (1)
  • prepare_chunk_indices (17-23)
tensorrt_llm/_torch/modules/fla/utils.py (1)
  • input_guard (133-166)
tensorrt_llm/_torch/modules/fla/fused_recurrent.py (1)
tensorrt_llm/_torch/modules/fla/utils.py (1)
  • input_guard (133-166)
tensorrt_llm/_torch/modules/fla/fused_sigmoid_gating_recurrent.py (3)
tensorrt_llm/_torch/modules/fla/utils.py (1)
  • input_guard (133-166)
tensorrt_llm/_torch/modules/fla/chunk_delta_h.py (1)
  • grid (270-271)
tensorrt_llm/_torch/modules/fla/chunk_o.py (1)
  • grid (145-146)
tensorrt_llm/_torch/models/modeling_qwen3_next.py (13)
tensorrt_llm/_torch/models/checkpoints/base_weight_mapper.py (1)
  • BaseWeightMapper (10-165)
tensorrt_llm/_torch/modules/fla/chunk.py (2)
  • chunk_gated_delta_rule (114-238)
  • forward (80-110)
tensorrt_llm/_torch/modules/fla/fused_sigmoid_gating_recurrent.py (1)
  • fused_sigmoid_gating_delta_rule_update (156-222)
tensorrt_llm/_torch/modules/mamba/mamba2_metadata.py (1)
  • Mamba2Metadata (88-137)
tensorrt_llm/mapping.py (1)
  • Mapping (32-513)
tensorrt_llm/_torch/attention_backend/interface.py (2)
  • AttentionMetadata (40-336)
  • num_ctx_tokens (263-264)
tensorrt_llm/_torch/distributed/ops.py (2)
  • AllReduce (400-553)
  • MoEAllReduce (556-644)
tensorrt_llm/_torch/modules/embedding.py (1)
  • Embedding (162-244)
tensorrt_llm/_torch/modules/fused_moe/create_moe.py (1)
  • create_moe (61-211)
tensorrt_llm/_torch/modules/mamba/causal_conv1d.py (2)
  • causal_conv1d_fn (26-74)
  • causal_conv1d_update (77-118)
tensorrt_llm/_torch/modules/rms_norm.py (2)
  • RMSNorm (25-110)
  • forward (54-100)
tensorrt_llm/_torch/models/modeling_speculative.py (1)
  • SpecDecOneEngineForCausalLM (377-496)
tensorrt_llm/_torch/models/modeling_utils.py (7)
  • DecoderModel (228-321)
  • EagerFusionConfig (40-44)
  • register_auto_model (614-620)
  • forward (242-271)
  • forward (526-552)
  • load_weights (554-572)
  • config (519-520)
tensorrt_llm/_torch/modules/fla/l2norm.py (2)
tensorrt_llm/_torch/modules/fla/utils.py (1)
  • input_guard (133-166)
tensorrt_llm/_torch/modules/fla/chunk.py (1)
  • forward (80-110)
🪛 Ruff (0.13.1)
tensorrt_llm/_torch/pyexecutor/_util.py

456-458: Avoid specifying long messages outside the exception class

(TRY003)

tensorrt_llm/_torch/modules/fla/cumsum.py

33-33: Unused function argument: B

(ARG001)


92-92: Unused function argument: B

(ARG001)


165-165: PEP 484 prohibits implicit Optional

Convert to Optional[T]

(RUF013)


204-204: PEP 484 prohibits implicit Optional

Convert to Optional[T]

(RUF013)


250-250: PEP 484 prohibits implicit Optional

Convert to Optional[T]

(RUF013)


254-254: Unused function argument: kwargs

(ARG001)


280-282: Avoid specifying long messages outside the exception class

(TRY003)

tensorrt_llm/_torch/custom_ops/flashinfer_custom_ops.py

38-38: Unused function argument: weight

(ARG001)


39-39: Unused function argument: eps

(ARG001)

tensorrt_llm/_torch/modules/fla/chunk.py

81-81: Unused static method argument: ctx

(ARG004)


99-99: Unpacked variable A is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)


99-99: Unpacked variable w is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)


99-99: Unpacked variable h is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)


99-99: Unpacked variable v_new is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)


120-120: PEP 484 prohibits implicit Optional

Convert to Optional[T]

(RUF013)


199-201: Avoid specifying long messages outside the exception class

(TRY003)


213-215: Avoid specifying long messages outside the exception class

(TRY003)


218-221: Avoid specifying long messages outside the exception class

(TRY003)

tensorrt_llm/_torch/modules/fla/op.py

33-33: Unused function argument: src

(ARG001)


33-33: Unused function argument: index

(ARG001)


33-33: Unused function argument: axis

(ARG001)


59-59: Unused function argument: base

(ARG001)


60-60: Unused function argument: shape

(ARG001)


61-61: Unused function argument: strides

(ARG001)


62-62: Unused function argument: block_shape

(ARG001)

tensorrt_llm/_torch/modules/fla/layernorm_gated.py

156-157: Avoid specifying long messages outside the exception class

(TRY003)


188-188: Unused static method argument: ctx

(ARG004)


213-213: Unpacked variable mean is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)


213-213: Unpacked variable rstd is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)

tensorrt_llm/_torch/modules/fla/solve_tril.py

90-90: Unused function argument: BT

(ARG001)


162-162: Unused function argument: BT

(ARG001)

tensorrt_llm/_torch/modules/fla/fused_recurrent.py

178-178: Unused static method argument: ctx

(ARG004)


220-220: PEP 484 prohibits implicit Optional

Convert to Optional[T]

(RUF013)


287-289: Avoid specifying long messages outside the exception class

(TRY003)


292-295: Avoid specifying long messages outside the exception class

(TRY003)


529-529: Unused static method argument: ctx

(ARG004)


579-579: PEP 484 prohibits implicit Optional

Convert to Optional[T]

(RUF013)


591-593: Avoid specifying long messages outside the exception class

(TRY003)


596-599: Avoid specifying long messages outside the exception class

(TRY003)

tensorrt_llm/_torch/modules/fla/utils.py

85-85: No explicit stacklevel keyword argument found

Set stacklevel=2

(B028)


226-226: Do not catch blind exception: BaseException

(BLE001)


235-235: Do not catch blind exception: BaseException

(BLE001)


280-280: Do not catch blind exception: BaseException

(BLE001)


305-305: Do not catch blind exception: Exception

(BLE001)

tensorrt_llm/_torch/models/modeling_qwen3_next.py

66-66: Unused method argument: moe_backend

(ARG002)


97-98: Avoid specifying long messages outside the exception class

(TRY003)


175-175: Do not assert False (python -O removes these calls), raise AssertionError()

Replace assert False

(B011)


503-508: Consider iterable unpacking instead of concatenation

Replace with iterable unpacking

(RUF005)


509-512: Consider iterable unpacking instead of concatenation

Replace with iterable unpacking

(RUF005)


735-735: Unpacked variable seq_len is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)


736-736: Local variable conv_state is assigned to but never used

Remove assignment to unused variable conv_state

(F841)


736-736: Local variable recurrent_state is assigned to but never used

Remove assignment to unused variable recurrent_state

(F841)


878-878: Unused method argument: position_ids

(ARG002)


1199-1199: Unused method argument: kwargs

(ARG002)


1202-1204: Avoid specifying long messages outside the exception class

(TRY003)

tensorrt_llm/_torch/modules/fla/l2norm.py

59-59: Unused function argument: NB

(ARG001)


91-91: Avoid specifying long messages outside the exception class

(TRY003)


129-129: Unused static method argument: ctx

(ARG004)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Pre-commit Check

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Review continued from previous batch...

@tensorrt-cicd
Copy link
Collaborator

PR_Github #19484 [ run ] triggered by Bot

@tensorrt-cicd
Copy link
Collaborator

PR_Github #19484 [ run ] completed with state FAILURE
/LLM/main/L0_MergeRequest_PR pipeline #14647 completed with status: 'FAILURE'

@byshiue
Copy link
Collaborator Author

byshiue commented Sep 23, 2025

/bot run

@tensorrt-cicd
Copy link
Collaborator

PR_Github #19646 [ run ] triggered by Bot

@byshiue
Copy link
Collaborator Author

byshiue commented Sep 23, 2025

/bot run

@tensorrt-cicd
Copy link
Collaborator

PR_Github #19652 [ run ] triggered by Bot

@tensorrt-cicd
Copy link
Collaborator

PR_Github #19646 [ run ] completed with state ABORTED
LLM/main/L0_MergeRequest_PR #14780 (Blue Ocean) completed with status: ABORTED

@tensorrt-cicd
Copy link
Collaborator

PR_Github #19652 [ run ] completed with state SUCCESS
/LLM/main/L0_MergeRequest_PR pipeline #14787 completed with status: 'FAILURE'

@nv-guomingz
Copy link
Collaborator

/bot run

byshiue and others added 7 commits September 29, 2025 08:49
Signed-off-by: bhsueh <11360707+byshiue@users.noreply.github.com>
Signed-off-by: bhsueh <11360707+byshiue@users.noreply.github.com>
Signed-off-by: bhsueh <11360707+byshiue@users.noreply.github.com>
Signed-off-by: bhsueh <11360707+byshiue@users.noreply.github.com>
Signed-off-by: nv-guomingz <137257613+nv-guomingz@users.noreply.github.com>
Signed-off-by: nv-guomingz <137257613+nv-guomingz@users.noreply.github.com>
Signed-off-by: nv-guomingz <137257613+nv-guomingz@users.noreply.github.com>
@nv-guomingz
Copy link
Collaborator

/bot run --disable-fail-fast

@tensorrt-cicd
Copy link
Collaborator

PR_Github #20191 [ run ] triggered by Bot

@tensorrt-cicd
Copy link
Collaborator

PR_Github #20191 [ run ] completed with state SUCCESS
/LLM/main/L0_MergeRequest_PR pipeline #15227 completed with status: 'FAILURE'

@nv-guomingz
Copy link
Collaborator

/bot run

@tensorrt-cicd
Copy link
Collaborator

PR_Github #20214 [ run ] triggered by Bot

Copy link
Collaborator

@yuxianq yuxianq left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reviewed attention.py

@tensorrt-cicd
Copy link
Collaborator

PR_Github #20214 [ run ] completed with state SUCCESS
/LLM/main/L0_MergeRequest_PR pipeline #15246 completed with status: 'FAILURE'

@nv-guomingz
Copy link
Collaborator

/bot run

@tensorrt-cicd
Copy link
Collaborator

PR_Github #20256 [ run ] triggered by Bot

@tensorrt-cicd
Copy link
Collaborator

PR_Github #20256 [ run ] completed with state SUCCESS
/LLM/main/L0_MergeRequest_PR pipeline #15273 completed with status: 'SUCCESS'
Pipeline passed with automatic retried tests. Check the rerun report for details.

@nv-guomingz nv-guomingz merged commit 38d6e4e into NVIDIA:main Sep 29, 2025
5 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

9 participants