Skip to content

[None][perf] Cache FlashMLA tile-scheduler metadata across attention layers#12161

Merged
bobboli merged 7 commits intoNVIDIA:mainfrom
bobboli:flashmla_cache_metadata
Mar 21, 2026
Merged

[None][perf] Cache FlashMLA tile-scheduler metadata across attention layers#12161
bobboli merged 7 commits intoNVIDIA:mainfrom
bobboli:flashmla_cache_metadata

Conversation

@bobboli
Copy link
Collaborator

@bobboli bobboli commented Mar 12, 2026

Description

Pre-compute FlashMLA tile-scheduler metadata once per forward pass in Python and pass it into the C++ attention op, instead of rebuilding it inside every attention layer.
This also fixes the metadata layout for mixed prefill + decode batches by computing metadata from the compacted generation sub-batch, so the generated num_splits and tile-scheduler metadata match the block_ids_per_seq layout consumed by FlashMLA.

Why

The previous flow rebuilt FlashMLA metadata inside the C++ attention path for each layer. Besides the repeated overhead, the ported logic could use metadata derived from the wrong batch view when context and generation requests were mixed together.
That mismatch could leave FlashMLA reading metadata that did not correspond to the active generation sub-batch, which showed up as accuracy failures and, in some cases, illegal memory access.

What changed

Added a Python/C++ entry point to compute FlashMLA metadata into pre-allocated tensors.
Extended the attention op plumbing to accept pre-computed flash_mla_tile_scheduler_metadata and flash_mla_num_splits.
Updated the FlashMLA generation path to require and consume the pre-computed metadata instead of generating it locally.
Allocated stable metadata buffers in TrtllmAttentionMetadata so they can be reused across layers and CUDA graph captures.
Recompute metadata from the generation slice only, and invalidate cached metadata when KV lengths change or speculative decoding updates the batch state.

Test Coverage

PR Checklist

Please review the following before submitting your PR:

  • PR description clearly explains what and why. If using CodeRabbit's summary, please make sure it makes sense.

  • PR Follows TRT-LLM CODING GUIDELINES to the best of your knowledge.

  • Test cases are provided for new code paths (see test instructions)

  • Any new dependencies have been scanned for license and vulnerabilities

  • CODEOWNERS updated if ownership changes

  • Documentation updated as needed

  • Update tava architecture diagram if there is a significant design change in PR.

  • The reviewers assigned automatically/manually are appropriate for the PR.

  • Please check this after reviewing the above items as appropriate for this PR.

GitHub Bot Help

To see a list of available CI bot commands, please comment /bot help.

Summary by CodeRabbit

  • New Features
    • Introduced attention metadata caching to minimize redundant computations across layers within a single forward pass.
    • Added new function to invalidate metadata cache at forward pass and speculative decoding boundaries.

@bobboli bobboli requested a review from a team as a code owner March 12, 2026 11:30
@bobboli bobboli requested a review from yuxianq March 12, 2026 11:30
@bobboli
Copy link
Collaborator Author

bobboli commented Mar 12, 2026

/bot run --disable-fail-fast

@bobboli bobboli changed the title [None][perf] Cache FlashMLA tile-scheduler metadata across attention ;aers/ [None][perf] Cache FlashMLA tile-scheduler metadata across attention layers Mar 12, 2026
@tensorrt-cicd
Copy link
Collaborator

PR_Github #38723 [ run ] triggered by Bot. Commit: e1e0ac1 Link to invocation

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Mar 12, 2026

📝 Walkthrough

Walkthrough

This PR introduces a thread-local cache for MLA (Multi-Head Latent Attention) metadata to avoid redundant recomputation across layers within a single forward pass. A new incrementFlashMlaMetadataStep() function invalidates the cache, exposed through C++ headers and Python bindings, called during MLA preparation and speculative decoding updates.

Changes

Cohort / File(s) Summary
MLA Metadata Caching Implementation
cpp/tensorrt_llm/common/attentionOp.cpp
Introduces thread-local FlashMlaMetadataCache with per-thread step identifier; reworks AttentionOp::mlaGeneration to check cache validity and reuse cached metadata (tileSchedulerMetadata, numSplits) across layers, invalidating when step id or batch parameters change.
C++ API and Bindings
cpp/tensorrt_llm/thop/attentionOp.h, cpp/tensorrt_llm/nanobind/thop/bindings.cpp
Declares new public function incrementFlashMlaMetadataStep() in TRTLLM namespace; adds corresponding Nanobind binding to expose it to Python with documentation.
Python Integration
tensorrt_llm/_torch/attention_backend/trtllm.py
Adds update_for_spec_dec() method to TrtllmAttentionWrapper and calls incrementFlashMlaMetadataStep() in two locations (prepare_flash_mla and new method) to invalidate cache during MLA preparation and speculative decoding transitions.

