From 511d744fb961bf285d0ac6da99bc54da4d62c841 Mon Sep 17 00:00:00 2001 From: Qubitium Date: Tue, 30 Sep 2025 09:10:52 +0000 Subject: [PATCH 1/4] save meta files directly without load to cpu Signed-off-by: Qubitium --- gptqmodel/looper/awq_processor.py | 3 - gptqmodel/looper/gptq_processor.py | 3 - gptqmodel/looper/qqq_processor.py | 5 - gptqmodel/models/writer.py | 197 +++++++-------- gptqmodel/utils/model.py | 376 ++++++++++++++++++++++++++++- 5 files changed, 449 insertions(+), 135 deletions(-) diff --git a/gptqmodel/looper/awq_processor.py b/gptqmodel/looper/awq_processor.py index eae540407..3fe3465b4 100644 --- a/gptqmodel/looper/awq_processor.py +++ b/gptqmodel/looper/awq_processor.py @@ -29,7 +29,6 @@ from ..quantization.config import FORMAT, METHOD, QuantizeConfig from ..utils.logger import setup_logger from ..utils.model import get_module_by_name_prefix, move_to -from ..utils.offload import undo_offload_to_disk from ..utils.torch import CPU, torch_sync log = setup_logger() @@ -779,8 +778,6 @@ def finalize(self, model: BaseQModel, **kwargs): if self.stream: torch_sync() - model.model = undo_offload_to_disk(module=model.model, include_buffers=True, delete_offload_folders=True) - if model.quantize_config.format == FORMAT.GEMM: model.qlinear_kernel = AwqGEMMQuantLinear elif model.quantize_config.format == FORMAT.GEMV: diff --git a/gptqmodel/looper/gptq_processor.py b/gptqmodel/looper/gptq_processor.py index c45f3e1c6..e8084a0f0 100644 --- a/gptqmodel/looper/gptq_processor.py +++ b/gptqmodel/looper/gptq_processor.py @@ -21,7 +21,6 @@ from ..utils.importer import select_quant_linear from ..utils.logger import setup_logger from ..utils.model import create_quant_module, find_modules, move_to, pack_model, pack_module -from ..utils.offload import undo_offload_to_disk from ..utils.torch import HAS_CUDA, torch_streamCtx, torch_sync log = setup_logger() @@ -265,8 +264,6 @@ def finalize(self, model: BaseQModel, **kwargs): if self.stream: torch_sync() - model.model = undo_offload_to_disk(module=model.model, include_buffers=True, delete_offload_folders=True) - # print("finalize") # print_module_tree(model.model) diff --git a/gptqmodel/looper/qqq_processor.py b/gptqmodel/looper/qqq_processor.py index 9f7220a49..f2fe748de 100644 --- a/gptqmodel/looper/qqq_processor.py +++ b/gptqmodel/looper/qqq_processor.py @@ -20,7 +20,6 @@ from ..quantization.qqq import QQQ from ..utils.logger import setup_logger from ..utils.model import create_quant_module, find_modules, move_to, pack_model, pack_module -from ..utils.offload import undo_offload_to_disk from ..utils.torch import CPU, DEVICE_0, torch_streamCtx, torch_sync log = setup_logger() @@ -249,10 +248,6 @@ def finalize(self, model: BaseQModel, **kwargs): if self.stream: torch_sync() - model.model = undo_offload_to_disk(module=model.model, include_buffers=True, delete_offload_folders=True) - # print("finalize") - # print_module_tree(model.model) - # set quantized state model.quantized = True diff --git a/gptqmodel/models/writer.py b/gptqmodel/models/writer.py index d7ccb9e7e..8f4dff152 100644 --- a/gptqmodel/models/writer.py +++ b/gptqmodel/models/writer.py @@ -6,19 +6,15 @@ from __future__ import annotations import copy -import csv import json import os import re from os.path import isfile, join -from typing import Dict, Optional, Union +import csv +from typing import Any, Dict, Optional, Union import torch import transformers -from huggingface_hub import split_torch_state_dict_into_shards -from huggingface_hub.constants import SAFETENSORS_WEIGHTS_FILE_PATTERN -from safetensors.torch import save_file -from safetensors.torch import save_file as safe_save from transformers import AutoConfig, PreTrainedTokenizerFast, ProcessorMixin from transformers.modeling_utils import no_init_weights from transformers.models.auto.tokenization_auto import get_tokenizer_config @@ -52,6 +48,7 @@ get_state_dict_for_save, load_checkpoint_in_model_then_tie_weights, make_quant, + streaming_state_dict_to_shards, ) from ..utils.structure import alias_all_from_turtle_if_meta from ..utils.torch import torch_empty_cache @@ -280,124 +277,102 @@ def debug_saved_config(path): if not self.load_quantized_model: alias_all_from_turtle_if_meta(shell_model=model, turtle_model=self.turtle_model) - state_dict = get_state_dict_for_save(model) + offload_root = self.quantize_config.offload_to_disk_path if getattr(self.quantize_config, "offload_to_disk", False) else None + state_dict = get_state_dict_for_save(model, offload_root=offload_root) model_base_name = "model" - - state_dict = {k: v.clone().contiguous() for k, v in state_dict.items()} model_save_name = model_base_name + ".safetensors" if not self.qlinear_kernel.SUPPORTS_SHARDS and max_shard_size is not None: log.warn("Sharding is not supported for this quant. Disabling sharding.") max_shard_size = None - if max_shard_size is None: - if safetensors_metadata is None: - safetensors_metadata = {} - elif not isinstance(safetensors_metadata, dict): + def _parse_max_shard_size(value: Optional[Union[int, str]]) -> Optional[int]: + if value is None: + return None + if isinstance(value, int): + return value + match = re.fullmatch(r"\s*(\d+)([KMGTP]?B?)\s*", value, re.IGNORECASE) + if not match: + raise ValueError(f"Invalid max_shard_size value: {value}") + base = int(match.group(1)) + suffix = match.group(2).upper() + multiplier = 1 + if suffix.startswith("K"): + multiplier = 1024 + elif suffix.startswith("M"): + multiplier = 1024 ** 2 + elif suffix.startswith("G"): + multiplier = 1024 ** 3 + elif suffix.startswith("T"): + multiplier = 1024 ** 4 + elif suffix.startswith("P"): + multiplier = 1024 ** 5 + return base * multiplier + + def _normalize_metadata(meta: Optional[Dict[str, Any]]) -> Dict[str, str]: + if meta is None: + return {} + if not isinstance(meta, dict): raise TypeError("safetensors_metadata must be a dictionary.") - else: - log.debug(f"Received safetensors_metadata: {safetensors_metadata}") - new_safetensors_metadata = {} - converted_keys = False - for key, value in safetensors_metadata.items(): - if not isinstance(key, str) or not isinstance(value, str): - converted_keys = True - try: - new_key = str(key) - new_value = str(value) - except Exception as e: - raise TypeError( - f"safetensors_metadata: both keys and values must be strings and an error occured when trying to convert them: {e}" - ) - if new_key in new_safetensors_metadata: - log.warn( - f"After converting safetensors_metadata keys to strings, the key '{new_key}' is duplicated. Ensure that all your metadata keys are strings to avoid overwriting." - ) - new_safetensors_metadata[new_key] = new_value - safetensors_metadata = new_safetensors_metadata - if converted_keys: - log.debug( - f"One or more safetensors_metadata keys or values had to be converted to str(). Final safetensors_metadata: {safetensors_metadata}" + normalized: Dict[str, str] = {} + for key, value in meta.items(): + try: + new_key = str(key) + new_value = str(value) + except Exception as exc: + raise TypeError( + f"safetensors_metadata: both keys and values must be strings and conversion failed for ({key}, {value}): {exc}" + ) + if new_key in normalized: + log.warn( + f"Duplicate metadata key '{new_key}' after conversion to string; overwriting previous value." ) + normalized[new_key] = new_value + return normalized + + max_shard_size_bytes = _parse_max_shard_size(max_shard_size) + metadata_dict = _normalize_metadata(safetensors_metadata) + metadata_dict["format"] = "pt" + + expected_files, tensor_to_filename, total_size_bytes = streaming_state_dict_to_shards( + state_dict, + save_dir=save_dir, + model_base_name=model_base_name, + single_file_name=model_save_name, + metadata=metadata_dict, + max_shard_size=max_shard_size_bytes, + ) - # Format is required to enable Accelerate to load the metadata - # otherwise it raises an OSError - safetensors_metadata["format"] = "pt" - safe_save(state_dict, join(save_dir, model_save_name), safetensors_metadata) - total_size_mb = os.path.getsize(join(save_dir, model_save_name)) / (1024 * 1024) + pattern = re.compile(rf"{re.escape(model_base_name)}-\d{{5}}-of-\d{{5}}\.safetensors") + for filename in os.listdir(save_dir): + full_filename = join(save_dir, filename) + if not isfile(full_filename): + continue + if filename == model_save_name and filename not in expected_files: + os.remove(full_filename) + continue + if pattern.fullmatch(filename) and filename not in expected_files: + os.remove(full_filename) + + total_size_mb = total_size_bytes / (1024 * 1024) + + if len(expected_files) > 1: + index = { + "metadata": {"total_size": total_size_bytes}, + "weight_map": tensor_to_filename, + } + index_save_name = model_save_name + ".index.json" + index_save_path = join(save_dir, index_save_name) + with open(index_save_path, "w", encoding="utf-8") as f: + content = json.dumps(index, indent=2, sort_keys=True) + "\n" + f.write(content) else: - file_name_pattern = SAFETENSORS_WEIGHTS_FILE_PATTERN + index_save_path = join(save_dir, model_save_name + ".index.json") + if os.path.exists(index_save_path): + os.remove(index_save_path) - # Shard checkpoint - state_dict_split= split_torch_state_dict_into_shards(state_dict, max_shard_size=max_shard_size, filename_pattern=file_name_pattern) - - # Clean the folder from a previous save - for filename in os.listdir(save_dir): - full_filename = join(save_dir, filename) - - # make sure that file to be deleted matches format of sharded file, e.g. pytorch_model-00001-of-00005 - filename_no_suffix = filename.replace(".bin", "").replace(".safetensors", "") - reg = re.compile(r"(.*?)-\d{5}-of-\d{5}") - - if ( - filename.startswith(model_base_name) - and isfile(full_filename) - and filename not in state_dict_split.filename_to_tensors.keys() - and reg.fullmatch(filename_no_suffix) is not None - ): - os.remove(full_filename) - - total_size_mb = 0 - # Save the model - for filename, tensors in state_dict_split.filename_to_tensors.items(): - shard = {tensor: state_dict[tensor] for tensor in tensors} - if safetensors_metadata is None: - safetensors_metadata = {} - elif not isinstance(safetensors_metadata, dict): - raise TypeError("safetensors_metadata must be a dictionary.") - else: - log.debug(f"Received safetensors_metadata: {safetensors_metadata}") - new_safetensors_metadata = {} - converted_keys = False - for key, value in safetensors_metadata.items(): - if not isinstance(key, str) or not isinstance(value, str): - converted_keys = True - try: - new_key = str(key) - new_value = str(value) - except Exception as e: - raise TypeError( - f"safetensors_metadata: both keys and values must be strings and an error occured when trying to convert them: {e}") - if new_key in new_safetensors_metadata: - log.warn( - f"After converting safetensors_metadata keys to strings, the key '{new_key}' is duplicated. Ensure that all your metadata keys are strings to avoid overwriting.") - new_safetensors_metadata[new_key] = new_value - safetensors_metadata = new_safetensors_metadata - if converted_keys: - log.debug( - f"One or more safetensors_metadata keys or values had to be converted to str(). Final safetensors_metadata: {safetensors_metadata}") - - # Format is required to enable Accelerate to load the metadata - # otherwise it raises an OSError - safetensors_metadata["format"] = "pt" - - safe_save(shard, join(save_dir, filename), safetensors_metadata) - shard_size_mb = os.path.getsize(join(save_dir, filename)) / (1024 * 1024) - total_size_mb += shard_size_mb - - if state_dict_split.is_sharded: - index = { - "metadata": state_dict_split.metadata, - "weight_map": state_dict_split.tensor_to_filename, - } - - index_save_name = model_save_name + ".index.json" - index_save_path = join(save_dir, index_save_name) - # Save the index as well - with open(index_save_path, "w", encoding="utf-8") as f: - content = json.dumps(index, indent=2, sort_keys=True) + "\n" - f.write(content) + state_dict.clear() # save lora if self.quantize_config.adapter: diff --git a/gptqmodel/utils/model.py b/gptqmodel/utils/model.py index a2fb302a8..8ea149804 100644 --- a/gptqmodel/utils/model.py +++ b/gptqmodel/utils/model.py @@ -8,13 +8,16 @@ import collections import functools import json +import math import operator import os import re import shutil +import struct import threading import time from concurrent.futures import ThreadPoolExecutor +from dataclasses import dataclass from enum import Enum from typing import Any, Dict, List, Optional, Tuple, Type, Union @@ -23,6 +26,7 @@ import torch import torch.nn as nn import transformers +from safetensors import safe_open from huggingface_hub import HfApi, hf_hub_download from packaging import version from torch.nn.modules.conv import _ConvNd @@ -58,6 +62,85 @@ log = setup_logger() + +_DTYPE_SAFE_MAP = { + torch.float32: ("F32", 4), + torch.float16: ("F16", 2), + torch.float64: ("F64", 8), + torch.bfloat16: ("BF16", 2), + torch.int64: ("I64", 8), + torch.int32: ("I32", 4), + torch.int16: ("I16", 2), + torch.int8: ("I8", 1), + torch.uint8: ("U8", 1), + torch.bool: ("BOOL", 1), +} + + +_DTYPE_STR_MAP = { + "float32": torch.float32, + "float": torch.float32, + "float16": torch.float16, + "half": torch.float16, + "float64": torch.float64, + "double": torch.float64, + "bfloat16": torch.bfloat16, + "int64": torch.int64, + "long": torch.int64, + "int32": torch.int32, + "int": torch.int32, + "int16": torch.int16, + "short": torch.int16, + "int8": torch.int8, + "uint8": torch.uint8, + "bool": torch.bool, +} + + +def _torch_dtype_num_bytes(dtype: torch.dtype) -> int: + if dtype not in _DTYPE_SAFE_MAP: + raise NotImplementedError(f"Unsupported dtype for safetensors export: {dtype}") + return _DTYPE_SAFE_MAP[dtype][1] + + +def _torch_dtype_to_safetensors(dtype: torch.dtype) -> str: + if dtype not in _DTYPE_SAFE_MAP: + raise NotImplementedError(f"Unsupported dtype for safetensors export: {dtype}") + return _DTYPE_SAFE_MAP[dtype][0] + + +def _dtype_string_to_torch(dtype_str: Optional[str], fallback: torch.dtype) -> torch.dtype: + if dtype_str is None: + return fallback + key = dtype_str.lower() + return _DTYPE_STR_MAP.get(key, fallback) + + +@dataclass(frozen=True) +class OffloadTensorRef: + path: str + torch_dtype: torch.dtype + shape: Tuple[int, ...] + format: str # 'dat' or 'safetensors' + weight_name: Optional[str] = None + data_offsets: Optional[Tuple[int, int]] = None + + @property + def num_bytes(self) -> int: + return _torch_dtype_num_bytes(self.torch_dtype) * math.prod(self.shape or (1,)) + + +@dataclass +class TensorSource: + name: str + torch_dtype: torch.dtype + shape: Tuple[int, ...] + source: Union[torch.Tensor, OffloadTensorRef] + + @property + def num_bytes(self) -> int: + return _torch_dtype_num_bytes(self.torch_dtype) * math.prod(self.shape or (1,)) + def recurse_getattr(obj, attr: str): """ Recursive `getattr`. @@ -1073,29 +1156,164 @@ class MODALITY(str, Enum): # TEXT_TO_IMAGE = "text_to_image" -def get_state_dict_for_save(model: nn.Module) -> Dict: +def _split_parameter_path(full_name: str) -> Tuple[str, str]: + if "." in full_name: + module_path, leaf = full_name.rsplit(".", 1) + else: + module_path, leaf = "", full_name + return module_path, leaf + + +def _resolve_offload_entry( + offload_root: str, + module_path: str, + leaf: str, + dtype: torch.dtype, + shape_hint: Tuple[int, ...], + index_cache: Dict[str, Optional[Dict]], +) -> Optional[OffloadTensorRef]: + if not offload_root: + return None + + module_dir = os.path.join(offload_root, module_path) if module_path else offload_root + index = index_cache.get(module_dir) + if index is None: + index_path = os.path.join(module_dir, "index.json") + if not os.path.isfile(index_path): + index_cache[module_dir] = None + return None + with open(index_path, "r", encoding="utf-8") as fh: + index = json.load(fh) + index_cache[module_dir] = index + + if not index: + return None + + entry = index.get(leaf) or index.get(f"{module_path}.{leaf}") + if entry is None: + return None + + resolved_dtype = _dtype_string_to_torch(entry.get("dtype"), dtype) + if "shape" in entry: + shape = tuple(entry["shape"]) + else: + shape = shape_hint + + safetensors_file = entry.get("safetensors_file") + if safetensors_file: + path = safetensors_file + if not os.path.isabs(path): + path = os.path.join(module_dir, path) + offsets = entry.get("data_offsets") + if offsets is not None: + offsets = tuple(int(x) for x in offsets) + return OffloadTensorRef( + path=os.path.abspath(path), + torch_dtype=resolved_dtype, + shape=shape, + format="safetensors", + weight_name=entry.get("weight_name", leaf), + data_offsets=offsets, + ) + + data_path = os.path.join(module_dir, f"{leaf}.dat") + if not os.path.isfile(data_path): + return None + + return OffloadTensorRef( + path=os.path.abspath(data_path), + torch_dtype=resolved_dtype, + shape=shape, + format="dat", + weight_name=None, + data_offsets=None, + ) + + +def _collect_state_dict_with_offload(model: nn.Module, offload_root: str) -> Dict[str, TensorSource]: + state_dict: Dict[str, TensorSource] = collections.OrderedDict() + index_cache: Dict[str, Optional[Dict]] = {} + + for name, param in model.named_parameters(): + module_path, leaf = _split_parameter_path(name) + source = None + if getattr(param, "is_meta", False) or param.device.type == "meta": + source = _resolve_offload_entry( + offload_root, + module_path, + leaf, + param.dtype, + tuple(param.shape), + index_cache, + ) + if source is None: + raise FileNotFoundError( + f"Offloaded tensor '{name}' not found in offload directory '{offload_root}'." + ) + else: + source = param + state_dict[name] = TensorSource(name=name, torch_dtype=param.dtype, shape=tuple(param.shape), source=source) + + for name, buf in model.named_buffers(): + if name in state_dict: + continue + module_path, leaf = _split_parameter_path(name) + if getattr(buf, "is_meta", False) or buf.device.type == "meta": + source = _resolve_offload_entry( + offload_root, + module_path, + leaf, + buf.dtype, + tuple(buf.shape), + index_cache, + ) + if source is None: + raise FileNotFoundError( + f"Offloaded buffer '{name}' not found in offload directory '{offload_root}'." + ) + else: + source = buf + state_dict[name] = TensorSource(name=name, torch_dtype=buf.dtype, shape=tuple(buf.shape), source=source) + + return state_dict + + +def get_state_dict_for_save(model: nn.Module, offload_root: Optional[str] = None) -> Dict[str, TensorSource]: """ Filter weight-sharing tensors. Referenced from transformers.modeling_utils.PreTrainedModel.save_pretrained. See https://github.com/huggingface/transformers/blob/v4.38.2/src/transformers/modeling_utils.py#L2369 """ + if offload_root: + state_dict = _collect_state_dict_with_offload(model, offload_root) + else: + state_dict = collections.OrderedDict() + for name, param in model.named_parameters(): + state_dict[name] = TensorSource(name=name, torch_dtype=param.dtype, shape=tuple(param.shape), source=param) + for name, buf in model.named_buffers(): + if name in state_dict: + continue + state_dict[name] = TensorSource(name=name, torch_dtype=buf.dtype, shape=tuple(buf.shape), source=buf) - state_dict = model.state_dict() - - # Safetensors does not allow tensor aliasing. - # We're going to remove aliases before saving ptrs = collections.defaultdict(list) - for name, tensor in state_dict.items(): - # Sometimes in the state_dict we have non-tensor objects. - # e.g. in bitsandbytes we have some `str` objects in the state_dict - if isinstance(tensor, torch.Tensor): - ptrs[id_tensor_storage(tensor)].append(name) + for name, entry in state_dict.items(): + source = entry.source + if isinstance(source, OffloadTensorRef): + key = ("offload", source.path, source.weight_name or name, source.data_offsets) + elif isinstance(source, torch.Tensor): + tensor = source + if getattr(tensor, "is_meta", False) or tensor.device.type == "meta": + key = ("meta", id(tensor)) + else: + try: + key = ("storage", id_tensor_storage(tensor)) + except Exception: + key = ("tensor", id(tensor)) else: - # In the non-tensor case, fall back to the pointer of the object itself - ptrs[id(tensor)].append(name) + key = ("other", id(source)) + ptrs[key].append(name) - # These are all the pointers of shared tensors. shared_ptrs = {ptr: names for ptr, names in ptrs.items() if len(names) > 1} warn_names = set() for names in shared_ptrs.values(): @@ -1134,6 +1352,138 @@ def load_checkpoint_in_model_then_tie_weights(model, *args, **kwargs): model.tie_weights() +_STREAM_BUFFER_BYTES = 128 * 1024 * 1024 + + +def _copy_file_stream(src_path: str, dst_fh, length: int, *, offset: int = 0, buffer_size: int = _STREAM_BUFFER_BYTES) -> None: + with open(src_path, "rb") as src: + if offset: + src.seek(offset) + remaining = length + while remaining > 0: + chunk = src.read(min(buffer_size, remaining)) + if not chunk: + raise IOError(f"Unexpected EOF while copying from {src_path}") + dst_fh.write(chunk) + remaining -= len(chunk) + + +def _write_tensor_bytes(out, tensor: torch.Tensor, dtype: torch.dtype) -> None: + if tensor.device.type != "cpu": + tensor = tensor.to("cpu") + if not tensor.is_contiguous(): + tensor = tensor.contiguous() + if dtype is torch.bfloat16: + view = tensor.view(torch.int16) + out.write(view.numpy().tobytes()) + else: + out.write(tensor.numpy().tobytes()) + + +def _write_shard_file(path: str, entries: List[TensorSource], metadata: Dict[str, str]) -> int: + header: Dict[str, Any] = {} + if metadata: + header["__metadata__"] = metadata + + offset = 0 + for entry in entries: + header[entry.name] = { + "dtype": _torch_dtype_to_safetensors(entry.torch_dtype), + "shape": list(entry.shape), + "data_offsets": [offset, offset + entry.num_bytes], + } + offset += entry.num_bytes + + header_bytes = json.dumps(header, separators=(",", ":")).encode("utf-8") + + with open(path, "wb") as out: + out.write(struct.pack(" List[List[TensorSource]]: + if not max_shard_size or max_shard_size <= 0: + return [entries] + + shards: List[List[TensorSource]] = [] + current: List[TensorSource] = [] + current_size = 0 + + for entry in entries: + size = entry.num_bytes + if size > max_shard_size: + if current: + shards.append(current) + current = [] + current_size = 0 + shards.append([entry]) + continue + if current_size + size > max_shard_size and current: + shards.append(current) + current = [] + current_size = 0 + current.append(entry) + current_size += size + + if current: + shards.append(current) + + return shards + + +def streaming_state_dict_to_shards( + state_dict: Dict[str, TensorSource], + save_dir: str, + model_base_name: str, + single_file_name: str, + metadata: Dict[str, str], + max_shard_size: Optional[int], +) -> Tuple[List[str], Dict[str, str], int]: + entries = list(state_dict.values()) + shards = _plan_shards(entries, max_shard_size) + num_shards = len(shards) + filenames: List[str] = [] + tensor_to_filename: Dict[str, str] = {} + total_size = 0 + + for idx, shard_entries in enumerate(shards, start=1): + if num_shards == 1: + filename = single_file_name + else: + filename = f"{model_base_name}-{idx:05d}-of-{num_shards:05d}.safetensors" + + path = os.path.join(save_dir, filename) + size = _write_shard_file(path, shard_entries, metadata) + total_size += size + filenames.append(filename) + for entry in shard_entries: + tensor_to_filename[entry.name] = filename + + return filenames, tensor_to_filename, total_size + + def find_config_seq_len(config_dict, target_keys): for k, v in config_dict.items(): if k in target_keys: From 74becd23a79920cc93b068b6e46229ce45532e06 Mon Sep 17 00:00:00 2001 From: Qubitium Date: Tue, 30 Sep 2025 09:13:51 +0000 Subject: [PATCH 2/4] format Signed-off-by: Qubitium --- gptqmodel/models/writer.py | 3 ++- gptqmodel/utils/model.py | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/gptqmodel/models/writer.py b/gptqmodel/models/writer.py index 8f4dff152..597a34d24 100644 --- a/gptqmodel/models/writer.py +++ b/gptqmodel/models/writer.py @@ -6,15 +6,16 @@ from __future__ import annotations import copy +import csv import json import os import re from os.path import isfile, join -import csv from typing import Any, Dict, Optional, Union import torch import transformers +from safetensors.torch import save_file from transformers import AutoConfig, PreTrainedTokenizerFast, ProcessorMixin from transformers.modeling_utils import no_init_weights from transformers.models.auto.tokenization_auto import get_tokenizer_config diff --git a/gptqmodel/utils/model.py b/gptqmodel/utils/model.py index 8ea149804..d8c0dc5af 100644 --- a/gptqmodel/utils/model.py +++ b/gptqmodel/utils/model.py @@ -26,9 +26,9 @@ import torch import torch.nn as nn import transformers -from safetensors import safe_open from huggingface_hub import HfApi, hf_hub_download from packaging import version +from safetensors import safe_open from torch.nn.modules.conv import _ConvNd from transformers import PretrainedConfig from transformers.pytorch_utils import id_tensor_storage From 991879d7e477453618577cae82e2de0a5567c1da Mon Sep 17 00:00:00 2001 From: Qubitium Date: Tue, 30 Sep 2025 09:27:50 +0000 Subject: [PATCH 3/4] clean Signed-off-by: Qubitium --- gptqmodel/utils/model.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/gptqmodel/utils/model.py b/gptqmodel/utils/model.py index d8c0dc5af..44ca2e671 100644 --- a/gptqmodel/utils/model.py +++ b/gptqmodel/utils/model.py @@ -1352,7 +1352,8 @@ def load_checkpoint_in_model_then_tie_weights(model, *args, **kwargs): model.tie_weights() -_STREAM_BUFFER_BYTES = 128 * 1024 * 1024 +# 64MB for io transfer buffer +_STREAM_BUFFER_BYTES = 64 * 1024 * 1024 def _copy_file_stream(src_path: str, dst_fh, length: int, *, offset: int = 0, buffer_size: int = _STREAM_BUFFER_BYTES) -> None: @@ -1369,10 +1370,7 @@ def _copy_file_stream(src_path: str, dst_fh, length: int, *, offset: int = 0, bu def _write_tensor_bytes(out, tensor: torch.Tensor, dtype: torch.dtype) -> None: - if tensor.device.type != "cpu": - tensor = tensor.to("cpu") - if not tensor.is_contiguous(): - tensor = tensor.contiguous() + tensor = tensor.detach().to("cpu").contiguous() if dtype is torch.bfloat16: view = tensor.view(torch.int16) out.write(view.numpy().tobytes()) @@ -1403,6 +1401,7 @@ def _write_shard_file(path: str, entries: List[TensorSource], metadata: Dict[str for entry in entries: source = entry.source if isinstance(source, OffloadTensorRef): + print("offload tesnor io buffered transfer") if source.format == "dat": _copy_file_stream(source.path, out, entry.num_bytes) elif source.format == "safetensors" and source.data_offsets is not None: From ccb79a70601846ceccaa1b06994613403b42927e Mon Sep 17 00:00:00 2001 From: Qubitium Date: Tue, 30 Sep 2025 09:51:52 +0000 Subject: [PATCH 4/4] reuse fully fixed/static buffer Signed-off-by: Qubitium --- gptqmodel/utils/model.py | 37 +++++++++++++++++++++---------------- 1 file changed, 21 insertions(+), 16 deletions(-) diff --git a/gptqmodel/utils/model.py b/gptqmodel/utils/model.py index 44ca2e671..c8b10ea4d 100644 --- a/gptqmodel/utils/model.py +++ b/gptqmodel/utils/model.py @@ -1352,21 +1352,24 @@ def load_checkpoint_in_model_then_tie_weights(model, *args, **kwargs): model.tie_weights() -# 64MB for io transfer buffer -_STREAM_BUFFER_BYTES = 64 * 1024 * 1024 - - -def _copy_file_stream(src_path: str, dst_fh, length: int, *, offset: int = 0, buffer_size: int = _STREAM_BUFFER_BYTES) -> None: - with open(src_path, "rb") as src: - if offset: - src.seek(offset) - remaining = length - while remaining > 0: - chunk = src.read(min(buffer_size, remaining)) - if not chunk: - raise IOError(f"Unexpected EOF while copying from {src_path}") - dst_fh.write(chunk) - remaining -= len(chunk) +# 32MB read/write i/o buffer +_STREAM_BUFFER_SIZE = 32 * 1024 * 1024 +_STREAM_BUFFER = memoryview(bytearray(_STREAM_BUFFER_SIZE)) +_STREAM_BUFFER_LOCK = threading.Lock() + +def _copy_file_stream(src_path: str, dst_fh, length: int, *, offset: int = 0) -> None: + with open(src_path, "rb", buffering=0) as src: + with _STREAM_BUFFER_LOCK: + if offset: + src.seek(offset) + remaining = length + while remaining > 0: + chunk_size = min(_STREAM_BUFFER_SIZE, remaining) + read = src.readinto(_STREAM_BUFFER[:chunk_size]) + if not read: + raise IOError(f"Unexpected EOF while copying from {src_path}") + dst_fh.write(_STREAM_BUFFER[:read]) + remaining -= read def _write_tensor_bytes(out, tensor: torch.Tensor, dtype: torch.dtype) -> None: @@ -1401,13 +1404,15 @@ def _write_shard_file(path: str, entries: List[TensorSource], metadata: Dict[str for entry in entries: source = entry.source if isinstance(source, OffloadTensorRef): - print("offload tesnor io buffered transfer") if source.format == "dat": + # print("offload tensor io buffered transfer DAT") _copy_file_stream(source.path, out, entry.num_bytes) elif source.format == "safetensors" and source.data_offsets is not None: + # print("offload tensor io buffered transfer SAFETENSOR stream") start, end = source.data_offsets _copy_file_stream(source.path, out, end - start, offset=start) else: + # print("offload tensor slow tensor read") with safe_open(source.path, framework="pt", device="cpu") as handler: tensor = handler.get_tensor(source.weight_name or entry.name) tensor = tensor.to(source.torch_dtype)