Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 18 additions & 20 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,25 +17,28 @@
</p>

## Latest News
* 10/20/2025 [5.0.0](https://github.com/ModelCloud/GPTQModel/releases/tag/v5.0.0): 🎉 Data-parallel quant support for `MoE` models on multi-gpu using `nogil` Python. `offload_to_disk` support enabled by
default to massively reduce `cpu` ram usage. New `Intel` and `AMD` cpu hw accelerated `TorchFused` kernel. Packing stage is now 4x faster and now inlined with quantization. `Vram` pressure for large models reduced during quantization.
`act_group_aware` is 16k+ times faster and now the default when `desc_act=False` for higher quality recovery without inference penalty of `desc_act=True`. New beta quality `AWQ` support with full `gemm`,
`gemm_fast`, `marlin` kernel support. `LFM`, `Ling`, `Qwen3 Omni` model support. Quantization is now faster with reduced vram usage. Enhanced logging support with `LogBar`.
* 09/16/2025 [4.2.5](https://github.com/ModelCloud/GPTQModel/releases/tag/v4.2.5): `hyb_act` renamed to `act_group_aware`. Removed finicky `torch` import within `setup.py`. Packing bug fix and prebuilt Pytorch 2.8 whls.
* 09/12/2025 [4.2.0](https://github.com/ModelCloud/GPTQModel/releases/tag/v4.2.0): ✨ New Models Support: Qwen3-Next, Apertus, Kimi K2, Klear, FastLLM, Nemotron H. New `fail_safe` `boolean` toggle to `.quantize()` to patch-fix non-activated `MoE` modules due to highly uneven MoE model training. Fixed LavaQwen2 compat. Patch fix GIL=0 cuda error for multi-gpu. Fix compat with autoround + new transformers.
* 09/04/2025 [4.1.0](https://github.com/ModelCloud/GPTQModel/releases/tag/v4.1.0): ✨ Meituan LongCat Flash Chat, Llama 4, GPT-OSS (BF16), and GLM-4.5-Air support. New experiemental `mock_quantization` config to skip complex computational code paths during quantization to accelerate model quant testing.
* 08/21/2025 [4.0.0](https://github.com/ModelCloud/GPTQModel/releases/tag/v4.0.0): 🎉 New Group Aware Reordering (GAR) support. New models support: Bytedance Seed-OSS, Baidu Ernie, Huawei PanGu, Gemma3, Xiaomi Mimo, Qwen 3/MoE, Falcon H1, GPT-Neo. Memory leak and multiple model compatibility fixes related to Transformers >= 4.54. Python >= 3.13t free-threading support added with near N x GPU linear scaling for quantization of MoE models and also linear N x Cpu Core scaling of packing stage. Early access Pytorch 2.8 fused-ops on Intel XPU for up to 50% speedup.

<details>

<summary>Archived News</summary>
* 10/17/2025 5.0.0-dev `main`: 👀: EoRA now multi-gpu compatible. Fixed both quality stability of multi-gpu quanta and vram usage. New LFM and Ling models support.
* 09/30/2025 5.0.0-dev `main`: 👀: New Data Parallel + Multi-GPU + Python 3.13T (PYTHON_GIL=0) equals 80%+ overall quant time reduction of large MoE models vs v4.2.5.
* 09/29/2025 5.0.0-dev `main`: 🎉 New Qwen3 Omni model support. AWQ Marlin kernel integrated + many disk offload, threading, and memory usage fixes.
* 09/24/2025 5.0.0-dev `main`: 🎉 Up to 90% cpu mem saving for large MoE models with faster/inline packing! 26% quant time reduction for Qwen3 MoE! AWQ Marlin kernel added. AWQ Gemm loading bug fixes. `act_group_aware` now faster and auto enabled for GPTQ when `desc_act` is False for higher quality recovery.
* 09/19/2025 5.0.0-dev `main`: 👀 Cpu memory saving of ~73.5% during quantization stage with new `offload_to_disk` quantization config property default to `True`.
* 09/18/2025 5.0.0-dev `main`: 🎉 AWQ quantization support! Complete refractor and simplification of model definitions in prepreation for future quantization formats.
* 09/16/2025 [4.2.5](https://github.com/ModelCloud/GPTQModel/releases/tag/v4.2.5): `hyb_act` renamed to `act_group_aware`. Removed finicky `torch` import within `setup.py`. Packing bug fix and prebuilt Pytorch 2.8 whls.
* 09/12/2025 [4.2.0](https://github.com/ModelCloud/GPTQModel/releases/tag/v4.2.0): ✨ New Models Support: Qwen3-Next, Apertus, Kimi K2, Klear, FastLLM, Nemotron H. New `fail_safe` `boolean` toggle to `.quantize()` to patch-fix non-activated `MoE` modules due to highly uneven MoE model training. Fixed LavaQwen2 compat. Patch fix GIL=0 cuda error for multi-gpu. Fix compat with autoround + new transformers.
* 09/04/2025 [4.1.0](https://github.com/ModelCloud/GPTQModel/releases/tag/v4.1.0): ✨ Meituan LongCat Flash Chat, Llama 4, GPT-OSS (BF16), and GLM-4.5-Air support. New experiemental `mock_quantization` config to skip complex computational code paths during quantization to accelerate model quant testing.
* 08/21/2025 [4.0.0](https://github.com/ModelCloud/GPTQModel/releases/tag/v4.0.0): 🎉 New Group Aware Reordering (GAR) support. New models support: Bytedance Seed-OSS, Baidu Ernie, Huawei PanGu, Gemma3, Xiaomi Mimo, Qwen 3/MoE, Falcon H1, GPT-Neo. Memory leak and multiple model compatibility fixes related to Transformers >= 4.54. Python >= 3.13t free-threading support added with near N x GPU linear scaling for quantization of MoE models and also linear N x Cpu Core scaling of packing stage. Early access Pytorch 2.8 fused-ops on Intel XPU for up to 50% speedup.
* 09/18/2025 5.0.0-dev `main`: 🎉 AWQ quantization support! Complete refractor and simplification of model definitions in prepreation for future quantization formats.
* 08/19/2025 4.0.0-dev `main`: Fix quantization memory usage due to some model's incorrect application of `config.use_cache` during inference. Fixed `Transformers` >= 4.54.0 compat which changed layer forward return signature for some models.
* 08/18/2025 4.0.0-dev `main`: GPT-Neo model support. Memory leak fix in error capture (stacktrace) and fixed `lm_head` quantization compatibility for many models.
* 07/31/2025 4.0.0-dev `main`: New Group Aware Reordering (GAR) support and prelim Pytorch 2.8 fused-ops for Intel XPU for up to 50% speedup.
* 07/03/2025 4.0.0-dev `main`: New Baidu Ernie and Huawei PanGu model support.

<details>

<summary>Archived News</summary>

* 07/02/2025 4.0.0-dev `main`: Gemma3 4B model compat fix.
* 05/29/2025 4.0.0-dev `main`: Falcon H1 model support. Fixed Transformers `4.52+` compat with Qwen 2.5 VL models.
* 05/19/2025 4.0.0-dev `main`: Qwen 2.5 Omni model support.
Expand Down Expand Up @@ -172,12 +175,6 @@ Native support support some of the most popular multi-modal models:

<img src=https://github.com/user-attachments/assets/c1b89394-f8f6-44e5-9949-bef15a124723 width="51%"> <img src=https://github.com/user-attachments/assets/23901236-10c5-4435-ac2f-06cf2e097f1e width="47%">

## Experimental GPTQ v2 quantization: Users have reported this mode of quantization may or may not match original GPTQ v1 implementation in terms of quality recovery.

<div align=center>
<img src=https://github.com/user-attachments/assets/8e627922-0b73-4e44-b3e2-c01def5301f9 width=300>
</div>

## Model Support
| Model | | | | | | | | | |
|-------------------|---|-------------------|---|----------------|---|----------------|---|---------------------|---|
Expand Down Expand Up @@ -287,12 +284,13 @@ model.quantize(calibration_dataset, batch_size=1)
model.save(quant_path)
```

### Quantization using GPTQ V2
### Quantization using GPTQ V2* (Experimental, not MoE compatible, and results may not be better than v1)

Enable GPTQ v2 quantization by setting `v2 = True` for potentially higher post-quantization accuracy recovery.
Enable GPTQ v2 quantization by setting `v2 = True`.
```py
# note v2 is currently experiemental and requires 2-4x more vram to execute
# if oom on 1 gpu, please set CUDA_VISIBLE_DEVICES=0,1 to 2 gpu and gptqmodel will auto use second gpu
# Note v2 is currently experimental, not MoE compatible, and requires 2-4x more vram to execute
# We have many reports of v2 not working better or exceeding v1 so please use for testing only
# If oom on 1 gpu, please set CUDA_VISIBLE_DEVICES=0,1 to 2 gpu and gptqmodel will auto use second gpu
quant_config = QuantizeConfig(bits=4, group_size=128, v2=True)
```
`Llama 3.1 8B-Instruct` quantized using `test/models/test_llama3_2.py`
Expand Down
77 changes: 68 additions & 9 deletions gptqmodel/looper/module_looper.py
Original file line number Diff line number Diff line change
Expand Up @@ -579,14 +579,7 @@ def _run_forward_batches_parallel(
progress_total_rows: Optional[int] = None,
) -> List[List[torch.Tensor]]:
"""Fan batches across device clones and preserve result ordering."""
module_replicas = clone_module_for_devices(module, devices)

# Ensure any async replication/memcpy ops are complete before threads start fanning out.
torch_sync()

prev_kv = shared_kv_cache_dict.get(layer_index - 1) if reuse_kv else None

results: Dict[int, torch.Tensor | tuple | None] = {}
effective_title = progress_title or (progress_stage or "Forward")

total_batches = self._resolve_batch_total(processor.num_batches, layer_inputs)
batch_row_counts = progress_rows_per_batch or self._collect_row_counts(layer_inputs)
Expand All @@ -599,8 +592,74 @@ def _run_forward_batches_parallel(
if total_rows <= 0 and total_batches > 0:
total_rows = total_batches
total_rows = max(total_rows, 1)
processed_rows = 0
stage_label = progress_stage or "Forward"

replica_pb: "ProgressBar" | None = None
replica_title = ""
replica_completed = 0

if progress_pb is not None:
progress_pb.title(effective_title)
if len(devices) > 1:
replica_title = f"{stage_label}: replicate to {len(devices)} devices"
replica_pb = (
log.pb(range(len(devices)))
.manual()
.set(show_left_steps=False)
)
replica_pb.title(replica_title).subtitle("Staging module...").draw()
else:
device_label = str(devices[0]) if devices else "<device>"
progress_pb.subtitle(f"{stage_label}: staging on {device_label}").draw()

def _replica_progress(idx: int, total: int, device: torch.device, step: str) -> None:
nonlocal replica_completed
device_label = str(device)
if replica_pb is not None:
if step == "stage":
replica_pb.title(replica_title).subtitle(f"Stage {device_label}").draw()
return
if idx > replica_completed:
replica_completed = idx
replica_pb.title(replica_title).subtitle(
f"{device_label} {idx}/{total}"
).next().draw()
else:
replica_pb.title(replica_title).subtitle(
f"{device_label} {idx}/{total}"
).draw()
elif progress_pb is not None:
stage_msg = (
f"{stage_label}: staging on {device_label}"
if step == "stage"
else f"{stage_label}: {step} {idx}/{total} on {device_label}"
)
progress_pb.title(effective_title).subtitle(stage_msg).draw()

progress_cb = _replica_progress if progress_pb is not None else None

# Ensure any async replication/memcpy ops are complete before threads start fanning out.
torch_sync()

try:
module_replicas = clone_module_for_devices(
module,
devices,
progress_callback=progress_cb,
)
finally:
if replica_pb is not None:
replica_pb.close()
if progress_pb is not None:
progress_pb.title(effective_title).subtitle(
f"{stage_label} rows 0/{total_rows}"
).draw()

prev_kv = shared_kv_cache_dict.get(layer_index - 1) if reuse_kv else None

results: Dict[int, torch.Tensor | tuple | None] = {}

processed_rows = 0
device_segments: Dict[torch.device, List[int]] = {}
segment_start = 0
num_devices = len(devices)
Expand Down
4 changes: 2 additions & 2 deletions gptqmodel/models/definitions/glm4_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,15 +28,15 @@ class GLM4MoEGPTQ(BaseQModel):
"#",
{
"input_layernorm": ("input_layernorm:!",),
"self_attn": ("q_proj:0", "k_proj:0", "v_proj:0", "o_proj:1"),
"self_attn": ("q_proj:0", "q_norm:0:!","k_proj:0", "k_norm:0:!", "v_proj:0", "o_proj:1"),
"post_attention_layernorm": ("post_attention_layernorm:!",),
"mlp": {
"shared_experts": {
"gate_proj": ("gate_proj:0",),
"up_proj": ("up_proj:0",),
"down_proj": ("down_proj:1",),
},
"gate": ("gate:!",),
"gate": ("gate:!",), # Glm4MoeTopKRouter, ~1.6MB float32 per layer. We really do not quant to quantize this.
"experts": {
"#": ("gate_proj:0", "up_proj:0", "down_proj:1"),
},
Expand Down
19 changes: 9 additions & 10 deletions gptqmodel/quantization/gptq.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
# SPDX-License-Identifier: Apache-2.0
# Contact: qubitium@modelcloud.ai, x.com/qubitium

# adapted from @qwopqwop200 's [GPTQ-for-LLaMa](https://github.com/qwopqwop200/GPTQ-for-LLaMa/tree/cuda), which itself is based on [gptq](https://github.com/IST-DASLab/gptq)
# Based on original gptq algorithm and code from https://github.com/IST-DASLab/gptq

import contextlib
import math
Expand Down Expand Up @@ -559,33 +559,32 @@ def hf_quantize(

@torch.inference_mode()
def hessian_inverse(self, H: torch.Tensor):

damp = self.qcfg.damp_percent
diag = torch.arange(self.columns, device=H.device)
mean = torch.mean(torch.diag(H))

orig_diag = H.diag().clone()
while 0 < damp < 1:
try:
H2 = H.clone()
H2[diag, diag] += damp * mean
# TODO call to torch.linalg is not threadsafe? Porque no? Esta muy mal.
H2 = torch.linalg.cholesky(H2)
H.diagonal().add_(damp * mean)
H2 = torch.linalg.cholesky(H)
Hinv = torch.linalg.cholesky(torch.cholesky_inverse(H2), upper=True)
del H, H2
H.diagonal().copy_(orig_diag)
del H2
break
except torch._C._LinAlgError as e:
H.diagonal().copy_(orig_diag)
if self.qcfg.damp_auto_increment != 0:
log.warn(
f"Quantization: Module `{self.name}` -> Current `damp_percent = {damp:.5f}` is too low, auto-incrementing by `{self.qcfg.damp_auto_increment:.5f}`")
damp += self.qcfg.damp_auto_increment
else:
log.warn(
"Quantization: Module `{self.name}` -> Please increase damp or nsamples for calibration data to avoid the following quant error: current damp_percent=`{damp_percent:.5f}`")
"Quantization: Module `{self.name}` -> Please increase damp or nsamples for calibration data to avoid the following quant error: current damp_percent=`{damp:.5f}`")
raise e

if not (0 < damp < 1):
log.error(
f"Quantization: Module `{self.name}` -> `damp_percent` must between 0 and 1. current is {damp}. Module cannot be correctly processed.")
# raise ValueError(f"Quantization: `damp_percent` must between 0 and 1. current is {damp}")
return None, 1.0

return Hinv, damp
Expand Down
28 changes: 25 additions & 3 deletions gptqmodel/utils/looper_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import threading
import time
from contextlib import contextmanager
from typing import Dict, List, Optional, Sequence, Tuple
from typing import Callable, Dict, List, Optional, Sequence, Tuple

import torch
from torch.nn import parallel as torch_parallel
Expand Down Expand Up @@ -225,6 +225,7 @@ def clone_module_for_devices(
devices: List[torch.device],
*,
clear_state_fn=clear_non_picklable_state,
progress_callback: Optional[Callable[[int, int, torch.device, str], None]] = None,
) -> Dict[torch.device, torch.nn.Module]:
clones: Dict[torch.device, torch.nn.Module] = {}
if not devices:
Expand All @@ -234,6 +235,21 @@ def clone_module_for_devices(
clone_timings: List[Tuple[str, float]] = []
overall_start = time.perf_counter()

total_targets = len(devices)

def _notify(idx: int, device: torch.device, step: str) -> None:
if progress_callback is None:
return
try:
progress_callback(idx, total_targets, device, step)
except Exception:
if DEBUG_ON:
log.debug(
"clone_module_for_devices: progress callback failed (device=%s, step=%s)",
device,
step,
)

def _record(name: str, start_ts: Optional[float]) -> None:
if not DEBUG_ON or start_ts is None:
return
Expand Down Expand Up @@ -283,17 +299,19 @@ def _prepare_module(target_device: torch.device, step_name: str) -> None:
if use_replicate:
try:
_prepare_module(base_device, f"stage_{base_device}")
_notify(0, base_device, "stage")

replicate_start = time.perf_counter()
replicas = torch_replicate(module, devices)
_record("replicate", replicate_start)

for dev, replica in zip(devices, replicas):
for idx, (dev, replica) in enumerate(zip(devices, replicas), start=1):
replica.eval()
rehome_module_to_device(replica, dev, move_parameters=True, move_buffers=True)
clear_state_fn(replica)
setattr(replica, "_gptqmodule_device_hint", dev)
clones[dev] = replica
_notify(idx, dev, "replica")

_emit_clone_log("replicate")
return clones
Expand All @@ -305,14 +323,17 @@ def _prepare_module(target_device: torch.device, step_name: str) -> None:

if len(devices) == 1 and devices[0].type == "cpu":
_prepare_module(CPU, "stage_cpu")
_notify(0, CPU, "stage")
clones[devices[0]] = module
_notify(1, devices[0], "reuse")
_emit_clone_log("reuse")
return clones

if not use_replicate:
_prepare_module(stage_device, f"stage_{stage_device}")
_notify(0, stage_device, "stage")

for dev in devices:
for idx, dev in enumerate(devices, start=1):
start_ts = time.perf_counter()
with _DEEPCOPY_LOCK:
replica = copy.deepcopy(module)
Expand All @@ -322,6 +343,7 @@ def _prepare_module(target_device: torch.device, step_name: str) -> None:
setattr(replica, "_gptqmodule_device_hint", dev)
clones[dev] = replica
_record(str(dev), start_ts)
_notify(idx, dev, "clone")

_emit_clone_log("deepcopy")
return clones
Expand Down
Loading