Skip to content

[None][feat] Add llm.encode() fast path for encoder-only models#12801

Merged
pcastonguay merged 9 commits intoNVIDIA:mainfrom
tingyangk:tingyangk/encoder-llmapi-optimize
Apr 27, 2026
Merged

[None][feat] Add llm.encode() fast path for encoder-only models#12801
pcastonguay merged 9 commits intoNVIDIA:mainfrom
tingyangk:tingyangk/encoder-llmapi-optimize

Conversation

@tingyangk
Copy link
Copy Markdown
Collaborator

@tingyangk tingyangk commented Apr 7, 2026

Summary by CodeRabbit

New Features

  • Added encoder-only execution mode via new encode_only parameter for models like BERT
  • Introduced encode() API method for efficient encode-only inference
  • Added EncoderOutput dataclass containing logits and tokenized prompts
  • Generation APIs now raise errors when encode-only mode is active

Tests

  • Added test coverage for encode-only inference, batch processing, and API validation

Summary

Adds a dedicated llm.encode() API for encode-only paths that bypasses the decoder-oriented PyExecutor loop entirely. Works for encoder models and decoder models running a "single-prefill" path.

Problem

The current LLM API routes encoder models through the same PyExecutor designed for autoregressive decoders, introducing significant CPU overhead per batch from scheduler, KV cache management, sampling, and request state machine — none of which apply to encoders. Encoder models need a simple, direct path to the model’s forward call with batch inference executed in a single pass.

Solution

A new execution path (encode_only=True) that creates a lightweight EncoderExecutor instead of the full PyExecutor. The encode() method tokenizes, packs, and runs a single forward pass directly through ModelEngine.encoder_forward(), returning EncoderOutput with logits. This new API demonstrates a 3.92× speedup for the BERT 110M model (textattack/bert-base-uncased-yelp-polarity) in eager mode with batch size 10.

encode()   mean: 5.17ms  (p50=5.16ms)
generate() mean: 20.25ms (p50=20.02ms)
Speedup: 3.92x

Usage

# New dedicated path
llm = LLM(model="bert-base-uncased-yelp-polarity", encode_only=True)
outputs = llm.encode(["Hello world", "Test sentence"])                                                                                                                 
print(outputs[0].logits)  # [num_classes] tensor
                                                                                                                                                                       
# Old path still works unchanged (no encode_only flag)
llm = LLM(model="bert-base-uncased-yelp-polarity", disable_overlap_scheduler=True)
outputs = llm.generate(prompts, SamplingParams(return_context_logits=True))    
  • encode_only=True must be explicitly set. Default (None) uses the old generate() path.
  • encode_only=True creates only EncoderExecutor; False/None creates only PyExecutor. Mutually exclusive.
  • generate()/generate_async() raise RuntimeError when encode_only=True. encode() is the only API.
  • Since llm.encode() reuses PyTorchModelEngine and its _forward_step() path, features like TorchCompileConfig are compatible.

Future Works

  • Encoder CUDA graph integration — capture the encoder model to one single CUDA graph
  • Triton backend update — add an encoder model example
  • Parallelism supports (e.g. TP > 1) — expand the EncoderExecutor
  • Other minor optimizations (e.g. batch tokenization, cache AttentionMetadata, etc)

Test Coverage

  • tests/unittest/llmapi/test_llm_encode.py — 11 new tests:
    • Basic: single string, batch, token IDs, mixed input types
    • Correctness: logits compared against HuggingFace BertForSequenceClassification
    • Health check: _check_health() returns True for encoder-only LLM
  • Existing tests unaffected (no encode_only=True → old path):
    • tests/integration/defs/test_e2e.py::test_ptp_quickstart_bert
    • tests/unittest/llmapi/test_llm_pytorch.py::test_llm_reward_model

CC: @symphonylyh @amukkara @nvrohanv @schetlur-nv @juney-nvidia

PR Checklist

Please review the following before submitting your PR:

  • PR description clearly explains what and why. If using CodeRabbit's summary, please make sure it makes sense.

  • PR Follows TRT-LLM CODING GUIDELINES to the best of your knowledge.

  • Test cases are provided for new code paths (see test instructions)

  • Any new dependencies have been scanned for license and vulnerabilities

  • CODEOWNERS updated if ownership changes

  • Documentation updated as needed

  • Update tava architecture diagram if there is a significant design change in PR.

  • The reviewers assigned automatically/manually are appropriate for the PR.

  • Please check this after reviewing the above items as appropriate for this PR.

GitHub Bot Help

To see a list of available CI bot commands, please comment /bot help.

@svc-trtllm-gh-bot svc-trtllm-gh-bot added the Community want to contribute PRs initiated from Community label Apr 7, 2026
@tingyangk tingyangk force-pushed the tingyangk/encoder-llmapi-optimize branch from 04bcd64 to 83bc6b9 Compare April 7, 2026 08:45
Copy link
Copy Markdown
Collaborator

@nvrohanv nvrohanv left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some comments on tokenization piece and handling of empty batch but overall looks good!

Comment thread tensorrt_llm/llmapi/llm.py
Comment thread tensorrt_llm/llmapi/llm.py
Comment thread tensorrt_llm/llmapi/llm_args.py Outdated
Comment thread tensorrt_llm/llmapi/llm_args.py Outdated
Comment thread tensorrt_llm/_torch/pyexecutor/py_executor_creator.py Outdated
Comment thread tensorrt_llm/_torch/pyexecutor/model_engine.py Outdated
@schetlur-nv schetlur-nv requested a review from brb-nv April 7, 2026 22:12
@pcastonguay pcastonguay requested a review from Superjomn April 8, 2026 13:19
@pcastonguay
Copy link
Copy Markdown
Collaborator

@Superjomn could you review since it adds a new method to LLM API? Thx.

Comment thread tensorrt_llm/_torch/pyexecutor/model_engine.py Outdated
Comment thread tensorrt_llm/_torch/pyexecutor/model_engine.py Outdated
@tingyangk tingyangk marked this pull request as ready for review April 13, 2026 17:16
@tingyangk tingyangk requested review from a team as code owners April 13, 2026 17:16
@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented Apr 13, 2026

📝 Walkthrough

Walkthrough

Introduces encoder-only inference support to TensorRT-LLM, bypassing generation components. Adds EncoderExecutor class, encode() public API, configuration flag encoder_only to TorchLlmArgs, and infrastructure for lightweight encoder-only execution with proper control flow validation and comprehensive unit tests.

Changes

Cohort / File(s) Summary
Executor Infrastructure
tensorrt_llm/_torch/pyexecutor/encoder_executor.py, tensorrt_llm/_torch/pyexecutor/model_engine.py, tensorrt_llm/_torch/pyexecutor/py_executor_creator.py
Added lightweight EncoderExecutor class with batch_forward() for synchronous inference. Extended PyTorchModelEngine with _prepare_encoder_inputs() and encoder_forward() to handle encoder-only forward passes. Added factory function create_encoder_executor() to instantiate encoder-only executors without scheduler, sampler, or KV-cache components.
API & Configuration
tensorrt_llm/llmapi/llm.py, tensorrt_llm/llmapi/llm_args.py
Added EncoderOutput dataclass. Introduced encoder_only flag and _encoder_executor handle to BaseLLM/_TorchLLM. New public encode() method validates inputs, tokenizes, packs token IDs, calls executor, and returns encoder results. Updated generate_async(), get_stats(), get_kv_cache_events(), and _collective_rpc() to raise RuntimeError in encoder-only mode. Updated shutdown() and _check_health() for encoder executor lifecycle.
Tests
tests/unittest/llmapi/test_llm_encode.py
Added comprehensive test module validating single and batched text/token inputs, encoder output structure, correctness against HuggingFace reference, rejection of incompatible generation APIs in encoder-only mode, and health checks.

Sequence Diagram

sequenceDiagram
    participant Client
    participant LLM
    participant InputProcessor
    participant EncoderExecutor
    participant ModelEngine
    participant CUDA

    Client->>LLM: encode(inputs)
    activate LLM
    LLM->>InputProcessor: tokenize(inputs)
    InputProcessor-->>LLM: token_ids
    LLM->>EncoderExecutor: batch_forward(packed_inputs)
    activate EncoderExecutor
    EncoderExecutor->>ModelEngine: encoder_forward(inputs)
    activate ModelEngine
    ModelEngine->>ModelEngine: _prepare_encoder_inputs(inputs)
    Note over ModelEngine: Validate tokens, copy to GPU, build attn_metadata
    ModelEngine->>ModelEngine: _forward_step(model_inputs, None, False)
    ModelEngine->>CUDA: Forward pass
    CUDA-->>ModelEngine: logits
    ModelEngine-->>EncoderExecutor: encoder_output
    deactivate ModelEngine
    EncoderExecutor-->>LLM: encoder_output
    deactivate EncoderExecutor
    LLM-->>Client: EncoderOutput(logits, prompt_token_ids, prompt)
    deactivate LLM
Loading

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~25 minutes

🚥 Pre-merge checks | ✅ 2 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 75.00% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (2 passed)
Check name Status Explanation
Title check ✅ Passed The title clearly and concisely summarizes the main change: adding a fast path llm.encode() API specifically for encoder-only models.
Description check ✅ Passed The PR description is comprehensive and follows the template structure with clear sections for problem, solution, usage, test coverage, and PR checklist completion.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 5

♻️ Duplicate comments (1)
tensorrt_llm/llmapi/llm.py (1)

753-758: ⚠️ Potential issue | 🟠 Major

Normalize batched inputs before indexing the first element.

The signature accepts Sequence[PromptInputs], but this branch only recognizes list. encode([]) throws an IndexError at Line 755, and encode(("a", "b")) gets treated as one prompt instead of a batch. Please handle empty batches and non-list sequences before peeking at inputs[0].

Suggested fix
-        unbatched = not isinstance(inputs, list)
+        unbatched = not isinstance(inputs, (list, tuple))
         if not unbatched:
+            if len(inputs) == 0:
+                raise ValueError("encode() requires at least one input.")
             if isinstance(inputs[0], int):
                 unbatched = True
         if unbatched:
             inputs = [inputs]
+        else:
+            inputs = list(inputs)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tensorrt_llm/llmapi/llm.py` around lines 753 - 758, In the encode path
normalize inputs without indexing inputs[0] prematurely: treat any non-sequence
(or str/bytes) as a single prompt and wrap it in a list, treat empty sequences
as an empty batch (do not access inputs[0]), and only peek the first element
when len(inputs) > 0 to detect an unbatched numeric input; update the logic in
the function/method handling the inputs variable (the block that currently
checks isinstance(inputs, list) and inputs[0]) to use collections.abc.Sequence
checks, guard against str/bytes, and check len(inputs) before accessing index 0.
🧹 Nitpick comments (3)
tensorrt_llm/_torch/pyexecutor/py_executor_creator.py (1)

228-231: Add return type annotation.

The function signature is missing the return type annotation. For consistency with the codebase and coding guidelines requiring static type annotations for all functions.

♻️ Suggested fix
+from .encoder_executor import EncoderExecutor
+
 def create_encoder_executor(
     llm_args: TorchLlmArgs,
     checkpoint_dir: Optional[str] = None,
-):
+) -> "EncoderExecutor":

Note: Use string annotation to avoid circular import, or move the import to the top of the file.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tensorrt_llm/_torch/pyexecutor/py_executor_creator.py` around lines 228 -
231, The create_encoder_executor function is missing a return type annotation;
update its signature to include the correct return type (e.g., the executor
class/type returned by create_encoder_executor) using a string literal
annotation to avoid circular imports or alternatively import the return type at
module top if safe; locate the function create_encoder_executor in
py_executor_creator.py and add the string-based return type annotation matching
the executor class name used elsewhere in this module.
tensorrt_llm/_torch/pyexecutor/encoder_executor.py (2)

57-59: Consider explicit resource cleanup in shutdown().

The current implementation relies on Python's garbage collector to release CUDA resources. While this works, explicit cleanup (e.g., calling any model engine cleanup methods or clearing CUDA cache) would be more deterministic, especially for repeated instantiation scenarios.

♻️ Suggested improvement
     def shutdown(self):
         """No background thread to stop — just release model engine resources."""
+        if hasattr(self.model_engine, 'model') and self.model_engine.model is not None:
+            del self.model_engine.model
         del self.model_engine
+        import gc
+        gc.collect()
+        torch.cuda.empty_cache()
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tensorrt_llm/_torch/pyexecutor/encoder_executor.py` around lines 57 - 59,
shutdown currently just deletes self.model_engine; update it to perform explicit
deterministic cleanup by calling any available cleanup/close method on the model
engine (e.g., self.model_engine.cleanup() or self.model_engine.close() if
present), then delete the attribute and clear CUDA memory with
torch.cuda.empty_cache() (ensure torch is imported), and handle missing cleanup
methods gracefully (use getattr or try/except) so repeated instantiation
releases GPU resources reliably.

33-35: Add type hints for __init__ parameters.

The parameters model_engine and dist lack type annotations. Per coding guidelines, all function parameters should have static type annotations.

♻️ Suggested fix
+from .model_engine import PyTorchModelEngine
+from ..distributed import Distributed
+
 class EncoderExecutor:
     ...
-    def __init__(self, model_engine, dist):
+    def __init__(self, model_engine: PyTorchModelEngine, dist: Distributed):
         self.model_engine = model_engine
         self.dist = dist
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tensorrt_llm/_torch/pyexecutor/encoder_executor.py` around lines 33 - 35, Add
static type annotations to the __init__ parameters: import Any from typing and
change the signature to def __init__(self, model_engine: Any, dist: Any) -> None
so both model_engine and dist are typed and the constructor return type is
explicit; update any relevant imports if needed and keep the parameter names
(model_engine, dist) as-is to match the existing usage in the class.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@tensorrt_llm/_torch/pyexecutor/model_engine.py`:
- Around line 3648-3649: The method signature for _prepare_encoder_inputs
violates Flake8 E128 due to under-indented continuation lines; fix by aligning
the continued parameter lines consistently—either place all parameters on the
same line or use a hanging indent where continuation lines are indented to align
under the opening parenthesis (e.g., 8 spaces from the left margin for
continuation) so the def _prepare_encoder_inputs(self, inputs: Dict[str, Any])
-> Dict[str, Any]: signature and its wrapped parameters follow PEP8/4-space
indentation rules and resolve E128.
- Around line 3663-3701: The model_inputs dict currently spreads **inputs
(including internal keys like 'seq_lens') which can cause TypeError for models
whose forward() doesn't accept these kwargs and also lacks input validation;
update the block that builds model_inputs in model_engine.py (around
model_inputs, input_ids_cuda, position_ids_cuda, and attn_metadata set-up) to:
1) explicitly remove/filter internal keys such as 'seq_lens', 'position_ids',
'token_type_ids', 'attention_mask' (and any other internal-only keys) from
inputs before merging so only valid model args are forwarded; 2) add validation
checks using seq_lens and batch_size to assert batch_size <= self.batch_size,
sum(seq_lens) == num_tokens, and position_ids.shape[0] == num_tokens (raise
clear ValueError messages); and 3) ensure position_ids is the correct
dtype/shape before copying to self.position_ids_cuda so the subsequent
attn_metadata and model_inputs construction use the sanitized inputs.

In `@tensorrt_llm/llmapi/llm.py`:
- Around line 738-740: The docstring advertises support for inputs_embeds but
the current flow strips it via _RESERVED_KEYS and always materializes input_ids
before calling batch_forward (see methods batch_forward and the forward
wrapper), so inputs_embeds is effectively unusable; either implement an
inputs_embeds execution branch that checks for "inputs_embeds" in model_kwargs,
skips materializing input_ids, preserves inputs_embeds through _RESERVED_KEYS to
the model, and passes it into batch_forward/forward path (including handling
attention_mask/position ids), or remove "inputs_embeds" from the public
docstrings and any places it’s removed by _RESERVED_KEYS; update both the
forward wrapper around batch_forward and the analogous code block referenced at
lines ~819-826 to match the chosen approach.
- Around line 748-760: In encode(), before dereferencing
self._encoder_executor.model_engine, check whether the encoder executor has been
shut down (e.g., if self._encoder_executor is None or has a closed flag) and
raise the same RuntimeError used for shutdown instead of letting an
AttributeError bubble up; update the beginning of the encode() method (which
already checks self._encoder_only) to verify the executor's presence and raise a
clear RuntimeError like "LLM has been shut down" if missing, so encode()
consistently errors on closed encoder-only LLMs.

In `@tests/unittest/llmapi/test_llm_encode.py`:
- Around line 119-131: The HF model is being moved to tllm_logits.device which
ended up being CPU; ensure the HuggingFace inference runs on the same CUDA
device as the TRT-LLM: compute a device variable (e.g., device =
tllm_logits.device if tllm_logits.device.type == "cuda" else
torch.device("cuda") if torch.cuda.is_available() else tllm_logits.device), then
call hf_model = hf_model.half().to(device) and move tokenizer outputs with
inputs = tokenizer(...).to(device) so AutoModelForSequenceClassification
(hf_model) and its inputs run on CUDA for a fair comparison with tllm_logits and
PROMPTS.

---

Duplicate comments:
In `@tensorrt_llm/llmapi/llm.py`:
- Around line 753-758: In the encode path normalize inputs without indexing
inputs[0] prematurely: treat any non-sequence (or str/bytes) as a single prompt
and wrap it in a list, treat empty sequences as an empty batch (do not access
inputs[0]), and only peek the first element when len(inputs) > 0 to detect an
unbatched numeric input; update the logic in the function/method handling the
inputs variable (the block that currently checks isinstance(inputs, list) and
inputs[0]) to use collections.abc.Sequence checks, guard against str/bytes, and
check len(inputs) before accessing index 0.

---

Nitpick comments:
In `@tensorrt_llm/_torch/pyexecutor/encoder_executor.py`:
- Around line 57-59: shutdown currently just deletes self.model_engine; update
it to perform explicit deterministic cleanup by calling any available
cleanup/close method on the model engine (e.g., self.model_engine.cleanup() or
self.model_engine.close() if present), then delete the attribute and clear CUDA
memory with torch.cuda.empty_cache() (ensure torch is imported), and handle
missing cleanup methods gracefully (use getattr or try/except) so repeated
instantiation releases GPU resources reliably.
- Around line 33-35: Add static type annotations to the __init__ parameters:
import Any from typing and change the signature to def __init__(self,
model_engine: Any, dist: Any) -> None so both model_engine and dist are typed
and the constructor return type is explicit; update any relevant imports if
needed and keep the parameter names (model_engine, dist) as-is to match the
existing usage in the class.

In `@tensorrt_llm/_torch/pyexecutor/py_executor_creator.py`:
- Around line 228-231: The create_encoder_executor function is missing a return
type annotation; update its signature to include the correct return type (e.g.,
the executor class/type returned by create_encoder_executor) using a string
literal annotation to avoid circular imports or alternatively import the return
type at module top if safe; locate the function create_encoder_executor in
py_executor_creator.py and add the string-based return type annotation matching
the executor class name used elsewhere in this module.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro Plus

Run ID: 2fc3228c-1dc3-47f5-9f7b-8f3302f7fca1

📥 Commits

Reviewing files that changed from the base of the PR and between 4e69c14 and 83bc6b9.

📒 Files selected for processing (6)
  • tensorrt_llm/_torch/pyexecutor/encoder_executor.py
  • tensorrt_llm/_torch/pyexecutor/model_engine.py
  • tensorrt_llm/_torch/pyexecutor/py_executor_creator.py
  • tensorrt_llm/llmapi/llm.py
  • tensorrt_llm/llmapi/llm_args.py
  • tests/unittest/llmapi/test_llm_encode.py

Comment thread tensorrt_llm/_torch/pyexecutor/model_engine.py Outdated
Comment thread tensorrt_llm/_torch/pyexecutor/model_engine.py
Comment thread tensorrt_llm/llmapi/llm.py Outdated
Comment thread tensorrt_llm/llmapi/llm.py Outdated
Comment thread tests/unittest/llmapi/test_llm_encode.py Outdated
Copy link
Copy Markdown
Collaborator

@Superjomn Superjomn left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM on the llmapi changes.

Comment thread tensorrt_llm/llmapi/llm.py
@schetlur-nv
Copy link
Copy Markdown
Collaborator

/bot run

Comment thread tensorrt_llm/_torch/pyexecutor/model_engine.py Outdated
@github-actions
Copy link
Copy Markdown

👎 Promotion blocked, new vulnerability found

Vulnerability report

Component Vulnerability Description Severity
encode/uvicorn CVE-2020-7694 This affects all versions of package uvicorn. The request logger provided by the package is vulnerable to ASNI escape sequence injection. Whenever any HTTP request is received, the default behaviour of uvicorn is to log its details to either the console or a log file. When attackers request crafted URLs with percent-encoded escape sequences, the logging component will log the URL after it's been processed with urllib.parse.unquote, therefore converting any percent-encoded characters into their single-character equivalent, which can have special meaning in terminal emulators. By requesting URLs with crafted paths, attackers can: * Pollute uvicorn's access logs, therefore jeopardising the integrity of such files. * Use ANSI sequence codes to attempt to interact with the terminal emulator that's displaying the logs (either in real time or from a file). HIGH
encode/uvicorn CVE-2020-7695 Uvicorn before 0.11.7 is vulnerable to HTTP response splitting. CRLF sequences are not escaped in the value of HTTP headers. Attackers can exploit this to add arbitrary headers to HTTP responses, or even return an arbitrary response body, whenever crafted input is used to construct HTTP headers. MEDIUM

@Funatiq
Copy link
Copy Markdown
Collaborator

Funatiq commented Apr 17, 2026

/bot run

@tingyangk tingyangk requested a review from a team as a code owner April 17, 2026 08:29
@tingyangk
Copy link
Copy Markdown
Collaborator Author

/bot run

@mikeiovine
Copy link
Copy Markdown
Collaborator

/bot run

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #44047 [ run ] triggered by Bot. Commit: 8700110 Link to invocation

@nvrohanv nvrohanv requested a review from litaotju April 18, 2026 01:36
@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #45260 [ run ] completed with state SUCCESS. Commit: 8b74a57
/LLM/main/L0_MergeRequest_PR pipeline #35518 completed with status: 'FAILURE'

CI Report

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

Link to invocation

@tingyangk
Copy link
Copy Markdown
Collaborator Author

/bot run

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #45417 [ run ] triggered by Bot. Commit: 8b74a57 Link to invocation

@tingyangk tingyangk force-pushed the tingyangk/encoder-llmapi-optimize branch from 8b74a57 to 285ba46 Compare April 24, 2026 17:26
@tingyangk
Copy link
Copy Markdown
Collaborator Author

/bot kill

@tingyangk
Copy link
Copy Markdown
Collaborator Author

/bot run --disable-fail-fast

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #45427 [ kill ] triggered by Bot. Commit: 285ba46 Link to invocation

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #45428 [ run ] triggered by Bot. Commit: 285ba46 Link to invocation

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #45427 [ kill ] completed with state ABORTED. Commit: 285ba46

Link to invocation

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #45428 [ run ] completed with state SUCCESS. Commit: 285ba46
/LLM/main/L0_MergeRequest_PR pipeline #35661 completed with status: 'FAILURE'

CI Report

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

Link to invocation

Signed-off-by: tingyangk <tingyangk@nvidia.com>
Signed-off-by: tingyangk <tingyangk@nvidia.com>
Signed-off-by: tingyangk <tingyangk@nvidia.com>
Signed-off-by: tingyangk <tingyangk@nvidia.com>
Signed-off-by: tingyangk <tingyangk@nvidia.com>
Signed-off-by: tingyangk <tingyangk@nvidia.com>
Signed-off-by: tingyangk <tingyangk@nvidia.com>
Signed-off-by: tingyangk <tingyangk@nvidia.com>
Signed-off-by: tingyangk <tingyangk@nvidia.com>
@tingyangk tingyangk force-pushed the tingyangk/encoder-llmapi-optimize branch from 285ba46 to 38dac93 Compare April 26, 2026 06:52
@tingyangk
Copy link
Copy Markdown
Collaborator Author

/bot run --disable-fail-fast

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #45556 [ run ] triggered by Bot. Commit: 38dac93 Link to invocation

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #45556 [ run ] completed with state SUCCESS. Commit: 38dac93
/LLM/main/L0_MergeRequest_PR pipeline #35774 completed with status: 'FAILURE'

CI Report

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

Link to invocation

@tingyangk
Copy link
Copy Markdown
Collaborator Author

/bot run

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #45597 [ run ] triggered by Bot. Commit: 38dac93 Link to invocation

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #45597 [ run ] completed with state FAILURE. Commit: 38dac93
/LLM/main/L0_MergeRequest_PR pipeline #35813 completed with status: 'FAILURE'

CI Report

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

Link to invocation

@tingyangk
Copy link
Copy Markdown
Collaborator Author

/bot run

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #45652 [ run ] triggered by Bot. Commit: 38dac93 Link to invocation

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #45652 [ run ] completed with state SUCCESS. Commit: 38dac93
/LLM/main/L0_MergeRequest_PR pipeline #35865 completed with status: 'SUCCESS'

CI Report

Link to invocation

@pcastonguay pcastonguay merged commit 00218e5 into NVIDIA:main Apr 27, 2026
5 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Community want to contribute PRs initiated from Community

Projects

None yet

Development

Successfully merging this pull request may close these issues.