[#12634][feat] AutoDeploy: Support rank 256 MLA in flashinfer_mla#12519
[#12634][feat] AutoDeploy: Support rank 256 MLA in flashinfer_mla#12519suyoggupta merged 2 commits intoNVIDIA:mainfrom
Conversation
ℹ️ Recent review info⚙️ Run configurationConfiguration used: Path: .coderabbit.yaml Review profile: CHILL Plan: Pro Run ID: 📒 Files selected for processing (9)
📝 WalkthroughWalkthroughAdds a TRTLLM FlashInfer MLA custom operator with paged KV cache, integrates it into backend resolution and graph partitioning, provides a Triton cache-append kernel, and adds comprehensive unit tests and configuration changes. Changes
Sequence DiagramsequenceDiagram
participant Input as Query / Incoming KV
participant Host as Host Metadata
participant Cache as Paged MLA Cache
participant Resolver as Backend Resolver
participant FlashInfer as FlashInfer Kernel
participant TorchRef as Torch Reference
participant Output as Attention Output
Input->>Host: Provide compressed_kv, kpe, input_pos, batch_info_host
Host->>Cache: append_paged_mla_cache(compressed_kv, kpe, cu_seqlen, cu_num_pages, cache_loc, input_pos)
Host->>Resolver: resolve_backend_for_node(requested_backend, mla_dims, device_cc)
Resolver->>Cache: decide compute path (flashinfer_trtllm_mla or torch_mla)
alt FlashInfer TRTLLM path
Cache->>FlashInfer: build block_tables, run trtllm_batch_decode_with_kv_cache_mla
FlashInfer->>Output: project latent to final output (w_v)
else Torch fallback path
Cache->>TorchRef: gather cached pages per-sequence
TorchRef->>TorchRef: compute attention (prefill/mixed/decode) with causal masking as needed
TorchRef->>Output: produce attention output
end
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes 🚥 Pre-merge checks | ✅ 1 | ❌ 2❌ Failed checks (2 warnings)
✅ Passed checks (1 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Comment |
9011c17 to
c1d3f76
Compare
|
@coderabbitai summary |
✅ Actions performedSummary regeneration triggered. |
Signed-off-by: Balamurugan Marimuthu <246387390+bmarimuthu-nv@users.noreply.github.com>
f9b28f9 to
96d6174
Compare
|
/bot run --disable-fail-fast --extra-stage "DGX_B200-4_GPUs-AutoDeploy-1, DGX_H100-4_GPUs-AutoDeploy-1" |
|
PR_Github #42388 [ run ] triggered by Bot. Commit: |
|
PR_Github #42388 [ run ] completed with state
|
Mock _is_blackwell_decode_supported to False so the test always exercises the reference decode path, avoiding FlashInfer rejecting block_size=1 on Blackwell where the decode kernel path is taken. Signed-off-by: Balamurugan Marimuthu <246387390+bmarimuthu-nv@users.noreply.github.com>
|
/bot run --disable-fail-fast --extra-stage "DGX_B200-4_GPUs-AutoDeploy-1, DGX_H100-4_GPUs-AutoDeploy-1" |
|
PR_Github #42721 [ run ] triggered by Bot. Commit: |
|
PR_Github #42721 [ run ] completed with state |
Summary by CodeRabbit
New Features
Documentation
Tests
Description
Test Coverage
PR Checklist
Please review the following before submitting your PR:
PR description clearly explains what and why. If using CodeRabbit's summary, please make sure it makes sense.
PR Follows TRT-LLM CODING GUIDELINES to the best of your knowledge.
Test cases are provided for new code paths (see test instructions)
Any new dependencies have been scanned for license and vulnerabilities
CODEOWNERS updated if ownership changes
Documentation updated as needed
Update tava architecture diagram if there is a significant design change in PR.
The reviewers assigned automatically/manually are appropriate for the PR.
Please check this after reviewing the above items as appropriate for this PR.
GitHub Bot Help
To see a list of available CI bot commands, please comment
/bot help.