Skip to content

[None][feat] Support MLA generation in TrtllmGen attention backend#12606

Merged
yihwang-nv merged 3 commits intoNVIDIA:mainfrom
yihwang-nv:yihwang/trtllm_gen_attn_mla_decode
Apr 3, 2026
Merged

[None][feat] Support MLA generation in TrtllmGen attention backend#12606
yihwang-nv merged 3 commits intoNVIDIA:mainfrom
yihwang-nv:yihwang/trtllm_gen_attn_mla_decode

Conversation

@yihwang-nv
Copy link
Copy Markdown
Collaborator

@yihwang-nv yihwang-nv commented Mar 31, 2026

Summary by CodeRabbit

  • New Features

    • Added Multi-head Latent Attention (MLA) generation support with dedicated decode functionality for enhanced model performance.
  • Bug Fixes

    • Improved attention backend validation and configuration handling for MLA-enabled models.

Description

Test Coverage

PR Checklist

Please review the following before submitting your PR:

  • PR description clearly explains what and why. If using CodeRabbit's summary, please make sure it makes sense.

  • PR Follows TRT-LLM CODING GUIDELINES to the best of your knowledge.

  • Test cases are provided for new code paths (see test instructions)

  • Any new dependencies have been scanned for license and vulnerabilities

  • CODEOWNERS updated if ownership changes

  • Documentation updated as needed

  • Update tava architecture diagram if there is a significant design change in PR.

  • The reviewers assigned automatically/manually are appropriate for the PR.

  • Please check this after reviewing the above items as appropriate for this PR.

GitHub Bot Help

To see a list of available CI bot commands, please comment /bot help.

@yihwang-nv yihwang-nv requested a review from a team as a code owner March 31, 2026 03:16
@yihwang-nv yihwang-nv requested a review from yuxianq March 31, 2026 03:16
@yihwang-nv
Copy link
Copy Markdown
Collaborator Author

/bot run --disable-fail-fast

@yihwang-nv yihwang-nv changed the title [None][feat] Support MLA in TrtllmGen attention backend [None][feat] Support MLA generation in TrtllmGen attention backend Mar 31, 2026
@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #40838 [ run ] triggered by Bot. Commit: 3c168c6 Link to invocation

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai bot commented Mar 31, 2026

📝 Walkthrough

Walkthrough

Changes enable MLA (Multi-head Latent Attention) support in TensorRT-LLM's TRTLLM-Gen attention backend by forcing eligibility checks to pass, adding MLA-specific parameters and generation logic, modifying support validation, and conditionally routing execution between MLA and standard attention paths.

Changes

Cohort / File(s) Summary
Backend Eligibility Control
tensorrt_llm/_torch/attention_backend/trtllm.py
Replaced environment-variable-controlled gating with hard-coded True, removing prior conditional logic and delegating TRTLLM-Gen backend eligibility decision to is_supported() and helix_active checks.
MLA Support Validation & Configuration
tensorrt_llm/_torch/attention_backend/trtllm_gen.py
Updated TrtllmGenSupportChecker.is_supported() to allow MLA only when is_fused_qkv is true; removed head-ratio validation (MAX_HEADS_RATIO_GENERATION); changed ContextWorkspaceBuffers and GenerationWorkspaceBuffers to use @dataclass(slots=True).
MLA Parameters & Dataclass Extensions
tensorrt_llm/_torch/attention_backend/trtllm_gen.py
Added 7 MLA-specific fields to EnqueueParams (kv_lora_rank, qk_nope_head_dim, qk_rope_head_dim, v_head_dim, q_scaling, latent_cache, num_layers) and 3 context-side MLA fields to EnqueueContextParams (k_input, v_input, absorption_mode).
MLA Generation Implementation & Control Flow
tensorrt_llm/_torch/attention_backend/trtllm_gen.py
Added run_mla_generation() method implementing MLA decode via FlashInfer; reworked token counting for AttentionInputType variants; adjusted output reshaping based on MLA state; introduced conditional routing to MLA kernels during context and generation phases.

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~50 minutes

🚥 Pre-merge checks | ✅ 1 | ❌ 2

❌ Failed checks (2 warnings)

Check name Status Explanation Resolution
Description check ⚠️ Warning The PR description contains only the template with instructions and an unchecked checklist; it lacks actual implementation details, rationale, test coverage, and issue context. Add a Description section explaining what MLA generation support entails and why it's needed; add Test Coverage section listing relevant tests; complete and verify the PR Checklist items.
Docstring Coverage ⚠️ Warning Docstring coverage is 60.00% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (1 passed)
Check name Status Explanation
Title check ✅ Passed The title clearly summarizes the main feature: adding MLA generation support to the TrtllmGen attention backend, directly aligned with the changeset.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

🧹 Nitpick comments (2)
tensorrt_llm/_torch/attention_backend/trtllm_gen.py (2)

1366-1367: Add the missing -> None annotation.

This helper is side-effect-only, so the signature should declare that explicitly.

Suggested change
-    def run_mla_generation(self, params: EnqueueGenerationParams):
+    def run_mla_generation(self, params: EnqueueGenerationParams) -> None:

As per coding guidelines, "Static type checking is opt-in by submodule PICs in Python. Always annotate functions with return types, and make the return type None if the function does not return anything."

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tensorrt_llm/_torch/attention_backend/trtllm_gen.py` around lines 1366 -
1367, The method run_mla_generation currently lacks a return type; update its
signature to explicitly declare it returns None (i.e., add -> None) to comply
with the project's type-annotation guideline, keeping the function body
unchanged and ensuring any callers or overrides remain compatible with the new
annotated signature.

1906-1909: Either remove or implement the MLA context branch.

This branch calls backend.run_mla_context(), but FlashInferTrtllmGenAttention in this file only defines run_context(), run_generation(), and run_mla_generation(). Today is_supported() rejects is_mla_enable and not is_fused_qkv, so the branch is dead; if that guard is relaxed later, this becomes an immediate AttributeError.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tensorrt_llm/_torch/attention_backend/trtllm_gen.py` around lines 1906 -
1909, The conditional branch calling backend.run_mla_context(...) is dead now
and will cause AttributeError if the MLA guard is relaxed; either remove the
branch or implement the missing method. Fix by replacing
backend.run_mla_context(ctx_params) with the existing
backend.run_mla_generation(ctx_params) (or add a run_mla_context(self,
ctx_params) implementation on FlashInferTrtllmGenAttention that delegates to
run_mla_generation/run_context as appropriate), and ensure callers use the
is_mla_enable and is_fused_qkv guards consistently.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@tensorrt_llm/_torch/attention_backend/trtllm_gen.py`:
- Around line 1366-1430: run_mla_generation ignores windowed-attention params
and thus runs full attention; add a fast-fail guard at the start of
run_mla_generation that checks
EnqueueGenerationParams.cyclic_attention_window_size and
EnqueueGenerationParams.attention_chunk_size (or computed window_left) and
raises/returns an error if either indicates a finite/windowed attention
configuration so MLA path is not used incorrectly; place the check before
building block_tables or invoking
flashinfer.mla.trtllm_batch_decode_with_kv_cache_mla (use the run_mla_generation
function name and the params.* fields to locate where to add the guard).

In `@tensorrt_llm/_torch/attention_backend/trtllm.py`:
- Around line 33-35: Restore the env-var-based kill switch for
_TRTLLM_ENABLE_TRTLLM_GEN_ATTENTION instead of hard-coding True: use
os.environ.get to read "TRTLLM_ENABLE_TRTLLM_GEN_ATTENTION" and set the variable
based on that value so callers can opt-out via environment; if you want the
feature default-on, set the default string to "1" (or treat missing as enabled)
but keep the environment override logic so _TRTLLM_ENABLE_TRTLLM_GEN_ATTENTION
remains a runtime kill switch.

---

Nitpick comments:
In `@tensorrt_llm/_torch/attention_backend/trtllm_gen.py`:
- Around line 1366-1367: The method run_mla_generation currently lacks a return
type; update its signature to explicitly declare it returns None (i.e., add ->
None) to comply with the project's type-annotation guideline, keeping the
function body unchanged and ensuring any callers or overrides remain compatible
with the new annotated signature.
- Around line 1906-1909: The conditional branch calling
backend.run_mla_context(...) is dead now and will cause AttributeError if the
MLA guard is relaxed; either remove the branch or implement the missing method.
Fix by replacing backend.run_mla_context(ctx_params) with the existing
backend.run_mla_generation(ctx_params) (or add a run_mla_context(self,
ctx_params) implementation on FlashInferTrtllmGenAttention that delegates to
run_mla_generation/run_context as appropriate), and ensure callers use the
is_mla_enable and is_fused_qkv guards consistently.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: acec2432-16ad-441f-a37f-c36dc6d90737

📥 Commits

Reviewing files that changed from the base of the PR and between 6803a91 and 3c168c6.

📒 Files selected for processing (2)
  • tensorrt_llm/_torch/attention_backend/trtllm.py
  • tensorrt_llm/_torch/attention_backend/trtllm_gen.py

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #40838 [ run ] completed with state SUCCESS. Commit: 3c168c6
/LLM/main/L0_MergeRequest_PR pipeline #31846 completed with status: 'FAILURE'

CI Report

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

Link to invocation

@yihwang-nv
Copy link
Copy Markdown
Collaborator Author

/bot run --disable-fail-fast

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #41188 [ run ] triggered by Bot. Commit: ea8d76e Link to invocation

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #41188 [ run ] completed with state SUCCESS. Commit: ea8d76e
/LLM/main/L0_MergeRequest_PR pipeline #32151 completed with status: 'FAILURE'

CI Report

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

Link to invocation

@yihwang-nv
Copy link
Copy Markdown
Collaborator Author

/bot run --disable-fail-fast

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #41230 [ run ] triggered by Bot. Commit: ea8d76e Link to invocation

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #41230 [ run ] completed with state SUCCESS. Commit: ea8d76e
/LLM/main/L0_MergeRequest_PR pipeline #32191 completed with status: 'FAILURE'

CI Report

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

Link to invocation

Signed-off-by: Yihan Wang <yihwang@nvidia.com>
Signed-off-by: Yihan Wang <yihwang@nvidia.com>
@yihwang-nv yihwang-nv force-pushed the yihwang/trtllm_gen_attn_mla_decode branch from ea8d76e to c055a39 Compare April 2, 2026 02:50
@yihwang-nv
Copy link
Copy Markdown
Collaborator Author

/bot run --disable-fail-fast

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #41316 [ run ] triggered by Bot. Commit: c055a39 Link to invocation

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #41316 [ run ] completed with state SUCCESS. Commit: c055a39
/LLM/main/L0_MergeRequest_PR pipeline #32269 completed with status: 'FAILURE'

CI Report

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

Link to invocation

@yihwang-nv
Copy link
Copy Markdown
Collaborator Author

/bot run --disable-fail-fast

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #41444 [ run ] triggered by Bot. Commit: c055a39 Link to invocation

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #41444 [ run ] completed with state SUCCESS. Commit: c055a39
/LLM/main/L0_MergeRequest_PR pipeline #32375 completed with status: 'SUCCESS'

CI Report

Link to invocation

Signed-off-by: Yihan Wang <yihwang@nvidia.com>
@yihwang-nv
Copy link
Copy Markdown
Collaborator Author

/bot run --disable-fail-fast

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #41557 [ run ] triggered by Bot. Commit: c5541d1 Link to invocation

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #41557 [ run ] completed with state SUCCESS. Commit: c5541d1
/LLM/main/L0_MergeRequest_PR pipeline #32468 completed with status: 'FAILURE'

CI Report

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

Link to invocation

@yihwang-nv
Copy link
Copy Markdown
Collaborator Author

/bot run --disable-fail-fast

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #41623 [ run ] triggered by Bot. Commit: c5541d1 Link to invocation

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #41623 [ run ] completed with state SUCCESS. Commit: c5541d1
/LLM/main/L0_MergeRequest_PR pipeline #32531 completed with status: 'SUCCESS'

CI Report

Link to invocation

@yihwang-nv yihwang-nv merged commit c00c982 into NVIDIA:main Apr 3, 2026
5 checks passed
govind-ramnarayan pushed a commit to nv-auto-deploy/TensorRT-LLM that referenced this pull request Apr 6, 2026
yufeiwu-nv pushed a commit to yufeiwu-nv/TensorRT-LLM that referenced this pull request Apr 7, 2026
karen-sy pushed a commit to karen-sy/TensorRT-LLM that referenced this pull request Apr 7, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants