From 578351e0a4440d7d2f4d60a82f54cf7898a18f2a Mon Sep 17 00:00:00 2001 From: Qubitium Date: Tue, 30 Sep 2025 02:40:38 +0000 Subject: [PATCH 01/12] init Signed-off-by: Qubitium --- gptqmodel/looper/module_looper.py | 377 ++++++++++++++++++++---------- 1 file changed, 260 insertions(+), 117 deletions(-) diff --git a/gptqmodel/looper/module_looper.py b/gptqmodel/looper/module_looper.py index fee50bdb1..bbe916fb1 100644 --- a/gptqmodel/looper/module_looper.py +++ b/gptqmodel/looper/module_looper.py @@ -6,11 +6,9 @@ from __future__ import annotations import copy -import gc import threading import time from contextlib import contextmanager -from functools import partial from typing import Dict, List, Optional import torch @@ -22,7 +20,7 @@ from ..looper.loop_processor import LoopProcessor from ..looper.named_module import NamedModule from ..models import BaseQModel -from ..models._const import CUDA, SUPPORTS_MODULE_TYPES +from ..models._const import SUPPORTS_MODULE_TYPES from ..nn_modules.hooked_linear import (STOP_FORWARD_EXCEPTION, HookedLinear, StopForward, replace_module_with_hooked_legacy) from ..utils.attn_mask import apply_keep_mask_bt, normalize_seq_mask @@ -30,10 +28,8 @@ from ..utils.logger import setup_logger from ..utils.model import find_modules, get_module, get_module_by_name_prefix, move_to, nested_move_to from ..utils.offload import offload_to_disk -from ..utils.structure import print_module_tree from ..utils.threadx import DeviceThreadPool -from ..utils.torch import (ALL_DEVICES, CPU, DEFAULT_BALANCE_STRATEGY, HAS_CUDA, META, BalanceStrategy, - device_next, device_next_reset, torch_empty_cache, torch_sync) +from ..utils.torch import (ALL_DEVICES, CPU, META, device_next, device_next_reset, torch_sync) from .awq_processor import AWQProcessor from .qqq_processor import QQQProcessor @@ -144,12 +140,233 @@ def __init__(self, model: BaseQModel, processors: List[LoopProcessor]): empty_cache_every_n=14, # disable auto GC during quant loops; enable if you want ) + for processor in self.processors: + self._processor_mask_tls(processor) + # NEW: Wrap an existing hook so its inputs/outputs are pre-masked for GPTQ stats. # We *do not* alter the module's actual computation; only what the hook # passes down to the processor capture path is masked. + def _processor_mask_tls(self, processor: LoopProcessor) -> threading.local: + tls = getattr(processor, "_mask_tls", None) + if tls is None: + tls = threading.local() + setattr(processor, "_mask_tls", tls) + return tls + + def _set_processor_mask(self, processor: LoopProcessor, mask): + tls = self._processor_mask_tls(processor) + tls.value = mask + + def _get_processor_mask(self, processor: LoopProcessor): + tls = getattr(processor, "_mask_tls", None) + return getattr(tls, "value", None) if tls else None + + def _select_forward_devices(self, base_device: Optional[torch.device]) -> List[torch.device]: + if base_device is None: + return [CPU] + + devices = [base_device] + base_type = getattr(base_device, "type", None) + if base_type in ("cuda", "xpu", "mps"): + for dev in ALL_DEVICES: + if getattr(dev, "type", None) == base_type and dev not in devices: + devices.append(dev) + return devices + + def _clone_module_for_devices(self, module: torch.nn.Module, devices: List[torch.device]) -> Dict[torch.device, torch.nn.Module]: + clones: Dict[torch.device, torch.nn.Module] = {} + base_device = get_device(module) + + cleared_attrs = self._clear_non_picklable_state(module) + try: + for dev in devices: + if base_device is not None and dev == base_device: + clones[dev] = module + _rehome_module_to_device(module, dev, move_parameters=False, move_buffers=True) + else: + replica = copy.deepcopy(module) + replica = replica.to(dev) + replica.eval() + _rehome_module_to_device(replica, dev, move_parameters=False, move_buffers=True) + self._clear_non_picklable_state(replica) + clones[dev] = replica + finally: + self._restore_non_picklable_state(cleared_attrs) + return clones + + def _clear_non_picklable_state(self, module: torch.nn.Module): + cleared = [] + seen = set() + + def maybe_clear(obj: torch.nn.Module): + if id(obj) in seen: + return + seen.add(id(obj)) + if hasattr(obj, "target_device_stream"): + cleared.append((obj, getattr(obj, "target_device_stream"))) + setattr(obj, "target_device_stream", None) + + if isinstance(module, torch.nn.Module): + for sub in module.modules(): + maybe_clear(sub) + else: + maybe_clear(module) + + return cleared + + @staticmethod + def _restore_non_picklable_state(cleared): + for obj, value in cleared: + setattr(obj, "target_device_stream", value) + + @staticmethod + def _forward_batch_worker( + module: torch.nn.Module, + processor: LoopProcessor, + batch_index: int, + layer_input: List[torch.Tensor], + layer_input_kwargs: Dict[str, torch.Tensor], + attention_mask: Optional[torch.Tensor], + position_ids: Optional[torch.Tensor], + *, + support_batch_quantize: bool, + is_lm_head_module: bool, + need_output: bool, + reuse_kv: bool, + prev_kv, + ): + module_device = get_device(module) + _rehome_module_to_device(module, module_device, move_parameters=False, move_buffers=True) + + inputs = [move_to(inp, device=module_device, stream=False) for inp in layer_input] + + attn_tensor = None + if attention_mask is not None: + attn_tensor = move_to(attention_mask, device=module_device, stream=False) + + additional_inputs: Dict[str, torch.Tensor] = {} + if support_batch_quantize and attn_tensor is not None: + additional_inputs["attention_mask"] = attn_tensor + + if position_ids is not None: + additional_inputs["position_ids"] = move_to(position_ids, device=module_device, stream=False) + + for key, value in layer_input_kwargs.items(): + additional_inputs[key] = nested_move_to(value, device=module_device, stream=False) + + keep_mask = None + if attn_tensor is not None: + seq_len = inputs[0].shape[1] if (len(inputs) > 0 and inputs[0].dim() >= 2) else None + keep_mask = normalize_seq_mask(attn_tensor, seq_len=seq_len) + + mask_tls = getattr(processor, "_mask_tls", None) + if mask_tls is not None: + mask_tls.value = keep_mask + + if reuse_kv and prev_kv is not None: + additional_inputs["kv_last_layer"] = nested_move_to(prev_kv, device=module_device, stream=False) + + module_output = None + kv_next = None + try: + if is_lm_head_module: + module_output = module(*inputs) + else: + module_output = module(*inputs, **additional_inputs) + except StopForward: + module_output = None + finally: + if mask_tls is not None: + mask_tls.value = None + + if reuse_kv and module_output is not None and isinstance(module_output, tuple) and len(module_output) > 0: + kv_next = module_output[-1] + + result_output = module_output if need_output else None + return batch_index, result_output, kv_next + + def _run_forward_batches( + self, + *, + module: torch.nn.Module, + processor: LoopProcessor, + layer_inputs: List[List[torch.Tensor]], + layer_input_kwargs: List[Dict[str, torch.Tensor]], + position_ids: List[torch.Tensor], + attention_masks: List[torch.Tensor], + cur_layer_device: torch.device, + is_lm_head_module: bool, + shared_kv_cache_dict: Dict[int, torch.Tensor], + layer_index: int, + need_outputs: bool, + reuse_kv: bool, + ) -> List[List[torch.Tensor]]: + devices = self._select_forward_devices(cur_layer_device) + module_replicas = self._clone_module_for_devices(module, devices) + + prev_kv = shared_kv_cache_dict.get(layer_index - 1) if reuse_kv else None + + results: Dict[int, torch.Tensor | tuple | None] = {} + + chunk = len(devices) + total_batches = processor.num_batches + for start in range(0, total_batches, chunk): + futures = [] + end = min(start + chunk, total_batches) + for offset, batch_idx in enumerate(range(start, end)): + device = devices[offset] + replica = module_replicas[device] + futures.append( + self.pool.submit( + device, + self._forward_batch_worker, + replica, + processor, + batch_idx, + layer_inputs[batch_idx], + layer_input_kwargs[batch_idx], + attention_masks[batch_idx], + position_ids[batch_idx] if position_ids else None, + support_batch_quantize=self.support_batch_quantize, + is_lm_head_module=is_lm_head_module, + need_output=need_outputs, + reuse_kv=reuse_kv, + prev_kv=prev_kv, + ) + ) + + for fut in futures: + batch_idx, module_output, kv_next = fut.result() + if need_outputs and module_output is not None: + results[batch_idx] = module_output + if reuse_kv and kv_next is not None and shared_kv_cache_dict.get(layer_index) is None: + shared_kv_cache_dict[layer_index] = nested_move_to(kv_next, device=cur_layer_device, stream=False) + + # ensure replicas that are clones release promptly + for dev in list(module_replicas.keys()): + if dev != cur_layer_device: + del module_replicas[dev] + + if not need_outputs: + return [] + + ordered_outputs: List[List[torch.Tensor]] = [] + for idx in range(total_batches): + module_output = results.get(idx) + if module_output is None: + raise RuntimeError("Forward batch returned no output; data-parallel execution produced empty result.") + if isinstance(module_output, tuple): + primary = module_output[0] + else: + primary = module_output + primary = move_to(primary, device=cur_layer_device, stream=False) + ordered_outputs.append([primary]) + + return ordered_outputs + def _masked_hook_wrapper(self, processor: LoopProcessor, inner_hook): def hook(module, inputs, output): - keep = getattr(processor, "current_attention_mask", None) + keep = self._get_processor_mask(processor) # Mask first tensor-like input if it's [B, S, ...] new_inputs = inputs @@ -461,66 +678,27 @@ def loop(self, fail_safe: bool = False, **kwargs): # ---- Start Pre-Quantized Forward ---- fwd_start = time.time() - layer_outputs = [] - for j in range(processor.num_batches): - layer_input = [] - for k, layer_inp in enumerate(layer_inputs[j]): - layer_input.append(move_to(layer_inp, device=cur_layer_device, stream=False)) - - raw_mask = attention_masks[j] - layer_attention_mask = raw_mask if raw_mask is None else move_to(raw_mask, device=cur_layer_device, stream=False) - - # Compute and set keep-mask for this batch - if raw_mask is not None: - seq_len = layer_input[0].shape[1] if (len(layer_input) > 0 and layer_input[0].dim() >= 2) else None - keep_mask_bs = normalize_seq_mask(layer_attention_mask, seq_len=seq_len) - setattr(processor, "current_attention_mask", keep_mask_bs) - else: - setattr(processor, "current_attention_mask", None) - - additional_layer_inputs = {"attention_mask": layer_attention_mask} if self.support_batch_quantize else {} - layer_position_ids = ( - None if not position_ids else move_to(position_ids[j], device=cur_layer_device, stream=False) - ) - if layer_position_ids is not None: - additional_layer_inputs["position_ids"] = layer_position_ids - for k, v in layer_input_kwargs[j].items(): - additional_layer_inputs[k] = nested_move_to(v, device=cur_layer_device, stream=False) - - try: - # Ensure internal buffers (e.g., RoPE caches) are on the layer's device - # _rehome_module_to_device(module, cur_layer_device, move_parameters=False, move_buffers=True) - - # Acquire read lock so auto-GC cannot run while we forward - with self.pool.read_lock(cur_layer_device): - with _device_ctx(cur_layer_device): - # reuse_kv special-case - if hasattr(module, "reuse_kv") and module.reuse_kv: - additional_layer_inputs["kv_last_layer"] = shared_kv_cache_dict.get(layer_index - 1) - layer_output = module(*layer_input) if is_lm_head_module else module(*layer_input, **additional_layer_inputs) - if shared_kv_cache_dict.get(layer_index) is None: - shared_kv_cache_dict[layer_index] = layer_output[-1] - else: - layer_output = module(*layer_input) if is_lm_head_module else module(*layer_input, **additional_layer_inputs) - except StopForward: - pass - finally: - setattr(processor, "current_attention_mask", None) - del layer_input - del additional_layer_inputs - - if not processor.fwd_after_process: - if isinstance(layer_output, tuple): - layer_outputs.append([layer_output[0]]) - else: - layer_outputs.append([layer_output]) - - if not processor.fwd_after_process: - processor.receive_layer_inputs(layer_outputs) - del layer_outputs - - fwd_end = time.time() - fwd_time = fwd_end - fwd_start + need_outputs = not processor.fwd_after_process + reuse_kv = bool(getattr(module, "reuse_kv", False)) + forward_outputs = self._run_forward_batches( + module=module, + processor=processor, + layer_inputs=layer_inputs, + layer_input_kwargs=layer_input_kwargs, + position_ids=position_ids, + attention_masks=attention_masks, + cur_layer_device=cur_layer_device, + is_lm_head_module=is_lm_head_module, + shared_kv_cache_dict=shared_kv_cache_dict, + layer_index=layer_index, + need_outputs=need_outputs, + reuse_kv=reuse_kv, + ) + if need_outputs: + processor.receive_layer_inputs(forward_outputs) + del forward_outputs + + fwd_time = time.time() - fwd_start processor.set_fwd_time(fwd_time) for h in handle: @@ -565,58 +743,23 @@ def _process_on_worker(proc: LoopProcessor, nm: NamedModule): # ---- End Process Hook ---- is_last_module = layer_index == len(quant_modules_pb) - 1 + layer_outputs: List[List[torch.Tensor]] = [] # second forward after process() if not is_last_module and processor.fwd_after_process: - layer_outputs = [] - for j in range(processor.num_batches): - layer_input = [] - for k, layer_inp in enumerate(layer_inputs[j]): - layer_input.append(move_to(layer_inp, device=cur_layer_device)) - - raw_mask = attention_masks[j] - layer_attention_mask = raw_mask if raw_mask is None else move_to(raw_mask, device=cur_layer_device) - - if raw_mask is not None: - seq_len = layer_input[0].shape[1] if (len(layer_input) > 0 and layer_input[0].dim() >= 2) else None - keep_mask_bs = normalize_seq_mask(layer_attention_mask, seq_len=seq_len) - setattr(processor, "current_attention_mask", keep_mask_bs) - else: - setattr(processor, "current_attention_mask", None) - - additional_layer_inputs = {"attention_mask": layer_attention_mask} if self.support_batch_quantize else {} - layer_position_ids = None if not position_ids else move_to(position_ids[j], device=cur_layer_device) - if layer_position_ids is not None: - additional_layer_inputs["position_ids"] = layer_position_ids - for k, v in layer_input_kwargs[j].items(): - additional_layer_inputs[k] = nested_move_to(v, device=cur_layer_device) - - # Rehome buffers again in case module ran on a different device previously - _rehome_module_to_device(module, cur_layer_device, move_parameters=False, move_buffers=True) - - # Guard forward with read lock to block auto-GC - with self.pool.read_lock(cur_layer_device): - with _device_ctx(cur_layer_device): - if is_lm_head_module: - module_output = module(*layer_input) - else: - module_output = module(*layer_input, **additional_layer_inputs) - - if isinstance(module_output, tuple): - layer_output = module_output[0] - else: - layer_output = module_output - - layer_output = move_to( - layer_output, - device=cur_layer_device, - ) - - layer_outputs.append([layer_output]) - - setattr(processor, "current_attention_mask", None) - - del layer_input - del additional_layer_inputs + layer_outputs = self._run_forward_batches( + module=module, + processor=processor, + layer_inputs=layer_inputs, + layer_input_kwargs=layer_input_kwargs, + position_ids=position_ids, + attention_masks=attention_masks, + cur_layer_device=cur_layer_device, + is_lm_head_module=is_lm_head_module, + shared_kv_cache_dict=shared_kv_cache_dict, + layer_index=layer_index, + need_outputs=True, + reuse_kv=False, + ) # Finalize module after last processor if p_index == len(self.processors) - 1: From 6c64d4e4e0bf76535a59d56ece683232eb8f0607 Mon Sep 17 00:00:00 2001 From: Qubitium Date: Tue, 30 Sep 2025 02:53:57 +0000 Subject: [PATCH 02/12] fix clone and hidden_states Signed-off-by: Qubitium --- gptqmodel/looper/module_looper.py | 144 +++++++++++++++++++++++++++--- 1 file changed, 132 insertions(+), 12 deletions(-) diff --git a/gptqmodel/looper/module_looper.py b/gptqmodel/looper/module_looper.py index bbe916fb1..4392243f8 100644 --- a/gptqmodel/looper/module_looper.py +++ b/gptqmodel/looper/module_looper.py @@ -137,7 +137,7 @@ def __init__(self, model: BaseQModel, processors: List[LoopProcessor]): "mps": 8, # unified memory "cpu": 8, # unified memory }, - empty_cache_every_n=14, # disable auto GC during quant loops; enable if you want + empty_cache_every_n=28, # disable auto GC during quant loops; enable if you want ) for processor in self.processors: @@ -175,21 +175,16 @@ def _select_forward_devices(self, base_device: Optional[torch.device]) -> List[t def _clone_module_for_devices(self, module: torch.nn.Module, devices: List[torch.device]) -> Dict[torch.device, torch.nn.Module]: clones: Dict[torch.device, torch.nn.Module] = {} - base_device = get_device(module) cleared_attrs = self._clear_non_picklable_state(module) try: for dev in devices: - if base_device is not None and dev == base_device: - clones[dev] = module - _rehome_module_to_device(module, dev, move_parameters=False, move_buffers=True) - else: - replica = copy.deepcopy(module) - replica = replica.to(dev) - replica.eval() - _rehome_module_to_device(replica, dev, move_parameters=False, move_buffers=True) - self._clear_non_picklable_state(replica) - clones[dev] = replica + replica = copy.deepcopy(module) + replica = replica.to(dev) + replica.eval() + _rehome_module_to_device(replica, dev, move_parameters=False, move_buffers=True) + self._clear_non_picklable_state(replica) + clones[dev] = replica finally: self._restore_non_picklable_state(cleared_attrs) return clones @@ -302,6 +297,131 @@ def _run_forward_batches( reuse_kv: bool, ) -> List[List[torch.Tensor]]: devices = self._select_forward_devices(cur_layer_device) + + if len(devices) <= 1: + return self._run_forward_batches_single( + module=module, + processor=processor, + layer_inputs=layer_inputs, + layer_input_kwargs=layer_input_kwargs, + position_ids=position_ids, + attention_masks=attention_masks, + cur_layer_device=cur_layer_device, + is_lm_head_module=is_lm_head_module, + shared_kv_cache_dict=shared_kv_cache_dict, + layer_index=layer_index, + need_outputs=need_outputs, + reuse_kv=reuse_kv, + ) + + return self._run_forward_batches_parallel( + module=module, + processor=processor, + layer_inputs=layer_inputs, + layer_input_kwargs=layer_input_kwargs, + position_ids=position_ids, + attention_masks=attention_masks, + cur_layer_device=cur_layer_device, + is_lm_head_module=is_lm_head_module, + shared_kv_cache_dict=shared_kv_cache_dict, + layer_index=layer_index, + need_outputs=need_outputs, + reuse_kv=reuse_kv, + devices=devices, + ) + + def _run_forward_batches_single( + self, + *, + module: torch.nn.Module, + processor: LoopProcessor, + layer_inputs: List[List[torch.Tensor]], + layer_input_kwargs: List[Dict[str, torch.Tensor]], + position_ids: List[torch.Tensor], + attention_masks: List[torch.Tensor], + cur_layer_device: torch.device, + is_lm_head_module: bool, + shared_kv_cache_dict: Dict[int, torch.Tensor], + layer_index: int, + need_outputs: bool, + reuse_kv: bool, + ) -> List[List[torch.Tensor]]: + outputs: List[List[torch.Tensor]] = [] + prev_kv = shared_kv_cache_dict.get(layer_index - 1) if reuse_kv else None + + for batch_idx in range(processor.num_batches): + layer_input = [move_to(inp, device=cur_layer_device, stream=False) for inp in layer_inputs[batch_idx]] + + raw_mask = attention_masks[batch_idx] + attn_tensor = raw_mask if raw_mask is None else move_to(raw_mask, device=cur_layer_device, stream=False) + + keep_mask = None + if attn_tensor is not None: + seq_len = layer_input[0].shape[1] if (len(layer_input) > 0 and layer_input[0].dim() >= 2) else None + keep_mask = normalize_seq_mask(attn_tensor, seq_len=seq_len) + self._set_processor_mask(processor, keep_mask) + + additional_inputs: Dict[str, torch.Tensor] = {} + if self.support_batch_quantize and attn_tensor is not None: + additional_inputs["attention_mask"] = attn_tensor + + if position_ids: + pos = position_ids[batch_idx] + if pos is not None: + additional_inputs["position_ids"] = move_to(pos, device=cur_layer_device, stream=False) + + for key, value in layer_input_kwargs[batch_idx].items(): + additional_inputs[key] = nested_move_to(value, device=cur_layer_device, stream=False) + + if reuse_kv and prev_kv is not None: + additional_inputs["kv_last_layer"] = nested_move_to(prev_kv, device=cur_layer_device, stream=False) + + _rehome_module_to_device(module, cur_layer_device, move_parameters=False, move_buffers=True) + + module_output = None + try: + if is_lm_head_module: + module_output = module(*layer_input) + else: + module_output = module(*layer_input, **additional_inputs) + except StopForward: + module_output = None + finally: + self._set_processor_mask(processor, None) + + if ( + reuse_kv + and module_output is not None + and isinstance(module_output, tuple) + and len(module_output) > 0 + and shared_kv_cache_dict.get(layer_index) is None + ): + shared_kv_cache_dict[layer_index] = module_output[-1] + + if need_outputs and module_output is not None: + primary = module_output[0] if isinstance(module_output, tuple) else module_output + primary = move_to(primary, device=cur_layer_device, stream=False) + outputs.append([primary]) + + return outputs + + def _run_forward_batches_parallel( + self, + *, + module: torch.nn.Module, + processor: LoopProcessor, + layer_inputs: List[List[torch.Tensor]], + layer_input_kwargs: List[Dict[str, torch.Tensor]], + position_ids: List[torch.Tensor], + attention_masks: List[torch.Tensor], + cur_layer_device: torch.device, + is_lm_head_module: bool, + shared_kv_cache_dict: Dict[int, torch.Tensor], + layer_index: int, + need_outputs: bool, + reuse_kv: bool, + devices: List[torch.device], + ) -> List[List[torch.Tensor]]: module_replicas = self._clone_module_for_devices(module, devices) prev_kv = shared_kv_cache_dict.get(layer_index - 1) if reuse_kv else None From 309b7fa4fbff4345e01be5fb39b1216d69fbad6a Mon Sep 17 00:00:00 2001 From: Qubitium Date: Tue, 30 Sep 2025 03:07:15 +0000 Subject: [PATCH 03/12] allow blocks to disable auto gc Signed-off-by: Qubitium --- gptqmodel/utils/threadx.py | 36 ++++++++++++++++++++++++++++++++++++ 1 file changed, 36 insertions(+) diff --git a/gptqmodel/utils/threadx.py b/gptqmodel/utils/threadx.py index 04bbdd8ac..bb2e768d9 100644 --- a/gptqmodel/utils/threadx.py +++ b/gptqmodel/utils/threadx.py @@ -17,6 +17,7 @@ from ..utils.logger import setup_logger + log = setup_logger() # Debug logging is very chatty and can alter timings subtly in tests. @@ -458,6 +459,10 @@ def __init__( self._stop_event = threading.Event() self._janitor: Optional[threading.Thread] = None + # Auto-GC disable tracking (allows latency-sensitive regions to pause janitor) + self._auto_gc_disable_count = 0 + self._auto_gc_disable_cv = threading.Condition() + # In-flight (scheduled but not finished) counters + per-device CVs. # Each device has a condition variable to let wait() callers block # until inflight hits zero for that device scope. @@ -630,6 +635,28 @@ def shutdown(self, wait: bool = True): if DEBUG_ON: log.debug("DeviceThreadPool shutdown complete") + @contextlib.contextmanager + def no_auto_gc(self): + """ + Temporarily disable automatic empty-cache passes. Useful for latency-sensitive + critical sections (e.g., forwarding) where janitor interference is undesirable. + """ + with self._auto_gc_disable_cv: + self._auto_gc_disable_count += 1 + try: + yield + finally: + should_signal = False + with self._auto_gc_disable_cv: + if self._auto_gc_disable_count > 0: + self._auto_gc_disable_count -= 1 + if self._auto_gc_disable_count == 0: + should_signal = True + self._auto_gc_disable_cv.notify_all() + if should_signal: + # Wake janitor in case a trigger is pending. + self._gc_event.set() + # --------------- Public Lock API --------------- def device_lock(self, device: DeviceLike): @@ -1137,6 +1164,15 @@ def _janitor_loop(self): self._gc_event.clear() if DEBUG_ON: log.debug("DP-Janitor: debounce window end") + with self._auto_gc_disable_cv: + while self._auto_gc_disable_count > 0 and not self._stop_event.is_set(): + if DEBUG_ON: + log.debug("DP-Janitor: auto-GC disabled; waiting…") + self._auto_gc_disable_cv.wait(timeout=WAIT_TIMEOUT) + if self._stop_event.is_set(): + if DEBUG_ON: log.debug("DP-Janitor: stop event set during auto-GC wait; exiting") + break + # Snapshot & decision try: pre = self._collect_state_snapshot() From 60ccdc20753710cadc6685efb87c534747d61a72 Mon Sep 17 00:00:00 2001 From: Qubitium Date: Tue, 30 Sep 2025 03:28:50 +0000 Subject: [PATCH 04/12] async turtle reload Signed-off-by: Qubitium --- gptqmodel/looper/module_looper.py | 6 +- gptqmodel/models/base.py | 132 +++++++++++++++++++++++++++--- 2 files changed, 127 insertions(+), 11 deletions(-) diff --git a/gptqmodel/looper/module_looper.py b/gptqmodel/looper/module_looper.py index 4392243f8..955e951b4 100644 --- a/gptqmodel/looper/module_looper.py +++ b/gptqmodel/looper/module_looper.py @@ -137,9 +137,11 @@ def __init__(self, model: BaseQModel, processors: List[LoopProcessor]): "mps": 8, # unified memory "cpu": 8, # unified memory }, - empty_cache_every_n=28, # disable auto GC during quant loops; enable if you want + empty_cache_every_n=0, # disable auto GC during quant loops; enable if you want ) + self.gptq_model.register_background_pool(self.pool) + for processor in self.processors: self._processor_mask_tls(processor) @@ -931,6 +933,8 @@ def finalize_module(process, module): # Ensure ANY remaining tasks the looper submitted have drained self.pool.wait() # same as wait('all') + self.gptq_model.wait_for_turtle_reload() + # paranoid safety check # torch_sync() # torch_sync(device=CPU) diff --git a/gptqmodel/models/base.py b/gptqmodel/models/base.py index 1e503b253..3af193620 100644 --- a/gptqmodel/models/base.py +++ b/gptqmodel/models/base.py @@ -8,8 +8,10 @@ import json import os import random +import threading from collections import defaultdict -from typing import Any, Dict, List, Optional, Type, Union +from concurrent.futures import Future +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Type, Union import torch import torch._dynamo @@ -51,6 +53,9 @@ from .loader import ModelLoader from .writer import ModelWriter +if TYPE_CHECKING: + from ..utils.threadx import DeviceThreadPool + class _ClassPropertyDescriptor: def __init__(self, fget, fset=None): @@ -198,6 +203,10 @@ def __init__( self.quantize_config = quantize_config + self._background_pool: Optional["DeviceThreadPool"] = None + self._turtle_reload_future: Optional[Future] = None + self._turtle_reload_lock = threading.Lock() + # compat: state to assist in checkpoint_format gptq(v1) to gptq_v2 conversion self.qlinear_kernel = qlinear_kernel self.trust_remote_code = trust_remote_code @@ -1073,6 +1082,114 @@ def format_nodes(nodes): # print("DEBUG AWQ NODES:", format_nodes(nodes)) return nodes + def register_background_pool(self, pool: Optional["DeviceThreadPool"]) -> None: + self._background_pool = pool + + def _clone_model_init_kwargs(self, source: PreTrainedModel) -> Dict[str, Any]: + kwargs = getattr(source, "_model_init_kwargs", {}) or {} + if isinstance(kwargs, dict): + return dict(kwargs) + return copy.deepcopy(kwargs) + + def _reload_turtle_model_sync(self) -> Optional[PreTrainedModel]: + if self.turtle_model is None or self.model_local_path is None: + return self.turtle_model + + reload_kwargs = self._clone_model_init_kwargs(self.turtle_model) + config = self.turtle_model.config + + new_model = self.loader.from_pretrained( + self.model_local_path, + config=config, + low_cpu_mem_usage=True, + **reload_kwargs, + ) + new_model._model_init_kwargs = reload_kwargs + return new_model + + def _schedule_turtle_reload(self) -> None: + self._apply_completed_turtle_reload() + + if self.turtle_model is None or self.model_local_path is None: + return + + pool = self._background_pool + if pool is None: + new_model = self._reload_turtle_model_sync() + if new_model is not None: + self.turtle_model = new_model + return + + with self._turtle_reload_lock: + future = self._turtle_reload_future + if future is not None and not future.done(): + return + + reload_kwargs = self._clone_model_init_kwargs(self.turtle_model) + config = self.turtle_model.config + model_local_path = self.model_local_path + loader = self.loader + + def _reload_task(): + model = loader.from_pretrained( + model_local_path, + config=config, + low_cpu_mem_usage=True, + **reload_kwargs, + ) + model._model_init_kwargs = reload_kwargs + return model + + try: + self._turtle_reload_future = pool.submit(CPU, _reload_task) + except Exception as exc: + log.warning("Turtle reload scheduling failed; falling back to sync reload: %s", exc) + self._turtle_reload_future = None + + if self._turtle_reload_future is None: + new_model = self._reload_turtle_model_sync() + if new_model is not None: + self.turtle_model = new_model + + def _apply_completed_turtle_reload(self, *, wait: bool = False) -> None: + """Adopt the result of any finished background turtle reload.""" + future: Optional[Future] + with self._turtle_reload_lock: + future = self._turtle_reload_future + + if future is None: + return + + if wait and not future.done(): + try: + future.result() + except Exception: + # result() already logs in the calling path + pass + + if not future.done(): + return + + try: + new_model = future.result() + except Exception as exc: + log.warning("Background turtle reload failed; retrying synchronously: %s", exc) + new_model = self._reload_turtle_model_sync() + with self._turtle_reload_lock: + if self._turtle_reload_future is future: + self._turtle_reload_future = None + if new_model is not None: + self.turtle_model = new_model + return + + with self._turtle_reload_lock: + if self._turtle_reload_future is future: + self.turtle_model = new_model + self._turtle_reload_future = None + + def wait_for_turtle_reload(self) -> None: + self._apply_completed_turtle_reload(wait=True) + # transfer actually materizlied module from turtle (real) to shell def shell_module_materialize( self, @@ -1080,6 +1197,8 @@ def shell_module_materialize( device: torch.device, non_blocking: bool = False, ) -> torch.nn.Module: + self._apply_completed_turtle_reload() + if self.turtle_model is None: if get_device(target_submodule) != device: target_submodule.to(device) @@ -1090,16 +1209,9 @@ def shell_module_materialize( target_model=self.model, turtle_model=self.turtle_model, target_submodule=target_submodule, - device=self.quantize_config.device, + device=device, ) - - # reload turle - # FIX ME..need trust remote true - model_init_kwargs = self.turtle_model._model_init_kwargs - self.turtle_model = self.loader.from_pretrained(self.model_local_path, config=self.turtle_model.config, low_cpu_mem_usage=True, **model_init_kwargs) - self.turtle_model._model_init_kwargs = model_init_kwargs - - # gc.collect() + self._schedule_turtle_reload() return module ## overrides nn.module.train() From 1c18b46ec63bcbf4c49a3fb86b5c58dac1337dee Mon Sep 17 00:00:00 2001 From: Qubitium Date: Tue, 30 Sep 2025 03:43:20 +0000 Subject: [PATCH 05/12] fix propagation of hidden states Signed-off-by: Qubitium --- gptqmodel/looper/module_looper.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/gptqmodel/looper/module_looper.py b/gptqmodel/looper/module_looper.py index 955e951b4..037d61c9e 100644 --- a/gptqmodel/looper/module_looper.py +++ b/gptqmodel/looper/module_looper.py @@ -818,6 +818,7 @@ def loop(self, fail_safe: bool = False, **kwargs): ) if need_outputs: processor.receive_layer_inputs(forward_outputs) + layer_inputs = processor.inputs_cache.layer_inputs del forward_outputs fwd_time = time.time() - fwd_start @@ -895,6 +896,7 @@ def _process_on_worker(proc: LoopProcessor, nm: NamedModule): if processor.fwd_after_process: processor.clear_cache_data() processor.receive_layer_inputs(layer_outputs) + layer_inputs = processor.inputs_cache.layer_inputs if p_index == len(self.processors) - 1: torch_sync() From 71ecddfc59e1cfad6f30f43b2078d924dfbefb2f Mon Sep 17 00:00:00 2001 From: Qubitium Date: Tue, 30 Sep 2025 03:55:57 +0000 Subject: [PATCH 06/12] fix forwarding cannot be sharded across same device:index Signed-off-by: Qubitium --- gptqmodel/looper/module_looper.py | 4 +++- gptqmodel/utils/threadx.py | 33 +++++++++++++++++++++++++++++++ 2 files changed, 36 insertions(+), 1 deletion(-) diff --git a/gptqmodel/looper/module_looper.py b/gptqmodel/looper/module_looper.py index 037d61c9e..9a970efd2 100644 --- a/gptqmodel/looper/module_looper.py +++ b/gptqmodel/looper/module_looper.py @@ -438,8 +438,10 @@ def _run_forward_batches_parallel( for offset, batch_idx in enumerate(range(start, end)): device = devices[offset] replica = module_replicas[device] + submitter = self.pool.submit_serial if device.type in ("cuda", "xpu", "mps") else self.pool.submit + futures.append( - self.pool.submit( + submitter( device, self._forward_batch_worker, replica, diff --git a/gptqmodel/utils/threadx.py b/gptqmodel/utils/threadx.py index bb2e768d9..2a7a1899f 100644 --- a/gptqmodel/utils/threadx.py +++ b/gptqmodel/utils/threadx.py @@ -448,6 +448,7 @@ def __init__( self._worker_groups: Dict[str, List[_DeviceWorker]] = {} self._dispatch_rr: Dict[str, int] = {} self._dispatch_lock = threading.Lock() + self._serial_workers: Dict[str, _DeviceWorker] = {} # Stats / GC / inflight control self._stats_lock = threading.Lock() @@ -501,6 +502,8 @@ def __init__( group.append(worker) self._worker_groups[key] = group self._dispatch_rr[key] = 0 + if group: + self._serial_workers[key] = group[0] # A canonical ordering for multi-device lock acquisitions. self._ordered_keys = sorted(self._locks.keys()) @@ -586,6 +589,36 @@ def submit( self._mark_finished(key) raise + def submit_serial( + self, + device: DeviceLike, + fn: Callable[..., Any], + /, + *args, + _cuda_stream: Optional[torch.cuda.Stream] = None, + **kwargs, + ) -> Future: + """ + Schedule work that must execute sequentially on a device. Tasks are + enqueued onto a dedicated worker so they run in submission order. + """ + dev = _coerce_device(device) + key = self._key(dev) + if _cuda_stream is not None and dev.type != "cuda": + raise ValueError("_cuda_stream is only valid for CUDA devices") + + worker = self._serial_workers.get(key) + if worker is None: + raise ValueError(f"No serial worker available for device '{key}'") + + if DEBUG_ON: log.debug(f"submit_serial: device={key} fn={getattr(fn, '__name__', repr(fn))}") + self._mark_scheduled(key) + try: + return worker.submit(fn, *args, _cuda_stream=_cuda_stream, **kwargs) + except BaseException: + self._mark_finished(key) + raise + def do( self, device: DeviceLike, From 1e826516805b7773eb6d49e9d1b128465a1870a7 Mon Sep 17 00:00:00 2001 From: Qubitium Date: Tue, 30 Sep 2025 04:10:18 +0000 Subject: [PATCH 07/12] dedup Signed-off-by: Qubitium --- gptqmodel/looper/module_looper.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/gptqmodel/looper/module_looper.py b/gptqmodel/looper/module_looper.py index 9a970efd2..4e9507b4d 100644 --- a/gptqmodel/looper/module_looper.py +++ b/gptqmodel/looper/module_looper.py @@ -47,11 +47,10 @@ def _device_ctx(dev: Optional[torch.device]): if dev is None: yield else: - dtyp = getattr(dev, "type", None) - if dtyp == "cuda": + if dev.type == "cuda": with torch.cuda.device(dev.index): yield - elif dtyp == "xpu" and hasattr(torch, "xpu"): + elif dev.type == "xpu" and hasattr(torch, "xpu"): with torch.xpu.device(dev.index): # type: ignore[attr-defined] yield else: @@ -168,10 +167,10 @@ def _select_forward_devices(self, base_device: Optional[torch.device]) -> List[t return [CPU] devices = [base_device] - base_type = getattr(base_device, "type", None) + base_type = base_device.type if base_type in ("cuda", "xpu", "mps"): for dev in ALL_DEVICES: - if getattr(dev, "type", None) == base_type and dev not in devices: + if dev.type == base_type and dev not in devices: devices.append(dev) return devices From 18d30014d32e0581110f00592c432b7986dcde75 Mon Sep 17 00:00:00 2001 From: Qubitium Date: Tue, 30 Sep 2025 04:17:55 +0000 Subject: [PATCH 08/12] add some docs Signed-off-by: Qubitium --- gptqmodel/looper/module_looper.py | 45 ++++++++++++++++++++++++++----- 1 file changed, 39 insertions(+), 6 deletions(-) diff --git a/gptqmodel/looper/module_looper.py b/gptqmodel/looper/module_looper.py index 4e9507b4d..116646c58 100644 --- a/gptqmodel/looper/module_looper.py +++ b/gptqmodel/looper/module_looper.py @@ -3,6 +3,16 @@ # SPDX-License-Identifier: Apache-2.0 # Contact: qubitium@modelcloud.ai, x.com/qubitium +"""Utilities for orchestrating the quantisation loop across multiple devices. + +ModuleLooper is the high-level coordinator that fans calibration batches across +the available accelerators, runs each processing stage, and keeps the shell and +turtle model state coherent. The implementation mixes synchronous orchestration +with asynchronous workers, so the helpers below focus on keeping device context +consistent and ensuring data dependencies survive the roundtrips through the +thread pool. +""" + from __future__ import annotations import copy @@ -119,15 +129,23 @@ def _rehome_module_to_device( class ModuleLooper(): + """Drive the per-layer quantisation workflow over one or more devices. + + The looper owns a :class:`DeviceThreadPool` that executes CPU and accelerator + work. Forward passes can be replicated across devices, processors can enqueue + asynchronous tasks, and the class handles the bookkeeping required to stitch + the results back into a sequential quantisation order. + """ def __init__(self, model: BaseQModel, processors: List[LoopProcessor]): self.processors = processors self.gptq_model = model self.support_batch_quantize = model.support_batch_quantize self.lock = threading.Lock() - # Create a single pool for the entire looper lifecycle. - # Eagerly discovers devices and pins worker threads per device. - # Tune worker counts here if desired (example policy shown). + # The looper shares one pool for its lifetime so tasks such as module + # reloading, forward passes and finalisation reuse the same worker + # threads. The first worker per device is treated as the serial lane for + # forward execution; additional workers handle background jobs. self.pool = DeviceThreadPool( inference_mode=True, workers={ @@ -144,9 +162,8 @@ def __init__(self, model: BaseQModel, processors: List[LoopProcessor]): for processor in self.processors: self._processor_mask_tls(processor) - # NEW: Wrap an existing hook so its inputs/outputs are pre-masked for GPTQ stats. - # We *do not* alter the module's actual computation; only what the hook - # passes down to the processor capture path is masked. + # Processors capture activations through hooks that need thread-local state + # so masks survive the roundtrip to worker threads. def _processor_mask_tls(self, processor: LoopProcessor) -> threading.local: tls = getattr(processor, "_mask_tls", None) if tls is None: @@ -231,6 +248,13 @@ def _forward_batch_worker( reuse_kv: bool, prev_kv, ): + """Run one forward micro-batch on a pool worker and return its output. + + The worker receives pre-moved inputs, executes the module on its bound + device and ships back both the model outputs and any next-layer KV cache + update. The thin signature keeps the function pickleable for the worker + queue. + """ module_device = get_device(module) _rehome_module_to_device(module, module_device, move_parameters=False, move_buffers=True) @@ -297,6 +321,13 @@ def _run_forward_batches( need_outputs: bool, reuse_kv: bool, ) -> List[List[torch.Tensor]]: + """Dispatch the captured layer inputs through the module. + + When multiple accelerators of the same type are available we clone the + module and execute batches in parallel, otherwise we fall back to a + single threaded path. The helper returns the ordered outputs that feed + the next processor stage when ``need_outputs`` is set. + """ devices = self._select_forward_devices(cur_layer_device) if len(devices) <= 1: @@ -347,6 +378,7 @@ def _run_forward_batches_single( need_outputs: bool, reuse_kv: bool, ) -> List[List[torch.Tensor]]: + """Sequential fallback when only one forward device is in use.""" outputs: List[List[torch.Tensor]] = [] prev_kv = shared_kv_cache_dict.get(layer_index - 1) if reuse_kv else None @@ -423,6 +455,7 @@ def _run_forward_batches_parallel( reuse_kv: bool, devices: List[torch.device], ) -> List[List[torch.Tensor]]: + """Fan batches across device clones and preserve result ordering.""" module_replicas = self._clone_module_for_devices(module, devices) prev_kv = shared_kv_cache_dict.get(layer_index - 1) if reuse_kv else None From c2f5b3c2623830c09331d153c10c5eff5cc4ed7e Mon Sep 17 00:00:00 2001 From: Qubitium Date: Tue, 30 Sep 2025 04:35:55 +0000 Subject: [PATCH 09/12] fix device override Signed-off-by: Qubitium --- gptqmodel/looper/module_looper.py | 6 +++--- gptqmodel/models/_const.py | 20 ++++++++++++++++++++ 2 files changed, 23 insertions(+), 3 deletions(-) diff --git a/gptqmodel/looper/module_looper.py b/gptqmodel/looper/module_looper.py index 116646c58..1f4b310f8 100644 --- a/gptqmodel/looper/module_looper.py +++ b/gptqmodel/looper/module_looper.py @@ -30,7 +30,7 @@ from ..looper.loop_processor import LoopProcessor from ..looper.named_module import NamedModule from ..models import BaseQModel -from ..models._const import SUPPORTS_MODULE_TYPES +from ..models._const import SUPPORTS_MODULE_TYPES, DEVICE from ..nn_modules.hooked_linear import (STOP_FORWARD_EXCEPTION, HookedLinear, StopForward, replace_module_with_hooked_legacy) from ..utils.attn_mask import apply_keep_mask_bt, normalize_seq_mask @@ -49,7 +49,7 @@ # -------------------- Device helpers (local) -------------------- @contextmanager -def _device_ctx(dev: Optional[torch.device]): +def _device_ctx(dev: Optional[torch.device|DEVICE]): """ Ensure the caller thread’s current device matches `dev` for the duration of the context (CUDA/XPU). Prevents cuBLAS/cuDNN handle/device mismatches in multi-GPU. @@ -154,7 +154,7 @@ def __init__(self, model: BaseQModel, processors: List[LoopProcessor]): "mps": 8, # unified memory "cpu": 8, # unified memory }, - empty_cache_every_n=0, # disable auto GC during quant loops; enable if you want + empty_cache_every_n=64, # disable auto GC during quant loops; enable if you want ) self.gptq_model.register_background_pool(self.pool) diff --git a/gptqmodel/models/_const.py b/gptqmodel/models/_const.py index c56342cf2..1e2318c91 100644 --- a/gptqmodel/models/_const.py +++ b/gptqmodel/models/_const.py @@ -2,6 +2,7 @@ # SPDX-FileCopyrightText: 2024-2025 qubitium@modelcloud.ai # SPDX-License-Identifier: Apache-2.0 # Contact: qubitium@modelcloud.ai, x.com/qubitium +from __future__ import annotations from enum import Enum @@ -43,6 +44,25 @@ def _missing_(cls, value): return cls.ROCM return super()._missing_(value) + @property + def type(self) -> str: + """Return the backend type compatible with torch.device semantics.""" + if self == DEVICE.ROCM: + return "cuda" + return str(self) + + @property + def index(self) -> int | None: + """Default index used when materialising a torch.device from this enum.""" + if self in (DEVICE.CUDA, DEVICE.ROCM, DEVICE.XPU): + return 0 + return None + + def to_torch_device(self) -> torch.device: + """Convert the enum to a concrete torch.device, defaulting to index 0.""" + idx = self.index + return torch.device(self.type if idx is None else f"{self.type}:{idx}") + def to_device_map(self): return {"": DEVICE.CUDA if self == DEVICE.ROCM else self} From 8f044a262e348785ab0f8bc154a33303df0f7a89 Mon Sep 17 00:00:00 2001 From: Qubitium Date: Tue, 30 Sep 2025 04:40:14 +0000 Subject: [PATCH 10/12] increase gc interval Signed-off-by: Qubitium --- gptqmodel/looper/module_looper.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gptqmodel/looper/module_looper.py b/gptqmodel/looper/module_looper.py index 1f4b310f8..3b8b08bb8 100644 --- a/gptqmodel/looper/module_looper.py +++ b/gptqmodel/looper/module_looper.py @@ -154,7 +154,7 @@ def __init__(self, model: BaseQModel, processors: List[LoopProcessor]): "mps": 8, # unified memory "cpu": 8, # unified memory }, - empty_cache_every_n=64, # disable auto GC during quant loops; enable if you want + empty_cache_every_n=1024, # disable auto gc based gpu work rate; enable if you want ) self.gptq_model.register_background_pool(self.pool) From 6c3f7bbd0921c60c3a4450cbd249871ab873ed1a Mon Sep 17 00:00:00 2001 From: Qubitium Date: Tue, 30 Sep 2025 04:41:57 +0000 Subject: [PATCH 11/12] format Signed-off-by: Qubitium --- gptqmodel/models/base.py | 1 + 1 file changed, 1 insertion(+) diff --git a/gptqmodel/models/base.py b/gptqmodel/models/base.py index 3af193620..617bddd64 100644 --- a/gptqmodel/models/base.py +++ b/gptqmodel/models/base.py @@ -53,6 +53,7 @@ from .loader import ModelLoader from .writer import ModelWriter + if TYPE_CHECKING: from ..utils.threadx import DeviceThreadPool From a68ea0a06faed9e0ec67520f0647edeb668f3963 Mon Sep 17 00:00:00 2001 From: Qubitium Date: Tue, 30 Sep 2025 04:48:37 +0000 Subject: [PATCH 12/12] readme changes Signed-off-by: Qubitium --- README.md | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index 0afaee0d7..7be7ae956 100644 --- a/README.md +++ b/README.md @@ -17,6 +17,7 @@

## Latest News +* 09/30/2025 5.0.0-dev `main`: πŸ‘€: New Data Parallel + Multi-GPU + Python 3.13g (PYTHON_GIL=0) equals 80%+ overall quant time reduction of large MoE models va v4.2.5. * 09/29/2025 5.0.0-dev `main`: πŸŽ‰ New Qwen3 Omni model support. AWQ Marlin kernel integrated + many disk offload, threading, and memory usage fixes. * 09/24/2025 5.0.0-dev `main`: πŸŽ‰ Up to 90% cpu mem saving for large MoE models with faster/inline packing! 26% quant time reduction for Qwen3 MoE! AWQ Marlin kernel added. AWQ Gemm loading bug fixes. `act_group_aware` now faster and auto enabled for GPTQ when `desc_act` is False for higher quality recovery. * 09/19/2025 5.0.0-dev `main`: πŸ‘€ Cpu memory saving of ~73.5% during quantization stage with new `offload_to_disk` quantization config property default to `True`. @@ -152,14 +153,17 @@ Native support support some of the most popular multi-modal models: ## Features * ✨ Native integration with HF [Transformers](https://github.com/huggingface/transformers), [Optimum](https://github.com/huggingface/optimum), and [Peft (main)](https://github.com/huggingface/peft) * πŸš€ [vLLM](https://github.com/vllm-project/vllm) and [SGLang](https://github.com/sgl-project/sglang) inference integration for quantized model with format = `FORMAT.GPTQ` +* ✨ GPTQ, AWQ, and QQQ quantization format with hw accelerated inference kernels. +* πŸš€ Data Parallism for 80%+ quantization speed reduction with Multi-GPU. +* πŸš€ Optimized for Python >= 3.13t (free threading) with lock-free threading. * ✨ Linux, MacOS, Windows platform quantization and accelerated inference support for CUDA (Nvidia), XPU (Intel), ROCm (AMD), MPS (Apple Silicon), CPU (Intel/AMD/Apple Silicon). -* πŸ’― 100% CI unit-test coverage for all supported models and kernels including post-quantization quality regression. * ✨ `Dynamic` mixed quantization control on a per-module basis. Each layer/module can have a unique quantization config or be excluded from quantization all together. * πŸš€ Intel Torch 2.8 fused kernel support for XPU [`Arc` + `Datacenter Max`] and CPU [`avx`, `amx`, `xmx`]. * πŸš€ Python 3.13.3t (free-threading, GIL disabled) support for multi-gpu accelerated quantization for MoE models and multi-core cpu boost for post-quant packing. * ✨ Asymmetric `Sym=False` support. Model weights sharding support with optional hash check of model weights on load. * ✨ `lm_head` module quant inference support for further VRAM reduction. * πŸš€ [Microsoft/BITBLAS](https://github.com/microsoft/BitBLAS) format + dynamically compiled inference. +* πŸ’― 100% CI unit-test coverage for all supported models and kernels including post-quantization quality regression. ## Quality: GPTQ 4bit (5.0 bpw) can match BF16: