Skip to content

[https://nvbugs/6095421][fix] fix PP>=3 executor shutdown hang in broadcast sample state loop#13267

Merged
yihwang-nv merged 1 commit into
NVIDIA:mainfrom
yihwang-nv:yihwang/fix_pp_hang
May 10, 2026
Merged

[https://nvbugs/6095421][fix] fix PP>=3 executor shutdown hang in broadcast sample state loop#13267
yihwang-nv merged 1 commit into
NVIDIA:mainfrom
yihwang-nv:yihwang/fix_pp_hang

Conversation

@yihwang-nv
Copy link
Copy Markdown
Collaborator

@yihwang-nv yihwang-nv commented Apr 21, 2026

Description

Background

TestLlama3_1_8BInstruct::test_bfloat16_4gpus[pp4-attn_backend=FLASHINFER-torch_compile=False] and other PP4 cases deadlock during KV cache size estimation, right after the single dummy request is processed. The hang consistently reproduces without logger.info calls on the wait path and disappears when any I/O is injected — a classic timing-sensitive MPI/UCX deadlock. All cases were previously waived under nvbug/6095421.
Root cause is three independent PP event-loop / MPI-transport issues that compound at shutdown. Fixing any one of them reduces but does not eliminate the hang; all three must be fixed.

Root cause & fix

1. MPI_Comm_dup() race at startup

The broadcast_sample_state_handler thread called the collective MPI_Comm_dup on itself, concurrently with the event-loop thread's point-to-point traffic on the original communicator. Thread-scheduling jitter could let the event loop's isend/recv_object grab the pkl5 intracomm send/recv locks before Dup() returned, producing an intermittent startup deadlock.
Fix: Move Dup() to the main thread inside start_worker(), before either the worker or broadcast thread starts. All ranks participate in the collective while no rank is yet issuing point-to-point calls.

2. Freeing the duplicated comm while peers still had pending pkl5 sub-requests

The PP ring is asymmetric:
3 → 0 → 1 → 2 is_second_last_pp_rank (rank 2): only recv, no send is_last_pp_rank (rank 3): only send, no recv

Rank 2's broadcast thread therefore finishes its end-of-loop flush as a no-op, receives the shutdown sentinel, enters finally, and calls broadcast_mpi_comm.Free() almost immediately. pkl5 large objects are delivered as multiple underlying MPI_Request objects (header + data chunks); once the receiving side has torn down its handle to the communicator, MPI_Test on the peer ranks' subsidiary requests stops making progress, leaving wait_on_pp_send_handles spinning forever.
Fix: Do not Free() the duplicated communicator in the broadcast thread's finally. Let it live until process teardown; MPI_Finalize reclaims it. Cost is at most a couple of leaked comms per process lifetime (estimation + real executor).

3. UCX worker progress starvation at shutdown

Even after (1) and (2), the remaining rank's broadcast thread was stuck in
MPI_Waitall(count=2) → opal_progress → ucp_worker_progress → pthread_spin_lock (busy-spin, wchan=0)

while peer ranks had moved on to MPI_Allreduce on the main communicator inside the next executor's calculate_max_num_blocks. UCX rendezvous sends need two-sided progress to finalize, but no peer was polling the broadcast communicator anymore; the stuck rank could not push the handshake through on its own.
Fix: Remove the end-of-iteration wait_on_pp_send_handles in _broadcast_sample_state_loop. Correctness is preserved because the drain at the top of the next _ring_broadcast_sample_state already waits on the previous isend for the same microbatch_id before posting a new one. During shutdown the shared UCX worker drives broadcast-comm progress implicitly via subsequent MPI activity on the main communicator; remaining requests are reclaimed at MPI_Finalize.

Side refactor (not load-bearing for the hang fix)

KV cache estimation previously had every rank wait on await_responses. Only rank 0 actually needs the response payload (to surface response.has_error()). Estimation now has:

  • rank 0 await_responses(req_ids) and error-checks
  • all ranks wait on a new per-request local-termination tracker (track_request_completions / await_request_completions)
  • memory / KV stats are read in the same place as before (pre-shutdown)
    This makes the estimation lifecycle easier to reason about and removes the non-rank-0 dependency on response delivery, which is unnecessary for estimation.

Files changed

  • tensorrt_llm/_torch/pyexecutor/py_executor.py
    • start_worker(): move MPI_Comm_dup() to main thread
    • _broadcast_sample_state_loop(): use pre-duplicated comm, drop end-of-loop flush, drop Free()
    • _do_terminate_request(): notify tracked req-id completion
    • New helpers: track_request_completions, await_request_completions
  • tensorrt_llm/_torch/pyexecutor/_util.py
    • configure_kv_cache_capacity(): rank-0 awaits response, all ranks await local termination
  • tests/integration/test_lists/waives.txt
    • Un-waive 6 PP4 cases covered by this fix (see below)

Verification

Locally on 4× B300 (inside tensorrt_llm-devel container), each case previously waived under nvbug/6095421 with pp4 in its id:

Verdict Case
PASS TestLlama3_1_8BInstruct::test_bfloat16_4gpus[pp4-attn_backend=FLASHINFER-torch_compile=False] (82 s)
PASS TestLlama3_1_8BInstruct::test_fp8_4gpus[pp4-fp8kv=True-attn_backend=TRTLLM-torch_compile=False]
PASS TestDeepSeekV3Lite::test_bfloat16_4gpus[pp4-mtp_nextn=0-attention_dp=True-cuda_graph=True-overlap_scheduler=True-torch_compile=False]
PASS TestDeepSeekV3Lite::test_nvfp4_4gpus[moe_backend=CUTLASS-mtp_nextn=0-pp4-fp8kv=True-attention_dp=True-cuda_graph=True-overlap_scheduler=True-low_precision_combine=False-torch_compile=True]
PASS TestDeepSeekV3Lite::test_nvfp4_4gpus[moe_backend=CUTLASS-mtp_nextn=0-pp4-fp8kv=False-attention_dp=False-cuda_graph=False-overlap_scheduler=False-low_precision_combine=False-torch_compile=True]
PASS TestDeepSeekV3Lite::test_fp8_block_scales_4gpus[pp4-mtp_nextn=0-fp8kv=False-attention_dp=True-cuda_graph=True-overlap_scheduler=True-torch_compile=False-sampler_async_worker=False]
unrelated failure TestDeepSeekV3Lite::test_nvfp4_4gpus[moe_backend=CUTLASS-mtp_nextn=2-pp4-fp8kv=True-attention_dp=True-cuda_graph=True-overlap_scheduler=True-low_precision_combine=False-torch_compile=False] — no longer hangs; now fails with an unrelated MoE NVFP4 autotune assertion at moe_gemm_template_dispatch_tma_ws.h:108 ("No Smem epilogue schedule is not supported for block scaled types or finalize fusion"). Tracked separately.

Summary by CodeRabbit

Release Notes

  • New Features

    • Added request completion tracking for improved synchronization in distributed inference scenarios.
  • Bug Fixes

    • Fixed MPI communicator handling in distributed broadcast operations.
    • Improved KV cache warmup phase synchronization across distributed ranks.
  • Tests

    • Re-enabled previously skipped test cases for LLM API PyTorch configurations.

@yihwang-nv yihwang-nv requested a review from a team as a code owner April 21, 2026 08:58
@yihwang-nv yihwang-nv requested a review from joyang-nv April 21, 2026 08:58
@yihwang-nv
Copy link
Copy Markdown
Collaborator Author

/bot run --disable-fail-fast

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #44686 [ run ] triggered by Bot. Commit: 63ecd5e Link to invocation

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented Apr 21, 2026

📝 Walkthrough

Walkthrough

The changes enhance request completion tracking in the PyExecutor by introducing per-request synchronization primitives and methods, while refactoring KV cache capacity configuration to synchronize across distributed ranks. MPI communicator handling for the broadcast loop is also optimized to pre-duplicate and reuse the communicator instead of creating/freeing it per iteration.

Changes

Cohort / File(s) Summary
Request Completion Tracking
tensorrt_llm/_torch/pyexecutor/py_executor.py
Added track_request_completions() and await_request_completions() methods with condition variable-based synchronization. Updated _do_terminate_request() to notify completion waiters when tracked requests finish.
KV Cache Configuration Synchronization
tensorrt_llm/_torch/pyexecutor/_util.py
Modified configure_kv_cache_capacity() to call new request tracking methods. Restricted response payload processing to rank 0 and added explicit completion synchronization across all ranks before reading CUDA statistics.
MPI Communicator Optimization
tensorrt_llm/_torch/pyexecutor/py_executor.py
Changed PP sample-state broadcast thread to use pre-duplicated MPI communicator stored in _broadcast_mpi_comm instead of duplicating/freeing per iteration. Removed queue flush behavior waiting on send_handles.
Test Skip List Cleanup
tests/integration/test_lists/waives.txt
Removed SKIP entries for accuracy/test_llm_api_pytorch.py test cases covering DeepSeekV3Lite and Llama3.1-8B with various quantization and backend configurations.

Sequence Diagram(s)

sequenceDiagram
    participant Root as Root Rank (0)
    participant Other as Other Ranks
    participant PyExec as PyExecutor
    participant CUDA as CUDA Memory

    Root->>PyExec: enqueue_warmup_requests(req_ids)
    PyExec->>PyExec: distribute to workers
    
    Root->>PyExec: track_request_completions(req_ids)
    PyExec->>PyExec: store in tracked_request_completion_ids
    
    Root->>PyExec: start_worker()
    par
        Root->>PyExec: await_responses(req_ids)<br/>(rank==0 only)
        Other->>PyExec: process locally
    end
    
    PyExec->>PyExec: _do_terminate_request()<br/>(all ranks)
    PyExec->>PyExec: notify_all() on<br/>request_completion_cv
    
    Root->>PyExec: await_request_completions(req_ids)<br/>(barrier across ranks)
    Other->>PyExec: await_request_completions(req_ids)<br/>(barrier across ranks)
    
    par
        Root->>CUDA: read memory statistics
        Other->>CUDA: ready
    end
Loading

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~25 minutes

🚥 Pre-merge checks | ✅ 4 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 33.33% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (4 passed)
Check name Status Explanation
Title check ✅ Passed Title is concise and clearly describes the main fix: addressing a PP>=3 executor shutdown hang in the broadcast sample state loop, with proper ticket reference and type tag.
Description check ✅ Passed Description is comprehensive, covering background, three root causes with fixes, side refactor, files changed, and local verification results against the repository template.
Linked Issues check ✅ Passed Check skipped because no linked issues were found for this pull request.
Out of Scope Changes check ✅ Passed Check skipped because no linked issues were found for this pull request.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Warning

Review ran into problems

🔥 Problems

Timed out fetching pipeline failures after 30000ms


Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@tensorrt_llm/_torch/pyexecutor/_util.py`:
- Around line 433-446: The current change removes rank-local error propagation:
instead of only calling py_executor.await_responses(req_ids) and checking
ExecutorResponse.has_error()/error_msg on rank 0, restore per-rank error
checking (call py_executor.await_responses(req_ids) and inspect each
ExecutorResponse.has_error() on every rank) before reading memory stats, or
implement an explicit distributed error reduction (use py_executor.dist.rank and
a collective to fail-fast) that aggregates any per-rank error flags/messages and
raises a RuntimeError with the combined message; ensure checks reference
py_executor.await_responses, ExecutorResponse.has_error, and response.error_msg
and run on non-zero ranks prior to
py_executor.await_request_completions(req_ids).
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro Plus

Run ID: 32f5b885-213e-4e83-8e0c-210281e417cb

📥 Commits

Reviewing files that changed from the base of the PR and between c665127 and 63ecd5e.

📒 Files selected for processing (3)
  • tensorrt_llm/_torch/pyexecutor/_util.py
  • tensorrt_llm/_torch/pyexecutor/py_executor.py
  • tests/integration/test_lists/waives.txt
💤 Files with no reviewable changes (1)
  • tests/integration/test_lists/waives.txt

Comment thread tensorrt_llm/_torch/pyexecutor/_util.py Outdated
@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #44686 [ run ] completed with state SUCCESS. Commit: 63ecd5e
/LLM/main/L0_MergeRequest_PR pipeline #35053 completed with status: 'FAILURE'

CI Report

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

Link to invocation

@yihwang-nv
Copy link
Copy Markdown
Collaborator Author

/bot run --disable-fail-fast

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #45209 [ run ] triggered by Bot. Commit: 63ecd5e Link to invocation

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #45209 [ run ] completed with state SUCCESS. Commit: 63ecd5e
/LLM/main/L0_MergeRequest_PR pipeline #35476 completed with status: 'FAILURE'

CI Report

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

Link to invocation

@yuxianq yuxianq self-requested a review April 24, 2026 09:39
@yihwang-nv yihwang-nv force-pushed the yihwang/fix_pp_hang branch from 63ecd5e to 999e25a Compare April 28, 2026 05:10
@yihwang-nv
Copy link
Copy Markdown
Collaborator Author

/bot run --disable-fail-fast

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #45858 [ run ] triggered by Bot. Commit: 999e25a Link to invocation

@yihwang-nv yihwang-nv enabled auto-merge (squash) April 28, 2026 07:58
@yihwang-nv
Copy link
Copy Markdown
Collaborator Author

/bot help

@github-actions
Copy link
Copy Markdown

GitHub Bot Help

/bot [-h] ['run', 'kill', 'skip', 'reuse-pipeline'] ...

Provide a user friendly way for developers to interact with a Jenkins server.

Run /bot [-h|--help] to print this help message.

See details below for each supported subcommand.

Details

run [--reuse-test (optional)pipeline-id --disable-fail-fast --skip-test --stage-list "A10-PyTorch-1, xxx" --gpu-type "A30, H100_PCIe" --test-backend "pytorch, cpp" --add-multi-gpu-test --only-multi-gpu-test --disable-multi-gpu-test --post-merge --extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx" --detailed-log --debug(experimental) --high-priority]

Launch build/test pipelines. All previously running jobs will be killed.

--reuse-test (optional)pipeline-id (OPTIONAL) : Allow the new pipeline to reuse build artifacts and skip successful test stages from a specified pipeline or the last pipeline if no pipeline-id is indicated. If the Git commit ID has changed, this option will be always ignored. The DEFAULT behavior of the bot is to reuse build artifacts and successful test results from the last pipeline.

--disable-reuse-test (OPTIONAL) : Explicitly prevent the pipeline from reusing build artifacts and skipping successful test stages from a previous pipeline. Ensure that all builds and tests are run regardless of previous successes.

--disable-fail-fast (OPTIONAL) : Disable fail fast on build/tests/infra failures.

--skip-test (OPTIONAL) : Skip all test stages, but still run build stages, package stages and sanity check stages. Note: Does NOT update GitHub check status.

--stage-list "A10-PyTorch-1, xxx" (OPTIONAL) : Only run the specified test stages. Supports wildcard * for pattern matching (e.g., "*PerfSanity*" matches all stages containing PerfSanity). Examples: "A10-PyTorch-1, xxx", "PerfSanity". Note: Does NOT update GitHub check status.

--gpu-type "A30, H100_PCIe" (OPTIONAL) : Only run the test stages on the specified GPU types. Examples: "A30, H100_PCIe". Note: Does NOT update GitHub check status.

--test-backend "pytorch, cpp" (OPTIONAL) : Skip test stages which don't match the specified backends. Only support [pytorch, cpp, tensorrt, triton]. Examples: "pytorch, cpp" (does not run test stages with tensorrt or triton backend). Note: Does NOT update GitHub pipeline status.

--only-multi-gpu-test (OPTIONAL) : Only run the multi-GPU tests. Note: Does NOT update GitHub check status.

--disable-multi-gpu-test (OPTIONAL) : Disable the multi-GPU tests. Note: Does NOT update GitHub check status.

--add-multi-gpu-test (OPTIONAL) : Force run the multi-GPU tests in addition to running L0 pre-merge pipeline.

--post-merge (OPTIONAL) : Run the L0 post-merge pipeline instead of the ordinary L0 pre-merge pipeline.

--extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx" (OPTIONAL) : Run the ordinary L0 pre-merge pipeline and specified test stages. Supports wildcard * for pattern matching. Examples: --extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx", --extra-stage "Post-Merge".

--detailed-log (OPTIONAL) : Enable flushing out all logs to the Jenkins console. This will significantly increase the log volume and may slow down the job.

--debug (OPTIONAL) : Experimental feature. Enable access to the CI container for debugging purpose. Note: Specify exactly one stage in the stage-list parameter to access the appropriate container environment. Note: Does NOT update GitHub check status.

--high-priority (OPTIONAL) : Run the pipeline with high priority. This option is restricted to authorized users only and will route the job to a high-priority queue.

kill

kill

Kill all running builds associated with pull request.

skip

skip --comment COMMENT

Skip testing for latest commit on pull request. --comment "Reason for skipping build/test" is required. IMPORTANT NOTE: This is dangerous since lack of user care and validation can cause top of tree to break.

reuse-pipeline

reuse-pipeline

Reuse a previous pipeline to validate current commit. This action will also kill all currently running builds associated with the pull request. IMPORTANT NOTE: This is dangerous since lack of user care and validation can cause top of tree to break.

@yihwang-nv
Copy link
Copy Markdown
Collaborator Author

/bot kill

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #45946 [ kill ] triggered by Bot. Commit: 999e25a Link to invocation

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #45946 [ kill ] completed with state SUCCESS. Commit: 999e25a
Successfully killed previous jobs for commit 999e25a

Link to invocation

@yihwang-nv
Copy link
Copy Markdown
Collaborator Author

/bot run --disable-fail-fast

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #46153 [ run ] triggered by Bot. Commit: 999e25a Link to invocation

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #46153 [ run ] completed with state SUCCESS. Commit: 999e25a
/LLM/main/L0_MergeRequest_PR pipeline #36279 completed with status: 'FAILURE'

CI Report

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

Link to invocation

@yihwang-nv yihwang-nv force-pushed the yihwang/fix_pp_hang branch from 999e25a to feb9fe0 Compare April 30, 2026 02:15
@yihwang-nv
Copy link
Copy Markdown
Collaborator Author

/bot run --disable-fail-fast

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #46270 [ run ] triggered by Bot. Commit: feb9fe0 Link to invocation

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #46270 [ run ] completed with state SUCCESS. Commit: feb9fe0
/LLM/main/L0_MergeRequest_PR pipeline #36375 completed with status: 'FAILURE'

CI Report

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

CI Agent Failure Analysis

Link to invocation

@yihwang-nv yihwang-nv force-pushed the yihwang/fix_pp_hang branch from feb9fe0 to 9980b92 Compare April 30, 2026 16:40
@yihwang-nv
Copy link
Copy Markdown
Collaborator Author

/bot run --disable-fail-fast

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #46408 [ run ] triggered by Bot. Commit: 9980b92 Link to invocation

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #46408 [ run ] completed with state SUCCESS. Commit: 9980b92
/LLM/main/L0_MergeRequest_PR pipeline #36484 completed with status: 'FAILURE'

CI Report

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

Link to invocation

@yihwang-nv
Copy link
Copy Markdown
Collaborator Author

/bot run --disable-fail-fast

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #46547 [ run ] triggered by Bot. Commit: 9980b92 Link to invocation

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #46547 [ run ] completed with state FAILURE. Commit: 9980b92
/LLM/main/L0_MergeRequest_PR pipeline #36604 completed with status: 'FAILURE'

CI Report

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

CI Agent Failure Analysis

Link to invocation

@yihwang-nv yihwang-nv force-pushed the yihwang/fix_pp_hang branch from 9980b92 to 1a5b4eb Compare May 6, 2026 07:03
@yihwang-nv
Copy link
Copy Markdown
Collaborator Author

/bot run --disable-fail-fast

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #46944 [ run ] triggered by Bot. Commit: 1a5b4eb Link to invocation

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #46944 [ run ] completed with state SUCCESS. Commit: 1a5b4eb
/LLM/main/L0_MergeRequest_PR pipeline #36944 completed with status: 'FAILURE'

CI Report

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

Link to invocation

…adcast sample state loop

Signed-off-by: Yihan Wang <yihwang@nvidia.com>
@yihwang-nv yihwang-nv force-pushed the yihwang/fix_pp_hang branch from 1a5b4eb to 0049242 Compare May 8, 2026 02:35
@yihwang-nv
Copy link
Copy Markdown
Collaborator Author

/bot run --disable-fail-fast

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #47291 [ run ] triggered by Bot. Commit: 0049242 Link to invocation

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #47291 [ run ] completed with state SUCCESS. Commit: 0049242
/LLM/main/L0_MergeRequest_PR pipeline #37233 completed with status: 'FAILURE'

CI Report

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

CI Agent Failure Analysis

Link to invocation

@yihwang-nv
Copy link
Copy Markdown
Collaborator Author

/bot run --disable-fail-fast

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #47398 [ run ] triggered by Bot. Commit: 0049242 Link to invocation

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #47398 [ run ] completed with state SUCCESS. Commit: 0049242
/LLM/main/L0_MergeRequest_PR pipeline #37328 completed with status: 'FAILURE'

CI Report

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

Link to invocation

@yihwang-nv
Copy link
Copy Markdown
Collaborator Author

/bot run --disable-fail-fast

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #47575 [ run ] triggered by Bot. Commit: 0049242 Link to invocation

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #47575 [ run ] completed with state SUCCESS. Commit: 0049242
/LLM/main/L0_MergeRequest_PR pipeline #37486 completed with status: 'SUCCESS'

CI Report

Link to invocation

@yihwang-nv yihwang-nv merged commit b535204 into NVIDIA:main May 10, 2026
6 checks passed
yufeiwu-nv pushed a commit to yufeiwu-nv/TensorRT-LLM that referenced this pull request May 19, 2026
…adcast sample state loop (NVIDIA#13267)

Signed-off-by: Yihan Wang <yihwang@nvidia.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants