Skip to content

[https://nvbugs/5983390][perf] Split MLA DSA custom op for piecewise CUDA graph capture#12503

Merged
liji-nv merged 5 commits intoNVIDIA:mainfrom
liji-nv:dev-liji-dsa-piecewise
Mar 31, 2026
Merged

[https://nvbugs/5983390][perf] Split MLA DSA custom op for piecewise CUDA graph capture#12503
liji-nv merged 5 commits intoNVIDIA:mainfrom
liji-nv:dev-liji-dsa-piecewise

Conversation

@liji-nv
Copy link
Copy Markdown
Collaborator

@liji-nv liji-nv commented Mar 24, 2026

Split the monolithic mla_custom_op_inplace into two ops for DSA models:

  • mla_dsa_proj (Op 1): Token-wise projections (cublas_mm, rope, FP8
    quantize, weight scaling). CUDA-graph-capturable — no batch metadata
    access, no tensor slicing by num_tokens.
  • mla_dsa_attn_inplace (Op 2): Batch-dependent k cache update,
    sparse_attn_indexer, and context/generation attention dispatch.
    Excluded from CUDA graph capture.

This enables the piecewise CUDA graph optimizer to capture the
compute-heavy projection portion of DSA MLA while keeping the
batch-structure-dependent attention dispatch outside the graph.

Key design decisions:

  • Indexer split into pre_indexer_proj (graph-safe) and _update_k_cache
    (moved to Op 2) to avoid capturing metadata-dependent scatter ops.
  • All num_tokens slicing deferred to Op 2 so graph capture sees
    fixed-shape tensors.
  • Indexer intermediates (q_fp8, k_fp8, k_scale, weights) returned from
    Op 1 as List[Tensor] and passed explicitly to Op 2 — no stashing on
    self to avoid CUDA graph address aliasing.
  • _should_use_short_mha disabled under torch compile for straight-line
    control flow in Op 1.
  • Non-DSA MLA unchanged (still uses mla_custom_op_inplace).

Summary by CodeRabbit

  • Performance Optimizations

    • Improved Dynamic Sparse Attention (DSA) execution efficiency with streamlined indexing and cache operations.
    • Enhanced CUDA graph capture compatibility to reduce runtime overhead for DSA workloads.
  • Refactor

    • Restructured DSA attention computation pipeline for better compiler integration and modularity.
    • Extended compilation optimizer support for advanced attention operations.

Description

Test Coverage

PR Checklist

Please review the following before submitting your PR:

  • PR description clearly explains what and why. If using CodeRabbit's summary, please make sure it makes sense.

  • PR Follows TRT-LLM CODING GUIDELINES to the best of your knowledge.

  • Test cases are provided for new code paths (see test instructions)

  • Any new dependencies have been scanned for license and vulnerabilities

  • CODEOWNERS updated if ownership changes

  • Documentation updated as needed

  • Update tava architecture diagram if there is a significant design change in PR.

  • The reviewers assigned automatically/manually are appropriate for the PR.

  • Please check this after reviewing the above items as appropriate for this PR.

GitHub Bot Help

To see a list of available CI bot commands, please comment /bot help.

@liji-nv liji-nv requested review from a team as code owners March 24, 2026 14:03
@liji-nv liji-nv requested review from hyukn and yuxianq March 24, 2026 14:03
@liji-nv liji-nv force-pushed the dev-liji-dsa-piecewise branch from b02a780 to 8112487 Compare March 24, 2026 14:04
@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai bot commented Mar 24, 2026

📝 Walkthrough

Walkthrough

The changes introduce DSA (Dynamic Sparse Attention) support for MLA by adding two new TRTLLM custom ops (mla_dsa_proj and mla_dsa_attn_inplace), refactoring the Indexer class to separate projection/quantization from cache updates, and extending compilation utilities to recognize these new ops during graph optimization and piecewise partitioning.

Changes

Cohort / File(s) Summary
DSA Indexer Refactoring
tensorrt_llm/_torch/attention_backend/sparse/dsa.py
Split Indexer.forward() logic into separate pre_indexer() (handles projection, quantization, weight scaling, and k-cache update) and pre_indexer_proj() (projection/quantization only without k-cache). Removed CUDA graph-safety bounds check assertion and quantization block size constraint assertion.
DSA Custom Ops
tensorrt_llm/_torch/modules/attention.py
Added two TRTLLM CUDA-graph-oriented custom ops: mla_dsa_proj (token-wise projection/quantization stage) and mla_dsa_attn_inplace (attention stage with in-place output mutation). Refactored DSA MLA forward path into separate forward_dsa_proj() and forward_dsa_attn() stages. Modified _should_use_short_mha() to always return False during torch compilation to ensure straight-line control flow for CUDA graphs.
Compilation Support
tensorrt_llm/_torch/compilation/piecewise_optimizer.py, tensorrt_llm/_torch/compilation/utils.py
Extended piecewise partitioning logic to recognize mla_dsa_attn_inplace.default as an attention-related custom op alongside existing attention ops. Added in-place operation metadata for the new op with argument index 1 mapped to "output".

Sequence Diagram

sequenceDiagram
    participant Client as Forward Call
    participant Proj as mla_dsa_proj<br/>(Custom Op)
    participant Proj_Exec as Projection/Quantization<br/>Execution
    participant Attn as mla_dsa_attn_inplace<br/>(Custom Op)
    participant Indexer as Indexer<br/>(pre_indexer)
    participant Cache as K-Cache Update
    participant Attention as Sparse Attention<br/>Computation
    
    Client->>Proj: hidden_states, position_ids, layer_idx
    Proj->>Proj_Exec: forward_dsa_proj (graph-capturable)
    Proj_Exec-->>Proj: q_fp8, k_fp8, k_scale, weights
    Proj-->>Client: projection outputs
    
    Client->>Attn: q, compressed_kv, indexer_intermediates, output
    Attn->>Indexer: pre_indexer with intermediates
    Indexer->>Cache: _update_k_cache
    Cache-->>Indexer: cache updated
    Indexer-->>Attn: indexer state prepared
    Attn->>Attention: sparse_attn_indexer
    Attention-->>Attn: attention result
    Attn->>Attn: mutate output in-place
    Attn-->>Client: None (in-place mutation)
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~65 minutes

🚥 Pre-merge checks | ✅ 2 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 58.82% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (2 passed)
Check name Status Explanation
Title check ✅ Passed The PR title accurately reflects the main change: splitting MLA DSA custom op for piecewise CUDA graph capture, which is the core objective described in the commit messages.
Description check ✅ Passed PR description is mostly complete with clear objective, design decisions, and scope, though title and some checklist items are incomplete.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (2)
tensorrt_llm/_torch/modules/attention.py (2)

1682-1737: ⚠️ Potential issue | 🟠 Major

Handle MLA-lite explicitly in forward_dsa_proj().

This refactor always takes the non-lite path (split([self.q_lora_rank, ...]) + self.q_a_layernorm), but MLA.__init__() still allows self.is_lite and never creates self.q_a_layernorm there. A DSA layer with q_lora_rank=None will now fail here instead of taking the lite projection path or rejecting the config up front.

🛠️ Minimal guard if DSA-lite is intentionally unsupported
     def forward_dsa_proj(
         self,
         position_ids: Optional[torch.Tensor],
         hidden_states: torch.Tensor,
         attn_metadata: AttentionMetadata,
     ) -> List[torch.Tensor]:
         """Token-wise projections for DSA MLA (CUDA-graph-capturable Op 1)."""
         assert self.mqa is not None, "DSA is only supported in MQA mode"
+        if self.is_lite:
+            raise NotImplementedError(
+                "DSA MLA does not support q_lora_rank=None yet"
+            )

If DSA-lite should be supported, restore the existing self.is_lite branch from forward_impl() instead.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tensorrt_llm/_torch/modules/attention.py` around lines 1682 - 1737,
forward_dsa_proj unconditionally assumes MLA full mode (calls
kv_a_proj_with_mqa().split(...) and self.q_a_layernorm) but MLA-lite may set
self.is_lite / omit self.q_a_layernorm, causing crashes; update forward_dsa_proj
to explicitly handle MLA-lite by checking self.is_lite (or existence of
self.q_a_layernorm) and either (A) take the lite projection branch used in
forward_impl (use the lite split/projection path and skip q_a_layernorm), or (B)
raise a clear configuration error if DSA-lite is unsupported; ensure you still
call self.mqa.indexer.pre_indexer_proj(...) and preserve the existing short-MHA
check via _should_use_short_mha and the return shapes.

1739-1839: ⚠️ Potential issue | 🟠 Major

Reapply llama_4_scaling before the DSA dispatch.

forward_impl() still scales Q in both context and generation branches when self.llama_4_scaling is enabled, but the new forward_dsa_attn() path never does. That changes MLA math for any DSA model carrying a Llama-4-style config.

🛠️ Suggested fix
         if num_contexts > 0:
             q_ctx = q[:num_ctx_tokens, ...]
             compressed_kv_ctx = compressed_kv[:num_ctx_tokens, ...]
             k_pe_ctx = k_pe[:num_ctx_tokens, ...]
             latent_cache_ctx = latent_cache[:num_ctx_tokens, ...]
             if self.apply_rotary_emb:
                 assert position_ids is not None
                 k_pe_ctx = self.apply_rope(q_ctx, k_pe_ctx, position_ids)
+            if self.llama_4_scaling:
+                q_ctx = self._attention_scaling(
+                    q_ctx, position_ids[..., :num_ctx_tokens]
+                )
 
             self.forward_context_dsa(
                 q_ctx,
                 compressed_kv_ctx,
                 k_pe_ctx,
@@
         if num_generations > 0:
             q_gen = q[num_ctx_tokens:, ...]
             compressed_kv_gen = compressed_kv[num_ctx_tokens:, ...]
             k_pe_gen = k_pe[num_ctx_tokens:, ...]
             latent_cache_gen = latent_cache[num_ctx_tokens:, ...]
             if self.apply_rotary_emb:
                 assert position_ids is not None
                 k_pe_gen = self.apply_rope(q_gen, k_pe_gen, position_ids)
+            if self.llama_4_scaling:
+                q_gen = self._attention_scaling(
+                    q_gen, position_ids[..., num_ctx_tokens:num_tokens]
+                )
 
             self.forward_generation_dsa(
                 q_gen,
                 compressed_kv_gen,
                 k_pe_gen,
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tensorrt_llm/_torch/modules/attention.py` around lines 1739 - 1839,
forward_dsa_attn() never reapplies the Llama-4 scale to Q (llama_4_scaling)
before dispatching to DSA paths, changing MLA math versus forward_impl(); fix by
applying the same scaling used in forward_impl() to the sliced query tensors
before DSA calls: when self.llama_4_scaling is true, multiply q_ctx and q_gen
(the local q slices) by the same scale factor or call the existing scaling
helper used in forward_impl() immediately after slicing (and before
apply_rope/forward_context_dsa/forward_generation_dsa) so both
forward_context_dsa and forward_generation_dsa receive properly scaled Qs.
🧹 Nitpick comments (1)
tensorrt_llm/_torch/attention_backend/sparse/dsa.py (1)

85-88: Keep the block-index validation outside capture.

Commenting this out entirely drops the fast-fail check from the CPU/eager path too. Guard it on device/capture state instead of deleting it, so bad block tables still fail with a clear error before the scatter path sees them.

♻️ Possible tweak
-    # if not torch.cuda.is_current_stream_capturing():
-    #     max_blocks = block_offsets.shape[1]
-    #     assert (block_indices_in_seq < max_blocks).all(), \
-    #         f"Block index out of bounds: max={max_blocks}, got indices up to {block_indices_in_seq.max().item()}"
+    if block_offsets.device.type != "cuda" or not torch.cuda.is_current_stream_capturing():
+        max_blocks = block_offsets.shape[1]
+        assert (block_indices_in_seq < max_blocks).all(), (
+            f"Block index out of bounds: max={max_blocks}, "
+            f"got indices up to {block_indices_in_seq.max().item()}"
+        )
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tensorrt_llm/_torch/attention_backend/sparse/dsa.py` around lines 85 - 88,
The removed bounds-check allows invalid block tables to slip through during
CPU/eager execution; restore the validation but skip it only when inside a CUDA
capture stream. Reintroduce the assert using the existing symbols
(block_offsets, block_indices_in_seq) and guard it with a condition using
torch.cuda.is_current_stream_capturing() so the check runs in non-capture
(CPU/eager) paths and is skipped during capture before the scatter path is
executed.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Outside diff comments:
In `@tensorrt_llm/_torch/modules/attention.py`:
- Around line 1682-1737: forward_dsa_proj unconditionally assumes MLA full mode
(calls kv_a_proj_with_mqa().split(...) and self.q_a_layernorm) but MLA-lite may
set self.is_lite / omit self.q_a_layernorm, causing crashes; update
forward_dsa_proj to explicitly handle MLA-lite by checking self.is_lite (or
existence of self.q_a_layernorm) and either (A) take the lite projection branch
used in forward_impl (use the lite split/projection path and skip
q_a_layernorm), or (B) raise a clear configuration error if DSA-lite is
unsupported; ensure you still call self.mqa.indexer.pre_indexer_proj(...) and
preserve the existing short-MHA check via _should_use_short_mha and the return
shapes.
- Around line 1739-1839: forward_dsa_attn() never reapplies the Llama-4 scale to
Q (llama_4_scaling) before dispatching to DSA paths, changing MLA math versus
forward_impl(); fix by applying the same scaling used in forward_impl() to the
sliced query tensors before DSA calls: when self.llama_4_scaling is true,
multiply q_ctx and q_gen (the local q slices) by the same scale factor or call
the existing scaling helper used in forward_impl() immediately after slicing
(and before apply_rope/forward_context_dsa/forward_generation_dsa) so both
forward_context_dsa and forward_generation_dsa receive properly scaled Qs.

---

Nitpick comments:
In `@tensorrt_llm/_torch/attention_backend/sparse/dsa.py`:
- Around line 85-88: The removed bounds-check allows invalid block tables to
slip through during CPU/eager execution; restore the validation but skip it only
when inside a CUDA capture stream. Reintroduce the assert using the existing
symbols (block_offsets, block_indices_in_seq) and guard it with a condition
using torch.cuda.is_current_stream_capturing() so the check runs in non-capture
(CPU/eager) paths and is skipped during capture before the scatter path is
executed.

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: a8d030e1-3fb1-4d91-b965-193d19f64ef9

📥 Commits

Reviewing files that changed from the base of the PR and between b636dcc and 8112487.

📒 Files selected for processing (4)
  • tensorrt_llm/_torch/attention_backend/sparse/dsa.py
  • tensorrt_llm/_torch/compilation/piecewise_optimizer.py
  • tensorrt_llm/_torch/compilation/utils.py
  • tensorrt_llm/_torch/modules/attention.py

@liji-nv liji-nv changed the title Dev liji dsa piecewise [https://nvbugs/5983390][perf] Split MLA DSA custom op for piecewise CUDA graph capture Mar 24, 2026
@liji-nv
Copy link
Copy Markdown
Collaborator Author

liji-nv commented Mar 25, 2026

/bot run

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #40315 [ run ] triggered by Bot. Commit: 55980ee Link to invocation

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #40315 [ run ] completed with state SUCCESS. Commit: 55980ee
/LLM/main/L0_MergeRequest_PR pipeline #31425 completed with status: 'FAILURE'

CI Report

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

Link to invocation

@liji-nv
Copy link
Copy Markdown
Collaborator Author

liji-nv commented Mar 26, 2026

/bot run —disable-fail-fast

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #40392 Bot args parsing error: usage: /bot [-h]
{run,kill,skip,submit,reviewers,reuse-pipeline,reuse-review} ...
/bot: error: unrecognized arguments: —disable-fail-fast

Link to invocation

@longlee0622
Copy link
Copy Markdown
Collaborator

/bot run --disable-fail-fast

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #40614 Bot args parsing error: usage: /bot [-h]
{run,kill,skip,submit,reviewers,reuse-pipeline,reuse-review} ...
/bot: error: unrecognized arguments: -disable-fail-fast

Link to invocation

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #40615 [ run ] triggered by Bot. Commit: 3f9e51d Link to invocation

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #40615 [ run ] completed with state SUCCESS. Commit: 3f9e51d
/LLM/main/L0_MergeRequest_PR pipeline #31655 completed with status: 'FAILURE'

CI Report

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

Link to invocation

@liji-nv
Copy link
Copy Markdown
Collaborator Author

liji-nv commented Mar 30, 2026

/bot run --disable-fail-fast

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #40686 [ run ] triggered by Bot. Commit: 3f9e51d Link to invocation

@longlee0622 longlee0622 enabled auto-merge (squash) March 30, 2026 08:34
@liji-nv liji-nv disabled auto-merge March 30, 2026 11:15
@liji-nv
Copy link
Copy Markdown
Collaborator Author

liji-nv commented Mar 30, 2026

/bot help

@github-actions
Copy link
Copy Markdown

GitHub Bot Help

/bot [-h] ['run', 'kill', 'skip', 'reuse-pipeline'] ...

Provide a user friendly way for developers to interact with a Jenkins server.

Run /bot [-h|--help] to print this help message.

See details below for each supported subcommand.

Details

run [--reuse-test (optional)pipeline-id --disable-fail-fast --skip-test --stage-list "A10-PyTorch-1, xxx" --gpu-type "A30, H100_PCIe" --test-backend "pytorch, cpp" --add-multi-gpu-test --only-multi-gpu-test --disable-multi-gpu-test --post-merge --extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx" --detailed-log --debug(experimental) --high-priority]

Launch build/test pipelines. All previously running jobs will be killed.

--reuse-test (optional)pipeline-id (OPTIONAL) : Allow the new pipeline to reuse build artifacts and skip successful test stages from a specified pipeline or the last pipeline if no pipeline-id is indicated. If the Git commit ID has changed, this option will be always ignored. The DEFAULT behavior of the bot is to reuse build artifacts and successful test results from the last pipeline.

--disable-reuse-test (OPTIONAL) : Explicitly prevent the pipeline from reusing build artifacts and skipping successful test stages from a previous pipeline. Ensure that all builds and tests are run regardless of previous successes.

--disable-fail-fast (OPTIONAL) : Disable fail fast on build/tests/infra failures.

--skip-test (OPTIONAL) : Skip all test stages, but still run build stages, package stages and sanity check stages. Note: Does NOT update GitHub check status.

--stage-list "A10-PyTorch-1, xxx" (OPTIONAL) : Only run the specified test stages. Examples: "A10-PyTorch-1, xxx". Note: Does NOT update GitHub check status.

--gpu-type "A30, H100_PCIe" (OPTIONAL) : Only run the test stages on the specified GPU types. Examples: "A30, H100_PCIe". Note: Does NOT update GitHub check status.

--test-backend "pytorch, cpp" (OPTIONAL) : Skip test stages which don't match the specified backends. Only support [pytorch, cpp, tensorrt, triton]. Examples: "pytorch, cpp" (does not run test stages with tensorrt or triton backend). Note: Does NOT update GitHub pipeline status.

--only-multi-gpu-test (OPTIONAL) : Only run the multi-GPU tests. Note: Does NOT update GitHub check status.

--disable-multi-gpu-test (OPTIONAL) : Disable the multi-GPU tests. Note: Does NOT update GitHub check status.

--add-multi-gpu-test (OPTIONAL) : Force run the multi-GPU tests in addition to running L0 pre-merge pipeline.

--post-merge (OPTIONAL) : Run the L0 post-merge pipeline instead of the ordinary L0 pre-merge pipeline.

--extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx" (OPTIONAL) : Run the ordinary L0 pre-merge pipeline and specified test stages. Examples: --extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx".

--detailed-log (OPTIONAL) : Enable flushing out all logs to the Jenkins console. This will significantly increase the log volume and may slow down the job.

--debug (OPTIONAL) : Experimental feature. Enable access to the CI container for debugging purpose. Note: Specify exactly one stage in the stage-list parameter to access the appropriate container environment. Note: Does NOT update GitHub check status.

--high-priority (OPTIONAL) : Run the pipeline with high priority. This option is restricted to authorized users only and will route the job to a high-priority queue.

kill

kill

Kill all running builds associated with pull request.

skip

skip --comment COMMENT

Skip testing for latest commit on pull request. --comment "Reason for skipping build/test" is required. IMPORTANT NOTE: This is dangerous since lack of user care and validation can cause top of tree to break.

reuse-pipeline

reuse-pipeline

Reuse a previous pipeline to validate current commit. This action will also kill all currently running builds associated with the pull request. IMPORTANT NOTE: This is dangerous since lack of user care and validation can cause top of tree to break.

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #40686 [ run ] completed with state SUCCESS. Commit: 3f9e51d
/LLM/main/L0_MergeRequest_PR pipeline #31715 completed with status: 'SUCCESS'

CI Report

Link to invocation

@liji-nv
Copy link
Copy Markdown
Collaborator Author

liji-nv commented Mar 30, 2026

/bot run --disable-fail-fast --add-multi-gpu-test

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #40725 [ run ] triggered by Bot. Commit: 3f9e51d Link to invocation

@liji-nv liji-nv enabled auto-merge (squash) March 30, 2026 11:54
@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #40725 [ run ] completed with state SUCCESS. Commit: 3f9e51d
/LLM/main/L0_MergeRequest_PR pipeline #31749 completed with status: 'FAILURE'

CI Report

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

Link to invocation

@liji-nv
Copy link
Copy Markdown
Collaborator Author

liji-nv commented Mar 31, 2026

/bot run --disable-fail-fast --add-multi-gpu-test

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #40824 [ run ] triggered by Bot. Commit: 3f9e51d Link to invocation

liji-nv added 4 commits March 30, 2026 22:41
Split the monolithic mla_custom_op_inplace into two ops for DSA models:
- mla_dsa_proj (Op 1): Token-wise projections (cublas_mm, rope, FP8
  quantize, weight scaling). CUDA-graph-capturable — no batch metadata
  access, no tensor slicing by num_tokens.
- mla_dsa_attn_inplace (Op 2): Batch-dependent k cache update,
  sparse_attn_indexer, and context/generation attention dispatch.
  Excluded from CUDA graph capture.

This enables the piecewise CUDA graph optimizer to capture the
compute-heavy projection portion of DSA MLA while keeping the
batch-structure-dependent attention dispatch outside the graph.

Key design decisions:
- Indexer split into pre_indexer_proj (graph-safe) and _update_k_cache
  (moved to Op 2) to avoid capturing metadata-dependent scatter ops.
- All num_tokens slicing deferred to Op 2 so graph capture sees
  fixed-shape tensors.
- Indexer intermediates (q_fp8, k_fp8, k_scale, weights) returned from
  Op 1 as List[Tensor] and passed explicitly to Op 2 — no stashing on
  self to avoid CUDA graph address aliasing.
- _should_use_short_mha disabled under torch compile for straight-line
  control flow in Op 1.
- Non-DSA MLA unchanged (still uses mla_custom_op_inplace).

Signed-off-by: Jin Li <59594262+liji-nv@users.noreply.github.com>
…ocstring

- Remove no-op `lambda: weights` in pre_indexer's maybe_execute_in_parallel;
  _weight_scale already ran in pre_indexer_proj, so just call _update_k_cache
  directly.
- Fix mla_dsa_proj docstring: k cache update happens in Op 2
  (mla_dsa_attn_inplace), not Op 1.

Signed-off-by: Jin Li <59594262+liji-nv@users.noreply.github.com>
Move _update_k_cache call to the top of sparse_attn_indexer so
the k cache is populated right before prefill chunks gather from it.
Remove pre_indexer (now redundant); forward() and forward_dsa_proj
both call pre_indexer_proj directly.

Signed-off-by: Jin Li <59594262+liji-nv@users.noreply.github.com>
- Remove dead is_dsa branch from mla_custom_op_inplace since DSA is now
  exclusively handled by the split mla_dsa_proj/mla_dsa_attn_inplace ops
- Use literal 1 for k_scale shape to match C++ fusedCatFp8 kernel output
- Simplify proj_outputs unpacking

Signed-off-by: Jin Li <59594262+liji-nv@users.noreply.github.com>
@liji-nv liji-nv force-pushed the dev-liji-dsa-piecewise branch from 3f9e51d to 66627a1 Compare March 31, 2026 05:41
@liji-nv
Copy link
Copy Markdown
Collaborator Author

liji-nv commented Mar 31, 2026

/bot run

The assertion was dropped when the old Indexer.forward was split into
pre_indexer_proj and sparse_attn_indexer. Restore it in
sparse_attn_indexer which has access to metadata.kv_cache_manager.

Signed-off-by: Jin Li <59594262+liji-nv@users.noreply.github.com>
@liji-nv
Copy link
Copy Markdown
Collaborator Author

liji-nv commented Mar 31, 2026

/bot run

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #40869 [ run ] triggered by Bot. Commit: c21c41c Link to invocation

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #40869 [ run ] completed with state SUCCESS. Commit: c21c41c
/LLM/main/L0_MergeRequest_PR pipeline #31876 completed with status: 'SUCCESS'

CI Report

Link to invocation

@liji-nv liji-nv merged commit 7e477ba into NVIDIA:main Mar 31, 2026
5 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants