Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 1 addition & 9 deletions gptqmodel/looper/gptq_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from ..quantization.config import QUANT_METHOD, QuantizeConfig
from ..utils.logger import setup_logger
from ..utils.model import move_to, pack_model
from ..utils.torch import CPU, DEVICE_0, DEVICE_0_STREAM, DEVICE_1, torch_streamCtx, torch_sync
from ..utils.torch import CPU, DEVICE_0, DEVICE_0_STREAM, DEVICE_1, torch_empty_cache, torch_streamCtx, torch_sync

log = setup_logger()
lock = threading.Lock()
Expand Down Expand Up @@ -222,14 +222,6 @@ def process(self, module: NamedModule, auto_gc: bool = True):
with self.lock:
self.tasks[module.name].free()

# prepare for module.forward post generate
# module.weight.data = torch.empty(1,1) # hack to remove weight.data
# if auto_gc:
# torch_empty_cache()
# with torch_streamCtx(DEVICE_0_STREAM):
# wq = wq.to(device=DEVICE_0, non_blocking=True) # move to d0 for post quant inference
wq = wq.to(device=DEVICE_0, non_blocking=False)

# logger.info(f"Quantizing module END: {name}, {gptq[name].shape()}")

module.state.update({
Expand Down
32 changes: 18 additions & 14 deletions gptqmodel/quantization/gptq.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,11 +288,10 @@ def quantize(

# Temporarily disable torch.compile due to compatibility issues with torch 2.8
# Will re-enable once the issue is fixed
if not TORCH_GTE_28:
if not TORCH_GTE_28 and not self.qcfg.mock_quantization:
self.hessian_inverse = torch_compile(self.hessian_inverse)

# Mock heavy computations
if hasattr(self.qcfg, 'mock_quantization') and self.qcfg.mock_quantization:
if self.qcfg.mock_quantization:
# Use simplified hessian inverse (identity matrix)
self.hessian_inverse = self._mock_hessian_inverse

Expand Down Expand Up @@ -367,13 +366,12 @@ def quantize(
Hinv, damp = self.hessian_inverse(H)

# Use simplified loop when mock_quantization is active
if hasattr(self.qcfg, 'mock_quantization') and self.qcfg.mock_quantization:
if self.qcfg.mock_quantization:
for i1 in range(0, self.columns, blocksize):
i2 = min(i1 + blocksize, self.columns)
count = i2 - i1

# Clone the weights like the original code to maintain device/dtype consistency
W1 = W[:, i1:i2].clone()
W1 = W[:, i1:i2]
Q1 = torch.zeros_like(W1)

# Handle group quantization parameters efficiently (similar to original)
Expand Down Expand Up @@ -559,16 +557,10 @@ def quantize(
if isinstance(self.module, transformers.Conv1D):
Q = Q.t()

# Ensure Q is on the same device as the original module weight before type conversion
if Q.device != self.module.weight.data.device:
Q = Q.to(device=self.module.weight.data.device)

if Q.shape != self.module.weight.shape:
Q = Q.reshape(self.module.weight.shape).type_as(self.module.weight.data)
Q = Q.reshape(self.module.weight.shape).to(self.module.weight.dtype)
else:
Q = Q.type_as(self.module.weight.data)

# Q = Q.to(device=use_device)
Q = Q.to(self.module.weight.dtype)

if scale == []:
scale.append(self.quantizer.scale)
Expand All @@ -577,6 +569,18 @@ def quantize(
scale = torch.cat(scale, dim=1)
zero = torch.cat(zero, dim=1)

# prepare for module.forward post generate, move to weight device with retry
if Q.device != self.module.weight.data.device:
try:
Q = Q.to(device=self.module.weight.data.device, non_blocking=False)
except Exception as e:
#log.warn(f'Failed to move Q from {Q.device} to {self.module.weight.data.device} retrying with torch_empty_cache, {e}')
try:
Q = Q.to(device=self.module.weight.data.device, non_blocking=False)
except Exception as e2:
log.error(f'Failed to move Q from {Q.device} to {self.module.weight.data.device}, {e2}')
raise

duration = time.time() - start

return Q, scale, zero, g_idx, duration, avg_loss, damp, self.nsamples
Expand Down