diff --git a/modelopt/torch/quantization/config.py b/modelopt/torch/quantization/config.py index 140d7c080..844ddf7d6 100644 --- a/modelopt/torch/quantization/config.py +++ b/modelopt/torch/quantization/config.py @@ -1084,6 +1084,29 @@ class QuantizeAlgorithmConfig(ModeloptBaseConfig): ), ) + calib_include_modules: list[str] | None = ModeloptField( + default=None, + title="Patterns of modules to include in calibration.", + description=( + "If provided, only modules whose names match at least one of the fnmatch patterns are " + "calibrated. Modules that do not match any pattern are skipped and retain their " + "pre-existing calibration state. " + "Note: filtering applies only to quantized linear modules; TensorQuantizers in " + "non-linear modules (e.g. layer norms, embeddings) are unaffected." + ), + ) + + calib_exclude_modules: list[str] | None = ModeloptField( + default=None, + title="Patterns of modules to exclude from calibration.", + description=( + "If provided, modules whose names match at least one of the fnmatch patterns are " + "skipped during calibration and retain their pre-existing calibration state. " + "Note: filtering applies only to quantized linear modules; TensorQuantizers in " + "non-linear modules (e.g. layer norms, embeddings) are unaffected." + ), + ) + class MaxCalibConfig(QuantizeAlgorithmConfig): """The config for max calibration algorithm. diff --git a/modelopt/torch/quantization/mode.py b/modelopt/torch/quantization/mode.py index 1fbe65406..de7161e6a 100644 --- a/modelopt/torch/quantization/mode.py +++ b/modelopt/torch/quantization/mode.py @@ -59,6 +59,7 @@ ) from .model_calib import ( awq, + filter_calib_modules, gptq_lite, local_hessian_calibrate, max_calibrate, @@ -223,6 +224,8 @@ def wrapped_calib_func( kwargs = config.model_dump() method = kwargs.pop("method") sequential = kwargs.pop("use_sequential", False) + calib_include_modules = kwargs.pop("calib_include_modules", None) + calib_exclude_modules = kwargs.pop("calib_exclude_modules", None) if method is not None and "awq" in method: # For backward compatibility kwargs["algorithm"] = method @@ -243,22 +246,23 @@ def wrapped_calib_func( module._moe_count_expert_calib_tokens = True if func is not None: - if sequential: - if forward_loop is None: - raise ValueError("forward_loop is required for calibration but got None.") - assert method in ["max"], ( - f"Sequential calibration currently only supports max calibration, got {method}" - ) - # Wrap with sequential processing - sequential_calibrate( - model, - forward_loop=forward_loop, - calib_func=func, - **kwargs, - ) - else: - # Direct calibration (existing behavior) - func(model, forward_loop=forward_loop, **kwargs) + with filter_calib_modules(model, calib_include_modules, calib_exclude_modules): + if sequential: + if forward_loop is None: + raise ValueError("forward_loop is required for calibration but got None.") + assert method in ["max"], ( + f"Sequential calibration currently only supports max calibration, got {method}" + ) + # Wrap with sequential processing + sequential_calibrate( + model, + forward_loop=forward_loop, + calib_func=func, + **kwargs, + ) + else: + # Direct calibration (existing behavior) + func(model, forward_loop=forward_loop, **kwargs) # Lets get the latest metadata for the quantizer states metadata = {} diff --git a/modelopt/torch/quantization/model_calib.py b/modelopt/torch/quantization/model_calib.py index 5618fa413..8cdb6366d 100644 --- a/modelopt/torch/quantization/model_calib.py +++ b/modelopt/torch/quantization/model_calib.py @@ -15,6 +15,8 @@ """Calibration utilities.""" +import contextlib +import fnmatch import math import os import warnings @@ -56,6 +58,7 @@ __all__ = [ "awq", + "filter_calib_modules", "local_hessian_calibrate", "max_calibrate", "sequential_calibrate", @@ -64,6 +67,62 @@ ] +@contextlib.contextmanager +def filter_calib_modules( + model: nn.Module, + include_modules: list[str] | None = None, + exclude_modules: list[str] | None = None, +): + """Context manager to restrict calibration to a subset of the model's modules. + + Temporarily disables quantizers in modules that do not pass the include/exclude filters. + Disabled quantizers retain their pre-existing ``_amax`` values because + :meth:`TensorQuantizer.disable` does not clear ``_amax``. + + Args: + model: The quantized model. + include_modules: If provided, only modules whose names match at least one fnmatch pattern + are calibrated. All others are skipped. + exclude_modules: If provided, modules whose names match at least one fnmatch pattern are + skipped. + + Note: + Only quantized linear modules (as identified by :func:`is_quantized_linear`) are filtered. + ``TensorQuantizer`` instances inside non-linear quantized modules (e.g. layer norms, + embeddings) are not disabled even if their module name matches a pattern. + + Example:: + + with filter_calib_modules(model, exclude_modules=["*lm_head*"]): + mse_calibrate(model, forward_loop) + """ + if include_modules is None and exclude_modules is None: + yield + return + + def _should_calibrate(name: str) -> bool: + if include_modules is not None: + if not any(fnmatch.fnmatch(name, p) for p in include_modules): + return False + if exclude_modules is not None: + if any(fnmatch.fnmatch(name, p) for p in exclude_modules): + return False + return True + + disabled = [] + for name, module in model.named_modules(): + if is_quantized_linear(module) and not _should_calibrate(name): + for _, child in module.named_modules(): + if isinstance(child, TensorQuantizer) and not child._disabled: + child.disable() + disabled.append(child) + try: + yield + finally: + for q in disabled: + q.enable() + + def weight_only_quantize(model: nn.Module): """Just quantize the weights of the model.""" name_to_module = dict(model.named_modules()) diff --git a/tests/unit/torch/quantization/test_calib.py b/tests/unit/torch/quantization/test_calib.py index 7bc78c40e..d13eea231 100644 --- a/tests/unit/torch/quantization/test_calib.py +++ b/tests/unit/torch/quantization/test_calib.py @@ -13,8 +13,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Unittests for AWQ and SVDQuant""" +"""Unittests for AWQ, SVDQuant, and calibration module filtering""" +import copy from functools import partial import torch @@ -26,6 +27,8 @@ from modelopt.torch.quantization.model_calib import ( apply_pre_quant_scale_and_smooth, disable_pre_quant_scale_and_resmooth, + filter_calib_modules, + max_calibrate, ) from modelopt.torch.quantization.nn import TensorQuantizer @@ -375,3 +378,192 @@ def test_svdquant_lora_weights(): module.weight_quantizer.svdquant_lora_b @ module.weight_quantizer.svdquant_lora_a ) assert lora_residual.shape == module.weight.shape + + +# --------------------------------------------------------------------------- +# Tests for filter_calib_modules / include_modules / exclude_modules +# --------------------------------------------------------------------------- + +INT8_CFG = mtq.INT8_DEFAULT_CFG + + +def _make_quantized_mlp(): + """Return a freshly-created quantized _SimpleMLP with max calibration applied.""" + torch.manual_seed(42) + model = _SimpleMLP() + data = [torch.randn(4, 16)] + model = mtq.quantize(model, INT8_CFG, partial(forward_loop, dataloader=data)) + return model, data + + +def _get_weight_amax(model, layer_name: str) -> torch.Tensor: + """Return a copy of the weight_quantizer amax for the named linear layer.""" + module = dict(model.named_modules())[layer_name] + return module.weight_quantizer._amax.clone() + + +def test_mse_calibrate_exclude_modules(): + """MSE calibration with exclude_modules leaves excluded layer amax unchanged.""" + model, data = _make_quantized_mlp() + + # Record the amax of the excluded layer (net.4) before re-calibrating + amax_net4_before = _get_weight_amax(model, "net.4") + + mtq.calibrate( + model, + algorithm={"method": "mse", "calib_exclude_modules": ["*net.4*"]}, + forward_loop=partial(forward_loop, dataloader=data), + ) + + # net.4 should be untouched + assert torch.allclose(amax_net4_before, _get_weight_amax(model, "net.4")), ( + "Excluded module net.4 should have unchanged amax" + ) + + # net.0 and net.2 should still have valid amaxes (were re-calibrated) + assert dict(model.named_modules())["net.0"].weight_quantizer._amax is not None + assert dict(model.named_modules())["net.2"].weight_quantizer._amax is not None + + +def test_mse_calibrate_include_modules(): + """MSE calibration with calib_include_modules leaves non-included layer amaxes unchanged.""" + model, data = _make_quantized_mlp() + + # Record amaxes of non-included layers + amax_net2_before = _get_weight_amax(model, "net.2") + amax_net4_before = _get_weight_amax(model, "net.4") + + mtq.calibrate( + model, + algorithm={"method": "mse", "calib_include_modules": ["net.0"]}, + forward_loop=partial(forward_loop, dataloader=data), + ) + + # Non-included layers should be untouched + assert torch.allclose(amax_net2_before, _get_weight_amax(model, "net.2")), ( + "Non-included module net.2 should have unchanged amax" + ) + assert torch.allclose(amax_net4_before, _get_weight_amax(model, "net.4")), ( + "Non-included module net.4 should have unchanged amax" + ) + + # net.0 should have a valid amax (was calibrated) + assert dict(model.named_modules())["net.0"].weight_quantizer._amax is not None + + +def test_filter_no_op_when_none(): + """filter_calib_modules with both args None is a no-op context manager.""" + model, data = _make_quantized_mlp() + + # Record all amaxes + amaxes_before = {name: _get_weight_amax(model, name) for name in ["net.0", "net.2", "net.4"]} + + # Calling filter_calib_modules with None args and then re-running max_calibrate + # should behave identically to running max_calibrate directly. + with filter_calib_modules(model, include_modules=None, exclude_modules=None): + max_calibrate(model, partial(forward_loop, dataloader=data)) + + # All quantizers should still have valid amaxes + for name in ["net.0", "net.2", "net.4"]: + assert dict(model.named_modules())[name].weight_quantizer._amax is not None + + # Amaxes should be identical to those computed without filter_calib_modules + for name, amax_before in amaxes_before.items(): + amax_after = _get_weight_amax(model, name) + assert torch.allclose(amax_before, amax_after), ( + f"{name} amax changed unexpectedly when filter_calib_modules args are None" + ) + + +def test_smoothquant_include_modules(): + """smoothquant with include_modules only applies to matching layers.""" + torch.manual_seed(42) + model = _SimpleMLP() + data = [torch.randn(4, 16)] + + # Use algorithm=None so quantizers are inserted but no calibration is run. + # This avoids stale calibrator state when smoothquant later changes axis to -1. + no_calib_cfg = {**INT8_CFG, "algorithm": None} + model = mtq.quantize(model, no_calib_cfg, forward_loop=None) + + mtq.calibrate( + model, + algorithm={"method": "smoothquant", "calib_include_modules": ["*net.0*"]}, + forward_loop=partial(forward_loop, dataloader=data), + ) + + # net.0 should have _pre_quant_scale (was smoothed) + net0 = dict(model.named_modules())["net.0"] + assert hasattr(net0.input_quantizer, "_pre_quant_scale"), "net.0 should have been smoothed" + + # net.2 and net.4 should NOT have _pre_quant_scale (were excluded by filter) + for name in ["net.2", "net.4"]: + mod = dict(model.named_modules())[name] + assert not hasattr(mod.input_quantizer, "_pre_quant_scale"), ( + f"{name} should not have been smoothed" + ) + + +def test_filter_via_config_api(): + """calib_exclude/include_modules embedded in the algorithm config dict work end-to-end. + + This is the intended usage: users set these fields directly in the algorithm dict of their + quantization config rather than passing them as separate CLI arguments. + """ + torch.manual_seed(42) + model = _SimpleMLP() + data = [torch.randn(4, 16)] + + # Intended usage: embed calib_exclude_modules in the algorithm dict of the quant config. + quant_cfg = copy.deepcopy(INT8_CFG) + quant_cfg["algorithm"] = {"method": "max", "calib_exclude_modules": ["*net.4*"]} + model = mtq.quantize(model, quant_cfg, partial(forward_loop, dataloader=data)) + + modules = dict(model.named_modules()) + # net.0 and net.2 were calibrated — _amax buffer should be registered. + assert hasattr(modules["net.0"].weight_quantizer, "_amax") + assert hasattr(modules["net.2"].weight_quantizer, "_amax") + # net.4 was excluded — calibrator never ran, so _amax buffer is absent. + assert not hasattr(modules["net.4"].weight_quantizer, "_amax"), ( + "net.4 was excluded from calibration so its _amax buffer should not be registered" + ) + + # include_modules variant: only net.0 is calibrated. + torch.manual_seed(42) + model2 = _SimpleMLP() + quant_cfg2 = copy.deepcopy(INT8_CFG) + quant_cfg2["algorithm"] = {"method": "max", "calib_include_modules": ["net.0"]} + model2 = mtq.quantize(model2, quant_cfg2, partial(forward_loop, dataloader=data)) + + modules2 = dict(model2.named_modules()) + assert hasattr(modules2["net.0"].weight_quantizer, "_amax") + assert not hasattr(modules2["net.2"].weight_quantizer, "_amax"), ( + "net.2 was not included in calibration so its _amax buffer should not be registered" + ) + assert not hasattr(modules2["net.4"].weight_quantizer, "_amax"), ( + "net.4 was not included in calibration so its _amax buffer should not be registered" + ) + + +def test_wildcard_pattern_matching(): + """fnmatch bracket patterns correctly include/exclude specific layers.""" + model, data = _make_quantized_mlp() + + # Record amax of non-matched layer + amax_net4_before = _get_weight_amax(model, "net.4") + + # "net.[02]" matches net.0 and net.2 but NOT net.4 + mtq.calibrate( + model, + algorithm={"method": "mse", "calib_include_modules": ["net.[02]"]}, + forward_loop=partial(forward_loop, dataloader=data), + ) + + # net.4 should be untouched + assert torch.allclose(amax_net4_before, _get_weight_amax(model, "net.4")), ( + "net.4 should not be affected by include_modules=['net.[02]']" + ) + + # net.0 and net.2 should have valid amaxes + assert dict(model.named_modules())["net.0"].weight_quantizer._amax is not None + assert dict(model.named_modules())["net.2"].weight_quantizer._amax is not None