Skip to content

[#4674][feat] optimize llama8B decode: trtllm silu_mul backend, quant+silu_mul, QKV passthrough to attention#12507

Merged
MrGeva merged 9 commits intoNVIDIA:mainfrom
nv-auto-deploy:opt_qkv
May 7, 2026
Merged

[#4674][feat] optimize llama8B decode: trtllm silu_mul backend, quant+silu_mul, QKV passthrough to attention#12507
MrGeva merged 9 commits intoNVIDIA:mainfrom
nv-auto-deploy:opt_qkv

Conversation

@MrGeva
Copy link
Copy Markdown
Collaborator

@MrGeva MrGeva commented Mar 24, 2026

Three optimizations to close the AD-vs-PT decode gap on Llama 3.1 8B FP8:

1. Add trtllm backend for silu_and_mul custom op that delegates to
   torch.ops.trtllm.silu_and_mul (Triton kernel), which is faster than
   FlashInfer and natively supports fused FP8 output quantization via
   scale/dtype params. Selectable via config (backend: trtllm).

2. Fuse FP8 input quantization into silu_and_mul: when the sole consumer
   is trtllm_quant_fp8_linear, extract input_scale and fold it into the
   silu_and_mul call so the Triton kernel quantizes in-kernel, eliminating
   the separate scaleMatrixPerTensorVec kernel.

3. Fused QKV passthrough: trace pre-RoPE Q/K/V back through the
   narrow->view->contiguous chain to the flat fused QKV GEMM output,
   rewire all three attention args to use it directly. This eliminates
   2 bf16 copy kernels + 1 CatArrayBatchedCopy per layer (the zero-copy
   check was failing due to non-contiguous split views).

Benchmark (Llama 3.1 8B FP8, 1k/2k/64, H100):
  AD baseline:  7,527 output tps, 17,007 ms latency
  AD optimized: 7,966 output tps, 16,069 ms latency (+5.8%)
  PT baseline:  8,477 output tps
  Gap reduced from 11.2% to 6.0% (46% of gap closed)

Summary by CodeRabbit

  • Refactor

    • Enhanced attention fusion with fused QKV tensor support for improved performance
    • Refactored SiLU multiplication fusion to support multiple backends (FlashInfer and TRT-LLM) with optimized quantization handling
    • Updated transform pipeline configuration system with backend-specific settings
  • Tests

    • Reorganized unit tests to validate FX-graph fusion transformations and cache insertion compatibility

Description

Test Coverage

PR Checklist

Please review the following before submitting your PR:

  • PR description clearly explains what and why. If using CodeRabbit's summary, please make sure it makes sense.

  • PR Follows TRT-LLM CODING GUIDELINES to the best of your knowledge.

  • Test cases are provided for new code paths (see test instructions)

  • Any new dependencies have been scanned for license and vulnerabilities

  • CODEOWNERS updated if ownership changes

  • Documentation updated as needed

  • Update tava architecture diagram if there is a significant design change in PR.

  • The reviewers assigned automatically/manually are appropriate for the PR.

  • Please check this after reviewing the above items as appropriate for this PR.

GitHub Bot Help

To see a list of available CI bot commands, please comment /bot help.

@MrGeva
Copy link
Copy Markdown
Collaborator Author

MrGeva commented Apr 15, 2026

/bot run --extra-stage "DGX_B200-4_GPUs-AutoDeploy-1, DGX_H100-4_GPUs-AutoDeploy-1"

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #43527 [ run ] triggered by Bot. Commit: d7bb3c2 Link to invocation

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #43527 [ run ] completed with state FAILURE. Commit: d7bb3c2

Link to invocation

@MrGeva
Copy link
Copy Markdown
Collaborator Author

MrGeva commented Apr 15, 2026

/bot run --extra-stage "DGX_B200-4_GPUs-AutoDeploy-1, DGX_H100-4_GPUs-AutoDeploy-1"

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #43568 [ run ] triggered by Bot. Commit: d7bb3c2 Link to invocation

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #43568 [ run ] completed with state SUCCESS. Commit: d7bb3c2
/LLM/main/L0_MergeRequest_PR pipeline #34067 completed with status: 'FAILURE'

CI Report

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

Link to invocation

Comment thread examples/auto_deploy/model_registry/configs/llama3_1_8b.yaml
Comment thread tensorrt_llm/_torch/auto_deploy/custom_ops/attention/trtllm_attention.py Outdated
Comment thread tensorrt_llm/_torch/auto_deploy/custom_ops/linear/silu_mul.py
Comment thread tensorrt_llm/_torch/auto_deploy/custom_ops/attention/trtllm_attention.py Outdated
Comment thread tensorrt_llm/_torch/auto_deploy/custom_ops/linear/silu_mul.py Outdated
Comment thread tensorrt_llm/_torch/auto_deploy/custom_ops/linear/silu_mul.py
Comment thread tensorrt_llm/_torch/auto_deploy/custom_ops/linear/silu_mul.py Outdated
Comment thread tensorrt_llm/llmapi/llm_args.py Outdated
Comment thread tensorrt_llm/llmapi/llm_args.py Outdated
@MrGeva MrGeva force-pushed the opt_qkv branch 2 times, most recently from 0f97953 to b108e41 Compare May 4, 2026 13:42
Signed-off-by: Eran Geva <19514940+MrGeva@users.noreply.github.com>
MrGeva added 2 commits May 4, 2026 06:59
Switch the trtllm silu+mul custom op from a private int-encoded dtype
contract (_DTYPE_TO_INT/_INT_TO_DTYPE) to an Optional[str] like
trtllm_quant_fp8_linear, resolved internally via getattr(torch, name).

Signed-off-by: Eran Geva <19514940+MrGeva@users.noreply.github.com>
Replace the 4-key dispatch dict with the standard
str(dtype).removeprefix("torch.") idiom that already appears
elsewhere in auto_deploy.

Signed-off-by: Eran Geva <19514940+MrGeva@users.noreply.github.com>
@MrGeva MrGeva changed the title optimize AD decode: trtllm silu backend, FP8 quant fusion, fused QKV passthrough [#4674][feat] optimize llama8B decode: trtllm silu_mul backend, quant+silu_mul, QKV passthrough to attention May 4, 2026
@MrGeva MrGeva marked this pull request as ready for review May 4, 2026 14:21
@MrGeva MrGeva requested a review from a team as a code owner May 4, 2026 14:21
@MrGeva MrGeva requested a review from greg-kwasniewski1 May 4, 2026 14:21
@MrGeva
Copy link
Copy Markdown
Collaborator Author

MrGeva commented May 4, 2026

/bot run --extra-stage "DGX_B200-4_GPUs-AutoDeploy-1, DGX_H100-4_GPUs-AutoDeploy-1" --disable-fail-fast

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #46647 [ run ] triggered by Bot. Commit: 9a96e88 Link to invocation

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented May 4, 2026

📝 Walkthrough

Walkthrough

This pull request introduces backend-agnostic SiLU+mul fusion with FlashInfer and TRT-LLM implementations, extends RoPE fusion to support fused-QKV rewiring, and enhances the TRT-LLM attention custom op to accept fused-QKV mode hints. Configuration files add the new fuse_silu_mul backend selector and introduce an optional FP8 GEMM fusion transform. Integration tests verify TRT-LLM cache insertion with the updated fused-QKV pipeline.

Changes

Backend-Specific SiLU+Mul Fusion

Layer / File(s) Summary
Custom Ops
tensorrt_llm/_torch/auto_deploy/custom_ops/linear/silu_mul.py
Replaces single silu_and_mul custom op with two backend-specific variants: flashinfer_silu_and_mul (FlashInfer kernel or manual SiLU×up) and trtllm_silu_and_mul (TRT-LLM kernel with optional FP8 quantization via scale and out_dtype parameters). Adds _resolve_out_dtype() helper.
Transform
tensorrt_llm/_torch/auto_deploy/transform/library/fuse_silu_mul.py
Consolidates narrow and getitem fusion patterns into single aten.mul.Tensor traversal; adds FuseSiluMulConfig with backend selector; implements FP8 quantization fusion for TRT-LLM backend via _try_fuse_fp8_quant().
Configuration
tensorrt_llm/_torch/auto_deploy/config/default.yaml
Updates fuse_silu_mul to explicitly specify backend: flashinfer; adds fuse_fp8_gemms transform as disabled placeholder.
Model Config Example
examples/auto_deploy/model_registry/configs/llama3_1_8b.yaml
Adds fuse_silu_mul transform with backend: trtllm and enabled: true configuration entry.
Tests
tests/unittest/auto_deploy/singlegpu/transformations/library/test_fuse_silu_mul.py
Simplifies to unit-level FX-graph fusion tests; removes end-to-end export harness; updates narrow and getitem variants to expect flashinfer_silu_and_mul op; adds backend parameter to _run_fuse_silu_mul().

Fused QKV Pipeline (RoPE Fusion + TRT-LLM Attention Integration)

Layer / File(s) Summary
RoPE Fusion Config & Transform
tensorrt_llm/_torch/auto_deploy/transform/library/fuse_rope_into_trtllm_attention.py
Adds FuseRopeIntoTrtllmAttentionConfig with fuse_qkv_passthrough flag (default True); extends _try_fuse_one() to conditionally trace pre-RoPE Q/K/V back to a single fused QKV tensor via new _try_trace_to_fused_qkv() helper; stores head/dimension hints in attention-node metadata when fused-QKV path is enabled.
TRT-LLM Attention Custom Op
tensorrt_llm/_torch/auto_deploy/custom_ops/attention/trtllm_attention.py
Extends trtllm_mha_with_cache signature with num_heads_hint, num_kv_heads_hint, and head_dim_hint parameters to enable fused-QKV mode; adds _materialize_out_scale() for static FP8 out_scale buffers; updates cache initialization and constant extraction to propagate fused-QKV metadata; adjusts fake implementation to match fused-QKV output shape.
Integration Tests
tests/unittest/auto_deploy/singlegpu/transformations/library/test_gemm_fusion.py
Adds two new TRT-LLM cache insertion tests (test_fuse_qkv_with_trtllm_cache_insertion() and test_fuse_qkv_gqa_with_trtllm_cache_insertion()) verifying post-GEMM fusion, cache insertion, and expected node/output counts with fused-QKV structure.

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~75 minutes

🚥 Pre-merge checks | ✅ 3 | ❌ 2

❌ Failed checks (1 warning, 1 inconclusive)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 70.27% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
Description check ❓ Inconclusive The PR description includes detailed explanations of all three optimizations with benchmark results, but lacks explicit Test Coverage and Checklist sections required by the template. Add a dedicated Test Coverage section listing specific test cases (test_fuse_silu_mul_narrow_variant, test_fuse_silu_mul_getitem_variant, test_fuse_qkv_with_trtllm_cache_insertion, etc.) that validate the changes.
✅ Passed checks (3 passed)
Check name Status Explanation
Title check ✅ Passed The PR title clearly identifies the three main optimizations: trtllm silu_mul backend, quantization fusion, and QKV passthrough, which are all reflected in the changeset.
Linked Issues check ✅ Passed Check skipped because no linked issues were found for this pull request.
Out of Scope Changes check ✅ Passed Check skipped because no linked issues were found for this pull request.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Review rate limit: 9/10 reviews remaining, refill in 6 minutes.

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 4

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
tests/unittest/auto_deploy/singlegpu/transformations/library/test_fuse_silu_mul.py (1)

85-111: ⚠️ Potential issue | 🟠 Major | ⚡ Quick win

Add coverage for the new TRT-LLM fusion path.

These tests still only assert the FlashInfer default path, so the new backend="trtllm" branch and _try_fuse_fp8_quant() rewrite can regress without failing this suite. Please add at least one case that fuses to torch.ops.auto_deploy.trtllm_silu_and_mul.default and verifies the downstream trtllm_quant_fp8_linear rewrite (scale folded into the fused op, out_dtype backfilled). QA list updates are unnecessary for this unit-only coverage.

As per coding guidelines: Coverage expectations — note missing parametrization where multiple backends or dtypes apply, and missing failure modes relevant to the feature.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In
`@tests/unittest/auto_deploy/singlegpu/transformations/library/test_fuse_silu_mul.py`
around lines 85 - 111, Add a unit test variant that exercises the TRT-LLM fusion
path: call the existing _build_narrow_silu_mul_graph() (or
_build_getitem_silu_mul_graph()) and invoke _run_fuse_silu_mul(gm,
backend="trtllm") to trigger trtllm fusion; assert info.num_matches == 1 and
that the graph contains torch.ops.auto_deploy.trtllm_silu_and_mul.default
exactly once and no aten.silu/aten.mul ops, then locate the downstream quant
rewrite trtllm_quant_fp8_linear in the transformed graph and assert the scale
has been folded into the fused op and out_dtype has been backfilled (verify
scale removed from separate node and out_dtype attribute present on the fused
op); add a complementary negative test using _run_fuse_silu_mul(...,
enabled=False) or a non-trtllm backend to ensure the trtllm fusion is skipped.
🧹 Nitpick comments (1)
examples/auto_deploy/model_registry/configs/llama3_1_8b.yaml (1)

22-24: Please confirm this opt-in is covered by perf CI/QA.

This YAML flips Llama 3.1 8B FP8 decode onto the new backend, but I don’t see matching perf-list updates in the provided diff. Please verify that existing entries in tests/integration/test_lists/test-db/l0_*.yml and tests/integration/test_lists/qa/llm_perf_*.yml already exercise this exact AutoDeploy path; otherwise this optimization won’t be regression-tracked.

As per coding guidelines: For performance-sensitive paths, check whether a perf test entry is present or updated in tests/integration/test_lists/test-db/l0_*.yml and tests/integration/test_lists/qa/llm_perf_*.yml.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@examples/auto_deploy/model_registry/configs/llama3_1_8b.yaml` around lines 22
- 24, You enabled the fuse_silu_mul opt-in for the Llama 3.1 8B config
(fuse_silu_mul: enabled, backend: trtllm) but did not add or verify
corresponding perf regression coverage; update the CI/QA perf test lists to
cover this new AutoDeploy path by adding or confirming entries in the L0 test
list (l0_*.yml) and the perf QA list (llm_perf_*.yml) that reference the
llama3_1_8b configuration and the trtllm/fuse_silu_mul combination, or document
why existing entries already cover it; ensure the test names or selectors
explicitly exercise FP8 decode on the trtllm backend so the change is tracked by
perf CI.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@tensorrt_llm/_torch/auto_deploy/custom_ops/attention/trtllm_attention.py`:
- Around line 844-855: When handling fused QKV mode in source_attn_node (check
meta "_trtllm_fused_qkv"), do not default kv_dtype to torch.bfloat16 when
source_attn_node.args[0].meta.get("val") is missing; instead detect q_meta ==
None and raise a clear exception (or assert) indicating missing fused-QKV meta
so tracing/propagation bugs fail fast; keep the existing logic to infer
num_kv_heads and head_dim from "_trtllm_num_kv_heads" and "_trtllm_head_dim" and
only set kv_dtype from q_meta.dtype when q_meta is present (refer to
source_attn_node.args[0], _trtllm_fused_qkv, kv_dtype).

In
`@tensorrt_llm/_torch/auto_deploy/transform/library/fuse_rope_into_trtllm_attention.py`:
- Around line 361-373: _trace_split currently only unwraps nodes whose
call_function target name is "contiguous", so when the exporter emits the aten
overload (torch.ops.aten.contiguous.default) the function fails to recognize it
and returns None; update _trace_split to also detect and unwrap nodes that are
aten.contiguous (e.g., check is_op(current, torch.ops.aten.contiguous.default)
in addition to the getattr(current.target, "__name__") == "contiguous" check),
ensuring the same unwrap logic used for contiguous in _trace_narrow is applied
before testing view/reshape and returning the traced node.

In `@tensorrt_llm/_torch/auto_deploy/transform/library/fuse_silu_mul.py`:
- Around line 333-356: The current narrow-match returns gate_parent and
gate_size without ensuring the parent tensor's last-dimension equals exactly
2*gate_size; update the check in the fuse_silu_mul matcher (around
_get_narrow_info usage and the return) to verify the parent extent covers
exactly the fused projection by asserting the parent tensor's relevant dimension
== gate_size * 2 (for split-based matches also ensure split sizes are exactly
[gate_size, gate_size]); if this condition fails, return None so you only
rewrite when the parent extent is exactly 2*d.

In
`@tests/unittest/auto_deploy/singlegpu/transformations/library/test_gemm_fusion.py`:
- Around line 842-897: The test never verifies the fused-QKV passthrough because
QKVAttentionModel uses torch_attention (no torch_rope_* nodes), so update the
test_gemm_fusion.py case (the function
test_fuse_qkv_with_trtllm_cache_insertion) to exercise a rope-backed path or to
assert fused-QKV hints: ensure the graph includes rope nodes (so
fuse_rope_into_trtllm_attention runs) or after running
InferenceOptimizer(insert_cached_attention) locate the inserted
trtllm_attention_mha_with_cache node
(torch.ops.auto_deploy.trtllm_attention_mha_with_cache.default) and assert its
metadata includes a non-zero _trtllm_fused_qkv / head-hints value; reference
functions/classes: QKVAttentionModel, fuse_rope_into_trtllm_attention,
_trtllm_fused_qkv, and insert_cached_attention when adding the check.

---

Outside diff comments:
In
`@tests/unittest/auto_deploy/singlegpu/transformations/library/test_fuse_silu_mul.py`:
- Around line 85-111: Add a unit test variant that exercises the TRT-LLM fusion
path: call the existing _build_narrow_silu_mul_graph() (or
_build_getitem_silu_mul_graph()) and invoke _run_fuse_silu_mul(gm,
backend="trtllm") to trigger trtllm fusion; assert info.num_matches == 1 and
that the graph contains torch.ops.auto_deploy.trtllm_silu_and_mul.default
exactly once and no aten.silu/aten.mul ops, then locate the downstream quant
rewrite trtllm_quant_fp8_linear in the transformed graph and assert the scale
has been folded into the fused op and out_dtype has been backfilled (verify
scale removed from separate node and out_dtype attribute present on the fused
op); add a complementary negative test using _run_fuse_silu_mul(...,
enabled=False) or a non-trtllm backend to ensure the trtllm fusion is skipped.

---

Nitpick comments:
In `@examples/auto_deploy/model_registry/configs/llama3_1_8b.yaml`:
- Around line 22-24: You enabled the fuse_silu_mul opt-in for the Llama 3.1 8B
config (fuse_silu_mul: enabled, backend: trtllm) but did not add or verify
corresponding perf regression coverage; update the CI/QA perf test lists to
cover this new AutoDeploy path by adding or confirming entries in the L0 test
list (l0_*.yml) and the perf QA list (llm_perf_*.yml) that reference the
llama3_1_8b configuration and the trtllm/fuse_silu_mul combination, or document
why existing entries already cover it; ensure the test names or selectors
explicitly exercise FP8 decode on the trtllm backend so the change is tracked by
perf CI.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Enterprise

Run ID: b5e71f70-8702-4831-be1e-9a8c560f06ca

📥 Commits

Reviewing files that changed from the base of the PR and between abe5570 and 9a96e88.

📒 Files selected for processing (8)
  • examples/auto_deploy/model_registry/configs/llama3_1_8b.yaml
  • tensorrt_llm/_torch/auto_deploy/config/default.yaml
  • tensorrt_llm/_torch/auto_deploy/custom_ops/attention/trtllm_attention.py
  • tensorrt_llm/_torch/auto_deploy/custom_ops/linear/silu_mul.py
  • tensorrt_llm/_torch/auto_deploy/transform/library/fuse_rope_into_trtllm_attention.py
  • tensorrt_llm/_torch/auto_deploy/transform/library/fuse_silu_mul.py
  • tests/unittest/auto_deploy/singlegpu/transformations/library/test_fuse_silu_mul.py
  • tests/unittest/auto_deploy/singlegpu/transformations/library/test_gemm_fusion.py

Comment thread tensorrt_llm/_torch/auto_deploy/transform/library/fuse_silu_mul.py
Comment thread tests/unittest/auto_deploy/singlegpu/transformations/library/test_gemm_fusion.py Outdated
MrGeva added 4 commits May 4, 2026 07:33
The helper is required when piecewise CUDA graphs are enabled: the
dynamic-submodule split causes the cached-attention call_function to
already carry ``out=None`` positionally (because get_constants returns
args after the schema's ``out`` slot), so adding ``out`` as a kwarg
on top would over-supply the argument list.  Removing the helper
silently broke PWCG with:
  RuntimeError: auto_deploy::trtllm_attention_mha_with_cache()
  expected at most 20 argument(s) but received 21 argument(s).

Signed-off-by: Eran Geva <19514940+MrGeva@users.noreply.github.com>
Putting ``out`` last (after the rope and fused-QKV-passthrough hints)
matches the FlashInfer attention op layout: ``get_constants`` no longer
needs to fill ``out=None`` positionally, and ``_inject_out_param`` can
just add it as a kwarg without colliding.  Drops ``_find_positional_out_arg``
that previously worked around the conflict for the PWCG path.

Signed-off-by: Eran Geva <19514940+MrGeva@users.noreply.github.com>
Storing the dtype on attn_node.meta when the rope transform fires
means we still have a trustworthy source: at that point V (and the
pre-RoPE K/Q) are still typed correctly.  Reading from the fused-QKV
node at cache_init time was unreliable because subsequent FP8 GEMM
rewrites can drop meta['val'] on the node we point to.

Drop the silent ~bfloat16~ default in TrtllmAttention.get_cache_initializers
that masked this propagation issue; if neither the pre-captured dtype
nor the fused-QKV node's meta is available, raise a clear assertion.

Signed-off-by: Eran Geva <19514940+MrGeva@users.noreply.github.com>
1. fuse_rope_into_trtllm_attention: _trace_split was only unwrapping
   nodes whose target.__name__ == 'contiguous' (the Tensor.contiguous
   method form), missing aten.contiguous.default emitted by some
   exporter paths.  Factor the unwrap into a shared _unwrap_contiguous
   helper used by both _trace_split and _trace_narrow.

2. fuse_silu_mul: the narrow/getitem matcher was returning
   (gate_parent, gate_size) without checking that the parent's last
   dim is exactly 2*gate_size.  Add the check (reading parent.meta.val)
   so we only fuse when the two narrows together cover the parent's
   full activation axis — the kernel always splits at half, so any
   trailing slice would be silently misinterpreted.

3. tests: add a dedicated test_fuse_qkv_passthrough_with_rope that
   actually exercises fuse_rope_into_trtllm_attention with a model
   carrying torch_rope_with_explicit_cos_sin and asserts the
   _trtllm_fused_qkv / _trtllm_num_heads / _trtllm_num_kv_heads /
   _trtllm_head_dim metadata is set on the attention node.  Existing
   tests only ran fuse_gemms_mixed_children + cache insertion and
   never verified the passthrough fired.

Signed-off-by: Eran Geva <19514940+MrGeva@users.noreply.github.com>
@MrGeva
Copy link
Copy Markdown
Collaborator Author

MrGeva commented May 4, 2026

/bot run --extra-stage "DGX_B200-4_GPUs-AutoDeploy-1, DGX_H100-4_GPUs-AutoDeploy-1" --disable-fail-fast

@MrGeva MrGeva enabled auto-merge (squash) May 4, 2026 19:58
@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #46682 [ run ] triggered by Bot. Commit: afece81 Link to invocation

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #46682 [ run ] completed with state SUCCESS. Commit: afece81
/LLM/main/L0_MergeRequest_PR pipeline #36721 completed with status: 'FAILURE'

CI Report

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

CI Agent Failure Analysis

Link to invocation

Three tests were either broken at authoring time or stopped matching
the current codebase shape; align them with current main:

* test_fuse_qkv_with_trtllm_cache_insertion / _gqa_:
  - CachedSequenceInterface required max_num_tokens since NVIDIA#12708
    (2026-04-13); tests were authored without it.  Pass max_num_tokens=256.
  - The post-cache-insertion graph now contains 4 split-output getitems
    (3 from QKV split + 1 metadata-prep tuple).  Relax the strict
    ==3 check to >=3 so cache-insertion plumbing additions don't break
    the assertion.
  - GEMM fusion no longer emits torch.narrow ops for asymmetric Q/KV
    children since NVIDIA#13091 (2026-04-17); it uses a split_output closure
    whose getitem nodes lack meta['val'].  Read sizes from the
    downstream view consumer when narrow ops aren't present.

* test_insert_cached_attention_trtllm_materializes_out_scale_reciprocal:
  out_scale is now pre-folded to a static get_attr (_trtllm_recip_*
  buffer) when input_scale is a static buffer, instead of always being
  emitted as aten.reciprocal.  Accept either form — the static-fold
  path saves a per-step kernel launch and is the optimized default.

Signed-off-by: Eran Geva <19514940+MrGeva@users.noreply.github.com>
@MrGeva
Copy link
Copy Markdown
Collaborator Author

MrGeva commented May 5, 2026

/bot run --extra-stage "DGX_B200-4_GPUs-AutoDeploy-1, DGX_H100-4_GPUs-AutoDeploy-1" --disable-fail-fast

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #46772 [ run ] triggered by Bot. Commit: 33d0bff Link to invocation

@MrGeva MrGeva disabled auto-merge May 5, 2026 17:41
@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #46772 [ run ] completed with state SUCCESS. Commit: 33d0bff
/LLM/main/L0_MergeRequest_PR pipeline #36795 completed with status: 'FAILURE'

CI Report

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

CI Agent Failure Analysis

Link to invocation

Commit 1297aed added three tests that depend on the TRT-LLM attention
backend into ``test_gemm_fusion.py``:

  * ``test_fuse_qkv_with_trtllm_cache_insertion``
  * ``test_fuse_qkv_gqa_with_trtllm_cache_insertion``
  * ``test_fuse_qkv_passthrough_with_rope`` (+ ``QKVRopeAttentionModel``)

That file is part of the standalone auto_deploy package, so when the
package is built the tests run with ``TRTLLM_AVAILABLE=False`` and:

  - import ``KvCacheConfig`` directly from ``tensorrt_llm.llmapi.llm_args``
    (collection error: ``ModuleNotFoundError: No module named 'tensorrt_llm'``)
  - call ``torch.ops.auto_deploy.trtllm_attention_mha_with_cache``, which
    is only registered when TRT-LLM is present.

Move the three tests + ``QKVRopeAttentionModel`` into
``test_gemm_fusion_trtllm.py``, which is already in ``EXCLUDE_TEST_FILES``
in ``examples/auto_deploy/llmc/create_standalone_package.py`` for exactly
this reason.  ``QKVAttentionModel`` and ``_get_narrow_nodes`` are imported
from ``test_gemm_fusion`` to avoid duplication.

Drop the unused ``KvCacheConfig`` and ``CachedSequenceInterface`` imports
from ``test_gemm_fusion.py``.

Verified locally:
  - standalone package suite: 5/5 pass (1372 inner tests, 0 errors)
  - in-repo run: ``test_gemm_fusion.py`` + ``test_gemm_fusion_trtllm.py``
    -> 30 passed, 2 skipped, 2 xfailed, 2 xpassed

Signed-off-by: Eran Geva <19514940+MrGeva@users.noreply.github.com>
@MrGeva
Copy link
Copy Markdown
Collaborator Author

MrGeva commented May 6, 2026

/bot run --extra-stage "DGX_B200-4_GPUs-AutoDeploy-1, DGX_H100-4_GPUs-AutoDeploy-1" --disable-fail-fast

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #46963 [ run ] triggered by Bot. Commit: 826f4eb Link to invocation

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #46963 [ run ] completed with state SUCCESS. Commit: 826f4eb
/LLM/main/L0_MergeRequest_PR pipeline #36953 completed with status: 'FAILURE'

CI Report

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

CI Agent Failure Analysis

Link to invocation

@MrGeva
Copy link
Copy Markdown
Collaborator Author

MrGeva commented May 7, 2026

/bot run --extra-stage "DGX_B200-4_GPUs-AutoDeploy-1, DGX_H100-4_GPUs-AutoDeploy-1" --disable-fail-fast

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #47107 [ run ] triggered by Bot. Commit: 826f4eb Link to invocation

@tensorrt-cicd
Copy link
Copy Markdown
Collaborator

PR_Github #47107 [ run ] completed with state SUCCESS. Commit: 826f4eb
/LLM/main/L0_MergeRequest_PR pipeline #37076 completed with status: 'SUCCESS'

CI Report

Link to invocation

@MrGeva MrGeva merged commit f45f524 into NVIDIA:main May 7, 2026
6 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants