From 71d56c76d0d06735406da04d7509e8a1a1a740c8 Mon Sep 17 00:00:00 2001 From: qwopqwop200 Date: Thu, 31 Aug 2023 14:35:04 +0900 Subject: [PATCH 1/3] skip install qigen(windows) --- setup.py | 30 ++++++++++++++++++------------ 1 file changed, 18 insertions(+), 12 deletions(-) diff --git a/setup.py b/setup.py index 5f7f5735..aa76349c 100644 --- a/setup.py +++ b/setup.py @@ -4,6 +4,7 @@ from setuptools import setup, Extension, find_packages import subprocess import math +import platform os.environ["CC"] = "g++" os.environ["CXX"] = "g++" @@ -94,10 +95,11 @@ additional_setup_kwargs = dict() if BUILD_CUDA_EXT: from torch.utils import cpp_extension - - p = int(subprocess.run("cat /proc/cpuinfo | grep cores | head -1", shell=True, check=True, text=True, stdout=subprocess.PIPE).stdout.split(" ")[2]) - - subprocess.call(["python", "./autogptq_extension/qigen/generate.py", "--module", "--search", "--p", str(p)]) + + if platform != 'Windows': + p = int(subprocess.run("cat /proc/cpuinfo | grep cores | head -1", shell=True, check=True, text=True, stdout=subprocess.PIPE).stdout.split(" ")[2]) + subprocess.call(["python", "./autogptq_extension/qigen/generate.py", "--module", "--search", "--p", str(p)]) + if not ROCM_VERSION: from distutils.sysconfig import get_python_lib conda_cuda_include_dir = os.path.join(get_python_lib(), "nvidia/cuda_runtime/include") @@ -120,16 +122,20 @@ "autogptq_extension/cuda_256/autogptq_cuda_256.cpp", "autogptq_extension/cuda_256/autogptq_cuda_kernel_256.cu" ] - ), - cpp_extension.CppExtension( - "cQIGen", - [ - 'autogptq_extension/qigen/backend.cpp' - ], - extra_compile_args = ["-O3", "-mavx", "-mavx2", "-mfma", "-march=native", "-ffast-math", "-ftree-vectorize", "-faligned-new", "-std=c++17", "-fopenmp", "-fno-signaling-nans", "-fno-trapping-math"] ) ] - + + if platform != 'Windows': + extensions.append( + cpp_extension.CppExtension( + "cQIGen", + [ + 'autogptq_extension/qigen/backend.cpp' + ], + extra_compile_args = ["-O3", "-mavx", "-mavx2", "-mfma", "-march=native", "-ffast-math", "-ftree-vectorize", "-faligned-new", "-std=c++17", "-fopenmp", "-fno-signaling-nans", "-fno-trapping-math"] + ) + ) + if os.name == "nt": # On Windows, fix an error LNK2001: unresolved external symbol cublasHgemm bug in the compilation cuda_path = os.environ.get("CUDA_PATH", None) From 45a1ee4d84b43a86133de8ea973ed067528e98bc Mon Sep 17 00:00:00 2001 From: qwopqwop200 Date: Thu, 31 Aug 2023 14:37:39 +0900 Subject: [PATCH 2/3] install check qigen --- auto_gptq/modeling/_base.py | 21 ++++---- auto_gptq/modeling/_utils.py | 53 ++++++++++--------- auto_gptq/nn_modules/qlinear/qlinear_qigen.py | 6 ++- auto_gptq/utils/import_utils.py | 7 +++ 4 files changed, 53 insertions(+), 34 deletions(-) diff --git a/auto_gptq/modeling/_base.py b/auto_gptq/modeling/_base.py index e3c3c7c9..89b5ae7e 100644 --- a/auto_gptq/modeling/_base.py +++ b/auto_gptq/modeling/_base.py @@ -26,7 +26,7 @@ from ..quantization import GPTQ from ..utils.data_utils import collate_data from ..utils.import_utils import ( - dynamically_import_QuantLinear, TRITON_AVAILABLE, AUTOGPTQ_CUDA_AVAILABLE, EXLLAMA_KERNELS_AVAILABLE + dynamically_import_QuantLinear, TRITON_AVAILABLE, AUTOGPTQ_CUDA_AVAILABLE, EXLLAMA_KERNELS_AVAILABLE, QIGEN_AVAILABLE ) logger = getLogger(__name__) @@ -727,13 +727,9 @@ def from_quantized( "_raise_exceptions_for_missing_entries": False, "_commit_hash": commit_hash, } - if use_qigen: - logger.warning("QIgen is active. Ignores all settings related to cuda.") - inject_fused_attention = False - inject_fused_mlp = False - use_triton = False - disable_exllama = True - + if use_qigen and not QIGEN_AVAILABLE: + logger.warning("Qigen is not installed, reset use_qigen to False.") + use_qigen = False if use_triton and not TRITON_AVAILABLE: logger.warning("Triton is not installed, reset use_triton to False.") use_triton = False @@ -754,7 +750,14 @@ def from_quantized( "2. You are using pytorch without CUDA support.\n" "3. CUDA and nvcc are not installed in your device." ) - + + if use_qigen and QIGEN_AVAILABLE: + logger.warning("QIgen is active. Ignores all settings related to cuda.") + inject_fused_attention = False + inject_fused_mlp = False + use_triton = False + disable_exllama = True + # == step1: prepare configs and file names == # config = AutoConfig.from_pretrained(model_name_or_path, trust_remote_code=trust_remote_code, **cached_file_kwargs) diff --git a/auto_gptq/modeling/_utils.py b/auto_gptq/modeling/_utils.py index 6c2fd38e..a268f406 100644 --- a/auto_gptq/modeling/_utils.py +++ b/auto_gptq/modeling/_utils.py @@ -6,7 +6,6 @@ import torch.nn as nn from transformers import AutoConfig import transformers -import cQIGen as qinfer from ._const import SUPPORTED_MODELS, CPU, CUDA_0, EXLLAMA_DEFAULT_MAX_INPUT_LENGTH from ..utils.import_utils import dynamically_import_QuantLinear @@ -105,28 +104,6 @@ def make_quant( use_qigen=use_qigen ) -def process_zeros_scales(zeros, scales, bits, out_features): - if zeros.dtype != torch.float32: - new_zeros = torch.zeros_like(scales).float().contiguous() - if bits == 4: - qinfer.unpack_zeros4(zeros, new_zeros, new_zeros.shape[0], new_zeros.shape[1]) - elif bits == 2: - qinfer.unpack_zeros2(zeros, new_zeros, new_zeros.shape[0], new_zeros.shape[1]) - elif bits == 3: - logger.info("Unpacking zeros for 3 bits") - new_scales = scales.contiguous() - else: - if scales.shape[1] != out_features: - new_scales = scales.transpose(0,1).contiguous() - else: - new_scales = scales.contiguous() - if zeros.shape[1] != out_features: - new_zeros = zeros.transpose(0,1).contiguous() - else: - new_zeros = zeros.contiguous() - - return new_zeros, new_scales - def preprocess_checkpoint_qigen( module, names, @@ -135,12 +112,40 @@ def preprocess_checkpoint_qigen( checkpoint, name='', ): + try: + import cQIGen as qinfer + except ImportError: + logger.error('cQIGen not installed.') + raise + QuantLinear = dynamically_import_QuantLinear(use_triton=False, desc_act=False, group_size=group_size, bits=bits, disable_exllama=False, use_qigen=True) if isinstance(module, QuantLinear): in_features = module.infeatures out_features = module.outfeatures + + zeros = checkpoint[name + '.qzeros'] + scales = checkpoint[name + '.scales'].float() + + if zeros.dtype != torch.float32: + new_zeros = torch.zeros_like(scales).float().contiguous() + if bits == 4: + qinfer.unpack_zeros4(zeros, new_zeros, new_zeros.shape[0], new_zeros.shape[1]) + elif bits == 2: + qinfer.unpack_zeros2(zeros, new_zeros, new_zeros.shape[0], new_zeros.shape[1]) + elif bits == 3: + logger.info("Unpacking zeros for 3 bits") + new_scales = scales.contiguous() + else: + if scales.shape[1] != out_features: + new_scales = scales.transpose(0,1).contiguous() + else: + new_scales = scales.contiguous() + if zeros.shape[1] != out_features: + new_zeros = zeros.transpose(0,1).contiguous() + else: + new_zeros = zeros.contiguous() - checkpoint[name + '.zeros'],checkpoint[name + '.scales'] = process_zeros_scales(checkpoint[name + '.qzeros'],checkpoint[name + '.scales'].float(), bits, out_features) + checkpoint[name + '.zeros'],checkpoint[name + '.scales'] = new_zeros, new_scales del checkpoint[name + '.qzeros'] del checkpoint[name + '.g_idx'] if name + '.bias' in checkpoint: diff --git a/auto_gptq/nn_modules/qlinear/qlinear_qigen.py b/auto_gptq/nn_modules/qlinear/qlinear_qigen.py index 3061083c..f771cbd9 100644 --- a/auto_gptq/nn_modules/qlinear/qlinear_qigen.py +++ b/auto_gptq/nn_modules/qlinear/qlinear_qigen.py @@ -4,7 +4,6 @@ from tqdm import tqdm import gc -import cQIGen as qinfer import math import numpy as np from gekko import GEKKO @@ -12,6 +11,11 @@ logger = getLogger(__name__) +try: + import cQIGen as qinfer +except ImportError: + logger.error('cQIGen not installed.') + raise def mem_model(N, M, T, mu, tu, bits, l1, p, gs): m = GEKKO() # create GEKKO model diff --git a/auto_gptq/utils/import_utils.py b/auto_gptq/utils/import_utils.py index 5cf1a053..cba52ba6 100644 --- a/auto_gptq/utils/import_utils.py +++ b/auto_gptq/utils/import_utils.py @@ -25,6 +25,13 @@ except: EXLLAMA_KERNELS_AVAILABLE = False +try: + import cQIGen as qinfer + + QIGEN_AVAILABLE = True +except: + QIGEN_AVAILABLE = False + logger = getLogger(__name__) From f97b77a64ec63081eb30e07188427dcd2f21ea5f Mon Sep 17 00:00:00 2001 From: qwopqwop200 Date: Thu, 31 Aug 2023 15:00:38 +0900 Subject: [PATCH 3/3] fix install bug --- setup.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/setup.py b/setup.py index aa76349c..13877389 100644 --- a/setup.py +++ b/setup.py @@ -95,8 +95,8 @@ additional_setup_kwargs = dict() if BUILD_CUDA_EXT: from torch.utils import cpp_extension - - if platform != 'Windows': + + if platform.system() != 'Windows': p = int(subprocess.run("cat /proc/cpuinfo | grep cores | head -1", shell=True, check=True, text=True, stdout=subprocess.PIPE).stdout.split(" ")[2]) subprocess.call(["python", "./autogptq_extension/qigen/generate.py", "--module", "--search", "--p", str(p)]) @@ -125,7 +125,7 @@ ) ] - if platform != 'Windows': + if platform.system() != 'Windows': extensions.append( cpp_extension.CppExtension( "cQIGen",