[None][feat] Support MLA generation in TrtllmGen attention backend#12606
[None][feat] Support MLA generation in TrtllmGen attention backend#12606yihwang-nv merged 3 commits intoNVIDIA:mainfrom
Conversation
|
/bot run --disable-fail-fast |
|
PR_Github #40838 [ run ] triggered by Bot. Commit: |
📝 WalkthroughWalkthroughChanges enable MLA (Multi-head Latent Attention) support in TensorRT-LLM's TRTLLM-Gen attention backend by forcing eligibility checks to pass, adding MLA-specific parameters and generation logic, modifying support validation, and conditionally routing execution between MLA and standard attention paths. Changes
Estimated code review effort🎯 4 (Complex) | ⏱️ ~50 minutes 🚥 Pre-merge checks | ✅ 1 | ❌ 2❌ Failed checks (2 warnings)
✅ Passed checks (1 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Comment |
There was a problem hiding this comment.
Actionable comments posted: 2
🧹 Nitpick comments (2)
tensorrt_llm/_torch/attention_backend/trtllm_gen.py (2)
1366-1367: Add the missing-> Noneannotation.This helper is side-effect-only, so the signature should declare that explicitly.
Suggested change
- def run_mla_generation(self, params: EnqueueGenerationParams): + def run_mla_generation(self, params: EnqueueGenerationParams) -> None:As per coding guidelines, "Static type checking is opt-in by submodule PICs in Python. Always annotate functions with return types, and make the return type
Noneif the function does not return anything."🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tensorrt_llm/_torch/attention_backend/trtllm_gen.py` around lines 1366 - 1367, The method run_mla_generation currently lacks a return type; update its signature to explicitly declare it returns None (i.e., add -> None) to comply with the project's type-annotation guideline, keeping the function body unchanged and ensuring any callers or overrides remain compatible with the new annotated signature.
1906-1909: Either remove or implement the MLA context branch.This branch calls
backend.run_mla_context(), butFlashInferTrtllmGenAttentionin this file only definesrun_context(),run_generation(), andrun_mla_generation(). Todayis_supported()rejectsis_mla_enable and not is_fused_qkv, so the branch is dead; if that guard is relaxed later, this becomes an immediateAttributeError.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tensorrt_llm/_torch/attention_backend/trtllm_gen.py` around lines 1906 - 1909, The conditional branch calling backend.run_mla_context(...) is dead now and will cause AttributeError if the MLA guard is relaxed; either remove the branch or implement the missing method. Fix by replacing backend.run_mla_context(ctx_params) with the existing backend.run_mla_generation(ctx_params) (or add a run_mla_context(self, ctx_params) implementation on FlashInferTrtllmGenAttention that delegates to run_mla_generation/run_context as appropriate), and ensure callers use the is_mla_enable and is_fused_qkv guards consistently.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@tensorrt_llm/_torch/attention_backend/trtllm_gen.py`:
- Around line 1366-1430: run_mla_generation ignores windowed-attention params
and thus runs full attention; add a fast-fail guard at the start of
run_mla_generation that checks
EnqueueGenerationParams.cyclic_attention_window_size and
EnqueueGenerationParams.attention_chunk_size (or computed window_left) and
raises/returns an error if either indicates a finite/windowed attention
configuration so MLA path is not used incorrectly; place the check before
building block_tables or invoking
flashinfer.mla.trtllm_batch_decode_with_kv_cache_mla (use the run_mla_generation
function name and the params.* fields to locate where to add the guard).
In `@tensorrt_llm/_torch/attention_backend/trtllm.py`:
- Around line 33-35: Restore the env-var-based kill switch for
_TRTLLM_ENABLE_TRTLLM_GEN_ATTENTION instead of hard-coding True: use
os.environ.get to read "TRTLLM_ENABLE_TRTLLM_GEN_ATTENTION" and set the variable
based on that value so callers can opt-out via environment; if you want the
feature default-on, set the default string to "1" (or treat missing as enabled)
but keep the environment override logic so _TRTLLM_ENABLE_TRTLLM_GEN_ATTENTION
remains a runtime kill switch.
---
Nitpick comments:
In `@tensorrt_llm/_torch/attention_backend/trtllm_gen.py`:
- Around line 1366-1367: The method run_mla_generation currently lacks a return
type; update its signature to explicitly declare it returns None (i.e., add ->
None) to comply with the project's type-annotation guideline, keeping the
function body unchanged and ensuring any callers or overrides remain compatible
with the new annotated signature.
- Around line 1906-1909: The conditional branch calling
backend.run_mla_context(...) is dead now and will cause AttributeError if the
MLA guard is relaxed; either remove the branch or implement the missing method.
Fix by replacing backend.run_mla_context(ctx_params) with the existing
backend.run_mla_generation(ctx_params) (or add a run_mla_context(self,
ctx_params) implementation on FlashInferTrtllmGenAttention that delegates to
run_mla_generation/run_context as appropriate), and ensure callers use the
is_mla_enable and is_fused_qkv guards consistently.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Run ID: acec2432-16ad-441f-a37f-c36dc6d90737
📒 Files selected for processing (2)
tensorrt_llm/_torch/attention_backend/trtllm.pytensorrt_llm/_torch/attention_backend/trtllm_gen.py
|
PR_Github #40838 [ run ] completed with state
|
|
/bot run --disable-fail-fast |
|
PR_Github #41188 [ run ] triggered by Bot. Commit: |
|
PR_Github #41188 [ run ] completed with state
|
|
/bot run --disable-fail-fast |
|
PR_Github #41230 [ run ] triggered by Bot. Commit: |
|
PR_Github #41230 [ run ] completed with state
|
Signed-off-by: Yihan Wang <yihwang@nvidia.com>
Signed-off-by: Yihan Wang <yihwang@nvidia.com>
ea8d76e to
c055a39
Compare
|
/bot run --disable-fail-fast |
|
PR_Github #41316 [ run ] triggered by Bot. Commit: |
|
PR_Github #41316 [ run ] completed with state
|
|
/bot run --disable-fail-fast |
|
PR_Github #41444 [ run ] triggered by Bot. Commit: |
|
PR_Github #41444 [ run ] completed with state |
Signed-off-by: Yihan Wang <yihwang@nvidia.com>
|
/bot run --disable-fail-fast |
|
PR_Github #41557 [ run ] triggered by Bot. Commit: |
|
PR_Github #41557 [ run ] completed with state
|
|
/bot run --disable-fail-fast |
|
PR_Github #41623 [ run ] triggered by Bot. Commit: |
|
PR_Github #41623 [ run ] completed with state |
…VIDIA#12606) Signed-off-by: Yihan Wang <yihwang@nvidia.com>
…VIDIA#12606) Signed-off-by: Yihan Wang <yihwang@nvidia.com>
…VIDIA#12606) Signed-off-by: Yihan Wang <yihwang@nvidia.com>
Summary by CodeRabbit
New Features
Bug Fixes
Description
Test Coverage
PR Checklist
Please review the following before submitting your PR:
PR description clearly explains what and why. If using CodeRabbit's summary, please make sure it makes sense.
PR Follows TRT-LLM CODING GUIDELINES to the best of your knowledge.
Test cases are provided for new code paths (see test instructions)
Any new dependencies have been scanned for license and vulnerabilities
CODEOWNERS updated if ownership changes
Documentation updated as needed
Update tava architecture diagram if there is a significant design change in PR.
The reviewers assigned automatically/manually are appropriate for the PR.
Please check this after reviewing the above items as appropriate for this PR.
GitHub Bot Help
To see a list of available CI bot commands, please comment
/bot help.