From 93a92197868488f6f0da619ef0c19d2e24629fa3 Mon Sep 17 00:00:00 2001 From: Qubitium Date: Tue, 7 Oct 2025 13:44:37 +0000 Subject: [PATCH 1/6] update Signed-off-by: Qubitium --- tests/models/test_qwen3_moe.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/models/test_qwen3_moe.py b/tests/models/test_qwen3_moe.py index 7541ca609..531b93995 100644 --- a/tests/models/test_qwen3_moe.py +++ b/tests/models/test_qwen3_moe.py @@ -18,9 +18,9 @@ class TestQwen3Moe(ModelTest): DEBUG = True ACT_GROUP_AWARE = True DESC_ACT = False - DATASET_SIZE = 1024 + DATASET_SIZE = 2048 DATASET_SORT = "desc" - QUANT_BATCH_SIZE = 1 + QUANT_BATCH_SIZE = 8 CALIB_NOISE_MODE = "unseen" CALIB_NOISE_PERCENT = 0.025 From ec29796065f15b08a87fa5e01ba5d3f13679dffb Mon Sep 17 00:00:00 2001 From: Qubitium Date: Wed, 8 Oct 2025 00:52:07 +0000 Subject: [PATCH 2/6] normalize "cuda" to "cuda:0" so we don't have potential device mismatch and duplicate modules for single gpu env Signed-off-by: Qubitium --- gptqmodel/utils/looper_helpers.py | 40 ++++++++++++++++++++++++------- 1 file changed, 31 insertions(+), 9 deletions(-) diff --git a/gptqmodel/utils/looper_helpers.py b/gptqmodel/utils/looper_helpers.py index 47262c067..a0d665463 100644 --- a/gptqmodel/utils/looper_helpers.py +++ b/gptqmodel/utils/looper_helpers.py @@ -164,16 +164,35 @@ def maybe_clear(obj: torch.nn.Module): return cleared +def _canonical_device(device: torch.device) -> torch.device: + """Return a canonical form so indexless accelerators collapse to device:0.""" + if device.type in {"cuda", "xpu", "npu"}: + index = device.index if device.index is not None else 0 + return torch.device(f"{device.type}:{index}") + return device + + def select_forward_devices(base_device: Optional[torch.device]) -> List[torch.device]: if base_device is None: return [CPU] - devices = [base_device] - base_type = base_device.type - if base_type in ("cuda", "xpu", "mps"): + devices: List[torch.device] = [] + seen: set[tuple[str, int | None]] = set() + + def _add(device: torch.device) -> None: + canonical = _canonical_device(device) + key = (canonical.type, canonical.index) + if key in seen: + return + seen.add(key) + devices.append(canonical) + + _add(base_device) + base_type = devices[0].type + if base_type in {"cuda", "xpu", "mps", "npu"}: for dev in ALL_DEVICES: - if dev.type == base_type and dev not in devices: - devices.append(dev) + if dev.type == base_type: + _add(dev) return devices @@ -181,10 +200,13 @@ def normalize_device_like(device_like) -> Optional[torch.device]: if device_like is None: return None if isinstance(device_like, torch.device): - return device_like - if hasattr(device_like, "to_torch_device"): - return device_like.to_torch_device() - return torch.device(str(device_like)) + device = device_like + elif hasattr(device_like, "to_torch_device"): + device = device_like.to_torch_device() + else: + device = torch.device(str(device_like)) + + return _canonical_device(device) def clone_module_for_devices( From af9b6cc50780509e230315897f5719567c14cdce Mon Sep 17 00:00:00 2001 From: Qubitium Date: Wed, 8 Oct 2025 09:30:42 +0000 Subject: [PATCH 3/6] fix missing module finalizer logs Signed-off-by: Qubitium --- gptqmodel/looper/module_looper.py | 41 ++++++++++++++++++++++++++----- 1 file changed, 35 insertions(+), 6 deletions(-) diff --git a/gptqmodel/looper/module_looper.py b/gptqmodel/looper/module_looper.py index 2d5b7e36c..c42eb4276 100644 --- a/gptqmodel/looper/module_looper.py +++ b/gptqmodel/looper/module_looper.py @@ -17,6 +17,7 @@ import threading import time +from concurrent.futures import as_completed from contextlib import nullcontext from typing import Dict, List, Optional, TYPE_CHECKING @@ -1241,17 +1242,40 @@ def _finalize_on_worker(process, module, idx, total, module_label, layer_idx): finalize_futures_snapshot = list(finalize_futures) - def _drain_finalize_futures(futures, finalize_pb_local, finalize_count_local, progress_bar): + if finalize_futures_snapshot: + finalize_pb.title( + f"Submodule finalize 0/{finalize_count}" + ).subtitle("Waiting for completions...").draw() + + future_metadata = { + future: (module_label, process, layer_idx) + for future, _, module_label, process, layer_idx in finalize_futures_snapshot + } + + def _drain_finalize_futures( + futures, + finalize_pb_local, + finalize_count_local, + future_metadata_local, + ): try: - for future, idx, module_label, process, layer_idx in futures: + for future in as_completed(futures): + module_label, process, layer_idx = future_metadata_local.get( + future, (None, None, None) + ) + future.result() layer_label = f"Layer {layer_idx}" if layer_idx is not None else "layer ?" display_module = module_label or "" - subtitle = f"{process.name()}: {display_module}" + processor_name = process.name() if process is not None else "" + subtitle = f"{processor_name}: {display_module}" + + finalize_pb_local.next() + completed = finalize_pb_local.step() finalize_pb_local.title( - f"{layer_label} Finalize {idx}/{finalize_count_local}" - ).subtitle(subtitle).next().draw() + f"{layer_label} Finalize {completed}/{finalize_count_local}" + ).subtitle(subtitle).draw() finally: finalize_pb_local.close() @@ -1259,7 +1283,12 @@ def _drain_finalize_futures(futures, finalize_pb_local, finalize_count_local, pr # Drain finalize futures asynchronously so the main loop can continue scheduling work. threading.Thread( target=_drain_finalize_futures, - args=(finalize_futures_snapshot, finalize_pb, finalize_count, pb), + args=( + [future for future, *_ in finalize_futures_snapshot], + finalize_pb, + finalize_count, + future_metadata, + ), name="SubmoduleFinalizeWatcher", daemon=True, ).start() From 7dbb8bff9364640e65a37046ee015a9029855119 Mon Sep 17 00:00:00 2001 From: Qubitium Date: Wed, 8 Oct 2025 22:08:15 +0000 Subject: [PATCH 4/6] add experimental group size 256, 512, 1024 Signed-off-by: Qubitium --- gptqmodel/looper/module_looper.py | 5 +++-- gptqmodel/nn_modules/qlinear/torch.py | 2 +- gptqmodel/nn_modules/qlinear/tritonv2.py | 3 +-- gptqmodel/quantization/config.py | 4 ++-- 4 files changed, 7 insertions(+), 7 deletions(-) diff --git a/gptqmodel/looper/module_looper.py b/gptqmodel/looper/module_looper.py index c42eb4276..9af8ab114 100644 --- a/gptqmodel/looper/module_looper.py +++ b/gptqmodel/looper/module_looper.py @@ -1258,6 +1258,7 @@ def _drain_finalize_futures( finalize_count_local, future_metadata_local, ): + completed_local = 0 try: for future in as_completed(futures): module_label, process, layer_idx = future_metadata_local.get( @@ -1271,10 +1272,10 @@ def _drain_finalize_futures( processor_name = process.name() if process is not None else "" subtitle = f"{processor_name}: {display_module}" + completed_local += 1 finalize_pb_local.next() - completed = finalize_pb_local.step() finalize_pb_local.title( - f"{layer_label} Finalize {completed}/{finalize_count_local}" + f"{layer_label} Finalize {completed_local}/{finalize_count_local}" ).subtitle(subtitle).draw() finally: finalize_pb_local.close() diff --git a/gptqmodel/nn_modules/qlinear/torch.py b/gptqmodel/nn_modules/qlinear/torch.py index 3e9f2af77..bb2de2755 100644 --- a/gptqmodel/nn_modules/qlinear/torch.py +++ b/gptqmodel/nn_modules/qlinear/torch.py @@ -20,7 +20,7 @@ class TorchQuantLinear(PackableQuantLinear): SUPPORTS_BITS = [2, 3, 4, 8] - SUPPORTS_GROUP_SIZE = [-1, 16, 32, 64, 128] + SUPPORTS_GROUP_SIZE = [-1, 16, 32, 64, 128, 256, 512, 1024] SUPPORTS_DESC_ACT = [True, False] SUPPORTS_SYM = [True, False] SUPPORTS_SHARDS = True diff --git a/gptqmodel/nn_modules/qlinear/tritonv2.py b/gptqmodel/nn_modules/qlinear/tritonv2.py index 0e2dbdb76..e92846de5 100644 --- a/gptqmodel/nn_modules/qlinear/tritonv2.py +++ b/gptqmodel/nn_modules/qlinear/tritonv2.py @@ -47,7 +47,7 @@ class TritonModuleMixin: class TritonV2QuantLinear(TorchQuantLinear, TritonModuleMixin): SUPPORTS_BITS = [2, 4, 8] - SUPPORTS_GROUP_SIZE = [-1, 16, 32, 64, 128] + SUPPORTS_GROUP_SIZE = [-1, 16, 32, 64, 128, 256, 512, 1024] SUPPORTS_DESC_ACT = [True, False] SUPPORTS_SYM = [True, False] SUPPORTS_SHARDS = True @@ -207,4 +207,3 @@ def triton_xpu_available(): except Exception: return False - diff --git a/gptqmodel/quantization/config.py b/gptqmodel/quantization/config.py index 1f62de7e1..0b4bb4dfc 100644 --- a/gptqmodel/quantization/config.py +++ b/gptqmodel/quantization/config.py @@ -293,10 +293,10 @@ def __post_init__(self): if key == "bits" and value not in fields_info[0].metadata["choices"]: raise ValueError(f"QuantizeConfig: Layer `{layer}` only support quantization of `{fields_info[0].metadata['choices']}` bits.") elif key == "group_size" and value != -1 and value <= 0: - raise ValueError("QuantizeConfig: `group_size` must in the value set of `[-1, 16, 32, 64, 128]`.") + raise ValueError("QuantizeConfig: `group_size` must be one of `[-1, 16, 32, 64, 128, 256, 512, 1024]`.") if self.group_size != -1 and self.group_size <= 0: - raise ValueError("QuantizeConfig: `group_size` must in the value set of `[-1, 16, 32, 64, 128]`.") + raise ValueError("QuantizeConfig: `group_size` must be one of `[-1, 16, 32, 64, 128, 256, 512, 1024]`.") if not (0 < self.damp_percent < 1): raise ValueError("QuantizeConfig: `damp_percent` must between 0 and 1.") From 681a6c1dafc4d1733853bad6f441bf253ef42864 Mon Sep 17 00:00:00 2001 From: Qubitium Date: Wed, 8 Oct 2025 22:27:58 +0000 Subject: [PATCH 5/6] fix compat with transformers and small non-weight, buffer only modules like rotary embedd Signed-off-by: Qubitium --- gptqmodel/utils/offload.py | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/gptqmodel/utils/offload.py b/gptqmodel/utils/offload.py index a3c544bd5..a81271547 100644 --- a/gptqmodel/utils/offload.py +++ b/gptqmodel/utils/offload.py @@ -26,6 +26,9 @@ from .torch import CPU, META +_SMALL_MODULE_OFFLOAD_BYTES = 4 * 1024 # Skip disk writes for <4KB payloads + + # Patch fix thread unsafe accelerate.utils.modeling.clear_device_cache def _fake_clear_device_cache(garbage_collection=False): pass @@ -74,6 +77,14 @@ def _prepare_offload_directory(target_dir: str) -> None: os.makedirs(target_dir, exist_ok=True) +def _tensor_nbytes(tensor: torch.Tensor) -> int: + try: + itemsize = tensor.element_size() + except RuntimeError: + itemsize = torch.empty((), dtype=tensor.dtype).element_size() + return tensor.numel() * itemsize + + def _bundle_module_state_dict(module: nn.Module, offload_dir: str) -> dict: bundle_path = os.path.join(offload_dir, "module.safetensors") index: dict[str, dict] = {} @@ -177,6 +188,18 @@ def _offload_disk(module: nn.Module, name: str, disk_path: str = "."): module_offload_dir = os.path.join(disk_path, name) + total_bytes = 0 + try: + state_items = module.state_dict().values() + except Exception: + state_items = [] + + for tensor in state_items: + total_bytes += _tensor_nbytes(tensor) + + if total_bytes <= _SMALL_MODULE_OFFLOAD_BYTES: + return + _prepare_offload_directory(module_offload_dir) _bundle_module_state_dict(module, module_offload_dir) From 1312eb806fd19f8716d8e68e726cf5d97aa26715 Mon Sep 17 00:00:00 2001 From: Qubitium Date: Wed, 8 Oct 2025 23:07:04 +0000 Subject: [PATCH 6/6] cleanup Signed-off-by: Qubitium --- gptqmodel/models/auto.py | 3 ++- gptqmodel/models/definitions/__init__.py | 1 + gptqmodel/models/definitions/qwen2.py | 12 ++++++++++++ gptqmodel/models/loader.py | 2 +- gptqmodel/nn_modules/qlinear/__init__.py | 2 +- gptqmodel/utils/logger.py | 3 +-- gptqmodel/utils/model.py | 5 ++--- tests/models/test_qwen2_5.py | 8 ++++---- tests/pytest.ini | 1 + 9 files changed, 25 insertions(+), 12 deletions(-) create mode 100644 gptqmodel/models/definitions/qwen2.py diff --git a/gptqmodel/models/auto.py b/gptqmodel/models/auto.py index 25b114647..eef9d165e 100644 --- a/gptqmodel/models/auto.py +++ b/gptqmodel/models/auto.py @@ -105,6 +105,7 @@ from .definitions.phi3 import Phi3QModel, PhiMoEGPTQForCausalLM # noqa: E402 from .definitions.phi4 import Phi4MMGPTQ # noqa: E402 from .definitions.qwen import QwenQModel # noqa: E402 +from .definitions.qwen2 import Qwen2QModel # noqa: E402 from .definitions.qwen2_5_omni import Qwen2_5_OmniGPTQ from .definitions.qwen2_5_vl import Qwen2_5_VLQModel # noqa: E402 from .definitions.qwen2_moe import Qwen2MoeQModel # noqa: E402 @@ -162,7 +163,7 @@ "stablelm": LlamaQModel, # 100% llama clone "starcoder2": Starcoder2QModel, "mixtral": MixtralQModel, - "qwen2": LlamaQModel, # 100% llama clone + "qwen2": Qwen2QModel, "qwen3": Qwen3QModel, "longllama": LlamaQModel, # 100% llama clone "gemma": LlamaQModel, # 100% llama clone diff --git a/gptqmodel/models/definitions/__init__.py b/gptqmodel/models/definitions/__init__.py index 7072e254d..f5f325cc1 100644 --- a/gptqmodel/models/definitions/__init__.py +++ b/gptqmodel/models/definitions/__init__.py @@ -47,6 +47,7 @@ from .phi import PhiQModel from .phi3 import Phi3QModel from .qwen import QwenQModel +from .qwen2 import Qwen2QModel from .qwen2_5_vl import Qwen2_5_VLQModel from .qwen2_moe import Qwen2MoeQModel from .qwen2_vl import Qwen2VLQModel diff --git a/gptqmodel/models/definitions/qwen2.py b/gptqmodel/models/definitions/qwen2.py new file mode 100644 index 000000000..14539768e --- /dev/null +++ b/gptqmodel/models/definitions/qwen2.py @@ -0,0 +1,12 @@ +# SPDX-FileCopyrightText: 2024-2025 ModelCloud.ai +# SPDX-FileCopyrightText: 2024-2025 qubitium@modelcloud.ai +# SPDX-License-Identifier: Apache-2.0 +# Contact: qubitium@modelcloud.ai, x.com/qubitium + +from __future__ import annotations + +from .llama import LlamaQModel + + +class Qwen2QModel(LlamaQModel): + pass diff --git a/gptqmodel/models/loader.py b/gptqmodel/models/loader.py index 0a7c8becf..896e7ee02 100644 --- a/gptqmodel/models/loader.py +++ b/gptqmodel/models/loader.py @@ -104,7 +104,7 @@ def check_versions(model_class, requirements: List[str]): def get_model_local_path(pretrained_model_id_or_path, **kwargs): is_local = os.path.isdir(pretrained_model_id_or_path) if is_local: - return pretrained_model_id_or_path + return os.path.normpath(pretrained_model_id_or_path) else: # Clone kwargs before modifying download_kwargs = kwargs.copy() diff --git a/gptqmodel/nn_modules/qlinear/__init__.py b/gptqmodel/nn_modules/qlinear/__init__.py index 636362f10..5cedf9e21 100644 --- a/gptqmodel/nn_modules/qlinear/__init__.py +++ b/gptqmodel/nn_modules/qlinear/__init__.py @@ -628,7 +628,7 @@ def _process_block(i0: int, i1: int): # ---------- schedule blocks across a thread pool ---------- starts = list(range(0, in_features, block_in)) ranges = [(i0, min(i0 + block_in, in_features)) for i0 in starts] - total_blocks = len(ranges) + len(ranges) # TODO FIX ME...threads safety issue with threaded block work workers_eff = 1 # max(1, min(workers, total_blocks)) diff --git a/gptqmodel/utils/logger.py b/gptqmodel/utils/logger.py index 4d7cec3a0..d87815545 100644 --- a/gptqmodel/utils/logger.py +++ b/gptqmodel/utils/logger.py @@ -26,10 +26,9 @@ def log_time_block( if logger is None: logger = setup_logger() - label = block_name if not module_name else f"{module_name}: {block_name}" start = time.perf_counter() try: yield finally: - duration = time.perf_counter() - start + time.perf_counter() - start #logger.info(f"[time] {label} took {duration:.3f}s") diff --git a/gptqmodel/utils/model.py b/gptqmodel/utils/model.py index 94696d0c5..7eab07662 100644 --- a/gptqmodel/utils/model.py +++ b/gptqmodel/utils/model.py @@ -22,7 +22,6 @@ from typing import Any, Dict, List, Optional, Tuple, Type, Union import accelerate -import threadpoolctl as tctl import torch import torch.nn as nn import transformers @@ -57,7 +56,7 @@ from .ctx import ctx from .device import get_device from .importer import select_quant_linear -from .logger import setup_logger, log_time_block +from .logger import log_time_block, setup_logger from .torch import HAS_CUDA, torch_empty_cache, torch_new_stream_ctx @@ -597,7 +596,7 @@ def convert_gptq_v1_to_v2_format( # Limit thread usage to avoid auto-parallizataion regression # with tctl.threadpool_limits(limits=1): - t = time.time() + time.time() log.info( f"Format: Converting `{FORMAT_FIELD_CHECKPOINT}` from `{FORMAT.GPTQ}` to internal `{FORMAT.GPTQ_V2}`.") diff --git a/tests/models/test_qwen2_5.py b/tests/models/test_qwen2_5.py index 876a68fe6..6dc107995 100644 --- a/tests/models/test_qwen2_5.py +++ b/tests/models/test_qwen2_5.py @@ -8,12 +8,12 @@ class TestQwen2_5(ModelTest): NATIVE_MODEL_ID = "/monster/data/model/Qwen2.5-0.5B-Instruct" - QUANT_ARC_MAX_DELTA_FLOOR_PERCENT = 0.2 - NATIVE_ARC_CHALLENGE_ACC = 0.2739 - NATIVE_ARC_CHALLENGE_ACC_NORM = 0.3055 + QUANT_ARC_MAX_DELTA_FLOOR_PERCENT = 0.05 + NATIVE_ARC_CHALLENGE_ACC = 0.2705 + NATIVE_ARC_CHALLENGE_ACC_NORM = 0.3063 TRUST_REMOTE_CODE = False APPLY_CHAT_TEMPLATE = True - EVAL_BATCH_SIZE = 6 + #EVAL_BATCH_SIZE = 6 def test_qwen2_5(self): self.quant_lm_eval() diff --git a/tests/pytest.ini b/tests/pytest.ini index b7d9de822..dfd34a073 100644 --- a/tests/pytest.ini +++ b/tests/pytest.ini @@ -4,3 +4,4 @@ log_cli=true norecursedirs = tasks evalplus_results markers = ci: CPU-only CI regression coverage for DeviceThreadPool affinity behaviour + cuda: Requires CUDA device