[#11432][feat] AutoDeploy: Enable fp8 quantization fusion part 1#11910
[#11432][feat] AutoDeploy: Enable fp8 quantization fusion part 1#11910galagam merged 5 commits intoNVIDIA:mainfrom
Conversation
📝 WalkthroughWalkthroughThis pull request introduces FP8 quantization support for RMSNorm fusion in TensorRT-LLM's auto-deploy system. Changes include new Triton-backed kernels for RMSNorm FP8 quantization, graph transformation logic to fuse normalization with quantization, FP8-aware graph utilities, and configuration updates enabling the new transform in model deployments. Changes
Sequence Diagram(s)sequenceDiagram
participant Graph as Torch FX Graph
participant Transform as FuseRMSNormQuantFP8
participant RMSNorm as RMSNorm Node
participant Scale as Scale Node
participant FP8Linear as FP8 Linear Nodes
participant TritonOp as Triton Fused Op
participant PreQuantLinear as PreQuant Linear Op
Graph->>Transform: Apply transform on graph
Transform->>Graph: Scan for RMSNorm patterns
Graph-->>Transform: Found RMSNorm + FP8 consumers
Transform->>RMSNorm: Identify RMSNorm source
Transform->>FP8Linear: Find terminal FP8 linear consumers
FP8Linear-->>Transform: Return consumer nodes & scale
Transform->>Scale: Extract earliest FP8 consumer scale
Transform->>TritonOp: Create fused triton_rms_norm_quant_fp8
TritonOp->>Graph: Insert into graph (outputs: BF16, FP8)
Transform->>FP8Linear: Rewrite to use prequant path
FP8Linear->>PreQuantLinear: Replace with trtllm_fp8_prequant_linear
PreQuantLinear->>Graph: Wire scale & dtype info
Transform->>Graph: Erase obsolete norm nodes
Transform->>Graph: Return transformation info
Estimated code review effort🎯 4 (Complex) | ⏱️ ~45 minutes 🚥 Pre-merge checks | ✅ 2 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Tip Try Coding Plans. Let us write the prompt for your AI agent so you can ship faster (with fewer bugs). Comment |
There was a problem hiding this comment.
Actionable comments posted: 6
🧹 Nitpick comments (4)
tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_quant_fusion.py (1)
10-13: Use module imports instead of direct class imports.Please switch to module-level imports and qualify
TransformConfig/FuseRMSNormQuantFP8at use sites to match repo import policy.♻️ Proposed refactor
-from tensorrt_llm._torch.auto_deploy.transform.interface import TransformConfig -from tensorrt_llm._torch.auto_deploy.transform.library.fuse_rmsnorm_quant_fp8 import ( - FuseRMSNormQuantFP8, -) +from tensorrt_llm._torch.auto_deploy.transform import interface as transform_interface +from tensorrt_llm._torch.auto_deploy.transform.library import ( + fuse_rmsnorm_quant_fp8 as fuse_rmsnorm_quant_fp8_lib, +) @@ - transform = FuseRMSNormQuantFP8(TransformConfig(stage="post_load_fusion")) + transform = fuse_rmsnorm_quant_fp8_lib.FuseRMSNormQuantFP8( + transform_interface.TransformConfig(stage="post_load_fusion") + )As per coding guidelines:
Python imports must use form from package.subpackage import module (never from module import Class).Also applies to: 224-224
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_quant_fusion.py` around lines 10 - 13, Replace direct class imports with module-level imports: import the interface module (e.g., from tensorrt_llm._torch.auto_deploy.transform import interface) and the fuse_rmsnorm_quant_fp8 module (e.g., from tensorrt_llm._torch.auto_deploy.transform.library import fuse_rmsnorm_quant_fp8), then update all uses of TransformConfig and FuseRMSNormQuantFP8 to be qualified (interface.TransformConfig and fuse_rmsnorm_quant_fp8.FuseRMSNormQuantFP8). Make the same change where these classes are imported elsewhere (the other occurrence noted) so all references use the module-qualified names.tensorrt_llm/_torch/auto_deploy/custom_ops/normalization/triton_fused_add_rms_norm_quant_fp8.py (1)
18-23: Align imports with the repo’s module-import convention.Please avoid direct symbol imports (
Tuple,Tensor) and keep namespace-qualified module usage in this file.As per coding guidelines:
Python imports must use form from package.subpackage import module (never from module import Class).🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tensorrt_llm/_torch/auto_deploy/custom_ops/normalization/triton_fused_add_rms_norm_quant_fp8.py` around lines 18 - 23, Replace direct symbol imports: remove "from typing import Tuple" and "from torch import Tensor" and instead use "import typing" and reference typing.Tuple, and use torch.Tensor everywhere (you already have import torch). Also change "import triton.language as tl" to the module-style import "from triton import language as tl" so imports follow the from package.subpackage import module convention; update all occurrences of Tuple and Tensor in this file to typing.Tuple and torch.Tensor respectively.tensorrt_llm/_torch/auto_deploy/transform/library/fuse_rmsnorm_quant_fp8.py (1)
24-39: Prefer module imports over direct symbol imports in this new transform module.Line 24-Line 39 introduces direct function/class imports; switch to module imports to keep namespace visibility and consistency with repo standards.
As per coding guidelines: "When importing in Python, always maintain the namespace. Import the module, not individual classes or functions."
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tensorrt_llm/_torch/auto_deploy/transform/library/fuse_rmsnorm_quant_fp8.py` around lines 24 - 39, The module currently imports many symbols directly (flashinfer_fused_add_rms_norm, ModelFactory, CachedSequenceInterface, extract_op_args, extract_output_tuple, get_shared_input_scale_for_fp8_linears, is_op, BaseTransform, SharedConfig, TransformConfig, TransformInfo, TransformRegistry); change these to module-level imports (e.g., import the containing modules instead of direct symbols) and update all references in this transform (fuse_rmsnorm_quant_fp8.py) to use the module-qualified names so namespace visibility and repo import conventions are preserved.tensorrt_llm/_torch/auto_deploy/custom_ops/attention/trtllm_attention.py (1)
43-48: Use module-levelnode_utilsimport to preserve namespace consistency.The direct function imports at Line 43-Line 48 make this file diverge from the repository import rule; prefer
node_utils.<helper>()call sites.As per coding guidelines: "When importing in Python, always maintain the namespace. Import the module, not individual classes or functions."
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tensorrt_llm/_torch/auto_deploy/custom_ops/attention/trtllm_attention.py` around lines 43 - 48, The file currently imports helper functions directly (collect_terminal_users_through_passthrough, extract_op_args, get_shared_input_scale_for_fp8_linears, set_op_args) from ...utils.node_utils; change this to import the module (e.g., import ...utils.node_utils as node_utils) and update all call sites in trtllm_attention.py to use the module namespace (node_utils.collect_terminal_users_through_passthrough, node_utils.extract_op_args, node_utils.get_shared_input_scale_for_fp8_linears, node_utils.set_op_args) to preserve the repository's namespace import convention.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@examples/auto_deploy/model_registry/configs/llama3_1_8b.yaml`:
- Around line 11-14: The config currently enables only the transform
fuse_rmsnorm_quant_fp8 but omits fuse_fp8_gemms; update the transforms block to
add a fuse_fp8_gemms entry (similar to fuse_rmsnorm_quant_fp8) with stage set to
post_load_fusion and enabled: true so both transforms (fuse_fp8_gemms and
fuse_rmsnorm_quant_fp8) are enabled for the Llama 3.1 8B model.
In `@tensorrt_llm/_torch/auto_deploy/custom_ops/attention/trtllm_attention.py`:
- Around line 643-655: The code can enable FP8 output quantization without any
deterministic out_dtype when source_attn_node.meta["val"] has no dtype; fix by
assigning a deterministic fallback out_dtype_str before applying FP8 to
downstream linears: after retrieving val and attempting to set out_dtype_str
from val.dtype in the block around source_attn_node and val, add a fallback
assignment (e.g., out_dtype_str =
str(torch.get_default_dtype()).replace("torch.", "") or a literal like
"float32") when out_dtype_str is still None, then continue to call
set_op_args(user, out_dtype=out_dtype_str) for each user in fp8_users; keep the
existing handling of first_scale/out_scale unchanged but ensure out_dtype_str is
always set before enabling FP8 output quantization for fp8_users and before
using out_scale.
In
`@tensorrt_llm/_torch/auto_deploy/custom_ops/normalization/triton_fused_add_rms_norm_quant_fp8.py`:
- Around line 115-117: The fake handler functions are intentionally ignoring
some parameters; silence Ruff ARG001 by renaming unused parameters to start with
an underscore. Update the signature of _rms_norm_quant_fp8_fake (and the other
fake registration function around the second occurrence) so any unused args
(e.g., input, weight, eps, scale or similarly named params) are prefixed with an
underscore (for example _input, _weight, _eps, _scale) and keep the body
behavior identical.
In `@tensorrt_llm/_torch/auto_deploy/custom_ops/quantization/quant.py`:
- Around line 235-248: The fake op trtllm_fp8_prequant_linear_fake currently
ignores bias which breaks dtype propagation; update it to mirror the real
prequant path by resolving the base output dtype via
_resolve_out_dtype_or_raise(out_dtype) and, if bias is not None, compute the
promoted/result dtype between that base output dtype and bias.dtype (e.g., using
torch.result_type or torch.promote_types) and use that promoted dtype for the
returned torch.empty tensor so tracing sees the same dtype behavior as the real
operator.
In `@tensorrt_llm/_torch/auto_deploy/transform/library/fuse_rmsnorm_quant_fp8.py`:
- Around line 42-45: _find_fp8_linear_consumers currently only inspects direct
users of the RMSNorm node and therefore misses FP8 linear consumers reached via
trivial passthrough nodes; update _find_fp8_linear_consumers to perform a small
BFS/DFS from norm_node through allowed passthrough nodes (e.g.,
identity/reshape/alias-like nodes) collecting downstream FP8 linear nodes and
then call or reuse get_shared_input_scale_for_fp8_linears on that expanded
consumer set (refer to function _find_fp8_linear_consumers and helper
get_shared_input_scale_for_fp8_linears to locate where to change the traversal).
In `@tensorrt_llm/_torch/auto_deploy/utils/node_utils.py`:
- Around line 542-549: The current check in the fp8 fusion loop (iterating
fp8_linear_nodes and extracting scale via extract_op_args into variable scale
and comparing scale.target to first_scale.target) can falsely treat distinct
non-get_attr Nodes as shared; change the equivalence to require either the exact
same Node object (scale is first_scale) or both scale Nodes to be get_attr ops
that reference the same attribute target/name before treating them as shared; if
neither condition holds, return the fallback ([], None). Ensure you update the
comparison logic around first_scale and scale (and use Node/op identity or op
name "get_attr" plus matching target) in that loop to prevent incorrect FP8
fusion.
---
Nitpick comments:
In `@tensorrt_llm/_torch/auto_deploy/custom_ops/attention/trtllm_attention.py`:
- Around line 43-48: The file currently imports helper functions directly
(collect_terminal_users_through_passthrough, extract_op_args,
get_shared_input_scale_for_fp8_linears, set_op_args) from ...utils.node_utils;
change this to import the module (e.g., import ...utils.node_utils as
node_utils) and update all call sites in trtllm_attention.py to use the module
namespace (node_utils.collect_terminal_users_through_passthrough,
node_utils.extract_op_args, node_utils.get_shared_input_scale_for_fp8_linears,
node_utils.set_op_args) to preserve the repository's namespace import
convention.
In
`@tensorrt_llm/_torch/auto_deploy/custom_ops/normalization/triton_fused_add_rms_norm_quant_fp8.py`:
- Around line 18-23: Replace direct symbol imports: remove "from typing import
Tuple" and "from torch import Tensor" and instead use "import typing" and
reference typing.Tuple, and use torch.Tensor everywhere (you already have import
torch). Also change "import triton.language as tl" to the module-style import
"from triton import language as tl" so imports follow the from
package.subpackage import module convention; update all occurrences of Tuple and
Tensor in this file to typing.Tuple and torch.Tensor respectively.
In `@tensorrt_llm/_torch/auto_deploy/transform/library/fuse_rmsnorm_quant_fp8.py`:
- Around line 24-39: The module currently imports many symbols directly
(flashinfer_fused_add_rms_norm, ModelFactory, CachedSequenceInterface,
extract_op_args, extract_output_tuple, get_shared_input_scale_for_fp8_linears,
is_op, BaseTransform, SharedConfig, TransformConfig, TransformInfo,
TransformRegistry); change these to module-level imports (e.g., import the
containing modules instead of direct symbols) and update all references in this
transform (fuse_rmsnorm_quant_fp8.py) to use the module-qualified names so
namespace visibility and repo import conventions are preserved.
In
`@tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_quant_fusion.py`:
- Around line 10-13: Replace direct class imports with module-level imports:
import the interface module (e.g., from
tensorrt_llm._torch.auto_deploy.transform import interface) and the
fuse_rmsnorm_quant_fp8 module (e.g., from
tensorrt_llm._torch.auto_deploy.transform.library import
fuse_rmsnorm_quant_fp8), then update all uses of TransformConfig and
FuseRMSNormQuantFP8 to be qualified (interface.TransformConfig and
fuse_rmsnorm_quant_fp8.FuseRMSNormQuantFP8). Make the same change where these
classes are imported elsewhere (the other occurrence noted) so all references
use the module-qualified names.
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Run ID: 509558b9-7928-4df0-b183-39e997ee609b
📒 Files selected for processing (11)
examples/auto_deploy/model_registry/configs/dashboard_default.yamlexamples/auto_deploy/model_registry/configs/llama3_1_8b.yamlexamples/auto_deploy/model_registry/models.yamltensorrt_llm/_torch/auto_deploy/config/default.yamltensorrt_llm/_torch/auto_deploy/custom_ops/attention/trtllm_attention.pytensorrt_llm/_torch/auto_deploy/custom_ops/normalization/__init__.pytensorrt_llm/_torch/auto_deploy/custom_ops/normalization/triton_fused_add_rms_norm_quant_fp8.pytensorrt_llm/_torch/auto_deploy/custom_ops/quantization/quant.pytensorrt_llm/_torch/auto_deploy/transform/library/fuse_rmsnorm_quant_fp8.pytensorrt_llm/_torch/auto_deploy/utils/node_utils.pytests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_quant_fusion.py
tensorrt_llm/_torch/auto_deploy/custom_ops/attention/trtllm_attention.py
Outdated
Show resolved
Hide resolved
tensorrt_llm/_torch/auto_deploy/custom_ops/normalization/triton_fused_add_rms_norm_quant_fp8.py
Show resolved
Hide resolved
tensorrt_llm/_torch/auto_deploy/transform/library/fuse_rmsnorm_quant_fp8.py
Outdated
Show resolved
Hide resolved
a4d6517 to
6ba11cb
Compare
|
/bot run --stage-list "DGX_B200-4_GPUs-AutoDeploy-1, DGX_H100-4_GPUs-AutoDeploy-1" |
|
PR_Github #37837 [ run ] triggered by Bot. Commit: |
|
PR_Github #37837 [ run ] completed with state
|
|
/bot run --stage-list "DGX_B200-4_GPUs-AutoDeploy-1, DGX_H100-4_GPUs-AutoDeploy-1" |
|
PR_Github #37851 [ run ] triggered by Bot. Commit: |
|
PR_Github #37851 [ run ] completed with state
|
tensorrt_llm/_torch/auto_deploy/custom_ops/attention/trtllm_attention.py
Outdated
Show resolved
Hide resolved
tensorrt_llm/_torch/auto_deploy/custom_ops/attention/trtllm_attention.py
Show resolved
Hide resolved
tensorrt_llm/_torch/auto_deploy/custom_ops/attention/trtllm_attention.py
Show resolved
Hide resolved
tensorrt_llm/_torch/auto_deploy/transform/library/fuse_rmsnorm_quant_fp8.py
Outdated
Show resolved
Hide resolved
tensorrt_llm/_torch/auto_deploy/transform/library/fuse_rmsnorm_quant_fp8.py
Show resolved
Hide resolved
tests/unittest/auto_deploy/singlegpu/transformations/library/test_quant_fusion.py
Show resolved
Hide resolved
dc90e25 to
f37b5eb
Compare
|
/bot run --stage-list "DGX_B200-4_GPUs-AutoDeploy-1, DGX_H100-4_GPUs-AutoDeploy-1" |
|
PR_Github #39084 [ run ] completed with state |
|
/bot run |
|
PR_Github #39099 [ run ] triggered by Bot. Commit: |
|
PR_Github #39099 [ run ] completed with state
|
|
/bot run |
|
PR_Github #39114 [ run ] triggered by Bot. Commit: |
|
PR_Github #39114 [ run ] completed with state
|
|
/bot run |
|
PR_Github #39193 [ run ] triggered by Bot. Commit: |
|
PR_Github #39193 [ run ] completed with state
|
- Keep attn output in fp8 - Fuse rmnsnorm + quant Signed-off-by: Gal Hubara Agam <96368689+galagam@users.noreply.github.com>
Signed-off-by: Gal Hubara Agam <96368689+galagam@users.noreply.github.com>
Signed-off-by: Gal Hubara Agam <96368689+galagam@users.noreply.github.com>
…/super Signed-off-by: Gal Hubara Agam <96368689+galagam@users.noreply.github.com>
Existing logic does not support fake ops. Enabling fake ops and moving to pattern matching phase requires brittle cross-stage metadata transfer Signed-off-by: Gal Hubara Agam <96368689+galagam@users.noreply.github.com>
e5c74a4 to
403c21f
Compare
|
/bot run --stage-list "DGX_B200-4_GPUs-AutoDeploy-1, DGX_H100-4_GPUs-AutoDeploy-1" |
|
PR_Github #39274 [ run ] triggered by Bot. Commit: |
|
/bot run --stage-list "DGX_B200-4_GPUs-AutoDeploy-1, DGX_H100-4_GPUs-AutoDeploy-1, DGX_H100-4_GPUs-AutoDeploy-Post-Merge-1" |
|
PR_Github #39278 [ run ] triggered by Bot. Commit: |
|
PR_Github #39278 [ run ] completed with state |
|
/bot run |
|
PR_Github #39297 [ run ] triggered by Bot. Commit: |
|
PR_Github #39297 [ run ] completed with state |
NVIDIA#11910) Signed-off-by: Gal Hubara Agam <96368689+galagam@users.noreply.github.com>
Summary by CodeRabbit
New Features
Chores
Description
Future work (to be addressed in a follow-up PR):
Test Coverage
-k fuse_rmsnorm_quant_fp8_rewrites_graph
PR Checklist
Please review the following before submitting your PR:
PR description clearly explains what and why. If using CodeRabbit's summary, please make sure it makes sense.
PR Follows TRT-LLM CODING GUIDELINES to the best of your knowledge.
Test cases are provided for new code paths (see test instructions)
Any new dependencies have been scanned for license and vulnerabilities
CODEOWNERS updated if ownership changes
Documentation updated as needed
Update tava architecture diagram if there is a significant design change in PR.
The reviewers assigned automatically/manually are appropriate for the PR.
Please check this after reviewing the above items as appropriate for this PR.
GitHub Bot Help
To see a list of available CI bot commands, please comment
/bot help.