Skip to content

[#11432][feat] AutoDeploy: Enable fp8 quantization fusion part 1#11910

Merged
galagam merged 5 commits intoNVIDIA:mainfrom
nv-auto-deploy:gagam/fp8-quant-fuse
Mar 17, 2026
Merged

[#11432][feat] AutoDeploy: Enable fp8 quantization fusion part 1#11910
galagam merged 5 commits intoNVIDIA:mainfrom
nv-auto-deploy:gagam/fp8-quant-fuse

Conversation

@galagam
Copy link
Collaborator

@galagam galagam commented Mar 4, 2026

Summary by CodeRabbit

  • New Features

    • Added FP8 quantization fusion for RMSNorm operations to improve model efficiency.
    • Enabled support for Llama-3.1-8B-Instruct-FP8 model with optimized FP8 configurations.
    • Enhanced FP8 output quantization throughout attention and linear operations.
  • Chores

    • Updated model registry configurations with FP8 optimization settings.

Description

  • Enable fp8 quantization as part of attention output, when the trtllm attention backend is used
  • Add a new transformation to enable rmsnorm + fp8 quantization fusion. The fused kernel is implemented with triton. The transformation is disabled by default.
  • Add a new utility function to detect final consumer and skip trivial passthrough ops (copy-ops), allowing passing quantized outputs directly to consumers, even if separated by transpose/reshape/etc.
  • Enable the new transformation in dashboard default config to assess perf on a variety of models.
  • Add a config for llama3.1 8b to the model registry, enabling both fuse_fp8_gemms and the new fuse_rmsnorm_quant_fp8 transformation
  • Re-enable testing nvidia/Llama-3.1-8B-Instruct-FP8 in dashboard

Future work (to be addressed in a follow-up PR):

llama3-1-fp8

Test Coverage

  • pytest tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_quant_fusion.py
    -k fuse_rmsnorm_quant_fp8_rewrites_graph
  • Autodeploy performance dashboard

PR Checklist

Please review the following before submitting your PR:

  • PR description clearly explains what and why. If using CodeRabbit's summary, please make sure it makes sense.

  • PR Follows TRT-LLM CODING GUIDELINES to the best of your knowledge.

  • Test cases are provided for new code paths (see test instructions)

  • Any new dependencies have been scanned for license and vulnerabilities

  • CODEOWNERS updated if ownership changes

  • Documentation updated as needed

  • Update tava architecture diagram if there is a significant design change in PR.

  • The reviewers assigned automatically/manually are appropriate for the PR.

  • Please check this after reviewing the above items as appropriate for this PR.

GitHub Bot Help

To see a list of available CI bot commands, please comment /bot help.

@galagam galagam requested review from a team as code owners March 4, 2026 12:39
@coderabbitai
Copy link
Contributor

coderabbitai bot commented Mar 4, 2026

📝 Walkthrough

Walkthrough

This pull request introduces FP8 quantization support for RMSNorm fusion in TensorRT-LLM's auto-deploy system. Changes include new Triton-backed kernels for RMSNorm FP8 quantization, graph transformation logic to fuse normalization with quantization, FP8-aware graph utilities, and configuration updates enabling the new transform in model deployments.

Changes

Cohort / File(s) Summary
Configuration and Model Registry
examples/auto_deploy/model_registry/configs/dashboard_default.yaml, examples/auto_deploy/model_registry/configs/llama3_1_8b.yaml, examples/auto_deploy/model_registry/models.yaml
Added fuse_rmsnorm_quant_fp8 transform configuration with post_load_fusion stage. Created new Llama 3.1 8B model config file with FP8 kv_cache, 256 max_batch_size, and 16384 max_seq_len. Enabled nvidia/Llama-3.1-8B-Instruct-FP8 model entry with updated yaml_extra sources.
Default Configuration
tensorrt_llm/_torch/auto_deploy/config/default.yaml
Added shape propagation requirement to fuse_fp8_linear transform and introduced fuse_rmsnorm_quant_fp8 transform as disabled by default under post_load_fusion stage.
RMSNorm FP8 Quantization
tensorrt_llm/_torch/auto_deploy/custom_ops/normalization/triton_fused_add_rms_norm_quant_fp8.py, tensorrt_llm/_torch/auto_deploy/custom_ops/normalization/__init__.py
Introduced two new Triton custom ops: triton_rms_norm_quant_fp8 for RMSNorm with FP8 output, and triton_fused_add_rms_norm_quant_fp8 for fused add+RMSNorm+FP8. Each supports BF16 and FP8 outputs with corresponding fake implementations for testing.
FP8 Quantization Core
tensorrt_llm/_torch/auto_deploy/custom_ops/quantization/quant.py
Extended trtllm_quant_fp8_linear with optional out_dtype parameter. Refactored FP8 linear computation into _trtllm_fp8_prequant_linear_core. Added new public functions trtllm_fp8_prequant_linear and corresponding fake implementations with explicit dtype resolution via _resolve_out_dtype_or_raise.
Attention with FP8 Output
tensorrt_llm/_torch/auto_deploy/custom_ops/attention/trtllm_attention.py
Extended trtllm_mha_with_cache and fake variant to accept optional out_scale parameter for FP8 output tensors. Updated get_constants to return out_scale and dynamically detect FP8 paths through terminal consumer analysis. Added imports for FP8 graph analysis utilities.
Graph Transformation
tensorrt_llm/_torch/auto_deploy/transform/library/fuse_rmsnorm_quant_fp8.py
Implemented FuseRMSNormQuantFP8 transform class to fuse RMSNorm with FP8 quantization in Torch FX graphs. Detects direct and fused-add RMSNorm paths, rewrites consumers to use prequantized linear GEMM, and manages node replacement and erasure.
Graph Analysis Utilities
tensorrt_llm/_torch/auto_deploy/utils/node_utils.py
Added is_trivial_passthrough_user for detecting view/reshape/transpose operations, collect_terminal_users_through_passthrough for traversing passthrough nodes, and get_shared_input_scale_for_fp8_linears for FP8 linear scale extraction and validation.
Test Coverage
tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_quant_fusion.py
Added TinyRMSNormQuantFP8 test model and test_fuse_rmsnorm_quant_fp8_rewrites_graph test validating fusion of RMSNorm with FP8 quantization, verifying presence of triton_rms_norm_quant_fp8 and trtllm_fp8_prequant_linear ops post-fusion.

Sequence Diagram(s)

sequenceDiagram
    participant Graph as Torch FX Graph
    participant Transform as FuseRMSNormQuantFP8
    participant RMSNorm as RMSNorm Node
    participant Scale as Scale Node
    participant FP8Linear as FP8 Linear Nodes
    participant TritonOp as Triton Fused Op
    participant PreQuantLinear as PreQuant Linear Op

    Graph->>Transform: Apply transform on graph
    Transform->>Graph: Scan for RMSNorm patterns
    Graph-->>Transform: Found RMSNorm + FP8 consumers
    
    Transform->>RMSNorm: Identify RMSNorm source
    Transform->>FP8Linear: Find terminal FP8 linear consumers
    FP8Linear-->>Transform: Return consumer nodes & scale
    
    Transform->>Scale: Extract earliest FP8 consumer scale
    
    Transform->>TritonOp: Create fused triton_rms_norm_quant_fp8
    TritonOp->>Graph: Insert into graph (outputs: BF16, FP8)
    
    Transform->>FP8Linear: Rewrite to use prequant path
    FP8Linear->>PreQuantLinear: Replace with trtllm_fp8_prequant_linear
    PreQuantLinear->>Graph: Wire scale & dtype info
    
    Transform->>Graph: Erase obsolete norm nodes
    Transform->>Graph: Return transformation info
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~45 minutes

🚥 Pre-merge checks | ✅ 2 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 30.30% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (2 passed)
Check name Status Explanation
Title check ✅ Passed The PR title accurately describes the main feature being introduced: enabling FP8 quantization fusion for the AutoDeploy module, marked as part 1 of the effort.
Description check ✅ Passed The PR description comprehensively covers the main objectives, implementation details, test coverage, and includes a performance benchmark. All required template sections are present.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Tip

Try Coding Plans. Let us write the prompt for your AI agent so you can ship faster (with fewer bugs).
Share your feedback on Discord.


Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 6

🧹 Nitpick comments (4)
tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_quant_fusion.py (1)

10-13: Use module imports instead of direct class imports.

Please switch to module-level imports and qualify TransformConfig / FuseRMSNormQuantFP8 at use sites to match repo import policy.

♻️ Proposed refactor
-from tensorrt_llm._torch.auto_deploy.transform.interface import TransformConfig
-from tensorrt_llm._torch.auto_deploy.transform.library.fuse_rmsnorm_quant_fp8 import (
-    FuseRMSNormQuantFP8,
-)
+from tensorrt_llm._torch.auto_deploy.transform import interface as transform_interface
+from tensorrt_llm._torch.auto_deploy.transform.library import (
+    fuse_rmsnorm_quant_fp8 as fuse_rmsnorm_quant_fp8_lib,
+)
@@
-    transform = FuseRMSNormQuantFP8(TransformConfig(stage="post_load_fusion"))
+    transform = fuse_rmsnorm_quant_fp8_lib.FuseRMSNormQuantFP8(
+        transform_interface.TransformConfig(stage="post_load_fusion")
+    )

As per coding guidelines: Python imports must use form from package.subpackage import module (never from module import Class).

Also applies to: 224-224

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In
`@tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_quant_fusion.py`
around lines 10 - 13, Replace direct class imports with module-level imports:
import the interface module (e.g., from
tensorrt_llm._torch.auto_deploy.transform import interface) and the
fuse_rmsnorm_quant_fp8 module (e.g., from
tensorrt_llm._torch.auto_deploy.transform.library import
fuse_rmsnorm_quant_fp8), then update all uses of TransformConfig and
FuseRMSNormQuantFP8 to be qualified (interface.TransformConfig and
fuse_rmsnorm_quant_fp8.FuseRMSNormQuantFP8). Make the same change where these
classes are imported elsewhere (the other occurrence noted) so all references
use the module-qualified names.
tensorrt_llm/_torch/auto_deploy/custom_ops/normalization/triton_fused_add_rms_norm_quant_fp8.py (1)

18-23: Align imports with the repo’s module-import convention.

Please avoid direct symbol imports (Tuple, Tensor) and keep namespace-qualified module usage in this file.

As per coding guidelines: Python imports must use form from package.subpackage import module (never from module import Class).

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In
`@tensorrt_llm/_torch/auto_deploy/custom_ops/normalization/triton_fused_add_rms_norm_quant_fp8.py`
around lines 18 - 23, Replace direct symbol imports: remove "from typing import
Tuple" and "from torch import Tensor" and instead use "import typing" and
reference typing.Tuple, and use torch.Tensor everywhere (you already have import
torch). Also change "import triton.language as tl" to the module-style import
"from triton import language as tl" so imports follow the from
package.subpackage import module convention; update all occurrences of Tuple and
Tensor in this file to typing.Tuple and torch.Tensor respectively.
tensorrt_llm/_torch/auto_deploy/transform/library/fuse_rmsnorm_quant_fp8.py (1)

24-39: Prefer module imports over direct symbol imports in this new transform module.

Line 24-Line 39 introduces direct function/class imports; switch to module imports to keep namespace visibility and consistency with repo standards.

As per coding guidelines: "When importing in Python, always maintain the namespace. Import the module, not individual classes or functions."

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tensorrt_llm/_torch/auto_deploy/transform/library/fuse_rmsnorm_quant_fp8.py`
around lines 24 - 39, The module currently imports many symbols directly
(flashinfer_fused_add_rms_norm, ModelFactory, CachedSequenceInterface,
extract_op_args, extract_output_tuple, get_shared_input_scale_for_fp8_linears,
is_op, BaseTransform, SharedConfig, TransformConfig, TransformInfo,
TransformRegistry); change these to module-level imports (e.g., import the
containing modules instead of direct symbols) and update all references in this
transform (fuse_rmsnorm_quant_fp8.py) to use the module-qualified names so
namespace visibility and repo import conventions are preserved.
tensorrt_llm/_torch/auto_deploy/custom_ops/attention/trtllm_attention.py (1)

43-48: Use module-level node_utils import to preserve namespace consistency.

The direct function imports at Line 43-Line 48 make this file diverge from the repository import rule; prefer node_utils.<helper>() call sites.

As per coding guidelines: "When importing in Python, always maintain the namespace. Import the module, not individual classes or functions."

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tensorrt_llm/_torch/auto_deploy/custom_ops/attention/trtllm_attention.py`
around lines 43 - 48, The file currently imports helper functions directly
(collect_terminal_users_through_passthrough, extract_op_args,
get_shared_input_scale_for_fp8_linears, set_op_args) from ...utils.node_utils;
change this to import the module (e.g., import ...utils.node_utils as
node_utils) and update all call sites in trtllm_attention.py to use the module
namespace (node_utils.collect_terminal_users_through_passthrough,
node_utils.extract_op_args, node_utils.get_shared_input_scale_for_fp8_linears,
node_utils.set_op_args) to preserve the repository's namespace import
convention.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@examples/auto_deploy/model_registry/configs/llama3_1_8b.yaml`:
- Around line 11-14: The config currently enables only the transform
fuse_rmsnorm_quant_fp8 but omits fuse_fp8_gemms; update the transforms block to
add a fuse_fp8_gemms entry (similar to fuse_rmsnorm_quant_fp8) with stage set to
post_load_fusion and enabled: true so both transforms (fuse_fp8_gemms and
fuse_rmsnorm_quant_fp8) are enabled for the Llama 3.1 8B model.

In `@tensorrt_llm/_torch/auto_deploy/custom_ops/attention/trtllm_attention.py`:
- Around line 643-655: The code can enable FP8 output quantization without any
deterministic out_dtype when source_attn_node.meta["val"] has no dtype; fix by
assigning a deterministic fallback out_dtype_str before applying FP8 to
downstream linears: after retrieving val and attempting to set out_dtype_str
from val.dtype in the block around source_attn_node and val, add a fallback
assignment (e.g., out_dtype_str =
str(torch.get_default_dtype()).replace("torch.", "") or a literal like
"float32") when out_dtype_str is still None, then continue to call
set_op_args(user, out_dtype=out_dtype_str) for each user in fp8_users; keep the
existing handling of first_scale/out_scale unchanged but ensure out_dtype_str is
always set before enabling FP8 output quantization for fp8_users and before
using out_scale.

In
`@tensorrt_llm/_torch/auto_deploy/custom_ops/normalization/triton_fused_add_rms_norm_quant_fp8.py`:
- Around line 115-117: The fake handler functions are intentionally ignoring
some parameters; silence Ruff ARG001 by renaming unused parameters to start with
an underscore. Update the signature of _rms_norm_quant_fp8_fake (and the other
fake registration function around the second occurrence) so any unused args
(e.g., input, weight, eps, scale or similarly named params) are prefixed with an
underscore (for example _input, _weight, _eps, _scale) and keep the body
behavior identical.

In `@tensorrt_llm/_torch/auto_deploy/custom_ops/quantization/quant.py`:
- Around line 235-248: The fake op trtllm_fp8_prequant_linear_fake currently
ignores bias which breaks dtype propagation; update it to mirror the real
prequant path by resolving the base output dtype via
_resolve_out_dtype_or_raise(out_dtype) and, if bias is not None, compute the
promoted/result dtype between that base output dtype and bias.dtype (e.g., using
torch.result_type or torch.promote_types) and use that promoted dtype for the
returned torch.empty tensor so tracing sees the same dtype behavior as the real
operator.

In `@tensorrt_llm/_torch/auto_deploy/transform/library/fuse_rmsnorm_quant_fp8.py`:
- Around line 42-45: _find_fp8_linear_consumers currently only inspects direct
users of the RMSNorm node and therefore misses FP8 linear consumers reached via
trivial passthrough nodes; update _find_fp8_linear_consumers to perform a small
BFS/DFS from norm_node through allowed passthrough nodes (e.g.,
identity/reshape/alias-like nodes) collecting downstream FP8 linear nodes and
then call or reuse get_shared_input_scale_for_fp8_linears on that expanded
consumer set (refer to function _find_fp8_linear_consumers and helper
get_shared_input_scale_for_fp8_linears to locate where to change the traversal).

In `@tensorrt_llm/_torch/auto_deploy/utils/node_utils.py`:
- Around line 542-549: The current check in the fp8 fusion loop (iterating
fp8_linear_nodes and extracting scale via extract_op_args into variable scale
and comparing scale.target to first_scale.target) can falsely treat distinct
non-get_attr Nodes as shared; change the equivalence to require either the exact
same Node object (scale is first_scale) or both scale Nodes to be get_attr ops
that reference the same attribute target/name before treating them as shared; if
neither condition holds, return the fallback ([], None). Ensure you update the
comparison logic around first_scale and scale (and use Node/op identity or op
name "get_attr" plus matching target) in that loop to prevent incorrect FP8
fusion.

---

Nitpick comments:
In `@tensorrt_llm/_torch/auto_deploy/custom_ops/attention/trtllm_attention.py`:
- Around line 43-48: The file currently imports helper functions directly
(collect_terminal_users_through_passthrough, extract_op_args,
get_shared_input_scale_for_fp8_linears, set_op_args) from ...utils.node_utils;
change this to import the module (e.g., import ...utils.node_utils as
node_utils) and update all call sites in trtllm_attention.py to use the module
namespace (node_utils.collect_terminal_users_through_passthrough,
node_utils.extract_op_args, node_utils.get_shared_input_scale_for_fp8_linears,
node_utils.set_op_args) to preserve the repository's namespace import
convention.

In
`@tensorrt_llm/_torch/auto_deploy/custom_ops/normalization/triton_fused_add_rms_norm_quant_fp8.py`:
- Around line 18-23: Replace direct symbol imports: remove "from typing import
Tuple" and "from torch import Tensor" and instead use "import typing" and
reference typing.Tuple, and use torch.Tensor everywhere (you already have import
torch). Also change "import triton.language as tl" to the module-style import
"from triton import language as tl" so imports follow the from
package.subpackage import module convention; update all occurrences of Tuple and
Tensor in this file to typing.Tuple and torch.Tensor respectively.

In `@tensorrt_llm/_torch/auto_deploy/transform/library/fuse_rmsnorm_quant_fp8.py`:
- Around line 24-39: The module currently imports many symbols directly
(flashinfer_fused_add_rms_norm, ModelFactory, CachedSequenceInterface,
extract_op_args, extract_output_tuple, get_shared_input_scale_for_fp8_linears,
is_op, BaseTransform, SharedConfig, TransformConfig, TransformInfo,
TransformRegistry); change these to module-level imports (e.g., import the
containing modules instead of direct symbols) and update all references in this
transform (fuse_rmsnorm_quant_fp8.py) to use the module-qualified names so
namespace visibility and repo import conventions are preserved.

In
`@tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_quant_fusion.py`:
- Around line 10-13: Replace direct class imports with module-level imports:
import the interface module (e.g., from
tensorrt_llm._torch.auto_deploy.transform import interface) and the
fuse_rmsnorm_quant_fp8 module (e.g., from
tensorrt_llm._torch.auto_deploy.transform.library import
fuse_rmsnorm_quant_fp8), then update all uses of TransformConfig and
FuseRMSNormQuantFP8 to be qualified (interface.TransformConfig and
fuse_rmsnorm_quant_fp8.FuseRMSNormQuantFP8). Make the same change where these
classes are imported elsewhere (the other occurrence noted) so all references
use the module-qualified names.

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

Run ID: 509558b9-7928-4df0-b183-39e997ee609b

📥 Commits

Reviewing files that changed from the base of the PR and between b15062e and 789d5a7.

📒 Files selected for processing (11)
  • examples/auto_deploy/model_registry/configs/dashboard_default.yaml
  • examples/auto_deploy/model_registry/configs/llama3_1_8b.yaml
  • examples/auto_deploy/model_registry/models.yaml
  • tensorrt_llm/_torch/auto_deploy/config/default.yaml
  • tensorrt_llm/_torch/auto_deploy/custom_ops/attention/trtllm_attention.py
  • tensorrt_llm/_torch/auto_deploy/custom_ops/normalization/__init__.py
  • tensorrt_llm/_torch/auto_deploy/custom_ops/normalization/triton_fused_add_rms_norm_quant_fp8.py
  • tensorrt_llm/_torch/auto_deploy/custom_ops/quantization/quant.py
  • tensorrt_llm/_torch/auto_deploy/transform/library/fuse_rmsnorm_quant_fp8.py
  • tensorrt_llm/_torch/auto_deploy/utils/node_utils.py
  • tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_quant_fusion.py

@galagam galagam force-pushed the gagam/fp8-quant-fuse branch 3 times, most recently from a4d6517 to 6ba11cb Compare March 4, 2026 14:21
@galagam
Copy link
Collaborator Author

galagam commented Mar 5, 2026

/bot run --stage-list "DGX_B200-4_GPUs-AutoDeploy-1, DGX_H100-4_GPUs-AutoDeploy-1"

@tensorrt-cicd
Copy link
Collaborator

PR_Github #37837 [ run ] triggered by Bot. Commit: dc90e25 Link to invocation

@tensorrt-cicd
Copy link
Collaborator

PR_Github #37837 [ run ] completed with state ABORTED. Commit: dc90e25
/LLM/main/L0_MergeRequest_PR pipeline #29299 (Partly Tested) completed with status: 'ABORTED'

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

Link to invocation

@galagam
Copy link
Collaborator Author

galagam commented Mar 5, 2026

/bot run --stage-list "DGX_B200-4_GPUs-AutoDeploy-1, DGX_H100-4_GPUs-AutoDeploy-1"

@tensorrt-cicd
Copy link
Collaborator

PR_Github #37851 [ run ] triggered by Bot. Commit: dc90e25 Link to invocation

@tensorrt-cicd
Copy link
Collaborator

PR_Github #37851 [ run ] completed with state SUCCESS. Commit: dc90e25
/LLM/main/L0_MergeRequest_PR pipeline #29309 (Partly Tested) completed with status: 'FAILURE'

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

Link to invocation

@galagam galagam force-pushed the gagam/fp8-quant-fuse branch from dc90e25 to f37b5eb Compare March 5, 2026 16:51
@galagam galagam marked this pull request as draft March 5, 2026 16:51
@galagam
Copy link
Collaborator Author

galagam commented Mar 5, 2026

/bot run --stage-list "DGX_B200-4_GPUs-AutoDeploy-1, DGX_H100-4_GPUs-AutoDeploy-1"

@tensorrt-cicd
Copy link
Collaborator

PR_Github #39084 [ run ] completed with state SUCCESS. Commit: e5c74a4
/LLM/main/L0_MergeRequest_PR pipeline #30347 (Partly Tested) completed with status: 'SUCCESS'

CI Report

Link to invocation

@galagam
Copy link
Collaborator Author

galagam commented Mar 16, 2026

/bot run

@galagam galagam enabled auto-merge (squash) March 16, 2026 14:48
@tensorrt-cicd
Copy link
Collaborator

PR_Github #39099 [ run ] triggered by Bot. Commit: e5c74a4 Link to invocation

@tensorrt-cicd
Copy link
Collaborator

PR_Github #39099 [ run ] completed with state SUCCESS. Commit: e5c74a4
/LLM/main/L0_MergeRequest_PR pipeline #30361 completed with status: 'FAILURE'

CI Report

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

Link to invocation

@galagam
Copy link
Collaborator Author

galagam commented Mar 16, 2026

/bot run

@tensorrt-cicd
Copy link
Collaborator

PR_Github #39114 [ run ] triggered by Bot. Commit: e5c74a4 Link to invocation

@tensorrt-cicd
Copy link
Collaborator

PR_Github #39114 [ run ] completed with state SUCCESS. Commit: e5c74a4
/LLM/main/L0_MergeRequest_PR pipeline #30373 completed with status: 'FAILURE'

CI Report

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

Link to invocation

@galagam
Copy link
Collaborator Author

galagam commented Mar 17, 2026

/bot run

@tensorrt-cicd
Copy link
Collaborator

PR_Github #39193 [ run ] triggered by Bot. Commit: e5c74a4 Link to invocation

@tensorrt-cicd
Copy link
Collaborator

PR_Github #39193 [ run ] completed with state FAILURE. Commit: e5c74a4
/LLM/main/L0_MergeRequest_PR pipeline #30445 completed with status: 'FAILURE'

CI Report

⚠️ Action Required:

  • Please check the failed tests and fix your PR
  • If you cannot view the failures, ask the CI triggerer to share details
  • Once fixed, request an NVIDIA team member to trigger CI again

Link to invocation

galagam added 5 commits March 17, 2026 16:47
- Keep attn output in fp8
- Fuse rmnsnorm + quant

Signed-off-by: Gal Hubara Agam <96368689+galagam@users.noreply.github.com>
Signed-off-by: Gal Hubara Agam <96368689+galagam@users.noreply.github.com>
Signed-off-by: Gal Hubara Agam <96368689+galagam@users.noreply.github.com>
…/super

Signed-off-by: Gal Hubara Agam <96368689+galagam@users.noreply.github.com>
Existing logic does not support fake ops.
Enabling fake ops and moving to pattern matching phase requires brittle
cross-stage metadata transfer

Signed-off-by: Gal Hubara Agam <96368689+galagam@users.noreply.github.com>
@galagam galagam force-pushed the gagam/fp8-quant-fuse branch from e5c74a4 to 403c21f Compare March 17, 2026 14:47
@galagam
Copy link
Collaborator Author

galagam commented Mar 17, 2026

/bot run --stage-list "DGX_B200-4_GPUs-AutoDeploy-1, DGX_H100-4_GPUs-AutoDeploy-1"

@galagam galagam disabled auto-merge March 17, 2026 15:08
@tensorrt-cicd
Copy link
Collaborator

PR_Github #39274 [ run ] triggered by Bot. Commit: 403c21f Link to invocation

@galagam
Copy link
Collaborator Author

galagam commented Mar 17, 2026

/bot run --stage-list "DGX_B200-4_GPUs-AutoDeploy-1, DGX_H100-4_GPUs-AutoDeploy-1, DGX_H100-4_GPUs-AutoDeploy-Post-Merge-1"

@tensorrt-cicd
Copy link
Collaborator

PR_Github #39278 [ run ] triggered by Bot. Commit: 403c21f Link to invocation

@tensorrt-cicd
Copy link
Collaborator

PR_Github #39278 [ run ] completed with state SUCCESS. Commit: 403c21f
/LLM/main/L0_MergeRequest_PR pipeline #30526 (Partly Tested) completed with status: 'SUCCESS'

CI Report

Link to invocation

@galagam
Copy link
Collaborator Author

galagam commented Mar 17, 2026

/bot run

@galagam galagam enabled auto-merge (squash) March 17, 2026 17:21
@tensorrt-cicd
Copy link
Collaborator

PR_Github #39297 [ run ] triggered by Bot. Commit: 403c21f Link to invocation

@tensorrt-cicd
Copy link
Collaborator

PR_Github #39297 [ run ] completed with state SUCCESS. Commit: 403c21f
/LLM/main/L0_MergeRequest_PR pipeline #30546 completed with status: 'SUCCESS'

CI Report

Link to invocation

@galagam galagam merged commit 43d3ad8 into NVIDIA:main Mar 17, 2026
6 checks passed
limin2021 pushed a commit to limin2021/TensorRT-LLM that referenced this pull request Mar 19, 2026
NVIDIA#11910)

Signed-off-by: Gal Hubara Agam <96368689+galagam@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants