From 2a0feaf4aa28b1ff2f286aef66093ebfd04d0ef2 Mon Sep 17 00:00:00 2001 From: Fridah-nv <201670829+Fridah-nv@users.noreply.github.com> Date: Mon, 16 Mar 2026 03:28:53 +0000 Subject: [PATCH 1/4] Add calib_include/exclude_modules to calibration algorithms Adds calib_include_modules and calib_exclude_modules fields to QuantizeAlgorithmConfig so users can restrict any calibration algorithm (max, mse, smoothquant, awq, ...) to a subset of the model's layers. Filtering is applied via the new filter_calib_modules context manager, which temporarily disables TensorQuantizer instances in non-matching modules while preserving their pre-existing _amax values. Also exposes --calib_include_modules / --calib_exclude_modules CLI args in the hf_ptq.py example and wires them through build_quant_cfg in example_utils.py. Co-Authored-By: Claude Sonnet 4.6 Signed-off-by: Fridah-nv <201670829+Fridah-nv@users.noreply.github.com> --- examples/llm_ptq/example_utils.py | 17 ++ examples/llm_ptq/hf_ptq.py | 32 ++++ modelopt/torch/quantization/config.py | 19 +++ modelopt/torch/quantization/mode.py | 36 +++-- modelopt/torch/quantization/model_calib.py | 54 +++++++ tests/unit/torch/quantization/test_calib.py | 169 +++++++++++++++++++- 6 files changed, 310 insertions(+), 17 deletions(-) diff --git a/examples/llm_ptq/example_utils.py b/examples/llm_ptq/example_utils.py index 50ac51aace..db377c51fc 100755 --- a/examples/llm_ptq/example_utils.py +++ b/examples/llm_ptq/example_utils.py @@ -202,6 +202,8 @@ def build_quant_cfg( quant_cfg_choices, kv_quant_cfg_choices, moe_calib_experts_ratio: float | None = None, + calib_exclude_modules: list[str] | None = None, + calib_include_modules: list[str] | None = None, ) -> dict[str, Any]: quant_cfg = {} assert qformat in quant_cfg_choices, ( @@ -247,6 +249,21 @@ def build_quant_cfg( f"Quantization algorithm: {quant_cfg['algorithm']} does not support setting moe_calib_experts_ratio" ) + if calib_exclude_modules or calib_include_modules: + if isinstance(quant_cfg["algorithm"], str): + quant_cfg["algorithm"] = {"method": quant_cfg["algorithm"]} + elif isinstance(quant_cfg["algorithm"], dict): + pass + else: + warnings.warn( + f"Quantization algorithm: {quant_cfg['algorithm']} does not support calib_exclude/include_modules" + ) + if isinstance(quant_cfg["algorithm"], dict): + if calib_exclude_modules: + quant_cfg["algorithm"]["calib_exclude_modules"] = calib_exclude_modules + if calib_include_modules: + quant_cfg["algorithm"]["calib_include_modules"] = calib_include_modules + # Gemma 7B has accuracy regression using alpha 1. We set 0.5 instead. if model_type == "gemma" and "int8_sq" in qformat: quant_cfg["algorithm"] = {"method": "smoothquant", "alpha": 0.5} diff --git a/examples/llm_ptq/hf_ptq.py b/examples/llm_ptq/hf_ptq.py index fd35a53f27..ef438374e6 100755 --- a/examples/llm_ptq/hf_ptq.py +++ b/examples/llm_ptq/hf_ptq.py @@ -925,6 +925,8 @@ def quantize_main( QUANT_CFG_CHOICES, KV_QUANT_CFG_CHOICES, args.moe_calib_experts_ratio, + calib_exclude_modules=args.calib_exclude_modules, + calib_include_modules=args.calib_include_modules, ) # Exclude MTP layers from quantization if detected (e.g., GLM-4.7's layer 92) @@ -1165,6 +1167,26 @@ def parse_args() -> argparse.Namespace: "Does not impact non-MOE models." ), ) + parser.add_argument( + "--calib_exclude_modules", + type=str, + default=None, + help=( + "Comma-separated list of fnmatch patterns for modules to exclude from calibration. " + "Matching modules retain their pre-existing calibration state. " + "Example: --calib_exclude_modules '*lm_head*,*vision*'" + ), + ) + parser.add_argument( + "--calib_include_modules", + type=str, + default=None, + help=( + "Comma-separated list of fnmatch patterns for modules to include in calibration. " + "Only matching modules are calibrated; all others are skipped. " + "Example: --calib_include_modules '*layers.0*,*layers.1*'" + ), + ) args = parser.parse_args() if not (0.0 < args.moe_calib_experts_ratio <= 1.0): @@ -1224,4 +1246,14 @@ def main(args: argparse.Namespace): args.dataset = args.dataset.split(",") if isinstance(args.dataset, str) else args.dataset args.calib_size = [int(num_sample) for num_sample in args.calib_size.split(",")] + args.calib_exclude_modules = ( + [p.strip() for p in args.calib_exclude_modules.split(",")] + if args.calib_exclude_modules + else None + ) + args.calib_include_modules = ( + [p.strip() for p in args.calib_include_modules.split(",")] + if args.calib_include_modules + else None + ) main(args) diff --git a/modelopt/torch/quantization/config.py b/modelopt/torch/quantization/config.py index 140d7c080a..ffc5ab28a4 100644 --- a/modelopt/torch/quantization/config.py +++ b/modelopt/torch/quantization/config.py @@ -1084,6 +1084,25 @@ class QuantizeAlgorithmConfig(ModeloptBaseConfig): ), ) + calib_include_modules: list[str] | None = ModeloptField( + default=None, + title="Patterns of modules to include in calibration.", + description=( + "If provided, only modules whose names match at least one of the fnmatch patterns are " + "calibrated. Modules that do not match any pattern are skipped and retain their " + "pre-existing calibration state." + ), + ) + + calib_exclude_modules: list[str] | None = ModeloptField( + default=None, + title="Patterns of modules to exclude from calibration.", + description=( + "If provided, modules whose names match at least one of the fnmatch patterns are " + "skipped during calibration and retain their pre-existing calibration state." + ), + ) + class MaxCalibConfig(QuantizeAlgorithmConfig): """The config for max calibration algorithm. diff --git a/modelopt/torch/quantization/mode.py b/modelopt/torch/quantization/mode.py index 1fbe654068..de7161e6a0 100644 --- a/modelopt/torch/quantization/mode.py +++ b/modelopt/torch/quantization/mode.py @@ -59,6 +59,7 @@ ) from .model_calib import ( awq, + filter_calib_modules, gptq_lite, local_hessian_calibrate, max_calibrate, @@ -223,6 +224,8 @@ def wrapped_calib_func( kwargs = config.model_dump() method = kwargs.pop("method") sequential = kwargs.pop("use_sequential", False) + calib_include_modules = kwargs.pop("calib_include_modules", None) + calib_exclude_modules = kwargs.pop("calib_exclude_modules", None) if method is not None and "awq" in method: # For backward compatibility kwargs["algorithm"] = method @@ -243,22 +246,23 @@ def wrapped_calib_func( module._moe_count_expert_calib_tokens = True if func is not None: - if sequential: - if forward_loop is None: - raise ValueError("forward_loop is required for calibration but got None.") - assert method in ["max"], ( - f"Sequential calibration currently only supports max calibration, got {method}" - ) - # Wrap with sequential processing - sequential_calibrate( - model, - forward_loop=forward_loop, - calib_func=func, - **kwargs, - ) - else: - # Direct calibration (existing behavior) - func(model, forward_loop=forward_loop, **kwargs) + with filter_calib_modules(model, calib_include_modules, calib_exclude_modules): + if sequential: + if forward_loop is None: + raise ValueError("forward_loop is required for calibration but got None.") + assert method in ["max"], ( + f"Sequential calibration currently only supports max calibration, got {method}" + ) + # Wrap with sequential processing + sequential_calibrate( + model, + forward_loop=forward_loop, + calib_func=func, + **kwargs, + ) + else: + # Direct calibration (existing behavior) + func(model, forward_loop=forward_loop, **kwargs) # Lets get the latest metadata for the quantizer states metadata = {} diff --git a/modelopt/torch/quantization/model_calib.py b/modelopt/torch/quantization/model_calib.py index 5618fa413f..caa6cf95bb 100644 --- a/modelopt/torch/quantization/model_calib.py +++ b/modelopt/torch/quantization/model_calib.py @@ -15,6 +15,8 @@ """Calibration utilities.""" +import contextlib +import fnmatch import math import os import warnings @@ -56,6 +58,7 @@ __all__ = [ "awq", + "filter_calib_modules", "local_hessian_calibrate", "max_calibrate", "sequential_calibrate", @@ -64,6 +67,57 @@ ] +@contextlib.contextmanager +def filter_calib_modules( + model: nn.Module, + include_modules: list[str] | None = None, + exclude_modules: list[str] | None = None, +): + """Context manager to restrict calibration to a subset of the model's modules. + + Temporarily disables quantizers in modules that do not pass the include/exclude filters. + Disabled quantizers retain their pre-existing ``_amax`` values because + :meth:`TensorQuantizer.disable` does not clear ``_amax``. + + Args: + model: The quantized model. + include_modules: If provided, only modules whose names match at least one fnmatch pattern + are calibrated. All others are skipped. + exclude_modules: If provided, modules whose names match at least one fnmatch pattern are + skipped. + + Example:: + + with filter_calib_modules(model, exclude_modules=["*lm_head*"]): + mse_calibrate(model, forward_loop) + """ + if include_modules is None and exclude_modules is None: + yield + return + + def _should_calibrate(name: str) -> bool: + if include_modules is not None: + if not any(fnmatch.fnmatch(name, p) for p in include_modules): + return False + if exclude_modules is not None: + if any(fnmatch.fnmatch(name, p) for p in exclude_modules): + return False + return True + + disabled = [] + for name, module in model.named_modules(): + if is_quantized_linear(module) and not _should_calibrate(name): + for _, child in module.named_modules(): + if isinstance(child, TensorQuantizer) and not child._disabled: + child.disable() + disabled.append(child) + try: + yield + finally: + for q in disabled: + q.enable() + + def weight_only_quantize(model: nn.Module): """Just quantize the weights of the model.""" name_to_module = dict(model.named_modules()) diff --git a/tests/unit/torch/quantization/test_calib.py b/tests/unit/torch/quantization/test_calib.py index 7bc78c40ec..a390e4030a 100644 --- a/tests/unit/torch/quantization/test_calib.py +++ b/tests/unit/torch/quantization/test_calib.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Unittests for AWQ and SVDQuant""" +"""Unittests for AWQ, SVDQuant, and calibration module filtering""" from functools import partial @@ -26,6 +26,8 @@ from modelopt.torch.quantization.model_calib import ( apply_pre_quant_scale_and_smooth, disable_pre_quant_scale_and_resmooth, + filter_calib_modules, + max_calibrate, ) from modelopt.torch.quantization.nn import TensorQuantizer @@ -375,3 +377,168 @@ def test_svdquant_lora_weights(): module.weight_quantizer.svdquant_lora_b @ module.weight_quantizer.svdquant_lora_a ) assert lora_residual.shape == module.weight.shape + + +# --------------------------------------------------------------------------- +# Tests for filter_calib_modules / include_modules / exclude_modules +# --------------------------------------------------------------------------- + +INT8_CFG = mtq.INT8_DEFAULT_CFG + + +def _make_quantized_mlp(): + """Return a freshly-created quantized _SimpleMLP with max calibration applied.""" + torch.manual_seed(42) + model = _SimpleMLP() + data = [torch.randn(4, 16)] + model = mtq.quantize(model, INT8_CFG, partial(forward_loop, dataloader=data)) + return model, data + + +def _get_weight_amax(model, layer_name: str) -> torch.Tensor: + """Return a copy of the weight_quantizer amax for the named linear layer.""" + module = dict(model.named_modules())[layer_name] + return module.weight_quantizer._amax.clone() + + +def test_mse_calibrate_exclude_modules(): + """MSE calibration with exclude_modules leaves excluded layer amax unchanged.""" + model, data = _make_quantized_mlp() + + # Record the amax of the excluded layer (net.4) before re-calibrating + amax_net4_before = _get_weight_amax(model, "net.4") + + mtq.calibrate( + model, + algorithm={"method": "mse", "calib_exclude_modules": ["*net.4*"]}, + forward_loop=partial(forward_loop, dataloader=data), + ) + + # net.4 should be untouched + assert torch.allclose(amax_net4_before, _get_weight_amax(model, "net.4")), ( + "Excluded module net.4 should have unchanged amax" + ) + + # net.0 and net.2 should still have valid amaxes (were re-calibrated) + assert dict(model.named_modules())["net.0"].weight_quantizer._amax is not None + assert dict(model.named_modules())["net.2"].weight_quantizer._amax is not None + + +def test_mse_calibrate_include_modules(): + """MSE calibration with calib_include_modules leaves non-included layer amaxes unchanged.""" + model, data = _make_quantized_mlp() + + # Record amaxes of non-included layers + amax_net2_before = _get_weight_amax(model, "net.2") + amax_net4_before = _get_weight_amax(model, "net.4") + + mtq.calibrate( + model, + algorithm={"method": "mse", "calib_include_modules": ["net.0"]}, + forward_loop=partial(forward_loop, dataloader=data), + ) + + # Non-included layers should be untouched + assert torch.allclose(amax_net2_before, _get_weight_amax(model, "net.2")), ( + "Non-included module net.2 should have unchanged amax" + ) + assert torch.allclose(amax_net4_before, _get_weight_amax(model, "net.4")), ( + "Non-included module net.4 should have unchanged amax" + ) + + # net.0 should have a valid amax (was calibrated) + assert dict(model.named_modules())["net.0"].weight_quantizer._amax is not None + + +def test_filter_no_op_when_none(): + """filter_calib_modules with both args None is a no-op context manager.""" + model, data = _make_quantized_mlp() + + # Record all amaxes + amaxes_before = {name: _get_weight_amax(model, name) for name in ["net.0", "net.2", "net.4"]} + + # Calling filter_calib_modules with None args and then re-running max_calibrate + # should behave identically to running max_calibrate directly. + with filter_calib_modules(model, include_modules=None, exclude_modules=None): + max_calibrate(model, partial(forward_loop, dataloader=data)) + + # All quantizers should still have valid amaxes + for name in ["net.0", "net.2", "net.4"]: + assert dict(model.named_modules())[name].weight_quantizer._amax is not None + + # Amaxes should be consistent with standard max calibration (not None) + for name in amaxes_before: + amax_after = _get_weight_amax(model, name) + assert amax_after is not None, f"{name} should have a valid amax after calibration" + + +def test_smoothquant_include_modules(): + """smoothquant with include_modules only applies to matching layers.""" + torch.manual_seed(42) + model = _SimpleMLP() + data = [torch.randn(4, 16)] + + # Use algorithm=None so quantizers are inserted but no calibration is run. + # This avoids stale calibrator state when smoothquant later changes axis to -1. + no_calib_cfg = {**INT8_CFG, "algorithm": None} + model = mtq.quantize(model, no_calib_cfg, forward_loop=None) + + mtq.calibrate( + model, + algorithm={"method": "smoothquant", "calib_include_modules": ["*net.0*"]}, + forward_loop=partial(forward_loop, dataloader=data), + ) + + # net.0 should have _pre_quant_scale (was smoothed) + net0 = dict(model.named_modules())["net.0"] + assert hasattr(net0.input_quantizer, "_pre_quant_scale"), "net.0 should have been smoothed" + + # net.2 and net.4 should NOT have _pre_quant_scale (were excluded by filter) + for name in ["net.2", "net.4"]: + mod = dict(model.named_modules())[name] + assert not hasattr(mod.input_quantizer, "_pre_quant_scale"), ( + f"{name} should not have been smoothed" + ) + + +def test_filter_via_config_api(): + """exclude_modules passed through the calibrate() config dict API works correctly.""" + model, data = _make_quantized_mlp() + + # Record amax of excluded layer + amax_net4_before = _get_weight_amax(model, "net.4") + + mtq.calibrate( + model, + algorithm={"method": "mse", "calib_exclude_modules": ["*net.4*"]}, + forward_loop=partial(forward_loop, dataloader=data), + ) + + # net.4 should be untouched + assert torch.allclose(amax_net4_before, _get_weight_amax(model, "net.4")), ( + "Excluded module net.4 should have unchanged amax via config API" + ) + + +def test_wildcard_pattern_matching(): + """fnmatch bracket patterns correctly include/exclude specific layers.""" + model, data = _make_quantized_mlp() + + # Record amax of non-matched layer + amax_net4_before = _get_weight_amax(model, "net.4") + + # "net.[02]" matches net.0 and net.2 but NOT net.4 + mtq.calibrate( + model, + algorithm={"method": "mse", "calib_include_modules": ["net.[02]"]}, + forward_loop=partial(forward_loop, dataloader=data), + ) + + # net.4 should be untouched + assert torch.allclose(amax_net4_before, _get_weight_amax(model, "net.4")), ( + "net.4 should not be affected by include_modules=['net.[02]']" + ) + + # net.0 and net.2 should have valid amaxes + assert dict(model.named_modules())["net.0"].weight_quantizer._amax is not None + assert dict(model.named_modules())["net.2"].weight_quantizer._amax is not None From 378d5248c7f6a5ab9bc4f3d9d87473ff0aec14b0 Mon Sep 17 00:00:00 2001 From: Fridah-nv <201670829+Fridah-nv@users.noreply.github.com> Date: Mon, 16 Mar 2026 22:56:02 +0000 Subject: [PATCH 2/4] Address PR review comments for calib_include/exclude_modules - Fix shared preset mutation in build_quant_cfg by always deep-copying the preset dict before modification (previously only awq path did this) - Document linear-only filtering limitation in filter_calib_modules docstring and calib_include/exclude_modules field descriptions - Filter empty strings from CLI pattern parsing in hf_ptq.py to handle trailing commas gracefully - Strengthen test_filter_no_op_when_none to assert amax value equality rather than just non-None presence Co-Authored-By: Claude Sonnet 4.6 Signed-off-by: Fridah-nv <201670829+Fridah-nv@users.noreply.github.com> --- examples/llm_ptq/example_utils.py | 3 +-- examples/llm_ptq/hf_ptq.py | 4 ++-- modelopt/torch/quantization/config.py | 8 ++++++-- modelopt/torch/quantization/model_calib.py | 5 +++++ tests/unit/torch/quantization/test_calib.py | 8 +++++--- 5 files changed, 19 insertions(+), 9 deletions(-) diff --git a/examples/llm_ptq/example_utils.py b/examples/llm_ptq/example_utils.py index db377c51fc..1d6942bc73 100755 --- a/examples/llm_ptq/example_utils.py +++ b/examples/llm_ptq/example_utils.py @@ -210,10 +210,9 @@ def build_quant_cfg( f"Unsupported quantization format: {qformat} with {kv_cache_qformat} KV cache" ) - quant_cfg = quant_cfg_choices[qformat] + quant_cfg = copy.deepcopy(quant_cfg_choices[qformat]) if "awq" in qformat: - quant_cfg = copy.deepcopy(quant_cfg_choices[qformat]) weight_quantizer = quant_cfg["quant_cfg"]["*weight_quantizer"] if isinstance(weight_quantizer, list): weight_quantizer = weight_quantizer[0] diff --git a/examples/llm_ptq/hf_ptq.py b/examples/llm_ptq/hf_ptq.py index ef438374e6..736732d689 100755 --- a/examples/llm_ptq/hf_ptq.py +++ b/examples/llm_ptq/hf_ptq.py @@ -1247,12 +1247,12 @@ def main(args: argparse.Namespace): args.dataset = args.dataset.split(",") if isinstance(args.dataset, str) else args.dataset args.calib_size = [int(num_sample) for num_sample in args.calib_size.split(",")] args.calib_exclude_modules = ( - [p.strip() for p in args.calib_exclude_modules.split(",")] + [p.strip() for p in args.calib_exclude_modules.split(",") if p.strip()] if args.calib_exclude_modules else None ) args.calib_include_modules = ( - [p.strip() for p in args.calib_include_modules.split(",")] + [p.strip() for p in args.calib_include_modules.split(",") if p.strip()] if args.calib_include_modules else None ) diff --git a/modelopt/torch/quantization/config.py b/modelopt/torch/quantization/config.py index ffc5ab28a4..844ddf7d68 100644 --- a/modelopt/torch/quantization/config.py +++ b/modelopt/torch/quantization/config.py @@ -1090,7 +1090,9 @@ class QuantizeAlgorithmConfig(ModeloptBaseConfig): description=( "If provided, only modules whose names match at least one of the fnmatch patterns are " "calibrated. Modules that do not match any pattern are skipped and retain their " - "pre-existing calibration state." + "pre-existing calibration state. " + "Note: filtering applies only to quantized linear modules; TensorQuantizers in " + "non-linear modules (e.g. layer norms, embeddings) are unaffected." ), ) @@ -1099,7 +1101,9 @@ class QuantizeAlgorithmConfig(ModeloptBaseConfig): title="Patterns of modules to exclude from calibration.", description=( "If provided, modules whose names match at least one of the fnmatch patterns are " - "skipped during calibration and retain their pre-existing calibration state." + "skipped during calibration and retain their pre-existing calibration state. " + "Note: filtering applies only to quantized linear modules; TensorQuantizers in " + "non-linear modules (e.g. layer norms, embeddings) are unaffected." ), ) diff --git a/modelopt/torch/quantization/model_calib.py b/modelopt/torch/quantization/model_calib.py index caa6cf95bb..8cdb6366de 100644 --- a/modelopt/torch/quantization/model_calib.py +++ b/modelopt/torch/quantization/model_calib.py @@ -86,6 +86,11 @@ def filter_calib_modules( exclude_modules: If provided, modules whose names match at least one fnmatch pattern are skipped. + Note: + Only quantized linear modules (as identified by :func:`is_quantized_linear`) are filtered. + ``TensorQuantizer`` instances inside non-linear quantized modules (e.g. layer norms, + embeddings) are not disabled even if their module name matches a pattern. + Example:: with filter_calib_modules(model, exclude_modules=["*lm_head*"]): diff --git a/tests/unit/torch/quantization/test_calib.py b/tests/unit/torch/quantization/test_calib.py index a390e4030a..7b2eabb343 100644 --- a/tests/unit/torch/quantization/test_calib.py +++ b/tests/unit/torch/quantization/test_calib.py @@ -466,10 +466,12 @@ def test_filter_no_op_when_none(): for name in ["net.0", "net.2", "net.4"]: assert dict(model.named_modules())[name].weight_quantizer._amax is not None - # Amaxes should be consistent with standard max calibration (not None) - for name in amaxes_before: + # Amaxes should be identical to those computed without filter_calib_modules + for name, amax_before in amaxes_before.items(): amax_after = _get_weight_amax(model, name) - assert amax_after is not None, f"{name} should have a valid amax after calibration" + assert torch.allclose(amax_before, amax_after), ( + f"{name} amax changed unexpectedly when filter_calib_modules args are None" + ) def test_smoothquant_include_modules(): From 22bd94c96d17f222b162ea9332031ff97acf8f72 Mon Sep 17 00:00:00 2001 From: Fridah-nv <201670829+Fridah-nv@users.noreply.github.com> Date: Mon, 16 Mar 2026 23:15:04 +0000 Subject: [PATCH 3/4] Remove CLI surface for calib_include/exclude_modules Users should set calib_include_modules / calib_exclude_modules directly in the algorithm dict of their quantization config rather than via dedicated CLI flags. Remove --calib_exclude_modules / --calib_include_modules from hf_ptq.py and the corresponding parameters from build_quant_cfg. Update test_filter_via_config_api to exercise the intended usage path: embedding both fields in the algorithm dict and calling mtq.quantize, covering exclude and include variants and asserting that uncalibrated module _amax buffers are absent. Co-Authored-By: Claude Sonnet 4.6 Signed-off-by: Fridah-nv <201670829+Fridah-nv@users.noreply.github.com> --- examples/llm_ptq/example_utils.py | 17 -------- examples/llm_ptq/hf_ptq.py | 33 --------------- tests/unit/torch/quantization/test_calib.py | 45 ++++++++++++++++----- 3 files changed, 34 insertions(+), 61 deletions(-) diff --git a/examples/llm_ptq/example_utils.py b/examples/llm_ptq/example_utils.py index 1d6942bc73..f42cd36bcc 100755 --- a/examples/llm_ptq/example_utils.py +++ b/examples/llm_ptq/example_utils.py @@ -202,8 +202,6 @@ def build_quant_cfg( quant_cfg_choices, kv_quant_cfg_choices, moe_calib_experts_ratio: float | None = None, - calib_exclude_modules: list[str] | None = None, - calib_include_modules: list[str] | None = None, ) -> dict[str, Any]: quant_cfg = {} assert qformat in quant_cfg_choices, ( @@ -248,21 +246,6 @@ def build_quant_cfg( f"Quantization algorithm: {quant_cfg['algorithm']} does not support setting moe_calib_experts_ratio" ) - if calib_exclude_modules or calib_include_modules: - if isinstance(quant_cfg["algorithm"], str): - quant_cfg["algorithm"] = {"method": quant_cfg["algorithm"]} - elif isinstance(quant_cfg["algorithm"], dict): - pass - else: - warnings.warn( - f"Quantization algorithm: {quant_cfg['algorithm']} does not support calib_exclude/include_modules" - ) - if isinstance(quant_cfg["algorithm"], dict): - if calib_exclude_modules: - quant_cfg["algorithm"]["calib_exclude_modules"] = calib_exclude_modules - if calib_include_modules: - quant_cfg["algorithm"]["calib_include_modules"] = calib_include_modules - # Gemma 7B has accuracy regression using alpha 1. We set 0.5 instead. if model_type == "gemma" and "int8_sq" in qformat: quant_cfg["algorithm"] = {"method": "smoothquant", "alpha": 0.5} diff --git a/examples/llm_ptq/hf_ptq.py b/examples/llm_ptq/hf_ptq.py index 736732d689..55bebf05a4 100755 --- a/examples/llm_ptq/hf_ptq.py +++ b/examples/llm_ptq/hf_ptq.py @@ -925,8 +925,6 @@ def quantize_main( QUANT_CFG_CHOICES, KV_QUANT_CFG_CHOICES, args.moe_calib_experts_ratio, - calib_exclude_modules=args.calib_exclude_modules, - calib_include_modules=args.calib_include_modules, ) # Exclude MTP layers from quantization if detected (e.g., GLM-4.7's layer 92) @@ -1167,27 +1165,6 @@ def parse_args() -> argparse.Namespace: "Does not impact non-MOE models." ), ) - parser.add_argument( - "--calib_exclude_modules", - type=str, - default=None, - help=( - "Comma-separated list of fnmatch patterns for modules to exclude from calibration. " - "Matching modules retain their pre-existing calibration state. " - "Example: --calib_exclude_modules '*lm_head*,*vision*'" - ), - ) - parser.add_argument( - "--calib_include_modules", - type=str, - default=None, - help=( - "Comma-separated list of fnmatch patterns for modules to include in calibration. " - "Only matching modules are calibrated; all others are skipped. " - "Example: --calib_include_modules '*layers.0*,*layers.1*'" - ), - ) - args = parser.parse_args() if not (0.0 < args.moe_calib_experts_ratio <= 1.0): parser.error("--moe_calib_experts_ratio must be in the range (0.0, 1.0].") @@ -1246,14 +1223,4 @@ def main(args: argparse.Namespace): args.dataset = args.dataset.split(",") if isinstance(args.dataset, str) else args.dataset args.calib_size = [int(num_sample) for num_sample in args.calib_size.split(",")] - args.calib_exclude_modules = ( - [p.strip() for p in args.calib_exclude_modules.split(",") if p.strip()] - if args.calib_exclude_modules - else None - ) - args.calib_include_modules = ( - [p.strip() for p in args.calib_include_modules.split(",") if p.strip()] - if args.calib_include_modules - else None - ) main(args) diff --git a/tests/unit/torch/quantization/test_calib.py b/tests/unit/torch/quantization/test_calib.py index 7b2eabb343..d13eea2319 100644 --- a/tests/unit/torch/quantization/test_calib.py +++ b/tests/unit/torch/quantization/test_calib.py @@ -15,6 +15,7 @@ """Unittests for AWQ, SVDQuant, and calibration module filtering""" +import copy from functools import partial import torch @@ -504,21 +505,43 @@ def test_smoothquant_include_modules(): def test_filter_via_config_api(): - """exclude_modules passed through the calibrate() config dict API works correctly.""" - model, data = _make_quantized_mlp() + """calib_exclude/include_modules embedded in the algorithm config dict work end-to-end. - # Record amax of excluded layer - amax_net4_before = _get_weight_amax(model, "net.4") + This is the intended usage: users set these fields directly in the algorithm dict of their + quantization config rather than passing them as separate CLI arguments. + """ + torch.manual_seed(42) + model = _SimpleMLP() + data = [torch.randn(4, 16)] - mtq.calibrate( - model, - algorithm={"method": "mse", "calib_exclude_modules": ["*net.4*"]}, - forward_loop=partial(forward_loop, dataloader=data), + # Intended usage: embed calib_exclude_modules in the algorithm dict of the quant config. + quant_cfg = copy.deepcopy(INT8_CFG) + quant_cfg["algorithm"] = {"method": "max", "calib_exclude_modules": ["*net.4*"]} + model = mtq.quantize(model, quant_cfg, partial(forward_loop, dataloader=data)) + + modules = dict(model.named_modules()) + # net.0 and net.2 were calibrated — _amax buffer should be registered. + assert hasattr(modules["net.0"].weight_quantizer, "_amax") + assert hasattr(modules["net.2"].weight_quantizer, "_amax") + # net.4 was excluded — calibrator never ran, so _amax buffer is absent. + assert not hasattr(modules["net.4"].weight_quantizer, "_amax"), ( + "net.4 was excluded from calibration so its _amax buffer should not be registered" ) - # net.4 should be untouched - assert torch.allclose(amax_net4_before, _get_weight_amax(model, "net.4")), ( - "Excluded module net.4 should have unchanged amax via config API" + # include_modules variant: only net.0 is calibrated. + torch.manual_seed(42) + model2 = _SimpleMLP() + quant_cfg2 = copy.deepcopy(INT8_CFG) + quant_cfg2["algorithm"] = {"method": "max", "calib_include_modules": ["net.0"]} + model2 = mtq.quantize(model2, quant_cfg2, partial(forward_loop, dataloader=data)) + + modules2 = dict(model2.named_modules()) + assert hasattr(modules2["net.0"].weight_quantizer, "_amax") + assert not hasattr(modules2["net.2"].weight_quantizer, "_amax"), ( + "net.2 was not included in calibration so its _amax buffer should not be registered" + ) + assert not hasattr(modules2["net.4"].weight_quantizer, "_amax"), ( + "net.4 was not included in calibration so its _amax buffer should not be registered" ) From 5a8c907d47d963c641b1cf51f7626ffb418cd57d Mon Sep 17 00:00:00 2001 From: Fridah-nv <201670829+Fridah-nv@users.noreply.github.com> Date: Mon, 16 Mar 2026 23:17:59 +0000 Subject: [PATCH 4/4] Revert example_utils.py and hf_ptq.py to pre-PR state calib_include/exclude_modules is a core library feature accessed via the algorithm config dict; example scripts should not be modified. Co-Authored-By: Claude Sonnet 4.6 Signed-off-by: Fridah-nv <201670829+Fridah-nv@users.noreply.github.com> --- examples/llm_ptq/example_utils.py | 3 ++- examples/llm_ptq/hf_ptq.py | 1 + 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/examples/llm_ptq/example_utils.py b/examples/llm_ptq/example_utils.py index f42cd36bcc..50ac51aace 100755 --- a/examples/llm_ptq/example_utils.py +++ b/examples/llm_ptq/example_utils.py @@ -208,9 +208,10 @@ def build_quant_cfg( f"Unsupported quantization format: {qformat} with {kv_cache_qformat} KV cache" ) - quant_cfg = copy.deepcopy(quant_cfg_choices[qformat]) + quant_cfg = quant_cfg_choices[qformat] if "awq" in qformat: + quant_cfg = copy.deepcopy(quant_cfg_choices[qformat]) weight_quantizer = quant_cfg["quant_cfg"]["*weight_quantizer"] if isinstance(weight_quantizer, list): weight_quantizer = weight_quantizer[0] diff --git a/examples/llm_ptq/hf_ptq.py b/examples/llm_ptq/hf_ptq.py index 55bebf05a4..fd35a53f27 100755 --- a/examples/llm_ptq/hf_ptq.py +++ b/examples/llm_ptq/hf_ptq.py @@ -1165,6 +1165,7 @@ def parse_args() -> argparse.Namespace: "Does not impact non-MOE models." ), ) + args = parser.parse_args() if not (0.0 < args.moe_calib_experts_ratio <= 1.0): parser.error("--moe_calib_experts_ratio must be in the range (0.0, 1.0].")