Skip to content

Add layerwise calibration for large models#1251

Merged
realAsma merged 5 commits intomainfrom
asma/ptq-large-models
Apr 18, 2026
Merged

Add layerwise calibration for large models#1251
realAsma merged 5 commits intomainfrom
asma/ptq-large-models

Conversation

@realAsma
Copy link
Copy Markdown
Contributor

@realAsma realAsma commented Apr 13, 2026

Summary

Adds performant layerwise calibration for quantizing large models (e.g. DeepSeek-R1 671B) that don't fit entirely on GPU. (Example commands)

  1. Performant calibration for large models — Each decoder layer is moved from CPU/disk to GPU (accelerate) or unsharded (FSDP2) only once and kept on GPU for the entire calibration step. Previously, every calibration batch triggered weight transfer for every layer — O(num_batches) weight movements per layer. Now it is O(1) per layer. This also means you can increase batch size since only one layer's weights occupy GPU at a time — e.g. DeepSeek-R1 on a single node (8×80GB) with batch_size=16 and gpu_max_mem_percentage=0.5.
  2. Checkpoint save/resume — Saves progress after each layer, so jobs that exceed cluster time limits (e.g. 4-hour Slurm windows for 100+ layer MoE models) can resume from the last completed layer.
  3. Rename sequential_calibratelayerwise_calibrate for clarity.

Design details

The existing layerwise state machine (skip/run/capture) already processes one layer at a time, but skip-mode layers still kept their parameters in the ModuleList — so frameworks transferred all weights every forward pass. This PR adds:

  • _SkipLayer: replaces fully-calibrated layers with a parameter-free dummy in the ModuleList, so framework hooks have nothing to transfer
  • persistent_materialization: keeps the active layer on GPU for the entire calibration step, avoiding repeated offload/reload cycles

Checkpoint save is per-layer; restore is bulk — quantizer state and weights for layers 0..K-1 are restored once at the end of calibration, keeping the hot path fast.

Example commands

Qwen3-8B (NVFP4+GPTQ, single GPU):

python hf_ptq.py \
    --pyt_ckpt_path Qwen/Qwen3-8B \
    --recipe nvfp4_gptq_sequential.yaml \
    --calib_size 64 \
    --batch_size 16 \
    --dataset cnn_dailymail \
    --export_path outputs/qwen3_8b_nvfp4_gptq_seq \
    --gpu_max_mem_percentage 0.5 \
    --use_seq_device_map \
    --vllm_fakequant_export

DeepSeek-R1 (NVFP4 experts-only + FP8 KV, 8×80GB):

python hf_ptq.py \
    --model unsloth/DeepSeek-R1-0528-BF16 \
    --recipe ../../modelopt_recipes/general/ptq/nvfp4_experts_only-fp8_kv.yaml \
    --dataset cnn_dailymail \
    --batch_size 16 \
    --calib_size 64 \
    --calib_seq 512 \
    --gpu_max_mem_percentage 0.5 \
    --use_seq_device_map \
    --trust_remote_code \
    --export_path output/DeepSeek-R1-BF16-nvfp4-experts-only-fp8-kv \
    --vllm_fakequant_export

Example: NVFP4+GPTQ layerwise calibration on Qwen3-8B (36 layers, single GPU — 20 GB peak)

Initial run (killed after layer 11):

Layerwise calibration: Found 36 transformer layers
Calibrating layer 1/36 | capture: [1]
Computing Hessians for 7 linear layers...
GPTQ time: 51.39s
Calibrating layer 2/36 | run: [1] | capture: [2]
Checkpoint: saved layer 0
GPTQ time: 50.06s
Calibrating layer 3/36 | skip: 1 | run: [2] | capture: [3]
Checkpoint: saved layer 1
...
Calibrating layer 12/36 | skip: 10 | run: [11] | capture: [12]
Checkpoint: saved layer 10
<killed>

Resumed run (picks up from layer 11, finishes all 36):

Layerwise calibration: Found 36 transformer layers
Checkpoint: resuming layerwise calibration from layer 11/36
Calibrating layer 12 (resumed)
GPTQ time: 51.45s
Calibrating layer 13/36 | skip: 11 | run: [12] | capture: [13]
Checkpoint: saved layer 11
...
Calibrating layer 36/36 | skip: 34 | run: [35] | capture: [36]
Checkpoint: saved layer 34
GPTQ time: 50.33s
Checkpoint: saved layer 35 (final)
Checkpoint: restored 11 previously calibrated layers
Layerwise calibration completed
Quantized model exported to: outputs/qwen3_8b_nvfp4_gptq_seq
GPU 0: Peak memory usage = 20.42 GB

TODO

  • Update CHANGELOG

Test plan

  • tests/unit/torch/quantization/test_layerwise_calibrate.py — unit tests for skip/swap/restore
  • tests/unit/torch/quantization/test_sequential_checkpoint.py — checkpoint save/resume correctness
  • tests/gpu/torch/quantization/plugins/test_accelerate_gpu.py — CPU-offloaded layerwise + GPTQ + checkpoint resume
  • tests/gpu/torch/quantization/test_fsdp2.py — FSDP2 layerwise calibration

Verified

  • Qwen3-8B: layerwise calibration + checkpoint save/restore + fakequantized checkpoint export + vLLM serve
  • DeepSeek-R1: checkpoint resume tested
  • DeepSeek-R1: fakequantized checkpoint export verified

@copy-pr-bot
Copy link
Copy Markdown

copy-pr-bot bot commented Apr 13, 2026

Auto-sync is disabled for draft pull requests in this repository. Workflows must be run manually.

Contributors can view more details about this message here.

@github-actions
Copy link
Copy Markdown
Contributor

github-actions bot commented Apr 13, 2026

PR Preview Action v1.8.1
Preview removed because the pull request was closed.
2026-04-18 00:33 UTC

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai bot commented Apr 13, 2026

Note

Reviews paused

It looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the reviews.auto_review.auto_pause_after_reviewed_commits setting.

Use the following commands to manage reviews:

  • @coderabbitai resume to resume automatic reviews.
  • @coderabbitai review to trigger a single review.

Use the checkboxes below for quick actions:

  • ▶️ Resume reviews
  • 🔍 Trigger review
📝 Walkthrough

Walkthrough

Rename calibration mode flag use_sequentialuse_layerwise, add optional checkpoint_dir, replace sequential calibration with a new layerwise calibrator (with per-layer checkpoints/resume), introduce a new layerwise activation collector, update accelerate/FSDP/device helpers, and add extensive tests and example helpers.

Changes

Cohort / File(s) Summary
Config
modelopt/torch/quantization/config.py
Renamed use_sequentialuse_layerwise on QuantizeAlgorithmConfig; added optional `checkpoint_dir: str
Calibration entrypoint & mode
modelopt/torch/quantization/model_calib.py, modelopt/torch/quantization/mode.py
Replaced sequential_calibrate with layerwise_calibrate; wrapper now respects use_layerwise and checkpoint_dir, routes to layerwise_calibrate, and forwards checkpointing kwargs.
Layerwise implementation (new)
modelopt/torch/quantization/utils/layerwise_calib.py
New module implementing LayerActivationCollector, per-layer modes (capture/run/skip/original), persistent-materialization helpers, checkpoint manifest and per-layer snapshot/save/restore, and resume detection.
Removed legacy collector
modelopt/torch/quantization/utils/activation_collector.py
Deleted the old sequential LayerActivationCollector implementation.
Utils exports & imports
modelopt/torch/quantization/utils/__init__.py, modelopt/torch/quantization/plugins/huggingface.py, tests...
Switched imports to layerwise_calib collector across utils, plugins, and tests.
Accelerate CPU-offload integration
modelopt/torch/quantization/plugins/accelerate.py
Relaxed weights_map validation, added _writeback_params_to_weights_map, and reworked weight_access_and_writeback_context to handle single-module and child-hook layouts with multi-param writeback and correct pre/post hooks.
FSDP2 / core utils
modelopt/torch/quantization/utils/core_utils.py, modelopt/torch/utils/network.py
Added _set_parameter, persistent_materialization, _disable_fsdp_unshard_reshard; generalized FSDP2 parameter access/writeback across all named parameters; get_module_device now considers accelerate hook execution_device.
Hessian/device tweak
modelopt/torch/quantization/utils/calib_utils.py
Force Hessian allocation on CPU when module weight device is meta.
Examples & CLI helpers
examples/llm_ptq/hf_ptq.py, examples/llm_ptq/example_utils.py
Add needs_checkpoint_path_update and resolve_checkpoint_dir; normalize KV cfg; auto-resolve and print checkpoint dir before quantization when applicable.
Dataset loop tweak
modelopt/torch/utils/dataset_utils.py
Temporarily disable model.config.use_cache during _forward_loop and restore it afterwards.
Tests
tests/...
Extensive test additions and updates: replace sequential→layerwise in tests, add many layerwise/checkpoint/resume/FSDP/accelerate integration tests, and get_module_device unit tests.

Sequence Diagram(s)

sequenceDiagram
  participant Entrypoint as Calibration Entrypoint
  participant Model as Model
  participant Collector as LayerActivationCollector
  participant Forward as ForwardLoop
  participant Checkpoint as CheckpointStore
  participant GPTQ as GPTQ Updater

  Entrypoint->>Collector: attach/discover layers
  Entrypoint->>Checkpoint: detect_resume_point(checkpoint_dir)
  alt resume available
    Checkpoint-->>Collector: restore output_meta + next_inputs
  end
  loop for layer in start_layer..N
    Entrypoint->>Collector: set mode -> capture(layer)
    Entrypoint->>Forward: run forward (captures inputs / EarlyStop)
    Collector-->>Entrypoint: captured inputs
    Entrypoint->>Collector: set mode -> run(layer)
    Entrypoint->>Forward: replay captured inputs -> outputs
    Entrypoint->>Checkpoint: save(layer_weights, quantizer_state, output_meta, next_inputs)
    alt GPTQ enabled
      Entrypoint->>GPTQ: update_weights_for_layer(...)
    end
  end
  Entrypoint->>Checkpoint: full_restore(all_layers)
  Entrypoint->>Collector: unpatch and cleanup
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~75 minutes


Important

Pre-merge checks failed

Please resolve all errors before merging. Addressing warnings is optional.

❌ Failed checks (1 error, 1 warning)

Check name Status Explanation Resolution
Security Anti-Patterns ❌ Error Four torch.load() calls in layerwise_calib.py lack required inline security justification comments per SECURITY.md requirements. Add inline security comments to torch.load() calls at lines 578, 589, 613, 620 explaining files are internally-generated and trusted.
Docstring Coverage ⚠️ Warning Docstring coverage is 60.65% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (2 passed)
Check name Status Explanation
Title check ✅ Passed The title 'Add layerwise calibration for large models' clearly summarizes the main change, referring to the rename of sequential_calibrate to layerwise_calibrate and the addition of checkpoint save/resume support for large model calibration.
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Commit unit tests in branch asma/ptq-large-models

Comment @coderabbitai help to get the list of available commands and usage tips.

@realAsma realAsma force-pushed the asma/ptq-large-models branch 2 times, most recently from 8eabe76 to 6ec3721 Compare April 14, 2026 16:49
Comment thread modelopt/torch/quantization/plugins/accelerate.py Outdated
Comment thread modelopt/torch/quantization/plugins/accelerate.py Outdated
Comment thread modelopt/torch/quantization/utils/activation_collector.py
@codecov
Copy link
Copy Markdown

codecov bot commented Apr 14, 2026

Codecov Report

❌ Patch coverage is 94.73684% with 29 lines in your changes missing coverage. Please review.
✅ Project coverage is 76.67%. Comparing base (dc7ad66) to head (a33c01f).
⚠️ Report is 1 commits behind head on main.

Files with missing lines Patch % Lines
...delopt/torch/quantization/utils/layerwise_calib.py 94.02% 20 Missing ⚠️
modelopt/torch/export/plugins/vllm_fakequant_hf.py 92.18% 5 Missing ⚠️
modelopt/torch/quantization/plugins/accelerate.py 90.90% 3 Missing ⚠️
modelopt/torch/quantization/utils/core_utils.py 97.43% 1 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main    #1251      +/-   ##
==========================================
+ Coverage   72.52%   76.67%   +4.15%     
==========================================
  Files         459      459              
  Lines       48664    48975     +311     
==========================================
+ Hits        35292    37552    +2260     
+ Misses      13372    11423    -1949     
Flag Coverage Δ
examples 41.18% <21.05%> (+1.81%) ⬆️
gpu 60.18% <92.37%> (+7.87%) ⬆️
unit 52.29% <70.23%> (+0.11%) ⬆️

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

Comment thread modelopt/torch/quantization/utils/activation_collector.py
Comment thread modelopt/torch/quantization/utils/layerwise_calib.py Outdated
Comment thread modelopt/torch/quantization/model_calib.py Outdated
Comment thread modelopt/torch/utils/network.py Outdated
@realAsma realAsma force-pushed the asma/ptq-large-models branch from 6ec3721 to 8af3655 Compare April 14, 2026 18:48
Comment thread modelopt/torch/utils/network.py Outdated
@realAsma realAsma force-pushed the asma/ptq-large-models branch from 8af3655 to 6280846 Compare April 14, 2026 19:21
Comment thread tests/unit/torch/quantization/test_sequential_calibrate.py Outdated
@realAsma realAsma marked this pull request as ready for review April 14, 2026 19:25
@realAsma realAsma requested review from a team as code owners April 14, 2026 19:25
Copy link
Copy Markdown
Collaborator

@cjluo-nv cjluo-nv left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a substantial PR (~1500 lines) that adds checkpoint save/resume for sequential calibration, extends support to FSDP2 and accelerate-offloaded models, and renames activation_collector.pylayerwise_calib.py. The changes are cohesive and well-tested (unit + GPU tests for checkpoint, resume, offload, FSDP2 scenarios).

Key issues found:

  1. Removed guard on sequential calibration methods — The assertion restricting sequential calibration to max and gptq was removed without replacement. Methods like awq, smoothquant, and svdquant operate on the full model (not per-layer) and will break silently or produce incorrect results when used with use_sequential=True.

  2. weights_only=False security concerntorch.load(..., weights_only=False) is used for loading checkpoints, which can execute arbitrary code. While the checkpoints are locally generated, this is flagged by security scanners and should use weights_only=True where possible.

Minor observations:

  • PR size is above ~1000 lines but the changes are cohesive and hard to split
  • Good test coverage for the new functionality
  • The temporarily_remove_accelerate_hook rewrite is a nice improvement avoiding the init_hook pitfall
  • _writeback_params_to_weights_map properly handles all parameters (not just weight)
  • FSDP2 context manager correctly generalized to handle all DTensor parameters

Comment thread modelopt/torch/quantization/mode.py
Comment thread modelopt/torch/quantization/utils/layerwise_calib.py
Comment thread modelopt/torch/quantization/model_calib.py Outdated
Comment thread modelopt/torch/quantization/utils/layerwise_calib.py
Comment thread modelopt/torch/quantization/plugins/accelerate.py Outdated
Comment thread modelopt/torch/quantization/model_calib.py
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
modelopt/torch/quantization/model_calib.py (1)

1566-1632: ⚠️ Potential issue | 🔴 Critical

Add inline comments to torch.load(..., weights_only=False) calls in layerwise_calib.py.

Per SECURITY.md and the coding guidelines, torch.load(..., weights_only=False) must include an inline comment documenting why the file is internally-generated/trusted and safe to deserialize. Lines 545 and 555 in modelopt/torch/quantization/utils/layerwise_calib.py need this justification:

  • Line 545: Loading output_meta.pt
  • Line 555: Loading next_inputs.pt

Add a comment before each call explaining these checkpoint files are generated and managed internally by the sequential calibration process, confirming they are trusted sources.

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/quantization/model_calib.py` around lines 1566 - 1632, Add
inline comments immediately before the two torch.load(..., weights_only=False)
calls in modelopt.torch.quantization.utils.layerwise_calib (look around the
_CheckpointState usage and the methods that load "output_meta.pt" and
"next_inputs.pt") stating that these checkpoint files ("output_meta.pt" and
"next_inputs.pt") are generated and managed internally by the sequential
calibration process, are not user-supplied, and therefore are trusted for safe
deserialization; locate the calls near methods that restore checkpoint state
(e.g., _CheckpointState.setup_resume / _CheckpointState.from_folder or any load
calls inside setup_resume/save) and add the short justification comment directly
above each torch.load call.
🧹 Nitpick comments (1)
modelopt/torch/quantization/utils/layerwise_calib.py (1)

591-630: LGTM!

The save method correctly:

  1. Uses enable_weight_access_and_writeback context for managed-weight frameworks
  2. Moves all data to CPU before storage
  3. Has a defensive fallback for missing output_meta (line 617-618)

The fallback creates dummy metadata if output_meta is None, which could mask state-machine bugs. Consider logging a warning in this case.

Optional: Add warning for missing output_meta
         output_meta = getattr(layer._seq_calib, "output_meta", None)
         if output_meta is None:
+            print_rank_0(
+                f"Warning: layer {layer_idx} has no output_meta; using fallback. "
+                "This may indicate the layer was not run in 'run' mode."
+            )
             output_meta = LayerActivationCollector._extract_output_meta(torch.zeros(1))
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@modelopt/torch/quantization/utils/layerwise_calib.py` around lines 591 - 630,
In save (method save in layerwise_calib.py) add a warning when output_meta is
missing before calling LayerActivationCollector._extract_output_meta: detect if
getattr(layer._seq_calib, "output_meta", None) is None, log a warning (e.g.,
logger = logging.getLogger(__name__); logger.warning(...)) that includes
layer_idx and the layer identifier and states that dummy metadata is being
created, then proceed to call LayerActivationCollector._extract_output_meta;
this keeps behavior unchanged but surfaces the unexpected state-machine issue.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Inline comments:
In `@modelopt/torch/quantization/utils/layerwise_calib.py`:
- Line 555: Add an inline comment next to the torch.load call that sets
weights_only=False (the line loading next_inputs from next_inputs_path)
explaining that this file is produced internally by _save_layer and therefore
may contain non-tensor objects from the model's forward pass which require
pickle; explicitly state that the file source is trusted and why using
weights_only=False is safe in this context to satisfy the security guideline.
- Around line 544-546: Add an inline comment immediately above the
torch.load(...) call that sets weights_only=False (the line assigning meta =
torch.load(...)) explaining that this is safe because output_meta.pt is produced
internally by this module's _save_layer function (so it is not user-supplied and
controlled), that the file may contain arbitrary Python objects under the
("other", output) metadata path and therefore requires pickle deserialization,
and that this trusted-origin justification satisfies the SECURITY.md requirement
for using weights_only=False.

---

Outside diff comments:
In `@modelopt/torch/quantization/model_calib.py`:
- Around line 1566-1632: Add inline comments immediately before the two
torch.load(..., weights_only=False) calls in
modelopt.torch.quantization.utils.layerwise_calib (look around the
_CheckpointState usage and the methods that load "output_meta.pt" and
"next_inputs.pt") stating that these checkpoint files ("output_meta.pt" and
"next_inputs.pt") are generated and managed internally by the sequential
calibration process, are not user-supplied, and therefore are trusted for safe
deserialization; locate the calls near methods that restore checkpoint state
(e.g., _CheckpointState.setup_resume / _CheckpointState.from_folder or any load
calls inside setup_resume/save) and add the short justification comment directly
above each torch.load call.

---

Nitpick comments:
In `@modelopt/torch/quantization/utils/layerwise_calib.py`:
- Around line 591-630: In save (method save in layerwise_calib.py) add a warning
when output_meta is missing before calling
LayerActivationCollector._extract_output_meta: detect if
getattr(layer._seq_calib, "output_meta", None) is None, log a warning (e.g.,
logger = logging.getLogger(__name__); logger.warning(...)) that includes
layer_idx and the layer identifier and states that dummy metadata is being
created, then proceed to call LayerActivationCollector._extract_output_meta;
this keeps behavior unchanged but surfaces the unexpected state-machine issue.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro Plus

Run ID: 6e6e083b-89b1-4ebe-9b11-a051411fcf87

📥 Commits

Reviewing files that changed from the base of the PR and between b6c6ec3 and 6280846.

📒 Files selected for processing (19)
  • modelopt/torch/quantization/config.py
  • modelopt/torch/quantization/mode.py
  • modelopt/torch/quantization/model_calib.py
  • modelopt/torch/quantization/plugins/accelerate.py
  • modelopt/torch/quantization/plugins/huggingface.py
  • modelopt/torch/quantization/utils/__init__.py
  • modelopt/torch/quantization/utils/activation_collector.py
  • modelopt/torch/quantization/utils/calib_utils.py
  • modelopt/torch/quantization/utils/core_utils.py
  • modelopt/torch/quantization/utils/layerwise_calib.py
  • modelopt/torch/utils/network.py
  • tests/gpu/torch/quantization/plugins/test_accelerate_gpu.py
  • tests/gpu/torch/quantization/test_fsdp2.py
  • tests/gpu/torch/quantization/test_sequential_calibrate.py
  • tests/unit/torch/quantization/plugins/test_huggingface.py
  • tests/unit/torch/quantization/test_calib.py
  • tests/unit/torch/quantization/test_sequential_calibrate.py
  • tests/unit/torch/quantization/test_sequential_checkpoint.py
  • tests/unit/torch/quantization/test_utils.py
💤 Files with no reviewable changes (1)
  • modelopt/torch/quantization/utils/activation_collector.py

Comment thread modelopt/torch/quantization/utils/layerwise_calib.py
Comment thread modelopt/torch/quantization/utils/layerwise_calib.py
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🧹 Nitpick comments (1)
tests/unit/torch/quantization/test_sequential_calibrate.py (1)

585-590: Optional: guarantee cleanup with try/finally in restore test.

Use the same cleanup pattern as other tests so unpatch always runs if collection fails mid-test.

♻️ Suggested change
     collector = LayerActivationCollector(model)
     collector._patch_all_layers()
-    for layer in originals:
-        collector.get_input_activations(layer, forward_loop)
-    collector._unpatch_all_layers()
+    try:
+        for layer in originals:
+            collector.get_input_activations(layer, forward_loop)
+    finally:
+        collector._unpatch_all_layers()
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tests/unit/torch/quantization/test_sequential_calibrate.py` around lines 585
- 590, The test currently calls collector._patch_all_layers(), runs collection,
and then calls collector._unpatch_all_layers() but does not guarantee cleanup if
collection fails; wrap the collection calls in a try/finally so that
_unpatch_all_layers() is always invoked even on exceptions. Specifically, after
calling LayerActivationCollector(model) and collector._patch_all_layers(),
perform the loop that calls collector.get_input_activations(layer, forward_loop)
over originals inside a try block and call collector._unpatch_all_layers() in
the finally block to ensure restoration.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Nitpick comments:
In `@tests/unit/torch/quantization/test_sequential_calibrate.py`:
- Around line 585-590: The test currently calls collector._patch_all_layers(),
runs collection, and then calls collector._unpatch_all_layers() but does not
guarantee cleanup if collection fails; wrap the collection calls in a
try/finally so that _unpatch_all_layers() is always invoked even on exceptions.
Specifically, after calling LayerActivationCollector(model) and
collector._patch_all_layers(), perform the loop that calls
collector.get_input_activations(layer, forward_loop) over originals inside a try
block and call collector._unpatch_all_layers() in the finally block to ensure
restoration.

ℹ️ Review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro Plus

Run ID: b3aad426-8516-47cb-93d8-486cadc7717d

📥 Commits

Reviewing files that changed from the base of the PR and between 6280846 and 6515d4d.

📒 Files selected for processing (2)
  • tests/gpu/torch/quantization/plugins/test_accelerate_gpu.py
  • tests/unit/torch/quantization/test_sequential_calibrate.py
🚧 Files skipped from review as they are similar to previous changes (1)
  • tests/gpu/torch/quantization/plugins/test_accelerate_gpu.py

@realAsma realAsma force-pushed the asma/ptq-large-models branch from 6515d4d to 6a25fc2 Compare April 15, 2026 14:18
@realAsma realAsma changed the title Add checkpoint save/resume for sequential calibration Add layerwise calibration for large models Apr 15, 2026
@realAsma realAsma force-pushed the asma/ptq-large-models branch from 43d1888 to e0cda1b Compare April 15, 2026 19:58
Copy link
Copy Markdown
Contributor Author

@realAsma realAsma left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@sugunav14 Note on layerwise_calib.py: Most of this file is moved from modelopt/torch/quantization/utils/activation_collector.py (deleted in this PR). Git does not detect the rename because the file nearly doubled in size with new checkpoint/resume logic. To see the actual diff against the original, use:

git diff origin/main...HEAD -M10% -- modelopt/torch/quantization/utils/

This shows it as a rename with ~393 insertions and ~47 deletions, rather than a full 681-line new file + 335-line deletion.

Copy link
Copy Markdown
Contributor

@meenchen meenchen left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

1. Design — Method Guard Removed

[QUESTION] Calibration method guard removed with TODO
The old code asserted that only max and gptq methods could be used with sequential calibration. This PR removes that assertion with a TODO: "add a method guard here." This means any calibration method (AWQ, SmoothQuant, etc.) can now be silently used with layerwise, even if it doesn't support per-layer invocation. What's the plan for re-adding validation?

2. Correctness — Checkpoint Corruption Recovery

[SUGGESTION] Manifest corruption restarts silently from layer 0
_read_manifest() returns None on corrupt/missing JSON, causing from_folder() to treat it as a fresh run. If a user's 50-layer calibration checkpoint is partially corrupted, they'd restart from scratch without knowing why. Add a warning when the manifest exists but can't be parsed — "Checkpoint manifest found but unreadable; starting from layer 0."

3. Correctness — _SkipLayer Proxy Masks AttributeErrors

[SUGGESTION] _SkipLayer __getattr__ catches all AttributeErrors

def __getattr__(self, name):
    try:
        return super().__getattr__(name)
    except AttributeError:
        return getattr(object.__getattribute__(self, "_original"), name)

If the original layer also doesn't have the attribute, the error message will reference _original instead of the skip layer — confusing for debugging. Consider catching only the first super().__getattr__ and letting the second propagate naturally (which it does), but add context to the error.

4. Design — Distributed Checkpoint Blocked

[SUGGESTION] Multi-rank checkpointing blocked at runtime, should fail at config time
_CheckpointState.__init__ raises RuntimeError if dist.size() > 1. Users discover this only after model loading and initial setup. Validate this in the config validator alongside layerwise_checkpoint_dir — if running distributed, reject the config early.

Copy link
Copy Markdown
Collaborator

@cjluo-nv cjluo-nv left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All critical comments from previous reviews have been addressed:

Critical issues — all resolved:

  1. weights_only=False security: All 6 torch.load calls now have inline justification comments explaining the files are internally generated.
  2. _writeback_params_to_weights_map now uses state_dict(keep_vars=True) to write back both parameters and buffers.
  3. from_folder() validates num_layers mismatch between checkpoint manifest and current model.
  4. ✅ Config validator validate_layerwise_checkpoint_dir rejects layerwise_checkpoint_dir when layerwise=False.
  5. _get_execution_device_from_hook handles integer GPU ordinals with isinstance(dev, int).
  6. needs_checkpoint_path_update guards against non-dict algorithm values.
  7. ✅ Tests now clean up stale layer directories above last_completed_layer during crash simulation.
  8. weight_access_and_writeback_context now raises RuntimeError (instead of bare assert) for unsupported dual-hook layouts.
  9. ✅ Last layer's dummy output_meta has a clear comment explaining it's a placeholder.
  10. _layer_forward_loop uses default argument capture (_inputs=layer_inputs) for explicit binding.

Design decisions accepted by reviewers:

  • Author pushed back on restoring previous-layer state on exception (no retry mechanism exists), and CodeRabbit agreed.
  • Method guard removal: The layerwise field is on the base QuantizeAlgorithmConfig, so technically any method could set layerwise=True. However, the TODO is clearly marked, and in practice the typed config dispatch system means standard recipes won't hit this path for unsupported methods. This is acceptable risk with the TODO in place.

Code quality:

  • Excellent test coverage: unit tests for skip/swap/restore, checkpoint save/resume, GPU integration tests for CPU-offloaded models, FSDP2, GPTQ combinations.
  • The _SkipLayer pattern for replacing calibrated layers is clean and avoids framework hook overhead.
  • The persistent_materialization context manager is well-designed for both FSDP2 and accelerate.
  • The _forward_loop change to disable use_cache during calibration is a good correctness fix.
  • The calib_utils.py change to handle meta device in GPTQHelper.__init__ correctly supports CPU-offloaded models.

PR size: ~1900 lines is above the soft 1000-line guideline, but the changes are cohesive (rename + layerwise calibration + checkpoint + framework integration) and the author notes most of layerwise_calib.py is moved from the deleted activation_collector.py.

@realAsma realAsma requested a review from a team as a code owner April 16, 2026 13:39
Comment thread modelopt/torch/quantization/utils/layerwise_calib.py
Comment thread modelopt/torch/quantization/model_calib.py
Comment thread modelopt/torch/quantization/mode.py Outdated
Comment thread modelopt/torch/quantization/config.py Outdated
@realAsma realAsma force-pushed the asma/ptq-large-models branch from 482c883 to 5658381 Compare April 17, 2026 16:46
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@shengliangxu I am overwriting the layer wise save from the yaml recipe this way. Does this look goo to you?

Comment thread modelopt/torch/export/plugins/vllm_fakequant_hf.py
@realAsma
Copy link
Copy Markdown
Contributor Author

@meenchen @sugunav14 — added _supports_layerwise (per-mode opt-out) on BaseCalibrateModeDescriptor in modelopt/torch/quantization/mode.py. Defaults to True; SVDQuantModeDescriptor opts out. wrapped_calib_func now raises a clear ValueError when layerwise=True is requested on an unsupported mode. See commit baaf80f4.

@sychen52 sychen52 self-requested a review April 17, 2026 17:41
Comment thread modelopt/torch/export/plugins/vllm_fakequant_hf.py
@realAsma realAsma enabled auto-merge (squash) April 17, 2026 18:33
@realAsma realAsma disabled auto-merge April 17, 2026 18:43
@realAsma realAsma enabled auto-merge (squash) April 17, 2026 19:16
Copy link
Copy Markdown
Contributor

@Edwardf0t1 Edwardf0t1 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM in general. @realAsma It would be great to test GLM5.1 as well. cc @Fridah-nv

There's a draft PR for GLM5/5.1:
#985

Introduces layerwise calibration to enable PTQ on models that do not
fit in GPU memory, plus supporting infrastructure:

- New modelopt/torch/quantization/utils/layerwise_calib.py with layer-by-layer
  calibration and per-mode opt-out
- Disk offloading support in enable_weight_access_and_writeback
- Memory-efficient inplace fakequant export with disk offload
- Meta device detection in layerwise restore
- Fix meta tensor crash when exporting offloaded vLLM fakequant checkpoints
- Fix json.dumps sort_keys error with mixed int/str keys in quant_cfg
- Rename test_sequential_calibrate -> test_layerwise_calibrate (unit + gpu)
- Remove obsolete activation_collector.py

Signed-off-by: realAsma <akuriparambi@nvidia.com>
Max calibration is fast enough that checkpointing each layer adds
unnecessary I/O and disk usage. Comment explains why it is omitted.

Signed-off-by: realAsma <akuriparambi@nvidia.com>
Adds test_hf_vllm_export_offload covering the inplace_mem_efficient=True
path of export_hf_vllm_fq_checkpoint on a CPU-offloaded tiny LLaMA. The
test asserts the inplace path actually mutates offloaded layer weights
(falsifying a silent fall-through to the copy path), that the reloaded
HF model matches a deepcopy+fold_weight reference built inside
enable_weight_access_and_writeback (materializes meta tensors before
folding), and that the saved quantizer state preserves input amaxes.

Also adds a CHANGELOG.rst bullet under 0.44 New Features describing the
layerwise calibration feature and linking to the experts-only recipe.

Signed-off-by: realAsma <akuriparambi@nvidia.com>
Show the two recipes separately: first the plain layerwise recipe for
the base feature, then the intermediate-progress-saving detail with the
GPTQ recipe that demonstrates it.

Signed-off-by: realAsma <akuriparambi@nvidia.com>
Monkey-patch save_pretrained to a no-op so the test exercises only the
PR's new inplace_mem_efficient=True contribution (per-layer
enable_weight_access_and_writeback dispatch + inplace fake-quant
writeback) without tripping transformers load_offloaded_parameter on
SequentialHook — a pre-existing upstream limitation unrelated to this
PR's new code.

Broaden the folded-weights assertion to cover all decoder layers (not
just the offloaded layer 0) so regressions in the on-GPU inplace path
are also caught. The vllm_fq_modelopt_state.pth contents are still
asserted since torch.save happens before save_pretrained.

Signed-off-by: realAsma <akuriparambi@nvidia.com>
@realAsma realAsma force-pushed the asma/ptq-large-models branch from 73980e5 to a33c01f Compare April 17, 2026 23:30
Comment on lines 219 to 230
modelopt_state = mto.modelopt_state(model)
# ``modelopt_state`` may be stale if another mode (e.g. calibrate) ran last. Rebuild
# ``quantizer_state`` and drop disabled weight quantizer entries (weights already folded).
qstate = quantizer_state(model)
for key in list(qstate):
if key.endswith("weight_quantizer") and qstate[key].get("_disabled"):
qstate.pop(key)

for mode_str, m_state in modelopt_state.get("modelopt_state_dict", []):
if mode_str == "quantize" and "metadata" in m_state:
m_state["metadata"]["quantizer_state"] = qstate
break
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@kinjalpatel27 why are we doing this specially for quantize mode?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we need to remove disabled weight_quantizer from metadata, otherwise the reload creates an issue.

@realAsma realAsma merged commit 2d868d3 into main Apr 18, 2026
45 checks passed
@realAsma realAsma deleted the asma/ptq-large-models branch April 18, 2026 00:32
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

8 participants