Skip to content

[TRTLLM-11375][feat] Add Kimi K2.5 multimodal vision support#12788

Merged
lancelly merged 5 commits into
NVIDIA:mainfrom
tianyuxbear:kimi-k25
May 14, 2026
Merged

[TRTLLM-11375][feat] Add Kimi K2.5 multimodal vision support#12788
lancelly merged 5 commits into
NVIDIA:mainfrom
tianyuxbear:kimi-k25

Conversation

@tianyuxbear
Copy link
Copy Markdown
Collaborator

@tianyuxbear tianyuxbear commented Apr 7, 2026

Summary by CodeRabbit

  • New Features

    • Native support for Kimi K2.5 multimodal model with image and video preprocessing, vision encoding, and fusion into text generation.
    • Multimodal input processor handling media placeholders and expanded tokenization.
  • Enhancements

    • Improved handling of “thinking” / reasoning excerpts in generation outputs (conditional stripping/logging).
  • Tests

    • Added unit, integration, and accuracy tests plus test-list entries and benchmarks for Kimi K2.5.

Description

Add full vision-language support for Kimi K2.5 (moonshotai/Kimi-K2.5) in the PyTorch backend.

Previously only the text backbone was supported (via a thin wrapper in modeling_deepseekv3.py). This PR adds the complete multimodal pipeline: native MoonViT3d vision encoder, PatchMergerMLP projector, image/video input processor, and the composite KimiK25ForConditionalGeneration model that wires vision encoding into the DeepSeek-V3 text backbone.

Model support

  • tensorrt_llm/_torch/models/modeling_kimi_k25.py (new): MoonViT3d encoder, PatchMergerMLP projector, KimiK25InputProcessor (image + video), KimiK25ForConditionalGeneration
  • tensorrt_llm/_torch/models/__init__.py: Register new model
  • tensorrt_llm/_torch/models/modeling_deepseekv3.py: Remove old text-only stub

Serving / evaluation

  • tensorrt_llm/llmapi/reasoning_parser.py: Register kimi_k25 reasoning parser
  • tensorrt_llm/evaluate/lm_eval.py: _strip_thinking() to strip <think> tags for accuracy benchmarks

Limitations

  • NVFP4 quantization only (BF16 not yet validated)
  • Multimodal disaggregated serving not yet supported (TODO)
  • Video input requires decord or cv2 for frame decoding

Test Coverage

  • Unit tests: tests/unittest/_torch/modeling/test_modeling_kimi_k25.py — structure, config parsing, input processor, E2E smoke test (4x B200)
  • MMMU accuracy: tests/integration/defs/accuracy/test_llm_api_pytorch_multimodal.py::TestKimi_K25 — NVFP4, tp8/ep8/attn_dp/dep8, reference score 82.00
  • CI entries: Added to l0_dgx_b200.yml (pre-merge) and llm_function_core.txt (QA regression)
  • Existing text-only K2.5 tests unaffected

PR Checklist

Please review the following before submitting your PR:

  • PR description clearly explains what and why. If using CodeRabbit's summary, please make sure it makes sense.

  • PR Follows TRT-LLM CODING GUIDELINES to the best of your knowledge.

  • Test cases are provided for new code paths (see test instructions)

  • Any new dependencies have been scanned for license and vulnerabilities

  • CODEOWNERS updated if ownership changes

  • Documentation updated as needed

  • Update tava architecture diagram if there is a significant design change in PR.

  • The reviewers assigned automatically/manually are appropriate for the PR.

  • Please check this after reviewing the above items as appropriate for this PR.

GitHub Bot Help

To see a list of available CI bot commands, please comment /bot help.

@tianyuxbear tianyuxbear requested review from a team as code owners April 7, 2026 02:49
@tianyuxbear tianyuxbear requested a review from syuoni April 7, 2026 02:49
@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented Apr 7, 2026

📝 Walkthrough

Walkthrough

Adds a native PyTorch Kimi K2.5 multimodal implementation (vision encoder, input processor, fusion into an LLM), moves export, updates evaluation thinking-strip logic, registers reasoning parser variant, and adds unit/integration tests and accuracy reference entries.

Changes

Cohort / File(s) Summary
Model implementation & exports
tensorrt_llm/_torch/models/modeling_kimi_k25.py, tensorrt_llm/_torch/models/__init__.py, tensorrt_llm/_torch/models/modeling_deepseekv3.py
Added native Kimi K2.5 multimodal module (vision encoder, projector, input processor, conditional generation model) and exported KimiK25ForConditionalGeneration; removed old registration/class from DeepSeekV3 file.
Evaluation / parsing
tensorrt_llm/evaluate/lm_eval.py, tensorrt_llm/llmapi/reasoning_parser.py
Added _strip_thinking() and conditional thinking-string handling in generate paths; preserve larger existing sampling max_tokens; registered reasoning parser key kimi_k25 and adjusted parser behavior when thinking disabled.
Tooling & accuracy refs
pyproject.toml, tests/integration/defs/accuracy/references/mmmu.yaml
Appended medias to codespell ignore list; added MMMU accuracy reference for moonshotai/Kimi-K2.5 (NVFP4 / FP8 KV cache, accuracy 82.00).
Integration tests & test lists
tests/integration/defs/accuracy/test_llm_api_pytorch_multimodal.py, tests/integration/test_lists/qa/llm_function_core.txt, tests/integration/test_lists/test-db/l0_dgx_b200.yml
Added TestKimi_K25 multimodal accuracy harness and parameterized test_nvfp4 variants; appended corresponding entries to QA and DGX test lists (with increased timeout).
Unit tests
tests/unittest/_torch/modeling/test_modeling_kimi_k25.py
New comprehensive unittest module covering config nesting, model structure, input processor behavior (image/video/mixed), registration check, and guarded GPU end-to-end smoke tests.

Sequence Diagram(s)

sequenceDiagram
    actor User
    participant InputProcessor as KimiK25InputProcessor
    participant VisionModel as KimiK25VisionModel
    participant LLM as DeepseekV3ForCausalLM
    participant Output as GenerateOutput

    User->>InputProcessor: submit text + media (image/video) + sampling params
    InputProcessor->>InputProcessor: tokenize text\npreprocess media (decode/split video → chunks)\ncompute grid_thw and expand media placeholders
    InputProcessor->>User: return expanded input_ids + multimodal_data
    User->>VisionModel: multimodal_data (multimodal_params)
    VisionModel->>VisionModel: patch embed\nstacked encoder layers\ntemporal pooling + spatial merge\nproject to LLM dim
    VisionModel->>LLM: vision embeddings
    LLM->>LLM: locate media_placeholder tokens\nfuse vision embeddings into input embeddings\ncausal LM forward
    LLM->>Output: generated tokens/text
    Output->>User: final text
    opt Thinking mode enabled
        Output->>Output: strip <think>...</think>\nextract final answer via heuristics
        Output->>User: post-processed answer
    end
Loading

Estimated Code Review Effort

🎯 4 (Complex) | ⏱️ ~60 minutes

Possibly Related PRs

Suggested Reviewers

  • chang-l
  • Shixiaowei02
  • Fridah-nv
  • suyoggupta
🚥 Pre-merge checks | ✅ 2 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 57.32% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (2 passed)
Check name Status Explanation
Title check ✅ Passed The title clearly and specifically describes the main change: adding Kimi K2.5 multimodal vision support, which aligns with the substantial new modeling_kimi_k25.py implementation and related infrastructure updates.
Description check ✅ Passed The description comprehensively covers what was changed (model support, serving/evaluation), test coverage details, and limitations, following the repository template structure with clear sections.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Comment @coderabbitai help to get the list of available commands and usage tips.

@tianyuxbear
Copy link
Copy Markdown
Collaborator Author

@coderabbitai review

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented Apr 7, 2026

✅ Actions performed

Review triggered.

Note: CodeRabbit is an incremental review system and does not re-review already reviewed commits. This command is applicable only when automatic reviews are paused.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

🧹 Nitpick comments (6)
tensorrt_llm/_torch/models/modeling_kimi_k25.py (3)

1305-1306: Clarify the purpose of the early return guard.

This early return pattern is unusual. If __init__ is being called multiple times, that's likely a bug elsewhere. Consider adding a comment explaining when this can legitimately occur, or using a more explicit guard.

💡 Add explanatory comment
+        # Guard against double initialization from PreTrainedModel base class
         if hasattr(self, "llm"):
             return
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tensorrt_llm/_torch/models/modeling_kimi_k25.py` around lines 1305 - 1306,
The early-return "if hasattr(self, 'llm'): return" in the __init__ of the model
is ambiguous; replace it with a clearer guard or comment: either add a concise
explanatory comment above the line describing the legitimate scenario when
__init__ may run twice (e.g., reloading, subclass init chaining) or replace the
check with an explicit initialization flag (e.g., self._initialized) set at the
end of __init__ and tested at the start to make intent explicit; update
references to the guard in the class's __init__ and ensure the flag is set after
successful initialization (methods: __init__, attribute 'llm', and new
'_initialized' flag).

1320-1320: Prefer explicit raise ValueError over assert for configuration validation.

Assertions can be disabled with -O flag. For production configuration validation, use explicit checks.

🔒 Use explicit validation
-        assert hasattr(config, "text_config"), "Kimi K2.5 config must have text_config"
+        if not hasattr(config, "text_config"):
+            raise ValueError("Kimi K2.5 config must have text_config")
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tensorrt_llm/_torch/models/modeling_kimi_k25.py` at line 1320, Replace the
runtime assertion with an explicit validation that raises a ValueError: instead
of using "assert hasattr(config, 'text_config')" (in the Kimi K2.5 model config
validation, e.g., constructor or function where config is checked), check if not
hasattr(config, "text_config") and raise ValueError("Kimi K2.5 config must have
text_config") so the check cannot be disabled by Python optimizations.

116-130: Broad exception catch is acceptable for fallback logic.

The except Exception at line 130 catches any decord failure to enable cv2 fallback. This is a reasonable pattern here since decord can fail for various reasons (missing codecs, corrupt files, etc.).

Consider logging the exception at debug level to aid troubleshooting:

💡 Optional: Add debug logging
         frame_inds = np.linspace(0, total_frames - 1, sampled_n).round().astype(int).tolist()
         raw_frames = vr.get_batch(frame_inds).asnumpy()
         frames_pil = [Image.fromarray(f) for f in raw_frames]
-    except Exception:
+    except Exception as e:
+        logger.debug("decord failed, falling back to cv2: %s", e)
         # Fallback: cv2
         import cv2
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tensorrt_llm/_torch/models/modeling_kimi_k25.py` around lines 116 - 130, The
except Exception block that falls back from decord to cv2 should log the caught
exception at debug level to aid troubleshooting; inside the except handling for
the decord import/usage (around VideoReader/_VR, vr, get_batch, etc.) call a
module logger (e.g. logging.getLogger(__name__)) and emit a debug/exception log
with the caught exception and traceback before proceeding to the cv2 fallback so
you have context when decord fails.
tests/unittest/_torch/modeling/test_modeling_kimi_k25.py (3)

719-722: Add strict=True to zip() for safety.

Using strict=True (Python 3.10+) ensures results and colors have matching lengths, catching potential mismatches early.

🔒 Add strict parameter
-            for res, (_, name) in zip(results, colors):
+            for res, (_, name) in zip(results, colors, strict=True):
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/unittest/_torch/modeling/test_modeling_kimi_k25.py` around lines 719 -
722, The zip over results and colors should use strict=True to catch length
mismatches at runtime; update the loop that iterates "for res, (_, name) in
zip(results, colors):" to use "zip(results, colors, strict=True)" (maintain the
existing unpacking of res and name) so mismatched lengths raise an error
immediately (requires Python 3.10+).

297-301: Rename unused extra variable to _.

The extra variable is unpacked but never used. Prefix with underscore to indicate intentional discard.

🧹 Fix unused variable
-        token_ids, extra = processor(
+        token_ids, _ = processor(
             {"prompt": "Hello, world!"},
             sampling_params=None,
         )
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/unittest/_torch/modeling/test_modeling_kimi_k25.py` around lines 297 -
301, The test unpacks an unused second value from the processor call (token_ids,
extra = processor(...)); rename the unused variable `extra` to `_` to indicate
intentional discard and silence linter warnings. Update the unpacking in the
test where `processor(...)` is called (the variables `token_ids` and `extra`) so
it becomes `token_ids, _ = processor(...)`, leaving the rest of the assertion
using `token_ids` unchanged.

159-192: Consider adding debug logging for video writer fallback failures.

The fallback chain pattern is acceptable for test utilities, but swallowing exceptions silently makes debugging harder when all writers fail unexpectedly. Consider adding debug logging to help diagnose environment issues.

💡 Optional: Add debug logging for failed attempts
     # Try cv2 (OpenCV) — most likely available in NVIDIA environments
     try:
         import cv2

         fourcc = cv2.VideoWriter_fourcc(*"mp4v")
         writer = cv2.VideoWriter(path, fourcc, fps, (w, h))
         for frame in frames:
             # cv2 expects BGR
             writer.write(cv2.cvtColor(frame, cv2.COLOR_RGB2BGR))
         writer.release()
         # Verify the file is non-empty
         if os.path.getsize(path) > 0:
             return path
-    except Exception:
-        pass
+    except Exception as e:
+        import logging
+        logging.getLogger(__name__).debug("cv2 video writer failed: %s", e)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/unittest/_torch/modeling/test_modeling_kimi_k25.py` around lines 159 -
192, The three broad except blocks that silently swallow errors around the video
writer fallbacks (the cv2 block using VideoWriter and writer.release, the
imageio.v3 iio.imwrite(plugin="pyav") block, and the imageio
iio.imwrite(np.stack(frames)) block) should log the caught exception and which
backend failed before continuing; update each except Exception: to capture the
exception as e and call a logger (e.g.,
logging.getLogger(__name__).debug/exception) or print with the backend name and
exception details so failures are visible while preserving the fallback
behavior, and ensure the final cleanup (os.unlink(path)) still runs when all
backends fail.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@tensorrt_llm/evaluate/lm_eval.py`:
- Around line 472-476: The current check uses the raw override dict
(self.chat_template_kwargs) so models whose chat template defaults to thinking
mode are skipped; replace that by resolving effective chat template kwargs via
get_chat_template_kwargs(...) and then set thinking_enabled =
bool(effective_kwargs and effective_kwargs.get("thinking")) before calling
_strip_thinking(); locate the block around thinking_enabled (and functions
apply_chat_template and _strip_thinking) and use get_chat_template_kwargs to
compute the actual kwargs instead of using self.chat_template_kwargs directly.

In `@tensorrt_llm/llmapi/reasoning_parser.py`:
- Around line 428-432: The code path that handles missing end markers currently
sets reasoning_content = "" and moves the entire text into content when
reasoning_at_start is True, causing partial/unfinished <think> outputs (e.g.,
kimi_k25 instant-mode) to be treated as final answer and leak reasoning into
content; instead preserve the unclosed reasoning in reasoning_content and remove
it from content: detect the leading reasoning segment when reasoning_at_start is
True (existing logic in the same function that sets reasoning_content/content),
assign that leading portion to reasoning_content and set content to the
remainder (or empty if none), so unfinished <think> fragments remain classified
as reasoning rather than as final content (update the branches that currently
overwrite reasoning_content and content to implement this split).

In `@tests/integration/test_lists/test-db/l0_dgx_b200.yml`:
- Around line 187-190: The YAML places the Kimi K2.5 multimodal jobs under
stage: post_merge so the new vision path isn't run pre-merge; pick one
representative test (e.g.
accuracy/test_llm_api_pytorch_multimodal.py::TestKimi_K25::test_nvfp4[tp8] or
another from that group) and move its entry from the post_merge block into the
existing pre-merge 8-GPU block so an 8-GPU pre-merge run covers the new vision
path; ensure you retain the same test identifier and any GPU/param tags when
copying it into the pre-merge section.

---

Nitpick comments:
In `@tensorrt_llm/_torch/models/modeling_kimi_k25.py`:
- Around line 1305-1306: The early-return "if hasattr(self, 'llm'): return" in
the __init__ of the model is ambiguous; replace it with a clearer guard or
comment: either add a concise explanatory comment above the line describing the
legitimate scenario when __init__ may run twice (e.g., reloading, subclass init
chaining) or replace the check with an explicit initialization flag (e.g.,
self._initialized) set at the end of __init__ and tested at the start to make
intent explicit; update references to the guard in the class's __init__ and
ensure the flag is set after successful initialization (methods: __init__,
attribute 'llm', and new '_initialized' flag).
- Line 1320: Replace the runtime assertion with an explicit validation that
raises a ValueError: instead of using "assert hasattr(config, 'text_config')"
(in the Kimi K2.5 model config validation, e.g., constructor or function where
config is checked), check if not hasattr(config, "text_config") and raise
ValueError("Kimi K2.5 config must have text_config") so the check cannot be
disabled by Python optimizations.
- Around line 116-130: The except Exception block that falls back from decord to
cv2 should log the caught exception at debug level to aid troubleshooting;
inside the except handling for the decord import/usage (around VideoReader/_VR,
vr, get_batch, etc.) call a module logger (e.g. logging.getLogger(__name__)) and
emit a debug/exception log with the caught exception and traceback before
proceeding to the cv2 fallback so you have context when decord fails.

In `@tests/unittest/_torch/modeling/test_modeling_kimi_k25.py`:
- Around line 719-722: The zip over results and colors should use strict=True to
catch length mismatches at runtime; update the loop that iterates "for res, (_,
name) in zip(results, colors):" to use "zip(results, colors, strict=True)"
(maintain the existing unpacking of res and name) so mismatched lengths raise an
error immediately (requires Python 3.10+).
- Around line 297-301: The test unpacks an unused second value from the
processor call (token_ids, extra = processor(...)); rename the unused variable
`extra` to `_` to indicate intentional discard and silence linter warnings.
Update the unpacking in the test where `processor(...)` is called (the variables
`token_ids` and `extra`) so it becomes `token_ids, _ = processor(...)`, leaving
the rest of the assertion using `token_ids` unchanged.
- Around line 159-192: The three broad except blocks that silently swallow
errors around the video writer fallbacks (the cv2 block using VideoWriter and
writer.release, the imageio.v3 iio.imwrite(plugin="pyav") block, and the imageio
iio.imwrite(np.stack(frames)) block) should log the caught exception and which
backend failed before continuing; update each except Exception: to capture the
exception as e and call a logger (e.g.,
logging.getLogger(__name__).debug/exception) or print with the backend name and
exception details so failures are visible while preserving the fallback
behavior, and ensure the final cleanup (os.unlink(path)) still runs when all
backends fail.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: 476ad443-d5c8-49ea-994a-5e010f506726

📥 Commits

Reviewing files that changed from the base of the PR and between 88bbb4d and da3b1df.

📒 Files selected for processing (11)
  • pyproject.toml
  • tensorrt_llm/_torch/models/__init__.py
  • tensorrt_llm/_torch/models/modeling_deepseekv3.py
  • tensorrt_llm/_torch/models/modeling_kimi_k25.py
  • tensorrt_llm/evaluate/lm_eval.py
  • tensorrt_llm/llmapi/reasoning_parser.py
  • tests/integration/defs/accuracy/references/mmmu.yaml
  • tests/integration/defs/accuracy/test_llm_api_pytorch_multimodal.py
  • tests/integration/test_lists/qa/llm_function_core.txt
  • tests/integration/test_lists/test-db/l0_dgx_b200.yml
  • tests/unittest/_torch/modeling/test_modeling_kimi_k25.py
💤 Files with no reviewable changes (1)
  • tensorrt_llm/_torch/models/modeling_deepseekv3.py

Comment thread tensorrt_llm/evaluate/lm_eval.py Outdated
Comment thread tensorrt_llm/llmapi/reasoning_parser.py Outdated
Comment thread tests/integration/test_lists/test-db/l0_dgx_b200.yml Outdated
@tianyuxbear
Copy link
Copy Markdown
Collaborator Author

/bot run

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #42067 [ run ] triggered by Bot. Commit: da3b1df Link to invocation

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #42067 [ run ] completed with state SUCCESS. Commit: da3b1df
/LLM/main/L0_MergeRequest_PR pipeline #32907 completed with status: 'FAILURE'

CI Report

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

Link to invocation

@tianyuxbear
Copy link
Copy Markdown
Collaborator Author

/bot run --disable-fail-fast

@tianyuxbear
Copy link
Copy Markdown
Collaborator Author

/bot help

@github-actions
Copy link
Copy Markdown

github-actions Bot commented Apr 7, 2026

GitHub Bot Help

/bot [-h] ['run', 'kill', 'skip', 'reuse-pipeline'] ...

Provide a user friendly way for developers to interact with a Jenkins server.

Run /bot [-h|--help] to print this help message.

See details below for each supported subcommand.

Details

run [--reuse-test (optional)pipeline-id --disable-fail-fast --skip-test --stage-list "A10-PyTorch-1, xxx" --gpu-type "A30, H100_PCIe" --test-backend "pytorch, cpp" --add-multi-gpu-test --only-multi-gpu-test --disable-multi-gpu-test --post-merge --extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx" --detailed-log --debug(experimental) --high-priority]

Launch build/test pipelines. All previously running jobs will be killed.

--reuse-test (optional)pipeline-id (OPTIONAL) : Allow the new pipeline to reuse build artifacts and skip successful test stages from a specified pipeline or the last pipeline if no pipeline-id is indicated. If the Git commit ID has changed, this option will be always ignored. The DEFAULT behavior of the bot is to reuse build artifacts and successful test results from the last pipeline.

--disable-reuse-test (OPTIONAL) : Explicitly prevent the pipeline from reusing build artifacts and skipping successful test stages from a previous pipeline. Ensure that all builds and tests are run regardless of previous successes.

--disable-fail-fast (OPTIONAL) : Disable fail fast on build/tests/infra failures.

--skip-test (OPTIONAL) : Skip all test stages, but still run build stages, package stages and sanity check stages. Note: Does NOT update GitHub check status.

--stage-list "A10-PyTorch-1, xxx" (OPTIONAL) : Only run the specified test stages. Examples: "A10-PyTorch-1, xxx". Note: Does NOT update GitHub check status.

--gpu-type "A30, H100_PCIe" (OPTIONAL) : Only run the test stages on the specified GPU types. Examples: "A30, H100_PCIe". Note: Does NOT update GitHub check status.

--test-backend "pytorch, cpp" (OPTIONAL) : Skip test stages which don't match the specified backends. Only support [pytorch, cpp, tensorrt, triton]. Examples: "pytorch, cpp" (does not run test stages with tensorrt or triton backend). Note: Does NOT update GitHub pipeline status.

--only-multi-gpu-test (OPTIONAL) : Only run the multi-GPU tests. Note: Does NOT update GitHub check status.

--disable-multi-gpu-test (OPTIONAL) : Disable the multi-GPU tests. Note: Does NOT update GitHub check status.

--add-multi-gpu-test (OPTIONAL) : Force run the multi-GPU tests in addition to running L0 pre-merge pipeline.

--post-merge (OPTIONAL) : Run the L0 post-merge pipeline instead of the ordinary L0 pre-merge pipeline.

--extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx" (OPTIONAL) : Run the ordinary L0 pre-merge pipeline and specified test stages. Examples: --extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx".

--detailed-log (OPTIONAL) : Enable flushing out all logs to the Jenkins console. This will significantly increase the log volume and may slow down the job.

--debug (OPTIONAL) : Experimental feature. Enable access to the CI container for debugging purpose. Note: Specify exactly one stage in the stage-list parameter to access the appropriate container environment. Note: Does NOT update GitHub check status.

--high-priority (OPTIONAL) : Run the pipeline with high priority. This option is restricted to authorized users only and will route the job to a high-priority queue.

kill

kill

Kill all running builds associated with pull request.

skip

skip --comment COMMENT

Skip testing for latest commit on pull request. --comment "Reason for skipping build/test" is required. IMPORTANT NOTE: This is dangerous since lack of user care and validation can cause top of tree to break.

reuse-pipeline

reuse-pipeline

Reuse a previous pipeline to validate current commit. This action will also kill all currently running builds associated with the pull request. IMPORTANT NOTE: This is dangerous since lack of user care and validation can cause top of tree to break.

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #42124 [ run ] triggered by Bot. Commit: da3b1df Link to invocation

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #42124 [ run ] completed with state SUCCESS. Commit: da3b1df
/LLM/main/L0_MergeRequest_PR pipeline #32960 completed with status: 'FAILURE'

CI Report

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

Link to invocation

@tianyuxbear
Copy link
Copy Markdown
Collaborator Author

/bot run --disable-fail-fast

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #42219 [ run ] triggered by Bot. Commit: c87ea71 Link to invocation

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #42219 [ run ] completed with state SUCCESS. Commit: c87ea71
/LLM/main/L0_MergeRequest_PR pipeline #33034 completed with status: 'FAILURE'

CI Report

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

Link to invocation

@tianyuxbear
Copy link
Copy Markdown
Collaborator Author

/bot --disable-fail-fast

@github-actions
Copy link
Copy Markdown

github-actions Bot commented Apr 9, 2026

GitHub Bot Help

/bot [-h] ['run', 'kill', 'skip', 'reuse-pipeline'] ...

Provide a user friendly way for developers to interact with a Jenkins server.

Run /bot [-h|--help] to print this help message.

See details below for each supported subcommand.

Details

run [--reuse-test (optional)pipeline-id --disable-fail-fast --skip-test --stage-list "A10-PyTorch-1, xxx" --gpu-type "A30, H100_PCIe" --test-backend "pytorch, cpp" --add-multi-gpu-test --only-multi-gpu-test --disable-multi-gpu-test --post-merge --extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx" --detailed-log --debug(experimental) --high-priority]

Launch build/test pipelines. All previously running jobs will be killed.

--reuse-test (optional)pipeline-id (OPTIONAL) : Allow the new pipeline to reuse build artifacts and skip successful test stages from a specified pipeline or the last pipeline if no pipeline-id is indicated. If the Git commit ID has changed, this option will be always ignored. The DEFAULT behavior of the bot is to reuse build artifacts and successful test results from the last pipeline.

--disable-reuse-test (OPTIONAL) : Explicitly prevent the pipeline from reusing build artifacts and skipping successful test stages from a previous pipeline. Ensure that all builds and tests are run regardless of previous successes.

--disable-fail-fast (OPTIONAL) : Disable fail fast on build/tests/infra failures.

--skip-test (OPTIONAL) : Skip all test stages, but still run build stages, package stages and sanity check stages. Note: Does NOT update GitHub check status.

--stage-list "A10-PyTorch-1, xxx" (OPTIONAL) : Only run the specified test stages. Examples: "A10-PyTorch-1, xxx". Note: Does NOT update GitHub check status.

--gpu-type "A30, H100_PCIe" (OPTIONAL) : Only run the test stages on the specified GPU types. Examples: "A30, H100_PCIe". Note: Does NOT update GitHub check status.

--test-backend "pytorch, cpp" (OPTIONAL) : Skip test stages which don't match the specified backends. Only support [pytorch, cpp, tensorrt, triton]. Examples: "pytorch, cpp" (does not run test stages with tensorrt or triton backend). Note: Does NOT update GitHub pipeline status.

--only-multi-gpu-test (OPTIONAL) : Only run the multi-GPU tests. Note: Does NOT update GitHub check status.

--disable-multi-gpu-test (OPTIONAL) : Disable the multi-GPU tests. Note: Does NOT update GitHub check status.

--add-multi-gpu-test (OPTIONAL) : Force run the multi-GPU tests in addition to running L0 pre-merge pipeline.

--post-merge (OPTIONAL) : Run the L0 post-merge pipeline instead of the ordinary L0 pre-merge pipeline.

--extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx" (OPTIONAL) : Run the ordinary L0 pre-merge pipeline and specified test stages. Examples: --extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx".

--detailed-log (OPTIONAL) : Enable flushing out all logs to the Jenkins console. This will significantly increase the log volume and may slow down the job.

--debug (OPTIONAL) : Experimental feature. Enable access to the CI container for debugging purpose. Note: Specify exactly one stage in the stage-list parameter to access the appropriate container environment. Note: Does NOT update GitHub check status.

--high-priority (OPTIONAL) : Run the pipeline with high priority. This option is restricted to authorized users only and will route the job to a high-priority queue.

kill

kill

Kill all running builds associated with pull request.

skip

skip --comment COMMENT

Skip testing for latest commit on pull request. --comment "Reason for skipping build/test" is required. IMPORTANT NOTE: This is dangerous since lack of user care and validation can cause top of tree to break.

reuse-pipeline

reuse-pipeline

Reuse a previous pipeline to validate current commit. This action will also kill all currently running builds associated with the pull request. IMPORTANT NOTE: This is dangerous since lack of user care and validation can cause top of tree to break.

@tianyuxbear
Copy link
Copy Markdown
Collaborator Author

/bot run --disable-fail-fast

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #42476 [ run ] triggered by Bot. Commit: c87ea71 Link to invocation

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #42476 [ run ] completed with state SUCCESS. Commit: c87ea71
/LLM/main/L0_MergeRequest_PR pipeline #33230 completed with status: 'FAILURE'

CI Report

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

Link to invocation

Copy link
Copy Markdown
Collaborator

@yechank-nvidia yechank-nvidia left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the work! LGTM

Copy link
Copy Markdown
Collaborator

@Superjomn Superjomn left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Copy link
Copy Markdown
Collaborator

@tburt-nv tburt-nv left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No dependency changes, approved

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #47670 [ run ] completed with state SUCCESS. Commit: d589f36
/LLM/main/L0_MergeRequest_PR pipeline #37570 completed with status: 'FAILURE'

CI Report

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

Link to invocation

@tianyuxbear
Copy link
Copy Markdown
Collaborator Author

/bot run --disable-fail-fast

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #47769 [ run ] triggered by Bot. Commit: da6599e Link to invocation

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #47769 [ run ] completed with state SUCCESS. Commit: da6599e
/LLM/main/L0_MergeRequest_PR pipeline #37661 completed with status: 'FAILURE'

CI Report

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

CI Agent Failure Analysis

Link to invocation

@tianyuxbear tianyuxbear force-pushed the kimi-k25 branch 2 times, most recently from bac0383 to 0f96586 Compare May 12, 2026 04:32
@tianyuxbear
Copy link
Copy Markdown
Collaborator Author

/bot run --disable-fail-fast

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #47902 [ run ] triggered by Bot. Commit: 0f96586 Link to invocation

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #47902 [ run ] completed with state SUCCESS. Commit: 0f96586
/LLM/main/L0_MergeRequest_PR pipeline #37751 completed with status: 'FAILURE'

CI Report

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

Link to invocation

@tianyuxbear
Copy link
Copy Markdown
Collaborator Author

/bot run --disable-fail-fast

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #48082 [ run ] triggered by Bot. Commit: d8c7b07 Link to invocation

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #48082 [ run ] completed with state SUCCESS. Commit: d8c7b07
/LLM/main/L0_MergeRequest_PR pipeline #37913 completed with status: 'FAILURE'

CI Report

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

CI Agent Failure Analysis

Link to invocation

Add full vision-language support for moonshotai/Kimi-K2.5:

- Native MoonViT3d vision encoder + PatchMergerMLP projector
- KimiK25InputProcessor for image/video preprocessing
- KimiK25ForConditionalGeneration wiring vision into DeepSeek-V3 backbone
- Remove text-only stub from modeling_deepseekv3.py
- Register kimi_k25 reasoning parser
- Strip <think> tags in lm_eval for thinking-mode accuracy benchmarks
- Unit tests (structure + input processor) and E2E smoke test
- MMMU accuracy test (NVFP4, tp8/ep8/attn_dp/dep8) with reference score
- CI test list entries for pre-merge and QA regression

Signed-off-by: Tianyu Xiong <117647511+tianyuxbear@users.noreply.github.com>
Signed-off-by: Tianyu Xiong <117647511+tianyuxbear@users.noreply.github.com>
…_video

Signed-off-by: Tianyu Xiong <117647511+tianyuxbear@users.noreply.github.com>
Signed-off-by: Tianyu Xiong <117647511+tianyuxbear@users.noreply.github.com>
…t-in

Signed-off-by: Tianyu Xiong <117647511+tianyuxbear@users.noreply.github.com>
@tianyuxbear
Copy link
Copy Markdown
Collaborator Author

/bot run --disable-fail-fast

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #48196 [ run ] triggered by Bot. Commit: ea2d07f Link to invocation

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #48196 [ run ] completed with state SUCCESS. Commit: ea2d07f
/LLM/main/L0_MergeRequest_PR pipeline #38015 completed with status: 'FAILURE'

CI Report

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

CI Agent Failure Analysis

Link to invocation

@tianyuxbear
Copy link
Copy Markdown
Collaborator Author

/bot run --disable-fail-fast

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #48266 [ run ] triggered by Bot. Commit: ea2d07f Link to invocation

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #48266 [ run ] completed with state SUCCESS. Commit: ea2d07f
/LLM/main/L0_MergeRequest_PR pipeline #38080 completed with status: 'SUCCESS'

CI Report

Link to invocation

@lancelly lancelly merged commit a227373 into NVIDIA:main May 14, 2026
6 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

9 participants