diff --git a/gptqmodel/looper/awq_processor.py b/gptqmodel/looper/awq_processor.py index e3fc542d7..8b39f45bc 100644 --- a/gptqmodel/looper/awq_processor.py +++ b/gptqmodel/looper/awq_processor.py @@ -180,7 +180,8 @@ def _capture_previous_subset_scale(self, previous_subset: Optional[Dict[str, Nam def _layer_input_features(self, state: _AWQLayerState) -> Dict[str, torch.Tensor]: features: Dict[str, torch.Tensor] = {} root_buckets: Dict[str, List[torch.Tensor]] = {} - for name in state.modules: + # Iterate over a snapshot since quantization may mutate state.modules concurrently + for name in list(state.modules): entry = self.tasks.get(name) or {} tensors: List[torch.Tensor] = entry.get("inputs", []) # type: ignore[arg-type] if not tensors: @@ -188,6 +189,7 @@ def _layer_input_features(self, state: _AWQLayerState) -> Dict[str, torch.Tensor continue try: features[name] = torch.cat(tensors, dim=0) + entry["inputs"] = [features[name]] except RuntimeError: features[name] = tensors[0] root = name.split(".", 1)[0] @@ -576,38 +578,57 @@ def _search_best_scale( inp = inp.to(next(module2inspect.parameters()).device) # [STEP 1]: Compute per-channel mean of normalised weights - # All layer weights are concatted together - weight = torch.cat([_m.weight for _m in layers], dim=0) - org_shape = weight.shape - # The weights are reshaped to be organised by quantization group - weight = weight.view(-1, self.qcfg.group_size) - # Calculates the relative magnitude of the weights within each of the quantization groups, - # and rescales each group individually so that each group has weights on a 0-1 scale. - w_scale = weight.abs() / (weight.abs().amax(dim=1, keepdim=True) + 1e-6) - # Resizes the rescaled weight matrix back up to its original dimensions - w_scale = w_scale.view(org_shape) - # Gets the average rescaled magnitude for each output channel - w_mean = w_scale.mean(0) - del weight + # Accumulate statistics per-layer to avoid concatenating large tensors + # (original implementation materialized a giant cat() that doubled VRAM usage) + first_weight = layers[0].weight + weight_dtype = first_weight.dtype + weight_device = first_weight.device + num_channels = first_weight.shape[1] + w_sum = torch.zeros(num_channels, dtype=torch.float32, device=weight_device) + row_count = 0 + + for layer in layers: + weight = layer.weight + if weight.shape[1] != num_channels: + raise ValueError( + f"Expected consistent in_features across layers ({num_channels}), " + f"got {weight.shape[1]} for layer {layer}." + ) + org_shape = weight.shape + weight_abs = weight.abs() + weight_group = weight_abs.view(-1, self.qcfg.group_size) + group_scale = weight_group.amax(dim=1, keepdim=True) + 1e-6 + normalized = weight_group / group_scale + normalized = normalized.view(org_shape) + w_sum += normalized.sum(dim=0, dtype=torch.float32) + row_count += org_shape[0] + + if row_count == 0: + w_mean = torch.zeros(num_channels, dtype=weight_dtype, device=weight_device) + else: + w_mean = (w_sum / row_count).to(weight_dtype) # [STEP 2]: Compute per-channel mean of the input activation with chunking - # move inp to cpu to avoid memory leak - inp_flat = inp.cpu().abs().view(-1, inp.shape[-1]) + # Stream directly on the source device to avoid creating full CPU copies + inp_flat = inp.abs().view(-1, inp.shape[-1]) num_elements = inp_flat.size(0) num_channels = inp_flat.size(1) - element_size_bytes = inp_flat.element_size() * 2 # multiplied by 2 for FP32 + float32_size = torch.tensor([], dtype=torch.float32).element_size() + element_size_bytes = float32_size # accumulation happens in FP32 # Calculate chunk size dynamically based on max_chunk_memory chunk_size = int(self.max_chunk_memory // (element_size_bytes * num_channels)) chunk_size = min(chunk_size, num_elements) + chunk_size = max(chunk_size, 1) # Use float32 for sum calculation x_sum = torch.zeros(num_channels, dtype=torch.float32, device=inp.device) for i in range(0, num_elements, chunk_size): end = min(i + chunk_size, num_elements) - chunk_sum = inp_flat[i:end].to(torch.float32).sum(dim=0) - x_sum += chunk_sum.to(inp.device) + chunk = inp_flat[i:end] + chunk_sum = chunk.to(torch.float32).sum(dim=0) + x_sum += chunk_sum x_mean = (x_sum / num_elements).to(inp.dtype) del x_sum @@ -683,6 +704,11 @@ def _compute_best_clip( assert org_w_shape[0] % oc_batch_size == 0 w_all = w best_max_val_all = [] + device = w_all.device + # Pre-allocate scratch buffers so the inner loop never allocates large temporaries + scratch_clamp = torch.empty_like(w_all[:oc_batch_size]) + scratch_quant = torch.empty_like(scratch_clamp) + input_feat = input_feat.to(device) for i_b in range(org_w_shape[0] // oc_batch_size): w = w_all[i_b * oc_batch_size: (i_b + 1) * oc_batch_size] @@ -691,20 +717,19 @@ def _compute_best_clip( best_max_val = org_max_val.clone() min_errs = torch.ones_like(org_max_val) * 1e9 - input_feat = input_feat.to(w.device) - org_out = (input_feat * w).sum(dim=-1) # co, n_token, n_group + clamp_slice = scratch_clamp[: w.shape[0]] + quant_slice = scratch_quant[: w.shape[0]] + + org_out = (input_feat * w).sum(dim=-1) for i_s in range(int(max_shrink * n_grid)): max_val = org_max_val * (1 - i_s / n_grid) min_val = -max_val - cur_w = torch.clamp(w, min_val, max_val) - q_w = self.pseudo_quantize_tensor(cur_w)[0] - cur_out = (input_feat * q_w).sum(dim=-1) + torch.clamp(w, min_val, max_val, out=clamp_slice) + self._pseudo_quantize_tensor_into(clamp_slice, quant_slice) + cur_out = (input_feat * quant_slice).sum(dim=-1) - # co, 1, n_group, 1 err = (cur_out - org_out).pow(2).mean(dim=1).view(min_errs.shape) - del cur_w - del cur_out cur_best_idx = err < min_errs min_errs[cur_best_idx] = err[cur_best_idx] best_max_val[cur_best_idx] = max_val[cur_best_idx] @@ -753,6 +778,45 @@ def pseudo_quantize_tensor(self, w: torch.Tensor): return w, scales, zeros + @torch.inference_mode() + def _pseudo_quantize_tensor_into(self, src: torch.Tensor, dst: torch.Tensor) -> None: + # Quantize `src` into `dst` without allocating a new tensor (mirrors pseudo_quantize_tensor) + org_shape = src.shape + if self.qcfg.group_size > 0: + src_view = src.view(-1, self.qcfg.group_size) + dst_view = dst.view(-1, self.qcfg.group_size) + else: + src_view = src.reshape(org_shape[0], -1) + dst_view = dst.reshape_as(src_view) + + if self.qcfg.zero_point: + max_val = src_view.amax(dim=1, keepdim=True) + min_val = src_view.amin(dim=1, keepdim=True) + max_int = 2 ** self.qcfg.bits - 1 + min_int = 0 + scales = (max_val - min_val).clamp(min=1e-5) / max_int + zeros = (-torch.round(min_val / scales)).clamp_(min_int, max_int) + + dst_view.copy_(src_view) + dst_view.div_(scales) + torch.round(dst_view, out=dst_view) + dst_view.add_(zeros) + dst_view.clamp_(min_int, max_int) + dst_view.sub_(zeros) + dst_view.mul_(scales) + else: + max_val = src_view.abs().amax(dim=1, keepdim=True).clamp(min=1e-5) + max_int = 2 ** (self.qcfg.bits - 1) - 1 + min_int = -(2 ** (self.qcfg.bits - 1)) + scales = max_val / max_int + + dst_view.copy_(src_view) + dst_view.div_(scales) + torch.round(dst_view, out=dst_view) + dst_view.clamp_(min_int, max_int) + dst_view.mul_(scales) + + def _compute_best_scale( self, x: torch.Tensor, @@ -778,7 +842,12 @@ def _compute_best_scale( best_scales = None best_error = float("inf") - org_sd = {k: v.cpu() for k, v in module2inspect.state_dict().items()} + # Clone the original FP weights to CPU once so we can mutate/restore without load_state_dict overhead + orig_weights_cpu: Dict[nn.Linear, torch.Tensor] = { + # stash a contiguous FP32 master copy on CPU; avoids tying up GPU memory between ratios + fc: fc.weight.detach().to(torch.float32).cpu().contiguous() + for fc in linears2scale + } device = x.device x_mean = x_mean.view(-1).to(device) @@ -807,9 +876,8 @@ def _compute_best_scale( # Q(W * s) for fc in linears2scale: fc.weight.mul_(scales_view) - fc.weight.data = ( - self.pseudo_quantize_tensor(fc.weight.data)[0] / scales_view - ) + self._pseudo_quantize_tensor_into(fc.weight, fc.weight) + fc.weight.div_(scales_view) # W * X int_w_output = self._module_forward(x, module2inspect, kwargs) @@ -823,7 +891,12 @@ def _compute_best_scale( best_error = loss best_ratio = ratio best_scales = scales.clone() - module2inspect.load_state_dict(org_sd) + for fc in linears2scale: + fc.weight.copy_(orig_weights_cpu[fc].to(device=fc.weight.device, dtype=fc.weight.dtype)) + + for fc in linears2scale: + fc.weight.copy_(orig_weights_cpu[fc].to(device=fc.weight.device, dtype=fc.weight.dtype)) + orig_weights_cpu.clear() if best_ratio == -1: log.debug(history) diff --git a/tests/test_awq.py b/tests/test_awq.py index 38e6d46ea..6230a3de5 100644 --- a/tests/test_awq.py +++ b/tests/test_awq.py @@ -57,12 +57,13 @@ def setUpClass(cls): except Exception: total_mem_gb = 0 - if total_mem_gb >= 80: - sample_count = 1024 - elif total_mem_gb >= 48: - sample_count = 512 - else: - sample_count = 192 + # if total_mem_gb >= 80: + # sample_count = 1024 + # elif total_mem_gb >= 48: + # sample_count = 512 + # else: + # sample_count = 192 + sample_count = 512 traindata = load_dataset("json", data_files="/monster/data/model/dataset/c4-train.00000-of-01024.json.gz", split="train") diff --git a/tests/test_awq_clip_consistency.py b/tests/test_awq_clip_consistency.py new file mode 100644 index 000000000..7b45abe0f --- /dev/null +++ b/tests/test_awq_clip_consistency.py @@ -0,0 +1,84 @@ +import torch +from parameterized import parameterized + +from gptqmodel.looper.awq_processor import AWQProcessor +from gptqmodel.quantization.config import QuantizeConfig + + +class _ClipTestAWQProcessor(AWQProcessor): + def __init__(self, qcfg: QuantizeConfig) -> None: + super().__init__( + tokenizer=None, + qcfg=qcfg, + calibration=None, + prepare_dataset_func=None, + calibration_concat_size=None, + calibration_sort=None, + batch_size=1, + gptq_model=None, + model=None, + require_fwd=True, + calculate_w_wq_diff=False, + calibration_concat_separator=None, + ) + + def _module_forward(self, x, module, module_kwargs): + return module(x) + + +def _legacy_clip(processor: AWQProcessor, w: torch.Tensor, input_feat: torch.Tensor): + group_size = processor.qcfg.group_size if processor.qcfg.group_size > 0 else w.shape[1] + input_feat = input_feat.view(-1, input_feat.shape[-1]) + input_feat = input_feat.reshape(1, input_feat.shape[0], -1, group_size) + step_size = max(1, input_feat.shape[1] // 512) + input_feat = input_feat[:, ::step_size] + + w = w.reshape(w.shape[0], 1, -1, group_size) + oc_batch_size = 256 if w.shape[0] % 256 == 0 else 64 + assert w.shape[0] % oc_batch_size == 0 + best_max_val_all = [] + for i_b in range(w.shape[0] // oc_batch_size): + w_chunk = w[i_b * oc_batch_size: (i_b + 1) * oc_batch_size] + org_max_val = w_chunk.abs().amax(dim=-1, keepdim=True) + best_max_val = org_max_val.clone() + min_errs = torch.ones_like(org_max_val) * 1e9 + input_feat = input_feat.to(w_chunk.device) + org_out = (input_feat * w_chunk).sum(dim=-1) + for i_s in range(int(0.5 * 20)): + max_val = org_max_val * (1 - i_s / 20) + min_val = -max_val + cur_w = torch.clamp(w_chunk, min_val, max_val) + q_w = processor.pseudo_quantize_tensor(cur_w)[0] + cur_out = (input_feat * q_w).sum(dim=-1) + err = (cur_out - org_out).pow(2).mean(dim=1).view(min_errs.shape) + cur_best_idx = err < min_errs + min_errs[cur_best_idx] = err[cur_best_idx] + best_max_val[cur_best_idx] = max_val[cur_best_idx] + best_max_val_all.append(best_max_val) + return torch.cat(best_max_val_all, dim=0).squeeze(1) + + +@parameterized.expand([ + ("cpu", "cpu"), + ("cuda", "cuda:0"), +]) +def test_awq_clip_consistency(device_name: str, device_str: str): + if device_name == "cuda" and not torch.cuda.is_available(): + raise AssertionError("CUDA is not available for clip consistency test") + + dtype = torch.float32 if device_name == "cpu" else torch.float16 + processor = _ClipTestAWQProcessor(QuantizeConfig(group_size=128)) + + out_features = 256 + in_features = 3584 + w = torch.randn(out_features, in_features, dtype=dtype, device=device_str) + tokens = 1024 + input_feat = torch.randn(tokens, in_features, dtype=dtype, device=device_str) + + # Compare the streaming implementation against the legacy tensor-per-iter path + expected = _legacy_clip(processor, w.clone(), input_feat.clone()) + actual = processor._compute_best_clip(w, input_feat) + + tol = 1e-6 if dtype == torch.float32 else 1e-4 + assert torch.allclose(actual.cpu(), expected.cpu(), atol=tol, rtol=tol), \ + f"Inconsistent clip: max diff {(actual - expected).abs().max().item():.3e}" diff --git a/tests/test_awq_moe.py b/tests/test_awq_moe.py index 717e2fbe5..9f4371758 100644 --- a/tests/test_awq_moe.py +++ b/tests/test_awq_moe.py @@ -37,7 +37,7 @@ def setUpClass(self): self.tokenizer = AutoTokenizer.from_pretrained(self.pretrained_model_id, use_fast=True) traindata = load_dataset("json", data_files="/monster/data/model/dataset/c4-train.00000-of-01024.json.gz", split="train") - self.calibration_dataset = traindata.select(range(4096)) + self.calibration_dataset = traindata.select(range(512)) # def test_load_group_128(self): # model = GPTQModel.load( diff --git a/tests/test_awq_weight_mean.py b/tests/test_awq_weight_mean.py new file mode 100644 index 000000000..c8b541750 --- /dev/null +++ b/tests/test_awq_weight_mean.py @@ -0,0 +1,335 @@ +import os + +os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" +os.environ["PYTORCH_ALLOC_CONF"] = "expandable_segments:True,max_split_size_mb:256,garbage_collection_threshold:0.7" #"expandable_segments:True" + +import time +import torch +import pytest + +from parameterized import parameterized +from pytest import MonkeyPatch +from torch import nn + +from gptqmodel.looper.awq_processor import AWQProcessor +from gptqmodel.quantization.config import QuantizeConfig + + +QWEN3_HIDDEN_SIZE = 3584 + + +def _compute_legacy_w_mean(layers, group_size): + weights = [layer.weight.detach().to(torch.float32).cpu() for layer in layers] + weight = torch.cat(weights, dim=0) + org_shape = weight.shape + weight = weight.view(-1, group_size) + w_scale = weight.abs() / (weight.abs().amax(dim=1, keepdim=True) + 1e-6) + w_scale = w_scale.view(org_shape) + return w_scale.mean(0) + +def _compute_fast_w_mean(layers, group_size): + first_weight = layers[0].weight + num_channels = first_weight.shape[1] + device = first_weight.device + dtype = first_weight.dtype + w_sum = torch.zeros(num_channels, dtype=torch.float32, device=device) + row_count = 0 + + for layer in layers: + weight = layer.weight + org_shape = weight.shape + weight_abs = weight.abs() + weight_group = weight_abs.view(-1, group_size) + group_scale = weight_group.amax(dim=1, keepdim=True) + 1e-6 + normalized = (weight_group / group_scale).view(org_shape) + w_sum += normalized.sum(dim=0, dtype=torch.float32) + row_count += org_shape[0] + + if row_count == 0: + return torch.zeros(num_channels, dtype=dtype, device=device) + return (w_sum / row_count).to(dtype) + + +def _compute_fast_w_mean_multi(layer_groups, group_size): + total_sum = None + total_rows = 0 + for layers in layer_groups: + first_weight = layers[0].weight + device = first_weight.device + num_channels = first_weight.shape[1] + w_sum = torch.zeros(num_channels, dtype=torch.float32, device=device) + rows = 0 + for layer in layers: + weight = layer.weight + org_shape = weight.shape + weight_abs = weight.abs() + weight_group = weight_abs.view(-1, group_size) + group_scale = weight_group.amax(dim=1, keepdim=True) + 1e-6 + normalized = (weight_group / group_scale).view(org_shape) + w_sum += normalized.sum(dim=0, dtype=torch.float32) + rows += org_shape[0] + if total_sum is None: + total_sum = w_sum.cpu() + else: + total_sum += w_sum.cpu() + total_rows += rows + return (total_sum / total_rows) + + +class _DummyQwen3SelfAttention(nn.Module): + def __init__(self, hidden_size: int, device: str, dtype: torch.dtype) -> None: + super().__init__() + self.q_proj = nn.Linear(hidden_size, hidden_size, bias=False, device=device, dtype=dtype) + self.k_proj = nn.Linear(hidden_size, hidden_size, bias=False, device=device, dtype=dtype) + self.v_proj = nn.Linear(hidden_size, hidden_size, bias=False, device=device, dtype=dtype) + + +class _TestAWQProcessor(AWQProcessor): + def __init__(self, qcfg: QuantizeConfig): + super().__init__( + tokenizer=None, + qcfg=qcfg, + calibration=None, + prepare_dataset_func=None, + calibration_concat_size=None, + calibration_sort=None, + batch_size=1, + gptq_model=None, + model=None, + require_fwd=True, + calculate_w_wq_diff=False, + calibration_concat_separator=None, + ) + + def _module_forward(self, x: torch.Tensor, module: torch.nn.Module, module_kwargs): + return module(x) + + +@parameterized.expand([ + ("cpu_gs32", "cpu", 32), + ("cpu_gs64", "cpu", 64), + ("cpu_gs128", "cpu", 128), + ("cuda4_gs32", "cuda:4", 32), + ("cuda4_gs64", "cuda:4", 64), + ("cuda4_gs128", "cuda:4", 128), + ("cuda4_cuda5_gs128", ("cuda:4", "cuda:5"), 128), +]) +def test_awq_weight_mean_matches_legacy_impl(param_name, device, group_size): + if isinstance(device, (list, tuple)): + devices = list(device) + for dev in devices: + if not torch.cuda.is_available() or torch.device(dev).index >= torch.cuda.device_count(): + pytest.skip(f"{dev} is not available") + elif isinstance(device, str) and device.startswith("cuda"): + if not torch.cuda.is_available(): + pytest.skip("CUDA is not available for this test run.") + + torch.manual_seed(0) + if isinstance(device, (list, tuple)): + dtype = torch.float16 + layer_groups = [] + for dev in device: + layer_groups.append([ + nn.Linear(QWEN3_HIDDEN_SIZE, QWEN3_HIDDEN_SIZE, bias=False, device=dev, dtype=dtype) + for _ in range(3) + ]) + + baseline_layers = [layer for group in layer_groups for layer in group] + baseline = _compute_legacy_w_mean(baseline_layers, group_size) + fast = _compute_fast_w_mean_multi(layer_groups, group_size) + fast = fast.to(baseline.dtype) + + # Accuracy table + abs_diff = (fast - baseline).abs() + with torch.no_grad(): + safe_baseline = torch.where(baseline == 0, torch.ones_like(baseline), baseline) + rel_diff = abs_diff / safe_baseline.abs() + max_abs_diff = abs_diff.max().item() + max_rel_diff = rel_diff.max().item() + + header = f"{'Metric':<20}{'Measured':<20}{'Tolerance':<20}" + separator = "-" * len(header) + print(f"AWQ weight mean comparison (fast vs baseline) [{param_name}]") + print(separator) + print(header) + print(separator) + atol = 5e-4 + rtol = 1e-3 + print(f"{'max_abs_diff':<20}{max_abs_diff:<20.6e}{atol:<20.6e}") + print(f"{'max_rel_diff':<20}{max_rel_diff:<20.6e}{rtol:<20.6e}") + print(separator) + assert torch.allclose(fast, baseline, rtol=rtol, atol=atol) + + # Timing comparison + def _time_it(fn, runs=5, warmup=2): + for _ in range(warmup): + fn() + torch.cuda.synchronize(torch.device(device[0]).index) + start = time.perf_counter() + for _ in range(runs): + fn() + torch.cuda.synchronize(torch.device(device[0]).index) + return (time.perf_counter() - start) / runs + + def fast_fn(): + _ = _compute_fast_w_mean_multi(layer_groups, group_size) + + def legacy_fn(): + _ = _compute_legacy_w_mean(baseline_layers, group_size) + + fast_time = _time_it(fast_fn) + legacy_time = _time_it(legacy_fn) + + GREEN = "\033[32m" + RED = "\033[31m" + YELLOW = "\033[33m" + RESET = "\033[0m" + + delta_ms = (fast_time - legacy_time) * 1e3 + rel = (fast_time / legacy_time) if legacy_time > 0 else float("inf") + if rel <= 1.0: + color = GREEN + verdict = "faster" + elif rel <= 1.05: + color = YELLOW + verdict = "≈ parity" + else: + color = RED + verdict = "slower" + + print(f"AWQ weight mean timing [{param_name}]") + print("+----------------+------------+--------------+-------------+---------------+") + print("| Metric | Fast (ms) | Legacy (ms) | Delta (ms) | Relative |") + print("+----------------+------------+--------------+-------------+---------------+") + print( + f"| runtime | {fast_time*1e3:10.3f} | {legacy_time*1e3:12.3f} | " + f"{delta_ms:11.3f} | {color}{rel:>11.3%} {verdict:<7}{RESET}|" + ) + print("+----------------+------------+--------------+-------------+---------------+") + + assert fast_time <= legacy_time * 1.05, ( + f"Streaming mean slower than legacy for {param_name}: " + f"{fast_time*1e3:.3f} ms vs {legacy_time*1e3:.3f} ms" + ) + return + + device_str = device + dtype = torch.float16 if device_str.startswith("cuda") else torch.float32 + + attn = _DummyQwen3SelfAttention(QWEN3_HIDDEN_SIZE, device_str, dtype) + layers = [attn.q_proj, attn.k_proj, attn.v_proj] + + batch_size = 4 + inp = torch.randn(batch_size, QWEN3_HIDDEN_SIZE, device=device_str, dtype=dtype) + + processor = _TestAWQProcessor(QuantizeConfig(group_size=group_size)) + + captured = {} + + def fake_compute_best_scale( + self, + _inp, + w_mean, + x_mean, + module2inspect, + layers_arg, + fp16_output, + module_kwargs, + ): + captured["fast"] = w_mean.detach().to(torch.float32).cpu() + captured["baseline"] = ( + _compute_legacy_w_mean(layers_arg, self.qcfg.group_size).detach().to(torch.float32).cpu() + ) + return torch.ones_like(w_mean, dtype=w_mean.dtype).detach().cpu(), 0.0 + + monkey_patcher = MonkeyPatch() + monkey_patcher.setattr(AWQProcessor, "_compute_best_scale", fake_compute_best_scale) + + try: + processor._search_best_scale( + attn, + layers[0], + layers, + inp, + module2inspect=layers[0], + kwargs={}, + ) + finally: + monkey_patcher.undo() + + assert "fast" in captured and "baseline" in captured + if dtype == torch.float32: + atol = 2e-7 + rtol = 2e-7 + else: + atol = 5e-4 + rtol = 1e-3 + fast = captured["fast"] + baseline = captured["baseline"] + + abs_diff = (fast - baseline).abs() + with torch.no_grad(): + safe_baseline = torch.where(baseline == 0, torch.ones_like(baseline), baseline) + rel_diff = abs_diff / safe_baseline.abs() + + max_abs_diff = abs_diff.max().item() + max_rel_diff = rel_diff.max().item() + + header = f"{'Metric':<20}{'Measured':<20}{'Tolerance':<20}" + separator = "-" * len(header) + print(f"AWQ weight mean comparison (fast vs baseline) [{param_name}]") + print(separator) + print(header) + print(separator) + print(f"{'max_abs_diff':<20}{max_abs_diff:<20.6e}{atol:<20.6e}") + print(f"{'max_rel_diff':<20}{max_rel_diff:<20.6e}{rtol:<20.6e}") + print(separator) + + assert torch.allclose(fast, baseline, rtol=rtol, atol=atol) + + def _time_it(fn, runs=5, warmup=2): + for _ in range(warmup): + fn() + if device == "cuda": + torch.cuda.synchronize(device_str) + start = time.perf_counter() + for _ in range(runs): + fn() + if device == "cuda": + torch.cuda.synchronize(device_str) + return (time.perf_counter() - start) / runs + + fast_time = _time_it(lambda: _compute_fast_w_mean(layers, group_size)) + legacy_time = _time_it(lambda: _compute_legacy_w_mean(layers, group_size)) + + GREEN = "\033[32m" + RED = "\033[31m" + YELLOW = "\033[33m" + RESET = "\033[0m" + + delta_ms = (fast_time - legacy_time) * 1e3 + rel = (fast_time / legacy_time) if legacy_time > 0 else float("inf") + if rel <= 1.0: + color = GREEN + verdict = "faster" + elif rel <= 1.05: + color = YELLOW + verdict = "≈ parity" + else: + color = RED + verdict = "slower" + + print(f"AWQ weight mean timing [{param_name}]") + print("+----------------+------------+--------------+-------------+---------------+") + print("| Metric | Fast (ms) | Legacy (ms) | Delta (ms) | Relative |") + print("+----------------+------------+--------------+-------------+---------------+") + print( + f"| runtime | {fast_time*1e3:10.3f} | {legacy_time*1e3:12.3f} | " + f"{delta_ms:11.3f} | {color}{rel:>11.3%} {verdict:<7}{RESET}|" + ) + print("+----------------+------------+--------------+-------------+---------------+") + + assert fast_time <= legacy_time * 1.05, ( + f"Streaming mean slower than legacy for {param_name}: " + f"{fast_time*1e3:.3f} ms vs {legacy_time*1e3:.3f} ms" + )