Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 0 additions & 3 deletions gptqmodel/looper/awq_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@
from ..quantization.config import FORMAT, METHOD, QuantizeConfig
from ..utils.logger import setup_logger
from ..utils.model import get_module_by_name_prefix, move_to
from ..utils.offload import undo_offload_to_disk
from ..utils.torch import CPU, torch_sync

log = setup_logger()
Expand Down Expand Up @@ -779,8 +778,6 @@ def finalize(self, model: BaseQModel, **kwargs):
if self.stream:
torch_sync()

model.model = undo_offload_to_disk(module=model.model, include_buffers=True, delete_offload_folders=True)

if model.quantize_config.format == FORMAT.GEMM:
model.qlinear_kernel = AwqGEMMQuantLinear
elif model.quantize_config.format == FORMAT.GEMV:
Expand Down
3 changes: 0 additions & 3 deletions gptqmodel/looper/gptq_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
from ..utils.importer import select_quant_linear
from ..utils.logger import setup_logger
from ..utils.model import create_quant_module, find_modules, move_to, pack_model, pack_module
from ..utils.offload import undo_offload_to_disk
from ..utils.torch import HAS_CUDA, torch_streamCtx, torch_sync

log = setup_logger()
Expand Down Expand Up @@ -265,8 +264,6 @@ def finalize(self, model: BaseQModel, **kwargs):
if self.stream:
torch_sync()

model.model = undo_offload_to_disk(module=model.model, include_buffers=True, delete_offload_folders=True)

# print("finalize")
# print_module_tree(model.model)

Expand Down
5 changes: 0 additions & 5 deletions gptqmodel/looper/qqq_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
from ..quantization.qqq import QQQ
from ..utils.logger import setup_logger
from ..utils.model import create_quant_module, find_modules, move_to, pack_model, pack_module
from ..utils.offload import undo_offload_to_disk
from ..utils.torch import CPU, DEVICE_0, torch_streamCtx, torch_sync

log = setup_logger()
Expand Down Expand Up @@ -249,10 +248,6 @@ def finalize(self, model: BaseQModel, **kwargs):
if self.stream:
torch_sync()

model.model = undo_offload_to_disk(module=model.model, include_buffers=True, delete_offload_folders=True)
# print("finalize")
# print_module_tree(model.model)

# set quantized state
model.quantized = True

Expand Down
194 changes: 85 additions & 109 deletions gptqmodel/models/writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,11 @@
import os
import re
from os.path import isfile, join
from typing import Dict, Optional, Union
from typing import Any, Dict, Optional, Union

import torch
import transformers
from huggingface_hub import split_torch_state_dict_into_shards
from huggingface_hub.constants import SAFETENSORS_WEIGHTS_FILE_PATTERN
from safetensors.torch import save_file
from safetensors.torch import save_file as safe_save
from transformers import AutoConfig, PreTrainedTokenizerFast, ProcessorMixin
from transformers.modeling_utils import no_init_weights
from transformers.models.auto.tokenization_auto import get_tokenizer_config
Expand Down Expand Up @@ -52,6 +49,7 @@
get_state_dict_for_save,
load_checkpoint_in_model_then_tie_weights,
make_quant,
streaming_state_dict_to_shards,
)
from ..utils.structure import alias_all_from_turtle_if_meta
from ..utils.torch import torch_empty_cache
Expand Down Expand Up @@ -280,124 +278,102 @@ def debug_saved_config(path):
if not self.load_quantized_model:
alias_all_from_turtle_if_meta(shell_model=model, turtle_model=self.turtle_model)

state_dict = get_state_dict_for_save(model)
offload_root = self.quantize_config.offload_to_disk_path if getattr(self.quantize_config, "offload_to_disk", False) else None
state_dict = get_state_dict_for_save(model, offload_root=offload_root)

model_base_name = "model"

state_dict = {k: v.clone().contiguous() for k, v in state_dict.items()}
model_save_name = model_base_name + ".safetensors"

if not self.qlinear_kernel.SUPPORTS_SHARDS and max_shard_size is not None:
log.warn("Sharding is not supported for this quant. Disabling sharding.")
max_shard_size = None

if max_shard_size is None:
if safetensors_metadata is None:
safetensors_metadata = {}
elif not isinstance(safetensors_metadata, dict):
def _parse_max_shard_size(value: Optional[Union[int, str]]) -> Optional[int]:
if value is None:
return None
if isinstance(value, int):
return value
match = re.fullmatch(r"\s*(\d+)([KMGTP]?B?)\s*", value, re.IGNORECASE)
if not match:
raise ValueError(f"Invalid max_shard_size value: {value}")
base = int(match.group(1))
suffix = match.group(2).upper()
multiplier = 1
if suffix.startswith("K"):
multiplier = 1024
elif suffix.startswith("M"):
multiplier = 1024 ** 2
elif suffix.startswith("G"):
multiplier = 1024 ** 3
elif suffix.startswith("T"):
multiplier = 1024 ** 4
elif suffix.startswith("P"):
multiplier = 1024 ** 5
return base * multiplier

def _normalize_metadata(meta: Optional[Dict[str, Any]]) -> Dict[str, str]:
if meta is None:
return {}
if not isinstance(meta, dict):
raise TypeError("safetensors_metadata must be a dictionary.")
else:
log.debug(f"Received safetensors_metadata: {safetensors_metadata}")
new_safetensors_metadata = {}
converted_keys = False
for key, value in safetensors_metadata.items():
if not isinstance(key, str) or not isinstance(value, str):
converted_keys = True
try:
new_key = str(key)
new_value = str(value)
except Exception as e:
raise TypeError(
f"safetensors_metadata: both keys and values must be strings and an error occured when trying to convert them: {e}"
)
if new_key in new_safetensors_metadata:
log.warn(
f"After converting safetensors_metadata keys to strings, the key '{new_key}' is duplicated. Ensure that all your metadata keys are strings to avoid overwriting."
)
new_safetensors_metadata[new_key] = new_value
safetensors_metadata = new_safetensors_metadata
if converted_keys:
log.debug(
f"One or more safetensors_metadata keys or values had to be converted to str(). Final safetensors_metadata: {safetensors_metadata}"
normalized: Dict[str, str] = {}
for key, value in meta.items():
try:
new_key = str(key)
new_value = str(value)
except Exception as exc:
raise TypeError(
f"safetensors_metadata: both keys and values must be strings and conversion failed for ({key}, {value}): {exc}"
)
if new_key in normalized:
log.warn(
f"Duplicate metadata key '{new_key}' after conversion to string; overwriting previous value."
)
normalized[new_key] = new_value
return normalized

max_shard_size_bytes = _parse_max_shard_size(max_shard_size)
metadata_dict = _normalize_metadata(safetensors_metadata)
metadata_dict["format"] = "pt"

expected_files, tensor_to_filename, total_size_bytes = streaming_state_dict_to_shards(
state_dict,
save_dir=save_dir,
model_base_name=model_base_name,
single_file_name=model_save_name,
metadata=metadata_dict,
max_shard_size=max_shard_size_bytes,
)

# Format is required to enable Accelerate to load the metadata
# otherwise it raises an OSError
safetensors_metadata["format"] = "pt"
safe_save(state_dict, join(save_dir, model_save_name), safetensors_metadata)
total_size_mb = os.path.getsize(join(save_dir, model_save_name)) / (1024 * 1024)
pattern = re.compile(rf"{re.escape(model_base_name)}-\d{{5}}-of-\d{{5}}\.safetensors")
for filename in os.listdir(save_dir):
full_filename = join(save_dir, filename)
if not isfile(full_filename):
continue
if filename == model_save_name and filename not in expected_files:
os.remove(full_filename)
continue
if pattern.fullmatch(filename) and filename not in expected_files:
os.remove(full_filename)

total_size_mb = total_size_bytes / (1024 * 1024)

if len(expected_files) > 1:
index = {
"metadata": {"total_size": total_size_bytes},
"weight_map": tensor_to_filename,
}
index_save_name = model_save_name + ".index.json"
index_save_path = join(save_dir, index_save_name)
with open(index_save_path, "w", encoding="utf-8") as f:
content = json.dumps(index, indent=2, sort_keys=True) + "\n"
f.write(content)
else:
file_name_pattern = SAFETENSORS_WEIGHTS_FILE_PATTERN

# Shard checkpoint
state_dict_split= split_torch_state_dict_into_shards(state_dict, max_shard_size=max_shard_size, filename_pattern=file_name_pattern)
index_save_path = join(save_dir, model_save_name + ".index.json")
if os.path.exists(index_save_path):
os.remove(index_save_path)

# Clean the folder from a previous save
for filename in os.listdir(save_dir):
full_filename = join(save_dir, filename)

# make sure that file to be deleted matches format of sharded file, e.g. pytorch_model-00001-of-00005
filename_no_suffix = filename.replace(".bin", "").replace(".safetensors", "")
reg = re.compile(r"(.*?)-\d{5}-of-\d{5}")

if (
filename.startswith(model_base_name)
and isfile(full_filename)
and filename not in state_dict_split.filename_to_tensors.keys()
and reg.fullmatch(filename_no_suffix) is not None
):
os.remove(full_filename)

total_size_mb = 0
# Save the model
for filename, tensors in state_dict_split.filename_to_tensors.items():
shard = {tensor: state_dict[tensor] for tensor in tensors}
if safetensors_metadata is None:
safetensors_metadata = {}
elif not isinstance(safetensors_metadata, dict):
raise TypeError("safetensors_metadata must be a dictionary.")
else:
log.debug(f"Received safetensors_metadata: {safetensors_metadata}")
new_safetensors_metadata = {}
converted_keys = False
for key, value in safetensors_metadata.items():
if not isinstance(key, str) or not isinstance(value, str):
converted_keys = True
try:
new_key = str(key)
new_value = str(value)
except Exception as e:
raise TypeError(
f"safetensors_metadata: both keys and values must be strings and an error occured when trying to convert them: {e}")
if new_key in new_safetensors_metadata:
log.warn(
f"After converting safetensors_metadata keys to strings, the key '{new_key}' is duplicated. Ensure that all your metadata keys are strings to avoid overwriting.")
new_safetensors_metadata[new_key] = new_value
safetensors_metadata = new_safetensors_metadata
if converted_keys:
log.debug(
f"One or more safetensors_metadata keys or values had to be converted to str(). Final safetensors_metadata: {safetensors_metadata}")

# Format is required to enable Accelerate to load the metadata
# otherwise it raises an OSError
safetensors_metadata["format"] = "pt"

safe_save(shard, join(save_dir, filename), safetensors_metadata)
shard_size_mb = os.path.getsize(join(save_dir, filename)) / (1024 * 1024)
total_size_mb += shard_size_mb

if state_dict_split.is_sharded:
index = {
"metadata": state_dict_split.metadata,
"weight_map": state_dict_split.tensor_to_filename,
}

index_save_name = model_save_name + ".index.json"
index_save_path = join(save_dir, index_save_name)
# Save the index as well
with open(index_save_path, "w", encoding="utf-8") as f:
content = json.dumps(index, indent=2, sort_keys=True) + "\n"
f.write(content)
state_dict.clear()

# save lora
if self.quantize_config.adapter:
Expand Down
Loading