Add layerwise calibration for large models#1251
Conversation
|
Auto-sync is disabled for draft pull requests in this repository. Workflows must be run manually. Contributors can view more details about this message here. |
|
|
Note Reviews pausedIt looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the Use the following commands to manage reviews:
Use the checkboxes below for quick actions:
📝 WalkthroughWalkthroughRename calibration mode flag Changes
Sequence Diagram(s)sequenceDiagram
participant Entrypoint as Calibration Entrypoint
participant Model as Model
participant Collector as LayerActivationCollector
participant Forward as ForwardLoop
participant Checkpoint as CheckpointStore
participant GPTQ as GPTQ Updater
Entrypoint->>Collector: attach/discover layers
Entrypoint->>Checkpoint: detect_resume_point(checkpoint_dir)
alt resume available
Checkpoint-->>Collector: restore output_meta + next_inputs
end
loop for layer in start_layer..N
Entrypoint->>Collector: set mode -> capture(layer)
Entrypoint->>Forward: run forward (captures inputs / EarlyStop)
Collector-->>Entrypoint: captured inputs
Entrypoint->>Collector: set mode -> run(layer)
Entrypoint->>Forward: replay captured inputs -> outputs
Entrypoint->>Checkpoint: save(layer_weights, quantizer_state, output_meta, next_inputs)
alt GPTQ enabled
Entrypoint->>GPTQ: update_weights_for_layer(...)
end
end
Entrypoint->>Checkpoint: full_restore(all_layers)
Entrypoint->>Collector: unpatch and cleanup
Estimated code review effort🎯 4 (Complex) | ⏱️ ~75 minutes Important Pre-merge checks failedPlease resolve all errors before merging. Addressing warnings is optional. ❌ Failed checks (1 error, 1 warning)
✅ Passed checks (2 passed)
✨ Finishing Touches🧪 Generate unit tests (beta)
Comment |
8eabe76 to
6ec3721
Compare
Codecov Report❌ Patch coverage is Additional details and impacted files@@ Coverage Diff @@
## main #1251 +/- ##
==========================================
+ Coverage 72.52% 76.67% +4.15%
==========================================
Files 459 459
Lines 48664 48975 +311
==========================================
+ Hits 35292 37552 +2260
+ Misses 13372 11423 -1949
Flags with carried forward coverage won't be shown. Click here to find out more. ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
6ec3721 to
8af3655
Compare
8af3655 to
6280846
Compare
cjluo-nv
left a comment
There was a problem hiding this comment.
This is a substantial PR (~1500 lines) that adds checkpoint save/resume for sequential calibration, extends support to FSDP2 and accelerate-offloaded models, and renames activation_collector.py → layerwise_calib.py. The changes are cohesive and well-tested (unit + GPU tests for checkpoint, resume, offload, FSDP2 scenarios).
Key issues found:
-
Removed guard on sequential calibration methods — The assertion restricting sequential calibration to
maxandgptqwas removed without replacement. Methods likeawq,smoothquant, andsvdquantoperate on the full model (not per-layer) and will break silently or produce incorrect results when used withuse_sequential=True. -
weights_only=Falsesecurity concern —torch.load(..., weights_only=False)is used for loading checkpoints, which can execute arbitrary code. While the checkpoints are locally generated, this is flagged by security scanners and should useweights_only=Truewhere possible.
Minor observations:
- PR size is above ~1000 lines but the changes are cohesive and hard to split
- Good test coverage for the new functionality
- The
temporarily_remove_accelerate_hookrewrite is a nice improvement avoiding theinit_hookpitfall _writeback_params_to_weights_mapproperly handles all parameters (not justweight)- FSDP2 context manager correctly generalized to handle all DTensor parameters
There was a problem hiding this comment.
Actionable comments posted: 2
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
modelopt/torch/quantization/model_calib.py (1)
1566-1632:⚠️ Potential issue | 🔴 CriticalAdd inline comments to
torch.load(..., weights_only=False)calls inlayerwise_calib.py.Per SECURITY.md and the coding guidelines,
torch.load(..., weights_only=False)must include an inline comment documenting why the file is internally-generated/trusted and safe to deserialize. Lines 545 and 555 inmodelopt/torch/quantization/utils/layerwise_calib.pyneed this justification:
- Line 545: Loading
output_meta.pt- Line 555: Loading
next_inputs.ptAdd a comment before each call explaining these checkpoint files are generated and managed internally by the sequential calibration process, confirming they are trusted sources.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@modelopt/torch/quantization/model_calib.py` around lines 1566 - 1632, Add inline comments immediately before the two torch.load(..., weights_only=False) calls in modelopt.torch.quantization.utils.layerwise_calib (look around the _CheckpointState usage and the methods that load "output_meta.pt" and "next_inputs.pt") stating that these checkpoint files ("output_meta.pt" and "next_inputs.pt") are generated and managed internally by the sequential calibration process, are not user-supplied, and therefore are trusted for safe deserialization; locate the calls near methods that restore checkpoint state (e.g., _CheckpointState.setup_resume / _CheckpointState.from_folder or any load calls inside setup_resume/save) and add the short justification comment directly above each torch.load call.
🧹 Nitpick comments (1)
modelopt/torch/quantization/utils/layerwise_calib.py (1)
591-630: LGTM!The
savemethod correctly:
- Uses
enable_weight_access_and_writebackcontext for managed-weight frameworks- Moves all data to CPU before storage
- Has a defensive fallback for missing
output_meta(line 617-618)The fallback creates dummy metadata if
output_metais None, which could mask state-machine bugs. Consider logging a warning in this case.Optional: Add warning for missing output_meta
output_meta = getattr(layer._seq_calib, "output_meta", None) if output_meta is None: + print_rank_0( + f"Warning: layer {layer_idx} has no output_meta; using fallback. " + "This may indicate the layer was not run in 'run' mode." + ) output_meta = LayerActivationCollector._extract_output_meta(torch.zeros(1))🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@modelopt/torch/quantization/utils/layerwise_calib.py` around lines 591 - 630, In save (method save in layerwise_calib.py) add a warning when output_meta is missing before calling LayerActivationCollector._extract_output_meta: detect if getattr(layer._seq_calib, "output_meta", None) is None, log a warning (e.g., logger = logging.getLogger(__name__); logger.warning(...)) that includes layer_idx and the layer identifier and states that dummy metadata is being created, then proceed to call LayerActivationCollector._extract_output_meta; this keeps behavior unchanged but surfaces the unexpected state-machine issue.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@modelopt/torch/quantization/utils/layerwise_calib.py`:
- Line 555: Add an inline comment next to the torch.load call that sets
weights_only=False (the line loading next_inputs from next_inputs_path)
explaining that this file is produced internally by _save_layer and therefore
may contain non-tensor objects from the model's forward pass which require
pickle; explicitly state that the file source is trusted and why using
weights_only=False is safe in this context to satisfy the security guideline.
- Around line 544-546: Add an inline comment immediately above the
torch.load(...) call that sets weights_only=False (the line assigning meta =
torch.load(...)) explaining that this is safe because output_meta.pt is produced
internally by this module's _save_layer function (so it is not user-supplied and
controlled), that the file may contain arbitrary Python objects under the
("other", output) metadata path and therefore requires pickle deserialization,
and that this trusted-origin justification satisfies the SECURITY.md requirement
for using weights_only=False.
---
Outside diff comments:
In `@modelopt/torch/quantization/model_calib.py`:
- Around line 1566-1632: Add inline comments immediately before the two
torch.load(..., weights_only=False) calls in
modelopt.torch.quantization.utils.layerwise_calib (look around the
_CheckpointState usage and the methods that load "output_meta.pt" and
"next_inputs.pt") stating that these checkpoint files ("output_meta.pt" and
"next_inputs.pt") are generated and managed internally by the sequential
calibration process, are not user-supplied, and therefore are trusted for safe
deserialization; locate the calls near methods that restore checkpoint state
(e.g., _CheckpointState.setup_resume / _CheckpointState.from_folder or any load
calls inside setup_resume/save) and add the short justification comment directly
above each torch.load call.
---
Nitpick comments:
In `@modelopt/torch/quantization/utils/layerwise_calib.py`:
- Around line 591-630: In save (method save in layerwise_calib.py) add a warning
when output_meta is missing before calling
LayerActivationCollector._extract_output_meta: detect if
getattr(layer._seq_calib, "output_meta", None) is None, log a warning (e.g.,
logger = logging.getLogger(__name__); logger.warning(...)) that includes
layer_idx and the layer identifier and states that dummy metadata is being
created, then proceed to call LayerActivationCollector._extract_output_meta;
this keeps behavior unchanged but surfaces the unexpected state-machine issue.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro Plus
Run ID: 6e6e083b-89b1-4ebe-9b11-a051411fcf87
📒 Files selected for processing (19)
modelopt/torch/quantization/config.pymodelopt/torch/quantization/mode.pymodelopt/torch/quantization/model_calib.pymodelopt/torch/quantization/plugins/accelerate.pymodelopt/torch/quantization/plugins/huggingface.pymodelopt/torch/quantization/utils/__init__.pymodelopt/torch/quantization/utils/activation_collector.pymodelopt/torch/quantization/utils/calib_utils.pymodelopt/torch/quantization/utils/core_utils.pymodelopt/torch/quantization/utils/layerwise_calib.pymodelopt/torch/utils/network.pytests/gpu/torch/quantization/plugins/test_accelerate_gpu.pytests/gpu/torch/quantization/test_fsdp2.pytests/gpu/torch/quantization/test_sequential_calibrate.pytests/unit/torch/quantization/plugins/test_huggingface.pytests/unit/torch/quantization/test_calib.pytests/unit/torch/quantization/test_sequential_calibrate.pytests/unit/torch/quantization/test_sequential_checkpoint.pytests/unit/torch/quantization/test_utils.py
💤 Files with no reviewable changes (1)
- modelopt/torch/quantization/utils/activation_collector.py
There was a problem hiding this comment.
🧹 Nitpick comments (1)
tests/unit/torch/quantization/test_sequential_calibrate.py (1)
585-590: Optional: guarantee cleanup withtry/finallyin restore test.Use the same cleanup pattern as other tests so unpatch always runs if collection fails mid-test.
♻️ Suggested change
collector = LayerActivationCollector(model) collector._patch_all_layers() - for layer in originals: - collector.get_input_activations(layer, forward_loop) - collector._unpatch_all_layers() + try: + for layer in originals: + collector.get_input_activations(layer, forward_loop) + finally: + collector._unpatch_all_layers()🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@tests/unit/torch/quantization/test_sequential_calibrate.py` around lines 585 - 590, The test currently calls collector._patch_all_layers(), runs collection, and then calls collector._unpatch_all_layers() but does not guarantee cleanup if collection fails; wrap the collection calls in a try/finally so that _unpatch_all_layers() is always invoked even on exceptions. Specifically, after calling LayerActivationCollector(model) and collector._patch_all_layers(), perform the loop that calls collector.get_input_activations(layer, forward_loop) over originals inside a try block and call collector._unpatch_all_layers() in the finally block to ensure restoration.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Nitpick comments:
In `@tests/unit/torch/quantization/test_sequential_calibrate.py`:
- Around line 585-590: The test currently calls collector._patch_all_layers(),
runs collection, and then calls collector._unpatch_all_layers() but does not
guarantee cleanup if collection fails; wrap the collection calls in a
try/finally so that _unpatch_all_layers() is always invoked even on exceptions.
Specifically, after calling LayerActivationCollector(model) and
collector._patch_all_layers(), perform the loop that calls
collector.get_input_activations(layer, forward_loop) over originals inside a try
block and call collector._unpatch_all_layers() in the finally block to ensure
restoration.
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro Plus
Run ID: b3aad426-8516-47cb-93d8-486cadc7717d
📒 Files selected for processing (2)
tests/gpu/torch/quantization/plugins/test_accelerate_gpu.pytests/unit/torch/quantization/test_sequential_calibrate.py
🚧 Files skipped from review as they are similar to previous changes (1)
- tests/gpu/torch/quantization/plugins/test_accelerate_gpu.py
6515d4d to
6a25fc2
Compare
43d1888 to
e0cda1b
Compare
realAsma
left a comment
There was a problem hiding this comment.
@sugunav14 Note on layerwise_calib.py: Most of this file is moved from modelopt/torch/quantization/utils/activation_collector.py (deleted in this PR). Git does not detect the rename because the file nearly doubled in size with new checkpoint/resume logic. To see the actual diff against the original, use:
git diff origin/main...HEAD -M10% -- modelopt/torch/quantization/utils/
This shows it as a rename with ~393 insertions and ~47 deletions, rather than a full 681-line new file + 335-line deletion.
meenchen
left a comment
There was a problem hiding this comment.
1. Design — Method Guard Removed
[QUESTION] Calibration method guard removed with TODO
The old code asserted that only max and gptq methods could be used with sequential calibration. This PR removes that assertion with a TODO: "add a method guard here." This means any calibration method (AWQ, SmoothQuant, etc.) can now be silently used with layerwise, even if it doesn't support per-layer invocation. What's the plan for re-adding validation?
2. Correctness — Checkpoint Corruption Recovery
[SUGGESTION] Manifest corruption restarts silently from layer 0
_read_manifest() returns None on corrupt/missing JSON, causing from_folder() to treat it as a fresh run. If a user's 50-layer calibration checkpoint is partially corrupted, they'd restart from scratch without knowing why. Add a warning when the manifest exists but can't be parsed — "Checkpoint manifest found but unreadable; starting from layer 0."
3. Correctness — _SkipLayer Proxy Masks AttributeErrors
[SUGGESTION] _SkipLayer __getattr__ catches all AttributeErrors
def __getattr__(self, name):
try:
return super().__getattr__(name)
except AttributeError:
return getattr(object.__getattribute__(self, "_original"), name)If the original layer also doesn't have the attribute, the error message will reference _original instead of the skip layer — confusing for debugging. Consider catching only the first super().__getattr__ and letting the second propagate naturally (which it does), but add context to the error.
4. Design — Distributed Checkpoint Blocked
[SUGGESTION] Multi-rank checkpointing blocked at runtime, should fail at config time
_CheckpointState.__init__ raises RuntimeError if dist.size() > 1. Users discover this only after model loading and initial setup. Validate this in the config validator alongside layerwise_checkpoint_dir — if running distributed, reject the config early.
cjluo-nv
left a comment
There was a problem hiding this comment.
All critical comments from previous reviews have been addressed:
Critical issues — all resolved:
- ✅
weights_only=Falsesecurity: All 6torch.loadcalls now have inline justification comments explaining the files are internally generated. - ✅
_writeback_params_to_weights_mapnow usesstate_dict(keep_vars=True)to write back both parameters and buffers. - ✅
from_folder()validatesnum_layersmismatch between checkpoint manifest and current model. - ✅ Config validator
validate_layerwise_checkpoint_dirrejectslayerwise_checkpoint_dirwhenlayerwise=False. - ✅
_get_execution_device_from_hookhandles integer GPU ordinals withisinstance(dev, int). - ✅
needs_checkpoint_path_updateguards against non-dict algorithm values. - ✅ Tests now clean up stale layer directories above
last_completed_layerduring crash simulation. - ✅
weight_access_and_writeback_contextnow raisesRuntimeError(instead of bare assert) for unsupported dual-hook layouts. - ✅ Last layer's dummy
output_metahas a clear comment explaining it's a placeholder. - ✅
_layer_forward_loopuses default argument capture (_inputs=layer_inputs) for explicit binding.
Design decisions accepted by reviewers:
- Author pushed back on restoring previous-layer state on exception (no retry mechanism exists), and CodeRabbit agreed.
- Method guard removal: The
layerwisefield is on the baseQuantizeAlgorithmConfig, so technically any method could setlayerwise=True. However, the TODO is clearly marked, and in practice the typed config dispatch system means standard recipes won't hit this path for unsupported methods. This is acceptable risk with the TODO in place.
Code quality:
- Excellent test coverage: unit tests for skip/swap/restore, checkpoint save/resume, GPU integration tests for CPU-offloaded models, FSDP2, GPTQ combinations.
- The
_SkipLayerpattern for replacing calibrated layers is clean and avoids framework hook overhead. - The
persistent_materializationcontext manager is well-designed for both FSDP2 and accelerate. - The
_forward_loopchange to disableuse_cacheduring calibration is a good correctness fix. - The
calib_utils.pychange to handlemetadevice inGPTQHelper.__init__correctly supports CPU-offloaded models.
PR size: ~1900 lines is above the soft 1000-line guideline, but the changes are cohesive (rename + layerwise calibration + checkpoint + framework integration) and the author notes most of layerwise_calib.py is moved from the deleted activation_collector.py.
482c883 to
5658381
Compare
There was a problem hiding this comment.
@shengliangxu I am overwriting the layer wise save from the yaml recipe this way. Does this look goo to you?
|
@meenchen @sugunav14 — added |
Edwardf0t1
left a comment
There was a problem hiding this comment.
LGTM in general. @realAsma It would be great to test GLM5.1 as well. cc @Fridah-nv
There's a draft PR for GLM5/5.1:
#985
Introduces layerwise calibration to enable PTQ on models that do not fit in GPU memory, plus supporting infrastructure: - New modelopt/torch/quantization/utils/layerwise_calib.py with layer-by-layer calibration and per-mode opt-out - Disk offloading support in enable_weight_access_and_writeback - Memory-efficient inplace fakequant export with disk offload - Meta device detection in layerwise restore - Fix meta tensor crash when exporting offloaded vLLM fakequant checkpoints - Fix json.dumps sort_keys error with mixed int/str keys in quant_cfg - Rename test_sequential_calibrate -> test_layerwise_calibrate (unit + gpu) - Remove obsolete activation_collector.py Signed-off-by: realAsma <akuriparambi@nvidia.com>
Max calibration is fast enough that checkpointing each layer adds unnecessary I/O and disk usage. Comment explains why it is omitted. Signed-off-by: realAsma <akuriparambi@nvidia.com>
Adds test_hf_vllm_export_offload covering the inplace_mem_efficient=True path of export_hf_vllm_fq_checkpoint on a CPU-offloaded tiny LLaMA. The test asserts the inplace path actually mutates offloaded layer weights (falsifying a silent fall-through to the copy path), that the reloaded HF model matches a deepcopy+fold_weight reference built inside enable_weight_access_and_writeback (materializes meta tensors before folding), and that the saved quantizer state preserves input amaxes. Also adds a CHANGELOG.rst bullet under 0.44 New Features describing the layerwise calibration feature and linking to the experts-only recipe. Signed-off-by: realAsma <akuriparambi@nvidia.com>
Show the two recipes separately: first the plain layerwise recipe for the base feature, then the intermediate-progress-saving detail with the GPTQ recipe that demonstrates it. Signed-off-by: realAsma <akuriparambi@nvidia.com>
Monkey-patch save_pretrained to a no-op so the test exercises only the PR's new inplace_mem_efficient=True contribution (per-layer enable_weight_access_and_writeback dispatch + inplace fake-quant writeback) without tripping transformers load_offloaded_parameter on SequentialHook — a pre-existing upstream limitation unrelated to this PR's new code. Broaden the folded-weights assertion to cover all decoder layers (not just the offloaded layer 0) so regressions in the on-GPU inplace path are also caught. The vllm_fq_modelopt_state.pth contents are still asserted since torch.save happens before save_pretrained. Signed-off-by: realAsma <akuriparambi@nvidia.com>
73980e5 to
a33c01f
Compare
| modelopt_state = mto.modelopt_state(model) | ||
| # ``modelopt_state`` may be stale if another mode (e.g. calibrate) ran last. Rebuild | ||
| # ``quantizer_state`` and drop disabled weight quantizer entries (weights already folded). | ||
| qstate = quantizer_state(model) | ||
| for key in list(qstate): | ||
| if key.endswith("weight_quantizer") and qstate[key].get("_disabled"): | ||
| qstate.pop(key) | ||
|
|
||
| for mode_str, m_state in modelopt_state.get("modelopt_state_dict", []): | ||
| if mode_str == "quantize" and "metadata" in m_state: | ||
| m_state["metadata"]["quantizer_state"] = qstate | ||
| break |
There was a problem hiding this comment.
@kinjalpatel27 why are we doing this specially for quantize mode?
There was a problem hiding this comment.
we need to remove disabled weight_quantizer from metadata, otherwise the reload creates an issue.
Summary
Adds performant layerwise calibration for quantizing large models (e.g. DeepSeek-R1 671B) that don't fit entirely on GPU. (Example commands)
batch_size=16andgpu_max_mem_percentage=0.5.sequential_calibrate→layerwise_calibratefor clarity.Design details
The existing layerwise state machine (skip/run/capture) already processes one layer at a time, but skip-mode layers still kept their parameters in the ModuleList — so frameworks transferred all weights every forward pass. This PR adds:
_SkipLayer: replaces fully-calibrated layers with a parameter-free dummy in the ModuleList, so framework hooks have nothing to transferpersistent_materialization: keeps the active layer on GPU for the entire calibration step, avoiding repeated offload/reload cyclesCheckpoint save is per-layer; restore is bulk — quantizer state and weights for layers 0..K-1 are restored once at the end of calibration, keeping the hot path fast.
Example commands
Qwen3-8B (NVFP4+GPTQ, single GPU):
python hf_ptq.py \ --pyt_ckpt_path Qwen/Qwen3-8B \ --recipe nvfp4_gptq_sequential.yaml \ --calib_size 64 \ --batch_size 16 \ --dataset cnn_dailymail \ --export_path outputs/qwen3_8b_nvfp4_gptq_seq \ --gpu_max_mem_percentage 0.5 \ --use_seq_device_map \ --vllm_fakequant_exportDeepSeek-R1 (NVFP4 experts-only + FP8 KV, 8×80GB):
python hf_ptq.py \ --model unsloth/DeepSeek-R1-0528-BF16 \ --recipe ../../modelopt_recipes/general/ptq/nvfp4_experts_only-fp8_kv.yaml \ --dataset cnn_dailymail \ --batch_size 16 \ --calib_size 64 \ --calib_seq 512 \ --gpu_max_mem_percentage 0.5 \ --use_seq_device_map \ --trust_remote_code \ --export_path output/DeepSeek-R1-BF16-nvfp4-experts-only-fp8-kv \ --vllm_fakequant_exportExample: NVFP4+GPTQ layerwise calibration on Qwen3-8B (36 layers, single GPU — 20 GB peak)
Initial run (killed after layer 11):
Resumed run (picks up from layer 11, finishes all 36):
TODO
Test plan
tests/unit/torch/quantization/test_layerwise_calibrate.py— unit tests for skip/swap/restoretests/unit/torch/quantization/test_sequential_checkpoint.py— checkpoint save/resume correctnesstests/gpu/torch/quantization/plugins/test_accelerate_gpu.py— CPU-offloaded layerwise + GPTQ + checkpoint resumetests/gpu/torch/quantization/test_fsdp2.py— FSDP2 layerwise calibrationVerified