[https://nvbugs/5983390][perf] Split MLA DSA custom op for piecewise CUDA graph capture#12503
Conversation
b02a780 to
8112487
Compare
📝 WalkthroughWalkthroughThe changes introduce DSA (Dynamic Sparse Attention) support for MLA by adding two new TRTLLM custom ops ( Changes
Sequence DiagramsequenceDiagram
participant Client as Forward Call
participant Proj as mla_dsa_proj<br/>(Custom Op)
participant Proj_Exec as Projection/Quantization<br/>Execution
participant Attn as mla_dsa_attn_inplace<br/>(Custom Op)
participant Indexer as Indexer<br/>(pre_indexer)
participant Cache as K-Cache Update
participant Attention as Sparse Attention<br/>Computation
Client->>Proj: hidden_states, position_ids, layer_idx
Proj->>Proj_Exec: forward_dsa_proj (graph-capturable)
Proj_Exec-->>Proj: q_fp8, k_fp8, k_scale, weights
Proj-->>Client: projection outputs
Client->>Attn: q, compressed_kv, indexer_intermediates, output
Attn->>Indexer: pre_indexer with intermediates
Indexer->>Cache: _update_k_cache
Cache-->>Indexer: cache updated
Indexer-->>Attn: indexer state prepared
Attn->>Attention: sparse_attn_indexer
Attention-->>Attn: attention result
Attn->>Attn: mutate output in-place
Attn-->>Client: None (in-place mutation)
Estimated code review effort🎯 4 (Complex) | ⏱️ ~65 minutes 🚥 Pre-merge checks | ✅ 2 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Comment |
There was a problem hiding this comment.
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (2)
tensorrt_llm/_torch/modules/attention.py (2)
1682-1737:⚠️ Potential issue | 🟠 MajorHandle MLA-lite explicitly in
forward_dsa_proj().This refactor always takes the non-lite path (
split([self.q_lora_rank, ...])+self.q_a_layernorm), butMLA.__init__()still allowsself.is_liteand never createsself.q_a_layernormthere. A DSA layer withq_lora_rank=Nonewill now fail here instead of taking the lite projection path or rejecting the config up front.🛠️ Minimal guard if DSA-lite is intentionally unsupported
def forward_dsa_proj( self, position_ids: Optional[torch.Tensor], hidden_states: torch.Tensor, attn_metadata: AttentionMetadata, ) -> List[torch.Tensor]: """Token-wise projections for DSA MLA (CUDA-graph-capturable Op 1).""" assert self.mqa is not None, "DSA is only supported in MQA mode" + if self.is_lite: + raise NotImplementedError( + "DSA MLA does not support q_lora_rank=None yet" + )If DSA-lite should be supported, restore the existing
self.is_litebranch fromforward_impl()instead.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tensorrt_llm/_torch/modules/attention.py` around lines 1682 - 1737, forward_dsa_proj unconditionally assumes MLA full mode (calls kv_a_proj_with_mqa().split(...) and self.q_a_layernorm) but MLA-lite may set self.is_lite / omit self.q_a_layernorm, causing crashes; update forward_dsa_proj to explicitly handle MLA-lite by checking self.is_lite (or existence of self.q_a_layernorm) and either (A) take the lite projection branch used in forward_impl (use the lite split/projection path and skip q_a_layernorm), or (B) raise a clear configuration error if DSA-lite is unsupported; ensure you still call self.mqa.indexer.pre_indexer_proj(...) and preserve the existing short-MHA check via _should_use_short_mha and the return shapes.
1739-1839:⚠️ Potential issue | 🟠 MajorReapply
llama_4_scalingbefore the DSA dispatch.
forward_impl()still scales Q in both context and generation branches whenself.llama_4_scalingis enabled, but the newforward_dsa_attn()path never does. That changes MLA math for any DSA model carrying a Llama-4-style config.🛠️ Suggested fix
if num_contexts > 0: q_ctx = q[:num_ctx_tokens, ...] compressed_kv_ctx = compressed_kv[:num_ctx_tokens, ...] k_pe_ctx = k_pe[:num_ctx_tokens, ...] latent_cache_ctx = latent_cache[:num_ctx_tokens, ...] if self.apply_rotary_emb: assert position_ids is not None k_pe_ctx = self.apply_rope(q_ctx, k_pe_ctx, position_ids) + if self.llama_4_scaling: + q_ctx = self._attention_scaling( + q_ctx, position_ids[..., :num_ctx_tokens] + ) self.forward_context_dsa( q_ctx, compressed_kv_ctx, k_pe_ctx, @@ if num_generations > 0: q_gen = q[num_ctx_tokens:, ...] compressed_kv_gen = compressed_kv[num_ctx_tokens:, ...] k_pe_gen = k_pe[num_ctx_tokens:, ...] latent_cache_gen = latent_cache[num_ctx_tokens:, ...] if self.apply_rotary_emb: assert position_ids is not None k_pe_gen = self.apply_rope(q_gen, k_pe_gen, position_ids) + if self.llama_4_scaling: + q_gen = self._attention_scaling( + q_gen, position_ids[..., num_ctx_tokens:num_tokens] + ) self.forward_generation_dsa( q_gen, compressed_kv_gen, k_pe_gen,🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tensorrt_llm/_torch/modules/attention.py` around lines 1739 - 1839, forward_dsa_attn() never reapplies the Llama-4 scale to Q (llama_4_scaling) before dispatching to DSA paths, changing MLA math versus forward_impl(); fix by applying the same scaling used in forward_impl() to the sliced query tensors before DSA calls: when self.llama_4_scaling is true, multiply q_ctx and q_gen (the local q slices) by the same scale factor or call the existing scaling helper used in forward_impl() immediately after slicing (and before apply_rope/forward_context_dsa/forward_generation_dsa) so both forward_context_dsa and forward_generation_dsa receive properly scaled Qs.
🧹 Nitpick comments (1)
tensorrt_llm/_torch/attention_backend/sparse/dsa.py (1)
85-88: Keep the block-index validation outside capture.Commenting this out entirely drops the fast-fail check from the CPU/eager path too. Guard it on device/capture state instead of deleting it, so bad block tables still fail with a clear error before the scatter path sees them.
♻️ Possible tweak
- # if not torch.cuda.is_current_stream_capturing(): - # max_blocks = block_offsets.shape[1] - # assert (block_indices_in_seq < max_blocks).all(), \ - # f"Block index out of bounds: max={max_blocks}, got indices up to {block_indices_in_seq.max().item()}" + if block_offsets.device.type != "cuda" or not torch.cuda.is_current_stream_capturing(): + max_blocks = block_offsets.shape[1] + assert (block_indices_in_seq < max_blocks).all(), ( + f"Block index out of bounds: max={max_blocks}, " + f"got indices up to {block_indices_in_seq.max().item()}" + )🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tensorrt_llm/_torch/attention_backend/sparse/dsa.py` around lines 85 - 88, The removed bounds-check allows invalid block tables to slip through during CPU/eager execution; restore the validation but skip it only when inside a CUDA capture stream. Reintroduce the assert using the existing symbols (block_offsets, block_indices_in_seq) and guard it with a condition using torch.cuda.is_current_stream_capturing() so the check runs in non-capture (CPU/eager) paths and is skipped during capture before the scatter path is executed.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Outside diff comments:
In `@tensorrt_llm/_torch/modules/attention.py`:
- Around line 1682-1737: forward_dsa_proj unconditionally assumes MLA full mode
(calls kv_a_proj_with_mqa().split(...) and self.q_a_layernorm) but MLA-lite may
set self.is_lite / omit self.q_a_layernorm, causing crashes; update
forward_dsa_proj to explicitly handle MLA-lite by checking self.is_lite (or
existence of self.q_a_layernorm) and either (A) take the lite projection branch
used in forward_impl (use the lite split/projection path and skip
q_a_layernorm), or (B) raise a clear configuration error if DSA-lite is
unsupported; ensure you still call self.mqa.indexer.pre_indexer_proj(...) and
preserve the existing short-MHA check via _should_use_short_mha and the return
shapes.
- Around line 1739-1839: forward_dsa_attn() never reapplies the Llama-4 scale to
Q (llama_4_scaling) before dispatching to DSA paths, changing MLA math versus
forward_impl(); fix by applying the same scaling used in forward_impl() to the
sliced query tensors before DSA calls: when self.llama_4_scaling is true,
multiply q_ctx and q_gen (the local q slices) by the same scale factor or call
the existing scaling helper used in forward_impl() immediately after slicing
(and before apply_rope/forward_context_dsa/forward_generation_dsa) so both
forward_context_dsa and forward_generation_dsa receive properly scaled Qs.
---
Nitpick comments:
In `@tensorrt_llm/_torch/attention_backend/sparse/dsa.py`:
- Around line 85-88: The removed bounds-check allows invalid block tables to
slip through during CPU/eager execution; restore the validation but skip it only
when inside a CUDA capture stream. Reintroduce the assert using the existing
symbols (block_offsets, block_indices_in_seq) and guard it with a condition
using torch.cuda.is_current_stream_capturing() so the check runs in non-capture
(CPU/eager) paths and is skipped during capture before the scatter path is
executed.
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Run ID: a8d030e1-3fb1-4d91-b965-193d19f64ef9
📒 Files selected for processing (4)
tensorrt_llm/_torch/attention_backend/sparse/dsa.pytensorrt_llm/_torch/compilation/piecewise_optimizer.pytensorrt_llm/_torch/compilation/utils.pytensorrt_llm/_torch/modules/attention.py
|
/bot run |
|
PR_Github #40315 [ run ] triggered by Bot. Commit: |
|
PR_Github #40315 [ run ] completed with state
|
|
/bot run —disable-fail-fast |
|
PR_Github #40392 Bot args parsing error: usage: /bot [-h] |
|
/bot run --disable-fail-fast |
|
PR_Github #40614 Bot args parsing error: usage: /bot [-h] |
|
PR_Github #40615 [ run ] triggered by Bot. Commit: |
|
PR_Github #40615 [ run ] completed with state
|
|
/bot run --disable-fail-fast |
|
PR_Github #40686 [ run ] triggered by Bot. Commit: |
|
/bot help |
GitHub Bot Help
Provide a user friendly way for developers to interact with a Jenkins server. Run See details below for each supported subcommand. Details
Launch build/test pipelines. All previously running jobs will be killed.
kill
Kill all running builds associated with pull request. skip
Skip testing for latest commit on pull request. reuse-pipeline
Reuse a previous pipeline to validate current commit. This action will also kill all currently running builds associated with the pull request. IMPORTANT NOTE: This is dangerous since lack of user care and validation can cause top of tree to break. |
|
PR_Github #40686 [ run ] completed with state |
|
/bot run --disable-fail-fast --add-multi-gpu-test |
|
PR_Github #40725 [ run ] triggered by Bot. Commit: |
|
PR_Github #40725 [ run ] completed with state
|
|
/bot run --disable-fail-fast --add-multi-gpu-test |
|
PR_Github #40824 [ run ] triggered by Bot. Commit: |
Split the monolithic mla_custom_op_inplace into two ops for DSA models: - mla_dsa_proj (Op 1): Token-wise projections (cublas_mm, rope, FP8 quantize, weight scaling). CUDA-graph-capturable — no batch metadata access, no tensor slicing by num_tokens. - mla_dsa_attn_inplace (Op 2): Batch-dependent k cache update, sparse_attn_indexer, and context/generation attention dispatch. Excluded from CUDA graph capture. This enables the piecewise CUDA graph optimizer to capture the compute-heavy projection portion of DSA MLA while keeping the batch-structure-dependent attention dispatch outside the graph. Key design decisions: - Indexer split into pre_indexer_proj (graph-safe) and _update_k_cache (moved to Op 2) to avoid capturing metadata-dependent scatter ops. - All num_tokens slicing deferred to Op 2 so graph capture sees fixed-shape tensors. - Indexer intermediates (q_fp8, k_fp8, k_scale, weights) returned from Op 1 as List[Tensor] and passed explicitly to Op 2 — no stashing on self to avoid CUDA graph address aliasing. - _should_use_short_mha disabled under torch compile for straight-line control flow in Op 1. - Non-DSA MLA unchanged (still uses mla_custom_op_inplace). Signed-off-by: Jin Li <59594262+liji-nv@users.noreply.github.com>
…ocstring - Remove no-op `lambda: weights` in pre_indexer's maybe_execute_in_parallel; _weight_scale already ran in pre_indexer_proj, so just call _update_k_cache directly. - Fix mla_dsa_proj docstring: k cache update happens in Op 2 (mla_dsa_attn_inplace), not Op 1. Signed-off-by: Jin Li <59594262+liji-nv@users.noreply.github.com>
Move _update_k_cache call to the top of sparse_attn_indexer so the k cache is populated right before prefill chunks gather from it. Remove pre_indexer (now redundant); forward() and forward_dsa_proj both call pre_indexer_proj directly. Signed-off-by: Jin Li <59594262+liji-nv@users.noreply.github.com>
- Remove dead is_dsa branch from mla_custom_op_inplace since DSA is now exclusively handled by the split mla_dsa_proj/mla_dsa_attn_inplace ops - Use literal 1 for k_scale shape to match C++ fusedCatFp8 kernel output - Simplify proj_outputs unpacking Signed-off-by: Jin Li <59594262+liji-nv@users.noreply.github.com>
3f9e51d to
66627a1
Compare
|
/bot run |
The assertion was dropped when the old Indexer.forward was split into pre_indexer_proj and sparse_attn_indexer. Restore it in sparse_attn_indexer which has access to metadata.kv_cache_manager. Signed-off-by: Jin Li <59594262+liji-nv@users.noreply.github.com>
|
/bot run |
|
PR_Github #40869 [ run ] triggered by Bot. Commit: |
|
PR_Github #40869 [ run ] completed with state |
Split the monolithic mla_custom_op_inplace into two ops for DSA models:
quantize, weight scaling). CUDA-graph-capturable — no batch metadata
access, no tensor slicing by num_tokens.
sparse_attn_indexer, and context/generation attention dispatch.
Excluded from CUDA graph capture.
This enables the piecewise CUDA graph optimizer to capture the
compute-heavy projection portion of DSA MLA while keeping the
batch-structure-dependent attention dispatch outside the graph.
Key design decisions:
(moved to Op 2) to avoid capturing metadata-dependent scatter ops.
fixed-shape tensors.
Op 1 as List[Tensor] and passed explicitly to Op 2 — no stashing on
self to avoid CUDA graph address aliasing.
control flow in Op 1.
Summary by CodeRabbit
Performance Optimizations
Refactor
Description
Test Coverage
PR Checklist
Please review the following before submitting your PR:
PR description clearly explains what and why. If using CodeRabbit's summary, please make sure it makes sense.
PR Follows TRT-LLM CODING GUIDELINES to the best of your knowledge.
Test cases are provided for new code paths (see test instructions)
Any new dependencies have been scanned for license and vulnerabilities
CODEOWNERS updated if ownership changes
Documentation updated as needed
Update tava architecture diagram if there is a significant design change in PR.
The reviewers assigned automatically/manually are appropriate for the PR.
Please check this after reviewing the above items as appropriate for this PR.
GitHub Bot Help
To see a list of available CI bot commands, please comment
/bot help.