From f18e6ea4ad2abf0e1333b738bc3903c25000166a Mon Sep 17 00:00:00 2001 From: Qubitium Date: Mon, 27 Oct 2025 13:01:38 +0000 Subject: [PATCH 01/10] f4_e4m1 --- gptqmodel/quantization/dtype.py | 90 +++++++++++++++++++++++++++++++++ tests/test_fp4_llama3_fp4.py | 68 +++++++++++++++++++++++++ tests/test_quant_dtype.py | 59 +++++++++++++++++++++ 3 files changed, 217 insertions(+) create mode 100644 tests/test_fp4_llama3_fp4.py diff --git a/gptqmodel/quantization/dtype.py b/gptqmodel/quantization/dtype.py index 05430e849..195f4490f 100644 --- a/gptqmodel/quantization/dtype.py +++ b/gptqmodel/quantization/dtype.py @@ -11,9 +11,16 @@ import torch +try: + from torchao.prototype.mx_formats.kernels import unpack_uint4, f4_unpacked_to_f32 +except Exception: + unpack_uint4 = None + f4_unpacked_to_f32 = None + __all__ = [ "device_supports_native_fp8", "dequantize_f8_e4m3", + "dequantize_f4_e2m1", ] @@ -162,3 +169,86 @@ def _expand_scale(scale_tensor: torch.Tensor, *, axis_hint: Optional[int]) -> to result = result / scale_tensor return result + + +def dequantize_f4_e2m1( + tensor: torch.Tensor, + *, + scale: Optional[torch.Tensor] = None, + scale_inv: Optional[torch.Tensor] = None, + axis: Optional[int] = 0, + target_dtype: torch.dtype = torch.bfloat16, +) -> torch.Tensor: + """Dequantize FP4 (E2M1) values packed as two nibbles per byte.""" + + if unpack_uint4 is None or f4_unpacked_to_f32 is None: + raise RuntimeError("torchao with nvfp4 support is required for FP4 dequantization") + + if scale is not None and scale_inv is not None: + raise ValueError("Provide either scale or scale_inv, not both") + + if tensor.dtype is not torch.uint8: + raise ValueError("FP4 packed tensors must use torch.uint8 storage") + + orig_shape = list(tensor.shape) + if not orig_shape: + raise ValueError("Tensor must have at least one dimension") + + unpacked = unpack_uint4(tensor.reshape(-1)) + expanded_shape = orig_shape[:-1] + [orig_shape[-1] * 2] + unpacked = unpacked.view(*expanded_shape) + + result = f4_unpacked_to_f32(unpacked).to(target_dtype) + + def _expand_scale_fp4(scale_tensor: torch.Tensor, *, axis_hint: Optional[int]) -> torch.Tensor: + if scale_tensor.ndim == 0: + return scale_tensor + + target_shape = result.shape + + if scale_tensor.shape == target_shape: + return scale_tensor + + if scale_tensor.ndim == 2 and len(target_shape) == 2: + blocks_r, blocks_c = scale_tensor.shape + rows, cols = target_shape + if rows % blocks_r == 0 and cols % blocks_c == 0: + repeat_r = rows // blocks_r + repeat_c = cols // blocks_c + expanded = scale_tensor.repeat_interleave(repeat_r, dim=0) + expanded = expanded.repeat_interleave(repeat_c, dim=1) + return expanded + + if scale_tensor.ndim == result.ndim: + expanded = scale_tensor + for dim, (target_size, current_size) in enumerate(zip(result.shape, expanded.shape)): + if target_size == current_size: + continue + if current_size == 1: + expanded = expanded.expand(*[ + target_size if i == dim else expanded.shape[i] + for i in range(expanded.ndim) + ]) + continue + if target_size % current_size != 0: + raise ValueError( + f"Cannot broadcast scale dimension {current_size} to target {target_size}" + ) + repeat = target_size // current_size + expanded = expanded.repeat_interleave(repeat, dim=dim) + return expanded + + reshaped = _reshape_for_axis(scale_tensor, axis_hint, result.ndim) + return reshaped.expand(result.shape) + + if scale is not None: + scale_tensor = _expand_scale_fp4(scale.to(result.dtype), axis_hint=axis) + result = result * scale_tensor + elif scale_inv is not None: + scale_tensor = _expand_scale_fp4(scale_inv.to(result.dtype), axis_hint=axis) + if torch.max(torch.abs(scale_tensor)) <= 1: + result = result * scale_tensor + else: + result = result / scale_tensor + + return result diff --git a/tests/test_fp4_llama3_fp4.py b/tests/test_fp4_llama3_fp4.py new file mode 100644 index 000000000..1131f2ee9 --- /dev/null +++ b/tests/test_fp4_llama3_fp4.py @@ -0,0 +1,68 @@ +# SPDX-FileCopyrightText: 2024-2025 ModelCloud.ai +# SPDX-FileCopyrightText: 2024-2025 qubitium@modelcloud.ai +# SPDX-License-Identifier: Apache-2.0 +# Contact: qubitium@modelcloud.ai, x.com/qubitium + +import json +from pathlib import Path + +import pytest +import torch +from safetensors import safe_open + +from gptqmodel.quantization.dtype import dequantize_f4_e2m1 + +try: + from torchao.prototype.mx_formats.nvfp4_tensor import NVFP4Tensor +except Exception: + NVFP4Tensor = None + + +MODEL_DIR = Path("/monster/data/model/Llama-3.3-70B-Instruct-FP4") + + +@pytest.mark.skipif(NVFP4Tensor is None, reason="torchao NVFP4 support required") +@pytest.mark.skipif(not MODEL_DIR.exists(), reason="Llama-3.3 FP4 model not available") +def test_fp4_llama3_module_dequant_matches_nvfp4_tensor(): + index = json.loads((MODEL_DIR / "model.safetensors.index.json").read_text()) + shard = sorted(set(index["weight_map"].values()))[0] + + with safe_open(MODEL_DIR / shard, framework="pt", device="cpu") as f: + weight = f.get_tensor("model.layers.0.mlp.down_proj.weight") + scales = f.get_tensor("model.layers.0.mlp.down_proj.weight_scale") + + dequant = dequantize_f4_e2m1(weight, scale=scales, axis=None, target_dtype=torch.bfloat16) + + nv_tensor = NVFP4Tensor(weight, scales, block_size=16, orig_dtype=torch.bfloat16) + expected = nv_tensor.to_dtype(torch.bfloat16) + + diff = torch.max(torch.abs(dequant - expected)).item() + assert torch.allclose(dequant, expected, atol=1e-3, rtol=1e-3), diff + + +@pytest.mark.skipif(NVFP4Tensor is None, reason="torchao NVFP4 support required") +@pytest.mark.skipif(not MODEL_DIR.exists(), reason="Llama-3.3 FP4 model not available") +@pytest.mark.parametrize("device", ["cuda:7", "cuda:8"], ids=["A100", "RTX5090"]) +def test_fp4_llama3_module_gpu_consistency(device): + if not torch.cuda.is_available(): + pytest.skip("CUDA not available") + + dev = torch.device(device) + if dev.index is not None and dev.index >= torch.cuda.device_count(): + pytest.skip(f"CUDA device {device} not accessible") + + index = json.loads((MODEL_DIR / "model.safetensors.index.json").read_text()) + shard = sorted(set(index["weight_map"].values()))[0] + + with safe_open(MODEL_DIR / shard, framework="pt", device="cpu") as f: + weight = f.get_tensor("model.layers.0.mlp.down_proj.weight") + scales = f.get_tensor("model.layers.0.mlp.down_proj.weight_scale") + + cpu = dequantize_f4_e2m1(weight, scale=scales, axis=None, target_dtype=torch.bfloat16) + + torch.cuda.set_device(dev) + gpu_weight = weight.to(dev) + gpu_scales = scales.to(dev) + gpu = dequantize_f4_e2m1(gpu_weight, scale=gpu_scales, axis=None, target_dtype=torch.bfloat16) + + assert torch.allclose(cpu, gpu.cpu(), atol=1e-3, rtol=1e-3) diff --git a/tests/test_quant_dtype.py b/tests/test_quant_dtype.py index ba512ae9b..23c7e8ccf 100644 --- a/tests/test_quant_dtype.py +++ b/tests/test_quant_dtype.py @@ -7,6 +7,7 @@ from gptqmodel.quantization.dtype import ( dequantize_f8_e4m3, + dequantize_f4_e2m1, device_supports_native_fp8, ) @@ -16,6 +17,13 @@ def _print_accuracy(title: str, rows, headers) -> None: print(f"\n{title}\n{table}\n") +try: # pragma: no cover - optional dependency + from torchao.prototype.mx_formats.nvfp4_tensor import NVFP4Tensor, nvfp4_quantize +except Exception: # pragma: no cover + NVFP4Tensor = None + nvfp4_quantize = None + + @pytest.mark.skipif(not hasattr(torch, "float8_e4m3fn"), reason="float8 dtype not available") def test_dequantize_f8_e4m3_basic_conversion(): values = torch.linspace(-1, 1, steps=8, dtype=torch.float32) @@ -106,6 +114,57 @@ def test_device_supports_native_fp8_reports_capability(monkeypatch): assert device_supports_native_fp8(torch.device("cuda", 0)) is False +@pytest.mark.skipif(NVFP4Tensor is None, reason="torchao NVFP4 support required") +def test_dequantize_f4_e2m1_matches_nvfp4tensor(): + torch.manual_seed(0) + data = torch.randn(128, 256, dtype=torch.float32) + scales, packed = nvfp4_quantize(data, block_size=16) + + dequant = dequantize_f4_e2m1(packed, scale=scales, axis=None, target_dtype=torch.bfloat16) + nv_tensor = NVFP4Tensor(packed, scales, block_size=16, orig_dtype=torch.bfloat16) + expected = nv_tensor.to_dtype(torch.bfloat16) + + diff = torch.max(torch.abs(dequant - expected)).item() + _print_accuracy( + "dequantize_f4_e2m1_matches_nvfp4tensor", + [ + ["NVFP4Tensor", str(expected.dtype), float(expected.abs().max().item()), 0.0], + ["dtype_impl", str(dequant.dtype), float(dequant.abs().max().item()), diff], + ], + ["variant", "dtype", "max|value|", "max|diff vs baseline|"], + ) + assert torch.allclose(dequant, expected, atol=1e-3, rtol=1e-3) + + +@pytest.mark.skipif(NVFP4Tensor is None, reason="torchao NVFP4 support required") +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA device required") +def test_dequantize_f4_e2m1_cpu_vs_gpu(): + torch.manual_seed(1) + data = torch.randn(128, 512, dtype=torch.float32) + scales, packed = nvfp4_quantize(data, block_size=16) + + cpu = dequantize_f4_e2m1(packed, scale=scales, axis=None, target_dtype=torch.bfloat16) + + packed_gpu = packed.cuda() + scales_gpu = scales.cuda() + + start = time.perf_counter() + gpu = dequantize_f4_e2m1(packed_gpu, scale=scales_gpu, axis=None, target_dtype=torch.bfloat16) + torch.cuda.synchronize() + gpu_time = time.perf_counter() - start + + diff = torch.max(torch.abs(cpu - gpu.cpu())).item() + _print_accuracy( + "dequantize_f4_e2m1_cpu_vs_gpu", + [ + ["CPU", 0.0, float(cpu.abs().max().item()), 0.0], + ["GPU", gpu_time, float(gpu.abs().max().item()), diff], + ], + ["variant", "time (s)", "max|value|", "max|diff vs baseline|"], + ) + assert torch.allclose(cpu, gpu.cpu(), atol=1e-3, rtol=1e-3) + + @pytest.mark.skipif( not torch.cuda.is_available(), reason="CUDA device required for GPU benchmark", From bafb8cf08a255d6512e071093714ffdac56ee98b Mon Sep 17 00:00:00 2001 From: Qubitium Date: Mon, 27 Oct 2025 13:01:50 +0000 Subject: [PATCH 02/10] ruff --- gptqmodel/quantization/dtype.py | 3 ++- tests/test_fp4_llama3_fp4.py | 1 + tests/test_fp8_minimax2_test.py | 2 -- tests/test_quant_dtype.py | 5 ++--- 4 files changed, 5 insertions(+), 6 deletions(-) diff --git a/gptqmodel/quantization/dtype.py b/gptqmodel/quantization/dtype.py index 195f4490f..f24fe5818 100644 --- a/gptqmodel/quantization/dtype.py +++ b/gptqmodel/quantization/dtype.py @@ -11,8 +11,9 @@ import torch + try: - from torchao.prototype.mx_formats.kernels import unpack_uint4, f4_unpacked_to_f32 + from torchao.prototype.mx_formats.kernels import f4_unpacked_to_f32, unpack_uint4 except Exception: unpack_uint4 = None f4_unpacked_to_f32 = None diff --git a/tests/test_fp4_llama3_fp4.py b/tests/test_fp4_llama3_fp4.py index 1131f2ee9..38cfd1e3b 100644 --- a/tests/test_fp4_llama3_fp4.py +++ b/tests/test_fp4_llama3_fp4.py @@ -12,6 +12,7 @@ from gptqmodel.quantization.dtype import dequantize_f4_e2m1 + try: from torchao.prototype.mx_formats.nvfp4_tensor import NVFP4Tensor except Exception: diff --git a/tests/test_fp8_minimax2_test.py b/tests/test_fp8_minimax2_test.py index 768510817..ceb8f5ca9 100644 --- a/tests/test_fp8_minimax2_test.py +++ b/tests/test_fp8_minimax2_test.py @@ -23,7 +23,6 @@ def test_fp8_weight_dequant_matches_scaled_matmul(): weight_block_size = config.get("quantization_config", {}).get("weight_block_size", [128, 128]) block_rows, block_cols = weight_block_size - weight_name = None scale_inv_name = None weight_tensor = None scale_inv_tensor = None @@ -42,7 +41,6 @@ def test_fp8_weight_dequant_matches_scaled_matmul(): candidate = key + "_scale_inv" if candidate not in f.keys(): continue - weight_name = key scale_inv_name = candidate weight_tensor = tensor scale_inv_tensor = f.get_tensor(scale_inv_name) diff --git a/tests/test_quant_dtype.py b/tests/test_quant_dtype.py index 23c7e8ccf..c21f3564d 100644 --- a/tests/test_quant_dtype.py +++ b/tests/test_quant_dtype.py @@ -1,13 +1,12 @@ import time -import torch - import pytest +import torch from tabulate import tabulate from gptqmodel.quantization.dtype import ( - dequantize_f8_e4m3, dequantize_f4_e2m1, + dequantize_f8_e4m3, device_supports_native_fp8, ) From 638d7c3f928971a2612886ef8b78085de1e0aa5d Mon Sep 17 00:00:00 2001 From: Qubitium Date: Mon, 27 Oct 2025 14:39:29 +0000 Subject: [PATCH 03/10] update --- gptqmodel/utils/model_dequant.py | 452 ++++++++++++++++++++++--------- scripts/dequantize_fp8_model.py | 76 ------ scripts/dequantize_model.py | 113 ++++++++ tests/test_model_dequant.py | 147 ++++++++++ 4 files changed, 579 insertions(+), 209 deletions(-) delete mode 100755 scripts/dequantize_fp8_model.py create mode 100755 scripts/dequantize_model.py create mode 100644 tests/test_model_dequant.py diff --git a/gptqmodel/utils/model_dequant.py b/gptqmodel/utils/model_dequant.py index 1d5c93400..81fa68738 100644 --- a/gptqmodel/utils/model_dequant.py +++ b/gptqmodel/utils/model_dequant.py @@ -3,23 +3,59 @@ # SPDX-License-Identifier: Apache-2.0 # Contact: qubitium@modelcloud.ai, x.com/qubitium -"""Utilities for converting FP8-quantized models to higher precision.""" +"""Safetensor-level dequantization helpers for common quant formats.""" from __future__ import annotations import json import shutil +from collections import defaultdict from pathlib import Path -from typing import Optional +from typing import Dict, Iterable, Optional, Tuple import torch from safetensors import safe_open from safetensors.torch import save_file -from ..quantization.dtype import dequantize_f8_e4m3 +from ..quantization.dtype import dequantize_f4_e2m1, dequantize_f8_e4m3 from ..utils.logger import setup_logger +def _load_json(path: Path) -> dict: + if not path.exists(): + return {} + with path.open("r", encoding="utf-8") as fh: + return json.load(fh) + + +def _write_json(path: Path, payload: dict) -> None: + with path.open("w", encoding="utf-8") as fh: + json.dump(payload, fh, indent=2) + + +def _list_safetensor_files(model_path: Path) -> Tuple[list, Optional[dict]]: + index_path = model_path / "model.safetensors.index.json" + if index_path.exists(): + index = _load_json(index_path) + files = sorted(set(index.get("weight_map", {}).values())) + return files, index + + files = sorted([p.name for p in model_path.glob("*.safetensors")]) + return files, None + + +def _finalize_for_save(tensor: torch.Tensor, target_dtype: torch.dtype) -> torch.Tensor: + """Cast to ``target_dtype`` when floating point and move to CPU with optimal layout.""" + + if torch.is_floating_point(tensor): + tensor = tensor.to(target_dtype) + + tensor_cpu = tensor.to("cpu") + if tensor_cpu.ndim >= 4: + tensor_cpu = tensor_cpu.contiguous(memory_format=torch.channels_last) + return tensor_cpu + + def _normalize_device(device: Optional[str]) -> Optional[str]: if device is None: return None @@ -36,173 +72,323 @@ def _normalize_device(device: Optional[str]) -> Optional[str]: return f"cuda:{dev.index}" -def _load_json(path: Path) -> dict: - if not path.exists(): - return {} - return json.loads(path.read_text()) - - -def _is_fp8_format(config: dict) -> bool: - quant_cfg = config.get("quantization_config", {}) - fmt = quant_cfg.get("fmt") - return fmt == "float8_e4m3fn" - - def _resolve_block_size(config: dict) -> tuple[int, int]: - quant_cfg = config.get("quantization_config", {}) + quant_cfg = config.get("quantization_config", {}) or {} block_size = quant_cfg.get("weight_block_size") if isinstance(block_size, (list, tuple)) and len(block_size) == 2: return int(block_size[0]), int(block_size[1]) return (128, 128) -def _dequantize_shard( - shard_path: Path, - output_path: Path, - *, +def _detect_format(model_path: Path, config: dict) -> str: + quant_cfg = config.get("quantization_config", {}) or {} + method = (quant_cfg.get("quant_method") or "").lower() + fmt = (quant_cfg.get("fmt") or "").lower() + + files, _ = _list_safetensor_files(model_path) + if not files: + raise FileNotFoundError("No .safetensors files found in model directory") + + with safe_open(model_path / files[0], framework="pt", device="cpu") as reader: + keys = list(reader.keys()) + # Prefer dtype-based detection + for key in keys: + if key.endswith(".weight"): + tensor = reader.get_tensor(key) + if tensor.dtype == torch.float8_e4m3fn: + return "fp8" + if tensor.dtype == torch.uint8 and (key + "_scale") in keys: + return "nvfp4" + if any(k.endswith(".weight_scale") for k in keys): + return "nvfp4" + if any(k.endswith(".weight_scale_inv") for k in keys): + return "fp8" + if any(k.endswith(".qweight") for k in keys): + has_g = any(k.endswith(".g_idx") for k in keys) + return "gptq" if has_g else "awq" + + if fmt == "float8_e4m3fn": + return "fp8" + if method in ("gptq", "gptqmodel"): + return "gptq" + if method == "awq": + return "awq" + + raise ValueError("Unable to detect quantization format for model") + + +def _unpack_cols(packed: torch.Tensor, bits: int) -> torch.Tensor: + pack_bits = packed.element_size() * 8 + pack_factor = pack_bits // bits + mask = (1 << bits) - 1 + packed_uint = packed.to(torch.int64) & ((1 << pack_bits) - 1) + rows, cols = packed.shape + result = torch.empty(rows, cols * pack_factor, dtype=torch.int32) + for i in range(pack_factor): + result[:, i::pack_factor] = ((packed_uint >> (i * bits)) & mask).to(torch.int32) + return result + + +def _unpack_rows(packed: torch.Tensor, bits: int) -> torch.Tensor: + pack_bits = packed.element_size() * 8 + pack_factor = pack_bits // bits + mask = (1 << bits) - 1 + packed_uint = packed.to(torch.int64) & ((1 << pack_bits) - 1) + rows, cols = packed.shape + result = torch.empty(rows * pack_factor, cols, dtype=torch.int32) + for i in range(pack_factor): + result[i::pack_factor, :] = ((packed_uint >> (i * bits)) & mask).to(torch.int32) + return result + + +def _convert_fp8_shard( + reader, target_dtype: torch.dtype, + *, block_shape: tuple[int, int], - device: Optional[str], -) -> None: - tensors = {} - open_device = device or "cpu" - - with safe_open(shard_path, framework="pt", device=open_device) as reader: - for name in reader.keys(): - tensor = reader.get_tensor(name) - - if tensor.dtype is torch.float8_e4m3fn: - scale_inv_name = name + "_scale_inv" - if scale_inv_name not in reader.keys(): - tensors[name] = tensor.to(target_dtype).to("cpu") - continue - - rows, cols = tensor.shape - block_rows, block_cols = block_shape - if rows % block_rows != 0 or cols % block_cols != 0: - raise ValueError( - f"Tensor {name} shape {tensor.shape} incompatible with block size {block_shape}" - ) - - scale_inv = reader.get_tensor(scale_inv_name) - deq = dequantize_f8_e4m3( - tensor, - scale_inv=scale_inv, - axis=None, - target_dtype=target_dtype, +) -> Dict[str, torch.Tensor]: + tensors: Dict[str, torch.Tensor] = {} + block_rows, block_cols = block_shape + + for key in reader.keys(): + tensor = reader.get_tensor(key) + if key.endswith(".weight") and tensor.dtype == torch.float8_e4m3fn: + scale_key = key + "_scale_inv" + if scale_key not in reader.keys(): + raise KeyError(f"Missing scale inverse tensor for {key}") + scale_inv = reader.get_tensor(scale_key) + + rows, cols = tensor.shape + if rows % block_rows != 0 or cols % block_cols != 0: + raise ValueError( + f"Tensor {key} shape {tensor.shape} incompatible with block size {block_shape}" ) - if deq.ndimension() >= 4: - tensors[name] = deq.to("cpu", memory_format=torch.channels_last) - else: - tensors[name] = deq.to("cpu") - elif tensor.dtype is torch.uint8 and name.endswith(".weight"): - converted = tensor.to(target_dtype) - if converted.ndimension() >= 4: - tensors[name] = converted.to("cpu", memory_format=torch.channels_last) - else: - tensors[name] = converted.to("cpu") + + deq = dequantize_f8_e4m3( + tensor, + scale_inv=scale_inv, + axis=None, + target_dtype=target_dtype, + ) + tensors[key] = _finalize_for_save(deq, target_dtype) + elif key.endswith("_scale_inv"): + continue + else: + tensors[key] = _finalize_for_save(tensor, target_dtype) + return tensors + + +def _convert_nvfp4_shard(reader, target_dtype: torch.dtype) -> Dict[str, torch.Tensor]: + tensors: Dict[str, torch.Tensor] = {} + for key in reader.keys(): + tensor = reader.get_tensor(key) + if key.endswith(".weight") and tensor.dtype == torch.uint8: + scale_key = key + "_scale" + if scale_key not in reader.keys(): + raise KeyError(f"Missing scale tensor for {key}") + scale = reader.get_tensor(scale_key) + deq = dequantize_f4_e2m1( + tensor, + scale=scale, + axis=None, + target_dtype=target_dtype, + ) + tensors[key] = _finalize_for_save(deq, target_dtype) + elif key.endswith("_weight_scale"): + continue + else: + tensors[key] = _finalize_for_save(tensor, target_dtype) + return tensors + + +def _convert_awq_file(path: Path, target_dtype: torch.dtype, device: str) -> Dict[str, torch.Tensor]: + tensors: Dict[str, torch.Tensor] = {} + module_buffers: Dict[str, Dict[str, torch.Tensor]] = defaultdict(dict) + with safe_open(path, framework="pt", device=device) as reader: + for key in reader.keys(): + tensor = reader.get_tensor(key) + if key.endswith(".qweight"): + prefix = key[:-len(".qweight")] + module_buffers[prefix]["qweight"] = tensor + elif key.endswith(".qzeros"): + prefix = key[:-len(".qzeros")] + module_buffers[prefix]["qzeros"] = tensor + elif key.endswith(".scales"): + prefix = key[:-len(".scales")] + module_buffers[prefix]["scales"] = tensor + else: + tensors[key] = _finalize_for_save(tensor, target_dtype) + + for prefix, buf in module_buffers.items(): + missing = {k for k in ("qweight", "qzeros", "scales") if k not in buf} + if missing: + raise KeyError(f"Incomplete AWQ buffers for module {prefix}: missing {missing}") + + qweight = buf["qweight"] + qzeros = buf["qzeros"] + scales = buf["scales"] + bits = 4 + + unpacked_weight = _unpack_cols(qweight, bits).to(torch.float32) + unpacked_zeros = _unpack_cols(qzeros, bits).to(torch.float32) + + num_groups = scales.shape[0] + group_size = unpacked_weight.shape[0] // num_groups + scales_full = scales.to(torch.float32).repeat_interleave(group_size, dim=0) + zeros_full = unpacked_zeros.repeat_interleave(group_size, dim=0) + + weight = (unpacked_weight - zeros_full) * scales_full + tensors[prefix + ".weight"] = _finalize_for_save( + weight.to(target_dtype).t().contiguous(), + target_dtype, + ) + if prefix + ".bias" in tensors: + tensors[prefix + ".bias"] = _finalize_for_save(tensors[prefix + ".bias"], target_dtype) + + return tensors + + +def _convert_gptq_file(path: Path, target_dtype: torch.dtype, config: dict, device: str) -> Dict[str, torch.Tensor]: + tensors: Dict[str, torch.Tensor] = {} + module_buffers: Dict[str, Dict[str, torch.Tensor]] = defaultdict(dict) + with safe_open(path, framework="pt", device=device) as reader: + for key in reader.keys(): + tensor = reader.get_tensor(key) + if key.endswith(".qweight"): + prefix = key[:-len(".qweight")] + module_buffers[prefix]["qweight"] = tensor + elif key.endswith(".qzeros"): + prefix = key[:-len(".qzeros")] + module_buffers[prefix]["qzeros"] = tensor + elif key.endswith(".scales"): + prefix = key[:-len(".scales")] + module_buffers[prefix]["scales"] = tensor + elif key.endswith(".g_idx"): + prefix = key[:-len(".g_idx")] + module_buffers[prefix]["g_idx"] = tensor else: - if tensor.ndimension() >= 4 and (device is None or tensor.device.type == "cpu"): - tensors[name] = tensor.to("cpu", memory_format=torch.channels_last) - else: - tensors[name] = tensor.to("cpu") + tensors[key] = _finalize_for_save(tensor, target_dtype) + + for prefix, buf in module_buffers.items(): + missing = {k for k in ("qweight", "qzeros", "scales", "g_idx") if k not in buf} + if missing: + raise KeyError(f"Incomplete GPTQ buffers for module {prefix}: missing {missing}") + + qweight = buf["qweight"] + qzeros = buf["qzeros"] + scales = buf["scales"] + g_idx = buf["g_idx"].to(torch.long) + + bits = config.get("bits", 4) + weight_int = _unpack_rows(qweight, bits) + zeros = _unpack_cols(qzeros, bits) + + scales_full = scales.to(torch.float32)[g_idx] + zeros_full = zeros.to(torch.float32)[g_idx] + weight = (weight_int.to(torch.float32) - zeros_full) * scales_full + tensors[prefix + ".weight"] = _finalize_for_save( + weight.to(target_dtype).t().contiguous(), + target_dtype, + ) + if prefix + ".bias" in tensors: + tensors[prefix + ".bias"] = _finalize_for_save(tensors[prefix + ".bias"], target_dtype) + + return tensors - save_file(tensors, str(output_path)) + +def _copy_aux_files(model_path: Path, output_path: Path, skip: Iterable[str]) -> None: + for item in model_path.iterdir(): + if item.name in skip: + continue + target = output_path / item.name + if item.is_dir(): + shutil.copytree(item, target) + else: + shutil.copy2(item, target) -def dequantize( - model_path: str | Path, - model_output_path: str | Path, +def dequantize_model( + model_path: Path | str, + output_path: Path | str, *, target_dtype: torch.dtype = torch.bfloat16, device: Optional[str] = None, ) -> None: - """Dequantize an FP8 E4M3 model into the requested ``target_dtype``. - - Parameters - ---------- - model_path: - Directory containing ``model.safetensors`` shards and ``config.json``. - model_output_path: - Destination directory for the dequantized model. - target_dtype: - Desired floating point dtype (defaults to ``torch.bfloat16``). - """ - model_path = Path(model_path) - model_output_path = Path(model_output_path) + output_path = Path(output_path) - if not model_path.exists(): - raise FileNotFoundError(f"Model path {model_path} does not exist") + if output_path.exists(): + raise FileExistsError(f"Output path {output_path} already exists") - if model_output_path.exists(): - raise FileExistsError(f"Output path {model_output_path} already exists") + output_path.mkdir(parents=True) config = _load_json(model_path / "config.json") - if not _is_fp8_format(config): - raise ValueError("Model does not advertise float8_e4m3fn quantization") + quant_cfg = config.get("quantization_config", {}) or {} + fmt = _detect_format(model_path, config) - block_shape = _resolve_block_size(config) + files, index = _list_safetensor_files(model_path) + if not files: + raise RuntimeError("No safetensor files to convert") device_str = _normalize_device(device) if device_str is not None: if not torch.cuda.is_available(): raise RuntimeError("CUDA is not available for GPU dequantization") torch.cuda.set_device(torch.device(device_str)) + open_device = device_str or "cpu" - output_path = model_output_path - output_path.mkdir(parents=True, exist_ok=False) - - index_path = model_path / "model.safetensors.index.json" - index = _load_json(index_path) - weight_map: dict[str, str] = index.get("weight_map", {}) - - shard_names = sorted(set(weight_map.values())) + block_shape = _resolve_block_size(config) if fmt == "fp8" else None - new_weight_map = {} log = setup_logger() - pb = ( - log.pb(range(len(shard_names))) - .manual() - .set(show_left_steps=False) - .title("Dequantizing FP8 shards") - ) + pb = log.pb(range(len(files))).manual().set(show_left_steps=False).title(f"Dequantizing ({fmt})") pb.draw() - for shard in shard_names: - shard_path = model_path / shard - output_shard = output_path / shard - output_shard.parent.mkdir(parents=True, exist_ok=True) - - _dequantize_shard( - shard_path, - output_shard, - target_dtype=target_dtype, - block_shape=block_shape, - device=device_str, - ) - - new_weight_map.update({name: shard for name, shard in weight_map.items() if weight_map[name] == shard}) - pb.subtitle(shard).next().draw() - - pb.close() - - new_index = dict(index) - new_index["weight_map"] = new_weight_map - (output_path / "model.safetensors.index.json").write_text(json.dumps(new_index, indent=2)) + weight_map: Dict[str, str] = {} + total_size = 0 + + try: + for idx, filename in enumerate(files): + path = model_path / filename + if fmt == "fp8": + with safe_open(path, framework="pt", device=open_device) as reader: + if block_shape is None: + raise RuntimeError("FP8 conversion requires block_shape metadata") + tensors = _convert_fp8_shard(reader, target_dtype, block_shape=block_shape) + elif fmt == "nvfp4": + with safe_open(path, framework="pt", device=open_device) as reader: + tensors = _convert_nvfp4_shard(reader, target_dtype) + elif fmt == "awq": + tensors = _convert_awq_file(path, target_dtype, open_device) + elif fmt == "gptq": + tensors = _convert_gptq_file(path, target_dtype, quant_cfg, open_device) + else: + raise ValueError(f"Unsupported format {fmt}") + + save_file(tensors, str(output_path / filename)) + weight_map.update({str(name): filename for name in tensors}) + total_size += sum(t.element_size() * t.numel() for t in tensors.values()) + pb.subtitle(filename).next().draw() + finally: + pb.close() + + if index is not None: + new_index = dict(index) + else: + new_index = {} + metadata = dict(new_index.get("metadata", {})) + metadata["total_size"] = total_size + new_index["metadata"] = metadata + new_index["weight_map"] = weight_map + _write_json(output_path / "model.safetensors.index.json", new_index) new_config = dict(config) new_config.pop("quantization_config", None) - new_config["torch_dtype"] = target_dtype.__repr__().split(".")[-1] - (output_path / "config.json").write_text(json.dumps(new_config, indent=2)) + new_config["torch_dtype"] = str(target_dtype).split(".")[-1] + _write_json(output_path / "config.json", new_config) - skip_files = {"config.json", "model.safetensors.index.json"}.union(shard_names) + skip_files = set(files) | {"config.json", "model.safetensors.index.json"} + _copy_aux_files(model_path, output_path, skip_files) - for entry in model_path.iterdir(): - if entry.name in skip_files: - continue - target = output_path / entry.name - if entry.is_dir(): - shutil.copytree(entry, target) - else: - shutil.copy2(entry, target) + +# Backwards compatibility with older imports. +dequantize = dequantize_model diff --git a/scripts/dequantize_fp8_model.py b/scripts/dequantize_fp8_model.py deleted file mode 100755 index bad8d474d..000000000 --- a/scripts/dequantize_fp8_model.py +++ /dev/null @@ -1,76 +0,0 @@ -#!/usr/bin/env python -# SPDX-FileCopyrightText: 2024-2025 ModelCloud.ai -# SPDX-License-Identifier: Apache-2.0 - -"""Dequantize an FP8 E4M3 model into BF16 using gptqmodel.utils.model_dequant.""" - -from __future__ import annotations - -import os -os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" -os.environ["PYTORCH_ALLOC_CONF"] = "expandable_segments:True,max_split_size_mb:256,garbage_collection_threshold:0.7" #"expandable_segments:True" - - -import argparse -from pathlib import Path - -import torch - -from gptqmodel.utils.model_dequant import dequantize - - -def parse_args() -> argparse.Namespace: - parser = argparse.ArgumentParser(description="Dequantize FP8 model shards to BF16") - parser.add_argument("model_path", type=Path, nargs="?", help="Path to the FP8 model directory") - parser.add_argument("output_path", type=Path, nargs="?", help="Destination directory for the BF16 model") - parser.add_argument("--model_path", dest="model_path_opt", type=Path, help="Path to the FP8 model directory") - parser.add_argument("--output_path", dest="output_path_opt", type=Path, help="Destination directory for the BF16 model") - parser.add_argument( - "--dtype", - default="bfloat16", - choices=["bfloat16", "float16", "float32"], - help="Output dtype (default: bfloat16)", - ) - parser.add_argument( - "--device", - default="cpu", - help="Device for intermediate dequantization (e.g. cpu, cuda, cuda:0)", - ) - return parser.parse_args() - - -def resolve_dtype(name: str) -> torch.dtype: - if name == "bfloat16": - return torch.bfloat16 - if name == "float16": - return torch.float16 - if name == "float32": - return torch.float32 - raise ValueError(f"Unsupported dtype: {name}") - - -def main() -> None: - args = parse_args() - model_path = args.model_path if args.model_path is not None else args.model_path_opt - output_path = args.output_path if args.output_path is not None else args.output_path_opt - if model_path is None or output_path is None: - raise SystemExit("model_path and output_path must be provided either positionally or via flags") - - dtype = resolve_dtype(args.dtype) - device = None - if args.device is not None and args.device.lower() != "cpu": - device = args.device - print( - "[dequantize_fp8_model] args", - { - "model_path": str(model_path), - "output_path": str(output_path), - "dtype": dtype, - "device": device or "cpu", - }, - ) - dequantize(model_path, output_path, target_dtype=dtype, device=device) - - -if __name__ == "__main__": - main() diff --git a/scripts/dequantize_model.py b/scripts/dequantize_model.py new file mode 100755 index 000000000..337149ba7 --- /dev/null +++ b/scripts/dequantize_model.py @@ -0,0 +1,113 @@ +#!/usr/bin/env python +# SPDX-FileCopyrightText: 2024-2025 ModelCloud.ai +# SPDX-FileCopyrightText: 2024-2025 qubitium@modelcloud.ai +# SPDX-License-Identifier: Apache-2.0 + +"""CLI entry point for dequantizing GPTQModel safetensor shards.""" + +from __future__ import annotations + +import argparse +import os +from pathlib import Path +from typing import Optional + +import torch + +from gptqmodel.utils.model_dequant import dequantize_model + + +def _resolve_dtype(name: str) -> torch.dtype: + mapping = { + "bfloat16": torch.bfloat16, + "float16": torch.float16, + "float32": torch.float32, + } + try: + return mapping[name] + except KeyError as exc: # pragma: no cover - argparse ensures this + raise ValueError(f"Unsupported dtype {name}") from exc + + +def _parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument( + "model_path", + type=Path, + nargs="?", + help="Path to the quantized model directory", + ) + parser.add_argument( + "output_path", + type=Path, + nargs="?", + help="Path where the dequantized model will be written", + ) + parser.add_argument( + "--model-path", + dest="model_path_flag", + type=Path, + help="Explicit model path (overrides positional argument)", + ) + parser.add_argument( + "--output-path", + dest="output_path_flag", + type=Path, + help="Explicit output path (overrides positional argument)", + ) + parser.add_argument( + "--dtype", + choices=("bfloat16", "float16", "float32"), + default="bfloat16", + help="Target floating point dtype (default: bfloat16)", + ) + parser.add_argument( + "--device", + default="cpu", + help="Device to stage tensors during dequantization (cpu, cuda, cuda:7, ...)", + ) + parser.add_argument( + "--env", + action="append", + metavar="KEY=VALUE", + help="Optional environment variable overrides applied before execution", + ) + return parser.parse_args() + + +def _apply_env(overrides: Optional[list[str]]) -> None: + if not overrides: + return + for item in overrides: + if "=" not in item: + raise SystemExit(f"Invalid --env entry {item!r}; expected KEY=VALUE") + key, value = item.split("=", 1) + os.environ[key] = value + + +def main() -> None: + args = _parse_args() + _apply_env(args.env) + + model_path = args.model_path_flag or args.model_path + output_path = args.output_path_flag or args.output_path + + if model_path is None or output_path is None: + raise SystemExit("model_path and output_path must be provided (positionally or via flags)") + + dtype = _resolve_dtype(args.dtype) + device = args.device if args.device is not None else "cpu" + + debug_payload = { + "model_path": str(Path(model_path)), + "output_path": str(Path(output_path)), + "dtype": dtype, + "device": device, + } + print(f"[dequantize_model] parsed args: {debug_payload}") + + dequantize_model(model_path, output_path, target_dtype=dtype, device=device) + + +if __name__ == "__main__": + main() diff --git a/tests/test_model_dequant.py b/tests/test_model_dequant.py new file mode 100644 index 000000000..7a9aab1ab --- /dev/null +++ b/tests/test_model_dequant.py @@ -0,0 +1,147 @@ +import json +from pathlib import Path + +import pytest +import torch +from safetensors import safe_open +from safetensors.torch import save_file + +from gptqmodel.quantization.dtype import dequantize_f8_e4m3 +from gptqmodel.utils.model_dequant import dequantize_model + + +def _pack_cols(values: torch.Tensor, bits: int = 4) -> torch.Tensor: + """Pack per-column low-bit values into int32 words.""" + + if values.dtype != torch.int32: + values = values.to(torch.int32) + + rows, cols = values.shape + pack_factor = 32 // bits + if cols % pack_factor != 0: + raise ValueError("columns must be divisible by pack factor") + + packed_cols = cols // pack_factor + packed = torch.zeros(rows, packed_cols, dtype=torch.int32) + mask = (1 << bits) - 1 + for col in range(cols): + group = col // pack_factor + shift = (col % pack_factor) * bits + packed[:, group] |= (values[:, col] & mask) << shift + return packed + + +def _write_index(path: Path, shard: str, keys: list[str]) -> None: + weight_map = {key: shard for key in keys} + payload = {"weight_map": weight_map} + (path / "model.safetensors.index.json").write_text(json.dumps(payload)) + + +@pytest.mark.skipif(not hasattr(torch, "float8_e4m3fn"), reason="float8 dtype not available") +def test_dequantize_model_fp8(tmp_path): + model_dir = tmp_path / "fp8_model" + output_dir = tmp_path / "fp8_output" + model_dir.mkdir() + + config = { + "architectures": ["TestModel"], + "quantization_config": { + "fmt": "float8_e4m3fn", + "quant_method": "fp8", + "weight_block_size": [2, 4], + }, + } + (model_dir / "config.json").write_text(json.dumps(config)) + + weight = torch.randn(2, 4, dtype=torch.float32).to(torch.float8_e4m3fn) + scale_inv = torch.ones(1, 1, dtype=torch.float32) + shard_name = "model.safetensors" + save_file( + { + "linear.weight": weight, + "linear.weight_scale_inv": scale_inv, + "linear.bias": torch.randn(4, dtype=torch.float32), + }, + str(model_dir / shard_name), + ) + _write_index(model_dir, shard_name, ["linear.weight", "linear.weight_scale_inv", "linear.bias"]) + + dequantize_model(model_dir, output_dir, target_dtype=torch.bfloat16, device="cpu") + + with safe_open(output_dir / shard_name, framework="pt", device="cpu") as reader: + assert "linear.weight" in reader.keys() + assert "linear.weight_scale_inv" not in reader.keys() + weight_out = reader.get_tensor("linear.weight") + bias_out = reader.get_tensor("linear.bias") + + expected = dequantize_f8_e4m3(weight, scale_inv=scale_inv, axis=None, target_dtype=torch.bfloat16) + assert torch.equal(weight_out, expected) + assert bias_out.dtype is torch.bfloat16 + + updated_config = json.loads((output_dir / "config.json").read_text()) + assert "quantization_config" not in updated_config + assert updated_config.get("torch_dtype") == "bfloat16" + + new_index = json.loads((output_dir / "model.safetensors.index.json").read_text()) + assert "linear.weight" in new_index["weight_map"] + assert "linear.weight_scale_inv" not in new_index["weight_map"] + + +def test_dequantize_model_awq(tmp_path): + model_dir = tmp_path / "awq_model" + output_dir = tmp_path / "awq_output" + model_dir.mkdir() + + config = { + "architectures": ["TestModel"], + "quantization_config": { + "quant_method": "awq", + }, + } + (model_dir / "config.json").write_text(json.dumps(config)) + + rows, cols = 8, 16 + weight_values = torch.randint(0, 16, (rows, cols), dtype=torch.int32) + zero_values = torch.randint(0, 16, (rows, cols), dtype=torch.int32) + scales = torch.rand(rows, cols, dtype=torch.float32) * 0.5 + 0.5 + bias = torch.randn(cols, dtype=torch.float32) + + packed_weight = _pack_cols(weight_values) + packed_zero = _pack_cols(zero_values) + + shard_name = "awq.safetensors" + save_file( + { + "layer.qweight": packed_weight, + "layer.qzeros": packed_zero, + "layer.scales": scales, + "layer.bias": bias, + }, + str(model_dir / shard_name), + ) + _write_index( + model_dir, + shard_name, + ["layer.qweight", "layer.qzeros", "layer.scales", "layer.bias"], + ) + + dequantize_model(model_dir, output_dir, target_dtype=torch.bfloat16, device="cpu") + + with safe_open(output_dir / shard_name, framework="pt", device="cpu") as reader: + keys = list(reader.keys()) + assert "layer.weight" in keys + assert "layer.qweight" not in keys + assert "layer.qzeros" not in keys + weight_out = reader.get_tensor("layer.weight") + bias_out = reader.get_tensor("layer.bias") + + expected = ((weight_values.float() - zero_values.float()) * scales).t().contiguous().to(torch.bfloat16) + assert torch.equal(weight_out, expected) + assert bias_out.dtype is torch.bfloat16 + + updated_config = json.loads((output_dir / "config.json").read_text()) + assert "quantization_config" not in updated_config + + new_index = json.loads((output_dir / "model.safetensors.index.json").read_text()) + assert "layer.weight" in new_index["weight_map"] + assert "layer.qweight" not in new_index["weight_map"] From feec9dfb04fa8e9833b78303e7f9ec8aac34699a Mon Sep 17 00:00:00 2001 From: Qubitium Date: Mon, 27 Oct 2025 14:56:27 +0000 Subject: [PATCH 04/10] cleanup --- gptqmodel/utils/model_dequant.py | 171 +++++++++++++++++++++++++++++-- tests/test_model_dequant.py | 37 +++++++ 2 files changed, 200 insertions(+), 8 deletions(-) diff --git a/gptqmodel/utils/model_dequant.py b/gptqmodel/utils/model_dequant.py index 81fa68738..fe088e564 100644 --- a/gptqmodel/utils/model_dequant.py +++ b/gptqmodel/utils/model_dequant.py @@ -8,6 +8,7 @@ from __future__ import annotations import json +import logging import shutil from collections import defaultdict from pathlib import Path @@ -21,6 +22,9 @@ from ..utils.logger import setup_logger +LOG = logging.getLogger(__name__) + + def _load_json(path: Path) -> dict: if not path.exists(): return {} @@ -72,12 +76,106 @@ def _normalize_device(device: Optional[str]) -> Optional[str]: return f"cuda:{dev.index}" -def _resolve_block_size(config: dict) -> tuple[int, int]: +def _resolve_block_size(config: dict) -> Optional[Tuple[int, int]]: quant_cfg = config.get("quantization_config", {}) or {} block_size = quant_cfg.get("weight_block_size") if isinstance(block_size, (list, tuple)) and len(block_size) == 2: return int(block_size[0]), int(block_size[1]) - return (128, 128) + return None + + +def _infer_block_shape(weight_shape: Tuple[int, int], scale_tensor: torch.Tensor) -> Tuple[int, int]: + rows, cols = weight_shape + shape = tuple(scale_tensor.shape) + + if scale_tensor.ndim == 0: + LOG.debug( + "Inferred block size (%d, %d) from scalar scale tensor for weight shape %s", + rows, + cols, + weight_shape, + ) + return rows, cols + + if shape == weight_shape: + LOG.debug( + "Inferred element-wise scaling (block size 1x1) for weight shape %s", + weight_shape, + ) + return 1, 1 + + if scale_tensor.ndim == 2: + block_rows = shape[0] + block_cols = shape[1] + if block_rows == 0 or block_cols == 0: + raise ValueError("scale tensor has zero-sized dimension") + if rows % block_rows != 0 or cols % block_cols != 0: + raise ValueError("scale tensor shape incompatible with weight dimensions") + inferred = (rows // block_rows, cols // block_cols) + LOG.debug( + "Inferred block size %s from 2D scale tensor shape %s and weight shape %s", + inferred, + shape, + weight_shape, + ) + return inferred + + if scale_tensor.ndim == 1: + count = shape[0] + if count == 0: + raise ValueError("scale tensor is empty") + + candidates: list[Tuple[int, int]] = [] + for row_blocks in range(1, count + 1): + if count % row_blocks != 0: + continue + if rows % row_blocks != 0: + continue + col_blocks = count // row_blocks + if cols % col_blocks != 0: + continue + block_rows = rows // row_blocks + block_cols = cols // col_blocks + candidates.append((block_rows, block_cols)) + + if candidates: + candidates.sort( + key=lambda bc: ( + abs(bc[0] - bc[1]), + -min(bc), + -max(bc), + bc[0], + ) + ) + inferred = candidates[0] + LOG.debug( + "Inferred block size %s from 1D scale tensor (count=%d) and weight shape %s", + inferred, + count, + weight_shape, + ) + return inferred + + if rows % count == 0: + inferred = (rows // count, cols) + LOG.debug( + "Inferred row-only scaling block size %s from 1D scale tensor (count=%d)", + inferred, + count, + ) + return inferred + if cols % count == 0: + inferred = (rows, cols // count) + LOG.debug( + "Inferred column-only scaling block size %s from 1D scale tensor (count=%d)", + inferred, + count, + ) + return inferred + + raise ValueError("unable to infer block size from 1D scale tensor") + + raise ValueError("unsupported scale tensor rank for block size inference") def _detect_format(model_path: Path, config: dict) -> str: @@ -96,22 +194,34 @@ def _detect_format(model_path: Path, config: dict) -> str: if key.endswith(".weight"): tensor = reader.get_tensor(key) if tensor.dtype == torch.float8_e4m3fn: + LOG.debug("Detected FP8 weights via dtype on tensor '%s'", key) return "fp8" if tensor.dtype == torch.uint8 and (key + "_scale") in keys: + LOG.debug("Detected NVFP4 weights via dtype on tensor '%s'", key) return "nvfp4" if any(k.endswith(".weight_scale") for k in keys): + LOG.debug("Detected NVFP4 format via '.weight_scale' metadata in shard '%s'", files[0]) return "nvfp4" if any(k.endswith(".weight_scale_inv") for k in keys): + LOG.debug("Detected FP8 format via '.weight_scale_inv' metadata in shard '%s'", files[0]) return "fp8" if any(k.endswith(".qweight") for k in keys): has_g = any(k.endswith(".g_idx") for k in keys) + LOG.debug( + "Detected %s format via qweight tensors in shard '%s'", + "gptq" if has_g else "awq", + files[0], + ) return "gptq" if has_g else "awq" if fmt == "float8_e4m3fn": + LOG.debug("Detected FP8 format via config fmt=%s", fmt) return "fp8" if method in ("gptq", "gptqmodel"): + LOG.debug("Detected GPTQ format via quant_method=%s", method) return "gptq" if method == "awq": + LOG.debug("Detected AWQ format via quant_method=%s", method) return "awq" raise ValueError("Unable to detect quantization format for model") @@ -145,11 +255,9 @@ def _convert_fp8_shard( reader, target_dtype: torch.dtype, *, - block_shape: tuple[int, int], + block_shape: Optional[Tuple[int, int]], ) -> Dict[str, torch.Tensor]: tensors: Dict[str, torch.Tensor] = {} - block_rows, block_cols = block_shape - for key in reader.keys(): tensor = reader.get_tensor(key) if key.endswith(".weight") and tensor.dtype == torch.float8_e4m3fn: @@ -157,11 +265,30 @@ def _convert_fp8_shard( if scale_key not in reader.keys(): raise KeyError(f"Missing scale inverse tensor for {key}") scale_inv = reader.get_tensor(scale_key) + LOG.debug("Using scale_inv tensor '%s' for FP8 weight '%s'", scale_key, key) rows, cols = tensor.shape + effective_block = block_shape + if effective_block is None: + try: + effective_block = _infer_block_shape((rows, cols), scale_inv) + LOG.debug("Inferred block size %s for weight '%s'", effective_block, key) + except ValueError as exc: + LOG.debug( + "Falling back to full-tensor block size for weight '%s' (%s)", + key, + exc, + ) + effective_block = (rows, cols) + else: + LOG.debug("Using configured block size %s for weight '%s'", effective_block, key) + + block_rows, block_cols = effective_block + if block_rows <= 0 or block_cols <= 0: + raise ValueError(f"Inferred invalid block size {effective_block} for {key}") if rows % block_rows != 0 or cols % block_cols != 0: raise ValueError( - f"Tensor {key} shape {tensor.shape} incompatible with block size {block_shape}" + f"Tensor {key} shape {tensor.shape} incompatible with block size {effective_block}" ) deq = dequantize_f8_e4m3( @@ -172,6 +299,7 @@ def _convert_fp8_shard( ) tensors[key] = _finalize_for_save(deq, target_dtype) elif key.endswith("_scale_inv"): + LOG.debug("Dropping auxiliary FP8 tensor '%s' after dequantization", key) continue else: tensors[key] = _finalize_for_save(tensor, target_dtype) @@ -187,6 +315,7 @@ def _convert_nvfp4_shard(reader, target_dtype: torch.dtype) -> Dict[str, torch.T if scale_key not in reader.keys(): raise KeyError(f"Missing scale tensor for {key}") scale = reader.get_tensor(scale_key) + LOG.debug("Using scale tensor '%s' for NVFP4 weight '%s'", scale_key, key) deq = dequantize_f4_e2m1( tensor, scale=scale, @@ -195,6 +324,7 @@ def _convert_nvfp4_shard(reader, target_dtype: torch.dtype) -> Dict[str, torch.T ) tensors[key] = _finalize_for_save(deq, target_dtype) elif key.endswith("_weight_scale"): + LOG.debug("Dropping auxiliary NVFP4 tensor '%s' after dequantization", key) continue else: tensors[key] = _finalize_for_save(tensor, target_dtype) @@ -210,12 +340,15 @@ def _convert_awq_file(path: Path, target_dtype: torch.dtype, device: str) -> Dic if key.endswith(".qweight"): prefix = key[:-len(".qweight")] module_buffers[prefix]["qweight"] = tensor + LOG.debug("Collected AWQ qweight tensor '%s'", key) elif key.endswith(".qzeros"): prefix = key[:-len(".qzeros")] module_buffers[prefix]["qzeros"] = tensor + LOG.debug("Collected AWQ qzeros tensor '%s'", key) elif key.endswith(".scales"): prefix = key[:-len(".scales")] module_buffers[prefix]["scales"] = tensor + LOG.debug("Collected AWQ scale tensor '%s'", key) else: tensors[key] = _finalize_for_save(tensor, target_dtype) @@ -242,6 +375,7 @@ def _convert_awq_file(path: Path, target_dtype: torch.dtype, device: str) -> Dic weight.to(target_dtype).t().contiguous(), target_dtype, ) + LOG.debug("Dequantized AWQ module '%s' to dtype %s", prefix, target_dtype) if prefix + ".bias" in tensors: tensors[prefix + ".bias"] = _finalize_for_save(tensors[prefix + ".bias"], target_dtype) @@ -257,15 +391,19 @@ def _convert_gptq_file(path: Path, target_dtype: torch.dtype, config: dict, devi if key.endswith(".qweight"): prefix = key[:-len(".qweight")] module_buffers[prefix]["qweight"] = tensor + LOG.debug("Collected GPTQ qweight tensor '%s'", key) elif key.endswith(".qzeros"): prefix = key[:-len(".qzeros")] module_buffers[prefix]["qzeros"] = tensor + LOG.debug("Collected GPTQ qzeros tensor '%s'", key) elif key.endswith(".scales"): prefix = key[:-len(".scales")] module_buffers[prefix]["scales"] = tensor + LOG.debug("Collected GPTQ scale tensor '%s'", key) elif key.endswith(".g_idx"): prefix = key[:-len(".g_idx")] module_buffers[prefix]["g_idx"] = tensor + LOG.debug("Collected GPTQ g_idx tensor '%s'", key) else: tensors[key] = _finalize_for_save(tensor, target_dtype) @@ -290,6 +428,12 @@ def _convert_gptq_file(path: Path, target_dtype: torch.dtype, config: dict, devi weight.to(target_dtype).t().contiguous(), target_dtype, ) + LOG.debug( + "Dequantized GPTQ module '%s' with %d-bit groups to dtype %s", + prefix, + bits, + target_dtype, + ) if prefix + ".bias" in tensors: tensors[prefix + ".bias"] = _finalize_for_save(tensors[prefix + ".bias"], target_dtype) @@ -339,7 +483,19 @@ def dequantize_model( block_shape = _resolve_block_size(config) if fmt == "fp8" else None + if block_shape is not None: + LOG.debug("Configured FP8 block size %s found in quantization_config", block_shape) + else: + LOG.debug("No explicit FP8 block size found; will infer from scale tensors if needed") + log = setup_logger() + LOG.debug( + "Starting dequantization for model '%s' (format=%s, target_dtype=%s, device=%s)", + model_path, + fmt, + target_dtype, + open_device, + ) pb = log.pb(range(len(files))).manual().set(show_left_steps=False).title(f"Dequantizing ({fmt})") pb.draw() @@ -349,10 +505,9 @@ def dequantize_model( try: for idx, filename in enumerate(files): path = model_path / filename + LOG.debug("Processing shard '%s' for format %s on device %s", filename, fmt, open_device) if fmt == "fp8": with safe_open(path, framework="pt", device=open_device) as reader: - if block_shape is None: - raise RuntimeError("FP8 conversion requires block_shape metadata") tensors = _convert_fp8_shard(reader, target_dtype, block_shape=block_shape) elif fmt == "nvfp4": with safe_open(path, framework="pt", device=open_device) as reader: diff --git a/tests/test_model_dequant.py b/tests/test_model_dequant.py index 7a9aab1ab..2e96967d6 100644 --- a/tests/test_model_dequant.py +++ b/tests/test_model_dequant.py @@ -37,6 +37,43 @@ def _write_index(path: Path, shard: str, keys: list[str]) -> None: (path / "model.safetensors.index.json").write_text(json.dumps(payload)) +@pytest.mark.skipif(not hasattr(torch, "float8_e4m3fn"), reason="float8 dtype not available") +def test_dequantize_model_fp8_infers_block_size(tmp_path): + model_dir = tmp_path / "fp8_model_infer" + output_dir = tmp_path / "fp8_output_infer" + model_dir.mkdir() + + config = { + "architectures": ["TestModel"], + "quantization_config": { + "fmt": "float8_e4m3fn", + "quant_method": "fp8", + }, + } + (model_dir / "config.json").write_text(json.dumps(config)) + + weight = torch.randn(4, 8, dtype=torch.float32).to(torch.float8_e4m3fn) + scale_inv = torch.ones(2, 2, dtype=torch.float32) + shard_name = "model.safetensors" + save_file( + { + "linear.weight": weight, + "linear.weight_scale_inv": scale_inv, + }, + str(model_dir / shard_name), + ) + _write_index(model_dir, shard_name, ["linear.weight", "linear.weight_scale_inv"]) + + dequantize_model(model_dir, output_dir, target_dtype=torch.bfloat16, device="cpu") + + with safe_open(output_dir / shard_name, framework="pt", device="cpu") as reader: + weight_out = reader.get_tensor("linear.weight") + assert weight_out.dtype is torch.bfloat16 + + expected = dequantize_f8_e4m3(weight, scale_inv=scale_inv, axis=None, target_dtype=torch.bfloat16) + assert torch.equal(weight_out, expected) + + @pytest.mark.skipif(not hasattr(torch, "float8_e4m3fn"), reason="float8 dtype not available") def test_dequantize_model_fp8(tmp_path): model_dir = tmp_path / "fp8_model" From 63576c35b697dc05cb45a2dd762a03751a270eda Mon Sep 17 00:00:00 2001 From: Qubitium Date: Mon, 27 Oct 2025 14:59:59 +0000 Subject: [PATCH 05/10] cleanup --- gptqmodel/utils/model_dequant.py | 90 ++++++++++++++++---------------- 1 file changed, 45 insertions(+), 45 deletions(-) diff --git a/gptqmodel/utils/model_dequant.py b/gptqmodel/utils/model_dequant.py index fe088e564..7dc1ba5ef 100644 --- a/gptqmodel/utils/model_dequant.py +++ b/gptqmodel/utils/model_dequant.py @@ -25,22 +25,22 @@ LOG = logging.getLogger(__name__) -def _load_json(path: Path) -> dict: +def load_json(path: Path) -> dict: if not path.exists(): return {} with path.open("r", encoding="utf-8") as fh: return json.load(fh) -def _write_json(path: Path, payload: dict) -> None: +def write_json(path: Path, payload: dict) -> None: with path.open("w", encoding="utf-8") as fh: json.dump(payload, fh, indent=2) -def _list_safetensor_files(model_path: Path) -> Tuple[list, Optional[dict]]: +def list_safetensor_files(model_path: Path) -> Tuple[list, Optional[dict]]: index_path = model_path / "model.safetensors.index.json" if index_path.exists(): - index = _load_json(index_path) + index = load_json(index_path) files = sorted(set(index.get("weight_map", {}).values())) return files, index @@ -48,7 +48,7 @@ def _list_safetensor_files(model_path: Path) -> Tuple[list, Optional[dict]]: return files, None -def _finalize_for_save(tensor: torch.Tensor, target_dtype: torch.dtype) -> torch.Tensor: +def finalize_for_save(tensor: torch.Tensor, target_dtype: torch.dtype) -> torch.Tensor: """Cast to ``target_dtype`` when floating point and move to CPU with optimal layout.""" if torch.is_floating_point(tensor): @@ -60,7 +60,7 @@ def _finalize_for_save(tensor: torch.Tensor, target_dtype: torch.dtype) -> torch return tensor_cpu -def _normalize_device(device: Optional[str]) -> Optional[str]: +def normalize_device(device: Optional[str]) -> Optional[str]: if device is None: return None device = device.strip() @@ -76,7 +76,7 @@ def _normalize_device(device: Optional[str]) -> Optional[str]: return f"cuda:{dev.index}" -def _resolve_block_size(config: dict) -> Optional[Tuple[int, int]]: +def resolve_block_size(config: dict) -> Optional[Tuple[int, int]]: quant_cfg = config.get("quantization_config", {}) or {} block_size = quant_cfg.get("weight_block_size") if isinstance(block_size, (list, tuple)) and len(block_size) == 2: @@ -84,7 +84,7 @@ def _resolve_block_size(config: dict) -> Optional[Tuple[int, int]]: return None -def _infer_block_shape(weight_shape: Tuple[int, int], scale_tensor: torch.Tensor) -> Tuple[int, int]: +def infer_block_shape(weight_shape: Tuple[int, int], scale_tensor: torch.Tensor) -> Tuple[int, int]: rows, cols = weight_shape shape = tuple(scale_tensor.shape) @@ -175,15 +175,15 @@ def _infer_block_shape(weight_shape: Tuple[int, int], scale_tensor: torch.Tensor raise ValueError("unable to infer block size from 1D scale tensor") - raise ValueError("unsupported scale tensor rank for block size inference") + raise ValueError("unsupported scale tensor rank for block size inference") -def _detect_format(model_path: Path, config: dict) -> str: +def detect_format(model_path: Path, config: dict) -> str: quant_cfg = config.get("quantization_config", {}) or {} method = (quant_cfg.get("quant_method") or "").lower() fmt = (quant_cfg.get("fmt") or "").lower() - files, _ = _list_safetensor_files(model_path) + files, _ = list_safetensor_files(model_path) if not files: raise FileNotFoundError("No .safetensors files found in model directory") @@ -227,7 +227,7 @@ def _detect_format(model_path: Path, config: dict) -> str: raise ValueError("Unable to detect quantization format for model") -def _unpack_cols(packed: torch.Tensor, bits: int) -> torch.Tensor: +def unpack_cols(packed: torch.Tensor, bits: int) -> torch.Tensor: pack_bits = packed.element_size() * 8 pack_factor = pack_bits // bits mask = (1 << bits) - 1 @@ -239,7 +239,7 @@ def _unpack_cols(packed: torch.Tensor, bits: int) -> torch.Tensor: return result -def _unpack_rows(packed: torch.Tensor, bits: int) -> torch.Tensor: +def unpack_rows(packed: torch.Tensor, bits: int) -> torch.Tensor: pack_bits = packed.element_size() * 8 pack_factor = pack_bits // bits mask = (1 << bits) - 1 @@ -251,7 +251,7 @@ def _unpack_rows(packed: torch.Tensor, bits: int) -> torch.Tensor: return result -def _convert_fp8_shard( +def convert_fp8_shard( reader, target_dtype: torch.dtype, *, @@ -271,7 +271,7 @@ def _convert_fp8_shard( effective_block = block_shape if effective_block is None: try: - effective_block = _infer_block_shape((rows, cols), scale_inv) + effective_block = infer_block_shape((rows, cols), scale_inv) LOG.debug("Inferred block size %s for weight '%s'", effective_block, key) except ValueError as exc: LOG.debug( @@ -297,16 +297,16 @@ def _convert_fp8_shard( axis=None, target_dtype=target_dtype, ) - tensors[key] = _finalize_for_save(deq, target_dtype) + tensors[key] = finalize_for_save(deq, target_dtype) elif key.endswith("_scale_inv"): LOG.debug("Dropping auxiliary FP8 tensor '%s' after dequantization", key) continue else: - tensors[key] = _finalize_for_save(tensor, target_dtype) + tensors[key] = finalize_for_save(tensor, target_dtype) return tensors -def _convert_nvfp4_shard(reader, target_dtype: torch.dtype) -> Dict[str, torch.Tensor]: +def convert_nvfp4_shard(reader, target_dtype: torch.dtype) -> Dict[str, torch.Tensor]: tensors: Dict[str, torch.Tensor] = {} for key in reader.keys(): tensor = reader.get_tensor(key) @@ -322,16 +322,16 @@ def _convert_nvfp4_shard(reader, target_dtype: torch.dtype) -> Dict[str, torch.T axis=None, target_dtype=target_dtype, ) - tensors[key] = _finalize_for_save(deq, target_dtype) + tensors[key] = finalize_for_save(deq, target_dtype) elif key.endswith("_weight_scale"): LOG.debug("Dropping auxiliary NVFP4 tensor '%s' after dequantization", key) continue else: - tensors[key] = _finalize_for_save(tensor, target_dtype) + tensors[key] = finalize_for_save(tensor, target_dtype) return tensors -def _convert_awq_file(path: Path, target_dtype: torch.dtype, device: str) -> Dict[str, torch.Tensor]: +def convert_awq_file(path: Path, target_dtype: torch.dtype, device: str) -> Dict[str, torch.Tensor]: tensors: Dict[str, torch.Tensor] = {} module_buffers: Dict[str, Dict[str, torch.Tensor]] = defaultdict(dict) with safe_open(path, framework="pt", device=device) as reader: @@ -350,7 +350,7 @@ def _convert_awq_file(path: Path, target_dtype: torch.dtype, device: str) -> Dic module_buffers[prefix]["scales"] = tensor LOG.debug("Collected AWQ scale tensor '%s'", key) else: - tensors[key] = _finalize_for_save(tensor, target_dtype) + tensors[key] = finalize_for_save(tensor, target_dtype) for prefix, buf in module_buffers.items(): missing = {k for k in ("qweight", "qzeros", "scales") if k not in buf} @@ -362,8 +362,8 @@ def _convert_awq_file(path: Path, target_dtype: torch.dtype, device: str) -> Dic scales = buf["scales"] bits = 4 - unpacked_weight = _unpack_cols(qweight, bits).to(torch.float32) - unpacked_zeros = _unpack_cols(qzeros, bits).to(torch.float32) + unpacked_weight = unpack_cols(qweight, bits).to(torch.float32) + unpacked_zeros = unpack_cols(qzeros, bits).to(torch.float32) num_groups = scales.shape[0] group_size = unpacked_weight.shape[0] // num_groups @@ -371,18 +371,18 @@ def _convert_awq_file(path: Path, target_dtype: torch.dtype, device: str) -> Dic zeros_full = unpacked_zeros.repeat_interleave(group_size, dim=0) weight = (unpacked_weight - zeros_full) * scales_full - tensors[prefix + ".weight"] = _finalize_for_save( + tensors[prefix + ".weight"] = finalize_for_save( weight.to(target_dtype).t().contiguous(), target_dtype, ) LOG.debug("Dequantized AWQ module '%s' to dtype %s", prefix, target_dtype) if prefix + ".bias" in tensors: - tensors[prefix + ".bias"] = _finalize_for_save(tensors[prefix + ".bias"], target_dtype) + tensors[prefix + ".bias"] = finalize_for_save(tensors[prefix + ".bias"], target_dtype) return tensors -def _convert_gptq_file(path: Path, target_dtype: torch.dtype, config: dict, device: str) -> Dict[str, torch.Tensor]: +def convert_gptq_file(path: Path, target_dtype: torch.dtype, config: dict, device: str) -> Dict[str, torch.Tensor]: tensors: Dict[str, torch.Tensor] = {} module_buffers: Dict[str, Dict[str, torch.Tensor]] = defaultdict(dict) with safe_open(path, framework="pt", device=device) as reader: @@ -405,7 +405,7 @@ def _convert_gptq_file(path: Path, target_dtype: torch.dtype, config: dict, devi module_buffers[prefix]["g_idx"] = tensor LOG.debug("Collected GPTQ g_idx tensor '%s'", key) else: - tensors[key] = _finalize_for_save(tensor, target_dtype) + tensors[key] = finalize_for_save(tensor, target_dtype) for prefix, buf in module_buffers.items(): missing = {k for k in ("qweight", "qzeros", "scales", "g_idx") if k not in buf} @@ -418,13 +418,13 @@ def _convert_gptq_file(path: Path, target_dtype: torch.dtype, config: dict, devi g_idx = buf["g_idx"].to(torch.long) bits = config.get("bits", 4) - weight_int = _unpack_rows(qweight, bits) - zeros = _unpack_cols(qzeros, bits) + weight_int = unpack_rows(qweight, bits) + zeros = unpack_cols(qzeros, bits) scales_full = scales.to(torch.float32)[g_idx] zeros_full = zeros.to(torch.float32)[g_idx] weight = (weight_int.to(torch.float32) - zeros_full) * scales_full - tensors[prefix + ".weight"] = _finalize_for_save( + tensors[prefix + ".weight"] = finalize_for_save( weight.to(target_dtype).t().contiguous(), target_dtype, ) @@ -435,12 +435,12 @@ def _convert_gptq_file(path: Path, target_dtype: torch.dtype, config: dict, devi target_dtype, ) if prefix + ".bias" in tensors: - tensors[prefix + ".bias"] = _finalize_for_save(tensors[prefix + ".bias"], target_dtype) + tensors[prefix + ".bias"] = finalize_for_save(tensors[prefix + ".bias"], target_dtype) return tensors -def _copy_aux_files(model_path: Path, output_path: Path, skip: Iterable[str]) -> None: +def copy_aux_files(model_path: Path, output_path: Path, skip: Iterable[str]) -> None: for item in model_path.iterdir(): if item.name in skip: continue @@ -466,22 +466,22 @@ def dequantize_model( output_path.mkdir(parents=True) - config = _load_json(model_path / "config.json") + config = load_json(model_path / "config.json") quant_cfg = config.get("quantization_config", {}) or {} - fmt = _detect_format(model_path, config) + fmt = detect_format(model_path, config) - files, index = _list_safetensor_files(model_path) + files, index = list_safetensor_files(model_path) if not files: raise RuntimeError("No safetensor files to convert") - device_str = _normalize_device(device) + device_str = normalize_device(device) if device_str is not None: if not torch.cuda.is_available(): raise RuntimeError("CUDA is not available for GPU dequantization") torch.cuda.set_device(torch.device(device_str)) open_device = device_str or "cpu" - block_shape = _resolve_block_size(config) if fmt == "fp8" else None + block_shape = resolve_block_size(config) if fmt == "fp8" else None if block_shape is not None: LOG.debug("Configured FP8 block size %s found in quantization_config", block_shape) @@ -508,14 +508,14 @@ def dequantize_model( LOG.debug("Processing shard '%s' for format %s on device %s", filename, fmt, open_device) if fmt == "fp8": with safe_open(path, framework="pt", device=open_device) as reader: - tensors = _convert_fp8_shard(reader, target_dtype, block_shape=block_shape) + tensors = convert_fp8_shard(reader, target_dtype, block_shape=block_shape) elif fmt == "nvfp4": with safe_open(path, framework="pt", device=open_device) as reader: - tensors = _convert_nvfp4_shard(reader, target_dtype) + tensors = convert_nvfp4_shard(reader, target_dtype) elif fmt == "awq": - tensors = _convert_awq_file(path, target_dtype, open_device) + tensors = convert_awq_file(path, target_dtype, open_device) elif fmt == "gptq": - tensors = _convert_gptq_file(path, target_dtype, quant_cfg, open_device) + tensors = convert_gptq_file(path, target_dtype, quant_cfg, open_device) else: raise ValueError(f"Unsupported format {fmt}") @@ -534,15 +534,15 @@ def dequantize_model( metadata["total_size"] = total_size new_index["metadata"] = metadata new_index["weight_map"] = weight_map - _write_json(output_path / "model.safetensors.index.json", new_index) + write_json(output_path / "model.safetensors.index.json", new_index) new_config = dict(config) new_config.pop("quantization_config", None) new_config["torch_dtype"] = str(target_dtype).split(".")[-1] - _write_json(output_path / "config.json", new_config) + write_json(output_path / "config.json", new_config) skip_files = set(files) | {"config.json", "model.safetensors.index.json"} - _copy_aux_files(model_path, output_path, skip_files) + copy_aux_files(model_path, output_path, skip_files) # Backwards compatibility with older imports. From 8ce5e33ac0af553b2035ceb70c4aeaead53daca1 Mon Sep 17 00:00:00 2001 From: Qubitium Date: Mon, 27 Oct 2025 15:00:28 +0000 Subject: [PATCH 06/10] cleanup --- tests/test_model_dequant.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/tests/test_model_dequant.py b/tests/test_model_dequant.py index 2e96967d6..10df9af67 100644 --- a/tests/test_model_dequant.py +++ b/tests/test_model_dequant.py @@ -1,3 +1,8 @@ +# SPDX-FileCopyrightText: 2024-2025 ModelCloud.ai +# SPDX-FileCopyrightText: 2024-2025 qubitium@modelcloud.ai +# SPDX-License-Identifier: Apache-2.0 +# Contact: qubitium@modelcloud.ai, x.com/qubitium + import json from pathlib import Path @@ -32,7 +37,7 @@ def _pack_cols(values: torch.Tensor, bits: int = 4) -> torch.Tensor: def _write_index(path: Path, shard: str, keys: list[str]) -> None: - weight_map = {key: shard for key in keys} + weight_map = dict.fromkeys(keys, shard) payload = {"weight_map": weight_map} (path / "model.safetensors.index.json").write_text(json.dumps(payload)) From 240d226b311ade637d69682fd1d86dfc7fd020d8 Mon Sep 17 00:00:00 2001 From: Qubitium Date: Mon, 27 Oct 2025 15:01:58 +0000 Subject: [PATCH 07/10] cleanup --- tests/test_model_dequant.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/tests/test_model_dequant.py b/tests/test_model_dequant.py index 10df9af67..eb2f4a1c1 100644 --- a/tests/test_model_dequant.py +++ b/tests/test_model_dequant.py @@ -15,7 +15,7 @@ from gptqmodel.utils.model_dequant import dequantize_model -def _pack_cols(values: torch.Tensor, bits: int = 4) -> torch.Tensor: +def pack_cols(values: torch.Tensor, bits: int = 4) -> torch.Tensor: """Pack per-column low-bit values into int32 words.""" if values.dtype != torch.int32: @@ -36,7 +36,7 @@ def _pack_cols(values: torch.Tensor, bits: int = 4) -> torch.Tensor: return packed -def _write_index(path: Path, shard: str, keys: list[str]) -> None: +def write_index(path: Path, shard: str, keys: list[str]) -> None: weight_map = dict.fromkeys(keys, shard) payload = {"weight_map": weight_map} (path / "model.safetensors.index.json").write_text(json.dumps(payload)) @@ -67,7 +67,7 @@ def test_dequantize_model_fp8_infers_block_size(tmp_path): }, str(model_dir / shard_name), ) - _write_index(model_dir, shard_name, ["linear.weight", "linear.weight_scale_inv"]) + write_index(model_dir, shard_name, ["linear.weight", "linear.weight_scale_inv"]) dequantize_model(model_dir, output_dir, target_dtype=torch.bfloat16, device="cpu") @@ -106,7 +106,7 @@ def test_dequantize_model_fp8(tmp_path): }, str(model_dir / shard_name), ) - _write_index(model_dir, shard_name, ["linear.weight", "linear.weight_scale_inv", "linear.bias"]) + write_index(model_dir, shard_name, ["linear.weight", "linear.weight_scale_inv", "linear.bias"]) dequantize_model(model_dir, output_dir, target_dtype=torch.bfloat16, device="cpu") @@ -148,8 +148,8 @@ def test_dequantize_model_awq(tmp_path): scales = torch.rand(rows, cols, dtype=torch.float32) * 0.5 + 0.5 bias = torch.randn(cols, dtype=torch.float32) - packed_weight = _pack_cols(weight_values) - packed_zero = _pack_cols(zero_values) + packed_weight = pack_cols(weight_values) + packed_zero = pack_cols(zero_values) shard_name = "awq.safetensors" save_file( @@ -161,7 +161,7 @@ def test_dequantize_model_awq(tmp_path): }, str(model_dir / shard_name), ) - _write_index( + write_index( model_dir, shard_name, ["layer.qweight", "layer.qzeros", "layer.scales", "layer.bias"], From 5fae8cadfcac319de5269f25ebdedd4f16cbcedb Mon Sep 17 00:00:00 2001 From: Qubitium Date: Mon, 27 Oct 2025 15:03:20 +0000 Subject: [PATCH 08/10] cleanup --- gptqmodel/utils/model_dequant.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/gptqmodel/utils/model_dequant.py b/gptqmodel/utils/model_dequant.py index 7dc1ba5ef..5d2bba0a8 100644 --- a/gptqmodel/utils/model_dequant.py +++ b/gptqmodel/utils/model_dequant.py @@ -543,7 +543,3 @@ def dequantize_model( skip_files = set(files) | {"config.json", "model.safetensors.index.json"} copy_aux_files(model_path, output_path, skip_files) - - -# Backwards compatibility with older imports. -dequantize = dequantize_model From 797092895b87db9356c9ba9569c359b61f8adf1d Mon Sep 17 00:00:00 2001 From: Qubitium Date: Mon, 27 Oct 2025 15:04:44 +0000 Subject: [PATCH 09/10] ruff 0.14.2 --- format/format.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/format/format.sh b/format/format.sh index c31a01e59..5f9e35118 100755 --- a/format/format.sh +++ b/format/format.sh @@ -3,7 +3,7 @@ cd "$(dirname "$0")" || exit # force ruff/isort to be same version as setup.py -pip install -U ruff==0.13.0 +pip install -U ruff==0.14.2 #isort==6.0.1 ruff check ../gptqmodel/models ../gptqmodel/nn_modules ../gptqmodel/quantization ../gptqmodel/utils ../gptqmodel/__init__.py ../examples ../tests ../setup.py --fix --unsafe-fixes From 974650cf6f635a7a40406163b44312c9e3e930c7 Mon Sep 17 00:00:00 2001 From: Qubitium Date: Mon, 27 Oct 2025 16:49:22 +0000 Subject: [PATCH 10/10] cleanup --- tests/models/test_falcon.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/models/test_falcon.py b/tests/models/test_falcon.py index 3745ebf15..159ed0e47 100644 --- a/tests/models/test_falcon.py +++ b/tests/models/test_falcon.py @@ -3,7 +3,7 @@ # SPDX-License-Identifier: Apache-2.0 # Contact: qubitium@modelcloud.ai, x.com/qubitium -import torch # noqa: E402from tests.model_test import ModelTest +import torch # noqa: E402 from model_test import ModelTest from gptqmodel.utils.eval import EVAL