[None][refactor] VisualGen attention backend refactor#12663
[None][refactor] VisualGen attention backend refactor#12663chang-l merged 3 commits intoNVIDIA:mainfrom
Conversation
📝 WalkthroughWalkthroughThis PR introduces a standardized Changes
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~25 minutes 🚥 Pre-merge checks | ✅ 2 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Comment |
There was a problem hiding this comment.
Actionable comments posted: 2
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (3)
tensorrt_llm/_torch/visual_gen/attention_backend/trtllm.py (2)
1-1:⚠️ Potential issue | 🟡 MinorUpdate copyright year to 2026.
Per coding guidelines, the copyright year should reflect the latest meaningful modification.
Proposed fix
-# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.As per coding guidelines: "All TensorRT-LLM Open Source Software code should contain an NVIDIA copyright header that includes the year of its latest meaningful modification."
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tensorrt_llm/_torch/visual_gen/attention_backend/trtllm.py` at line 1, Update the SPDX copyright header in the file by changing the year from 2025 to 2026: locate the line beginning with "# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved." (the SPDX header in trtllm.py) and revise the year to 2026 so it reads 2026.
242-245:⚠️ Potential issue | 🟡 MinorHandle mismatched
k/vNone states.If only one of
korvisNone(but not both), the code falls through to_concat_qkvwhich will fail when accessing.view()onNone. Consider adding validation or explicitly handling this edge case.Proposed fix
+ if (k is None) != (v is None): + raise ValueError("k and v must both be None (fused QKV) or both provided") + if k is None and v is None: qkv = q.reshape(batch_size * seq_len, -1) else: qkv = self._concat_qkv(q, k, v, batch_size, seq_len, kv_seq_len)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tensorrt_llm/_torch/visual_gen/attention_backend/trtllm.py` around lines 242 - 245, The code currently only handles both k and v being None or both present; add a guard in the block using q, k, v (around the qkv assignment) to detect mismatched None states (when (k is None) != (v is None)) and handle it: either raise a clear ValueError mentioning _concat_qkv and the mismatched k/v state, or normalize by setting the missing tensor to the present one before calling self._concat_qkv; ensure the check references k, v, self._concat_qkv, and q.reshape so the intent and location are clear.tensorrt_llm/_torch/visual_gen/attention_backend/vanilla.py (1)
1-1:⚠️ Potential issue | 🟡 MinorUpdate copyright year to 2026.
Per coding guidelines, the NVIDIA copyright header should include the year of the latest meaningful modification. Since this file is being modified in 2026, the copyright year should be updated.
Proposed fix
-# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.As per coding guidelines: "All TensorRT-LLM Open Source Software code should contain an NVIDIA copyright header that includes the year of its latest meaningful modification."
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tensorrt_llm/_torch/visual_gen/attention_backend/vanilla.py` at line 1, Update the copyright year in the SPDX header: locate the SPDX comment line that currently reads "# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved." and change the year from 2025 to 2026 so the header reflects the latest modification year.
🧹 Nitpick comments (3)
tensorrt_llm/_torch/visual_gen/modules/attention.py (1)
254-254: Call the backend module via__call__, notforward()directly.Directly invoking
self.attn.forward(...)bypassesnn.Module.__call__, so forward hooks and wrappers on the backend never run. Usingself.attn(...)preserves the normal PyTorch dispatch path and still reaches the same implementation.Suggested fix
- out = self.attn.forward(q=q, k=k, v=v, **kwargs) + out = self.attn(q=q, k=k, v=v, **kwargs)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tensorrt_llm/_torch/visual_gen/modules/attention.py` at line 254, Replace the direct backend forward invocation with a normal module call so PyTorch hooks/wrappers run: change the call site that currently does out = self.attn.forward(q=q, k=k, v=v, **kwargs) to use the module __call__ (e.g., out = self.attn(q=q, k=k, v=v, **kwargs)) in the method where self.attn is used so that nn.Module.__call__ dispatch executes; ensure any keyword/positional arguments remain identical to preserve behavior.tests/unittest/_torch/visual_gen/multi_gpu/test_ulysses_attention.py (1)
455-466: Keep the new backend test double typed like the interface it is validating.
_FusedVanillaAttention.forward()drops the tensor and return annotations right where this class starts modeling the sharedAttentionBackendcontract. Adding them here makes interface drift visible to type checkers instead of silently collapsing toAny.Suggested fix
- def forward(self, q, k=None, v=None, **kwargs): + def forward( + self, + q: torch.Tensor, + k: torch.Tensor | None = None, + v: torch.Tensor | None = None, + **kwargs, + ) -> torch.Tensor:As per coding guidelines, "Static type checking is opt-in by submodule PICs in Python. Always annotate functions with return types, and make the return type
Noneif the function does not return anything."🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/unittest/_torch/visual_gen/multi_gpu/test_ulysses_attention.py` around lines 455 - 466, The new test backend _FusedVanillaAttention must preserve the same typed signature as the AttentionBackend interface: add explicit parameter and return type annotations on _FusedVanillaAttention.forward (e.g. q: torch.Tensor, k: Optional[torch.Tensor]=None, v: Optional[torch.Tensor]=None, **kwargs) -> torch.Tensor (or -> None if the interface forward returns None) so static type checkers catch drift; ensure you import Optional and torch.Tensor and match the exact return type declared on AttentionBackend.forward.tensorrt_llm/_torch/visual_gen/attention_backend/parallel.py (1)
166-168:support_fused_qkv()is broader than this wrapper's public API.
UlyssesAttention.forward()still requires separateq,k, andvand only builds fused QKV internally. ReturningTruehere makes the wrapper look interchangeable with backends that actually accept fused inputs, which is not true today. Either document this as an internal optimization hint or make the wrapper accept fused inputs as well.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tensorrt_llm/_torch/visual_gen/attention_backend/parallel.py` around lines 166 - 168, The method support_fused_qkv() currently returns True but the wrapper’s public API (UlyssesAttention.forward) still requires separate q, k, v; change support_fused_qkv() to return False so the wrapper is not advertised as accepting fused inputs, or alternatively implement fused-input handling in UlyssesAttention.forward (accept a single fused qkv tensor, split it into q/k/v before existing logic) and update any input checks; prefer the first option (set support_fused_qkv() -> False) unless you also add fused-input parsing in UlyssesAttention.forward.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@tensorrt_llm/_torch/visual_gen/attention_backend/interface.py`:
- Around line 48-55: Update the abstract base class method signature for
AttentionBackend.forward to accept Optional[torch.Tensor] for k and v (matching
implementations like TRTLLMAttention and call sites such as
UlyssesAttention._forward_fused which pass k=None, v=None), and split the
one-line ellipsis stubs into their own lines (replace single-line "->
torch.Tensor: ..." with a proper signature line followed by a separate line
containing "..." for both forward and the preferred_layout property) so the ABC
matches implementations and fixes the E704 style violation.
In `@tensorrt_llm/_torch/visual_gen/attention_backend/trtllm.py`:
- Line 148: The MRO breaks because AttentionBackend.__init__ in the main backend
doesn't call super().__init__, so nn.Module.__init__ is never run when creating
TrtllmAttention (which inherits BaseTrtllmAttention, AttentionBackend); fix by
updating the main backend AttentionBackend.__init__ to call
super().__init__(**kwargs) (preserving existing logic) so the init chain reaches
visual_gen AttentionBackend -> nn.Module, or alternatively ensure
TrtllmAttention.__init__ explicitly invokes nn.Module.__init__(self) before
other inits; reference AttentionBackend.__init__, TrtllmAttention, and
BaseTrtllmAttention when making the change.
---
Outside diff comments:
In `@tensorrt_llm/_torch/visual_gen/attention_backend/trtllm.py`:
- Line 1: Update the SPDX copyright header in the file by changing the year from
2025 to 2026: locate the line beginning with "# SPDX-FileCopyrightText:
Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved." (the
SPDX header in trtllm.py) and revise the year to 2026 so it reads 2026.
- Around line 242-245: The code currently only handles both k and v being None
or both present; add a guard in the block using q, k, v (around the qkv
assignment) to detect mismatched None states (when (k is None) != (v is None))
and handle it: either raise a clear ValueError mentioning _concat_qkv and the
mismatched k/v state, or normalize by setting the missing tensor to the present
one before calling self._concat_qkv; ensure the check references k, v,
self._concat_qkv, and q.reshape so the intent and location are clear.
In `@tensorrt_llm/_torch/visual_gen/attention_backend/vanilla.py`:
- Line 1: Update the copyright year in the SPDX header: locate the SPDX comment
line that currently reads "# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA
CORPORATION & AFFILIATES. All rights reserved." and change the year from 2025 to
2026 so the header reflects the latest modification year.
---
Nitpick comments:
In `@tensorrt_llm/_torch/visual_gen/attention_backend/parallel.py`:
- Around line 166-168: The method support_fused_qkv() currently returns True but
the wrapper’s public API (UlyssesAttention.forward) still requires separate q,
k, v; change support_fused_qkv() to return False so the wrapper is not
advertised as accepting fused inputs, or alternatively implement fused-input
handling in UlyssesAttention.forward (accept a single fused qkv tensor, split it
into q/k/v before existing logic) and update any input checks; prefer the first
option (set support_fused_qkv() -> False) unless you also add fused-input
parsing in UlyssesAttention.forward.
In `@tensorrt_llm/_torch/visual_gen/modules/attention.py`:
- Line 254: Replace the direct backend forward invocation with a normal module
call so PyTorch hooks/wrappers run: change the call site that currently does out
= self.attn.forward(q=q, k=k, v=v, **kwargs) to use the module __call__ (e.g.,
out = self.attn(q=q, k=k, v=v, **kwargs)) in the method where self.attn is used
so that nn.Module.__call__ dispatch executes; ensure any keyword/positional
arguments remain identical to preserve behavior.
In `@tests/unittest/_torch/visual_gen/multi_gpu/test_ulysses_attention.py`:
- Around line 455-466: The new test backend _FusedVanillaAttention must preserve
the same typed signature as the AttentionBackend interface: add explicit
parameter and return type annotations on _FusedVanillaAttention.forward (e.g. q:
torch.Tensor, k: Optional[torch.Tensor]=None, v: Optional[torch.Tensor]=None,
**kwargs) -> torch.Tensor (or -> None if the interface forward returns None) so
static type checkers catch drift; ensure you import Optional and torch.Tensor
and match the exact return type declared on AttentionBackend.forward.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Run ID: 27942b7e-f67f-4d22-ba17-6501dcd38b6d
📒 Files selected for processing (11)
tensorrt_llm/_torch/visual_gen/attention_backend/__init__.pytensorrt_llm/_torch/visual_gen/attention_backend/flash_attn4.pytensorrt_llm/_torch/visual_gen/attention_backend/interface.pytensorrt_llm/_torch/visual_gen/attention_backend/parallel.pytensorrt_llm/_torch/visual_gen/attention_backend/trtllm.pytensorrt_llm/_torch/visual_gen/attention_backend/utils.pytensorrt_llm/_torch/visual_gen/attention_backend/vanilla.pytensorrt_llm/_torch/visual_gen/models/flux/attention.pytensorrt_llm/_torch/visual_gen/models/wan/transformer_wan.pytensorrt_llm/_torch/visual_gen/modules/attention.pytests/unittest/_torch/visual_gen/multi_gpu/test_ulysses_attention.py
|
/bot run --disable-fail-fast |
|
PR_Github #41238 [ run ] triggered by Bot. Commit: |
|
/bot kill |
|
PR_Github #41239 [ kill ] triggered by Bot. Commit: |
|
PR_Github #41238 [ run ] completed with state |
|
PR_Github #41239 [ kill ] completed with state |
|
/bot run --disable-fail-fast |
|
PR_Github #41242 [ run ] triggered by Bot. Commit: |
|
PR_Github #41242 [ run ] completed with state
|
|
/bot run --disable-fail-fast |
|
PR_Github #41290 [ run ] triggered by Bot. Commit: |
|
PR_Github #41290 [ run ] completed with state
|
Signed-off-by: Shreyas Misra <shreyasm@nvidia.com>
Signed-off-by: Shreyas Misra <shreyasm@nvidia.com>
4b48264 to
4ec41ad
Compare
|
/bot run --disable-fail-fast |
|
PR_Github #41443 [ run ] triggered by Bot. Commit: |
chang-l
left a comment
There was a problem hiding this comment.
In general, LGTM with minor comments
Signed-off-by: Shreyas Misra <shreyasm@nvidia.com>
|
/bot kill |
|
PR_Github #41467 [ kill ] triggered by Bot. Commit: |
|
PR_Github #41467 [ kill ] completed with state |
|
/bot run --disable-fail-fast |
|
PR_Github #41475 [ run ] triggered by Bot. Commit: |
|
PR_Github #41475 [ run ] completed with state
|
|
/bot run --disable-fail-fast |
|
PR_Github #41512 [ run ] triggered by Bot. Commit: |
|
PR_Github #41512 [ run ] completed with state
|
|
/bot help |
GitHub Bot Help
Provide a user friendly way for developers to interact with a Jenkins server. Run See details below for each supported subcommand. Details
Launch build/test pipelines. All previously running jobs will be killed.
kill
Kill all running builds associated with pull request. skip
Skip testing for latest commit on pull request. reuse-pipeline
Reuse a previous pipeline to validate current commit. This action will also kill all currently running builds associated with the pull request. IMPORTANT NOTE: This is dangerous since lack of user care and validation can cause top of tree to break. |
|
/bot skip --comment "Failing autodeploy DS test is unrelated to this PRs changes" |
|
PR_Github #41515 [ skip ] triggered by Bot. Commit: |
|
PR_Github #41515 [ skip ] completed with state |
Signed-off-by: Shreyas Misra <shreyasm@nvidia.com>
Signed-off-by: Shreyas Misra <shreyasm@nvidia.com>
Summary by CodeRabbit
Release Notes
AttentionBackendinterface to unify behavior across attention implementations.Description
Test Coverage
PR Checklist
Please review the following before submitting your PR:
PR description clearly explains what and why. If using CodeRabbit's summary, please make sure it makes sense.
PR Follows TRT-LLM CODING GUIDELINES to the best of your knowledge.
Test cases are provided for new code paths (see test instructions)
Any new dependencies have been scanned for license and vulnerabilities
CODEOWNERS updated if ownership changes
Documentation updated as needed
Update tava architecture diagram if there is a significant design change in PR.
The reviewers assigned automatically/manually are appropriate for the PR.
Please check this after reviewing the above items as appropriate for this PR.
GitHub Bot Help
To see a list of available CI bot commands, please comment
/bot help.