[None][perf] Cache FlashMLA tile-scheduler metadata across attention layers#12161
[None][perf] Cache FlashMLA tile-scheduler metadata across attention layers#12161bobboli merged 7 commits intoNVIDIA:mainfrom
Conversation
|
/bot run --disable-fail-fast |
|
PR_Github #38723 [ run ] triggered by Bot. Commit: |
📝 WalkthroughWalkthroughThis PR introduces a thread-local cache for MLA (Multi-Head Latent Attention) metadata to avoid redundant recomputation across layers within a single forward pass. A new Changes
Sequence DiagramsequenceDiagram
participant Python as Python Code
participant Binding as C++ Binding Layer
participant Impl as AttentionOp Implementation
participant Cache as Thread-Local Cache
participant MLA as MLA Metadata Generator
Python->>Binding: increment_flash_mla_metadata_step()
Binding->>Impl: incrementFlashMlaMetadataStep()
Impl->>Cache: Increment step counter
Python->>Impl: mlaGeneration (Layer 1)
Impl->>Cache: Check cache validity
alt Cache Invalid or Step Changed
Cache-->>Impl: Invalid
Impl->>MLA: Compute metadata
MLA-->>Impl: tileSchedulerMetadata, numSplits
Impl->>Cache: Store derived pointers & step id
else Cache Valid
Cache-->>Impl: Valid cached pointers
end
Python->>Impl: mlaGeneration (Layer 2)
Impl->>Cache: Check cache validity
Cache-->>Impl: Still valid (same step)
Impl-->>Python: Reuse cached metadata (skip MLA compute)
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~22 minutes 🚥 Pre-merge checks | ✅ 2 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
📝 Coding Plan
Comment Tip You can customize the high-level summary generated by CodeRabbit.Configure the |
There was a problem hiding this comment.
Actionable comments posted: 1
🧹 Nitpick comments (1)
cpp/tensorrt_llm/thop/attentionOp.h (1)
78-84: Use//!docs for this new header API.
/** ... */on a new function prototype goes against the repo's C++ comment rules. Please convert this block to//!lines so the header stays Doxygen-compatible.As per coding guidelines, "C++ comments must be in C++ style (
//)" and "Use//!for documenting new C++ class interfaces and function prototypes".🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@cpp/tensorrt_llm/thop/attentionOp.h` around lines 78 - 84, Replace the C-style block comment above the header API (the block describing "Invalidate the per-step FlashMLA tile-scheduler metadata cache" and mentioning get_mla_metadata and MTP sub-steps) with C++-style Doxygen lines: change the /** ... */ to multiple leading //! lines preserving the exact text and formatting so the declaration in attentionOp.h remains Doxygen-compatible and follows the repo rule of using // comments for new function prototypes and class interfaces.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@tensorrt_llm/_torch/attention_backend/trtllm.py`:
- Around line 1038-1043: The invalidation in update_for_spec_dec is gated by the
Python property self.enable_flash_mla but the C++ path uses a different
predicate mUseGenFlashMLA, which can be true when enable_flash_mla is false and
cause stale metadata; fix by gating the Python invalidation on the actual C++
runtime predicate (expose mUseGenFlashMLA to Python and replace the check around
thop.increment_flash_mla_metadata_step with that predicate) or alternatively
ensure a call path that queries the C++ predicate before calling
thop.increment_flash_mla_metadata_step (also apply the same change for the
analogous check near the other spec-dec location where
thop.increment_flash_mla_metadata_step is invoked).
---
Nitpick comments:
In `@cpp/tensorrt_llm/thop/attentionOp.h`:
- Around line 78-84: Replace the C-style block comment above the header API (the
block describing "Invalidate the per-step FlashMLA tile-scheduler metadata
cache" and mentioning get_mla_metadata and MTP sub-steps) with C++-style Doxygen
lines: change the /** ... */ to multiple leading //! lines preserving the exact
text and formatting so the declaration in attentionOp.h remains
Doxygen-compatible and follows the repo rule of using // comments for new
function prototypes and class interfaces.
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Run ID: aa738a74-f8a5-41f6-b3f6-1aeaf5b72af9
📒 Files selected for processing (4)
cpp/tensorrt_llm/common/attentionOp.cppcpp/tensorrt_llm/nanobind/thop/bindings.cppcpp/tensorrt_llm/thop/attentionOp.htensorrt_llm/_torch/attention_backend/trtllm.py
|
PR_Github #38723 [ run ] completed with state
|
e1e0ac1 to
b4445dc
Compare
|
/bot run --disable-fail-fast |
|
PR_Github #39286 [ run ] triggered by Bot. Commit: |
|
PR_Github #39286 [ run ] completed with state
|
|
/bot run --disable-fail-fast --reuse-test |
|
PR_Github #39377 [ run ] triggered by Bot. Commit: |
|
PR_Github #39377 [ run ] completed with state
|
2f3ccc5 to
ada60c8
Compare
…y when necessary Signed-off-by: Bo Li <22713281+bobboli@users.noreply.github.com>
ada60c8 to
594a1f0
Compare
|
/bot run --disable-fail-fast |
|
PR_Github #39490 [ run ] triggered by Bot. Commit: |
|
PR_Github #39490 [ run ] completed with state
|
594a1f0 to
8c8d584
Compare
|
PR_Github #39605 [ kill ] triggered by Bot. Commit: |
|
PR_Github #39605 [ kill ] completed with state |
|
/bot run --disable-fail-fast |
|
PR_Github #39606 [ run ] triggered by Bot. Commit: |
|
PR_Github #39606 [ run ] completed with state
|
|
/bot run --disable-fail-fast --reuse-test |
|
PR_Github #39675 [ run ] triggered by Bot. Commit: |
|
PR_Github #39675 [ run ] completed with state
|
|
/bot run --disable-fail-fast --reuse-test |
|
PR_Github #39694 [ run ] triggered by Bot. Commit: |
|
PR_Github #39694 [ run ] completed with state
|
lfr-0531
left a comment
There was a problem hiding this comment.
A few minor comments. The rest LGTM~
|
/bot run --disable-fail-fast --reuse-test |
|
PR_Github #39706 [ run ] triggered by Bot. Commit: |
|
PR_Github #39706 [ run ] completed with state
|
- Fix cudaDeviceGetAttribute to query the current device instead of hardcoded device 0 in getFlashMlaNumSmPartsStatic - Add TORCH_CHECK that flash_mla_num_splits is provided whenever flash_mla_tile_scheduler_metadata is set Signed-off-by: Bo Li <22713281+bobboli@users.noreply.github.com>
Signed-off-by: Bo Li <22713281+bobboli@users.noreply.github.com>
Signed-off-by: Bo Li <22713281+bobboli@users.noreply.github.com>
|
/bot run --disable-fail-fast --reuse-test |
|
PR_Github #39721 [ run ] triggered by Bot. Commit: |
|
PR_Github #39721 [ run ] completed with state
|
|
/bot run --disable-fail-fast --reuse-test |
|
PR_Github #39750 [ run ] triggered by Bot. Commit: |
|
PR_Github #39750 [ run ] completed with state
|
|
/bot run --disable-fail-fast --reuse-test |
|
PR_Github #39753 [ run ] triggered by Bot. Commit: |
|
PR_Github #39753 [ run ] completed with state |
Description
Pre-compute FlashMLA tile-scheduler metadata once per forward pass in Python and pass it into the C++ attention op, instead of rebuilding it inside every attention layer.
This also fixes the metadata layout for mixed prefill + decode batches by computing metadata from the compacted generation sub-batch, so the generated num_splits and tile-scheduler metadata match the block_ids_per_seq layout consumed by FlashMLA.
Why
The previous flow rebuilt FlashMLA metadata inside the C++ attention path for each layer. Besides the repeated overhead, the ported logic could use metadata derived from the wrong batch view when context and generation requests were mixed together.
That mismatch could leave FlashMLA reading metadata that did not correspond to the active generation sub-batch, which showed up as accuracy failures and, in some cases, illegal memory access.
What changed
Added a Python/C++ entry point to compute FlashMLA metadata into pre-allocated tensors.
Extended the attention op plumbing to accept pre-computed flash_mla_tile_scheduler_metadata and flash_mla_num_splits.
Updated the FlashMLA generation path to require and consume the pre-computed metadata instead of generating it locally.
Allocated stable metadata buffers in TrtllmAttentionMetadata so they can be reused across layers and CUDA graph captures.
Recompute metadata from the generation slice only, and invalidate cached metadata when KV lengths change or speculative decoding updates the batch state.
Test Coverage
PR Checklist
Please review the following before submitting your PR:
PR description clearly explains what and why. If using CodeRabbit's summary, please make sure it makes sense.
PR Follows TRT-LLM CODING GUIDELINES to the best of your knowledge.
Test cases are provided for new code paths (see test instructions)
Any new dependencies have been scanned for license and vulnerabilities
CODEOWNERS updated if ownership changes
Documentation updated as needed
Update tava architecture diagram if there is a significant design change in PR.
The reviewers assigned automatically/manually are appropriate for the PR.
Please check this after reviewing the above items as appropriate for this PR.
GitHub Bot Help
To see a list of available CI bot commands, please comment
/bot help.Summary by CodeRabbit