From ddc4df1d394631eebbb36ba7e0d5bf37155cebc3 Mon Sep 17 00:00:00 2001 From: Qubitium Date: Sat, 18 Oct 2025 22:42:30 +0000 Subject: [PATCH 01/15] update --- tests/models/test_glm4_moe.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/tests/models/test_glm4_moe.py b/tests/models/test_glm4_moe.py index fbd76fc71..b30736155 100644 --- a/tests/models/test_glm4_moe.py +++ b/tests/models/test_glm4_moe.py @@ -4,12 +4,19 @@ # Contact: qubitium@modelcloud.ai, x.com/qubitium from model_test import ModelTest - +from gptqmodel.utils.eval import EVAL class TestGlm4Moe(ModelTest): NATIVE_MODEL_ID = "/monster/data/_ci_/GLM-4.5-Air/" - NATIVE_ARC_CHALLENGE_ACC = 0.3805 - NATIVE_ARC_CHALLENGE_ACC_NORM = 0.4078 + EVAL_TASKS = { + EVAL.LM_EVAL.ARC_CHALLENGE: { + "acc": {"value": 0.5026, "floor_pct": 0.04}, + "acc_norm": {"value": 0.5171, "floor_pct": 0.04}, + }, + EVAL.LM_EVAL.MMLU: { + "acc": {"value": 0.6362, "floor_pct": 0.04}, + }, + } def test_glm4moe(self): self.quant_lm_eval() From a5d8d7091db6e1ccc354414b1f34b48c706f8c9d Mon Sep 17 00:00:00 2001 From: Qubitium Date: Sun, 19 Oct 2025 03:36:22 +0000 Subject: [PATCH 02/15] reuse streams --- gptqmodel/utils/stream.py | 28 +++++++++++++++++++++++++++- 1 file changed, 27 insertions(+), 1 deletion(-) diff --git a/gptqmodel/utils/stream.py b/gptqmodel/utils/stream.py index 0cba0a7bc..16a0327d8 100644 --- a/gptqmodel/utils/stream.py +++ b/gptqmodel/utils/stream.py @@ -65,6 +65,26 @@ def _build_stream_pool() -> Tuple[DeviceThreadPool, Dict[str, str]]: STREAM_DEVICE_POOL, _ACCELERATOR_ALIAS_MAP = _build_stream_pool() +_STREAM_CACHE_LOCK = threading.RLock() +_CUDA_COPY_STREAMS: Dict[int, torch.cuda.Stream] = {} + + +def _resolve_device_index(device: torch.device) -> int: + index = device.index + if index is not None: + return index + return torch.cuda.current_device() + +# reuse streams instead of creating tons of new streams +def _get_cached_copy_stream(device: torch.device) -> torch.cuda.Stream: + idx = _resolve_device_index(device) + with _STREAM_CACHE_LOCK: + stream = _CUDA_COPY_STREAMS.get(idx) + if stream is None: + stream = torch.cuda.Stream(device=torch.device("cuda", idx)) + _CUDA_COPY_STREAMS[idx] = stream + return stream + def _queue_key_for_device(device: Optional[torch.device]) -> str: if device is None: @@ -168,6 +188,12 @@ def stream_tensor_dict_to_cpu( if not filtered: return {} + # sync copy + # host_map = {name: tensor.detach().to("cpu") for name, tensor in filtered.items()} + # with state_lock: + # store_callback(host_map) + # return host_map + first = next(iter(filtered.values())) if first.device.type != "cuda" or not torch.cuda.is_available(): @@ -180,7 +206,7 @@ def stream_tensor_dict_to_cpu( copy_device = first.device compute_stream = torch.cuda.current_stream(device=copy_device) - copy_stream = torch.cuda.Stream(device=copy_device) + copy_stream = _get_cached_copy_stream(copy_device) done_event = torch.cuda.Event(enable_timing=False, blocking=False) pending_sources: List[torch.Tensor] = [] From 6637fa7bd15a5f1e95c3ae122566dd1e5314f36c Mon Sep 17 00:00:00 2001 From: Qubitium Date: Sun, 19 Oct 2025 03:45:57 +0000 Subject: [PATCH 03/15] reduce memory --- gptqmodel/quantization/gptq.py | 19 +++++++++---------- 1 file changed, 9 insertions(+), 10 deletions(-) diff --git a/gptqmodel/quantization/gptq.py b/gptqmodel/quantization/gptq.py index 9583f0664..59b341a15 100644 --- a/gptqmodel/quantization/gptq.py +++ b/gptqmodel/quantization/gptq.py @@ -3,7 +3,7 @@ # SPDX-License-Identifier: Apache-2.0 # Contact: qubitium@modelcloud.ai, x.com/qubitium -# adapted from @qwopqwop200 's [GPTQ-for-LLaMa](https://github.com/qwopqwop200/GPTQ-for-LLaMa/tree/cuda), which itself is based on [gptq](https://github.com/IST-DASLab/gptq) +# Based on original gptq algorithm and code from https://github.com/IST-DASLab/gptq import contextlib import math @@ -559,33 +559,32 @@ def hf_quantize( @torch.inference_mode() def hessian_inverse(self, H: torch.Tensor): - damp = self.qcfg.damp_percent - diag = torch.arange(self.columns, device=H.device) mean = torch.mean(torch.diag(H)) + + orig_diag = H.diag().clone() while 0 < damp < 1: try: - H2 = H.clone() - H2[diag, diag] += damp * mean - # TODO call to torch.linalg is not threadsafe? Porque no? Esta muy mal. - H2 = torch.linalg.cholesky(H2) + H.diagonal().add_(damp * mean) + H2 = torch.linalg.cholesky(H) Hinv = torch.linalg.cholesky(torch.cholesky_inverse(H2), upper=True) - del H, H2 + H.diagonal().copy_(orig_diag) + del H2 break except torch._C._LinAlgError as e: + H.diagonal().copy_(orig_diag) if self.qcfg.damp_auto_increment != 0: log.warn( f"Quantization: Module `{self.name}` -> Current `damp_percent = {damp:.5f}` is too low, auto-incrementing by `{self.qcfg.damp_auto_increment:.5f}`") damp += self.qcfg.damp_auto_increment else: log.warn( - "Quantization: Module `{self.name}` -> Please increase damp or nsamples for calibration data to avoid the following quant error: current damp_percent=`{damp_percent:.5f}`") + "Quantization: Module `{self.name}` -> Please increase damp or nsamples for calibration data to avoid the following quant error: current damp_percent=`{damp:.5f}`") raise e if not (0 < damp < 1): log.error( f"Quantization: Module `{self.name}` -> `damp_percent` must between 0 and 1. current is {damp}. Module cannot be correctly processed.") - # raise ValueError(f"Quantization: `damp_percent` must between 0 and 1. current is {damp}") return None, 1.0 return Hinv, damp From 6074b3e7c3d6850e1667b2f098c9bb9e1fc91e0f Mon Sep 17 00:00:00 2001 From: Qubitium Date: Sun, 19 Oct 2025 10:44:07 +0000 Subject: [PATCH 04/15] correctly tag optional glm4 moe q/k_norm modules --- gptqmodel/models/definitions/glm4_moe.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/gptqmodel/models/definitions/glm4_moe.py b/gptqmodel/models/definitions/glm4_moe.py index 59c958367..ac64d111a 100644 --- a/gptqmodel/models/definitions/glm4_moe.py +++ b/gptqmodel/models/definitions/glm4_moe.py @@ -28,7 +28,7 @@ class GLM4MoEGPTQ(BaseQModel): "#", { "input_layernorm": ("input_layernorm:!",), - "self_attn": ("q_proj:0", "k_proj:0", "v_proj:0", "o_proj:1"), + "self_attn": ("q_proj:0", "q_norm:0:!","k_proj:0", "k_norm:0:!", "v_proj:0", "o_proj:1"), "post_attention_layernorm": ("post_attention_layernorm:!",), "mlp": { "shared_experts": { @@ -36,7 +36,7 @@ class GLM4MoEGPTQ(BaseQModel): "up_proj": ("up_proj:0",), "down_proj": ("down_proj:1",), }, - "gate": ("gate:!",), + "gate": ("gate:!",), # Glm4MoeTopKRouter, ~1.6MB float32 per layer. We really do not quant to quantize this. "experts": { "#": ("gate_proj:0", "up_proj:0", "down_proj:1"), }, From 62eb84e20ec6e62a1ee26436c410f7af1b2bae47 Mon Sep 17 00:00:00 2001 From: Qubitium Date: Sun, 19 Oct 2025 10:44:54 +0000 Subject: [PATCH 05/15] gc min interval to reject too closed grouped gc requests --- gptqmodel/utils/threadx.py | 34 +++++++++++++++++++++++++++++++++- tests/test_threadx_janitor.py | 1 + 2 files changed, 34 insertions(+), 1 deletion(-) diff --git a/gptqmodel/utils/threadx.py b/gptqmodel/utils/threadx.py index a3a64fc47..9ac0e588c 100644 --- a/gptqmodel/utils/threadx.py +++ b/gptqmodel/utils/threadx.py @@ -559,6 +559,7 @@ def __init__( empty_cache_every_n: int = 50, # <=0 disables janitor workers: Optional[Dict[str, int]] = None, # e.g. {'cpu':4, 'cuda:per':1, 'cuda:0':3} gc_debounce_seconds: float = 0.02, # absorb bursty triggers before GC + gc_min_interval_seconds: float = 1.0, # throttle janitor passes pin_cpu_workers: bool = False, pin_accelerator_workers: bool = False, ): @@ -579,6 +580,7 @@ def __init__( warmups: optional mapping from device family (e.g. 'cuda') to a callable run once after the worker activates its device. A special key 'default' applies when no family-specific warmup is found. + gc_min_interval_seconds: minimum interval between GC passes. Values <= 0 disable throttling. pin_cpu_workers: bind CPU device workers to individual CPU cores when affinity APIs are available. Defaults to False so CPU tasks may float across cores unless explicitly opt-in. @@ -633,6 +635,7 @@ def __init__( # GC dedupe/coalesce: debounce window to absorb bursty triggers; # per-device "done" watermark to skip redundant GC passes. self._gc_debounce_s = float(gc_debounce_seconds) + self._gc_min_interval_s = max(0.0, float(gc_min_interval_seconds)) self._last_gc_done_per_device: Dict[str, int] = {} self._inference_mode = bool(inference_mode) @@ -756,7 +759,10 @@ def __init__( ) self._janitor.start() if DEBUG_ON: - log.debug(f"DP-Janitor thread started (debounce={self._gc_debounce_s:.3f}s, threshold={self._empty_cache_every_n})") + log.debug( + f"DP-Janitor thread started (debounce={self._gc_debounce_s:.3f}s, " + f"min_interval={self._gc_min_interval_s:.3f}s, threshold={self._empty_cache_every_n})" + ) else: if DEBUG_ON: log.debug("DP-Janitor disabled (no accelerators or threshold <= 0)") @@ -1655,12 +1661,38 @@ def _janitor_loop(self): with self._stats_lock: current_generation = self._gc_generation last_generation = self._last_consumed_gc_generation + last_gc_ts = self._last_gc_ts if current_generation == last_generation: if DEBUG_ON: log.debug("DP-Janitor: trigger generation already consumed; skipping") continue + min_interval = self._gc_min_interval_s + if min_interval > 0.0 and last_gc_ts is not None: + elapsed = time.time() - last_gc_ts + if elapsed < min_interval: + wait_for = min_interval - elapsed + if DEBUG_ON: + log.debug( + f"DP-Janitor: last pass {elapsed * 1000:.1f}ms ago; waiting {wait_for * 1000:.1f}ms to honor min interval" + ) + if self._stop_event.wait(timeout=wait_for): + if DEBUG_ON: + log.debug("DP-Janitor: stop event set during min-interval wait; exiting") + break + if self._stop_event.is_set(): + if DEBUG_ON: + log.debug("DP-Janitor: stop event observed after min-interval wait; exiting") + break + with self._stats_lock: + current_generation = self._gc_generation + last_generation = self._last_consumed_gc_generation + if current_generation == last_generation: + if DEBUG_ON: + log.debug("DP-Janitor: no pending GC generation after min-interval wait; skipping") + continue + # Snapshot & decision try: pre = self._collect_state_snapshot() diff --git a/tests/test_threadx_janitor.py b/tests/test_threadx_janitor.py index b85708f4e..77cbd4d94 100644 --- a/tests/test_threadx_janitor.py +++ b/tests/test_threadx_janitor.py @@ -24,6 +24,7 @@ def _make_pool(): pool._auto_gc_disable_cv = threading.Condition() pool._auto_gc_disable_count = 0 pool._gc_debounce_s = 0.0 + pool._gc_min_interval_s = 0.0 pool._stats_lock = threading.Lock() pool._per_device_done = {} pool._total_done = 0 From 8166c0d2dfab3c2fb3b505f3dfde97b4f165b488 Mon Sep 17 00:00:00 2001 From: Qubitium Date: Sun, 19 Oct 2025 12:07:37 +0000 Subject: [PATCH 06/15] fix ci model test should switch to torch if marlin fails validation for config --- tests/models/model_test.py | 61 +++++++++++++++++++++++++++++++++----- 1 file changed, 53 insertions(+), 8 deletions(-) diff --git a/tests/models/model_test.py b/tests/models/model_test.py index 3658a1212..ce1a8f067 100644 --- a/tests/models/model_test.py +++ b/tests/models/model_test.py @@ -356,13 +356,55 @@ def run_eval_tasks(self, model, backend, trust_remote_code=False): self.LOAD_BACKEND = previous_backend return task_results + def _current_load_backend(self): + effective = getattr(self, "_effective_load_backend", None) + if effective is not None and self.LOAD_BACKEND == BACKEND.MARLIN: + return effective + return self.LOAD_BACKEND + def perform_post_quant_validation(self, model_path, trust_remote_code=False): inference_records = {} eval_records = {} reuse_candidates = {} compare_backends = (BACKEND.MARLIN,) if self.FORMAT is FORMAT.GPTQ else (BACKEND.MARLIN, BACKEND.GEMM) - target_backend = self.LOAD_BACKEND + fallback_backend = None + if BACKEND.MARLIN in compare_backends: + try: + from gptqmodel.nn_modules.qlinear.marlin import MarlinQuantLinear # type: ignore + except Exception: # pragma: no cover - fallback if module unavailable + marlin_group_sizes = () + marlin_sym = () + else: + marlin_group_sizes = tuple(getattr(MarlinQuantLinear, "SUPPORTS_GROUP_SIZE", ())) + marlin_sym = tuple(getattr(MarlinQuantLinear, "SUPPORTS_SYM", ())) + + requested_group_size = getattr(self, "GROUP_SIZE", None) + requested_sym = getattr(self, "SYM", None) + + marlin_supported = True + if marlin_group_sizes and requested_group_size not in marlin_group_sizes: + marlin_supported = False + if marlin_sym and requested_sym not in marlin_sym: + marlin_supported = False + + if not marlin_supported: + fallback_backend = BACKEND.TORCH + compare_backends = tuple( + BACKEND.TORCH if backend == BACKEND.MARLIN else backend + for backend in compare_backends + ) + log.info( + f"Marlin backend unsupported for current quant config (group_size={requested_group_size}, sym={requested_sym}); " + "falling back to BACKEND.TORCH for validation." + ) + + if fallback_backend is not None and self.LOAD_BACKEND == BACKEND.MARLIN: + self._effective_load_backend = fallback_backend + else: + self._effective_load_backend = None + + target_backend = self._current_load_backend() can_reuse = target_backend not in (BACKEND.AUTO, BACKEND.AUTO_TRAINABLE) for backend in compare_backends: @@ -702,6 +744,7 @@ def quantModel(self, model_id_or_path, trust_remote_code=False, dtype="auto", ne tokenizer = model.tokenizer self._post_quant_eval_records = {} + self._effective_load_backend = None is_image_to_text_model = MODALITY.IMAGE_TO_TEXT in model.modality calibration_dataset = get_calib_dataset(model) if is_image_to_text_model else self.load_dataset(tokenizer, self.DATASET_SIZE) @@ -732,22 +775,23 @@ def quantModel(self, model_id_or_path, trust_remote_code=False, dtype="auto", ne self._print_post_quant_artifacts(path) log.info(f"Quantized Model saved to tmp dir: {path}") - target_backend = self.LOAD_BACKEND reuse_candidates, eval_records = self.perform_post_quant_validation(path, trust_remote_code=trust_remote_code) self._post_quant_eval_records = eval_records + target_backend = self._current_load_backend() q_model = reuse_candidates.pop(target_backend, None) if q_model is None: # Ensure the post-quant reload stays on a single CUDA device when available. - use_cuda_map = torch.cuda.is_available() and self.LOAD_BACKEND != BACKEND.TORCH_FUSED + use_cuda_map = torch.cuda.is_available() and target_backend != BACKEND.TORCH_FUSED if use_cuda_map: q_model = self.loadQuantModel( path, trust_remote_code=trust_remote_code, + backend=target_backend, device_map={"": "cuda:0"}, ) else: - q_model = self.loadQuantModel(path, trust_remote_code=trust_remote_code) + q_model = self.loadQuantModel(path, trust_remote_code=trust_remote_code, backend=target_backend) else: log.info(f"Reusing post-quant validation model for backend `{target_backend.name}`") @@ -781,7 +825,7 @@ def loadQuantModel(self, model_id_or_path, trust_remote_code=False, tokenizer_pa else: log.warn("flash-attn requested but not available; falling back to framework defaults") - active_backend = backend if backend is not None else self.LOAD_BACKEND + active_backend = backend if backend is not None else self._current_load_backend() default_device_map = {"": "cpu"} if active_backend == BACKEND.TORCH_FUSED else "auto" explicit_device = "device" in load_kwargs @@ -847,7 +891,8 @@ def lm_eval(self, model, apply_chat_template=False, trust_remote_code=False, del task_groups = EVAL.get_task_groups_from_tasks(task_names) for framework, tasks in task_groups.items(): - log.info(f"TEST: EVAL starting: backend = {self.LOAD_BACKEND}") + active_backend = self._current_load_backend() + log.info(f"TEST: EVAL starting: backend = {active_backend.name}") if model_path: log.info(f"Inference from model path: {model_path}") @@ -875,7 +920,7 @@ def lm_eval(self, model, apply_chat_template=False, trust_remote_code=False, del llm_backend="vllm" if self.USE_VLLM else "gptqmodel", model_args=model_args, output_path=tmp_dir, - backend=self.LOAD_BACKEND, + backend=active_backend, framework=framework, tasks=eval_tasks, apply_chat_template=apply_chat_template, @@ -953,7 +998,7 @@ def quant_lm_eval(self): self.check_kernel(self.model, self.KERNEL_INFERENCE) eval_records = getattr(self, "_post_quant_eval_records", {}) - target_backend = self.LOAD_BACKEND + target_backend = self._current_load_backend() if eval_records and len(eval_records) == 1 and target_backend in eval_records: log.info("Reusing evaluation results for backend `%s`; skipping duplicate lm_eval run", target_backend.name) task_results = eval_records[target_backend] From 5779163d045c380474895f8036c137904ed90036 Mon Sep 17 00:00:00 2001 From: Qubitium Date: Sun, 19 Oct 2025 12:18:39 +0000 Subject: [PATCH 07/15] use glm 4.6 for test --- tests/models/test_glm4_moe.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/models/test_glm4_moe.py b/tests/models/test_glm4_moe.py index b30736155..74540ef43 100644 --- a/tests/models/test_glm4_moe.py +++ b/tests/models/test_glm4_moe.py @@ -7,7 +7,7 @@ from gptqmodel.utils.eval import EVAL class TestGlm4Moe(ModelTest): - NATIVE_MODEL_ID = "/monster/data/_ci_/GLM-4.5-Air/" + NATIVE_MODEL_ID = "/monster/data/model/GLM-4.6/" EVAL_TASKS = { EVAL.LM_EVAL.ARC_CHALLENGE: { "acc": {"value": 0.5026, "floor_pct": 0.04}, From 2130ae2ee38742fc034614066b755da8ff5b3f2e Mon Sep 17 00:00:00 2001 From: Qubitium Date: Sun, 19 Oct 2025 13:15:04 +0000 Subject: [PATCH 08/15] update scores for qwen3 next --- tests/models/test_qwen3_next.py | 39 +++++++++++++++++++-------------- 1 file changed, 23 insertions(+), 16 deletions(-) diff --git a/tests/models/test_qwen3_next.py b/tests/models/test_qwen3_next.py index 923834440..882473af6 100644 --- a/tests/models/test_qwen3_next.py +++ b/tests/models/test_qwen3_next.py @@ -6,28 +6,35 @@ from model_test import ModelTest from gptqmodel.utils.eval import EVAL - +# | Metric | MARLIN | +# |--------------------------------|----------| +# | arc_challenge :: acc,none | 0.6271 | +# | arc_challenge :: acc_norm,none | 0.6613 | +# | mmlu :: acc,none | 0.8403 | class TestQwen3Next(ModelTest): NATIVE_MODEL_ID = "/monster/data/model/Qwen3-Next-80B-A3B-Instruct" EVAL_TASKS = { EVAL.LM_EVAL.ARC_CHALLENGE: { - "acc": {"value": 0.3900, "floor_pct": 0.04}, - "acc_norm": {"value": 0.3900, "floor_pct": 0.04}, + "acc": {"value": 0.6271, "floor_pct": 0.04}, + "acc_norm": {"value": 0.6613, "floor_pct": 0.04}, + }, + EVAL.LM_EVAL.MMLU: { + "acc": {"value": 0.8403, "floor_pct": 0.04}, }, } - TRUST_REMOTE_CODE = True - APPLY_CHAT_TEMPLATE = True - EVAL_BATCH_SIZE = 4 - V2 = False - DEBUG = True - ACT_GROUP_AWARE = True - DESC_ACT = False - DATASET_SIZE = 1024 - DATASET_SORT = "desc" - QUANT_BATCH_SIZE = 4 - CALIB_NOISE_MODE = "unseen" - CALIB_NOISE_PERCENT = 0.025 - USE_FLASH_ATTN = True + # TRUST_REMOTE_CODE = True + # APPLY_CHAT_TEMPLATE = True + # EVAL_BATCH_SIZE = 4 + # V2 = False + # DEBUG = True + # ACT_GROUP_AWARE = True + # DESC_ACT = False + # DATASET_SIZE = 1024 + # DATASET_SORT = "desc" + # QUANT_BATCH_SIZE = 4 + # CALIB_NOISE_MODE = "unseen" + # CALIB_NOISE_PERCENT = 0.025 + # USE_FLASH_ATTN = True def test_mimo(self): self.quant_lm_eval() From d6051cc4db2334c7e07c58d677f7fc6633a5802a Mon Sep 17 00:00:00 2001 From: Qubitium Date: Sun, 19 Oct 2025 13:57:11 +0000 Subject: [PATCH 09/15] granular gc per device:index --- gptqmodel/utils/threadx.py | 448 ++++++++++++++++++++++++++++++------- 1 file changed, 369 insertions(+), 79 deletions(-) diff --git a/gptqmodel/utils/threadx.py b/gptqmodel/utils/threadx.py index 9ac0e588c..c55098688 100644 --- a/gptqmodel/utils/threadx.py +++ b/gptqmodel/utils/threadx.py @@ -6,6 +6,7 @@ from __future__ import annotations import contextlib +import inspect import os import queue import sys @@ -18,6 +19,11 @@ import torch +try: + from device_smi import Device # type: ignore +except Exception: # pragma: no cover - defensive: optional dependency may be unavailable + Device = None + from .. import DEBUG_ON from ..utils.ctx import ctx from ..utils.logger import setup_logger @@ -79,6 +85,42 @@ def _mps_available() -> bool: # --------------------------- Device coercion & context helpers --------------------------- +_EMPTY_CACHE_SIGNATURE_CACHE: Dict[int, Tuple[bool, bool]] = {} + + +def _analyze_empty_cache_callable(fn: Callable[..., Any]) -> Tuple[bool, bool]: + """ + Inspect an empty_cache callable and determine whether it accepts a `device` + keyword argument or at least one positional argument. Results are memoized. + """ + cache_key = id(fn) + cached = _EMPTY_CACHE_SIGNATURE_CACHE.get(cache_key) + if cached is not None: + return cached + + supports_kw = False + supports_pos = False + try: + sig = inspect.signature(fn) + except (TypeError, ValueError): + _EMPTY_CACHE_SIGNATURE_CACHE[cache_key] = (supports_kw, supports_pos) + return supports_kw, supports_pos + + for param in sig.parameters.values(): + if param.kind == inspect.Parameter.VAR_KEYWORD: + supports_kw = True + elif param.kind == inspect.Parameter.VAR_POSITIONAL: + supports_pos = True + elif param.kind in (inspect.Parameter.POSITIONAL_ONLY, inspect.Parameter.POSITIONAL_OR_KEYWORD): + supports_pos = True + if param.name == "device": + supports_kw = True + elif param.kind == inspect.Parameter.KEYWORD_ONLY and param.name == "device": + supports_kw = True + + _EMPTY_CACHE_SIGNATURE_CACHE[cache_key] = (supports_kw, supports_pos) + return supports_kw, supports_pos + def _coerce_device(d: DeviceLike) -> torch.device: """ Convert a DeviceLike into a concrete torch.device. For integers, we @@ -637,6 +679,16 @@ def __init__( self._gc_debounce_s = float(gc_debounce_seconds) self._gc_min_interval_s = max(0.0, float(gc_min_interval_seconds)) self._last_gc_done_per_device: Dict[str, int] = {} + # Physical-device GC bookkeeping (per accelerator index). + self._gc_done_physical: Dict[str, int] = {} + self._last_gc_done_physical: Dict[str, int] = {} + self._gc_pending_physical: Set[str] = set() + self._physical_children: Dict[str, Set[str]] = {} + + # Device-SMI handles are created lazily for GC logging. + self._device_smi_lock = threading.Lock() + self._device_smi_handles: Dict[str, Any] = {} + self._device_smi_failures: Set[str] = set() self._inference_mode = bool(inference_mode) self._worker_warmups = ( @@ -695,6 +747,10 @@ def __init__( self._inflight[key] = 0 self._inflight_cv[key] = threading.Condition() self._last_gc_done_per_device[key] = 0 + self._physical_children[key] = {key} + if dev.type in ("cuda", "xpu", "mps"): + self._gc_done_physical[key] = 0 + self._last_gc_done_physical[key] = 0 n_workers = self._resolve_workers_for_device(dev, base_workers) group: List[_DeviceWorker] = [] @@ -730,6 +786,7 @@ def __init__( self._inflight[v_key] = 0 self._inflight_cv[v_key] = threading.Condition() self._last_gc_done_per_device[v_key] = 0 + self._physical_children.setdefault(parent_key, set()).add(v_key) alias_group: List[_DeviceWorker] = [] for wid in range(limit): @@ -1075,6 +1132,15 @@ def shutdown(self, wait: bool = True): for w in snapshot: w.join() + if hasattr(self, "_device_smi_handles"): + with self._device_smi_lock: + for handle in list(self._device_smi_handles.values()): + try: + handle.close() + except Exception: + pass + self._device_smi_handles.clear() + if DEBUG_ON: log.debug("DeviceThreadPool shutdown complete") @contextlib.contextmanager @@ -1380,6 +1446,153 @@ def _resolve_scope_to_keys(self, scope: Optional[Union[str, DeviceLike, Iterable return self._normalize_scope_to_keys([scope]) return self._normalize_scope_to_keys(scope) + def _physical_key(self, key: str) -> str: + """ + Map a logical worker/alias key back to its physical device key. + """ + return getattr(self, "_virtual_to_parent", {}).get(key, key) + + def _invoke_empty_cache(self, fn: Callable[..., Any], dev: torch.device) -> None: + """ + Call an empty_cache-like callable, preferring a `device` argument when + supported and falling back to positional or zero-arg variants. + """ + supports_kw, supports_pos = _analyze_empty_cache_callable(fn) + if supports_kw: + try: + fn(device=dev) + return + except TypeError: + if DEBUG_ON: + log.debug("empty_cache callable rejected keyword arg; retrying positional (%s)", fn) + if supports_pos: + try: + fn(dev) + return + except TypeError: + if DEBUG_ON: + log.debug("empty_cache callable rejected positional arg; retrying no-arg (%s)", fn) + fn() + + def _run_empty_cache_for_device(self, key: str, dev: torch.device) -> Optional[float]: + """ + Execute an empty_cache call for the given device. Returns execution time in seconds. + """ + start = time.time() + if dev.type == "cuda": + live = getattr(torch.cuda, "empty_cache", None) if hasattr(torch, "cuda") else None + use_fn = live if callable(live) else TORCH_CUDA_EMPTY_CACHE + if use_fn is None: + if DEBUG_ON: + log.debug("DP-Janitor: no empty_cache callable available for %s", key) + return None + target = dev if dev.index is not None else "cuda" + with torch.cuda.device(target): + self._invoke_empty_cache(use_fn, dev) + return time.time() - start + + if dev.type == "xpu" and hasattr(torch, "xpu"): + live = getattr(torch.xpu, "empty_cache", None) + use_fn = live if callable(live) else TORCH_XPU_EMPTY_CACHE + if use_fn is None: + if DEBUG_ON: + log.debug("DP-Janitor: no empty_cache callable available for %s", key) + return None + target = dev if dev.index is not None else "xpu" + with torch.xpu.device(target): + self._invoke_empty_cache(use_fn, dev) + return time.time() - start + + if dev.type == "mps" and hasattr(torch, "mps"): + live = getattr(torch.mps, "empty_cache", None) + use_fn = live if callable(live) else TORCH_MPS_EMPTY_CACHE + if use_fn is None: + if DEBUG_ON: + log.debug("DP-Janitor: no empty_cache callable available for %s", key) + return None + self._invoke_empty_cache(use_fn, dev) + return time.time() - start + + if DEBUG_ON: + log.debug("DP-Janitor: unsupported device type '%s' for key %s", dev.type, key) + return None + + @staticmethod + def _format_gib_value(value: float) -> str: + text = f"{value:.1f}" + if text.endswith(".0"): + text = text[:-2] + return f"{text}G" + + def _device_smi_identifier(self, dev: torch.device) -> Optional[str]: + if Device is None: + return None + if dev.type == "cuda": + idx = dev.index + if idx is None: + return None + prefix = "rocm" if getattr(torch.version, "hip", None) else "cuda" + return f"{prefix}:{idx}" + if dev.type == "xpu": + idx = dev.index + if idx is None: + return None + return f"xpu:{idx}" + return None + + def _query_device_vram_gib(self, key: str) -> Optional[float]: + if Device is None: + return None + if not hasattr(self, "_device_smi_lock"): + self._device_smi_lock = threading.Lock() + self._device_smi_handles = {} + self._device_smi_failures = set() + dev = self._devices_by_key.get(key) + if dev is None: + return None + identifier = self._device_smi_identifier(dev) + if identifier is None: + return None + + with self._device_smi_lock: + if identifier in self._device_smi_failures: + return None + handle = self._device_smi_handles.get(identifier) + if handle is None: + try: + handle = Device(identifier) + except Exception: + self._device_smi_failures.add(identifier) + return None + self._device_smi_handles[identifier] = handle + + try: + metrics = handle.metrics(fast=True) + except Exception: + with self._device_smi_lock: + self._device_smi_failures.add(identifier) + stored = self._device_smi_handles.pop(identifier, None) + if stored is not None: + try: + stored.close() + except Exception: + pass + return None + + memory_used = getattr(metrics, "memory_used", None) + if memory_used is None: + return None + return float(memory_used) / (1024 ** 3) + + def _format_vram_summary(self, physical_keys: Iterable[str]) -> str: + readings: List[str] = [] + for key in physical_keys: + value = self._query_device_vram_gib(key) + if value is None: + continue + readings.append(f"{key}={self._format_gib_value(value)}") + return ", ".join(readings) if readings else "n/a" + # ---- inflight & completion accounting ---- def _mark_scheduled(self, key: str) -> None: @@ -1412,20 +1625,45 @@ def _on_task_finished(self, key: str) -> None: Called at the end of every task (success or failure). Updates counters and signals the janitor if the per-device threshold is reached. """ + if not hasattr(self, "_gc_done_physical"): + self._gc_done_physical = {} + if not hasattr(self, "_gc_pending_physical"): + self._gc_pending_physical = set() + if not hasattr(self, "_last_gc_done_physical"): + self._last_gc_done_physical = {} + if not hasattr(self, "_physical_children"): + self._physical_children = {} + self._mark_finished(key) trigger_gc = False with self._stats_lock: self._per_device_done[key] = self._per_device_done.get(key, 0) + 1 self._total_done += 1 - dev_type = self._devices_by_key[key].type - if self._empty_cache_every_n > 0 and dev_type in ("cuda", "xpu", "mps"): - n = self._per_device_done[key] - if n % self._empty_cache_every_n == 0: + dev = self._devices_by_key.get(key) + if ( + dev is not None + and self._empty_cache_every_n > 0 + and dev.type in ("cuda", "xpu", "mps") + ): + physical_key = self._physical_key(key) + current = self._gc_done_physical.get(physical_key, 0) + 1 + self._gc_done_physical[physical_key] = current + if current % self._empty_cache_every_n == 0: + if physical_key not in self._gc_pending_physical: + self._gc_pending_physical.add(physical_key) + self._gc_generation += 1 trigger_gc = True - self._gc_generation += 1 if DEBUG_ON: - log.debug(f"GC trigger set by {key}: per_device_done={n} threshold={self._empty_cache_every_n} total_done={self._total_done}") + log.debug( + "GC trigger set by %s (physical=%s): per_physical_done=%d threshold=%d total_done=%d pending=%s", + key, + physical_key, + current, + self._empty_cache_every_n, + self._total_done, + sorted(self._gc_pending_physical), + ) if trigger_gc: self._gc_event.set() @@ -1489,9 +1727,17 @@ def _collect_state_snapshot(self) -> Dict[str, Any]: idx = "" if dev.index is None else str(dev.index) meta[k] = {"type": dev.type, "index": idx} + physical_children = getattr(self, "_physical_children", {}) + per_done_physical: Dict[str, int] = {} + for phys_key, members in physical_children.items(): + per_done_physical[phys_key] = sum(per_done.get(member, 0) for member in members) + + pending_gc = sorted(getattr(self, "_gc_pending_physical", set())) + snap: Dict[str, Any] = { "devices": sorted(self._devices_by_key.keys()), "per_done": per_done, + "per_done_physical": per_done_physical, "total_done": total_done, "threshold": threshold, "inflight": inflight, @@ -1504,6 +1750,7 @@ def _collect_state_snapshot(self) -> Dict[str, Any]: "gc_generation_consumed": int(self._last_consumed_gc_generation), "last_gc_ts": self._last_gc_ts, "now": time.time(), + "pending_gc": pending_gc, } return snap @@ -1584,12 +1831,13 @@ def _should_run_gc_from_snapshot(self, snap: Dict[str, Any]) -> bool: thr = snap["threshold"] if thr <= 0: return False - for k in snap["devices"]: - dev_type = snap["meta"][k]["type"] - if dev_type not in ("cuda", "xpu", "mps"): - continue - done_now = snap["per_done"].get(k, 0) - done_prev = self._last_gc_done_per_device.get(k, 0) + pending = snap.get("pending_gc") or [] + if pending: + return True + per_done_physical = snap.get("per_done_physical") or {} + last_done_physical = getattr(self, "_last_gc_done_physical", {}) + for phys_key, done_now in per_done_physical.items(): + done_prev = last_done_physical.get(phys_key, 0) if done_now - done_prev >= thr: return True return False @@ -1600,20 +1848,34 @@ def _update_gc_watermarks(self, snap_after: Dict[str, Any]) -> None: before a subsequent pass is allowed. """ threshold = int(self._empty_cache_every_n) - for k in snap_after["devices"]: - done = snap_after["per_done"].get(k, 0) - if threshold <= 0: - self._last_gc_done_per_device[k] = done - continue + per_done_physical = snap_after.get("per_done_physical") or {} + per_done = snap_after.get("per_done") or {} + meta = snap_after.get("meta") or {} + processed = snap_after.get("_gc_processed_devices") + if processed is None: + processed_iter = per_done_physical.keys() + else: + processed_iter = processed - meta = snap_after.get("meta", {}).get(k, {}) - dev_type = meta.get("type") - if dev_type in ("cuda", "xpu", "mps"): - # Reset to the last completed threshold bucket so GC triggers - # again after another `threshold` tasks on this accelerator. - self._last_gc_done_per_device[k] = done - (done % threshold) + for phys_key in processed_iter: + done_phys = per_done_physical.get(phys_key) + if done_phys is None: + continue + if threshold <= 0: + self._last_gc_done_physical[phys_key] = done_phys else: - self._last_gc_done_per_device[k] = done + self._last_gc_done_physical[phys_key] = done_phys - (done_phys % threshold) + + members = self._physical_children.get(phys_key, {phys_key}) + for member in members: + done_member = per_done.get(member) + if done_member is None: + continue + dev_type = meta.get(member, {}).get("type") + if threshold <= 0 or dev_type not in ("cuda", "xpu", "mps"): + self._last_gc_done_per_device[member] = done_member + else: + self._last_gc_done_per_device[member] = done_member - (done_member % threshold) def _janitor_loop(self): """ @@ -1727,53 +1989,64 @@ def _janitor_loop(self): continue t0 = time.time() - # Optionally synchronize devices; often too slow to be worthwhile: - # self._synchronize_all() - # Per-device exclusive: acquire write lock, then call empty_cache(). - for key in sorted(self._ordered_keys): - dev = self._devices_by_key[key] - if dev.type not in ("cuda", "xpu", "mps"): - continue + try: + pre = self._collect_state_snapshot() + if DEBUG_ON: + log.debug( + "DP-Janitor: pre-snapshot taken: total_done=%s threshold=%s inflight=%s pending=%s", + pre["total_done"], + pre["threshold"], + pre["inflight"], + pre.get("pending_gc"), + ) + log.debug("GC trigger received; evaluating whether to run…") + pending_targets = [k for k in pre.get("pending_gc", []) if k in self._locks] + except Exception as e: + try: + log.warn(f"Failed to render GC pre-snapshot: {e!r}") + except Exception: + pass + pending_targets = sorted(getattr(self, "_gc_pending_physical", set())) - lk = self._locks[key] - if DEBUG_ON: log.debug(f"DP-Janitor: attempting writer lock for {key}") + if not pending_targets: + if DEBUG_ON: + log.debug("DP-Janitor: no pending devices after snapshot; marking generation %d consumed", current_generation) + with self._stats_lock: + self._last_consumed_gc_generation = max(self._last_consumed_gc_generation, current_generation) + continue + + processed_devices: List[str] = [] + skipped_devices: List[str] = [] + per_device_durations: Dict[str, float] = {} + + for key in pending_targets: + dev = self._devices_by_key.get(key) + if dev is None or dev.type not in ("cuda", "xpu", "mps"): + skipped_devices.append(key) + continue + lk = self._locks.get(key) + if lk is None: + skipped_devices.append(key) + continue + if DEBUG_ON: + log.debug("DP-Janitor: attempting writer lock for %s", key) with lk.writer(): - if DEBUG_ON: log.debug(f"DP-Janitor: acquired writer lock for {key}") + if DEBUG_ON: + log.debug("DP-Janitor: acquired writer lock for %s", key) + duration = self._run_empty_cache_for_device(key, dev) + if duration is not None: + per_device_durations[key] = duration + processed_devices.append(key) - if dev.type == "cuda": - live = getattr(torch.cuda, "empty_cache", None) if hasattr(torch, "cuda") else None - use_fn = live if callable(live) else TORCH_CUDA_EMPTY_CACHE - if DEBUG_ON: - src = "live" if use_fn is live else "hardcopy" - log.debug(f"DP-Janitor: empty_cache(cuda) using {src} on {key}") - if use_fn is not None: - with torch.cuda.device(dev.index): - use_fn() - - elif dev.type == "xpu": - live = getattr(torch.xpu, "empty_cache", None) if hasattr(torch, "xpu") else None - use_fn = live if callable(live) else TORCH_XPU_EMPTY_CACHE - if DEBUG_ON: - src = "live" if use_fn is live else "hardcopy" - log.debug(f"DP-Janitor: empty_cache(xpu) using {src} on {key}") - if use_fn is not None: - with torch.xpu.device(dev.index): - use_fn() - - elif dev.type == "mps": - live = getattr(torch.mps, "empty_cache", None) if hasattr(torch, "mps") else None - use_fn = live if callable(live) else TORCH_MPS_EMPTY_CACHE - if DEBUG_ON: - src = "live" if use_fn is live else "hardcopy" - log.debug(f"DP-Janitor: empty_cache(mps) using {src}") - if use_fn is not None: - use_fn() + if not processed_devices and DEBUG_ON: + log.debug("DP-Janitor: no eligible accelerator devices found in pending=%s", pending_targets) t1 = time.time() prev_gc_ts = self._last_gc_ts - self._gc_passes += 1 - self._last_gc_ts = t1 + if processed_devices: + self._gc_passes += 1 + self._last_gc_ts = t1 gc_timestamp = datetime.fromtimestamp(t1, tz=timezone.utc).isoformat() if prev_gc_ts is None: since_last_gc = "since last GC: n/a" @@ -1781,22 +2054,39 @@ def _janitor_loop(self): delta_s = t1 - prev_gc_ts since_last_gc = f"since last GC: {delta_s:.3f}s ({delta_s * 1000:.1f}ms)" - # Post-pass accounting & watermarks. - try: - post = self._collect_state_snapshot() - self._update_gc_watermarks(post) - log.info( - f"GC completed in {t1 - t0:.3f}s (pass #{self._gc_passes}) at {gc_timestamp}; {since_last_gc}." - ) - if DEBUG_ON: log.debug(f"DP-Janitor: post-snapshot: inflight={post['inflight']} per_done={post['per_done']}") - except Exception as e: + if processed_devices: + vram_summary = self._format_vram_summary(processed_devices) try: - log.warn(f"Failed to render GC post-snapshot: {e!r}") - except Exception: - pass - finally: - with self._stats_lock: - self._last_consumed_gc_generation = self._gc_generation + post = self._collect_state_snapshot() + post["_gc_processed_devices"] = processed_devices + self._update_gc_watermarks(post) + devices_clause = ", ".join(processed_devices) + log.info( + f"GC completed in {t1 - t0:.3f}s (pass #{self._gc_passes}) at {gc_timestamp}; devices={devices_clause}; VRAM {vram_summary}; {since_last_gc}." + ) + if DEBUG_ON: + log.debug( + "DP-Janitor: post-snapshot inflight=%s per_done=%s per_done_physical=%s durations=%s", + post["inflight"], + post["per_done"], + post.get("per_done_physical"), + per_device_durations, + ) + except Exception as e: + try: + log.warn(f"Failed to render GC post-snapshot: {e!r}") + except Exception: + pass + + with self._stats_lock: + for key in processed_devices: + self._gc_pending_physical.discard(key) + self._last_gc_done_physical[key] = self._gc_done_physical.get(key, 0) + for key in skipped_devices: + self._gc_pending_physical.discard(key) + self._last_consumed_gc_generation = max(self._last_consumed_gc_generation, current_generation) + if self._gc_pending_physical: + self._gc_event.set() # Legacy helper (not used by janitor). Kept for compatibility with any # external callers that previously expected a "clear everything" helper. From af39a3df097b3c161361e4427e8db31990cfbdf6 Mon Sep 17 00:00:00 2001 From: Qubitium Date: Sun, 19 Oct 2025 15:26:00 +0000 Subject: [PATCH 10/15] log slow replicate --- README.md | 6 --- gptqmodel/looper/module_looper.py | 77 +++++++++++++++++++++++++++---- gptqmodel/utils/looper_helpers.py | 28 +++++++++-- 3 files changed, 93 insertions(+), 18 deletions(-) diff --git a/README.md b/README.md index 5611545b1..30be33b79 100644 --- a/README.md +++ b/README.md @@ -172,12 +172,6 @@ Native support support some of the most popular multi-modal models: -## Experimental GPTQ v2 quantization: Users have reported this mode of quantization may or may not match original GPTQ v1 implementation in terms of quality recovery. - -
- -
- ## Model Support | Model | | | | | | | | | | |-------------------|---|-------------------|---|----------------|---|----------------|---|---------------------|---| diff --git a/gptqmodel/looper/module_looper.py b/gptqmodel/looper/module_looper.py index 24467dda9..9481d6a55 100644 --- a/gptqmodel/looper/module_looper.py +++ b/gptqmodel/looper/module_looper.py @@ -579,14 +579,7 @@ def _run_forward_batches_parallel( progress_total_rows: Optional[int] = None, ) -> List[List[torch.Tensor]]: """Fan batches across device clones and preserve result ordering.""" - module_replicas = clone_module_for_devices(module, devices) - - # Ensure any async replication/memcpy ops are complete before threads start fanning out. - torch_sync() - - prev_kv = shared_kv_cache_dict.get(layer_index - 1) if reuse_kv else None - - results: Dict[int, torch.Tensor | tuple | None] = {} + effective_title = progress_title or (progress_stage or "Forward") total_batches = self._resolve_batch_total(processor.num_batches, layer_inputs) batch_row_counts = progress_rows_per_batch or self._collect_row_counts(layer_inputs) @@ -599,8 +592,74 @@ def _run_forward_batches_parallel( if total_rows <= 0 and total_batches > 0: total_rows = total_batches total_rows = max(total_rows, 1) - processed_rows = 0 stage_label = progress_stage or "Forward" + + replica_pb: "ProgressBar" | None = None + replica_title = "" + replica_completed = 0 + + if progress_pb is not None: + progress_pb.title(effective_title) + if len(devices) > 1: + replica_title = f"{stage_label}: replicate to {len(devices)} devices" + replica_pb = ( + log.pb(range(len(devices))) + .manual() + .set(show_left_steps=False) + ) + replica_pb.title(replica_title).subtitle("Staging module...").draw() + else: + device_label = str(devices[0]) if devices else "" + progress_pb.subtitle(f"{stage_label}: staging on {device_label}").draw() + + def _replica_progress(idx: int, total: int, device: torch.device, step: str) -> None: + nonlocal replica_completed + device_label = str(device) + if replica_pb is not None: + if step == "stage": + replica_pb.title(replica_title).subtitle(f"Stage {device_label}").draw() + return + if idx > replica_completed: + replica_completed = idx + replica_pb.title(replica_title).subtitle( + f"{device_label} {idx}/{total}" + ).next().draw() + else: + replica_pb.title(replica_title).subtitle( + f"{device_label} {idx}/{total}" + ).draw() + elif progress_pb is not None: + stage_msg = ( + f"{stage_label}: staging on {device_label}" + if step == "stage" + else f"{stage_label}: {step} {idx}/{total} on {device_label}" + ) + progress_pb.title(effective_title).subtitle(stage_msg).draw() + + progress_cb = _replica_progress if progress_pb is not None else None + + # Ensure any async replication/memcpy ops are complete before threads start fanning out. + torch_sync() + + try: + module_replicas = clone_module_for_devices( + module, + devices, + progress_callback=progress_cb, + ) + finally: + if replica_pb is not None: + replica_pb.close() + if progress_pb is not None: + progress_pb.title(effective_title).subtitle( + f"{stage_label} rows 0/{total_rows}" + ).draw() + + prev_kv = shared_kv_cache_dict.get(layer_index - 1) if reuse_kv else None + + results: Dict[int, torch.Tensor | tuple | None] = {} + + processed_rows = 0 device_segments: Dict[torch.device, List[int]] = {} segment_start = 0 num_devices = len(devices) diff --git a/gptqmodel/utils/looper_helpers.py b/gptqmodel/utils/looper_helpers.py index 2f9b77236..a2050b6f0 100644 --- a/gptqmodel/utils/looper_helpers.py +++ b/gptqmodel/utils/looper_helpers.py @@ -8,7 +8,7 @@ import threading import time from contextlib import contextmanager -from typing import Dict, List, Optional, Sequence, Tuple +from typing import Callable, Dict, List, Optional, Sequence, Tuple import torch from torch.nn import parallel as torch_parallel @@ -225,6 +225,7 @@ def clone_module_for_devices( devices: List[torch.device], *, clear_state_fn=clear_non_picklable_state, + progress_callback: Optional[Callable[[int, int, torch.device, str], None]] = None, ) -> Dict[torch.device, torch.nn.Module]: clones: Dict[torch.device, torch.nn.Module] = {} if not devices: @@ -234,6 +235,21 @@ def clone_module_for_devices( clone_timings: List[Tuple[str, float]] = [] overall_start = time.perf_counter() + total_targets = len(devices) + + def _notify(idx: int, device: torch.device, step: str) -> None: + if progress_callback is None: + return + try: + progress_callback(idx, total_targets, device, step) + except Exception: + if DEBUG_ON: + log.debug( + "clone_module_for_devices: progress callback failed (device=%s, step=%s)", + device, + step, + ) + def _record(name: str, start_ts: Optional[float]) -> None: if not DEBUG_ON or start_ts is None: return @@ -283,17 +299,19 @@ def _prepare_module(target_device: torch.device, step_name: str) -> None: if use_replicate: try: _prepare_module(base_device, f"stage_{base_device}") + _notify(0, base_device, "stage") replicate_start = time.perf_counter() replicas = torch_replicate(module, devices) _record("replicate", replicate_start) - for dev, replica in zip(devices, replicas): + for idx, (dev, replica) in enumerate(zip(devices, replicas), start=1): replica.eval() rehome_module_to_device(replica, dev, move_parameters=True, move_buffers=True) clear_state_fn(replica) setattr(replica, "_gptqmodule_device_hint", dev) clones[dev] = replica + _notify(idx, dev, "replica") _emit_clone_log("replicate") return clones @@ -305,14 +323,17 @@ def _prepare_module(target_device: torch.device, step_name: str) -> None: if len(devices) == 1 and devices[0].type == "cpu": _prepare_module(CPU, "stage_cpu") + _notify(0, CPU, "stage") clones[devices[0]] = module + _notify(1, devices[0], "reuse") _emit_clone_log("reuse") return clones if not use_replicate: _prepare_module(stage_device, f"stage_{stage_device}") + _notify(0, stage_device, "stage") - for dev in devices: + for idx, dev in enumerate(devices, start=1): start_ts = time.perf_counter() with _DEEPCOPY_LOCK: replica = copy.deepcopy(module) @@ -322,6 +343,7 @@ def _prepare_module(target_device: torch.device, step_name: str) -> None: setattr(replica, "_gptqmodule_device_hint", dev) clones[dev] = replica _record(str(dev), start_ts) + _notify(idx, dev, "clone") _emit_clone_log("deepcopy") return clones From e916b4be1d4eeaa1622858e2eccac2137d297d1e Mon Sep 17 00:00:00 2001 From: Qubitium Date: Sun, 19 Oct 2025 16:25:51 +0000 Subject: [PATCH 11/15] update scores --- tests/models/test_qwen2_5.py | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/tests/models/test_qwen2_5.py b/tests/models/test_qwen2_5.py index d4af6eb32..672103c1a 100644 --- a/tests/models/test_qwen2_5.py +++ b/tests/models/test_qwen2_5.py @@ -6,20 +6,25 @@ from model_test import ModelTest from gptqmodel.utils.eval import EVAL - +# | Metric | MARLIN | +# |--------------------------------|----------| +# | arc_challenge :: acc,none | 0.2884 | +# | arc_challenge :: acc_norm,none | 0.3208 | +# | mmlu :: acc,none | 0.442 | class TestQwen2_5(ModelTest): + GROUP_SIZE = 32 NATIVE_MODEL_ID = "/monster/data/model/Qwen2.5-0.5B-Instruct" EVAL_TASKS = { EVAL.LM_EVAL.ARC_CHALLENGE: { - "acc": {"value": 0.2722, "floor_pct": 0.04}, - "acc_norm": {"value": 0.3072, "floor_pct": 0.04}, + "acc": {"value": 0.2884, "floor_pct": 0.04}, + "acc_norm": {"value": 0.3208, "floor_pct": 0.04}, }, EVAL.LM_EVAL.MMLU: { - "acc": {"value": 0.4029, "floor_pct": 0.04}, + "acc": {"value": 0.4420, "floor_pct": 0.04}, }, } - TRUST_REMOTE_CODE = False - APPLY_CHAT_TEMPLATE = True + #TRUST_REMOTE_CODE = False + #APPLY_CHAT_TEMPLATE = True #EVAL_BATCH_SIZE = 6 def test_qwen2_5(self): From b5f918fa2b12f64eaf9fdf819641fac6cd101ad1 Mon Sep 17 00:00:00 2001 From: Qubitium Date: Sun, 19 Oct 2025 16:41:30 +0000 Subject: [PATCH 12/15] update logbar to fix log flicker --- pyproject.toml | 2 +- requirements.txt | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index b6e484ea4..6d2f009ca 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -46,7 +46,7 @@ dependencies = [ "huggingface_hub>=0.34.4", "random_word>=1.0.13", "tokenicer>=0.0.5", - "logbar>=0.1.4", + "logbar>=0.1.5", "maturin>=1.9.4", # required by safetensors and hf_transfer "datasets>=3.6.0", "pyarrow>=21.0", diff --git a/requirements.txt b/requirements.txt index a9697fcb5..e9ba019ef 100644 --- a/requirements.txt +++ b/requirements.txt @@ -13,7 +13,7 @@ hf_transfer>=0.1.9 huggingface_hub>=0.34.4 random_word>=1.0.13 tokenicer>=0.0.5 -logbar>=0.1.4 +logbar>=0.1.5 maturin>=1.9.4 datasets>=3.6.0 pyarrow>=21.0 From 86510d5839b20bd0f21fc852dcf31353021e3230 Mon Sep 17 00:00:00 2001 From: Qubitium Date: Sun, 19 Oct 2025 16:56:35 +0000 Subject: [PATCH 13/15] update scores --- tests/models/test_glm.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/tests/models/test_glm.py b/tests/models/test_glm.py index 295f234f1..c315a0301 100644 --- a/tests/models/test_glm.py +++ b/tests/models/test_glm.py @@ -8,19 +8,20 @@ # | Metric | MARLIN | # |--------------------------------|----------| -# | arc_challenge :: acc,none | 0.5026 | -# | arc_challenge :: acc_norm,none | 0.5171 | -# | mmlu :: acc,none | 0.6362 | +# | arc_challenge :: acc,none | 0.5154 | +# | arc_challenge :: acc_norm,none | 0.535 | +# | mmlu :: acc,none | 0.6325 | class TestGlm(ModelTest): + GROUP_SIZE = 32 # real: THUDM/glm-4-9b-chat-hf NATIVE_MODEL_ID = "/monster/data/model/glm-4-9b-chat-hf" EVAL_TASKS = { EVAL.LM_EVAL.ARC_CHALLENGE: { - "acc": {"value": 0.5026, "floor_pct": 0.04}, - "acc_norm": {"value": 0.5171, "floor_pct": 0.04}, + "acc": {"value": 0.5154, "floor_pct": 0.04}, + "acc_norm": {"value": 0.5350, "floor_pct": 0.04}, }, EVAL.LM_EVAL.MMLU: { - "acc": {"value": 0.6362, "floor_pct": 0.04}, + "acc": {"value": 0.6325, "floor_pct": 0.04}, }, } From 8f189c4268f2c731cdfbca1f86cb518d32174bff Mon Sep 17 00:00:00 2001 From: Qubitium Date: Sun, 19 Oct 2025 17:20:05 +0000 Subject: [PATCH 14/15] update news --- README.md | 32 ++++++++++++++++++-------------- 1 file changed, 18 insertions(+), 14 deletions(-) diff --git a/README.md b/README.md index 30be33b79..c3115b122 100644 --- a/README.md +++ b/README.md @@ -17,25 +17,28 @@

## Latest News +* 10/20/2025 [5.0.0](https://github.com/ModelCloud/GPTQModel/releases/tag/v5.0.0): 🎉 Data-parallel quant support for `MoE` models on multi-gpu using `nogil` Python. `offload_to_disk` support enabled by +default to massively reduce `cpu` ram usage. New `Intel` and `AMD` cpu hw accelerated `TorchFused` kernel. Packing stage is now 4x faster and now inlined with quantization. `Vram` pressure for large models reduced during quantization. +`act_group_aware` is 16k+ times faster and now the default when `desc_act=False` for higher quality recovery without inference penalty of `desc_act=True`. New beta quality `AWQ` support with full `gemm`, +`gemm_fast`, `marlin` kernel support. `LFM`, `Ling`, `Qwen3 Omni` model support. Quantization is now faster with reduced vram usage. Enhanced logging support with `LogBar`. +* 09/16/2025 [4.2.5](https://github.com/ModelCloud/GPTQModel/releases/tag/v4.2.5): `hyb_act` renamed to `act_group_aware`. Removed finicky `torch` import within `setup.py`. Packing bug fix and prebuilt Pytorch 2.8 whls. +* 09/12/2025 [4.2.0](https://github.com/ModelCloud/GPTQModel/releases/tag/v4.2.0): ✨ New Models Support: Qwen3-Next, Apertus, Kimi K2, Klear, FastLLM, Nemotron H. New `fail_safe` `boolean` toggle to `.quantize()` to patch-fix non-activated `MoE` modules due to highly uneven MoE model training. Fixed LavaQwen2 compat. Patch fix GIL=0 cuda error for multi-gpu. Fix compat with autoround + new transformers. +* 09/04/2025 [4.1.0](https://github.com/ModelCloud/GPTQModel/releases/tag/v4.1.0): ✨ Meituan LongCat Flash Chat, Llama 4, GPT-OSS (BF16), and GLM-4.5-Air support. New experiemental `mock_quantization` config to skip complex computational code paths during quantization to accelerate model quant testing. +* 08/21/2025 [4.0.0](https://github.com/ModelCloud/GPTQModel/releases/tag/v4.0.0): 🎉 New Group Aware Reordering (GAR) support. New models support: Bytedance Seed-OSS, Baidu Ernie, Huawei PanGu, Gemma3, Xiaomi Mimo, Qwen 3/MoE, Falcon H1, GPT-Neo. Memory leak and multiple model compatibility fixes related to Transformers >= 4.54. Python >= 3.13t free-threading support added with near N x GPU linear scaling for quantization of MoE models and also linear N x Cpu Core scaling of packing stage. Early access Pytorch 2.8 fused-ops on Intel XPU for up to 50% speedup. + +
+ +Archived News * 10/17/2025 5.0.0-dev `main`: 👀: EoRA now multi-gpu compatible. Fixed both quality stability of multi-gpu quanta and vram usage. New LFM and Ling models support. * 09/30/2025 5.0.0-dev `main`: 👀: New Data Parallel + Multi-GPU + Python 3.13T (PYTHON_GIL=0) equals 80%+ overall quant time reduction of large MoE models vs v4.2.5. * 09/29/2025 5.0.0-dev `main`: 🎉 New Qwen3 Omni model support. AWQ Marlin kernel integrated + many disk offload, threading, and memory usage fixes. * 09/24/2025 5.0.0-dev `main`: 🎉 Up to 90% cpu mem saving for large MoE models with faster/inline packing! 26% quant time reduction for Qwen3 MoE! AWQ Marlin kernel added. AWQ Gemm loading bug fixes. `act_group_aware` now faster and auto enabled for GPTQ when `desc_act` is False for higher quality recovery. * 09/19/2025 5.0.0-dev `main`: 👀 Cpu memory saving of ~73.5% during quantization stage with new `offload_to_disk` quantization config property default to `True`. -* 09/18/2025 5.0.0-dev `main`: 🎉 AWQ quantization support! Complete refractor and simplification of model definitions in prepreation for future quantization formats. -* 09/16/2025 [4.2.5](https://github.com/ModelCloud/GPTQModel/releases/tag/v4.2.5): `hyb_act` renamed to `act_group_aware`. Removed finicky `torch` import within `setup.py`. Packing bug fix and prebuilt Pytorch 2.8 whls. -* 09/12/2025 [4.2.0](https://github.com/ModelCloud/GPTQModel/releases/tag/v4.2.0): ✨ New Models Support: Qwen3-Next, Apertus, Kimi K2, Klear, FastLLM, Nemotron H. New `fail_safe` `boolean` toggle to `.quantize()` to patch-fix non-activated `MoE` modules due to highly uneven MoE model training. Fixed LavaQwen2 compat. Patch fix GIL=0 cuda error for multi-gpu. Fix compat with autoround + new transformers. -* 09/04/2025 [4.1.0](https://github.com/ModelCloud/GPTQModel/releases/tag/v4.1.0): ✨ Meituan LongCat Flash Chat, Llama 4, GPT-OSS (BF16), and GLM-4.5-Air support. New experiemental `mock_quantization` config to skip complex computational code paths during quantization to accelerate model quant testing. -* 08/21/2025 [4.0.0](https://github.com/ModelCloud/GPTQModel/releases/tag/v4.0.0): 🎉 New Group Aware Reordering (GAR) support. New models support: Bytedance Seed-OSS, Baidu Ernie, Huawei PanGu, Gemma3, Xiaomi Mimo, Qwen 3/MoE, Falcon H1, GPT-Neo. Memory leak and multiple model compatibility fixes related to Transformers >= 4.54. Python >= 3.13t free-threading support added with near N x GPU linear scaling for quantization of MoE models and also linear N x Cpu Core scaling of packing stage. Early access Pytorch 2.8 fused-ops on Intel XPU for up to 50% speedup. +* 09/18/2025 5.0.0-dev `main`: 🎉 AWQ quantization support! Complete refractor and simplification of model definitions in prepreation for future quantization formats. * 08/19/2025 4.0.0-dev `main`: Fix quantization memory usage due to some model's incorrect application of `config.use_cache` during inference. Fixed `Transformers` >= 4.54.0 compat which changed layer forward return signature for some models. * 08/18/2025 4.0.0-dev `main`: GPT-Neo model support. Memory leak fix in error capture (stacktrace) and fixed `lm_head` quantization compatibility for many models. * 07/31/2025 4.0.0-dev `main`: New Group Aware Reordering (GAR) support and prelim Pytorch 2.8 fused-ops for Intel XPU for up to 50% speedup. * 07/03/2025 4.0.0-dev `main`: New Baidu Ernie and Huawei PanGu model support. - -
- -Archived News - * 07/02/2025 4.0.0-dev `main`: Gemma3 4B model compat fix. * 05/29/2025 4.0.0-dev `main`: Falcon H1 model support. Fixed Transformers `4.52+` compat with Qwen 2.5 VL models. * 05/19/2025 4.0.0-dev `main`: Qwen 2.5 Omni model support. @@ -281,12 +284,13 @@ model.quantize(calibration_dataset, batch_size=1) model.save(quant_path) ``` -### Quantization using GPTQ V2 +### Quantization using GPTQ V2* (Experimental, not MoE compatible, and results may not be better than v1) -Enable GPTQ v2 quantization by setting `v2 = True` for potentially higher post-quantization accuracy recovery. +Enable GPTQ v2 quantization by setting `v2 = True`. ```py -# note v2 is currently experiemental and requires 2-4x more vram to execute -# if oom on 1 gpu, please set CUDA_VISIBLE_DEVICES=0,1 to 2 gpu and gptqmodel will auto use second gpu +# Note v2 is currently experimental, not MoE compatible, and requires 2-4x more vram to execute +# We have many reports of v2 not working better or exceeding v1 so please use for testing only +# If oom on 1 gpu, please set CUDA_VISIBLE_DEVICES=0,1 to 2 gpu and gptqmodel will auto use second gpu quant_config = QuantizeConfig(bits=4, group_size=128, v2=True) ``` `Llama 3.1 8B-Instruct` quantized using `test/models/test_llama3_2.py` From 6084179c4b3fd0f33a50e5f3b6a7a877572e857c Mon Sep 17 00:00:00 2001 From: Qubitium Date: Sun, 19 Oct 2025 17:21:02 +0000 Subject: [PATCH 15/15] update scores --- tests/models/test_glm4_moe.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tests/models/test_glm4_moe.py b/tests/models/test_glm4_moe.py index 74540ef43..a53942fb4 100644 --- a/tests/models/test_glm4_moe.py +++ b/tests/models/test_glm4_moe.py @@ -8,6 +8,9 @@ class TestGlm4Moe(ModelTest): NATIVE_MODEL_ID = "/monster/data/model/GLM-4.6/" + DELETE_QUANTIZED_MODEL = False + DATASET_SIZE = 512 + GROUP_SIZE = 32 EVAL_TASKS = { EVAL.LM_EVAL.ARC_CHALLENGE: { "acc": {"value": 0.5026, "floor_pct": 0.04},