Sequence Diagram

sequenceDiagram
    participant Python as Python Code
    participant Binding as C++ Binding Layer
    participant Impl as AttentionOp Implementation
    participant Cache as Thread-Local Cache
    participant MLA as MLA Metadata Generator

    Python->>Binding: increment_flash_mla_metadata_step()
    Binding->>Impl: incrementFlashMlaMetadataStep()
    Impl->>Cache: Increment step counter
    
    Python->>Impl: mlaGeneration (Layer 1)
    Impl->>Cache: Check cache validity
    alt Cache Invalid or Step Changed
        Cache-->>Impl: Invalid
        Impl->>MLA: Compute metadata
        MLA-->>Impl: tileSchedulerMetadata, numSplits
        Impl->>Cache: Store derived pointers & step id
    else Cache Valid
        Cache-->>Impl: Valid cached pointers
    end
    
    Python->>Impl: mlaGeneration (Layer 2)
    Impl->>Cache: Check cache validity
    Cache-->>Impl: Still valid (same step)
    Impl-->>Python: Reuse cached metadata (skip MLA compute)
Loading

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~22 minutes

🚥 Pre-merge checks | ✅ 2 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 18.18% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (2 passed)
Check name Status Explanation
Title check ✅ Passed The title clearly summarizes the main change: caching FlashMLA tile-scheduler metadata across attention layers. It accurately reflects the core optimization from the changeset.
Description check ✅ Passed PR description covers the issue, solution, and provides detailed context about why the change was needed and what was changed.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests
📝 Coding Plan
  • Generate coding plan for human review comments

Comment @coderabbitai help to get the list of available commands and usage tips.

Tip

You can customize the high-level summary generated by CodeRabbit.

Configure the reviews.high_level_summary_instructions setting to provide custom instructions for generating the high-level summary.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🧹 Nitpick comments (1)
cpp/tensorrt_llm/thop/attentionOp.h (1)

78-84: Use //! docs for this new header API.

/** ... */ on a new function prototype goes against the repo's C++ comment rules. Please convert this block to //! lines so the header stays Doxygen-compatible.

As per coding guidelines, "C++ comments must be in C++ style (//)" and "Use //! for documenting new C++ class interfaces and function prototypes".

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@cpp/tensorrt_llm/thop/attentionOp.h` around lines 78 - 84, Replace the
C-style block comment above the header API (the block describing "Invalidate the
per-step FlashMLA tile-scheduler metadata cache" and mentioning get_mla_metadata
and MTP sub-steps) with C++-style Doxygen lines: change the /** ... */ to
multiple leading //! lines preserving the exact text and formatting so the
declaration in attentionOp.h remains Doxygen-compatible and follows the repo
rule of using // comments for new function prototypes and class interfaces.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@tensorrt_llm/_torch/attention_backend/trtllm.py`:
- Around line 1038-1043: The invalidation in update_for_spec_dec is gated by the
Python property self.enable_flash_mla but the C++ path uses a different
predicate mUseGenFlashMLA, which can be true when enable_flash_mla is false and
cause stale metadata; fix by gating the Python invalidation on the actual C++
runtime predicate (expose mUseGenFlashMLA to Python and replace the check around
thop.increment_flash_mla_metadata_step with that predicate) or alternatively
ensure a call path that queries the C++ predicate before calling
thop.increment_flash_mla_metadata_step (also apply the same change for the
analogous check near the other spec-dec location where
thop.increment_flash_mla_metadata_step is invoked).

---

Nitpick comments:
In `@cpp/tensorrt_llm/thop/attentionOp.h`:
- Around line 78-84: Replace the C-style block comment above the header API (the
block describing "Invalidate the per-step FlashMLA tile-scheduler metadata
cache" and mentioning get_mla_metadata and MTP sub-steps) with C++-style Doxygen
lines: change the /** ... */ to multiple leading //! lines preserving the exact
text and formatting so the declaration in attentionOp.h remains
Doxygen-compatible and follows the repo rule of using // comments for new
function prototypes and class interfaces.

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: aa738a74-f8a5-41f6-b3f6-1aeaf5b72af9

📥 Commits

Reviewing files that changed from the base of the PR and between adfc542 and e1e0ac1.

📒 Files selected for processing (4)
  • cpp/tensorrt_llm/common/attentionOp.cpp
  • cpp/tensorrt_llm/nanobind/thop/bindings.cpp
  • cpp/tensorrt_llm/thop/attentionOp.h
  • tensorrt_llm/_torch/attention_backend/trtllm.py

@tensorrt-cicd
Copy link
Collaborator

PR_Github #38723 [ run ] completed with state SUCCESS. Commit: e1e0ac1
/LLM/main/L0_MergeRequest_PR pipeline #30041 completed with status: 'FAILURE'

CI Report

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

Link to invocation

@bobboli
Copy link
Collaborator Author

bobboli commented Mar 17, 2026

/bot run --disable-fail-fast

@tensorrt-cicd
Copy link
Collaborator

PR_Github #39286 [ run ] triggered by Bot. Commit: 2ca0cf4 Link to invocation

@tensorrt-cicd
Copy link
Collaborator

PR_Github #39286 [ run ] completed with state SUCCESS. Commit: 2ca0cf4
/LLM/main/L0_MergeRequest_PR pipeline #30535 completed with status: 'FAILURE'

CI Report

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

Link to invocation

@bobboli
Copy link
Collaborator Author

bobboli commented Mar 18, 2026

/bot run --disable-fail-fast --reuse-test

@tensorrt-cicd
Copy link
Collaborator

PR_Github #39377 [ run ] triggered by Bot. Commit: 2ca0cf4 Link to invocation

@tensorrt-cicd
Copy link
Collaborator

PR_Github #39377 [ run ] completed with state SUCCESS. Commit: 2ca0cf4
/LLM/main/L0_MergeRequest_PR pipeline #30618 completed with status: 'FAILURE'

CI Report

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

Link to invocation

@bobboli bobboli force-pushed the flashmla_cache_metadata branch from 2f3ccc5 to ada60c8 Compare March 18, 2026 15:07
…y when necessary

Signed-off-by: Bo Li <22713281+bobboli@users.noreply.github.com>
@bobboli bobboli force-pushed the flashmla_cache_metadata branch from ada60c8 to 594a1f0 Compare March 18, 2026 17:08
@bobboli bobboli requested a review from a team as a code owner March 18, 2026 17:08
@bobboli bobboli requested a review from sunnyqgg March 18, 2026 17:08
@bobboli
Copy link
Collaborator Author

bobboli commented Mar 18, 2026

/bot run --disable-fail-fast

@tensorrt-cicd
Copy link
Collaborator

PR_Github #39490 [ run ] triggered by Bot. Commit: 594a1f0 Link to invocation

@tensorrt-cicd
Copy link
Collaborator

PR_Github #39490 [ run ] completed with state FAILURE. Commit: 594a1f0
/LLM/main/L0_MergeRequest_PR pipeline #30711 completed with status: 'FAILURE'

CI Report

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

Link to invocation

@bobboli bobboli force-pushed the flashmla_cache_metadata branch from 594a1f0 to 8c8d584 Compare March 19, 2026 04:01
@tensorrt-cicd
Copy link
Collaborator

PR_Github #39605 [ kill ] triggered by Bot. Commit: 507e8ad Link to invocation

@tensorrt-cicd
Copy link
Collaborator

PR_Github #39605 [ kill ] completed with state SUCCESS. Commit: 507e8ad
Successfully killed previous jobs for commit 507e8ad

Link to invocation

@bobboli
Copy link
Collaborator Author

bobboli commented Mar 19, 2026

/bot run --disable-fail-fast

@tensorrt-cicd
Copy link
Collaborator

PR_Github #39606 [ run ] triggered by Bot. Commit: 507e8ad Link to invocation

@tensorrt-cicd
Copy link
Collaborator

PR_Github #39606 [ run ] completed with state SUCCESS. Commit: 507e8ad
/LLM/main/L0_MergeRequest_PR pipeline #30814 completed with status: 'FAILURE'

CI Report

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

Link to invocation

@bobboli
Copy link
Collaborator Author

bobboli commented Mar 20, 2026

/bot run --disable-fail-fast --reuse-test

@tensorrt-cicd
Copy link
Collaborator

PR_Github #39675 [ run ] triggered by Bot. Commit: 507e8ad Link to invocation

@tensorrt-cicd
Copy link
Collaborator

PR_Github #39675 [ run ] completed with state SUCCESS. Commit: 507e8ad
/LLM/main/L0_MergeRequest_PR pipeline #30876 completed with status: 'FAILURE'

CI Report

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

Link to invocation

@bobboli
Copy link
Collaborator Author

bobboli commented Mar 20, 2026

/bot run --disable-fail-fast --reuse-test

@tensorrt-cicd
Copy link
Collaborator

PR_Github #39694 [ run ] triggered by Bot. Commit: 507e8ad Link to invocation

@tensorrt-cicd
Copy link
Collaborator

PR_Github #39694 [ run ] completed with state SUCCESS. Commit: 507e8ad
/LLM/main/L0_MergeRequest_PR pipeline #30892 completed with status: 'FAILURE'

CI Report

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

Link to invocation

Copy link
Collaborator

@lfr-0531 lfr-0531 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A few minor comments. The rest LGTM~

@bobboli
Copy link
Collaborator Author

bobboli commented Mar 20, 2026

/bot run --disable-fail-fast --reuse-test

@tensorrt-cicd
Copy link
Collaborator

PR_Github #39706 [ run ] triggered by Bot. Commit: 507e8ad Link to invocation

@tensorrt-cicd
Copy link
Collaborator

PR_Github #39706 [ run ] completed with state FAILURE. Commit: 507e8ad
/LLM/main/L0_MergeRequest_PR pipeline #30903 completed with status: 'FAILURE'

CI Report

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

Link to invocation

bobboli added 3 commits March 20, 2026 08:25
- Fix cudaDeviceGetAttribute to query the current device instead of
  hardcoded device 0 in getFlashMlaNumSmPartsStatic
- Add TORCH_CHECK that flash_mla_num_splits is provided whenever
  flash_mla_tile_scheduler_metadata is set

Signed-off-by: Bo Li <22713281+bobboli@users.noreply.github.com>
Signed-off-by: Bo Li <22713281+bobboli@users.noreply.github.com>
Signed-off-by: Bo Li <22713281+bobboli@users.noreply.github.com>
@bobboli
Copy link
Collaborator Author

bobboli commented Mar 20, 2026

/bot run --disable-fail-fast --reuse-test

@tensorrt-cicd
Copy link
Collaborator

PR_Github #39721 [ run ] triggered by Bot. Commit: cc61c80 Link to invocation

@tensorrt-cicd
Copy link
Collaborator

PR_Github #39721 [ run ] completed with state SUCCESS. Commit: cc61c80
/LLM/main/L0_MergeRequest_PR pipeline #30918 completed with status: 'FAILURE'

CI Report

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

Link to invocation

@bobboli
Copy link
Collaborator Author

bobboli commented Mar 20, 2026

/bot run --disable-fail-fast --reuse-test

@tensorrt-cicd
Copy link
Collaborator

PR_Github #39750 [ run ] triggered by Bot. Commit: cc61c80 Link to invocation

@tensorrt-cicd
Copy link
Collaborator

PR_Github #39750 [ run ] completed with state FAILURE. Commit: cc61c80
/LLM/main/L0_MergeRequest_PR pipeline #30945 completed with status: 'FAILURE'

CI Report

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

Link to invocation

@bobboli
Copy link
Collaborator Author

bobboli commented Mar 20, 2026

/bot run --disable-fail-fast --reuse-test

@tensorrt-cicd
Copy link
Collaborator

PR_Github #39753 [ run ] triggered by Bot. Commit: cc61c80 Link to invocation

@tensorrt-cicd
Copy link
Collaborator

PR_Github #39753 [ run ] completed with state SUCCESS. Commit: cc61c80
/LLM/main/L0_MergeRequest_PR pipeline #30948 completed with status: 'SUCCESS'

CI Report

Link to invocation

Copy link
Collaborator

@sunnyqgg sunnyqgg left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@bobboli bobboli merged commit aac66bd into NVIDIA:main Mar 21, 2026
4 of 5 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants