From 6562e66f3914bb831fb3c708cd1360b5717a14cf Mon Sep 17 00:00:00 2001 From: Qubitium Date: Sat, 22 Nov 2025 07:10:47 +0000 Subject: [PATCH 01/24] remove unused --- gptqmodel/quantization/awq/_config.py | 108 -------------------------- 1 file changed, 108 deletions(-) delete mode 100644 gptqmodel/quantization/awq/_config.py diff --git a/gptqmodel/quantization/awq/_config.py b/gptqmodel/quantization/awq/_config.py deleted file mode 100644 index 433fe1267..000000000 --- a/gptqmodel/quantization/awq/_config.py +++ /dev/null @@ -1,108 +0,0 @@ -# SPDX-FileCopyrightText: 2024-2025 ModelCloud.ai -# SPDX-FileCopyrightText: 2024-2025 qubitium@modelcloud.ai -# SPDX-License-Identifier: Apache-2.0 -# Contact: qubitium@modelcloud.ai, x.com/qubitium - -import json -import os -from dataclasses import dataclass, field -from typing import Dict, List, Optional - -from transformers.utils.hub import PushToHubMixin, cached_file - - -@dataclass -class AwqConfig(PushToHubMixin): - quant_method: str = field(default="awq") - zero_point: bool = field(default=True) - q_group_size: int = field(default=128) - w_bit: int = field(default=4) - version: str = field(default="gemm") - config_file_name = "config.json" - modules_to_not_convert: Optional[List] = None - - @classmethod - def from_dict(cls, quant_config: Dict = {}): - if not quant_config: - quant_config = cls() - else: - quant_config = cls(**quant_config) - quant_config.version = quant_config.version.lower() - - return quant_config - - @classmethod - def from_pretrained(cls, save_dir: str, **kwargs): - cache_dir = kwargs.pop("cache_dir", None) - force_download = kwargs.pop("force_download", False) - resume_download = kwargs.pop("resume_download", False) - proxies = kwargs.pop("proxies", None) - local_files_only = kwargs.pop("local_files_only", False) - use_auth_token = kwargs.pop("use_auth_token", None) - revision = kwargs.pop("revision", None) - subfolder = kwargs.pop("subfolder", None) - commit_hash = kwargs.pop("_commit_hash", None) - - if os.path.isdir(save_dir): # Local - resolved_config_file = os.path.join(save_dir, cls.config_file_name) - else: # Remote - resolved_config_file = cached_file( - save_dir, - cls.config_file_name, - cache_dir=cache_dir, - force_download=force_download, - resume_download=resume_download, - proxies=proxies, - use_auth_token=use_auth_token, - revision=revision, - local_files_only=local_files_only, - subfolder=subfolder, - _raise_exceptions_for_missing_entries=False, - _raise_exceptions_for_connection_errors=False, - _commit_hash=commit_hash, - ) - - quant_config = None - if os.path.exists(resolved_config_file): - with open(resolved_config_file, "r", encoding="utf-8") as file: - loaded_config = json.loads(file.read()) - - quant_config = loaded_config.get("quantization_config") - - if quant_config is not None: - awq_config = cls.from_transformers_dict(cls, quant_config) - quant_config = cls(**awq_config) - - if quant_config is None: - quant_config = cls() - - return quant_config - - def to_dict(self): - return { - "zero_point": self.zero_point, - "q_group_size": self.q_group_size, - "w_bit": self.w_bit, - "version": self.version, - "modules_to_not_convert": self.modules_to_not_convert, - } - - def to_transformers_dict(self): - return { - "quant_method": self.quant_method, - "zero_point": self.zero_point, - "group_size": self.q_group_size, - "bits": self.w_bit, - "version": self.version.lower(), - "modules_to_not_convert": self.modules_to_not_convert, - } - - def from_transformers_dict(self, transformers_dict: Dict): - return { - "quant_method": transformers_dict.get("quant_method"), - "zero_point": transformers_dict.get("zero_point"), - "q_group_size": transformers_dict.get("group_size"), - "w_bit": transformers_dict.get("bits"), - "version": transformers_dict.get("version"), - "modules_to_not_convert": transformers_dict.get("modules_to_not_convert"), - } From e3031363876d014d9264506f26cadef440e626f9 Mon Sep 17 00:00:00 2001 From: Qubitium Date: Sat, 22 Nov 2025 07:49:29 +0000 Subject: [PATCH 02/24] add hf_select_quant_linear_v2 for transformer compat --- gptqmodel/nn_modules/qlinear/__init__.py | 9 +- gptqmodel/utils/importer.py | 103 ++++++++++++++++++++++- 2 files changed, 110 insertions(+), 2 deletions(-) diff --git a/gptqmodel/nn_modules/qlinear/__init__.py b/gptqmodel/nn_modules/qlinear/__init__.py index 442e925be..881ae8ab6 100644 --- a/gptqmodel/nn_modules/qlinear/__init__.py +++ b/gptqmodel/nn_modules/qlinear/__init__.py @@ -227,6 +227,8 @@ def validate( in_features:int=None, out_features:int=None, pack_dtype:t.dtype=None, + dtype: Optional[t.dtype]=None, + zero_point: Optional[bool]=None, dynamic:Optional[dict]=None, device:Optional[DEVICE]=None, trainable:Optional[bool]=None, @@ -235,6 +237,7 @@ def validate( bool, Optional[Exception]]: return cls._validate(bits=bits, group_size=group_size, desc_act=desc_act, sym=sym, in_features=in_features, out_features=out_features, pack_dtype=pack_dtype, + dtype=dtype, zero_point=zero_point, dynamic=dynamic, device=device, trainable=trainable, adapter=adapter) @classmethod @@ -274,7 +277,7 @@ def verify_supports_params(cls): # raise ValueError(f"{cls.__name__}.{name} cannot be an empty list.") @classmethod - def _validate(cls, bits: int=4, group_size: int=128, desc_act: bool=False, sym: bool=False, pack_dtype:t.dtype=None, dynamic:Optional[dict]=None, in_features:int=None, + def _validate(cls, bits: int=4, group_size: int=128, desc_act: bool=False, sym: bool=False, pack_dtype:t.dtype=None, dtype: Optional[t.dtype]=None, zero_point: Optional[bool]=None, dynamic:Optional[dict]=None, in_features:int=None, out_features:int=None, device:Optional[DEVICE]=None, trainable:Optional[bool]=None, adapter:Optional[Adapter]=None) -> Tuple[bool, Optional[Exception]]: cls.verify_supports_params() @@ -286,6 +289,10 @@ def _validate(cls, bits: int=4, group_size: int=128, desc_act: bool=False, sym: err = f"{cls} does not support `pack_dtype`: {pack_dtype}" return False, NotImplementedError(err) + if dtype is not None and dtype not in cls.SUPPORTS_DTYPES: + err = f"{cls} only supports `{cls.SUPPORTS_DTYPES}` dtype: actual dtype = `{dtype}`" + return False, NotImplementedError(err) + if PLATFORM.ALL not in cls.SUPPORTS_PLATFORM and sys.platform not in cls.SUPPORTS_PLATFORM: err = f"{cls} does not support platform: {sys.platform}" return False, NotImplementedError(err) diff --git a/gptqmodel/utils/importer.py b/gptqmodel/utils/importer.py index 03805504f..5cab8eca9 100644 --- a/gptqmodel/utils/importer.py +++ b/gptqmodel/utils/importer.py @@ -210,6 +210,81 @@ def hf_select_quant_linear( adapter=None, ) +# public/stable api exposed to transformer/optimum +def hf_select_quant_linear_v2( + bits: int, + group_size: int, + desc_act: bool, + sym: bool, + format: Union[str, FORMAT], # awq `version` should be pre-mapped to format + quant_method: Union[str, METHOD], # awq llm-awq `version` should be pre-mapped to method + zero_point: Optional[bool] = True, # awq only + dtype: Optional[Union[str, torch.dtype]] = None, + meta: Optional[Dict[str, any]] = None, + pack: Optional[bool] = True, + device_map: Optional[Union[str, dict]] = None, + backend: Optional[Union[str, BACKEND]] = None, +) -> Type[BaseQuantLinear]: + # convert hf string backend to backend.enum + if isinstance(backend, str): + backend = BACKEND(backend.lower()) + + def _normalize_enum(value, enum_cls, field: str): + if isinstance(value, enum_cls): + return value + if isinstance(value, str): + try: + return enum_cls(value.lower()) + except ValueError as exc: + raise ValueError(f"Unsupported {field}: `{value}`") from exc + raise ValueError(f"{field} must be a string or `{enum_cls.__name__}`, got `{type(value)}`") + + def _normalize_dtype(value: Optional[Union[str, torch.dtype]], field: str) -> Optional[torch.dtype]: + if value is None: + return None + if isinstance(value, torch.dtype): + return value + if isinstance(value, str): + normalized = value.replace("torch.", "").lower() + candidate = getattr(torch, normalized, None) + if isinstance(candidate, torch.dtype): + return candidate + raise ValueError(f"Unsupported {field}: `{value}`") + + method = _normalize_enum(quant_method, METHOD, "quant_method") + fmt = _normalize_enum(format, FORMAT, "format") + normalized_dtype = _normalize_dtype(dtype, "dtype") + + pack_dtype_override = None + if meta is not None: + pack_dtype_override = meta.get("pack_dtype", None) + # GEMV_FAST checkpoints are packed as int16; default to int32 otherwise. + default_pack_dtype = torch.int16 if method == METHOD.AWQ and fmt == FORMAT.GEMV_FAST else torch.int32 + pack_dtype = _normalize_dtype(pack_dtype_override, "pack_dtype") if pack_dtype_override is not None else default_pack_dtype + + if device_map is not None: + device = normalize_device_device_map(None, device_map) + else: + device = DEVICE.CPU + + return select_quant_linear( + bits=bits, + group_size=group_size, + desc_act=desc_act, + sym=sym, + backend=backend, + device=device, + format=fmt, + quant_method=method, + pack=pack, + allow_marlin=True, # TODO: remove this after marlin padding is fixed + dynamic=None, + pack_dtype=pack_dtype, + dtype=normalized_dtype, + zero_point=zero_point, + adapter=None, + ) + # auto select the correct/optimal QuantLinear class def select_quant_linear( @@ -225,6 +300,8 @@ def select_quant_linear( allow_marlin: bool = True, # TODO: remove this after marlin padding is fixed dynamic=None, pack_dtype: torch.dtype = None, + dtype: Optional[torch.dtype] = None, + zero_point: Optional[bool] = None, multi_select: bool = False, # return all valid kernels adapter: Optional[Adapter] = None, ) -> Union[Type[BaseQuantLinear], List[Type[BaseQuantLinear]]]: @@ -232,6 +309,17 @@ def select_quant_linear( if device is None: device = DEVICE.CUDA + if isinstance(format, str): + format = FORMAT(format.lower()) + if isinstance(quant_method, str): + quant_method = METHOD(quant_method.lower()) + + supported_formats = SUPPORTS_BACKEND_MAP.get(quant_method) + if supported_formats is None: + raise ValueError(f"Unsupported quantization method: `{quant_method}`") + if format not in supported_formats: + raise ValueError(f"Unsupported format: `{format}` for quantization method `{quant_method}`") + backend = BACKEND.AUTO if backend is None else backend trainable = backend == BACKEND.AUTO_TRAINABLE @@ -250,6 +338,8 @@ def select_quant_linear( desc_act=desc_act, sym=sym, pack_dtype=pack_dtype, + dtype=dtype, + zero_point=zero_point, dynamic=dynamic, device=device, trainable=trainable, @@ -335,7 +425,18 @@ def select_quant_linear( else: qlinear = TorchQuantLinear - validate, err = qlinear.validate(bits=bits, group_size=group_size, desc_act=desc_act, sym=sym, pack_dtype=pack_dtype, dynamic=dynamic, device=device, trainable=trainable) + validate, err = qlinear.validate( + bits=bits, + group_size=group_size, + desc_act=desc_act, + sym=sym, + pack_dtype=pack_dtype, + dtype=dtype, + zero_point=zero_point, + dynamic=dynamic, + device=device, + trainable=trainable, + ) log.info(f"{'Packing' if pack else ''} Kernel: selected: `{qlinear.__name__}`") if not validate: From 77ff529c116466d6c6a4bf397d0e4652d8830fae Mon Sep 17 00:00:00 2001 From: Qubitium Date: Sat, 22 Nov 2025 07:53:23 +0000 Subject: [PATCH 03/24] machete is not ready --- tests/test_awq.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_awq.py b/tests/test_awq.py index d1d3cc7f2..083bee402 100644 --- a/tests/test_awq.py +++ b/tests/test_awq.py @@ -126,7 +126,7 @@ def tearDownClass(cls): @parameterized.expand([ (FORMAT.GEMM, BACKEND.GEMM, 128), - (FORMAT.GEMM, BACKEND.MACHETE, 128), + #(FORMAT.GEMM, BACKEND.MACHETE, 128), (FORMAT.GEMM, BACKEND.MARLIN, 128), (FORMAT.GEMV, BACKEND.GEMV, 128), (FORMAT.GEMV_FAST, BACKEND.GEMV_FAST, 128), From f45b57cb538ead48c5d6a18e05761c7118f0d2c3 Mon Sep 17 00:00:00 2001 From: Qubitium Date: Sat, 22 Nov 2025 09:32:12 +0000 Subject: [PATCH 04/24] remove unused post_init --- .../{awq_exllama.py => exllama_awq.py} | 0 .../{awq_exllamav2.py => exllamav2_awq.py} | 1 - .../awq/modules/linear/__init__.py | 4 +- .../awq/modules/linear/exllama.py | 7 --- .../awq/modules/linear/exllamav2.py | 45 ------------------- 5 files changed, 1 insertion(+), 56 deletions(-) rename gptqmodel/nn_modules/qlinear/{awq_exllama.py => exllama_awq.py} (100%) rename gptqmodel/nn_modules/qlinear/{awq_exllamav2.py => exllamav2_awq.py} (98%) diff --git a/gptqmodel/nn_modules/qlinear/awq_exllama.py b/gptqmodel/nn_modules/qlinear/exllama_awq.py similarity index 100% rename from gptqmodel/nn_modules/qlinear/awq_exllama.py rename to gptqmodel/nn_modules/qlinear/exllama_awq.py diff --git a/gptqmodel/nn_modules/qlinear/awq_exllamav2.py b/gptqmodel/nn_modules/qlinear/exllamav2_awq.py similarity index 98% rename from gptqmodel/nn_modules/qlinear/awq_exllamav2.py rename to gptqmodel/nn_modules/qlinear/exllamav2_awq.py index 5233d7e79..21add1a29 100644 --- a/gptqmodel/nn_modules/qlinear/awq_exllamav2.py +++ b/gptqmodel/nn_modules/qlinear/exllamav2_awq.py @@ -118,7 +118,6 @@ def post_init(self, scratch_space: ScratchSpace): def forward(self, x: torch.Tensor): assert self.q_handle is not None, ( "module.post_init() must be called before module.forward(). " - "Use exllamav2_post_init() on the whole model." ) if exlv2_ext is None: raise ModuleNotFoundError("External ExLlamaV2 kernels are not properly installed." + msg) diff --git a/gptqmodel/quantization/awq/modules/linear/__init__.py b/gptqmodel/quantization/awq/modules/linear/__init__.py index 162de045c..4e6637ad7 100644 --- a/gptqmodel/quantization/awq/modules/linear/__init__.py +++ b/gptqmodel/quantization/awq/modules/linear/__init__.py @@ -3,9 +3,7 @@ # SPDX-License-Identifier: Apache-2.0 # Contact: qubitium@modelcloud.ai, x.com/qubitium -from .exllama import WQLinear_Exllama, exllama_post_init -from .exllamav2 import WQLinear_ExllamaV2, exllamav2_post_init from .gemm import WQLinear_GEMM from .gemv import WQLinear_GEMV from .gemv_fast import WQLinear_GEMVFast -from .marlin import WQLinear_Marlin, marlin_post_init +from .marlin import WQLinear_Marlin diff --git a/gptqmodel/quantization/awq/modules/linear/exllama.py b/gptqmodel/quantization/awq/modules/linear/exllama.py index bc9657fbb..97c6f14da 100644 --- a/gptqmodel/quantization/awq/modules/linear/exllama.py +++ b/gptqmodel/quantization/awq/modules/linear/exllama.py @@ -133,10 +133,3 @@ def forward(self, x): return out.view(out_shape) - -def exllama_post_init(model): - for _, submodule in model.named_modules(): - if isinstance(submodule, WQLinear_Exllama): - submodule.post_init() - - return model diff --git a/gptqmodel/quantization/awq/modules/linear/exllamav2.py b/gptqmodel/quantization/awq/modules/linear/exllamav2.py index 295d77350..e62631993 100644 --- a/gptqmodel/quantization/awq/modules/linear/exllamav2.py +++ b/gptqmodel/quantization/awq/modules/linear/exllamav2.py @@ -133,7 +133,6 @@ def scratch_space_fixed(self, max_input_len=2048, max_batch_size=8): def forward(self, x): assert self.q_handle is not None, ( "module.post_init() must be called before module.forward(). " - "Use exllamav2_post_init() on the whole model." ) if exlv2_ext is None: raise ModuleNotFoundError("External ExLlamaV2 kernels are not properly installed." + msg) @@ -160,47 +159,3 @@ def forward(self, x): out.add_(self.bias) return out.view(out_shape) - - -class ScratchSpace: - def __init__(self, scratch_bytes, dev): - self.scratch_bytes = scratch_bytes - self.scratch = torch.empty( - self.scratch_bytes // 2, - dtype=torch.float16, - device=dev, - ) - - def get_slice(self, size_bytes): - size_halfs = next_multiple(size_bytes, 128) // 2 - scratch_slice = self.scratch.narrow(0, 0, size_halfs) - - return scratch_slice - - -def exllamav2_post_init(model, max_input_len: int = 2048, max_batch_size: int = 8): - # we search for the maximum number of bytes required for each device's scratch space - fixed_bytes: Dict[torch.device, int] = {} - for _, submodule in model.named_modules(): - if isinstance(submodule, AwqExllamaV2QuantLinear): - device = submodule.qweight.device - scratch_fixed = submodule.scratch_space_fixed( - max_input_len=max_input_len, max_batch_size=max_batch_size - ) - fixed_bytes[device] = max(fixed_bytes.get(device, 0), scratch_fixed) - - # we allocate a model-persistent scratch space for each device - model.scratch_spaces: Dict[torch.device, ScratchSpace] = {} - for device, scratch_bytes in fixed_bytes.items(): - model.scratch_spaces[device] = ScratchSpace(scratch_bytes, device) - - for _, submodule in model.named_modules(): - if isinstance(submodule, AwqExllamaV2QuantLinear): - device = submodule.qweight.device - submodule.post_init(scratch_space=model.scratch_spaces[device]) - - return model - - -def next_multiple(x, multiple): - return ((x + multiple - 1) // multiple) * multiple From a7f92857d244b746389b77fbe87af04b59629722 Mon Sep 17 00:00:00 2001 From: Qubitium Date: Sat, 22 Nov 2025 09:32:33 +0000 Subject: [PATCH 05/24] rename kernel files --- gptqmodel/looper/awq_processor.py | 8 ++++---- .../qlinear/{awq_gemm.py => gemm_awq.py} | 0 .../qlinear/{awq_gemv.py => gemv_awq.py} | 0 .../{awq_gemv_fast.py => gemv_fast_awq.py} | 0 .../qlinear/{awq_machete.py => machete_awq.py} | 0 .../qlinear/{awq_marlin.py => marlin_awq.py} | 0 .../qlinear/{awq_torch.py => torch_awq.py} | 0 gptqmodel/utils/importer.py | 16 ++++++++-------- gptqmodel/utils/model.py | 2 +- tests/test_awq.py | 10 +++++----- tests/test_awq_moe.py | 2 +- tests/test_awq_torch_kernel.py | 2 +- tests/test_kernel_output_awq.py | 6 +++--- tests/test_torch_fused_awq.py | 2 +- 14 files changed, 24 insertions(+), 24 deletions(-) rename gptqmodel/nn_modules/qlinear/{awq_gemm.py => gemm_awq.py} (100%) rename gptqmodel/nn_modules/qlinear/{awq_gemv.py => gemv_awq.py} (100%) rename gptqmodel/nn_modules/qlinear/{awq_gemv_fast.py => gemv_fast_awq.py} (100%) rename gptqmodel/nn_modules/qlinear/{awq_machete.py => machete_awq.py} (100%) rename gptqmodel/nn_modules/qlinear/{awq_marlin.py => marlin_awq.py} (100%) rename gptqmodel/nn_modules/qlinear/{awq_torch.py => torch_awq.py} (100%) diff --git a/gptqmodel/looper/awq_processor.py b/gptqmodel/looper/awq_processor.py index f365dfdac..0c356ca3d 100644 --- a/gptqmodel/looper/awq_processor.py +++ b/gptqmodel/looper/awq_processor.py @@ -20,10 +20,10 @@ from ..models._const import SUPPORTS_MODULE_TYPES from ..models.writer import (PROCESS_LOG_LAYER, PROCESS_LOG_MODULE, PROCESS_LOG_NAME, PROCESS_LOG_TIME, PROCESS_USED_MEMORY, QUANT_LOG_LOSS, QUANT_LOG_NSAMPLES) -from ..nn_modules.qlinear.awq_gemm import AwqGEMMQuantLinear -from ..nn_modules.qlinear.awq_gemv import AwqGEMVQuantLinear -from ..nn_modules.qlinear.awq_gemv_fast import AwqGEMVFastQuantLinear -from ..nn_modules.qlinear.awq_marlin import AwqMarlinQuantLinear +from ..nn_modules.qlinear.gemm_awq import AwqGEMMQuantLinear +from ..nn_modules.qlinear.gemv_awq import AwqGEMVQuantLinear +from ..nn_modules.qlinear.gemv_fast_awq import AwqGEMVFastQuantLinear +from ..nn_modules.qlinear.marlin_awq import AwqMarlinQuantLinear from ..quantization.awq.modules.linear import WQLinear_GEMM, WQLinear_GEMV, WQLinear_GEMVFast, WQLinear_Marlin from ..quantization.awq.quantize.scale import apply_clip, apply_scale from ..quantization.awq.utils.module import append_str_prefix, get_op_name, get_op_by_name, set_op_by_name diff --git a/gptqmodel/nn_modules/qlinear/awq_gemm.py b/gptqmodel/nn_modules/qlinear/gemm_awq.py similarity index 100% rename from gptqmodel/nn_modules/qlinear/awq_gemm.py rename to gptqmodel/nn_modules/qlinear/gemm_awq.py diff --git a/gptqmodel/nn_modules/qlinear/awq_gemv.py b/gptqmodel/nn_modules/qlinear/gemv_awq.py similarity index 100% rename from gptqmodel/nn_modules/qlinear/awq_gemv.py rename to gptqmodel/nn_modules/qlinear/gemv_awq.py diff --git a/gptqmodel/nn_modules/qlinear/awq_gemv_fast.py b/gptqmodel/nn_modules/qlinear/gemv_fast_awq.py similarity index 100% rename from gptqmodel/nn_modules/qlinear/awq_gemv_fast.py rename to gptqmodel/nn_modules/qlinear/gemv_fast_awq.py diff --git a/gptqmodel/nn_modules/qlinear/awq_machete.py b/gptqmodel/nn_modules/qlinear/machete_awq.py similarity index 100% rename from gptqmodel/nn_modules/qlinear/awq_machete.py rename to gptqmodel/nn_modules/qlinear/machete_awq.py diff --git a/gptqmodel/nn_modules/qlinear/awq_marlin.py b/gptqmodel/nn_modules/qlinear/marlin_awq.py similarity index 100% rename from gptqmodel/nn_modules/qlinear/awq_marlin.py rename to gptqmodel/nn_modules/qlinear/marlin_awq.py diff --git a/gptqmodel/nn_modules/qlinear/awq_torch.py b/gptqmodel/nn_modules/qlinear/torch_awq.py similarity index 100% rename from gptqmodel/nn_modules/qlinear/awq_torch.py rename to gptqmodel/nn_modules/qlinear/torch_awq.py diff --git a/gptqmodel/utils/importer.py b/gptqmodel/utils/importer.py index 5cab8eca9..cff945c19 100644 --- a/gptqmodel/utils/importer.py +++ b/gptqmodel/utils/importer.py @@ -13,14 +13,14 @@ from ..models._const import DEVICE, normalize_device from ..nn_modules.qlinear import BaseQuantLinear, PackableQuantLinear -from ..nn_modules.qlinear.awq_exllama import AwqExllamaQuantLinear -from ..nn_modules.qlinear.awq_exllamav2 import AwqExllamaV2QuantLinear -from ..nn_modules.qlinear.awq_gemm import AwqGEMMQuantLinear -from ..nn_modules.qlinear.awq_gemv import AwqGEMVQuantLinear -from ..nn_modules.qlinear.awq_gemv_fast import AwqGEMVFastQuantLinear -from ..nn_modules.qlinear.awq_machete import AwqMacheteQuantLinear -from ..nn_modules.qlinear.awq_marlin import AwqMarlinQuantLinear -from ..nn_modules.qlinear.awq_torch import AwqTorchQuantLinear +from ..nn_modules.qlinear.exllama_awq import AwqExllamaQuantLinear +from ..nn_modules.qlinear.exllamav2_awq import AwqExllamaV2QuantLinear +from ..nn_modules.qlinear.gemm_awq import AwqGEMMQuantLinear +from ..nn_modules.qlinear.gemv_awq import AwqGEMVQuantLinear +from ..nn_modules.qlinear.gemv_fast_awq import AwqGEMVFastQuantLinear +from ..nn_modules.qlinear.machete_awq import AwqMacheteQuantLinear +from ..nn_modules.qlinear.marlin_awq import AwqMarlinQuantLinear +from ..nn_modules.qlinear.torch_awq import AwqTorchQuantLinear from ..nn_modules.qlinear.bitblas import BitBLASQuantLinear from ..nn_modules.qlinear.exllama import ExllamaQuantLinear from ..nn_modules.qlinear.exllama_eora import ExllamaEoraQuantLinear diff --git a/gptqmodel/utils/model.py b/gptqmodel/utils/model.py index 35213c974..b8fd5608f 100644 --- a/gptqmodel/utils/model.py +++ b/gptqmodel/utils/model.py @@ -46,7 +46,7 @@ SUPPORTS_MODULE_TYPES, ) from ..nn_modules.qlinear import BaseQuantLinear -from ..nn_modules.qlinear.awq_exllamav2 import AwqExllamaV2QuantLinear +from ..nn_modules.qlinear.exllamav2_awq import AwqExllamaV2QuantLinear from ..nn_modules.qlinear.exllama import ExllamaQuantLinear from ..nn_modules.qlinear.exllamav2 import ExllamaV2QuantLinear from ..quantization import FORMAT, QuantizeConfig diff --git a/tests/test_awq.py b/tests/test_awq.py index 083bee402..8edbb1fd9 100644 --- a/tests/test_awq.py +++ b/tests/test_awq.py @@ -15,11 +15,11 @@ from parameterized import parameterized from transformers import AutoTokenizer -from gptqmodel.nn_modules.qlinear.awq_gemm import AwqGEMMQuantLinear -from gptqmodel.nn_modules.qlinear.awq_gemv import AwqGEMVQuantLinear -from gptqmodel.nn_modules.qlinear.awq_gemv_fast import AwqGEMVFastQuantLinear -from gptqmodel.nn_modules.qlinear.awq_machete import AwqMacheteQuantLinear -from gptqmodel.nn_modules.qlinear.awq_marlin import AwqMarlinQuantLinear +from gptqmodel.nn_modules.qlinear.gemm_awq import AwqGEMMQuantLinear +from gptqmodel.nn_modules.qlinear.gemv_awq import AwqGEMVQuantLinear +from gptqmodel.nn_modules.qlinear.gemv_fast_awq import AwqGEMVFastQuantLinear +from gptqmodel.nn_modules.qlinear.machete_awq import AwqMacheteQuantLinear +from gptqmodel.nn_modules.qlinear.marlin_awq import AwqMarlinQuantLinear from gptqmodel.quantization import FORMAT, METHOD, QUANT_CONFIG_FILENAME from gptqmodel.utils.machete import _validate_machete_device_support, machete_import_exception diff --git a/tests/test_awq_moe.py b/tests/test_awq_moe.py index 9f4371758..b11998c5c 100644 --- a/tests/test_awq_moe.py +++ b/tests/test_awq_moe.py @@ -14,7 +14,7 @@ from parameterized import parameterized from transformers import AutoTokenizer -from gptqmodel.nn_modules.qlinear.awq_gemm import AwqGEMMQuantLinear +from gptqmodel.nn_modules.qlinear.gemm_awq import AwqGEMMQuantLinear from gptqmodel.quantization import FORMAT, METHOD, QUANT_CONFIG_FILENAME from gptqmodel.utils.torch import torch_empty_cache diff --git a/tests/test_awq_torch_kernel.py b/tests/test_awq_torch_kernel.py index a91dd88ef..ba387612c 100644 --- a/tests/test_awq_torch_kernel.py +++ b/tests/test_awq_torch_kernel.py @@ -6,7 +6,7 @@ import pytest import torch -from gptqmodel.nn_modules.qlinear.awq_torch import AwqTorchQuantLinear +from gptqmodel.nn_modules.qlinear.torch_awq import AwqTorchQuantLinear from gptqmodel.quantization import FORMAT, METHOD from gptqmodel.quantization.awq.utils.packing_utils import dequantize_gemm from gptqmodel.utils.backend import BACKEND diff --git a/tests/test_kernel_output_awq.py b/tests/test_kernel_output_awq.py index 22186f6d4..c1c4d154b 100644 --- a/tests/test_kernel_output_awq.py +++ b/tests/test_kernel_output_awq.py @@ -16,12 +16,12 @@ from tabulate import tabulate from gptqmodel import BACKEND -from gptqmodel.nn_modules.qlinear.awq_gemm import AwqGEMMQuantLinear -from gptqmodel.nn_modules.qlinear.awq_marlin import ( +from gptqmodel.nn_modules.qlinear.gemm_awq import AwqGEMMQuantLinear +from gptqmodel.nn_modules.qlinear.marlin_awq import ( AwqMarlinQuantLinear, marlin_import_exception, ) -from gptqmodel.nn_modules.qlinear.awq_torch import AwqTorchQuantLinear +from gptqmodel.nn_modules.qlinear.torch_awq import AwqTorchQuantLinear from gptqmodel.nn_modules.qlinear.torch_fused_awq import TorchFusedAwqQuantLinear from gptqmodel.utils.marlin import marlin_make_workspace_new diff --git a/tests/test_torch_fused_awq.py b/tests/test_torch_fused_awq.py index 35f102838..d3644e314 100644 --- a/tests/test_torch_fused_awq.py +++ b/tests/test_torch_fused_awq.py @@ -13,7 +13,7 @@ from safetensors import safe_open from tabulate import tabulate -from gptqmodel.nn_modules.qlinear.awq_torch import AwqTorchQuantLinear +from gptqmodel.nn_modules.qlinear.torch_awq import AwqTorchQuantLinear from gptqmodel.nn_modules.qlinear.torch_fused_awq import TorchFusedAwqQuantLinear from gptqmodel.utils.torch import TORCH_HAS_FUSED_OPS From 27617447260b5884f6a4ec70133b2bd89e425da4 Mon Sep 17 00:00:00 2001 From: Qubitium Date: Sun, 23 Nov 2025 01:30:28 +0000 Subject: [PATCH 06/24] use FORMAT enum --- gptqmodel/looper/awq_processor.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/gptqmodel/looper/awq_processor.py b/gptqmodel/looper/awq_processor.py index 0c356ca3d..7fc58a046 100644 --- a/gptqmodel/looper/awq_processor.py +++ b/gptqmodel/looper/awq_processor.py @@ -101,7 +101,7 @@ def __init__( # This argument avoids real quantization by only applying the scales without quantizing down to FP16. self.export_compatible = False - self.version = qcfg.format + self.format = qcfg.format # Whether to scale using both w/x or just x. self.duo_scaling = True @@ -1113,23 +1113,23 @@ def _apply_quant(self, module, named_linears: Dict[str, NamedModule], start_time linear_layer.weight.data = wq - if self.version == "gemm": + if self.format == FORMAT.GEMM: scales = scales.t().contiguous() if zeros is not None: zeros = zeros.t().contiguous() q_linear_module = WQLinear_GEMM - elif self.version == "gemv": + elif self.format == FORMAT.GEMV: q_linear_module = WQLinear_GEMV - elif self.version == "marlin": + elif self.format == FORMAT.MARLIN: q_linear_module = WQLinear_Marlin - elif self.version == "gemv_fast": + elif self.format == FORMAT.GEMV_FAST: q_linear_module = WQLinear_GEMVFast else: - raise ValueError(f"Unknown version {self.version}") + raise ValueError(f"Unknown version {self.format}") q_linear = q_linear_module.from_linear( linear=linear_layer, From 942332b8bdbdad67bc13465166ee8ca2e8d43282 Mon Sep 17 00:00:00 2001 From: Qubitium Date: Sun, 23 Nov 2025 03:53:55 +0000 Subject: [PATCH 07/24] cleanup awq gemm/merge code --- gptqmodel/looper/awq_processor.py | 2 +- gptqmodel/nn_modules/qlinear/gemm_awq.py | 101 +++++- .../awq/modules/linear/__init__.py | 1 - .../quantization/awq/modules/linear/gemm.py | 306 ------------------ 4 files changed, 100 insertions(+), 310 deletions(-) delete mode 100644 gptqmodel/quantization/awq/modules/linear/gemm.py diff --git a/gptqmodel/looper/awq_processor.py b/gptqmodel/looper/awq_processor.py index 7fc58a046..0aabeb91e 100644 --- a/gptqmodel/looper/awq_processor.py +++ b/gptqmodel/looper/awq_processor.py @@ -24,7 +24,7 @@ from ..nn_modules.qlinear.gemv_awq import AwqGEMVQuantLinear from ..nn_modules.qlinear.gemv_fast_awq import AwqGEMVFastQuantLinear from ..nn_modules.qlinear.marlin_awq import AwqMarlinQuantLinear -from ..quantization.awq.modules.linear import WQLinear_GEMM, WQLinear_GEMV, WQLinear_GEMVFast, WQLinear_Marlin +from ..quantization.awq.modules.linear import WQLinear_GEMV, WQLinear_GEMVFast, WQLinear_Marlin from ..quantization.awq.quantize.scale import apply_clip, apply_scale from ..quantization.awq.utils.module import append_str_prefix, get_op_name, get_op_by_name, set_op_by_name from ..quantization.awq.utils.utils import get_best_device diff --git a/gptqmodel/nn_modules/qlinear/gemm_awq.py b/gptqmodel/nn_modules/qlinear/gemm_awq.py index e2b95100e..645277282 100644 --- a/gptqmodel/nn_modules/qlinear/gemm_awq.py +++ b/gptqmodel/nn_modules/qlinear/gemm_awq.py @@ -8,13 +8,17 @@ from ...adapter.adapter import Adapter, Lora from ...models._const import DEVICE, PLATFORM from ...nn_modules.qlinear import AWQuantLinear -from ...quantization.awq.modules.linear.gemm import WQLinearMMFunction from ...utils.backend import BACKEND from ...utils.logger import setup_logger +from ...quantization.awq.utils.module import try_import + log = setup_logger() +awq_ext, msg = try_import("gptqmodel_awq_kernels") + + class AwqGEMMQuantLinear(AWQuantLinear): SUPPORTS_BITS = [4] SUPPORTS_GROUP_SIZE = [-1, 16, 32, 64, 128] @@ -122,4 +126,97 @@ def forward(self, x: torch.Tensor): return out.reshape(out_shape) -__all__ = ["AwqGEMMQuantLinear"] +# Adapted from https://github.com/compressa-ai/AutoAWQ/tree/dev +class WQLinearMMFunction(torch.autograd.Function): + @staticmethod + # ctx is the first argument to forward + def forward( + ctx, + x, + qweight, + qzeros, + scales, + w_bit=4, + group_size=128, + bias=None, + out_features=0, + ): + # The forward pass can use ctx. + ctx.save_for_backward(x, qweight, qzeros, scales, bias) + ctx.out_features = out_features + + out_shape = x.shape[:-1] + (out_features,) + x = x.to(torch.float16) + if x.shape[0] == 0: + return torch.zeros(out_shape, dtype=x.dtype, device=x.device) + + if awq_ext is not None: + FP16_MATMUL_HEURISTIC_CONDITION = x.shape[0] * x.shape[1] >= 1024 + + if FP16_MATMUL_HEURISTIC_CONDITION: + out = awq_ext.dequantize_weights_cuda( + qweight, scales, qzeros, 0, 0, 0, False + ) + out = torch.matmul(x, out) + else: + out = awq_ext.gemm_forward_cuda( + x.reshape(-1, x.shape[-1]), qweight, scales, qzeros, 8 + ) + + elif TRITON_AVAILABLE: + FP16_MATMUL_HEURISTIC_CONDITION = x.shape[0] * x.shape[1] >= 1024 + + if FP16_MATMUL_HEURISTIC_CONDITION: + out = awq_dequantize_triton(qweight, scales, qzeros) + out = torch.matmul(x, out.to(x.dtype)) + else: + out = awq_gemm_triton( + x.reshape(-1, x.shape[-1]), qweight, scales, qzeros, split_k_iters=8, + ) + + else: + global user_has_been_warned + if not user_has_been_warned: + warnings.warn("Using naive (slow) implementation." + msg) + user_has_been_warned = True + out = dequantize_gemm(qweight, qzeros, scales, w_bit, group_size) + out = torch.matmul(x, out) + + out = out + bias if bias is not None else out + out = out.reshape(out_shape) + + # always want 3D tensor if tensor is 2D + if len(out.shape) == 2: + out = out.unsqueeze(0) + + return out + + @staticmethod + def backward(ctx, grad_output): + input, qweight, qzeros, scales, bias = ctx.saved_tensors + + if awq_ext is None and not TRITON_AVAILABLE: + raise ValueError( + "either triton or autoawq-kernels is needed to be installed to use `.backward()`. Make sure to install the auto-awq kernels" + " by following the installation guides in https://github.com/casper-hansen/AutoAWQ_kernels" + ) + + # Cast to correct dtype for mixed precision training + if awq_ext is not None: + weights = awq_ext.dequantize_weights_cuda( + qweight, scales, qzeros, 1, 0, 0, False + ).to(grad_output.dtype) + else: + weights = awq_dequantize_triton( + qweight, scales, qzeros + ).to(grad_output.dtype) + + if ctx.needs_input_grad[0]: + # 3D matmul using torch.bmm: https://pytorch.org/docs/stable/generated/torch.bmm.html#torch.bmm + # to propagate gradient across all batch sizes. + batch_size = grad_output.shape[0] + grad_input = grad_output.bmm(weights.transpose(0, 1).unsqueeze(0).repeat(batch_size, 1, 1)) + + return grad_input, None, None, None, None, None, None, None + +__all__ = ["AwqGEMMQuantLinear", "WQLinearMMFunction"] diff --git a/gptqmodel/quantization/awq/modules/linear/__init__.py b/gptqmodel/quantization/awq/modules/linear/__init__.py index 4e6637ad7..8f114b9f0 100644 --- a/gptqmodel/quantization/awq/modules/linear/__init__.py +++ b/gptqmodel/quantization/awq/modules/linear/__init__.py @@ -3,7 +3,6 @@ # SPDX-License-Identifier: Apache-2.0 # Contact: qubitium@modelcloud.ai, x.com/qubitium -from .gemm import WQLinear_GEMM from .gemv import WQLinear_GEMV from .gemv_fast import WQLinear_GEMVFast from .marlin import WQLinear_Marlin diff --git a/gptqmodel/quantization/awq/modules/linear/gemm.py b/gptqmodel/quantization/awq/modules/linear/gemm.py deleted file mode 100644 index ad8e87825..000000000 --- a/gptqmodel/quantization/awq/modules/linear/gemm.py +++ /dev/null @@ -1,306 +0,0 @@ -# SPDX-FileCopyrightText: 2024-2025 ModelCloud.ai -# SPDX-FileCopyrightText: 2024-2025 qubitium@modelcloud.ai -# SPDX-License-Identifier: Apache-2.0 -# Contact: qubitium@modelcloud.ai, x.com/qubitium - -import warnings - -import torch -import torch.nn as nn -from torch.autograd import Function - -from gptqmodel.quantization.awq.utils.module import try_import -from gptqmodel.quantization.awq.utils.packing_utils import dequantize_gemm -from gptqmodel.quantization.awq.utils.utils import get_best_device - - -# NOTE: We check if awq_ext or triton is available. awq_ext will be preferred if both are installed. - -awq_ext, msg = try_import("gptqmodel_awq_kernels") -user_has_been_warned = False - -try: - from gptqmodel.quantization.awq.modules.triton.gemm import awq_dequantize_triton, awq_gemm_triton - - # covers CUDA, ROCm and XPU. If we can import triton, then we can use it. - TRITON_AVAILABLE = True - -except ImportError: - TRITON_AVAILABLE = False - -# Adapted from https://github.com/compressa-ai/AutoAWQ/tree/dev -class WQLinearMMFunction(Function): - @staticmethod - # ctx is the first argument to forward - def forward( - ctx, - x, - qweight, - qzeros, - scales, - w_bit=4, - group_size=128, - bias=None, - out_features=0, - ): - # The forward pass can use ctx. - ctx.save_for_backward(x, qweight, qzeros, scales, bias) - ctx.out_features = out_features - - out_shape = x.shape[:-1] + (out_features,) - x = x.to(torch.float16) - if x.shape[0] == 0: - return torch.zeros(out_shape, dtype=x.dtype, device=x.device) - - if awq_ext is not None: - FP16_MATMUL_HEURISTIC_CONDITION = x.shape[0] * x.shape[1] >= 1024 - - if FP16_MATMUL_HEURISTIC_CONDITION: - out = awq_ext.dequantize_weights_cuda( - qweight, scales, qzeros, 0, 0, 0, False - ) - out = torch.matmul(x, out) - else: - out = awq_ext.gemm_forward_cuda( - x.reshape(-1, x.shape[-1]), qweight, scales, qzeros, 8 - ) - - elif TRITON_AVAILABLE: - FP16_MATMUL_HEURISTIC_CONDITION = x.shape[0] * x.shape[1] >= 1024 - - if FP16_MATMUL_HEURISTIC_CONDITION: - out = awq_dequantize_triton(qweight, scales, qzeros) - out = torch.matmul(x, out.to(x.dtype)) - else: - out = awq_gemm_triton( - x.reshape(-1, x.shape[-1]), qweight, scales, qzeros, split_k_iters=8, - ) - - else: - global user_has_been_warned - if not user_has_been_warned: - warnings.warn("Using naive (slow) implementation." + msg) - user_has_been_warned = True - out = dequantize_gemm(qweight, qzeros, scales, w_bit, group_size) - out = torch.matmul(x, out) - - out = out + bias if bias is not None else out - out = out.reshape(out_shape) - - # always want 3D tensor if tensor is 2D - if len(out.shape) == 2: - out = out.unsqueeze(0) - - return out - - @staticmethod - def backward(ctx, grad_output): - input, qweight, qzeros, scales, bias = ctx.saved_tensors - - if awq_ext is None and not TRITON_AVAILABLE: - raise ValueError( - "either triton or autoawq-kernels is needed to be installed to use `.backward()`. Make sure to install the auto-awq kernels" - " by following the installation guides in https://github.com/casper-hansen/AutoAWQ_kernels" - ) - - # Cast to correct dtype for mixed precision training - if awq_ext is not None: - weights = awq_ext.dequantize_weights_cuda( - qweight, scales, qzeros, 1, 0, 0, False - ).to(grad_output.dtype) - else: - weights = awq_dequantize_triton( - qweight, scales, qzeros - ).to(grad_output.dtype) - - if ctx.needs_input_grad[0]: - # 3D matmul using torch.bmm: https://pytorch.org/docs/stable/generated/torch.bmm.html#torch.bmm - # to propagate gradient across all batch sizes. - batch_size = grad_output.shape[0] - grad_input = grad_output.bmm(weights.transpose(0, 1).unsqueeze(0).repeat(batch_size, 1, 1)) - - return grad_input, None, None, None, None, None, None, None - -class WQLinear_GEMM(nn.Module): - def __init__( - self, w_bit, group_size, in_features, out_features, bias, dev, training=False - ): - super().__init__() - - if w_bit not in [4]: - raise NotImplementedError("Only 4-bit are supported for now.") - - self.in_features = in_features - self.out_features = out_features - self.w_bit = w_bit - self.group_size = group_size if group_size != -1 else in_features - self.training = training - - # quick sanity check (make sure aligment) - assert self.in_features % self.group_size == 0 - assert out_features % (32 // self.w_bit) == 0 - - self.register_buffer( - "qweight", - torch.zeros( - (in_features, out_features // (32 // self.w_bit)), - dtype=torch.int32, - device=dev, - ), - ) - self.register_buffer( - "qzeros", - torch.zeros( - (in_features // self.group_size, out_features // (32 // self.w_bit)), - dtype=torch.int32, - device=dev, - ), - ) - self.register_buffer( - "scales", - torch.zeros( - (in_features // self.group_size, out_features), - dtype=torch.float16, - device=dev, - ), - ) - if bias: - self.register_buffer( - "bias", - torch.zeros( - (out_features), - dtype=torch.float16, - device=dev, - ), - ) - else: - self.bias = None - - @classmethod - def from_linear( - cls, linear, w_bit, group_size, init_only=False, scales=None, zeros=None - ): - awq_linear = cls( - w_bit, - group_size, - linear.in_features, - linear.out_features, - linear.bias is not None, - linear.weight.device, - ) - if init_only: # just prepare for loading sd - return awq_linear - - # need scales and zeros info for real quantization - assert scales is not None and zeros is not None - scale_zeros = zeros * scales - - awq_linear.scales = scales.clone().half() - if linear.bias is not None: - awq_linear.bias = linear.bias.clone().half() - - pack_num = 32 // awq_linear.w_bit - - intweight = [] - for idx in range(awq_linear.in_features): - intweight.append( - torch.round( - (linear.weight.data[:, idx] + scale_zeros[idx // group_size]) - / awq_linear.scales[idx // group_size] - ).to(torch.int)[:, None] - ) - intweight = torch.cat(intweight, dim=1) - intweight = intweight.t().contiguous() - intweight = intweight.to(dtype=torch.int32) - - best_device = get_best_device() - - # Avoid: The operator 'aten::__lshift__.Scalar' is not currently implemented for the MPS device - if "mps" in best_device: - intweight = intweight.to("cpu") - - qweight = torch.zeros( - (intweight.shape[0], intweight.shape[1] // 32 * awq_linear.w_bit), - dtype=torch.int32, - device=intweight.device, - ) - - for col in range(intweight.shape[1] // pack_num): - if awq_linear.w_bit == 4: - order_map = [0, 2, 4, 6, 1, 3, 5, 7] - else: - raise NotImplementedError("Only 4-bit are supported for now.") - for i in range(pack_num): - qweight_col = intweight[:, col * pack_num + order_map[i]] - qweight[:, col] |= qweight_col << (i * awq_linear.w_bit) - awq_linear.qweight = qweight - - zeros = zeros.to(dtype=torch.int32, device=best_device) - - if "mps" in best_device: - zeros = zeros.to("cpu") - - qzeros = torch.zeros( - (zeros.shape[0], zeros.shape[1] // 32 * awq_linear.w_bit), - dtype=torch.int32, - device=zeros.device, - ) - - for col in range(zeros.shape[1] // pack_num): - if awq_linear.w_bit == 4: - order_map = [0, 2, 4, 6, 1, 3, 5, 7] - else: - raise NotImplementedError("Only 4-bit are supported for now.") - for i in range(pack_num): - qzero_col = zeros[:, col * pack_num + order_map[i]] - qzeros[:, col] |= qzero_col << (i * awq_linear.w_bit) - awq_linear.qzeros = qzeros - - return awq_linear - - def forward(self, x): - out_shape = x.shape[:-1] + (self.out_features,) - - input_dtype = x.dtype - if input_dtype != torch.float16: - x = x.half() - - if self.training: - out = WQLinearMMFunction.apply( - x, - self.qweight, - self.qzeros, - self.scales, - self.w_bit, - self.group_size, - self.bias, - self.out_features, - ) - else: - with torch.inference_mode(): - out = WQLinearMMFunction.apply( - x, - self.qweight, - self.qzeros, - self.scales, - self.w_bit, - self.group_size, - self.bias, - self.out_features, - ) - - if input_dtype != torch.float16: - out = out.to(dtype=input_dtype) - - return out.reshape(out_shape) - - def extra_repr(self) -> str: - return ( - "in_features={}, out_features={}, bias={}, w_bit={}, group_size={}".format( - self.in_features, - self.out_features, - self.bias is not None, - self.w_bit, - self.group_size, - ) - ) From 1657ba47a95c919abb722f694192eed931b17878 Mon Sep 17 00:00:00 2001 From: Qubitium Date: Sun, 23 Nov 2025 03:54:40 +0000 Subject: [PATCH 08/24] format --- gptqmodel/models/base.py | 12 ++++++------ gptqmodel/nn_modules/qlinear/gemm_awq.py | 3 +-- .../quantization/awq/modules/linear/exllamav2.py | 3 +-- gptqmodel/utils/importer.py | 14 +++++++------- gptqmodel/utils/model.py | 2 +- tests/models/test_qwen2_5_vl.py | 1 + tests/test_auto_detect_module_tree.py | 3 ++- 7 files changed, 19 insertions(+), 19 deletions(-) diff --git a/gptqmodel/models/base.py b/gptqmodel/models/base.py index f6f242c5f..a0d0dfde4 100644 --- a/gptqmodel/models/base.py +++ b/gptqmodel/models/base.py @@ -238,11 +238,11 @@ def __init__( if type(self).module_tree is None: type(self).module_tree = self._auto_detect_module_tree(model, quant_method) - + # If module_tree is still None after auto-detection, raise an error indicating unsupported model type if type(self).module_tree is None: raise ValueError(f"Unsupport model_type {model.config.model_type}, and failed to auto-detect module tree for model {model}") - + # record configuration early so model lifecycle hooks can rely on them self.compiled = False # set to True while compile() is triggered successfully @@ -1690,7 +1690,7 @@ def _get(path): "blocks", "model.blocks", ] - + chosen = None for c in candidates: m = _get(c) @@ -1700,7 +1700,7 @@ def _get(path): break if chosen is None: - log.warn("Module Tree AutoCompat: All candidate paths invalid, return None") + log.warn("Module Tree AutoCompat: All candidate paths invalid, return None") return None layer0 = _get(chosen)[0] @@ -1715,7 +1715,7 @@ def _linear_names(module): if len(all_linear)>0: log.warn(f"Module Tree AutoCompat: found {len(all_linear)} Linear/Conv modules in {type(layer0).__name__}: {all_linear}") else: - log.warn(f"Module Tree AutoCompat: No Linear/Conv names in layer0, return None") + log.warn("Module Tree AutoCompat: No Linear/Conv names in layer0, return None") return None mapping = {} @@ -1732,7 +1732,7 @@ def _leaf_tokens(prefix): return tuple(x.split(".")[-1] for x in all_linear if x.startswith(f"{prefix}.")) possible_parent = ["attn", "attention", "self_attn", "mlp", "ffn", "feed", "dense"] - + found_parents = _find_parents(layer0, possible_parent) for p in found_parents: diff --git a/gptqmodel/nn_modules/qlinear/gemm_awq.py b/gptqmodel/nn_modules/qlinear/gemm_awq.py index 645277282..67c7a76e6 100644 --- a/gptqmodel/nn_modules/qlinear/gemm_awq.py +++ b/gptqmodel/nn_modules/qlinear/gemm_awq.py @@ -8,10 +8,9 @@ from ...adapter.adapter import Adapter, Lora from ...models._const import DEVICE, PLATFORM from ...nn_modules.qlinear import AWQuantLinear +from ...quantization.awq.utils.module import try_import from ...utils.backend import BACKEND from ...utils.logger import setup_logger -from ...quantization.awq.utils.module import try_import - log = setup_logger() diff --git a/gptqmodel/quantization/awq/modules/linear/exllamav2.py b/gptqmodel/quantization/awq/modules/linear/exllamav2.py index e62631993..a529d3876 100644 --- a/gptqmodel/quantization/awq/modules/linear/exllamav2.py +++ b/gptqmodel/quantization/awq/modules/linear/exllamav2.py @@ -3,14 +3,13 @@ # SPDX-License-Identifier: Apache-2.0 # Contact: qubitium@modelcloud.ai, x.com/qubitium -from typing import Dict import torch import torch.nn as nn from gptqmodel.quantization.awq.utils.module import try_import from gptqmodel.quantization.awq.utils.packing_utils import unpack_reorder_pack -from gptqmodel.nn_modules.qlinear.awq_exllamav2 import AwqExllamaV2QuantLinear + exlv2_ext, msg = try_import("gptqmodel_exlv2_kernels") diff --git a/gptqmodel/utils/importer.py b/gptqmodel/utils/importer.py index cff945c19..7b372eb7e 100644 --- a/gptqmodel/utils/importer.py +++ b/gptqmodel/utils/importer.py @@ -13,22 +13,22 @@ from ..models._const import DEVICE, normalize_device from ..nn_modules.qlinear import BaseQuantLinear, PackableQuantLinear +from ..nn_modules.qlinear.bitblas import BitBLASQuantLinear +from ..nn_modules.qlinear.exllama import ExllamaQuantLinear from ..nn_modules.qlinear.exllama_awq import AwqExllamaQuantLinear +from ..nn_modules.qlinear.exllama_eora import ExllamaEoraQuantLinear +from ..nn_modules.qlinear.exllamav2 import ExllamaV2QuantLinear from ..nn_modules.qlinear.exllamav2_awq import AwqExllamaV2QuantLinear from ..nn_modules.qlinear.gemm_awq import AwqGEMMQuantLinear from ..nn_modules.qlinear.gemv_awq import AwqGEMVQuantLinear from ..nn_modules.qlinear.gemv_fast_awq import AwqGEMVFastQuantLinear -from ..nn_modules.qlinear.machete_awq import AwqMacheteQuantLinear -from ..nn_modules.qlinear.marlin_awq import AwqMarlinQuantLinear -from ..nn_modules.qlinear.torch_awq import AwqTorchQuantLinear -from ..nn_modules.qlinear.bitblas import BitBLASQuantLinear -from ..nn_modules.qlinear.exllama import ExllamaQuantLinear -from ..nn_modules.qlinear.exllama_eora import ExllamaEoraQuantLinear -from ..nn_modules.qlinear.exllamav2 import ExllamaV2QuantLinear from ..nn_modules.qlinear.machete import MacheteQuantLinear +from ..nn_modules.qlinear.machete_awq import AwqMacheteQuantLinear from ..nn_modules.qlinear.marlin import MarlinQuantLinear +from ..nn_modules.qlinear.marlin_awq import AwqMarlinQuantLinear from ..nn_modules.qlinear.qqq import QQQQuantLinear from ..nn_modules.qlinear.torch import TorchQuantLinear +from ..nn_modules.qlinear.torch_awq import AwqTorchQuantLinear from ..nn_modules.qlinear.torch_fused import TorchFusedQuantLinear from ..nn_modules.qlinear.torch_fused_awq import TorchFusedAwqQuantLinear from ..nn_modules.qlinear.tritonv2 import TRITON_AVAILABLE, TRITON_INSTALL_HINT, TritonV2QuantLinear diff --git a/gptqmodel/utils/model.py b/gptqmodel/utils/model.py index b8fd5608f..3d8343409 100644 --- a/gptqmodel/utils/model.py +++ b/gptqmodel/utils/model.py @@ -46,9 +46,9 @@ SUPPORTS_MODULE_TYPES, ) from ..nn_modules.qlinear import BaseQuantLinear -from ..nn_modules.qlinear.exllamav2_awq import AwqExllamaV2QuantLinear from ..nn_modules.qlinear.exllama import ExllamaQuantLinear from ..nn_modules.qlinear.exllamav2 import ExllamaV2QuantLinear +from ..nn_modules.qlinear.exllamav2_awq import AwqExllamaV2QuantLinear from ..quantization import FORMAT, QuantizeConfig from ..quantization.config import FORMAT_FIELD_CHECKPOINT, METHOD, dynamic_get from . import has_gil_disabled diff --git a/tests/models/test_qwen2_5_vl.py b/tests/models/test_qwen2_5_vl.py index a6e89a317..09609c796 100644 --- a/tests/models/test_qwen2_5_vl.py +++ b/tests/models/test_qwen2_5_vl.py @@ -8,6 +8,7 @@ from gptqmodel.models.definitions.qwen2_5_vl import Qwen2_5_VLQModel from gptqmodel.utils.eval import EVAL + class TestQwen2_5_VL(ModelTest): NATIVE_MODEL_ID = "/monster/data/model/Qwen2.5-VL-3B-Instruct" EVAL_TASKS = { diff --git a/tests/test_auto_detect_module_tree.py b/tests/test_auto_detect_module_tree.py index 6d9b2a629..c7dd04403 100644 --- a/tests/test_auto_detect_module_tree.py +++ b/tests/test_auto_detect_module_tree.py @@ -1,4 +1,5 @@ import unittest + import torch.nn as nn from gptqmodel.models.base import BaseQModel @@ -41,4 +42,4 @@ def test_layers_with_parents(self): self.assertIn("self_attn", mapping) self.assertIn("mlp", mapping) self.assertSetEqual(set(mapping["self_attn"]), {"q_proj", "k_proj"}) - self.assertSetEqual(set(mapping["mlp"]), {"fc1", "fc2"}) \ No newline at end of file + self.assertSetEqual(set(mapping["mlp"]), {"fc1", "fc2"}) From 47822d6fb5d00127254f2dadc823e32df0127ccd Mon Sep 17 00:00:00 2001 From: Qubitium Date: Sun, 23 Nov 2025 04:43:46 +0000 Subject: [PATCH 09/24] cleanup --- gptqmodel/nn_modules/qlinear/gemm_awq.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/gptqmodel/nn_modules/qlinear/gemm_awq.py b/gptqmodel/nn_modules/qlinear/gemm_awq.py index 67c7a76e6..371c51573 100644 --- a/gptqmodel/nn_modules/qlinear/gemm_awq.py +++ b/gptqmodel/nn_modules/qlinear/gemm_awq.py @@ -7,7 +7,7 @@ from ...adapter.adapter import Adapter, Lora from ...models._const import DEVICE, PLATFORM -from ...nn_modules.qlinear import AWQuantLinear +from ...nn_modules.qlinear import AWQuantLinear, tritonv2 from ...quantization.awq.utils.module import try_import from ...utils.backend import BACKEND from ...utils.logger import setup_logger @@ -16,6 +16,7 @@ log = setup_logger() awq_ext, msg = try_import("gptqmodel_awq_kernels") +user_has_been_warned = False class AwqGEMMQuantLinear(AWQuantLinear): @@ -162,7 +163,7 @@ def forward( x.reshape(-1, x.shape[-1]), qweight, scales, qzeros, 8 ) - elif TRITON_AVAILABLE: + elif tritonv2.TRITON_AVAILABLE: FP16_MATMUL_HEURISTIC_CONDITION = x.shape[0] * x.shape[1] >= 1024 if FP16_MATMUL_HEURISTIC_CONDITION: @@ -176,7 +177,7 @@ def forward( else: global user_has_been_warned if not user_has_been_warned: - warnings.warn("Using naive (slow) implementation." + msg) + log.warn("Using naive (slow) implementation." + msg) user_has_been_warned = True out = dequantize_gemm(qweight, qzeros, scales, w_bit, group_size) out = torch.matmul(x, out) @@ -194,10 +195,9 @@ def forward( def backward(ctx, grad_output): input, qweight, qzeros, scales, bias = ctx.saved_tensors - if awq_ext is None and not TRITON_AVAILABLE: + if awq_ext is None and not tritonv2.TRITON_AVAILABLE: raise ValueError( - "either triton or autoawq-kernels is needed to be installed to use `.backward()`. Make sure to install the auto-awq kernels" - " by following the installation guides in https://github.com/casper-hansen/AutoAWQ_kernels" + "Please install required triton via `pip install -U triton`" ) # Cast to correct dtype for mixed precision training From 95d7b8b0a52e1cc5f1403926713972aeedbe21e5 Mon Sep 17 00:00:00 2001 From: Qubitium Date: Sun, 23 Nov 2025 05:01:16 +0000 Subject: [PATCH 10/24] simplify --- gptqmodel/nn_modules/qlinear/gemm_awq.py | 17 ++++------------- 1 file changed, 4 insertions(+), 13 deletions(-) diff --git a/gptqmodel/nn_modules/qlinear/gemm_awq.py b/gptqmodel/nn_modules/qlinear/gemm_awq.py index 371c51573..6829dd03d 100644 --- a/gptqmodel/nn_modules/qlinear/gemm_awq.py +++ b/gptqmodel/nn_modules/qlinear/gemm_awq.py @@ -3,6 +3,8 @@ # SPDX-License-Identifier: Apache-2.0 # Contact: qubitium@modelcloud.ai, x.com/qubitium +from contextlib import nullcontext + import torch from ...adapter.adapter import Adapter, Lora @@ -94,7 +96,8 @@ def forward(self, x: torch.Tensor): if input_dtype != torch.float16: x = x.half() - if self.training: + ctx = nullcontext() if self.training else torch.inference_mode() + with ctx: out = WQLinearMMFunction.apply( x, self.qweight, @@ -105,18 +108,6 @@ def forward(self, x: torch.Tensor): self.bias, self.out_features, ) - else: - with torch.inference_mode(): - out = WQLinearMMFunction.apply( - x, - self.qweight, - self.qzeros, - self.scales, - self.bits, - self.group_size, - self.bias, - self.out_features, - ) if input_dtype != torch.float16: out = out.to(dtype=input_dtype) From f7d2ad6a00ed589c0009db8911bc6f83e2f2a1a2 Mon Sep 17 00:00:00 2001 From: Qubitium Date: Sun, 23 Nov 2025 09:43:49 +0000 Subject: [PATCH 11/24] separate triton kernel --- gptqmodel/nn_modules/qlinear/gemm_awq.py | 87 ++++---- .../nn_modules/qlinear/gemm_awq_triton.py | 186 ++++++++++++++++++ tests/test_awq_fp16_matmul_heuristic.py | 163 +++++++++++++++ 3 files changed, 391 insertions(+), 45 deletions(-) create mode 100644 gptqmodel/nn_modules/qlinear/gemm_awq_triton.py create mode 100644 tests/test_awq_fp16_matmul_heuristic.py diff --git a/gptqmodel/nn_modules/qlinear/gemm_awq.py b/gptqmodel/nn_modules/qlinear/gemm_awq.py index 6829dd03d..6fec7e11d 100644 --- a/gptqmodel/nn_modules/qlinear/gemm_awq.py +++ b/gptqmodel/nn_modules/qlinear/gemm_awq.py @@ -9,7 +9,7 @@ from ...adapter.adapter import Adapter, Lora from ...models._const import DEVICE, PLATFORM -from ...nn_modules.qlinear import AWQuantLinear, tritonv2 +from ...nn_modules.qlinear import AWQuantLinear from ...quantization.awq.utils.module import try_import from ...utils.backend import BACKEND from ...utils.logger import setup_logger @@ -21,6 +21,32 @@ user_has_been_warned = False +def cuda_backend_available() -> bool: + return awq_ext is not None + + +def cuda_forward(x: torch.Tensor, qweight: torch.Tensor, qzeros: torch.Tensor, scales: torch.Tensor, + group_size: int, bias: torch.Tensor, out_features: int) -> torch.Tensor: + FP16_MATMUL_HEURISTIC_CONDITION = x.shape[0] * x.shape[1] >= 1024 + + if FP16_MATMUL_HEURISTIC_CONDITION: + out = awq_ext.dequantize_weights_cuda(qweight, scales, qzeros, 0, 0, 0, False) + out = torch.matmul(x, out) + else: + out = awq_ext.gemm_forward_cuda( + x.reshape(-1, x.shape[-1]), qweight, scales, qzeros, 8 + ) + + return out + + +def cuda_dequantize_weights(qweight: torch.Tensor, qzeros: torch.Tensor, scales: torch.Tensor, + dtype: torch.dtype) -> torch.Tensor: + return awq_ext.dequantize_weights_cuda( + qweight, scales, qzeros, 1, 0, 0, False + ).to(dtype) + + class AwqGEMMQuantLinear(AWQuantLinear): SUPPORTS_BITS = [4] SUPPORTS_GROUP_SIZE = [-1, 16, 32, 64, 128] @@ -58,6 +84,9 @@ def __init__( register_buffers: bool = False, **kwargs, ): + if not cuda_backend_available(): + raise ValueError("CUDA AWQ extension not available; cannot build AwqGEMMQuantLinear") + super().__init__( bits=bits, group_size=group_size, @@ -107,6 +136,7 @@ def forward(self, x: torch.Tensor): self.group_size, self.bias, self.out_features, + "cuda", ) if input_dtype != torch.float16: @@ -131,6 +161,7 @@ def forward( group_size=128, bias=None, out_features=0, + prefer_backend=None, ): # The forward pass can use ctx. ctx.save_for_backward(x, qweight, qzeros, scales, bias) @@ -141,37 +172,10 @@ def forward( if x.shape[0] == 0: return torch.zeros(out_shape, dtype=x.dtype, device=x.device) - if awq_ext is not None: - FP16_MATMUL_HEURISTIC_CONDITION = x.shape[0] * x.shape[1] >= 1024 - - if FP16_MATMUL_HEURISTIC_CONDITION: - out = awq_ext.dequantize_weights_cuda( - qweight, scales, qzeros, 0, 0, 0, False - ) - out = torch.matmul(x, out) - else: - out = awq_ext.gemm_forward_cuda( - x.reshape(-1, x.shape[-1]), qweight, scales, qzeros, 8 - ) - - elif tritonv2.TRITON_AVAILABLE: - FP16_MATMUL_HEURISTIC_CONDITION = x.shape[0] * x.shape[1] >= 1024 - - if FP16_MATMUL_HEURISTIC_CONDITION: - out = awq_dequantize_triton(qweight, scales, qzeros) - out = torch.matmul(x, out.to(x.dtype)) - else: - out = awq_gemm_triton( - x.reshape(-1, x.shape[-1]), qweight, scales, qzeros, split_k_iters=8, - ) - - else: - global user_has_been_warned - if not user_has_been_warned: - log.warn("Using naive (slow) implementation." + msg) - user_has_been_warned = True - out = dequantize_gemm(qweight, qzeros, scales, w_bit, group_size) - out = torch.matmul(x, out) + if not cuda_backend_available(): + raise ValueError("CUDA AWQ extension not available for WQLinearMMFunction") + + out = cuda_forward(x, qweight, qzeros, scales, group_size, bias, out_features) out = out + bias if bias is not None else out out = out.reshape(out_shape) @@ -186,20 +190,11 @@ def forward( def backward(ctx, grad_output): input, qweight, qzeros, scales, bias = ctx.saved_tensors - if awq_ext is None and not tritonv2.TRITON_AVAILABLE: - raise ValueError( - "Please install required triton via `pip install -U triton`" - ) + if not cuda_backend_available(): + raise ValueError("CUDA AWQ extension not available for WQLinearMMFunction") # Cast to correct dtype for mixed precision training - if awq_ext is not None: - weights = awq_ext.dequantize_weights_cuda( - qweight, scales, qzeros, 1, 0, 0, False - ).to(grad_output.dtype) - else: - weights = awq_dequantize_triton( - qweight, scales, qzeros - ).to(grad_output.dtype) + weights = cuda_dequantize_weights(qweight, qzeros, scales, grad_output.dtype) if ctx.needs_input_grad[0]: # 3D matmul using torch.bmm: https://pytorch.org/docs/stable/generated/torch.bmm.html#torch.bmm @@ -207,6 +202,8 @@ def backward(ctx, grad_output): batch_size = grad_output.shape[0] grad_input = grad_output.bmm(weights.transpose(0, 1).unsqueeze(0).repeat(batch_size, 1, 1)) - return grad_input, None, None, None, None, None, None, None + return grad_input, None, None, None, None, None, None, None, None + + __all__ = ["AwqGEMMQuantLinear", "WQLinearMMFunction"] diff --git a/gptqmodel/nn_modules/qlinear/gemm_awq_triton.py b/gptqmodel/nn_modules/qlinear/gemm_awq_triton.py new file mode 100644 index 000000000..82f115115 --- /dev/null +++ b/gptqmodel/nn_modules/qlinear/gemm_awq_triton.py @@ -0,0 +1,186 @@ +# SPDX-FileCopyrightText: 2024-2025 ModelCloud.ai +# SPDX-FileCopyrightText: 2024-2025 qubitium@modelcloud.ai +# SPDX-License-Identifier: Apache-2.0 +# Contact: qubitium@modelcloud.ai, x.com/qubitium + +from contextlib import nullcontext + +import torch + +from ...adapter.adapter import Adapter, Lora +from ...models._const import DEVICE, PLATFORM +from ...nn_modules.qlinear import AWQuantLinear +from ...utils.backend import BACKEND +from . import tritonv2 +from ...quantization.awq.modules.triton.gemm import awq_dequantize_triton, awq_gemm_triton + + +def triton_backend_available() -> bool: + return tritonv2.TRITON_AVAILABLE + + +def triton_forward(x: torch.Tensor, qweight: torch.Tensor, qzeros: torch.Tensor, scales: torch.Tensor, + group_size: int, bias: torch.Tensor, out_features: int) -> torch.Tensor: + FP16_MATMUL_HEURISTIC_CONDITION = x.shape[0] * x.shape[1] >= 1024 + + if FP16_MATMUL_HEURISTIC_CONDITION: + out = awq_dequantize_triton(qweight, scales, qzeros) + out = torch.matmul(x, out.to(x.dtype)) + else: + out = awq_gemm_triton( + x.reshape(-1, x.shape[-1]), qweight, scales, qzeros, split_k_iters=8, + ) + + return out + + +def triton_dequantize_weights(qweight: torch.Tensor, qzeros: torch.Tensor, scales: torch.Tensor, + dtype: torch.dtype) -> torch.Tensor: + return awq_dequantize_triton(qweight, scales, qzeros).to(dtype) + + +class WQLinearMMTritonFunction(torch.autograd.Function): + @staticmethod + def forward( + ctx, + x, + qweight, + qzeros, + scales, + w_bit=4, + group_size=128, + bias=None, + out_features=0, + prefer_backend=None, + ): + if not triton_backend_available(): + raise ValueError(tritonv2.TRITON_INSTALL_HINT) + + ctx.save_for_backward(x, qweight, qzeros, scales, bias) + ctx.out_features = out_features + + out_shape = x.shape[:-1] + (out_features,) + x = x.to(torch.float16) + if x.shape[0] == 0: + return torch.zeros(out_shape, dtype=x.dtype, device=x.device) + + out = triton_forward(x, qweight, qzeros, scales, group_size, bias, out_features) + out = out + bias if bias is not None else out + out = out.reshape(out_shape) + if len(out.shape) == 2: + out = out.unsqueeze(0) + return out + + @staticmethod + def backward(ctx, grad_output): + input, qweight, qzeros, scales, bias = ctx.saved_tensors + if not triton_backend_available(): + raise ValueError(tritonv2.TRITON_INSTALL_HINT) + + weights = triton_dequantize_weights(qweight, qzeros, scales, grad_output.dtype) + + grad_input = None + if ctx.needs_input_grad[0]: + batch_size = grad_output.shape[0] + grad_input = grad_output.bmm(weights.transpose(0, 1).unsqueeze(0).repeat(batch_size, 1, 1)) + + return grad_input, None, None, None, None, None, None, None, None + + +class AwqGEMMTritonQuantLinear(AWQuantLinear): + SUPPORTS_BITS = [4] + SUPPORTS_GROUP_SIZE = [-1, 16, 32, 64, 128] + SUPPORTS_DESC_ACT = [True, False] + SUPPORTS_SYM = [True, False] + SUPPORTS_SHARDS = True + SUPPORTS_TRAINING = True + SUPPORTS_AUTO_PADDING = False + SUPPORTS_IN_FEATURES_DIVISIBLE_BY = [1] + SUPPORTS_OUT_FEATURES_DIVISIBLE_BY = [1] + + SUPPORTS_DEVICES = [DEVICE.ALL] + SUPPORTS_PLATFORM = [PLATFORM.ALL] + SUPPORTS_PACK_DTYPES = [torch.int32] + SUPPORTS_ADAPTERS = [Lora] + + SUPPORTS_DTYPES = [torch.float16] + + REQUIRES_FORMAT_V2 = False + + QUANT_TYPE = "awq_gemm_triton" + + def __init__( + self, + bits: int, + group_size: int, + sym: bool, + desc_act: bool, + in_features: int, + out_features: int, + bias: bool = False, + pack_dtype: torch.dtype = torch.int32, + adapter: Adapter = None, + register_buffers: bool = False, + **kwargs, + ): + if not triton_backend_available(): + raise ValueError(tritonv2.TRITON_INSTALL_HINT) + + super().__init__( + bits=bits, + group_size=group_size, + sym=sym, + desc_act=desc_act, + in_features=in_features, + out_features=out_features, + bias=bias, + pack_dtype=pack_dtype, + backend=kwargs.pop("backend", BACKEND.TRITON), + adapter=adapter, + register_buffers=register_buffers, + **kwargs) + + def post_init(self): + if self.scales is not None: + self.scales = self.scales.to(dtype=torch.float16) + super().post_init() + + def forward(self, x: torch.Tensor): + out_shape = x.shape[:-1] + (self.out_features,) + + input_dtype = x.dtype + if input_dtype != torch.float16: + x = x.half() + + ctx = nullcontext() if self.training else torch.inference_mode() + with ctx: + out = WQLinearMMTritonFunction.apply( + x, + self.qweight, + self.qzeros, + self.scales, + self.bits, + self.group_size, + self.bias, + self.out_features, + "triton", + ) + + if input_dtype != torch.float16: + out = out.to(dtype=input_dtype) + + if self.adapter: + out = self.adapter.apply(x=x, out=out) + + return out.reshape(out_shape) + + +__all__ = [ + "awq_dequantize_triton", + "awq_gemm_triton", + "triton_backend_available", + "triton_forward", + "triton_dequantize_weights", + "AwqGEMMTritonQuantLinear", + "WQLinearMMTritonFunction", +] diff --git a/tests/test_awq_fp16_matmul_heuristic.py b/tests/test_awq_fp16_matmul_heuristic.py new file mode 100644 index 000000000..bf1533d81 --- /dev/null +++ b/tests/test_awq_fp16_matmul_heuristic.py @@ -0,0 +1,163 @@ +import os +import time + +import pytest +import torch + +import gptqmodel.nn_modules.qlinear.gemm_awq as gemm_awq +import gptqmodel.nn_modules.qlinear.gemm_awq_triton as gemm_awq_triton + + +def _fake_quant_tensors(in_features: int = 32, out_features: int = 8, group_size: int = 32): + qweight = torch.ones((in_features, out_features // 8), dtype=torch.int32) + scales = torch.ones((in_features // group_size, out_features), dtype=torch.float16) + qzeros = torch.zeros((in_features // group_size, out_features // 8), dtype=torch.int32) + return qweight, scales, qzeros + + +def _patch_for_triton(monkeypatch, calls): + monkeypatch.setattr(gemm_awq, "awq_ext", None) + monkeypatch.setattr(gemm_awq, "cuda_backend_available", lambda: False) + + monkeypatch.setattr(gemm_awq_triton.tritonv2, "TRITON_AVAILABLE", True) + monkeypatch.setattr(gemm_awq_triton, "triton_backend_available", lambda: True) + monkeypatch.setattr(gemm_awq, "triton_backend_available", lambda: True, raising=False) + + def fake_dequant(qweight, scales, qzeros): + calls["dequant"] += 1 + return torch.ones(qweight.shape[0], qweight.shape[1] * 8, dtype=torch.float16) + + def fake_gemm(input, qweight, scales, qzeros, split_k_iters, **_): + calls["gemm"] += 1 + out_features = qweight.shape[1] * 8 + return torch.ones(input.shape[0], out_features, device=input.device, dtype=input.dtype) + + monkeypatch.setattr(gemm_awq_triton, "awq_dequantize_triton", fake_dequant, raising=False) + monkeypatch.setattr(gemm_awq_triton, "awq_gemm_triton", fake_gemm, raising=False) + + +def test_fp16_matmul_heuristic_prefers_dequant_for_large_matrices(monkeypatch): + calls = {"dequant": 0, "gemm": 0} + _patch_for_triton(monkeypatch, calls) + + group_size = 32 + out_features = 8 + qweight, scales, qzeros = _fake_quant_tensors(in_features=32, out_features=out_features, group_size=group_size) + + # Large batch x sequence activates the dequantize-then-matmul path. + x = torch.ones((33, 32, qweight.shape[0]), dtype=torch.float16) + + out = gemm_awq_triton.WQLinearMMTritonFunction.apply( + x, qweight, qzeros, scales, 4, group_size, None, out_features, + ) + + assert calls == {"dequant": 1, "gemm": 0} + assert out.shape == (33, 32, out_features) + + +def test_fp16_matmul_heuristic_prefers_fused_gemm_for_small_matrices(monkeypatch): + calls = {"dequant": 0, "gemm": 0} + _patch_for_triton(monkeypatch, calls) + + group_size = 32 + out_features = 8 + qweight, scales, qzeros = _fake_quant_tensors(in_features=32, out_features=out_features, group_size=group_size) + + # Small batch x sequence stays on the fused GEMM kernel. + x = torch.ones((1, 1, qweight.shape[0]), dtype=torch.float16) + + out = gemm_awq_triton.WQLinearMMTritonFunction.apply( + x, qweight, qzeros, scales, 4, group_size, None, out_features, + ) + + assert calls == {"dequant": 0, "gemm": 1} + assert out.shape == (1, 1, out_features) + + +SEQ_LENS = [128, 256, 512, 1024, 1280, 1536, 2048, 4096, 8192] + +# Each entry: (case_name, batch, in_features, out_features) +BENCH_CASES = [ + # Llama 3.2-style shapes (hidden=4096) + ("llama3.2_qkv", 1, 4096, 4096 * 3), + ("llama3.2_up_proj", 1, 4096, 11008), + ("llama3.2_down_proj", 1, 11008, 4096), + # Qwen3-style shapes (hidden≈3584) + ("qwen3_qkv", 1, 3584, 3584 * 3), + ("qwen3_up_proj", 1, 3584, 14336), # 4x hidden for MLP expansion + ("qwen3_down_proj", 1, 14336, 3584), +] + + +@pytest.mark.parametrize( + ("case_name", "batch", "seq", "in_features", "out_features"), + [(case, batch, seq, inf, outf) for (case, batch, inf, outf) in BENCH_CASES for seq in SEQ_LENS], + ids=[f"{case}_s{seq}" for (case, _, _, _) in BENCH_CASES for seq in SEQ_LENS], +) +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA backend required for AWQ benchmark") +def test_fp16_matmul_heuristic_benchmark(case_name, batch, seq, in_features, out_features): + if os.getenv("RUN_AWQ_FP16_HEURISTIC_BENCH") != "1": + pytest.skip("Set RUN_AWQ_FP16_HEURISTIC_BENCH=1 to enable this benchmark") + + tabulate = pytest.importorskip("tabulate").tabulate + + backend = None + if gemm_awq.awq_ext is not None: + backend = "awq_ext" + elif gemm_awq.tritonv2.TRITON_AVAILABLE: + backend = "triton" + else: + pytest.skip("No AWQ backend available for benchmark") + + device = torch.device("cuda") + torch.manual_seed(0) + + group_size = 32 + + x = torch.randn((batch, seq, in_features), device=device, dtype=torch.float16) + qweight = torch.randint(0, 16, (in_features, out_features // 8), device=device, dtype=torch.int32) + scales = torch.randn((in_features // group_size, out_features), device=device, dtype=torch.float16) + qzeros = torch.zeros((in_features // group_size, out_features // 8), device=device, dtype=torch.int32) + + if backend == "triton": + from gptqmodel.quantization.awq.modules.triton.gemm import awq_dequantize_triton, awq_gemm_triton + + def run_dequant_matmul(): + with torch.inference_mode(): + if backend == "awq_ext": + weight = gemm_awq.awq_ext.dequantize_weights_cuda(qweight, scales, qzeros, 0, 0, 0, False) + else: + try: + weight = awq_dequantize_triton(qweight, scales, qzeros) + except AttributeError as err: + pytest.skip(f"Triton backend is incompatible: {err}") + return torch.matmul(x, weight.to(x.dtype)) + + def run_fused_gemm(): + with torch.inference_mode(): + x2d = x.reshape(-1, x.shape[-1]) + if backend == "awq_ext": + return gemm_awq.awq_ext.gemm_forward_cuda(x2d, qweight, scales, qzeros, 8) + try: + return awq_gemm_triton(x2d, qweight, scales, qzeros, split_k_iters=8) + except AttributeError as err: + pytest.skip(f"Triton backend is incompatible: {err}") + + def benchmark(fn, iters=3): + fn() + torch.cuda.synchronize() + start = time.perf_counter() + for _ in range(iters): + fn() + torch.cuda.synchronize() + return (time.perf_counter() - start) / iters * 1e3 + + dequant_ms = benchmark(run_dequant_matmul) + fused_ms = benchmark(run_fused_gemm) + + meets_condition = batch * seq >= 1024 + rows = [ + [case_name, batch, seq, meets_condition, f"{in_features}->{out_features}", "condition=True (dequant+matmul)", f"{dequant_ms:.3f} ms"], + [case_name, batch, seq, meets_condition, f"{in_features}->{out_features}", "condition=False (fused gemm)", f"{fused_ms:.3f} ms"], + ] + print(tabulate(rows, headers=["case", "batch", "seq", "meets >=1024", "matmul (in->out)", "path", "avg latency"])) From 3946e1f7b312d41bf57052a7906be01ba2919b54 Mon Sep 17 00:00:00 2001 From: Qubitium Date: Sun, 23 Nov 2025 09:55:35 +0000 Subject: [PATCH 12/24] cleanup --- .../nn_modules/qlinear/gemm_awq_triton.py | 113 +++++++++--------- 1 file changed, 55 insertions(+), 58 deletions(-) diff --git a/gptqmodel/nn_modules/qlinear/gemm_awq_triton.py b/gptqmodel/nn_modules/qlinear/gemm_awq_triton.py index 82f115115..ecaba2394 100644 --- a/gptqmodel/nn_modules/qlinear/gemm_awq_triton.py +++ b/gptqmodel/nn_modules/qlinear/gemm_awq_triton.py @@ -15,10 +15,6 @@ from ...quantization.awq.modules.triton.gemm import awq_dequantize_triton, awq_gemm_triton -def triton_backend_available() -> bool: - return tritonv2.TRITON_AVAILABLE - - def triton_forward(x: torch.Tensor, qweight: torch.Tensor, qzeros: torch.Tensor, scales: torch.Tensor, group_size: int, bias: torch.Tensor, out_features: int) -> torch.Tensor: FP16_MATMUL_HEURISTIC_CONDITION = x.shape[0] * x.shape[1] >= 1024 @@ -39,54 +35,6 @@ def triton_dequantize_weights(qweight: torch.Tensor, qzeros: torch.Tensor, scale return awq_dequantize_triton(qweight, scales, qzeros).to(dtype) -class WQLinearMMTritonFunction(torch.autograd.Function): - @staticmethod - def forward( - ctx, - x, - qweight, - qzeros, - scales, - w_bit=4, - group_size=128, - bias=None, - out_features=0, - prefer_backend=None, - ): - if not triton_backend_available(): - raise ValueError(tritonv2.TRITON_INSTALL_HINT) - - ctx.save_for_backward(x, qweight, qzeros, scales, bias) - ctx.out_features = out_features - - out_shape = x.shape[:-1] + (out_features,) - x = x.to(torch.float16) - if x.shape[0] == 0: - return torch.zeros(out_shape, dtype=x.dtype, device=x.device) - - out = triton_forward(x, qweight, qzeros, scales, group_size, bias, out_features) - out = out + bias if bias is not None else out - out = out.reshape(out_shape) - if len(out.shape) == 2: - out = out.unsqueeze(0) - return out - - @staticmethod - def backward(ctx, grad_output): - input, qweight, qzeros, scales, bias = ctx.saved_tensors - if not triton_backend_available(): - raise ValueError(tritonv2.TRITON_INSTALL_HINT) - - weights = triton_dequantize_weights(qweight, qzeros, scales, grad_output.dtype) - - grad_input = None - if ctx.needs_input_grad[0]: - batch_size = grad_output.shape[0] - grad_input = grad_output.bmm(weights.transpose(0, 1).unsqueeze(0).repeat(batch_size, 1, 1)) - - return grad_input, None, None, None, None, None, None, None, None - - class AwqGEMMTritonQuantLinear(AWQuantLinear): SUPPORTS_BITS = [4] SUPPORTS_GROUP_SIZE = [-1, 16, 32, 64, 128] @@ -109,6 +57,60 @@ class AwqGEMMTritonQuantLinear(AWQuantLinear): QUANT_TYPE = "awq_gemm_triton" + class _Function(torch.autograd.Function): + @staticmethod + def forward( + ctx, + x, + qweight, + qzeros, + scales, + w_bit=4, + group_size=128, + bias=None, + out_features=0, + prefer_backend=None, + ): + if not tritonv2.TRITON_AVAILABLE: + raise ValueError(tritonv2.TRITON_INSTALL_HINT) + + ctx.save_for_backward(x, qweight, qzeros, scales, bias) + ctx.out_features = out_features + + out_shape = x.shape[:-1] + (out_features,) + x = x.to(torch.float16) + if x.shape[0] == 0: + return torch.zeros(out_shape, dtype=x.dtype, device=x.device) + + out = triton_forward(x, qweight, qzeros, scales, group_size, bias, out_features) + out = out + bias if bias is not None else out + out = out.reshape(out_shape) + if len(out.shape) == 2: + out = out.unsqueeze(0) + return out + + @staticmethod + def backward(ctx, grad_output): + input, qweight, qzeros, scales, bias = ctx.saved_tensors + if not tritonv2.TRITON_AVAILABLE: + raise ValueError(tritonv2.TRITON_INSTALL_HINT) + + weights = triton_dequantize_weights(qweight, qzeros, scales, grad_output.dtype) + + grad_input = None + if ctx.needs_input_grad[0]: + batch_size = grad_output.shape[0] + grad_input = grad_output.bmm(weights.transpose(0, 1).unsqueeze(0).repeat(batch_size, 1, 1)) + + return grad_input, None, None, None, None, None, None, None, None + + @classmethod + def validate(cls, **args): + if not tritonv2.TRITON_AVAILABLE: + return False, ValueError(tritonv2.TRITON_INSTALL_HINT) + + return cls._validate(**args) + def __init__( self, bits: int, @@ -123,9 +125,6 @@ def __init__( register_buffers: bool = False, **kwargs, ): - if not triton_backend_available(): - raise ValueError(tritonv2.TRITON_INSTALL_HINT) - super().__init__( bits=bits, group_size=group_size, @@ -154,7 +153,7 @@ def forward(self, x: torch.Tensor): ctx = nullcontext() if self.training else torch.inference_mode() with ctx: - out = WQLinearMMTritonFunction.apply( + out = self._Function.apply( x, self.qweight, self.qzeros, @@ -178,9 +177,7 @@ def forward(self, x: torch.Tensor): __all__ = [ "awq_dequantize_triton", "awq_gemm_triton", - "triton_backend_available", "triton_forward", "triton_dequantize_weights", "AwqGEMMTritonQuantLinear", - "WQLinearMMTritonFunction", ] From d1cbe25f9c3545d90969c21b213de006cba5b6b7 Mon Sep 17 00:00:00 2001 From: Qubitium Date: Sun, 23 Nov 2025 09:58:22 +0000 Subject: [PATCH 13/24] cleanup --- gptqmodel/nn_modules/qlinear/gemm_awq.py | 156 ++++++++++------------- tests/test_awq_fp16_matmul_heuristic.py | 7 +- 2 files changed, 71 insertions(+), 92 deletions(-) diff --git a/gptqmodel/nn_modules/qlinear/gemm_awq.py b/gptqmodel/nn_modules/qlinear/gemm_awq.py index 6fec7e11d..69b7cf72b 100644 --- a/gptqmodel/nn_modules/qlinear/gemm_awq.py +++ b/gptqmodel/nn_modules/qlinear/gemm_awq.py @@ -21,32 +21,6 @@ user_has_been_warned = False -def cuda_backend_available() -> bool: - return awq_ext is not None - - -def cuda_forward(x: torch.Tensor, qweight: torch.Tensor, qzeros: torch.Tensor, scales: torch.Tensor, - group_size: int, bias: torch.Tensor, out_features: int) -> torch.Tensor: - FP16_MATMUL_HEURISTIC_CONDITION = x.shape[0] * x.shape[1] >= 1024 - - if FP16_MATMUL_HEURISTIC_CONDITION: - out = awq_ext.dequantize_weights_cuda(qweight, scales, qzeros, 0, 0, 0, False) - out = torch.matmul(x, out) - else: - out = awq_ext.gemm_forward_cuda( - x.reshape(-1, x.shape[-1]), qweight, scales, qzeros, 8 - ) - - return out - - -def cuda_dequantize_weights(qweight: torch.Tensor, qzeros: torch.Tensor, scales: torch.Tensor, - dtype: torch.dtype) -> torch.Tensor: - return awq_ext.dequantize_weights_cuda( - qweight, scales, qzeros, 1, 0, 0, False - ).to(dtype) - - class AwqGEMMQuantLinear(AWQuantLinear): SUPPORTS_BITS = [4] SUPPORTS_GROUP_SIZE = [-1, 16, 32, 64, 128] @@ -70,6 +44,73 @@ class AwqGEMMQuantLinear(AWQuantLinear): # for transformers/optimum tests compat QUANT_TYPE = "awq_gemm" + class _Function(torch.autograd.Function): + @staticmethod + def forward( + ctx, + x, + qweight, + qzeros, + scales, + w_bit=4, + group_size=128, + bias=None, + out_features=0, + prefer_backend=None, + ): + if awq_ext is None: + raise ValueError(msg or "CUDA AWQ extension not available for AwqGEMMQuantLinear") + + ctx.save_for_backward(x, qweight, qzeros, scales, bias) + ctx.out_features = out_features + + out_shape = x.shape[:-1] + (out_features,) + x = x.to(torch.float16) + if x.shape[0] == 0: + return torch.zeros(out_shape, dtype=x.dtype, device=x.device) + + FP16_MATMUL_HEURISTIC_CONDITION = x.shape[0] * x.shape[1] >= 1024 + if FP16_MATMUL_HEURISTIC_CONDITION: + out = awq_ext.dequantize_weights_cuda(qweight, scales, qzeros, 0, 0, 0, False) + out = torch.matmul(x, out) + else: + out = awq_ext.gemm_forward_cuda( + x.reshape(-1, x.shape[-1]), qweight, scales, qzeros, 8 + ) + + out = out + bias if bias is not None else out + out = out.reshape(out_shape) + + if len(out.shape) == 2: + out = out.unsqueeze(0) + + return out + + @staticmethod + def backward(ctx, grad_output): + input, qweight, qzeros, scales, bias = ctx.saved_tensors + + if awq_ext is None: + raise ValueError(msg or "CUDA AWQ extension not available for AwqGEMMQuantLinear") + + weights = awq_ext.dequantize_weights_cuda( + qweight, scales, qzeros, 1, 0, 0, False + ).to(grad_output.dtype) + + grad_input = None + if ctx.needs_input_grad[0]: + batch_size = grad_output.shape[0] + grad_input = grad_output.bmm(weights.transpose(0, 1).unsqueeze(0).repeat(batch_size, 1, 1)) + + return grad_input, None, None, None, None, None, None, None, None + + @classmethod + def validate(cls, **args): + if awq_ext is None: + return False, ValueError(msg or "CUDA AWQ extension not available; cannot select AwqGEMMQuantLinear") + + return cls._validate(**args) + def __init__( self, bits: int, @@ -84,8 +125,6 @@ def __init__( register_buffers: bool = False, **kwargs, ): - if not cuda_backend_available(): - raise ValueError("CUDA AWQ extension not available; cannot build AwqGEMMQuantLinear") super().__init__( bits=bits, @@ -127,7 +166,7 @@ def forward(self, x: torch.Tensor): ctx = nullcontext() if self.training else torch.inference_mode() with ctx: - out = WQLinearMMFunction.apply( + out = self._Function.apply( x, self.qweight, self.qzeros, @@ -147,63 +186,6 @@ def forward(self, x: torch.Tensor): return out.reshape(out_shape) -# Adapted from https://github.com/compressa-ai/AutoAWQ/tree/dev -class WQLinearMMFunction(torch.autograd.Function): - @staticmethod - # ctx is the first argument to forward - def forward( - ctx, - x, - qweight, - qzeros, - scales, - w_bit=4, - group_size=128, - bias=None, - out_features=0, - prefer_backend=None, - ): - # The forward pass can use ctx. - ctx.save_for_backward(x, qweight, qzeros, scales, bias) - ctx.out_features = out_features - - out_shape = x.shape[:-1] + (out_features,) - x = x.to(torch.float16) - if x.shape[0] == 0: - return torch.zeros(out_shape, dtype=x.dtype, device=x.device) - - if not cuda_backend_available(): - raise ValueError("CUDA AWQ extension not available for WQLinearMMFunction") - - out = cuda_forward(x, qweight, qzeros, scales, group_size, bias, out_features) - - out = out + bias if bias is not None else out - out = out.reshape(out_shape) - - # always want 3D tensor if tensor is 2D - if len(out.shape) == 2: - out = out.unsqueeze(0) - - return out - - @staticmethod - def backward(ctx, grad_output): - input, qweight, qzeros, scales, bias = ctx.saved_tensors - - if not cuda_backend_available(): - raise ValueError("CUDA AWQ extension not available for WQLinearMMFunction") - - # Cast to correct dtype for mixed precision training - weights = cuda_dequantize_weights(qweight, qzeros, scales, grad_output.dtype) - - if ctx.needs_input_grad[0]: - # 3D matmul using torch.bmm: https://pytorch.org/docs/stable/generated/torch.bmm.html#torch.bmm - # to propagate gradient across all batch sizes. - batch_size = grad_output.shape[0] - grad_input = grad_output.bmm(weights.transpose(0, 1).unsqueeze(0).repeat(batch_size, 1, 1)) - - return grad_input, None, None, None, None, None, None, None, None - -__all__ = ["AwqGEMMQuantLinear", "WQLinearMMFunction"] +__all__ = ["AwqGEMMQuantLinear"] diff --git a/tests/test_awq_fp16_matmul_heuristic.py b/tests/test_awq_fp16_matmul_heuristic.py index bf1533d81..89a6b0e14 100644 --- a/tests/test_awq_fp16_matmul_heuristic.py +++ b/tests/test_awq_fp16_matmul_heuristic.py @@ -17,11 +17,8 @@ def _fake_quant_tensors(in_features: int = 32, out_features: int = 8, group_size def _patch_for_triton(monkeypatch, calls): monkeypatch.setattr(gemm_awq, "awq_ext", None) - monkeypatch.setattr(gemm_awq, "cuda_backend_available", lambda: False) monkeypatch.setattr(gemm_awq_triton.tritonv2, "TRITON_AVAILABLE", True) - monkeypatch.setattr(gemm_awq_triton, "triton_backend_available", lambda: True) - monkeypatch.setattr(gemm_awq, "triton_backend_available", lambda: True, raising=False) def fake_dequant(qweight, scales, qzeros): calls["dequant"] += 1 @@ -47,7 +44,7 @@ def test_fp16_matmul_heuristic_prefers_dequant_for_large_matrices(monkeypatch): # Large batch x sequence activates the dequantize-then-matmul path. x = torch.ones((33, 32, qweight.shape[0]), dtype=torch.float16) - out = gemm_awq_triton.WQLinearMMTritonFunction.apply( + out = gemm_awq_triton.AwqGEMMTritonQuantLinear._Function.apply( x, qweight, qzeros, scales, 4, group_size, None, out_features, ) @@ -66,7 +63,7 @@ def test_fp16_matmul_heuristic_prefers_fused_gemm_for_small_matrices(monkeypatch # Small batch x sequence stays on the fused GEMM kernel. x = torch.ones((1, 1, qweight.shape[0]), dtype=torch.float16) - out = gemm_awq_triton.WQLinearMMTritonFunction.apply( + out = gemm_awq_triton.AwqGEMMTritonQuantLinear._Function.apply( x, qweight, qzeros, scales, 4, group_size, None, out_features, ) From f5d3c59f1cb83f2f3a71d0608b1f398fb41c25d7 Mon Sep 17 00:00:00 2001 From: Qubitium Date: Sun, 23 Nov 2025 12:57:53 +0000 Subject: [PATCH 14/24] refractor --- gptqmodel/nn_modules/qlinear/gemm_awq.py | 128 +++++++++--------- .../nn_modules/qlinear/gemm_awq_triton.py | 118 +++++++--------- 2 files changed, 118 insertions(+), 128 deletions(-) diff --git a/gptqmodel/nn_modules/qlinear/gemm_awq.py b/gptqmodel/nn_modules/qlinear/gemm_awq.py index 69b7cf72b..5d5e1fef9 100644 --- a/gptqmodel/nn_modules/qlinear/gemm_awq.py +++ b/gptqmodel/nn_modules/qlinear/gemm_awq.py @@ -21,6 +21,67 @@ user_has_been_warned = False +class AwqGemmFn(torch.autograd.Function): + @staticmethod + def forward( + ctx, + x, + qweight, + qzeros, + scales, + w_bit=4, + group_size=128, + bias=None, + out_features=0, + prefer_backend=None, + ): + if awq_ext is None: + raise ValueError(msg or "CUDA AWQ extension not available for AwqGEMMQuantLinear") + + ctx.save_for_backward(x, qweight, qzeros, scales, bias) + ctx.out_features = out_features + + out_shape = x.shape[:-1] + (out_features,) + x = x.to(torch.float16) + if x.shape[0] == 0: + return torch.zeros(out_shape, dtype=x.dtype, device=x.device) + + FP16_MATMUL_HEURISTIC_CONDITION = x.shape[0] * x.shape[1] >= 1024 + if FP16_MATMUL_HEURISTIC_CONDITION: + out = awq_ext.dequantize_weights_cuda(qweight, scales, qzeros, 0, 0, 0, False) + out = torch.matmul(x, out) + else: + out = awq_ext.gemm_forward_cuda( + x.reshape(-1, x.shape[-1]), qweight, scales, qzeros, 8 + ) + + out = out + bias if bias is not None else out + out = out.reshape(out_shape) + + if len(out.shape) == 2: + out = out.unsqueeze(0) + + return out + + @staticmethod + def backward(ctx, grad_output): + input, qweight, qzeros, scales, bias = ctx.saved_tensors + + if awq_ext is None: + raise ValueError(msg or "CUDA AWQ extension not available for AwqGEMMQuantLinear") + + weights = awq_ext.dequantize_weights_cuda( + qweight, scales, qzeros, 1, 0, 0, False + ).to(grad_output.dtype) + + grad_input = None + if ctx.needs_input_grad[0]: + batch_size = grad_output.shape[0] + grad_input = grad_output.bmm(weights.transpose(0, 1).unsqueeze(0).repeat(batch_size, 1, 1)) + + return grad_input, None, None, None, None, None, None, None, None + + class AwqGEMMQuantLinear(AWQuantLinear): SUPPORTS_BITS = [4] SUPPORTS_GROUP_SIZE = [-1, 16, 32, 64, 128] @@ -44,66 +105,6 @@ class AwqGEMMQuantLinear(AWQuantLinear): # for transformers/optimum tests compat QUANT_TYPE = "awq_gemm" - class _Function(torch.autograd.Function): - @staticmethod - def forward( - ctx, - x, - qweight, - qzeros, - scales, - w_bit=4, - group_size=128, - bias=None, - out_features=0, - prefer_backend=None, - ): - if awq_ext is None: - raise ValueError(msg or "CUDA AWQ extension not available for AwqGEMMQuantLinear") - - ctx.save_for_backward(x, qweight, qzeros, scales, bias) - ctx.out_features = out_features - - out_shape = x.shape[:-1] + (out_features,) - x = x.to(torch.float16) - if x.shape[0] == 0: - return torch.zeros(out_shape, dtype=x.dtype, device=x.device) - - FP16_MATMUL_HEURISTIC_CONDITION = x.shape[0] * x.shape[1] >= 1024 - if FP16_MATMUL_HEURISTIC_CONDITION: - out = awq_ext.dequantize_weights_cuda(qweight, scales, qzeros, 0, 0, 0, False) - out = torch.matmul(x, out) - else: - out = awq_ext.gemm_forward_cuda( - x.reshape(-1, x.shape[-1]), qweight, scales, qzeros, 8 - ) - - out = out + bias if bias is not None else out - out = out.reshape(out_shape) - - if len(out.shape) == 2: - out = out.unsqueeze(0) - - return out - - @staticmethod - def backward(ctx, grad_output): - input, qweight, qzeros, scales, bias = ctx.saved_tensors - - if awq_ext is None: - raise ValueError(msg or "CUDA AWQ extension not available for AwqGEMMQuantLinear") - - weights = awq_ext.dequantize_weights_cuda( - qweight, scales, qzeros, 1, 0, 0, False - ).to(grad_output.dtype) - - grad_input = None - if ctx.needs_input_grad[0]: - batch_size = grad_output.shape[0] - grad_input = grad_output.bmm(weights.transpose(0, 1).unsqueeze(0).repeat(batch_size, 1, 1)) - - return grad_input, None, None, None, None, None, None, None, None - @classmethod def validate(cls, **args): if awq_ext is None: @@ -166,7 +167,7 @@ def forward(self, x: torch.Tensor): ctx = nullcontext() if self.training else torch.inference_mode() with ctx: - out = self._Function.apply( + out = AwqGemmFn.apply( x, self.qweight, self.qzeros, @@ -188,4 +189,7 @@ def forward(self, x: torch.Tensor): -__all__ = ["AwqGEMMQuantLinear"] +__all__ = [ + "AwqGemmFn", + "AwqGEMMQuantLinear", +] diff --git a/gptqmodel/nn_modules/qlinear/gemm_awq_triton.py b/gptqmodel/nn_modules/qlinear/gemm_awq_triton.py index ecaba2394..1af38896f 100644 --- a/gptqmodel/nn_modules/qlinear/gemm_awq_triton.py +++ b/gptqmodel/nn_modules/qlinear/gemm_awq_triton.py @@ -15,24 +15,60 @@ from ...quantization.awq.modules.triton.gemm import awq_dequantize_triton, awq_gemm_triton -def triton_forward(x: torch.Tensor, qweight: torch.Tensor, qzeros: torch.Tensor, scales: torch.Tensor, - group_size: int, bias: torch.Tensor, out_features: int) -> torch.Tensor: - FP16_MATMUL_HEURISTIC_CONDITION = x.shape[0] * x.shape[1] >= 1024 +class AwqGemmTritonFn(torch.autograd.Function): + @staticmethod + def forward( + ctx, + x, + qweight, + qzeros, + scales, + w_bit=4, + group_size=128, + bias=None, + out_features=0, + prefer_backend=None, + ): + if not tritonv2.TRITON_AVAILABLE: + raise ValueError(tritonv2.TRITON_INSTALL_HINT) + + ctx.save_for_backward(x, qweight, qzeros, scales, bias) + ctx.out_features = out_features + + out_shape = x.shape[:-1] + (out_features,) + x = x.to(torch.float16) + if x.shape[0] == 0: + return torch.zeros(out_shape, dtype=x.dtype, device=x.device) + + FP16_MATMUL_HEURISTIC_CONDITION = x.shape[0] * x.shape[1] >= 1024 + if FP16_MATMUL_HEURISTIC_CONDITION: + out = awq_dequantize_triton(qweight, scales, qzeros) + out = torch.matmul(x, out.to(x.dtype)) + else: + out = awq_gemm_triton( + x.reshape(-1, x.shape[-1]), qweight, scales, qzeros, split_k_iters=8, + ) - if FP16_MATMUL_HEURISTIC_CONDITION: - out = awq_dequantize_triton(qweight, scales, qzeros) - out = torch.matmul(x, out.to(x.dtype)) - else: - out = awq_gemm_triton( - x.reshape(-1, x.shape[-1]), qweight, scales, qzeros, split_k_iters=8, - ) + out = out + bias if bias is not None else out + out = out.reshape(out_shape) + if len(out.shape) == 2: + out = out.unsqueeze(0) + return out - return out + @staticmethod + def backward(ctx, grad_output): + input, qweight, qzeros, scales, bias = ctx.saved_tensors + if not tritonv2.TRITON_AVAILABLE: + raise ValueError(tritonv2.TRITON_INSTALL_HINT) + + weights = awq_dequantize_triton(qweight, scales, qzeros).to(grad_output.dtype) + grad_input = None + if ctx.needs_input_grad[0]: + batch_size = grad_output.shape[0] + grad_input = grad_output.bmm(weights.transpose(0, 1).unsqueeze(0).repeat(batch_size, 1, 1)) -def triton_dequantize_weights(qweight: torch.Tensor, qzeros: torch.Tensor, scales: torch.Tensor, - dtype: torch.dtype) -> torch.Tensor: - return awq_dequantize_triton(qweight, scales, qzeros).to(dtype) + return grad_input, None, None, None, None, None, None, None, None class AwqGEMMTritonQuantLinear(AWQuantLinear): @@ -57,53 +93,6 @@ class AwqGEMMTritonQuantLinear(AWQuantLinear): QUANT_TYPE = "awq_gemm_triton" - class _Function(torch.autograd.Function): - @staticmethod - def forward( - ctx, - x, - qweight, - qzeros, - scales, - w_bit=4, - group_size=128, - bias=None, - out_features=0, - prefer_backend=None, - ): - if not tritonv2.TRITON_AVAILABLE: - raise ValueError(tritonv2.TRITON_INSTALL_HINT) - - ctx.save_for_backward(x, qweight, qzeros, scales, bias) - ctx.out_features = out_features - - out_shape = x.shape[:-1] + (out_features,) - x = x.to(torch.float16) - if x.shape[0] == 0: - return torch.zeros(out_shape, dtype=x.dtype, device=x.device) - - out = triton_forward(x, qweight, qzeros, scales, group_size, bias, out_features) - out = out + bias if bias is not None else out - out = out.reshape(out_shape) - if len(out.shape) == 2: - out = out.unsqueeze(0) - return out - - @staticmethod - def backward(ctx, grad_output): - input, qweight, qzeros, scales, bias = ctx.saved_tensors - if not tritonv2.TRITON_AVAILABLE: - raise ValueError(tritonv2.TRITON_INSTALL_HINT) - - weights = triton_dequantize_weights(qweight, qzeros, scales, grad_output.dtype) - - grad_input = None - if ctx.needs_input_grad[0]: - batch_size = grad_output.shape[0] - grad_input = grad_output.bmm(weights.transpose(0, 1).unsqueeze(0).repeat(batch_size, 1, 1)) - - return grad_input, None, None, None, None, None, None, None, None - @classmethod def validate(cls, **args): if not tritonv2.TRITON_AVAILABLE: @@ -153,7 +142,7 @@ def forward(self, x: torch.Tensor): ctx = nullcontext() if self.training else torch.inference_mode() with ctx: - out = self._Function.apply( + out = AwqGemmTritonFn.apply( x, self.qweight, self.qzeros, @@ -175,9 +164,6 @@ def forward(self, x: torch.Tensor): __all__ = [ - "awq_dequantize_triton", - "awq_gemm_triton", - "triton_forward", - "triton_dequantize_weights", + "AwqGemmTritonFn", "AwqGEMMTritonQuantLinear", ] From ca3a8b8d54346d5fe094f21b7d9101783e780bf1 Mon Sep 17 00:00:00 2001 From: Qubitium Date: Sun, 23 Nov 2025 13:01:34 +0000 Subject: [PATCH 15/24] refractor --- gptqmodel/nn_modules/qlinear/gemm_awq.py | 5 +++-- gptqmodel/nn_modules/qlinear/gemm_awq_triton.py | 5 +++-- tests/test_awq_fp16_matmul_heuristic.py | 4 ++-- 3 files changed, 8 insertions(+), 6 deletions(-) diff --git a/gptqmodel/nn_modules/qlinear/gemm_awq.py b/gptqmodel/nn_modules/qlinear/gemm_awq.py index 5d5e1fef9..06e2e8e1b 100644 --- a/gptqmodel/nn_modules/qlinear/gemm_awq.py +++ b/gptqmodel/nn_modules/qlinear/gemm_awq.py @@ -46,8 +46,9 @@ def forward( if x.shape[0] == 0: return torch.zeros(out_shape, dtype=x.dtype, device=x.device) - FP16_MATMUL_HEURISTIC_CONDITION = x.shape[0] * x.shape[1] >= 1024 - if FP16_MATMUL_HEURISTIC_CONDITION: + # Above compute density threshold it is faster to just dequantize the whole thing and do simple matmul + FULL_DEQUANT_MATMUL_THRESHOLD = x.shape[0] * x.shape[1] > 1024 + if FULL_DEQUANT_MATMUL_THRESHOLD: out = awq_ext.dequantize_weights_cuda(qweight, scales, qzeros, 0, 0, 0, False) out = torch.matmul(x, out) else: diff --git a/gptqmodel/nn_modules/qlinear/gemm_awq_triton.py b/gptqmodel/nn_modules/qlinear/gemm_awq_triton.py index 1af38896f..a09409b23 100644 --- a/gptqmodel/nn_modules/qlinear/gemm_awq_triton.py +++ b/gptqmodel/nn_modules/qlinear/gemm_awq_triton.py @@ -40,8 +40,9 @@ def forward( if x.shape[0] == 0: return torch.zeros(out_shape, dtype=x.dtype, device=x.device) - FP16_MATMUL_HEURISTIC_CONDITION = x.shape[0] * x.shape[1] >= 1024 - if FP16_MATMUL_HEURISTIC_CONDITION: + # Above compute density threshold it is faster to just dequantize the whole thing and do simple matmul + FULL_DEQUANT_MATMUL_THRESHOLD = x.shape[0] * x.shape[1] > 1024 + if FULL_DEQUANT_MATMUL_THRESHOLD: out = awq_dequantize_triton(qweight, scales, qzeros) out = torch.matmul(x, out.to(x.dtype)) else: diff --git a/tests/test_awq_fp16_matmul_heuristic.py b/tests/test_awq_fp16_matmul_heuristic.py index 89a6b0e14..a8b402b7b 100644 --- a/tests/test_awq_fp16_matmul_heuristic.py +++ b/tests/test_awq_fp16_matmul_heuristic.py @@ -44,7 +44,7 @@ def test_fp16_matmul_heuristic_prefers_dequant_for_large_matrices(monkeypatch): # Large batch x sequence activates the dequantize-then-matmul path. x = torch.ones((33, 32, qweight.shape[0]), dtype=torch.float16) - out = gemm_awq_triton.AwqGEMMTritonQuantLinear._Function.apply( + out = gemm_awq_triton.AwqGemmTritonFn.apply( x, qweight, qzeros, scales, 4, group_size, None, out_features, ) @@ -63,7 +63,7 @@ def test_fp16_matmul_heuristic_prefers_fused_gemm_for_small_matrices(monkeypatch # Small batch x sequence stays on the fused GEMM kernel. x = torch.ones((1, 1, qweight.shape[0]), dtype=torch.float16) - out = gemm_awq_triton.AwqGEMMTritonQuantLinear._Function.apply( + out = gemm_awq_triton.AwqGemmTritonFn.apply( x, qweight, qzeros, scales, 4, group_size, None, out_features, ) From 5aacbe6b1a58cf4597d9838783aea87a17eb5a4b Mon Sep 17 00:00:00 2001 From: Qubitium Date: Sun, 23 Nov 2025 13:22:29 +0000 Subject: [PATCH 16/24] update test --- tests/test_awq_fp16_matmul_heuristic.py | 90 +++++++++++++++++-------- 1 file changed, 62 insertions(+), 28 deletions(-) diff --git a/tests/test_awq_fp16_matmul_heuristic.py b/tests/test_awq_fp16_matmul_heuristic.py index a8b402b7b..3ef2d6e74 100644 --- a/tests/test_awq_fp16_matmul_heuristic.py +++ b/tests/test_awq_fp16_matmul_heuristic.py @@ -15,27 +15,46 @@ def _fake_quant_tensors(in_features: int = 32, out_features: int = 8, group_size return qweight, scales, qzeros -def _patch_for_triton(monkeypatch, calls): - monkeypatch.setattr(gemm_awq, "awq_ext", None) +def _patch_backend(monkeypatch, backend: str, calls): + if backend == "triton": + monkeypatch.setattr(gemm_awq, "awq_ext", None) + + monkeypatch.setattr(gemm_awq_triton.tritonv2, "TRITON_AVAILABLE", True) + + def fake_dequant(qweight, scales, qzeros): + calls["dequant"] += 1 + return torch.ones(qweight.shape[0], qweight.shape[1] * 8, dtype=torch.float16) + + def fake_gemm(input, qweight, scales, qzeros, split_k_iters, **_): + calls["gemm"] += 1 + out_features = qweight.shape[1] * 8 + return torch.ones(input.shape[0], out_features, device=input.device, dtype=input.dtype) + + monkeypatch.setattr(gemm_awq_triton, "awq_dequantize_triton", fake_dequant, raising=False) + monkeypatch.setattr(gemm_awq_triton, "awq_gemm_triton", fake_gemm, raising=False) - monkeypatch.setattr(gemm_awq_triton.tritonv2, "TRITON_AVAILABLE", True) + return gemm_awq_triton.AwqGemmTritonFn - def fake_dequant(qweight, scales, qzeros): - calls["dequant"] += 1 - return torch.ones(qweight.shape[0], qweight.shape[1] * 8, dtype=torch.float16) + # Stub the compiled AWQ extension so we can count which path is taken. + class FakeAwqExt: + def dequantize_weights_cuda(self, qweight, scales, qzeros, *_args): + calls["dequant"] += 1 + return torch.ones(qweight.shape[0], qweight.shape[1] * 8, dtype=torch.float16) - def fake_gemm(input, qweight, scales, qzeros, split_k_iters, **_): - calls["gemm"] += 1 - out_features = qweight.shape[1] * 8 - return torch.ones(input.shape[0], out_features, device=input.device, dtype=input.dtype) + def gemm_forward_cuda(self, input, qweight, scales, qzeros, _split_k_iters): + calls["gemm"] += 1 + out_features = qweight.shape[1] * 8 + return torch.ones(input.shape[0], out_features, device=input.device, dtype=input.dtype) - monkeypatch.setattr(gemm_awq_triton, "awq_dequantize_triton", fake_dequant, raising=False) - monkeypatch.setattr(gemm_awq_triton, "awq_gemm_triton", fake_gemm, raising=False) + monkeypatch.setattr(gemm_awq, "awq_ext", FakeAwqExt()) + monkeypatch.setattr(gemm_awq_triton.tritonv2, "TRITON_AVAILABLE", False) + return gemm_awq.AwqGemmFn -def test_fp16_matmul_heuristic_prefers_dequant_for_large_matrices(monkeypatch): +@pytest.mark.parametrize("backend", ["triton", "ext"], ids=["triton", "awq_ext"]) +def test_fp16_matmul_heuristic_prefers_dequant_for_large_matrices(monkeypatch, backend): calls = {"dequant": 0, "gemm": 0} - _patch_for_triton(monkeypatch, calls) + fn = _patch_backend(monkeypatch, backend, calls) group_size = 32 out_features = 8 @@ -44,7 +63,7 @@ def test_fp16_matmul_heuristic_prefers_dequant_for_large_matrices(monkeypatch): # Large batch x sequence activates the dequantize-then-matmul path. x = torch.ones((33, 32, qweight.shape[0]), dtype=torch.float16) - out = gemm_awq_triton.AwqGemmTritonFn.apply( + out = fn.apply( x, qweight, qzeros, scales, 4, group_size, None, out_features, ) @@ -52,9 +71,10 @@ def test_fp16_matmul_heuristic_prefers_dequant_for_large_matrices(monkeypatch): assert out.shape == (33, 32, out_features) -def test_fp16_matmul_heuristic_prefers_fused_gemm_for_small_matrices(monkeypatch): +@pytest.mark.parametrize("backend", ["triton", "ext"], ids=["triton", "awq_ext"]) +def test_fp16_matmul_heuristic_prefers_fused_gemm_for_small_matrices(monkeypatch, backend): calls = {"dequant": 0, "gemm": 0} - _patch_for_triton(monkeypatch, calls) + fn = _patch_backend(monkeypatch, backend, calls) group_size = 32 out_features = 8 @@ -63,7 +83,7 @@ def test_fp16_matmul_heuristic_prefers_fused_gemm_for_small_matrices(monkeypatch # Small batch x sequence stays on the fused GEMM kernel. x = torch.ones((1, 1, qweight.shape[0]), dtype=torch.float16) - out = gemm_awq_triton.AwqGemmTritonFn.apply( + out = fn.apply( x, qweight, qzeros, scales, 4, group_size, None, out_features, ) @@ -71,6 +91,24 @@ def test_fp16_matmul_heuristic_prefers_fused_gemm_for_small_matrices(monkeypatch assert out.shape == (1, 1, out_features) +def _available_bench_backends(): + backends = [] + if gemm_awq.awq_ext is not None: + backends.append("awq_ext") + if gemm_awq_triton.tritonv2.TRITON_AVAILABLE: + backends.append("triton") + return backends + + +_BACKEND_PARAMS = _available_bench_backends() +if _BACKEND_PARAMS: + BACKEND_PARAMS = [pytest.param(backend, id=backend) for backend in _BACKEND_PARAMS] +else: + BACKEND_PARAMS = [ + pytest.param("missing", id="no_backend", marks=pytest.mark.skip(reason="No AWQ backend available for benchmark")) + ] + + SEQ_LENS = [128, 256, 512, 1024, 1280, 1536, 2048, 4096, 8192] # Each entry: (case_name, batch, in_features, out_features) @@ -92,18 +130,14 @@ def test_fp16_matmul_heuristic_prefers_fused_gemm_for_small_matrices(monkeypatch ids=[f"{case}_s{seq}" for (case, _, _, _) in BENCH_CASES for seq in SEQ_LENS], ) @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA backend required for AWQ benchmark") -def test_fp16_matmul_heuristic_benchmark(case_name, batch, seq, in_features, out_features): +@pytest.mark.parametrize("backend", BACKEND_PARAMS) +def test_fp16_matmul_heuristic_benchmark(case_name, batch, seq, in_features, out_features, backend): if os.getenv("RUN_AWQ_FP16_HEURISTIC_BENCH") != "1": pytest.skip("Set RUN_AWQ_FP16_HEURISTIC_BENCH=1 to enable this benchmark") tabulate = pytest.importorskip("tabulate").tabulate - backend = None - if gemm_awq.awq_ext is not None: - backend = "awq_ext" - elif gemm_awq.tritonv2.TRITON_AVAILABLE: - backend = "triton" - else: + if backend not in {"awq_ext", "triton"}: pytest.skip("No AWQ backend available for benchmark") device = torch.device("cuda") @@ -154,7 +188,7 @@ def benchmark(fn, iters=3): meets_condition = batch * seq >= 1024 rows = [ - [case_name, batch, seq, meets_condition, f"{in_features}->{out_features}", "condition=True (dequant+matmul)", f"{dequant_ms:.3f} ms"], - [case_name, batch, seq, meets_condition, f"{in_features}->{out_features}", "condition=False (fused gemm)", f"{fused_ms:.3f} ms"], + [case_name, backend, batch, seq, meets_condition, f"{in_features}->{out_features}", "condition=True (dequant+matmul)", f"{dequant_ms:.3f} ms"], + [case_name, backend, batch, seq, meets_condition, f"{in_features}->{out_features}", "condition=False (fused gemm)", f"{fused_ms:.3f} ms"], ] - print(tabulate(rows, headers=["case", "batch", "seq", "meets >=1024", "matmul (in->out)", "path", "avg latency"])) + print(tabulate(rows, headers=["case", "backend", "batch", "seq", "meets >=1024", "matmul (in->out)", "path", "avg latency"])) From 78f1471ebf445c5d1f6c19de02f50b957718151b Mon Sep 17 00:00:00 2001 From: Qubitium Date: Sun, 23 Nov 2025 23:57:28 +0000 Subject: [PATCH 17/24] add gemm_triton --- gptqmodel/utils/backend.py | 1 + gptqmodel/utils/importer.py | 5 +++ tests/test_kernel_output_awq.py | 54 +++++++++++++++++++++++++++++++++ 3 files changed, 60 insertions(+) diff --git a/gptqmodel/utils/backend.py b/gptqmodel/utils/backend.py index 2d2746bee..c54d2449f 100644 --- a/gptqmodel/utils/backend.py +++ b/gptqmodel/utils/backend.py @@ -28,6 +28,7 @@ class BACKEND(str, Enum): # awq GEMM = "gemm" + GEMM_TRITON = "gemm_triton" GEMV = "gemv" GEMV_FAST = "gemv_fast" TORCH_AWQ = "torch_awq" diff --git a/gptqmodel/utils/importer.py b/gptqmodel/utils/importer.py index 7b372eb7e..817c8998f 100644 --- a/gptqmodel/utils/importer.py +++ b/gptqmodel/utils/importer.py @@ -20,6 +20,7 @@ from ..nn_modules.qlinear.exllamav2 import ExllamaV2QuantLinear from ..nn_modules.qlinear.exllamav2_awq import AwqExllamaV2QuantLinear from ..nn_modules.qlinear.gemm_awq import AwqGEMMQuantLinear +from ..nn_modules.qlinear.gemm_awq_triton import AwqGEMMTritonQuantLinear from ..nn_modules.qlinear.gemv_awq import AwqGEMVQuantLinear from ..nn_modules.qlinear.gemv_fast_awq import AwqGEMVFastQuantLinear from ..nn_modules.qlinear.machete import MacheteQuantLinear @@ -69,6 +70,7 @@ BACKEND.EXLLAMA_V2: AwqExllamaV2QuantLinear, BACKEND.EXLLAMA_V1: AwqExllamaQuantLinear, BACKEND.GEMM: AwqGEMMQuantLinear, + BACKEND.GEMM_TRITON: AwqGEMMTritonQuantLinear, BACKEND.GEMV: AwqGEMVQuantLinear, BACKEND.GEMV_FAST: AwqGEMVFastQuantLinear, BACKEND.TORCH_FUSED_AWQ: TorchFusedAwqQuantLinear, @@ -93,6 +95,7 @@ BACKEND.EXLLAMA_V2, BACKEND.EXLLAMA_V1, BACKEND.GEMM, + BACKEND.GEMM_TRITON, BACKEND.TORCH_FUSED_AWQ, BACKEND.TORCH_AWQ, ], @@ -412,6 +415,8 @@ def select_quant_linear( qlinear = QQQQuantLinear elif backend == BACKEND.GEMM: qlinear = AwqGEMMQuantLinear + elif backend == BACKEND.GEMM_TRITON: + qlinear = AwqGEMMTritonQuantLinear elif backend == BACKEND.GEMV: qlinear = AwqGEMVQuantLinear elif backend == BACKEND.GEMV_FAST: diff --git a/tests/test_kernel_output_awq.py b/tests/test_kernel_output_awq.py index c1c4d154b..0c88bdbcc 100644 --- a/tests/test_kernel_output_awq.py +++ b/tests/test_kernel_output_awq.py @@ -24,6 +24,13 @@ from gptqmodel.nn_modules.qlinear.torch_awq import AwqTorchQuantLinear from gptqmodel.nn_modules.qlinear.torch_fused_awq import TorchFusedAwqQuantLinear from gptqmodel.utils.marlin import marlin_make_workspace_new +try: + from gptqmodel.nn_modules.qlinear.gemm_awq_triton import AwqGEMMTritonQuantLinear + + awq_triton_import_exception: Optional[Exception] = None +except Exception as exc: # pragma: no cover - triton import may fail in CI + AwqGEMMTritonQuantLinear = None # type: ignore[assignment] + awq_triton_import_exception = exc os.environ.setdefault("CUDA_DEVICE_ORDER", "PCI_BUS_ID") @@ -55,6 +62,7 @@ class TestAwqKernelOutput(unittest.TestCase): # (baseline_backend, torch.bfloat16, 0.0), (BACKEND.GEMM, torch.float16, 0.004), # (BACKEND.GEMM, torch.bfloat16, 0.05), + (BACKEND.TRITON, torch.float16, 0.004), (BACKEND.MARLIN, torch.float16, 0.006), (BACKEND.TORCH_FUSED_AWQ, torch.float16, 0.004), # (BACKEND.MARLIN, torch.bfloat16, 0.05), @@ -69,7 +77,12 @@ def setUpClass(cls) -> None: cls.backend_skip_reason: Dict[BACKEND, str] = {} if not cls.cuda_available: cls.backend_skip_reason[BACKEND.GEMM] = "CUDA is required for GEMM backend." + cls.backend_skip_reason[BACKEND.TRITON] = "CUDA is required for AWQ Triton backend." cls.backend_skip_reason[BACKEND.MARLIN] = "CUDA is required for AWQ Marlin kernel." + if awq_triton_import_exception is not None: + cls.backend_skip_reason[BACKEND.TRITON] = ( + f"AWQ Triton kernel unavailable: {awq_triton_import_exception}" + ) try: tensors = cls._load_awq_tensors(cls.TARGET) @@ -102,6 +115,16 @@ def setUpClass(cls) -> None: else None ) + try: + cls.modules[BACKEND.TRITON] = ( + cls._build_gemm_triton_module(qweight_cpu, qzeros_cpu, scales_cpu, bias_cpu) + if cls.cuda_available + else None + ) + except Exception as exc: + cls.backend_skip_reason[BACKEND.TRITON] = f"AWQ Triton kernel unavailable: {exc}" + cls.modules[BACKEND.TRITON] = None + cls.modules[BACKEND.MARLIN] = ( cls._build_marlin_module(qweight_cpu, qzeros_cpu, scales_cpu, bias_cpu) if cls.cuda_available @@ -204,6 +227,37 @@ def _build_gemm_module( module.post_init() return module + @classmethod + def _build_gemm_triton_module( + cls, + qweight_cpu: torch.Tensor, + qzeros_cpu: torch.Tensor, + scales_cpu: torch.Tensor, + bias_cpu: torch.Tensor, + ) -> AwqGEMMTritonQuantLinear: + if AwqGEMMTritonQuantLinear is None: + raise RuntimeError("AWQ Triton kernel not available.") + module = AwqGEMMTritonQuantLinear( + bits=cls.BITS, + group_size=cls.GROUP_SIZE, + sym=True, + desc_act=False, + in_features=cls.in_features, + out_features=cls.out_features, + bias=True, + adapter=None, + register_buffers=True, + ).to(cls.device) + + module.qweight.copy_(qweight_cpu.to(cls.device)) + module.qzeros.copy_(qzeros_cpu.to(cls.device)) + module.scales.copy_(scales_cpu.to(torch.float16).to(cls.device)) + module.bias.copy_(bias_cpu.to(torch.float16).to(cls.device)) + + module.eval() + module.post_init() + return module + @classmethod def _build_marlin_module( cls, From 289fd931ea0ad790c022f0e57f1531da2a7fb6ef Mon Sep 17 00:00:00 2001 From: ZX-ModelCloud Date: Mon, 24 Nov 2025 13:11:45 +0800 Subject: [PATCH 18/24] fix empty named_childs Signed-off-by: ZX-ModelCloud --- gptqmodel/looper/awq_processor.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/gptqmodel/looper/awq_processor.py b/gptqmodel/looper/awq_processor.py index 0aabeb91e..7f497e5a9 100644 --- a/gptqmodel/looper/awq_processor.py +++ b/gptqmodel/looper/awq_processor.py @@ -334,10 +334,13 @@ def _quantize_layer(self, layer_index: int, state: _AWQLayerState) -> None: with state.lock: # Filtering MLP modules like Qwen3MoeSparseMoeBlock + def unwrap(m): + return m.module if isinstance(m, NamedModule) else m + named_childs = { name: module for name, module in state.modules.items() - if isinstance(module, tuple(SUPPORTS_MODULE_TYPES)) + if isinstance(unwrap(module), tuple(SUPPORTS_MODULE_TYPES)) } module_kwargs_global = dict(self._module_forward_kwargs) From d11099b3dc2bef6239a0a0bb2b38999d0c75edd5 Mon Sep 17 00:00:00 2001 From: ZX-ModelCloud Date: Mon, 24 Nov 2025 17:08:35 +0800 Subject: [PATCH 19/24] add gemm_awq pack() Signed-off-by: ZX-ModelCloud --- gptqmodel/looper/awq_processor.py | 124 ++++++++++++++++------- gptqmodel/nn_modules/qlinear/gemm_awq.py | 72 +++++++++++++ gptqmodel/utils/model.py | 13 +++ 3 files changed, 172 insertions(+), 37 deletions(-) diff --git a/gptqmodel/looper/awq_processor.py b/gptqmodel/looper/awq_processor.py index 7f497e5a9..723fdc8a5 100644 --- a/gptqmodel/looper/awq_processor.py +++ b/gptqmodel/looper/awq_processor.py @@ -24,14 +24,14 @@ from ..nn_modules.qlinear.gemv_awq import AwqGEMVQuantLinear from ..nn_modules.qlinear.gemv_fast_awq import AwqGEMVFastQuantLinear from ..nn_modules.qlinear.marlin_awq import AwqMarlinQuantLinear -from ..quantization.awq.modules.linear import WQLinear_GEMV, WQLinear_GEMVFast, WQLinear_Marlin from ..quantization.awq.quantize.scale import apply_clip, apply_scale -from ..quantization.awq.utils.module import append_str_prefix, get_op_name, get_op_by_name, set_op_by_name +from ..quantization.awq.utils.module import append_str_prefix, get_op_name, get_op_by_name from ..quantization.awq.utils.utils import get_best_device from ..quantization.config import FORMAT, METHOD, QuantizeConfig -from ..utils.logger import setup_logger +from ..utils.logger import setup_logger, log_time_block from ..utils.ctx import ctx -from ..utils.model import find_modules, get_module_by_name_prefix, move_to +from ..utils.model import find_modules, get_module_by_name_prefix, move_to, create_quant_module, pack_module +from ..utils.module_locks import parent_module_lock from ..utils.torch import CPU log = setup_logger() @@ -90,6 +90,16 @@ def __init__( self._layer_states_lock = threading.Lock() self._scale_context = threading.local() self.gptq_model = gptq_model + + if qcfg.format == FORMAT.GEMM: + self.gptq_model.qlinear_kernel = AwqGEMMQuantLinear + elif qcfg.format == FORMAT.GEMV: + self.gptq_model.qlinear_kernel = AwqGEMVQuantLinear + elif qcfg.format == FORMAT.GEMV_FAST: + self.gptq_model.qlinear_kernel = AwqGEMVFastQuantLinear + else: + raise ValueError(f"METHOD.AWQ does not support this FORMAT: {qcfg.format}") + self.model = model # Whether to apply clipping to the model during quantization. Some models may perform better with this set to False. self.apply_clip = True @@ -546,7 +556,7 @@ def unwrap(m): if not self.export_compatible: start = time.time() - self._apply_quant(layer_module_ref, named_childs, start, scales_list) + self._apply_quant(named_childs, start, scales_list) with state.lock: state.quantized = True @@ -1061,7 +1071,7 @@ def _slice_value(val, length): return module_output - def _apply_quant(self, module, named_linears: Dict[str, NamedModule], start_time, scales_list): + def _apply_quant(self, named_linears: Dict[str, NamedModule], start_time, scales_list): for name, named_module in named_linears.items(): self.pb.title(f"Quantizing {named_module.name} in layer ").draw() linear_layer = named_module.module @@ -1116,37 +1126,6 @@ def _apply_quant(self, module, named_linears: Dict[str, NamedModule], start_time linear_layer.weight.data = wq - if self.format == FORMAT.GEMM: - scales = scales.t().contiguous() - if zeros is not None: - zeros = zeros.t().contiguous() - q_linear_module = WQLinear_GEMM - - elif self.format == FORMAT.GEMV: - q_linear_module = WQLinear_GEMV - - elif self.format == FORMAT.MARLIN: - q_linear_module = WQLinear_Marlin - - elif self.format == FORMAT.GEMV_FAST: - q_linear_module = WQLinear_GEMVFast - - else: - raise ValueError(f"Unknown version {self.format}") - - q_linear = q_linear_module.from_linear( - linear=linear_layer, - w_bit=self.qcfg.bits, - group_size=self.qcfg.group_size, - init_only=False, - scales=scales, - zeros=zeros, - ) - - linear_layer.cpu() - q_linear.to(next(module.parameters()).device) - set_op_by_name(module, name, q_linear) - # records duration = time.time() - start_time @@ -1194,6 +1173,77 @@ def _apply_quant(self, module, named_linears: Dict[str, NamedModule], start_time f"{duration:.3f}", ) + linear_layer = linear_layer.cpu() + scales = scales.cpu() + zeros = zeros.cpu() + + layers = find_modules(self.gptq_model.model) + module_label = getattr(named_module, "full_name", getattr(named_module, "name", "")) + parent_key = getattr(named_module, "full_name", getattr(named_module, "name", None)) + + # replace module with quantized module + timer = getattr(self.gptq_model, "quant_region_timer", None) + + create_start = time.perf_counter() if timer is not None else None + with log_time_block( + "create_quant_module", + logger=log, + module_name=module_label, + ): + with parent_module_lock(parent_key): + create_quant_module( + name=named_module.full_name, + linear_cls=self.gptq_model.qlinear_kernel, + bits=self.qcfg.bits, + desc_act=self.qcfg.desc_act, + dynamic=self.qcfg.dynamic, + group_size=self.qcfg.group_size, + module=self.gptq_model.model, + submodule=named_module, + sym=self.qcfg.sym, + device=self.qcfg.device, + lm_head_name=self.gptq_model.lm_head, + pack_dtype=self.qcfg.pack_dtype, + register_buffers=False, + ) + if timer is not None and create_start is not None: + timer.record( + "submodule_finalize_create", + time.perf_counter() - create_start, + source=module_label, + ) + + # pack module + qModules = { + name: submodule + for name, submodule in find_modules(self.gptq_model.model, [self.gptq_model.qlinear_kernel]).items() + if name == named_module.full_name + } + pack_start = time.perf_counter() if timer is not None else None + with log_time_block( + "pack", + logger=log, + module_name=module_label, + ): + with parent_module_lock(parent_key): + packer_label = pack_module( + name=named_module.full_name, + qModules=qModules, + q_scales=scales, + q_zeros=zeros, + q_g_idx=None, + layers=layers, + quant_linear_cls=self.gptq_model.qlinear_kernel, + lock=self.lock, + quantize_config=self.qcfg, + ) + if timer is not None and pack_start is not None: + timer.record( + "submodule_finalize_pack", + time.perf_counter() - pack_start, + source=f"{module_label} [{packer_label or 'module.pack_original'}]", + ) + def _sanitize_kwargs(self, inputs_kwargs, module): """ Remove the arguments that are not supported in the module's diff --git a/gptqmodel/nn_modules/qlinear/gemm_awq.py b/gptqmodel/nn_modules/qlinear/gemm_awq.py index 06e2e8e1b..2da68d764 100644 --- a/gptqmodel/nn_modules/qlinear/gemm_awq.py +++ b/gptqmodel/nn_modules/qlinear/gemm_awq.py @@ -6,11 +6,13 @@ from contextlib import nullcontext import torch +from torch import nn from ...adapter.adapter import Adapter, Lora from ...models._const import DEVICE, PLATFORM from ...nn_modules.qlinear import AWQuantLinear from ...quantization.awq.utils.module import try_import +from ...quantization.awq.utils.utils import get_best_device from ...utils.backend import BACKEND from ...utils.logger import setup_logger @@ -188,6 +190,76 @@ def forward(self, x: torch.Tensor): return out.reshape(out_shape) + def pack(self, linear: nn.Module, scales: torch.Tensor, zeros: torch.Tensor, g_idx: torch.Tensor=None): + # need scales and zeros info for real quantization + assert scales is not None and zeros is not None + scales = scales.t().contiguous() + zeros = zeros.t().contiguous() + scale_zeros = zeros * scales + + self.register_buffer("scales", scales.clone().half()) + if linear.bias is not None: + self.register_buffer("bias", linear.bias.clone().half()) + else: + self.bias = None + + pack_num = 32 // self.bits + + intweight = [] + for idx in range(self.in_features): + intweight.append( + torch.round( + (linear.weight.data[:, idx] + scale_zeros[idx // self.group_size]) + / self.scales[idx // self.group_size] + ).to(torch.int)[:, None] + ) + intweight = torch.cat(intweight, dim=1) + intweight = intweight.t().contiguous() + intweight = intweight.to(dtype=torch.int32) + + best_device = get_best_device() + + # Avoid: The operator 'aten::__lshift__.Scalar' is not currently implemented for the MPS device + if "mps" in best_device: + intweight = intweight.to("cpu") + + qweight = torch.zeros( + (intweight.shape[0], intweight.shape[1] // 32 * self.bits), + dtype=torch.int32, + device=intweight.device, + ) + + for col in range(intweight.shape[1] // pack_num): + if self.bits == 4: + order_map = [0, 2, 4, 6, 1, 3, 5, 7] + else: + raise NotImplementedError("Only 4-bit are supported for now.") + for i in range(pack_num): + qweight_col = intweight[:, col * pack_num + order_map[i]] + qweight[:, col] |= qweight_col << (i * self.bits) + self.register_buffer("qweight", qweight) + + zeros = zeros.to(dtype=torch.int32, device=best_device) + + if "mps" in best_device: + zeros = zeros.to("cpu") + + qzeros = torch.zeros( + (zeros.shape[0], zeros.shape[1] // 32 * self.bits), + dtype=torch.int32, + device=zeros.device, + ) + + for col in range(zeros.shape[1] // pack_num): + if self.bits == 4: + order_map = [0, 2, 4, 6, 1, 3, 5, 7] + else: + raise NotImplementedError("Only 4-bit are supported for now.") + for i in range(pack_num): + qzero_col = zeros[:, col * pack_num + order_map[i]] + qzeros[:, col] |= qzero_col << (i * self.bits) + self.register_buffer("qzeros", qzeros) + __all__ = [ diff --git a/gptqmodel/utils/model.py b/gptqmodel/utils/model.py index 3d8343409..b3fa099e8 100644 --- a/gptqmodel/utils/model.py +++ b/gptqmodel/utils/model.py @@ -724,6 +724,19 @@ def pack_module( module_name=name, ): module.pack(linear=layer, scales=q_scales, s_extra=q_scales_extra) + if quant_linear_cls.QUANT_TYPE.startswith("awq_"): + packer_label = "module.pack" + with log_time_block( + packer_label, + logger=log, + module_name=name, + ): + module.pack( + linear=layer, + scales=q_scales, + zeros=q_zeros, + g_idx=q_g_idx, + ) else: effective_impl = (pack_impl or "original").lower() From 969c490ef5908af2deff8e25a76f8a630db3fc30 Mon Sep 17 00:00:00 2001 From: Qubitium Date: Tue, 25 Nov 2025 02:39:28 +0000 Subject: [PATCH 20/24] for triton the dense matmul threshold is 128 --- gptqmodel/nn_modules/qlinear/gemm_awq_triton.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gptqmodel/nn_modules/qlinear/gemm_awq_triton.py b/gptqmodel/nn_modules/qlinear/gemm_awq_triton.py index a09409b23..be2fec5d9 100644 --- a/gptqmodel/nn_modules/qlinear/gemm_awq_triton.py +++ b/gptqmodel/nn_modules/qlinear/gemm_awq_triton.py @@ -41,7 +41,7 @@ def forward( return torch.zeros(out_shape, dtype=x.dtype, device=x.device) # Above compute density threshold it is faster to just dequantize the whole thing and do simple matmul - FULL_DEQUANT_MATMUL_THRESHOLD = x.shape[0] * x.shape[1] > 1024 + FULL_DEQUANT_MATMUL_THRESHOLD = x.shape[0] * x.shape[1] > 128 if FULL_DEQUANT_MATMUL_THRESHOLD: out = awq_dequantize_triton(qweight, scales, qzeros) out = torch.matmul(x, out.to(x.dtype)) From ac64c64218742cb273f62ded8ec878d44201336a Mon Sep 17 00:00:00 2001 From: ZX-ModelCloud Date: Tue, 25 Nov 2025 10:50:38 +0800 Subject: [PATCH 21/24] add gemv_awq pack() Signed-off-by: ZX-ModelCloud --- gptqmodel/looper/awq_processor.py | 4 +- gptqmodel/nn_modules/qlinear/gemv_awq.py | 67 ++++++ .../quantization/awq/modules/linear/gemv.py | 204 ------------------ 3 files changed, 69 insertions(+), 206 deletions(-) delete mode 100644 gptqmodel/quantization/awq/modules/linear/gemv.py diff --git a/gptqmodel/looper/awq_processor.py b/gptqmodel/looper/awq_processor.py index 723fdc8a5..f7be2a7a1 100644 --- a/gptqmodel/looper/awq_processor.py +++ b/gptqmodel/looper/awq_processor.py @@ -556,7 +556,7 @@ def unwrap(m): if not self.export_compatible: start = time.time() - self._apply_quant(named_childs, start, scales_list) + self.pack_module(named_childs, start, scales_list) with state.lock: state.quantized = True @@ -1071,7 +1071,7 @@ def _slice_value(val, length): return module_output - def _apply_quant(self, named_linears: Dict[str, NamedModule], start_time, scales_list): + def pack_module(self, named_linears: Dict[str, NamedModule], start_time, scales_list): for name, named_module in named_linears.items(): self.pb.title(f"Quantizing {named_module.name} in layer ").draw() linear_layer = named_module.module diff --git a/gptqmodel/nn_modules/qlinear/gemv_awq.py b/gptqmodel/nn_modules/qlinear/gemv_awq.py index 150ad3b69..499e646a0 100644 --- a/gptqmodel/nn_modules/qlinear/gemv_awq.py +++ b/gptqmodel/nn_modules/qlinear/gemv_awq.py @@ -4,6 +4,7 @@ # Contact: qubitium@modelcloud.ai, x.com/qubitium import torch +from torch import nn from ...adapter.adapter import Adapter, Lora from ...models._const import DEVICE, PLATFORM @@ -148,6 +149,72 @@ def forward(self, x: torch.Tensor): return out.reshape(out_shape) + def pack(self, linear: nn.Module, scales: torch.Tensor, zeros: torch.Tensor, g_idx: torch.Tensor=None): + # need scales and zeros info for real quantization + assert scales is not None and zeros is not None + scale_zeros = zeros * scales + + pack_num = 32 // self.bits + qscales = torch.zeros( + ( + scales.shape[0], + calculate_zeros_width(linear.in_features, self.group_size) * pack_num, + ), + dtype=torch.float16, + device=scales.device, + ) + qscales[:, : scales.shape[1]] = scales + self.register_buffer("scales", qscales) + if linear.bias is not None: + self.register_buffer("bias", linear.bias.clone().half()) + else: + self.bias = None + + intweight = [] + for idx in range(self.in_features): + intweight.append( + torch.round( + (linear.weight.data[:, idx] + scale_zeros[:, idx // self.group_size]) + / self.scales[:, idx // self.group_size] + ).to(torch.int)[:, None] + ) + intweight = torch.cat(intweight, dim=1) + intweight = intweight.to(dtype=torch.int32) + qweight = torch.zeros( + (intweight.shape[0], intweight.shape[1] // 32 * self.bits), + dtype=torch.int32, + device=intweight.device, + ) + + for col in range(intweight.shape[1] // pack_num): + if self.bits == 4: + order_map = [0, 1, 2, 3, 4, 5, 6, 7] + else: + raise NotImplementedError("Only 4-bit are supported for now.") + for i in range(pack_num): + qweight_col = intweight[:, col * pack_num + order_map[i]] + qweight[:, col] |= qweight_col << (i * self.bits) + self.register_buffer("qweight", qweight) + + zeros = zeros.to(dtype=torch.int32) + qzeros = torch.zeros( + (zeros.shape[0], calculate_zeros_width(linear.in_features, self.group_size)), + dtype=torch.int32, + device=zeros.device, + ) + + for col in range((zeros.shape[1] + pack_num - 1) // pack_num): + if self.bits == 4: + order_map = [0, 1, 2, 3, 4, 5, 6, 7] + else: + raise NotImplementedError("Only 4-bit are supported for now.") + for i in range(pack_num): + if col * pack_num + order_map[i] >= zeros.shape[1]: + continue + qzero_col = zeros[:, col * pack_num + order_map[i]] + qzeros[:, col] |= qzero_col << (i * self.bits) + self.register_buffer("qzeros", qzeros) + def extra_repr(self) -> str: return ( "in_features={}, out_features={}, bias={}, bits={}, group_size={}".format( diff --git a/gptqmodel/quantization/awq/modules/linear/gemv.py b/gptqmodel/quantization/awq/modules/linear/gemv.py deleted file mode 100644 index c62863289..000000000 --- a/gptqmodel/quantization/awq/modules/linear/gemv.py +++ /dev/null @@ -1,204 +0,0 @@ -# SPDX-FileCopyrightText: 2024-2025 ModelCloud.ai -# SPDX-FileCopyrightText: 2024-2025 qubitium@modelcloud.ai -# SPDX-License-Identifier: Apache-2.0 -# Contact: qubitium@modelcloud.ai, x.com/qubitium - - -import torch -import torch.nn as nn - -from gptqmodel.quantization.awq.utils.module import try_import - - -awq_ext, msg = try_import("gptqmodel_awq_kernels") - -def make_divisible(c, divisor): - return (c + divisor - 1) // divisor - - -def calculate_zeros_width(in_features, group_size=128, pack_num=8): - if group_size >= 128: - size_multiplier = 1 - elif group_size == 64: - size_multiplier = 2 - elif group_size == 32: - size_multiplier = 4 - else: - raise NotImplementedError - - base_width = make_divisible(in_features // group_size, pack_num) - base_width = make_divisible(base_width, size_multiplier) * size_multiplier - return base_width - - -class WQLinear_GEMV(nn.Module): - def __init__(self, w_bit, group_size, in_features, out_features, bias, dev): - super().__init__() - - if w_bit not in [4]: - raise NotImplementedError("Only 4-bit are supported for now.") - - self.in_features = in_features - self.out_features = out_features - self.w_bit = w_bit - self.group_size = group_size if group_size != -1 else in_features - self.split_k_iters = 8 - - # quick sanity check (make sure aligment) - assert self.in_features % self.group_size == 0 - assert out_features % (32 // self.w_bit) == 0 - pack_num = 32 // self.w_bit - - self.register_buffer( - "qweight", - torch.zeros( - (out_features, in_features // pack_num), dtype=torch.int32, device=dev - ), - ) - self.register_buffer( - "qzeros", - torch.zeros( - (out_features, calculate_zeros_width(in_features, self.group_size)), - dtype=torch.int32, - device=dev, - ), - ) - self.register_buffer( - "scales", - torch.zeros( - ( - out_features, - calculate_zeros_width(in_features, self.group_size) * pack_num, - ), - dtype=torch.float16, - device=dev, - ), - ) - if bias: - self.register_buffer( - "bias", torch.zeros((out_features), dtype=torch.float16, device=dev) - ) - else: - self.bias = None - - @classmethod - def from_linear( - cls, linear, w_bit, group_size, init_only=False, scales=None, zeros=None - ): - awq_linear = cls( - w_bit, - group_size, - linear.in_features, - linear.out_features, - linear.bias is not None, - linear.weight.device, - ) - if init_only: # just prepare for loading sd - return awq_linear - - # need scales and zeros info for real quantization - assert scales is not None and zeros is not None - scale_zeros = zeros * scales - - pack_num = 32 // awq_linear.w_bit - qscales = torch.zeros( - ( - scales.shape[0], - calculate_zeros_width(linear.in_features, group_size) * pack_num, - ), - dtype=torch.float16, - device=scales.device, - ) - qscales[:, : scales.shape[1]] = scales - awq_linear.scales = qscales - if linear.bias is not None: - awq_linear.bias = linear.bias.clone().half() - - intweight = [] - for idx in range(awq_linear.in_features): - intweight.append( - torch.round( - (linear.weight.data[:, idx] + scale_zeros[:, idx // group_size]) - / awq_linear.scales[:, idx // group_size] - ).to(torch.int)[:, None] - ) - intweight = torch.cat(intweight, dim=1) - intweight = intweight.to(dtype=torch.int32) - qweight = torch.zeros( - (intweight.shape[0], intweight.shape[1] // 32 * awq_linear.w_bit), - dtype=torch.int32, - device=intweight.device, - ) - - for col in range(intweight.shape[1] // pack_num): - if awq_linear.w_bit == 4: - order_map = [0, 1, 2, 3, 4, 5, 6, 7] - else: - raise NotImplementedError("Only 4-bit are supported for now.") - for i in range(pack_num): - qweight_col = intweight[:, col * pack_num + order_map[i]] - qweight[:, col] |= qweight_col << (i * awq_linear.w_bit) - awq_linear.qweight = qweight - - zeros = zeros.to(dtype=torch.int32) - qzeros = torch.zeros( - (zeros.shape[0], calculate_zeros_width(linear.in_features, group_size)), - dtype=torch.int32, - device=zeros.device, - ) - - for col in range((zeros.shape[1] + pack_num - 1) // pack_num): - if awq_linear.w_bit == 4: - order_map = [0, 1, 2, 3, 4, 5, 6, 7] - else: - raise NotImplementedError("Only 4-bit are supported for now.") - for i in range(pack_num): - if col * pack_num + order_map[i] >= zeros.shape[1]: - continue - qzero_col = zeros[:, col * pack_num + order_map[i]] - qzeros[:, col] |= qzero_col << (i * awq_linear.w_bit) - awq_linear.qzeros = qzeros - return awq_linear - - @torch.inference_mode() - def forward(self, x): - if awq_ext is None: - raise ModuleNotFoundError("External AWQ kernels are not properly installed." + msg) - - out_shape = x.shape[:-1] + (self.out_features,) - inputs = x.reshape(-1, x.shape[-1]) - - input_dtype = inputs.dtype - if input_dtype != torch.float16: - inputs = inputs.half() - - if inputs.shape[0] > 8: - out = awq_ext.gemmv2_forward_cuda( - inputs, - self.qweight, - self.scales, - self.qzeros, - self.group_size, - self.split_k_iters, - ) - else: - out = awq_ext.gemv_forward_cuda( - inputs, self.qweight, self.scales, self.qzeros, self.group_size - ) - - if input_dtype != torch.float16: - out = out.to(dtype=input_dtype) - - out = out + self.bias if self.bias is not None else out - return out.reshape(out_shape) - - def extra_repr(self) -> str: - return ( - "in_features={}, out_features={}, bias={}, w_bit={}, group_size={}".format( - self.in_features, - self.out_features, - self.bias is not None, - self.w_bit, - self.group_size, - ) - ) From 49c4bc519648f28a5cc083e24aa63efbe7586343 Mon Sep 17 00:00:00 2001 From: ZX-ModelCloud Date: Tue, 25 Nov 2025 14:11:28 +0800 Subject: [PATCH 22/24] add gemv_fast pack() Signed-off-by: ZX-ModelCloud --- gptqmodel/nn_modules/qlinear/gemv_fast_awq.py | 95 ++++++++ .../awq/modules/linear/gemv_fast.py | 215 ------------------ 2 files changed, 95 insertions(+), 215 deletions(-) delete mode 100644 gptqmodel/quantization/awq/modules/linear/gemv_fast.py diff --git a/gptqmodel/nn_modules/qlinear/gemv_fast_awq.py b/gptqmodel/nn_modules/qlinear/gemv_fast_awq.py index e12337046..922caf013 100644 --- a/gptqmodel/nn_modules/qlinear/gemv_fast_awq.py +++ b/gptqmodel/nn_modules/qlinear/gemv_fast_awq.py @@ -4,6 +4,7 @@ # Contact: qubitium@modelcloud.ai, x.com/qubitium import torch +from torch import nn from ...adapter.adapter import Adapter, Lora from ...models._const import DEVICE, PLATFORM @@ -18,6 +19,49 @@ awq_v2_ext, msg = try_import("gptqmodel_awq_v2_kernels") + +def pack_intweight(unpacked_qweight, interleave, kstride): + # unpacked_qweight: [N, K] + N = unpacked_qweight.shape[0] + K = unpacked_qweight.shape[1] + + Packed_Kernel = unpacked_qweight.cpu().numpy().reshape(N, K // 32, 32) + # np.arange(32).reshape(4, 4, 2).transpose(1, 0, 2) => [0, 1, 8, 9, 16, 17, 24, 25, ...] + Packed_Kernel = Packed_Kernel.reshape(N, K // 32, 4, 4, 2).transpose(0, 1, 3, 2, 4) + Packed_Kernel = Packed_Kernel.reshape(N, K // 32, 32) + + # reorder each 8 weights for fast dequantization + # [0, 1, 2, 3, 4, 5, 6, 7] => [0, 2, 4, 6, 1, 3, 5, 7] + Packed_Kernel = Packed_Kernel.reshape(N, K // 32, 4, 8) + Packed_Kernel = Packed_Kernel.reshape(N, K // 32, 4, 4, 2).transpose(0, 1, 2, 4, 3) + Packed_Kernel = Packed_Kernel.reshape(N, K) + + # interleaving every four rows + Packed_Kernel = Packed_Kernel.reshape( + N // interleave, interleave, K // kstride, kstride + ) + # N // 4, K // 64, 4, 64 + Packed_Kernel = Packed_Kernel.transpose(0, 2, 1, 3) + Packed_Kernel = Packed_Kernel.reshape( + N // interleave, K // kstride, kstride, interleave + ) + # Packing -> (N // 4, K // 64, 64) + Packed_Kernel = ( + Packed_Kernel[..., 0] + | (Packed_Kernel[..., 1] << 4) + | (Packed_Kernel[..., 2] << 8) + | (Packed_Kernel[..., 3] << 12) + ) + # reshape to (N // 4, K), FP16 format + Packed_Kernel = Packed_Kernel.reshape(N // interleave, K) + qweight = ( + torch.tensor(Packed_Kernel.astype("int16")) + .to(unpacked_qweight.device) + .contiguous() + ) + return qweight + + class AwqGEMVFastQuantLinear(AWQuantLinear): SUPPORTS_BITS = [4] SUPPORTS_GROUP_SIZE = [-1, 16, 32, 64, 128] @@ -120,6 +164,10 @@ def forward(self, x: torch.Tensor): inputs = x batch_size, n_tokens, _ = inputs.shape + input_dtype = inputs.dtype + if input_dtype != torch.float16: + inputs = inputs.half() + if batch_size < 8 and n_tokens == 1: out = awq_v2_ext.gemv_forward_cuda_decode( inputs, @@ -135,6 +183,10 @@ def forward(self, x: torch.Tensor): out = awq_v2_ext.gemm_forward_cuda_prefill( inputs, self.qweight, self.scales, self.qzeros ) + + if input_dtype != torch.float16: + out = out.to(dtype=input_dtype) + out = out + self.bias if self.bias is not None else out if self.adapter: @@ -142,6 +194,49 @@ def forward(self, x: torch.Tensor): return out + def pack(self, linear: nn.Module, scales: torch.Tensor, zeros: torch.Tensor, g_idx: torch.Tensor=None): + # need scales and zeros info for real quantization + assert scales is not None and zeros is not None + scale_zeros = zeros * scales + + pack_num = 32 // self.bits + qscales = torch.zeros( + ( + scales.shape[0], + calculate_zeros_width(linear.in_features, self.group_size) * pack_num, + ), + dtype=torch.float16, + device=scales.device, + ) + qscales[:, : scales.shape[1]] = scales + self.register_buffer("scales", qscales.transpose(1, 0).contiguous()) + if linear.bias is not None: + self.register_buffer("bias", linear.bias.clone().half()) + else: + self.bias = None + + intweight = [] + for idx in range(self.in_features): + intweight.append( + torch.round( + (linear.weight.data[:, idx] + scale_zeros[:, idx // self.group_size]) + / qscales[:, idx // self.group_size] + ).to(torch.int)[:, None] + ) + intweight = torch.cat(intweight, dim=1) + intweight = intweight.to(dtype=torch.int32) + self.register_buffer("qweight", pack_intweight( + intweight.contiguous(), interleave=4, kstride=64 + )) + + zeros = zeros.to(dtype=torch.int32) + qzeros = torch.zeros_like(qscales) + + qzeros[:, : scales.shape[1]] = -( + qscales[:, : scales.shape[1]] * (zeros.to(torch.float32)) + ).to(torch.float16) + self.register_buffer("qzeros", qzeros.transpose(1, 0).contiguous()) + def extra_repr(self) -> str: return ( "in_features={}, out_features={}, bias={}, bits={}, group_size={}".format( diff --git a/gptqmodel/quantization/awq/modules/linear/gemv_fast.py b/gptqmodel/quantization/awq/modules/linear/gemv_fast.py deleted file mode 100644 index 756e725ac..000000000 --- a/gptqmodel/quantization/awq/modules/linear/gemv_fast.py +++ /dev/null @@ -1,215 +0,0 @@ -# SPDX-FileCopyrightText: 2024-2025 ModelCloud.ai -# SPDX-FileCopyrightText: 2024-2025 qubitium@modelcloud.ai -# SPDX-License-Identifier: Apache-2.0 -# Contact: qubitium@modelcloud.ai, x.com/qubitium - - -import torch - -from gptqmodel.quantization.awq.utils.module import try_import - - -awq_v2_ext, msg = try_import("gptqmodel_awq_v2_kernels") - -def make_divisible(c, divisor): - return (c + divisor - 1) // divisor - - -def calculate_zeros_width(in_features, group_size=128, pack_num=8): - if group_size >= 128: - size_multiplier = 1 - elif group_size == 64: - size_multiplier = 2 - elif group_size == 32: - size_multiplier = 4 - else: - raise NotImplementedError - - base_width = make_divisible(in_features // group_size, pack_num) - base_width = make_divisible(base_width, size_multiplier) * size_multiplier - return base_width - - -def pack_intweight(unpacked_qweight, interleave, kstride): - # unpacked_qweight: [N, K] - N = unpacked_qweight.shape[0] - K = unpacked_qweight.shape[1] - - Packed_Kernel = unpacked_qweight.cpu().numpy().reshape(N, K // 32, 32) - # np.arange(32).reshape(4, 4, 2).transpose(1, 0, 2) => [0, 1, 8, 9, 16, 17, 24, 25, ...] - Packed_Kernel = Packed_Kernel.reshape(N, K // 32, 4, 4, 2).transpose(0, 1, 3, 2, 4) - Packed_Kernel = Packed_Kernel.reshape(N, K // 32, 32) - - # reorder each 8 weights for fast dequantization - # [0, 1, 2, 3, 4, 5, 6, 7] => [0, 2, 4, 6, 1, 3, 5, 7] - Packed_Kernel = Packed_Kernel.reshape(N, K // 32, 4, 8) - Packed_Kernel = Packed_Kernel.reshape(N, K // 32, 4, 4, 2).transpose(0, 1, 2, 4, 3) - Packed_Kernel = Packed_Kernel.reshape(N, K) - - # interleaving every four rows - Packed_Kernel = Packed_Kernel.reshape( - N // interleave, interleave, K // kstride, kstride - ) - # N // 4, K // 64, 4, 64 - Packed_Kernel = Packed_Kernel.transpose(0, 2, 1, 3) - Packed_Kernel = Packed_Kernel.reshape( - N // interleave, K // kstride, kstride, interleave - ) - # Packing -> (N // 4, K // 64, 64) - Packed_Kernel = ( - Packed_Kernel[..., 0] - | (Packed_Kernel[..., 1] << 4) - | (Packed_Kernel[..., 2] << 8) - | (Packed_Kernel[..., 3] << 12) - ) - # reshape to (N // 4, K), FP16 format - Packed_Kernel = Packed_Kernel.reshape(N // interleave, K) - qweight = ( - torch.tensor(Packed_Kernel.astype("int16")) - .to(unpacked_qweight.device) - .contiguous() - ) - return qweight - - -class WQLinear_GEMVFast(torch.nn.Module): - def __init__(self, w_bit, group_size, in_features, out_features, bias, dev): - super().__init__() - - self.in_features = in_features - self.out_features = out_features - self.w_bit = w_bit - self.group_size = group_size if group_size != -1 else in_features - self.split_k_iters = 8 - self.interleave = 4 - - # quick sanity check (make sure aligment) - assert self.in_features % self.group_size == 0 - assert out_features % (32 // self.w_bit) == 0 - pack_num = 32 // self.w_bit - int16_pack_num = 16 // self.w_bit - - assert out_features % (self.interleave) == 0 - self.register_buffer( - "qweight", - torch.zeros( - ( - out_features // self.interleave, - in_features // int16_pack_num * self.interleave, - ), - dtype=torch.int16, - device=dev, - ), - ) - self.register_buffer( - "scales", - torch.zeros( - ( - calculate_zeros_width(in_features, self.group_size) * pack_num, - out_features, - ), - dtype=torch.float16, - device=dev, - ), - ) - self.register_buffer( - "qzeros", - torch.zeros( - ( - calculate_zeros_width(in_features, self.group_size) * pack_num, - out_features, - ), - dtype=torch.float16, - device=dev, - ), - ) - - if bias: - self.register_buffer( - "bias", torch.zeros((out_features), dtype=torch.float16, device=dev) - ) - else: - self.bias = None - - @classmethod - def from_linear( - cls, linear, w_bit, group_size, init_only=False, scales=None, zeros=None - ): - awq_linear = cls( - w_bit, - group_size, - linear.in_features, - linear.out_features, - linear.bias is not None, - linear.weight.device, - ) - if init_only: - return awq_linear - - # need scales and zeros info for real quantization - assert scales is not None and zeros is not None - scale_zeros = zeros * scales - - pack_num = 32 // awq_linear.w_bit - qscales = torch.zeros( - ( - scales.shape[0], - calculate_zeros_width(linear.in_features, group_size) * pack_num, - ), - dtype=torch.float16, - device=scales.device, - ) - qscales[:, : scales.shape[1]] = scales - # awq_linear.scales = scales.clone().half() - awq_linear.scales = qscales.transpose(1, 0).contiguous() - if linear.bias is not None: - awq_linear.bias = linear.bias.clone().half() - - intweight = [] - for idx in range(awq_linear.in_features): - intweight.append( - torch.round( - (linear.weight.data[:, idx] + scale_zeros[:, idx // group_size]) - / qscales[:, idx // group_size] - ).to(torch.int)[:, None] - ) - intweight = torch.cat(intweight, dim=1) - intweight = intweight.to(dtype=torch.int32) - awq_linear.qweight = pack_intweight( - intweight.contiguous(), interleave=4, kstride=64 - ) - - zeros = zeros.to(dtype=torch.int32) - qzeros = torch.zeros_like(qscales) - - qzeros[:, : scales.shape[1]] = -( - qscales[:, : scales.shape[1]] * (zeros.to(torch.float32)) - ).to(torch.float16) - awq_linear.qzeros = qzeros.transpose(1, 0).contiguous() - - return awq_linear - - @torch.inference_mode() - def forward(self, x): - if awq_v2_ext is None: - raise ModuleNotFoundError("External AWQ V2 kernels are not properly installed." + msg) - inputs = x - batch_size, n_tokens, _ = inputs.shape - if batch_size < 8 and n_tokens == 1: - out = awq_v2_ext.gemv_forward_cuda_decode( - inputs, - self.qweight, - self.scales, - self.qzeros, - inputs.numel() // inputs.shape[-1], - self.out_features, - self.in_features, - self.group_size, - ) - else: - out = awq_v2_ext.gemm_forward_cuda_prefill( - inputs, self.qweight, self.scales, self.qzeros - ) - out = out + self.bias if self.bias is not None else out - - return out From 65bb7d45dd1506b5bac4df66c9f519a818c5a658 Mon Sep 17 00:00:00 2001 From: ZX-ModelCloud Date: Tue, 25 Nov 2025 14:15:19 +0800 Subject: [PATCH 23/24] remove gptqmodel/quantization/awq/modules/linear/ Signed-off-by: ZX-ModelCloud --- .../awq/modules/linear/__init__.py | 8 - .../awq/modules/linear/exllama.py | 135 ---------- .../awq/modules/linear/exllamav2.py | 160 ------------ .../quantization/awq/modules/linear/marlin.py | 237 ------------------ 4 files changed, 540 deletions(-) delete mode 100644 gptqmodel/quantization/awq/modules/linear/__init__.py delete mode 100644 gptqmodel/quantization/awq/modules/linear/exllama.py delete mode 100644 gptqmodel/quantization/awq/modules/linear/exllamav2.py delete mode 100644 gptqmodel/quantization/awq/modules/linear/marlin.py diff --git a/gptqmodel/quantization/awq/modules/linear/__init__.py b/gptqmodel/quantization/awq/modules/linear/__init__.py deleted file mode 100644 index 8f114b9f0..000000000 --- a/gptqmodel/quantization/awq/modules/linear/__init__.py +++ /dev/null @@ -1,8 +0,0 @@ -# SPDX-FileCopyrightText: 2024-2025 ModelCloud.ai -# SPDX-FileCopyrightText: 2024-2025 qubitium@modelcloud.ai -# SPDX-License-Identifier: Apache-2.0 -# Contact: qubitium@modelcloud.ai, x.com/qubitium - -from .gemv import WQLinear_GEMV -from .gemv_fast import WQLinear_GEMVFast -from .marlin import WQLinear_Marlin diff --git a/gptqmodel/quantization/awq/modules/linear/exllama.py b/gptqmodel/quantization/awq/modules/linear/exllama.py deleted file mode 100644 index 97c6f14da..000000000 --- a/gptqmodel/quantization/awq/modules/linear/exllama.py +++ /dev/null @@ -1,135 +0,0 @@ -# SPDX-FileCopyrightText: 2024-2025 ModelCloud.ai -# SPDX-FileCopyrightText: 2024-2025 qubitium@modelcloud.ai -# SPDX-License-Identifier: Apache-2.0 -# Contact: qubitium@modelcloud.ai, x.com/qubitium - -import torch -import torch.nn as nn - -from gptqmodel.quantization.awq.utils.module import try_import -from gptqmodel.quantization.awq.utils.packing_utils import unpack_reorder_pack - - -exl_ext, msg = try_import("gptqmodel_exl_kernels") - -# Dummy tensor to pass instead of g_idx since there is no way to pass "None" to a C++ extension -none_tensor = torch.empty((1, 1), device="meta") - - -class WQLinear_Exllama(nn.Module): - def __init__(self, w_bit, group_size, in_features, out_features, bias, dev): - super().__init__() - - if w_bit not in [4]: - raise NotImplementedError("Only 4-bit are supported for Exllama kernels") - - self.q4 = None - - self.w_bit = w_bit - self.in_features = in_features - self.out_features = out_features - self.group_size = group_size if group_size != -1 else in_features - - ################################################################################## - ## These shapes are only for compatibility with the state_dict of WQLinear_GEMM ## - self.register_buffer( - "qweight", - torch.zeros( - (in_features, out_features // (32 // self.w_bit)), - dtype=torch.int32, - device=dev, - ), - ) - self.register_buffer( - "qzeros", - torch.zeros( - (in_features // self.group_size, out_features // (32 // self.w_bit)), - dtype=torch.int32, - device=dev, - ), - ) - ################################################################################## - - self.register_buffer( - "scales", - torch.zeros( - (in_features // self.group_size, out_features), - dtype=torch.float16, - device=dev, - ), - ) - if bias: - self.register_buffer( - "bias", - torch.zeros( - (out_features), - dtype=torch.float16, - device=dev, - ), - ) - else: - self.bias = None - - def post_init(self): - assert self.qweight.device.type == "cuda" - assert self.qweight.device.index is not None - - self.qweight, self.qzeros = unpack_reorder_pack( - self.qweight, self.qzeros, self.w_bit - ) - self.q4 = exl_ext.make_q4( - self.qweight, - self.qzeros, - self.scales, - none_tensor, # g_idx - self.qweight.device.index, # device index - ) - - @classmethod - def from_linear( - cls, linear, w_bit, group_size, init_only=False, scales=None, zeros=None - ): - awq_linear = cls( - w_bit, - group_size, - linear.in_features, - linear.out_features, - linear.bias is not None, - linear.weight.device, - ) - if init_only: # just prepare for loading sd - return awq_linear - - raise NotImplementedError("Only inference is supported for Exllama kernels") - - def forward(self, x): - assert self.q4 is not None, ( - "module.post_init() must be called before module.forward(). " - "Use exllama_post_init() on the whole model." - ) - if exl_ext is None: - raise ModuleNotFoundError("External ExLlama kernels are not properly installed." + msg) - - input_dtype = x.dtype - out_shape = x.shape[:-1] + (self.out_features,) - - if input_dtype != torch.float16: - x = x.to(dtype=torch.float16) - - x = x.view(-1, x.shape[-1]) - - out = torch.empty( - (x.shape[0], self.out_features), - dtype=torch.float16, - device=x.device, - ) - exl_ext.q4_matmul(x, self.q4, out) - - if input_dtype != torch.float16: - out = out.to(dtype=input_dtype) - - if self.bias is not None: - out.add_(self.bias) - - return out.view(out_shape) - diff --git a/gptqmodel/quantization/awq/modules/linear/exllamav2.py b/gptqmodel/quantization/awq/modules/linear/exllamav2.py deleted file mode 100644 index a529d3876..000000000 --- a/gptqmodel/quantization/awq/modules/linear/exllamav2.py +++ /dev/null @@ -1,160 +0,0 @@ -# SPDX-FileCopyrightText: 2024-2025 ModelCloud.ai -# SPDX-FileCopyrightText: 2024-2025 qubitium@modelcloud.ai -# SPDX-License-Identifier: Apache-2.0 -# Contact: qubitium@modelcloud.ai, x.com/qubitium - - -import torch -import torch.nn as nn - -from gptqmodel.quantization.awq.utils.module import try_import -from gptqmodel.quantization.awq.utils.packing_utils import unpack_reorder_pack - - -exlv2_ext, msg = try_import("gptqmodel_exlv2_kernels") - -# Dummy tensor to pass instead of g_idx since there is no way to pass "None" to a C++ extension -none_tensor = torch.empty((1, 1), device="meta") - - -class WQLinear_ExllamaV2(nn.Module): - def __init__(self, w_bit, group_size, in_features, out_features, bias, dev): - super().__init__() - - if w_bit not in [4]: - raise NotImplementedError("Only 4-bit are supported for now.") - - self.q_handle = None - - self.w_bit = w_bit - self.in_features = in_features - self.out_features = out_features - self.group_size = group_size if group_size != -1 else in_features - - ################################################################################## - ## These shapes are only for compatibility with the state_dict of WQLinear_GEMM ## - self.register_buffer( - "qweight", - torch.zeros( - (in_features, out_features // (32 // self.w_bit)), - dtype=torch.int32, - device=dev, - ), - ) - self.register_buffer( - "qzeros", - torch.zeros( - (in_features // self.group_size, out_features // (32 // self.w_bit)), - dtype=torch.int32, - device=dev, - ), - ) - ################################################################################## - - self.register_buffer( - "scales", - torch.zeros( - (in_features // self.group_size, out_features), - dtype=torch.float16, - device=dev, - ), - ) - if bias: - self.register_buffer( - "bias", - torch.zeros( - (out_features), - dtype=torch.float16, - device=dev, - ), - ) - else: - self.bias = None - - def post_init(self, scratch_space: "ScratchSpace"): - assert self.qweight.device.type == "cuda" - assert self.qweight.device.index is not None - - self.qweight, self.qzeros = unpack_reorder_pack( - self.qweight, self.qzeros, self.w_bit - ) - - temp_dq_size = self.temp_dq_size() - temp_dq = scratch_space.get_slice(temp_dq_size) - self.q_handle = exlv2_ext.make_q_matrix( - self.qweight, - none_tensor, - none_tensor, - none_tensor, - none_tensor, - none_tensor, - self.qzeros, - self.scales, - none_tensor, - temp_dq, - ) - - @classmethod - def from_linear( - cls, linear, w_bit, group_size, init_only=False, scales=None, zeros=None - ): - awq_linear = cls( - w_bit, - group_size, - linear.in_features, - linear.out_features, - linear.bias is not None, - linear.weight.device, - ) - if init_only: # just prepare for loading sd - return awq_linear - - raise NotImplementedError("Only inference is supported for ExllamaV2 kernels") - - def temp_dq_size(self): - """ - Returns the size of the temporary buffer required for the dq kernel. - """ - return self.in_features * self.out_features * 2 + 128 - - def temp_fwd_size(self, max_input_len, max_batch_size): - """ - Returns the size of the temporary buffer required for the fwd kernel. - """ - return self.out_features * max_input_len * max_batch_size * 4 + 128 - - def scratch_space_fixed(self, max_input_len=2048, max_batch_size=8): - """ - Returns the size of the fixed scratch space required for the kernel. - """ - return self.temp_dq_size() + self.temp_fwd_size(max_input_len, max_batch_size) - - def forward(self, x): - assert self.q_handle is not None, ( - "module.post_init() must be called before module.forward(). " - ) - if exlv2_ext is None: - raise ModuleNotFoundError("External ExLlamaV2 kernels are not properly installed." + msg) - - input_dtype = x.dtype - out_shape = x.shape[:-1] + (self.out_features,) - - if input_dtype != torch.float16: - x = x.to(dtype=torch.float16) - - x = x.view(-1, x.shape[-1]) - - out = torch.empty( - (x.shape[0], self.out_features), - dtype=torch.float16, - device=x.device, - ) - exlv2_ext.gemm_half_q_half(x, self.q_handle, out, False) - - if input_dtype != torch.float16: - out = out.to(dtype=input_dtype) - - if self.bias is not None: - out.add_(self.bias) - - return out.view(out_shape) diff --git a/gptqmodel/quantization/awq/modules/linear/marlin.py b/gptqmodel/quantization/awq/modules/linear/marlin.py deleted file mode 100644 index a852c426d..000000000 --- a/gptqmodel/quantization/awq/modules/linear/marlin.py +++ /dev/null @@ -1,237 +0,0 @@ -# SPDX-FileCopyrightText: 2024-2025 ModelCloud.ai -# SPDX-FileCopyrightText: 2024-2025 qubitium@modelcloud.ai -# SPDX-License-Identifier: Apache-2.0 -# Contact: qubitium@modelcloud.ai, x.com/qubitium - -import numpy as np -import torch -import torch.nn as nn - -from gptqmodel.quantization.awq.utils.module import try_import - - -marlin_cuda, msg = try_import("marlin_cuda") - -def _get_perms(): - perm = [] - for i in range(32): - perm1 = [] - col = i // 4 - for block in [0, 1]: - for row in [ - 2 * (i % 4), - 2 * (i % 4) + 1, - 2 * (i % 4 + 4), - 2 * (i % 4 + 4) + 1, - ]: - perm1.append(16 * row + col + 8 * block) - - for j in range(4): - perm.extend([p + 256 * j for p in perm1]) - - perm = np.array(perm) - interleave = np.array([0, 2, 4, 6, 1, 3, 5, 7]) - perm = perm.reshape((-1, 8))[:, interleave].ravel() - perm = torch.from_numpy(perm) - scale_perm = [] - for i in range(8): - scale_perm.extend([i + 8 * j for j in range(8)]) - scale_perm_single = [] - for i in range(4): - scale_perm_single.extend([2 * i + j for j in [0, 1, 8, 9, 16, 17, 24, 25]]) - return perm, scale_perm, scale_perm_single - - -_perm, _scale_perm, _scale_perm_single = _get_perms() - - -class WQLinear_Marlin(nn.Module): - def __init__(self, w_bit, group_size, in_features, out_features, bias, dev): - super().__init__() - - if w_bit not in [4]: - raise NotImplementedError("Only 4-bit are supported for now.") - - self.w_bit = w_bit - self.in_features = in_features - self.out_features = out_features - self.group_size = group_size if group_size != -1 else in_features - self.max_par = 8 # partitioning for large inputs - - # quick sanity check (make sure aligment) - assert self.in_features % self.group_size == 0 - assert out_features % (32 // self.w_bit) == 0 - - ###################################################### - ## These shapes are only specific for Marlin models ## - self.register_buffer( - "qweight", - torch.zeros( - (in_features // 16, out_features * 16 // 8), - dtype=torch.int32, - device=dev, - ), - ) - self.register_buffer( - "scales", - torch.zeros( - (in_features // group_size, out_features), - dtype=torch.float16, - device=dev, - ), - ) - ###################################################### - - if bias: - self.register_buffer( - "bias", - torch.zeros( - (out_features), - dtype=torch.float16, - device=dev, - ), - ) - else: - self.bias = None - - @classmethod - def from_linear( - cls, - linear, - w_bit, - group_size, - init_only=False, - scales=None, - zeros=None, - ): - awq_linear = cls( - w_bit, - group_size, - linear.in_features, - linear.out_features, - linear.bias is not None, - linear.weight.device, - ) - if init_only: # just prepare for loading sd - return awq_linear - - assert zeros is None and scales is not None - - tile = 16 - maxq = 2**4 - 1 - s = scales.t() - w = linear.weight.data.t() - if awq_linear.group_size != awq_linear.in_features: - w = w.reshape((-1, awq_linear.group_size, awq_linear.out_features)) - w = w.permute(1, 0, 2) - w = w.reshape((awq_linear.group_size, -1)) - s = s.reshape((1, -1)) - w = torch.round(w / s).int() - w += (maxq + 1) // 2 - w = torch.clamp(w, 0, maxq) - if awq_linear.group_size != awq_linear.in_features: - w = w.reshape((awq_linear.group_size, -1, awq_linear.out_features)) - w = w.permute(1, 0, 2) - w = w.reshape( - (awq_linear.in_features, awq_linear.out_features) - ).contiguous() - s = s.reshape((-1, len(_scale_perm)))[:, _scale_perm] - else: - s = s.reshape((-1, len(_scale_perm_single)))[:, _scale_perm_single] - s = s.reshape((-1, awq_linear.out_features)).contiguous() - w = w.reshape( - ( - awq_linear.in_features // tile, - tile, - awq_linear.out_features // tile, - tile, - ) - ) - w = w.permute((0, 2, 1, 3)) - w = w.reshape((awq_linear.in_features // tile, awq_linear.out_features * tile)) - res = w - res = res.reshape((-1, _perm.numel()))[:, _perm].reshape(res.shape) - q = np.zeros((res.shape[0], res.shape[1] // 8), dtype=np.uint32) - res = res.cpu().numpy().astype(np.uint32) - for i in range(8): - q |= res[:, i::8] << 4 * i - q = torch.from_numpy(q.astype(np.int32)).to(w.device) - awq_linear.qweight[:] = q.to(awq_linear.qweight.device) - awq_linear.scales[:] = s.to(awq_linear.qweight.device) - - if awq_linear.bias is not None: - awq_linear.bias[:] = linear.bias.data.to(awq_linear.bias.device) - - return awq_linear - - def post_init(self): - self.register_buffer( - "workspace", - torch.zeros( - self.out_features // 128 * self.max_par, - dtype=torch.int32, - device=self.qweight.device, - ), - persistent=False, - ) - - @torch.inference_mode() - def forward(self, x): - assert hasattr(self, "workspace"), ( - "module.post_init() must be called before module.forward(). " - "Use marlin_post_init() on the whole model." - ) - if marlin_cuda is None: - raise ModuleNotFoundError("External Marlin kernels are not properly installed." + msg) - - out_shape = x.shape[:-1] + (self.out_features,) - - input_dtype = x.dtype - if input_dtype != torch.float16: - x = x.half() - - x = x.view(-1, x.shape[-1]) - - out = torch.empty( - (x.shape[0], self.out_features), - dtype=torch.float16, - device=x.device, - ) - marlin_cuda.mul( - x, - self.qweight, - out, - self.scales, - self.workspace, - -1, # thread_k - -1, # thread_n - -1, # sms - self.max_par, - ) - - if input_dtype != torch.float16: - out = out.to(dtype=input_dtype) - - if self.bias is not None: - out.add_(self.bias) - - return out.view(out_shape) - - def extra_repr(self) -> str: - return ( - "in_features={}, out_features={}, bias={}, w_bit={}, group_size={}".format( - self.in_features, - self.out_features, - self.bias is not None, - self.w_bit, - self.group_size, - ) - ) - - -def marlin_post_init(model): - for _, submodule in model.named_modules(): - if isinstance(submodule, WQLinear_Marlin): - submodule.post_init() - - return model From 64ec4cd1d41c3a4b82dc339f56602f402778050d Mon Sep 17 00:00:00 2001 From: ZX-ModelCloud Date: Tue, 25 Nov 2025 14:32:23 +0800 Subject: [PATCH 24/24] format Signed-off-by: ZX-ModelCloud --- gptqmodel/nn_modules/qlinear/gemm_awq_triton.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gptqmodel/nn_modules/qlinear/gemm_awq_triton.py b/gptqmodel/nn_modules/qlinear/gemm_awq_triton.py index be2fec5d9..afee97421 100644 --- a/gptqmodel/nn_modules/qlinear/gemm_awq_triton.py +++ b/gptqmodel/nn_modules/qlinear/gemm_awq_triton.py @@ -10,9 +10,9 @@ from ...adapter.adapter import Adapter, Lora from ...models._const import DEVICE, PLATFORM from ...nn_modules.qlinear import AWQuantLinear +from ...quantization.awq.modules.triton.gemm import awq_dequantize_triton, awq_gemm_triton from ...utils.backend import BACKEND from . import tritonv2 -from ...quantization.awq.modules.triton.gemm import awq_dequantize_triton, awq_gemm_triton class AwqGemmTritonFn(torch.autograd.Function):