[None][feat] optimize GDN prefill with indexed in-kernel state updates#12791
Conversation
Signed-off-by: nv-guomingz <137257613+nv-guomingz@users.noreply.github.com>
📝 WalkthroughWalkthroughThis pull request introduces indexed state update support to the chunked gated delta rule pipeline. New parameters ( Changes
Sequence Diagram(s)sequenceDiagram
participant Caller
participant Python API
participant Kernel
participant State Memory
rect rgba(100, 150, 200, 0.5)
Note over Caller,State Memory: Previous Flow (Host-side State Update)
Caller->>Python API: Call chunk_gated_delta_rule(initial_state)
Python API->>Kernel: Execute with initial_state batch
Kernel->>State Memory: Compute output
Kernel-->>Python API: Return final_state
Python API->>State Memory: index_assign updated state
Python API-->>Caller: Return output
end
rect rgba(150, 200, 100, 0.5)
Note over Caller,State Memory: New Flow (Indexed In-kernel State Update)
Caller->>Python API: Call chunk_gated_delta_rule(ssm_states, initial_state_indices, inplace_indexed_state_update=True)
Python API->>Kernel: Execute with full state tensor + indices
Kernel->>State Memory: Indexed select & initialize state
Kernel->>State Memory: Update selected state slots in-kernel
Kernel-->>Python API: Return output (no final_state)
Python API-->>Caller: Return output
end
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~20 minutes 🚥 Pre-merge checks | ✅ 1 | ❌ 2❌ Failed checks (2 warnings)
✅ Passed checks (1 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Comment |
There was a problem hiding this comment.
Actionable comments posted: 2
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (2)
tensorrt_llm/_torch/modules/fla/chunk_delta_h.py (1)
1-4:⚠️ Potential issue | 🟠 MajorAdd the required NVIDIA header to this modified file.
This file is changed in the PR but still has no NVIDIA copyright block at the top.
As per coding guidelines, "All TensorRT-LLM Open Source Software code files should contain an NVIDIA copyright header with the year of latest meaningful modification."
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tensorrt_llm/_torch/modules/fla/chunk_delta_h.py` around lines 1 - 4, This file is missing the required NVIDIA copyright header; add the standard NVIDIA copyright/license block (including the year of latest meaningful modification) at the very top of chunk_delta_h.py before any existing comments or imports, replacing or preceding the current adapted-from comments so the file contains the mandated NVIDIA header text.tensorrt_llm/_torch/modules/fla/chunk.py (1)
1-4:⚠️ Potential issue | 🟠 MajorAdd the required NVIDIA header to this modified file.
This file is changed in the PR but still has no NVIDIA copyright block at the top.
As per coding guidelines, "All TensorRT-LLM Open Source Software code files should contain an NVIDIA copyright header with the year of latest meaningful modification."
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tensorrt_llm/_torch/modules/fla/chunk.py` around lines 1 - 4, This file is missing the required NVIDIA copyright header; add the project's standard NVIDIA copyright header block at the very top of chunk.py (above the existing adaptation comments), include the year of the latest meaningful modification and the correct copyright owner phrasing used across the repo, and ensure the SPDX license identifier and any required short notice lines match the repo's standard header format.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@tensorrt_llm/_torch/modules/fla/chunk_delta_h.py`:
- Around line 19-23: The predicate for indexed-mode is too permissive: change
the triton heuristic key "USE_INDEXED_STATE" to require both the initial state
and its indices (i.e., lambda args: args["h0"] is not None and args["h0_i"] is
not None) so it matches the Python guard in chunk_gated_delta_rule_fwd_h; update
every occurrence of "USE_INDEXED_STATE" in this module (including the other
heuristics blocks around chunk_gated_delta_rule_fwd_h, the epilogue/store logic
referring to h0/h0_i/ht, and any similar function specializations) to use the
combined check so indexed mode is only specialized when both initial_state and
initial_state_indices are present.
In `@tensorrt_llm/_torch/modules/fla/chunk.py`:
- Around line 128-129: Before dispatch, validate initial_state_indices (when not
None) for both fixed-length and varlen paths: ensure it's a 1-D integer tensor,
contains no negative values, every entry is < initial_state.shape[0], and its
length equals the expected number of indices used by the call (the same count
currently enforced only in the varlen branch). If any check fails, raise a clear
ValueError. Apply this validation wherever initial_state_indices is consumed
(including the indexed path handling and the code paths referenced around
chunk_delta_h.py and the blocks corresponding to the earlier and later ranges
noted in the review).
---
Outside diff comments:
In `@tensorrt_llm/_torch/modules/fla/chunk_delta_h.py`:
- Around line 1-4: This file is missing the required NVIDIA copyright header;
add the standard NVIDIA copyright/license block (including the year of latest
meaningful modification) at the very top of chunk_delta_h.py before any existing
comments or imports, replacing or preceding the current adapted-from comments so
the file contains the mandated NVIDIA header text.
In `@tensorrt_llm/_torch/modules/fla/chunk.py`:
- Around line 1-4: This file is missing the required NVIDIA copyright header;
add the project's standard NVIDIA copyright header block at the very top of
chunk.py (above the existing adaptation comments), include the year of the
latest meaningful modification and the correct copyright owner phrasing used
across the repo, and ensure the SPDX license identifier and any required short
notice lines match the repo's standard header format.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Run ID: 675fda29-1a49-449c-9641-d8cd936494f8
📒 Files selected for processing (3)
tensorrt_llm/_torch/modules/fla/chunk.pytensorrt_llm/_torch/modules/fla/chunk_delta_h.pytensorrt_llm/_torch/modules/mamba/gdn_mixer.py
|
/bot run |
|
PR_Github #42100 [ run ] triggered by Bot. Commit: |
|
PR_Github #42100 [ run ] completed with state
|
|
/bot run --disable-fail-fast |
|
PR_Github #42153 [ run ] triggered by Bot. Commit: |
|
PR_Github #42153 [ run ] completed with state
|
|
/bot run |
|
PR_Github #42209 [ run ] triggered by Bot. Commit: |
|
PR_Github #42209 [ run ] completed with state
|
|
/bot run |
|
PR_Github #42311 [ run ] triggered by Bot. Commit: |
|
/bot run |
|
PR_Github #42452 [ run ] triggered by Bot. Commit: |
|
PR_Github #42452 [ run ] completed with state
|
|
/bot run --disable-fail-fast |
|
PR_Github #42532 [ run ] triggered by Bot. Commit: |
|
PR_Github #42532 [ run ] completed with state
|
|
/bot run --disable-fail-fast |
|
PR_Github #42615 [ run ] triggered by Bot. Commit: |
|
PR_Github #42615 [ run ] completed with state |
Summary by CodeRabbit
Description
Test Coverage
PR Checklist
Please review the following before submitting your PR:
PR description clearly explains what and why. If using CodeRabbit's summary, please make sure it makes sense.
PR Follows TRT-LLM CODING GUIDELINES to the best of your knowledge.
Test cases are provided for new code paths (see test instructions)
Any new dependencies have been scanned for license and vulnerabilities
CODEOWNERS updated if ownership changes
Documentation updated as needed
Update tava architecture diagram if there is a significant design change in PR.
The reviewers assigned automatically/manually are appropriate for the PR.
Please check this after reviewing the above items as appropriate for this PR.
GitHub Bot Help
To see a list of available CI bot commands, please comment
/bot help.