Skip to content

Refine static NVFP4 MSE calibration#1536

Open
realAsma wants to merge 12 commits into
mainfrom
asma/mse_cleanups
Open

Refine static NVFP4 MSE calibration#1536
realAsma wants to merge 12 commits into
mainfrom
asma/mse_cleanups

Conversation

@realAsma
Copy link
Copy Markdown
Contributor

@realAsma realAsma commented May 22, 2026

What does this PR do?

Type of change: Bug fix

Refines static NVFP4 MSE calibration so FP8-scale sweep calibration, max-calibration finalization, and export-facing static quantizer state stay consistent.

Main changes:

  • Consolidate the MSE calibration path around one user-facing switch: fp8_scale_sweep. Registered custom backends can provide an FP8 sweep calibrator, ModelOpt static NVFP4 weights use NVFP4MSECalibrator, and other weight quantizers keep the regular multiplier MSE search.
  • Keep static NVFP4 per-block amax in fp32. NVFP4MSECalibrator now caches the final per-block amax as torch.float32 for both the Triton fast path and reference sweep, and TensorQuantizer.load_calib_amax() preserves the loaded amax dtype after validating shape.
  • Make NVFP4MSECalibrator one-shot between resets. collect() computes and caches the final per-block amax immediately, clears the large reference loss accumulator, and compute_amax() returns the cached result.
  • Centralize max-stat collection/loading in model_calib.py, and move static NVFP4 finalization into max_calibrate(). Finalization bootstraps missing weight amax values for skipped/dead experts, promotes eligible quantizers to NVFP4StaticQuantizer, and synchronizes grouped Q/K/V and gate/up global amax values.
  • Keep promote_nvfp4_static_quantizers() as the public promotion helper, with the model-calibration path delegating to it before grouped global-amax sync.
  • Update tests for fp32 amax preservation, one-shot collect/reset behavior, static NVFP4 promotion and global-amax sync, dead-expert bootstrap coverage, registered FP8 sweep calibrator dispatch, and removed duplicate GPU coverage.

Usage

For static NVFP4 weight-MSE calibration, use the NVFP4 W4A4 FP8-sweep preset or set the MSE algorithm field directly:

algorithm:
  method: mse
  fp8_scale_sweep: true

fp8_scale_sweep applies to ModelOpt static NVFP4 weight quantizers and registered custom backends with FP8 sweep support. Other weight quantizers use the regular multiplier search.

Testing

Focused validation used while developing this PR included:

python_pwd -m ruff check modelopt/torch/quantization/calib/mse.py modelopt/torch/quantization/model_calib.py modelopt/torch/quantization/nn/modules/tensor_quantizer.py modelopt/torch/quantization/utils/core_utils.py tests/unit/torch/quantization/test_mse_calibrator.py tests/gpu/torch/quantization/test_nvfp4_fp8_sweep_kernel.py tests/gpu/torch/quantization/test_nvfp4_static_quantizer_cuda.py
pytest_pwd tests/unit/torch/quantization/test_mse_calibrator.py::TestRegisterFP8SweepCalibrator tests/unit/torch/quantization/test_mse_calibrator.py::TestStaticNVFP4Promotion -q
pytest_pwd tests/unit/torch/quantization/test_calib.py::test_awq_lite tests/unit/torch/quantization/test_calib.py::test_awq_full tests/unit/torch/quantization/test_calib.py::test_awq_clip -q
pytest_pwd CUDA_VISIBLE_DEVICES=2 tests/gpu/torch/quantization/test_real_quantize_cuda.py::test_real_quantize_linear -q
pytest_pwd tests/gpu/torch/quantization/test_nvfp4_static_quantizer_cuda.py::TestNVFP4MSECalibrator::test_basic_initialization tests/gpu/torch/quantization/test_nvfp4_static_quantizer_cuda.py::TestNVFP4MSECalibrator::test_compute_amax_before_collect_returns_none tests/gpu/torch/quantization/test_nvfp4_static_quantizer_cuda.py::TestNVFP4MSECalibrator::test_collect_and_compute_amax tests/gpu/torch/quantization/test_nvfp4_static_quantizer_cuda.py::TestNVFP4MSECalibrator::test_reference_path_reset_allows_recollect -q
pytest_pwd tests/gpu/torch/quantization/test_nvfp4_static_quantizer_cuda.py::TestNVFP4MSECalibrator::test_collect_and_compute_amax tests/gpu/torch/quantization/test_nvfp4_fp8_sweep_kernel.py::test_dispatch_fast_path_default tests/gpu/torch/quantization/test_nvfp4_fp8_sweep_kernel.py::test_dispatch_custom_error_func_falls_back -q
git diff --check

GitHub CI is also running the standard unit, regression, GPU, example, docs, DCO, and code-quality workflows for the PR.

Before your PR is "Ready for review"

Make sure you read and follow Contributor guidelines and your commits are signed (git commit -s -S).

Make sure you read and follow the Security Best Practices (e.g. avoiding hardcoded trust_remote_code=True, torch.load(..., weights_only=False), pickle, etc.).

  • Is this change backward compatible?: Yes
  • If you copied code from any other sources or added a new PIP dependency, did you follow guidance in CONTRIBUTING.md: N/A
  • Did you write any new necessary tests?: Yes
  • Did you update Changelog?: N/A
  • Did you get Claude approval on this PR?: Yes

Additional Information

N/A

Signed-off-by: realAsma <akuriparambi@nvidia.com>

minor

Signed-off-by: realAsma <akuriparambi@nvidia.com>

Cache NVFP4 MSE amax in fp32

Signed-off-by: realAsma <akuriparambi@nvidia.com>
@realAsma realAsma requested review from a team as code owners May 22, 2026 17:25
@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented May 22, 2026

Note

Reviews paused

It looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the reviews.auto_review.auto_pause_after_reviewed_commits setting.

Use the following commands to manage reviews:

  • @coderabbitai resume to resume automatic reviews.
  • @coderabbitai review to trigger a single review.

Use the checkboxes below for quick actions:

  • ▶️ Resume reviews
  • 🔍 Trigger review
📝 Walkthrough

Walkthrough

This PR refactors NVFP4 MSE calibration to cache final per-block amax immediately as float32 in a one-shot cycle, centralizes max-stat collection and per-weight MSE calibrator dispatch, tightens NVFP4-static promotion and global-amax sync, updates TensorQuantizer amax buffer handling, and expands tests for dtype and one-shot semantics.

Changes

NVFP4 MSE Calibration and Static Promotion Refactor

Layer / File(s) Summary
NVFP4MSECalibrator one-shot fp32 caching and validation
modelopt/torch/quantization/calib/mse.py
NVFP4MSECalibrator documents one-shot Triton sweep; _compute_candidate_amax normalizes candidates/global_amax to float32; collect() enforces single-cycle semantics via _best_amax_fast and routes to Triton fast-path or reference helper; compute_amax() returns cached fp32 result or None.
MSE config, calibrator dispatch, and per-weight factory
modelopt/torch/quantization/config.py, modelopt/torch/quantization/model_calib.py
MseCalibConfig docs corrected; new gating helper and _make_weight_mse_calibrator() centralize per-weight eligibility and calibrator selection (registered FP8 sweep, NVFP4MSECalibrator, or MseCalibrator); mse_calibrate() refactored to run max_calibrate then per-quantizer MSE refinement.
Max calibration refactor with centralized stats collection
modelopt/torch/quantization/model_calib.py
Adds _run_and_load_max_stats() to unify enable-stats → run → finish-load; tightens _is_calibrated_nvfp4_static; bootstraps eligible uncalibrated static weight quantizers via weight-only stat re-runs; max_calibrate reworked to use centralized lifecycle.
Max calibrate ordering, bootstrap, and gptq removal
modelopt/torch/quantization/model_calib.py
Reorders max_calibrate to run core lifecycle, performs local MoE expert amax sync, replaces unconditional promotion with wrapper that runs promote+group preprocessing, and removes initial promote_nvfp4_static_quantizers call from gptq and post-max hook.
TensorQuantizer load_calib_amax buffer handling
modelopt/torch/quantization/nn/modules/tensor_quantizer.py
load_calib_amax always clones/detaches and replaces internal _amax buffer aligned to target device; validates shape when existing amax exists; removes in-place .data.copy and register_buffer path.
NVFP4 static promotion and core utils update
modelopt/torch/quantization/utils/core_utils.py
promote_nvfp4_static_quantizers now promotes enabled TensorQuantizer instances marked is_nvfp4_static with non-None amax, computes global_amax from module.amax (clone/detach + reduce_amax), and constructs NVFP4StaticQuantizer.from_tensor_quantizer while counting only new conversions.
NVFP4StaticQuantizer and NVFP4MSECalibrator CUDA tests
tests/gpu/torch/quantization/test_nvfp4_static_quantizer_cuda.py
Adds CUDA test verifying load_calib_amax preserves fp32 dtype when existing amax buffer lower-precision; updates NVFP4MSECalibrator tests to assert compute_amax() returns None before collect finalization, expects reference-path to cache _best_amax_fast and leave _losses_sum None, and replaces multi-collection test with reset-based recollect scenario.
NVFP4MSECalibrator fp32 amax dtype and fast-path caching tests
tests/gpu/torch/quantization/test_nvfp4_fp8_sweep_kernel.py
Adds Triton-gated test to verify per-block amax stored as float32 and output dtype preserved; strengthens reset/recollect and dispatch tests to assert _best_amax_fast set and _losses_sum None; end-to-end mse test extended to assert NVFP4 static quantizer amax dtypes are torch.float32 across runs.
Test infra, fused-experts dead-expert validation, and mse dispatch tests
tests/gpu/torch/quantization/test_gptq.py, tests/unit/torch/quantization/plugins/test_fused_experts.py, tests/unit/torch/quantization/test_mse_calibrator.py
test_gptq.py imports updated to use calib_utils and no longer calls promote utility in test setup; fused-experts dead-expert test refactored to validate max_calibrate populates dead static NVFP4 quantizers and weight config updated to static NVFP4 style; test_mse_calibrator imports/dispatch tests extended and new TestStaticNVFP4Promotion class added to validate promotion and grouped global_amax sync.
Example and docstring updates
examples/llm_ptq/cast_mxfp4_to_nvfp4.py
Minor docstring wording change to reference static NVFP4 finalization as the pickup mechanism for forced block_sizes['type']='static' entries.

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~60 minutes

Suggested reviewers

  • meenchen
  • cjluo-nv
  • Edwardf0t1
🚥 Pre-merge checks | ✅ 6
✅ Passed checks (6 passed)
Check name Status Explanation
Title check ✅ Passed The title "Refine static NVFP4 MSE calibration" directly and specifically describes the primary purpose of the pull request: refining the MSE calibration behavior for static NVFP4 quantizers.
Docstring Coverage ✅ Passed Docstring coverage is 84.91% which is sufficient. The required threshold is 80.00%.
Linked Issues check ✅ Passed Check skipped because no linked issues were found for this pull request.
Out of Scope Changes check ✅ Passed Check skipped because no linked issues were found for this pull request.
Security Anti-Patterns ✅ Passed All 11 modified files passed security review. No unsafe torch.load, numpy.load, trust_remote_code, eval/exec, nosec comments, or restricted-license dependencies found.
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Commit unit tests in branch asma/mse_cleanups

Comment @coderabbitai help to get the list of available commands and usage tips.

@github-actions
Copy link
Copy Markdown
Contributor

github-actions Bot commented May 22, 2026

PR Preview Action v1.8.1

QR code for preview link

🚀 View preview at
https://NVIDIA.github.io/Model-Optimizer/pr-preview/pr-1536/

Built to branch gh-pages at 2026-05-27 22:29 UTC.
Preview will be ready when the GitHub Pages deployment is complete.

Comment thread modelopt/torch/quantization/nn/modules/tensor_quantizer.py Outdated
Comment thread modelopt/torch/quantization/nn/modules/tensor_quantizer.py
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Warning

CodeRabbit couldn't request changes on this pull request because it doesn't have sufficient GitHub permissions.

Please grant CodeRabbit Pull requests: Read and write permission and re-run the review.

👉 Steps to fix this

Actionable comments posted: 3

🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Inline comments:
In `@modelopt/torch/quantization/model_calib.py`:
- Around line 214-221: The stats lifecycle in _run_and_load_max_stats is not
guarded: call enable_stats_collection(model) then run the forward path (either
weight_only_quantize(model) or forward_loop(model)) inside a try block and call
finish_stats_collection(model) in a finally block so finish_stats_collection
always executes even if the forward path raises; re-raise any caught exception
after the finally to preserve behavior. Reference functions:
_run_and_load_max_stats, enable_stats_collection, weight_only_quantize,
forward_loop, finish_stats_collection.

In `@tests/gpu/torch/quantization/test_nvfp4_fp8_sweep_kernel.py`:
- Line 295: The local import of TensorQuantizer inside the
test_mse_calibrate_end_to_end function should be moved to module scope: remove
the in-function import and add "from modelopt.torch.quantization.nn import
TensorQuantizer" to the top of the test file with the other imports so import
failures surface at collection time; update any references in
test_mse_calibrate_end_to_end to use the now-module-level TensorQuantizer and
ensure there is no justification comment left for an inside-function import.

In `@tests/unit/torch/quantization/test_mse_calibrator.py`:
- Around line 686-700: Move the in-test imports of
_promote_nvfp4_static_quantizers_with_global_amax_sync out of the individual
test methods and place them in the module-level import block (i.e., import
_promote_nvfp4_static_quantizers_with_global_amax_sync from
modelopt.torch.quantization.model_calib at the top of the test file) so tests
follow the guideline that imports belong at file scope; only keep them inside a
test if there is a documented circular/optional dependency reason.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Enterprise

Run ID: 39296e29-e4e1-4049-9048-513405b3ee9d

📥 Commits

Reviewing files that changed from the base of the PR and between 3ff15cc and a263e63.

📒 Files selected for processing (12)
  • examples/llm_ptq/cast_mxfp4_to_nvfp4.py
  • modelopt/torch/quantization/calib/mse.py
  • modelopt/torch/quantization/config.py
  • modelopt/torch/quantization/model_calib.py
  • modelopt/torch/quantization/nn/modules/tensor_quantizer.py
  • modelopt/torch/quantization/utils/core_utils.py
  • modelopt_recipes/configs/ptq/presets/model/nvfp4_w4a4_weight_mse_fp8_sweep.yaml
  • tests/gpu/torch/quantization/test_gptq.py
  • tests/gpu/torch/quantization/test_nvfp4_fp8_sweep_kernel.py
  • tests/gpu/torch/quantization/test_nvfp4_static_quantizer_cuda.py
  • tests/unit/torch/quantization/plugins/test_fused_experts.py
  • tests/unit/torch/quantization/test_mse_calibrator.py
💤 Files with no reviewable changes (2)
  • tests/gpu/torch/quantization/test_gptq.py
  • modelopt/torch/quantization/utils/core_utils.py

Comment thread modelopt/torch/quantization/model_calib.py
Comment thread tests/gpu/torch/quantization/test_nvfp4_fp8_sweep_kernel.py Outdated
Comment thread tests/unit/torch/quantization/test_mse_calibrator.py Outdated
Comment thread modelopt/torch/quantization/model_calib.py Outdated
@codecov
Copy link
Copy Markdown

codecov Bot commented May 22, 2026

Codecov Report

✅ All modified and coverable lines are covered by tests.
✅ Project coverage is 69.13%. Comparing base (c9098b6) to head (4bc83af).
⚠️ Report is 15 commits behind head on main.

❗ There is a different number of reports uploaded between BASE (c9098b6) and HEAD (4bc83af). Click for more details.

HEAD has 4 uploads less than BASE
Flag BASE (c9098b6) HEAD (4bc83af)
unit 2 1
gpu 3 2
examples 11 9
Additional details and impacted files
@@            Coverage Diff             @@
##             main    #1536      +/-   ##
==========================================
- Coverage   76.75%   69.13%   -7.62%     
==========================================
  Files         476      477       +1     
  Lines       51811    53007    +1196     
==========================================
- Hits        39767    36649    -3118     
- Misses      12044    16358    +4314     
Flag Coverage Δ
examples 33.74% <42.85%> (-6.99%) ⬇️
gpu 51.03% <95.71%> (-9.07%) ⬇️
regression 15.23% <11.42%> (+0.09%) ⬆️
unit 52.84% <81.53%> (+0.21%) ⬆️

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

Signed-off-by: realAsma <akuriparambi@nvidia.com>
@realAsma
Copy link
Copy Markdown
Contributor Author

🤖 Bot comment.

Regarding CodeRabbit’s stats-lifecycle suggestion at #1536 (comment): I am going to leave this as-is. The helper currently has a simple, linear stats lifecycle, and adding a try/finally here makes the common path harder to read without addressing a demonstrated failure in this PR. We can revisit this if we see stale stats after a failed calibration run.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
modelopt/torch/quantization/model_calib.py (1)

523-536: ⚠️ Potential issue | 🟠 Major | ⚡ Quick win

Restore the original calibrator after the temporary MSE pass.

This loop leaves weight_quantizer._calibrator pointing at cal, then immediately resets it. For the base MseCalibrator, reset() clears _initial_amax, so a later calibration pass on the same model can hit a broken calibrator state instead of the original one.

Suggested fix
                 if cal is None:
                     continue
-                weight_quantizer._calibrator = cal
-                _run_and_load_max_stats(weight_quantizer, lambda q: q(weight))
-                if hasattr(cal, "reset"):
-                    cal.reset()
+                original_calibrator = weight_quantizer._calibrator
+                weight_quantizer._calibrator = cal
+                try:
+                    _run_and_load_max_stats(weight_quantizer, lambda q: q(weight))
+                finally:
+                    weight_quantizer._calibrator = original_calibrator
+                    if hasattr(cal, "reset"):
+                        cal.reset()
 
                 pbar.update(1)
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@modelopt/torch/quantization/model_calib.py` around lines 523 - 536, The loop
temporarily replaces weight_quantizer._calibrator with a new MSE calibrator from
_make_weight_mse_calibrator, runs _run_and_load_max_stats, then calls
cal.reset(), but never restores the original calibrator; save the original (orig
= weight_quantizer._calibrator) before assigning the temporary cal, run
_run_and_load_max_stats with the temp calibrator, call cal.reset() if needed,
and finally restore weight_quantizer._calibrator = orig so the original
calibrator state is preserved for subsequent passes (references:
weight_quantizer._calibrator, _make_weight_mse_calibrator,
_run_and_load_max_stats, cal.reset()).
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Outside diff comments:
In `@modelopt/torch/quantization/model_calib.py`:
- Around line 523-536: The loop temporarily replaces
weight_quantizer._calibrator with a new MSE calibrator from
_make_weight_mse_calibrator, runs _run_and_load_max_stats, then calls
cal.reset(), but never restores the original calibrator; save the original (orig
= weight_quantizer._calibrator) before assigning the temporary cal, run
_run_and_load_max_stats with the temp calibrator, call cal.reset() if needed,
and finally restore weight_quantizer._calibrator = orig so the original
calibrator state is preserved for subsequent passes (references:
weight_quantizer._calibrator, _make_weight_mse_calibrator,
_run_and_load_max_stats, cal.reset()).

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Enterprise

Run ID: 38c07283-0685-4897-9033-c82711e6c56c

📥 Commits

Reviewing files that changed from the base of the PR and between a263e63 and e980773.

📒 Files selected for processing (5)
  • modelopt/torch/quantization/calib/mse.py
  • modelopt/torch/quantization/model_calib.py
  • modelopt/torch/quantization/nn/modules/tensor_quantizer.py
  • tests/gpu/torch/quantization/test_nvfp4_fp8_sweep_kernel.py
  • tests/unit/torch/quantization/test_mse_calibrator.py

@realAsma
Copy link
Copy Markdown
Contributor Author

🤖 Bot comment.

Auto-replying to realAsma review comments:

@realAsma
Copy link
Copy Markdown
Contributor Author

🤖 Bot comment.

/claude review

Copy link
Copy Markdown
Collaborator

@cjluo-nv cjluo-nv left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bot review — DM the bot to share feedback.

The actual diff goes well beyond the stated "sync on CUDA + remove _wait_for_cuda_amax" scope: it adds a new public-ish config field apply_mse_nvfp_static_only, removes promote_nvfp4_static_quantizers from utils/core_utils.py, restructures mse_calibrate, narrows _bootstrap_uncalibrated_weight_quantizers to NVFP4-static only (renamed accordingly), and changes load_calib_amax so that an existing _amax buffer is always replaced (inheriting calib_amax's dtype) instead of copy_'d in place (preserving the buffer's dtype). That last change is model-wide, not NVFP4-specific, and the PR body doesn't call it out — worth a human pass to confirm no downstream caller relies on the buffer dtype being preserved (e.g. an fp16 _amax registered earlier silently becomes fp32). Tests look reasonable: new fp32-amax preservation test, fast-path-cache one-shot test, and a static-only-skip test for the registry. Couple of smaller items also worth eyeballing:

  • _promote_nvfp4_static_quantizers_with_global_amax_sync is now called inside max_calibrate, so the bootstrap (formerly only triggered from mse_calibrate) now runs for every max_calibrate consumer (gptq, awq postprocess, smoothquant, etc.). Intended? The _check_moe_calibration_complete path runs after promotion — fine since the check looks at _amax on input quantizers too, but worth confirming nothing now warns where it didn't before.
  • local_hessian_calibrate lost its explicit _sync_grouped_weight_global_amax(model) call. It's now subsumed by max_calibrate → promotion path, but local_hessian also has its own per-module promotion later. Double-promotion is a no-op (covered by from_tensor_quantizer), just noting the redundancy.
  • _compute_candidate_amax now multiplies by torch.ones_like(self._initial_amax, dtype=torch.float32) — the ones_like is purely for broadcasting and a plain (self._global_amax.to(torch.float32) * candidates).expand_as(self._initial_amax) (or simpler, just relying on broadcast) would be cheaper; minor.

Comment thread modelopt_recipes/configs/ptq/presets/model/nvfp4_w4a4_weight_mse_fp8_sweep.yaml Outdated
Copy link
Copy Markdown

@claude claude Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Claude review passed — no blocking issues found. LGTM

Findings: 0 CRITICAL, 0 IMPORTANT, 2 SUGGESTIONS

The refactor is internally consistent:

  • NVFP4MSECalibrator becoming one-shot is fine — it's only constructed for weights in _make_weight_mse_calibrator, never for activations.
  • Bootstrap scope narrowing (_bootstrap_uncalibrated_static_weight_quantizers) is intentional and reflected in the test rename; all in-tree MSE recipes are NVFP4-only.
  • load_calib_amax change to preserve calibrator dtype on the buffer is covered by the new test_load_calib_amax_preserves_fp32_result_dtype regression test.
  • _promote_nvfp4_static_quantizers_with_global_amax_sync is idempotent (from_tensor_quantizer is a no-op for already-promoted modules) and preprocess_linear_fusion correctly unifies grouped global_amax after promotion.

Two non-blocking suggestions left as inline comments:

  1. Likely-unnecessary torch.cuda.synchronize in _run_reference_collect.
  2. load_calib_amax could keep the new-buffer branch going through register_buffer for symmetry with amax.setter.

Comment thread modelopt/torch/quantization/calib/mse.py Outdated
Comment thread modelopt/torch/quantization/nn/modules/tensor_quantizer.py Outdated
@realAsma
Copy link
Copy Markdown
Contributor Author

🤖 Bot comment.

Regarding #1536 (review): leaving this unchanged per branch-owner review. The current mse_calibrate flow intentionally installs the selected MSE calibrator on the weight quantizer before loading the chosen stats, and the extra restore/reset change is not needed here. I reverted the local experimental change.

@realAsma
Copy link
Copy Markdown
Contributor Author

🤖 Bot comment.

Regarding the NVFP4 reference MSE sync comment: updated the source comment to clarify that the CPU-GPU sync is there to prevent reference MSE calibration for another weight from running in parallel. The comment now also documents the retained _losses_sum memory: one fp32 reduced loss per candidate per block, about 126 * num_blocks * 4 bytes for NVFP4.

@realAsma
Copy link
Copy Markdown
Contributor Author

🤖 Bot comment.

Correction to my previous note about the NVFP4 reference MSE sync comment: I updated the memory wording to express _losses_sum relative to the calibrated weight. With 16-element NVFP4 blocks and bf16 weights, _losses_sum is roughly 128 / 16 * (4 / 2) = 16x the calibrated weight size, and the sync prevents reference MSE calibration for another weight from overlapping that allocation.

Signed-off-by: realAsma <akuriparambi@nvidia.com>
@realAsma
Copy link
Copy Markdown
Contributor Author

🤖 Bot comment.

Regarding #1536 (review): Thanks for the careful pass. The broader scope is intentional: the fp32 amax preservation, static-NVFP4-only MSE gate, max-calibration promotion, and grouped global-amax sync are tied together so static NVFP4 weights are finalized consistently before downstream calibrators/export. I agree the PR body should make that scope clearer, especially the load_calib_amax dtype behavior and the fact that max_calibrate now performs static NVFP4 promotion. The local-hessian promotion path is redundant but no-op after max_calibrate. I left the ones_like expression as-is because it keeps the shape/dtype broadcasting explicit. I also pushed fc8a95d0 to clarify the NVFP4 reference MSE sync comment and its _losses_sum memory cost.

@realAsma
Copy link
Copy Markdown
Contributor Author

Is fp32 MSE scale preserved after save/restore?

@realAsma
Copy link
Copy Markdown
Contributor Author

/claude review

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
tests/unit/torch/quantization/test_mse_calibrator.py (1)

560-563: ⚠️ Potential issue | 🟠 Major | ⚡ Quick win

Move helper imports to module scope in tests.

These imports are inside a test helper without a circular/optional/heavy-import justification, which violates the test conventions and can defer import failures to runtime.

Proposed fix
 import torch
 
+import modelopt.torch.quantization as mtq
 from modelopt.torch.quantization import calib
 from modelopt.torch.quantization.config import QuantizerAttributeConfig
 from modelopt.torch.quantization.model_calib import (
+    mse_calibrate,
     _make_weight_mse_calibrator,
     _promote_nvfp4_static_quantizers_with_global_amax_sync,
 )
 from modelopt.torch.quantization.nn import NVFP4StaticQuantizer, TensorQuantizer
+from modelopt.torch.quantization.nn.modules.tensor_quantizer import register_quant_backend
 from modelopt.torch.quantization.utils import enable_fake_quant
@@
     def _quantize_and_calibrate(self, backend_name, fp8_scale_sweep=True):
         """Quantize a small Linear with the given backend and run mse_calibrate."""
-        import modelopt.torch.quantization as mtq
-        from modelopt.torch.quantization.model_calib import mse_calibrate
-        from modelopt.torch.quantization.nn.modules.tensor_quantizer import register_quant_backend
-
         register_quant_backend(backend_name, lambda x, tq: x)

As per coding guidelines, “tests/**/*.py: Imports inside functions or test methods without explicit justification... Imports belong at the top of the file...”.

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@tests/unit/torch/quantization/test_mse_calibrator.py` around lines 560 - 563,
The test currently performs local imports of modelopt.torch.quantization as mtq,
mse_calibrate (from modelopt.torch.quantization.model_calib) and
register_quant_backend (from
modelopt.torch.quantization.nn.modules.tensor_quantizer) inside a helper; move
these import statements to module scope at the top of the test file and remove
the in-function imports so import errors surface at collection time and follow
the test import convention.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Outside diff comments:
In `@tests/unit/torch/quantization/test_mse_calibrator.py`:
- Around line 560-563: The test currently performs local imports of
modelopt.torch.quantization as mtq, mse_calibrate (from
modelopt.torch.quantization.model_calib) and register_quant_backend (from
modelopt.torch.quantization.nn.modules.tensor_quantizer) inside a helper; move
these import statements to module scope at the top of the test file and remove
the in-function imports so import errors surface at collection time and follow
the test import convention.

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Enterprise

Run ID: d9b4f24a-c6cb-42a3-84e5-d1c8e44159db

📥 Commits

Reviewing files that changed from the base of the PR and between fc8a95d and 419657d.

📒 Files selected for processing (3)
  • modelopt/torch/quantization/config.py
  • modelopt/torch/quantization/model_calib.py
  • tests/unit/torch/quantization/test_mse_calibrator.py

Copy link
Copy Markdown

@claude claude Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Claude review summary

Findings: CRITICAL: 0, IMPORTANT: 1, SUGGESTION: 5

Most impactful

  • [IMPORTANT Compatibility] _bootstrap_uncalibrated_static_weight_quantizers was narrowed to NVFP4-static only and the unconditional bootstrap step at the top of mse_calibrate was removed. Non-NVFP4 MoE workloads (e.g. INT8/FP8 weight quantization) with dead experts now leave those quantizers with _amax = None after mse_calibrate_make_weight_mse_calibrator silently skips them, and the renamed test no longer covers the non-NVFP4 case. This is a real regression for INT8/FP8 MoE MSE flows.

Suggestions

  • load_calib_amax shape check uses assert (stripped under -O) and replaces the buffer instead of in-place copy; both worth a comment / raise.
  • promote_nvfp4_static_quantizers was importable from modelopt.torch.quantization.utils; consider a deprecation shim or a CHANGELOG note since the PR claims backward compatible.
  • Stale bootstrap warning text (no longer mentions "static NVFP4").
  • lambda q: q(weight) captures the loop variable; safe today but cheap to harden as lambda q, w=weight: q(w).
  • PR description advertises an apply_mse_nvfp_static_only field on MseCalibConfig that isn't present in the diff — please reconcile.

Risk

The refactor itself (one-shot caching, fp32 amax preservation, centralized max-stat collection, NVFP4 promotion inside max_calibrate) is a clean cleanup with good test coverage for the NVFP4 path. The main risk is the silent loss of dead-expert bootstrap for non-NVFP4 MoE configurations going through mse_calibrate; addressing that (or confirming it's intended and out of scope) would clear the blocker.

Comment thread modelopt/torch/quantization/nn/modules/tensor_quantizer.py Outdated
Comment thread modelopt/torch/quantization/utils/core_utils.py
Comment thread modelopt/torch/quantization/model_calib.py Outdated
Copy link
Copy Markdown
Collaborator

@cjluo-nv cjluo-nv left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bot review — DM the bot to share feedback.

A lot of progress on this re-review — the fp32-amax preservation, test-import scope, FP8-sweep config simplification (collapsing apply_mse_nvfp_static_only into fp8_scale_sweep), and the load_calib_amax simplification have all landed. Below are the points still worth a human pass before approval; ones the author has responded to are marked with 💬.

  • 💬 Non-NVFP4 dead-expert MoE MSE regression_bootstrap_uncalibrated_weight_quantizers was narrowed to _bootstrap_uncalibrated_static_weight_quantizers (NVFP4-only) and the unconditional bootstrap step at the top of mse_calibrate was removed. Author's response (to cjluo) confirms scope expansion is intentional and tied to NVFP4 finalization, but doesn't directly address what happens to INT8/FP8 MoE configs with dead experts under mse: those weight quantizers will now stay at _amax=None, _make_weight_mse_calibrator silently returns None, and the renamed test_max_calibrate_populates_dead_static_nvfp4_expert_quantizers no longer exercises the non-NVFP4 path. Worth confirming this is intended (and not an unrelated MoE recipe regression).

  • 💬 load_calib_amax buffer identity change is model-wide — cjluo flagged that replacing self._buffers["_amax"] instead of self._amax.data.copy_() changes the identity of the buffer for every quantizer, not just NVFP4. Author kept the new behavior to preserve fp32 dtype and added a regression test for the dtype, but the identity change for non-NVFP4 callers (anyone holding a stale _amax reference across load_calib_amax) hasn't been audited — the PR body still doesn't call this out as a behavior change. A second pair of eyes on downstream callers would be reassuring.

  • promote_nvfp4_static_quantizers removed from utils/core_utils.py with no deprecation shim — the PR self-marks "backward compatible", but from modelopt.torch.quantization.utils import promote_nvfp4_static_quantizers (used by at least tests/gpu/torch/quantization/test_gptq.py before this PR) will now ImportError for any external caller. Either keep a thin deprecation shim that calls _promote_nvfp4_static_quantizers_with_global_amax_sync, or note the removal in CHANGELOG and uncheck the back-compat box.

  • In-function imports in test helpertests/unit/torch/quantization/test_mse_calibrator.py::TestRegisterFP8SweepCalibrator._quantize_and_calibrate still does import modelopt.torch.quantization as mtq / mse_calibrate / register_quant_backend inside the method. Earlier in-function imports in the same file and in test_nvfp4_fp8_sweep_kernel.py were moved to module scope; this one was missed. Style-only, not a blocker.

  • Minor: the lambda q: q(weight) in mse_calibrate captures the loop weight; safe today because _run_and_load_max_stats invokes synchronously, but lambda q, w=weight: q(w) would harden it against any future deferral. The getattr(self, "_best_amax_fast", None) in NVFP4MSECalibrator.compute_amax is also redundant given __init__ sets the attribute — a plain if self._best_amax_fast is None would do.


Additional comments (outside the PR diff):

  • tests/unit/torch/quantization/test_mse_calibrator.py:562 — > Bot comment.

Style nit: mtq, mse_calibrate, and register_quant_backend are still imported inside this method. Earlier in-function imports in this file and in test_nvfp4_fp8_sweep_kernel.py were moved to module scope in e980773b; this helper looks like it was missed. Not a blocker.

Comment thread modelopt/torch/quantization/model_calib.py
Comment thread modelopt/torch/quantization/nn/modules/tensor_quantizer.py Outdated
Comment thread modelopt/torch/quantization/utils/core_utils.py
Comment thread modelopt/torch/quantization/model_calib.py
Signed-off-by: realAsma <akuriparambi@nvidia.com>
@realAsma
Copy link
Copy Markdown
Contributor Author

🤖 Bot comment.

Pushed e3da774a with the review-fix batch.

Addressed the bot-tagged feedback from the latest review:

Validation before push:

  • git diff --check
  • pre-commit run --files modelopt/torch/quantization/model_calib.py modelopt/torch/quantization/nn/modules/tensor_quantizer.py modelopt/torch/quantization/utils/core_utils.py tests/unit/torch/quantization/test_mse_calibrator.py
  • pytest_pwd tests/unit/torch/quantization/test_mse_calibrator.py::TestRegisterFP8SweepCalibrator tests/unit/torch/quantization/test_mse_calibrator.py::TestStaticNVFP4Promotion -q

@realAsma
Copy link
Copy Markdown
Contributor Author

/claude review

Copy link
Copy Markdown

@claude claude Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Claude review passed — no blocking issues found. LGTM

Signed-off-by: realAsma <akuriparambi@nvidia.com>
@realAsma
Copy link
Copy Markdown
Contributor Author

/claude review

Copy link
Copy Markdown

@claude claude Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Claude review passed — no blocking issues found. LGTM

Signed-off-by: realAsma <akuriparambi@nvidia.com>
@realAsma
Copy link
Copy Markdown
Contributor Author

/claude review

Copy link
Copy Markdown

@claude claude Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Claude review passed — no blocking issues found. LGTM

Reviewed the NVFP4 MSE calibration refactor end-to-end (one-shot collect/cache, max-stat helper centralization, static NVFP4 promotion moved into max_calibrate, fp32-preserving load_calib_amax). Algorithm flow is correct: NVFP4MSECalibrator caches the final per-block amax in fp32 after either the Triton fast path or the reference 126-step sweep; multi-collect-after-final is now an explicit RuntimeError. The mse_calibrate dispatch (registry → ModelOpt static NVFP4 → MseCalibrator fallback) lines up with the docstring, and the local_hessian_calibrate path still calls collect once per weight so the new one-shot semantics don't regress it.

Posted three SUGGESTION-level comments only (none blocking):

  • load_calib_amax docstring still says "updates" but now replaces the buffer; direct _buffers["_amax"] access is inconsistent with the delattr+register_buffer pattern in export/quant_utils.py.
  • _compute_candidate_amax has a redundant ones_like multiplication that can be trimmed.
  • _uses_modelopt_fp8_weight_scales relies on an implicit "max_calibrate already promoted" precondition before reading weight_quantizer.global_amax — making that explicit would harden the helper.

Risk assessment: low. Core changes are well-tested (fp32 dtype preservation, one-shot reset semantics, dead-expert static-NVFP4 bootstrap, grouped global-amax sync) and the behavior change is intentional and documented in the description.

Comment on lines +635 to +640
if hasattr(self, "_amax"):
if self._amax.shape != calib_amax.shape:
raise ValueError("Changing shape when loading calibration amax is not allowed.")
device = self._amax.device if hasattr(self, "_amax") else calib_amax.device
# Replace instead of copy_ so the stored amax keeps the calibrator result dtype.
self._buffers["_amax"] = calib_amax.clone().detach().to(device=device)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[SUGGESTION] Two small notes on this dtype-preserving rewrite:

  1. The docstring above (line 611) still says "Updates the amax buffer", but the new behavior replaces the buffer (so any external code holding q._amax directly will keep a stale tensor reference). Worth a one-line clarification.
  2. Direct self._buffers["_amax"] = ... bypasses register_buffer, which is inconsistent with the delattr + register_buffer pattern in modelopt/torch/export/quant_utils.py:_set_amax_from_tensor. Functionally equivalent in PyTorch today, but if you want to keep the codebase using a single replacement idiom, calling delattr(self, "_amax") (when it exists) and then self.register_buffer("_amax", calib_amax.clone().detach().to(device=device)) would match the existing pattern.

Comment thread modelopt/torch/quantization/calib/mse.py
Comment on lines +449 to +466
initial_amax = weight_quantizer._amax.clone().detach()
axis = weight_quantizer._calibrator._axis
quant_func = partial(_mse_quant_func, quantizer=weight_quantizer)

if fp8_scale_sweep:
backend: str | None = getattr(weight_quantizer, "backend", None)
backend_factory = (
_FP8_SWEEP_CALIBRATOR_REGISTRY.get(backend) if backend is not None else None
)
if backend is not None and backend_factory is not None:
return backend_factory(initial_amax, axis, quant_func)
elif _uses_modelopt_fp8_weight_scales(weight_quantizer):
return NVFP4MSECalibrator(
amax=initial_amax,
axis=axis,
global_amax=weight_quantizer.global_amax,
quant_func=quant_func,
)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[SUGGESTION] _uses_modelopt_fp8_weight_scales only checks is_nvfp4_static — it doesn't require the quantizer to already be promoted to NVFP4StaticQuantizer. Reading weight_quantizer.global_amax here will AttributeError for a base TensorQuantizer that happens to be is_nvfp4_static but hasn't been promoted yet.

In the current call graph this is safe because mse_calibrate runs max_calibrate first and max_calibrate now promotes via _promote_nvfp4_static_quantizers_with_global_amax_sync. But the precondition is implicit. Either an assert isinstance(weight_quantizer, NVFP4StaticQuantizer) here, or having _uses_modelopt_fp8_weight_scales also check isinstance(weight_quantizer, NVFP4StaticQuantizer), would make the contract explicit and produce a friendlier error if the helper is reused elsewhere.

@realAsma realAsma requested a review from cjluo-nv May 27, 2026 17:58
realAsma added 2 commits May 27, 2026 18:12
Signed-off-by: realAsma <akuriparambi@nvidia.com>
Signed-off-by: realAsma <akuriparambi@nvidia.com>
Copy link
Copy Markdown
Collaborator

@cjluo-nv cjluo-nv left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bot review — DM the bot to share feedback.

Re-review summary: most prior concerns are addressed, but one model-wide behavior change still merits human sign-off.

Addressed since the previous round:

  • 💬 Non-NVFP4 dead-expert MoE bootstrap (cjluo) — e3da774a widens _bootstrap_uncalibrated_weight_quantizers back to all quantizers and runs it inside max_calibrate (default bootstrap_weight_quantizers=True), with a new test_max_calibrate_bootstraps_non_nvfp4_dead_weight_quantizer test. Looks resolved.
  • 💬 promote_nvfp4_static_quantizers removal (cjluo) — e3da774a restores it in utils/core_utils.py as the public helper; from modelopt.torch.quantization.utils import promote_nvfp4_static_quantizers works again via the star re-export.
  • 💬 load_calib_amax assertraise ValueError and new-buffer comment (claude) — applied.
  • 💬 In-method test imports (CodeRabbit) — moved to module scope in test_mse_calibrator.py and test_nvfp4_fp8_sweep_kernel.py.
  • 💬 fp8_scale_sweep + apply_mse_nvfp_static_only two-flag confusion (Fridah-nv) — 419657d2 collapses to a single fp8_scale_sweep switch, with dispatch handled in _make_weight_mse_calibrator via _uses_modelopt_fp8_weight_scales and the registered backend factory.
  • Lambda capture of loop weight in mse_calibrate (claude) — implicitly fixed by switching to partial(_collect_weight_stats, weight=weight).

Still warrants a human pass before approval:

  • 💬 load_calib_amax now replaces self._buffers["_amax"] instead of copy_'ing in-place (cjluo). Author kept the new behavior to preserve fp32 dtype and added an inline comment + a regression test, but this is a model-wide identity change for every quantizer (not just NVFP4). Any caller holding a stale q._amax reference across load_calib_amax() will now silently see the pre-load tensor. The PR body still doesn't call this out as a behavior change. Quick audit of in-tree callers of load_calib_amax (export paths, calibration restore, amax.setter) would be reassuring before sign-off.
  • _run_and_load_max_stats has no try/finally around enable_stats_collection / finish_stats_collection (CodeRabbit). Author explicitly declined; reasonable given the linear lifecycle, but worth a final sanity check that no new caller (e.g. max_calibrate consumers) can leave quantizers stuck in _if_calib=True after a forward exception.

No design-review concerns: this PR refactors existing calibration code rather than introducing a new subsystem. Test coverage looks solid (fp32-amax preservation, one-shot collect/reset, dead-expert bootstrap for NVFP4 and non-NVFP4, registered FP8-sweep dispatch, grouped global-amax sync).

Signed-off-by: realAsma <akuriparambi@nvidia.com>
@realAsma
Copy link
Copy Markdown
Contributor Author

🤖 Bot comment.

Addressing the latest modelopt review note about the model-wide load_calib_amax buffer identity change: pushed 9e36ea34.

TensorQuantizer.load_calib_amax() is now restored to the main-branch register-or-copy_ behavior, so non-NVFP4 quantizers no longer replace _amax and retain the prior buffer identity/dtype semantics.

The fp32 preservation requirement is now isolated to NVFP4StaticQuantizer.load_calib_amax(): it upcasts an existing lower-precision _amax to fp32 before delegating to the base implementation. That means buffer replacement is scoped to NVFP4 static quantizers, and only when dtype migration is required for the NVFP4 MSE fp32-amax path.

Focused validation passed:
pytest_pwd tests/gpu/torch/quantization/test_nvfp4_static_quantizer_cuda.py::TestNVFP4StaticQuantizer::test_load_calib_amax_preserves_fp32_result_dtype

Signed-off-by: realAsma <akuriparambi@nvidia.com>
Comment thread modelopt/torch/quantization/calib/mse.py
Comment thread modelopt/torch/quantization/calib/mse.py
Comment thread modelopt/torch/quantization/model_calib.py
Comment thread modelopt/torch/quantization/model_calib.py
Signed-off-by: realAsma <akuriparambi@nvidia.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants