diff --git a/gptqmodel/looper/module_looper.py b/gptqmodel/looper/module_looper.py index d6830bdf4..577b5eaac 100644 --- a/gptqmodel/looper/module_looper.py +++ b/gptqmodel/looper/module_looper.py @@ -19,7 +19,7 @@ import time from concurrent.futures import as_completed from contextlib import nullcontext -from typing import Dict, List, Optional, TYPE_CHECKING +from typing import Dict, List, NamedTuple, Optional, TYPE_CHECKING import torch @@ -59,6 +59,12 @@ from logbar.progress import ProgressBar +class FinalizeProgressInfo(NamedTuple): + module_label: Optional[str] + process_name: str + layer_idx: Optional[int] + + class ModuleLooper(): """Drive the per-layer quantisation workflow over one or more devices. @@ -488,6 +494,9 @@ def _run_forward_batches_parallel( """Fan batches across device clones and preserve result ordering.""" module_replicas = clone_module_for_devices(module, devices) + # Ensure any async replication/memcpy ops are complete before threads start fanning out. + torch_sync() + prev_kv = shared_kv_cache_dict.get(layer_index - 1) if reuse_kv else None results: Dict[int, torch.Tensor | tuple | None] = {} @@ -1308,6 +1317,9 @@ def _finalize_on_worker(process, module, idx, total, module_label, layer_idx): source=resolved_label, ) + process_name = process.name() if process is not None else "" + return FinalizeProgressInfo(module_label, process_name, layer_idx) + # pb.subtitle( # f"{process.name()}: layer:{layer_idx} Finalized {idx}/{total} {module_label}" # ).draw() @@ -1328,34 +1340,57 @@ def _finalize_on_worker(process, module, idx, total, module_label, layer_idx): finalize_futures_snapshot = list(finalize_futures) if finalize_futures_snapshot: + known_layers = sorted( + { + layer_idx + for _, _, _, _, layer_idx in finalize_futures_snapshot + if layer_idx is not None + } + ) + includes_unknown = any( + layer_idx is None + for _, _, _, _, layer_idx in finalize_futures_snapshot + ) + + layer_heading = "Layer ?" + if known_layers: + sample_layers = ", ".join(str(idx) for idx in known_layers[:3]) + if len(known_layers) > 3: + sample_layers += ", …" + suffix = ", ?" if includes_unknown else "" + prefix = "Layer" if len(known_layers) == 1 else "Layers" + layer_heading = f"{prefix} {sample_layers}{suffix}" + elif includes_unknown: + layer_heading = "Layer ?" + finalize_pb.title( - f"Submodule finalize 0/{finalize_count}" + f"{layer_heading} Submodule finalize 0/{finalize_count}" ).subtitle("Waiting for completions...").draw() - future_metadata = { - future: (module_label, process, layer_idx) - for future, _, module_label, process, layer_idx in finalize_futures_snapshot - } - def _drain_finalize_futures( futures, finalize_pb_local, finalize_count_local, - future_metadata_local, ): completed_local = 0 try: for future in as_completed(futures): - module_label, process, layer_idx = future_metadata_local.get( - future, (None, None, None) - ) - - future.result() - - layer_label = f"Layer {layer_idx}" if layer_idx is not None else "layer ?" + result = future.result() + + if isinstance(result, FinalizeProgressInfo): + module_label = result.module_label + process_name = result.process_name + layer_idx = result.layer_idx + elif isinstance(result, tuple) and len(result) == 3: + module_label, process_name, layer_idx = result + else: + module_label = None + process_name = "" + layer_idx = None + + layer_label = f"Layer {layer_idx}" if layer_idx is not None else "Layer ?" display_module = module_label or "" - processor_name = process.name() if process is not None else "" - subtitle = f"{processor_name}: {display_module}" + subtitle = f"{process_name}: {display_module}" completed_local += 1 finalize_pb_local.next() @@ -1373,7 +1408,6 @@ def _drain_finalize_futures( [future for future, *_ in finalize_futures_snapshot], finalize_pb, finalize_count, - future_metadata, ), name="SubmoduleFinalizeWatcher", daemon=True, diff --git a/gptqmodel/models/auto.py b/gptqmodel/models/auto.py index 15298a6b0..74ae0c2d9 100644 --- a/gptqmodel/models/auto.py +++ b/gptqmodel/models/auto.py @@ -56,6 +56,7 @@ from .base import BaseQModel, QuantizeConfig # noqa: E402 from .definitions.apertus import ApertusQModel # noqa: E402 from .definitions.baichuan import BaiChuanQModel # noqa: E402 +from .definitions.bailing_moe import BailingMoeQModel # noqa: E402 from .definitions.bloom import BloomQModel # noqa: E402 from .definitions.chatglm import ChatGLMQModel # noqa: E402 from .definitions.codegen import CodeGenQModel # noqa: E402 @@ -85,6 +86,7 @@ from .definitions.internlm import InternLMQModel # noqa: E402 from .definitions.internlm2 import InternLM2QModel # noqa: E402 from .definitions.klear import KlearQModel # noqa: E402 +from .definitions.lfm2_moe import LFM2MoeQModel # noqa: E402 from .definitions.llama import LlamaQModel # noqa: E402 from .definitions.llama4 import Llama4QModel # noqa: E402 from .definitions.llava_qwen2 import LlavaQwen2QModel # noqa: E402 @@ -118,9 +120,6 @@ from .definitions.starcoder2 import Starcoder2QModel # noqa: E402 from .definitions.telechat2 import TeleChat2QModel from .definitions.xverse import XverseQModel # noqa: E402 -from .definitions.bailing_moe import BailingMoeQModel # noqa: E402 -from .definitions.lfm2_moe import LFM2MoeQModel # noqa: E402 - # make quants and inference more determinisitc diff --git a/gptqmodel/models/base.py b/gptqmodel/models/base.py index fdb02d9a2..57613e44c 100644 --- a/gptqmodel/models/base.py +++ b/gptqmodel/models/base.py @@ -37,14 +37,14 @@ HFDataset = None HFIterableDataset = None +from .. import DEVICE_THREAD_POOL from ..adapter.adapter import Adapter from ..nn_modules.qlinear import BaseQuantLinear -from ..nn_modules.qlinear.torch import TorchQuantLinear from ..nn_modules.qlinear.lookahead import configure_default_lookahead +from ..nn_modules.qlinear.torch import TorchQuantLinear from ..quantization import QuantizeConfig from ..quantization.config import FORMAT, METHOD, QUANTIZE_BLACK_LIST, dynamic_get from ..quantization.rotation.rotation import fuse_layer_norms, rotate_model -from .. import DEVICE_THREAD_POOL from ..utils.backend import BACKEND from ..utils.data import collate_data from ..utils.device import get_device diff --git a/gptqmodel/models/definitions/lfm2_moe.py b/gptqmodel/models/definitions/lfm2_moe.py index dd325f3d7..217bbb3b5 100644 --- a/gptqmodel/models/definitions/lfm2_moe.py +++ b/gptqmodel/models/definitions/lfm2_moe.py @@ -20,7 +20,7 @@ class LFM2MoeQModel(BaseQModel): "operator_norm": ("operator_norm:!",), "conv": ("in_proj", "out_proj"), "self_attn": ("q_proj:0", "k_proj:0", "v_proj:0", "o_proj:1"), - "ffn_norm": ("ffn_norm:!",), + "ffn_norm": ("ffn_norm:!",), "feed_forward": { "gate": ("gate:!",), "": ("w1:0", "w3:0", "w2:1"), diff --git a/gptqmodel/models/writer.py b/gptqmodel/models/writer.py index 6e78fac26..6f33b2d4a 100644 --- a/gptqmodel/models/writer.py +++ b/gptqmodel/models/writer.py @@ -9,10 +9,10 @@ import csv import json import os -import pcre as re from os.path import isfile, join from typing import Any, Dict, Optional, Union +import pcre as re import torch import transformers from safetensors.torch import save_file diff --git a/gptqmodel/nn_modules/qlinear/__init__.py b/gptqmodel/nn_modules/qlinear/__init__.py index 9bf49158e..aa971d500 100644 --- a/gptqmodel/nn_modules/qlinear/__init__.py +++ b/gptqmodel/nn_modules/qlinear/__init__.py @@ -10,7 +10,6 @@ from typing import List, Optional, Tuple import numpy as np -import threadpoolctl import torch as t # conflict with torch.py import torch.nn as nn import transformers @@ -20,6 +19,7 @@ from ...models._const import DEVICE, PLATFORM from ...utils.backend import BACKEND from ...utils.logger import setup_logger +from ...utils.safe import THREADPOOLCTL log = setup_logger() @@ -890,7 +890,7 @@ def _pack_rows_3(int32_blk_32xN: t.Tensor, dst: t.Tensor, dst_rows_base: int): del weight, scales_dev, zeros_dev, scale_zeros_dev, qweight_dev, qzeros_dev def pack_original(self, linear: nn.Module, scales: t.Tensor, zeros: t.Tensor, g_idx: t.Tensor=None): - with threadpoolctl.threadpool_limits(1): + with THREADPOOLCTL.threadpool_limits(1): # TODO why did we need to clone? at packing, the original weight is no longer used by other processors? # W = linear.weight.data.clone() W = linear.weight.data diff --git a/gptqmodel/nn_modules/qlinear/awq_gemm.py b/gptqmodel/nn_modules/qlinear/awq_gemm.py index f666de2c0..4f041e794 100644 --- a/gptqmodel/nn_modules/qlinear/awq_gemm.py +++ b/gptqmodel/nn_modules/qlinear/awq_gemm.py @@ -7,7 +7,7 @@ from ...adapter.adapter import Adapter, Lora from ...models._const import DEVICE, PLATFORM -from ...nn_modules.qlinear import AWQuantLinear, PackableQuantLinear +from ...nn_modules.qlinear import AWQuantLinear from ...quantization.awq.modules.linear.gemm import WQLinearMMFunction from ...utils.backend import BACKEND from ...utils.logger import setup_logger diff --git a/gptqmodel/nn_modules/qlinear/awq_gemv.py b/gptqmodel/nn_modules/qlinear/awq_gemv.py index 22978b872..eabd08733 100644 --- a/gptqmodel/nn_modules/qlinear/awq_gemv.py +++ b/gptqmodel/nn_modules/qlinear/awq_gemv.py @@ -7,7 +7,7 @@ from ...adapter.adapter import Adapter, Lora from ...models._const import DEVICE, PLATFORM -from ...nn_modules.qlinear import AWQuantLinear, PackableQuantLinear +from ...nn_modules.qlinear import AWQuantLinear from ...quantization.awq.utils.module import try_import from ...utils.backend import BACKEND from ...utils.gemv import calculate_zeros_width diff --git a/gptqmodel/nn_modules/qlinear/awq_gemv_fast.py b/gptqmodel/nn_modules/qlinear/awq_gemv_fast.py index 5ffae97e4..225f02749 100644 --- a/gptqmodel/nn_modules/qlinear/awq_gemv_fast.py +++ b/gptqmodel/nn_modules/qlinear/awq_gemv_fast.py @@ -7,7 +7,7 @@ from ...adapter.adapter import Adapter, Lora from ...models._const import DEVICE, PLATFORM -from ...nn_modules.qlinear import AWQuantLinear, PackableQuantLinear +from ...nn_modules.qlinear import AWQuantLinear from ...quantization.awq.utils.module import try_import from ...utils.backend import BACKEND from ...utils.gemv import calculate_zeros_width diff --git a/gptqmodel/nn_modules/qlinear/pack_block_ext.py b/gptqmodel/nn_modules/qlinear/pack_block_ext.py index 3edfd2f71..856afc4d7 100644 --- a/gptqmodel/nn_modules/qlinear/pack_block_ext.py +++ b/gptqmodel/nn_modules/qlinear/pack_block_ext.py @@ -13,6 +13,7 @@ from torch import Tensor from torch.utils.cpp_extension import load + log = logging.getLogger(__name__) _EXTENSION = None diff --git a/gptqmodel/nn_modules/qlinear/torch.py b/gptqmodel/nn_modules/qlinear/torch.py index f07270f65..86b366ce5 100644 --- a/gptqmodel/nn_modules/qlinear/torch.py +++ b/gptqmodel/nn_modules/qlinear/torch.py @@ -19,6 +19,7 @@ from ...utils.logger import setup_logger from ...utils.torch import torch_compile + try: from ..triton_utils.dequant import dequant as triton_dequant @@ -259,7 +260,7 @@ def schedule(tile_idx: int, buffer_idx: int): compute_stream.wait_stream(stream_dequant) width = widths[buffer_idx] start = tile_idx * tile - end = start + width + start + width out_slice = out.narrow(1, start, width) out_slice.zero_() diff --git a/gptqmodel/quantization/awq/modules/triton/gemm.py b/gptqmodel/quantization/awq/modules/triton/gemm.py index 8d5789c4d..c027e195b 100644 --- a/gptqmodel/quantization/awq/modules/triton/gemm.py +++ b/gptqmodel/quantization/awq/modules/triton/gemm.py @@ -18,6 +18,7 @@ import triton import triton.language as tl + AWQ_TRITON_SUPPORTED_GROUP_SIZES = [-1, 32, 64, 128] def get_same_device_cm(t): @@ -282,10 +283,11 @@ def awq_dequantize_triton( Y = qweight.shape[0] # num rows X = qweight.shape[1] # num cols - grid = lambda META: ( - triton.cdiv(X, META["BLOCK_SIZE_X"]), - triton.cdiv(Y, META["BLOCK_SIZE_Y"]), - ) + def grid(META): + return ( + triton.cdiv(X, META["BLOCK_SIZE_X"]), + triton.cdiv(Y, META["BLOCK_SIZE_Y"]), + ) with get_same_device_cm(qweight): awq_dequantize_kernel[grid]( qweight, @@ -330,10 +332,11 @@ def awq_gemm_triton( assert group_size <= K assert group_size in AWQ_TRITON_SUPPORTED_GROUP_SIZES or group_size == K - grid = lambda META: ( - triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]), - split_k_iters, - ) + def grid(META): + return ( + triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]), + split_k_iters, + ) result = torch.zeros((M, N), dtype=scales.dtype, device=input.device) @@ -356,4 +359,4 @@ def awq_gemm_triton( SPLIT_K=split_k_iters, ) - return result \ No newline at end of file + return result diff --git a/gptqmodel/quantization/config.py b/gptqmodel/quantization/config.py index d8dd7aab2..4ce618708 100644 --- a/gptqmodel/quantization/config.py +++ b/gptqmodel/quantization/config.py @@ -5,12 +5,12 @@ import json import os.path -import pcre as re from dataclasses import dataclass, field, fields from enum import Enum from os.path import join from typing import Any, Dict, List, Optional, Tuple, Union +import pcre as re import torch from packaging import version from random_word import random_word diff --git a/gptqmodel/utils/bitblas.py b/gptqmodel/utils/bitblas.py index 55588e627..71523dd1c 100644 --- a/gptqmodel/utils/bitblas.py +++ b/gptqmodel/utils/bitblas.py @@ -4,14 +4,15 @@ # Contact: qubitium@modelcloud.ai, x.com/qubitium import os +from contextlib import nullcontext -import threadpoolctl as tctl import torch from ..nn_modules.qlinear.bitblas import BitBLASQuantLinear from ..quantization import FORMAT, QuantizeConfig from ..utils.logger import setup_logger from .model import load_checkpoint_in_model_then_tie_weights +from .safe import THREADPOOLCTL from .torch import torch_empty_cache @@ -76,7 +77,13 @@ def convert_to_bitblas(model, model_quantlinear, qcfg: QuantizeConfig, sym: bool message = "Overriding QuantLinear layers to use BitBLAS's QuantLinear..." # TODO: need to benchmark to see multiple threads help with bitblas/tvm compilation and runtime - with tctl.threadpool_limits(limits=1): + threadpool_limits = ( + THREADPOOLCTL.threadpool_limits + if THREADPOOLCTL is not None + else (lambda *args, **kwargs: nullcontext()) + ) + + with threadpool_limits(limits=1): os.environ["NUMEXPR_MAX_THREADS"] = "1" # Note that due to tvm compilation of per layer modules shapes, the first layer loop is diff --git a/gptqmodel/utils/mmlupro.py b/gptqmodel/utils/mmlupro.py index a0f9a2e23..656616428 100644 --- a/gptqmodel/utils/mmlupro.py +++ b/gptqmodel/utils/mmlupro.py @@ -7,9 +7,9 @@ import json import os import random -import pcre as re import time +import pcre as re import torch from datasets import load_dataset from torch.utils.data import DataLoader diff --git a/gptqmodel/utils/model.py b/gptqmodel/utils/model.py index 7c39763f0..3d2642493 100644 --- a/gptqmodel/utils/model.py +++ b/gptqmodel/utils/model.py @@ -11,7 +11,6 @@ import math import operator import os -import pcre as re import shutil import struct import threading @@ -22,6 +21,7 @@ from typing import Any, Dict, List, Optional, Tuple, Type, Union import accelerate +import pcre as re import torch import torch.nn as nn import transformers diff --git a/gptqmodel/utils/safe.py b/gptqmodel/utils/safe.py index fac346ea6..05e9551f1 100644 --- a/gptqmodel/utils/safe.py +++ b/gptqmodel/utils/safe.py @@ -12,6 +12,7 @@ from functools import wraps from types import ModuleType +import threadpoolctl as _threadpoolctl import torch @@ -97,8 +98,10 @@ def __repr__(self): TORCH_LINALG = ThreadSafe(torch.linalg) +THREADPOOLCTL = ThreadSafe(_threadpoolctl) __all__ = [ "ThreadSafe", "TORCH_LINALG", + "THREADPOOLCTL", ] diff --git a/gptqmodel/utils/structure.py b/gptqmodel/utils/structure.py index 27d411add..1707d89b3 100644 --- a/gptqmodel/utils/structure.py +++ b/gptqmodel/utils/structure.py @@ -24,9 +24,9 @@ - Collapsing is generic: any numeric-indexed ModuleList whose qualified name matches `experts-regex`. """ -import pcre as re from typing import Dict, Iterable, Optional, Set, Tuple +import pcre as re import torch from torch import nn diff --git a/gptqmodel/utils/torch.py b/gptqmodel/utils/torch.py index d8c323c93..93f42a1e3 100644 --- a/gptqmodel/utils/torch.py +++ b/gptqmodel/utils/torch.py @@ -179,17 +179,42 @@ def torch_new_stream_ctx(): return contextlib.nullcontext() def torch_sync(device: torch.device = None): - # check all backends + """Synchronize accelerator queues. + + When no device is provided we synchronize every detected accelerator index so + replication work staged on multiple GPUs/NPUs completes before issuing more + kernels. + """ + if device is None: + synchronized_any = False + if HAS_CUDA: - torch.cuda.synchronize() - elif HAS_XPU: - torch.xpu.synchronize() - elif HAS_MPS: + dev_count = torch.cuda.device_count() + if dev_count: + synchronized_any = True + for idx in range(dev_count): + torch.cuda.synchronize(idx) + + if HAS_XPU and hasattr(torch.xpu, "device_count"): + dev_count = torch.xpu.device_count() + if dev_count: + synchronized_any = True + for idx in range(dev_count): + torch.xpu.synchronize(idx) + + if HAS_MPS: + synchronized_any = True torch.mps.synchronize() - elif HAS_NPU: - torch.npu.synchronize() - else: + + if HAS_NPU and hasattr(torch.npu, "device_count"): + dev_count = torch.npu.device_count() + if dev_count: + synchronized_any = True + for idx in range(dev_count): + torch.npu.synchronize(idx) + + if not synchronized_any: torch.cpu.synchronize() return @@ -200,7 +225,7 @@ def torch_sync(device: torch.device = None): elif device.type == "mps": torch.mps.synchronize() elif device.type == "npu": - torch.npu.synchronize() + torch.npu.synchronize(device=device) elif device.type == "cpu": torch.cpu.synchronize() diff --git a/setup.py b/setup.py index b9f62cd64..820335b33 100644 --- a/setup.py +++ b/setup.py @@ -3,11 +3,11 @@ # SPDX-License-Identifier: Apache-2.0 # Contact: qubitium@modelcloud.ai, x.com/qubitium import os -import pcre as re import subprocess import sys from pathlib import Path +import pcre as re from setuptools import find_packages, setup from setuptools.command.bdist_wheel import bdist_wheel as _bdist_wheel diff --git a/tests/tasks/gpqa/cot_n_shot/utils.py b/tests/tasks/gpqa/cot_n_shot/utils.py index f960f95e4..01ee60695 100644 --- a/tests/tasks/gpqa/cot_n_shot/utils.py +++ b/tests/tasks/gpqa/cot_n_shot/utils.py @@ -4,9 +4,9 @@ # Contact: qubitium@modelcloud.ai, x.com/qubitium import random -import pcre as re import datasets +import pcre as re def preprocess(text): diff --git a/tests/tasks/gpqa/cot_zeroshot/utils.py b/tests/tasks/gpqa/cot_zeroshot/utils.py index f960f95e4..01ee60695 100644 --- a/tests/tasks/gpqa/cot_zeroshot/utils.py +++ b/tests/tasks/gpqa/cot_zeroshot/utils.py @@ -4,9 +4,9 @@ # Contact: qubitium@modelcloud.ai, x.com/qubitium import random -import pcre as re import datasets +import pcre as re def preprocess(text): diff --git a/tests/tasks/gpqa/generative/utils.py b/tests/tasks/gpqa/generative/utils.py index f960f95e4..01ee60695 100644 --- a/tests/tasks/gpqa/generative/utils.py +++ b/tests/tasks/gpqa/generative/utils.py @@ -4,9 +4,9 @@ # Contact: qubitium@modelcloud.ai, x.com/qubitium import random -import pcre as re import datasets +import pcre as re def preprocess(text): diff --git a/tests/tasks/gpqa/n_shot/utils.py b/tests/tasks/gpqa/n_shot/utils.py index edc6a2106..44dd386ba 100644 --- a/tests/tasks/gpqa/n_shot/utils.py +++ b/tests/tasks/gpqa/n_shot/utils.py @@ -4,9 +4,9 @@ # Contact: qubitium@modelcloud.ai, x.com/qubitium import random -import pcre as re import datasets +import pcre as re def preprocess(text): diff --git a/tests/tasks/gpqa/zeroshot/utils.py b/tests/tasks/gpqa/zeroshot/utils.py index 79772b3d4..5b8aed281 100644 --- a/tests/tasks/gpqa/zeroshot/utils.py +++ b/tests/tasks/gpqa/zeroshot/utils.py @@ -4,9 +4,9 @@ # Contact: qubitium@modelcloud.ai, x.com/qubitium import random -import pcre as re import datasets +import pcre as re def preprocess(text): diff --git a/tests/tasks/hellaswag/utils.py b/tests/tasks/hellaswag/utils.py index 7de3c7984..60e4c2817 100644 --- a/tests/tasks/hellaswag/utils.py +++ b/tests/tasks/hellaswag/utils.py @@ -3,9 +3,8 @@ # SPDX-License-Identifier: Apache-2.0 # Contact: qubitium@modelcloud.ai, x.com/qubitium -import pcre as re - import datasets +import pcre as re def preprocess(text): diff --git a/tests/tasks/mmlu/flan_cot_zeroshot/utils.py b/tests/tasks/mmlu/flan_cot_zeroshot/utils.py index e53e1db7e..163502ae8 100644 --- a/tests/tasks/mmlu/flan_cot_zeroshot/utils.py +++ b/tests/tasks/mmlu/flan_cot_zeroshot/utils.py @@ -3,10 +3,10 @@ # SPDX-License-Identifier: Apache-2.0 # Contact: qubitium@modelcloud.ai, x.com/qubitium -import pcre as re import sys import unicodedata +import pcre as re from lm_eval.filters.extraction import RegexFilter diff --git a/tests/tasks/mmlu/flan_n_shot/generative/utils.py b/tests/tasks/mmlu/flan_n_shot/generative/utils.py index e53e1db7e..163502ae8 100644 --- a/tests/tasks/mmlu/flan_n_shot/generative/utils.py +++ b/tests/tasks/mmlu/flan_n_shot/generative/utils.py @@ -3,10 +3,10 @@ # SPDX-License-Identifier: Apache-2.0 # Contact: qubitium@modelcloud.ai, x.com/qubitium -import pcre as re import sys import unicodedata +import pcre as re from lm_eval.filters.extraction import RegexFilter diff --git a/tests/test_awq.py b/tests/test_awq.py index 539db1d7d..943f76995 100644 --- a/tests/test_awq.py +++ b/tests/test_awq.py @@ -15,9 +15,9 @@ from transformers import AutoTokenizer from gptqmodel.nn_modules.qlinear.awq_gemm import AwqGEMMQuantLinear -from gptqmodel.nn_modules.qlinear.awq_marlin import AwqMarlinQuantLinear from gptqmodel.nn_modules.qlinear.awq_gemv import AwqGEMVQuantLinear from gptqmodel.nn_modules.qlinear.awq_gemv_fast import AwqGEMVFastQuantLinear +from gptqmodel.nn_modules.qlinear.awq_marlin import AwqMarlinQuantLinear from gptqmodel.quantization import FORMAT, METHOD, QUANT_CONFIG_FILENAME from gptqmodel.utils.torch import torch_empty_cache diff --git a/tests/test_torch_weight_cache.py b/tests/test_torch_weight_cache.py index 86a617b0e..c63ef6d3e 100644 --- a/tests/test_torch_weight_cache.py +++ b/tests/test_torch_weight_cache.py @@ -2,9 +2,9 @@ # SPDX-License-Identifier: Apache-2.0 # Contact: qubitium@modelcloud.ai, x.com/qubitium +import pytest import torch import torch.nn as nn -import pytest from gptqmodel.nn_modules.qlinear.lookahead import configure_default_lookahead from gptqmodel.nn_modules.qlinear.torch import TorchQuantLinear