Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions modelopt/torch/quantization/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -1084,6 +1084,29 @@ class QuantizeAlgorithmConfig(ModeloptBaseConfig):
),
)

calib_include_modules: list[str] | None = ModeloptField(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need the calib prefix? Is not it obvious that this is for calib include_modules?

Suggested change
calib_include_modules: list[str] | None = ModeloptField(
include_modules: list[str] | None = ModeloptField(

default=None,
title="Patterns of modules to include in calibration.",
description=(
"If provided, only modules whose names match at least one of the fnmatch patterns are "
"calibrated. Modules that do not match any pattern are skipped and retain their "
"pre-existing calibration state. "
"Note: filtering applies only to quantized linear modules; TensorQuantizers in "
"non-linear modules (e.g. layer norms, embeddings) are unaffected."
),
)

calib_exclude_modules: list[str] | None = ModeloptField(
default=None,
title="Patterns of modules to exclude from calibration.",
description=(
"If provided, modules whose names match at least one of the fnmatch patterns are "
"skipped during calibration and retain their pre-existing calibration state. "
"Note: filtering applies only to quantized linear modules; TensorQuantizers in "
"non-linear modules (e.g. layer norms, embeddings) are unaffected."
),
)
Comment on lines +1087 to +1108
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If a user passes both calib_include_modules and calib_exclude_modules, the behavior is implicitly "include first, then exclude"? Do you think we need to either:

  • Documented explicitly (what happens if a module matches both?), or
  • Validated to raise an error if both are set simultaneously.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the ordering does not actually matter, but we do need to document a clear semantic. These 2 actually are 2 exclude module lists:

  1. exclude those that not in the incllude_modules
  2. exclude those that in the exclude_modeuls



class MaxCalibConfig(QuantizeAlgorithmConfig):
"""The config for max calibration algorithm.
Expand Down
36 changes: 20 additions & 16 deletions modelopt/torch/quantization/mode.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@
)
from .model_calib import (
awq,
filter_calib_modules,
gptq_lite,
local_hessian_calibrate,
max_calibrate,
Expand Down Expand Up @@ -223,6 +224,8 @@ def wrapped_calib_func(
kwargs = config.model_dump()
method = kwargs.pop("method")
sequential = kwargs.pop("use_sequential", False)
calib_include_modules = kwargs.pop("calib_include_modules", None)
calib_exclude_modules = kwargs.pop("calib_exclude_modules", None)
if method is not None and "awq" in method:
# For backward compatibility
kwargs["algorithm"] = method
Expand All @@ -243,22 +246,23 @@ def wrapped_calib_func(
module._moe_count_expert_calib_tokens = True

if func is not None:
if sequential:
if forward_loop is None:
raise ValueError("forward_loop is required for calibration but got None.")
assert method in ["max"], (
f"Sequential calibration currently only supports max calibration, got {method}"
)
# Wrap with sequential processing
sequential_calibrate(
model,
forward_loop=forward_loop,
calib_func=func,
**kwargs,
)
else:
# Direct calibration (existing behavior)
func(model, forward_loop=forward_loop, **kwargs)
with filter_calib_modules(model, calib_include_modules, calib_exclude_modules):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just to double check. If in the future we want to run multiple algorithms in the sequential flow, for example local_hessian followed by gptq will this work?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Another follow up question. Do we plan on running different calibration algorithms for different layers in the future? Can this context manager be helpful in that case?

if sequential:
if forward_loop is None:
raise ValueError("forward_loop is required for calibration but got None.")
assert method in ["max"], (
f"Sequential calibration currently only supports max calibration, got {method}"
)
# Wrap with sequential processing
sequential_calibrate(
model,
forward_loop=forward_loop,
calib_func=func,
**kwargs,
)
else:
# Direct calibration (existing behavior)
func(model, forward_loop=forward_loop, **kwargs)

# Lets get the latest metadata for the quantizer states
metadata = {}
Expand Down
59 changes: 59 additions & 0 deletions modelopt/torch/quantization/model_calib.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@

"""Calibration utilities."""

import contextlib
import fnmatch
import math
import os
import warnings
Expand Down Expand Up @@ -56,6 +58,7 @@

__all__ = [
"awq",
"filter_calib_modules",
"local_hessian_calibrate",
"max_calibrate",
"sequential_calibrate",
Expand All @@ -64,6 +67,62 @@
]


@contextlib.contextmanager
def filter_calib_modules(
model: nn.Module,
include_modules: list[str] | None = None,
exclude_modules: list[str] | None = None,
):
"""Context manager to restrict calibration to a subset of the model's modules.

Temporarily disables quantizers in modules that do not pass the include/exclude filters.
Disabled quantizers retain their pre-existing ``_amax`` values because
:meth:`TensorQuantizer.disable` does not clear ``_amax``.

Args:
model: The quantized model.
include_modules: If provided, only modules whose names match at least one fnmatch pattern
are calibrated. All others are skipped.
exclude_modules: If provided, modules whose names match at least one fnmatch pattern are
skipped.

Note:
Only quantized linear modules (as identified by :func:`is_quantized_linear`) are filtered.
``TensorQuantizer`` instances inside non-linear quantized modules (e.g. layer norms,
embeddings) are not disabled even if their module name matches a pattern.

Example::

with filter_calib_modules(model, exclude_modules=["*lm_head*"]):
mse_calibrate(model, forward_loop)
"""
if include_modules is None and exclude_modules is None:
yield
return

def _should_calibrate(name: str) -> bool:
if include_modules is not None:
if not any(fnmatch.fnmatch(name, p) for p in include_modules):
return False
if exclude_modules is not None:
if any(fnmatch.fnmatch(name, p) for p in exclude_modules):
return False
return True

disabled = []
for name, module in model.named_modules():
if is_quantized_linear(module) and not _should_calibrate(name):
for _, child in module.named_modules():
if isinstance(child, TensorQuantizer) and not child._disabled:
child.disable()
disabled.append(child)
try:
yield
finally:
for q in disabled:
q.enable()


def weight_only_quantize(model: nn.Module):
"""Just quantize the weights of the model."""
name_to_module = dict(model.named_modules())
Expand Down
194 changes: 193 additions & 1 deletion tests/unit/torch/quantization/test_calib.py
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.

"""Unittests for AWQ and SVDQuant"""
"""Unittests for AWQ, SVDQuant, and calibration module filtering"""

import copy
from functools import partial

import torch
Expand All @@ -26,6 +27,8 @@
from modelopt.torch.quantization.model_calib import (
apply_pre_quant_scale_and_smooth,
disable_pre_quant_scale_and_resmooth,
filter_calib_modules,
max_calibrate,
)
from modelopt.torch.quantization.nn import TensorQuantizer

Expand Down Expand Up @@ -375,3 +378,192 @@ def test_svdquant_lora_weights():
module.weight_quantizer.svdquant_lora_b @ module.weight_quantizer.svdquant_lora_a
)
assert lora_residual.shape == module.weight.shape


# ---------------------------------------------------------------------------
# Tests for filter_calib_modules / include_modules / exclude_modules
# ---------------------------------------------------------------------------

INT8_CFG = mtq.INT8_DEFAULT_CFG


def _make_quantized_mlp():
"""Return a freshly-created quantized _SimpleMLP with max calibration applied."""
torch.manual_seed(42)
model = _SimpleMLP()
data = [torch.randn(4, 16)]
model = mtq.quantize(model, INT8_CFG, partial(forward_loop, dataloader=data))
return model, data


def _get_weight_amax(model, layer_name: str) -> torch.Tensor:
"""Return a copy of the weight_quantizer amax for the named linear layer."""
module = dict(model.named_modules())[layer_name]
return module.weight_quantizer._amax.clone()


def test_mse_calibrate_exclude_modules():
"""MSE calibration with exclude_modules leaves excluded layer amax unchanged."""
model, data = _make_quantized_mlp()

# Record the amax of the excluded layer (net.4) before re-calibrating
amax_net4_before = _get_weight_amax(model, "net.4")

mtq.calibrate(
model,
algorithm={"method": "mse", "calib_exclude_modules": ["*net.4*"]},
forward_loop=partial(forward_loop, dataloader=data),
)

# net.4 should be untouched
assert torch.allclose(amax_net4_before, _get_weight_amax(model, "net.4")), (
"Excluded module net.4 should have unchanged amax"
)
Comment on lines +418 to +421
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are we using the same dataset for mtq.calibrate? How will amax change if we do this?
How about we dont pass in forward_loop in _make_quantized_mlp() and pass forward_loop in mtq.calibrate( -> this way net.4 wont have amax


# net.0 and net.2 should still have valid amaxes (were re-calibrated)
assert dict(model.named_modules())["net.0"].weight_quantizer._amax is not None
assert dict(model.named_modules())["net.2"].weight_quantizer._amax is not None


def test_mse_calibrate_include_modules():
"""MSE calibration with calib_include_modules leaves non-included layer amaxes unchanged."""
model, data = _make_quantized_mlp()

# Record amaxes of non-included layers
amax_net2_before = _get_weight_amax(model, "net.2")
amax_net4_before = _get_weight_amax(model, "net.4")

mtq.calibrate(
model,
algorithm={"method": "mse", "calib_include_modules": ["net.0"]},
forward_loop=partial(forward_loop, dataloader=data),
)

# Non-included layers should be untouched
assert torch.allclose(amax_net2_before, _get_weight_amax(model, "net.2")), (
"Non-included module net.2 should have unchanged amax"
)
assert torch.allclose(amax_net4_before, _get_weight_amax(model, "net.4")), (
"Non-included module net.4 should have unchanged amax"
)

# net.0 should have a valid amax (was calibrated)
assert dict(model.named_modules())["net.0"].weight_quantizer._amax is not None


def test_filter_no_op_when_none():
"""filter_calib_modules with both args None is a no-op context manager."""
model, data = _make_quantized_mlp()

# Record all amaxes
amaxes_before = {name: _get_weight_amax(model, name) for name in ["net.0", "net.2", "net.4"]}

# Calling filter_calib_modules with None args and then re-running max_calibrate
# should behave identically to running max_calibrate directly.
with filter_calib_modules(model, include_modules=None, exclude_modules=None):
max_calibrate(model, partial(forward_loop, dataloader=data))

# All quantizers should still have valid amaxes
for name in ["net.0", "net.2", "net.4"]:
assert dict(model.named_modules())[name].weight_quantizer._amax is not None

# Amaxes should be identical to those computed without filter_calib_modules
for name, amax_before in amaxes_before.items():
amax_after = _get_weight_amax(model, name)
assert torch.allclose(amax_before, amax_after), (
f"{name} amax changed unexpectedly when filter_calib_modules args are None"
)


def test_smoothquant_include_modules():
"""smoothquant with include_modules only applies to matching layers."""
torch.manual_seed(42)
model = _SimpleMLP()
data = [torch.randn(4, 16)]

# Use algorithm=None so quantizers are inserted but no calibration is run.
# This avoids stale calibrator state when smoothquant later changes axis to -1.
no_calib_cfg = {**INT8_CFG, "algorithm": None}
model = mtq.quantize(model, no_calib_cfg, forward_loop=None)

mtq.calibrate(
model,
algorithm={"method": "smoothquant", "calib_include_modules": ["*net.0*"]},
forward_loop=partial(forward_loop, dataloader=data),
)

# net.0 should have _pre_quant_scale (was smoothed)
net0 = dict(model.named_modules())["net.0"]
assert hasattr(net0.input_quantizer, "_pre_quant_scale"), "net.0 should have been smoothed"

# net.2 and net.4 should NOT have _pre_quant_scale (were excluded by filter)
for name in ["net.2", "net.4"]:
mod = dict(model.named_modules())[name]
assert not hasattr(mod.input_quantizer, "_pre_quant_scale"), (
f"{name} should not have been smoothed"
)


def test_filter_via_config_api():
"""calib_exclude/include_modules embedded in the algorithm config dict work end-to-end.

This is the intended usage: users set these fields directly in the algorithm dict of their
quantization config rather than passing them as separate CLI arguments.
"""
torch.manual_seed(42)
model = _SimpleMLP()
data = [torch.randn(4, 16)]

# Intended usage: embed calib_exclude_modules in the algorithm dict of the quant config.
quant_cfg = copy.deepcopy(INT8_CFG)
quant_cfg["algorithm"] = {"method": "max", "calib_exclude_modules": ["*net.4*"]}
model = mtq.quantize(model, quant_cfg, partial(forward_loop, dataloader=data))

modules = dict(model.named_modules())
# net.0 and net.2 were calibrated — _amax buffer should be registered.
assert hasattr(modules["net.0"].weight_quantizer, "_amax")
assert hasattr(modules["net.2"].weight_quantizer, "_amax")
# net.4 was excluded — calibrator never ran, so _amax buffer is absent.
assert not hasattr(modules["net.4"].weight_quantizer, "_amax"), (
"net.4 was excluded from calibration so its _amax buffer should not be registered"
)

# include_modules variant: only net.0 is calibrated.
torch.manual_seed(42)
model2 = _SimpleMLP()
quant_cfg2 = copy.deepcopy(INT8_CFG)
quant_cfg2["algorithm"] = {"method": "max", "calib_include_modules": ["net.0"]}
model2 = mtq.quantize(model2, quant_cfg2, partial(forward_loop, dataloader=data))

modules2 = dict(model2.named_modules())
assert hasattr(modules2["net.0"].weight_quantizer, "_amax")
assert not hasattr(modules2["net.2"].weight_quantizer, "_amax"), (
"net.2 was not included in calibration so its _amax buffer should not be registered"
)
assert not hasattr(modules2["net.4"].weight_quantizer, "_amax"), (
"net.4 was not included in calibration so its _amax buffer should not be registered"
)


def test_wildcard_pattern_matching():
"""fnmatch bracket patterns correctly include/exclude specific layers."""
model, data = _make_quantized_mlp()

# Record amax of non-matched layer
amax_net4_before = _get_weight_amax(model, "net.4")

# "net.[02]" matches net.0 and net.2 but NOT net.4
mtq.calibrate(
model,
algorithm={"method": "mse", "calib_include_modules": ["net.[02]"]},
forward_loop=partial(forward_loop, dataloader=data),
)

# net.4 should be untouched
assert torch.allclose(amax_net4_before, _get_weight_amax(model, "net.4")), (
"net.4 should not be affected by include_modules=['net.[02]']"
)

# net.0 and net.2 should have valid amaxes
assert dict(model.named_modules())["net.0"].weight_quantizer._amax is not None
assert dict(model.named_modules())["net.2"].weight_quantizer._amax is not None
Loading