Skip to content

Commit

Permalink
Add checkpoint_format to unify all current/future saved weight formats (
Browse files Browse the repository at this point in the history
#603)

* Add checkpoint_format to unify all current/future saved weight formats

* some nit

---------

Co-authored-by: fxmarty <9808326+fxmarty@users.noreply.github.com>
  • Loading branch information
Qubitium and fxmarty committed Mar 28, 2024
1 parent ff3dcc4 commit 866b4c8
Show file tree
Hide file tree
Showing 9 changed files with 417 additions and 252 deletions.
231 changes: 46 additions & 185 deletions auto_gptq/modeling/_base.py

Large diffs are not rendered by default.

12 changes: 7 additions & 5 deletions auto_gptq/modeling/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def make_quant(
desc_act=desc_act,
group_size=group_size,
bits=bits,
disable_marlin=not use_marlin,
use_marlin=use_marlin,
disable_exllama=disable_exllama,
disable_exllamav2=disable_exllamav2,
use_qigen=use_qigen,
Expand Down Expand Up @@ -263,7 +263,8 @@ def pack_model(
desc_act=False,
warmup_triton: bool = False,
force_layer_back_to_cpu: bool = False,
is_marlin_format: bool = False,
use_marlin: bool = False,
use_tritonv2: bool = False,
):
QuantLinear = dynamically_import_QuantLinear(
use_triton=use_triton,
Expand All @@ -272,7 +273,8 @@ def pack_model(
bits=bits,
disable_exllama=False,
disable_exllamav2=True,
disable_marlin=not is_marlin_format,
use_marlin=use_marlin,
use_tritonv2=use_tritonv2,
)

if force_layer_back_to_cpu:
Expand All @@ -291,7 +293,7 @@ def pack_model(
desc_act=desc_act,
disable_exllama=False,
disable_exllamav2=True,
use_marlin=is_marlin_format,
use_marlin=use_marlin,
)
qlayers = find_layers(model, [QuantLinear])

Expand Down Expand Up @@ -506,7 +508,7 @@ def autogptq_post_init(model, use_act_order: bool, max_input_length: Optional[in
def make_sure_no_tensor_in_meta_device(
model, use_triton: bool, desc_act: bool, group_size: int, bits: int, disable_exllama: bool, disable_exllamav2: bool, use_marlin: bool = False, use_tritonv2: bool = False,
):
QuantLinear = dynamically_import_QuantLinear(use_triton, desc_act, group_size, bits=bits, disable_exllama=disable_exllama, disable_exllamav2=disable_exllamav2, disable_marlin=not use_marlin, use_tritonv2=use_tritonv2)
QuantLinear = dynamically_import_QuantLinear(use_triton, desc_act, group_size, bits=bits, disable_exllama=disable_exllama, disable_exllamav2=disable_exllamav2, use_marlin=use_marlin, use_tritonv2=use_tritonv2)
for n, m in model.named_modules():
if isinstance(m, QuantLinear) and m.bias.device == torch.device("meta"):
m.register_buffer("bias", torch.zeros((m.outfeatures), dtype=torch.float16, device="cpu"))
Expand Down
9 changes: 9 additions & 0 deletions auto_gptq/quantization/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,11 @@
from .config import (
CHECKPOINT_FORMAT,
CHECKPOINT_FORMAT_FIELD,
CHECKPOINT_FORMAT_FIELD_COMPAT_MARLIN,
QUANT_CONFIG_FILENAME,
QUANT_METHOD,
QUANT_METHOD_FIELD,
BaseQuantizeConfig,
)
from .gptq import GPTQ
from .quantizer import Quantizer, quantize
258 changes: 258 additions & 0 deletions auto_gptq/quantization/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,258 @@
import json
import logging
import os
from dataclasses import dataclass, field, fields
from os.path import isdir, join
from typing import Optional

import huggingface_hub
from transformers.utils.hub import PushToHubMixin, cached_file


logger = logging.getLogger(__name__)
handler = logging.StreamHandler()
formatter = logging.Formatter("%(levelname)s - %(message)s")
handler.setFormatter(formatter)
logger.propagate = False
logger.addHandler(handler)
logger.setLevel(logging.INFO)

CHECKPOINT_FORMAT_FIELD = "checkpoint_format"
CHECKPOINT_FORMAT_FIELD_COMPAT_MARLIN = "is_marlin_format"
QUANT_METHOD_FIELD = "quant_method"
QUANT_CONFIG_FILENAME = "quantize_config.json"


# checkpoint formats
class CHECKPOINT_FORMAT:
GPTQ = "gptq"
MARLIN = "marlin"
AWQ_GEMM = "gemm"


# quant methods
class QUANT_METHOD:
GPTQ = "gptq"
AWQ = "awq"


QUANT_METHOD_FORMAT_MAPPING = {
QUANT_METHOD.GPTQ: {
CHECKPOINT_FORMAT.GPTQ,
CHECKPOINT_FORMAT.MARLIN,
},
QUANT_METHOD.AWQ: {
CHECKPOINT_FORMAT.AWQ_GEMM
}
}

# awq is inference only
QUANTIZE_BLACK_LIST = {QUANT_METHOD.AWQ}

# compat
QUANT_CONFIG_ARG_SYNONYMS = {
"w_bit": "bits",
"q_group_size": "group_size",
}


@dataclass
class BaseQuantizeConfig(PushToHubMixin):
bits: int = field(default=4, metadata={"choices": [2, 3, 4, 8]})
group_size: int = field(default=-1)
damp_percent: float = field(default=0.01)
desc_act: bool = field(default=True)
static_groups: bool = field(default=False)
sym: bool = field(default=True)
true_sequential: bool = field(default=True)
quant_method: str = field(default=QUANT_METHOD.GPTQ)
checkpoint_format: str = field(default=CHECKPOINT_FORMAT.GPTQ)
model_name_or_path: Optional[str] = field(default=None)
model_file_base_name: Optional[str] = field(default=None)

def __post_init__(self):
fields_info = fields(self)

# validate quant method and format is matched
valid_checkpoint_formats = QUANT_METHOD_FORMAT_MAPPING.get(self.quant_method, None)
if valid_checkpoint_formats is None:
raise ValueError(f"Unsupported quantization method: {self.quant_method}")

if self.checkpoint_format not in valid_checkpoint_formats:
raise ValueError(
f"The checkpoint format used is {self.checkpoint_format}, and the quantization method is {self.quant_method}. "
f"This is not supported, please open an issue at https://github.com/AutoGPTQ/AutoGPTQ/issues.")

if self.bits not in fields_info[0].metadata["choices"]:
raise ValueError(f"only support quantize to {fields_info[0].metadata['choices']} bits.")

if self.group_size != -1 and self.group_size <= 0:
raise ValueError("unless equal to -1, group_size must greater then 0.")

if not (0 < self.damp_percent < 1):
raise ValueError("damp_percent must between 0 and 1.")

def save_pretrained(self, save_dir: str, **kwargs):
with open(join(save_dir, QUANT_CONFIG_FILENAME), "w", encoding="utf-8") as f:
json.dump(self.to_dict(), f, indent=2)

@classmethod
# normalize quant config for compat and also performs validation
def from_quant_config(cls, quantize_cfg, checkpoint_format: str = None):
valid_formats = {CHECKPOINT_FORMAT.GPTQ, CHECKPOINT_FORMAT.MARLIN, CHECKPOINT_FORMAT.AWQ_GEMM}

checkpoint_format_auto_inferred = False
# compat: checkpoint_format can be passed in via from_quantized() if field missing from json
if checkpoint_format:
if checkpoint_format not in valid_formats:
raise ValueError(f"Unknown quantization checkpoint format: {checkpoint_format}.")
if quantize_cfg.get(CHECKPOINT_FORMAT_FIELD):
raise ValueError("Conflict: quantization checkpoint_format is passed in and also exists in model config.")
# compat: warn if checkpoint_format is missing
elif quantize_cfg.get(CHECKPOINT_FORMAT_FIELD) is None:
checkpoint_format_auto_inferred = True

field_names = [field.name for field in fields(cls)]

normalized = {QUANT_METHOD_FIELD: QUANT_METHOD.GPTQ, CHECKPOINT_FORMAT_FIELD: checkpoint_format if checkpoint_format else CHECKPOINT_FORMAT.GPTQ}
for key, val in quantize_cfg.items():
key = key.lower()

# remap keys according to compat map
if key in QUANT_CONFIG_ARG_SYNONYMS and QUANT_CONFIG_ARG_SYNONYMS[key] in field_names:
key = QUANT_CONFIG_ARG_SYNONYMS[key]

if key == CHECKPOINT_FORMAT_FIELD:
val = val.lower()

if val in {CHECKPOINT_FORMAT.GPTQ, CHECKPOINT_FORMAT.MARLIN, CHECKPOINT_FORMAT.AWQ_GEMM}:
normalized[key] = val
else:
raise ValueError(f"Unknown quantization format: {val}.")
elif key == QUANT_METHOD_FIELD:
val = val.lower()
# compat: some hf models use quant_method=marlin
if val == CHECKPOINT_FORMAT.MARLIN:
normalized[CHECKPOINT_FORMAT_FIELD] = CHECKPOINT_FORMAT.MARLIN
elif val not in {QUANT_METHOD.GPTQ, QUANT_METHOD.AWQ}:
raise ValueError(f"Unknown quantization method: {val}.")
else:
normalized[QUANT_METHOD_FIELD] = val
elif key == CHECKPOINT_FORMAT_FIELD_COMPAT_MARLIN and val:
normalized[CHECKPOINT_FORMAT_FIELD] = CHECKPOINT_FORMAT.MARLIN
elif key == "version" and val.lower() == CHECKPOINT_FORMAT.AWQ_GEMM:
normalized[QUANT_METHOD_FIELD] = QUANT_METHOD.AWQ
normalized[CHECKPOINT_FORMAT_FIELD] = CHECKPOINT_FORMAT.AWQ_GEMM
elif key in field_names:
normalized[key] = val
else:
logger.info(f"Ignoring unknown parameter in the quantization configuration: {key}.")

if checkpoint_format_auto_inferred:
logger.info(f"`checkpoint_format` is missing from the quantization configuration and is automatically inferred to {normalized[CHECKPOINT_FORMAT_FIELD]}.")

if normalized[CHECKPOINT_FORMAT_FIELD] in {CHECKPOINT_FORMAT.AWQ_GEMM, CHECKPOINT_FORMAT.MARLIN}:
# AWQ and Marlin do not reorder the rows.
normalized["desc_act"] = False

if "sym" not in normalized:
logger.warning(
"The quantization configuration does not contain an entry `sym` (symmetric quantization). "
"This may result in silent errors. Defaulting to `sym=True`."
)

return cls(**normalized)

@classmethod
def from_pretrained(cls, save_dir: str, **kwargs):
# Parameters related to loading from Hugging Face Hub
cache_dir = kwargs.pop("cache_dir", None)
force_download = kwargs.pop("force_download", False)
resume_download = kwargs.pop("resume_download", False)
proxies = kwargs.pop("proxies", None)
local_files_only = kwargs.pop("local_files_only", False)
use_auth_token = kwargs.pop("use_auth_token", None)
revision = kwargs.pop("revision", None)
subfolder = kwargs.pop("subfolder", None)
commit_hash = kwargs.pop("_commit_hash", None)
checkpoint_format = kwargs.pop("checkpoint_format", None)

transformers_config = False
for quantize_config_filename in [
QUANT_CONFIG_FILENAME,
"quant_config.json",
"config.json",
]:
if isdir(save_dir): # Local
resolved_config_file = join(save_dir, quantize_config_filename)
else: # Remote
resolved_config_file = cached_file(
save_dir,
quantize_config_filename,
cache_dir=cache_dir,
force_download=force_download,
resume_download=resume_download,
proxies=proxies,
use_auth_token=use_auth_token,
revision=revision,
local_files_only=local_files_only,
subfolder=subfolder,
_raise_exceptions_for_missing_entries=False,
_raise_exceptions_for_connection_errors=False,
_commit_hash=commit_hash,
)
if resolved_config_file is not None:
if quantize_config_filename == "config.json":
transformers_config = True
break

if resolved_config_file is None:
raise ValueError(
"No quantize_config.json, quant_config.json or config.json file was found in the model repository."
)

with open(resolved_config_file, "r", encoding="utf-8") as f:
args_from_json = json.load(f)

if transformers_config:
args_from_json = args_from_json["quantization_config"]

return cls.from_quant_config(args_from_json, checkpoint_format)

def get_cache_file_path(self, quant_method: QUANT_METHOD = None, checkpoint_format: CHECKPOINT_FORMAT = None):
"""
Gets The Cached Weight Path.
If remote: $HF_HOME/assets/autogptq/{model_name_or_path}/_{quant-method}_{checkpoint_format}.safetensors
If local: {model_name_or_path}/autogptq_model_{quant-method}_{checkpoint_format}.safetensors
"""

use_quant_method = quant_method if quant_method else self.quant_method
use_checkpoint_format = checkpoint_format if checkpoint_format else self.checkpoint_format

cache_file_name = f"autogptq_model_{use_quant_method}_{use_checkpoint_format}.safetensors"

if os.path.isdir(self.model_name_or_path):
cache_file_name = os.path.join(self.model_name_or_path, cache_file_name)
else:
namespace, subfolder = self.model_name_or_path.split("/")
assets_path = huggingface_hub.cached_assets_path(
library_name="auto_gptq", namespace=namespace, subfolder=subfolder
)
cache_file_name = os.path.join(assets_path, cache_file_name)

return cache_file_name, os.path.isfile(cache_file_name)

def to_dict(self):
return {
"bits": self.bits,
"group_size": self.group_size,
"damp_percent": self.damp_percent,
"desc_act": self.desc_act,
"static_groups": self.static_groups,
"sym": self.sym,
"true_sequential": self.true_sequential,
"model_name_or_path": self.model_name_or_path,
"model_file_base_name": self.model_file_base_name,
QUANT_METHOD_FIELD: self.quant_method,
CHECKPOINT_FORMAT_FIELD: self.checkpoint_format,
}
2 changes: 1 addition & 1 deletion auto_gptq/utils/accelerate_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,7 @@ def load_checkpoint_in_model(
del loaded_checkpoint
gc.collect()

if not strict:
if not strict and len(unexpected_keys) > 0:
logger.warning(
f"Some weights of the model checkpoint at {checkpoint} were not used when"
f" initializing {model.__class__.__name__}: {unexpected_keys}. This may or may not be an issue - make sure that the checkpoint does not have unnecessary parameters, or that the model definition correctly corresponds to the checkpoint."
Expand Down
4 changes: 2 additions & 2 deletions auto_gptq/utils/import_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def dynamically_import_QuantLinear(
disable_exllama: Optional[bool] = None,
disable_exllamav2: bool = False,
use_qigen: bool = False,
disable_marlin: bool = True,
use_marlin: bool = False,
use_tritonv2: bool = False,
):
if use_qigen:
Expand All @@ -91,7 +91,7 @@ def dynamically_import_QuantLinear(
disable_exllama = False
else:
disable_exllama = True
if bits == 4 and not disable_marlin:
if bits == 4 and use_marlin:
from ..nn_modules.qlinear.qlinear_marlin import QuantLinear
elif bits == 4 and not disable_exllamav2 and EXLLAMAV2_KERNELS_AVAILABLE:
from ..nn_modules.qlinear.qlinear_exllamav2 import QuantLinear
Expand Down
Loading

0 comments on commit 866b4c8

Please sign in to comment.