Skip to content

[None][perf] AutoDeploy: reduce C++ dispatch overhead in decode scheduling loop#13012

Merged
nvchenghaoz merged 6 commits into
NVIDIA:mainfrom
nv-auto-deploy:chenghao/ad-cpu-scheduling-0413
May 5, 2026
Merged

[None][perf] AutoDeploy: reduce C++ dispatch overhead in decode scheduling loop#13012
nvchenghaoz merged 6 commits into
NVIDIA:mainfrom
nv-auto-deploy:chenghao/ad-cpu-scheduling-0413

Conversation

@nvchenghaoz
Copy link
Copy Markdown
Collaborator

@nvchenghaoz nvchenghaoz commented Apr 13, 2026

Summary

Four targeted changes to reduce Python/C++ API call overhead per decode step in `ADEngine._prepare_inputs()`.

  • `get_last_tokens` instead of `get_token` + `get_num_tokens`: replace the two-call pattern `get_token(0, get_num_tokens(0) - 1)` with `get_last_tokens(0)` — retrieves the last committed token in one C++ round-trip, saving one call per non-overlap decode request per step. Also aligns `ad_executor.py` with the pattern already used in `pyexecutor/model_engine.py`.

  • Inline `get_num_kv_blocks`: replace `kv_cache_manager.get_num_kv_blocks(end_compute_i)` with the equivalent pure-Python ceildiv expression, saving a Python function-call overhead per request per step.

  • Hoist mamba `hasattr` check: move `hasattr(kv_cache_manager, "mamba_cache_index")` outside the per-request loop into `_use_mamba` to avoid repeated attribute lookup.

  • Remove redundant `position_ids` computation in `SequenceInfo.nest_sequences()`: position_ids were computed here and then recomputed again downstream. Removing the duplicate saves one NumPy array construction and `torch.from_numpy()` call on every prefill step.

Perf

Benchmarked on Gemma4 26B MoE (single H100, ISL=1000/OSL=1000, no MLIR):

concurrency baseline ITL opt ITL delta
1 6.67ms 6.61ms −0.06ms (−0.9%)
4 9.23ms 9.24ms within noise (std=0.03ms)
16 13.82ms 13.70ms within noise (std=0.14ms)

The ~60µs absolute saving is consistent at conc=1 where GPU step time is short enough that CPU scheduling is on the critical path. At higher concurrencies the GPU step dominates and scheduling overhead is fully pipelined.

Test plan

  • `pytest tests/unittest/auto_deploy/` — no regressions
  • Gemma4 26B MoE decode benchmark — neutral to slightly positive

🤖 Generated with Claude Code

@nvchenghaoz nvchenghaoz requested a review from a team as a code owner April 13, 2026 18:34
@nvchenghaoz nvchenghaoz marked this pull request as draft April 13, 2026 18:38
@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented Apr 13, 2026

Caution

Review failed

The head commit changed during the review from 23a78a7 to abb593f.

📝 Walkthrough

Walkthrough

These changes optimize position ID handling and KV cache computation in the auto-deployment system by deferring position ID staging and introducing caching for input positions alongside batched KV cache index lookups.

Changes

Cohort / File(s) Summary
Position ID Deferred Staging
tensorrt_llm/_torch/auto_deploy/custom_ops/attention_interface.py
Removed eager computation and staging of position_ids during nest_sequences(); now only produced later when explicitly required via self._is_required("position_ids").
Input Position Caching & KV Cache Optimization
tensorrt_llm/_torch/auto_deploy/shim/ad_executor.py
Added per-instance _gen_input_pos_cache memoization for generation requests. Updated _prepare_inputs to derive num_tokens_seen from cached input position for non-overlap, non-draft requests. Introduced batched KV cache index lookup via get_batch_cache_indices() and replaced get_num_kv_blocks() call with inline Python computation. Refactored KV-cache handling control flow with _use_mamba flag and simplified token selection using get_last_tokens().

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~25 minutes

🚥 Pre-merge checks | ✅ 3
✅ Passed checks (3 passed)
Check name Status Explanation
Title check ✅ Passed The title clearly and specifically describes the main change: reducing C++ dispatch overhead in AutoDeploy's decode scheduling loop.
Docstring Coverage ✅ Passed Docstring coverage is 100.00% which is sufficient. The required threshold is 80.00%.
Description check ✅ Passed The PR description clearly explains the changes, includes detailed performance benchmarks, and references test coverage with specific pytest commands and benchmark results.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🧹 Nitpick comments (1)
tensorrt_llm/_torch/auto_deploy/shim/ad_executor.py (1)

837-855: Use isinstance(kv_cache_manager, MambaHybridCacheManager) instead of hasattr() for mamba-specific guards.

The hasattr(kv_cache_manager, "mamba_cache_index") check at line 837 works today but is fragile. The attribute is unique to MambaHybridCacheManager and not exposed on the base KVCacheManager class. However, a direct type check with isinstance() is more explicit, safer if the attribute ever becomes a placeholder elsewhere, and consistent with how other call sites in the codebase guard mamba-specific access (e.g., disaggregation/transceiver.py, kv_extractor.py, _util.py).

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tensorrt_llm/_torch/auto_deploy/shim/ad_executor.py` around lines 837 - 855,
Replace the fragile hasattr guard with an explicit type check: instead of
_use_mamba = hasattr(kv_cache_manager, "mamba_cache_index") use _use_mamba =
isinstance(kv_cache_manager, MambaHybridCacheManager); update the import or
reference so MambaHybridCacheManager is available in this module and keep the
rest of the logic (uses of kv_cache_manager.mamba_cache_index and fallbacks)
unchanged to match other call sites like disaggregation/transceiver.py and
kv_extractor.py.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Nitpick comments:
In `@tensorrt_llm/_torch/auto_deploy/shim/ad_executor.py`:
- Around line 837-855: Replace the fragile hasattr guard with an explicit type
check: instead of _use_mamba = hasattr(kv_cache_manager, "mamba_cache_index")
use _use_mamba = isinstance(kv_cache_manager, MambaHybridCacheManager); update
the import or reference so MambaHybridCacheManager is available in this module
and keep the rest of the logic (uses of kv_cache_manager.mamba_cache_index and
fallbacks) unchanged to match other call sites like
disaggregation/transceiver.py and kv_extractor.py.

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro Plus

Run ID: 9bf5195f-cffa-4566-a54a-1f813fc17edb

📥 Commits

Reviewing files that changed from the base of the PR and between fc2ba32 and 74c3bf7.

📒 Files selected for processing (2)
  • tensorrt_llm/_torch/auto_deploy/custom_ops/attention_interface.py
  • tensorrt_llm/_torch/auto_deploy/shim/ad_executor.py
💤 Files with no reviewable changes (1)
  • tensorrt_llm/_torch/auto_deploy/custom_ops/attention_interface.py

@nvchenghaoz nvchenghaoz force-pushed the chenghao/ad-cpu-scheduling-0413 branch from 74c3bf7 to e368317 Compare April 13, 2026 20:58
Three low-overhead changes that reduce Python/C++ dispatch cost per decode
step in the ADEngine scheduling loop:

1. get_last_tokens instead of get_token + get_num_tokens (ad_executor.py):
   Replace the two-call pattern get_token(0, get_num_tokens(0) - 1) with a
   single get_last_tokens(0) call. Saves one C++ round-trip per non-overlap
   decode request per step. Aligns with the pattern in model_engine.py.

2. Inline get_num_kv_blocks (ad_executor.py):
   Replace kv_cache_manager.get_num_kv_blocks(end_compute_i) with the
   equivalent pure-Python ceildiv expression, saving a function-call overhead
   per request per step.

3. Hoist mamba hasattr check (ad_executor.py):
   Move hasattr(kv_cache_manager, "mamba_cache_index") outside the per-request
   loop into _use_mamba to avoid repeated attribute lookup.

4. Remove redundant position_ids computation (attention_interface.py):
   position_ids were computed in SequenceInfo.nest_sequences() and then
   recomputed again downstream. Removing the duplicate saves one NumPy array
   construction and torch.from_numpy() call on every prefill step.

Signed-off-by: Chenghao Zhang <211069071+nvchenghaoz@users.noreply.github.com>
@nvchenghaoz nvchenghaoz force-pushed the chenghao/ad-cpu-scheduling-0413 branch from e368317 to 23a78a7 Compare April 14, 2026 01:06
Comment thread tensorrt_llm/_torch/auto_deploy/shim/ad_executor.py Outdated
@nvchenghaoz
Copy link
Copy Markdown
Collaborator Author

@coderabbitai review the changes

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented Apr 14, 2026

@nvchenghaoz: I'll review the changes in this PR right away!

✅ Actions performed

Review triggered.

Note: CodeRabbit is an incremental review system and does not re-review already reviewed commits. This command is applicable only when automatic reviews are paused.

- Remove explanatory comment on get_last_tokens (self-evident)
- Use isinstance(kv_cache_manager, MambaHybridCacheManager) instead of
  hasattr() for the mamba guard — more explicit and consistent with
  other call sites in ad_executor.py (line 295) and the codebase

Signed-off-by: Chenghao Zhang <211069071+nvchenghaoz@users.noreply.github.com>
@nvchenghaoz
Copy link
Copy Markdown
Collaborator Author

Addressed CodeRabbit's hasattrisinstance suggestion: MambaHybridCacheManager is already imported in this file and used with isinstance at line 295, so the change is a one-liner and consistent with the rest of the file.

@nvchenghaoz nvchenghaoz marked this pull request as ready for review April 14, 2026 02:17
@nvchenghaoz
Copy link
Copy Markdown
Collaborator Author

/bot run --extra-stage "DGX_B200-4_GPUs-AutoDeploy-1, DGX_H100-4_GPUs-AutoDeploy-1"

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #43144 [ run ] triggered by Bot. Commit: abb593f Link to invocation

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #43144 [ run ] completed with state FAILURE. Commit: abb593f
/LLM/main/L0_MergeRequest_PR pipeline #33776 completed with status: 'FAILURE'

CI Report

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

Link to invocation

…guard

isinstance(kv_cache_manager, MambaHybridCacheManager) breaks unit tests
that use _DummyHybridKVCacheManager, which cannot subclass the real class
due to its complex constructor (~15 required params). hasattr() is the
correct duck-typing approach here.

Signed-off-by: Chenghao Zhang <211069071+nvchenghaoz@users.noreply.github.com>
@nvchenghaoz
Copy link
Copy Markdown
Collaborator Author

Tried switching to isinstance(kv_cache_manager, MambaHybridCacheManager) per CodeRabbit's suggestion, but it breaks test_ad_engine_prepare_inputs_with_hybrid_cache_manager — the test uses _DummyHybridKVCacheManager which cannot subclass MambaHybridCacheManager due to its ~15 required constructor params. Reverted to hasattr(kv_cache_manager, "mamba_cache_index"). The attribute name is specific enough to serve as a safe duck-typing guard.

@nvchenghaoz
Copy link
Copy Markdown
Collaborator Author

/bot run --extra-stage "DGX_B200-4_GPUs-AutoDeploy-1, DGX_H100-4_GPUs-AutoDeploy-1"

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #43251 [ run ] triggered by Bot. Commit: d83ce44 Link to invocation

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #43251 [ run ] completed with state DISABLED
Freeze main and open the PR merge only after CI is back to healthy https://nvidia.slack.com/archives/C059LSY62BT/p1776141760843319?thread_ts=1775985925.442509&cid=C059LSY62BT

Link to invocation

@nvchenghaoz
Copy link
Copy Markdown
Collaborator Author

/bot run --extra-stage "DGX_B200-4_GPUs-AutoDeploy-1, DGX_H100-4_GPUs-AutoDeploy-1"

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #43541 [ run ] triggered by Bot. Commit: d83ce44 Link to invocation

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #43541 [ run ] completed with state FAILURE. Commit: d83ce44
/LLM/main/L0_MergeRequest_PR pipeline #34046 completed with status: 'FAILURE'

CI Report

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

Link to invocation

@nvchenghaoz
Copy link
Copy Markdown
Collaborator Author

/bot run --extra-stage "DGX_B200-4_GPUs-AutoDeploy-1, DGX_H100-4_GPUs-AutoDeploy-1"

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #43817 [ run ] triggered by Bot. Commit: da730ae Link to invocation

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #43817 [ run ] completed with state SUCCESS. Commit: da730ae
/LLM/main/L0_MergeRequest_PR pipeline #34293 completed with status: 'FAILURE'

CI Report

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

Link to invocation

@nvchenghaoz
Copy link
Copy Markdown
Collaborator Author

/bot run --stage-list "A10-Build_Docs, A10-PackageSanityCheck-PY310-UB2204, A100X-PackageSanityCheck-PY312-UB2404, A30-AutoDeploy-1, H100_PCIe-AutoDeploy-1, DGX_B200-AutoDeploy-1, A100X-PyTorch-1, DGX_H100-4_GPUs-AutoDeploy-1, DGX_B200-4_GPUs-AutoDeploy-1, DGX_H100-4_GPUs-AutoDeploy-Post-Merge-1, DGX_B200-8_GPUs-AutoDeploy-Post-Merge-1" --disable-fail-fast

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #46689 [ run ] triggered by Bot. Commit: 0ffb308 Link to invocation

Signed-off-by: Chenghao Zhang <211069071+nvchenghaoz@users.noreply.github.com>
@nvchenghaoz
Copy link
Copy Markdown
Collaborator Author

/bot run --stage-list "A10-Build_Docs, A10-PackageSanityCheck-PY310-UB2204, A100X-PackageSanityCheck-PY312-UB2404, A30-AutoDeploy-1, H100_PCIe-AutoDeploy-1, DGX_B200-AutoDeploy-1, A100X-PyTorch-1, DGX_H100-4_GPUs-AutoDeploy-1, DGX_B200-4_GPUs-AutoDeploy-1, DGX_H100-4_GPUs-AutoDeploy-Post-Merge-1, DGX_B200-8_GPUs-AutoDeploy-Post-Merge-1" --disable-fail-fast

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #46706 [ run ] triggered by Bot. Commit: b1f6fec Link to invocation

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #46706 [ run ] completed with state SUCCESS. Commit: b1f6fec
/LLM/main/L0_MergeRequest_PR pipeline #36741 (Partly Tested) completed with status: 'SUCCESS'

CI Report

Link to invocation

@nvchenghaoz
Copy link
Copy Markdown
Collaborator Author

/bot skip --comment "AD pipeline passed"

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #46837 [ skip ] triggered by Bot. Commit: b1f6fec Link to invocation

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #46837 [ skip ] completed with state SUCCESS. Commit: b1f6fec
Skipping testing for commit b1f6fec

Link to invocation

@nvchenghaoz nvchenghaoz merged commit 92dde9e into NVIDIA:main May 5, 2026
6 checks passed
eopXD added a commit to eopXD/TensorRT-LLM that referenced this pull request May 6, 2026
Resolve conflict in ad_executor.py: combine the dual-pool VSWA prelude
(kv_group_windows/is_vswa + per-pool dicts) from sg/swa-take2 with the
hot-loop locals (_tokens_per_block, _use_mamba) introduced by NVIDIA#13012,
and apply NVIDIA#13012's inlined num_active_blocks_i formula in the non-VSWA
branch so the perf win still applies on single-pool models.

Signed-off-by: Yueh-Ting Chen <yueh.ting.chen@gmail.com>
yufeiwu-nv pushed a commit to yufeiwu-nv/TensorRT-LLM that referenced this pull request May 19, 2026
…uling loop (NVIDIA#13012)

Signed-off-by: Chenghao Zhang <211069071+nvchenghaoz@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants