Skip to content
137 changes: 105 additions & 32 deletions gptqmodel/looper/awq_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,14 +180,16 @@ def _capture_previous_subset_scale(self, previous_subset: Optional[Dict[str, Nam
def _layer_input_features(self, state: _AWQLayerState) -> Dict[str, torch.Tensor]:
features: Dict[str, torch.Tensor] = {}
root_buckets: Dict[str, List[torch.Tensor]] = {}
for name in state.modules:
# Iterate over a snapshot since quantization may mutate state.modules concurrently
for name in list(state.modules):
entry = self.tasks.get(name) or {}
tensors: List[torch.Tensor] = entry.get("inputs", []) # type: ignore[arg-type]
if not tensors:
features[name] = torch.empty(0)
continue
try:
features[name] = torch.cat(tensors, dim=0)
entry["inputs"] = [features[name]]
except RuntimeError:
features[name] = tensors[0]
root = name.split(".", 1)[0]
Expand Down Expand Up @@ -576,38 +578,57 @@ def _search_best_scale(
inp = inp.to(next(module2inspect.parameters()).device)

# [STEP 1]: Compute per-channel mean of normalised weights
# All layer weights are concatted together
weight = torch.cat([_m.weight for _m in layers], dim=0)
org_shape = weight.shape
# The weights are reshaped to be organised by quantization group
weight = weight.view(-1, self.qcfg.group_size)
# Calculates the relative magnitude of the weights within each of the quantization groups,
# and rescales each group individually so that each group has weights on a 0-1 scale.
w_scale = weight.abs() / (weight.abs().amax(dim=1, keepdim=True) + 1e-6)
# Resizes the rescaled weight matrix back up to its original dimensions
w_scale = w_scale.view(org_shape)
# Gets the average rescaled magnitude for each output channel
w_mean = w_scale.mean(0)
del weight
# Accumulate statistics per-layer to avoid concatenating large tensors
# (original implementation materialized a giant cat() that doubled VRAM usage)
first_weight = layers[0].weight
weight_dtype = first_weight.dtype
weight_device = first_weight.device
num_channels = first_weight.shape[1]
w_sum = torch.zeros(num_channels, dtype=torch.float32, device=weight_device)
row_count = 0

for layer in layers:
weight = layer.weight
if weight.shape[1] != num_channels:
raise ValueError(
f"Expected consistent in_features across layers ({num_channels}), "
f"got {weight.shape[1]} for layer {layer}."
)
org_shape = weight.shape
weight_abs = weight.abs()
weight_group = weight_abs.view(-1, self.qcfg.group_size)
group_scale = weight_group.amax(dim=1, keepdim=True) + 1e-6
normalized = weight_group / group_scale
normalized = normalized.view(org_shape)
w_sum += normalized.sum(dim=0, dtype=torch.float32)
row_count += org_shape[0]

if row_count == 0:
w_mean = torch.zeros(num_channels, dtype=weight_dtype, device=weight_device)
else:
w_mean = (w_sum / row_count).to(weight_dtype)

# [STEP 2]: Compute per-channel mean of the input activation with chunking
# move inp to cpu to avoid memory leak
inp_flat = inp.cpu().abs().view(-1, inp.shape[-1])
# Stream directly on the source device to avoid creating full CPU copies
inp_flat = inp.abs().view(-1, inp.shape[-1])
num_elements = inp_flat.size(0)
num_channels = inp_flat.size(1)
element_size_bytes = inp_flat.element_size() * 2 # multiplied by 2 for FP32
float32_size = torch.tensor([], dtype=torch.float32).element_size()
element_size_bytes = float32_size # accumulation happens in FP32

# Calculate chunk size dynamically based on max_chunk_memory
chunk_size = int(self.max_chunk_memory // (element_size_bytes * num_channels))
chunk_size = min(chunk_size, num_elements)
chunk_size = max(chunk_size, 1)

# Use float32 for sum calculation
x_sum = torch.zeros(num_channels, dtype=torch.float32, device=inp.device)

for i in range(0, num_elements, chunk_size):
end = min(i + chunk_size, num_elements)
chunk_sum = inp_flat[i:end].to(torch.float32).sum(dim=0)
x_sum += chunk_sum.to(inp.device)
chunk = inp_flat[i:end]
chunk_sum = chunk.to(torch.float32).sum(dim=0)
x_sum += chunk_sum

x_mean = (x_sum / num_elements).to(inp.dtype)
del x_sum
Expand Down Expand Up @@ -683,6 +704,11 @@ def _compute_best_clip(
assert org_w_shape[0] % oc_batch_size == 0
w_all = w
best_max_val_all = []
device = w_all.device
# Pre-allocate scratch buffers so the inner loop never allocates large temporaries
scratch_clamp = torch.empty_like(w_all[:oc_batch_size])
scratch_quant = torch.empty_like(scratch_clamp)
input_feat = input_feat.to(device)

for i_b in range(org_w_shape[0] // oc_batch_size):
w = w_all[i_b * oc_batch_size: (i_b + 1) * oc_batch_size]
Expand All @@ -691,20 +717,19 @@ def _compute_best_clip(

best_max_val = org_max_val.clone()
min_errs = torch.ones_like(org_max_val) * 1e9
input_feat = input_feat.to(w.device)
org_out = (input_feat * w).sum(dim=-1) # co, n_token, n_group
clamp_slice = scratch_clamp[: w.shape[0]]
quant_slice = scratch_quant[: w.shape[0]]

org_out = (input_feat * w).sum(dim=-1)

for i_s in range(int(max_shrink * n_grid)):
max_val = org_max_val * (1 - i_s / n_grid)
min_val = -max_val
cur_w = torch.clamp(w, min_val, max_val)
q_w = self.pseudo_quantize_tensor(cur_w)[0]
cur_out = (input_feat * q_w).sum(dim=-1)
torch.clamp(w, min_val, max_val, out=clamp_slice)
self._pseudo_quantize_tensor_into(clamp_slice, quant_slice)
cur_out = (input_feat * quant_slice).sum(dim=-1)

# co, 1, n_group, 1
err = (cur_out - org_out).pow(2).mean(dim=1).view(min_errs.shape)
del cur_w
del cur_out
cur_best_idx = err < min_errs
min_errs[cur_best_idx] = err[cur_best_idx]
best_max_val[cur_best_idx] = max_val[cur_best_idx]
Expand Down Expand Up @@ -753,6 +778,45 @@ def pseudo_quantize_tensor(self, w: torch.Tensor):

return w, scales, zeros

@torch.inference_mode()
def _pseudo_quantize_tensor_into(self, src: torch.Tensor, dst: torch.Tensor) -> None:
# Quantize `src` into `dst` without allocating a new tensor (mirrors pseudo_quantize_tensor)
org_shape = src.shape
if self.qcfg.group_size > 0:
src_view = src.view(-1, self.qcfg.group_size)
dst_view = dst.view(-1, self.qcfg.group_size)
else:
src_view = src.reshape(org_shape[0], -1)
dst_view = dst.reshape_as(src_view)

if self.qcfg.zero_point:
max_val = src_view.amax(dim=1, keepdim=True)
min_val = src_view.amin(dim=1, keepdim=True)
max_int = 2 ** self.qcfg.bits - 1
min_int = 0
scales = (max_val - min_val).clamp(min=1e-5) / max_int
zeros = (-torch.round(min_val / scales)).clamp_(min_int, max_int)

dst_view.copy_(src_view)
dst_view.div_(scales)
torch.round(dst_view, out=dst_view)
dst_view.add_(zeros)
dst_view.clamp_(min_int, max_int)
dst_view.sub_(zeros)
dst_view.mul_(scales)
else:
max_val = src_view.abs().amax(dim=1, keepdim=True).clamp(min=1e-5)
max_int = 2 ** (self.qcfg.bits - 1) - 1
min_int = -(2 ** (self.qcfg.bits - 1))
scales = max_val / max_int

dst_view.copy_(src_view)
dst_view.div_(scales)
torch.round(dst_view, out=dst_view)
dst_view.clamp_(min_int, max_int)
dst_view.mul_(scales)


def _compute_best_scale(
self,
x: torch.Tensor,
Expand All @@ -778,7 +842,12 @@ def _compute_best_scale(
best_scales = None
best_error = float("inf")

org_sd = {k: v.cpu() for k, v in module2inspect.state_dict().items()}
# Clone the original FP weights to CPU once so we can mutate/restore without load_state_dict overhead
orig_weights_cpu: Dict[nn.Linear, torch.Tensor] = {
# stash a contiguous FP32 master copy on CPU; avoids tying up GPU memory between ratios
fc: fc.weight.detach().to(torch.float32).cpu().contiguous()
for fc in linears2scale
}

device = x.device
x_mean = x_mean.view(-1).to(device)
Expand Down Expand Up @@ -807,9 +876,8 @@ def _compute_best_scale(
# Q(W * s)
for fc in linears2scale:
fc.weight.mul_(scales_view)
fc.weight.data = (
self.pseudo_quantize_tensor(fc.weight.data)[0] / scales_view
)
self._pseudo_quantize_tensor_into(fc.weight, fc.weight)
fc.weight.div_(scales_view)

# W * X
int_w_output = self._module_forward(x, module2inspect, kwargs)
Expand All @@ -823,7 +891,12 @@ def _compute_best_scale(
best_error = loss
best_ratio = ratio
best_scales = scales.clone()
module2inspect.load_state_dict(org_sd)
for fc in linears2scale:
fc.weight.copy_(orig_weights_cpu[fc].to(device=fc.weight.device, dtype=fc.weight.dtype))

for fc in linears2scale:
fc.weight.copy_(orig_weights_cpu[fc].to(device=fc.weight.device, dtype=fc.weight.dtype))
orig_weights_cpu.clear()

if best_ratio == -1:
log.debug(history)
Expand Down
13 changes: 7 additions & 6 deletions tests/test_awq.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,12 +57,13 @@ def setUpClass(cls):
except Exception:
total_mem_gb = 0

if total_mem_gb >= 80:
sample_count = 1024
elif total_mem_gb >= 48:
sample_count = 512
else:
sample_count = 192
# if total_mem_gb >= 80:
# sample_count = 1024
# elif total_mem_gb >= 48:
# sample_count = 512
# else:
# sample_count = 192
sample_count = 512

traindata = load_dataset("json", data_files="/monster/data/model/dataset/c4-train.00000-of-01024.json.gz",
split="train")
Expand Down
84 changes: 84 additions & 0 deletions tests/test_awq_clip_consistency.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
import torch
from parameterized import parameterized

from gptqmodel.looper.awq_processor import AWQProcessor
from gptqmodel.quantization.config import QuantizeConfig


class _ClipTestAWQProcessor(AWQProcessor):
def __init__(self, qcfg: QuantizeConfig) -> None:
super().__init__(
tokenizer=None,
qcfg=qcfg,
calibration=None,
prepare_dataset_func=None,
calibration_concat_size=None,
calibration_sort=None,
batch_size=1,
gptq_model=None,
model=None,
require_fwd=True,
calculate_w_wq_diff=False,
calibration_concat_separator=None,
)

def _module_forward(self, x, module, module_kwargs):
return module(x)


def _legacy_clip(processor: AWQProcessor, w: torch.Tensor, input_feat: torch.Tensor):
group_size = processor.qcfg.group_size if processor.qcfg.group_size > 0 else w.shape[1]
input_feat = input_feat.view(-1, input_feat.shape[-1])
input_feat = input_feat.reshape(1, input_feat.shape[0], -1, group_size)
step_size = max(1, input_feat.shape[1] // 512)
input_feat = input_feat[:, ::step_size]

w = w.reshape(w.shape[0], 1, -1, group_size)
oc_batch_size = 256 if w.shape[0] % 256 == 0 else 64
assert w.shape[0] % oc_batch_size == 0
best_max_val_all = []
for i_b in range(w.shape[0] // oc_batch_size):
w_chunk = w[i_b * oc_batch_size: (i_b + 1) * oc_batch_size]
org_max_val = w_chunk.abs().amax(dim=-1, keepdim=True)
best_max_val = org_max_val.clone()
min_errs = torch.ones_like(org_max_val) * 1e9
input_feat = input_feat.to(w_chunk.device)
org_out = (input_feat * w_chunk).sum(dim=-1)
for i_s in range(int(0.5 * 20)):
max_val = org_max_val * (1 - i_s / 20)
min_val = -max_val
cur_w = torch.clamp(w_chunk, min_val, max_val)
q_w = processor.pseudo_quantize_tensor(cur_w)[0]
cur_out = (input_feat * q_w).sum(dim=-1)
err = (cur_out - org_out).pow(2).mean(dim=1).view(min_errs.shape)
cur_best_idx = err < min_errs
min_errs[cur_best_idx] = err[cur_best_idx]
best_max_val[cur_best_idx] = max_val[cur_best_idx]
best_max_val_all.append(best_max_val)
return torch.cat(best_max_val_all, dim=0).squeeze(1)


@parameterized.expand([
("cpu", "cpu"),
("cuda", "cuda:0"),
])
def test_awq_clip_consistency(device_name: str, device_str: str):
if device_name == "cuda" and not torch.cuda.is_available():
raise AssertionError("CUDA is not available for clip consistency test")

dtype = torch.float32 if device_name == "cpu" else torch.float16
processor = _ClipTestAWQProcessor(QuantizeConfig(group_size=128))

out_features = 256
in_features = 3584
w = torch.randn(out_features, in_features, dtype=dtype, device=device_str)
tokens = 1024
input_feat = torch.randn(tokens, in_features, dtype=dtype, device=device_str)

# Compare the streaming implementation against the legacy tensor-per-iter path
expected = _legacy_clip(processor, w.clone(), input_feat.clone())
actual = processor._compute_best_clip(w, input_feat)

tol = 1e-6 if dtype == torch.float32 else 1e-4
assert torch.allclose(actual.cpu(), expected.cpu(), atol=tol, rtol=tol), \
f"Inconsistent clip: max diff {(actual - expected).abs().max().item():.3e}"
2 changes: 1 addition & 1 deletion tests/test_awq_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def setUpClass(self):
self.tokenizer = AutoTokenizer.from_pretrained(self.pretrained_model_id, use_fast=True)

traindata = load_dataset("json", data_files="/monster/data/model/dataset/c4-train.00000-of-01024.json.gz", split="train")
self.calibration_dataset = traindata.select(range(4096))
self.calibration_dataset = traindata.select(range(512))

# def test_load_group_128(self):
# model = GPTQModel.load(
Expand Down
Loading