[#4674][feat] optimize llama8B decode: trtllm silu_mul backend, quant+silu_mul, QKV passthrough to attention#12507
Conversation
7634c33 to
ccfc45b
Compare
|
/bot run --extra-stage "DGX_B200-4_GPUs-AutoDeploy-1, DGX_H100-4_GPUs-AutoDeploy-1" |
|
PR_Github #43527 [ run ] triggered by Bot. Commit: |
|
PR_Github #43527 [ run ] completed with state |
|
/bot run --extra-stage "DGX_B200-4_GPUs-AutoDeploy-1, DGX_H100-4_GPUs-AutoDeploy-1" |
|
PR_Github #43568 [ run ] triggered by Bot. Commit: |
|
PR_Github #43568 [ run ] completed with state
|
0f97953 to
b108e41
Compare
Signed-off-by: Eran Geva <19514940+MrGeva@users.noreply.github.com>
Switch the trtllm silu+mul custom op from a private int-encoded dtype contract (_DTYPE_TO_INT/_INT_TO_DTYPE) to an Optional[str] like trtllm_quant_fp8_linear, resolved internally via getattr(torch, name). Signed-off-by: Eran Geva <19514940+MrGeva@users.noreply.github.com>
Replace the 4-key dispatch dict with the standard
str(dtype).removeprefix("torch.") idiom that already appears
elsewhere in auto_deploy.
Signed-off-by: Eran Geva <19514940+MrGeva@users.noreply.github.com>
|
/bot run --extra-stage "DGX_B200-4_GPUs-AutoDeploy-1, DGX_H100-4_GPUs-AutoDeploy-1" --disable-fail-fast |
|
PR_Github #46647 [ run ] triggered by Bot. Commit: |
📝 WalkthroughWalkthroughThis pull request introduces backend-agnostic SiLU+mul fusion with FlashInfer and TRT-LLM implementations, extends RoPE fusion to support fused-QKV rewiring, and enhances the TRT-LLM attention custom op to accept fused-QKV mode hints. Configuration files add the new ChangesBackend-Specific SiLU+Mul Fusion
Fused QKV Pipeline (RoPE Fusion + TRT-LLM Attention Integration)
Estimated code review effort🎯 4 (Complex) | ⏱️ ~75 minutes 🚥 Pre-merge checks | ✅ 3 | ❌ 2❌ Failed checks (1 warning, 1 inconclusive)
✅ Passed checks (3 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Review rate limit: 9/10 reviews remaining, refill in 6 minutes. Comment |
There was a problem hiding this comment.
Actionable comments posted: 4
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
tests/unittest/auto_deploy/singlegpu/transformations/library/test_fuse_silu_mul.py (1)
85-111:⚠️ Potential issue | 🟠 Major | ⚡ Quick winAdd coverage for the new TRT-LLM fusion path.
These tests still only assert the FlashInfer default path, so the new
backend="trtllm"branch and_try_fuse_fp8_quant()rewrite can regress without failing this suite. Please add at least one case that fuses totorch.ops.auto_deploy.trtllm_silu_and_mul.defaultand verifies the downstreamtrtllm_quant_fp8_linearrewrite (scalefolded into the fused op,out_dtypebackfilled). QA list updates are unnecessary for this unit-only coverage.As per coding guidelines: Coverage expectations — note missing parametrization where multiple backends or dtypes apply, and missing failure modes relevant to the feature.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/unittest/auto_deploy/singlegpu/transformations/library/test_fuse_silu_mul.py` around lines 85 - 111, Add a unit test variant that exercises the TRT-LLM fusion path: call the existing _build_narrow_silu_mul_graph() (or _build_getitem_silu_mul_graph()) and invoke _run_fuse_silu_mul(gm, backend="trtllm") to trigger trtllm fusion; assert info.num_matches == 1 and that the graph contains torch.ops.auto_deploy.trtllm_silu_and_mul.default exactly once and no aten.silu/aten.mul ops, then locate the downstream quant rewrite trtllm_quant_fp8_linear in the transformed graph and assert the scale has been folded into the fused op and out_dtype has been backfilled (verify scale removed from separate node and out_dtype attribute present on the fused op); add a complementary negative test using _run_fuse_silu_mul(..., enabled=False) or a non-trtllm backend to ensure the trtllm fusion is skipped.
🧹 Nitpick comments (1)
examples/auto_deploy/model_registry/configs/llama3_1_8b.yaml (1)
22-24: Please confirm this opt-in is covered by perf CI/QA.This YAML flips Llama 3.1 8B FP8 decode onto the new backend, but I don’t see matching perf-list updates in the provided diff. Please verify that existing entries in
tests/integration/test_lists/test-db/l0_*.ymlandtests/integration/test_lists/qa/llm_perf_*.ymlalready exercise this exact AutoDeploy path; otherwise this optimization won’t be regression-tracked.As per coding guidelines: For performance-sensitive paths, check whether a perf test entry is present or updated in
tests/integration/test_lists/test-db/l0_*.ymlandtests/integration/test_lists/qa/llm_perf_*.yml.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@examples/auto_deploy/model_registry/configs/llama3_1_8b.yaml` around lines 22 - 24, You enabled the fuse_silu_mul opt-in for the Llama 3.1 8B config (fuse_silu_mul: enabled, backend: trtllm) but did not add or verify corresponding perf regression coverage; update the CI/QA perf test lists to cover this new AutoDeploy path by adding or confirming entries in the L0 test list (l0_*.yml) and the perf QA list (llm_perf_*.yml) that reference the llama3_1_8b configuration and the trtllm/fuse_silu_mul combination, or document why existing entries already cover it; ensure the test names or selectors explicitly exercise FP8 decode on the trtllm backend so the change is tracked by perf CI.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@tensorrt_llm/_torch/auto_deploy/custom_ops/attention/trtllm_attention.py`:
- Around line 844-855: When handling fused QKV mode in source_attn_node (check
meta "_trtllm_fused_qkv"), do not default kv_dtype to torch.bfloat16 when
source_attn_node.args[0].meta.get("val") is missing; instead detect q_meta ==
None and raise a clear exception (or assert) indicating missing fused-QKV meta
so tracing/propagation bugs fail fast; keep the existing logic to infer
num_kv_heads and head_dim from "_trtllm_num_kv_heads" and "_trtllm_head_dim" and
only set kv_dtype from q_meta.dtype when q_meta is present (refer to
source_attn_node.args[0], _trtllm_fused_qkv, kv_dtype).
In
`@tensorrt_llm/_torch/auto_deploy/transform/library/fuse_rope_into_trtllm_attention.py`:
- Around line 361-373: _trace_split currently only unwraps nodes whose
call_function target name is "contiguous", so when the exporter emits the aten
overload (torch.ops.aten.contiguous.default) the function fails to recognize it
and returns None; update _trace_split to also detect and unwrap nodes that are
aten.contiguous (e.g., check is_op(current, torch.ops.aten.contiguous.default)
in addition to the getattr(current.target, "__name__") == "contiguous" check),
ensuring the same unwrap logic used for contiguous in _trace_narrow is applied
before testing view/reshape and returning the traced node.
In `@tensorrt_llm/_torch/auto_deploy/transform/library/fuse_silu_mul.py`:
- Around line 333-356: The current narrow-match returns gate_parent and
gate_size without ensuring the parent tensor's last-dimension equals exactly
2*gate_size; update the check in the fuse_silu_mul matcher (around
_get_narrow_info usage and the return) to verify the parent extent covers
exactly the fused projection by asserting the parent tensor's relevant dimension
== gate_size * 2 (for split-based matches also ensure split sizes are exactly
[gate_size, gate_size]); if this condition fails, return None so you only
rewrite when the parent extent is exactly 2*d.
In
`@tests/unittest/auto_deploy/singlegpu/transformations/library/test_gemm_fusion.py`:
- Around line 842-897: The test never verifies the fused-QKV passthrough because
QKVAttentionModel uses torch_attention (no torch_rope_* nodes), so update the
test_gemm_fusion.py case (the function
test_fuse_qkv_with_trtllm_cache_insertion) to exercise a rope-backed path or to
assert fused-QKV hints: ensure the graph includes rope nodes (so
fuse_rope_into_trtllm_attention runs) or after running
InferenceOptimizer(insert_cached_attention) locate the inserted
trtllm_attention_mha_with_cache node
(torch.ops.auto_deploy.trtllm_attention_mha_with_cache.default) and assert its
metadata includes a non-zero _trtllm_fused_qkv / head-hints value; reference
functions/classes: QKVAttentionModel, fuse_rope_into_trtllm_attention,
_trtllm_fused_qkv, and insert_cached_attention when adding the check.
---
Outside diff comments:
In
`@tests/unittest/auto_deploy/singlegpu/transformations/library/test_fuse_silu_mul.py`:
- Around line 85-111: Add a unit test variant that exercises the TRT-LLM fusion
path: call the existing _build_narrow_silu_mul_graph() (or
_build_getitem_silu_mul_graph()) and invoke _run_fuse_silu_mul(gm,
backend="trtllm") to trigger trtllm fusion; assert info.num_matches == 1 and
that the graph contains torch.ops.auto_deploy.trtllm_silu_and_mul.default
exactly once and no aten.silu/aten.mul ops, then locate the downstream quant
rewrite trtllm_quant_fp8_linear in the transformed graph and assert the scale
has been folded into the fused op and out_dtype has been backfilled (verify
scale removed from separate node and out_dtype attribute present on the fused
op); add a complementary negative test using _run_fuse_silu_mul(...,
enabled=False) or a non-trtllm backend to ensure the trtllm fusion is skipped.
---
Nitpick comments:
In `@examples/auto_deploy/model_registry/configs/llama3_1_8b.yaml`:
- Around line 22-24: You enabled the fuse_silu_mul opt-in for the Llama 3.1 8B
config (fuse_silu_mul: enabled, backend: trtllm) but did not add or verify
corresponding perf regression coverage; update the CI/QA perf test lists to
cover this new AutoDeploy path by adding or confirming entries in the L0 test
list (l0_*.yml) and the perf QA list (llm_perf_*.yml) that reference the
llama3_1_8b configuration and the trtllm/fuse_silu_mul combination, or document
why existing entries already cover it; ensure the test names or selectors
explicitly exercise FP8 decode on the trtllm backend so the change is tracked by
perf CI.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Enterprise
Run ID: b5e71f70-8702-4831-be1e-9a8c560f06ca
📒 Files selected for processing (8)
examples/auto_deploy/model_registry/configs/llama3_1_8b.yamltensorrt_llm/_torch/auto_deploy/config/default.yamltensorrt_llm/_torch/auto_deploy/custom_ops/attention/trtllm_attention.pytensorrt_llm/_torch/auto_deploy/custom_ops/linear/silu_mul.pytensorrt_llm/_torch/auto_deploy/transform/library/fuse_rope_into_trtllm_attention.pytensorrt_llm/_torch/auto_deploy/transform/library/fuse_silu_mul.pytests/unittest/auto_deploy/singlegpu/transformations/library/test_fuse_silu_mul.pytests/unittest/auto_deploy/singlegpu/transformations/library/test_gemm_fusion.py
The helper is required when piecewise CUDA graphs are enabled: the dynamic-submodule split causes the cached-attention call_function to already carry ``out=None`` positionally (because get_constants returns args after the schema's ``out`` slot), so adding ``out`` as a kwarg on top would over-supply the argument list. Removing the helper silently broke PWCG with: RuntimeError: auto_deploy::trtllm_attention_mha_with_cache() expected at most 20 argument(s) but received 21 argument(s). Signed-off-by: Eran Geva <19514940+MrGeva@users.noreply.github.com>
Putting ``out`` last (after the rope and fused-QKV-passthrough hints) matches the FlashInfer attention op layout: ``get_constants`` no longer needs to fill ``out=None`` positionally, and ``_inject_out_param`` can just add it as a kwarg without colliding. Drops ``_find_positional_out_arg`` that previously worked around the conflict for the PWCG path. Signed-off-by: Eran Geva <19514940+MrGeva@users.noreply.github.com>
Storing the dtype on attn_node.meta when the rope transform fires means we still have a trustworthy source: at that point V (and the pre-RoPE K/Q) are still typed correctly. Reading from the fused-QKV node at cache_init time was unreliable because subsequent FP8 GEMM rewrites can drop meta['val'] on the node we point to. Drop the silent ~bfloat16~ default in TrtllmAttention.get_cache_initializers that masked this propagation issue; if neither the pre-captured dtype nor the fused-QKV node's meta is available, raise a clear assertion. Signed-off-by: Eran Geva <19514940+MrGeva@users.noreply.github.com>
1. fuse_rope_into_trtllm_attention: _trace_split was only unwrapping nodes whose target.__name__ == 'contiguous' (the Tensor.contiguous method form), missing aten.contiguous.default emitted by some exporter paths. Factor the unwrap into a shared _unwrap_contiguous helper used by both _trace_split and _trace_narrow. 2. fuse_silu_mul: the narrow/getitem matcher was returning (gate_parent, gate_size) without checking that the parent's last dim is exactly 2*gate_size. Add the check (reading parent.meta.val) so we only fuse when the two narrows together cover the parent's full activation axis — the kernel always splits at half, so any trailing slice would be silently misinterpreted. 3. tests: add a dedicated test_fuse_qkv_passthrough_with_rope that actually exercises fuse_rope_into_trtllm_attention with a model carrying torch_rope_with_explicit_cos_sin and asserts the _trtllm_fused_qkv / _trtllm_num_heads / _trtllm_num_kv_heads / _trtllm_head_dim metadata is set on the attention node. Existing tests only ran fuse_gemms_mixed_children + cache insertion and never verified the passthrough fired. Signed-off-by: Eran Geva <19514940+MrGeva@users.noreply.github.com>
|
/bot run --extra-stage "DGX_B200-4_GPUs-AutoDeploy-1, DGX_H100-4_GPUs-AutoDeploy-1" --disable-fail-fast |
|
PR_Github #46682 [ run ] triggered by Bot. Commit: |
|
PR_Github #46682 [ run ] completed with state
|
Three tests were either broken at authoring time or stopped matching the current codebase shape; align them with current main: * test_fuse_qkv_with_trtllm_cache_insertion / _gqa_: - CachedSequenceInterface required max_num_tokens since NVIDIA#12708 (2026-04-13); tests were authored without it. Pass max_num_tokens=256. - The post-cache-insertion graph now contains 4 split-output getitems (3 from QKV split + 1 metadata-prep tuple). Relax the strict ==3 check to >=3 so cache-insertion plumbing additions don't break the assertion. - GEMM fusion no longer emits torch.narrow ops for asymmetric Q/KV children since NVIDIA#13091 (2026-04-17); it uses a split_output closure whose getitem nodes lack meta['val']. Read sizes from the downstream view consumer when narrow ops aren't present. * test_insert_cached_attention_trtllm_materializes_out_scale_reciprocal: out_scale is now pre-folded to a static get_attr (_trtllm_recip_* buffer) when input_scale is a static buffer, instead of always being emitted as aten.reciprocal. Accept either form — the static-fold path saves a per-step kernel launch and is the optimized default. Signed-off-by: Eran Geva <19514940+MrGeva@users.noreply.github.com>
|
/bot run --extra-stage "DGX_B200-4_GPUs-AutoDeploy-1, DGX_H100-4_GPUs-AutoDeploy-1" --disable-fail-fast |
|
PR_Github #46772 [ run ] triggered by Bot. Commit: |
|
PR_Github #46772 [ run ] completed with state
|
Commit 1297aed added three tests that depend on the TRT-LLM attention backend into ``test_gemm_fusion.py``: * ``test_fuse_qkv_with_trtllm_cache_insertion`` * ``test_fuse_qkv_gqa_with_trtllm_cache_insertion`` * ``test_fuse_qkv_passthrough_with_rope`` (+ ``QKVRopeAttentionModel``) That file is part of the standalone auto_deploy package, so when the package is built the tests run with ``TRTLLM_AVAILABLE=False`` and: - import ``KvCacheConfig`` directly from ``tensorrt_llm.llmapi.llm_args`` (collection error: ``ModuleNotFoundError: No module named 'tensorrt_llm'``) - call ``torch.ops.auto_deploy.trtllm_attention_mha_with_cache``, which is only registered when TRT-LLM is present. Move the three tests + ``QKVRopeAttentionModel`` into ``test_gemm_fusion_trtllm.py``, which is already in ``EXCLUDE_TEST_FILES`` in ``examples/auto_deploy/llmc/create_standalone_package.py`` for exactly this reason. ``QKVAttentionModel`` and ``_get_narrow_nodes`` are imported from ``test_gemm_fusion`` to avoid duplication. Drop the unused ``KvCacheConfig`` and ``CachedSequenceInterface`` imports from ``test_gemm_fusion.py``. Verified locally: - standalone package suite: 5/5 pass (1372 inner tests, 0 errors) - in-repo run: ``test_gemm_fusion.py`` + ``test_gemm_fusion_trtllm.py`` -> 30 passed, 2 skipped, 2 xfailed, 2 xpassed Signed-off-by: Eran Geva <19514940+MrGeva@users.noreply.github.com>
|
/bot run --extra-stage "DGX_B200-4_GPUs-AutoDeploy-1, DGX_H100-4_GPUs-AutoDeploy-1" --disable-fail-fast |
|
PR_Github #46963 [ run ] triggered by Bot. Commit: |
|
PR_Github #46963 [ run ] completed with state
|
|
/bot run --extra-stage "DGX_B200-4_GPUs-AutoDeploy-1, DGX_H100-4_GPUs-AutoDeploy-1" --disable-fail-fast |
|
PR_Github #47107 [ run ] triggered by Bot. Commit: |
|
PR_Github #47107 [ run ] completed with state |
Three optimizations to close the AD-vs-PT decode gap on Llama 3.1 8B FP8:
Summary by CodeRabbit
Refactor
Tests
Description
Test Coverage
PR Checklist
Please review the following before submitting your PR:
PR description clearly explains what and why. If using CodeRabbit's summary, please make sure it makes sense.
PR Follows TRT-LLM CODING GUIDELINES to the best of your knowledge.
Test cases are provided for new code paths (see test instructions)
Any new dependencies have been scanned for license and vulnerabilities
CODEOWNERS updated if ownership changes
Documentation updated as needed
Update tava architecture diagram if there is a significant design change in PR.
The reviewers assigned automatically/manually are appropriate for the PR.
Please check this after reviewing the above items as appropriate for this PR.
GitHub Bot Help
To see a list of available CI bot commands, please comment
/bot help.