Skip to content

[OMNIML-3349] Add FP8 MHA quantization support for HuggingFace ViT#1289

Merged
ajrasane merged 4 commits intomainfrom
ajrasane/mha_quantization
Apr 23, 2026
Merged

[OMNIML-3349] Add FP8 MHA quantization support for HuggingFace ViT#1289
ajrasane merged 4 commits intomainfrom
ajrasane/mha_quantization

Conversation

@ajrasane
Copy link
Copy Markdown
Contributor

@ajrasane ajrasane commented Apr 17, 2026

Summary

Enables TensorRT attention-v2 fusion for vision transformers when exported to ONNX with FP8 Q/DQ. The core library changes are architecture-agnostic (drop-in for any FP8 ONNX export); coverage is exercised by the existing examples/torch_onnx/torch_quant_to_onnx.py pipeline.

  • modelopt/onnx/export/fp8_exporter.py — new post-processing passes: move attention-scaling Mul and K Transpose to the Q-side so DQ feeds MatMul directly, pre-transpose constant weights, and insert FP8 Q/DQ on Softmax outputs (fixed 1/448 scale, data-independent) for MHA-v2 fusion. Rewrites only fire when every downstream consumer is a MatMul so non-attention branches are never perturbed.
  • modelopt/onnx/utils.pyfold_dq_fp32_to_fp16_casts / fold_q_fp16_to_fp32_casts remove the Cast nodes convert_float_to_float16 inserts around Q/DQ and rewrite scale initializers to FP16 so TRT fuses DQ into the downstream GEMM. Guarded behind opset >= 19 (FP16 Q/DQ scale requirement). Warns on FP16 overflow/underflow.
  • modelopt/torch/_deploy/utils/torch_onnx.py — calls the fold helpers for FP8-quantized models after convert_float_to_float16.
  • modelopt/torch/quantization/export_onnx.py — keeps FP8 Q/DQ scale in the native input dtype so no Cast is emitted between graph and Q/DQ. Removes the now-unused trt_high_precision_dtype parameter from _fp8_quantize/_fp8_dequantize.
  • modelopt/torch/quantization/nn/modules/quant_layernorm.py (new) — registers nn.LayerNorm in QuantModuleRegistry so LayerNorm output quantizers are honored.
  • modelopt/torch/quantization/plugins/huggingface.py — skips *Attention wrappers whose children are also *Attention per-instance (not per-class) to avoid double-patching eager_attention_forward (e.g. ViTAttention vs ViTSelfAttention).
  • examples/torch_onnx/torch_quant_to_onnx.py — adds a _FP8_MHA_OVERRIDE config block to FP8 mode that enables LayerNorm output quantizer + disables its input quantizer for TRT attention fusion.
  • Unit tests (12 CPU tests, ~1.2s total) — fp8_exporter rewrites + fanout safety, fold-cast helpers + opset guard, LayerNorm quant-wrapper identity, per-instance nested-attention detection.

Benchmarks

ViT-base-patch16-224, RTX 6000 Ada, strongly-typed FP8 via trtexec. Accuracy on 2 000 ImageNet-1k validation samples (streaming).

Batch = 1 (latency-bound)

Model Top-1 Top-5 TRT latency Speedup
FP16 baseline 80.96% 95.80% 0.722 ms 1.00x
Torch FP8 MHA 80.66% 95.75% 0.657 ms 1.10x
ONNX PTQ FP8 0.589 ms 1.23x

Batch = 64 (throughput-bound, realistic inference)

Model TRT latency Speedup Images/s
FP16 baseline 23.40 ms 1.00x 1152
Torch FP8 MHA 15.89 ms 1.47x 1152
ONNX PTQ FP8 15.89 ms 1.47x 1216

Top-1 accuracy stays within 0.30 pp of FP16; at batch=64 the Torch FP8 MHA path matches ONNX PTQ wall-time — attention is the bottleneck there and both paths achieve full FP8 attention fusion (36/36 attention MatMuls with QDQ in ViT-base).

Test plan

  • CPU unit tests (new): `python -m pytest tests/unit/onnx/quantization/test_fp8_mha_exporter.py tests/unit/onnx/test_fold_casts.py tests/unit/torch/quantization/test_quant_layernorm.py tests/unit/torch/quantization/plugins/test_nested_attention_skip.py`
  • Existing ONNX / quantization unit suites unaffected: `python -m pytest tests/unit/onnx tests/unit/torch/quantization`
  • End-to-end ViT FP8 export: `python examples/torch_onnx/torch_quant_to_onnx.py --timm_model_name vit_base_patch16_224 --quantize_mode fp8 --onnx_save_path vit_base_fp8.onnx` — expect log lines `Folded 48 weight Transpose nodes`, `Inserted FP8 weight DequantizeLinear for 1 Conv nodes`, and `Attention QDQ rewrites: ... inserted QDQ on 12 Softmax outputs`
  • trtexec FP8 strongly-typed build: `trtexec --onnx=vit_base_fp8.onnx --fp8 --stronglyTyped`
  • Accuracy within ~0.3 pp of FP16 baseline on ImageNet-1k subset

@ajrasane ajrasane requested review from a team as code owners April 17, 2026 20:14
@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented Apr 17, 2026

Note

Reviews paused

It looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the reviews.auto_review.auto_pause_after_reviewed_commits setting.

Use the following commands to manage reviews:

  • @coderabbitai resume to resume automatic reviews.
  • @coderabbitai review to trigger a single review.

Use the checkboxes below for quick actions:

  • ▶️ Resume reviews
  • 🔍 Trigger review
📝 Walkthrough

Walkthrough

Adds FP8-focused ONNX export and post-processing: constant/transpose folding into FP8 weights, two attention-aware Q/DQ graph rewrites, FP16/FP32 scale-cast folding utilities, simplified FP8 export helpers, LayerNorm/Softmax quant-module registrations, and HF nested-attention detection tweaks.

Changes

Cohort / File(s) Summary
FP8 ONNX exporter & weight folding
modelopt/onnx/export/fp8_exporter.py
Added module-level _FP8_E4M3_MAX constant; changed Conv weight FP8 scale computation. compress_weights now raises RuntimeError when expected Q→DQ pairing is missing, can pre-transpose/fold constant weight Transpose (and optional Cast) into FP8 encoding, updates DQ output shapes, rewires downstream inputs, and tracks n_t_folded. Added two post-process graph-rewrite helpers (_move_mul_before_qdq, _move_transpose_before_qdq) and integrated them into post_process with consolidated logging.
ONNX utils: Q/DQ cast folding & helpers
modelopt/onnx/utils.py
Introduced _DQ_OPS/_Q_OPS classification sets and _scale_fp32_to_fp16 (mutates scale initializers to FLOAT16 with safety warnings). Added fold_q_fp16_to_fp32_casts to remove CastQuantizeLinear patterns by converting scale tensors; updated existing fp32↔fp16 fold functions to skip transforms for opset < BASE_MIN_OPSET and use module-level NumPy.
FP8 export helpers (Torch → ONNX)
modelopt/torch/quantization/export_onnx.py
Removed trt_high_precision_dtype from _fp8_quantize/_fp8_dequantize signatures and logic; helpers now preserve native scalar types and derive FP8 ops/output element types from input scalarType() without intermediate casting. Public export_fp8 keeps parameter but does not use it.
Deployment ONNX pipeline integration
modelopt/torch/_deploy/utils/torch_onnx.py
When is_fp8_quantized(model) is true, get_onnx_bytes_and_metadata() now runs fold_q_fp16_to_fp32_casts then fold_dq_fp32_to_fp16_casts after remove_redundant_casts to adjust Cast placement around Q/DQ patterns.
Quantization module API surface
modelopt/torch/quantization/nn/__init__.py, modelopt/torch/quantization/nn/modules/quant_layernorm.py, modelopt/torch/quantization/nn/modules/quant_softmax.py
Added new modules registering torch.nn.LayerNorm and torch.nn.Softmax with QuantModuleRegistry (handler QuantInputBase); package __init__ now wildcard re-exports these modules so their public symbols are exported.
HuggingFace plugin: attention detection
modelopt/torch/quantization/plugins/huggingface.py
register_hf_attentions_on_the_fly() adds per-module helper to detect and skip registration for modules that wrap nested attention children (skip when a nested submodule class name ends with Attention), avoiding double-registration.
Examples & docs
examples/torch_onnx/torch_quant_to_onnx.py, CHANGELOG.rst
Example script extended with FP8 MHA config overrides, softmax injection/patching helpers for ViT attention, and runtime patching when --quantize_mode fp8. Changelog entry documenting FP8 MHA quantization support, LayerNorm registration, and HF nested-attention behavior.

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~50 minutes

🚥 Pre-merge checks | ✅ 3 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 72.34% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (3 passed)
Check name Status Explanation
Security Anti-Patterns ✅ Passed Files do not introduce unsafe torch.load/numpy.load patterns, hardcoded trust_remote_code=True in executable code, eval/exec on external input, nosec comments, or non-permissively licensed dependencies.
Title check ✅ Passed The title accurately summarizes the main change: adding FP8 MHA (multi-head attention) quantization support for HuggingFace ViT models, which is the core objective across multiple files in the changeset.
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Commit unit tests in branch ajrasane/mha_quantization

Comment @coderabbitai help to get the list of available commands and usage tips.

@github-actions
Copy link
Copy Markdown
Contributor

github-actions Bot commented Apr 17, 2026

PR Preview Action v1.8.1
Preview removed because the pull request was closed.
2026-04-23 15:38 UTC

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 5

🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@examples/torch_onnx/vit_mha_quantization.py`:
- Around line 267-298: The current loop counts any MatMul consuming a
DequantizeLinear output (matmul_with_qdq) which falsely includes projection/MLP
matmuls; replace this with an attention-specific check: implement a helper
(e.g., is_attention_matmul(node, output_to_node, graph)) and only increment
matmul_with_qdq when that returns true. Make is_attention_matmul examine the
MatMul node name and upstream pattern (check parent ops via output_to_node for
Transpose/Reshape, Softmax, or names containing tokens like "q", "k", "v",
"attn", "score", "softmax") or detect the Q@K^T pattern by verifying one input
path comes from a Transpose of a Q-like tensor and the other from K-like tensor;
for attn@V detect the MatMul consuming Softmax output and a V-like source.
Update the loop that currently inspects node.op_type == "MatMul" and uses
inputs_from_dq to call this helper and only count/print when both QDQ and
attention pattern match.
- Around line 225-230: The export is mutating the live model because
model.float() is in-place, which alters base_model/quantized_model used later;
fix by exporting from a detached copy instead (e.g., create a deep copy of model
with copy.deepcopy(model) and call .float() or .to(torch.float16) on that copy)
so get_onnx_bytes_and_metadata receives a non-mutated model; ensure you import
copy and use the copied instance when calling get_onnx_bytes_and_metadata to
avoid changing base_model/quantized_model before accuracy evaluation.

In `@modelopt/onnx/export/fp8_exporter.py`:
- Around line 100-108: The code currently uses any(c.op == "MatMul" for c in
candidate.outputs[0].outputs) and then rewires/clears all consumers which breaks
non-MatMul branches; change the logic to require all(c.op == "MatMul" for c in
candidate.outputs[0].outputs) before performing the global rewrite OR,
preferably, only rewrite the specific MatMul edges: iterate
candidate.outputs[0].outputs, for each consumer c with c.op == "MatMul" rewire
that consumer's input to use the transposed/scaled/quantized tensor and leave
other consumers untouched, and do not clear original outputs (update
transpose_to_remove only when all downstream edges have been safely redirected).
Apply the same fix pattern to the other rewrite sites that manipulate
torch_weights, perm, transpose_to_remove, and similar MatMul-aware transforms.

In `@modelopt/onnx/utils.py`:
- Around line 1422-1505: The fold helpers unconditionally convert Q/DQ scale
initializers to FLOAT16 which is invalid for opsets < BASE_MIN_OPSET; update
_scale_fp32_to_fp16, fold_dq_fp32_to_fp16_casts and fold_q_fp16_to_fp32_casts to
guard the mutation by checking get_opset_version(onnx_model) (or the model
passed in) and only perform the FP32→FP16 rewrite when
get_opset_version(onnx_model) >= BASE_MIN_OPSET; if the check fails, skip
mutating initializers and skip folding the cast nodes (i.e., return the model
unchanged or continue without calling _scale_fp32_to_fp16/_bypass_cast_node),
using the existing function names (_scale_fp32_to_fp16,
fold_dq_fp32_to_fp16_casts, fold_q_fp16_to_fp32_casts) and constants
(BASE_MIN_OPSET) to locate where to add the guard.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro Plus

Run ID: 005378aa-8fac-4f2d-98a1-55297415cbe3

📥 Commits

Reviewing files that changed from the base of the PR and between e4b054b and d6533ac.

📒 Files selected for processing (8)
  • examples/torch_onnx/vit_mha_quantization.py
  • modelopt/onnx/export/fp8_exporter.py
  • modelopt/onnx/utils.py
  • modelopt/torch/_deploy/utils/torch_onnx.py
  • modelopt/torch/quantization/export_onnx.py
  • modelopt/torch/quantization/nn/__init__.py
  • modelopt/torch/quantization/nn/modules/quant_layernorm.py
  • modelopt/torch/quantization/plugins/huggingface.py

Comment thread examples/torch_onnx/vit_mha_quantization.py Outdated
Comment thread examples/torch_onnx/vit_mha_quantization.py Outdated
Comment thread modelopt/onnx/export/fp8_exporter.py Outdated
Comment thread modelopt/onnx/utils.py
Comment thread modelopt/torch/quantization/plugins/huggingface.py Outdated
Copy link
Copy Markdown
Collaborator

@cjluo-nv cjluo-nv left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This PR adds FP8 MHA quantization support for HuggingFace ViT models with ONNX export optimizations. The implementation is well-structured and addresses a real gap (NVBug 6078291). However, there are several issues to address:

Critical issues:

  1. No unit tests — This is ~933 lines of new/changed library code across core export paths (fp8_exporter.py, utils.py, export_onnx.py, huggingface.py) with zero unit tests. The only "test" is the example script which requires GPU, ImageNet data, and TRT. The graph rewrite functions in fp8_exporter.py, the cast folding helpers in utils.py, the attention skipping logic in huggingface.py, and the LayerNorm quantization registration all need unit tests.

  2. Bare assert for runtime validation in fp8_exporter.py — the existing assert on QDQ pair validation will be stripped with -O.

  3. Silent contextlib.suppress(Exception) in the example — can mask real failures during benchmark parsing.

Minor issues:
4. The _scale_fp32_to_fp16 helper doesn't handle the case where the scale value overflows or underflows to inf/0 in FP16 — this could silently produce bad quantization results for extreme scales.

  1. The _move_mul_before_qdq rewrite assumes a single scalar const Mul for attention scaling; if the model architecture changes, these pattern-matching rewrites could silently become no-ops without any warning.

  2. The _insert_qdq_after_softmax hardcodes scale=1/448 which is correct for E4M3 but should at minimum document why this specific value and that it's tied to the FP8 E4M3 max representable value.

Positive aspects:

  • Clean separation of graph rewrites as static methods
  • Good docstrings on the new functions
  • The parent_attention_types detection for avoiding double-patching is well done
  • The LayerNorm registration follows existing patterns exactly

Comment thread modelopt/onnx/export/fp8_exporter.py Outdated
Comment thread modelopt/onnx/export/fp8_exporter.py Outdated
Comment thread modelopt/onnx/export/fp8_exporter.py
Comment thread modelopt/onnx/utils.py
Comment thread modelopt/onnx/utils.py
Comment thread modelopt/torch/quantization/export_onnx.py
Comment thread modelopt/torch/quantization/plugins/huggingface.py Outdated
Comment thread modelopt/torch/quantization/nn/modules/quant_layernorm.py
Comment thread examples/torch_onnx/vit_mha_quantization.py Outdated
Comment thread examples/torch_onnx/vit_mha_quantization.py Outdated
@codecov
Copy link
Copy Markdown

codecov Bot commented Apr 17, 2026

Codecov Report

❌ Patch coverage is 90.36697% with 21 lines in your changes missing coverage. Please review.
✅ Project coverage is 75.73%. Comparing base (c796611) to head (a9f87bf).
⚠️ Report is 3 commits behind head on main.

Files with missing lines Patch % Lines
modelopt/onnx/export/fp8_exporter.py 89.61% 16 Missing ⚠️
modelopt/onnx/utils.py 89.58% 5 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main    #1289      +/-   ##
==========================================
+ Coverage   74.60%   75.73%   +1.12%     
==========================================
  Files         467      468       +1     
  Lines       50176    50374     +198     
==========================================
+ Hits        37435    38151     +716     
+ Misses      12741    12223     -518     
Flag Coverage Δ
examples 41.57% <72.01%> (+5.97%) ⬆️
gpu 58.32% <11.92%> (-0.71%) ⬇️
regression 14.78% <2.29%> (+0.02%) ⬆️
unit 52.53% <68.80%> (+0.17%) ⬆️

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

Comment thread examples/torch_onnx/vit_mha_quantization.py Outdated
@ajrasane ajrasane force-pushed the ajrasane/mha_quantization branch from d6533ac to c436553 Compare April 20, 2026 14:24
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🧹 Nitpick comments (1)
modelopt/onnx/export/fp8_exporter.py (1)

79-81: Replace assert with explicit exception for runtime validation.

Per codebase conventions, assert statements are stripped with -O flag. Use raise RuntimeError(...) for runtime validation that must always execute.

Suggested fix
-                assert dq_op.op == "TRT_FP8DequantizeLinear", (
-                    f"QDQ does not occur in pairs. You reached {dq_op.op}"
-                )
+                if dq_op.op != "TRT_FP8DequantizeLinear":
+                    raise RuntimeError(f"QDQ does not occur in pairs. You reached {dq_op.op}")
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/onnx/export/fp8_exporter.py` around lines 79 - 81, The assertion
using assert dq_op.op == "TRT_FP8DequantizeLinear" in fp8_exporter.py must be
replaced with an explicit runtime check that always runs: check the condition on
dq_op.op and if it fails raise a RuntimeError with the same descriptive message
(e.g., f"QDQ does not occur in pairs. You reached {dq_op.op}"); update the code
around the existing dq_op.op check rather than using assert so the validation
remains active under optimized runs.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@examples/torch_onnx/vit_mha_quantization.py`:
- Around line 428-429: Add a boolean CLI flag (e.g., --trust_remote_code)
defaulting to False to the argument parser and expose it as
args.trust_remote_code, then pass that value into the model/component loading
calls (replace ViTImageProcessor.from_pretrained(args.model_name) and
ViTForImageClassification.from_pretrained(args.model_name) with calls that
include trust_remote_code=args.trust_remote_code) so callers can opt-in to
remote code execution while keeping the default safe.

---

Nitpick comments:
In `@modelopt/onnx/export/fp8_exporter.py`:
- Around line 79-81: The assertion using assert dq_op.op ==
"TRT_FP8DequantizeLinear" in fp8_exporter.py must be replaced with an explicit
runtime check that always runs: check the condition on dq_op.op and if it fails
raise a RuntimeError with the same descriptive message (e.g., f"QDQ does not
occur in pairs. You reached {dq_op.op}"); update the code around the existing
dq_op.op check rather than using assert so the validation remains active under
optimized runs.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro Plus

Run ID: 84a8283a-4d75-4771-9733-e78e49eaf910

📥 Commits

Reviewing files that changed from the base of the PR and between d6533ac and c436553.

📒 Files selected for processing (8)
  • examples/torch_onnx/vit_mha_quantization.py
  • modelopt/onnx/export/fp8_exporter.py
  • modelopt/onnx/utils.py
  • modelopt/torch/_deploy/utils/torch_onnx.py
  • modelopt/torch/quantization/export_onnx.py
  • modelopt/torch/quantization/nn/__init__.py
  • modelopt/torch/quantization/nn/modules/quant_layernorm.py
  • modelopt/torch/quantization/plugins/huggingface.py
✅ Files skipped from review due to trivial changes (2)
  • modelopt/torch/quantization/nn/modules/quant_layernorm.py
  • modelopt/torch/quantization/plugins/huggingface.py
🚧 Files skipped from review as they are similar to previous changes (4)
  • modelopt/torch/quantization/nn/init.py
  • modelopt/torch/quantization/export_onnx.py
  • modelopt/torch/_deploy/utils/torch_onnx.py
  • modelopt/onnx/utils.py

Comment on lines +428 to +429
processor = ViTImageProcessor.from_pretrained(args.model_name)
base_model = ViTForImageClassification.from_pretrained(args.model_name).eval().to(device)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Expose trust_remote_code as a CLI parameter defaulting to False.

The --model_name argument allows users to specify any HuggingFace model. Some models require trust_remote_code=True, which enables execution of arbitrary Python shipped with the checkpoint. Per coding guidelines, this should be a caller-configurable parameter defaulting to False.

Suggested fix
     parser.add_argument("--skip_onnx_ptq", action="store_true", help="Skip ONNX PTQ path")
+    parser.add_argument(
+        "--trust_remote_code",
+        action="store_true",
+        help="Trust remote code when loading HuggingFace models (security risk)",
+    )
     args = parser.parse_args()

Then update the loading calls:

-    processor = ViTImageProcessor.from_pretrained(args.model_name)
-    base_model = ViTForImageClassification.from_pretrained(args.model_name).eval().to(device)
+    processor = ViTImageProcessor.from_pretrained(
+        args.model_name, trust_remote_code=args.trust_remote_code
+    )
+    base_model = ViTForImageClassification.from_pretrained(
+        args.model_name, trust_remote_code=args.trust_remote_code
+    ).eval().to(device)

As per coding guidelines: "Do not hardcode trust_remote_code=True when loading Hugging Face Transformers models. Let the caller decide via a parameter; default to False."

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@examples/torch_onnx/vit_mha_quantization.py` around lines 428 - 429, Add a
boolean CLI flag (e.g., --trust_remote_code) defaulting to False to the argument
parser and expose it as args.trust_remote_code, then pass that value into the
model/component loading calls (replace
ViTImageProcessor.from_pretrained(args.model_name) and
ViTForImageClassification.from_pretrained(args.model_name) with calls that
include trust_remote_code=args.trust_remote_code) so callers can opt-in to
remote code execution while keeping the default safe.

@ajrasane ajrasane force-pushed the ajrasane/mha_quantization branch from c436553 to 9bfcb72 Compare April 20, 2026 14:54
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🧹 Nitpick comments (2)
modelopt/onnx/export/fp8_exporter.py (2)

428-434: Unreachable defensive check.

The check if consumer is q_node: continue at line 429 is unreachable because q_node is created at line 413-419 after consumers was captured at line 388. The q_node cannot be in the consumers list.

This is harmless but adds dead code that could confuse future readers.

♻️ Suggested removal
             for consumer in consumers:
-                if consumer is q_node:
-                    continue
                 for i, inp in enumerate(consumer.inputs):
                     if inp is softmax_output:
                         consumer.inputs[i] = dq_output
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/onnx/export/fp8_exporter.py` around lines 428 - 434, The loop
contains an unreachable defensive check "if consumer is q_node: continue"
because q_node is created after consumers was captured, so remove that check to
avoid dead code; in the block that iterates over consumers (the for consumer in
consumers: loop that replaces softmax_output with dq_output in consumer.inputs),
delete the "if consumer is q_node: continue" line and leave the replacement
logic intact (references: q_node, consumers, softmax_output, dq_output, count).

30-33: Consider deriving _FP8_E4M3_MAX from torch.finfo for consistency.

The hardcoded value 448.0 is correct, but modelopt/torch/quantization/qtensor/mxfp8_tensor.py uses torch.finfo(torch.float8_e4m3fn).max to obtain this value programmatically. Using the same pattern here would be more robust and self-documenting.

♻️ Suggested change
-# FP8 E4M3 max representable magnitude; softmax output in [0, 1] saturates exactly at 1.0
-# when using 1/448 as the Q scale.
-_FP8_E4M3_MAX = 448.0
+# FP8 E4M3 max representable magnitude; softmax output in [0, 1] saturates exactly at 1.0
+# when using 1/448 as the Q scale.
+_FP8_E4M3_MAX = torch.finfo(torch.float8_e4m3fn).max  # 448.0
 _FP8_E4M3_SOFTMAX_SCALE = 1.0 / _FP8_E4M3_MAX
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/onnx/export/fp8_exporter.py` around lines 30 - 33, The constant
_FP8_E4M3_MAX is hardcoded but should be derived from torch.finfo for
consistency with mxfp8_tensor; replace the literal 448.0 with
torch.finfo(torch.float8_e4m3fn).max and recompute _FP8_E4M3_SOFTMAX_SCALE as
1.0 / _FP8_E4M3_MAX; ensure torch is imported in this module and preserve the
existing constant names (_FP8_E4M3_MAX and _FP8_E4M3_SOFTMAX_SCALE) so other
references remain valid.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Nitpick comments:
In `@modelopt/onnx/export/fp8_exporter.py`:
- Around line 428-434: The loop contains an unreachable defensive check "if
consumer is q_node: continue" because q_node is created after consumers was
captured, so remove that check to avoid dead code; in the block that iterates
over consumers (the for consumer in consumers: loop that replaces softmax_output
with dq_output in consumer.inputs), delete the "if consumer is q_node: continue"
line and leave the replacement logic intact (references: q_node, consumers,
softmax_output, dq_output, count).
- Around line 30-33: The constant _FP8_E4M3_MAX is hardcoded but should be
derived from torch.finfo for consistency with mxfp8_tensor; replace the literal
448.0 with torch.finfo(torch.float8_e4m3fn).max and recompute
_FP8_E4M3_SOFTMAX_SCALE as 1.0 / _FP8_E4M3_MAX; ensure torch is imported in this
module and preserve the existing constant names (_FP8_E4M3_MAX and
_FP8_E4M3_SOFTMAX_SCALE) so other references remain valid.

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro Plus

Run ID: ec567acc-14ac-471b-b217-d45efe89af22

📥 Commits

Reviewing files that changed from the base of the PR and between c436553 and 9bfcb72.

📒 Files selected for processing (9)
  • CHANGELOG.rst
  • examples/torch_onnx/vit_mha_quantization.py
  • modelopt/onnx/export/fp8_exporter.py
  • modelopt/onnx/utils.py
  • modelopt/torch/_deploy/utils/torch_onnx.py
  • modelopt/torch/quantization/export_onnx.py
  • modelopt/torch/quantization/nn/__init__.py
  • modelopt/torch/quantization/nn/modules/quant_layernorm.py
  • modelopt/torch/quantization/plugins/huggingface.py
✅ Files skipped from review due to trivial changes (2)
  • CHANGELOG.rst
  • modelopt/torch/quantization/nn/modules/quant_layernorm.py
🚧 Files skipped from review as they are similar to previous changes (6)
  • modelopt/torch/quantization/nn/init.py
  • modelopt/torch/_deploy/utils/torch_onnx.py
  • modelopt/torch/quantization/plugins/huggingface.py
  • modelopt/torch/quantization/export_onnx.py
  • examples/torch_onnx/vit_mha_quantization.py
  • modelopt/onnx/utils.py

@ajrasane ajrasane force-pushed the ajrasane/mha_quantization branch from 9bfcb72 to ef8c769 Compare April 20, 2026 17:15
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

🧹 Nitpick comments (1)
modelopt/onnx/export/fp8_exporter.py (1)

391-392: Remove unreachable QuantizeLinear check in softmax rewrite.

After Line 389 (all(c.op == "MatMul" for c in consumers)), Line 391 cannot be true in this code path.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/onnx/export/fp8_exporter.py` around lines 391 - 392, In the softmax
rewrite inside fp8_exporter.py remove the redundant unreachable check that tests
for any(c.op == "QuantizeLinear" for c in consumers) after the preceding
all(c.op == "MatMul" for c in consumers) guard; update the block around the
softmax rewrite (look for the function/method handling the consumers list and
the if statements using all(...) and any(...)) to delete the second check and
its continue so the logic relies only on the MatMul consumer predicate.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@modelopt/onnx/export/fp8_exporter.py`:
- Around line 99-104: The code assumes a Cast node (identified as
cast_to_remove) can be fully cleared when a Transpose child is found, but that
disconnects other live consumers; change the removal logic so that after finding
candidate = Transpose you check cast_to_remove.outputs for other consumers
besides that Transpose and only remove the specific edge/consumer (or skip
removing the Cast entirely) instead of calling a blanket clear on
cast_to_remove.outputs; locate and update the removal at the code that clears
Cast outputs (the logic referencing cast_to_remove and candidate/Transpose later
around where outputs are cleared) to remove only the Transpose consumer (e.g.,
remove that output link or reroute it) while leaving other consumers intact.
- Around line 271-279: The rewrite moves Mul/Transpose across
Quantize/Dequantize but fails to guard upstream fanout; add single-consumer
checks before mutating upstream nodes: for the variables and nodes referenced
(dq_node, q_output, q_node, q_input, and the DequantizeLinear/QuantizeLinear
pairs) ensure the upstream variable that will be rewritten has outputs length ==
1 (i.e., only consumed by the DQ/transpose path) and that the DQ/transpose node
itself does not have other consumers, and only then perform the q_node.inputs[0]
mutation; apply the same single-consumer guard to the other similar blocks (the
code around the other uses of dq_node/q_node/q_output at the locations noted) so
unrelated branches aren’t modified.

---

Nitpick comments:
In `@modelopt/onnx/export/fp8_exporter.py`:
- Around line 391-392: In the softmax rewrite inside fp8_exporter.py remove the
redundant unreachable check that tests for any(c.op == "QuantizeLinear" for c in
consumers) after the preceding all(c.op == "MatMul" for c in consumers) guard;
update the block around the softmax rewrite (look for the function/method
handling the consumers list and the if statements using all(...) and any(...))
to delete the second check and its continue so the logic relies only on the
MatMul consumer predicate.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro Plus

Run ID: 8107c4b4-ca4e-4406-b813-c8988638709b

📥 Commits

Reviewing files that changed from the base of the PR and between 9bfcb72 and ef8c769.

📒 Files selected for processing (8)
  • CHANGELOG.rst
  • modelopt/onnx/export/fp8_exporter.py
  • modelopt/onnx/utils.py
  • modelopt/torch/_deploy/utils/torch_onnx.py
  • modelopt/torch/quantization/export_onnx.py
  • modelopt/torch/quantization/nn/__init__.py
  • modelopt/torch/quantization/nn/modules/quant_layernorm.py
  • modelopt/torch/quantization/plugins/huggingface.py
