diff --git a/gptqmodel/looper/qqq_processor.py b/gptqmodel/looper/qqq_processor.py index 1d125b8d8..9f7220a49 100644 --- a/gptqmodel/looper/qqq_processor.py +++ b/gptqmodel/looper/qqq_processor.py @@ -75,9 +75,21 @@ def preprocess(self, module: NamedModule): tmp = QQQ(module=module, qcfg=qcfg_clone) + if self.qcfg.mse > 0.0: + mse = True + norm = self.qcfg.mse + else: + mse = False + norm = 100 tmp.quantizer.configure( + self.qcfg.bits, perchannel=True, + sym=self.qcfg.sym, + mse=mse, + norm=norm, + groupsize=self.qcfg.group_size, ) + self.tasks[module.name] = tmp def is_skipped(self, module: NamedModule) -> bool: diff --git a/gptqmodel/quantization/qqq.py b/gptqmodel/quantization/qqq.py index 1c5c33af7..6e6bb8cad 100644 --- a/gptqmodel/quantization/qqq.py +++ b/gptqmodel/quantization/qqq.py @@ -1,35 +1,457 @@ -from copy import deepcopy +import math +import time +from typing import Optional import torch +import transformers +from torch import nn -from ..quantization import GPTQ, Quantizer -from ..quantization.quantizer import QQQQuantizer +from .gptq import get_number_of_rows_and_cols +from .. import QuantizeConfig +from ..looper.named_module import NamedModule +from ..quantization.quantizer import HF_OPTIMUM +from ..utils import setup_logger +from ..utils.torch import device_next +DEBUG = False -class QQQ(GPTQ): - def create_quantizer(self, name: str) -> Quantizer: - return QQQQuantizer(self.qcfg, name=name) +torch.backends.cuda.matmul.allow_tf32 = False +torch.backends.cudnn.allow_tf32 = False + +log = setup_logger() + +def quantize(x, scale, zero, maxq, sym, groupsize): + if maxq < 0: + return (x > scale / 2).float() * scale + (x < zero / 2).float() * zero + if groupsize != -1 or not sym: + q = torch.clamp(torch.round(x / scale) + zero, 0, maxq) + return scale * (q - zero) + else: + q = torch.clamp(torch.round(x / scale), -maxq, maxq) + return scale * q + + +class Quantizer(nn.Module): + def __init__(self, shape=1): + super(Quantizer, self).__init__() + self.register_buffer("maxq", torch.tensor(0)) + self.register_buffer("scale", torch.zeros(shape)) + self.register_buffer("zero", torch.zeros(shape)) + + def configure( + self, + bits, + perchannel=False, + sym=True, + groupsize=-1, + mse=False, + norm=2.4, + grid=100, + maxshrink=0.8, + trits=False, + ): + if groupsize != -1 or not sym: + self.maxq = torch.tensor(2 ** bits - 1) + else: + self.maxq = torch.tensor(2 ** (bits - 1) - 1) + self.perchannel = perchannel + self.groupsize = groupsize + self.sym = sym + self.mse = mse + self.norm = norm + self.grid = grid + self.maxshrink = maxshrink + if trits: + self.maxq = torch.tensor(-1) + + def find_params(self, x, weight=False): + dev = x.device + self.maxq = self.maxq.to(dev) + + shape = x.shape + if self.perchannel: + if weight: + x = x.flatten(1) + else: + if len(shape) == 4: + x = x.permute([1, 0, 2, 3]) + x = x.flatten(1) + if len(shape) == 3: + x = x.reshape((-1, shape[-1])).t() + if len(shape) == 2: + x = x.t() + else: + x = x.flatten().unsqueeze(0) + + tmp = torch.zeros(x.shape[0], device=dev) + xmin = torch.minimum(x.min(1)[0], tmp) + xmax = torch.maximum(x.max(1)[0], tmp) + + if self.sym: + xmax = torch.maximum(torch.abs(xmin), xmax) + tmp = xmin < 0 + if torch.any(tmp): + xmin[tmp] = -xmax[tmp] + tmp = (xmin == 0) & (xmax == 0) + xmin[tmp] = -1 + xmax[tmp] = +1 + + if self.maxq < 0: + self.scale = xmax + self.zero = xmin + else: + if self.groupsize != -1 or not self.sym: + self.scale = (xmax - xmin) / self.maxq + if self.sym: + self.zero = torch.full_like(self.scale, (self.maxq + 1) / 2) + else: + self.zero = torch.round(-xmin / self.scale) + else: + self.scale = xmax / self.maxq + self.zero = torch.zeros_like(self.scale) + + if self.mse: + best = torch.full([x.shape[0]], float("inf"), device=dev) + for i in range(int(self.maxshrink * self.grid)): + p = 1 - i / self.grid + xmin1 = p * xmin + xmax1 = p * xmax + scale1 = ( + (xmax1 - xmin1) / self.maxq + if (self.groupsize != -1 or not self.sym) + else xmax1 / self.maxq + ) + zero1 = torch.round(-xmin1 / scale1) if not self.sym else self.zero + q = quantize( + x, + scale1.unsqueeze(1), + zero1.unsqueeze(1), + self.maxq, + self.sym, + self.groupsize, + ) + q -= x + q.abs_() + q.pow_(self.norm) + err = torch.sum(q, 1) + tmp = err < best + if torch.any(tmp): + best[tmp] = err[tmp] + self.scale[tmp] = scale1[tmp] + self.zero[tmp] = zero1[tmp] + if not self.perchannel: + if weight: + tmp = shape[0] + else: + tmp = shape[1] if len(shape) != 3 else shape[2] + self.scale = self.scale.repeat(tmp) + self.zero = self.zero.repeat(tmp) + + if weight: + shape = [-1] + [1] * (len(shape) - 1) + self.scale = self.scale.reshape(shape) + self.zero = self.zero.reshape(shape) + return + if len(shape) == 4: + self.scale = self.scale.reshape((1, -1, 1, 1)) + self.zero = self.zero.reshape((1, -1, 1, 1)) + if len(shape) == 3: + self.scale = self.scale.reshape((1, 1, -1)) + self.zero = self.zero.reshape((1, 1, -1)) + if len(shape) == 2: + self.scale = self.scale.unsqueeze(0) + self.zero = self.zero.unsqueeze(0) + + def quantize(self, x): + if self.ready(): + return quantize( + x, self.scale, self.zero, self.maxq, self.sym, self.groupsize + ) + return x + + def enabled(self): + return self.maxq > 0 + + def ready(self): + return torch.all(self.scale != 0) + + +class QQQ: + def __init__(self, module: nn.Module, qcfg: Optional[QuantizeConfig] = None): + # self.lock = threading.Lock() + + # self.num_tied_handles = 0 + # if qcfg.tied_gptq_handle is not None: + # qcfg.tied_gptq_handle.num_tied_handles += 1 + + # Flags indicating issues + # self.issue_zero_samples = False + # self.issue_nan_hessian = False + # self.issue_non_invertible = False + + # self.W = module.weight + self.rows, self.columns = get_number_of_rows_and_cols(module) + if isinstance(module, NamedModule): + self.layer = module.module + self.name = module.name + else: + self.name = HF_OPTIMUM + self.layer = module + # emulate NamedModule properties + self.layer.target_device, self.layer.target_device_stream = device_next() + self.dev = self.layer.weight.device + + self._validate_module(self.layer) + + self.qcfg = qcfg if qcfg else QuantizeConfig() # HF compat will not pass qcfg + + self.module_copy = None + + self.nsamples = 0 + + self.quantizer = Quantizer() + + # fwd counter + self.fwd_counter = 0 + + self.fail_safe = False + + self.H = torch.zeros((self.columns, self.columns), + dtype=torch.float32, + device = self.dev) + + @staticmethod + def _validate_module(module): + assert isinstance(module, (nn.Linear, nn.Conv1d, nn.Conv2d, + transformers.Conv1D)), f"We supports only linear and convolutional layers. actual = `{module}`" + + def add_batch(self, inp, out): + if DEBUG: + self.inp1 = inp + self.out1 = out + if len(inp.shape) == 2: + inp = inp.unsqueeze(0) + tmp = inp.shape[0] + if isinstance(self.layer, nn.Linear) or isinstance( + self.layer, transformers.Conv1D + ): + if len(inp.shape) == 3: + inp = inp.reshape((-1, inp.shape[-1])) + inp = inp.t() + if isinstance(self.layer, nn.Conv2d): + unfold = nn.Unfold( + self.layer.kernel_size, + dilation=self.layer.dilation, + padding=self.layer.padding, + stride=self.layer.stride, + ) + inp = unfold(inp) + inp = inp.permute([1, 0, 2]) + inp = inp.flatten(1) + self.H *= self.nsamples / (self.nsamples + tmp) + self.nsamples += tmp + # inp = inp.float() + inp = math.sqrt(2 / self.nsamples) * inp.float() + # self.H += 2 / self.nsamples * inp.matmul(inp.t()) + self.H += inp.matmul(inp.t()) @torch.inference_mode() def quantize( - self, - blocksize=128, + self, + blocksize=128, ): - wq, scale, zero, g_idx, duration, avg_loss, damp_percent, nsamples = super().quantize(blocksize=blocksize) + start = time.time() + + percdamp = self.qcfg.damp_percent + groupsize = self.qcfg.group_size + actorder = self.qcfg.desc_act + static_groups = self.qcfg.static_groups + + W = self.layer.weight.data.clone() + if isinstance(self.layer, nn.Conv2d): + W = W.flatten(1) + if isinstance(self.layer, transformers.Conv1D): + W = W.t() + W = W.float() + + tick = time.time() + + if not self.quantizer.ready(): + self.quantizer.find_params(W, weight=True) + + H = self.H + del self.H + dead = torch.diag(H) == 0 + H[dead, dead] = 1 + W[:, dead] = 0 + + g_idx = [] + scale = [] + zero = [] + now_idx = 1 + if static_groups: + import copy + + groups = [] + for i in range(0, self.columns, groupsize): + quantizer = copy.deepcopy(self.quantizer) + quantizer.find_params(W[:, i: (i + groupsize)], weight=True) + scale.append(quantizer.scale) + zero.append(quantizer.zero) + groups.append(quantizer) + + if actorder: + perm = torch.argsort(torch.diag(H), descending=True) + W = W[:, perm] + H = H[perm][:, perm] + invperm = torch.argsort(perm) + + Losses = torch.zeros_like(W) + Q = torch.zeros_like(W) + + damp = percdamp * torch.mean(torch.diag(H)) + diag = torch.arange(self.columns, device=self.dev) + H[diag, diag] += damp + H = torch.linalg.cholesky(H) + H = torch.cholesky_inverse(H) + H = torch.linalg.cholesky(H, upper=True) + Hinv = H + + for i1 in range(0, self.columns, blocksize): + i2 = min(i1 + blocksize, self.columns) + count = i2 - i1 + + W1 = W[:, i1:i2].clone() + Q1 = torch.zeros_like(W1) + Err1 = torch.zeros_like(W1) + Losses1 = torch.zeros_like(W1) + Hinv1 = Hinv[i1:i2, i1:i2] + + for i in range(count): + w = W1[:, i] + d = Hinv1[i, i] + + if groupsize != -1: + if not static_groups: + if (i1 + i) % groupsize == 0: + self.quantizer.find_params( + W[:, (i1 + i): (i1 + i + groupsize)], weight=True + ) + + if ((i1 + i) // groupsize) - now_idx == -1: + scale.append(self.quantizer.scale) + zero.append(self.quantizer.zero) + now_idx += 1 + else: + idx = i1 + i + if actorder: + idx = perm[idx] + self.quantizer = groups[idx // groupsize] + + q = quantize( + w.unsqueeze(1), + self.quantizer.scale, + self.quantizer.zero, + self.quantizer.maxq, + self.quantizer.sym, + self.quantizer.groupsize, + ).flatten() + Q1[:, i] = q + Losses1[:, i] = (w - q) ** 2 / d ** 2 + + err1 = (w - q) / d + W1[:, i:] -= err1.unsqueeze(1).matmul(Hinv1[i, i:].unsqueeze(0)) + Err1[:, i] = err1 + + Q[:, i1:i2] = Q1 + Losses[:, i1:i2] = Losses1 / 2 + + W[:, i2:] -= Err1.matmul(Hinv[i1:i2, i2:]) + + if DEBUG: + self.layer.weight.data[:, :i2] = Q[:, :i2] + self.layer.weight.data[:, i2:] = W[:, i2:] + print(torch.sum((self.layer(self.inp1) - self.out1) ** 2)) + print(torch.sum(Losses)) + + torch.cuda.synchronize() + print("time %.2f" % (time.time() - tick)) + print("error", torch.sum(Losses).item()) + + if Hinv is not None: + del Hinv + if self.nsamples != 0: + avg_loss = torch.sum(Losses).item() / self.nsamples + + if math.isnan(avg_loss): + print("Losses sum item:", torch.sum(Losses).item()) + if self.fail_safe: + log.info(f"Quantization: Failed due to `NaN` loss for `{self.name}`, use mock quantization retry for `{self.name}`") + self.qcfg.mock_quantization = True + return self.quantize(blocksize=blocksize) + else: + raise ValueError(f"Quantization: Failed due to `NaN` loss for `{self.name}`, please try increasing calibration data samples or enable fail_safe=True") + else: + if self.fail_safe: + log.warn(f"Quantization: Module `{self.name}` -> using fail safe mode. Please check if calibration data is sufficient.") + else: + log.warn(f"Quantization: `{self.name}` is not activated due to model inference logic (MoE)") + avg_loss = 999999999 + else: + avg_loss = 999999999 + + del Losses + + groupsize = groupsize if groupsize != -1 else self.columns + if static_groups and actorder: + g_idx = [perm[i] // groupsize for i in range(self.columns)] + else: + g_idx = [i // groupsize for i in range(self.columns)] + g_idx = torch.tensor(g_idx, dtype=torch.int32, device=Q.device) + if actorder: + Q = Q[:, invperm] + g_idx = g_idx[invperm] + + if isinstance(self.layer, transformers.Conv1D): + Q = Q.t() + if Q.shape != self.layer.weight.shape: + Q = Q.reshape(self.layer.weight.shape).to(self.layer.weight.dtype) + else: + Q = Q.to(self.layer.weight.dtype) + if DEBUG: + print(torch.sum((self.layer(self.inp1) - self.out1) ** 2)) + + if scale == []: + scale.append(self.quantizer.scale) + zero.append(self.quantizer.zero) + scale = torch.cat(scale, dim=1) + zero = torch.cat(zero, dim=1) + + Q = Q.to(device=self.layer.weight.data.device, non_blocking=False) # post int8 quant scale_extra = None - if self.qcfg.group_size != self.columns: - qcfg = deepcopy(self.qcfg) - qcfg.group_size = -1 - qcfg.bits = 8 - qcfg.sym = True - qcfg.mse = 0.0 - - quantizer_extra = QQQQuantizer(qcfg, name=self.quantizer.name) + if groupsize != self.columns: + quantizer_extra = Quantizer() quantizer_extra.configure( + bits=8, perchannel=True, + groupsize=-1, + sym=True, + mse=False, ) - quantizer_extra.find_params(self.module.weight.data.clone(), weight=True) + quantizer_extra.find_params(self.layer.weight.data.clone(), weight=True) scale_extra = quantizer_extra.scale - return wq, scale, zero, g_idx, duration, avg_loss, damp_percent, scale_extra, nsamples + + duration = time.time() - start + + return Q, scale, zero, g_idx, duration, avg_loss, damp, scale_extra, self.nsamples + + def free(self): + if DEBUG: + self.inp1 = None + self.out1 = None + self.H = None + self.Losses = None + self.Trace = None + torch.cuda.empty_cache() diff --git a/gptqmodel/utils/model.py b/gptqmodel/utils/model.py index 90f887a6f..1fe3c3747 100644 --- a/gptqmodel/utils/model.py +++ b/gptqmodel/utils/model.py @@ -602,7 +602,8 @@ def pack_module(name, qModules, q_scales, q_zeros, q_g_idx, layers, quant_linear # TODO FIX ME..remove hard coded qqq pack if quant_linear_cls.QUANT_TYPE == "qqq": - q_scales_extra = q_scales_extra.to(CPU) + if q_scales_extra is not None: + q_scales_extra = q_scales_extra.to(CPU) module.pack(linear=layer, scales=q_scales, s_extra=q_scales_extra) else: module.pack(linear=layer, scales=q_scales, zeros=q_zeros, g_idx=q_g_idx)