[#11694][feat] AutoDeploy: Improve the piecewise CG memory usage#11993
[#11694][feat] AutoDeploy: Improve the piecewise CG memory usage#11993nvchenghaoz merged 15 commits intoNVIDIA:mainfrom
Conversation
📝 WalkthroughWalkthroughThis PR implements dynamic output buffering and enhanced memory-aware pipelining for piecewise CUDA graph capture. It introduces wrapper classes for metadata and dynamic operations, adds optional output buffer parameters across custom ops, refactors runner semantics into explicit warmup/capture/replay phases, and adds batch capacity-aware bucket filtering to the compilation pipeline. Changes
Sequence DiagramsequenceDiagram
participant Client
participant PiecewiseCapturedGraph
participant ADPiecewiseRunner
participant DynamicOpWrapper
participant MetadataWrapper
participant CUDAGraph
Client->>PiecewiseCapturedGraph: __init__(model, capture_lm_head=False)
PiecewiseCapturedGraph->>ADPiecewiseRunner: Create runner instances
Client->>ADPiecewiseRunner: set_current_phase("warmup")
ADPiecewiseRunner->>DynamicOpWrapper: Execute dynamic ops (eager)
ADPiecewiseRunner->>MetadataWrapper: Execute metadata ops (clone outputs)
Note over ADPiecewiseRunner: Discover output shapes/dtypes
Client->>ADPiecewiseRunner: set_dynamic_out_info(OutputInfo)
ADPiecewiseRunner->>ADPiecewiseRunner: Store output shape/dtype info
Client->>ADPiecewiseRunner: set_current_phase("capture")
ADPiecewiseRunner->>ADPiecewiseRunner: Allocate dynamic output buffers
ADPiecewiseRunner->>DynamicOpWrapper: Execute with out= parameter
ADPiecewiseRunner->>CUDAGraph: Capture static segment
CUDAGraph-->>ADPiecewiseRunner: Captured graph with buffer references
Client->>ADPiecewiseRunner: set_current_phase("replay")
ADPiecewiseRunner->>CUDAGraph: Replay with pre-allocated buffers
ADPiecewiseRunner->>DynamicOpWrapper: Execute dynamic ops (reuse buffers)
CUDAGraph-->>ADPiecewiseRunner: Results written to buffers
ADPiecewiseRunner-->>Client: Output via buffer
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes 🚥 Pre-merge checks | ✅ 1 | ❌ 2❌ Failed checks (2 warnings)
✅ Passed checks (1 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Tip Try Coding Plans. Let us write the prompt for your AI agent so you can ship faster (with fewer bugs). Comment |
There was a problem hiding this comment.
Actionable comments posted: 9
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (7)
tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/triton_backend_mamba.py (1)
1-1:⚠️ Potential issue | 🟡 MinorUpdate the NVIDIA header year to 2026.
This file has a 2026 modification, but the copyright header still ends at 2025.
As per coding guidelines, "All TensorRT-LLM source files should contain an NVIDIA copyright header with the year of the latest meaningful modification."
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/triton_backend_mamba.py` at line 1, Update the NVIDIA SPDX copyright header at the top of triton_backend_mamba.py to reflect the latest modification year 2026 (change "2022-2025" to "2022-2026") so the file header matches the required policy.tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/torch_backend_mamba.py (1)
127-150:⚠️ Potential issue | 🔴 CriticalDeclare
ssm_state_cacheas mutated in the custom_op decorator.
_torch_cached_ssmwrites tossm_state_cacheat lines 190 and 259 viaindex_copy_, but the decorator declares onlyoutinmutates_args. PyTorch's custom_op uses this metadata to establish the correct alias/mutation model during tracing and CUDA graph capture.Minimal fix
-@torch.library.custom_op("auto_deploy::torch_cached_ssm", mutates_args=("out",)) +@torch.library.custom_op( + "auto_deploy::torch_cached_ssm", mutates_args=("ssm_state_cache", "out") +)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/torch_backend_mamba.py` around lines 127 - 150, The custom op decorator for _torch_cached_ssm incorrectly only lists "out" in mutates_args while the function mutates ssm_state_cache (via index_copy_); update the `@torch.library.custom_op` declaration for function _torch_cached_ssm to include ssm_state_cache (or its tensor argument name) in mutates_args so PyTorch knows it is mutated during tracing/CUDA graph capture; locate the decorator above def _torch_cached_ssm and add the ssm_state_cache identifier to the mutates_args tuple alongside "out".tensorrt_llm/_torch/auto_deploy/custom_ops/attention/flashinfer_attention.py (1)
318-343:⚠️ Potential issue | 🔴 CriticalAdd
kv_cachetomutates_args.The function mutates
kv_cachein-place at line 370 viaflashinfer.page.append_paged_kv_cache(), but the decorator declares onlymutates_args=("out",). This omission causes PyTorch's tracing and CUDA graph capture to receive incorrect side-effect information, leading to potential aliasing and correctness issues.Minimal fix
-@torch.library.custom_op("auto_deploy::flashinfer_attention_mha_with_cache", mutates_args=("out",)) +@torch.library.custom_op( + "auto_deploy::flashinfer_attention_mha_with_cache", + mutates_args=("kv_cache", "out"), +)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tensorrt_llm/_torch/auto_deploy/custom_ops/attention/flashinfer_attention.py` around lines 318 - 343, The decorator on flashinfer_mha_with_cache currently declares mutates_args=("out",) but the function mutates kv_cache in-place via flashinfer.page.append_paged_kv_cache(), so update the torch.library.custom_op decorator to include "kv_cache" in mutates_args (e.g., mutates_args=("out","kv_cache")) to accurately reflect side effects; ensure the decorator signature around flashinfer_mha_with_cache is the only change so tracing and CUDA graph capture see the correct aliasing for kv_cache.tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/flashinfer_backend_mamba.py (1)
31-55:⚠️ Potential issue | 🔴 CriticalDeclare the cache mutation in the custom-op schema.
This op mutates
ssm_state_cacheboth in_run_ssm_prefill(viaindex_copy_at line 88) and viaflashinfer.mamba.selective_state_updateat line 129. PyTorch requires all mutated tensors to be declared inmutates_args, otherwise behavior is undefined andopcheckwill fail.Minimal fix
-@torch.library.custom_op("auto_deploy::flashinfer_cached_ssm", mutates_args=("out",)) +@torch.library.custom_op( + "auto_deploy::flashinfer_cached_ssm", mutates_args=("ssm_state_cache", "out") +)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/flashinfer_backend_mamba.py` around lines 31 - 55, The custom op decorator for _flashinfer_cached_ssm currently only lists "out" in mutates_args but the op also mutates ssm_state_cache (seen in _run_ssm_prefill via index_copy_ and in flashinfer.mamba.selective_state_update); update the `@torch.library.custom_op`("auto_deploy::flashinfer_cached_ssm", mutates_args=(...)) declaration to include "ssm_state_cache" (alongside "out") so PyTorch knows this tensor is mutated.tensorrt_llm/_torch/auto_deploy/custom_ops/attention/torch_backend_attention.py (1)
345-392:⚠️ Potential issue | 🟠 MajorWrite attention output into
outdirectly.
yis allocated unconditionally on Line 346 and then copied intoouton Lines 387-392. That preserves the same peak allocation as before for one of the largest tensors in piecewise capture, so this path can still OOM even with a preallocated buffer. Useout.view(*bs_view, num_heads, v_head_dim)asywhenoutis provided and zero the padded tail in place.♻️ Proposed change
- # Preallocate output tensor - y = q.new_empty(*bs_view, num_heads, v_head_dim).contiguous() + # Reuse caller-owned storage when available. + if out is not None: + y = out.view(*bs_view, num_heads, v_head_dim) + else: + y = q.new_empty(*bs_view, num_heads, v_head_dim).contiguous() @@ - if out is not None: - out_flat = out.view(*bs_view, num_heads, v_head_dim) - out_flat[:num_total_tokens].copy_(y[:num_total_tokens]) - if num_total_tokens < bs: - out_flat[num_total_tokens:].zero_() - return out.new_empty(0) - - # Zero padding positions so downstream ops don't see garbage (piecewise CG) if num_total_tokens < bs: y[num_total_tokens:].zero_() + if out is not None: + return out.new_empty(0) return y.view(*output_shape)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tensorrt_llm/_torch/auto_deploy/custom_ops/attention/torch_backend_attention.py` around lines 345 - 392, The code unconditionally allocates y then copies into out, causing peak memory that can OOM; change allocation so that if out is provided you set y = out.view(*bs_view, num_heads, v_head_dim) (ensuring correct dtype/contiguity) and otherwise allocate q.new_empty(...) as before, then pass y into _torch_generate_mha/_torch_context_mha and after attention zero the padded tail in-place (out_flat[num_total_tokens:].zero_()) instead of copying; update usages of y, out_flat, num_total_tokens, bs_view, num_heads, v_head_dim, _torch_generate_mha and _torch_context_mha accordingly.tensorrt_llm/_torch/auto_deploy/custom_ops/fla/fla_backend_delta.py (1)
72-79:⚠️ Potential issue | 🟠 MajorThe delta-rule
outpath still pays for a full temporary tensor.
y = torch.empty_like(v)is always allocated and fully populated before its contents are copied intoout. That means the newoutparameter doesn't actually remove the extra outer output buffer for this op. Reuseoutas the backing storage fory_flatwhen it's provided, and zero the padded tail in place.♻️ Proposed change
- # pre-allocate output - y = torch.empty_like(v, memory_format=torch.contiguous_format) + # Reuse caller-owned storage when available. + y = out if out is not None else torch.empty_like(v, memory_format=torch.contiguous_format) y_flat = y.view(b * s, num_heads, -1) @@ - if out is not None: - out_flat = out.view(b * s, num_heads, -1) - out_flat[:num_total_tokens].copy_(y_flat[:num_total_tokens]) - if num_total_tokens < b * s: - out_flat[num_total_tokens:].zero_() - return out.new_empty(0) + if num_total_tokens < b * s: + y_flat[num_total_tokens:].zero_() + if out is not None: + return out.new_empty(0) return yAlso applies to: 128-133
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tensorrt_llm/_torch/auto_deploy/custom_ops/fla/fla_backend_delta.py` around lines 72 - 79, The delta-rule path currently allocates y = torch.empty_like(v) and copies into out, negating the purpose of the out parameter; modify the logic in fla_backend_delta.py (the block creating y and y_flat from v and using batch_info_host -> num_prefill/num_decode) to reuse the provided out buffer as backing storage: when out is not None, reshape/view out into y_flat (matching b*s, num_heads, -1) instead of allocating torch.empty_like(v), and write results directly into that view; ensure any padded tail elements (from the last row if the flattened size isn’t a multiple) are zeroed in-place on the out buffer rather than creating temporaries; apply the same change to the analogous section around lines 128-133 that also allocates y.tensorrt_llm/_torch/auto_deploy/compile/piecewise_utils.py (1)
1-18:⚠️ Potential issue | 🟠 MajorRestore the required NVIDIA license header.
This modified Python source file now begins with the module docstring, but the repository rule requires the NVIDIA Apache-2.0 copyright block on all source files.
As per coding guidelines,
**/*.{cpp,h,cu,cuh,hpp,py}: All TensorRT-LLM source files should contain an NVIDIA copyright header with the year of the latest meaningful modification. The header should be an Apache 2.0 license block as specified.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tensorrt_llm/_torch/auto_deploy/compile/piecewise_utils.py` around lines 1 - 18, Add the required NVIDIA Apache-2.0 license header at the very top of this source file (above the module docstring) using the year of the latest meaningful modification; ensure the header matches the repository's canonical NVIDIA copyright/license block and remains in place before any imports or docstrings so tools and scanners detect it. Locate this file by the module-level docstring and symbols such as GraphModule, Node, split_module, and ad_logger to confirm you're editing tensorrt_llm/_torch/auto_deploy/compile/piecewise_utils.py, then commit the file with the header prepended.
🧹 Nitpick comments (3)
tensorrt_llm/_torch/auto_deploy/custom_ops/mla/torch_backend_mla.py (1)
389-405: Use the external buffer as the helper output.Both helper paths already accept an output tensor, but this branch still allocates
yand then copies it intoout. That keeps a full-sized temporary alive and adds an extra device copy on the new buffer-reuse path, which leaves part of the intended memory win on the table.♻️ Suggested simplification
+ out_flat = out.view(bs, num_heads, v_head_dim) if out is not None else None + if s == 1: # ===================================================================== # Generate phase: Use weight absorption # ===================================================================== - y = q_nope.new_empty(b, num_heads, v_head_dim).contiguous() + y = out_flat if out_flat is not None else q_nope.new_empty( + b, num_heads, v_head_dim + ).contiguous() _torch_mla_generate_with_absorption( q_nope, q_pe, @@ y, ) else: @@ - y = q_nope.new_empty(*bs_view, num_heads, v_head_dim).contiguous() + y = out_flat if out_flat is not None else q_nope.new_empty( + *bs_view, num_heads, v_head_dim + ).contiguous() _torch_mla_context_with_expansion( q_nope_flat, q_pe_flat, @@ y, ) if out is not None: - out_flat = out.view(bs, num_heads, v_head_dim) - out_flat[:num_total_tokens].copy_(y[:num_total_tokens]) if num_total_tokens < bs: - out_flat[num_total_tokens:].zero_() + y[num_total_tokens:].zero_() return out.new_empty(0)Also applies to: 418-444
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tensorrt_llm/_torch/auto_deploy/custom_ops/mla/torch_backend_mla.py` around lines 389 - 405, The branch currently allocates a temporary tensor y and later copies it into out; instead reuse the external buffer out as the helper output to avoid the full-sized temp and extra device copy: replace the y allocation with a view/slice of out shaped (b, num_heads, v_head_dim) (ensuring contiguity if required) and pass that buffer into _torch_mla_generate_with_absorption instead of y, then remove the final copy into out; do the analogous change in the other helper branch (the block around lines 418-444 that calls the other MLA helper) so both helper calls directly write into out (refer to variables y, out, _torch_mla_generate_with_absorption and the other MLA helper name) and ensure shapes/contiguity match before calling.tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/triton_backend_mamba.py (1)
31-55: Document the newoutbuffer contract on this custom op.
outnow changes both buffer ownership and the observable return behavior, but there is still no docstring describing the expected shape/view-compatibility foroutor that callers should ignore the return value whenoutis provided. That is easy to misuse fromtorch.ops.auto_deploy.triton_cached_ssm.default.As per coding guidelines, "For Python interfaces that may be used outside a file, prefer docstrings over comments" and "Use Google style docstrings for Python classes and functions, which can be parsed by Sphinx."
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/triton_backend_mamba.py` around lines 31 - 55, Add a Google-style docstring to the custom op function _triton_cached_ssm (decorated via torch.library.custom_op) that documents the new out buffer contract: explicitly state the expected shape and view-compatibility requirements for out relative to the computed output (including any allowed broadcasting or required contiguous/strided layouts), describe that providing out transfers buffer ownership/observers semantics so the op will write into out in-place, and state callers MUST ignore the Python return value when supplying out (and instead read from the out buffer); also mention interaction with torch.ops.auto_deploy.triton_cached_ssm.default and any error/validation behavior when out has an incompatible shape or dtype.tensorrt_llm/_torch/auto_deploy/compile/backends/torch_cudagraph.py (1)
405-439: Remove or gate the temporary memory probes before merge.The
TODOon Line 405 is still in the capture path, and_install_mem_hooks()adds per-segmentget_mem_infocalls and log spam for every bucket. Please either drop this from the PR or guard it behind a debug-only switch.If you want, I can help split the memory-probe code into a follow-up debug-only change.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tensorrt_llm/_torch/auto_deploy/compile/backends/torch_cudagraph.py` around lines 405 - 439, The temporary memory-probe logic (_fmt_mem and _install_mem_hooks) is currently active and spams logs; either remove these methods and their calls or guard them behind a debug-only switch (e.g., a class/static flag MEMORY_PROBES_ENABLED or check ad_logger.isEnabledFor(logging.DEBUG) / an env var) so probes run only when explicitly enabled. Locate _install_mem_hooks (and any callers that attach its returned handles) and wrap the call with the guard or early-return when the flag is false; likewise make _fmt_mem and the get_mem_info calls no-ops unless the same flag is enabled (or remove them entirely). Ensure checks reference split_gm and ADPiecewiseRunner as before and keep ad_logger usage only under the debug gate so normal runs are not impacted.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@tensorrt_llm/_torch/auto_deploy/compile/piecewise_runner.py`:
- Around line 200-225: The bug is that dynamic output info and buffers are
stored per-runner (fields _next_dynamic_out_info and
SegmentEntry.dynamic_out_buf) so when one ADPiecewiseRunner precedes multiple
DynamicOpWrapper instances the later shape-discovery results overwrite earlier
ones; change the storage to key by the dynamic submodule id: make
_next_dynamic_out_info a dict keyed by dynamic submodule identifier and change
SegmentEntry.dynamic_out_buf to a mapping of num_tokens->(submodule_id->buffer)
or similar, update get_dynamic_out_buf to accept/lookup the submodule id and
return the buffer for that id/num_tokens, and adjust prepare() in
PiecewiseCapturedGraph and any code paths that set/read these fields (including
the logic around lines referenced 248-264 and 311-338) to pass and use the
dynamic submodule id when storing and retrieving OutputInfo and buffers.
- Around line 234-241: The bug is that SegmentEntry() is created unconditionally
and inserted into self.entries[num_tokens] before checking phase, so a warmup
call (phase == "warmup") wrongly populates entries; move the allocation and
assignment of entry = SegmentEntry() and self.entries[num_tokens] = entry so
they occur only in the capture/replay branches (i.e., after the phase check),
leaving the warmup branch to just return self.submodule(*args, **kwargs) without
touching self.entries; update any code paths that reference entry to first fetch
it from self.entries when in capture or replay.
In `@tensorrt_llm/_torch/auto_deploy/custom_ops/attention/triton_attention.py`:
- Around line 330-334: The fake output-shape logic is wrong: instead of using
v.shape[-1] directly, compute v_head_dim = v_cache.shape[-1] and num_heads =
q.shape[2] // qk_head_dim (matching the real op's qk head division) then build
output_shape = (*q.shape[:-1], num_heads * v_head_dim) and return
q.new_empty(*output_shape).contiguous() (still early-returning out if provided);
update the fake function to use these symbols (q, v_cache, qk_head_dim,
num_heads, v_head_dim, output_shape) so its output matches the real op's
y.view(*output_shape) behavior.
- Around line 293-298: The fake-registration path must return the same tensor
shape as the real Triton op when an output buffer is provided: change the fake
implementation to detect when the `out` argument is non-None and return an empty
tensor with the same type/device as `out` (i.e., equivalent to
`out.new_empty(0)`) instead of returning `out` itself; ensure the logic that
writes into `out_flat` (variables `out`, `out_flat`, `bs`, `num_heads`,
`v_head_dim`, `num_total_tokens`, `y`) remains unchanged so the fake propagation
matches the real op's return shape.
In `@tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/triton_backend_mamba.py`:
- Around line 140-145: The fake-tensor branch currently returns the full buffer
`out` while the eager branch returns an empty tensor, causing shape/aliasing
mismatch; modify the fake implementation so that when `out` is not None you
mirror the eager behavior—apply the same padding-zeroing of
`preallocated_ssm_out` when `num_total_tokens < bs` (as done in the eager path)
and return an empty tensor created with `out.new_empty(0)` (preserving `out`'s
dtype/device) instead of returning `out`.
In `@tensorrt_llm/_torch/auto_deploy/custom_ops/mla/flashinfer_mla.py`:
- Around line 530-539: The current code rebinds compressed_kv_flat and kpe_flat
to FP8 which also alters the tensors used later for attention; instead create
new variables (e.g., compressed_kv_flat_for_cache and kpe_flat_for_cache),
perform the torch.float8_e4m3fn cast only on those when ckv_cache.dtype ==
torch.float8_e4m3fn, and pass those dedicated cast tensors into
flashinfer.page.append_paged_mla_kv_cache(compressed_kv_flat_for_cache[:num_total_tokens],
kpe_flat_for_cache[:num_total_tokens]) so the original compressed_kv_flat and
kpe_flat remain at activation precision for prefill paths.
- Around line 754-759: The fake implementation in flashinfer_mla currently
returns the full out buffer while the real CUDA path returns a 0-length sentinel
(out.new_empty(0)) when out is provided, causing a contract mismatch; change the
fake branch that handles the `out` argument to return the same 0-length sentinel
(use out.new_empty(0)) after performing any in-place modifications (e.g.,
y[num_total_tokens:].zero_()), ensuring variables referenced (out, y,
num_total_tokens, bs) and the function flashinfer_mla keep the same
side-effect-only behavior and return shape as the real implementation.
In `@tensorrt_llm/_torch/auto_deploy/custom_ops/utils/torch_gather_logits.py`:
- Around line 50-60: The code currently always clones hidden_states into result
even when an output buffer (out) is provided, negating buffer reuse; change the
logic so cloning (hidden_states.clone(memory_format=torch.contiguous_format))
only happens when out is None, otherwise create the reshaped view directly from
hidden_states (use the same contiguous/reshape logic used for result.view(...))
and then copy that view into out via out.copy_(); specifically modify the paths
around result, hidden_states, result.view(...), and out.copy_ so that clone is
skipped when out is present and cloning is only performed in the no-out branch.
In `@tensorrt_llm/_torch/auto_deploy/transform/library/compile_model.py`:
- Around line 99-101: The current code computes batch_capacity = (max_batch - 1)
* max_seq + 1 and then applies that ceiling to every piecewise_num_tokens
bucket, which incorrectly clips pure-prefill buckets; change the logic so
piecewise_enabled still generates buckets up to max_num_tokens for prefill-only
use, and only apply batch_capacity when constructing or validating buckets that
will be used for mixed-batch inference. Concretely, keep piecewise_num_tokens
(and any loop that builds it) able to reach max_num_tokens, and when you need a
mixed-batch token limit (e.g., when mixing prefill + decode or validating
mixed-batch requests) use min(bucket, batch_capacity) rather than mutating or
clamping the original bucket; reference variables: max_seq, max_batch,
batch_capacity, piecewise_enabled, piecewise_num_tokens, and max_num_tokens.
---
Outside diff comments:
In `@tensorrt_llm/_torch/auto_deploy/compile/piecewise_utils.py`:
- Around line 1-18: Add the required NVIDIA Apache-2.0 license header at the
very top of this source file (above the module docstring) using the year of the
latest meaningful modification; ensure the header matches the repository's
canonical NVIDIA copyright/license block and remains in place before any imports
or docstrings so tools and scanners detect it. Locate this file by the
module-level docstring and symbols such as GraphModule, Node, split_module, and
ad_logger to confirm you're editing
tensorrt_llm/_torch/auto_deploy/compile/piecewise_utils.py, then commit the file
with the header prepended.
In
`@tensorrt_llm/_torch/auto_deploy/custom_ops/attention/flashinfer_attention.py`:
- Around line 318-343: The decorator on flashinfer_mha_with_cache currently
declares mutates_args=("out",) but the function mutates kv_cache in-place via
flashinfer.page.append_paged_kv_cache(), so update the torch.library.custom_op
decorator to include "kv_cache" in mutates_args (e.g.,
mutates_args=("out","kv_cache")) to accurately reflect side effects; ensure the
decorator signature around flashinfer_mha_with_cache is the only change so
tracing and CUDA graph capture see the correct aliasing for kv_cache.
In
`@tensorrt_llm/_torch/auto_deploy/custom_ops/attention/torch_backend_attention.py`:
- Around line 345-392: The code unconditionally allocates y then copies into
out, causing peak memory that can OOM; change allocation so that if out is
provided you set y = out.view(*bs_view, num_heads, v_head_dim) (ensuring correct
dtype/contiguity) and otherwise allocate q.new_empty(...) as before, then pass y
into _torch_generate_mha/_torch_context_mha and after attention zero the padded
tail in-place (out_flat[num_total_tokens:].zero_()) instead of copying; update
usages of y, out_flat, num_total_tokens, bs_view, num_heads, v_head_dim,
_torch_generate_mha and _torch_context_mha accordingly.
In `@tensorrt_llm/_torch/auto_deploy/custom_ops/fla/fla_backend_delta.py`:
- Around line 72-79: The delta-rule path currently allocates y =
torch.empty_like(v) and copies into out, negating the purpose of the out
parameter; modify the logic in fla_backend_delta.py (the block creating y and
y_flat from v and using batch_info_host -> num_prefill/num_decode) to reuse the
provided out buffer as backing storage: when out is not None, reshape/view out
into y_flat (matching b*s, num_heads, -1) instead of allocating
torch.empty_like(v), and write results directly into that view; ensure any
padded tail elements (from the last row if the flattened size isn’t a multiple)
are zeroed in-place on the out buffer rather than creating temporaries; apply
the same change to the analogous section around lines 128-133 that also
allocates y.
In
`@tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/flashinfer_backend_mamba.py`:
- Around line 31-55: The custom op decorator for _flashinfer_cached_ssm
currently only lists "out" in mutates_args but the op also mutates
ssm_state_cache (seen in _run_ssm_prefill via index_copy_ and in
flashinfer.mamba.selective_state_update); update the
`@torch.library.custom_op`("auto_deploy::flashinfer_cached_ssm",
mutates_args=(...)) declaration to include "ssm_state_cache" (alongside "out")
so PyTorch knows this tensor is mutated.
In `@tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/torch_backend_mamba.py`:
- Around line 127-150: The custom op decorator for _torch_cached_ssm incorrectly
only lists "out" in mutates_args while the function mutates ssm_state_cache (via
index_copy_); update the `@torch.library.custom_op` declaration for function
_torch_cached_ssm to include ssm_state_cache (or its tensor argument name) in
mutates_args so PyTorch knows it is mutated during tracing/CUDA graph capture;
locate the decorator above def _torch_cached_ssm and add the ssm_state_cache
identifier to the mutates_args tuple alongside "out".
In `@tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/triton_backend_mamba.py`:
- Line 1: Update the NVIDIA SPDX copyright header at the top of
triton_backend_mamba.py to reflect the latest modification year 2026 (change
"2022-2025" to "2022-2026") so the file header matches the required policy.
---
Nitpick comments:
In `@tensorrt_llm/_torch/auto_deploy/compile/backends/torch_cudagraph.py`:
- Around line 405-439: The temporary memory-probe logic (_fmt_mem and
_install_mem_hooks) is currently active and spams logs; either remove these
methods and their calls or guard them behind a debug-only switch (e.g., a
class/static flag MEMORY_PROBES_ENABLED or check
ad_logger.isEnabledFor(logging.DEBUG) / an env var) so probes run only when
explicitly enabled. Locate _install_mem_hooks (and any callers that attach its
returned handles) and wrap the call with the guard or early-return when the flag
is false; likewise make _fmt_mem and the get_mem_info calls no-ops unless the
same flag is enabled (or remove them entirely). Ensure checks reference split_gm
and ADPiecewiseRunner as before and keep ad_logger usage only under the debug
gate so normal runs are not impacted.
In `@tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/triton_backend_mamba.py`:
- Around line 31-55: Add a Google-style docstring to the custom op function
_triton_cached_ssm (decorated via torch.library.custom_op) that documents the
new out buffer contract: explicitly state the expected shape and
view-compatibility requirements for out relative to the computed output
(including any allowed broadcasting or required contiguous/strided layouts),
describe that providing out transfers buffer ownership/observers semantics so
the op will write into out in-place, and state callers MUST ignore the Python
return value when supplying out (and instead read from the out buffer); also
mention interaction with torch.ops.auto_deploy.triton_cached_ssm.default and any
error/validation behavior when out has an incompatible shape or dtype.
In `@tensorrt_llm/_torch/auto_deploy/custom_ops/mla/torch_backend_mla.py`:
- Around line 389-405: The branch currently allocates a temporary tensor y and
later copies it into out; instead reuse the external buffer out as the helper
output to avoid the full-sized temp and extra device copy: replace the y
allocation with a view/slice of out shaped (b, num_heads, v_head_dim) (ensuring
contiguity if required) and pass that buffer into
_torch_mla_generate_with_absorption instead of y, then remove the final copy
into out; do the analogous change in the other helper branch (the block around
lines 418-444 that calls the other MLA helper) so both helper calls directly
write into out (refer to variables y, out, _torch_mla_generate_with_absorption
and the other MLA helper name) and ensure shapes/contiguity match before
calling.
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Run ID: b08a4a2f-ddf6-4622-86c5-48f2f568cc31
📒 Files selected for processing (21)
examples/auto_deploy/model_registry/configs/qwen3.5_moe_35b.yamltensorrt_llm/_torch/auto_deploy/compile/backends/torch_cudagraph.pytensorrt_llm/_torch/auto_deploy/compile/piecewise_runner.pytensorrt_llm/_torch/auto_deploy/compile/piecewise_utils.pytensorrt_llm/_torch/auto_deploy/config/default.yamltensorrt_llm/_torch/auto_deploy/custom_ops/attention/flashinfer_attention.pytensorrt_llm/_torch/auto_deploy/custom_ops/attention/torch_backend_attention.pytensorrt_llm/_torch/auto_deploy/custom_ops/attention/triton_attention.pytensorrt_llm/_torch/auto_deploy/custom_ops/attention/trtllm_attention.pytensorrt_llm/_torch/auto_deploy/custom_ops/fla/fla_backend_delta.pytensorrt_llm/_torch/auto_deploy/custom_ops/fla/fla_backend_gated_delta.pytensorrt_llm/_torch/auto_deploy/custom_ops/fla/torch_backend_gated_delta.pytensorrt_llm/_torch/auto_deploy/custom_ops/mamba/flashinfer_backend_mamba.pytensorrt_llm/_torch/auto_deploy/custom_ops/mamba/torch_backend_mamba.pytensorrt_llm/_torch/auto_deploy/custom_ops/mamba/triton_backend_mamba.pytensorrt_llm/_torch/auto_deploy/custom_ops/mla/flashinfer_mla.pytensorrt_llm/_torch/auto_deploy/custom_ops/mla/torch_backend_mla.pytensorrt_llm/_torch/auto_deploy/custom_ops/utils/torch_gather_logits.pytensorrt_llm/_torch/auto_deploy/transform/library/compile_model.pytests/unittest/_torch/auto_deploy/unit/singlegpu/compile/test_captured_graph.pytests/unittest/_torch/auto_deploy/unit/singlegpu/compile/test_piecewise_runner.py
tensorrt_llm/_torch/auto_deploy/custom_ops/attention/triton_attention.py
Show resolved
Hide resolved
tensorrt_llm/_torch/auto_deploy/custom_ops/attention/triton_attention.py
Show resolved
Hide resolved
tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/triton_backend_mamba.py
Show resolved
Hide resolved
tensorrt_llm/_torch/auto_deploy/custom_ops/mla/flashinfer_mla.py
Outdated
Show resolved
Hide resolved
tensorrt_llm/_torch/auto_deploy/custom_ops/utils/torch_gather_logits.py
Outdated
Show resolved
Hide resolved
Signed-off-by: Chenghao Zhang <211069071+nvchenghaoz@users.noreply.github.com>
Signed-off-by: Chenghao Zhang <211069071+nvchenghaoz@users.noreply.github.com>
Signed-off-by: Chenghao Zhang <211069071+nvchenghaoz@users.noreply.github.com>
Signed-off-by: Chenghao Zhang <211069071+nvchenghaoz@users.noreply.github.com>
Signed-off-by: Chenghao Zhang <211069071+nvchenghaoz@users.noreply.github.com> Made-with: Cursor
Signed-off-by: Chenghao Zhang <211069071+nvchenghaoz@users.noreply.github.com>
d149b0c to
d808289
Compare
Signed-off-by: Chenghao Zhang <211069071+nvchenghaoz@users.noreply.github.com>
Signed-off-by: Chenghao Zhang <211069071+nvchenghaoz@users.noreply.github.com>
Signed-off-by: Chenghao Zhang <211069071+nvchenghaoz@users.noreply.github.com>
Signed-off-by: Chenghao Zhang <211069071+nvchenghaoz@users.noreply.github.com>
Signed-off-by: Chenghao Zhang <211069071+nvchenghaoz@users.noreply.github.com>
|
/bot run --extra-stage "DGX_B200-4_GPUs-AutoDeploy-1, DGX_H100-4_GPUs-AutoDeploy-1" --disable-fail-fast |
|
PR_Github #38924 [ run ] triggered by Bot. Commit: |
Signed-off-by: Chenghao Zhang <211069071+nvchenghaoz@users.noreply.github.com>
|
/bot run --extra-stage "DGX_B200-4_GPUs-AutoDeploy-1, DGX_H100-4_GPUs-AutoDeploy-1" --disable-fail-fast |
|
PR_Github #38932 [ run ] triggered by Bot. Commit: |
|
PR_Github #38932 [ run ] completed with state |
|
PR_Github #38924 [ run ] completed with state
|
|
/bot run --extra-stage "DGX_B200-4_GPUs-AutoDeploy-1, DGX_H100-4_GPUs-AutoDeploy-1" --disable-fail-fast |
|
PR_Github #39292 [ run ] triggered by Bot. Commit: |
|
PR_Github #39292 [ run ] completed with state
|
…mory_0306 Signed-off-by: Chenghao Zhang <211069071+nvchenghaoz@users.noreply.github.com> Made-with: Cursor # Conflicts: # tensorrt_llm/_torch/auto_deploy/custom_ops/attention/trtllm_attention.py
|
/bot run --extra-stage "DGX_B200-4_GPUs-AutoDeploy-1, DGX_H100-4_GPUs-AutoDeploy-1" --disable-fail-fast |
|
PR_Github #39323 [ run ] triggered by Bot. Commit: |
|
PR_Github #39323 [ run ] completed with state
|
Signed-off-by: Chenghao Zhang <211069071+nvchenghaoz@users.noreply.github.com>
…mory_0306 Signed-off-by: Chenghao Zhang <211069071+nvchenghaoz@users.noreply.github.com> Made-with: Cursor # Conflicts: # tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/torch_backend_mamba.py # tensorrt_llm/_torch/auto_deploy/custom_ops/mamba/triton_backend_mamba.py
|
/bot run --extra-stage "DGX_B200-4_GPUs-AutoDeploy-1, DGX_H100-4_GPUs-AutoDeploy-1" --disable-fail-fast |
|
Some numbers from local test - GLM 4.7 Flash 1k/1k
Nano v3 fp8 1k/1k
|
|
PR_Github #39481 [ run ] triggered by Bot. Commit: |
|
PR_Github #39481 [ run ] completed with state |
NVIDIA#11993) Signed-off-by: Chenghao Zhang <211069071+nvchenghaoz@users.noreply.github.com>

This PR is to improve the memory usage during the piecewise Cudagraph capturing.
Several main updates:
With the change, the memory usage of the piecewise cudagraph is under control and can be used for most cases.
Also the perf improves a bit due to the changes -

For Qwen 3.5 isl 1k osl 1k, H100 -
The next step for piecewise CG is, the final gemm for the prefill is really big and AutoDeploy should enable the gather_logits transform by default. This region is marked as dynamic due to the unknown shape of the gemm. This will be addressed by a follow up task. My plan to fix this is to pad the final gemm to the max_batch_size. That would need some further investigation. #12049 to track the work.
Summary by CodeRabbit
Release Notes
New Features
Bug Fixes
Performance Improvements