✅ Files skipped from review due to trivial changes (1)
  • modelopt/torch/quantization/nn/modules/quant_layernorm.py
🚧 Files skipped from review as they are similar to previous changes (6)
  • modelopt/torch/quantization/nn/init.py
  • modelopt/torch/_deploy/utils/torch_onnx.py
  • modelopt/torch/quantization/plugins/huggingface.py
  • modelopt/onnx/utils.py
  • modelopt/torch/quantization/export_onnx.py
  • CHANGELOG.rst

Comment thread modelopt/onnx/export/fp8_exporter.py
Comment thread modelopt/onnx/export/fp8_exporter.py
@ajrasane ajrasane force-pushed the ajrasane/mha_quantization branch from ef8c769 to 928d417 Compare April 20, 2026 17:31
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🧹 Nitpick comments (1)
examples/torch_onnx/torch_quant_to_onnx.py (1)

186-215: Consider adding a fallback for self.attn_dim to support older timm versions.

Line 210 uses self.attn_dim, which is a relatively recent addition to timm's Attention class (circa late 2025). Since pyproject.toml does not pin a specific timm version, older releases may lack this attribute. If broad timm compatibility is intended, use: getattr(self, 'attn_dim', self.num_heads * self.head_dim) to gracefully handle versions that compute this value dynamically instead of exposing it as an attribute.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@examples/torch_onnx/torch_quant_to_onnx.py` around lines 186 - 215, In
_vit_attention_forward replace direct use of self.attn_dim with a safe fallback
so older timm versions that lack the attribute don't break: compute attn_dim =
getattr(self, "attn_dim", self.num_heads * self.head_dim) and use that variable
when reshaping (and anywhere else self.attn_dim is referenced) so the method
works whether the attribute exists or must be derived from num_heads and
head_dim.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Nitpick comments:
In `@examples/torch_onnx/torch_quant_to_onnx.py`:
- Around line 186-215: In _vit_attention_forward replace direct use of
self.attn_dim with a safe fallback so older timm versions that lack the
attribute don't break: compute attn_dim = getattr(self, "attn_dim",
self.num_heads * self.head_dim) and use that variable when reshaping (and
anywhere else self.attn_dim is referenced) so the method works whether the
attribute exists or must be derived from num_heads and head_dim.

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro Plus

Run ID: 08cceb9c-1496-4b86-8b2f-60f5dbdbd74c

📥 Commits

Reviewing files that changed from the base of the PR and between ef8c769 and 928d417.

📒 Files selected for processing (10)
  • CHANGELOG.rst
  • examples/torch_onnx/torch_quant_to_onnx.py
  • modelopt/onnx/export/fp8_exporter.py
  • modelopt/onnx/utils.py
  • modelopt/torch/_deploy/utils/torch_onnx.py
  • modelopt/torch/quantization/export_onnx.py
  • modelopt/torch/quantization/nn/__init__.py
  • modelopt/torch/quantization/nn/modules/quant_layernorm.py
  • modelopt/torch/quantization/nn/modules/quant_softmax.py
  • modelopt/torch/quantization/plugins/huggingface.py
✅ Files skipped from review due to trivial changes (2)
  • CHANGELOG.rst
  • modelopt/torch/quantization/export_onnx.py
🚧 Files skipped from review as they are similar to previous changes (4)
  • modelopt/torch/_deploy/utils/torch_onnx.py
  • modelopt/torch/quantization/nn/modules/quant_layernorm.py
  • modelopt/torch/quantization/plugins/huggingface.py
  • modelopt/torch/quantization/nn/init.py

@ajrasane ajrasane force-pushed the ajrasane/mha_quantization branch 3 times, most recently from 48d8486 to ce9165e Compare April 20, 2026 18:44
Comment thread CHANGELOG.rst Outdated
@ajrasane ajrasane force-pushed the ajrasane/mha_quantization branch 2 times, most recently from f6f62a3 to 9af7e18 Compare April 20, 2026 19:22
@ajrasane ajrasane changed the title Add FP8 MHA quantization support for HuggingFace ViT [OMNIML-3349] Add FP8 MHA quantization support for HuggingFace ViT Apr 21, 2026
Copy link
Copy Markdown
Collaborator

@cjluo-nv cjluo-nv left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All critical issues from previous reviews have been addressed: unit tests added (4 test files), bare assert replaced with RuntimeError, all graph rewrites guarded with all(MatMul) fanout checks, opset guards added to FP16 scale folding, FP16 overflow warning added, per-instance nested attention detection implemented, and unused parameter documented. The remaining minor items (unreachable dead code in _insert_qdq_after_softmax, copyright year 2024 on new files) are non-blocking.

Enables TensorRT attention-v2 fusion for HuggingFace ViT (and similar
transformer vision models) when exported to ONNX with FP8 Q/DQ.

- fp8_exporter: rewrite attention-scaling Mul and K Transpose to the
  Q-side so DQ feeds MatMul directly, pre-transpose weight constants,
  insert FP8 Q/DQ on Softmax outputs for MHA-v2 fusion. Scale dtype
  now matches the graph's float dtype to keep strongly-typed builds
  consistent.
- onnx/utils: fold Cast(FP16<->FP32) nodes that convert_float_to_float16
  inserts around Q/DQ by rewriting scale initializers to FP16, so TRT
  fuses DQ into the downstream GEMM/MatMul kernel.
- torch/quantization/export_onnx: keep FP8 Q/DQ scale in the native
  input dtype so no Cast is injected between graph and Q/DQ.
- torch/quantization/nn: register nn.LayerNorm in QuantModuleRegistry
  so LayerNorm output quantizers are honored.
- torch/quantization/plugins/huggingface: skip attention wrappers whose
  children are also "*Attention" to avoid double-patching
  eager_attention_forward (e.g. ViTAttention vs ViTSelfAttention).

Example: examples/torch_onnx/vit_mha_quantization.py shows a ViT-FP8
config (extends FP8_DEFAULT_CFG with LayerNorm output quantizer,
disabled input quantizers on LayerNorm-followed layers, and
*_bmm_quantizer entries) plus accuracy + TRT-latency comparison
against an FP16 baseline.

Measured on ViT-base-patch16-224 (RTX 6000 Ada, batch=1):
- Top-1 / top-5 on 5k ImageNet-val: 81.16% / 95.50% (FP16) vs
  80.96% / 95.44% (torch FP8) — -0.20% / -0.06%
- TRT latency: 0.721 ms (FP16) vs 0.646 ms (torch FP8) — 1.12x speedup

Signed-off-by: ajrasane <131806219+ajrasane@users.noreply.github.com>
Two independent bugs surfaced by the parametrized matrix in
tests/examples/torch_onnx/test_torch_quant_to_onnx.py:

- MXFP8/NVFP4 lower input quantizers to TRT DynamicQuantize, which only
  supports 2D/3D input. Swin/SwinV2 keep the 4D (B, H, W, C) layout on
  per-block norm1, downsample.norm, and the top-level norm, causing
  trtexec (MXFP8) and the NVFP4 autocast TRT-shape-inference pre-pass
  to reject the graph. Added _disable_high_rank_input_quantizers which
  runs a forward-pass rank probe and disables quantizers on 4D+ inputs;
  gated on mxfp8 / nvfp4 / auto so FP8 and INT8 still quantize those
  layers (their Q/DQ has no rank constraint). Name-based alternatives
  would false-positive on ViT, whose same-named top-level norm is 3D.

- swinv2_tiny-fp8 hit ZeroDivisionError in export_fp8 (448 / amax):
  timm's res-post-norm scheme zero-inits each SwinV2 block's norm1 /
  norm2 weight and bias, so under --no_pretrained those LayerNorm
  outputs are exactly zero, and the FP8 MHA override's output_quantizer
  calibrates to amax == 0. Added _disable_dead_quantizers to drop any
  quantizer whose calibrated amax is NaN or <= 0 before export.

Full matrix (4 models x 5 modes) now passes: 20/20 in ~33 min.

Signed-off-by: ajrasane <131806219+ajrasane@users.noreply.github.com>
The test module imports from modelopt.torch.quantization.plugins.huggingface,
which imports transformers at module scope. Under the partial-install (torch)
CI job — which installs only torch, without transformers/onnx/diffusers —
collection failed with ModuleNotFoundError, taking the whole unit-torch
partial-install step down.

Add pytest.importorskip("transformers") before the plugin import, matching
the pattern used by the sibling test_fused_experts.py.

Signed-off-by: ajrasane <131806219+ajrasane@users.noreply.github.com>
CI (Blackwell, compute capability 12.0) fails TRT engine build for resnet50
under fp8 / mxfp8 / nvfp4 / auto:

  Error Code 10: Could not find any implementation for node
  /conv1/input_quantizer/TRT_FP8QuantizeLinear ... [ElementWise]

The node is ResNet50's top-level conv1 (7x7 stride-2, in_channels=3). TRT's
Blackwell tactics for FP8 Q -> Conv fusion don't cover the raw-RGB (3-channel)
first-layer pattern. Ada (compute capability 8.9, the local dev GPU) happens
to have a tactic, which is why the matrix passed locally.

Swin/ViT avoid this because their first conv (patch_embed.proj, also 3-channel)
is already excluded by filter_func's patch_embed pattern. ResNet50's conv1
wasn't on any list.

Add _disable_low_channel_conv_input_quantizers to disable the input_quantizer
on any Conv2d with in_channels <= 3 for FP8-family modes. Weight quantization
is preserved. This also aligns with standard quantization practice (leave
first/last layers in higher precision).

INT8 is unchanged - INT8 Q/DQ has broader TRT kernel coverage on Blackwell
and built successfully in CI.

Signed-off-by: ajrasane <131806219+ajrasane@users.noreply.github.com>
@ajrasane ajrasane force-pushed the ajrasane/mha_quantization branch from 6537f37 to a9f87bf Compare April 23, 2026 14:06
@ajrasane ajrasane merged commit e4e3508 into main Apr 23, 2026
47 checks passed
@ajrasane ajrasane deleted the ajrasane/mha_quantization branch April 23, 2026 15:38
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants