Skip to content

[TRTLLM-11794][feat] Optimize ViT Attention kernel on Nemotron#12911

Merged
2ez4bz merged 3 commits intoNVIDIA:mainfrom
yechank-nvidia:vit_kernel
Apr 17, 2026
Merged

[TRTLLM-11794][feat] Optimize ViT Attention kernel on Nemotron#12911
2ez4bz merged 3 commits intoNVIDIA:mainfrom
yechank-nvidia:vit_kernel

Conversation

@yechank-nvidia
Copy link
Copy Markdown
Collaborator

@yechank-nvidia yechank-nvidia commented Apr 10, 2026

Description

This PR introduces new Flashinfer Backend BatchPrefillWithRaggedKVCacheWrapper for the case where kv_cache_manger is None.(e.g. Vision Attention case)

For the case of seq_len > 1024, FLASHINFER's cudnn backend seems promising result. And the image data is on average over seq_len > 1024, so it is convincing that changing to FLASHINFER's cudnn is good.

Below is the data from microbenchmark.

batch=1 seq_lens=[128, 256, 512, 1024, 2048, 4096, 8192] dim=1280 heads=16 dtype=bfloat16 warmup=10 iters=50

===================================================================================================
seq_len | FLASHINFER_ms |  TRTLLM_ms  | FLASHINFER_µs/tok | TRTLLM_µs/tok | ratio_FLASHINFER/TRTLLM
---------------------------------------------------------------------------------------------------
    128 |       0.106 |       0.091 |       0.829 |       0.713 |          1.162
    256 |       0.111 |       0.089 |       0.433 |       0.349 |          1.240
    512 |       0.110 |       0.089 |       0.214 |       0.173 |          1.238
   1024 |       0.110 |       0.090 |       0.107 |       0.088 |          1.223
   2048 |       0.110 |       0.157 |       0.054 |       0.077 |          0.699
   4096 |       0.156 |       0.480 |       0.038 |       0.117 |          0.325
   8192 |       0.420 |       1.743 |       0.051 |       0.213 |          0.241
===================================================================================================

Summary by CodeRabbit

  • New Features

    • Added support for ragged KV cache prefill without explicit cache management, enabling more efficient context batch processing
    • Added configurable vision attention backend parameter for multimodal models
  • Tests

    • Added test coverage for ragged prefill mode validation

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented Apr 10, 2026

📝 Walkthrough

Walkthrough

The changes extend the FlashInfer attention backend to support ragged prefill operations without requiring a KV cache manager. This involves making the KV cache manager optional in the interface, adding ragged prefill wrapper support, implementing a cuDNN-based planning path, and updating model configuration handling.

Changes

Cohort / File(s) Summary
Interface Update
tensorrt_llm/_torch/attention_backend/interface.py
Loosened kv_cache_manager and draft_kv_cache_manager types to include None, making KV cache management optional.
FlashInfer Ragged Prefill Support
tensorrt_llm/_torch/attention_backend/flashinfer.py
Extended FlashInferWrappers with optional ragged_prefill_wrapper; made decode_wrapper and prefill_wrapper optional. Added _plan_ragged_cudnn_no_kv() planning path, get_ragged_prefill_wrapper() method, and conditional logic in prepare() and forward_impl() to handle no-KV-cache scenarios via ragged prefill.
Model Configuration Handling
tensorrt_llm/_torch/models/modeling_nemotron_nano.py, tensorrt_llm/_torch/models/modeling_radio.py
Updated Nemotron model to use separate deep-copied configs for vision and language submodules. Added DEFAULT_VISION_ATTN_BACKEND and vision_attn_backend parameter to RADIO vision model with dataclass config replacement.
Test Coverage
tests/unittest/_torch/attention/test_flashinfer_attention.py
Added test_ragged_prefill_no_kv_cache_uses_cudnn_plan to verify ragged prefill path activation and cuDNN backend selection when kv_cache_manager is None.

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~45 minutes

🚥 Pre-merge checks | ✅ 1 | ❌ 2

❌ Failed checks (1 warning, 1 inconclusive)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 23.08% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
Description check ❓ Inconclusive The description provides context with performance data but misses key template sections: Test Coverage details and PR Checklist confirmation are not addressed. Add explicit Test Coverage section listing test cases (e.g., test_ragged_prefill_no_kv_cache_uses_cudnn_plan) and confirm PR Checklist items, especially CODEOWNERS and documentation updates.
✅ Passed checks (1 passed)
Check name Status Explanation
Title check ✅ Passed The title accurately summarizes the main feature: optimizing ViT Attention kernel on Nemotron using FLASHINFER.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🧹 Nitpick comments (1)
tensorrt_llm/_torch/models/modeling_nemotron_nano.py (1)

1237-1243: Consider removing redundant deep copy of vision_model_config.

Line 1238 creates a deep copy of model_config for the vision encoder, but this copy is never modified before being passed to NanoV2VLVisionEncoder at line 1243. Additionally, NanoV2VLVisionEncoder.__init__ creates its own deep copy at lines 296-298 when instantiating RADIOVisionModel.

Since vision_model_config is not modified and the encoder creates its own copy, you could pass model_config directly to NanoV2VLVisionEncoder to avoid the redundant deep copy operation.

♻️ Suggested optimization
 llm_model_config = copy.deepcopy(model_config)
-vision_model_config = copy.deepcopy(model_config)
 if hasattr(self, "llm"):
     return

 if not _is_disagg():
-    self.vision_encoder = NanoV2VLVisionEncoder(vision_model_config).eval()
+    self.vision_encoder = NanoV2VLVisionEncoder(model_config).eval()
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tensorrt_llm/_torch/models/modeling_nemotron_nano.py` around lines 1237 -
1243, Remove the unnecessary deep copy for vision_model_config: instead of
creating vision_model_config = copy.deepcopy(model_config) and passing it to
NanoV2VLVisionEncoder, pass model_config directly to NanoV2VLVisionEncoder in
the block where self.vision_encoder is created (the code that checks
hasattr(self, "llm") and if not _is_disagg()). Leave llm_model_config as-is;
note that NanoV2VLVisionEncoder.__init__ already deep-copies when instantiating
RADIOVisionModel (see its __init__), so the external deep copy is redundant.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Nitpick comments:
In `@tensorrt_llm/_torch/models/modeling_nemotron_nano.py`:
- Around line 1237-1243: Remove the unnecessary deep copy for
vision_model_config: instead of creating vision_model_config =
copy.deepcopy(model_config) and passing it to NanoV2VLVisionEncoder, pass
model_config directly to NanoV2VLVisionEncoder in the block where
self.vision_encoder is created (the code that checks hasattr(self, "llm") and if
not _is_disagg()). Leave llm_model_config as-is; note that
NanoV2VLVisionEncoder.__init__ already deep-copies when instantiating
RADIOVisionModel (see its __init__), so the external deep copy is redundant.

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: 564bc019-e5e8-4c0e-bfed-aa106130cc2b

📥 Commits

Reviewing files that changed from the base of the PR and between 7876dc7 and 6ef5658.

📒 Files selected for processing (5)
  • tensorrt_llm/_torch/attention_backend/flashinfer.py
  • tensorrt_llm/_torch/attention_backend/interface.py
  • tensorrt_llm/_torch/models/modeling_nemotron_nano.py
  • tensorrt_llm/_torch/models/modeling_radio.py
  • tests/unittest/_torch/attention/test_flashinfer_attention.py

@yechank-nvidia
Copy link
Copy Markdown
Collaborator Author

/bot run

Copy link
Copy Markdown
Collaborator

@2ez4bz 2ez4bz left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for this!

Comment thread tensorrt_llm/_torch/attention_backend/flashinfer.py Outdated
Comment thread tensorrt_llm/_torch/attention_backend/flashinfer.py
Comment thread tensorrt_llm/_torch/models/modeling_nemotron_nano.py
Comment thread tensorrt_llm/_torch/models/modeling_radio.py Outdated
@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #42676 [ run ] triggered by Bot. Commit: 6ef5658 Link to invocation

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #42676 [ run ] completed with state SUCCESS. Commit: 6ef5658
/LLM/main/L0_MergeRequest_PR pipeline #33380 completed with status: 'FAILURE'

CI Report

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

Link to invocation

@yechank-nvidia
Copy link
Copy Markdown
Collaborator Author

/bot run

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #42997 [ run ] triggered by Bot. Commit: 9a3dda0 Link to invocation

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #42997 [ run ] completed with state SUCCESS. Commit: 9a3dda0
/LLM/main/L0_MergeRequest_PR pipeline #33649 completed with status: 'SUCCESS'

CI Report

Link to invocation

Copy link
Copy Markdown
Collaborator

@2ez4bz 2ez4bz left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Approving to unblock.

@2ez4bz 2ez4bz enabled auto-merge (squash) April 13, 2026 16:30
@yechank-nvidia yechank-nvidia added the Multimodal Label for issues & PRs regarding Multimodal related objects label Apr 14, 2026
@yechank-nvidia
Copy link
Copy Markdown
Collaborator Author

/bot run

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #43225 [ run ] triggered by Bot. Commit: 0e80eb3 Link to invocation

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #43225 [ run ] completed with state DISABLED
Freeze main and open the PR merge only after CI is back to healthy https://nvidia.slack.com/archives/C059LSY62BT/p1776141760843319?thread_ts=1775985925.442509&cid=C059LSY62BT

Link to invocation

@yechank-nvidia
Copy link
Copy Markdown
Collaborator Author

/bot run

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #43348 [ run ] triggered by Bot. Commit: 97748f9 Link to invocation

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #43348 [ run ] completed with state FAILURE. Commit: 97748f9
/LLM/main/L0_MergeRequest_PR pipeline #33887 completed with status: 'FAILURE'

CI Report

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

Link to invocation

@yechank-nvidia
Copy link
Copy Markdown
Collaborator Author

/bot run

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #43518 [ run ] triggered by Bot. Commit: 23a5be1 Link to invocation

Copy link
Copy Markdown
Collaborator

@brb-nv brb-nv left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just an informational question. LGTM.

Comment thread tensorrt_llm/_torch/attention_backend/flashinfer.py
@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #43518 [ run ] completed with state SUCCESS. Commit: 23a5be1
/LLM/main/L0_MergeRequest_PR pipeline #34032 completed with status: 'FAILURE'

CI Report

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

Link to invocation

…Nemotron ViT Attention to Flashinfer

Signed-off-by: yechank <161688079+yechank-nvidia@users.noreply.github.com>
Signed-off-by: yechank <161688079+yechank-nvidia@users.noreply.github.com>
Signed-off-by: yechank <161688079+yechank-nvidia@users.noreply.github.com>
@2ez4bz
Copy link
Copy Markdown
Collaborator

2ez4bz commented Apr 15, 2026

/bot run --disable-fail-fast

1 similar comment
@yechank-nvidia
Copy link
Copy Markdown
Collaborator Author

/bot run --disable-fail-fast

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #43614 [ run ] triggered by Bot. Commit: 9ef200d Link to invocation

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #43614 [ run ] completed with state SUCCESS. Commit: 9ef200d
/LLM/main/L0_MergeRequest_PR pipeline #34106 completed with status: 'FAILURE'

CI Report

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

Link to invocation

@yechank-nvidia
Copy link
Copy Markdown
Collaborator Author

/bot run

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #43722 [ run ] triggered by Bot. Commit: 9ef200d Link to invocation

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #43722 [ run ] completed with state FAILURE. Commit: 9ef200d
/LLM/main/L0_MergeRequest_PR pipeline #34205 completed with status: 'FAILURE'

CI Report

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

Link to invocation

@yechank-nvidia
Copy link
Copy Markdown
Collaborator Author

/bot run

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #43758 [ run ] triggered by Bot. Commit: 9ef200d Link to invocation

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #43758 [ run ] completed with state SUCCESS. Commit: 9ef200d
/LLM/main/L0_MergeRequest_PR pipeline #34240 completed with status: 'SUCCESS'

CI Report

Link to invocation

@2ez4bz
Copy link
Copy Markdown
Collaborator

2ez4bz commented Apr 16, 2026

/bot run

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #43794 [ run ] triggered by Bot. Commit: 9ef200d Link to invocation

@2ez4bz
Copy link
Copy Markdown
Collaborator

2ez4bz commented Apr 16, 2026

/bot run

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #43824 [ run ] triggered by Bot. Commit: 9ef200d Link to invocation

@2ez4bz
Copy link
Copy Markdown
Collaborator

2ez4bz commented Apr 16, 2026

/bot run

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #43844 [ run ] triggered by Bot. Commit: 9ef200d Link to invocation

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #43824 [ run ] completed with state ABORTED. Commit: 9ef200d

Link to invocation

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #43844 [ run ] completed with state FAILURE. Commit: 9ef200d
/LLM/main/L0_MergeRequest_PR pipeline #34305 completed with status: 'FAILURE'

CI Report

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

Link to invocation

@yechank-nvidia
Copy link
Copy Markdown
Collaborator Author

/bot run

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #44004 [ run ] triggered by Bot. Commit: 9ef200d Link to invocation

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #44004 [ run ] completed with state SUCCESS. Commit: 9ef200d
/LLM/main/L0_MergeRequest_PR pipeline #34443 completed with status: 'SUCCESS'

CI Report

Link to invocation

@2ez4bz 2ez4bz merged commit 2a0bcb1 into NVIDIA:main Apr 17, 2026
5 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Multimodal Label for issues & PRs regarding Multimodal related objects

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants