Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 59 additions & 4 deletions modelopt/torch/export/plugins/vllm_fakequant_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,14 +47,18 @@
"merge_amax_tensors_for_group",
]

# Matches ``…weight_quantizer``, ``…weight_quantizer.0``, ``…w13_weight_quantizer.0``, etc.
_WEIGHT_QUANTIZER_STATE_KEY = re.compile(r"(?:^|\.)(?:\w+_)?weight_quantizer(?:\.\d+)*$")
# Matches ``…weight_quantizer``, ``…weight_quantizer.0``, ``…w13_weight_quantizer.0``,
# and the plural fused-experts form ``…weight_quantizers.0`` (per-expert ModuleList).
_WEIGHT_QUANTIZER_STATE_KEY = re.compile(r"(?:^|\.)(?:\w+_)?weight_quantizers?(?:\.\d+)*$")


def is_weight_quantizer_state_key(key: str) -> bool:
"""Return True for weight-quantizer state keys, including SequentialQuantizer entries.
"""Return True for weight-quantizer state keys.

Matches ``weight_quantizer``, ``w13_weight_quantizer``, ``weight_quantizer.0``, etc.
Includes ``SequentialQuantizer`` entries and fused-experts ``ModuleList``
entries (``*_weight_quantizers.<idx>``). Matches ``weight_quantizer``,
``w13_weight_quantizer``, ``weight_quantizer.0``,
``gate_up_proj_weight_quantizers.0``, etc.
"""
return bool(_WEIGHT_QUANTIZER_STATE_KEY.search(key))

Expand Down Expand Up @@ -142,6 +146,56 @@ def disable_rotate(quantizer: TensorQuantizer):
return False


def _fakequant_fused_experts_weights(
module: nn.Module,
module_name: str,
state_dict: dict | None,
fakequant_weights: set,
inplace: bool,
):
"""Apply per-expert fake-quant to a ``_QuantFusedExperts`` module's 3-D weights.

The base loop in :func:`_fakequant_module_weights` only handles singular
``*_weight_quantizer`` attrs (one TensorQuantizer per weight). Fused-experts
modules expose ``*_weight_quantizers`` (``nn.ModuleList`` with one entry per
expert) that the base loop skips, leaving the fused 3-D weight unquantized
in the export and breaking weight-fold round-trips.
"""
for w_attr, q_attr in (
("gate_up_proj", "gate_up_proj_weight_quantizers"),
("down_proj", "down_proj_weight_quantizers"),
):
quantizers = getattr(module, q_attr, None)
if not isinstance(quantizers, nn.ModuleList):
continue
if not any(
isinstance(q, TensorQuantizer) and q.fake_quant and q.is_enabled for q in quantizers
):
continue
sd_key = f"{module_name}.{w_attr}" if module_name else w_attr
if sd_key in fakequant_weights:
raise RuntimeError(f"Weight {sd_key} has already been fakequantized")

if inplace:
w = getattr(module, w_attr)
for idx, q in enumerate(quantizers):
if not (isinstance(q, TensorQuantizer) and q.fake_quant and q.is_enabled):
continue
slice_ = w.data[idx]
slice_.copy_(q(slice_.float()).to(w.dtype))
else:
if state_dict is None or sd_key not in state_dict:
continue
w_3d = state_dict[sd_key].clone()
for idx, q in enumerate(quantizers):
if not (isinstance(q, TensorQuantizer) and q.fake_quant and q.is_enabled):
continue
slice_ = w_3d[idx]
w_3d[idx] = q(slice_.float()).to(slice_.dtype)
state_dict[sd_key] = w_3d.cpu()
fakequant_weights.add(sd_key)


def _fakequant_module_weights(
module: nn.Module,
module_name: str,
Expand All @@ -159,6 +213,7 @@ def _fakequant_module_weights(
"""
if not isinstance(module, QuantModule):
return
_fakequant_fused_experts_weights(module, module_name, state_dict, fakequant_weights, inplace)
for attr_name, quantizer in module.named_children():
if not (
attr_name.endswith("weight_quantizer")
Expand Down
9 changes: 7 additions & 2 deletions modelopt/torch/export/quant_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
QuantizerAttrNames,
quantizer_attr_names,
reduce_block_amax,
representative_weight_quantizer,
weight_attr_names,
)
from modelopt.torch.utils import clear_cuda_cache
Expand Down Expand Up @@ -546,7 +547,7 @@ def _compute_kv_cache_dtype(

def get_weight_block_size(module: nn.Module, weight_name: str = "weight") -> int:
"""Returns the weight block size."""
weight_quantizer = getattr(module, quantizer_attr_names(weight_name).weight_quantizer, None)
weight_quantizer = representative_weight_quantizer(module, weight_name)

if weight_quantizer is None:
return 0
Expand All @@ -572,7 +573,11 @@ def get_quantization_format(module) -> str | None:
"""

def _get_quantization_from_layer(layer, quantizer_attr_names: QuantizerAttrNames):
weight_quantizer = getattr(layer, quantizer_attr_names.weight_quantizer, None)
# Singular form first, plural ModuleList fallback (fused-experts).
# Strip the "_weight_quantizer" suffix to recover the weight attr name.
weight_attr = quantizer_attr_names.weight_quantizer
weight_name = weight_attr[: -len("_weight_quantizer")].rstrip("_") or "weight"
weight_quantizer = representative_weight_quantizer(layer, weight_name)
input_quantizer = getattr(layer, quantizer_attr_names.input_quantizer, None)

if weight_quantizer is None or not weight_quantizer.is_enabled:
Expand Down
19 changes: 11 additions & 8 deletions modelopt/torch/export/unified_export_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@
QUANTIZATION_W4A8_NVFP4_FP8,
)
from .model_utils import get_language_model_from_vl, is_multimodal_model
from .moe_utils import _export_fused_experts
from .plugins import SpeculativeDecodingExporter, has_spec_opt
from .quant_utils import (
fuse_prequant_layernorm,
Expand Down Expand Up @@ -642,11 +643,20 @@ def _process_quantized_modules(
if is_modelopt_qlora and (hasattr(sub_module, "base_layer")):
continue

# Preprocessing: restore unpacked weight so the export path can read
# the live quantizer state. Falls through to the export branches below.
if hasattr(sub_module, "weight_packed") or (
"QuantFP8Linear" in type(sub_module).__name__ and sub_module.weight.element_size() <= 1
):
sub_module.unpack_weight()
if get_quantization_format(sub_module) != QUANTIZATION_NONE:

if hasattr(sub_module, "gate_up_proj_weight_quantizers"):
# _QuantFusedExperts uses plural `gate_up_proj_weight_quantizers` (ModuleList),
# which get_quantization_format's singular-weight_quantizer check misses. Handle
# it explicitly before the format gate so fused-experts get split + quantized.
with fsdp2_aware_weight_update(model, sub_module, reshard=False):
_export_fused_experts(sub_module, dtype)
Comment thread
meenchen marked this conversation as resolved.
elif get_quantization_format(sub_module) != QUANTIZATION_NONE:
# Skip QuantMoELinear - it's handled separately in _reconstruct_fused_moe_linear
if type(sub_module).__name__ == "QuantMoELinear":
continue
Expand Down Expand Up @@ -677,13 +687,6 @@ def _process_quantized_modules(
with fsdp2_aware_weight_update(model, sub_module, reshard=False):
for weight_name in ["gate_up_proj", "down_proj"]:
_export_quantized_weight(sub_module, dtype, weight_name)
elif hasattr(sub_module, "gate_up_proj_weight_quantizers"):
# Generic fused MoE experts (_QuantFusedExperts) with per-expert
# quantizer ModuleLists. Split into per-expert modules and export.
from modelopt.torch.export.moe_utils import _export_fused_experts

with fsdp2_aware_weight_update(model, sub_module, reshard=False):
_export_fused_experts(sub_module, dtype)


def _export_transformers_checkpoint(
Expand Down
34 changes: 33 additions & 1 deletion modelopt/torch/quantization/conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
"""Quantization conversion/restore utilities."""

import fnmatch
import re
import warnings
from collections.abc import Callable
from contextlib import contextmanager
Expand Down Expand Up @@ -286,6 +287,33 @@ def set_quantizer_by_cfg(quant_model: nn.Module, quant_cfg: QuantizeQuantCfgType
set_quantizer_attributes_full(quant_model, quantizer_name, attributes, parent_class)


_FUSED_EXPERTS_QUANTIZER_LIST_RE = re.compile(
r"(weight_quantizers?|input_quantizers?)\.\d+(?=$|\.)"
)


def _normalize_fused_experts_quantizer_name(name: str) -> str:
"""Strip the per-expert index from per-expert quantizer ModuleList names.

Fused-experts modules register per-expert weight/input quantizers in a
``nn.ModuleList``; its children surface as dotted names like
``...gate_up_proj_weight_quantizers.0`` (plural) or — if a variant uses
singular naming — ``...gate_up_proj_weight_quantizer.0``. Neither matches
the singular-suffix wildcards (``*weight_quantizer``) used in the stock
configs, so the experts stay at their defaults.

Return a normalized name where either ``weight_quantizer[s]?.N`` or
``input_quantizer[s]?.N`` collapses to the singular form without the index
so the standard wildcards match.
"""

def _repl(m: re.Match) -> str:
base = m.group(1)
return base.removesuffix("s")

return _FUSED_EXPERTS_QUANTIZER_LIST_RE.sub(_repl, name)


def _match_quantizer(
wildcard_or_filter_func: str | Callable,
name: str,
Expand All @@ -296,7 +324,11 @@ def _match_quantizer(
if not isinstance(module, (TensorQuantizer, SequentialQuantizer)):
return False
if isinstance(wildcard_or_filter_func, str):
if not fnmatch.fnmatch(name, wildcard_or_filter_func):
normalized = _normalize_fused_experts_quantizer_name(name)
if not (
fnmatch.fnmatch(name, wildcard_or_filter_func)
or (normalized != name and fnmatch.fnmatch(normalized, wildcard_or_filter_func))
):
return False
elif callable(wildcard_or_filter_func):
if not wildcard_or_filter_func(name):
Expand Down
60 changes: 60 additions & 0 deletions modelopt/torch/quantization/plugins/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -900,6 +900,33 @@ def forward(self, *args, **kwargs):
self._down_proj_linear = False
return super().forward(*args, **kwargs)

def fold_weight(self, keep_attrs: bool = False):
"""Fold per-expert weight quantizers into the fused 3-D weights.

The base ``fold_weight`` only handles singular ``*_weight_quantizer``
attributes. Fused experts use ``nn.ModuleList`` of per-expert quantizers
(``gate_up_proj_weight_quantizers``, ``down_proj_weight_quantizers``),
which would otherwise be skipped, leaving ``_amax`` on every quantizer.
"""
for weight_name, quantizers_name in (
("gate_up_proj", "gate_up_proj_weight_quantizers"),
("down_proj", "down_proj_weight_quantizers"),
):
weight = getattr(self, weight_name, None)
quantizers = getattr(self, quantizers_name, None)
if weight is None or quantizers is None:
continue
for idx, q in enumerate(quantizers):
if not (isinstance(q, TensorQuantizer) and q.fake_quant):
continue
slice_ = weight.data[idx]
slice_.copy_(q(slice_.float()).to(weight.dtype))
q.disable()
if not keep_attrs:
for attr_name in ("_pre_quant_scale", "_amax"):
if hasattr(q, attr_name):
delattr(q, attr_name)


class _QuantDbrxFFN(_QuantSparseSequentialMoe):
@property
Expand Down Expand Up @@ -1438,6 +1465,38 @@ def register_fused_experts_on_the_fly(model):
QuantModuleRegistry.register({mod_type: f"hf.{mod_type.__name__}"})(_QuantFusedExperts)


def force_eager_experts_impl_on_the_fly(model):
"""Force HF fused-experts modules onto the eager ``F.linear``-based forward.

HF transformers 5.0+ decorates fused-experts forwards with
``@use_experts_implementation``, which may dispatch to ``torch._grouped_mm``
or ``torch.bmm`` backends. Those backends bypass ``F.linear`` and so bypass
``_QuantFusedExperts``'s input/weight quantizer hooks — calibration silently
does nothing, no ``input_scale`` / ``amax`` is collected, and the exported
checkpoint produces garbage at inference.

Sets ``config._experts_implementation = "eager"`` on the model config (and
recursively on ``text_config`` / ``vision_config`` / ``audio_config`` /
``speech_config``) whenever a fused-experts module is present.
"""
if not any(_is_fused_experts_module(m) for m in model.modules()):
return

nested_cfg_attrs = ("text_config", "vision_config", "audio_config", "speech_config")

def _force(cfg):
if cfg is None:
return
if hasattr(cfg, "_experts_implementation"):
cfg._experts_implementation = "eager"
for sub in nested_cfg_attrs:
if hasattr(cfg, sub):
_force(getattr(cfg, sub))

if hasattr(model, "config"):
_force(model.config)


def _is_supported_hf_model(model):
"""Check if the model a valid model for transformers quantization specific support."""
supported_models = [transformers.PreTrainedModel]
Expand Down Expand Up @@ -1665,6 +1724,7 @@ def _reconstruct_fused_moe_linear(model: nn.Module) -> None:
register_dbrx_moe_on_the_fly,
register_step3p5_moe_on_the_fly,
register_fused_experts_on_the_fly,
force_eager_experts_impl_on_the_fly,
register_sparse_moe_on_the_fly,
register_hf_attentions_on_the_fly,
convert_hf_parallel_linears_on_the_fly,
Expand Down
1 change: 1 addition & 0 deletions modelopt/torch/quantization/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
"reduce_amax",
"reduce_sum",
"replace_function",
"representative_weight_quantizer",
"update_quant_cfg_with_kv_cache_quant",
"weight_attr_names",
]
58 changes: 44 additions & 14 deletions modelopt/torch/quantization/utils/core_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,27 +202,57 @@ def reduce_sum(input, axis=None, keepdims=True):
return output


def weight_attr_names(module: nn.Module) -> "Generator[str, None, None]":
"""Get the weight param attribute names in a converted module, non-recursive.
def representative_weight_quantizer(module: nn.Module, weight_name: str = "weight"):
"""Return the representative weight quantizer for ``weight_name`` on ``module``.

Handles two layouts:

- singular ``<name>_weight_quantizer`` — standard ``nn.Linear`` / ``_QuantLinear``.
- plural ``<name>_weight_quantizers`` (``nn.ModuleList``) — fused-experts modules
(``_QuantFusedExperts``) hold one ``TensorQuantizer`` per expert. Per-expert
formats are identical, so the first element is representative.

We consider the following two cases for each weight param attribute:
- The standard weight attribute (e.g. nn.Linear).
- The custom `weight_attr_name`. (e.g. Llama4TextExperts has weight attributes `gate_up_proj` and `down_proj`)
Returns ``None`` if no matching quantizer is found.
"""
from ..nn import SequentialQuantizer, TensorQuantizer

# the standard weight and quantizer case
weight = getattr(module, "weight", None)
weight_quantizer = getattr(module, "weight_quantizer", None)
if weight is not None and isinstance(weight_quantizer, (TensorQuantizer, SequentialQuantizer)):
yield "weight"
singular = quantizer_attr_names(weight_name).weight_quantizer
q = getattr(module, singular, None)
if isinstance(q, (TensorQuantizer, SequentialQuantizer)):
return q

# other weight and quantizer case
plural = getattr(module, singular + "s", None)
if isinstance(plural, nn.ModuleList) and len(plural) > 0:
first = plural[0]
if isinstance(first, (TensorQuantizer, SequentialQuantizer)):
return first
return None


def weight_attr_names(module: nn.Module) -> "Generator[str, None, None]":
"""Get the weight param attribute names in a converted module, non-recursive.

Covers three layouts:

- standard ``nn.Linear``: ``weight`` + ``weight_quantizer``.
- custom per-weight quantizer (e.g. ``Llama4TextExperts`` with ``gate_up_proj`` +
``gate_up_proj_weight_quantizer``).
- fused-experts ``nn.ModuleList`` quantizers (``_QuantFusedExperts`` with
``gate_up_proj`` + ``gate_up_proj_weight_quantizers`` plural list).
"""
# standard: "weight" + "weight_quantizer" (singular) or "weight_quantizers" (plural)
if getattr(module, "weight", None) is not None:
if representative_weight_quantizer(module, "weight") is not None:
yield "weight"

# per-parameter custom attr names
for name, _ in module.named_parameters(recurse=False):
if name == "weight":
continue
weight = getattr(module, name, None)
weight_quantizer = getattr(module, f"{name}_weight_quantizer", None)
if isinstance(weight, nn.Parameter) and isinstance(
weight_quantizer, (TensorQuantizer, SequentialQuantizer)
if (
isinstance(weight, nn.Parameter)
and representative_weight_quantizer(module, name) is not None
):
yield name
Comment thread
coderabbitai[bot] marked this conversation as resolved.

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,9 @@ quantize:
algorithm:
method: max
# Max calibration is fast and does not typically need checkpointing.
layerwise: true
# layerwise=false required for VLMs where the decoder layers are nested under
# `model.language_model.layers` (layerwise_calibrate can't find them otherwise).
layerwise: false
quant_cfg:
- quantizer_name: '*'
enable: false
Expand Down
Loading
Loading