Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions examples/benchmark/ipex.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@
import time

import torch
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
from transformers import AutoModelForCausalLM, AutoTokenizer

from gptqmodel.utils.hf import safe_auto_config_from_pretrained


try:
Expand Down Expand Up @@ -53,7 +55,7 @@ def prepare_dataset_for_bench(tokenizer, batch_size=8):

# load model, check model backend
start_load = time.time()
config = AutoConfig.from_pretrained(ars.model)
config = safe_auto_config_from_pretrained(ars.model)
is_quantized_model = hasattr(config, "quantization_config")
if is_quantized_model:
from gptqmodel import BACKEND, GPTQModel
Expand Down
9 changes: 3 additions & 6 deletions format/format.sh
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,11 @@

cd "$(dirname "$0")" || exit

# force ruff/isort to be same version as setup.py
pip install -U ruff==0.13.0 isort==6.0.1
# force ruff to be same version as setup.py
pip install -U ruff==0.13.0

ruff check ../gptqmodel/models ../gptqmodel/nn_modules ../gptqmodel/quantization ../gptqmodel/utils ../gptqmodel/__init__.py ../examples ../tests ../setup.py --fix --unsafe-fixes
ruff_status=$?

# isort is too slow
# isort -l 119 -e ../

# Exit with the status code of ruff check
exit $ruff_status
exit $ruff_status
7 changes: 6 additions & 1 deletion gptqmodel/looper/module_looper.py
Original file line number Diff line number Diff line change
Expand Up @@ -783,6 +783,8 @@ def loop(self, fail_safe: bool = False, **kwargs):
f"subset={index + 1}/{subset_total}, batches={batch_count})"
)
log.info(forward_msg)
# Drain any background work so the forward spike does not race pooled tasks.
self.pool.wait()
forward_outputs = self._run_forward_batches(
module=module,
processor=processor,
Expand Down Expand Up @@ -859,7 +861,8 @@ def _process_on_worker(proc: LoopProcessor, nm: NamedModule):
for fut in futures:
name, m = fut.result()
processed_subset[name] = m
torch_sync()

#torch_sync()
# ---- End Process Hook ----

is_last_module = layer_index == len(quant_modules_pb) - 1
Expand All @@ -881,6 +884,8 @@ def _process_on_worker(proc: LoopProcessor, nm: NamedModule):
f"batches={replay_batch_count})"
)
log.info(replay_msg)
# Forward replay shares the same VRAM spike; block until the pool drains first.
self.pool.wait()
layer_outputs = self._run_forward_batches(
module=module,
processor=processor,
Expand Down
78 changes: 59 additions & 19 deletions gptqmodel/models/auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,13 +46,14 @@
import torch # noqa: E402
from huggingface_hub import list_repo_files # noqa: E402
from tokenicer import Tokenicer # noqa: E402
from transformers import AutoConfig, GenerationConfig, PreTrainedModel, PreTrainedTokenizerBase # noqa: E402
from transformers import GenerationConfig, PreTrainedModel, PreTrainedTokenizerBase # noqa: E402

from ..adapter.adapter import Adapter, Lora, normalize_adapter # noqa: E402
from ..nn_modules.qlinear.torch import TorchQuantLinear # noqa: E402
from ..quantization import METHOD, QUANT_CONFIG_FILENAME # noqa: E402
from ..utils import BACKEND # noqa: E402
from ..utils.eval import EVAL # noqa: E402
from ..utils.hf import safe_auto_config_from_pretrained # noqa: E402
from ..utils.model import find_modules # noqa: E402
from ..utils.torch import CPU, torch_empty_cache # noqa: E402
from .base import BaseQModel, QuantizeConfig # noqa: E402
Expand Down Expand Up @@ -102,6 +103,7 @@
from .definitions.nemotron_h import NemotronHQModel # noqa: E402
from .definitions.opt import OptQModel # noqa: E402
from .definitions.ovis import OvisQModel # noqa: E402
from .definitions.ovis2_5 import Ovis2_5QModel # noqa: E402
from .definitions.pangu_alpha import PanguAlphaQModel # noqa: E402
from .definitions.phi import PhiQModel # noqa: E402
from .definitions.phi3 import Phi3QModel, PhiMoEGPTQForCausalLM # noqa: E402
Expand Down Expand Up @@ -197,6 +199,7 @@
"hymba": HymbaQModel,
"olmo2": LlamaQModel, # 100% llama clone
"ovis": OvisQModel,
"ovis2_5": Ovis2_5QModel,
"telechat": TeleChat2QModel,
"instella": InstellaQModel,
"mimo": MimoQModel,
Expand All @@ -215,7 +218,7 @@


def check_and_get_model_type(model_dir, trust_remote_code=False):
config = AutoConfig.from_pretrained(model_dir, trust_remote_code=trust_remote_code)
config = safe_auto_config_from_pretrained(model_dir, trust_remote_code=trust_remote_code)
if config.model_type.lower() not in SUPPORTED_MODELS:
raise TypeError(f"{config.model_type} isn't supported yet.")
model_type = config.model_type
Expand Down Expand Up @@ -252,7 +255,7 @@ def load(
backend = BACKEND(backend)

is_gptqmodel_quantized = False
model_cfg = AutoConfig.from_pretrained(model_id_or_path, trust_remote_code=trust_remote_code)
model_cfg = safe_auto_config_from_pretrained(model_id_or_path, trust_remote_code=trust_remote_code)
if hasattr(model_cfg, "quantization_config") and "quant_format" in model_cfg.quantization_config:
# only if the model is quantized or compatible with gptqmodel should we set is_quantized to true
if model_cfg.quantization_config["quant_format"].lower() in (METHOD.GPTQ, METHOD.AWQ, METHOD.QQQ):
Expand All @@ -273,6 +276,7 @@ def load(
break

if is_gptqmodel_quantized:
log.info("GPTQModel.load: loading quantized model `%s` with trust_remote_code=%s", model_id_or_path, trust_remote_code)
m = cls.from_quantized(
model_id_or_path=model_id_or_path,
device_map=device_map,
Expand Down Expand Up @@ -306,7 +310,8 @@ def from_pretrained(
trust_remote_code: bool = False,
**model_init_kwargs,
) -> BaseQModel:
if hasattr(AutoConfig.from_pretrained(model_id_or_path, trust_remote_code=trust_remote_code),
config = safe_auto_config_from_pretrained(model_id_or_path, trust_remote_code=trust_remote_code)
if hasattr(config,
"quantization_config"):
log.warn("Model is already quantized, will use `from_quantized` to load quantized model.\n"
"If you want to quantize the model, please pass un_quantized model path or id, and use "
Expand Down Expand Up @@ -397,7 +402,11 @@ def eval(

if isinstance(model_or_id_or_path, str):
log.info(f"Eval: loading using backend = `{backend}`")
model = GPTQModel.load(model_id_or_path=model_or_id_or_path, backend=backend)
model = GPTQModel.load(
model_id_or_path=model_or_id_or_path,
backend=backend,
trust_remote_code=trust_remote_code,
)
model_id_or_path = model_or_id_or_path
elif isinstance(model_or_id_or_path, BaseQModel) or isinstance(model_or_id_or_path, (PreTrainedModel, PeftModel)):
model = model_or_id_or_path
Expand All @@ -409,7 +418,7 @@ def eval(
if isinstance(model, BaseQModel):
tokenizer = model.tokenizer
elif isinstance(model, PreTrainedModel) or model_id_or_path.strip():
tokenizer = Tokenicer.load(model_id_or_path)
tokenizer = Tokenicer.load(model_id_or_path, trust_remote_code=trust_remote_code)

if tokenizer is None:
raise ValueError("Tokenizer: Auto-loading of tokenizer failed with `model_or_id_or_path`. Please pass in `tokenizer` as argument.")
Expand Down Expand Up @@ -448,19 +457,46 @@ def eval(

# use model.generation_config whenever possible
if gen_kwargs is None:
# TODO: move to utils
if hasattr(model, "generation_config") and isinstance(model.generation_config, GenerationConfig):
gen_dict = {
"do_sample": model.generation_config.do_sample,
"temperature": model.generation_config.temperature,
"top_k": model.generation_config.top_k,
"top_p": model.generation_config.top_p,
"min_p": model.generation_config.min_p,

}
gen_kwargs = ','.join(f"{key}={value}" for key, value in gen_dict.items() if value not in ["", {}, None, []])
cfg = model.generation_config
kv_pairs = []
if getattr(cfg, "do_sample", False):
kv_pairs.append("do_sample=True")
temperature = getattr(cfg, "temperature", None)
if temperature is not None and temperature != 1.0:
kv_pairs.append(f"temperature={temperature}")
top_k = getattr(cfg, "top_k", None)
if top_k is not None:
kv_pairs.append(f"top_k={top_k}")
top_p = getattr(cfg, "top_p", None)
if top_p is not None and top_p != 1.0:
kv_pairs.append(f"top_p={top_p}")
min_p = getattr(cfg, "min_p", None)
if min_p is not None and min_p > 0.0:
kv_pairs.append(f"min_p={min_p}")
typical_p = getattr(cfg, "typical_p", None)
if typical_p is not None and typical_p != 1.0:
kv_pairs.append(f"typical_p={typical_p}")
epsilon_cutoff = getattr(cfg, "epsilon_cutoff", None)
if epsilon_cutoff is not None and epsilon_cutoff != 0.0:
kv_pairs.append(f"epsilon_cutoff={epsilon_cutoff}")
eta_cutoff = getattr(cfg, "eta_cutoff", None)
if eta_cutoff is not None and eta_cutoff != 0.0:
kv_pairs.append(f"eta_cutoff={eta_cutoff}")
penalty_alpha = getattr(cfg, "penalty_alpha", None)
if penalty_alpha is not None:
kv_pairs.append(f"penalty_alpha={penalty_alpha}")
else:
kv_pairs.append("do_sample=False")
temperature = getattr(cfg, "temperature", None)
if temperature is None:
temperature = 0.0
if temperature != 1.0:
kv_pairs.append(f"temperature={temperature}")

gen_kwargs = ",".join(kv_pairs)
else:
gen_kwargs = "temperature=0.0,top_k=50" # default
gen_kwargs = "do_sample=False,temperature=0.0" # default

log.info(f"LM-EVAL: `gen_kwargs` = `{gen_kwargs}`")

Expand Down Expand Up @@ -537,15 +573,19 @@ def eval(
@staticmethod
def export(model_id_or_path: str, target_path: str, format: str, trust_remote_code: bool = False):
# load config
config = AutoConfig.from_pretrained(model_id_or_path, trust_remote_code=trust_remote_code)
config = safe_auto_config_from_pretrained(model_id_or_path, trust_remote_code=trust_remote_code)

if not config.quantization_config:
raise ValueError("Model is not quantized")

gptq_config = config.quantization_config

# load gptq model
gptq_model = GPTQModel.load(model_id_or_path, backend=BACKEND.TORCH)
gptq_model = GPTQModel.load(
model_id_or_path,
backend=BACKEND.TORCH,
trust_remote_code=trust_remote_code,
)

if format == "mlx":
try:
Expand Down
1 change: 1 addition & 0 deletions gptqmodel/models/definitions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
from .mpt import MptQModel
from .opt import OptQModel
from .ovis import OvisQModel
from .ovis2_5 import Ovis2_5QModel
from .phi import PhiQModel
from .phi3 import Phi3QModel
from .qwen import QwenQModel
Expand Down
Loading