From ddc4df1d394631eebbb36ba7e0d5bf37155cebc3 Mon Sep 17 00:00:00 2001
From: Qubitium
Date: Sat, 18 Oct 2025 22:42:30 +0000
Subject: [PATCH 01/15] update
---
tests/models/test_glm4_moe.py | 13 ++++++++++---
1 file changed, 10 insertions(+), 3 deletions(-)
diff --git a/tests/models/test_glm4_moe.py b/tests/models/test_glm4_moe.py
index fbd76fc71..b30736155 100644
--- a/tests/models/test_glm4_moe.py
+++ b/tests/models/test_glm4_moe.py
@@ -4,12 +4,19 @@
# Contact: qubitium@modelcloud.ai, x.com/qubitium
from model_test import ModelTest
-
+from gptqmodel.utils.eval import EVAL
class TestGlm4Moe(ModelTest):
NATIVE_MODEL_ID = "/monster/data/_ci_/GLM-4.5-Air/"
- NATIVE_ARC_CHALLENGE_ACC = 0.3805
- NATIVE_ARC_CHALLENGE_ACC_NORM = 0.4078
+ EVAL_TASKS = {
+ EVAL.LM_EVAL.ARC_CHALLENGE: {
+ "acc": {"value": 0.5026, "floor_pct": 0.04},
+ "acc_norm": {"value": 0.5171, "floor_pct": 0.04},
+ },
+ EVAL.LM_EVAL.MMLU: {
+ "acc": {"value": 0.6362, "floor_pct": 0.04},
+ },
+ }
def test_glm4moe(self):
self.quant_lm_eval()
From a5d8d7091db6e1ccc354414b1f34b48c706f8c9d Mon Sep 17 00:00:00 2001
From: Qubitium
Date: Sun, 19 Oct 2025 03:36:22 +0000
Subject: [PATCH 02/15] reuse streams
---
gptqmodel/utils/stream.py | 28 +++++++++++++++++++++++++++-
1 file changed, 27 insertions(+), 1 deletion(-)
diff --git a/gptqmodel/utils/stream.py b/gptqmodel/utils/stream.py
index 0cba0a7bc..16a0327d8 100644
--- a/gptqmodel/utils/stream.py
+++ b/gptqmodel/utils/stream.py
@@ -65,6 +65,26 @@ def _build_stream_pool() -> Tuple[DeviceThreadPool, Dict[str, str]]:
STREAM_DEVICE_POOL, _ACCELERATOR_ALIAS_MAP = _build_stream_pool()
+_STREAM_CACHE_LOCK = threading.RLock()
+_CUDA_COPY_STREAMS: Dict[int, torch.cuda.Stream] = {}
+
+
+def _resolve_device_index(device: torch.device) -> int:
+ index = device.index
+ if index is not None:
+ return index
+ return torch.cuda.current_device()
+
+# reuse streams instead of creating tons of new streams
+def _get_cached_copy_stream(device: torch.device) -> torch.cuda.Stream:
+ idx = _resolve_device_index(device)
+ with _STREAM_CACHE_LOCK:
+ stream = _CUDA_COPY_STREAMS.get(idx)
+ if stream is None:
+ stream = torch.cuda.Stream(device=torch.device("cuda", idx))
+ _CUDA_COPY_STREAMS[idx] = stream
+ return stream
+
def _queue_key_for_device(device: Optional[torch.device]) -> str:
if device is None:
@@ -168,6 +188,12 @@ def stream_tensor_dict_to_cpu(
if not filtered:
return {}
+ # sync copy
+ # host_map = {name: tensor.detach().to("cpu") for name, tensor in filtered.items()}
+ # with state_lock:
+ # store_callback(host_map)
+ # return host_map
+
first = next(iter(filtered.values()))
if first.device.type != "cuda" or not torch.cuda.is_available():
@@ -180,7 +206,7 @@ def stream_tensor_dict_to_cpu(
copy_device = first.device
compute_stream = torch.cuda.current_stream(device=copy_device)
- copy_stream = torch.cuda.Stream(device=copy_device)
+ copy_stream = _get_cached_copy_stream(copy_device)
done_event = torch.cuda.Event(enable_timing=False, blocking=False)
pending_sources: List[torch.Tensor] = []
From 6637fa7bd15a5f1e95c3ae122566dd1e5314f36c Mon Sep 17 00:00:00 2001
From: Qubitium
Date: Sun, 19 Oct 2025 03:45:57 +0000
Subject: [PATCH 03/15] reduce memory
---
gptqmodel/quantization/gptq.py | 19 +++++++++----------
1 file changed, 9 insertions(+), 10 deletions(-)
diff --git a/gptqmodel/quantization/gptq.py b/gptqmodel/quantization/gptq.py
index 9583f0664..59b341a15 100644
--- a/gptqmodel/quantization/gptq.py
+++ b/gptqmodel/quantization/gptq.py
@@ -3,7 +3,7 @@
# SPDX-License-Identifier: Apache-2.0
# Contact: qubitium@modelcloud.ai, x.com/qubitium
-# adapted from @qwopqwop200 's [GPTQ-for-LLaMa](https://github.com/qwopqwop200/GPTQ-for-LLaMa/tree/cuda), which itself is based on [gptq](https://github.com/IST-DASLab/gptq)
+# Based on original gptq algorithm and code from https://github.com/IST-DASLab/gptq
import contextlib
import math
@@ -559,33 +559,32 @@ def hf_quantize(
@torch.inference_mode()
def hessian_inverse(self, H: torch.Tensor):
-
damp = self.qcfg.damp_percent
- diag = torch.arange(self.columns, device=H.device)
mean = torch.mean(torch.diag(H))
+
+ orig_diag = H.diag().clone()
while 0 < damp < 1:
try:
- H2 = H.clone()
- H2[diag, diag] += damp * mean
- # TODO call to torch.linalg is not threadsafe? Porque no? Esta muy mal.
- H2 = torch.linalg.cholesky(H2)
+ H.diagonal().add_(damp * mean)
+ H2 = torch.linalg.cholesky(H)
Hinv = torch.linalg.cholesky(torch.cholesky_inverse(H2), upper=True)
- del H, H2
+ H.diagonal().copy_(orig_diag)
+ del H2
break
except torch._C._LinAlgError as e:
+ H.diagonal().copy_(orig_diag)
if self.qcfg.damp_auto_increment != 0:
log.warn(
f"Quantization: Module `{self.name}` -> Current `damp_percent = {damp:.5f}` is too low, auto-incrementing by `{self.qcfg.damp_auto_increment:.5f}`")
damp += self.qcfg.damp_auto_increment
else:
log.warn(
- "Quantization: Module `{self.name}` -> Please increase damp or nsamples for calibration data to avoid the following quant error: current damp_percent=`{damp_percent:.5f}`")
+ "Quantization: Module `{self.name}` -> Please increase damp or nsamples for calibration data to avoid the following quant error: current damp_percent=`{damp:.5f}`")
raise e
if not (0 < damp < 1):
log.error(
f"Quantization: Module `{self.name}` -> `damp_percent` must between 0 and 1. current is {damp}. Module cannot be correctly processed.")
- # raise ValueError(f"Quantization: `damp_percent` must between 0 and 1. current is {damp}")
return None, 1.0
return Hinv, damp
From 6074b3e7c3d6850e1667b2f098c9bb9e1fc91e0f Mon Sep 17 00:00:00 2001
From: Qubitium
Date: Sun, 19 Oct 2025 10:44:07 +0000
Subject: [PATCH 04/15] correctly tag optional glm4 moe q/k_norm modules
---
gptqmodel/models/definitions/glm4_moe.py | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
diff --git a/gptqmodel/models/definitions/glm4_moe.py b/gptqmodel/models/definitions/glm4_moe.py
index 59c958367..ac64d111a 100644
--- a/gptqmodel/models/definitions/glm4_moe.py
+++ b/gptqmodel/models/definitions/glm4_moe.py
@@ -28,7 +28,7 @@ class GLM4MoEGPTQ(BaseQModel):
"#",
{
"input_layernorm": ("input_layernorm:!",),
- "self_attn": ("q_proj:0", "k_proj:0", "v_proj:0", "o_proj:1"),
+ "self_attn": ("q_proj:0", "q_norm:0:!","k_proj:0", "k_norm:0:!", "v_proj:0", "o_proj:1"),
"post_attention_layernorm": ("post_attention_layernorm:!",),
"mlp": {
"shared_experts": {
@@ -36,7 +36,7 @@ class GLM4MoEGPTQ(BaseQModel):
"up_proj": ("up_proj:0",),
"down_proj": ("down_proj:1",),
},
- "gate": ("gate:!",),
+ "gate": ("gate:!",), # Glm4MoeTopKRouter, ~1.6MB float32 per layer. We really do not quant to quantize this.
"experts": {
"#": ("gate_proj:0", "up_proj:0", "down_proj:1"),
},
From 62eb84e20ec6e62a1ee26436c410f7af1b2bae47 Mon Sep 17 00:00:00 2001
From: Qubitium
Date: Sun, 19 Oct 2025 10:44:54 +0000
Subject: [PATCH 05/15] gc min interval to reject too closed grouped gc
requests
---
gptqmodel/utils/threadx.py | 34 +++++++++++++++++++++++++++++++++-
tests/test_threadx_janitor.py | 1 +
2 files changed, 34 insertions(+), 1 deletion(-)
diff --git a/gptqmodel/utils/threadx.py b/gptqmodel/utils/threadx.py
index a3a64fc47..9ac0e588c 100644
--- a/gptqmodel/utils/threadx.py
+++ b/gptqmodel/utils/threadx.py
@@ -559,6 +559,7 @@ def __init__(
empty_cache_every_n: int = 50, # <=0 disables janitor
workers: Optional[Dict[str, int]] = None, # e.g. {'cpu':4, 'cuda:per':1, 'cuda:0':3}
gc_debounce_seconds: float = 0.02, # absorb bursty triggers before GC
+ gc_min_interval_seconds: float = 1.0, # throttle janitor passes
pin_cpu_workers: bool = False,
pin_accelerator_workers: bool = False,
):
@@ -579,6 +580,7 @@ def __init__(
warmups: optional mapping from device family (e.g. 'cuda') to a callable
run once after the worker activates its device. A special key
'default' applies when no family-specific warmup is found.
+ gc_min_interval_seconds: minimum interval between GC passes. Values <= 0 disable throttling.
pin_cpu_workers: bind CPU device workers to individual CPU cores when
affinity APIs are available. Defaults to False so CPU tasks may
float across cores unless explicitly opt-in.
@@ -633,6 +635,7 @@ def __init__(
# GC dedupe/coalesce: debounce window to absorb bursty triggers;
# per-device "done" watermark to skip redundant GC passes.
self._gc_debounce_s = float(gc_debounce_seconds)
+ self._gc_min_interval_s = max(0.0, float(gc_min_interval_seconds))
self._last_gc_done_per_device: Dict[str, int] = {}
self._inference_mode = bool(inference_mode)
@@ -756,7 +759,10 @@ def __init__(
)
self._janitor.start()
if DEBUG_ON:
- log.debug(f"DP-Janitor thread started (debounce={self._gc_debounce_s:.3f}s, threshold={self._empty_cache_every_n})")
+ log.debug(
+ f"DP-Janitor thread started (debounce={self._gc_debounce_s:.3f}s, "
+ f"min_interval={self._gc_min_interval_s:.3f}s, threshold={self._empty_cache_every_n})"
+ )
else:
if DEBUG_ON:
log.debug("DP-Janitor disabled (no accelerators or threshold <= 0)")
@@ -1655,12 +1661,38 @@ def _janitor_loop(self):
with self._stats_lock:
current_generation = self._gc_generation
last_generation = self._last_consumed_gc_generation
+ last_gc_ts = self._last_gc_ts
if current_generation == last_generation:
if DEBUG_ON:
log.debug("DP-Janitor: trigger generation already consumed; skipping")
continue
+ min_interval = self._gc_min_interval_s
+ if min_interval > 0.0 and last_gc_ts is not None:
+ elapsed = time.time() - last_gc_ts
+ if elapsed < min_interval:
+ wait_for = min_interval - elapsed
+ if DEBUG_ON:
+ log.debug(
+ f"DP-Janitor: last pass {elapsed * 1000:.1f}ms ago; waiting {wait_for * 1000:.1f}ms to honor min interval"
+ )
+ if self._stop_event.wait(timeout=wait_for):
+ if DEBUG_ON:
+ log.debug("DP-Janitor: stop event set during min-interval wait; exiting")
+ break
+ if self._stop_event.is_set():
+ if DEBUG_ON:
+ log.debug("DP-Janitor: stop event observed after min-interval wait; exiting")
+ break
+ with self._stats_lock:
+ current_generation = self._gc_generation
+ last_generation = self._last_consumed_gc_generation
+ if current_generation == last_generation:
+ if DEBUG_ON:
+ log.debug("DP-Janitor: no pending GC generation after min-interval wait; skipping")
+ continue
+
# Snapshot & decision
try:
pre = self._collect_state_snapshot()
diff --git a/tests/test_threadx_janitor.py b/tests/test_threadx_janitor.py
index b85708f4e..77cbd4d94 100644
--- a/tests/test_threadx_janitor.py
+++ b/tests/test_threadx_janitor.py
@@ -24,6 +24,7 @@ def _make_pool():
pool._auto_gc_disable_cv = threading.Condition()
pool._auto_gc_disable_count = 0
pool._gc_debounce_s = 0.0
+ pool._gc_min_interval_s = 0.0
pool._stats_lock = threading.Lock()
pool._per_device_done = {}
pool._total_done = 0
From 8166c0d2dfab3c2fb3b505f3dfde97b4f165b488 Mon Sep 17 00:00:00 2001
From: Qubitium
Date: Sun, 19 Oct 2025 12:07:37 +0000
Subject: [PATCH 06/15] fix ci model test should switch to torch if marlin
fails validation for config
---
tests/models/model_test.py | 61 +++++++++++++++++++++++++++++++++-----
1 file changed, 53 insertions(+), 8 deletions(-)
diff --git a/tests/models/model_test.py b/tests/models/model_test.py
index 3658a1212..ce1a8f067 100644
--- a/tests/models/model_test.py
+++ b/tests/models/model_test.py
@@ -356,13 +356,55 @@ def run_eval_tasks(self, model, backend, trust_remote_code=False):
self.LOAD_BACKEND = previous_backend
return task_results
+ def _current_load_backend(self):
+ effective = getattr(self, "_effective_load_backend", None)
+ if effective is not None and self.LOAD_BACKEND == BACKEND.MARLIN:
+ return effective
+ return self.LOAD_BACKEND
+
def perform_post_quant_validation(self, model_path, trust_remote_code=False):
inference_records = {}
eval_records = {}
reuse_candidates = {}
compare_backends = (BACKEND.MARLIN,) if self.FORMAT is FORMAT.GPTQ else (BACKEND.MARLIN, BACKEND.GEMM)
- target_backend = self.LOAD_BACKEND
+ fallback_backend = None
+ if BACKEND.MARLIN in compare_backends:
+ try:
+ from gptqmodel.nn_modules.qlinear.marlin import MarlinQuantLinear # type: ignore
+ except Exception: # pragma: no cover - fallback if module unavailable
+ marlin_group_sizes = ()
+ marlin_sym = ()
+ else:
+ marlin_group_sizes = tuple(getattr(MarlinQuantLinear, "SUPPORTS_GROUP_SIZE", ()))
+ marlin_sym = tuple(getattr(MarlinQuantLinear, "SUPPORTS_SYM", ()))
+
+ requested_group_size = getattr(self, "GROUP_SIZE", None)
+ requested_sym = getattr(self, "SYM", None)
+
+ marlin_supported = True
+ if marlin_group_sizes and requested_group_size not in marlin_group_sizes:
+ marlin_supported = False
+ if marlin_sym and requested_sym not in marlin_sym:
+ marlin_supported = False
+
+ if not marlin_supported:
+ fallback_backend = BACKEND.TORCH
+ compare_backends = tuple(
+ BACKEND.TORCH if backend == BACKEND.MARLIN else backend
+ for backend in compare_backends
+ )
+ log.info(
+ f"Marlin backend unsupported for current quant config (group_size={requested_group_size}, sym={requested_sym}); "
+ "falling back to BACKEND.TORCH for validation."
+ )
+
+ if fallback_backend is not None and self.LOAD_BACKEND == BACKEND.MARLIN:
+ self._effective_load_backend = fallback_backend
+ else:
+ self._effective_load_backend = None
+
+ target_backend = self._current_load_backend()
can_reuse = target_backend not in (BACKEND.AUTO, BACKEND.AUTO_TRAINABLE)
for backend in compare_backends:
@@ -702,6 +744,7 @@ def quantModel(self, model_id_or_path, trust_remote_code=False, dtype="auto", ne
tokenizer = model.tokenizer
self._post_quant_eval_records = {}
+ self._effective_load_backend = None
is_image_to_text_model = MODALITY.IMAGE_TO_TEXT in model.modality
calibration_dataset = get_calib_dataset(model) if is_image_to_text_model else self.load_dataset(tokenizer, self.DATASET_SIZE)
@@ -732,22 +775,23 @@ def quantModel(self, model_id_or_path, trust_remote_code=False, dtype="auto", ne
self._print_post_quant_artifacts(path)
log.info(f"Quantized Model saved to tmp dir: {path}")
- target_backend = self.LOAD_BACKEND
reuse_candidates, eval_records = self.perform_post_quant_validation(path, trust_remote_code=trust_remote_code)
self._post_quant_eval_records = eval_records
+ target_backend = self._current_load_backend()
q_model = reuse_candidates.pop(target_backend, None)
if q_model is None:
# Ensure the post-quant reload stays on a single CUDA device when available.
- use_cuda_map = torch.cuda.is_available() and self.LOAD_BACKEND != BACKEND.TORCH_FUSED
+ use_cuda_map = torch.cuda.is_available() and target_backend != BACKEND.TORCH_FUSED
if use_cuda_map:
q_model = self.loadQuantModel(
path,
trust_remote_code=trust_remote_code,
+ backend=target_backend,
device_map={"": "cuda:0"},
)
else:
- q_model = self.loadQuantModel(path, trust_remote_code=trust_remote_code)
+ q_model = self.loadQuantModel(path, trust_remote_code=trust_remote_code, backend=target_backend)
else:
log.info(f"Reusing post-quant validation model for backend `{target_backend.name}`")
@@ -781,7 +825,7 @@ def loadQuantModel(self, model_id_or_path, trust_remote_code=False, tokenizer_pa
else:
log.warn("flash-attn requested but not available; falling back to framework defaults")
- active_backend = backend if backend is not None else self.LOAD_BACKEND
+ active_backend = backend if backend is not None else self._current_load_backend()
default_device_map = {"": "cpu"} if active_backend == BACKEND.TORCH_FUSED else "auto"
explicit_device = "device" in load_kwargs
@@ -847,7 +891,8 @@ def lm_eval(self, model, apply_chat_template=False, trust_remote_code=False, del
task_groups = EVAL.get_task_groups_from_tasks(task_names)
for framework, tasks in task_groups.items():
- log.info(f"TEST: EVAL starting: backend = {self.LOAD_BACKEND}")
+ active_backend = self._current_load_backend()
+ log.info(f"TEST: EVAL starting: backend = {active_backend.name}")
if model_path:
log.info(f"Inference from model path: {model_path}")
@@ -875,7 +920,7 @@ def lm_eval(self, model, apply_chat_template=False, trust_remote_code=False, del
llm_backend="vllm" if self.USE_VLLM else "gptqmodel",
model_args=model_args,
output_path=tmp_dir,
- backend=self.LOAD_BACKEND,
+ backend=active_backend,
framework=framework,
tasks=eval_tasks,
apply_chat_template=apply_chat_template,
@@ -953,7 +998,7 @@ def quant_lm_eval(self):
self.check_kernel(self.model, self.KERNEL_INFERENCE)
eval_records = getattr(self, "_post_quant_eval_records", {})
- target_backend = self.LOAD_BACKEND
+ target_backend = self._current_load_backend()
if eval_records and len(eval_records) == 1 and target_backend in eval_records:
log.info("Reusing evaluation results for backend `%s`; skipping duplicate lm_eval run", target_backend.name)
task_results = eval_records[target_backend]
From 5779163d045c380474895f8036c137904ed90036 Mon Sep 17 00:00:00 2001
From: Qubitium
Date: Sun, 19 Oct 2025 12:18:39 +0000
Subject: [PATCH 07/15] use glm 4.6 for test
---
tests/models/test_glm4_moe.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/tests/models/test_glm4_moe.py b/tests/models/test_glm4_moe.py
index b30736155..74540ef43 100644
--- a/tests/models/test_glm4_moe.py
+++ b/tests/models/test_glm4_moe.py
@@ -7,7 +7,7 @@
from gptqmodel.utils.eval import EVAL
class TestGlm4Moe(ModelTest):
- NATIVE_MODEL_ID = "/monster/data/_ci_/GLM-4.5-Air/"
+ NATIVE_MODEL_ID = "/monster/data/model/GLM-4.6/"
EVAL_TASKS = {
EVAL.LM_EVAL.ARC_CHALLENGE: {
"acc": {"value": 0.5026, "floor_pct": 0.04},
From 2130ae2ee38742fc034614066b755da8ff5b3f2e Mon Sep 17 00:00:00 2001
From: Qubitium
Date: Sun, 19 Oct 2025 13:15:04 +0000
Subject: [PATCH 08/15] update scores for qwen3 next
---
tests/models/test_qwen3_next.py | 39 +++++++++++++++++++--------------
1 file changed, 23 insertions(+), 16 deletions(-)
diff --git a/tests/models/test_qwen3_next.py b/tests/models/test_qwen3_next.py
index 923834440..882473af6 100644
--- a/tests/models/test_qwen3_next.py
+++ b/tests/models/test_qwen3_next.py
@@ -6,28 +6,35 @@
from model_test import ModelTest
from gptqmodel.utils.eval import EVAL
-
+# | Metric | MARLIN |
+# |--------------------------------|----------|
+# | arc_challenge :: acc,none | 0.6271 |
+# | arc_challenge :: acc_norm,none | 0.6613 |
+# | mmlu :: acc,none | 0.8403 |
class TestQwen3Next(ModelTest):
NATIVE_MODEL_ID = "/monster/data/model/Qwen3-Next-80B-A3B-Instruct"
EVAL_TASKS = {
EVAL.LM_EVAL.ARC_CHALLENGE: {
- "acc": {"value": 0.3900, "floor_pct": 0.04},
- "acc_norm": {"value": 0.3900, "floor_pct": 0.04},
+ "acc": {"value": 0.6271, "floor_pct": 0.04},
+ "acc_norm": {"value": 0.6613, "floor_pct": 0.04},
+ },
+ EVAL.LM_EVAL.MMLU: {
+ "acc": {"value": 0.8403, "floor_pct": 0.04},
},
}
- TRUST_REMOTE_CODE = True
- APPLY_CHAT_TEMPLATE = True
- EVAL_BATCH_SIZE = 4
- V2 = False
- DEBUG = True
- ACT_GROUP_AWARE = True
- DESC_ACT = False
- DATASET_SIZE = 1024
- DATASET_SORT = "desc"
- QUANT_BATCH_SIZE = 4
- CALIB_NOISE_MODE = "unseen"
- CALIB_NOISE_PERCENT = 0.025
- USE_FLASH_ATTN = True
+ # TRUST_REMOTE_CODE = True
+ # APPLY_CHAT_TEMPLATE = True
+ # EVAL_BATCH_SIZE = 4
+ # V2 = False
+ # DEBUG = True
+ # ACT_GROUP_AWARE = True
+ # DESC_ACT = False
+ # DATASET_SIZE = 1024
+ # DATASET_SORT = "desc"
+ # QUANT_BATCH_SIZE = 4
+ # CALIB_NOISE_MODE = "unseen"
+ # CALIB_NOISE_PERCENT = 0.025
+ # USE_FLASH_ATTN = True
def test_mimo(self):
self.quant_lm_eval()
From d6051cc4db2334c7e07c58d677f7fc6633a5802a Mon Sep 17 00:00:00 2001
From: Qubitium
Date: Sun, 19 Oct 2025 13:57:11 +0000
Subject: [PATCH 09/15] granular gc per device:index
---
gptqmodel/utils/threadx.py | 448 ++++++++++++++++++++++++++++++-------
1 file changed, 369 insertions(+), 79 deletions(-)
diff --git a/gptqmodel/utils/threadx.py b/gptqmodel/utils/threadx.py
index 9ac0e588c..c55098688 100644
--- a/gptqmodel/utils/threadx.py
+++ b/gptqmodel/utils/threadx.py
@@ -6,6 +6,7 @@
from __future__ import annotations
import contextlib
+import inspect
import os
import queue
import sys
@@ -18,6 +19,11 @@
import torch
+try:
+ from device_smi import Device # type: ignore
+except Exception: # pragma: no cover - defensive: optional dependency may be unavailable
+ Device = None
+
from .. import DEBUG_ON
from ..utils.ctx import ctx
from ..utils.logger import setup_logger
@@ -79,6 +85,42 @@ def _mps_available() -> bool:
# --------------------------- Device coercion & context helpers ---------------------------
+_EMPTY_CACHE_SIGNATURE_CACHE: Dict[int, Tuple[bool, bool]] = {}
+
+
+def _analyze_empty_cache_callable(fn: Callable[..., Any]) -> Tuple[bool, bool]:
+ """
+ Inspect an empty_cache callable and determine whether it accepts a `device`
+ keyword argument or at least one positional argument. Results are memoized.
+ """
+ cache_key = id(fn)
+ cached = _EMPTY_CACHE_SIGNATURE_CACHE.get(cache_key)
+ if cached is not None:
+ return cached
+
+ supports_kw = False
+ supports_pos = False
+ try:
+ sig = inspect.signature(fn)
+ except (TypeError, ValueError):
+ _EMPTY_CACHE_SIGNATURE_CACHE[cache_key] = (supports_kw, supports_pos)
+ return supports_kw, supports_pos
+
+ for param in sig.parameters.values():
+ if param.kind == inspect.Parameter.VAR_KEYWORD:
+ supports_kw = True
+ elif param.kind == inspect.Parameter.VAR_POSITIONAL:
+ supports_pos = True
+ elif param.kind in (inspect.Parameter.POSITIONAL_ONLY, inspect.Parameter.POSITIONAL_OR_KEYWORD):
+ supports_pos = True
+ if param.name == "device":
+ supports_kw = True
+ elif param.kind == inspect.Parameter.KEYWORD_ONLY and param.name == "device":
+ supports_kw = True
+
+ _EMPTY_CACHE_SIGNATURE_CACHE[cache_key] = (supports_kw, supports_pos)
+ return supports_kw, supports_pos
+
def _coerce_device(d: DeviceLike) -> torch.device:
"""
Convert a DeviceLike into a concrete torch.device. For integers, we
@@ -637,6 +679,16 @@ def __init__(
self._gc_debounce_s = float(gc_debounce_seconds)
self._gc_min_interval_s = max(0.0, float(gc_min_interval_seconds))
self._last_gc_done_per_device: Dict[str, int] = {}
+ # Physical-device GC bookkeeping (per accelerator index).
+ self._gc_done_physical: Dict[str, int] = {}
+ self._last_gc_done_physical: Dict[str, int] = {}
+ self._gc_pending_physical: Set[str] = set()
+ self._physical_children: Dict[str, Set[str]] = {}
+
+ # Device-SMI handles are created lazily for GC logging.
+ self._device_smi_lock = threading.Lock()
+ self._device_smi_handles: Dict[str, Any] = {}
+ self._device_smi_failures: Set[str] = set()
self._inference_mode = bool(inference_mode)
self._worker_warmups = (
@@ -695,6 +747,10 @@ def __init__(
self._inflight[key] = 0
self._inflight_cv[key] = threading.Condition()
self._last_gc_done_per_device[key] = 0
+ self._physical_children[key] = {key}
+ if dev.type in ("cuda", "xpu", "mps"):
+ self._gc_done_physical[key] = 0
+ self._last_gc_done_physical[key] = 0
n_workers = self._resolve_workers_for_device(dev, base_workers)
group: List[_DeviceWorker] = []
@@ -730,6 +786,7 @@ def __init__(
self._inflight[v_key] = 0
self._inflight_cv[v_key] = threading.Condition()
self._last_gc_done_per_device[v_key] = 0
+ self._physical_children.setdefault(parent_key, set()).add(v_key)
alias_group: List[_DeviceWorker] = []
for wid in range(limit):
@@ -1075,6 +1132,15 @@ def shutdown(self, wait: bool = True):
for w in snapshot:
w.join()
+ if hasattr(self, "_device_smi_handles"):
+ with self._device_smi_lock:
+ for handle in list(self._device_smi_handles.values()):
+ try:
+ handle.close()
+ except Exception:
+ pass
+ self._device_smi_handles.clear()
+
if DEBUG_ON: log.debug("DeviceThreadPool shutdown complete")
@contextlib.contextmanager
@@ -1380,6 +1446,153 @@ def _resolve_scope_to_keys(self, scope: Optional[Union[str, DeviceLike, Iterable
return self._normalize_scope_to_keys([scope])
return self._normalize_scope_to_keys(scope)
+ def _physical_key(self, key: str) -> str:
+ """
+ Map a logical worker/alias key back to its physical device key.
+ """
+ return getattr(self, "_virtual_to_parent", {}).get(key, key)
+
+ def _invoke_empty_cache(self, fn: Callable[..., Any], dev: torch.device) -> None:
+ """
+ Call an empty_cache-like callable, preferring a `device` argument when
+ supported and falling back to positional or zero-arg variants.
+ """
+ supports_kw, supports_pos = _analyze_empty_cache_callable(fn)
+ if supports_kw:
+ try:
+ fn(device=dev)
+ return
+ except TypeError:
+ if DEBUG_ON:
+ log.debug("empty_cache callable rejected keyword arg; retrying positional (%s)", fn)
+ if supports_pos:
+ try:
+ fn(dev)
+ return
+ except TypeError:
+ if DEBUG_ON:
+ log.debug("empty_cache callable rejected positional arg; retrying no-arg (%s)", fn)
+ fn()
+
+ def _run_empty_cache_for_device(self, key: str, dev: torch.device) -> Optional[float]:
+ """
+ Execute an empty_cache call for the given device. Returns execution time in seconds.
+ """
+ start = time.time()
+ if dev.type == "cuda":
+ live = getattr(torch.cuda, "empty_cache", None) if hasattr(torch, "cuda") else None
+ use_fn = live if callable(live) else TORCH_CUDA_EMPTY_CACHE
+ if use_fn is None:
+ if DEBUG_ON:
+ log.debug("DP-Janitor: no empty_cache callable available for %s", key)
+ return None
+ target = dev if dev.index is not None else "cuda"
+ with torch.cuda.device(target):
+ self._invoke_empty_cache(use_fn, dev)
+ return time.time() - start
+
+ if dev.type == "xpu" and hasattr(torch, "xpu"):
+ live = getattr(torch.xpu, "empty_cache", None)
+ use_fn = live if callable(live) else TORCH_XPU_EMPTY_CACHE
+ if use_fn is None:
+ if DEBUG_ON:
+ log.debug("DP-Janitor: no empty_cache callable available for %s", key)
+ return None
+ target = dev if dev.index is not None else "xpu"
+ with torch.xpu.device(target):
+ self._invoke_empty_cache(use_fn, dev)
+ return time.time() - start
+
+ if dev.type == "mps" and hasattr(torch, "mps"):
+ live = getattr(torch.mps, "empty_cache", None)
+ use_fn = live if callable(live) else TORCH_MPS_EMPTY_CACHE
+ if use_fn is None:
+ if DEBUG_ON:
+ log.debug("DP-Janitor: no empty_cache callable available for %s", key)
+ return None
+ self._invoke_empty_cache(use_fn, dev)
+ return time.time() - start
+
+ if DEBUG_ON:
+ log.debug("DP-Janitor: unsupported device type '%s' for key %s", dev.type, key)
+ return None
+
+ @staticmethod
+ def _format_gib_value(value: float) -> str:
+ text = f"{value:.1f}"
+ if text.endswith(".0"):
+ text = text[:-2]
+ return f"{text}G"
+
+ def _device_smi_identifier(self, dev: torch.device) -> Optional[str]:
+ if Device is None:
+ return None
+ if dev.type == "cuda":
+ idx = dev.index
+ if idx is None:
+ return None
+ prefix = "rocm" if getattr(torch.version, "hip", None) else "cuda"
+ return f"{prefix}:{idx}"
+ if dev.type == "xpu":
+ idx = dev.index
+ if idx is None:
+ return None
+ return f"xpu:{idx}"
+ return None
+
+ def _query_device_vram_gib(self, key: str) -> Optional[float]:
+ if Device is None:
+ return None
+ if not hasattr(self, "_device_smi_lock"):
+ self._device_smi_lock = threading.Lock()
+ self._device_smi_handles = {}
+ self._device_smi_failures = set()
+ dev = self._devices_by_key.get(key)
+ if dev is None:
+ return None
+ identifier = self._device_smi_identifier(dev)
+ if identifier is None:
+ return None
+
+ with self._device_smi_lock:
+ if identifier in self._device_smi_failures:
+ return None
+ handle = self._device_smi_handles.get(identifier)
+ if handle is None:
+ try:
+ handle = Device(identifier)
+ except Exception:
+ self._device_smi_failures.add(identifier)
+ return None
+ self._device_smi_handles[identifier] = handle
+
+ try:
+ metrics = handle.metrics(fast=True)
+ except Exception:
+ with self._device_smi_lock:
+ self._device_smi_failures.add(identifier)
+ stored = self._device_smi_handles.pop(identifier, None)
+ if stored is not None:
+ try:
+ stored.close()
+ except Exception:
+ pass
+ return None
+
+ memory_used = getattr(metrics, "memory_used", None)
+ if memory_used is None:
+ return None
+ return float(memory_used) / (1024 ** 3)
+
+ def _format_vram_summary(self, physical_keys: Iterable[str]) -> str:
+ readings: List[str] = []
+ for key in physical_keys:
+ value = self._query_device_vram_gib(key)
+ if value is None:
+ continue
+ readings.append(f"{key}={self._format_gib_value(value)}")
+ return ", ".join(readings) if readings else "n/a"
+
# ---- inflight & completion accounting ----
def _mark_scheduled(self, key: str) -> None:
@@ -1412,20 +1625,45 @@ def _on_task_finished(self, key: str) -> None:
Called at the end of every task (success or failure). Updates counters
and signals the janitor if the per-device threshold is reached.
"""
+ if not hasattr(self, "_gc_done_physical"):
+ self._gc_done_physical = {}
+ if not hasattr(self, "_gc_pending_physical"):
+ self._gc_pending_physical = set()
+ if not hasattr(self, "_last_gc_done_physical"):
+ self._last_gc_done_physical = {}
+ if not hasattr(self, "_physical_children"):
+ self._physical_children = {}
+
self._mark_finished(key)
trigger_gc = False
with self._stats_lock:
self._per_device_done[key] = self._per_device_done.get(key, 0) + 1
self._total_done += 1
- dev_type = self._devices_by_key[key].type
- if self._empty_cache_every_n > 0 and dev_type in ("cuda", "xpu", "mps"):
- n = self._per_device_done[key]
- if n % self._empty_cache_every_n == 0:
+ dev = self._devices_by_key.get(key)
+ if (
+ dev is not None
+ and self._empty_cache_every_n > 0
+ and dev.type in ("cuda", "xpu", "mps")
+ ):
+ physical_key = self._physical_key(key)
+ current = self._gc_done_physical.get(physical_key, 0) + 1
+ self._gc_done_physical[physical_key] = current
+ if current % self._empty_cache_every_n == 0:
+ if physical_key not in self._gc_pending_physical:
+ self._gc_pending_physical.add(physical_key)
+ self._gc_generation += 1
trigger_gc = True
- self._gc_generation += 1
if DEBUG_ON:
- log.debug(f"GC trigger set by {key}: per_device_done={n} threshold={self._empty_cache_every_n} total_done={self._total_done}")
+ log.debug(
+ "GC trigger set by %s (physical=%s): per_physical_done=%d threshold=%d total_done=%d pending=%s",
+ key,
+ physical_key,
+ current,
+ self._empty_cache_every_n,
+ self._total_done,
+ sorted(self._gc_pending_physical),
+ )
if trigger_gc:
self._gc_event.set()
@@ -1489,9 +1727,17 @@ def _collect_state_snapshot(self) -> Dict[str, Any]:
idx = "" if dev.index is None else str(dev.index)
meta[k] = {"type": dev.type, "index": idx}
+ physical_children = getattr(self, "_physical_children", {})
+ per_done_physical: Dict[str, int] = {}
+ for phys_key, members in physical_children.items():
+ per_done_physical[phys_key] = sum(per_done.get(member, 0) for member in members)
+
+ pending_gc = sorted(getattr(self, "_gc_pending_physical", set()))
+
snap: Dict[str, Any] = {
"devices": sorted(self._devices_by_key.keys()),
"per_done": per_done,
+ "per_done_physical": per_done_physical,
"total_done": total_done,
"threshold": threshold,
"inflight": inflight,
@@ -1504,6 +1750,7 @@ def _collect_state_snapshot(self) -> Dict[str, Any]:
"gc_generation_consumed": int(self._last_consumed_gc_generation),
"last_gc_ts": self._last_gc_ts,
"now": time.time(),
+ "pending_gc": pending_gc,
}
return snap
@@ -1584,12 +1831,13 @@ def _should_run_gc_from_snapshot(self, snap: Dict[str, Any]) -> bool:
thr = snap["threshold"]
if thr <= 0:
return False
- for k in snap["devices"]:
- dev_type = snap["meta"][k]["type"]
- if dev_type not in ("cuda", "xpu", "mps"):
- continue
- done_now = snap["per_done"].get(k, 0)
- done_prev = self._last_gc_done_per_device.get(k, 0)
+ pending = snap.get("pending_gc") or []
+ if pending:
+ return True
+ per_done_physical = snap.get("per_done_physical") or {}
+ last_done_physical = getattr(self, "_last_gc_done_physical", {})
+ for phys_key, done_now in per_done_physical.items():
+ done_prev = last_done_physical.get(phys_key, 0)
if done_now - done_prev >= thr:
return True
return False
@@ -1600,20 +1848,34 @@ def _update_gc_watermarks(self, snap_after: Dict[str, Any]) -> None:
before a subsequent pass is allowed.
"""
threshold = int(self._empty_cache_every_n)
- for k in snap_after["devices"]:
- done = snap_after["per_done"].get(k, 0)
- if threshold <= 0:
- self._last_gc_done_per_device[k] = done
- continue
+ per_done_physical = snap_after.get("per_done_physical") or {}
+ per_done = snap_after.get("per_done") or {}
+ meta = snap_after.get("meta") or {}
+ processed = snap_after.get("_gc_processed_devices")
+ if processed is None:
+ processed_iter = per_done_physical.keys()
+ else:
+ processed_iter = processed
- meta = snap_after.get("meta", {}).get(k, {})
- dev_type = meta.get("type")
- if dev_type in ("cuda", "xpu", "mps"):
- # Reset to the last completed threshold bucket so GC triggers
- # again after another `threshold` tasks on this accelerator.
- self._last_gc_done_per_device[k] = done - (done % threshold)
+ for phys_key in processed_iter:
+ done_phys = per_done_physical.get(phys_key)
+ if done_phys is None:
+ continue
+ if threshold <= 0:
+ self._last_gc_done_physical[phys_key] = done_phys
else:
- self._last_gc_done_per_device[k] = done
+ self._last_gc_done_physical[phys_key] = done_phys - (done_phys % threshold)
+
+ members = self._physical_children.get(phys_key, {phys_key})
+ for member in members:
+ done_member = per_done.get(member)
+ if done_member is None:
+ continue
+ dev_type = meta.get(member, {}).get("type")
+ if threshold <= 0 or dev_type not in ("cuda", "xpu", "mps"):
+ self._last_gc_done_per_device[member] = done_member
+ else:
+ self._last_gc_done_per_device[member] = done_member - (done_member % threshold)
def _janitor_loop(self):
"""
@@ -1727,53 +1989,64 @@ def _janitor_loop(self):
continue
t0 = time.time()
- # Optionally synchronize devices; often too slow to be worthwhile:
- # self._synchronize_all()
- # Per-device exclusive: acquire write lock, then call empty_cache().
- for key in sorted(self._ordered_keys):
- dev = self._devices_by_key[key]
- if dev.type not in ("cuda", "xpu", "mps"):
- continue
+ try:
+ pre = self._collect_state_snapshot()
+ if DEBUG_ON:
+ log.debug(
+ "DP-Janitor: pre-snapshot taken: total_done=%s threshold=%s inflight=%s pending=%s",
+ pre["total_done"],
+ pre["threshold"],
+ pre["inflight"],
+ pre.get("pending_gc"),
+ )
+ log.debug("GC trigger received; evaluating whether to run…")
+ pending_targets = [k for k in pre.get("pending_gc", []) if k in self._locks]
+ except Exception as e:
+ try:
+ log.warn(f"Failed to render GC pre-snapshot: {e!r}")
+ except Exception:
+ pass
+ pending_targets = sorted(getattr(self, "_gc_pending_physical", set()))
- lk = self._locks[key]
- if DEBUG_ON: log.debug(f"DP-Janitor: attempting writer lock for {key}")
+ if not pending_targets:
+ if DEBUG_ON:
+ log.debug("DP-Janitor: no pending devices after snapshot; marking generation %d consumed", current_generation)
+ with self._stats_lock:
+ self._last_consumed_gc_generation = max(self._last_consumed_gc_generation, current_generation)
+ continue
+
+ processed_devices: List[str] = []
+ skipped_devices: List[str] = []
+ per_device_durations: Dict[str, float] = {}
+
+ for key in pending_targets:
+ dev = self._devices_by_key.get(key)
+ if dev is None or dev.type not in ("cuda", "xpu", "mps"):
+ skipped_devices.append(key)
+ continue
+ lk = self._locks.get(key)
+ if lk is None:
+ skipped_devices.append(key)
+ continue
+ if DEBUG_ON:
+ log.debug("DP-Janitor: attempting writer lock for %s", key)
with lk.writer():
- if DEBUG_ON: log.debug(f"DP-Janitor: acquired writer lock for {key}")
+ if DEBUG_ON:
+ log.debug("DP-Janitor: acquired writer lock for %s", key)
+ duration = self._run_empty_cache_for_device(key, dev)
+ if duration is not None:
+ per_device_durations[key] = duration
+ processed_devices.append(key)
- if dev.type == "cuda":
- live = getattr(torch.cuda, "empty_cache", None) if hasattr(torch, "cuda") else None
- use_fn = live if callable(live) else TORCH_CUDA_EMPTY_CACHE
- if DEBUG_ON:
- src = "live" if use_fn is live else "hardcopy"
- log.debug(f"DP-Janitor: empty_cache(cuda) using {src} on {key}")
- if use_fn is not None:
- with torch.cuda.device(dev.index):
- use_fn()
-
- elif dev.type == "xpu":
- live = getattr(torch.xpu, "empty_cache", None) if hasattr(torch, "xpu") else None
- use_fn = live if callable(live) else TORCH_XPU_EMPTY_CACHE
- if DEBUG_ON:
- src = "live" if use_fn is live else "hardcopy"
- log.debug(f"DP-Janitor: empty_cache(xpu) using {src} on {key}")
- if use_fn is not None:
- with torch.xpu.device(dev.index):
- use_fn()
-
- elif dev.type == "mps":
- live = getattr(torch.mps, "empty_cache", None) if hasattr(torch, "mps") else None
- use_fn = live if callable(live) else TORCH_MPS_EMPTY_CACHE
- if DEBUG_ON:
- src = "live" if use_fn is live else "hardcopy"
- log.debug(f"DP-Janitor: empty_cache(mps) using {src}")
- if use_fn is not None:
- use_fn()
+ if not processed_devices and DEBUG_ON:
+ log.debug("DP-Janitor: no eligible accelerator devices found in pending=%s", pending_targets)
t1 = time.time()
prev_gc_ts = self._last_gc_ts
- self._gc_passes += 1
- self._last_gc_ts = t1
+ if processed_devices:
+ self._gc_passes += 1
+ self._last_gc_ts = t1
gc_timestamp = datetime.fromtimestamp(t1, tz=timezone.utc).isoformat()
if prev_gc_ts is None:
since_last_gc = "since last GC: n/a"
@@ -1781,22 +2054,39 @@ def _janitor_loop(self):
delta_s = t1 - prev_gc_ts
since_last_gc = f"since last GC: {delta_s:.3f}s ({delta_s * 1000:.1f}ms)"
- # Post-pass accounting & watermarks.
- try:
- post = self._collect_state_snapshot()
- self._update_gc_watermarks(post)
- log.info(
- f"GC completed in {t1 - t0:.3f}s (pass #{self._gc_passes}) at {gc_timestamp}; {since_last_gc}."
- )
- if DEBUG_ON: log.debug(f"DP-Janitor: post-snapshot: inflight={post['inflight']} per_done={post['per_done']}")
- except Exception as e:
+ if processed_devices:
+ vram_summary = self._format_vram_summary(processed_devices)
try:
- log.warn(f"Failed to render GC post-snapshot: {e!r}")
- except Exception:
- pass
- finally:
- with self._stats_lock:
- self._last_consumed_gc_generation = self._gc_generation
+ post = self._collect_state_snapshot()
+ post["_gc_processed_devices"] = processed_devices
+ self._update_gc_watermarks(post)
+ devices_clause = ", ".join(processed_devices)
+ log.info(
+ f"GC completed in {t1 - t0:.3f}s (pass #{self._gc_passes}) at {gc_timestamp}; devices={devices_clause}; VRAM {vram_summary}; {since_last_gc}."
+ )
+ if DEBUG_ON:
+ log.debug(
+ "DP-Janitor: post-snapshot inflight=%s per_done=%s per_done_physical=%s durations=%s",
+ post["inflight"],
+ post["per_done"],
+ post.get("per_done_physical"),
+ per_device_durations,
+ )
+ except Exception as e:
+ try:
+ log.warn(f"Failed to render GC post-snapshot: {e!r}")
+ except Exception:
+ pass
+
+ with self._stats_lock:
+ for key in processed_devices:
+ self._gc_pending_physical.discard(key)
+ self._last_gc_done_physical[key] = self._gc_done_physical.get(key, 0)
+ for key in skipped_devices:
+ self._gc_pending_physical.discard(key)
+ self._last_consumed_gc_generation = max(self._last_consumed_gc_generation, current_generation)
+ if self._gc_pending_physical:
+ self._gc_event.set()
# Legacy helper (not used by janitor). Kept for compatibility with any
# external callers that previously expected a "clear everything" helper.
From af39a3df097b3c161361e4427e8db31990cfbdf6 Mon Sep 17 00:00:00 2001
From: Qubitium
Date: Sun, 19 Oct 2025 15:26:00 +0000
Subject: [PATCH 10/15] log slow replicate
---
README.md | 6 ---
gptqmodel/looper/module_looper.py | 77 +++++++++++++++++++++++++++----
gptqmodel/utils/looper_helpers.py | 28 +++++++++--
3 files changed, 93 insertions(+), 18 deletions(-)
diff --git a/README.md b/README.md
index 5611545b1..30be33b79 100644
--- a/README.md
+++ b/README.md
@@ -172,12 +172,6 @@ Native support support some of the most popular multi-modal models:
-## Experimental GPTQ v2 quantization: Users have reported this mode of quantization may or may not match original GPTQ v1 implementation in terms of quality recovery.
-
-
-

-
-
## Model Support
| Model | | | | | | | | | |
|-------------------|---|-------------------|---|----------------|---|----------------|---|---------------------|---|
diff --git a/gptqmodel/looper/module_looper.py b/gptqmodel/looper/module_looper.py
index 24467dda9..9481d6a55 100644
--- a/gptqmodel/looper/module_looper.py
+++ b/gptqmodel/looper/module_looper.py
@@ -579,14 +579,7 @@ def _run_forward_batches_parallel(
progress_total_rows: Optional[int] = None,
) -> List[List[torch.Tensor]]:
"""Fan batches across device clones and preserve result ordering."""
- module_replicas = clone_module_for_devices(module, devices)
-
- # Ensure any async replication/memcpy ops are complete before threads start fanning out.
- torch_sync()
-
- prev_kv = shared_kv_cache_dict.get(layer_index - 1) if reuse_kv else None
-
- results: Dict[int, torch.Tensor | tuple | None] = {}
+ effective_title = progress_title or (progress_stage or "Forward")
total_batches = self._resolve_batch_total(processor.num_batches, layer_inputs)
batch_row_counts = progress_rows_per_batch or self._collect_row_counts(layer_inputs)
@@ -599,8 +592,74 @@ def _run_forward_batches_parallel(
if total_rows <= 0 and total_batches > 0:
total_rows = total_batches
total_rows = max(total_rows, 1)
- processed_rows = 0
stage_label = progress_stage or "Forward"
+
+ replica_pb: "ProgressBar" | None = None
+ replica_title = ""
+ replica_completed = 0
+
+ if progress_pb is not None:
+ progress_pb.title(effective_title)
+ if len(devices) > 1:
+ replica_title = f"{stage_label}: replicate to {len(devices)} devices"
+ replica_pb = (
+ log.pb(range(len(devices)))
+ .manual()
+ .set(show_left_steps=False)
+ )
+ replica_pb.title(replica_title).subtitle("Staging module...").draw()
+ else:
+ device_label = str(devices[0]) if devices else ""
+ progress_pb.subtitle(f"{stage_label}: staging on {device_label}").draw()
+
+ def _replica_progress(idx: int, total: int, device: torch.device, step: str) -> None:
+ nonlocal replica_completed
+ device_label = str(device)
+ if replica_pb is not None:
+ if step == "stage":
+ replica_pb.title(replica_title).subtitle(f"Stage {device_label}").draw()
+ return
+ if idx > replica_completed:
+ replica_completed = idx
+ replica_pb.title(replica_title).subtitle(
+ f"{device_label} {idx}/{total}"
+ ).next().draw()
+ else:
+ replica_pb.title(replica_title).subtitle(
+ f"{device_label} {idx}/{total}"
+ ).draw()
+ elif progress_pb is not None:
+ stage_msg = (
+ f"{stage_label}: staging on {device_label}"
+ if step == "stage"
+ else f"{stage_label}: {step} {idx}/{total} on {device_label}"
+ )
+ progress_pb.title(effective_title).subtitle(stage_msg).draw()
+
+ progress_cb = _replica_progress if progress_pb is not None else None
+
+ # Ensure any async replication/memcpy ops are complete before threads start fanning out.
+ torch_sync()
+
+ try:
+ module_replicas = clone_module_for_devices(
+ module,
+ devices,
+ progress_callback=progress_cb,
+ )
+ finally:
+ if replica_pb is not None:
+ replica_pb.close()
+ if progress_pb is not None:
+ progress_pb.title(effective_title).subtitle(
+ f"{stage_label} rows 0/{total_rows}"
+ ).draw()
+
+ prev_kv = shared_kv_cache_dict.get(layer_index - 1) if reuse_kv else None
+
+ results: Dict[int, torch.Tensor | tuple | None] = {}
+
+ processed_rows = 0
device_segments: Dict[torch.device, List[int]] = {}
segment_start = 0
num_devices = len(devices)
diff --git a/gptqmodel/utils/looper_helpers.py b/gptqmodel/utils/looper_helpers.py
index 2f9b77236..a2050b6f0 100644
--- a/gptqmodel/utils/looper_helpers.py
+++ b/gptqmodel/utils/looper_helpers.py
@@ -8,7 +8,7 @@
import threading
import time
from contextlib import contextmanager
-from typing import Dict, List, Optional, Sequence, Tuple
+from typing import Callable, Dict, List, Optional, Sequence, Tuple
import torch
from torch.nn import parallel as torch_parallel
@@ -225,6 +225,7 @@ def clone_module_for_devices(
devices: List[torch.device],
*,
clear_state_fn=clear_non_picklable_state,
+ progress_callback: Optional[Callable[[int, int, torch.device, str], None]] = None,
) -> Dict[torch.device, torch.nn.Module]:
clones: Dict[torch.device, torch.nn.Module] = {}
if not devices:
@@ -234,6 +235,21 @@ def clone_module_for_devices(
clone_timings: List[Tuple[str, float]] = []
overall_start = time.perf_counter()
+ total_targets = len(devices)
+
+ def _notify(idx: int, device: torch.device, step: str) -> None:
+ if progress_callback is None:
+ return
+ try:
+ progress_callback(idx, total_targets, device, step)
+ except Exception:
+ if DEBUG_ON:
+ log.debug(
+ "clone_module_for_devices: progress callback failed (device=%s, step=%s)",
+ device,
+ step,
+ )
+
def _record(name: str, start_ts: Optional[float]) -> None:
if not DEBUG_ON or start_ts is None:
return
@@ -283,17 +299,19 @@ def _prepare_module(target_device: torch.device, step_name: str) -> None:
if use_replicate:
try:
_prepare_module(base_device, f"stage_{base_device}")
+ _notify(0, base_device, "stage")
replicate_start = time.perf_counter()
replicas = torch_replicate(module, devices)
_record("replicate", replicate_start)
- for dev, replica in zip(devices, replicas):
+ for idx, (dev, replica) in enumerate(zip(devices, replicas), start=1):
replica.eval()
rehome_module_to_device(replica, dev, move_parameters=True, move_buffers=True)
clear_state_fn(replica)
setattr(replica, "_gptqmodule_device_hint", dev)
clones[dev] = replica
+ _notify(idx, dev, "replica")
_emit_clone_log("replicate")
return clones
@@ -305,14 +323,17 @@ def _prepare_module(target_device: torch.device, step_name: str) -> None:
if len(devices) == 1 and devices[0].type == "cpu":
_prepare_module(CPU, "stage_cpu")
+ _notify(0, CPU, "stage")
clones[devices[0]] = module
+ _notify(1, devices[0], "reuse")
_emit_clone_log("reuse")
return clones
if not use_replicate:
_prepare_module(stage_device, f"stage_{stage_device}")
+ _notify(0, stage_device, "stage")
- for dev in devices:
+ for idx, dev in enumerate(devices, start=1):
start_ts = time.perf_counter()
with _DEEPCOPY_LOCK:
replica = copy.deepcopy(module)
@@ -322,6 +343,7 @@ def _prepare_module(target_device: torch.device, step_name: str) -> None:
setattr(replica, "_gptqmodule_device_hint", dev)
clones[dev] = replica
_record(str(dev), start_ts)
+ _notify(idx, dev, "clone")
_emit_clone_log("deepcopy")
return clones
From e916b4be1d4eeaa1622858e2eccac2137d297d1e Mon Sep 17 00:00:00 2001
From: Qubitium
Date: Sun, 19 Oct 2025 16:25:51 +0000
Subject: [PATCH 11/15] update scores
---
tests/models/test_qwen2_5.py | 17 +++++++++++------
1 file changed, 11 insertions(+), 6 deletions(-)
diff --git a/tests/models/test_qwen2_5.py b/tests/models/test_qwen2_5.py
index d4af6eb32..672103c1a 100644
--- a/tests/models/test_qwen2_5.py
+++ b/tests/models/test_qwen2_5.py
@@ -6,20 +6,25 @@
from model_test import ModelTest
from gptqmodel.utils.eval import EVAL
-
+# | Metric | MARLIN |
+# |--------------------------------|----------|
+# | arc_challenge :: acc,none | 0.2884 |
+# | arc_challenge :: acc_norm,none | 0.3208 |
+# | mmlu :: acc,none | 0.442 |
class TestQwen2_5(ModelTest):
+ GROUP_SIZE = 32
NATIVE_MODEL_ID = "/monster/data/model/Qwen2.5-0.5B-Instruct"
EVAL_TASKS = {
EVAL.LM_EVAL.ARC_CHALLENGE: {
- "acc": {"value": 0.2722, "floor_pct": 0.04},
- "acc_norm": {"value": 0.3072, "floor_pct": 0.04},
+ "acc": {"value": 0.2884, "floor_pct": 0.04},
+ "acc_norm": {"value": 0.3208, "floor_pct": 0.04},
},
EVAL.LM_EVAL.MMLU: {
- "acc": {"value": 0.4029, "floor_pct": 0.04},
+ "acc": {"value": 0.4420, "floor_pct": 0.04},
},
}
- TRUST_REMOTE_CODE = False
- APPLY_CHAT_TEMPLATE = True
+ #TRUST_REMOTE_CODE = False
+ #APPLY_CHAT_TEMPLATE = True
#EVAL_BATCH_SIZE = 6
def test_qwen2_5(self):
From b5f918fa2b12f64eaf9fdf819641fac6cd101ad1 Mon Sep 17 00:00:00 2001
From: Qubitium
Date: Sun, 19 Oct 2025 16:41:30 +0000
Subject: [PATCH 12/15] update logbar to fix log flicker
---
pyproject.toml | 2 +-
requirements.txt | 2 +-
2 files changed, 2 insertions(+), 2 deletions(-)
diff --git a/pyproject.toml b/pyproject.toml
index b6e484ea4..6d2f009ca 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -46,7 +46,7 @@ dependencies = [
"huggingface_hub>=0.34.4",
"random_word>=1.0.13",
"tokenicer>=0.0.5",
- "logbar>=0.1.4",
+ "logbar>=0.1.5",
"maturin>=1.9.4", # required by safetensors and hf_transfer
"datasets>=3.6.0",
"pyarrow>=21.0",
diff --git a/requirements.txt b/requirements.txt
index a9697fcb5..e9ba019ef 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -13,7 +13,7 @@ hf_transfer>=0.1.9
huggingface_hub>=0.34.4
random_word>=1.0.13
tokenicer>=0.0.5
-logbar>=0.1.4
+logbar>=0.1.5
maturin>=1.9.4
datasets>=3.6.0
pyarrow>=21.0
From 86510d5839b20bd0f21fc852dcf31353021e3230 Mon Sep 17 00:00:00 2001
From: Qubitium
Date: Sun, 19 Oct 2025 16:56:35 +0000
Subject: [PATCH 13/15] update scores
---
tests/models/test_glm.py | 13 +++++++------
1 file changed, 7 insertions(+), 6 deletions(-)
diff --git a/tests/models/test_glm.py b/tests/models/test_glm.py
index 295f234f1..c315a0301 100644
--- a/tests/models/test_glm.py
+++ b/tests/models/test_glm.py
@@ -8,19 +8,20 @@
# | Metric | MARLIN |
# |--------------------------------|----------|
-# | arc_challenge :: acc,none | 0.5026 |
-# | arc_challenge :: acc_norm,none | 0.5171 |
-# | mmlu :: acc,none | 0.6362 |
+# | arc_challenge :: acc,none | 0.5154 |
+# | arc_challenge :: acc_norm,none | 0.535 |
+# | mmlu :: acc,none | 0.6325 |
class TestGlm(ModelTest):
+ GROUP_SIZE = 32
# real: THUDM/glm-4-9b-chat-hf
NATIVE_MODEL_ID = "/monster/data/model/glm-4-9b-chat-hf"
EVAL_TASKS = {
EVAL.LM_EVAL.ARC_CHALLENGE: {
- "acc": {"value": 0.5026, "floor_pct": 0.04},
- "acc_norm": {"value": 0.5171, "floor_pct": 0.04},
+ "acc": {"value": 0.5154, "floor_pct": 0.04},
+ "acc_norm": {"value": 0.5350, "floor_pct": 0.04},
},
EVAL.LM_EVAL.MMLU: {
- "acc": {"value": 0.6362, "floor_pct": 0.04},
+ "acc": {"value": 0.6325, "floor_pct": 0.04},
},
}
From 8f189c4268f2c731cdfbca1f86cb518d32174bff Mon Sep 17 00:00:00 2001
From: Qubitium
Date: Sun, 19 Oct 2025 17:20:05 +0000
Subject: [PATCH 14/15] update news
---
README.md | 32 ++++++++++++++++++--------------
1 file changed, 18 insertions(+), 14 deletions(-)
diff --git a/README.md b/README.md
index 30be33b79..c3115b122 100644
--- a/README.md
+++ b/README.md
@@ -17,25 +17,28 @@
## Latest News
+* 10/20/2025 [5.0.0](https://github.com/ModelCloud/GPTQModel/releases/tag/v5.0.0): 🎉 Data-parallel quant support for `MoE` models on multi-gpu using `nogil` Python. `offload_to_disk` support enabled by
+default to massively reduce `cpu` ram usage. New `Intel` and `AMD` cpu hw accelerated `TorchFused` kernel. Packing stage is now 4x faster and now inlined with quantization. `Vram` pressure for large models reduced during quantization.
+`act_group_aware` is 16k+ times faster and now the default when `desc_act=False` for higher quality recovery without inference penalty of `desc_act=True`. New beta quality `AWQ` support with full `gemm`,
+`gemm_fast`, `marlin` kernel support. `LFM`, `Ling`, `Qwen3 Omni` model support. Quantization is now faster with reduced vram usage. Enhanced logging support with `LogBar`.
+* 09/16/2025 [4.2.5](https://github.com/ModelCloud/GPTQModel/releases/tag/v4.2.5): `hyb_act` renamed to `act_group_aware`. Removed finicky `torch` import within `setup.py`. Packing bug fix and prebuilt Pytorch 2.8 whls.
+* 09/12/2025 [4.2.0](https://github.com/ModelCloud/GPTQModel/releases/tag/v4.2.0): ✨ New Models Support: Qwen3-Next, Apertus, Kimi K2, Klear, FastLLM, Nemotron H. New `fail_safe` `boolean` toggle to `.quantize()` to patch-fix non-activated `MoE` modules due to highly uneven MoE model training. Fixed LavaQwen2 compat. Patch fix GIL=0 cuda error for multi-gpu. Fix compat with autoround + new transformers.
+* 09/04/2025 [4.1.0](https://github.com/ModelCloud/GPTQModel/releases/tag/v4.1.0): ✨ Meituan LongCat Flash Chat, Llama 4, GPT-OSS (BF16), and GLM-4.5-Air support. New experiemental `mock_quantization` config to skip complex computational code paths during quantization to accelerate model quant testing.
+* 08/21/2025 [4.0.0](https://github.com/ModelCloud/GPTQModel/releases/tag/v4.0.0): 🎉 New Group Aware Reordering (GAR) support. New models support: Bytedance Seed-OSS, Baidu Ernie, Huawei PanGu, Gemma3, Xiaomi Mimo, Qwen 3/MoE, Falcon H1, GPT-Neo. Memory leak and multiple model compatibility fixes related to Transformers >= 4.54. Python >= 3.13t free-threading support added with near N x GPU linear scaling for quantization of MoE models and also linear N x Cpu Core scaling of packing stage. Early access Pytorch 2.8 fused-ops on Intel XPU for up to 50% speedup.
+
+
+
+Archived News
* 10/17/2025 5.0.0-dev `main`: 👀: EoRA now multi-gpu compatible. Fixed both quality stability of multi-gpu quanta and vram usage. New LFM and Ling models support.
* 09/30/2025 5.0.0-dev `main`: 👀: New Data Parallel + Multi-GPU + Python 3.13T (PYTHON_GIL=0) equals 80%+ overall quant time reduction of large MoE models vs v4.2.5.
* 09/29/2025 5.0.0-dev `main`: 🎉 New Qwen3 Omni model support. AWQ Marlin kernel integrated + many disk offload, threading, and memory usage fixes.
* 09/24/2025 5.0.0-dev `main`: 🎉 Up to 90% cpu mem saving for large MoE models with faster/inline packing! 26% quant time reduction for Qwen3 MoE! AWQ Marlin kernel added. AWQ Gemm loading bug fixes. `act_group_aware` now faster and auto enabled for GPTQ when `desc_act` is False for higher quality recovery.
* 09/19/2025 5.0.0-dev `main`: 👀 Cpu memory saving of ~73.5% during quantization stage with new `offload_to_disk` quantization config property default to `True`.
-* 09/18/2025 5.0.0-dev `main`: 🎉 AWQ quantization support! Complete refractor and simplification of model definitions in prepreation for future quantization formats.
-* 09/16/2025 [4.2.5](https://github.com/ModelCloud/GPTQModel/releases/tag/v4.2.5): `hyb_act` renamed to `act_group_aware`. Removed finicky `torch` import within `setup.py`. Packing bug fix and prebuilt Pytorch 2.8 whls.
-* 09/12/2025 [4.2.0](https://github.com/ModelCloud/GPTQModel/releases/tag/v4.2.0): ✨ New Models Support: Qwen3-Next, Apertus, Kimi K2, Klear, FastLLM, Nemotron H. New `fail_safe` `boolean` toggle to `.quantize()` to patch-fix non-activated `MoE` modules due to highly uneven MoE model training. Fixed LavaQwen2 compat. Patch fix GIL=0 cuda error for multi-gpu. Fix compat with autoround + new transformers.
-* 09/04/2025 [4.1.0](https://github.com/ModelCloud/GPTQModel/releases/tag/v4.1.0): ✨ Meituan LongCat Flash Chat, Llama 4, GPT-OSS (BF16), and GLM-4.5-Air support. New experiemental `mock_quantization` config to skip complex computational code paths during quantization to accelerate model quant testing.
-* 08/21/2025 [4.0.0](https://github.com/ModelCloud/GPTQModel/releases/tag/v4.0.0): 🎉 New Group Aware Reordering (GAR) support. New models support: Bytedance Seed-OSS, Baidu Ernie, Huawei PanGu, Gemma3, Xiaomi Mimo, Qwen 3/MoE, Falcon H1, GPT-Neo. Memory leak and multiple model compatibility fixes related to Transformers >= 4.54. Python >= 3.13t free-threading support added with near N x GPU linear scaling for quantization of MoE models and also linear N x Cpu Core scaling of packing stage. Early access Pytorch 2.8 fused-ops on Intel XPU for up to 50% speedup.
+* 09/18/2025 5.0.0-dev `main`: 🎉 AWQ quantization support! Complete refractor and simplification of model definitions in prepreation for future quantization formats.
* 08/19/2025 4.0.0-dev `main`: Fix quantization memory usage due to some model's incorrect application of `config.use_cache` during inference. Fixed `Transformers` >= 4.54.0 compat which changed layer forward return signature for some models.
* 08/18/2025 4.0.0-dev `main`: GPT-Neo model support. Memory leak fix in error capture (stacktrace) and fixed `lm_head` quantization compatibility for many models.
* 07/31/2025 4.0.0-dev `main`: New Group Aware Reordering (GAR) support and prelim Pytorch 2.8 fused-ops for Intel XPU for up to 50% speedup.
* 07/03/2025 4.0.0-dev `main`: New Baidu Ernie and Huawei PanGu model support.
-
-
-
-Archived News
-
* 07/02/2025 4.0.0-dev `main`: Gemma3 4B model compat fix.
* 05/29/2025 4.0.0-dev `main`: Falcon H1 model support. Fixed Transformers `4.52+` compat with Qwen 2.5 VL models.
* 05/19/2025 4.0.0-dev `main`: Qwen 2.5 Omni model support.
@@ -281,12 +284,13 @@ model.quantize(calibration_dataset, batch_size=1)
model.save(quant_path)
```
-### Quantization using GPTQ V2
+### Quantization using GPTQ V2* (Experimental, not MoE compatible, and results may not be better than v1)
-Enable GPTQ v2 quantization by setting `v2 = True` for potentially higher post-quantization accuracy recovery.
+Enable GPTQ v2 quantization by setting `v2 = True`.
```py
-# note v2 is currently experiemental and requires 2-4x more vram to execute
-# if oom on 1 gpu, please set CUDA_VISIBLE_DEVICES=0,1 to 2 gpu and gptqmodel will auto use second gpu
+# Note v2 is currently experimental, not MoE compatible, and requires 2-4x more vram to execute
+# We have many reports of v2 not working better or exceeding v1 so please use for testing only
+# If oom on 1 gpu, please set CUDA_VISIBLE_DEVICES=0,1 to 2 gpu and gptqmodel will auto use second gpu
quant_config = QuantizeConfig(bits=4, group_size=128, v2=True)
```
`Llama 3.1 8B-Instruct` quantized using `test/models/test_llama3_2.py`
From 6084179c4b3fd0f33a50e5f3b6a7a877572e857c Mon Sep 17 00:00:00 2001
From: Qubitium
Date: Sun, 19 Oct 2025 17:21:02 +0000
Subject: [PATCH 15/15] update scores
---
tests/models/test_glm4_moe.py | 3 +++
1 file changed, 3 insertions(+)
diff --git a/tests/models/test_glm4_moe.py b/tests/models/test_glm4_moe.py
index 74540ef43..a53942fb4 100644
--- a/tests/models/test_glm4_moe.py
+++ b/tests/models/test_glm4_moe.py
@@ -8,6 +8,9 @@
class TestGlm4Moe(ModelTest):
NATIVE_MODEL_ID = "/monster/data/model/GLM-4.6/"
+ DELETE_QUANTIZED_MODEL = False
+ DATASET_SIZE = 512
+ GROUP_SIZE = 32
EVAL_TASKS = {
EVAL.LM_EVAL.ARC_CHALLENGE: {
"acc": {"value": 0.5026, "floor_pct": 0.04},