diff --git a/gptqmodel/quantization/gptq.py b/gptqmodel/quantization/gptq.py index cca4416f0..ac6c3c79a 100644 --- a/gptqmodel/quantization/gptq.py +++ b/gptqmodel/quantization/gptq.py @@ -441,11 +441,20 @@ def process_batch(self, inp: torch.Tensor) -> Tuple[int, Optional[torch.Tensor], return batch_token_size, xtx, canonical_device - def _flush_pending_updates_locked(self) -> None: - while True: + def _flush_pending_updates_locked(self, *, allow_gaps: bool = False) -> None: + while self._pending_updates: update = self._pending_updates.pop(self._next_batch_index, None) if update is None: - break + if not allow_gaps: + break + + next_index = min(self._pending_updates.keys()) + if next_index != self._next_batch_index: + self._next_batch_index = next_index + + update = self._pending_updates.pop(self._next_batch_index, None) + if update is None: + break batch_token_size, xtx, device = update @@ -547,7 +556,7 @@ def quantize( start = time.time() with self.lock: - self._flush_pending_updates_locked() + self._flush_pending_updates_locked(allow_gaps=True) if self._pending_updates: raise RuntimeError( f"Pending Hessian updates remain for module '{self.name}' before quantization."