[None][feat] Add the invocation path for mamba2 mtp custom op#12787
[None][feat] Add the invocation path for mamba2 mtp custom op#12787nv-guomingz merged 12 commits intoNVIDIA:mainfrom
Conversation
Signed-off-by: jiant <107457950+JadoTu@users.noreply.github.com>
|
/bot run |
📝 WalkthroughWalkthroughThis PR adds optional use of a TRT-LLM CUDA custom op for MTP SSM cache updates in Mamba2 mixer through an environment-variable-controlled dispatch mechanism. Two integration tests validate this new code path on 4-GPU and 8-GPU configurations. Changes
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~25 minutes 🚥 Pre-merge checks | ✅ 1 | ❌ 2❌ Failed checks (2 warnings)
✅ Passed checks (1 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Comment |
There was a problem hiding this comment.
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (2)
tensorrt_llm/_torch/modules/mamba/mamba2_mixer.py (1)
1-1: 🛠️ Refactor suggestion | 🟠 MajorUpdate the copyright year to 2026.
The copyright header shows "2022-2024" but this file is being modified in 2026. As per coding guidelines, the year should be updated on modified files.
-# SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-FileCopyrightText: Copyright (c) 2022-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tensorrt_llm/_torch/modules/mamba/mamba2_mixer.py` at line 1, Update the SPDX copyright header line (the SPDX-FileCopyrightText comment at the top of the file) to reflect the current modification year by changing "2022-2024" to "2022-2026" so the header is accurate for 2026.tests/integration/defs/accuracy/test_llm_api_pytorch.py (1)
1-1:⚠️ Potential issue | 🟠 MajorUpdate SPDX copyright year for this modified file.
Line 1 still uses
2025, but this file has meaningful changes in this PR and should be updated to the latest modification year.Proposed fix
-# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.As per coding guidelines: “Add NVIDIA copyright header to ALL new files; update year on modified files” and “All TensorRT-LLM Open Source Software code files should contain an NVIDIA copyright header with the year of latest meaningful modification”.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/integration/defs/accuracy/test_llm_api_pytorch.py` at line 1, Update the SPDX header year in the file's top-line comment: replace "Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES." in the SPDX-FileCopyrightText line with the latest modification year (2026) so the header reads 2026.
🧹 Nitpick comments (2)
tensorrt_llm/_torch/modules/mamba/mamba2_mixer.py (1)
162-164: Consider using UPPER_SNAKE_CASE for the environment variable name.Environment variables conventionally use UPPER_SNAKE_CASE (e.g.,
MAMBA2_MTP_USE_CUSTOM_OP). This aligns with common conventions and coding guidelines that specify constants should use UPPER_SNAKE_CASE.- self._use_mtp_custom_op = os.environ.get("mamba2_mtp_use_custom_op", - "0") == "1" + self._use_mtp_custom_op = os.environ.get("MAMBA2_MTP_USE_CUSTOM_OP", + "0") == "1"🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tensorrt_llm/_torch/modules/mamba/mamba2_mixer.py` around lines 162 - 164, The environment variable name mamba2_mtp_use_custom_op used to set self._use_mtp_custom_op should follow UPPER_SNAKE_CASE conventions; update the code that reads os.environ.get("mamba2_mtp_use_custom_op", "0") to use "MAMBA2_MTP_USE_CUSTOM_OP" instead and ensure any related docs/tests or other occurrences are updated to the new name so _use_mtp_custom_op continues to be initialized from the consistent UPPER_SNAKE_CASE variable.tests/integration/defs/accuracy/test_llm_api_pytorch.py (1)
6668-6727: Consider extracting shared MTP acceptance-rate test logic into a helper.This block is nearly identical to the existing non-custom-op AR test. Pulling prompt prep + acceptance-rate computation into a shared helper would reduce maintenance drift between the two paths.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/integration/defs/accuracy/test_llm_api_pytorch.py` around lines 6668 - 6727, This test duplicates prompt preparation and MTP acceptance-rate logic; extract a shared helper (e.g., create function compute_accept_rate_or_run_mtp_test) that accepts an LLM spec (LLM), raw_prompts, MTPDecodingConfig/sampling params (SamplingParams), and max_draft_len, performs tokenizer.apply_chat_template + encode, iterates generate_async to compute num_drafted/num_accepted and returns the acceptance rate; then call that helper from this MTP test and the existing non-custom-op AR test to replace the duplicated blocks (refer to symbols: MTPDecodingConfig, LLM, tokenizer.apply_chat_template, tokenizer.encode, generate_async, SamplingParams, max_draft_len).
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Outside diff comments:
In `@tensorrt_llm/_torch/modules/mamba/mamba2_mixer.py`:
- Line 1: Update the SPDX copyright header line (the SPDX-FileCopyrightText
comment at the top of the file) to reflect the current modification year by
changing "2022-2024" to "2022-2026" so the header is accurate for 2026.
In `@tests/integration/defs/accuracy/test_llm_api_pytorch.py`:
- Line 1: Update the SPDX header year in the file's top-line comment: replace
"Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES." in the
SPDX-FileCopyrightText line with the latest modification year (2026) so the
header reads 2026.
---
Nitpick comments:
In `@tensorrt_llm/_torch/modules/mamba/mamba2_mixer.py`:
- Around line 162-164: The environment variable name mamba2_mtp_use_custom_op
used to set self._use_mtp_custom_op should follow UPPER_SNAKE_CASE conventions;
update the code that reads os.environ.get("mamba2_mtp_use_custom_op", "0") to
use "MAMBA2_MTP_USE_CUSTOM_OP" instead and ensure any related docs/tests or
other occurrences are updated to the new name so _use_mtp_custom_op continues to
be initialized from the consistent UPPER_SNAKE_CASE variable.
In `@tests/integration/defs/accuracy/test_llm_api_pytorch.py`:
- Around line 6668-6727: This test duplicates prompt preparation and MTP
acceptance-rate logic; extract a shared helper (e.g., create function
compute_accept_rate_or_run_mtp_test) that accepts an LLM spec (LLM),
raw_prompts, MTPDecodingConfig/sampling params (SamplingParams), and
max_draft_len, performs tokenizer.apply_chat_template + encode, iterates
generate_async to compute num_drafted/num_accepted and returns the acceptance
rate; then call that helper from this MTP test and the existing non-custom-op AR
test to replace the duplicated blocks (refer to symbols: MTPDecodingConfig, LLM,
tokenizer.apply_chat_template, tokenizer.encode, generate_async, SamplingParams,
max_draft_len).
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Run ID: 81888ec5-e6e6-4c5a-b44b-75dc1433a394
📒 Files selected for processing (3)
tensorrt_llm/_torch/modules/mamba/mamba2_mixer.pytests/integration/defs/accuracy/test_llm_api_pytorch.pytests/integration/test_lists/test-db/l0_dgx_b200.yml
|
PR_Github #42032 [ run ] triggered by Bot. Commit: |
Signed-off-by: jiant <107457950+JadoTu@users.noreply.github.com>
|
/bot run |
|
PR_Github #42035 [ run ] triggered by Bot. Commit: |
|
PR_Github #42035 [ run ] completed with state
|
|
/bot run |
|
PR_Github #42063 [ run ] triggered by Bot. Commit: |
|
PR_Github #42063 [ run ] completed with state
|
|
/bot run |
|
PR_Github #42102 [ run ] triggered by Bot. Commit: |
|
PR_Github #42102 [ run ] completed with state
|
|
/bot run |
|
PR_Github #42118 [ run ] triggered by Bot. Commit: |
|
PR_Github #42118 [ run ] completed with state
|
|
/bot run |
|
PR_Github #43312 [ run ] triggered by Bot. Commit: |
|
PR_Github #43312 [ run ] completed with state
|
|
/bot run |
|
/bot kill |
|
PR_Github #43349 [ run ] triggered by Bot. Commit: |
|
PR_Github #43351 [ kill ] triggered by Bot. Commit: |
|
PR_Github #43349 [ run ] completed with state |
|
PR_Github #43351 [ kill ] completed with state |
|
/bot run |
|
PR_Github #43368 [ run ] triggered by Bot. Commit: |
|
PR_Github #43368 [ run ] completed with state
|
|
/bot run |
|
PR_Github #43414 [ run ] triggered by Bot. Commit: |
|
PR_Github #43414 [ run ] completed with state
|
|
/bot run |
|
PR_Github #43451 [ run ] triggered by Bot. Commit: |
|
PR_Github #43451 [ run ] completed with state
|
|
/bot run |
1 similar comment
|
/bot run |
|
PR_Github #43629 [ run ] triggered by Bot. Commit: |
|
PR_Github #43629 [ run ] completed with state
|
|
/bot run |
|
PR_Github #43741 [ run ] triggered by Bot. Commit: |
|
PR_Github #43741 [ run ] completed with state |
Summary by CodeRabbit
Release Notes
New Features
Tests
Test Coverage
PR Checklist
Please review the following before submitting your PR:
PR description clearly explains what and why. If using CodeRabbit's summary, please make sure it makes sense.
PR Follows TRT-LLM CODING GUIDELINES to the best of your knowledge.
Test cases are provided for new code paths (see test instructions)
Any new dependencies have been scanned for license and vulnerabilities
CODEOWNERS updated if ownership changes
Documentation updated as needed
Update tava architecture diagram if there is a significant design change in PR.
The reviewers assigned automatically/manually are appropriate for the PR.
Please check this after reviewing the above items as appropriate for this PR.
GitHub Bot Help
To see a list of available CI bot commands, please comment
/bot help.