Skip to content

[#11694][feat] AutoDeploy: Improve the piecewise CG memory usage#11993

Merged
nvchenghaoz merged 15 commits intoNVIDIA:mainfrom
nv-auto-deploy:chenghao/piecewise_memory_0306
Mar 18, 2026
Merged

[#11694][feat] AutoDeploy: Improve the piecewise CG memory usage#11993
nvchenghaoz merged 15 commits intoNVIDIA:mainfrom
nv-auto-deploy:chenghao/piecewise_memory_0306

Conversation

@nvchenghaoz
Copy link
Collaborator

@nvchenghaoz nvchenghaoz commented Mar 6, 2026

This PR is to improve the memory usage during the piecewise Cudagraph capturing.

Several main updates:

  1. Use the output buffer as a parameter for the attention operators so that during the piecewise cudagraph, there is no eager attention output allocation. Another advantage for this is the output tensor can be managed by the memory pool, which will hugely improve the memory usage.
  2. Update the piecewise cudagraph control code for the output buffer implementation. And wire the input with the output between dynamic (the ops that cannot be CG captured) and static (can be CG captured) regions.
  3. Some bug fixes and the updates.

With the change, the memory usage of the piecewise cudagraph is under control and can be used for most cases.

Also the perf improves a bit due to the changes -
For Qwen 3.5 isl 1k osl 1k, H100 -
Screenshot 2026-03-09 at 11 44 13 AM

The next step for piecewise CG is, the final gemm for the prefill is really big and AutoDeploy should enable the gather_logits transform by default. This region is marked as dynamic due to the unknown shape of the gemm. This will be addressed by a follow up task. My plan to fix this is to pad the final gemm to the max_batch_size. That would need some further investigation. #12049 to track the work.

Summary by CodeRabbit

Release Notes

  • New Features

    • Added optional output buffer support across custom operations for more efficient memory usage.
    • Implemented dynamic submodule classification utilities for improved graph compilation.
    • Added batch capacity-based pruning for piecewise compilation buckets.
  • Bug Fixes

    • Reduced default free GPU memory fraction to prevent OOM errors.
  • Performance Improvements

    • Enabled piecewise compilation by default for faster inference.
    • Implemented three-phase warmup/capture/replay pipeline with weak reference memory recycling.

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Mar 6, 2026

📝 Walkthrough

Walkthrough

This PR implements dynamic output buffering and enhanced memory-aware pipelining for piecewise CUDA graph capture. It introduces wrapper classes for metadata and dynamic operations, adds optional output buffer parameters across custom ops, refactors runner semantics into explicit warmup/capture/replay phases, and adds batch capacity-aware bucket filtering to the compilation pipeline.

Changes

Cohort / File(s) Summary
Configuration Updates
examples/auto_deploy/model_registry/qwen3.5_moe_35b.yaml, tensorrt_llm/_torch/auto_deploy/config/default.yaml
Reduced free_gpu_memory_fraction from 0.95 to 0.7; enabled piecewise mode in default compilation config by changing piecewise_enabled to true.
Piecewise Runner Core
tensorrt_llm/_torch/auto_deploy/compile/piecewise_runner.py
Replaced legacy ADPiecewiseRunner semantics with three-phase flow (warmup, capture, replay). Added OutputInfo dataclass and MetadataWrapper/DynamicOpWrapper classes. Refactored SegmentEntry with dynamic_out_buf and input_addresses fields. Introduced set_current_num_tokens/set_current_phase class methods and weak reference support for output buffers.
Piecewise CUDA Graph Capture
tensorrt_llm/_torch/auto_deploy/compile/backends/torch_cudagraph.py
Reworked piecewise capture to support dynamic output buffering and memory-aware logging. Added out-buffer injection for dynamic ops, shape discovery hooks, and capture_lm_head flag to optionally exclude LM head partition. Updated PiecewiseCapturedGraph.init with capture_lm_head parameter.
Piecewise Utilities
tensorrt_llm/_torch/auto_deploy/compile/piecewise_utils.py
Added dynamic submodule classification utilities: submod_has_cuda_ops, needs_out_buffer, is_metadata_prep. Introduced _INPLACE_DYNAMIC_OPS tracking and refined dynamic vs static partition detection using submodule name suffixes.
Attention Custom Ops with Out Buffer
tensorrt_llm/_torch/auto_deploy/custom_ops/attention/flashinfer_attention.py, torch_backend_attention.py, triton_attention.py, trtllm_attention.py
Added optional out: Optional[torch.Tensor] parameter to all attention MHA operations (flashinfer_mha_with_cache, torch_backend_mha_with_cache, flattened_mha_with_cache, trtllm_mha_with_cache). Updated custom_op decorators to mutates_args=("out",). Ops now write to pre-allocated buffers when out is provided and return empty tensors.
FLA Custom Ops with Out Buffer
tensorrt_llm/_torch/auto_deploy/custom_ops/fla/fla_backend_delta.py, fla_backend_gated_delta.py, torch_backend_gated_delta.py
Added optional out parameter to fla_cached_delta_rule, fla_cached_gated_delta_rule, and torch_cached_gated_delta_rule. Updated mutates_args decorators and implemented in-place writing logic with padding zeroing when out buffers are provided.
Mamba Custom Ops with Out Buffer
tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/flashinfer_backend_mamba.py, torch_backend_mamba.py, triton_backend_mamba.py
Added optional out parameter to SSM operations (_flashinfer_cached_ssm, _torch_cached_ssm, _triton_cached_ssm). Updated decorators to mutate out argument and modified output handling to reuse pre-allocated buffers with padding management.
MLA Custom Ops with Out Buffer
tensorrt_llm/_torch/auto_deploy/custom_ops/mla/flashinfer_mla.py, torch_backend_mla.py
Added optional out parameter to flashinfer_mla_with_cache and torch_backend_mla_with_cache. Updated mutates_args and adjusted dtype handling logic. Ops now support external buffer reuse with conditional casting and early returns when out is provided.
Gather Utility Op
tensorrt_llm/_torch/auto_deploy/custom_ops/utils/torch_gather_logits.py
Added optional out parameter to gather_tokens and gather_tokens_fake. Updated custom_op decorator to mutates_args=("out",). Implemented conditional logic to write results into provided buffers and return empty tensors for in-place reuse.
Compilation Pipeline
tensorrt_llm/_torch/auto_deploy/transform/library/compile_model.py
Added batch_capacity computation from max_seq_len and max_batch_size when piecewise mode is enabled. Filters piecewise_num_tokens buckets to drop those exceeding capacity with warning messages.
Test Updates
tests/unittest/_torch/auto_deploy/unit/singlegpu/compile/test_captured_graph.py, test_piecewise_runner.py
Updated test imports to use public submod_has_cuda_ops from piecewise_utils. Refactored piecewise_runner tests to cover new OutputInfo, MetadataWrapper/DynamicOpWrapper, set_dynamic_out_info/get_dynamic_out_buf APIs, and adjusted test expectations for new warmup/capture/replay phase semantics.

Sequence Diagram

sequenceDiagram
    participant Client
    participant PiecewiseCapturedGraph
    participant ADPiecewiseRunner
    participant DynamicOpWrapper
    participant MetadataWrapper
    participant CUDAGraph

    Client->>PiecewiseCapturedGraph: __init__(model, capture_lm_head=False)
    PiecewiseCapturedGraph->>ADPiecewiseRunner: Create runner instances
    
    Client->>ADPiecewiseRunner: set_current_phase("warmup")
    ADPiecewiseRunner->>DynamicOpWrapper: Execute dynamic ops (eager)
    ADPiecewiseRunner->>MetadataWrapper: Execute metadata ops (clone outputs)
    Note over ADPiecewiseRunner: Discover output shapes/dtypes
    
    Client->>ADPiecewiseRunner: set_dynamic_out_info(OutputInfo)
    ADPiecewiseRunner->>ADPiecewiseRunner: Store output shape/dtype info
    
    Client->>ADPiecewiseRunner: set_current_phase("capture")
    ADPiecewiseRunner->>ADPiecewiseRunner: Allocate dynamic output buffers
    ADPiecewiseRunner->>DynamicOpWrapper: Execute with out= parameter
    ADPiecewiseRunner->>CUDAGraph: Capture static segment
    CUDAGraph-->>ADPiecewiseRunner: Captured graph with buffer references
    
    Client->>ADPiecewiseRunner: set_current_phase("replay")
    ADPiecewiseRunner->>CUDAGraph: Replay with pre-allocated buffers
    ADPiecewiseRunner->>DynamicOpWrapper: Execute dynamic ops (reuse buffers)
    CUDAGraph-->>ADPiecewiseRunner: Results written to buffers
    ADPiecewiseRunner-->>Client: Output via buffer
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~60 minutes

🚥 Pre-merge checks | ✅ 1 | ❌ 2

❌ Failed checks (2 warnings)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 45.26% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
Description check ⚠️ Warning PR description lacks required template sections including detailed technical explanation, test coverage list, and completion of all checklist items. Complete the PR description by adding: (1) a detailed technical explanation of the implementation changes in the 'Description' section, (2) a comprehensive list of test cases in 'Test Coverage', and (3) explicit confirmation of all PR Checklist items.
✅ Passed checks (1 passed)
Check name Status Explanation
Title check ✅ Passed The PR title clearly summarizes the main change: improving memory usage in the piecewise CUDA graph path within AutoDeploy. It is concise and directly related to the changeset.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Tip

Try Coding Plans. Let us write the prompt for your AI agent so you can ship faster (with fewer bugs).
Share your feedback on Discord.


Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 9

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (7)
tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/triton_backend_mamba.py (1)

1-1: ⚠️ Potential issue | 🟡 Minor

Update the NVIDIA header year to 2026.

This file has a 2026 modification, but the copyright header still ends at 2025.

As per coding guidelines, "All TensorRT-LLM source files should contain an NVIDIA copyright header with the year of the latest meaningful modification."

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/triton_backend_mamba.py` at
line 1, Update the NVIDIA SPDX copyright header at the top of
triton_backend_mamba.py to reflect the latest modification year 2026 (change
"2022-2025" to "2022-2026") so the file header matches the required policy.
tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/torch_backend_mamba.py (1)

127-150: ⚠️ Potential issue | 🔴 Critical

Declare ssm_state_cache as mutated in the custom_op decorator.

_torch_cached_ssm writes to ssm_state_cache at lines 190 and 259 via index_copy_, but the decorator declares only out in mutates_args. PyTorch's custom_op uses this metadata to establish the correct alias/mutation model during tracing and CUDA graph capture.

Minimal fix
-@torch.library.custom_op("auto_deploy::torch_cached_ssm", mutates_args=("out",))
+@torch.library.custom_op(
+    "auto_deploy::torch_cached_ssm", mutates_args=("ssm_state_cache", "out")
+)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/torch_backend_mamba.py`
around lines 127 - 150, The custom op decorator for _torch_cached_ssm
incorrectly only lists "out" in mutates_args while the function mutates
ssm_state_cache (via index_copy_); update the `@torch.library.custom_op`
declaration for function _torch_cached_ssm to include ssm_state_cache (or its
tensor argument name) in mutates_args so PyTorch knows it is mutated during
tracing/CUDA graph capture; locate the decorator above def _torch_cached_ssm and
add the ssm_state_cache identifier to the mutates_args tuple alongside "out".
tensorrt_llm/_torch/auto_deploy/custom_ops/attention/flashinfer_attention.py (1)

318-343: ⚠️ Potential issue | 🔴 Critical

Add kv_cache to mutates_args.

The function mutates kv_cache in-place at line 370 via flashinfer.page.append_paged_kv_cache(), but the decorator declares only mutates_args=("out",). This omission causes PyTorch's tracing and CUDA graph capture to receive incorrect side-effect information, leading to potential aliasing and correctness issues.

Minimal fix
-@torch.library.custom_op("auto_deploy::flashinfer_attention_mha_with_cache", mutates_args=("out",))
+@torch.library.custom_op(
+    "auto_deploy::flashinfer_attention_mha_with_cache",
+    mutates_args=("kv_cache", "out"),
+)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tensorrt_llm/_torch/auto_deploy/custom_ops/attention/flashinfer_attention.py`
around lines 318 - 343, The decorator on flashinfer_mha_with_cache currently
declares mutates_args=("out",) but the function mutates kv_cache in-place via
flashinfer.page.append_paged_kv_cache(), so update the torch.library.custom_op
decorator to include "kv_cache" in mutates_args (e.g.,
mutates_args=("out","kv_cache")) to accurately reflect side effects; ensure the
decorator signature around flashinfer_mha_with_cache is the only change so
tracing and CUDA graph capture see the correct aliasing for kv_cache.
tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/flashinfer_backend_mamba.py (1)

31-55: ⚠️ Potential issue | 🔴 Critical

Declare the cache mutation in the custom-op schema.

This op mutates ssm_state_cache both in _run_ssm_prefill (via index_copy_ at line 88) and via flashinfer.mamba.selective_state_update at line 129. PyTorch requires all mutated tensors to be declared in mutates_args, otherwise behavior is undefined and opcheck will fail.

Minimal fix
-@torch.library.custom_op("auto_deploy::flashinfer_cached_ssm", mutates_args=("out",))
+@torch.library.custom_op(
+    "auto_deploy::flashinfer_cached_ssm", mutates_args=("ssm_state_cache", "out")
+)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/flashinfer_backend_mamba.py`
around lines 31 - 55, The custom op decorator for _flashinfer_cached_ssm
currently only lists "out" in mutates_args but the op also mutates
ssm_state_cache (seen in _run_ssm_prefill via index_copy_ and in
flashinfer.mamba.selective_state_update); update the
`@torch.library.custom_op`("auto_deploy::flashinfer_cached_ssm",
mutates_args=(...)) declaration to include "ssm_state_cache" (alongside "out")
so PyTorch knows this tensor is mutated.
tensorrt_llm/_torch/auto_deploy/custom_ops/attention/torch_backend_attention.py (1)

345-392: ⚠️ Potential issue | 🟠 Major

Write attention output into out directly.

y is allocated unconditionally on Line 346 and then copied into out on Lines 387-392. That preserves the same peak allocation as before for one of the largest tensors in piecewise capture, so this path can still OOM even with a preallocated buffer. Use out.view(*bs_view, num_heads, v_head_dim) as y when out is provided and zero the padded tail in place.

♻️ Proposed change
-    # Preallocate output tensor
-    y = q.new_empty(*bs_view, num_heads, v_head_dim).contiguous()
+    # Reuse caller-owned storage when available.
+    if out is not None:
+        y = out.view(*bs_view, num_heads, v_head_dim)
+    else:
+        y = q.new_empty(*bs_view, num_heads, v_head_dim).contiguous()
@@
-    if out is not None:
-        out_flat = out.view(*bs_view, num_heads, v_head_dim)
-        out_flat[:num_total_tokens].copy_(y[:num_total_tokens])
-        if num_total_tokens < bs:
-            out_flat[num_total_tokens:].zero_()
-        return out.new_empty(0)
-
-    # Zero padding positions so downstream ops don't see garbage (piecewise CG)
     if num_total_tokens < bs:
         y[num_total_tokens:].zero_()
+    if out is not None:
+        return out.new_empty(0)
 
     return y.view(*output_shape)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In
`@tensorrt_llm/_torch/auto_deploy/custom_ops/attention/torch_backend_attention.py`
around lines 345 - 392, The code unconditionally allocates y then copies into
out, causing peak memory that can OOM; change allocation so that if out is
provided you set y = out.view(*bs_view, num_heads, v_head_dim) (ensuring correct
dtype/contiguity) and otherwise allocate q.new_empty(...) as before, then pass y
into _torch_generate_mha/_torch_context_mha and after attention zero the padded
tail in-place (out_flat[num_total_tokens:].zero_()) instead of copying; update
usages of y, out_flat, num_total_tokens, bs_view, num_heads, v_head_dim,
_torch_generate_mha and _torch_context_mha accordingly.
tensorrt_llm/_torch/auto_deploy/custom_ops/fla/fla_backend_delta.py (1)

72-79: ⚠️ Potential issue | 🟠 Major

The delta-rule out path still pays for a full temporary tensor.

y = torch.empty_like(v) is always allocated and fully populated before its contents are copied into out. That means the new out parameter doesn't actually remove the extra outer output buffer for this op. Reuse out as the backing storage for y_flat when it's provided, and zero the padded tail in place.

♻️ Proposed change
-    # pre-allocate output
-    y = torch.empty_like(v, memory_format=torch.contiguous_format)
+    # Reuse caller-owned storage when available.
+    y = out if out is not None else torch.empty_like(v, memory_format=torch.contiguous_format)
     y_flat = y.view(b * s, num_heads, -1)
@@
-    if out is not None:
-        out_flat = out.view(b * s, num_heads, -1)
-        out_flat[:num_total_tokens].copy_(y_flat[:num_total_tokens])
-        if num_total_tokens < b * s:
-            out_flat[num_total_tokens:].zero_()
-        return out.new_empty(0)
+    if num_total_tokens < b * s:
+        y_flat[num_total_tokens:].zero_()
+    if out is not None:
+        return out.new_empty(0)
 
     return y

Also applies to: 128-133

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tensorrt_llm/_torch/auto_deploy/custom_ops/fla/fla_backend_delta.py` around
lines 72 - 79, The delta-rule path currently allocates y = torch.empty_like(v)
and copies into out, negating the purpose of the out parameter; modify the logic
in fla_backend_delta.py (the block creating y and y_flat from v and using
batch_info_host -> num_prefill/num_decode) to reuse the provided out buffer as
backing storage: when out is not None, reshape/view out into y_flat (matching
b*s, num_heads, -1) instead of allocating torch.empty_like(v), and write results
directly into that view; ensure any padded tail elements (from the last row if
the flattened size isn’t a multiple) are zeroed in-place on the out buffer
rather than creating temporaries; apply the same change to the analogous section
around lines 128-133 that also allocates y.
tensorrt_llm/_torch/auto_deploy/compile/piecewise_utils.py (1)

1-18: ⚠️ Potential issue | 🟠 Major

Restore the required NVIDIA license header.

This modified Python source file now begins with the module docstring, but the repository rule requires the NVIDIA Apache-2.0 copyright block on all source files.

As per coding guidelines, **/*.{cpp,h,cu,cuh,hpp,py}: All TensorRT-LLM source files should contain an NVIDIA copyright header with the year of the latest meaningful modification. The header should be an Apache 2.0 license block as specified.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tensorrt_llm/_torch/auto_deploy/compile/piecewise_utils.py` around lines 1 -
18, Add the required NVIDIA Apache-2.0 license header at the very top of this
source file (above the module docstring) using the year of the latest meaningful
modification; ensure the header matches the repository's canonical NVIDIA
copyright/license block and remains in place before any imports or docstrings so
tools and scanners detect it. Locate this file by the module-level docstring and
symbols such as GraphModule, Node, split_module, and ad_logger to confirm you're
editing tensorrt_llm/_torch/auto_deploy/compile/piecewise_utils.py, then commit
the file with the header prepended.
🧹 Nitpick comments (3)
tensorrt_llm/_torch/auto_deploy/custom_ops/mla/torch_backend_mla.py (1)

389-405: Use the external buffer as the helper output.

Both helper paths already accept an output tensor, but this branch still allocates y and then copies it into out. That keeps a full-sized temporary alive and adds an extra device copy on the new buffer-reuse path, which leaves part of the intended memory win on the table.

♻️ Suggested simplification
+    out_flat = out.view(bs, num_heads, v_head_dim) if out is not None else None
+
     if s == 1:
         # =====================================================================
         # Generate phase: Use weight absorption
         # =====================================================================
-        y = q_nope.new_empty(b, num_heads, v_head_dim).contiguous()
+        y = out_flat if out_flat is not None else q_nope.new_empty(
+            b, num_heads, v_head_dim
+        ).contiguous()

         _torch_mla_generate_with_absorption(
             q_nope,
             q_pe,
@@
             y,
         )
     else:
@@
-        y = q_nope.new_empty(*bs_view, num_heads, v_head_dim).contiguous()
+        y = out_flat if out_flat is not None else q_nope.new_empty(
+            *bs_view, num_heads, v_head_dim
+        ).contiguous()

         _torch_mla_context_with_expansion(
             q_nope_flat,
             q_pe_flat,
@@
             y,
         )

     if out is not None:
-        out_flat = out.view(bs, num_heads, v_head_dim)
-        out_flat[:num_total_tokens].copy_(y[:num_total_tokens])
         if num_total_tokens < bs:
-            out_flat[num_total_tokens:].zero_()
+            y[num_total_tokens:].zero_()
         return out.new_empty(0)

Also applies to: 418-444

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tensorrt_llm/_torch/auto_deploy/custom_ops/mla/torch_backend_mla.py` around
lines 389 - 405, The branch currently allocates a temporary tensor y and later
copies it into out; instead reuse the external buffer out as the helper output
to avoid the full-sized temp and extra device copy: replace the y allocation
with a view/slice of out shaped (b, num_heads, v_head_dim) (ensuring contiguity
if required) and pass that buffer into _torch_mla_generate_with_absorption
instead of y, then remove the final copy into out; do the analogous change in
the other helper branch (the block around lines 418-444 that calls the other MLA
helper) so both helper calls directly write into out (refer to variables y, out,
_torch_mla_generate_with_absorption and the other MLA helper name) and ensure
shapes/contiguity match before calling.
tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/triton_backend_mamba.py (1)

31-55: Document the new out buffer contract on this custom op.

out now changes both buffer ownership and the observable return behavior, but there is still no docstring describing the expected shape/view-compatibility for out or that callers should ignore the return value when out is provided. That is easy to misuse from torch.ops.auto_deploy.triton_cached_ssm.default.

As per coding guidelines, "For Python interfaces that may be used outside a file, prefer docstrings over comments" and "Use Google style docstrings for Python classes and functions, which can be parsed by Sphinx."

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/triton_backend_mamba.py`
around lines 31 - 55, Add a Google-style docstring to the custom op function
_triton_cached_ssm (decorated via torch.library.custom_op) that documents the
new out buffer contract: explicitly state the expected shape and
view-compatibility requirements for out relative to the computed output
(including any allowed broadcasting or required contiguous/strided layouts),
describe that providing out transfers buffer ownership/observers semantics so
the op will write into out in-place, and state callers MUST ignore the Python
return value when supplying out (and instead read from the out buffer); also
mention interaction with torch.ops.auto_deploy.triton_cached_ssm.default and any
error/validation behavior when out has an incompatible shape or dtype.
tensorrt_llm/_torch/auto_deploy/compile/backends/torch_cudagraph.py (1)

405-439: Remove or gate the temporary memory probes before merge.

The TODO on Line 405 is still in the capture path, and _install_mem_hooks() adds per-segment get_mem_info calls and log spam for every bucket. Please either drop this from the PR or guard it behind a debug-only switch.

If you want, I can help split the memory-probe code into a follow-up debug-only change.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tensorrt_llm/_torch/auto_deploy/compile/backends/torch_cudagraph.py` around
lines 405 - 439, The temporary memory-probe logic (_fmt_mem and
_install_mem_hooks) is currently active and spams logs; either remove these
methods and their calls or guard them behind a debug-only switch (e.g., a
class/static flag MEMORY_PROBES_ENABLED or check
ad_logger.isEnabledFor(logging.DEBUG) / an env var) so probes run only when
explicitly enabled. Locate _install_mem_hooks (and any callers that attach its
returned handles) and wrap the call with the guard or early-return when the flag
is false; likewise make _fmt_mem and the get_mem_info calls no-ops unless the
same flag is enabled (or remove them entirely). Ensure checks reference split_gm
and ADPiecewiseRunner as before and keep ad_logger usage only under the debug
gate so normal runs are not impacted.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@tensorrt_llm/_torch/auto_deploy/compile/piecewise_runner.py`:
- Around line 200-225: The bug is that dynamic output info and buffers are
stored per-runner (fields _next_dynamic_out_info and
SegmentEntry.dynamic_out_buf) so when one ADPiecewiseRunner precedes multiple
DynamicOpWrapper instances the later shape-discovery results overwrite earlier
ones; change the storage to key by the dynamic submodule id: make
_next_dynamic_out_info a dict keyed by dynamic submodule identifier and change
SegmentEntry.dynamic_out_buf to a mapping of num_tokens->(submodule_id->buffer)
or similar, update get_dynamic_out_buf to accept/lookup the submodule id and
return the buffer for that id/num_tokens, and adjust prepare() in
PiecewiseCapturedGraph and any code paths that set/read these fields (including
the logic around lines referenced 248-264 and 311-338) to pass and use the
dynamic submodule id when storing and retrieving OutputInfo and buffers.
- Around line 234-241: The bug is that SegmentEntry() is created unconditionally
and inserted into self.entries[num_tokens] before checking phase, so a warmup
call (phase == "warmup") wrongly populates entries; move the allocation and
assignment of entry = SegmentEntry() and self.entries[num_tokens] = entry so
they occur only in the capture/replay branches (i.e., after the phase check),
leaving the warmup branch to just return self.submodule(*args, **kwargs) without
touching self.entries; update any code paths that reference entry to first fetch
it from self.entries when in capture or replay.

In `@tensorrt_llm/_torch/auto_deploy/custom_ops/attention/triton_attention.py`:
- Around line 330-334: The fake output-shape logic is wrong: instead of using
v.shape[-1] directly, compute v_head_dim = v_cache.shape[-1] and num_heads =
q.shape[2] // qk_head_dim (matching the real op's qk head division) then build
output_shape = (*q.shape[:-1], num_heads * v_head_dim) and return
q.new_empty(*output_shape).contiguous() (still early-returning out if provided);
update the fake function to use these symbols (q, v_cache, qk_head_dim,
num_heads, v_head_dim, output_shape) so its output matches the real op's
y.view(*output_shape) behavior.
- Around line 293-298: The fake-registration path must return the same tensor
shape as the real Triton op when an output buffer is provided: change the fake
implementation to detect when the `out` argument is non-None and return an empty
tensor with the same type/device as `out` (i.e., equivalent to
`out.new_empty(0)`) instead of returning `out` itself; ensure the logic that
writes into `out_flat` (variables `out`, `out_flat`, `bs`, `num_heads`,
`v_head_dim`, `num_total_tokens`, `y`) remains unchanged so the fake propagation
matches the real op's return shape.

In `@tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/triton_backend_mamba.py`:
- Around line 140-145: The fake-tensor branch currently returns the full buffer
`out` while the eager branch returns an empty tensor, causing shape/aliasing
mismatch; modify the fake implementation so that when `out` is not None you
mirror the eager behavior—apply the same padding-zeroing of
`preallocated_ssm_out` when `num_total_tokens < bs` (as done in the eager path)
and return an empty tensor created with `out.new_empty(0)` (preserving `out`'s
dtype/device) instead of returning `out`.

In `@tensorrt_llm/_torch/auto_deploy/custom_ops/mla/flashinfer_mla.py`:
- Around line 530-539: The current code rebinds compressed_kv_flat and kpe_flat
to FP8 which also alters the tensors used later for attention; instead create
new variables (e.g., compressed_kv_flat_for_cache and kpe_flat_for_cache),
perform the torch.float8_e4m3fn cast only on those when ckv_cache.dtype ==
torch.float8_e4m3fn, and pass those dedicated cast tensors into
flashinfer.page.append_paged_mla_kv_cache(compressed_kv_flat_for_cache[:num_total_tokens],
kpe_flat_for_cache[:num_total_tokens]) so the original compressed_kv_flat and
kpe_flat remain at activation precision for prefill paths.
- Around line 754-759: The fake implementation in flashinfer_mla currently
returns the full out buffer while the real CUDA path returns a 0-length sentinel
(out.new_empty(0)) when out is provided, causing a contract mismatch; change the
fake branch that handles the `out` argument to return the same 0-length sentinel
(use out.new_empty(0)) after performing any in-place modifications (e.g.,
y[num_total_tokens:].zero_()), ensuring variables referenced (out, y,
num_total_tokens, bs) and the function flashinfer_mla keep the same
side-effect-only behavior and return shape as the real implementation.

In `@tensorrt_llm/_torch/auto_deploy/custom_ops/utils/torch_gather_logits.py`:
- Around line 50-60: The code currently always clones hidden_states into result
even when an output buffer (out) is provided, negating buffer reuse; change the
logic so cloning (hidden_states.clone(memory_format=torch.contiguous_format))
only happens when out is None, otherwise create the reshaped view directly from
hidden_states (use the same contiguous/reshape logic used for result.view(...))
and then copy that view into out via out.copy_(); specifically modify the paths
around result, hidden_states, result.view(...), and out.copy_ so that clone is
skipped when out is present and cloning is only performed in the no-out branch.

In `@tensorrt_llm/_torch/auto_deploy/transform/library/compile_model.py`:
- Around line 99-101: The current code computes batch_capacity = (max_batch - 1)
* max_seq + 1 and then applies that ceiling to every piecewise_num_tokens
bucket, which incorrectly clips pure-prefill buckets; change the logic so
piecewise_enabled still generates buckets up to max_num_tokens for prefill-only
use, and only apply batch_capacity when constructing or validating buckets that
will be used for mixed-batch inference. Concretely, keep piecewise_num_tokens
(and any loop that builds it) able to reach max_num_tokens, and when you need a
mixed-batch token limit (e.g., when mixing prefill + decode or validating
mixed-batch requests) use min(bucket, batch_capacity) rather than mutating or
clamping the original bucket; reference variables: max_seq, max_batch,
batch_capacity, piecewise_enabled, piecewise_num_tokens, and max_num_tokens.

---

Outside diff comments:
In `@tensorrt_llm/_torch/auto_deploy/compile/piecewise_utils.py`:
- Around line 1-18: Add the required NVIDIA Apache-2.0 license header at the
very top of this source file (above the module docstring) using the year of the
latest meaningful modification; ensure the header matches the repository's
canonical NVIDIA copyright/license block and remains in place before any imports
or docstrings so tools and scanners detect it. Locate this file by the
module-level docstring and symbols such as GraphModule, Node, split_module, and
ad_logger to confirm you're editing
tensorrt_llm/_torch/auto_deploy/compile/piecewise_utils.py, then commit the file
with the header prepended.

In
`@tensorrt_llm/_torch/auto_deploy/custom_ops/attention/flashinfer_attention.py`:
- Around line 318-343: The decorator on flashinfer_mha_with_cache currently
declares mutates_args=("out",) but the function mutates kv_cache in-place via
flashinfer.page.append_paged_kv_cache(), so update the torch.library.custom_op
decorator to include "kv_cache" in mutates_args (e.g.,
mutates_args=("out","kv_cache")) to accurately reflect side effects; ensure the
decorator signature around flashinfer_mha_with_cache is the only change so
tracing and CUDA graph capture see the correct aliasing for kv_cache.

In
`@tensorrt_llm/_torch/auto_deploy/custom_ops/attention/torch_backend_attention.py`:
- Around line 345-392: The code unconditionally allocates y then copies into
out, causing peak memory that can OOM; change allocation so that if out is
provided you set y = out.view(*bs_view, num_heads, v_head_dim) (ensuring correct
dtype/contiguity) and otherwise allocate q.new_empty(...) as before, then pass y
into _torch_generate_mha/_torch_context_mha and after attention zero the padded
tail in-place (out_flat[num_total_tokens:].zero_()) instead of copying; update
usages of y, out_flat, num_total_tokens, bs_view, num_heads, v_head_dim,
_torch_generate_mha and _torch_context_mha accordingly.

In `@tensorrt_llm/_torch/auto_deploy/custom_ops/fla/fla_backend_delta.py`:
- Around line 72-79: The delta-rule path currently allocates y =
torch.empty_like(v) and copies into out, negating the purpose of the out
parameter; modify the logic in fla_backend_delta.py (the block creating y and
y_flat from v and using batch_info_host -> num_prefill/num_decode) to reuse the
provided out buffer as backing storage: when out is not None, reshape/view out
into y_flat (matching b*s, num_heads, -1) instead of allocating
torch.empty_like(v), and write results directly into that view; ensure any
padded tail elements (from the last row if the flattened size isn’t a multiple)
are zeroed in-place on the out buffer rather than creating temporaries; apply
the same change to the analogous section around lines 128-133 that also
allocates y.

In
`@tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/flashinfer_backend_mamba.py`:
- Around line 31-55: The custom op decorator for _flashinfer_cached_ssm
currently only lists "out" in mutates_args but the op also mutates
ssm_state_cache (seen in _run_ssm_prefill via index_copy_ and in
flashinfer.mamba.selective_state_update); update the
`@torch.library.custom_op`("auto_deploy::flashinfer_cached_ssm",
mutates_args=(...)) declaration to include "ssm_state_cache" (alongside "out")
so PyTorch knows this tensor is mutated.

In `@tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/torch_backend_mamba.py`:
- Around line 127-150: The custom op decorator for _torch_cached_ssm incorrectly
only lists "out" in mutates_args while the function mutates ssm_state_cache (via
index_copy_); update the `@torch.library.custom_op` declaration for function
_torch_cached_ssm to include ssm_state_cache (or its tensor argument name) in
mutates_args so PyTorch knows it is mutated during tracing/CUDA graph capture;
locate the decorator above def _torch_cached_ssm and add the ssm_state_cache
identifier to the mutates_args tuple alongside "out".

In `@tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/triton_backend_mamba.py`:
- Line 1: Update the NVIDIA SPDX copyright header at the top of
triton_backend_mamba.py to reflect the latest modification year 2026 (change
"2022-2025" to "2022-2026") so the file header matches the required policy.

---

Nitpick comments:
In `@tensorrt_llm/_torch/auto_deploy/compile/backends/torch_cudagraph.py`:
- Around line 405-439: The temporary memory-probe logic (_fmt_mem and
_install_mem_hooks) is currently active and spams logs; either remove these
methods and their calls or guard them behind a debug-only switch (e.g., a
class/static flag MEMORY_PROBES_ENABLED or check
ad_logger.isEnabledFor(logging.DEBUG) / an env var) so probes run only when
explicitly enabled. Locate _install_mem_hooks (and any callers that attach its
returned handles) and wrap the call with the guard or early-return when the flag
is false; likewise make _fmt_mem and the get_mem_info calls no-ops unless the
same flag is enabled (or remove them entirely). Ensure checks reference split_gm
and ADPiecewiseRunner as before and keep ad_logger usage only under the debug
gate so normal runs are not impacted.

In `@tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/triton_backend_mamba.py`:
- Around line 31-55: Add a Google-style docstring to the custom op function
_triton_cached_ssm (decorated via torch.library.custom_op) that documents the
new out buffer contract: explicitly state the expected shape and
view-compatibility requirements for out relative to the computed output
(including any allowed broadcasting or required contiguous/strided layouts),
describe that providing out transfers buffer ownership/observers semantics so
the op will write into out in-place, and state callers MUST ignore the Python
return value when supplying out (and instead read from the out buffer); also
mention interaction with torch.ops.auto_deploy.triton_cached_ssm.default and any
error/validation behavior when out has an incompatible shape or dtype.

In `@tensorrt_llm/_torch/auto_deploy/custom_ops/mla/torch_backend_mla.py`:
- Around line 389-405: The branch currently allocates a temporary tensor y and
later copies it into out; instead reuse the external buffer out as the helper
output to avoid the full-sized temp and extra device copy: replace the y
allocation with a view/slice of out shaped (b, num_heads, v_head_dim) (ensuring
contiguity if required) and pass that buffer into
_torch_mla_generate_with_absorption instead of y, then remove the final copy
into out; do the analogous change in the other helper branch (the block around
lines 418-444 that calls the other MLA helper) so both helper calls directly
write into out (refer to variables y, out, _torch_mla_generate_with_absorption
and the other MLA helper name) and ensure shapes/contiguity match before
calling.

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: b08a4a2f-ddf6-4622-86c5-48f2f568cc31

📥 Commits

Reviewing files that changed from the base of the PR and between d1ba3b8 and d15c20d.

📒 Files selected for processing (21)
  • examples/auto_deploy/model_registry/configs/qwen3.5_moe_35b.yaml
  • tensorrt_llm/_torch/auto_deploy/compile/backends/torch_cudagraph.py
  • tensorrt_llm/_torch/auto_deploy/compile/piecewise_runner.py
  • tensorrt_llm/_torch/auto_deploy/compile/piecewise_utils.py
  • tensorrt_llm/_torch/auto_deploy/config/default.yaml
  • tensorrt_llm/_torch/auto_deploy/custom_ops/attention/flashinfer_attention.py
  • tensorrt_llm/_torch/auto_deploy/custom_ops/attention/torch_backend_attention.py
  • tensorrt_llm/_torch/auto_deploy/custom_ops/attention/triton_attention.py
  • tensorrt_llm/_torch/auto_deploy/custom_ops/attention/trtllm_attention.py
  • tensorrt_llm/_torch/auto_deploy/custom_ops/fla/fla_backend_delta.py
  • tensorrt_llm/_torch/auto_deploy/custom_ops/fla/fla_backend_gated_delta.py
  • tensorrt_llm/_torch/auto_deploy/custom_ops/fla/torch_backend_gated_delta.py
  • tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/flashinfer_backend_mamba.py
  • tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/torch_backend_mamba.py
  • tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/triton_backend_mamba.py
  • tensorrt_llm/_torch/auto_deploy/custom_ops/mla/flashinfer_mla.py
  • tensorrt_llm/_torch/auto_deploy/custom_ops/mla/torch_backend_mla.py
  • tensorrt_llm/_torch/auto_deploy/custom_ops/utils/torch_gather_logits.py
  • tensorrt_llm/_torch/auto_deploy/transform/library/compile_model.py
  • tests/unittest/_torch/auto_deploy/unit/singlegpu/compile/test_captured_graph.py
  • tests/unittest/_torch/auto_deploy/unit/singlegpu/compile/test_piecewise_runner.py

Signed-off-by: Chenghao Zhang <211069071+nvchenghaoz@users.noreply.github.com>
Signed-off-by: Chenghao Zhang <211069071+nvchenghaoz@users.noreply.github.com>
Signed-off-by: Chenghao Zhang <211069071+nvchenghaoz@users.noreply.github.com>
Signed-off-by: Chenghao Zhang <211069071+nvchenghaoz@users.noreply.github.com>
Signed-off-by: Chenghao Zhang <211069071+nvchenghaoz@users.noreply.github.com>
Made-with: Cursor
Signed-off-by: Chenghao Zhang <211069071+nvchenghaoz@users.noreply.github.com>
@nvchenghaoz nvchenghaoz force-pushed the chenghao/piecewise_memory_0306 branch from d149b0c to d808289 Compare March 9, 2026 18:57
Signed-off-by: Chenghao Zhang <211069071+nvchenghaoz@users.noreply.github.com>
Signed-off-by: Chenghao Zhang <211069071+nvchenghaoz@users.noreply.github.com>
Signed-off-by: Chenghao Zhang <211069071+nvchenghaoz@users.noreply.github.com>
Signed-off-by: Chenghao Zhang <211069071+nvchenghaoz@users.noreply.github.com>
Signed-off-by: Chenghao Zhang <211069071+nvchenghaoz@users.noreply.github.com>
@nvchenghaoz
Copy link
Collaborator Author

/bot run --extra-stage "DGX_B200-4_GPUs-AutoDeploy-1, DGX_H100-4_GPUs-AutoDeploy-1" --disable-fail-fast

@tensorrt-cicd
Copy link
Collaborator

PR_Github #38924 [ run ] triggered by Bot. Commit: cf4793a Link to invocation

Signed-off-by: Chenghao Zhang <211069071+nvchenghaoz@users.noreply.github.com>
@nvchenghaoz
Copy link
Collaborator Author

/bot run --extra-stage "DGX_B200-4_GPUs-AutoDeploy-1, DGX_H100-4_GPUs-AutoDeploy-1" --disable-fail-fast

@tensorrt-cicd
Copy link
Collaborator

PR_Github #38932 [ run ] triggered by Bot. Commit: b467750 Link to invocation

@tensorrt-cicd
Copy link
Collaborator

PR_Github #38932 [ run ] completed with state DISABLED
CI server is currently disabled for scheduled maintenance. Estimated completion time: 9 PM PST on 3/14.

Link to invocation

@tensorrt-cicd
Copy link
Collaborator

PR_Github #38924 [ run ] completed with state FAILURE. Commit: cf4793a
/LLM/main/L0_MergeRequest_PR pipeline #30229 completed with status: 'FAILURE'

CI Report

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

Link to invocation

@nvchenghaoz
Copy link
Collaborator Author

/bot run --extra-stage "DGX_B200-4_GPUs-AutoDeploy-1, DGX_H100-4_GPUs-AutoDeploy-1" --disable-fail-fast

@tensorrt-cicd
Copy link
Collaborator

PR_Github #39292 [ run ] triggered by Bot. Commit: b467750 Link to invocation

@tensorrt-cicd
Copy link
Collaborator

PR_Github #39292 [ run ] completed with state SUCCESS. Commit: b467750
/LLM/main/L0_MergeRequest_PR pipeline #30543 completed with status: 'FAILURE'

CI Report

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

Link to invocation

…mory_0306

Signed-off-by: Chenghao Zhang <211069071+nvchenghaoz@users.noreply.github.com>
Made-with: Cursor

# Conflicts:
#	tensorrt_llm/_torch/auto_deploy/custom_ops/attention/trtllm_attention.py
@nvchenghaoz
Copy link
Collaborator Author

/bot run --extra-stage "DGX_B200-4_GPUs-AutoDeploy-1, DGX_H100-4_GPUs-AutoDeploy-1" --disable-fail-fast

@tensorrt-cicd
Copy link
Collaborator

PR_Github #39323 [ run ] triggered by Bot. Commit: 68b9565 Link to invocation

@tensorrt-cicd
Copy link
Collaborator

PR_Github #39323 [ run ] completed with state SUCCESS. Commit: 68b9565
/LLM/main/L0_MergeRequest_PR pipeline #30570 completed with status: 'FAILURE'

CI Report

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

Link to invocation

Signed-off-by: Chenghao Zhang <211069071+nvchenghaoz@users.noreply.github.com>
…mory_0306

Signed-off-by: Chenghao Zhang <211069071+nvchenghaoz@users.noreply.github.com>
Made-with: Cursor

# Conflicts:
#	tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/torch_backend_mamba.py
#	tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/triton_backend_mamba.py
@nvchenghaoz
Copy link
Collaborator Author

/bot run --extra-stage "DGX_B200-4_GPUs-AutoDeploy-1, DGX_H100-4_GPUs-AutoDeploy-1" --disable-fail-fast

@nvchenghaoz
Copy link
Collaborator Author

nvchenghaoz commented Mar 18, 2026

Some numbers from local test -

GLM 4.7 Flash 1k/1k

Concurrency TTFT avg (ms) Prefill TPS avg (tok/s) TPOT avg (ms)
no_pcg pcg Δ no_pcg pcg Δ no_pcg pcg Δ
1 58.40 48.62 -16.7% 17,123 20,566 +20.1% 7.226 7.218 -0.1%
4 136.87 113.21 -17.3% 7,531 9,374 +24.5% 10.127 9.998 -1.3%
16 309.36 305.52 -1.2% 3,464 3,902 +12.7% 15.689 15.690 +0.0%
64 485.93 482.26 -0.8% 2,791 2,891 +3.6% 23.516 23.540 +0.1%

Nano v3 fp8 1k/1k

Concurrency TTFT avg (ms) Prefill TPS avg (tok/s) TPOT avg (ms)
no_pcg pcg Δ no_pcg pcg Δ no_pcg pcg Δ
1 50.58 45.70 -9.6% 19,781 21,896 +10.7% 3.436 3.397 -1.1%
4 100.72 78.29 -22.3% 10,006 13,142 +31.4% 4.779 4.652 -2.7%
16 3,798.15 1,210.56 -68.1% 4,066 4,949 +21.7% 7.882 8.065 +2.3%
64 550.60 405.54 -26.3% 2,537 2,938 +15.8% 13.020 13.128 +0.8%
256 1,251.85 1,482.86 +18.5% 1,464 1,109 -24.3% 22.032 21.827 -0.9%

@tensorrt-cicd
Copy link
Collaborator

PR_Github #39481 [ run ] triggered by Bot. Commit: 5e9b731 Link to invocation

@nvchenghaoz
Copy link
Collaborator Author

image

@tensorrt-cicd
Copy link
Collaborator

PR_Github #39481 [ run ] completed with state SUCCESS. Commit: 5e9b731
/LLM/main/L0_MergeRequest_PR pipeline #30709 completed with status: 'SUCCESS'
Pipeline passed with automatic retried tests. Check the rerun report for details.

CI Report

Link to invocation

@nvchenghaoz nvchenghaoz merged commit 3e0ae7f into NVIDIA:main Mar 18, 2026
5 checks passed
limin2021 pushed a commit to limin2021/TensorRT-LLM that referenced this pull request Mar 19, 2026
NVIDIA#11993)

Signed-off-by: Chenghao Zhang <211069071+nvchenghaoz@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants