Skip to content

[None][feat] optimize GDN prefill with indexed in-kernel state updates#12791

Merged
nv-guomingz merged 1 commit intoNVIDIA:mainfrom
nv-guomingz:user/guomingz/fla-prefill-indexed-state
Apr 13, 2026
Merged

[None][feat] optimize GDN prefill with indexed in-kernel state updates#12791
nv-guomingz merged 1 commit intoNVIDIA:mainfrom
nv-guomingz:user/guomingz/fla-prefill-indexed-state

Conversation

@nv-guomingz
Copy link
Copy Markdown
Collaborator

@nv-guomingz nv-guomingz commented Apr 7, 2026

Summary by CodeRabbit

  • Refactor
    • Enhanced internal state handling pipeline to support indexed initial states with optimized in-kernel updates.
    • Improved variable-length sequence processing with updated validation logic and more flexible state initialization patterns.

Description

Test Coverage

PR Checklist

Please review the following before submitting your PR:

  • PR description clearly explains what and why. If using CodeRabbit's summary, please make sure it makes sense.

  • PR Follows TRT-LLM CODING GUIDELINES to the best of your knowledge.

  • Test cases are provided for new code paths (see test instructions)

  • Any new dependencies have been scanned for license and vulnerabilities

  • CODEOWNERS updated if ownership changes

  • Documentation updated as needed

  • Update tava architecture diagram if there is a significant design change in PR.

  • The reviewers assigned automatically/manually are appropriate for the PR.

  • Please check this after reviewing the above items as appropriate for this PR.

GitHub Bot Help

To see a list of available CI bot commands, please comment /bot help.

Signed-off-by: nv-guomingz <137257613+nv-guomingz@users.noreply.github.com>
@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai bot commented Apr 7, 2026

📝 Walkthrough

Walkthrough

This pull request introduces indexed state update support to the chunked gated delta rule pipeline. New parameters (initial_state_indices and inplace_indexed_state_update) are added and threaded through the Python API and Triton kernel. When enabled, state updates occur in-kernel via indexed selection rather than on the host, and callers are updated to leverage this mechanism.

Changes

Cohort / File(s) Summary
Core Pipeline API
tensorrt_llm/_torch/modules/fla/chunk.py
Added initial_state_indices and inplace_indexed_state_update parameters to chunk_gated_delta_rule(), ChunkGatedDeltaRuleFunction.forward(), and chunk_gated_delta_rule_fwd(). Updated validation logic for variable-length inputs to accept either indexed state or batch-level initial state matching sequence count.
Kernel Implementation
tensorrt_llm/_torch/modules/fla/chunk_delta_h.py
Introduced Triton kernel execution path USE_INDEXED_STATE for indexed state initialization. Updated kernel to load per-sequence state indices, remap h0 to selected slots, and perform in-kernel state writes when indexed mode is active. Extended Python wrapper to validate indexed mode requirements and construct final state via index_select when needed.
Consumer Integration
tensorrt_llm/_torch/modules/mamba/gdn_mixer.py
Refactored forward_extend() to pass full ssm_states tensor with initial_state_indices=cache_indices and inplace_indexed_state_update=True, enabling in-kernel state updates instead of host-side conversion and manual writeback.

Sequence Diagram(s)

sequenceDiagram
    participant Caller
    participant Python API
    participant Kernel
    participant State Memory
    
    rect rgba(100, 150, 200, 0.5)
    Note over Caller,State Memory: Previous Flow (Host-side State Update)
    Caller->>Python API: Call chunk_gated_delta_rule(initial_state)
    Python API->>Kernel: Execute with initial_state batch
    Kernel->>State Memory: Compute output
    Kernel-->>Python API: Return final_state
    Python API->>State Memory: index_assign updated state
    Python API-->>Caller: Return output
    end
    
    rect rgba(150, 200, 100, 0.5)
    Note over Caller,State Memory: New Flow (Indexed In-kernel State Update)
    Caller->>Python API: Call chunk_gated_delta_rule(ssm_states, initial_state_indices, inplace_indexed_state_update=True)
    Python API->>Kernel: Execute with full state tensor + indices
    Kernel->>State Memory: Indexed select & initialize state
    Kernel->>State Memory: Update selected state slots in-kernel
    Kernel-->>Python API: Return output (no final_state)
    Python API-->>Caller: Return output
    end
Loading

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~20 minutes

🚥 Pre-merge checks | ✅ 1 | ❌ 2

❌ Failed checks (2 warnings)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 16.67% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
Description check ⚠️ Warning The PR description is empty with only the template structure visible; no actual content explaining the changes, rationale, or test coverage is provided. Fill in the Description section with a clear explanation of what changes were made and why. Add specific test coverage details in the Test Coverage section. Complete the PR Checklist items as needed.
✅ Passed checks (1 passed)
Check name Status Explanation
Title check ✅ Passed The title clearly describes the main change: optimizing GDN prefill with indexed in-kernel state updates, which is directly reflected in the changeset.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (2)
tensorrt_llm/_torch/modules/fla/chunk_delta_h.py (1)

1-4: ⚠️ Potential issue | 🟠 Major

Add the required NVIDIA header to this modified file.

This file is changed in the PR but still has no NVIDIA copyright block at the top.

As per coding guidelines, "All TensorRT-LLM Open Source Software code files should contain an NVIDIA copyright header with the year of latest meaningful modification."

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tensorrt_llm/_torch/modules/fla/chunk_delta_h.py` around lines 1 - 4, This
file is missing the required NVIDIA copyright header; add the standard NVIDIA
copyright/license block (including the year of latest meaningful modification)
at the very top of chunk_delta_h.py before any existing comments or imports,
replacing or preceding the current adapted-from comments so the file contains
the mandated NVIDIA header text.
tensorrt_llm/_torch/modules/fla/chunk.py (1)

1-4: ⚠️ Potential issue | 🟠 Major

Add the required NVIDIA header to this modified file.

This file is changed in the PR but still has no NVIDIA copyright block at the top.

As per coding guidelines, "All TensorRT-LLM Open Source Software code files should contain an NVIDIA copyright header with the year of latest meaningful modification."

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tensorrt_llm/_torch/modules/fla/chunk.py` around lines 1 - 4, This file is
missing the required NVIDIA copyright header; add the project's standard NVIDIA
copyright header block at the very top of chunk.py (above the existing
adaptation comments), include the year of the latest meaningful modification and
the correct copyright owner phrasing used across the repo, and ensure the SPDX
license identifier and any required short notice lines match the repo's standard
header format.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@tensorrt_llm/_torch/modules/fla/chunk_delta_h.py`:
- Around line 19-23: The predicate for indexed-mode is too permissive: change
the triton heuristic key "USE_INDEXED_STATE" to require both the initial state
and its indices (i.e., lambda args: args["h0"] is not None and args["h0_i"] is
not None) so it matches the Python guard in chunk_gated_delta_rule_fwd_h; update
every occurrence of "USE_INDEXED_STATE" in this module (including the other
heuristics blocks around chunk_gated_delta_rule_fwd_h, the epilogue/store logic
referring to h0/h0_i/ht, and any similar function specializations) to use the
combined check so indexed mode is only specialized when both initial_state and
initial_state_indices are present.

In `@tensorrt_llm/_torch/modules/fla/chunk.py`:
- Around line 128-129: Before dispatch, validate initial_state_indices (when not
None) for both fixed-length and varlen paths: ensure it's a 1-D integer tensor,
contains no negative values, every entry is < initial_state.shape[0], and its
length equals the expected number of indices used by the call (the same count
currently enforced only in the varlen branch). If any check fails, raise a clear
ValueError. Apply this validation wherever initial_state_indices is consumed
(including the indexed path handling and the code paths referenced around
chunk_delta_h.py and the blocks corresponding to the earlier and later ranges
noted in the review).

---

Outside diff comments:
In `@tensorrt_llm/_torch/modules/fla/chunk_delta_h.py`:
- Around line 1-4: This file is missing the required NVIDIA copyright header;
add the standard NVIDIA copyright/license block (including the year of latest
meaningful modification) at the very top of chunk_delta_h.py before any existing
comments or imports, replacing or preceding the current adapted-from comments so
the file contains the mandated NVIDIA header text.

In `@tensorrt_llm/_torch/modules/fla/chunk.py`:
- Around line 1-4: This file is missing the required NVIDIA copyright header;
add the project's standard NVIDIA copyright header block at the very top of
chunk.py (above the existing adaptation comments), include the year of the
latest meaningful modification and the correct copyright owner phrasing used
across the repo, and ensure the SPDX license identifier and any required short
notice lines match the repo's standard header format.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: 675fda29-1a49-449c-9641-d8cd936494f8

📥 Commits

Reviewing files that changed from the base of the PR and between a03aeb1 and 907e871.

📒 Files selected for processing (3)
  • tensorrt_llm/_torch/modules/fla/chunk.py
  • tensorrt_llm/_torch/modules/fla/chunk_delta_h.py
  • tensorrt_llm/_torch/modules/mamba/gdn_mixer.py

Comment thread tensorrt_llm/_torch/modules/fla/chunk_delta_h.py
Comment thread tensorrt_llm/_torch/modules/fla/chunk.py
@nv-guomingz
Copy link
Copy Markdown
Collaborator Author

/bot run

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #42100 [ run ] triggered by Bot. Commit: 907e871 Link to invocation

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #42100 [ run ] completed with state FAILURE. Commit: 907e871
/LLM/main/L0_MergeRequest_PR pipeline #32938 completed with status: 'FAILURE'

CI Report

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

Link to invocation

@nv-guomingz
Copy link
Copy Markdown
Collaborator Author

/bot run --disable-fail-fast

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #42153 [ run ] triggered by Bot. Commit: 907e871 Link to invocation

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #42153 [ run ] completed with state SUCCESS. Commit: 907e871
/LLM/main/L0_MergeRequest_PR pipeline #32985 completed with status: 'FAILURE'

CI Report

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

Link to invocation

@nv-guomingz
Copy link
Copy Markdown
Collaborator Author

/bot run

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #42209 [ run ] triggered by Bot. Commit: 907e871 Link to invocation

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #42209 [ run ] completed with state FAILURE. Commit: 907e871
/LLM/main/L0_MergeRequest_PR pipeline #33027 completed with status: 'FAILURE'

CI Report

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

Link to invocation

@nv-guomingz
Copy link
Copy Markdown
Collaborator Author

/bot run

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #42311 [ run ] triggered by Bot. Commit: 907e871 Link to invocation

@nv-guomingz
Copy link
Copy Markdown
Collaborator Author

/bot run

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #42452 [ run ] triggered by Bot. Commit: 907e871 Link to invocation

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #42452 [ run ] completed with state FAILURE. Commit: 907e871
/LLM/main/L0_MergeRequest_PR pipeline #33216 completed with status: 'FAILURE'

CI Report

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

Link to invocation

@nv-guomingz
Copy link
Copy Markdown
Collaborator Author

/bot run --disable-fail-fast

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #42532 [ run ] triggered by Bot. Commit: 907e871 Link to invocation

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #42532 [ run ] completed with state SUCCESS. Commit: 907e871
/LLM/main/L0_MergeRequest_PR pipeline #33272 completed with status: 'FAILURE'

CI Report

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

Link to invocation

@nv-guomingz
Copy link
Copy Markdown
Collaborator Author

/bot run --disable-fail-fast

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #42615 [ run ] triggered by Bot. Commit: 907e871 Link to invocation

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #42615 [ run ] completed with state SUCCESS. Commit: 907e871
/LLM/main/L0_MergeRequest_PR pipeline #33335 completed with status: 'SUCCESS'

CI Report

Link to invocation

Comment thread tensorrt_llm/_torch/modules/mamba/gdn_mixer.py
Copy link
Copy Markdown
Collaborator

@HuiGao-NV HuiGao-NV left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@nv-guomingz nv-guomingz enabled auto-merge (squash) April 13, 2026 09:43
@nv-guomingz nv-guomingz merged commit 23c870b into NVIDIA:main Apr 13, 2026
8 of 10 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants