Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions gptqmodel/looper/gptq_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,7 @@ def submodule_finalize(self, module: NamedModule, model: BaseQModel, **kwargs):
layers=layers,
quant_linear_cls=model.qlinear_kernel,
lock=self.lock,
quantize_config=self.qcfg,
)

# TODO: store module quant results in module, not global processor result
Expand Down
1 change: 1 addition & 0 deletions gptqmodel/looper/qqq_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,7 @@ def submodule_finalize(self, module: NamedModule, model: BaseQModel, **kwargs):
quant_linear_cls=QQQQuantLinear,
lock=self.lock,
q_scales_extra=q_scales_extra,
quantize_config=self.qcfg,
)

# TODO: store module quant results in module, not global processor result
Expand Down
28 changes: 15 additions & 13 deletions gptqmodel/models/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -546,22 +546,24 @@ def skip(*args, **kwargs):
offload_state_dict=True,
offload_buffers=True,
)
# validate sym=False v1 loading needs to be protected for models produced with new v2 format codebase
if not qcfg.sym and not qcfg.is_quantized_by_v2():
raise ValueError(
f"Format: Loading of a sym=False model with format={FORMAT.GPTQ} is only supported if produced by gptqmodel version >= {MIN_VERSION_WITH_V2}"
)

model = convert_gptq_v1_to_v2_format(
model,
cfg=qcfg,
qlinear_kernel=preload_qlinear_kernel,
)

load_checkpoint_in_model = False

if preload_qlinear_kernel.REQUIRES_FORMAT_V2:
qcfg.runtime_format = FORMAT.GPTQ_V2
if qcfg.format == FORMAT.GPTQ:
# validate sym=False v1 loading needs to be protected for models produced with new v2 format codebase
if not qcfg.sym and not qcfg.is_quantized_by_v2():
raise ValueError(
f"Format: Loading of a sym=False model with format={FORMAT.GPTQ} is only supported if produced by gptqmodel version >= {MIN_VERSION_WITH_V2}"
)

if preload_qlinear_kernel.REQUIRES_FORMAT_V2:
model = convert_gptq_v1_to_v2_format(
model,
cfg=qcfg,
qlinear_kernel=preload_qlinear_kernel,
)

qcfg.runtime_format = FORMAT.GPTQ_V2

if backend in [BACKEND.MARLIN, BACKEND.MARLIN_FP16] and (
preload_qlinear_kernel == ExllamaV2QuantLinear or qcfg.format == FORMAT.MARLIN):
Expand Down
16 changes: 13 additions & 3 deletions gptqmodel/models/writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from ..adapter.peft import LoraConfig
from ..quantization.config import (
FORMAT,
METHOD,
META_FIELD_ACT_GROUP_AWARE,
META_FIELD_DAMP_AUTO_INCREMENT,
META_FIELD_DAMP_PERCENT,
Expand All @@ -40,6 +41,7 @@
MIN_VERSION_WITH_V2,
)
from ..utils.backend import BACKEND
from ..utils.hf import sanitize_generation_config_file
from ..utils.logger import setup_logger
from ..utils.model import (
convert_gptq_v2_to_v1_format,
Expand Down Expand Up @@ -225,9 +227,13 @@ def save_quantized(

if not self.load_quantized_model:
model = self.model
# # internal is always gptq v2 but allow users to pass gptq (v1) via config
if quantize_config.format == FORMAT.GPTQ or quantize_config.format == FORMAT.GEMM:
# Model qzeros may be edited in place.
# internal is always gptq v2 but allow users to pass gptq (v1) via config
if (
quantize_config.format == FORMAT.GPTQ
and quantize_config.quant_method == METHOD.GPTQ
and self.qlinear_kernel.REQUIRES_FORMAT_V2
):
# Model qzeros may be edited in place for export compatibility.
model = convert_gptq_v2_to_v1_format(
model, quantize_config=quantize_config, qlinear_kernel=self.qlinear_kernel
)
Expand All @@ -246,6 +252,10 @@ def save_quantized(
# Use empty state_dict hack to bypass saving weights
self.model.save_pretrained(save_dir, state_dict={}, is_main_process=True)

gen_config_path = os.path.join(save_dir, "generation_config.json")
if sanitize_generation_config_file(gen_config_path):
log.info("Model: Sanitized `generation_config.json` before packaging.")

# Save `quantize_config.json`
quantize_config.save_pretrained(save_dir)

Expand Down
75 changes: 74 additions & 1 deletion gptqmodel/utils/hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
# SPDX-License-Identifier: Apache-2.0
# Contact: qubitium@modelcloud.ai, x.com/qubitium

import json
from typing import Any, Optional

import torch
Expand All @@ -14,14 +15,60 @@

log = setup_logger()

GENERATION_SAMPLING_FIELDS = ("temperature", "top_p")


def _sanitize_generation_config(cfg: GenerationConfig, *, drop_sampling_fields: bool = False) -> bool:
changed = False
if cfg is None:
return changed

if getattr(cfg, "do_sample", None) is not True:
cfg.do_sample = True
changed = True

if drop_sampling_fields:
for field in GENERATION_SAMPLING_FIELDS:
if hasattr(cfg, field):
if getattr(cfg, field) is not None:
changed = True
setattr(cfg, field, None)
return changed


def _load_sanitized_generation_config(path: str) -> Optional[GenerationConfig]:
try:
config_dict, kwargs = GenerationConfig.get_config_dict(path)
except Exception:
return None

cleaned = dict(config_dict)
removed = False
for field in GENERATION_SAMPLING_FIELDS:
if field in cleaned:
cleaned.pop(field, None)
removed = True
if cleaned.get("do_sample") is not True:
cleaned["do_sample"] = True

cfg = GenerationConfig.from_dict(cleaned, **kwargs)
if removed:
log.info("Model: Removed unsupported sampling fields from `generation_config.json` during load.")
_sanitize_generation_config(cfg, drop_sampling_fields=True)
return cfg


# TODO FIXME! Pre-quantized use AutoModelForCausalLM.from_pretrained() but post-quantized use AutoModelForCausalLM.from_config()
def autofix_hf_model_config(model: PreTrainedModel, path: str = None):
if model.can_generate():
# sync config first
if path:
log.info(f"Model: Loaded `generation_config`: {model.generation_config}")
try:
cfg = GenerationConfig.from_pretrained(pretrained_model_name=path)
cfg = _load_sanitized_generation_config(path)
if cfg is None:
cfg = GenerationConfig.from_pretrained(pretrained_model_name=path, do_sample=True)
_sanitize_generation_config(cfg, drop_sampling_fields=True)
if cfg != model.generation_config:
# migrated pad_token_id to config
if hasattr(model.generation_config, "pad_token_id"):
Expand All @@ -41,7 +88,9 @@ def autofix_hf_model_config(model: PreTrainedModel, path: str = None):
autofix_hf_generation_config(model.generation_config)
# print(f"After autofix_hf_model_config: {model.generation_config}")


def autofix_hf_generation_config(cfg: GenerationConfig):
_sanitize_generation_config(cfg, drop_sampling_fields=True)
# HF has recently started to perform very strict validation model save which results in warnings on load()
# to become exceptions on save().
if cfg.do_sample is False:
Expand All @@ -67,6 +116,30 @@ def autofix_hf_generation_config(cfg: GenerationConfig):
cfg.do_sample = True
log.info("Model: Auto-Fixed `generation_config` by setting `do_sample=True`.")


def sanitize_generation_config_file(path: str) -> bool:
try:
with open(path, "r", encoding="utf-8") as fp:
data = json.load(fp)
except FileNotFoundError:
return False

changed = False
for field in GENERATION_SAMPLING_FIELDS:
if field in data:
data.pop(field, None)
changed = True

if data.get("do_sample") is not True:
data["do_sample"] = True
changed = True

if changed:
with open(path, "w", encoding="utf-8") as fp:
json.dump(data, fp, indent=2)

return changed

# load hf model with empty tensors on meta device (zero tensor memory usage)
def build_shell_model(
loader,
Expand Down
36 changes: 33 additions & 3 deletions gptqmodel/utils/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -673,7 +673,19 @@ def convert_gptq_v2_to_v1_format(
return model


def pack_module(name, qModules, q_scales, q_zeros, q_g_idx, layers, quant_linear_cls, lock: threading.Lock, q_scales_extra = None):
def pack_module(
name,
qModules,
q_scales,
q_zeros,
q_g_idx,
layers,
quant_linear_cls,
lock: threading.Lock,
q_scales_extra=None,
quantize_config: Optional[QuantizeConfig] = None,
quant_result: Optional[Dict[str, Any]] = None,
):
# Limit pack() thread usage to avoid auto-parallizataion regression
with tctl.threadpool_limits(limits=1):
with lock:
Expand Down Expand Up @@ -702,6 +714,17 @@ def pack_module(name, qModules, q_scales, q_zeros, q_g_idx, layers, quant_linear
else:
module.pack(linear=layer, scales=q_scales, zeros=q_zeros, g_idx=q_g_idx)

if (
quantize_config is not None
and quantize_config.quant_method == METHOD.GPTQ
and quantize_config.format == FORMAT.GPTQ
and getattr(quant_linear_cls, "REQUIRES_FORMAT_V2", False)
):
convert_gptq_v2_to_v1_format_module(
module=module,
quantize_config=quantize_config,
)

# TODO: why move it back to gpu?
# start = time.time()
# qModules[name].to(layer_device)
Expand Down Expand Up @@ -767,8 +790,15 @@ def wrapper(name):
# TODO FIX, thread pool executor does not advance iterator
pb.next()
pb.title(f"Packing {name}").draw()
pack_module(name=name, qModules=qModules, quant_result=quant_result, layers=modules,
quant_linear_cls=quant_linear_cls, lock=lock)
pack_module(
name=name,
qModules=qModules,
quant_result=quant_result,
layers=modules,
quant_linear_cls=quant_linear_cls,
lock=lock,
quantize_config=qcfg,
)

for _ in executor.map(wrapper, names):
pass
Expand Down
Loading