Skip to content

[Quantization] Saturate NVFP4 export FP8 scale cast to avoid NaN#1397

Merged
cjluo-nv merged 1 commit intomainfrom
chenjiel/fix-nvfp4-fp8-scale-cast-nan
May 6, 2026
Merged

[Quantization] Saturate NVFP4 export FP8 scale cast to avoid NaN#1397
cjluo-nv merged 1 commit intomainfrom
chenjiel/fix-nvfp4-fp8-scale-cast-nan

Conversation

@cjluo-nv
Copy link
Copy Markdown
Collaborator

@cjluo-nv cjluo-nv commented May 6, 2026

Summary

  • Saturates per_block_scale * 448 / per_block_scale_max to ≤ 448 before the to(torch.float8_e4m3fn) cast in NVFP4QTensor.get_weights_scaling_factor_from_quantizer.
  • Adds a regression test that reproduces the NaN byte without the clamp.

Why

When _amax contains a zero entry (e.g. an all-zero weight block left untouched by max calibration), the existing per_block_scale[per_block_scale == 0] = 1.0 safety net drives the pre-cast value to 1.0 * 448 / (global_amax / 6). fp8_e4m3fn has no Inf — anything ≥ 480 rounds to NaN — so a 0x7F byte slips into the exported weight_scale.

This was observed in a saved Kimi-K2.6-NVFP4-MSE checkpoint at language_model.model.layers.1.mlp.experts.21.down_proj.weight_scale[4001, 18]. The MSE FP8 sweep itself never produces zero per-block amax (it always emits at least c[0] * global_amax), but any export path where _amax ends up zero — including pure max calibration — hits the bug. With the clamp the byte saturates to 0x7E (= 448, fp8 max finite) and dequantization is unaffected: the FP4 nibbles for an all-zero block are all 0, so 0 × 448 × weight_scale_2 = 0 regardless of the stored fp8 scale. For non-degenerate blocks the clamp is a no-op since per_block_amax ≤ global_amax already bounds the pre-cast value at 448.

Test plan

  • New regression test test_export_fp8_scale_no_nan_for_zero_amax_block fails on main's export code (reproduces the 0x7F NaN byte) and passes with the clamp.
  • Existing tests in tests/gpu/torch/quantization/test_nvfp4_static_quantizer_cuda.py still pass (10/10).

🤖 Generated with Claude Code

Summary by CodeRabbit

  • Bug Fixes
    • Improved numerical stability in FP8 quantization scaling by preventing overflow and NaN conditions
    • Enhanced handling of edge cases in quantization processing for zero-weight blocks

When ``NVFP4QTensor.get_weights_scaling_factor_from_quantizer`` runs the static
path on a weight that contains an all-zero block (per-block amax == 0), the
existing ``per_block_scale[per_block_scale == 0] = 1.0`` safety net drives the
pre-cast value to ``1.0 * 448 / (global_amax / 6)``. ``fp8_e4m3fn`` has no Inf,
so any value >= 480 rounds to NaN — silently writing a 0x7F byte into the
exported ``weight_scale``.

Saturating to 448 before the cast keeps the stored byte finite. The all-zero
block dequantizes to zero regardless of the stored fp8 scale (the FP4 nibbles
are all 0), so dequantization is unaffected. For non-degenerate blocks the
clamp is a no-op since ``per_block_amax <= global_amax`` already bounds the
pre-cast value at 448.

Includes a regression test that reproduces the NaN without the clamp.

Signed-off-by: Chenjie Luo <chenjiel@nvidia.com>
@cjluo-nv cjluo-nv requested a review from a team as a code owner May 6, 2026 06:58
@cjluo-nv cjluo-nv requested a review from shengliangxu May 6, 2026 06:58
@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented May 6, 2026

📝 Walkthrough

Walkthrough

This PR adds overflow safety to FP8 scale factor computation in NVFP4 quantization by clamping per-block scaling factors to 448.0 before casting to FP8, and introduces a regression test verifying no NaN values are emitted when exporting scales for all-zero weight blocks.

Changes

FP8 Scaling Safety and Regression Test

Layer / File(s) Summary
Core Implementation
modelopt/torch/quantization/qtensor/nvfp4_tensor.py
Per-block scaling factors are explicitly clamped to max 448.0 before casting to torch.float8_e4m3fn, preventing overflow with added clarifying comments.
Regression Test
tests/gpu/torch/quantization/test_nvfp4_static_quantizer_cuda.py
New test verifies FP8 scale export behavior for a multi-block weight where the second block contains all zeros, asserting no NaN values and correct saturation handling. Imports NVFP4QTensor for test fixture construction.

Estimated code review effort

🎯 2 (Simple) | ⏱️ ~12 minutes

🚥 Pre-merge checks | ✅ 6
✅ Passed checks (6 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title clearly and specifically describes the main change: adding saturation to NVFP4's FP8 scale cast to prevent NaN values, which directly aligns with the core fix in the changeset.
Docstring Coverage ✅ Passed Docstring coverage is 100.00% which is sufficient. The required threshold is 80.00%.
Linked Issues check ✅ Passed Check skipped because no linked issues were found for this pull request.
Out of Scope Changes check ✅ Passed Check skipped because no linked issues were found for this pull request.
Security Anti-Patterns ✅ Passed No security anti-patterns found. PR adds clamping logic and regression test only, with no unsafe deserialization, remote code execution, eval/exec, nosec, or new dependencies.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
📝 Generate docstrings
  • Create stacked PR
  • Commit on current branch
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Commit unit tests in branch chenjiel/fix-nvfp4-fp8-scale-cast-nan

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown
Contributor

@meenchen meenchen left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bot review — DM the bot to share feedback.

Clean, targeted bug fix for a real NaN issue in NVFP4 FP8 scale export when zero-amax blocks are present. The .clamp_(max=448.0) before the to(torch.float8_e4m3fn) cast is correct — fp8_e4m3fn has no Inf representation and values ≥ 480 become NaN. The clamp is a no-op for non-degenerate blocks (where per_block_amax ≤ global_amax bounds the value at 448), and for zero blocks the FP4 nibbles are all zero so the scale value is irrelevant to dequantization correctness. The regression test is thorough — it constructs the exact triggering condition, verifies no NaN bytes, and checks the saturated value is 0x7E (448). +55/-3 lines, well-scoped.

Testing: Test plan has 1 unchecked item(s) out of 3. Finish or remove them before approving.

@cjluo-nv cjluo-nv requested review from a team, Fridah-nv and realAsma May 6, 2026 07:01
@github-actions
Copy link
Copy Markdown
Contributor

github-actions Bot commented May 6, 2026

PR Preview Action v1.8.1
Preview removed because the pull request was closed.
2026-05-06 15:33 UTC

@cjluo-nv cjluo-nv added the cherry-pick-0.44.0 After code freeze, cherry-pick to release branch for next rc (bulk update). Only for bug fixes / doc label May 6, 2026
Copy link
Copy Markdown
Contributor

@meenchen meenchen left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bot review — DM the bot to share feedback.

Re-review: The previous bot comment flagged unchecked test plan items — those are now all checked (2/2). The fix is correct, minimal (+55/-3), well-commented, and the regression test thoroughly covers the exact triggering condition. No outstanding issues.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Inline comments:
In `@modelopt/torch/quantization/qtensor/nvfp4_tensor.py`:
- Around line 125-135: get_weights_scaling_factor() can produce NaN when both
per_block_amax and weights_scaling_factor_2 are zero (all-zero tensor), which
later breaks quantize(); update get_weights_scaling_factor (and/or add an early
all-zero fast path in quantize) to detect the zero-denominator case and
short-circuit: if per_block_amax == 0 or weights_scaling_factor_2 == 0, set
per_block_scale to a safe finite value (e.g., 0.0) or return an all-zero
quantized result immediately; ensure the same clamp/finite handling applied in
the non-static branch (the per_block_scale fixup used elsewhere) is applied here
so per_block_scale is never NaN before division in quantize().
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Enterprise

Run ID: 688f0594-3925-40c3-8d6a-56dc7abe8833

📥 Commits

Reviewing files that changed from the base of the PR and between f34f488 and 5ecab66.

📒 Files selected for processing (2)
  • modelopt/torch/quantization/qtensor/nvfp4_tensor.py
  • tests/gpu/torch/quantization/test_nvfp4_static_quantizer_cuda.py

Comment on lines +125 to 135
# Quantize scales to FP8. Saturate to the fp8_e4m3fn max (448) before the
# cast: when the [==0]=1.0 safety net above fires (per_block_amax was zero
# for an all-zero weight block) and global_amax is small, the pre-cast value
# explodes to ``1.0 * 448 / (global_amax/6)``. fp8_e4m3fn has no Inf, so any
# value >= 480 casts to NaN — clamp first to keep the stored byte finite.
if not keep_high_precision:
per_block_scale = (per_block_scale * 448.0 / per_block_scale_max).to(
torch.float8_e4m3fn
per_block_scale = (
(per_block_scale * 448.0 / per_block_scale_max)
.clamp_(max=448.0)
.to(torch.float8_e4m3fn)
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major | ⚡ Quick win

This only closes the static branch; the dynamic all-zero path can still produce NaNs.

get_weights_scaling_factor() still computes per_block_amax / (6 * weights_scaling_factor_2), and for an all-zero tensor both terms are zero. That leaves NaN in per_block_scale because the later per_block_scale == 0 fixup does not catch it, and quantize() then divides by weights_scaling_factor * weights_scaling_factor_2 with a zero denominator. Please add the same zero-amax handling there, or an early all-zero fast path, so this fix covers both code paths.

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@modelopt/torch/quantization/qtensor/nvfp4_tensor.py` around lines 125 - 135,
get_weights_scaling_factor() can produce NaN when both per_block_amax and
weights_scaling_factor_2 are zero (all-zero tensor), which later breaks
quantize(); update get_weights_scaling_factor (and/or add an early all-zero fast
path in quantize) to detect the zero-denominator case and short-circuit: if
per_block_amax == 0 or weights_scaling_factor_2 == 0, set per_block_scale to a
safe finite value (e.g., 0.0) or return an all-zero quantized result
immediately; ensure the same clamp/finite handling applied in the non-static
branch (the per_block_scale fixup used elsewhere) is applied here so
per_block_scale is never NaN before division in quantize().

@codecov
Copy link
Copy Markdown

codecov Bot commented May 6, 2026

Codecov Report

✅ All modified and coverable lines are covered by tests.
✅ Project coverage is 76.82%. Comparing base (f34f488) to head (5ecab66).

Additional details and impacted files
@@            Coverage Diff             @@
##             main    #1397      +/-   ##
==========================================
+ Coverage   76.73%   76.82%   +0.09%     
==========================================
  Files         476      476              
  Lines       51306    51306              
==========================================
+ Hits        39369    39418      +49     
+ Misses      11937    11888      -49     
Flag Coverage Δ
examples 41.80% <0.00%> (+2.62%) ⬆️
gpu 59.84% <100.00%> (-0.59%) ⬇️
regression 15.20% <0.00%> (+0.07%) ⬆️
unit 52.53% <0.00%> (ø)

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@cjluo-nv cjluo-nv merged commit 097293b into main May 6, 2026
48 checks passed
@cjluo-nv cjluo-nv deleted the chenjiel/fix-nvfp4-fp8-scale-cast-nan branch May 6, 2026 15:33
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

cherry-pick-0.44.0 After code freeze, cherry-pick to release branch for next rc (bulk update). Only for bug fixes / doc

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants