From 0485b0bb540d472a0e4c9ade22ff1d3fceaee48a Mon Sep 17 00:00:00 2001 From: Qubitium Date: Mon, 27 Oct 2025 09:40:52 +0000 Subject: [PATCH 1/2] model dequant --- gptqmodel/utils/model_dequant.py | 197 +++++++++++++++++++++++++++++++ scripts/dequantize_fp8_model.py | 71 +++++++++++ tests/test_quant_dtype.py | 80 +++++++++++++ 3 files changed, 348 insertions(+) create mode 100644 gptqmodel/utils/model_dequant.py create mode 100755 scripts/dequantize_fp8_model.py diff --git a/gptqmodel/utils/model_dequant.py b/gptqmodel/utils/model_dequant.py new file mode 100644 index 000000000..35a4be938 --- /dev/null +++ b/gptqmodel/utils/model_dequant.py @@ -0,0 +1,197 @@ +# SPDX-FileCopyrightText: 2024-2025 ModelCloud.ai +# SPDX-FileCopyrightText: 2024-2025 qubitium@modelcloud.ai +# SPDX-License-Identifier: Apache-2.0 +# Contact: qubitium@modelcloud.ai, x.com/qubitium + +"""Utilities for converting FP8-quantized models to higher precision.""" + +from __future__ import annotations + +import json +import shutil +from pathlib import Path +from typing import Optional + +import torch +from safetensors import safe_open +from safetensors.torch import save_file + +from ..quantization.dtype import dequantize_f8_e4m3 +from ..utils.logger import setup_logger + + +def _normalize_device(device: Optional[str]) -> Optional[str]: + if device is None: + return None + device = device.strip() + if device.lower() == "cpu": + return None + + dev = torch.device(device) + if dev.type != "cuda": + raise ValueError(f"Unsupported device type: {device}") + + if dev.index is None: + return "cuda:0" + return f"cuda:{dev.index}" + + +def _load_json(path: Path) -> dict: + if not path.exists(): + return {} + return json.loads(path.read_text()) + + +def _is_fp8_format(config: dict) -> bool: + quant_cfg = config.get("quantization_config", {}) + fmt = quant_cfg.get("fmt") + return fmt == "float8_e4m3fn" + + +def _resolve_block_size(config: dict) -> tuple[int, int]: + quant_cfg = config.get("quantization_config", {}) + block_size = quant_cfg.get("weight_block_size") + if isinstance(block_size, (list, tuple)) and len(block_size) == 2: + return int(block_size[0]), int(block_size[1]) + return (128, 128) + + +def _dequantize_shard( + shard_path: Path, + output_path: Path, + *, + target_dtype: torch.dtype, + block_shape: tuple[int, int], + device: Optional[str], +) -> None: + tensors = {} + open_device = device or "cpu" + + with safe_open(shard_path, framework="pt", device=open_device) as reader: + for name in reader.keys(): + tensor = reader.get_tensor(name) + + if tensor.dtype is torch.float8_e4m3fn: + scale_inv_name = name + "_scale_inv" + if scale_inv_name not in reader.keys(): + tensors[name] = tensor.to(target_dtype).to("cpu") + continue + + rows, cols = tensor.shape + block_rows, block_cols = block_shape + if rows % block_rows != 0 or cols % block_cols != 0: + raise ValueError( + f"Tensor {name} shape {tensor.shape} incompatible with block size {block_shape}" + ) + + scale_inv = reader.get_tensor(scale_inv_name) + tensors[name] = dequantize_f8_e4m3( + tensor, + scale_inv=scale_inv, + axis=None, + target_dtype=target_dtype, + ).to("cpu") + elif tensor.dtype is torch.uint8 and name.endswith(".weight"): + tensors[name] = tensor.to(target_dtype).to("cpu") + else: + tensors[name] = tensor.to("cpu") + + save_file(tensors, str(output_path)) + + +def dequantize( + model_path: str | Path, + model_output_path: str | Path, + *, + target_dtype: torch.dtype = torch.bfloat16, + device: Optional[str] = None, +) -> None: + """Dequantize an FP8 E4M3 model into the requested ``target_dtype``. + + Parameters + ---------- + model_path: + Directory containing ``model.safetensors`` shards and ``config.json``. + model_output_path: + Destination directory for the dequantized model. + target_dtype: + Desired floating point dtype (defaults to ``torch.bfloat16``). + """ + + model_path = Path(model_path) + model_output_path = Path(model_output_path) + + if not model_path.exists(): + raise FileNotFoundError(f"Model path {model_path} does not exist") + + if model_output_path.exists(): + raise FileExistsError(f"Output path {model_output_path} already exists") + + config = _load_json(model_path / "config.json") + if not _is_fp8_format(config): + raise ValueError("Model does not advertise float8_e4m3fn quantization") + + block_shape = _resolve_block_size(config) + + device_str = _normalize_device(device) + if device_str is not None: + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is not available for GPU dequantization") + torch.cuda.set_device(torch.device(device_str)) + + output_path = model_output_path + output_path.mkdir(parents=True, exist_ok=False) + + index_path = model_path / "model.safetensors.index.json" + index = _load_json(index_path) + weight_map: dict[str, str] = index.get("weight_map", {}) + + shard_names = sorted(set(weight_map.values())) + + new_weight_map = {} + log = setup_logger() + pb = ( + log.pb(range(len(shard_names))) + .manual() + .set(show_left_steps=False) + .title("Dequantizing FP8 shards") + ) + pb.draw() + + for shard in shard_names: + shard_path = model_path / shard + output_shard = output_path / shard + output_shard.parent.mkdir(parents=True, exist_ok=True) + + _dequantize_shard( + shard_path, + output_shard, + target_dtype=target_dtype, + block_shape=block_shape, + device=device_str, + ) + + new_weight_map.update({name: shard for name, shard in weight_map.items() if weight_map[name] == shard}) + pb.subtitle(shard).next().draw() + + pb.close() + + new_index = dict(index) + new_index["weight_map"] = new_weight_map + (output_path / "model.safetensors.index.json").write_text(json.dumps(new_index, indent=2)) + + new_config = dict(config) + new_config.pop("quantization_config", None) + new_config["torch_dtype"] = target_dtype.__repr__().split(".")[-1] + (output_path / "config.json").write_text(json.dumps(new_config, indent=2)) + + skip_files = {"config.json", "model.safetensors.index.json"}.union(shard_names) + + for entry in model_path.iterdir(): + if entry.name in skip_files: + continue + target = output_path / entry.name + if entry.is_dir(): + shutil.copytree(entry, target) + else: + shutil.copy2(entry, target) diff --git a/scripts/dequantize_fp8_model.py b/scripts/dequantize_fp8_model.py new file mode 100755 index 000000000..e0c2c8a5f --- /dev/null +++ b/scripts/dequantize_fp8_model.py @@ -0,0 +1,71 @@ +#!/usr/bin/env python +# SPDX-FileCopyrightText: 2024-2025 ModelCloud.ai +# SPDX-License-Identifier: Apache-2.0 + +"""Dequantize an FP8 E4M3 model into BF16 using gptqmodel.utils.model_dequant.""" + +from __future__ import annotations + +import argparse +from pathlib import Path + +import torch + +from gptqmodel.utils.model_dequant import dequantize + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser(description="Dequantize FP8 model shards to BF16") + parser.add_argument("model_path", type=Path, nargs="?", help="Path to the FP8 model directory") + parser.add_argument("output_path", type=Path, nargs="?", help="Destination directory for the BF16 model") + parser.add_argument("--model_path", dest="model_path_opt", type=Path, help="Path to the FP8 model directory") + parser.add_argument("--output_path", dest="output_path_opt", type=Path, help="Destination directory for the BF16 model") + parser.add_argument( + "--dtype", + default="bfloat16", + choices=["bfloat16", "float16", "float32"], + help="Output dtype (default: bfloat16)", + ) + parser.add_argument( + "--device", + default="cpu", + help="Device for intermediate dequantization (e.g. cpu, cuda, cuda:0)", + ) + return parser.parse_args() + + +def resolve_dtype(name: str) -> torch.dtype: + if name == "bfloat16": + return torch.bfloat16 + if name == "float16": + return torch.float16 + if name == "float32": + return torch.float32 + raise ValueError(f"Unsupported dtype: {name}") + + +def main() -> None: + args = parse_args() + model_path = args.model_path if args.model_path is not None else args.model_path_opt + output_path = args.output_path if args.output_path is not None else args.output_path_opt + if model_path is None or output_path is None: + raise SystemExit("model_path and output_path must be provided either positionally or via flags") + + dtype = resolve_dtype(args.dtype) + device = None + if args.device is not None and args.device.lower() != "cpu": + device = args.device + print( + "[dequantize_fp8_model] args", + { + "model_path": str(model_path), + "output_path": str(output_path), + "dtype": dtype, + "device": device or "cpu", + }, + ) + dequantize(model_path, output_path, target_dtype=dtype, device=device) + + +if __name__ == "__main__": + main() diff --git a/tests/test_quant_dtype.py b/tests/test_quant_dtype.py index d271ab423..ba512ae9b 100644 --- a/tests/test_quant_dtype.py +++ b/tests/test_quant_dtype.py @@ -1,6 +1,9 @@ +import time + import torch import pytest +from tabulate import tabulate from gptqmodel.quantization.dtype import ( dequantize_f8_e4m3, @@ -8,6 +11,11 @@ ) +def _print_accuracy(title: str, rows, headers) -> None: + table = tabulate(rows, headers=headers, floatfmt=".6f") + print(f"\n{title}\n{table}\n") + + @pytest.mark.skipif(not hasattr(torch, "float8_e4m3fn"), reason="float8 dtype not available") def test_dequantize_f8_e4m3_basic_conversion(): values = torch.linspace(-1, 1, steps=8, dtype=torch.float32) @@ -28,6 +36,15 @@ def test_dequantize_f8_e4m3_with_scale_inv(): got = dequantize_f8_e4m3(fp8, scale_inv=scale_inv, axis=0) expected = (fp8.to(torch.bfloat16) / scale_inv.view(-1, 1).to(torch.bfloat16)).to(torch.bfloat16) + diff = torch.max(torch.abs(got - expected)).item() + _print_accuracy( + "dequantize_f8_e4m3_with_scale_inv", + [ + ["baseline", str(expected.dtype), float(expected.abs().max().item()), 0.0], + ["candidate", str(got.dtype), float(got.abs().max().item()), diff], + ], + ["variant", "dtype", "max|value|", "max|diff vs baseline|"], + ) assert torch.equal(got, expected) @@ -40,6 +57,15 @@ def test_dequantize_f8_e4m3_with_scale_axis_one(): got = dequantize_f8_e4m3(fp8, scale=scale, axis=1) expected = (fp8.to(torch.bfloat16) * scale.view(1, -1)).to(torch.bfloat16) + diff = torch.max(torch.abs(got - expected)).item() + _print_accuracy( + "dequantize_f8_e4m3_with_scale_axis_one", + [ + ["baseline", str(expected.dtype), float(expected.abs().max().item()), 0.0], + ["candidate", str(got.dtype), float(got.abs().max().item()), diff], + ], + ["variant", "dtype", "max|value|", "max|diff vs baseline|"], + ) assert torch.equal(got, expected) @@ -52,6 +78,15 @@ def test_dequantize_f8_e4m3_with_fractional_scale_inv(): got = dequantize_f8_e4m3(fp8, scale_inv=scale_inv, axis=0) expected = (fp8.to(torch.bfloat16) * scale_inv.view(-1, 1).to(torch.bfloat16)).to(torch.bfloat16) + diff = torch.max(torch.abs(got - expected)).item() + _print_accuracy( + "dequantize_f8_e4m3_with_fractional_scale_inv", + [ + ["baseline", str(expected.dtype), float(expected.abs().max().item()), 0.0], + ["candidate", str(got.dtype), float(got.abs().max().item()), diff], + ], + ["variant", "dtype", "max|value|", "max|diff vs baseline|"], + ) assert torch.equal(got, expected) @@ -69,3 +104,48 @@ def test_device_supports_native_fp8_reports_capability(monkeypatch): monkeypatch.setattr("torch.cuda.get_device_capability", lambda device=None: (8, 0)) assert device_supports_native_fp8(torch.device("cuda", 0)) is False + + +@pytest.mark.skipif( + not torch.cuda.is_available(), + reason="CUDA device required for GPU benchmark", +) +@pytest.mark.skipif(not hasattr(torch, "float8_e4m3fn"), reason="float8 dtype not available") +def test_dequantize_f8_e4m3_cpu_vs_gpu_benchmark(): + rows = 128 * 4 + cols = 128 * 3 + src = torch.randn(rows, cols, dtype=torch.float32) + fp8 = src.to(torch.float8_e4m3fn) + scale_inv = torch.rand(rows // 128, cols // 128, dtype=torch.float32) * 0.5 + + # Warmup + dequantize_f8_e4m3(fp8, scale_inv=scale_inv, axis=None, target_dtype=torch.bfloat16) + + start = time.perf_counter() + cpu_result = dequantize_f8_e4m3(fp8, scale_inv=scale_inv, axis=None, target_dtype=torch.bfloat16) + cpu_time = time.perf_counter() - start + + fp8_gpu = fp8.cuda() + scale_inv_gpu = scale_inv.cuda() + + dequantize_f8_e4m3(fp8_gpu, scale_inv=scale_inv_gpu, axis=None, target_dtype=torch.bfloat16) + torch.cuda.synchronize() + + start = time.perf_counter() + gpu_result = dequantize_f8_e4m3(fp8_gpu, scale_inv=scale_inv_gpu, axis=None, target_dtype=torch.bfloat16) + torch.cuda.synchronize() + gpu_time = time.perf_counter() - start + + diff = torch.max(torch.abs(cpu_result - gpu_result.cpu())).item() + _print_accuracy( + "dequantize_f8_e4m3_cpu_vs_gpu", + [ + ["CPU", cpu_time, float(cpu_result.abs().max().item()), 0.0], + ["GPU", gpu_time, float(gpu_result.abs().max().item()), diff], + ], + ["variant", "time (s)", "max|value|", "max|diff vs baseline|"], + ) + assert torch.allclose(cpu_result, gpu_result.cpu(), atol=1e-3, rtol=1e-3) + + # GPU should not be dramatically slower than CPU + assert gpu_time <= cpu_time * 2, f"GPU dequant slower than expected (cpu={cpu_time:.4f}s, gpu={gpu_time:.4f}s)" From 02740f5532c13f7535382dccb21bf9e6b4336268 Mon Sep 17 00:00:00 2001 From: Qubitium Date: Mon, 27 Oct 2025 09:50:48 +0000 Subject: [PATCH 2/2] cleanup --- gptqmodel/utils/model_dequant.py | 19 +++++++++++++++---- scripts/dequantize_fp8_model.py | 5 +++++ 2 files changed, 20 insertions(+), 4 deletions(-) diff --git a/gptqmodel/utils/model_dequant.py b/gptqmodel/utils/model_dequant.py index 35a4be938..1d5c93400 100644 --- a/gptqmodel/utils/model_dequant.py +++ b/gptqmodel/utils/model_dequant.py @@ -85,16 +85,27 @@ def _dequantize_shard( ) scale_inv = reader.get_tensor(scale_inv_name) - tensors[name] = dequantize_f8_e4m3( + deq = dequantize_f8_e4m3( tensor, scale_inv=scale_inv, axis=None, target_dtype=target_dtype, - ).to("cpu") + ) + if deq.ndimension() >= 4: + tensors[name] = deq.to("cpu", memory_format=torch.channels_last) + else: + tensors[name] = deq.to("cpu") elif tensor.dtype is torch.uint8 and name.endswith(".weight"): - tensors[name] = tensor.to(target_dtype).to("cpu") + converted = tensor.to(target_dtype) + if converted.ndimension() >= 4: + tensors[name] = converted.to("cpu", memory_format=torch.channels_last) + else: + tensors[name] = converted.to("cpu") else: - tensors[name] = tensor.to("cpu") + if tensor.ndimension() >= 4 and (device is None or tensor.device.type == "cpu"): + tensors[name] = tensor.to("cpu", memory_format=torch.channels_last) + else: + tensors[name] = tensor.to("cpu") save_file(tensors, str(output_path)) diff --git a/scripts/dequantize_fp8_model.py b/scripts/dequantize_fp8_model.py index e0c2c8a5f..bad8d474d 100755 --- a/scripts/dequantize_fp8_model.py +++ b/scripts/dequantize_fp8_model.py @@ -6,6 +6,11 @@ from __future__ import annotations +import os +os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" +os.environ["PYTORCH_ALLOC_CONF"] = "expandable_segments:True,max_split_size_mb:256,garbage_collection_threshold:0.7" #"expandable_segments:True" + + import argparse from pathlib import Path