Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
139 changes: 96 additions & 43 deletions gptqmodel/looper/awq_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,18 +20,18 @@
from ..models._const import SUPPORTS_MODULE_TYPES
from ..models.writer import (PROCESS_LOG_LAYER, PROCESS_LOG_MODULE, PROCESS_LOG_NAME,
PROCESS_LOG_TIME, PROCESS_USED_MEMORY, QUANT_LOG_LOSS, QUANT_LOG_NSAMPLES)
from ..nn_modules.qlinear.awq_gemm import AwqGEMMQuantLinear
from ..nn_modules.qlinear.awq_gemv import AwqGEMVQuantLinear
from ..nn_modules.qlinear.awq_gemv_fast import AwqGEMVFastQuantLinear
from ..nn_modules.qlinear.awq_marlin import AwqMarlinQuantLinear
from ..quantization.awq.modules.linear import WQLinear_GEMM, WQLinear_GEMV, WQLinear_GEMVFast, WQLinear_Marlin
from ..nn_modules.qlinear.gemm_awq import AwqGEMMQuantLinear
from ..nn_modules.qlinear.gemv_awq import AwqGEMVQuantLinear
from ..nn_modules.qlinear.gemv_fast_awq import AwqGEMVFastQuantLinear
from ..nn_modules.qlinear.marlin_awq import AwqMarlinQuantLinear
from ..quantization.awq.quantize.scale import apply_clip, apply_scale
from ..quantization.awq.utils.module import append_str_prefix, get_op_name, get_op_by_name, set_op_by_name
from ..quantization.awq.utils.module import append_str_prefix, get_op_name, get_op_by_name
from ..quantization.awq.utils.utils import get_best_device
from ..quantization.config import FORMAT, METHOD, QuantizeConfig
from ..utils.logger import setup_logger
from ..utils.logger import setup_logger, log_time_block
from ..utils.ctx import ctx
from ..utils.model import find_modules, get_module_by_name_prefix, move_to
from ..utils.model import find_modules, get_module_by_name_prefix, move_to, create_quant_module, pack_module
from ..utils.module_locks import parent_module_lock
from ..utils.torch import CPU

log = setup_logger()
Expand Down Expand Up @@ -90,6 +90,16 @@ def __init__(
self._layer_states_lock = threading.Lock()
self._scale_context = threading.local()
self.gptq_model = gptq_model

if qcfg.format == FORMAT.GEMM:
self.gptq_model.qlinear_kernel = AwqGEMMQuantLinear
elif qcfg.format == FORMAT.GEMV:
self.gptq_model.qlinear_kernel = AwqGEMVQuantLinear
elif qcfg.format == FORMAT.GEMV_FAST:
self.gptq_model.qlinear_kernel = AwqGEMVFastQuantLinear
else:
raise ValueError(f"METHOD.AWQ does not support this FORMAT: {qcfg.format}")

self.model = model
# Whether to apply clipping to the model during quantization. Some models may perform better with this set to False.
self.apply_clip = True
Expand All @@ -101,7 +111,7 @@ def __init__(
# This argument avoids real quantization by only applying the scales without quantizing down to FP16.
self.export_compatible = False

self.version = qcfg.format
self.format = qcfg.format

# Whether to scale using both w/x or just x.
self.duo_scaling = True
Expand Down Expand Up @@ -334,10 +344,13 @@ def _quantize_layer(self, layer_index: int, state: _AWQLayerState) -> None:

with state.lock:
# Filtering MLP modules like Qwen3MoeSparseMoeBlock
def unwrap(m):
return m.module if isinstance(m, NamedModule) else m

named_childs = {
name: module
for name, module in state.modules.items()
if isinstance(module, tuple(SUPPORTS_MODULE_TYPES))
if isinstance(unwrap(module), tuple(SUPPORTS_MODULE_TYPES))
}

module_kwargs_global = dict(self._module_forward_kwargs)
Expand Down Expand Up @@ -543,7 +556,7 @@ def _quantize_layer(self, layer_index: int, state: _AWQLayerState) -> None:

if not self.export_compatible:
start = time.time()
self._apply_quant(layer_module_ref, named_childs, start, scales_list)
self.pack_module(named_childs, start, scales_list)

with state.lock:
state.quantized = True
Expand Down Expand Up @@ -1058,7 +1071,7 @@ def _slice_value(val, length):

return module_output

def _apply_quant(self, module, named_linears: Dict[str, NamedModule], start_time, scales_list):
def pack_module(self, named_linears: Dict[str, NamedModule], start_time, scales_list):
for name, named_module in named_linears.items():
self.pb.title(f"Quantizing {named_module.name} in layer ").draw()
linear_layer = named_module.module
Expand Down Expand Up @@ -1113,37 +1126,6 @@ def _apply_quant(self, module, named_linears: Dict[str, NamedModule], start_time

linear_layer.weight.data = wq

if self.version == "gemm":
scales = scales.t().contiguous()
if zeros is not None:
zeros = zeros.t().contiguous()
q_linear_module = WQLinear_GEMM

elif self.version == "gemv":
q_linear_module = WQLinear_GEMV

elif self.version == "marlin":
q_linear_module = WQLinear_Marlin

elif self.version == "gemv_fast":
q_linear_module = WQLinear_GEMVFast

else:
raise ValueError(f"Unknown version {self.version}")

q_linear = q_linear_module.from_linear(
linear=linear_layer,
w_bit=self.qcfg.bits,
group_size=self.qcfg.group_size,
init_only=False,
scales=scales,
zeros=zeros,
)

linear_layer.cpu()
q_linear.to(next(module.parameters()).device)
set_op_by_name(module, name, q_linear)

# records
duration = time.time() - start_time

Expand Down Expand Up @@ -1191,6 +1173,77 @@ def _apply_quant(self, module, named_linears: Dict[str, NamedModule], start_time
f"{duration:.3f}",
)

linear_layer = linear_layer.cpu()
scales = scales.cpu()
zeros = zeros.cpu()

layers = find_modules(self.gptq_model.model)
module_label = getattr(named_module, "full_name", getattr(named_module, "name", ""))
parent_key = getattr(named_module, "full_name", getattr(named_module, "name", None))

# replace module with quantized module
timer = getattr(self.gptq_model, "quant_region_timer", None)

create_start = time.perf_counter() if timer is not None else None
with log_time_block(
"create_quant_module",
logger=log,
module_name=module_label,
):
with parent_module_lock(parent_key):
create_quant_module(
name=named_module.full_name,
linear_cls=self.gptq_model.qlinear_kernel,
bits=self.qcfg.bits,
desc_act=self.qcfg.desc_act,
dynamic=self.qcfg.dynamic,
group_size=self.qcfg.group_size,
module=self.gptq_model.model,
submodule=named_module,
sym=self.qcfg.sym,
device=self.qcfg.device,
lm_head_name=self.gptq_model.lm_head,
pack_dtype=self.qcfg.pack_dtype,
register_buffers=False,
)
if timer is not None and create_start is not None:
timer.record(
"submodule_finalize_create",
time.perf_counter() - create_start,
source=module_label,
)

# pack module
qModules = {
name: submodule
for name, submodule in find_modules(self.gptq_model.model, [self.gptq_model.qlinear_kernel]).items()
if name == named_module.full_name
}
pack_start = time.perf_counter() if timer is not None else None
with log_time_block(
"pack",
logger=log,
module_name=module_label,
):
with parent_module_lock(parent_key):
packer_label = pack_module(
name=named_module.full_name,
qModules=qModules,
q_scales=scales,
q_zeros=zeros,
q_g_idx=None,
layers=layers,
quant_linear_cls=self.gptq_model.qlinear_kernel,
lock=self.lock,
quantize_config=self.qcfg,
)
if timer is not None and pack_start is not None:
timer.record(
"submodule_finalize_pack",
time.perf_counter() - pack_start,
source=f"{module_label} [{packer_label or 'module.pack_original'}]",
)

def _sanitize_kwargs(self, inputs_kwargs, module):
"""
Remove the arguments that are not supported in the module's
Expand Down
12 changes: 6 additions & 6 deletions gptqmodel/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,11 +238,11 @@ def __init__(

if type(self).module_tree is None:
type(self).module_tree = self._auto_detect_module_tree(model, quant_method)

# If module_tree is still None after auto-detection, raise an error indicating unsupported model type
if type(self).module_tree is None:
raise ValueError(f"Unsupport model_type {model.config.model_type}, and failed to auto-detect module tree for model {model}")


# record configuration early so model lifecycle hooks can rely on them
self.compiled = False # set to True while compile() is triggered successfully
Expand Down Expand Up @@ -1690,7 +1690,7 @@ def _get(path):
"blocks",
"model.blocks",
]

chosen = None
for c in candidates:
m = _get(c)
Expand All @@ -1700,7 +1700,7 @@ def _get(path):
break

if chosen is None:
log.warn("Module Tree AutoCompat: All candidate paths invalid, return None")
log.warn("Module Tree AutoCompat: All candidate paths invalid, return None")
return None

layer0 = _get(chosen)[0]
Expand All @@ -1715,7 +1715,7 @@ def _linear_names(module):
if len(all_linear)>0:
log.warn(f"Module Tree AutoCompat: found {len(all_linear)} Linear/Conv modules in {type(layer0).__name__}: {all_linear}")
else:
log.warn(f"Module Tree AutoCompat: No Linear/Conv names in layer0, return None")
log.warn("Module Tree AutoCompat: No Linear/Conv names in layer0, return None")
return None

mapping = {}
Expand All @@ -1732,7 +1732,7 @@ def _leaf_tokens(prefix):
return tuple(x.split(".")[-1] for x in all_linear if x.startswith(f"{prefix}."))

possible_parent = ["attn", "attention", "self_attn", "mlp", "ffn", "feed", "dense"]

found_parents = _find_parents(layer0, possible_parent)

for p in found_parents:
Expand Down
9 changes: 8 additions & 1 deletion gptqmodel/nn_modules/qlinear/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,8 @@ def validate(
in_features:int=None,
out_features:int=None,
pack_dtype:t.dtype=None,
dtype: Optional[t.dtype]=None,
zero_point: Optional[bool]=None,
dynamic:Optional[dict]=None,
device:Optional[DEVICE]=None,
trainable:Optional[bool]=None,
Expand All @@ -235,6 +237,7 @@ def validate(
bool, Optional[Exception]]:
return cls._validate(bits=bits, group_size=group_size, desc_act=desc_act, sym=sym,
in_features=in_features, out_features=out_features, pack_dtype=pack_dtype,
dtype=dtype, zero_point=zero_point,
dynamic=dynamic, device=device, trainable=trainable, adapter=adapter)

@classmethod
Expand Down Expand Up @@ -274,7 +277,7 @@ def verify_supports_params(cls):
# raise ValueError(f"{cls.__name__}.{name} cannot be an empty list.")

@classmethod
def _validate(cls, bits: int=4, group_size: int=128, desc_act: bool=False, sym: bool=False, pack_dtype:t.dtype=None, dynamic:Optional[dict]=None, in_features:int=None,
def _validate(cls, bits: int=4, group_size: int=128, desc_act: bool=False, sym: bool=False, pack_dtype:t.dtype=None, dtype: Optional[t.dtype]=None, zero_point: Optional[bool]=None, dynamic:Optional[dict]=None, in_features:int=None,
out_features:int=None, device:Optional[DEVICE]=None, trainable:Optional[bool]=None, adapter:Optional[Adapter]=None) -> Tuple[bool, Optional[Exception]]:
cls.verify_supports_params()

Expand All @@ -286,6 +289,10 @@ def _validate(cls, bits: int=4, group_size: int=128, desc_act: bool=False, sym:
err = f"{cls} does not support `pack_dtype`: {pack_dtype}"
return False, NotImplementedError(err)

if dtype is not None and dtype not in cls.SUPPORTS_DTYPES:
err = f"{cls} only supports `{cls.SUPPORTS_DTYPES}` dtype: actual dtype = `{dtype}`"
return False, NotImplementedError(err)

if PLATFORM.ALL not in cls.SUPPORTS_PLATFORM and sys.platform not in cls.SUPPORTS_PLATFORM:
err = f"{cls} does not support platform: {sys.platform}"
return False, NotImplementedError(err)
Expand Down
Loading