Skip to content

[https://nvbugs/6094224][fix] Fix mamba disagg issues when conc > mbs#13274

Merged
bo-nv merged 8 commits into
NVIDIA:mainfrom
bo-nv:main-6094224-fix
May 6, 2026
Merged

[https://nvbugs/6094224][fix] Fix mamba disagg issues when conc > mbs#13274
bo-nv merged 8 commits into
NVIDIA:mainfrom
bo-nv:main-6094224-fix

Conversation

@bo-nv
Copy link
Copy Markdown
Collaborator

@bo-nv bo-nv commented Apr 21, 2026

Summary by CodeRabbit

  • Bug Fixes
    • Improved Mamba cache management to prevent context request overflows by filtering requests against available cache capacity during scheduling.

Description

Test Coverage

PR Checklist

Please review the following before submitting your PR:

  • PR description clearly explains what and why. If using CodeRabbit's summary, please make sure it makes sense.

  • PR Follows TRT-LLM CODING GUIDELINES to the best of your knowledge.

  • Test cases are provided for new code paths (see test instructions)

  • Any new dependencies have been scanned for license and vulnerabilities

  • CODEOWNERS updated if ownership changes

  • Documentation updated as needed

  • Update tava architecture diagram if there is a significant design change in PR.

  • The reviewers assigned automatically/manually are appropriate for the PR.

  • Please check this after reviewing the above items as appropriate for this PR.

GitHub Bot Help

To see a list of available CI bot commands, please comment /bot help.

bo-nv added 4 commits April 20, 2026 14:35
Signed-off-by: Bo Deng <deemod@nvidia.com>
Signed-off-by: Bo Deng <deemod@nvidia.com>
Signed-off-by: Bo Deng <deemod@nvidia.com>
@bo-nv bo-nv requested review from a team as code owners April 21, 2026 10:53
@bo-nv bo-nv self-assigned this Apr 21, 2026
@bo-nv bo-nv requested a review from lfr-0531 April 21, 2026 10:53
@bo-nv bo-nv marked this pull request as draft April 21, 2026 10:53
@bo-nv
Copy link
Copy Markdown
Collaborator Author

bo-nv commented Apr 21, 2026

/bot run --disable-fail-fast

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented Apr 21, 2026

📝 Walkthrough

Walkthrough

Added filter_ctx_requests_by_capacity() method to MambaCacheManager and PythonMambaCacheManager to filter context requests based on available Mamba cache blocks. Modified _schedule() in PyExecutor to filter scheduled context requests when using MambaHybridCacheManager with kv_cache_transceiver enabled, adjusting the returned fitting request count accordingly.

Changes

Cohort / File(s) Summary
Mamba Cache Filtering
tensorrt_llm/_torch/pyexecutor/mamba_cache_manager.py
Added filter_ctx_requests_by_capacity() method to both PythonMambaCacheManager and MambaCacheManager wrapper. The Python implementation iterates through context requests and returns a capacity-constrained prefix, excluding blocks already in the cache index. The wrapper delegates to the Python implementation when not using C++ backend.
Scheduler Integration
tensorrt_llm/_torch/pyexecutor/py_executor.py
Added import of MambaHybridCacheManager. Modified _schedule() to filter scheduled context requests via the new filter_ctx_requests_by_capacity() method when conditions are met (using MambaHybridCacheManager with kv_cache_transceiver enabled), adjusting the returned num_fitting_reqs accordingly.

Estimated code review effort

🎯 2 (Simple) | ⏱️ ~10 minutes

🚥 Pre-merge checks | ✅ 3 | ❌ 2

❌ Failed checks (2 warnings)

Check name Status Explanation Resolution
Description check ⚠️ Warning The PR description is incomplete, containing only the template structure with empty sections for Description and Test Coverage, and lacks explanation of the issue and solution. Fill in the Description section explaining the mamba disagg issue and how the cache filtering fixes it, and complete the Test Coverage section listing relevant tests.
Docstring Coverage ⚠️ Warning Docstring coverage is 37.50% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (3 passed)
Check name Status Explanation
Title check ✅ Passed The title clearly references the NVBugs ID and describes the main fix: addressing mamba disagg issues when concurrency exceeds max batch size.
Linked Issues check ✅ Passed Check skipped because no linked issues were found for this pull request.
Out of Scope Changes check ✅ Passed Check skipped because no linked issues were found for this pull request.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@tensorrt_llm/_torch/pyexecutor/mamba_cache_manager.py`:
- Around line 579-582: The C++ path in filter_ctx_requests_by_capacity
short-circuits when self._use_cpp is true, bypassing the new capacity guard;
either implement the same capacity-filtering logic in the C++ binding class
(e.g., add a filter_ctx_requests_by_capacity method to CppMambaCacheManager and
expose it via the bindings) so PyExecutor._schedule will call into the C++
check, or detect the unsupported combination and fail fast by raising a clear
error when _use_cpp is true and disaggregated scheduling is requested (check
TRTLLM_USE_CPP_MAMBA environment flag or the manager type) so you don't return
unfiltered context_requests. Ensure you reference and update
filter_ctx_requests_by_capacity, CppMambaCacheManager (and its bindings), and
the call site in PyExecutor._schedule accordingly.

In `@tensorrt_llm/_torch/pyexecutor/py_executor.py`:
- Around line 2947-2959: The code currently overwrites num_fitting (set from
scheduler_output.num_fitting_requests) with the filtered context count after
calling kv_cache_manager.filter_ctx_requests_by_capacity, which loses the
original "did anything fit?" signal; revert that overwrite so num_fitting
remains scheduler_output.num_fitting_requests, still apply the filtered
scheduled_context_requests to ScheduledRequests.reset_context_requests, and if
you need the filtered count keep it in a separate local (e.g.,
filtered_ctx_count) rather than assigning it back to num_fitting; refer to the
variables/methods num_fitting, scheduler_output.num_fitting_requests,
scheduled_context_requests, kv_cache_manager.filter_ctx_requests_by_capacity,
and ScheduledRequests.reset_context_requests to locate and fix the code.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro Plus

Run ID: aecbc7f5-b3da-4393-9ada-71446a7b2ab1

📥 Commits

Reviewing files that changed from the base of the PR and between ab315dd and 56079ea.

📒 Files selected for processing (2)
  • tensorrt_llm/_torch/pyexecutor/mamba_cache_manager.py
  • tensorrt_llm/_torch/pyexecutor/py_executor.py

Comment thread tensorrt_llm/_torch/pyexecutor/mamba_cache_manager.py
Comment thread tensorrt_llm/_torch/pyexecutor/py_executor.py
@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #44710 [ run ] triggered by Bot. Commit: 56079ea Link to invocation

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #44710 [ run ] completed with state SUCCESS. Commit: 56079ea
/LLM/main/L0_MergeRequest_PR pipeline #35074 completed with status: 'FAILURE'

CI Report

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

Link to invocation

@bo-nv
Copy link
Copy Markdown
Collaborator Author

bo-nv commented Apr 21, 2026

/bot run

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #44805 [ run ] triggered by Bot. Commit: 56079ea Link to invocation

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #44805 [ run ] completed with state SUCCESS. Commit: 56079ea
/LLM/main/L0_MergeRequest_PR pipeline #35155 completed with status: 'FAILURE'

CI Report

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

Link to invocation

@bo-nv
Copy link
Copy Markdown
Collaborator Author

bo-nv commented Apr 21, 2026

/bot run

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #44812 [ run ] triggered by Bot. Commit: 56079ea Link to invocation

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #44812 [ run ] completed with state SUCCESS. Commit: 56079ea
/LLM/main/L0_MergeRequest_PR pipeline #35161 completed with status: 'FAILURE'

CI Report

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

Link to invocation

@bo-nv
Copy link
Copy Markdown
Collaborator Author

bo-nv commented Apr 22, 2026

/bot run

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #44828 [ run ] triggered by Bot. Commit: 56079ea Link to invocation

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #44828 [ run ] completed with state SUCCESS. Commit: 56079ea
/LLM/main/L0_MergeRequest_PR pipeline #35174 completed with status: 'FAILURE'

CI Report

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

Link to invocation

@bo-nv
Copy link
Copy Markdown
Collaborator Author

bo-nv commented Apr 22, 2026

/bot run

@bo-nv
Copy link
Copy Markdown
Collaborator Author

bo-nv commented Apr 22, 2026

/bot run

1 similar comment
@bo-nv
Copy link
Copy Markdown
Collaborator Author

bo-nv commented Apr 22, 2026

/bot run

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #45257 [ run ] completed with state SUCCESS. Commit: 1c8bbcd
/LLM/main/L0_MergeRequest_PR pipeline #35516 completed with status: 'SUCCESS'

CI Report

Link to invocation

@bo-nv bo-nv marked this pull request as ready for review April 24, 2026 03:31
@bo-nv bo-nv requested review from Tabrizian and Wanli-Jiang April 27, 2026 04:53
@bo-nv
Copy link
Copy Markdown
Collaborator Author

bo-nv commented Apr 27, 2026

@Tabrizian @Wanli-Jiang Please help review, thanks!

@Wanli-Jiang
Copy link
Copy Markdown
Collaborator

better to add a test to cover the corner case (conc>max batch size).

Signed-off-by: Bo Deng <deemod@nvidia.com>
@bo-nv bo-nv requested a review from a team as a code owner April 27, 2026 10:18
@bo-nv
Copy link
Copy Markdown
Collaborator Author

bo-nv commented Apr 27, 2026

/bot skip --comment "The new commit just adds a new test, and the test passed on nsc cluster"

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #45715 [ skip ] triggered by Bot. Commit: b5c1e6a Link to invocation

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #45715 [ skip ] completed with state SUCCESS. Commit: b5c1e6a
Skipping testing for commit b5c1e6a

Link to invocation

@Tabrizian
Copy link
Copy Markdown
Member

@bo-nv Is it possible to understand why scheduler is not handling this case correctly? Perhaps CapacityScheduler needs to be updated to make sure we have enough space in MambaCacheManager.

Comment thread tensorrt_llm/_torch/pyexecutor/mamba_cache_manager.py Outdated
xinhe-nv and others added 2 commits April 28, 2026 17:55
Signed-off-by: Bo Deng <deemod@nvidia.com>
@bo-nv
Copy link
Copy Markdown
Collaborator Author

bo-nv commented Apr 29, 2026

/bot skip --comment "Just add a comment"

@bo-nv
Copy link
Copy Markdown
Collaborator Author

bo-nv commented Apr 29, 2026

/bot run

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #46042 [ run ] triggered by Bot. Commit: d676205 Link to invocation

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #46042 [ run ] completed with state FAILURE. Commit: d676205
/LLM/main/L0_MergeRequest_PR pipeline #36189 completed with status: 'FAILURE'

CI Report

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

Link to invocation

@bo-nv
Copy link
Copy Markdown
Collaborator Author

bo-nv commented Apr 29, 2026

/bot run

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #46080 [ run ] triggered by Bot. Commit: d676205 Link to invocation

@bo-nv
Copy link
Copy Markdown
Collaborator Author

bo-nv commented Apr 29, 2026

/bot skip --comment "Just add a comment"

@bo-nv bo-nv enabled auto-merge (squash) April 29, 2026 08:27
@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #46106 [ skip ] triggered by Bot. Commit: d676205 Link to invocation

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #46106 [ skip ] completed with state SUCCESS. Commit: d676205
Skipping testing for commit d676205

Link to invocation

@bo-nv bo-nv merged commit a753934 into NVIDIA:main May 6, 2026
7 checks passed
yufeiwu-nv pushed a commit to yufeiwu-nv/TensorRT-LLM that referenced this pull request May 19, 2026
…NVIDIA#13274)

Signed-off-by: Bo Deng <deemod@nvidia.com>
Co-authored-by: xinhe-nv <200704525+xinhe-nv@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants