[TRTLLM-11285][feat] Fuse indexer wk + weights_proj into single GEMM in TF32 for DS-V3.2#12055
Conversation
…V3.2 NVFP4 Signed-off-by: peihengh <259410613+peihu-nv@users.noreply.github.com>
…cuBLAS GEMM for DS-V3.2 Fuse the two indexer projections (wk and weights_proj) into a single FP32 cuBLAS GEMM for DS-V3.2 sparse attention. Uses TF32 tensor cores on Ampere+, saving one kernel launch (~1% throughput gain per DLSim). Signed-off-by: Pei-Hung Huang <peihengh@nvidia.com> Signed-off-by: peihengh <259410613+peihu-nv@users.noreply.github.com>
📝 WalkthroughWalkthroughThe changes refactor the Indexer weight fusion mechanism by introducing a post-load-weights method that fuses wk and weights_proj tensors. The explicit indexer_k parameter is removed from DSATrtllmAttention forward signature and indexer calls, with weight fusion now occurring during post-load initialization. Related model and test code is updated accordingly. Changes
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~22 minutes 🚥 Pre-merge checks | ✅ 2 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
📝 Coding Plan
Comment Tip CodeRabbit can use your project's `pylint` configuration to improve the quality of Python code reviews.Add a pylint configuration file to your project to customize how CodeRabbit runs |
…sv32-tf32-indexer-fuse Signed-off-by: peihengh <259410613+peihu-nv@users.noreply.github.com> # Conflicts: # tensorrt_llm/_torch/modules/attention.py
bba4a8e to
c7049ce
Compare
|
/bot run |
|
PR_Github #39188 [ run ] triggered by Bot. Commit: |
|
PR_Github #39188 [ run ] completed with state
|
|
/bot run --disable-fail-fast |
Signed-off-by: peihengh <259410613+peihu-nv@users.noreply.github.com>
|
/bot run --disable-fail-fast |
|
PR_Github #39282 [ run ] triggered by Bot. Commit: |
|
PR_Github #39282 [ run ] completed with state |
pengbowang-nv
left a comment
There was a problem hiding this comment.
LGTM. Leaving a minor question.
…in TF32 for DS-V3.2 (NVIDIA#12055) Signed-off-by: peihengh <259410613+peihu-nv@users.noreply.github.com> Signed-off-by: Pei-Hung Huang <peihengh@nvidia.com>
Prerequisite
Summary by CodeRabbit
Refactor
Tests
Description
Test Coverage
PR Checklist
Please review the following before submitting your PR:
PR description clearly explains what and why. If using CodeRabbit's summary, please make sure it makes sense.
PR Follows TRT-LLM CODING GUIDELINES to the best of your knowledge.
Test cases are provided for new code paths (see test instructions)
Any new dependencies have been scanned for license and vulnerabilities
CODEOWNERS updated if ownership changes
Documentation updated as needed
Update tava architecture diagram if there is a significant design change in PR.
The reviewers assigned automatically/manually are appropriate for the PR.
Please check this after reviewing the above items as appropriate for this PR.
GitHub Bot Help
To see a list of available CI bot commands, please comment
/bot help.