[https://nvbugs/6299530][fix] Capture Qwen3.5 GDN for piecewise CUDA …#15594
[https://nvbugs/6299530][fix] Capture Qwen3.5 GDN for piecewise CUDA …#15594liji-nv wants to merge 2 commits into
Conversation
|
/bot run --disable-fail-fast |
|
PR_Github #55499 [ run ] triggered by Bot. Commit: |
📝 WalkthroughWalkthroughThis PR enables piecewise CUDA graph compilation for Qwen3Next GatedDeltaNet (GDN) layers. It propagates preallocated ChangesGDN Piecewise CUDA Graph Support
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes Suggested reviewers
🚥 Pre-merge checks | ✅ 5✅ Passed checks (5 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Comment |
There was a problem hiding this comment.
Actionable comments posted: 1
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (3)
tensorrt_llm/_torch/modules/fla/flashinfer_chunk.py (1)
118-123: 🩺 Stability & Availability | 🟠 Major | ⚡ Quick winCheck
outputbefore reusing it as FlashInfer’s output buffer.Line 122 assumes
output.squeeze(0)has shape[T, num_o_heads, head_size]and a compatible contiguous layout. Validate this before passing it asoutput=so bad caller buffers fail deterministically.Suggested guard
total_seq_len = q3.shape[0] num_o_heads = max(q3.shape[1], v3.shape[1]) head_size = q3.shape[2] need_state = inplace_indexed_state_update or output_final_state - output_buf = output.squeeze(0) if output is not None else q3.new_empty( - total_seq_len, num_o_heads, head_size) + if output is not None: + expected_shape = (1, total_seq_len, num_o_heads, head_size) + if output.shape != expected_shape or output.dtype != q3.dtype or output.device != q3.device: + raise ValueError( + "`output` must match FlashInfer output shape/dtype/device; " + f"got {tuple(output.shape)}/{output.dtype}/{output.device}, " + f"expected {expected_shape}/{q3.dtype}/{q3.device}" + ) + if not output.is_contiguous(): + raise ValueError("`output` must be contiguous") + output_buf = output.squeeze(0) + else: + output_buf = q3.new_empty(total_seq_len, num_o_heads, head_size)🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@tensorrt_llm/_torch/modules/fla/flashinfer_chunk.py` around lines 118 - 123, In flashinfer_chunk.py, the output buffer handling in the FlashInfer path currently reuses output via FlashInfer’s output buffer without validating its shape or layout. Update the logic around the output.squeeze(0) reuse in the chunk/forward flow to first check that a caller-provided output matches [T, num_o_heads, head_size] and is contiguous/compatible before passing it as output=; otherwise fall back to allocating a fresh buffer or raise a clear error from the same code path.tensorrt_llm/_torch/modules/fla/fused_recurrent.py (1)
130-141: 🩺 Stability & Availability | 🟠 Major | ⚡ Quick winValidate
outputbefore unsqueezing it into the Triton destination.Both paths pass caller-owned
outputinto kernels that assume the same dense layout asv. Add shape/dtype/device/contiguity checks beforeoutput.unsqueeze(0)to avoid bad writes.Suggested helper
+def _validate_recurrent_output(output: torch.Tensor, v: torch.Tensor) -> None: + if output.shape != v.shape or output.dtype != v.dtype or output.device != v.device: + raise ValueError( + "`output` must match `v` in shape, dtype, and device; " + f"got output={tuple(output.shape)}/{output.dtype}/{output.device}, " + f"v={tuple(v.shape)}/{v.dtype}/{v.device}" + ) + if not output.is_contiguous(): + raise ValueError("`output` must be contiguous for fused recurrent kernels")Then call it before each
output.unsqueeze(0)branch.Also applies to: 483-498
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@tensorrt_llm/_torch/modules/fla/fused_recurrent.py` around lines 130 - 141, The `output` tensor is passed into Triton kernels via `fused_recurrent` and related call sites, but it is not validated before `output.unsqueeze(0)` is used as the destination. Add a helper in this module to check `output`’s shape, dtype, device, and contiguity against `v`, and invoke it before each `output.unsqueeze(0)` branch so only compatible dense tensors reach the kernel.tensorrt_llm/_torch/modules/fla/chunk_o.py (1)
133-144: 🩺 Stability & Availability | 🟠 Major | ⚡ Quick winValidate the caller-provided output buffer before launching the Triton kernel.
Line 144 now lets callers supply
o, butchunk_fwd_kernel_owrites with raw contiguous pointer arithmetic based onv.shape. A wrong dtype/device/shape/stride can silently write incorrect memory. Add a local contract check before assigningo.Suggested guard
- o = output if output is not None else torch.empty_like(v) + if output is not None: + if output.shape != v.shape or output.dtype != v.dtype or output.device != v.device: + raise ValueError( + "`output` must match `v` in shape, dtype, and device; " + f"got output={tuple(output.shape)}/{output.dtype}/{output.device}, " + f"v={tuple(v.shape)}/{v.dtype}/{v.device}" + ) + if not output.is_contiguous(): + raise ValueError("`output` must be contiguous for chunk_fwd_kernel_o") + o = output + else: + o = torch.empty_like(v)🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@tensorrt_llm/_torch/modules/fla/chunk_o.py` around lines 133 - 144, Validate the caller-supplied output buffer in chunk_fwd_kernel_o before using it for the Triton launch: ensure output/o matches the expected tensor dtype, device, shape, and contiguity/stride layout derived from v.shape and q.shape. Add a local assertion or explicit contract check right before assigning o so invalid buffers fail fast instead of allowing raw pointer writes to corrupt memory.
🧹 Nitpick comments (2)
tensorrt_llm/_torch/compilation/piecewise_optimizer.py (1)
22-32: 📐 Maintainability & Code Quality | 🔵 Trivial | ⚡ Quick winAdd a return type annotation to
_piecewise_boundary_ops.This new function is missing a return annotation.
As per coding guidelines, "Always annotate functions with return types (useNoneif no return)."🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@tensorrt_llm/_torch/compilation/piecewise_optimizer.py` around lines 22 - 32, The helper _piecewise_boundary_ops currently lacks the required return type annotation. Update the function signature for _piecewise_boundary_ops to explicitly declare its return type based on the list of ops it builds, keeping the implementation unchanged and ensuring it follows the project’s function annotation guidelines.Source: Coding guidelines
tensorrt_llm/_torch/modules/fla/chunk.py (1)
135-180: 📐 Maintainability & Code Quality | 🔵 Trivial | ⚡ Quick winDocument the new
outputtensor contract.The public
chunk_gated_delta_ruledocstring now omitsoutput, but callers need to know it must match the returnedolayout/shape/dtype. As per coding guidelines, public Tensor-like arguments should document expected dimensions and dtype options.Suggested docstring addition
cu_seqlens (torch.LongTensor): Cumulative sequence lengths of shape `[N+1]` used for variable-length training, consistent with the FlashAttention API. + output (Optional[torch.Tensor]): + Optional preallocated output buffer with the same shape, dtype, device, and contiguous + layout as the returned `o`.🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@tensorrt_llm/_torch/modules/fla/chunk.py` around lines 135 - 180, The public docstring for chunk_gated_delta_rule is missing the new output argument contract. Update the docstring near the existing parameter docs in chunk.py to describe output as an optional preallocated tensor that must match the returned o layout, shape, and dtype (including head_first-dependent dimensions). Keep the description aligned with the other tensor arguments so callers can safely pass a correctly sized buffer.Source: Coding guidelines
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@tensorrt_llm/_torch/modules/mamba/gdn_mixer.py`:
- Around line 82-90: The custom-op signature for gdn_custom_op_inplace is out of
sync with the inplace metadata, causing mutation tracking to point at the wrong
argument. Update the gdn_custom_op_inplace parameter order so output matches the
position expected by inplace_info() (or change the inplace_info() mapping to the
current output position), and keep mutates_args consistent with the actual
mutable tensor name.
---
Outside diff comments:
In `@tensorrt_llm/_torch/modules/fla/chunk_o.py`:
- Around line 133-144: Validate the caller-supplied output buffer in
chunk_fwd_kernel_o before using it for the Triton launch: ensure output/o
matches the expected tensor dtype, device, shape, and contiguity/stride layout
derived from v.shape and q.shape. Add a local assertion or explicit contract
check right before assigning o so invalid buffers fail fast instead of allowing
raw pointer writes to corrupt memory.
In `@tensorrt_llm/_torch/modules/fla/flashinfer_chunk.py`:
- Around line 118-123: In flashinfer_chunk.py, the output buffer handling in the
FlashInfer path currently reuses output via FlashInfer’s output buffer without
validating its shape or layout. Update the logic around the output.squeeze(0)
reuse in the chunk/forward flow to first check that a caller-provided output
matches [T, num_o_heads, head_size] and is contiguous/compatible before passing
it as output=; otherwise fall back to allocating a fresh buffer or raise a clear
error from the same code path.
In `@tensorrt_llm/_torch/modules/fla/fused_recurrent.py`:
- Around line 130-141: The `output` tensor is passed into Triton kernels via
`fused_recurrent` and related call sites, but it is not validated before
`output.unsqueeze(0)` is used as the destination. Add a helper in this module to
check `output`’s shape, dtype, device, and contiguity against `v`, and invoke it
before each `output.unsqueeze(0)` branch so only compatible dense tensors reach
the kernel.
---
Nitpick comments:
In `@tensorrt_llm/_torch/compilation/piecewise_optimizer.py`:
- Around line 22-32: The helper _piecewise_boundary_ops currently lacks the
required return type annotation. Update the function signature for
_piecewise_boundary_ops to explicitly declare its return type based on the list
of ops it builds, keeping the implementation unchanged and ensuring it follows
the project’s function annotation guidelines.
In `@tensorrt_llm/_torch/modules/fla/chunk.py`:
- Around line 135-180: The public docstring for chunk_gated_delta_rule is
missing the new output argument contract. Update the docstring near the existing
parameter docs in chunk.py to describe output as an optional preallocated tensor
that must match the returned o layout, shape, and dtype (including
head_first-dependent dimensions). Keep the description aligned with the other
tensor arguments so callers can safely pass a correctly sized buffer.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Enterprise
Run ID: 81ac5010-965d-4e82-8f6d-76fd40bf0aaf
📒 Files selected for processing (11)
tensorrt_llm/_torch/compilation/piecewise_optimizer.pytensorrt_llm/_torch/compilation/utils.pytensorrt_llm/_torch/modules/fla/chunk.pytensorrt_llm/_torch/modules/fla/chunk_o.pytensorrt_llm/_torch/modules/fla/flashinfer_chunk.pytensorrt_llm/_torch/modules/fla/fused_recurrent.pytensorrt_llm/_torch/modules/mamba/gdn_mixer.pytensorrt_llm/_torch/pyexecutor/model_engine.pytests/integration/defs/accuracy/test_llm_api_pytorch.pytests/integration/test_lists/qa/llm_function_core.txttests/integration/test_lists/test-db/l0_h100.yml
|
PR_Github #55499 [ run ] completed with state
|
3ca4c97 to
30793be
Compare
|
/bot run --disable-fail-fast |
890832c to
e94bdc8
Compare
|
/bot run --disable-fail-fast |
|
PR_Github #55740 [ run ] triggered by Bot. Commit: |
|
PR_Github #55740 [ run ] completed with state
|
e94bdc8 to
6f1d4c7
Compare
|
/bot run --disable-fail-fast |
|
PR_Github #56284 [ run ] triggered by Bot. Commit: |
|
PR_Github #56284 [ run ] completed with state
|
|
/bot run --disable-fail-fast |
|
PR_Github #56331 [ run ] triggered by Bot. Commit: |
|
PR_Github #56331 [ run ] completed with state
|
02addfd to
0892a72
Compare
|
/bot run --disable-fail-fast |
1 similar comment
|
/bot run --disable-fail-fast |
|
PR_Github #56522 [ run ] triggered by Bot. Commit: |
|
PR_Github #56522 [ run ] completed with state
|
|
/bot run --disable-fail-fast |
|
PR_Github #56797 [ run ] triggered by Bot. Commit: |
|
PR_Github #56797 [ run ] completed with state
|
|
/bot run |
|
PR_Github #57079 [ run ] triggered by Bot. Commit: |
|
PR_Github #57079 [ run ] completed with state |
…graph Keep eager and torch-compile GDN execution on the same forward_core path by passing the original mixed QKV and gating projection tensors into the custom op. The custom op only provides a compile boundary and an inplace output buffer. Restore the standard decode path to fused_sigmoid_gating_delta_rule_update so FlashInfer GDN decode receives the original a/b tensors and preserves the eager accuracy behavior. Thread the optional output buffer through the FlashInfer and Triton decode paths to avoid an extra copy. Tests: - python -m py_compile tensorrt_llm/_torch/modules/mamba/gdn_mixer.py tensorrt_llm/_torch/modules/fla/fused_sigmoid_gating_recurrent.py - git diff --check - PDX sqsh build job 112288: COMPLETED - PDX accuracy job 112301: TestQwen3_5_4B test_fp8 and test_fp8_piecewise_cuda_graph passed - PDX accuracy job 112326: TestQwen3_5_35B_A3B test_bf16[tp2-TRTLLM] passed Signed-off-by: Jin Li <59594262+liji-nv@users.noreply.github.com>
Signed-off-by: Jin Li <59594262+liji-nv@users.noreply.github.com>
0892a72 to
4492490
Compare
|
/bot run --disable-fail-fast |
|
PR_Github #57461 [ run ] triggered by Bot. Commit: |
Wrap the MiniMax M3 metadata- and cache-dependent attention core in an inplace custom op so torch.compile can split it out of piecewise CUDA graphs. Keep QKV/index projections, QK normalization, RoPE, and the output projection visible to the compiled graph. Write dense and sparse attention results into the custom-op output buffer. Preserve FP32 sparse GQA accumulation until the final copy/cast, and expose the output buffer through MiniMaxM3SparseRuntimeBackend.forward. Register attention boundaries and mutation metadata through optional TRT-LLM op lookup, matching the latest GDN registration pattern from PR NVIDIA#15594. This avoids depending on model-specific custom ops being imported when compilation utilities initialize. Track piecewise runners owned by the compile backend and reset their CUDA graphs, captured addresses, outputs, and warmup state when phase-1 KV-cache estimation is released. Phase 2 then recaptures against the final allocations instead of replaying stale graph pointers. Add an 8-GPU MiniMax-M3-MXFP8 torch.compile E2E variant covering TP8/EP8, attention DP, TRTLLM MoE, padding CUDA graphs, multi-stream piecewise capture, and phase-2 recapture. Signed-off-by: Jin Li <59594262+liji-nv@users.noreply.github.com>
Wrap the MiniMax M3 metadata- and cache-dependent attention core in an inplace custom op so torch.compile can split it out of piecewise CUDA graphs. Keep QKV/index projections, QK normalization, RoPE, and the output projection visible to the compiled graph. Write dense and sparse attention results into the custom-op output buffer. Preserve FP32 sparse GQA accumulation until the final copy/cast, and expose the output buffer through MiniMaxM3SparseRuntimeBackend.forward. Register attention boundaries and mutation metadata through optional TRT-LLM op lookup, matching the latest GDN registration pattern from PR NVIDIA#15594. This avoids depending on model-specific custom ops being imported when compilation utilities initialize. Track piecewise runners owned by the compile backend and reset their CUDA graphs, captured addresses, outputs, and warmup state when phase-1 KV-cache estimation is released. Phase 2 then recaptures against the final allocations instead of replaying stale graph pointers. Add an 8-GPU MiniMax-M3-MXFP8 torch.compile E2E variant covering TP8/EP8, attention DP, TRTLLM MoE, padding CUDA graphs, multi-stream piecewise capture, and phase-2 recapture. Signed-off-by: Jin Li <59594262+liji-nv@users.noreply.github.com>
|
PR_Github #57461 [ run ] completed with state
|
…graph
Add an inplace custom op boundary for Qwen3.5 GDN so torch.compile piecewise CUDA graph can keep tokenwise projections outside the custom op while hiding FLA state updates from FX capture.
Update the FLA GDN helpers to write into caller-provided output tensors, register inplace custom-op metadata, and exclude the GDN custom op from piecewise CUDA graph capture. Add a Qwen3.5 FP8 piecewise CUDA graph smoke test.
Bug: https://nvbugs/6299530
Tested:
Summary by CodeRabbit
New Features
Bug Fixes
Tests
Description
Test Coverage
PR Checklist
Please review the following before submitting your PR:
PR description clearly explains what and why. If using CodeRabbit's summary, please make sure it makes sense.
PR Follows TRT-LLM CODING GUIDELINES to the best of your knowledge.
Test cases are provided for new code paths (see test instructions)
If PR introduces API changes, an appropriate PR label is added - either
api-compatibleorapi-breaking. Forapi-breaking, includeBREAKINGin the PR title.Any new dependencies have been scanned for license and vulnerabilities
CODEOWNERS updated if ownership changes
Documentation updated as needed
Update tava architecture diagram if there is a significant design change in PR.
The reviewers assigned automatically/manually are appropriate for the PR.
Please check this after reviewing the above items as appropriate for this PR.
GitHub Bot Help
To see a list of available CI bot commands, please comment
/bot help.