From 0bbb7945309c5272d9334813dce5a63c89e7d0eb Mon Sep 17 00:00:00 2001 From: Qubitium Date: Thu, 16 Oct 2025 05:54:04 +0000 Subject: [PATCH 1/5] move to cpu if order is not expected --- gptqmodel/quantization/gptq.py | 66 ++++++++++++++++++++++++++++------ 1 file changed, 56 insertions(+), 10 deletions(-) diff --git a/gptqmodel/quantization/gptq.py b/gptqmodel/quantization/gptq.py index 0051084a0..6702fb1d2 100644 --- a/gptqmodel/quantization/gptq.py +++ b/gptqmodel/quantization/gptq.py @@ -174,6 +174,7 @@ def __init__(self, module: nn.Module, qcfg: Optional[QuantizeConfig] = None): self._default_hessian_device = torch.device(module_device) self._hessian_device: Optional[torch.device] = None + self._hessian_streams: Dict[Tuple[str, Optional[int]], torch.cuda.Stream] = {} self._validate_module(self.module) @@ -262,7 +263,21 @@ def add_batch(self, inp: torch.Tensor, out: torch.Tensor, batch_index: Optional[ batch_token_size, xtx, device = self.process_batch(inp) - pending_index = batch_index if batch_index is not None else self._next_batch_index + expected_index = self._next_batch_index + pending_index = batch_index if batch_index is not None else expected_index + + if ( + xtx is not None + and hasattr(xtx, "device") + and xtx.device.type != "cpu" + and pending_index != expected_index + ): + xtx = xtx.to(device="cpu") + device = torch.device("cpu") + + if xtx is not None and hasattr(xtx, "device") and xtx.device.type == "cpu": + device = torch.device("cpu") + heapq.heappush(self._pending_updates, (pending_index, batch_token_size, xtx, device)) self._flush_pending_updates_locked() @@ -332,6 +347,17 @@ def _borrow_materialized_chunk_fp32( if device.type == "cuda": torch.cuda.current_stream(device).synchronize() + def _get_hessian_stream(self, device: torch.device) -> Optional[torch.cuda.Stream]: + dev = torch.device(device) + if dev.type != "cuda" or not torch.cuda.is_available(): + return None + + key = _device_cache_key(dev) + stream = self._hessian_streams.get(key) + if stream is None: + stream = torch.cuda.Stream(device=dev.index) + self._hessian_streams[key] = stream + return stream def _compute_hessian_xtx(self, matrix: torch.Tensor) -> torch.Tensor: rows = matrix.shape[0] if rows == 0: @@ -496,16 +522,36 @@ def _flush_pending_updates_locked(self, *, allow_gaps: bool = False) -> None: if target_device is None: target_device = self.H.device - self.H = self.H.to(device=target_device) - if xtx.device != target_device: - xtx = xtx.to(device=target_device) + torch_device = torch.device(target_device) + + stream = self._get_hessian_stream(torch_device) + + if stream is not None: + with torch.cuda.stream(stream): + self.H = self.H.to(device=torch_device, non_blocking=True) + if xtx.device != torch_device: + xtx = xtx.to(device=torch_device, non_blocking=True) - total = self.nsamples + batch_token_size - beta = self.nsamples / total - alpha = 2.0 / total - self.H.mul_(beta) - self.H.add_(xtx, alpha=alpha) - self.nsamples = total + total = self.nsamples + batch_token_size + beta = self.nsamples / total + alpha = 2.0 / total + self.H.mul_(beta) + self.H.add_(xtx, alpha=alpha) + self.nsamples = total + + self.H.record_stream(stream) + xtx.record_stream(stream) + else: + self.H = self.H.to(device=torch_device) + if xtx.device != torch_device: + xtx = xtx.to(device=torch_device) + + total = self.nsamples + batch_token_size + beta = self.nsamples / total + alpha = 2.0 / total + self.H.mul_(beta) + self.H.add_(xtx, alpha=alpha) + self.nsamples = total del xtx From 5bedb64462c5fdd666f7feaa8e52c373d30e2f0c Mon Sep 17 00:00:00 2001 From: Qubitium Date: Thu, 16 Oct 2025 06:16:11 +0000 Subject: [PATCH 2/5] Revert "move to cpu if order is not expected" This reverts commit 0bbb7945309c5272d9334813dce5a63c89e7d0eb. --- gptqmodel/quantization/gptq.py | 66 ++++++---------------------------- 1 file changed, 10 insertions(+), 56 deletions(-) diff --git a/gptqmodel/quantization/gptq.py b/gptqmodel/quantization/gptq.py index 6702fb1d2..0051084a0 100644 --- a/gptqmodel/quantization/gptq.py +++ b/gptqmodel/quantization/gptq.py @@ -174,7 +174,6 @@ def __init__(self, module: nn.Module, qcfg: Optional[QuantizeConfig] = None): self._default_hessian_device = torch.device(module_device) self._hessian_device: Optional[torch.device] = None - self._hessian_streams: Dict[Tuple[str, Optional[int]], torch.cuda.Stream] = {} self._validate_module(self.module) @@ -263,21 +262,7 @@ def add_batch(self, inp: torch.Tensor, out: torch.Tensor, batch_index: Optional[ batch_token_size, xtx, device = self.process_batch(inp) - expected_index = self._next_batch_index - pending_index = batch_index if batch_index is not None else expected_index - - if ( - xtx is not None - and hasattr(xtx, "device") - and xtx.device.type != "cpu" - and pending_index != expected_index - ): - xtx = xtx.to(device="cpu") - device = torch.device("cpu") - - if xtx is not None and hasattr(xtx, "device") and xtx.device.type == "cpu": - device = torch.device("cpu") - + pending_index = batch_index if batch_index is not None else self._next_batch_index heapq.heappush(self._pending_updates, (pending_index, batch_token_size, xtx, device)) self._flush_pending_updates_locked() @@ -347,17 +332,6 @@ def _borrow_materialized_chunk_fp32( if device.type == "cuda": torch.cuda.current_stream(device).synchronize() - def _get_hessian_stream(self, device: torch.device) -> Optional[torch.cuda.Stream]: - dev = torch.device(device) - if dev.type != "cuda" or not torch.cuda.is_available(): - return None - - key = _device_cache_key(dev) - stream = self._hessian_streams.get(key) - if stream is None: - stream = torch.cuda.Stream(device=dev.index) - self._hessian_streams[key] = stream - return stream def _compute_hessian_xtx(self, matrix: torch.Tensor) -> torch.Tensor: rows = matrix.shape[0] if rows == 0: @@ -522,36 +496,16 @@ def _flush_pending_updates_locked(self, *, allow_gaps: bool = False) -> None: if target_device is None: target_device = self.H.device - torch_device = torch.device(target_device) - - stream = self._get_hessian_stream(torch_device) - - if stream is not None: - with torch.cuda.stream(stream): - self.H = self.H.to(device=torch_device, non_blocking=True) - if xtx.device != torch_device: - xtx = xtx.to(device=torch_device, non_blocking=True) + self.H = self.H.to(device=target_device) + if xtx.device != target_device: + xtx = xtx.to(device=target_device) - total = self.nsamples + batch_token_size - beta = self.nsamples / total - alpha = 2.0 / total - self.H.mul_(beta) - self.H.add_(xtx, alpha=alpha) - self.nsamples = total - - self.H.record_stream(stream) - xtx.record_stream(stream) - else: - self.H = self.H.to(device=torch_device) - if xtx.device != torch_device: - xtx = xtx.to(device=torch_device) - - total = self.nsamples + batch_token_size - beta = self.nsamples / total - alpha = 2.0 / total - self.H.mul_(beta) - self.H.add_(xtx, alpha=alpha) - self.nsamples = total + total = self.nsamples + batch_token_size + beta = self.nsamples / total + alpha = 2.0 / total + self.H.mul_(beta) + self.H.add_(xtx, alpha=alpha) + self.nsamples = total del xtx From a582c992e2ea603f673d9c56b4bdff45b889f334 Mon Sep 17 00:00:00 2001 From: Qubitium Date: Thu, 16 Oct 2025 06:36:48 +0000 Subject: [PATCH 3/5] fix memory usage --- gptqmodel/looper/module_looper.py | 36 ++++++- gptqmodel/quantization/gptq.py | 162 ++++++++++++++---------------- pyproject.toml | 2 +- requirements.txt | 2 +- tests/test_gptq_queue.py | 44 +++++++- tests/test_hessian_chunk.py | 12 ++- tests/test_hessian_merge.py | 56 +++++++++++ 7 files changed, 211 insertions(+), 103 deletions(-) create mode 100644 tests/test_hessian_merge.py diff --git a/gptqmodel/looper/module_looper.py b/gptqmodel/looper/module_looper.py index 04a3a88b9..88fda7dfc 100644 --- a/gptqmodel/looper/module_looper.py +++ b/gptqmodel/looper/module_looper.py @@ -588,7 +588,6 @@ def _run_forward_batches_parallel( results: Dict[int, torch.Tensor | tuple | None] = {} - chunk = len(devices) total_batches = self._resolve_batch_total(processor.num_batches, layer_inputs) batch_row_counts = progress_rows_per_batch or self._collect_row_counts(layer_inputs) batch_row_counts = list(batch_row_counts) @@ -602,11 +601,38 @@ def _run_forward_batches_parallel( total_rows = max(total_rows, 1) processed_rows = 0 stage_label = progress_stage or "Forward" - for start in range(0, total_batches, chunk): + device_segments: Dict[torch.device, List[int]] = {} + segment_start = 0 + num_devices = len(devices) + + for index, device in enumerate(devices): + remaining_batches = max(total_batches - segment_start, 0) + remaining_devices = max(num_devices - index, 1) + segment_length = remaining_batches // remaining_devices + remainder = remaining_batches % remaining_devices + if remainder > 0: + segment_length += 1 + + if segment_length <= 0: + device_segments[device] = [] + continue + + segment_end = min(segment_start + segment_length, total_batches) + device_segments[device] = list(range(segment_start, segment_end)) + segment_start = segment_end + + max_segment_length = 0 + for indices in device_segments.values(): + if len(indices) > max_segment_length: + max_segment_length = len(indices) + + for position in range(max_segment_length): futures = [] - end = min(start + chunk, total_batches) - for offset, batch_idx in enumerate(range(start, end)): - device = devices[offset] + for device in devices: + segment_indices = device_segments.get(device, []) + if position >= len(segment_indices): + continue + batch_idx = segment_indices[position] replica = module_replicas[device] submitter = ( DEVICE_THREAD_POOL.submit_serial diff --git a/gptqmodel/quantization/gptq.py b/gptqmodel/quantization/gptq.py index 0051084a0..18428e6bd 100644 --- a/gptqmodel/quantization/gptq.py +++ b/gptqmodel/quantization/gptq.py @@ -6,7 +6,6 @@ # adapted from @qwopqwop200 's [GPTQ-for-LLaMa](https://github.com/qwopqwop200/GPTQ-for-LLaMa/tree/cuda), which itself is based on [gptq](https://github.com/IST-DASLab/gptq) import contextlib -import heapq import math import os import sys @@ -169,11 +168,9 @@ def __init__(self, module: nn.Module, qcfg: Optional[QuantizeConfig] = None): setattr(self.module, "target_device", module_device) if module_device.type == "meta": - self._default_hessian_device = torch.device("cpu") + self._final_hessian_device_hint = torch.device("cpu") else: - self._default_hessian_device = torch.device(module_device) - - self._hessian_device: Optional[torch.device] = None + self._final_hessian_device_hint = torch.device(module_device) self._validate_module(self.module) @@ -191,13 +188,13 @@ def __init__(self, module: nn.Module, qcfg: Optional[QuantizeConfig] = None): self.fail_safe = False - self.H = torch.zeros((self.columns, self.columns), - dtype=torch.float32) + self.H: Optional[torch.Tensor] = None - # Track per-batch Hessian contributions so they can be applied in a - # deterministic order even when forwards execute in parallel. - self._pending_updates: List[Tuple[int, int, Optional[torch.Tensor], Optional[torch.device]]] = [] - self._next_batch_index: int = 0 + # Store per-device Hessian contributions so multi-GPU calibration can + # keep local accumulators and merge only once when quantization begins. + self._device_hessian_partials: Dict[torch.device, torch.Tensor] = {} + self._device_sample_counts: Dict[torch.device, int] = {} + self._hessian_dirty: bool = False @staticmethod def _validate_module(module): @@ -257,14 +254,25 @@ def _truncate_last_dim(tensor: torch.Tensor, length: int) -> torch.Tensor: return tensor.narrow(tensor.dim() - 1, 0, trim).contiguous() def add_batch(self, inp: torch.Tensor, out: torch.Tensor, batch_index: Optional[int] = None): + batch_token_size, xtx, device = self.process_batch(inp) + if batch_token_size == 0 or xtx is None: + return + + dev = torch.device(device) + with self.lock: self.fwd_counter += 1 - batch_token_size, xtx, device = self.process_batch(inp) + existing = self._device_hessian_partials.get(dev) + if existing is None: + self._device_hessian_partials[dev] = xtx + else: + existing.add_(xtx) + del xtx - pending_index = batch_index if batch_index is not None else self._next_batch_index - heapq.heappush(self._pending_updates, (pending_index, batch_token_size, xtx, device)) - self._flush_pending_updates_locked() + self._device_sample_counts[dev] = self._device_sample_counts.get(dev, 0) + batch_token_size + self.nsamples += batch_token_size + self._hessian_dirty = True def _preferred_staging_dtype(self, input_dtype: torch.dtype, device: torch.device) -> torch.dtype: device = torch.device(device) @@ -355,39 +363,6 @@ def _compute_hessian_xtx(self, matrix: torch.Tensor) -> torch.Tensor: return xtx_accum - def _resolve_hessian_device(self, batch_device: torch.device) -> torch.device: - """Select a stable device for Hessian accumulation. - - The first non-meta device we observe (module target, default hint, or - batch input) becomes the canonical Hessian device for the lifetime of - this GPTQ instance. Subsequent batches keep using the same target to - avoid bouncing tensors across GPUs when calibration runs on multiple - devices concurrently. - """ - - if self._hessian_device is not None: - return self._hessian_device - - module_target = getattr(self.module, "target_device", None) - canonical = None - - if module_target is not None: - canonical = torch.device(module_target) - if canonical.type == "meta": - canonical = None - - if canonical is None and hasattr(self, "_default_hessian_device"): - canonical = self._default_hessian_device - - if canonical is None or canonical.type == "meta": - canonical = batch_device - - if canonical.type == "meta": - canonical = torch.device("cpu") - - self._hessian_device = canonical - return canonical - def process_batch(self, inp: torch.Tensor) -> Tuple[int, Optional[torch.Tensor], torch.device]: # print(f"inp = {inp}") # print(f"self.module = {self.module} device = {self.module.target_device}") @@ -436,7 +411,7 @@ def process_batch(self, inp: torch.Tensor) -> Tuple[int, Optional[torch.Tensor], if self._tp_pad_cols: pad = reshaped_inp.new_zeros((reshaped_inp.shape[0], self._tp_pad_cols)) reshaped_inp = torch.cat((reshaped_inp, pad), dim=1) - canonical_device = self._resolve_hessian_device(inp_device) + canonical_device = torch.device(inp_device) batch_token_size = reshaped_inp.shape[0] @@ -460,7 +435,6 @@ def process_batch(self, inp: torch.Tensor) -> Tuple[int, Optional[torch.Tensor], if torch.cuda.is_available(): torch.cuda.empty_cache() canonical_device = torch.device("cpu") - self._hessian_device = canonical_device xtx = self._compute_hessian_xtx(reshaped_inp_cpu).to(dtype=torch.float32) xtx = xtx.detach() del reshaped_inp_cpu @@ -473,45 +447,63 @@ def process_batch(self, inp: torch.Tensor) -> Tuple[int, Optional[torch.Tensor], return batch_token_size, xtx, canonical_device - def _flush_pending_updates_locked(self, *, allow_gaps: bool = False) -> None: - expected = self._next_batch_index - - while self._pending_updates: - index, batch_token_size, xtx, device = self._pending_updates[0] - - if index < expected: - heapq.heappop(self._pending_updates) - continue - - if not allow_gaps and index != expected: - break + def _select_hessian_target_device(self, requested: Optional[torch.device]) -> torch.device: + if requested is not None: + return torch.device(requested) - heapq.heappop(self._pending_updates) + hint = getattr(self, "_final_hessian_device_hint", None) + if hint is not None: + return torch.device(hint) - if allow_gaps and index > expected: - expected = index + if self._device_hessian_partials: + partial_device = next(iter(self._device_hessian_partials.keys())) + return torch.device(partial_device) - if batch_token_size > 0 and xtx is not None: - target_device = device if device is not None else self.H.device - if target_device is None: - target_device = self.H.device + return torch.device("cpu") - self.H = self.H.to(device=target_device) - if xtx.device != target_device: - xtx = xtx.to(device=target_device) + def _materialize_global_hessian(self, target_device: Optional[torch.device] = None) -> None: + device = self._select_hessian_target_device(target_device) - total = self.nsamples + batch_token_size - beta = self.nsamples / total - alpha = 2.0 / total - self.H.mul_(beta) - self.H.add_(xtx, alpha=alpha) - self.nsamples = total + with self.lock: + if not self._hessian_dirty and self.H is not None: + if self.H.device != device: + self.H = self.H.to(device=device) + return + + total_samples = sum(self._device_sample_counts.values()) + result = torch.zeros((self.columns, self.columns), dtype=torch.float32, device=device) + + if total_samples == 0: + self.H = result + self.nsamples = 0 + self._hessian_dirty = False + self._final_hessian_device_hint = device + self._device_hessian_partials.clear() + self._device_sample_counts.clear() + return + + for partial_device, partial in self._device_hessian_partials.items(): + if partial.device != result.device: + tmp = partial.to(result.device) + result.add_(tmp) + del tmp + else: + result.add_(partial) - del xtx + result.mul_(2.0 / float(total_samples)) - expected = index + 1 + self.H = result + self.nsamples = total_samples + self._hessian_dirty = False + self._final_hessian_device_hint = result.device + self._device_hessian_partials.clear() + self._device_sample_counts.clear() - self._next_batch_index = expected + def finalize_hessian(self, target_device: Optional[torch.device] = None) -> torch.Tensor: + self._materialize_global_hessian(target_device=target_device) + if self.H is None: + self.H = torch.zeros((self.columns, self.columns), dtype=torch.float32, device=self._select_hessian_target_device(target_device)) + return self.H # FIXME, optimum needs fasterquant, we need to remove it def fasterquant( @@ -590,12 +582,8 @@ def quantize( # log.info(f"Quantization `{self.name}` using samples: `{self.nsamples}`") start = time.time() - with self.lock: - self._flush_pending_updates_locked(allow_gaps=True) - if self._pending_updates: - raise RuntimeError( - f"Pending Hessian updates remain for module '{self.name}' before quantization." - ) + target_device = getattr(self.module, "target_device", None) + self.finalize_hessian(target_device=target_device) # Temporarily disable torch.compile due to compatibility issues with torch 2.8 # Will re-enable once the issue is fixed diff --git a/pyproject.toml b/pyproject.toml index 1604b237f..e03324abb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -46,7 +46,7 @@ dependencies = [ "huggingface_hub>=0.34.4", "random_word>=1.0.13", "tokenicer>=0.0.5", - "logbar>=0.1.2", + "logbar>=0.1.3", "maturin>=1.9.4", # required by safetensors and hf_transfer "datasets>=3.6.0", "pyarrow>=21.0", diff --git a/requirements.txt b/requirements.txt index 4d110bd82..df93ad91d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -13,7 +13,7 @@ hf_transfer>=0.1.9 huggingface_hub>=0.34.4 random_word>=1.0.13 tokenicer>=0.0.5 -logbar>=0.1.2 +logbar>=0.1.3 maturin>=1.9.4 datasets>=3.6.0 pyarrow>=21.0 diff --git a/tests/test_gptq_queue.py b/tests/test_gptq_queue.py index fe6d8601c..0a8faf77e 100644 --- a/tests/test_gptq_queue.py +++ b/tests/test_gptq_queue.py @@ -1,13 +1,14 @@ import copy import torch +import pytest from gptqmodel.quantization.config import QuantizeConfig from gptqmodel.quantization.gptq import GPTQ @torch.no_grad() -def test_out_of_order_batches_flush_in_sequence(): +def test_out_of_order_batches_finalize_matches_reference(): torch.manual_seed(0) module = torch.nn.Linear(4, 4) @@ -23,14 +24,49 @@ def test_out_of_order_batches_flush_in_sequence(): y0 = module(x0) y1 = module(x1) + # Add batches out of order to ensure accumulation is order agnostic. gptq.add_batch(x1, y1, batch_index=1) - assert gptq.nsamples == 0 - gptq.add_batch(x0, y0, batch_index=0) + gptq.finalize_hessian() + reference.add_batch(x0, y0, batch_index=0) reference.add_batch(x1, y1, batch_index=1) + reference.finalize_hessian() + assert gptq.H is not None torch.testing.assert_close(gptq.H, reference.H) assert gptq.nsamples == reference.nsamples - assert not gptq._pending_updates + assert not gptq._device_hessian_partials + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="requires CUDA") +def test_finalize_hessian_preserves_device(monkeypatch): + module = torch.nn.Linear(4, 4).cuda() + cfg = QuantizeConfig() + gptq = GPTQ(module, cfg) + + module_device = module.weight.device + + def fake_process_batch(self, inp): + xtx = torch.eye(self.columns, dtype=torch.float32, device=module_device) + return 1, xtx.clone(), module_device + + monkeypatch.setattr(GPTQ, "process_batch", fake_process_batch, raising=False) + + inp = torch.zeros(1, device=module_device) + + gptq.add_batch(inp, inp, batch_index=1) + gptq.add_batch(inp, inp, batch_index=0) + + # No Hessian materialized until finalize is invoked. + assert gptq.H is None + assert module_device in gptq._device_hessian_partials + + gptq.finalize_hessian() + + assert gptq.H is not None + assert gptq.H.device == module_device + assert not gptq._device_hessian_partials + + torch.cuda.synchronize() diff --git a/tests/test_hessian_chunk.py b/tests/test_hessian_chunk.py index 1d45f41d6..f416fd866 100644 --- a/tests/test_hessian_chunk.py +++ b/tests/test_hessian_chunk.py @@ -74,10 +74,12 @@ def test_hessian_chunk_consistency_matches_full_precision(): calib = torch.randn(128, 32, dtype=torch.float16) - gptq_full.process_batch(calib.clone()) - gptq_chunked.process_batch(calib.clone()) + _, full_xtx, full_device = gptq_full.process_batch(calib.clone()) + _, chunked_xtx, chunked_device = gptq_chunked.process_batch(calib.clone()) - assert torch.allclose(gptq_full.H, gptq_chunked.H, atol=1e-5, rtol=1e-5) + assert full_device == chunked_device + assert full_xtx is not None and chunked_xtx is not None + assert torch.allclose(full_xtx, chunked_xtx, atol=5e-6, rtol=5e-6) def test_hessian_chunk_invocations_and_workspace_shape(): @@ -168,7 +170,7 @@ def worker(task_id: int) -> None: batch_size, xtx, canonical_device = gptq.process_batch(calib) assert batch_size == rows assert xtx is not None - assert canonical_device == gptq._hessian_device + assert canonical_device == device with ThreadPoolExecutor(max_workers=8) as pool: futures = [pool.submit(worker, idx) for idx in range(16)] @@ -176,7 +178,7 @@ def worker(task_id: int) -> None: fut.result() for gptq in gptq_workers: - assert gptq._hessian_device == device + assert getattr(gptq, "_final_hessian_device_hint", None) == device cols = base.in_features cache_key = gptq_impl._workspace_cache_key(device) diff --git a/tests/test_hessian_merge.py b/tests/test_hessian_merge.py new file mode 100644 index 000000000..51d76e31f --- /dev/null +++ b/tests/test_hessian_merge.py @@ -0,0 +1,56 @@ +import copy + +import pytest +import torch + +from gptqmodel.quantization.config import QuantizeConfig +from gptqmodel.quantization.gptq import GPTQ + + +@pytest.mark.skipif( + not torch.cuda.is_available() or torch.cuda.device_count() < 4, + reason="requires at least 4 CUDA devices", +) +@torch.no_grad() +def test_hessian_merge_multi_gpu_matches_serial(): + torch.manual_seed(0) + + in_features = 16 + out_features = 8 + batch_count = 100 + per_device = batch_count // 4 + devices = [torch.device(f"cuda:{idx}") for idx in range(4)] + + base = torch.nn.Linear(in_features, out_features, bias=False).eval() + cfg_serial = QuantizeConfig() + cfg_multi = copy.deepcopy(cfg_serial) + + serial_module = copy.deepcopy(base) + multi_module = copy.deepcopy(base).to(devices[0]) + + gptq_serial = GPTQ(serial_module, cfg_serial) + gptq_multi = GPTQ(multi_module, cfg_multi) + + samples = [torch.randn(1, 1, in_features) for _ in range(batch_count)] + + for idx, sample in enumerate(samples): + gptq_serial.add_batch(sample, torch.empty(0), batch_index=idx) + + gptq_serial.finalize_hessian() + serial_hessian = gptq_serial.H.detach().cpu() + assert gptq_serial.nsamples == batch_count + + for device_idx, device in enumerate(devices): + start = device_idx * per_device + end = start + per_device + for idx in range(start, end): + sample_gpu = samples[idx].to(device) + gptq_multi.add_batch(sample_gpu, torch.empty(0, device=device), batch_index=idx) + del sample_gpu + torch.cuda.synchronize(device=device) + + gptq_multi.finalize_hessian() + merged_hessian = gptq_multi.H.detach().cpu() + assert gptq_multi.nsamples == batch_count + + torch.testing.assert_close(merged_hessian, serial_hessian, atol=1e-6, rtol=1e-6) From 6fac2087bdd87ee780961fbe01b92efd4c417d05 Mon Sep 17 00:00:00 2001 From: Qubitium Date: Thu, 16 Oct 2025 06:43:38 +0000 Subject: [PATCH 4/5] revert batch --- tests/models/test_qwen3_moe.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/models/test_qwen3_moe.py b/tests/models/test_qwen3_moe.py index 7541ca609..4916361f6 100644 --- a/tests/models/test_qwen3_moe.py +++ b/tests/models/test_qwen3_moe.py @@ -20,7 +20,7 @@ class TestQwen3Moe(ModelTest): DESC_ACT = False DATASET_SIZE = 1024 DATASET_SORT = "desc" - QUANT_BATCH_SIZE = 1 + QUANT_BATCH_SIZE = 4 CALIB_NOISE_MODE = "unseen" CALIB_NOISE_PERCENT = 0.025 From d4bdee322ed222190a6d1409edee239e68b38a4b Mon Sep 17 00:00:00 2001 From: Qubitium Date: Thu, 16 Oct 2025 06:52:58 +0000 Subject: [PATCH 5/5] try to use same H buffer --- gptqmodel/quantization/gptq.py | 18 +++++++++++++++++- 1 file changed, 17 insertions(+), 1 deletion(-) diff --git a/gptqmodel/quantization/gptq.py b/gptqmodel/quantization/gptq.py index 18428e6bd..5da3ed302 100644 --- a/gptqmodel/quantization/gptq.py +++ b/gptqmodel/quantization/gptq.py @@ -471,7 +471,23 @@ def _materialize_global_hessian(self, target_device: Optional[torch.device] = No return total_samples = sum(self._device_sample_counts.values()) - result = torch.zeros((self.columns, self.columns), dtype=torch.float32, device=device) + + reuse_buffer = ( + self.H is not None + and self.H.shape == (self.columns, self.columns) + and self.H.dtype == torch.float32 + and self.H.device == device + ) + + if reuse_buffer: + result = self.H + result.zero_() + else: + result = torch.zeros( + (self.columns, self.columns), + dtype=torch.float32, + device=device, + ) if total_samples == 0: self.H = result