Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion gptqmodel/models/definitions/ovis.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,9 @@
from ...utils.calibration import batched
from ...utils.image import fetch_image
from ...utils.model import MODALITY, move_to
from ...utils.offload import offload_to_disk
from .._const import CPU
from ..base import BaseQModel
from ...utils.offload import offload_to_disk


class OvisQModel(BaseQModel):
Expand Down
1 change: 0 additions & 1 deletion gptqmodel/models/definitions/qwen3_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
# SPDX-License-Identifier: Apache-2.0
# Contact: qubitium@modelcloud.ai, x.com/qubitium

from ...quantization import METHOD
from ..base import BaseQModel


Expand Down
6 changes: 2 additions & 4 deletions gptqmodel/nn_modules/qlinear/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1032,12 +1032,10 @@ def pack_original(self, linear: nn.Module, scales: t.Tensor, zeros: t.Tensor, g_
class AWQuantLinear(BaseQuantLinear):
def __init__(self,
bias: bool = False,
use_bf16: bool = False,
register_buffers: bool = False,
**kwargs):
super().__init__(bias=bias, register_buffers=False, **kwargs)

self.use_bf16 = use_bf16

in_features = self.in_features
out_features = self.out_features
Expand All @@ -1058,12 +1056,12 @@ def __init__(self,
"scales",
t.zeros(
(in_features // self.group_size, out_features),
dtype=t.bfloat16 if self.use_bf16 else t.float32,
dtype=t.float16,
),
)

if bias:
self.register_buffer("bias", t.zeros(out_features, dtype=t.bfloat16 if self.use_bf16 else t.float32,))
self.register_buffer("bias", t.zeros(out_features, dtype=t.float16))
else:
self.bias = None

Expand Down
2 changes: 1 addition & 1 deletion gptqmodel/nn_modules/qlinear/awq_gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ class AwqGEMMQuantLinear(AWQuantLinear):
SUPPORTS_PACK_DTYPES = [torch.int32]
SUPPORTS_ADAPTERS = [Lora]

SUPPORTS_DTYPES = [torch.float16, torch.bfloat16]
SUPPORTS_DTYPES = [torch.float16]

REQUIRES_FORMAT_V2 = False

Expand Down
2 changes: 1 addition & 1 deletion gptqmodel/nn_modules/qlinear/awq_gemv.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ class AwqGEMVQuantLinear(AWQuantLinear):
SUPPORTS_PACK_DTYPES = [torch.int32]
SUPPORTS_ADAPTERS = [Lora]

SUPPORTS_DTYPES = [torch.float16, torch.bfloat16]
SUPPORTS_DTYPES = [torch.float16]

# for transformers/optimum tests compat
QUANT_TYPE = "awq_gemv"
Expand Down
2 changes: 1 addition & 1 deletion gptqmodel/nn_modules/qlinear/awq_marlin.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ class AwqMarlinQuantLinear(AWQuantLinear):
SUPPORTS_PACK_DTYPES = [torch.int32]
SUPPORTS_ADAPTERS = [Lora]

SUPPORTS_DTYPES = [torch.float16, torch.bfloat16]
SUPPORTS_DTYPES = [torch.float16]

REQUIRES_FORMAT_V2 = False

Expand Down
101 changes: 101 additions & 0 deletions gptqmodel/nn_modules/qlinear/awq_torch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
# SPDX-FileCopyrightText: 2024-2025 ModelCloud.ai
# SPDX-FileCopyrightText: 2024-2025 qubitium@modelcloud.ai
# SPDX-License-Identifier: Apache-2.0
# Contact: qubitium@modelcloud.ai, x.com/qubitium

import torch

from ...adapter.adapter import Adapter, Lora
from ...models._const import DEVICE, PLATFORM
from ...quantization.awq.utils.packing_utils import dequantize_gemm
from ...utils.backend import BACKEND
from ...utils.logger import setup_logger
from . import AWQuantLinear


log = setup_logger()


class AwqTorchQuantLinear(AWQuantLinear):
SUPPORTS_BITS = [4]
SUPPORTS_GROUP_SIZE = [-1, 16, 32, 64, 128]
SUPPORTS_DESC_ACT = [True, False]
SUPPORTS_SYM = [True, False]
SUPPORTS_SHARDS = True
SUPPORTS_TRAINING = True
SUPPORTS_AUTO_PADDING = False
SUPPORTS_IN_FEATURES_DIVISIBLE_BY = [1]
SUPPORTS_OUT_FEATURES_DIVISIBLE_BY = [1]

SUPPORTS_DEVICES = [DEVICE.ALL]
SUPPORTS_PLATFORM = [PLATFORM.ALL]
SUPPORTS_PACK_DTYPES = [torch.int32]
SUPPORTS_ADAPTERS = [Lora]

SUPPORTS_DTYPES = [torch.float16]

REQUIRES_FORMAT_V2 = False

QUANT_TYPE = "awq_torch"

def __init__(
self,
bits: int,
group_size: int,
sym: bool,
desc_act: bool,
in_features: int,
out_features: int,
bias: bool = False,
pack_dtype: torch.dtype = torch.int32,
adapter: Adapter = None,
register_buffers: bool = False,
**kwargs,
):
super().__init__(
bits=bits,
group_size=group_size,
sym=sym,
desc_act=desc_act,
in_features=in_features,
out_features=out_features,
bias=bias,
pack_dtype=pack_dtype,
backend=kwargs.pop("backend", BACKEND.TORCH_AWQ),
adapter=adapter,
register_buffers=register_buffers,
**kwargs,
)

def post_init(self):
super().post_init()

def extra_repr(self) -> str:
return (
f"in_features={self.in_features}, out_features={self.out_features}, "
f"bias={self.bias is not None}, bits={self.bits}, group_size={self.group_size}"
)

def forward(self, x: torch.Tensor):
original_shape = x.shape[:-1] + (self.out_features,)
device = x.device
x_flat = x.reshape(-1, x.shape[-1])

weight = dequantize_gemm(self.qweight, self.qzeros, self.scales, self.bits, self.group_size)
assert weight.dtype == torch.float16, f"weight {weight.dtype} is not float16"
if weight.dtype != x_flat.dtype or weight.device != device:
weight = weight.to(device=device, dtype=x_flat.dtype)

output = torch.matmul(x_flat, weight)

if self.bias is not None:
output = output + self.bias

if self.adapter:
output = self.adapter.apply(x=x_flat, out=output)

output = output.reshape(original_shape)

return output

__all__ = ["AwqTorchQuantLinear"]
1 change: 1 addition & 0 deletions gptqmodel/utils/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ class BACKEND(str, Enum):
GEMM = "gemm"
GEMV = "gemv"
GEMV_FAST = "gemv_fast"
TORCH_AWQ = "torch_awq"

# external
VLLM = "vllm" # External inference engine: CUDA + ROCm + IPEX
Expand Down
6 changes: 5 additions & 1 deletion gptqmodel/utils/importer.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from ..nn_modules.qlinear.awq_gemv_fast import AwqGEMVFastQuantLinear
from ..nn_modules.qlinear.awq_machete import AwqMacheteQuantLinear
from ..nn_modules.qlinear.awq_marlin import AwqMarlinQuantLinear
from ..nn_modules.qlinear.awq_torch import AwqTorchQuantLinear
from ..nn_modules.qlinear.bitblas import BitBLASQuantLinear
from ..nn_modules.qlinear.exllama import ExllamaQuantLinear
from ..nn_modules.qlinear.exllama_eora import ExllamaEoraQuantLinear
Expand Down Expand Up @@ -69,6 +70,7 @@
BACKEND.GEMM: AwqGEMMQuantLinear,
BACKEND.GEMV: AwqGEMVQuantLinear,
BACKEND.GEMV_FAST: AwqGEMVFastQuantLinear,
BACKEND.TORCH_AWQ: AwqTorchQuantLinear,
}),
}

Expand All @@ -83,7 +85,7 @@
FORMAT.QQQ: [BACKEND.QQQ],
},
METHOD.AWQ: {
FORMAT.GEMM: [BACKEND.MACHETE, BACKEND.MARLIN, BACKEND.EXLLAMA_V2, BACKEND.EXLLAMA_V1, BACKEND.GEMM],
FORMAT.GEMM: [BACKEND.MACHETE, BACKEND.MARLIN, BACKEND.EXLLAMA_V2, BACKEND.EXLLAMA_V1, BACKEND.GEMM, BACKEND.TORCH_AWQ],
FORMAT.GEMV: [BACKEND.GEMV],
FORMAT.GEMV_FAST: [BACKEND.GEMV_FAST],
FORMAT.MARLIN: [BACKEND.MACHETE, BACKEND.MARLIN],
Expand Down Expand Up @@ -314,6 +316,8 @@ def select_quant_linear(
qlinear = AwqGEMVQuantLinear
elif backend == BACKEND.GEMV_FAST:
qlinear = AwqGEMVFastQuantLinear
elif backend == BACKEND.TORCH_AWQ:
qlinear = AwqTorchQuantLinear
elif backend == BACKEND.TORCH:
qlinear = TorchQuantLinear
elif backend == BACKEND.TORCH_FUSED:
Expand Down
113 changes: 113 additions & 0 deletions tests/test_awq_torch_kernel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
# SPDX-FileCopyrightText: 2024-2025 ModelCloud.ai
# SPDX-FileCopyrightText: 2024-2025 qubitium@modelcloud.ai
# SPDX-License-Identifier: Apache-2.0
# Contact: qubitium@modelcloud.ai, x.com/qubitium

import pytest
import torch

from gptqmodel.nn_modules.qlinear.awq_torch import AwqTorchQuantLinear
from gptqmodel.quantization import FORMAT, METHOD
from gptqmodel.quantization.awq.utils.packing_utils import dequantize_gemm
from gptqmodel.utils.backend import BACKEND
from gptqmodel.utils.importer import select_quant_linear


def _pack_awq_tensor(unpacked: torch.Tensor, bits: int) -> torch.Tensor:
pack_factor = 32 // bits
order_map = [0, 2, 4, 6, 1, 3, 5, 7]

assert unpacked.shape[1] % pack_factor == 0
packed = torch.zeros(
(unpacked.shape[0], unpacked.shape[1] // pack_factor),
dtype=torch.int32,
)
for col in range(unpacked.shape[1] // pack_factor):
for i, order in enumerate(order_map):
value = unpacked[:, col * pack_factor + order].to(torch.int32)
packed[:, col] |= value << (i * bits)
return packed


@pytest.mark.parametrize("dtype", [torch.float16])
def test_awq_torch_matches_manual_dequant(dtype):
if dtype not in AwqTorchQuantLinear.SUPPORTS_DTYPES:
pytest.skip(f"dtype {dtype} not supported by AwqTorchQuantLinear")
torch.manual_seed(0)

bits = 4
in_features = 32
out_features = 64
group_size = 16

assert out_features % (32 // bits) == 0
assert in_features % group_size == 0

groups = in_features // group_size
pack_cols = out_features

int_weight = torch.randint(0, 2**bits, size=(in_features, out_features), dtype=torch.int32)
zero_points = torch.randint(0, 2**bits, size=(groups, pack_cols), dtype=torch.int32)
scales = (torch.rand(groups, pack_cols, dtype=torch.float16) * 2.0) + 0.25
bias = torch.randn(out_features, dtype=torch.float16)

qweight = _pack_awq_tensor(int_weight, bits)
qzeros = _pack_awq_tensor(zero_points, bits)

module = AwqTorchQuantLinear(
bits=bits,
group_size=group_size,
sym=True,
desc_act=False,
in_features=in_features,
out_features=out_features,
bias=True,
register_buffers=True,
)

module.qweight.copy_(qweight)
module.qzeros.copy_(qzeros)
module.scales = module.scales.to(dtype=torch.float16)
module.scales.copy_(scales.to(torch.float16))
module.bias.copy_(bias)
module.post_init()
module.eval()

batch = 4
x = torch.randn(batch, in_features, dtype=dtype)

bias_expected = module.bias

dequant_weight = dequantize_gemm(
qweight=module.qweight,
qzeros=module.qzeros,
scales=module.scales,
bits=bits,
group_size=group_size,
).to(dtype=dtype)

expected = torch.matmul(x.to(dtype), dequant_weight)
expected = expected + bias_expected

output_first = module(x)
output_second = module(x)

atol = 1e-4 if dtype == torch.float32 else 5e-3
rtol = 1e-4 if dtype == torch.float32 else 5e-3
torch.testing.assert_close(output_first, expected, atol=atol, rtol=rtol)
torch.testing.assert_close(output_second, expected, atol=atol, rtol=rtol)


def test_awq_torch_backend_selection():
qlinear_cls = select_quant_linear(
bits=4,
group_size=128,
desc_act=False,
sym=True,
device=None,
backend=BACKEND.TORCH_AWQ,
format=FORMAT.GEMM,
quant_method=METHOD.AWQ,
pack_dtype=torch.int32,
)
assert qlinear_cls is AwqTorchQuantLinear
Loading