Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 15 additions & 7 deletions gptqmodel/looper/awq_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -578,8 +578,10 @@ def _search_best_scale(
inp = inp.to(next(module2inspect.parameters()).device)

# [STEP 1]: Compute per-channel mean of normalised weights
# Accumulate statistics per-layer to avoid concatenating large tensors
# (original implementation materialized a giant cat() that doubled VRAM usage)
# Stream across each Linear instead of concatenating all weights at once. This mirrors the
# previous cat()+view pipeline while keeping peak memory low: for every group we normalise
# |w| by its per-group max so the values land on a [0, 1] scale, then accumulate the totals
# per output channel so the mean can be computed without allocating the combined tensor.
first_weight = layers[0].weight
weight_dtype = first_weight.dtype
weight_device = first_weight.device
Expand Down Expand Up @@ -609,14 +611,15 @@ def _search_best_scale(
w_mean = (w_sum / row_count).to(weight_dtype)

# [STEP 2]: Compute per-channel mean of the input activation with chunking
# Stream directly on the source device to avoid creating full CPU copies
# Stream directly on the source device to avoid creating full CPU copies while still enforcing
# a predictable memory bound derived from max_chunk_memory.
inp_flat = inp.abs().view(-1, inp.shape[-1])
num_elements = inp_flat.size(0)
num_channels = inp_flat.size(1)
float32_size = torch.tensor([], dtype=torch.float32).element_size()
element_size_bytes = float32_size # accumulation happens in FP32

# Calculate chunk size dynamically based on max_chunk_memory
# Calculate chunk size dynamically based on the available memory budget (default 1 GiB).
chunk_size = int(self.max_chunk_memory // (element_size_bytes * num_channels))
chunk_size = min(chunk_size, num_elements)
chunk_size = max(chunk_size, 1)
Expand All @@ -627,6 +630,7 @@ def _search_best_scale(
for i in range(0, num_elements, chunk_size):
end = min(i + chunk_size, num_elements)
chunk = inp_flat[i:end]
# Accumulate each chunk in FP32 to balance precision and memory usage.
chunk_sum = chunk.to(torch.float32).sum(dim=0)
x_sum += chunk_sum

Expand Down Expand Up @@ -705,22 +709,22 @@ def _compute_best_clip(
w_all = w
best_max_val_all = []
device = w_all.device
# Pre-allocate scratch buffers so the inner loop never allocates large temporaries
# Pre-allocate scratch buffers so the inner clamp loop never allocates large temporaries.
scratch_clamp = torch.empty_like(w_all[:oc_batch_size])
scratch_quant = torch.empty_like(scratch_clamp)
input_feat = input_feat.to(device)

for i_b in range(org_w_shape[0] // oc_batch_size):
w = w_all[i_b * oc_batch_size: (i_b + 1) * oc_batch_size]

org_max_val = w.abs().amax(dim=-1, keepdim=True) # co, 1, n_group, 1
org_max_val = w.abs().amax(dim=-1, keepdim=True) # [co_batch, 1, n_group, 1]

best_max_val = org_max_val.clone()
min_errs = torch.ones_like(org_max_val) * 1e9
clamp_slice = scratch_clamp[: w.shape[0]]
quant_slice = scratch_quant[: w.shape[0]]

org_out = (input_feat * w).sum(dim=-1)
org_out = (input_feat * w).sum(dim=-1) # [co_batch, n_token, n_group]

for i_s in range(int(max_shrink * n_grid)):
max_val = org_max_val * (1 - i_s / n_grid)
Expand All @@ -729,6 +733,7 @@ def _compute_best_clip(
self._pseudo_quantize_tensor_into(clamp_slice, quant_slice)
cur_out = (input_feat * quant_slice).sum(dim=-1)

# Evaluate the reconstruction error for the current clamp ratio and keep the best one.
err = (cur_out - org_out).pow(2).mean(dim=1).view(min_errs.shape)
cur_best_idx = err < min_errs
min_errs[cur_best_idx] = err[cur_best_idx]
Expand Down Expand Up @@ -874,6 +879,8 @@ def _compute_best_scale(
scales[torch.isnan(scales)] = 1

# Q(W * s)
# Temporarily apply the candidate scale, quantize the in-flight weights without allocating,
# and rely on the CPU master copy to restore the original FP values after evaluation.
for fc in linears2scale:
fc.weight.mul_(scales_view)
self._pseudo_quantize_tensor_into(fc.weight, fc.weight)
Expand All @@ -894,6 +901,7 @@ def _compute_best_scale(
for fc in linears2scale:
fc.weight.copy_(orig_weights_cpu[fc].to(device=fc.weight.device, dtype=fc.weight.dtype))

# Reset weights one final time so callers always see the pristine FP copy.
for fc in linears2scale:
fc.weight.copy_(orig_weights_cpu[fc].to(device=fc.weight.device, dtype=fc.weight.dtype))
orig_weights_cpu.clear()
Expand Down
41 changes: 21 additions & 20 deletions gptqmodel/looper/stage_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

import torch

from .. import DEVICE_THREAD_POOL
from .. import DEBUG_ON, DEVICE_THREAD_POOL
from ..looper.awq_processor import AWQProcessor
from ..looper.gptq_processor import GPTQProcessor
from ..looper.named_module import NamedModule
Expand Down Expand Up @@ -107,25 +107,26 @@ def run_layer_stage(
for index, names in enumerate(modules):
# Process the layer in smaller subsets so attention groups or
# MoE experts can be quantized independently within a layer.
if isinstance(processor, AWQProcessor):
log.info(
"StageLayer[awq]: layer=%s subset=%s/%s size=%s names=%s",
layer_index,
index + 1,
subset_total,
len(names),
names[:5],
)
elif log.isEnabledFor(logging.DEBUG):
log.debug(
"StageLayer: layer=%s subset=%s/%s processor=%s size=%s names=%s",
layer_index,
index + 1,
subset_total,
processor.name(),
len(names),
names[:8],
)
if DEBUG_ON and log.isEnabledFor(logging.DEBUG):
if isinstance(processor, AWQProcessor):
log.debug(
"StageLayer[awq]: layer=%s subset=%s/%s size=%s names=%s",
layer_index,
index + 1,
subset_total,
len(names),
names[:5],
)
else:
log.debug(
"StageLayer: layer=%s subset=%s/%s processor=%s size=%s names=%s",
layer_index,
index + 1,
subset_total,
processor.name(),
len(names),
names[:8],
)
subset_result = run_subset_stage(
looper,
processor=processor,
Expand Down
117 changes: 60 additions & 57 deletions gptqmodel/looper/stage_subset.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

import torch

from .. import DEVICE_THREAD_POOL
from .. import DEBUG_ON, DEVICE_THREAD_POOL
from ..looper.gptq_processor import GPTQProcessor
from ..looper.loop_processor import LoopProcessor
from ..looper.named_module import NamedModule
Expand Down Expand Up @@ -111,25 +111,26 @@ def run_subset_stage(
# )
# return SubsetStageResult(processed_subset={}, layer_inputs=layer_inputs, forward_context=None)

if is_awq_processor:
logger.info(
"StageSubset[awq]: layer=%s subset=%s/%s modules=%s sample=%s",
layer_index,
subset_index + 1,
subset_total,
len(subset),
list(subset.keys())[:8],
)
elif logger.isEnabledFor(logging.DEBUG):
logger.debug(
"StageSubset: layer=%s subset=%s/%s processor=%s created %s modules (sample=%s)",
layer_index,
subset_index + 1,
subset_total,
processor_name,
len(subset),
list(subset.keys())[:8],
)
if DEBUG_ON and logger.isEnabledFor(logging.DEBUG):
if is_awq_processor:
logger.debug(
"StageSubset[awq]: layer=%s subset=%s/%s modules=%s sample=%s",
layer_index,
subset_index + 1,
subset_total,
len(subset),
list(subset.keys())[:8],
)
else:
logger.debug(
"StageSubset: layer=%s subset=%s/%s processor=%s created %s modules (sample=%s)",
layer_index,
subset_index + 1,
subset_total,
processor_name,
len(subset),
list(subset.keys())[:8],
)

moe_group_keys_all: List[str] = []
forward_device_map: Dict[str, torch.device] = {}
Expand Down Expand Up @@ -242,23 +243,24 @@ def run_subset_stage(
looper._masked_hook_wrapper(processor, original_hook, hook_source)
))

if is_awq_processor:
logger.info(
"StageSubset[awq]: layer=%s subset=%s/%s registering hooks for %s modules",
layer_index,
subset_index + 1,
subset_total,
len(subset),
)
elif logger.isEnabledFor(logging.DEBUG):
logger.debug(
"StageSubset: layer=%s subset=%s/%s processor=%s registering hooks for %s modules",
layer_index,
subset_index + 1,
subset_total,
processor_name,
len(subset),
)
if DEBUG_ON and logger.isEnabledFor(logging.DEBUG):
if is_awq_processor:
logger.debug(
"StageSubset[awq]: layer=%s subset=%s/%s registering hooks for %s modules",
layer_index,
subset_index + 1,
subset_total,
len(subset),
)
else:
logger.debug(
"StageSubset: layer=%s subset=%s/%s processor=%s registering hooks for %s modules",
layer_index,
subset_index + 1,
subset_total,
processor_name,
len(subset),
)

fwd_start = time.perf_counter()
forward_source = f"{layer_descriptor}:subset{subset_index + 1}/{subset_total}"
Expand Down Expand Up @@ -410,26 +412,27 @@ def _process_on_worker(
timer = getattr(looper.gptq_model, "quant_region_timer", None)
start = time.perf_counter() if timer else None
try:
if is_awq_processor:
logger.info(
"StageSubsetWorker[awq]: layer=%s subset=%s/%s module=%s previous_subset=%s",
getattr(nm, "layer_index", None),
subset_idx + 1,
subset_total_count,
module_label,
bool(previous_subset_ref),
)
elif logger.isEnabledFor(logging.DEBUG):
logger.debug(
"StageSubsetWorker: processor=%s layer=%s subset=%s/%s module=%s running on %s (previous_subset=%s)",
proc_name,
getattr(nm, "layer_index", None),
subset_idx + 1,
subset_total_count,
module_label,
expected_device,
bool(previous_subset_ref),
)
if DEBUG_ON and logger.isEnabledFor(logging.DEBUG):
if is_awq_processor:
logger.debug(
"StageSubsetWorker[awq]: layer=%s subset=%s/%s module=%s previous_subset=%s",
getattr(nm, "layer_index", None),
subset_idx + 1,
subset_total_count,
module_label,
bool(previous_subset_ref),
)
else:
logger.debug(
"StageSubsetWorker: processor=%s layer=%s subset=%s/%s module=%s running on %s (previous_subset=%s)",
proc_name,
getattr(nm, "layer_index", None),
subset_idx + 1,
subset_total_count,
module_label,
expected_device,
bool(previous_subset_ref),
)
proc.process(
module=nm,
subset=subset_ref,
Expand Down
2 changes: 1 addition & 1 deletion gptqmodel/version.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,4 @@
# even minor versions are release
# 5.2.0 => release, 5.1.0 => devel
# micro version (5.2.x) denotes patch fix, i.e. 5.2.1 is a patch fix release
__version__ = "5.2.0"
__version__ = "5.3.0"
Loading