Skip to content

[#12634][feat] AutoDeploy: Support rank 256 MLA in flashinfer_mla#12519

Merged
suyoggupta merged 2 commits intoNVIDIA:mainfrom
nv-auto-deploy:bala/flashinfer-mla-v2
Apr 11, 2026
Merged

[#12634][feat] AutoDeploy: Support rank 256 MLA in flashinfer_mla#12519
suyoggupta merged 2 commits intoNVIDIA:mainfrom
nv-auto-deploy:bala/flashinfer-mla-v2

Conversation

@bmarimuthu-nv
Copy link
Copy Markdown
Collaborator

@bmarimuthu-nv bmarimuthu-nv commented Mar 25, 2026

Summary by CodeRabbit

  • New Features

    • Added FlashInfer TRTLLM MLA operator with paged KV-cache support and a GPU append kernel for optimized inference on Blackwell-capable devices.
    • Implemented automatic MLA backend selection with runtime resolution and safe fallbacks based on hardware and model dims.
  • Documentation

    • Updated MLA operator catalog with the new operator entry.
  • Tests

    • Added comprehensive tests covering correctness, kernel dispatch, cache append, and end-to-end prefill/decode workflows.

Description

Test Coverage

PR Checklist

Please review the following before submitting your PR:

  • PR description clearly explains what and why. If using CodeRabbit's summary, please make sure it makes sense.

  • PR Follows TRT-LLM CODING GUIDELINES to the best of your knowledge.

  • Test cases are provided for new code paths (see test instructions)

  • Any new dependencies have been scanned for license and vulnerabilities

  • CODEOWNERS updated if ownership changes

  • Documentation updated as needed

  • Update tava architecture diagram if there is a significant design change in PR.

  • The reviewers assigned automatically/manually are appropriate for the PR.

  • Please check this after reviewing the above items as appropriate for this PR.

GitHub Bot Help

To see a list of available CI bot commands, please comment /bot help.

@bmarimuthu-nv bmarimuthu-nv marked this pull request as ready for review March 31, 2026 16:26
@bmarimuthu-nv bmarimuthu-nv requested a review from a team as a code owner March 31, 2026 16:26
@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai bot commented Mar 31, 2026

ℹ️ Recent review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: f0fa5946-0b13-4567-a164-14e17f15c6b6

📥 Commits

Reviewing files that changed from the base of the PR and between 6ac5c15 and 7e2c0ef.

📒 Files selected for processing (9)
  • examples/auto_deploy/model_registry/configs/mistral_small_4_119b.yaml
  • tensorrt_llm/_torch/auto_deploy/compile/piecewise_utils.py
  • tensorrt_llm/_torch/auto_deploy/custom_ops/README.md
  • tensorrt_llm/_torch/auto_deploy/custom_ops/mla/__init__.py
  • tensorrt_llm/_torch/auto_deploy/custom_ops/mla/flashinfer_trtllm_mla.py
  • tensorrt_llm/_torch/auto_deploy/custom_ops/mla/triton_mla_cache_append.py
  • tensorrt_llm/_torch/auto_deploy/transform/library/kvcache.py
  • tests/unittest/auto_deploy/singlegpu/custom_ops/mla/test_flashinfer_trtllm_mla_op.py
  • tests/unittest/auto_deploy/singlegpu/transformations/library/test_kv_cache.py

📝 Walkthrough

Walkthrough

Adds a TRTLLM FlashInfer MLA custom operator with paged KV cache, integrates it into backend resolution and graph partitioning, provides a Triton cache-append kernel, and adds comprehensive unit tests and configuration changes.

Changes

Cohort / File(s) Summary
MLA custom op implementation
tensorrt_llm/_torch/auto_deploy/custom_ops/mla/flashinfer_trtllm_mla.py
New custom operator auto_deploy::flashinfer_trtllm_mla_with_cache (mutates paged MLA cache), fake meta implementation, and FlashInferTrtllmMLAAttention descriptor registered in the AttentionRegistry.
Module exports & docs
tensorrt_llm/_torch/auto_deploy/custom_ops/mla/__init__.py, tensorrt_llm/_torch/auto_deploy/custom_ops/README.md
Exports FlashInferTrtllmMLAAttention and flashinfer_trtllm_mla_with_cache; documents the new operator in the MLA README.
Paged cache append kernel
tensorrt_llm/_torch/auto_deploy/custom_ops/mla/triton_mla_cache_append.py
Adds Triton JIT kernel and Python frontend append_paged_mla_cache to append compressed KV + kpe into combined paged MLA cache layout.
Backend resolution & graph transform
tensorrt_llm/_torch/auto_deploy/transform/library/kvcache.py, tensorrt_llm/_torch/auto_deploy/compile/piecewise_utils.py
Adds MLA-specific backend resolution: examines MLA dims and device compute capability to choose between flashinfer_mla, flashinfer_trtllm_mla, or torch_mla; updates dynamic-op registry to include auto_deploy::flashinfer_trtllm_mla_with_cache.
Tests — operator & transforms
tests/unittest/auto_deploy/singlegpu/custom_ops/mla/test_flashinfer_trtllm_mla_op.py, tests/unittest/auto_deploy/singlegpu/transformations/library/test_kv_cache.py
New CUDA-heavy unit tests for the custom op (decode, prefill, mixed batches, page-boundaries, end-to-end flow) and unit tests for backend resolution logic with mocked device capabilities.
Example config change
examples/auto_deploy/model_registry/configs/mistral_small_4_119b.yaml
Changed transforms.insert_cached_mla_attention.backend from torch_mla to flashinfer_mla.

Sequence Diagram

sequenceDiagram
    participant Input as Query / Incoming KV
    participant Host as Host Metadata
    participant Cache as Paged MLA Cache
    participant Resolver as Backend Resolver
    participant FlashInfer as FlashInfer Kernel
    participant TorchRef as Torch Reference
    participant Output as Attention Output

    Input->>Host: Provide compressed_kv, kpe, input_pos, batch_info_host
    Host->>Cache: append_paged_mla_cache(compressed_kv, kpe, cu_seqlen, cu_num_pages, cache_loc, input_pos)
    Host->>Resolver: resolve_backend_for_node(requested_backend, mla_dims, device_cc)
    Resolver->>Cache: decide compute path (flashinfer_trtllm_mla or torch_mla)
    
    alt FlashInfer TRTLLM path
        Cache->>FlashInfer: build block_tables, run trtllm_batch_decode_with_kv_cache_mla
        FlashInfer->>Output: project latent to final output (w_v)
    else Torch fallback path
        Cache->>TorchRef: gather cached pages per-sequence
        TorchRef->>TorchRef: compute attention (prefill/mixed/decode) with causal masking as needed
        TorchRef->>Output: produce attention output
    end
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~60 minutes

🚥 Pre-merge checks | ✅ 1 | ❌ 2

❌ Failed checks (2 warnings)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 32.50% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
Description check ⚠️ Warning The PR description is empty/template-only with no actual implementation details, objectives, test coverage, or checklist items filled in. Fill in the Description section explaining the feature, Test Coverage section listing relevant tests, and complete the PR Checklist items.
✅ Passed checks (1 passed)
Check name Status Explanation
Title check ✅ Passed The title accurately describes the main change: adding Blackwell MLA backend (flashinfer_trtllm_mla) support with rank 256 constraints.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Comment @coderabbitai help to get the list of available commands and usage tips.

@bmarimuthu-nv bmarimuthu-nv force-pushed the bala/flashinfer-mla-v2 branch from 9011c17 to c1d3f76 Compare March 31, 2026 20:08
@bmarimuthu-nv bmarimuthu-nv requested a review from a team as a code owner March 31, 2026 22:22
@bmarimuthu-nv bmarimuthu-nv requested a review from achartier March 31, 2026 22:22
@bmarimuthu-nv bmarimuthu-nv changed the title [None][feat] Add Blackwell MLA backend selection [None][feat] AutoDeploy: Support rank 256 MLA in flashinfer_mla Apr 1, 2026
@bmarimuthu-nv
Copy link
Copy Markdown
Collaborator Author

@coderabbitai summary

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai bot commented Apr 1, 2026

✅ Actions performed

Summary regeneration triggered.

@bmarimuthu-nv bmarimuthu-nv changed the title [None][feat] AutoDeploy: Support rank 256 MLA in flashinfer_mla [#12634][feat] AutoDeploy: Support rank 256 MLA in flashinfer_mla Apr 7, 2026
Signed-off-by: Balamurugan Marimuthu <246387390+bmarimuthu-nv@users.noreply.github.com>
@bmarimuthu-nv bmarimuthu-nv force-pushed the bala/flashinfer-mla-v2 branch from f9b28f9 to 96d6174 Compare April 8, 2026 19:22
@bmarimuthu-nv
Copy link
Copy Markdown
Collaborator Author

/bot run --disable-fail-fast --extra-stage "DGX_B200-4_GPUs-AutoDeploy-1, DGX_H100-4_GPUs-AutoDeploy-1"

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #42388 [ run ] triggered by Bot. Commit: 96d6174 Link to invocation

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #42388 [ run ] completed with state SUCCESS. Commit: 96d6174
/LLM/main/L0_MergeRequest_PR pipeline #33165 completed with status: 'FAILURE'

CI Report

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

Link to invocation

Mock _is_blackwell_decode_supported to False so the test always
exercises the reference decode path, avoiding FlashInfer rejecting
block_size=1 on Blackwell where the decode kernel path is taken.

Signed-off-by: Balamurugan Marimuthu <246387390+bmarimuthu-nv@users.noreply.github.com>
@bmarimuthu-nv
Copy link
Copy Markdown
Collaborator Author

/bot run --disable-fail-fast --extra-stage "DGX_B200-4_GPUs-AutoDeploy-1, DGX_H100-4_GPUs-AutoDeploy-1"

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #42721 [ run ] triggered by Bot. Commit: 131e7b0 Link to invocation

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #42721 [ run ] completed with state SUCCESS. Commit: 131e7b0
/LLM/main/L0_MergeRequest_PR pipeline #33408 completed with status: 'SUCCESS'
Pipeline passed with automatic retried tests. Check the rerun report for details.

CI Report

Link to invocation

@suyoggupta suyoggupta self-requested a review April 11, 2026 10:22
@suyoggupta suyoggupta merged commit 8a56f4b into NVIDIA:main Apr 11, 2026
5 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants