Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 31 additions & 5 deletions gptqmodel/looper/module_looper.py
Original file line number Diff line number Diff line change
Expand Up @@ -588,7 +588,6 @@ def _run_forward_batches_parallel(

results: Dict[int, torch.Tensor | tuple | None] = {}

chunk = len(devices)
total_batches = self._resolve_batch_total(processor.num_batches, layer_inputs)
batch_row_counts = progress_rows_per_batch or self._collect_row_counts(layer_inputs)
batch_row_counts = list(batch_row_counts)
Expand All @@ -602,11 +601,38 @@ def _run_forward_batches_parallel(
total_rows = max(total_rows, 1)
processed_rows = 0
stage_label = progress_stage or "Forward"
for start in range(0, total_batches, chunk):
device_segments: Dict[torch.device, List[int]] = {}
segment_start = 0
num_devices = len(devices)

for index, device in enumerate(devices):
remaining_batches = max(total_batches - segment_start, 0)
remaining_devices = max(num_devices - index, 1)
segment_length = remaining_batches // remaining_devices
remainder = remaining_batches % remaining_devices
if remainder > 0:
segment_length += 1

if segment_length <= 0:
device_segments[device] = []
continue

segment_end = min(segment_start + segment_length, total_batches)
device_segments[device] = list(range(segment_start, segment_end))
segment_start = segment_end

max_segment_length = 0
for indices in device_segments.values():
if len(indices) > max_segment_length:
max_segment_length = len(indices)

for position in range(max_segment_length):
futures = []
end = min(start + chunk, total_batches)
for offset, batch_idx in enumerate(range(start, end)):
device = devices[offset]
for device in devices:
segment_indices = device_segments.get(device, [])
if position >= len(segment_indices):
continue
batch_idx = segment_indices[position]
replica = module_replicas[device]
submitter = (
DEVICE_THREAD_POOL.submit_serial
Expand Down
174 changes: 89 additions & 85 deletions gptqmodel/quantization/gptq.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
# adapted from @qwopqwop200 's [GPTQ-for-LLaMa](https://github.com/qwopqwop200/GPTQ-for-LLaMa/tree/cuda), which itself is based on [gptq](https://github.com/IST-DASLab/gptq)

import contextlib
import heapq
import math
import os
import sys
Expand Down Expand Up @@ -169,11 +168,9 @@ def __init__(self, module: nn.Module, qcfg: Optional[QuantizeConfig] = None):
setattr(self.module, "target_device", module_device)

if module_device.type == "meta":
self._default_hessian_device = torch.device("cpu")
self._final_hessian_device_hint = torch.device("cpu")
else:
self._default_hessian_device = torch.device(module_device)

self._hessian_device: Optional[torch.device] = None
self._final_hessian_device_hint = torch.device(module_device)

self._validate_module(self.module)

Expand All @@ -191,13 +188,13 @@ def __init__(self, module: nn.Module, qcfg: Optional[QuantizeConfig] = None):

self.fail_safe = False

self.H = torch.zeros((self.columns, self.columns),
dtype=torch.float32)
self.H: Optional[torch.Tensor] = None

# Track per-batch Hessian contributions so they can be applied in a
# deterministic order even when forwards execute in parallel.
self._pending_updates: List[Tuple[int, int, Optional[torch.Tensor], Optional[torch.device]]] = []
self._next_batch_index: int = 0
# Store per-device Hessian contributions so multi-GPU calibration can
# keep local accumulators and merge only once when quantization begins.
self._device_hessian_partials: Dict[torch.device, torch.Tensor] = {}
self._device_sample_counts: Dict[torch.device, int] = {}
self._hessian_dirty: bool = False

@staticmethod
def _validate_module(module):
Expand Down Expand Up @@ -257,14 +254,25 @@ def _truncate_last_dim(tensor: torch.Tensor, length: int) -> torch.Tensor:
return tensor.narrow(tensor.dim() - 1, 0, trim).contiguous()

def add_batch(self, inp: torch.Tensor, out: torch.Tensor, batch_index: Optional[int] = None):
batch_token_size, xtx, device = self.process_batch(inp)
if batch_token_size == 0 or xtx is None:
return

dev = torch.device(device)

with self.lock:
self.fwd_counter += 1

batch_token_size, xtx, device = self.process_batch(inp)
existing = self._device_hessian_partials.get(dev)
if existing is None:
self._device_hessian_partials[dev] = xtx
else:
existing.add_(xtx)
del xtx

pending_index = batch_index if batch_index is not None else self._next_batch_index
heapq.heappush(self._pending_updates, (pending_index, batch_token_size, xtx, device))
self._flush_pending_updates_locked()
self._device_sample_counts[dev] = self._device_sample_counts.get(dev, 0) + batch_token_size
self.nsamples += batch_token_size
self._hessian_dirty = True

def _preferred_staging_dtype(self, input_dtype: torch.dtype, device: torch.device) -> torch.dtype:
device = torch.device(device)
Expand Down Expand Up @@ -355,39 +363,6 @@ def _compute_hessian_xtx(self, matrix: torch.Tensor) -> torch.Tensor:

return xtx_accum

def _resolve_hessian_device(self, batch_device: torch.device) -> torch.device:
"""Select a stable device for Hessian accumulation.

The first non-meta device we observe (module target, default hint, or
batch input) becomes the canonical Hessian device for the lifetime of
this GPTQ instance. Subsequent batches keep using the same target to
avoid bouncing tensors across GPUs when calibration runs on multiple
devices concurrently.
"""

if self._hessian_device is not None:
return self._hessian_device

module_target = getattr(self.module, "target_device", None)
canonical = None

if module_target is not None:
canonical = torch.device(module_target)
if canonical.type == "meta":
canonical = None

if canonical is None and hasattr(self, "_default_hessian_device"):
canonical = self._default_hessian_device

if canonical is None or canonical.type == "meta":
canonical = batch_device

if canonical.type == "meta":
canonical = torch.device("cpu")

self._hessian_device = canonical
return canonical

def process_batch(self, inp: torch.Tensor) -> Tuple[int, Optional[torch.Tensor], torch.device]:
# print(f"inp = {inp}")
# print(f"self.module = {self.module} device = {self.module.target_device}")
Expand Down Expand Up @@ -436,7 +411,7 @@ def process_batch(self, inp: torch.Tensor) -> Tuple[int, Optional[torch.Tensor],
if self._tp_pad_cols:
pad = reshaped_inp.new_zeros((reshaped_inp.shape[0], self._tp_pad_cols))
reshaped_inp = torch.cat((reshaped_inp, pad), dim=1)
canonical_device = self._resolve_hessian_device(inp_device)
canonical_device = torch.device(inp_device)

batch_token_size = reshaped_inp.shape[0]

Expand All @@ -460,7 +435,6 @@ def process_batch(self, inp: torch.Tensor) -> Tuple[int, Optional[torch.Tensor],
if torch.cuda.is_available():
torch.cuda.empty_cache()
canonical_device = torch.device("cpu")
self._hessian_device = canonical_device
xtx = self._compute_hessian_xtx(reshaped_inp_cpu).to(dtype=torch.float32)
xtx = xtx.detach()
del reshaped_inp_cpu
Expand All @@ -473,45 +447,79 @@ def process_batch(self, inp: torch.Tensor) -> Tuple[int, Optional[torch.Tensor],

return batch_token_size, xtx, canonical_device

def _flush_pending_updates_locked(self, *, allow_gaps: bool = False) -> None:
expected = self._next_batch_index

while self._pending_updates:
index, batch_token_size, xtx, device = self._pending_updates[0]
def _select_hessian_target_device(self, requested: Optional[torch.device]) -> torch.device:
if requested is not None:
return torch.device(requested)

if index < expected:
heapq.heappop(self._pending_updates)
continue
hint = getattr(self, "_final_hessian_device_hint", None)
if hint is not None:
return torch.device(hint)

if not allow_gaps and index != expected:
break
if self._device_hessian_partials:
partial_device = next(iter(self._device_hessian_partials.keys()))
return torch.device(partial_device)

heapq.heappop(self._pending_updates)
return torch.device("cpu")

if allow_gaps and index > expected:
expected = index
def _materialize_global_hessian(self, target_device: Optional[torch.device] = None) -> None:
device = self._select_hessian_target_device(target_device)

if batch_token_size > 0 and xtx is not None:
target_device = device if device is not None else self.H.device
if target_device is None:
target_device = self.H.device
with self.lock:
if not self._hessian_dirty and self.H is not None:
if self.H.device != device:
self.H = self.H.to(device=device)
return

total_samples = sum(self._device_sample_counts.values())

reuse_buffer = (
self.H is not None
and self.H.shape == (self.columns, self.columns)
and self.H.dtype == torch.float32
and self.H.device == device
)

self.H = self.H.to(device=target_device)
if xtx.device != target_device:
xtx = xtx.to(device=target_device)
if reuse_buffer:
result = self.H
result.zero_()
else:
result = torch.zeros(
(self.columns, self.columns),
dtype=torch.float32,
device=device,
)

total = self.nsamples + batch_token_size
beta = self.nsamples / total
alpha = 2.0 / total
self.H.mul_(beta)
self.H.add_(xtx, alpha=alpha)
self.nsamples = total
if total_samples == 0:
self.H = result
self.nsamples = 0
self._hessian_dirty = False
self._final_hessian_device_hint = device
self._device_hessian_partials.clear()
self._device_sample_counts.clear()
return

for partial_device, partial in self._device_hessian_partials.items():
if partial.device != result.device:
tmp = partial.to(result.device)
result.add_(tmp)
del tmp
else:
result.add_(partial)

del xtx
result.mul_(2.0 / float(total_samples))

expected = index + 1
self.H = result
self.nsamples = total_samples
self._hessian_dirty = False
self._final_hessian_device_hint = result.device
self._device_hessian_partials.clear()
self._device_sample_counts.clear()

self._next_batch_index = expected
def finalize_hessian(self, target_device: Optional[torch.device] = None) -> torch.Tensor:
self._materialize_global_hessian(target_device=target_device)
if self.H is None:
self.H = torch.zeros((self.columns, self.columns), dtype=torch.float32, device=self._select_hessian_target_device(target_device))
return self.H

# FIXME, optimum needs fasterquant, we need to remove it
def fasterquant(
Expand Down Expand Up @@ -590,12 +598,8 @@ def quantize(
# log.info(f"Quantization `{self.name}` using samples: `{self.nsamples}`")
start = time.time()

with self.lock:
self._flush_pending_updates_locked(allow_gaps=True)
if self._pending_updates:
raise RuntimeError(
f"Pending Hessian updates remain for module '{self.name}' before quantization."
)
target_device = getattr(self.module, "target_device", None)
self.finalize_hessian(target_device=target_device)

# Temporarily disable torch.compile due to compatibility issues with torch 2.8
# Will re-enable once the issue is fixed
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ dependencies = [
"huggingface_hub>=0.34.4",
"random_word>=1.0.13",
"tokenicer>=0.0.5",
"logbar>=0.1.2",
"logbar>=0.1.3",
"maturin>=1.9.4", # required by safetensors and hf_transfer
"datasets>=3.6.0",
"pyarrow>=21.0",
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ hf_transfer>=0.1.9
huggingface_hub>=0.34.4
random_word>=1.0.13
tokenicer>=0.0.5
logbar>=0.1.2
logbar>=0.1.3
maturin>=1.9.4
datasets>=3.6.0
pyarrow>=21.0
Expand Down
2 changes: 1 addition & 1 deletion tests/models/test_qwen3_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ class TestQwen3Moe(ModelTest):
DESC_ACT = False
DATASET_SIZE = 1024
DATASET_SORT = "desc"
QUANT_BATCH_SIZE = 1
QUANT_BATCH_SIZE = 4
CALIB_NOISE_MODE = "unseen"
CALIB_NOISE_PERCENT = 0.025

Expand Down
44 changes: 40 additions & 4 deletions tests/test_gptq_queue.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
import copy

import torch
import pytest

from gptqmodel.quantization.config import QuantizeConfig
from gptqmodel.quantization.gptq import GPTQ


@torch.no_grad()
def test_out_of_order_batches_flush_in_sequence():
def test_out_of_order_batches_finalize_matches_reference():
torch.manual_seed(0)

module = torch.nn.Linear(4, 4)
Expand All @@ -23,14 +24,49 @@ def test_out_of_order_batches_flush_in_sequence():
y0 = module(x0)
y1 = module(x1)

# Add batches out of order to ensure accumulation is order agnostic.
gptq.add_batch(x1, y1, batch_index=1)
assert gptq.nsamples == 0

gptq.add_batch(x0, y0, batch_index=0)

gptq.finalize_hessian()

reference.add_batch(x0, y0, batch_index=0)
reference.add_batch(x1, y1, batch_index=1)
reference.finalize_hessian()

assert gptq.H is not None
torch.testing.assert_close(gptq.H, reference.H)
assert gptq.nsamples == reference.nsamples
assert not gptq._pending_updates
assert not gptq._device_hessian_partials


@pytest.mark.skipif(not torch.cuda.is_available(), reason="requires CUDA")
def test_finalize_hessian_preserves_device(monkeypatch):
module = torch.nn.Linear(4, 4).cuda()
cfg = QuantizeConfig()
gptq = GPTQ(module, cfg)

module_device = module.weight.device

def fake_process_batch(self, inp):
xtx = torch.eye(self.columns, dtype=torch.float32, device=module_device)
return 1, xtx.clone(), module_device

monkeypatch.setattr(GPTQ, "process_batch", fake_process_batch, raising=False)

inp = torch.zeros(1, device=module_device)

gptq.add_batch(inp, inp, batch_index=1)
gptq.add_batch(inp, inp, batch_index=0)

# No Hessian materialized until finalize is invoked.
assert gptq.H is None
assert module_device in gptq._device_hessian_partials

gptq.finalize_hessian()

assert gptq.H is not None
assert gptq.H.device == module_device
assert not gptq._device_hessian_partials

torch.cuda.synchronize()
Loading