Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 9 additions & 9 deletions gptqmodel/looper/awq_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from ..quantization.awq.utils.utils import get_best_device
from ..quantization.config import FORMAT, METHOD, QuantizeConfig
from ..utils.logger import setup_logger
from ..utils.ctx import ctx
from ..utils.model import get_module_by_name_prefix, move_to
from ..utils.torch import CPU, tf32_disable_guard, tf32_enable_guard, torch_sync

Expand Down Expand Up @@ -295,17 +296,16 @@ def _search_best_scale(

# [STEP 3]: Compute output of module
module_kwargs = self._sanitize_kwargs(kwargs, module2inspect)
with torch.inference_mode():
with tf32_enable_guard():
fp16_output = self._module_forward(inp, module2inspect, module_kwargs)
with ctx(torch.inference_mode(), tf32_enable_guard()):
fp16_output = self._module_forward(inp, module2inspect, module_kwargs)

with tf32_disable_guard():
fp16_output = fp16_output.clip(torch.finfo(fp16_output.dtype).min, torch.finfo(fp16_output.dtype).max)
with tf32_disable_guard():
fp16_output = fp16_output.clip(torch.finfo(fp16_output.dtype).min, torch.finfo(fp16_output.dtype).max)

# [STEP 4]: Compute loss
best_scales, loss = self._compute_best_scale(
inp, w_mean, x_mean, module2inspect, layers, fp16_output, module_kwargs
)
# [STEP 4]: Compute loss
best_scales, loss = self._compute_best_scale(
inp, w_mean, x_mean, module2inspect, layers, fp16_output, module_kwargs
)

return (
get_op_name(module, prev_op),
Expand Down
24 changes: 17 additions & 7 deletions gptqmodel/looper/module_looper.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

import threading
import time
from contextlib import nullcontext
from typing import Dict, List, Optional

import torch
Expand All @@ -32,6 +33,7 @@
from ..nn_modules.hooked_linear import (STOP_FORWARD_EXCEPTION, HookedLinear,
StopForward, replace_module_with_hooked_legacy)
from ..utils.attn_mask import apply_keep_mask_bt, normalize_seq_mask
from ..utils.ctx import ctx
from ..utils.device import get_device, get_device_new
from ..utils.logger import setup_logger
from ..utils.looper_helpers import (
Expand Down Expand Up @@ -564,12 +566,14 @@ def store_input_hook(module, args, kwargs):
example["attention_mask"] = example["attention_mask"].long()

# Ensure initial caches (like RoPE) are created on the quant device
with self.pool.read_lock(self.gptq_model.quantize_config.device):
with device_ctx(self.gptq_model.quantize_config.device):
if self.gptq_model.INPUT_EMBEDDING_EXTRA_ARGS:
self.gptq_model.model.generate(**example, **self.gptq_model.INPUT_EMBEDDING_EXTRA_ARGS)
else:
self.gptq_model.model(**example, use_cache=use_cache)
with ctx(
self.pool.read_lock(self.gptq_model.quantize_config.device),
device_ctx(self.gptq_model.quantize_config.device),
):
if self.gptq_model.INPUT_EMBEDDING_EXTRA_ARGS:
self.gptq_model.model.generate(**example, **self.gptq_model.INPUT_EMBEDDING_EXTRA_ARGS)
else:
self.gptq_model.model(**example, use_cache=use_cache)
except StopForward:
pass

Expand Down Expand Up @@ -711,7 +715,13 @@ def loop(self, fail_safe: bool = False, **kwargs):
processor=processor,
fail_safe=fail_safe)
named_childs.update(named_modules)
processor.layer_quantize(module, cur_layer_device, named_childs)

lock_ctx = nullcontext()
device_for_ctx = cur_layer_device if getattr(cur_layer_device, 'type', None) != 'meta' else None
if device_for_ctx is not None:
lock_ctx = self.pool.read_lock(cur_layer_device)
with ctx(lock_ctx, device_ctx(device_for_ctx)):
processor.layer_quantize(module, cur_layer_device, named_childs)
continue

layer_inputs = processor.inputs_cache.layer_inputs
Expand Down
3 changes: 1 addition & 2 deletions gptqmodel/nn_modules/qlinear/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -789,7 +789,6 @@ class AWQuantLinear(BaseQuantLinear):
def __init__(self,
bias: bool = False,
use_bf16: bool = False,
register_awq_buffers: bool = True,
register_buffers: bool = False,
**kwargs):
super().__init__(bias=bias, register_buffers=False, **kwargs)
Expand All @@ -799,7 +798,7 @@ def __init__(self,
in_features = self.in_features
out_features = self.out_features

if register_awq_buffers:
if register_buffers:
self.register_buffer(
"qweight",
t.zeros((in_features, out_features // (self.pack_dtype_bits // self.bits)), dtype=self.pack_dtype),
Expand Down
5 changes: 4 additions & 1 deletion gptqmodel/nn_modules/qlinear/awq_gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ def __init__(
bias: bool = False,
pack_dtype: torch.dtype = torch.int32,
adapter: Adapter = None,
register_buffers: bool = False,
**kwargs,
):
super().__init__(
Expand All @@ -62,6 +63,7 @@ def __init__(
pack_dtype=pack_dtype,
backend=kwargs.pop("backend", BACKEND.GEMM),
adapter=adapter,
register_buffers=register_buffers,
**kwargs)

def post_init(self):
Expand All @@ -76,7 +78,8 @@ def post_init(self):
# device=self.g_idx.device)

# awq only accepts float16
self.scales = self.scales.to(dtype=torch.float16)
if self.scales is not None:
self.scales = self.scales.to(dtype=torch.float16)

super().post_init()

Expand Down
55 changes: 29 additions & 26 deletions gptqmodel/nn_modules/qlinear/awq_gemv.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,10 @@ def __init__(
bias: bool = False,
pack_dtype: torch.dtype = torch.int32,
adapter: Adapter = None,
register_buffers: bool = False,
**kwargs,
):
backend = kwargs.pop("backend", BACKEND.GEMV)
super().__init__(
bits=bits,
group_size=group_size,
Expand All @@ -61,38 +63,39 @@ def __init__(
out_features=out_features,
bias=bias,
pack_dtype=pack_dtype,
backend=kwargs.pop("backend", BACKEND.GEMV),
backend=backend,
adapter=adapter,
register_awq_buffers=False,
register_buffers=False,
**kwargs)

self.split_k_iters = 8

self.register_buffer(
"qweight",
torch.zeros((out_features, in_features // self.pack_factor), dtype=self.pack_dtype),
)
self.register_buffer(
"qzeros",
torch.zeros(
out_features,
calculate_zeros_width(in_features, self.group_size),
dtype=self.pack_dtype,
),
)
self.register_buffer(
"scales",
torch.zeros(
out_features,
calculate_zeros_width(in_features, self.group_size) * self.pack_factor,
dtype=torch.float16,
),
)
self.bias = None

if bias:
self.register_buffer("bias", torch.zeros(out_features, dtype=torch.float16))
else:
self.bias = None
if register_buffers:
self.register_buffer(
"qweight",
torch.zeros((out_features, in_features // self.pack_factor), dtype=self.pack_dtype),
)
self.register_buffer(
"qzeros",
torch.zeros(
out_features,
calculate_zeros_width(in_features, self.group_size),
dtype=self.pack_dtype,
),
)
self.register_buffer(
"scales",
torch.zeros(
out_features,
calculate_zeros_width(in_features, self.group_size) * self.pack_factor,
dtype=torch.float16,
),
)

if bias:
self.register_buffer("bias", torch.zeros(out_features, dtype=torch.float16))

def post_init(self):
# if self.padded_infeatures != self.in_features:
Expand Down
55 changes: 29 additions & 26 deletions gptqmodel/nn_modules/qlinear/awq_gemv_fast.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,10 @@ def __init__(
bias: bool = False,
pack_dtype: torch.dtype = torch.int32,
adapter: Adapter = None,
register_buffers: bool = False,
**kwargs,
):
backend = kwargs.pop("backend", BACKEND.GEMV_FAST)
super().__init__(
bits=bits,
group_size=group_size,
Expand All @@ -61,41 +63,42 @@ def __init__(
out_features=out_features,
bias=bias,
pack_dtype=pack_dtype,
backend=kwargs.pop("backend", BACKEND.GEMV_FAST),
backend=backend,
adapter=adapter,
register_awq_buffers=False,
register_buffers=False,
**kwargs)

self.split_k_iters = 8
self.interleave = 4

int32_pack_factor = 32 // self.bits

self.register_buffer(
"qweight",
torch.zeros((out_features // self.interleave, in_features // self.pack_factor * self.interleave), dtype=self.pack_dtype),
)
self.register_buffer(
"qzeros",
torch.zeros(
calculate_zeros_width(in_features, self.group_size) * int32_pack_factor,
out_features,
dtype=torch.float16,
),
)
self.register_buffer(
"scales",
torch.zeros(
calculate_zeros_width(in_features, self.group_size) * int32_pack_factor,
out_features,
dtype=torch.float16,
),
)
self.bias = None

if bias:
self.register_buffer("bias", torch.zeros(out_features, dtype=torch.float16))
else:
self.bias = None
if register_buffers:
self.register_buffer(
"qweight",
torch.zeros((out_features // self.interleave, in_features // self.pack_factor * self.interleave), dtype=self.pack_dtype),
)
self.register_buffer(
"qzeros",
torch.zeros(
calculate_zeros_width(in_features, self.group_size) * int32_pack_factor,
out_features,
dtype=torch.float16,
),
)
self.register_buffer(
"scales",
torch.zeros(
calculate_zeros_width(in_features, self.group_size) * int32_pack_factor,
out_features,
dtype=torch.float16,
),
)

if bias:
self.register_buffer("bias", torch.zeros(out_features, dtype=torch.float16))

def post_init(self):
# if self.padded_infeatures != self.in_features:
Expand Down
84 changes: 43 additions & 41 deletions gptqmodel/nn_modules/qlinear/awq_marlin.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ def __init__(
bias: bool = False,
pack_dtype: torch.dtype = torch.int32,
adapter: Adapter = None,
register_buffers=False,
**kwargs):
if marlin_import_exception is not None:
raise ValueError(
Expand All @@ -95,54 +96,55 @@ def __init__(
pack_dtype=pack_dtype,
backend=kwargs.pop("backend", BACKEND.MARLIN),
adapter=adapter,
register_awq_buffers=False,
register_buffers=False,
**kwargs)

self.register_parameter(
"qweight",
torch.nn.Parameter(
torch.empty(
self.in_features,
self.out_features // self.pack_factor,
dtype=torch.int32,
if register_buffers:
self.register_parameter(
"qweight",
torch.nn.Parameter(
torch.empty(
self.in_features,
self.out_features // self.pack_factor,
dtype=torch.int32,
),
requires_grad=False
),
requires_grad=False
),
)
self.register_parameter(
"qzeros",
torch.nn.Parameter(
torch.empty(
self.in_features // self.group_size,
self.out_features // self.pack_factor,
dtype=torch.int32,
),
requires_grad=False
)
)

self.register_parameter(
"scales",
torch.nn.Parameter(
torch.empty(
self.in_features // self.group_size,
self.out_features,
dtype=torch.float16,
),
requires_grad=False
self.register_parameter(
"qzeros",
torch.nn.Parameter(
torch.empty(
self.in_features // self.group_size,
self.out_features // self.pack_factor,
dtype=torch.int32,
),
requires_grad=False
)
)
)

if bias:
self.register_buffer(
"bias",
torch.zeros(
(out_features),
dtype=torch.float16,
),
self.register_parameter(
"scales",
torch.nn.Parameter(
torch.empty(
self.in_features // self.group_size,
self.out_features,
dtype=torch.float16,
),
requires_grad=False
)
)
else:
self.bias = None

if bias:
self.register_buffer(
"bias",
torch.zeros(
(out_features),
dtype=torch.float16,
),
)
else:
self.bias = None

self.is_lm_head = False
if kwargs.get("name") is not None and kwargs.get("lm_head_name") is not None:
Expand Down
Loading