Skip to content

[TRTLLM-11285][feat] Fuse indexer wk + weights_proj into single GEMM in TF32 for DS-V3.2#12055

Merged
longlee0622 merged 4 commits intoNVIDIA:mainfrom
peihu-nv:feat/trtllm-10283-dsv32-tf32-indexer-fuse
Mar 18, 2026
Merged

[TRTLLM-11285][feat] Fuse indexer wk + weights_proj into single GEMM in TF32 for DS-V3.2#12055
longlee0622 merged 4 commits intoNVIDIA:mainfrom
peihu-nv:feat/trtllm-10283-dsv32-tf32-indexer-fuse

Conversation

@peihu-nv
Copy link
Collaborator

@peihu-nv peihu-nv commented Mar 9, 2026

Prerequisite

Summary by CodeRabbit

  • Refactor

    • Optimized sparse attention computation through improved weight fusion mechanism for better performance
    • Simplified sparse attention initialization and forward computation flow
  • Tests

    • Updated sparse attention tests to align with refactored weight fusion approach

Description

  • Fuse the two DS-V3.2 indexer projections (wk and weights_proj) into a single FP32 cuBLAS GEMM (TF32 tensor cores on Ampere+), saving one kernel launch
  • Force wk to FP32 with no quantization to match weights_proj, enabling the fusion
  • FP32 (TF32) chosen over FP16 based on accuracy analysis in TRTLLM-10283

Test Coverage

  • test_sparse_mla_forward.py
  • test_short_seq_mha.py
  • NVFP4 inference test (TP=4, DS-V3.2 checkpoint)

PR Checklist

Please review the following before submitting your PR:

  • PR description clearly explains what and why. If using CodeRabbit's summary, please make sure it makes sense.

  • PR Follows TRT-LLM CODING GUIDELINES to the best of your knowledge.

  • Test cases are provided for new code paths (see test instructions)

  • Any new dependencies have been scanned for license and vulnerabilities

  • CODEOWNERS updated if ownership changes

  • Documentation updated as needed

  • Update tava architecture diagram if there is a significant design change in PR.

  • The reviewers assigned automatically/manually are appropriate for the PR.

  • Please check this after reviewing the above items as appropriate for this PR.

GitHub Bot Help

To see a list of available CI bot commands, please comment /bot help.

peihu-nv added 2 commits March 6, 2026 13:10
…V3.2 NVFP4

Signed-off-by: peihengh <259410613+peihu-nv@users.noreply.github.com>
…cuBLAS GEMM for DS-V3.2

Fuse the two indexer projections (wk and weights_proj) into a single
FP32 cuBLAS GEMM for DS-V3.2 sparse attention. Uses TF32 tensor cores
on Ampere+, saving one kernel launch (~1% throughput gain per DLSim).

Signed-off-by: Pei-Hung Huang <peihengh@nvidia.com>
Signed-off-by: peihengh <259410613+peihu-nv@users.noreply.github.com>
@coderabbitai
Copy link
Contributor

coderabbitai bot commented Mar 9, 2026

📝 Walkthrough

Walkthrough

The changes refactor the Indexer weight fusion mechanism by introducing a post-load-weights method that fuses wk and weights_proj tensors. The explicit indexer_k parameter is removed from DSATrtllmAttention forward signature and indexer calls, with weight fusion now occurring during post-load initialization. Related model and test code is updated accordingly.

Changes

Cohort / File(s) Summary
Indexer Weight Fusion
tensorrt_llm/_torch/attention_backend/sparse/dsa.py
Introduces post_load_weights() method and _fused_wk_wp_weight attribute for fusing wk and weights_proj weights. Updates Indexer.__init__ to set wk and weights_proj to float32 with no quantization. Modifies forward computation to use fused cuBLAS GEMM, removing explicit indexer_k parameter from DSATrtllmAttention signature.
Attention Module Updates
tensorrt_llm/_torch/models/modeling_deepseekv3.py, tensorrt_llm/_torch/modules/attention.py
Removes post_load_weights indexer fusion logic from DeepseekV32Attention. Updates DeepseekV32Attention.kv_a_proj_with_mqa output features calculation to exclude indexer.head_dim. Modifies forward_impl_with_dsa to split kv_a_proj_with_mqa output into three components instead of four, removing indexer_k handling.
Test Updates
tests/unittest/_torch/attention/sparse/test_short_seq_mha.py, tests/unittest/_torch/attention/sparse/test_sparse_mla_forward.py
Adds post_load_weights() calls on indexer components after initialization. Removes indexer_k argument from indexer method calls, relying on internal default/top-k computation instead.

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~22 minutes

🚥 Pre-merge checks | ✅ 2 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 42.86% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (2 passed)
Check name Status Explanation
Title check ✅ Passed The PR title clearly and specifically describes the main change: fusing two indexer projections (wk + weights_proj) into a single GEMM in TF32 for DeepSeek-V3.2, matching the core objective.
Description check ✅ Passed The PR description clearly explains the objective (fuse indexer projections into FP32 GEMM), provides test coverage examples, and includes a reference to a related TRTLLM ticket for accuracy analysis.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests
📝 Coding Plan
  • Generate coding plan for human review comments

Comment @coderabbitai help to get the list of available commands and usage tips.

Tip

CodeRabbit can use your project's `pylint` configuration to improve the quality of Python code reviews.

Add a pylint configuration file to your project to customize how CodeRabbit runs pylint.

…sv32-tf32-indexer-fuse

Signed-off-by: peihengh <259410613+peihu-nv@users.noreply.github.com>

# Conflicts:
#	tensorrt_llm/_torch/modules/attention.py
@peihu-nv peihu-nv force-pushed the feat/trtllm-10283-dsv32-tf32-indexer-fuse branch from bba4a8e to c7049ce Compare March 17, 2026 05:12
@peihu-nv
Copy link
Collaborator Author

/bot run

@tensorrt-cicd
Copy link
Collaborator

PR_Github #39188 [ run ] triggered by Bot. Commit: c7049ce Link to invocation

@tensorrt-cicd
Copy link
Collaborator

PR_Github #39188 [ run ] completed with state SUCCESS. Commit: c7049ce
/LLM/main/L0_MergeRequest_PR pipeline #30442 completed with status: 'FAILURE'

CI Report

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

Link to invocation

@peihu-nv
Copy link
Collaborator Author

/bot run --disable-fail-fast

Signed-off-by: peihengh <259410613+peihu-nv@users.noreply.github.com>
@peihu-nv
Copy link
Collaborator Author

/bot run --disable-fail-fast

@tensorrt-cicd
Copy link
Collaborator

PR_Github #39282 [ run ] triggered by Bot. Commit: 6f744c1 Link to invocation

@tensorrt-cicd
Copy link
Collaborator

PR_Github #39282 [ run ] completed with state SUCCESS. Commit: 6f744c1
/LLM/main/L0_MergeRequest_PR pipeline #30540 completed with status: 'SUCCESS'

CI Report

Link to invocation

@longlee0622 longlee0622 enabled auto-merge (squash) March 18, 2026 03:07
Copy link
Collaborator

@pengbowang-nv pengbowang-nv left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Leaving a minor question.

@longlee0622 longlee0622 merged commit 4b915d4 into NVIDIA:main Mar 18, 2026
5 checks passed
limin2021 pushed a commit to limin2021/TensorRT-LLM that referenced this pull request Mar 19, 2026
…in TF32 for DS-V3.2 (NVIDIA#12055)

Signed-off-by: peihengh <259410613+peihu-nv@users.noreply.github.com>
Signed-off-by: Pei-Hung Huang <peihengh@nvidia.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants