-
Notifications
You must be signed in to change notification settings - Fork 1.8k
[None][feat] Support Qwen3 next #7892
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
/bot run |
📝 WalkthroughWalkthroughAdds Qwen3-Next model integration with linear/full attention, Triton-based FLA kernels, Gemma RMSNorm variants, attention output gating, new HF weight mapper, and executor support. Extends custom ops/public exports, adjusts Q/KV kernel to support head_dim=256, updates docs, and tweaks weight-loading and KV cache creation paths. Changes
Sequence Diagram(s)sequenceDiagram
autonumber
participant U as User
participant LM as Qwen3NextForCausalLM
participant M as Qwen3NextModel
participant L as DecoderLayer (linear/full)
participant G as Gated Delta Net / Attention
participant MOE as Sparse MoE Block
participant N as RMSNorm/O-Proj
U->>LM: forward(input_ids/embeds, attn_metadata, spec_metadata)
LM->>M: forward(...)
M->>M: embeddings + layer selection
loop num_hidden_layers
M->>L: forward(hidden_states, attn_metadata, spec_metadata)
alt linear_attention
L->>G: forward(hidden_states, attn_metadata, mamba_metadata)
G-->>L: attn_out
else full_attention
L->>G: self-attention (QKV, gating/norm)
G-->>L: attn_out
end
L->>MOE: route/process (if enabled)
MOE-->>L: moe_out
L->>N: post-attn norm/proj
N-->>M: layer_out
end
M-->>LM: logits or hidden_states
LM-->>U: output
sequenceDiagram
autonumber
participant HF as HF Weights
participant WM as Qwen3NextHfWeightMapper
participant MU as modeling_utils
participant MD as Model/Modules
HF-->>WM: provide state_dict + config
WM->>WM: init_model_and_config (num_kv_heads)
WM->>WM: preprocess_weights (A_log, dt_bias, conv1d split/concat)
WM->>WM: duplicate_kv (per TP, quantization-aware)
WM-->>MU: mapped weights
MU->>MD: traverse modules
alt module has load_weights and name includes "linear_attn.conv1d"
MU->>MD: squeeze(dim=1) then load_weights
else
MU->>MD: load_weights
end
sequenceDiagram
autonumber
participant C as Config
participant CU as config_utils.is_qwen3_next
participant EX as _pyexecutor._util
participant CM as CacheManager
C-->>CU: query linear_key_head_dim
CU-->>EX: True/False
alt True
EX->>EX: derive mamba/layer masks
EX->>CM: create MambaHybridCacheManager (validate beam/connector)
else False
EX->>CM: other cache manager paths
end
CM-->>EX: manager instance
sequenceDiagram
autonumber
participant C as Config
participant A as Attention
participant O as Output Gate
C-->>A: attn_output_gate flag
A->>A: compute attn_out
opt attn_output_gate
A->>O: sigmoid(gate) * attn_out
O-->>A: gated_out
end
A-->>A: o_proj(gated_or_plain)
Estimated code review effort🎯 5 (Critical) | ⏱️ ~120–180 minutes Pre-merge checks and finishing touches❌ Failed checks (2 warnings)
✅ Passed checks (1 passed)
✨ Finishing touches
🧪 Generate unit tests
📜 Recent review detailsConfiguration used: Path: .coderabbit.yaml Review profile: CHILL Plan: Pro 📒 Files selected for processing (1)
🚧 Files skipped from review as they are similar to previous changes (1)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 43
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
tensorrt_llm/_torch/models/checkpoints/__init__.py (1)
1-1
: Missing NVIDIA Apache-2.0 header (2025) at file topThis repo requires the NVIDIA Apache-2.0 header on all source files. Please add it.
Apply:
+# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from .base_checkpoint_loader import BaseCheckpointLoader from .hf.checkpoint_loader import HfCheckpointLoader from .hf.config_loader import HfConfigLoader from .hf.gemma3_weight_mapper import Gemma3HfWeightMapper from .hf.llama4_weight_mapper import Llama4HfWeightMapper from .hf.mixtral_weight_mapper import MixtralHfWeightMapper from .hf.nemotron_h_weight_mapper import NemotronHHfWeightMapper from .hf.qwen2_moe_weight_mapper import Qwen2MoeHfWeightMapper from .hf.qwen3_moe_weight_mapper import Qwen3MoeHfWeightMapper from .hf.qwen3_next_weight_mapper import Qwen3NextHfWeightMapper
🧹 Nitpick comments (57)
tensorrt_llm/_torch/modules/fla/l2norm.py (5)
24-44
: Fix misleading comment and align store-cast with the other kernel.
- The comment says “mean and variance” but the code computes squared L2 norm.
- Optional: mirror the explicit cast used in l2norm_fwd_kernel to make the intended destination dtype clear (Triton will auto-cast, but being explicit improves readability).
- # Compute mean and variance + # Compute squared L2 norm (sum of squares) and normalize @@ - b_y = b_x * b_rstd - tl.store(y + cols, b_y, mask=mask) + b_y = b_x * b_rstd + # Match the explicit cast pattern used in the 2D kernel for clarity + tl.store(y + cols, b_y.to(tl.float32), mask=mask)If you prefer exact destination dtype, pass it as a constexpr/meta or derive from a block pointer like in the 2D kernel. Otherwise, keep auto-cast and only change the comment.
55-71
: Remove unused NB meta-parameter and its call-site.NB is computed/passed but never used in the kernel, tripping linters and confusing readers.
-def l2norm_fwd_kernel( +def l2norm_fwd_kernel( x, y, eps, - NB: tl.constexpr, T: tl.constexpr, D: tl.constexpr, BT: tl.constexpr, BD: tl.constexpr, ): @@ - if D <= 512: - NB = triton.cdiv(T, 2048) + if D <= 512: @@ l2norm_fwd_kernel[grid]( x, y, eps, - NB=NB, T=T, D=D, BD=BD, BT=16, num_warps=8, num_stages=3, )Also applies to: 93-111
74-93
: Guard max feature size check with a clearer error and type info.The TRY003 lint likely flags the long message; also, including the resolved byte budget and dtype helps debugging.
- if D > BD: - raise RuntimeError("This layer doesn't support feature dim >= 64KB.") + if D > BD: + raise RuntimeError( + f"L2Norm fused kernel supports at most ~64KB per feature; " + f"D={D}, dtype={x.dtype}, element_size={x.element_size()}B." + )
127-131
: Silence linter for unused ctx or use it.Ruff flags ctx unused in forward. Either delete it or bind to underscore.
- def forward(ctx, x, eps=1e-6, output_dtype=None): + def forward(ctx, x, eps=1e-6, output_dtype=None): + # ctx is unused because we don't implement backward for the Triton path + del ctx return l2norm_fwd(x, eps, output_dtype)
142-152
: Add minimal docstrings for public API.Document behavior, shapes, and dtype handling for l2norm() and L2Norm to align with repo guidelines.
-def l2norm(x: torch.Tensor, - eps: float = 1e-6, - output_dtype: Optional[torch.dtype] = None) -> torch.Tensor: +def l2norm(x: torch.Tensor, + eps: float = 1e-6, + output_dtype: Optional[torch.dtype] = None) -> torch.Tensor: + """ + L2-normalize the last dimension of x. + + Args: + x: [..., D] tensor. If requires_grad is True, uses a Torch eager path + for autograd; otherwise uses a fused Triton kernel. + eps: Small epsilon added inside the sqrt for numerical stability. + output_dtype: Optional output dtype; defaults to x.dtype. + Returns: + Tensor of same shape as x with last-dim L2 normalization applied. + """ return L2NormFunction.apply(x, eps, output_dtype) @@ class L2Norm(nn.Module): def __init__(self, eps: float = 1e-6, output_dtype: Optional[torch.dtype] = None): super().__init__() + """ + Module wrapper for l2norm(). + + Args: + eps: Epsilon for numerical stability. + output_dtype: Optional output dtype override. + """ self.eps = eps self.output_dtype = output_dtypetensorrt_llm/_torch/modules/fla/layernorm_gated.py (5)
28-31
: Remove stray no-op statement.
x.shape[-1]
does nothing and can confuse readers.- x.shape[-1]
118-140
: Confirm device context handling; ensure compatibility across PyTorch backends.
torch.get_device_module(x.device).device(x.device.index)
may not exist or vary across PT versions/backends. Since wrappers now route CPU to refs, Triton path should only run on CUDA. Please either:
- switch to
with torch.cuda.device(x.device.index):
, or- keep current call but verify on supported PT versions and backends (CUDA/HIP) in CI.
Also applies to: 161-180
251-292
: Public API docstrings (Google-style) for LayerNorm.Add a brief Google-style docstring with args/returns to meet repo guidelines.
294-332
: Class name collision risk with existing RMSNorm.There is already
tensorrt_llm._torch.modules.rms_norm.RMSNorm
. This newRMSNorm
in a different module can cause import ambiguity. Prefer explicit import sites (fully-qualified) or rename this class (e.g.,FLARMSNorm
) to avoid collisions.
10-15
: Import style consistency forrearrange
.Repo has a local
tensorrt_llm.functional.rearrange
. Using einops is fine, but mixing both can be confusing. Consider importing with a qualified alias (from einops import rearrange as einops_rearrange
) or standardizing on one.tensorrt_llm/_torch/modules/fla/fused_sigmoid_gating_recurrent.py (3)
95-114
: Minor: numerics for softplus and sigmoidSoftplus branch uses a threshold on
beta_x
which is fine. Consider precomputinginv_beta = 1.0/softplus_beta
and using it to avoid repeated divisions; also clampsoftplus_beta > 0
. Sigmoid can overflow for very negativeb_b
—optional: use1.0 / (1.0 + tl.exp(-tl.clip(b_b, -20, 20)))
.Apply this micro‑diff:
- beta_x = softplus_beta * x + beta_x = softplus_beta * x + inv_beta = 1.0 / softplus_beta @@ - (1.0 / softplus_beta) * tl.log(1.0 + tl.exp(beta_x)), + inv_beta * tl.log(1.0 + tl.exp(beta_x)), x, ) - b_g = -tl.exp(b_A_log) * softplus_x + b_g = -tl.exp(b_A_log) * softplus_x @@ - b_beta = 1.0 / (1.0 + tl.exp(-b_b)) + b_beta = 1.0 / (1.0 + tl.exp(-tl.maximum(tl.minimum(b_b, 20.0), -20.0)))
116-121
: Minor: avoid redundant sqrt for L2 norm across masked lanesWhen K is not a power of two, masked lanes are zeroed which is fine. If you want stricter numerics, sum only valid lanes using
mask_k
in the reduction.Example:
- b_q = b_q / (tl.sqrt(tl.sum(b_q * b_q)) + 1e-6) - b_k = b_k / (tl.sqrt(tl.sum(b_k * b_k)) + 1e-6) + q_norm = tl.sqrt(tl.sum((b_q * b_q))) + 1e-6 + k_norm = tl.sqrt(tl.sum((b_k * b_k))) + 1e-6 + b_q = b_q / q_norm + b_k = b_k / k_norm
180-183
: Kernel tile config may underutilize the GPU
num_warps=1, num_stages=3
andBV<=8
are conservative. Consider autotuning (or exposing meta-params) for K=256, V>=16 which is common in Qwen3Next.Happy to sketch a simple Triton autotune decorator here if you want.
examples/models/core/qwen/README.md (2)
930-939
: Clarify acronyms and tighten language.Spell out IFB and BS to avoid confusion; e.g., “In‑Flight Batching (IFB) is not supported yet; use batch_size=1.”
860-875
: YAML filename mismatch breaks disaggregated serve command.You write
disagg-config.yml
but run with-c disagg-config.yaml
. Align names:-trtllm-serve disaggregated -c disagg-config.yaml +trtllm-serve disaggregated -c disagg-config.ymltensorrt_llm/_torch/models/checkpoints/hf/qwen3_next_weight_mapper.py (2)
25-29
: Make draft model skip robust to qualified names.
startswith("draft_model")
misses nested names (e.g.,model.draft_model.*
). Match by suffix or substring.- if module_name.startswith("draft_model"): + last = module_name.split(".")[-1] + if last == "draft_model" or ".draft_model" in module_name: return True
64-66
: Assume config fields exist; add fallback or assert.Accessing
linear_*
fields will fail if absent. Add asserts or defaults.tensorrt_llm/_torch/models/checkpoints/hf/qwen2_moe_weight_mapper.py (1)
27-33
: Remove or ticket the commented TODO.Prefer an issue or a clear TODO with owner/context; avoid dead commented code.
tensorrt_llm/_torch/models/modeling_utils.py (1)
958-961
: Harden conv1d squeeze special‑case.Guard presence/shape before squeezing to avoid key/shape errors on already‑squeezed tensors.
- if "linear_attn.conv1d" in name: - module_weights['weight'] = module_weights[ - 'weight'].squeeze(dim=1) + if "linear_attn.conv1d" in name and "weight" in module_weights: + w = module_weights["weight"] + if w.dim() >= 3 and w.shape[1] == 1: + module_weights["weight"] = w.squeeze(dim=1)tensorrt_llm/_torch/modules/attention.py (1)
515-525
: Avoid extra split/concat; compute gate without rebuilding qkv.Minor perf nit: we can split once and avoid rebuilding
qkv
. Optional:- if self.attn_output_gate: - q_gate, k, v = qkv.split( - [self.q_size * 2, self.kv_size, self.kv_size], dim=-1) + if self.attn_output_gate: + q_gate, k, v = qkv.split([self.q_size * 2, self.kv_size, self.kv_size], dim=-1) orig_shape = q_gate.shape[:-1] q_gate = q_gate.view(*orig_shape, self.num_heads, -1) q, gate = torch.chunk(q_gate, 2, dim=-1) q = q.reshape(*orig_shape, -1) gate = gate.reshape(*orig_shape, -1) - ### TODO: avoid the redundant split and concat - qkv = torch.concat([q, k, v], dim=-1) + qkv = torch.concat([q, k, v], dim=-1) # keep fused formattensorrt_llm/_torch/modules/rms_norm.py (2)
35-44
: Weight init differs for Gemma only when has_weights=True; consider “no-weights” parityIf
has_weights=False
, the buffer is always initialized to ones. In Gemma mode you then scale by(weight + 1)
, yielding a factor of 2 instead of 1. If this path is ever used, results will be biased. Align the no‑weights path with Gemma by initializing the buffer to zeros whenuse_gemma_rms_norm=True
.You can adjust the
else:
branch accordingly:# inside __init__, has_weights == False init = torch.zeros if use_gemma_rms_norm else torch.ones self.register_buffer( 'weight', init(hidden_size, dtype=dtype, device=device), persistent=False, )
91-96
: Python 3.8+ compatibility and typing hygiene in this module
- Elsewhere in this file (
group_rms_norm
), built‑in generics andmatch
are used; those require Python 3.9+/3.10+. Our target is Python 3.8+. Please replacelist[...]
/tuple[...]
withList[...]
/Tuple[...]
andmatch
withif/elif
.I can provide a concrete diff for that block if desired.
Also, consider replacing mutable default args likeweights: Optional[list[Tensor]] = []
withNone
.tensorrt_llm/_torch/modules/qk_norm_attention.py (1)
29-33
: Typing: use typing.Tuple for Python 3.8Return annotation uses
tuple[...]
which isn’t valid on 3.8. Switch toTuple[...]
and import it.Apply this diff:
-from typing import Optional +from typing import Optional, Tuple @@ -def compute_yarn_parameters( - config: PretrainedConfig, ) -> tuple[float, float, float, float]: +def compute_yarn_parameters( + config: PretrainedConfig, ) -> Tuple[float, float, float, float]:tensorrt_llm/_torch/custom_ops/flashinfer_custom_ops.py (2)
37-40
: Silence Ruff ARG001 in fake registration by prefixing unused argsRename unused args in fakes to avoid noise in CI.
Apply this diff:
- def _(input: torch.Tensor, weight: torch.Tensor, - eps: float) -> torch.Tensor: + def _(input: torch.Tensor, _weight: torch.Tensor, + _eps: float) -> torch.Tensor: return torch.empty_like(input)
51-60
: Same Ruff fix for fused Gemma fake opKeep consistency with the above change.
Apply this diff:
- def flashinfer_gemma_fused_add_rmsnorm(input: torch.Tensor, + def flashinfer_gemma_fused_add_rmsnorm(input: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor, eps: float) -> None: gemma_fused_add_rmsnorm(input, residual, weight, eps, enable_pdl=ENABLE_PDL)And for the fake (if present elsewhere), prefix unused args similarly.
tensorrt_llm/_torch/models/modeling_qwen3_next.py (3)
197-199
: Prefer torch.sigmoid over deprecated F.sigmoidModern PyTorch deprecates
F.sigmoid
.Apply this diff:
- shared_expert_output = F.sigmoid( - self.shared_expert_gate(hidden_states)) * shared_expert_output + shared_expert_output = torch.sigmoid( + self.shared_expert_gate(hidden_states)) * shared_expert_output
760-778
: Redundant compute:fix_query_key_value_ordering
called twice on the prefill pathYou call
fix_query_key_value_ordering
before theif
and again in theelse
. Drop the first call or guard it.Apply this diff to compute lazily:
- query, key, value, z, b, a = self.fix_query_key_value_ordering( - projected_states_qkvz, projected_states_ba) - if self.num_v_heads // self.num_k_heads in [1, 2, 4]: # and is_cuda_graph: mixed_qkv, z, b, a = fused_qkvzba_split_reshape_cat( projected_states_qkvz, projected_states_ba, triton.cdiv(self.num_k_heads, self.attn_tp_size), triton.cdiv(self.num_v_heads, self.attn_tp_size), self.head_k_dim, self.head_v_dim, ) else: query, key, value, z, b, a = self.fix_query_key_value_ordering( projected_states_qkvz, projected_states_ba)
1192-1204
: XOR check is fine; message could be clearerThe XOR condition is correct. Optionally refine the error message to “Specify exactly one of input_ids or inputs_embeds.”
tensorrt_llm/_torch/modules/fla/solve_tril.py (9)
15-24
: Consider re‑enabling Triton autotune behind a flag.Autotune blocks are commented out. Add an env‑gated autotune to retain performance without affecting reproducibility in CI.
Example:
-# @triton.autotune( +@triton.autotune( # configs=[...], # key=[...], -# ) +)Gate with:
if int(os.getenv("TRITON_AUTOTUNE", "0")): ...
49-69
: 16×16 kernel math/readability LGTM; add a brief comment on offset wrap.
offset = (i_t * 16) % BT
guaranteesoffset+o_i < BT
for BT∈{16,32,64}; add a one‑liner comment to avoid future regressions if BT set changes.
72-92
:BT
is unused in this kernel signature.Ruff ARG001; remove
BT
or use it in pointer shapes. It’s currently constant (32) in all block ptrs.Apply:
-def merge_16x16_to_32x32_inverse_kernel( +def merge_16x16_to_32x32_inverse_kernel( A, Ad, Ai, cu_seqlens, chunk_indices, T, H: tl.constexpr, - BT: tl.constexpr, IS_VARLEN: tl.constexpr, ):And drop
BT=
at the call site.
144-164
:BT
is unused in the 64×64 kernel too.Same ARG001; remove it like the 32×32 kernel, and drop the argument from the launcher.
- H: tl.constexpr, - BT: tl.constexpr, + H: tl.constexpr, IS_VARLEN: tl.constexpr,
314-356
: Zeroing strict upper blocks is fine; consider early exit when already zero‑initialized.If
Ai
is created withzeros_like
you can skip stores. Minor micro‑opt.
365-380
: Docstring mismatch and clarify API.You compute
(I + A)^-1
but the summary says “inverse of the lower triangular matrix.” Align text, and specify expected shapes for varlen.Apply:
- Compute the inverse of the lower triangular matrix - A should be strictly lower triangular, i.e., A.triu() == 0. + Compute (I + A)^-1 for a strictly lower-triangular A (A.triu() == 0). + A has shape [B, T, H, BT], BT ∈ {16, 32, 64}. With cu_seqlens, sequences + are processed in 16-wide chunks (and merged to BT), optionally varlen.
383-407
: Intermediate dtype choice is good; add an explicittorch.float32
.Use
torch.float32
overtorch.float
for clarity.- Ad = torch.empty(B, T, H, 16, device=A.device, - dtype=torch.float if BT != 16 else output_dtype) + Ad = torch.empty(B, T, H, 16, device=A.device, + dtype=torch.float32 if BT != 16 else output_dtype)
408-426
: Re-enable autotune or make warps/stages configurable.Hardcoding
num_warps/num_stages
is okay for now, but expose env overrides to simplify perf tuning without code changes.- num_warps=4, - num_stages=3, + num_warps=int(os.getenv("FLSOLVE_MERGE_WARPS", "4")), + num_stages=int(os.getenv("FLSOLVE_MERGE_STAGES", "3")),(Remember to
import os
.)
72-80
: Autotune configs: consider smaller num_warps on low‑SM GPUs.If you restore autotune, include configs with
num_warps=1,2
for 32×32/64×64 merges; helps older/lower‑end parts.Also applies to: 146-152
tensorrt_llm/_torch/modules/fla/utils.py (3)
107-108
: Fix incorrect type annotation for cache_entries.The type annotation suggests a tuple but the actual type is a list. This mismatch can confuse type checkers and developers.
- cache_entries: Tuple[Optional[Tuple], Optional[Dict], Any] = [] + cache_entries: list = []
85-85
: Add explicit stacklevel to warning call.The static analysis correctly identifies that the warning should specify a stacklevel for better debugging context.
- warnings.warn(msg) + warnings.warn(msg, stacklevel=2)
305-305
: Avoid catching bare Exception.Catching bare
Exception
is too broad and can mask programming errors.- except Exception: + except (KeyError, IndexError, RuntimeError):tensorrt_llm/_torch/modules/fla/wy_fast.py (1)
117-117
: Potential issue with unpacking syntax.The unpacking syntax
*k.shape
looks unusual and might cause issues. Consider using explicit tuple unpacking.- B, T, Hg, K, V = *k.shape, v.shape[-1] + B, T, Hg, K = k.shape + V = v.shape[-1]tensorrt_llm/_torch/modules/fla/index.py (2)
11-13
: Add boundary check for empty cu_seqlens.The function should validate that cu_seqlens has at least 2 elements to avoid index errors.
@tensor_cache def prepare_lens(cu_seqlens: torch.LongTensor) -> torch.LongTensor: + if cu_seqlens.numel() < 2: + raise ValueError("cu_seqlens must have at least 2 elements") return cu_seqlens[1:] - cu_seqlens[:-1]
19-22
: Potential performance issue with list comprehension and tolist().Converting to Python list and back to tensor can be inefficient for large sequences. Consider keeping operations in tensor space.
- indices = torch.cat([ - torch.arange(n) - for n in triton.cdiv(prepare_lens(cu_seqlens), chunk_size).tolist() - ]) + # More efficient: avoid Python list conversion + chunk_counts = triton.cdiv(prepare_lens(cu_seqlens), chunk_size) + indices = torch.cat([ + torch.arange(n.item()) + for n in chunk_counts + ])tensorrt_llm/_torch/modules/fla/chunk_o.py (1)
134-134
: Unpacking syntax issue.Similar to the previous file, the unpacking syntax could be clearer.
- B, T, Hg, K, V = *q.shape, v.shape[-1] + B, T, Hg, K = q.shape + V = v.shape[-1]tensorrt_llm/_torch/modules/fla/chunk.py (4)
139-141
: Docstring type: scale should be Optional[float], not Optional[int].- scale (Optional[int]): + scale (Optional[float]):
120-126
: PEP 484 Optional types for public API.Annotate optionals explicitly.
-def chunk_gated_delta_rule( +def chunk_gated_delta_rule( q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, g: torch.Tensor, beta: torch.Tensor, - scale: float = None, - initial_state: torch.Tensor = None, + scale: Optional[float] = None, + initial_state: Optional[torch.Tensor] = None, output_final_state: bool = False, - cu_seqlens: Optional[torch.LongTensor] = None, + cu_seqlens: Optional[torch.LongTensor] = None, head_first: bool = False, use_qk_l2norm_in_kernel: bool = False, ):
81-93
: Remove stray pass and silence unused ctx per Ruff.- def forward( - ctx, + def forward( + _ctx, q: torch.Tensor, @@ - ): - pass + ):
99-110
: Silence unused unpacked values to satisfy linters.- g, o, A, final_state, w, h, v_new = chunk_gated_delta_rule_fwd( + _g, o, _A, final_state, _w, _h, _v_new = chunk_gated_delta_rule_fwd( q=q, k=k, v=v, g=g, beta=beta, scale=scale, initial_state=initial_state, output_final_state=output_final_state, cu_seqlens=cu_seqlens, )tensorrt_llm/_torch/modules/fla/chunk_delta_h.py (2)
236-246
: Return type annotation incorrect; function returns three values.-def chunk_gated_delta_rule_fwd_h( +def chunk_gated_delta_rule_fwd_h( k: torch.Tensor, @@ - cu_seqlens: Optional[torch.LongTensor] = None, -) -> Tuple[torch.Tensor, torch.Tensor]: + cu_seqlens: Optional[torch.LongTensor] = None, +) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
243-244
: Remove author-internal note or convert to TODO.- chunk_size: int = 64, # SY: remove this argument and force chunk size 64? + chunk_size: int = 64, # TODO: consider enforcing chunk size = 64tensorrt_llm/_torch/modules/fla/cumsum.py (3)
161-169
: PEP 484 Optional types for public API.-def chunk_local_cumsum_scalar( +def chunk_local_cumsum_scalar( g: torch.Tensor, chunk_size: int, reverse: bool = False, - scale: float = None, - cu_seqlens: Optional[torch.Tensor] = None, + scale: Optional[float] = None, + cu_seqlens: Optional[torch.Tensor] = None, head_first: bool = False, - output_dtype: Optional[torch.dtype] = torch.float, + output_dtype: Optional[torch.dtype] = torch.float, ) -> torch.Tensor:
200-208
: PEP 484 Optional types for public API.-def chunk_local_cumsum_vector( +def chunk_local_cumsum_vector( g: torch.Tensor, chunk_size: int, reverse: bool = False, - scale: float = None, - cu_seqlens: Optional[torch.Tensor] = None, + scale: Optional[float] = None, + cu_seqlens: Optional[torch.Tensor] = None, head_first: bool = False, - output_dtype: Optional[torch.dtype] = torch.float, + output_dtype: Optional[torch.dtype] = torch.float, ) -> torch.Tensor:
245-255
: Drop unused kwargs; make Optional explicit.-@input_guard -def chunk_local_cumsum( +@input_guard +def chunk_local_cumsum( g: torch.Tensor, chunk_size: int, reverse: bool = False, - scale: float = None, - cu_seqlens: Optional[torch.Tensor] = None, + scale: Optional[float] = None, + cu_seqlens: Optional[torch.Tensor] = None, head_first: bool = False, - output_dtype: Optional[torch.dtype] = torch.float, - **kwargs, + output_dtype: Optional[torch.dtype] = torch.float, ) -> torch.Tensor:tensorrt_llm/_torch/modules/fla/fused_recurrent.py (4)
214-225
: PEP 484 Optional types for public API.-def fused_recurrent_gated_delta_rule( +def fused_recurrent_gated_delta_rule( q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, g: torch.Tensor, - beta: torch.Tensor = None, - scale: float = None, - initial_state: torch.Tensor = None, + beta: Optional[torch.Tensor] = None, + scale: Optional[float] = None, + initial_state: Optional[torch.Tensor] = None, output_final_state: bool = False, - cu_seqlens: Optional[torch.LongTensor] = None, + cu_seqlens: Optional[torch.LongTensor] = None, use_qk_l2norm_in_kernel: bool = False, ) -> Tuple[torch.Tensor, torch.Tensor]:
573-588
: PEP 484 Optional types for update API.-def fused_recurrent_gated_delta_rule_update( +def fused_recurrent_gated_delta_rule_update( q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, g: torch.Tensor, - beta: torch.Tensor = None, - scale: float = None, - initial_state_source: torch.Tensor = None, - initial_state_indices: torch.Tensor = None, - cu_seqlens: Optional[torch.LongTensor] = None, + beta: Optional[torch.Tensor] = None, + scale: Optional[float] = None, + initial_state_source: Optional[torch.Tensor] = None, + initial_state_indices: Optional[torch.Tensor] = None, + cu_seqlens: Optional[torch.LongTensor] = None, use_qk_l2norm_in_kernel: bool = False, disable_state_update: bool = False, disable_output_calculation: bool = False, - intermediate_states_buffer: Optional[torch.Tensor] = None, - cache_steps: Optional[int] = None, + intermediate_states_buffer: Optional[torch.Tensor] = None, + cache_steps: Optional[int] = None, ) -> torch.Tensor:
175-190
: Silence unused ctx per Ruff.- def forward( - ctx, + def forward( + _ctx,
564-570
: Silence unused ctx in backward; message is long (TRY003).Optional: shorten message or reference docs; rename ctx to _ctx.
- def backward(ctx, do, dht): - raise NotImplementedError( - "Backward pass is not implemented yet and we do not have plans to implement it " - "because we haven't figured out how to compute dg without materializing the full " - "hidden states for all time steps.") + def backward(_ctx, _do, _dht): + raise NotImplementedError("Backward is not implemented for this op.")
📜 Review details
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (30)
cpp/tensorrt_llm/kernels/fusedQKNormRopeKernel.cu
(1 hunks)examples/models/core/qwen/README.md
(2 hunks)tensorrt_llm/_torch/custom_ops/__init__.py
(1 hunks)tensorrt_llm/_torch/custom_ops/flashinfer_custom_ops.py
(2 hunks)tensorrt_llm/_torch/models/__init__.py
(2 hunks)tensorrt_llm/_torch/models/checkpoints/__init__.py
(1 hunks)tensorrt_llm/_torch/models/checkpoints/hf/qwen2_moe_weight_mapper.py
(1 hunks)tensorrt_llm/_torch/models/checkpoints/hf/qwen3_next_weight_mapper.py
(1 hunks)tensorrt_llm/_torch/models/modeling_qwen3.py
(2 hunks)tensorrt_llm/_torch/models/modeling_qwen3_next.py
(1 hunks)tensorrt_llm/_torch/models/modeling_utils.py
(1 hunks)tensorrt_llm/_torch/modules/attention.py
(5 hunks)tensorrt_llm/_torch/modules/fla/chunk.py
(1 hunks)tensorrt_llm/_torch/modules/fla/chunk_delta_h.py
(1 hunks)tensorrt_llm/_torch/modules/fla/chunk_o.py
(1 hunks)tensorrt_llm/_torch/modules/fla/chunk_scaled_dot_kkt.py
(1 hunks)tensorrt_llm/_torch/modules/fla/cumsum.py
(1 hunks)tensorrt_llm/_torch/modules/fla/fused_recurrent.py
(1 hunks)tensorrt_llm/_torch/modules/fla/fused_sigmoid_gating_recurrent.py
(1 hunks)tensorrt_llm/_torch/modules/fla/index.py
(1 hunks)tensorrt_llm/_torch/modules/fla/l2norm.py
(1 hunks)tensorrt_llm/_torch/modules/fla/layernorm_gated.py
(1 hunks)tensorrt_llm/_torch/modules/fla/op.py
(1 hunks)tensorrt_llm/_torch/modules/fla/solve_tril.py
(1 hunks)tensorrt_llm/_torch/modules/fla/utils.py
(1 hunks)tensorrt_llm/_torch/modules/fla/wy_fast.py
(1 hunks)tensorrt_llm/_torch/modules/qk_norm_attention.py
(2 hunks)tensorrt_llm/_torch/modules/rms_norm.py
(3 hunks)tensorrt_llm/_torch/pyexecutor/_util.py
(3 hunks)tensorrt_llm/_torch/pyexecutor/config_utils.py
(1 hunks)
🧰 Additional context used
📓 Path-based instructions (5)
**/*.{h,hpp,hh,hxx,cpp,cxx,cc,cu,cuh,py}
📄 CodeRabbit inference engine (CODING_GUIDELINES.md)
Use only spaces, no tabs; indent with 4 spaces.
Files:
tensorrt_llm/_torch/models/checkpoints/hf/qwen2_moe_weight_mapper.py
cpp/tensorrt_llm/kernels/fusedQKNormRopeKernel.cu
tensorrt_llm/_torch/pyexecutor/config_utils.py
tensorrt_llm/_torch/custom_ops/__init__.py
tensorrt_llm/_torch/models/__init__.py
tensorrt_llm/_torch/modules/attention.py
tensorrt_llm/_torch/pyexecutor/_util.py
tensorrt_llm/_torch/models/checkpoints/__init__.py
tensorrt_llm/_torch/modules/fla/cumsum.py
tensorrt_llm/_torch/modules/fla/chunk_o.py
tensorrt_llm/_torch/models/checkpoints/hf/qwen3_next_weight_mapper.py
tensorrt_llm/_torch/modules/rms_norm.py
tensorrt_llm/_torch/modules/fla/wy_fast.py
tensorrt_llm/_torch/modules/qk_norm_attention.py
tensorrt_llm/_torch/custom_ops/flashinfer_custom_ops.py
tensorrt_llm/_torch/modules/fla/chunk.py
tensorrt_llm/_torch/modules/fla/op.py
tensorrt_llm/_torch/models/modeling_utils.py
tensorrt_llm/_torch/modules/fla/chunk_delta_h.py
tensorrt_llm/_torch/modules/fla/index.py
tensorrt_llm/_torch/modules/fla/layernorm_gated.py
tensorrt_llm/_torch/modules/fla/chunk_scaled_dot_kkt.py
tensorrt_llm/_torch/models/modeling_qwen3.py
tensorrt_llm/_torch/modules/fla/solve_tril.py
tensorrt_llm/_torch/modules/fla/fused_recurrent.py
tensorrt_llm/_torch/modules/fla/fused_sigmoid_gating_recurrent.py
tensorrt_llm/_torch/modules/fla/utils.py
tensorrt_llm/_torch/models/modeling_qwen3_next.py
tensorrt_llm/_torch/modules/fla/l2norm.py
**/*.py
📄 CodeRabbit inference engine (CODING_GUIDELINES.md)
**/*.py
: Python code must target Python 3.8+.
Indent Python code with 4 spaces; do not use tabs.
Maintain module namespace when importing; prefer 'from package.subpackage import foo' then 'foo.SomeClass()' instead of importing the class directly.
Python filenames should be snake_case (e.g., some_file.py).
Python classes use PascalCase names.
Functions and methods use snake_case names.
Local variables use snake_case; prefix 'k' for variables that start with a number (e.g., k_99th_percentile).
Global variables use upper SNAKE_CASE prefixed with 'G' (e.g., G_MY_GLOBAL).
Constants use upper SNAKE_CASE (e.g., MY_CONSTANT).
Avoid shadowing variables from an outer scope.
Initialize all externally visible members of a class in the constructor.
Prefer docstrings for interfaces that may be used outside a file; comments for in-function or file-local interfaces.
Use Google-style docstrings for classes and functions (Sphinx-parsable).
Document attributes and variables inline so they render under the class/function docstring.
Avoid reflection when a simpler, explicit approach suffices (e.g., avoid dict(**locals()) patterns).
In try/except, catch the most specific exceptions possible.
For duck-typing try/except, keep the try body minimal and use else for the main logic.
Files:
tensorrt_llm/_torch/models/checkpoints/hf/qwen2_moe_weight_mapper.py
tensorrt_llm/_torch/pyexecutor/config_utils.py
tensorrt_llm/_torch/custom_ops/__init__.py
tensorrt_llm/_torch/models/__init__.py
tensorrt_llm/_torch/modules/attention.py
tensorrt_llm/_torch/pyexecutor/_util.py
tensorrt_llm/_torch/models/checkpoints/__init__.py
tensorrt_llm/_torch/modules/fla/cumsum.py
tensorrt_llm/_torch/modules/fla/chunk_o.py
tensorrt_llm/_torch/models/checkpoints/hf/qwen3_next_weight_mapper.py
tensorrt_llm/_torch/modules/rms_norm.py
tensorrt_llm/_torch/modules/fla/wy_fast.py
tensorrt_llm/_torch/modules/qk_norm_attention.py
tensorrt_llm/_torch/custom_ops/flashinfer_custom_ops.py
tensorrt_llm/_torch/modules/fla/chunk.py
tensorrt_llm/_torch/modules/fla/op.py
tensorrt_llm/_torch/models/modeling_utils.py
tensorrt_llm/_torch/modules/fla/chunk_delta_h.py
tensorrt_llm/_torch/modules/fla/index.py
tensorrt_llm/_torch/modules/fla/layernorm_gated.py
tensorrt_llm/_torch/modules/fla/chunk_scaled_dot_kkt.py
tensorrt_llm/_torch/models/modeling_qwen3.py
tensorrt_llm/_torch/modules/fla/solve_tril.py
tensorrt_llm/_torch/modules/fla/fused_recurrent.py
tensorrt_llm/_torch/modules/fla/fused_sigmoid_gating_recurrent.py
tensorrt_llm/_torch/modules/fla/utils.py
tensorrt_llm/_torch/models/modeling_qwen3_next.py
tensorrt_llm/_torch/modules/fla/l2norm.py
**/*.{cpp,cxx,cc,h,hpp,hh,hxx,cu,cuh,py}
📄 CodeRabbit inference engine (CODING_GUIDELINES.md)
Prepend the NVIDIA Apache-2.0 copyright header with current year to the top of all source files (e.g., .cpp, .h, .cu, .py).
Files:
tensorrt_llm/_torch/models/checkpoints/hf/qwen2_moe_weight_mapper.py
cpp/tensorrt_llm/kernels/fusedQKNormRopeKernel.cu
tensorrt_llm/_torch/pyexecutor/config_utils.py
tensorrt_llm/_torch/custom_ops/__init__.py
tensorrt_llm/_torch/models/__init__.py
tensorrt_llm/_torch/modules/attention.py
tensorrt_llm/_torch/pyexecutor/_util.py
tensorrt_llm/_torch/models/checkpoints/__init__.py
tensorrt_llm/_torch/modules/fla/cumsum.py
tensorrt_llm/_torch/modules/fla/chunk_o.py
tensorrt_llm/_torch/models/checkpoints/hf/qwen3_next_weight_mapper.py
tensorrt_llm/_torch/modules/rms_norm.py
tensorrt_llm/_torch/modules/fla/wy_fast.py
tensorrt_llm/_torch/modules/qk_norm_attention.py
tensorrt_llm/_torch/custom_ops/flashinfer_custom_ops.py
tensorrt_llm/_torch/modules/fla/chunk.py
tensorrt_llm/_torch/modules/fla/op.py
tensorrt_llm/_torch/models/modeling_utils.py
tensorrt_llm/_torch/modules/fla/chunk_delta_h.py
tensorrt_llm/_torch/modules/fla/index.py
tensorrt_llm/_torch/modules/fla/layernorm_gated.py
tensorrt_llm/_torch/modules/fla/chunk_scaled_dot_kkt.py
tensorrt_llm/_torch/models/modeling_qwen3.py
tensorrt_llm/_torch/modules/fla/solve_tril.py
tensorrt_llm/_torch/modules/fla/fused_recurrent.py
tensorrt_llm/_torch/modules/fla/fused_sigmoid_gating_recurrent.py
tensorrt_llm/_torch/modules/fla/utils.py
tensorrt_llm/_torch/models/modeling_qwen3_next.py
tensorrt_llm/_torch/modules/fla/l2norm.py
**/*.{h,hpp,hh,hxx,cpp,cxx,cc,cu,cuh}
📄 CodeRabbit inference engine (CODING_GUIDELINES.md)
**/*.{h,hpp,hh,hxx,cpp,cxx,cc,cu,cuh}
: Namespace closing braces must include a trailing comment with the namespace name (e.g., '} // namespace foo').
Prefer const or constexpr variables over #define for constants.
Declare variables that are not modified after initialization as const.
Avoid magic literals in code; except for 0, nullptr, true, false. Use named constants for comparisons and logic.
Use Allman brace style for formatting.
Place the semicolon of an empty for/while loop on a new line.
Bodies of switch/while/do-while/for must be compound statements (brace-delimited), and if/else must always be followed by brace-delimited statements.
Type names (e.g., classes) must be CamelCase starting with an uppercase letter (e.g., FooBar).
Local variables, methods, and namespaces use lowerCamelCase (e.g., localFooBar).
Non-magic-number global variables that are non-static and not in an anonymous namespace must be lowerCamelCase prefixed with 'g' (e.g., gDontUseGlobalFoos).
Non-magic-number globals that are static or in an anonymous namespace use lowerCamelCase prefixed with 's' (e.g., sMutableStaticGlobal).
Locally visible static variables use lowerCamelCase with 's' prefix (e.g., static std::once_flag sFlag).
Private/protected member variables use 'm' prefix with CamelCase (e.g., mNbFooValues). Public members may omit, but 'm' is encouraged for clarity.
Constants (enums, global constants, static constants, and function-scope magic/literal constants) use uppercase SNAKE_CASE with 'k' prefix (e.g., kDIGIT_NUM).
Function-scope constants that are not magic numbers or literals are named like non-constant variables (e.g., bool const pass = a && b).
If macros are necessary, name them in UPPER_SNAKE_CASE (e.g., FOO_VERSION) and prefer constants over #define.
Use LLVM clang-format; wrap lines at a maximum of 120 columns; use '// clang-format off/on' sparingly with justification.
Use smart pointers for heap allocations; prefer unique_ptr for sole ownership, shared_ptr for shared...
Files:
cpp/tensorrt_llm/kernels/fusedQKNormRopeKernel.cu
**/*.{cpp,cxx,cc,cu,h,hpp,hh,hxx,cuh}
📄 CodeRabbit inference engine (CODING_GUIDELINES.md)
C++ filenames should be lowerCamelCase (first letter lowercase) and must be case-insensitive unique within a compilation target.
Files:
cpp/tensorrt_llm/kernels/fusedQKNormRopeKernel.cu
🧠 Learnings (1)
📚 Learning: 2025-08-27T14:41:56.665Z
Learnt from: ixlmar
PR: NVIDIA/TensorRT-LLM#7294
File: tensorrt_llm/_torch/modules/rms_norm.py:96-99
Timestamp: 2025-08-27T14:41:56.665Z
Learning: In tensorrt_llm/_torch/modules/rms_norm.py, the RMSNorm class uses a custom sentinel (_ARGUMENT_NOT_SPECIFIED_SENTINEL) instead of Ellipsis (...) for detecting unspecified optional arguments. Other modules in the codebase may use Ellipsis as a sentinel but do not forward it to RMSNorm methods, so there's no need for backward compatibility with Ellipsis in RMSNorm.
Applied to files:
tensorrt_llm/_torch/modules/rms_norm.py
🧬 Code graph analysis (24)
tensorrt_llm/_torch/pyexecutor/config_utils.py (1)
tensorrt_llm/_torch/models/modeling_utils.py (1)
config
(519-520)
tensorrt_llm/_torch/custom_ops/__init__.py (1)
tensorrt_llm/_torch/custom_ops/flashinfer_custom_ops.py (5)
flashinfer_fused_add_rmsnorm
(44-47)flashinfer_gemma_fused_add_rmsnorm
(51-59)flashinfer_gemma_rmsnorm
(33-35)flashinfer_rmsnorm
(22-24)flashinfer_silu_and_mul
(13-14)
tensorrt_llm/_torch/models/__init__.py (1)
tensorrt_llm/_torch/models/modeling_qwen3_next.py (1)
Qwen3NextForCausalLM
(1231-1254)
tensorrt_llm/_torch/modules/attention.py (2)
tensorrt_llm/logger.py (1)
warning_once
(135-136)tensorrt_llm/_torch/modules/qk_norm_attention.py (1)
apply_rope
(240-257)
tensorrt_llm/_torch/pyexecutor/_util.py (3)
tensorrt_llm/_torch/pyexecutor/config_utils.py (1)
is_qwen3_next
(16-17)tensorrt_llm/_torch/models/modeling_utils.py (1)
config
(519-520)tensorrt_llm/_torch/pyexecutor/mamba_cache_manager.py (1)
MambaHybridCacheManager
(167-246)
tensorrt_llm/_torch/models/checkpoints/__init__.py (2)
tensorrt_llm/_torch/models/checkpoints/hf/qwen3_next_weight_mapper.py (1)
Qwen3NextHfWeightMapper
(15-105)tensorrt_llm/_torch/models/checkpoints/hf/weight_mapper.py (1)
HfWeightMapper
(10-99)
tensorrt_llm/_torch/modules/fla/cumsum.py (2)
tensorrt_llm/_torch/modules/fla/index.py (1)
prepare_chunk_indices
(17-23)tensorrt_llm/_torch/modules/fla/utils.py (2)
check_shared_mem
(300-306)input_guard
(133-166)
tensorrt_llm/_torch/modules/fla/chunk_o.py (3)
tensorrt_llm/_torch/modules/fla/index.py (1)
prepare_chunk_indices
(17-23)tensorrt_llm/_torch/modules/fla/op.py (1)
safe_exp
(26-27)tensorrt_llm/_torch/modules/fla/utils.py (1)
check_shared_mem
(300-306)
tensorrt_llm/_torch/models/checkpoints/hf/qwen3_next_weight_mapper.py (2)
tensorrt_llm/_torch/models/modeling_utils.py (3)
register_mapper
(655-666)DecoderModelForCausalLM
(342-603)config
(519-520)tensorrt_llm/_torch/models/checkpoints/hf/weight_mapper.py (1)
_duplicate_kv
(76-99)
tensorrt_llm/_torch/modules/rms_norm.py (1)
tensorrt_llm/_torch/custom_ops/flashinfer_custom_ops.py (4)
flashinfer_gemma_fused_add_rmsnorm
(51-59)flashinfer_gemma_rmsnorm
(33-35)flashinfer_rmsnorm
(22-24)flashinfer_fused_add_rmsnorm
(44-47)
tensorrt_llm/_torch/modules/fla/wy_fast.py (1)
tensorrt_llm/_torch/modules/fla/index.py (1)
prepare_chunk_indices
(17-23)
tensorrt_llm/_torch/modules/qk_norm_attention.py (1)
tensorrt_llm/_torch/modules/rms_norm.py (1)
RMSNorm
(25-110)
tensorrt_llm/_torch/custom_ops/flashinfer_custom_ops.py (1)
tests/unittest/_torch/misc/test_autotuner.py (2)
_
(129-130)_
(203-204)
tensorrt_llm/_torch/modules/fla/chunk.py (8)
tensorrt_llm/_torch/modules/fla/chunk_delta_h.py (1)
chunk_gated_delta_rule_fwd_h
(236-294)tensorrt_llm/_torch/modules/fla/chunk_o.py (1)
chunk_fwd_o
(124-169)tensorrt_llm/_torch/modules/fla/chunk_scaled_dot_kkt.py (1)
chunk_scaled_dot_kkt_fwd
(88-143)tensorrt_llm/_torch/modules/fla/cumsum.py (1)
chunk_local_cumsum
(246-282)tensorrt_llm/_torch/modules/fla/l2norm.py (2)
l2norm
(133-136)l2norm_fwd
(74-122)tensorrt_llm/_torch/modules/fla/solve_tril.py (1)
solve_tril
(360-426)tensorrt_llm/_torch/modules/fla/utils.py (1)
input_guard
(133-166)tensorrt_llm/_torch/modules/fla/wy_fast.py (1)
recompute_w_u_fwd
(109-149)
tensorrt_llm/_torch/modules/fla/chunk_delta_h.py (2)
tensorrt_llm/_torch/modules/fla/index.py (2)
prepare_chunk_indices
(17-23)prepare_chunk_offsets
(27-32)tensorrt_llm/_torch/modules/fla/op.py (1)
safe_exp
(26-27)
tensorrt_llm/_torch/modules/fla/index.py (1)
tensorrt_llm/_torch/modules/fla/utils.py (1)
tensor_cache
(93-130)
tensorrt_llm/_torch/modules/fla/layernorm_gated.py (2)
tensorrt_llm/functional.py (3)
rearrange
(6144-6287)sqrt
(450-454)sum
(3253-3275)tensorrt_llm/_torch/modules/rms_norm.py (1)
RMSNorm
(25-110)
tensorrt_llm/_torch/modules/fla/chunk_scaled_dot_kkt.py (2)
tensorrt_llm/_torch/modules/fla/index.py (1)
prepare_chunk_indices
(17-23)tensorrt_llm/_torch/modules/fla/op.py (1)
safe_exp
(26-27)
tensorrt_llm/_torch/models/modeling_qwen3.py (2)
tensorrt_llm/_torch/models/modeling_utils.py (1)
config
(519-520)tensorrt_llm/_torch/model_config.py (1)
torch_dtype
(199-204)
tensorrt_llm/_torch/modules/fla/solve_tril.py (2)
tensorrt_llm/_torch/modules/fla/index.py (1)
prepare_chunk_indices
(17-23)tensorrt_llm/_torch/modules/fla/utils.py (1)
input_guard
(133-166)
tensorrt_llm/_torch/modules/fla/fused_recurrent.py (1)
tensorrt_llm/_torch/modules/fla/utils.py (1)
input_guard
(133-166)
tensorrt_llm/_torch/modules/fla/fused_sigmoid_gating_recurrent.py (3)
tensorrt_llm/_torch/modules/fla/utils.py (1)
input_guard
(133-166)tensorrt_llm/_torch/modules/fla/chunk_delta_h.py (1)
grid
(270-271)tensorrt_llm/_torch/modules/fla/chunk_o.py (1)
grid
(145-146)
tensorrt_llm/_torch/models/modeling_qwen3_next.py (13)
tensorrt_llm/_torch/models/checkpoints/base_weight_mapper.py (1)
BaseWeightMapper
(10-165)tensorrt_llm/_torch/modules/fla/chunk.py (2)
chunk_gated_delta_rule
(114-238)forward
(80-110)tensorrt_llm/_torch/modules/fla/fused_sigmoid_gating_recurrent.py (1)
fused_sigmoid_gating_delta_rule_update
(156-222)tensorrt_llm/_torch/modules/mamba/mamba2_metadata.py (1)
Mamba2Metadata
(88-137)tensorrt_llm/mapping.py (1)
Mapping
(32-513)tensorrt_llm/_torch/attention_backend/interface.py (2)
AttentionMetadata
(40-336)num_ctx_tokens
(263-264)tensorrt_llm/_torch/distributed/ops.py (2)
AllReduce
(400-553)MoEAllReduce
(556-644)tensorrt_llm/_torch/modules/embedding.py (1)
Embedding
(162-244)tensorrt_llm/_torch/modules/fused_moe/create_moe.py (1)
create_moe
(61-211)tensorrt_llm/_torch/modules/mamba/causal_conv1d.py (2)
causal_conv1d_fn
(26-74)causal_conv1d_update
(77-118)tensorrt_llm/_torch/modules/rms_norm.py (2)
RMSNorm
(25-110)forward
(54-100)tensorrt_llm/_torch/models/modeling_speculative.py (1)
SpecDecOneEngineForCausalLM
(377-496)tensorrt_llm/_torch/models/modeling_utils.py (7)
DecoderModel
(228-321)EagerFusionConfig
(40-44)register_auto_model
(614-620)forward
(242-271)forward
(526-552)load_weights
(554-572)config
(519-520)
tensorrt_llm/_torch/modules/fla/l2norm.py (2)
tensorrt_llm/_torch/modules/fla/utils.py (1)
input_guard
(133-166)tensorrt_llm/_torch/modules/fla/chunk.py (1)
forward
(80-110)
🪛 Ruff (0.13.1)
tensorrt_llm/_torch/pyexecutor/_util.py
456-458: Avoid specifying long messages outside the exception class
(TRY003)
tensorrt_llm/_torch/modules/fla/cumsum.py
33-33: Unused function argument: B
(ARG001)
92-92: Unused function argument: B
(ARG001)
165-165: PEP 484 prohibits implicit Optional
Convert to Optional[T]
(RUF013)
204-204: PEP 484 prohibits implicit Optional
Convert to Optional[T]
(RUF013)
250-250: PEP 484 prohibits implicit Optional
Convert to Optional[T]
(RUF013)
254-254: Unused function argument: kwargs
(ARG001)
280-282: Avoid specifying long messages outside the exception class
(TRY003)
tensorrt_llm/_torch/custom_ops/flashinfer_custom_ops.py
38-38: Unused function argument: weight
(ARG001)
39-39: Unused function argument: eps
(ARG001)
tensorrt_llm/_torch/modules/fla/chunk.py
81-81: Unused static method argument: ctx
(ARG004)
99-99: Unpacked variable A
is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
99-99: Unpacked variable w
is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
99-99: Unpacked variable h
is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
99-99: Unpacked variable v_new
is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
120-120: PEP 484 prohibits implicit Optional
Convert to Optional[T]
(RUF013)
199-201: Avoid specifying long messages outside the exception class
(TRY003)
213-215: Avoid specifying long messages outside the exception class
(TRY003)
218-221: Avoid specifying long messages outside the exception class
(TRY003)
tensorrt_llm/_torch/modules/fla/op.py
33-33: Unused function argument: src
(ARG001)
33-33: Unused function argument: index
(ARG001)
33-33: Unused function argument: axis
(ARG001)
59-59: Unused function argument: base
(ARG001)
60-60: Unused function argument: shape
(ARG001)
61-61: Unused function argument: strides
(ARG001)
62-62: Unused function argument: block_shape
(ARG001)
tensorrt_llm/_torch/modules/fla/layernorm_gated.py
156-157: Avoid specifying long messages outside the exception class
(TRY003)
188-188: Unused static method argument: ctx
(ARG004)
213-213: Unpacked variable mean
is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
213-213: Unpacked variable rstd
is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
tensorrt_llm/_torch/modules/fla/solve_tril.py
90-90: Unused function argument: BT
(ARG001)
162-162: Unused function argument: BT
(ARG001)
tensorrt_llm/_torch/modules/fla/fused_recurrent.py
178-178: Unused static method argument: ctx
(ARG004)
220-220: PEP 484 prohibits implicit Optional
Convert to Optional[T]
(RUF013)
287-289: Avoid specifying long messages outside the exception class
(TRY003)
292-295: Avoid specifying long messages outside the exception class
(TRY003)
529-529: Unused static method argument: ctx
(ARG004)
579-579: PEP 484 prohibits implicit Optional
Convert to Optional[T]
(RUF013)
591-593: Avoid specifying long messages outside the exception class
(TRY003)
596-599: Avoid specifying long messages outside the exception class
(TRY003)
tensorrt_llm/_torch/modules/fla/utils.py
85-85: No explicit stacklevel
keyword argument found
Set stacklevel=2
(B028)
226-226: Do not catch blind exception: BaseException
(BLE001)
235-235: Do not catch blind exception: BaseException
(BLE001)
280-280: Do not catch blind exception: BaseException
(BLE001)
305-305: Do not catch blind exception: Exception
(BLE001)
tensorrt_llm/_torch/models/modeling_qwen3_next.py
66-66: Unused method argument: moe_backend
(ARG002)
97-98: Avoid specifying long messages outside the exception class
(TRY003)
175-175: Do not assert False
(python -O
removes these calls), raise AssertionError()
Replace assert False
(B011)
503-508: Consider iterable unpacking instead of concatenation
Replace with iterable unpacking
(RUF005)
509-512: Consider iterable unpacking instead of concatenation
Replace with iterable unpacking
(RUF005)
735-735: Unpacked variable seq_len
is never used
Prefix it with an underscore or any other dummy variable pattern
(RUF059)
736-736: Local variable conv_state
is assigned to but never used
Remove assignment to unused variable conv_state
(F841)
736-736: Local variable recurrent_state
is assigned to but never used
Remove assignment to unused variable recurrent_state
(F841)
878-878: Unused method argument: position_ids
(ARG002)
1199-1199: Unused method argument: kwargs
(ARG002)
1202-1204: Avoid specifying long messages outside the exception class
(TRY003)
tensorrt_llm/_torch/modules/fla/l2norm.py
59-59: Unused function argument: NB
(ARG001)
91-91: Avoid specifying long messages outside the exception class
(TRY003)
129-129: Unused static method argument: ctx
(ARG004)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Pre-commit Check
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Review continued from previous batch...
PR_Github #19484 [ run ] triggered by Bot |
PR_Github #19484 [ run ] completed with state |
/bot run |
PR_Github #19646 [ run ] triggered by Bot |
/bot run |
PR_Github #19652 [ run ] triggered by Bot |
PR_Github #19646 [ run ] completed with state |
PR_Github #19652 [ run ] completed with state |
395626f
to
c740975
Compare
/bot run |
Signed-off-by: bhsueh <11360707+byshiue@users.noreply.github.com>
Signed-off-by: bhsueh <11360707+byshiue@users.noreply.github.com>
Signed-off-by: bhsueh <11360707+byshiue@users.noreply.github.com>
Signed-off-by: bhsueh <11360707+byshiue@users.noreply.github.com>
Signed-off-by: nv-guomingz <137257613+nv-guomingz@users.noreply.github.com>
Signed-off-by: nv-guomingz <137257613+nv-guomingz@users.noreply.github.com>
caed68d
to
5f9e254
Compare
/bot run --disable-fail-fast |
PR_Github #20191 [ run ] triggered by Bot |
PR_Github #20191 [ run ] completed with state |
/bot run |
PR_Github #20214 [ run ] triggered by Bot |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Reviewed attention.py
PR_Github #20214 [ run ] completed with state |
/bot run |
PR_Github #20256 [ run ] triggered by Bot |
PR_Github #20256 [ run ] completed with state |
Summary by CodeRabbit
New Features
Performance
Documentation
PR Checklist
Please review the following before submitting your PR:
PR description clearly explains what and why. If using CodeRabbit's summary, please make sure it makes sense.
PR Follows TRT-LLM CODING GUIDELINES to the best of your knowledge.
Test cases are provided for new code paths (see test instructions)
Any new dependencies have been scanned for license and vulnerabilities
CODEOWNERS updated if ownership changes
Documentation updated as needed
The reviewers assigned automatically/manually are appropriate for the PR.
Please check this after reviewing the above items as appropriate for this PR.