diff --git a/examples/benchmark/ipex.py b/examples/benchmark/ipex.py index fc73436ed..cb0e70cd7 100644 --- a/examples/benchmark/ipex.py +++ b/examples/benchmark/ipex.py @@ -7,7 +7,9 @@ import time import torch -from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer +from transformers import AutoModelForCausalLM, AutoTokenizer + +from gptqmodel.utils.hf import safe_auto_config_from_pretrained try: @@ -53,7 +55,7 @@ def prepare_dataset_for_bench(tokenizer, batch_size=8): # load model, check model backend start_load = time.time() -config = AutoConfig.from_pretrained(ars.model) +config = safe_auto_config_from_pretrained(ars.model) is_quantized_model = hasattr(config, "quantization_config") if is_quantized_model: from gptqmodel import BACKEND, GPTQModel diff --git a/format/format.sh b/format/format.sh index 9e4d630ed..a087073c6 100755 --- a/format/format.sh +++ b/format/format.sh @@ -2,14 +2,11 @@ cd "$(dirname "$0")" || exit -# force ruff/isort to be same version as setup.py -pip install -U ruff==0.13.0 isort==6.0.1 +# force ruff to be same version as setup.py +pip install -U ruff==0.13.0 ruff check ../gptqmodel/models ../gptqmodel/nn_modules ../gptqmodel/quantization ../gptqmodel/utils ../gptqmodel/__init__.py ../examples ../tests ../setup.py --fix --unsafe-fixes ruff_status=$? -# isort is too slow -# isort -l 119 -e ../ - # Exit with the status code of ruff check -exit $ruff_status \ No newline at end of file +exit $ruff_status diff --git a/gptqmodel/looper/module_looper.py b/gptqmodel/looper/module_looper.py index d1ef44fad..aa528faad 100644 --- a/gptqmodel/looper/module_looper.py +++ b/gptqmodel/looper/module_looper.py @@ -783,6 +783,8 @@ def loop(self, fail_safe: bool = False, **kwargs): f"subset={index + 1}/{subset_total}, batches={batch_count})" ) log.info(forward_msg) + # Drain any background work so the forward spike does not race pooled tasks. + self.pool.wait() forward_outputs = self._run_forward_batches( module=module, processor=processor, @@ -859,7 +861,8 @@ def _process_on_worker(proc: LoopProcessor, nm: NamedModule): for fut in futures: name, m = fut.result() processed_subset[name] = m - torch_sync() + + #torch_sync() # ---- End Process Hook ---- is_last_module = layer_index == len(quant_modules_pb) - 1 @@ -881,6 +884,8 @@ def _process_on_worker(proc: LoopProcessor, nm: NamedModule): f"batches={replay_batch_count})" ) log.info(replay_msg) + # Forward replay shares the same VRAM spike; block until the pool drains first. + self.pool.wait() layer_outputs = self._run_forward_batches( module=module, processor=processor, diff --git a/gptqmodel/models/auto.py b/gptqmodel/models/auto.py index 924a5cdf0..6bba8dafa 100644 --- a/gptqmodel/models/auto.py +++ b/gptqmodel/models/auto.py @@ -46,13 +46,14 @@ import torch # noqa: E402 from huggingface_hub import list_repo_files # noqa: E402 from tokenicer import Tokenicer # noqa: E402 -from transformers import AutoConfig, GenerationConfig, PreTrainedModel, PreTrainedTokenizerBase # noqa: E402 +from transformers import GenerationConfig, PreTrainedModel, PreTrainedTokenizerBase # noqa: E402 from ..adapter.adapter import Adapter, Lora, normalize_adapter # noqa: E402 from ..nn_modules.qlinear.torch import TorchQuantLinear # noqa: E402 from ..quantization import METHOD, QUANT_CONFIG_FILENAME # noqa: E402 from ..utils import BACKEND # noqa: E402 from ..utils.eval import EVAL # noqa: E402 +from ..utils.hf import safe_auto_config_from_pretrained # noqa: E402 from ..utils.model import find_modules # noqa: E402 from ..utils.torch import CPU, torch_empty_cache # noqa: E402 from .base import BaseQModel, QuantizeConfig # noqa: E402 @@ -102,6 +103,7 @@ from .definitions.nemotron_h import NemotronHQModel # noqa: E402 from .definitions.opt import OptQModel # noqa: E402 from .definitions.ovis import OvisQModel # noqa: E402 +from .definitions.ovis2_5 import Ovis2_5QModel # noqa: E402 from .definitions.pangu_alpha import PanguAlphaQModel # noqa: E402 from .definitions.phi import PhiQModel # noqa: E402 from .definitions.phi3 import Phi3QModel, PhiMoEGPTQForCausalLM # noqa: E402 @@ -197,6 +199,7 @@ "hymba": HymbaQModel, "olmo2": LlamaQModel, # 100% llama clone "ovis": OvisQModel, + "ovis2_5": Ovis2_5QModel, "telechat": TeleChat2QModel, "instella": InstellaQModel, "mimo": MimoQModel, @@ -215,7 +218,7 @@ def check_and_get_model_type(model_dir, trust_remote_code=False): - config = AutoConfig.from_pretrained(model_dir, trust_remote_code=trust_remote_code) + config = safe_auto_config_from_pretrained(model_dir, trust_remote_code=trust_remote_code) if config.model_type.lower() not in SUPPORTED_MODELS: raise TypeError(f"{config.model_type} isn't supported yet.") model_type = config.model_type @@ -252,7 +255,7 @@ def load( backend = BACKEND(backend) is_gptqmodel_quantized = False - model_cfg = AutoConfig.from_pretrained(model_id_or_path, trust_remote_code=trust_remote_code) + model_cfg = safe_auto_config_from_pretrained(model_id_or_path, trust_remote_code=trust_remote_code) if hasattr(model_cfg, "quantization_config") and "quant_format" in model_cfg.quantization_config: # only if the model is quantized or compatible with gptqmodel should we set is_quantized to true if model_cfg.quantization_config["quant_format"].lower() in (METHOD.GPTQ, METHOD.AWQ, METHOD.QQQ): @@ -273,6 +276,7 @@ def load( break if is_gptqmodel_quantized: + log.info("GPTQModel.load: loading quantized model `%s` with trust_remote_code=%s", model_id_or_path, trust_remote_code) m = cls.from_quantized( model_id_or_path=model_id_or_path, device_map=device_map, @@ -306,7 +310,8 @@ def from_pretrained( trust_remote_code: bool = False, **model_init_kwargs, ) -> BaseQModel: - if hasattr(AutoConfig.from_pretrained(model_id_or_path, trust_remote_code=trust_remote_code), + config = safe_auto_config_from_pretrained(model_id_or_path, trust_remote_code=trust_remote_code) + if hasattr(config, "quantization_config"): log.warn("Model is already quantized, will use `from_quantized` to load quantized model.\n" "If you want to quantize the model, please pass un_quantized model path or id, and use " @@ -397,7 +402,11 @@ def eval( if isinstance(model_or_id_or_path, str): log.info(f"Eval: loading using backend = `{backend}`") - model = GPTQModel.load(model_id_or_path=model_or_id_or_path, backend=backend) + model = GPTQModel.load( + model_id_or_path=model_or_id_or_path, + backend=backend, + trust_remote_code=trust_remote_code, + ) model_id_or_path = model_or_id_or_path elif isinstance(model_or_id_or_path, BaseQModel) or isinstance(model_or_id_or_path, (PreTrainedModel, PeftModel)): model = model_or_id_or_path @@ -409,7 +418,7 @@ def eval( if isinstance(model, BaseQModel): tokenizer = model.tokenizer elif isinstance(model, PreTrainedModel) or model_id_or_path.strip(): - tokenizer = Tokenicer.load(model_id_or_path) + tokenizer = Tokenicer.load(model_id_or_path, trust_remote_code=trust_remote_code) if tokenizer is None: raise ValueError("Tokenizer: Auto-loading of tokenizer failed with `model_or_id_or_path`. Please pass in `tokenizer` as argument.") @@ -448,19 +457,46 @@ def eval( # use model.generation_config whenever possible if gen_kwargs is None: - # TODO: move to utils if hasattr(model, "generation_config") and isinstance(model.generation_config, GenerationConfig): - gen_dict = { - "do_sample": model.generation_config.do_sample, - "temperature": model.generation_config.temperature, - "top_k": model.generation_config.top_k, - "top_p": model.generation_config.top_p, - "min_p": model.generation_config.min_p, - - } - gen_kwargs = ','.join(f"{key}={value}" for key, value in gen_dict.items() if value not in ["", {}, None, []]) + cfg = model.generation_config + kv_pairs = [] + if getattr(cfg, "do_sample", False): + kv_pairs.append("do_sample=True") + temperature = getattr(cfg, "temperature", None) + if temperature is not None and temperature != 1.0: + kv_pairs.append(f"temperature={temperature}") + top_k = getattr(cfg, "top_k", None) + if top_k is not None: + kv_pairs.append(f"top_k={top_k}") + top_p = getattr(cfg, "top_p", None) + if top_p is not None and top_p != 1.0: + kv_pairs.append(f"top_p={top_p}") + min_p = getattr(cfg, "min_p", None) + if min_p is not None and min_p > 0.0: + kv_pairs.append(f"min_p={min_p}") + typical_p = getattr(cfg, "typical_p", None) + if typical_p is not None and typical_p != 1.0: + kv_pairs.append(f"typical_p={typical_p}") + epsilon_cutoff = getattr(cfg, "epsilon_cutoff", None) + if epsilon_cutoff is not None and epsilon_cutoff != 0.0: + kv_pairs.append(f"epsilon_cutoff={epsilon_cutoff}") + eta_cutoff = getattr(cfg, "eta_cutoff", None) + if eta_cutoff is not None and eta_cutoff != 0.0: + kv_pairs.append(f"eta_cutoff={eta_cutoff}") + penalty_alpha = getattr(cfg, "penalty_alpha", None) + if penalty_alpha is not None: + kv_pairs.append(f"penalty_alpha={penalty_alpha}") + else: + kv_pairs.append("do_sample=False") + temperature = getattr(cfg, "temperature", None) + if temperature is None: + temperature = 0.0 + if temperature != 1.0: + kv_pairs.append(f"temperature={temperature}") + + gen_kwargs = ",".join(kv_pairs) else: - gen_kwargs = "temperature=0.0,top_k=50" # default + gen_kwargs = "do_sample=False,temperature=0.0" # default log.info(f"LM-EVAL: `gen_kwargs` = `{gen_kwargs}`") @@ -537,7 +573,7 @@ def eval( @staticmethod def export(model_id_or_path: str, target_path: str, format: str, trust_remote_code: bool = False): # load config - config = AutoConfig.from_pretrained(model_id_or_path, trust_remote_code=trust_remote_code) + config = safe_auto_config_from_pretrained(model_id_or_path, trust_remote_code=trust_remote_code) if not config.quantization_config: raise ValueError("Model is not quantized") @@ -545,7 +581,11 @@ def export(model_id_or_path: str, target_path: str, format: str, trust_remote_co gptq_config = config.quantization_config # load gptq model - gptq_model = GPTQModel.load(model_id_or_path, backend=BACKEND.TORCH) + gptq_model = GPTQModel.load( + model_id_or_path, + backend=BACKEND.TORCH, + trust_remote_code=trust_remote_code, + ) if format == "mlx": try: diff --git a/gptqmodel/models/definitions/__init__.py b/gptqmodel/models/definitions/__init__.py index 7072e254d..e10838ed4 100644 --- a/gptqmodel/models/definitions/__init__.py +++ b/gptqmodel/models/definitions/__init__.py @@ -44,6 +44,7 @@ from .mpt import MptQModel from .opt import OptQModel from .ovis import OvisQModel +from .ovis2_5 import Ovis2_5QModel from .phi import PhiQModel from .phi3 import Phi3QModel from .qwen import QwenQModel diff --git a/gptqmodel/models/definitions/ovis.py b/gptqmodel/models/definitions/ovis.py index 234ffa81b..5c21b1aac 100644 --- a/gptqmodel/models/definitions/ovis.py +++ b/gptqmodel/models/definitions/ovis.py @@ -17,6 +17,7 @@ class OvisQModel(BaseQModel): + require_trust_remote_code = True pre_lm_head_norm_module = "llm.model.norm" module_tree = [ @@ -49,6 +50,23 @@ def monkey_patch(self): self.model.vte = self.model.vte.to(dtype=self.model.llm.dtype) def pre_quantize_generate_hook_start(self): + visual_tokenizer_meta = any(param.device.type == "meta" for param in self.model.visual_tokenizer.parameters()) or \ + any(buffer.device.type == "meta" for buffer in self.model.visual_tokenizer.buffers()) + if visual_tokenizer_meta: + try: + self.shell_module_materialize(self.model.visual_tokenizer, self.quantize_config.device) + except Exception: + logging.warning("OVIS visual_tokenizer shell materialization failed; continuing with fallback move.", + exc_info=True) + + vte_meta = any(param.device.type == "meta" for param in self.model.vte.parameters()) or \ + any(buffer.device.type == "meta" for buffer in self.model.vte.buffers()) + if vte_meta: + try: + self.shell_module_materialize(self.model.vte, self.quantize_config.device) + except Exception: + logging.warning("OVIS VTE shell materialization failed; continuing with fallback move.", exc_info=True) + self.model.visual_tokenizer = move_to(self.model.visual_tokenizer, device=self.quantize_config.device) self.model.vte = move_to(self.model.vte, device=self.quantize_config.device) @@ -76,8 +94,14 @@ def preprocess_dataset(self, sample: Dict) -> Dict: propagate_exception=False ) + target_dtype = self.model.visual_tokenizer.dtype if pixel_values is None: pixel_values, _ = self.visual_tokenizer.mock_input() + pixel_values = [pv.to(dtype=target_dtype) for pv in pixel_values] + elif isinstance(pixel_values, (list, tuple)): + pixel_values = [pv.to(dtype=target_dtype) for pv in pixel_values] + else: + pixel_values = pixel_values.to(dtype=target_dtype) input_ids = input_ids[:text_max_length] labels = labels[:text_max_length] @@ -122,7 +146,128 @@ def prepare_dataset( return calib_data - def generate(self, inputs, **kwargs): + def generate(self, inputs=None, **kwargs): """shortcut for model.generate""" - with torch.inference_mode(), torch.amp.autocast(device_type=self.device.type): - return self.model.generate(inputs, **kwargs) + model_device = getattr(self.model, "device", None) + if model_device is None: + quant_device = getattr(self.quantize_config, "device", None) + model_device = torch.device(quant_device if quant_device is not None else "cpu") + + llm = getattr(self.model, "llm", None) + + pixel_values = None + if isinstance(inputs, dict): + pixel_values = inputs.get("pixel_values") + if pixel_values is None: + pixel_values = kwargs.get("pixel_values") + + has_real_pixels = False + if pixel_values is not None: + if isinstance(pixel_values, (list, tuple)): + has_real_pixels = any(p is not None for p in pixel_values) + else: + has_real_pixels = True + + text_only = (llm is not None) and not has_real_pixels + + if text_only: + kwargs = dict(kwargs) + kwargs.pop("pixel_values", None) + + if isinstance(inputs, dict): + if inputs.get("pixel_values") is not None: + inputs = dict(inputs) + inputs.pop("pixel_values", None) + + llm_device = next(self.model.llm.parameters()).device + + def ensure_attention_mask(payload): + if payload.get("attention_mask") is not None or "input_ids" not in payload: + return payload + mask = torch.ones_like(payload["input_ids"], dtype=torch.bool, device=payload["input_ids"].device) + payload["attention_mask"] = mask + return payload + + if isinstance(inputs, (str, list)): + if isinstance(inputs, str) or (isinstance(inputs, list) and all(isinstance(x, str) for x in inputs)): + if self.tokenizer is None: + raise ValueError( + "You passed in text to OvisQModel.generate() but tokenizer is missing." + ) + tokenized = self.tokenizer( + inputs, + return_tensors="pt", + padding=True, + padding_side="left" + ).to(llm_device) + with torch.amp.autocast(device_type=llm_device.type): + return self.model.llm.generate(**tokenized, **kwargs) + + if isinstance(inputs, dict): + payload = {k: v for k, v in inputs.items() if k != "pixel_values"} + if "input_ids" in payload and isinstance(payload["input_ids"], torch.Tensor): + payload["input_ids"] = payload["input_ids"].to(llm_device) + if "attention_mask" in payload and isinstance(payload["attention_mask"], torch.Tensor): + payload["attention_mask"] = payload["attention_mask"].to(llm_device) + payload = ensure_attention_mask(payload) + payload.update(kwargs) + with torch.amp.autocast(device_type=llm_device.type): + return self.model.llm.generate(**payload) + + if isinstance(inputs, torch.Tensor): + inputs = inputs.to(llm_device) + attention_mask = kwargs.pop("attention_mask", None) + if attention_mask is None: + attention_mask = torch.ones_like(inputs, dtype=torch.bool, device=inputs.device) + else: + attention_mask = attention_mask.to(inputs.device) + + with torch.amp.autocast(device_type=llm_device.type): + return self.model.llm.generate( + inputs=inputs, + attention_mask=attention_mask, + **kwargs + ) + + if inputs is None: + payload = ensure_attention_mask({k: v for k, v in kwargs.items() if k in {"input_ids", "attention_mask"}}) + remaining = {k: v for k, v in kwargs.items() if k not in payload} + payload = {k: (v.to(llm_device) if isinstance(v, torch.Tensor) else v) for k, v in payload.items()} + with torch.amp.autocast(device_type=llm_device.type): + return self.model.llm.generate(**payload, **remaining) + + with torch.amp.autocast(device_type=model_device.type): + return super().generate(inputs=inputs, **kwargs) + + def forward(self, *args, **kwargs): + """Allow text-only invocations to bypass vision branch for evaluator compatibility.""" + if args: + # most callers pass input_ids positionally; forward them as keyword for clarity + if "input_ids" in kwargs: + raise TypeError("OvisQModel.forward() received positional and keyword input_ids") + kwargs = dict(kwargs) # shallow copy to avoid mutating caller state + kwargs["input_ids"] = args[0] + args = args[1:] + + input_ids = kwargs.get("input_ids") + pixel_values = kwargs.get("pixel_values") + + if input_ids is not None and pixel_values is None and hasattr(self.model, "llm"): + attention_mask = kwargs.get("attention_mask") + if attention_mask is None: + attention_mask = torch.ones_like(input_ids, dtype=torch.bool, device=input_ids.device) + + # Hugging Face text evaluators expect logits only; labels are optional here. + llm_kwargs = { + k: v + for k, v in kwargs.items() + if k not in {"pixel_values"} + } + llm_kwargs["attention_mask"] = attention_mask + + if llm_kwargs.get("labels") is None: + llm_kwargs.pop("labels", None) + + return self.model.llm(**llm_kwargs) + + return super().forward(*args, **kwargs) diff --git a/gptqmodel/models/loader.py b/gptqmodel/models/loader.py index 802e2a009..33e85529a 100644 --- a/gptqmodel/models/loader.py +++ b/gptqmodel/models/loader.py @@ -25,7 +25,7 @@ from huggingface_hub import snapshot_download from packaging.version import InvalidVersion, Version -from transformers import AutoConfig, AutoTokenizer, PretrainedConfig +from transformers import AutoTokenizer, PretrainedConfig from transformers.modeling_utils import no_init_weights from transformers.utils import is_flash_attn_2_available from transformers.utils.generic import ContextManagers @@ -35,6 +35,7 @@ from ..quantization import QuantizeConfig from ..quantization.config import FORMAT, METHOD, MIN_VERSION_WITH_V2 from ..utils.backend import BACKEND +from ..utils.hf import safe_auto_config_from_pretrained from ..utils.importer import auto_select_device, normalize_device_device_map, select_quant_linear from ..utils.logger import setup_logger from ..utils.marlin import _validate_marlin_device_support @@ -56,6 +57,45 @@ log = setup_logger() ATTN_IMPLEMENTATION = "attn_implementation" + + +def _coerce_torch_dtype(value): + if value is None: + return None + if isinstance(value, torch.dtype): + return value + attr = None + if isinstance(value, str): + attr = getattr(torch, value, None) + if isinstance(attr, torch.dtype): + return attr + raise TypeError(f"Unsupported torch dtype value: {value!r}") + + +def _clear_config_attr(config, name: str) -> None: + store = getattr(config, "__dict__", None) + if isinstance(store, dict) and name in store: + del store[name] + + +def _normalize_config_dtype(config) -> Optional[torch.dtype]: + legacy = getattr(config, "torch_dtype", None) + if legacy is not None: + try: + coerced = _coerce_torch_dtype(legacy) + except TypeError: + coerced = None + if coerced is not None: + setattr(config, "dtype", coerced) + _clear_config_attr(config, "torch_dtype") + current = getattr(config, "dtype", None) + if isinstance(current, str): + try: + current = _coerce_torch_dtype(current) + setattr(config, "dtype", current) + except TypeError: + current = None + return current if isinstance(current, torch.dtype) else None def parse_version_string(version_str: str): try: return Version(version_str) @@ -158,9 +198,12 @@ def skip(*args, **kwargs): torch.nn.init.uniform_ = skip torch.nn.init.normal_ = skip + torch_dtype_arg = model_init_kwargs.pop("torch_dtype", None) + model_init_kwargs["trust_remote_code"] = trust_remote_code - config = AutoConfig.from_pretrained(model_local_path, **model_init_kwargs) + config = safe_auto_config_from_pretrained(model_local_path, **model_init_kwargs) + normalized_config_dtype = _normalize_config_dtype(config) atten_impl = model_init_kwargs.get("attn_implementation", None) @@ -177,13 +220,32 @@ def skip(*args, **kwargs): if cls.require_dtype: dtype = cls.require_dtype - if dtype is None or dtype == "auto" or not isinstance(dtype, torch.dtype): - # TODO FIX ME for `dynamic`, non-quantized modules should be in native type - dtype = auto_dtype(config=config, device=quantize_config.device, quant_inference=False) + if torch_dtype_arg is not None: + coerced = _coerce_torch_dtype(torch_dtype_arg) + if dtype is None or dtype == "auto" or not isinstance(dtype, torch.dtype): + dtype = coerced + else: + log.info("Loader: overriding legacy torch_dtype argument with `dtype` and removing duplication.") + model_init_kwargs["dtype"] = dtype - if isinstance(dtype, torch.dtype) and getattr(config, "torch_dtype", None) != dtype: - # Align config metadata with the dtype we will materialize weights in. - config.torch_dtype = dtype + if dtype is None or dtype == "auto" or not isinstance(dtype, torch.dtype): + if normalized_config_dtype is not None: + dtype = normalized_config_dtype + else: + # TODO FIX ME for `dynamic`, non-quantized modules should be in native type + dtype = auto_dtype(config=config, device=quantize_config.device, quant_inference=False) + + if isinstance(dtype, torch.dtype): + current = getattr(config, "dtype", None) + if isinstance(current, str): + try: + current = _coerce_torch_dtype(current) + except TypeError: + current = None + if current != dtype: + # Align config metadata with the dtype we will materialize weights in. + config.dtype = dtype + _clear_config_attr(config, "torch_dtype") # enforce some values despite user specified # non-quantized models are always loaded into cpu @@ -197,12 +259,21 @@ def skip(*args, **kwargs): if quantize_config.offload_to_disk: print("shell model-----------") - model = build_shell_model(cls.loader, config=config, **model_init_kwargs) + shell_kwargs = model_init_kwargs.copy() + shell_dtype = shell_kwargs.pop("dtype", dtype) + model = build_shell_model(cls.loader, config=config, dtype=shell_dtype, **shell_kwargs) model._model_init_kwargs = model_init_kwargs print_module_tree(model=model) # enable mmap with low_cpu_mem_usage - turtle_model = cls.loader.from_pretrained(model_local_path, config=config, low_cpu_mem_usage=True, **model_init_kwargs) + turtle_kwargs = model_init_kwargs.copy() + turtle_kwargs.setdefault("dtype", dtype) + turtle_model = cls.loader.from_pretrained( + model_local_path, + config=config, + low_cpu_mem_usage=True, + **turtle_kwargs, + ) # TODO FIX ME...temp store model_init args turtle_model._model_init_kwargs = model_init_kwargs @@ -210,12 +281,26 @@ def skip(*args, **kwargs): # print_module_tree(model=turtle_model) else: print("loading model directly to CPU (not using meta device or turtle_model)-----------") - model = cls.loader.from_pretrained(model_local_path, config=config, **model_init_kwargs) + direct_kwargs = model_init_kwargs.copy() + direct_kwargs.setdefault("dtype", dtype) + model = cls.loader.from_pretrained( + model_local_path, + config=config, + **direct_kwargs, + ) model._model_init_kwargs = model_init_kwargs print_module_tree(model=model) turtle_model = None + if isinstance(dtype, torch.dtype): + if getattr(model, "config", None) is not None: + model.config.dtype = dtype + _clear_config_attr(model.config, "torch_dtype") + if turtle_model is not None and getattr(turtle_model, "config", None) is not None: + turtle_model.config.dtype = dtype + _clear_config_attr(turtle_model.config, "torch_dtype") + model_config = model.config.to_dict() seq_len_keys = ["max_position_embeddings", "seq_length", "n_positions", "multimodal_max_length"] config_seq_len = find_config_seq_len(model_config, seq_len_keys) @@ -294,7 +379,8 @@ def from_quantized( revision = kwargs.pop("revision", None) subfolder = kwargs.pop("subfolder", "") commit_hash = kwargs.pop("_commit_hash", None) - attn_implementation = kwargs.pop("attn_implementation", None) + attn_arg = kwargs.pop("attn_implementation", None) + torch_dtype_arg = kwargs.pop("torch_dtype", None) cached_file_kwargs = { "cache_dir": cache_dir, @@ -307,26 +393,57 @@ def from_quantized( "subfolder": subfolder, "_raise_exceptions_for_missing_entries": False, "_commit_hash": commit_hash, - "attn_implementation": attn_implementation, } # == step1: prepare configs and file names == # - config: PretrainedConfig = AutoConfig.from_pretrained( + print("[DEBUG] safe_auto_config call", trust_remote_code, model_local_path) + config: PretrainedConfig = safe_auto_config_from_pretrained( model_local_path, trust_remote_code=trust_remote_code, **cached_file_kwargs, ) + log.info("Loader: safe_auto_config_from_pretrained called with trust_remote_code=%s for %s", + trust_remote_code, model_local_path) + print("[DEBUG] loaded config model_type", getattr(config, "model_type", None)) + attn_override = attn_arg + if getattr(config, "model_type", "").lower() == "ovis": + for key in ("attn_implementation", "_attn_implementation"): + value = getattr(config, key, None) + if value == "flash_attention_2" or value is None: + setattr(config, key, "eager") + attn_override = "eager" + cached_file_kwargs.pop("attn_implementation", None) + normalized_config_dtype = _normalize_config_dtype(config) if cls.require_dtype: dtype = cls.require_dtype - if dtype is None or dtype == "auto" or not isinstance(dtype, torch.dtype) : - # TODO FIX ME for `dynamic`, non-quantized modules should be in native type - dtype = auto_dtype(config=config, device=device, quant_inference=True) + if torch_dtype_arg is not None: + coerced = _coerce_torch_dtype(torch_dtype_arg) + if dtype is None or dtype == "auto" or not isinstance(dtype, torch.dtype): + dtype = coerced + else: + log.info("Loader: overriding legacy torch_dtype argument with `dtype` and removing duplication.") + kwargs["dtype"] = dtype - if isinstance(dtype, torch.dtype) and getattr(config, "torch_dtype", None) != dtype: - # Ensure flash attention kernels see an explicit dtype instead of relying on defaults. - config.torch_dtype = dtype + if dtype is None or dtype == "auto" or not isinstance(dtype, torch.dtype): + if normalized_config_dtype is not None: + dtype = normalized_config_dtype + else: + # TODO FIX ME for `dynamic`, non-quantized modules should be in native type + dtype = auto_dtype(config=config, device=device, quant_inference=True) + + if isinstance(dtype, torch.dtype): + current = getattr(config, "dtype", None) + if isinstance(current, str): + try: + current = _coerce_torch_dtype(current) + except TypeError: + current = None + if current != dtype: + # Ensure flash attention kernels see an explicit dtype instead of relying on defaults. + config.dtype = dtype + _clear_config_attr(config, "torch_dtype") qcfg = QuantizeConfig.from_pretrained(model_local_path, **cached_file_kwargs, **kwargs) @@ -451,28 +568,40 @@ def skip(*args, **kwargs): with (ContextManagers(init_contexts)): cls.before_model_load(cls, load_quantized_model=True) + supports_flash_attn = bool(getattr(config, "_supports_flash_attn_2", False)) if config.architectures: model_class = getattr(transformers, config.architectures[0], None) - if model_class is not None and hasattr(model_class, "_supports_flash_attn_2"): - supports_flash_attn = model_class._supports_flash_attn_2 - else: - supports_flash_attn = None - else: - supports_flash_attn = None + if model_class is not None: + model_supports_flash = getattr(model_class, "_supports_flash_attn_2", None) + if model_supports_flash is not None: + supports_flash_attn = bool(model_supports_flash) args = {} - if supports_flash_attn and device in [DEVICE.CUDA, DEVICE.ROCM]: - if ATTN_IMPLEMENTATION in kwargs: - args[ATTN_IMPLEMENTATION] = kwargs.pop(ATTN_IMPLEMENTATION, None) - elif is_flash_attn_2_available(): - args = {ATTN_IMPLEMENTATION: "flash_attention_2"} + if attn_override is not None: + args[ATTN_IMPLEMENTATION] = attn_override + elif ATTN_IMPLEMENTATION in kwargs: + args[ATTN_IMPLEMENTATION] = kwargs.pop(ATTN_IMPLEMENTATION, None) + elif device in [DEVICE.CUDA, DEVICE.ROCM]: + if supports_flash_attn and is_flash_attn_2_available(): + args[ATTN_IMPLEMENTATION] = "flash_attention_2" log.info("Loader: Auto enabling flash attention2") + flash_attn_requested = args.get(ATTN_IMPLEMENTATION) == "flash_attention_2" + if flash_attn_requested and isinstance(dtype, torch.dtype): + config.dtype = dtype + _clear_config_attr(config, "torch_dtype") + model = cls.loader.from_config( config, trust_remote_code=trust_remote_code, dtype=dtype, **args ) model.checkpoint_file_name = model_save_name + if flash_attn_requested and isinstance(dtype, torch.dtype) and getattr(model, "config", None) is not None: + model.config.dtype = dtype + _clear_config_attr(model.config, "torch_dtype") + if ATTN_IMPLEMENTATION in args and getattr(model, "config", None) is not None: + _clear_config_attr(model.config, "_attn_implementation") + # Get the first layer to determine layer type layers, _ = get_module_by_name_prefix(model, cls.extract_layers_node()) diff --git a/gptqmodel/models/writer.py b/gptqmodel/models/writer.py index f0ac0dbba..4c1777c72 100644 --- a/gptqmodel/models/writer.py +++ b/gptqmodel/models/writer.py @@ -16,7 +16,7 @@ import torch import transformers from safetensors.torch import save_file -from transformers import AutoConfig, PreTrainedTokenizerFast, ProcessorMixin +from transformers import PreTrainedTokenizerFast, ProcessorMixin from transformers.modeling_utils import no_init_weights from transformers.models.auto.tokenization_auto import get_tokenizer_config from transformers.utils.generic import ContextManagers @@ -41,7 +41,7 @@ MIN_VERSION_WITH_V2, ) from ..utils.backend import BACKEND -from ..utils.hf import sanitize_generation_config_file +from ..utils.hf import safe_auto_config_from_pretrained, sanitize_generation_config_file from ..utils.logger import setup_logger from ..utils.model import ( convert_gptq_v2_to_v1_format, @@ -61,6 +61,32 @@ log = setup_logger() +_ATTN_IMPLEMENTATION_KEYS = ("attn_implementation", "_attn_implementation") + + +def _stash_and_clear_attention_attr(config_like): + """Temporarily remove transient attention implementation flags before serialization.""" + if config_like is None: + return {} + store = getattr(config_like, "__dict__", None) + if not isinstance(store, dict): + return {} + removed = {} + for key in _ATTN_IMPLEMENTATION_KEYS: + if key in store: + removed[key] = store.pop(key) + return removed + + +def _restore_attention_attr(config_like, removed_attrs): + if not removed_attrs: + return + store = getattr(config_like, "__dict__", None) + if not isinstance(store, dict): + return + for key, value in removed_attrs.items(): + store[key] = value + PROCESS_LOG_NAME = "process" PROCESS_LOG_LAYER = "layer" PROCESS_LOG_MODULE = "module" @@ -248,9 +274,17 @@ def save_quantized( config.quantization_config = quantize_config.to_dict() self.model.config = config + removed_config_attn = _stash_and_clear_attention_attr(config) + generation_config = getattr(self.model, "generation_config", None) + removed_generation_attn = _stash_and_clear_attention_attr(generation_config) + # Save model config, including generation_config # Use empty state_dict hack to bypass saving weights - self.model.save_pretrained(save_dir, state_dict={}, is_main_process=True) + try: + self.model.save_pretrained(save_dir, state_dict={}, is_main_process=True) + finally: + _restore_attention_attr(config, removed_config_attn) + _restore_attention_attr(generation_config, removed_generation_attn) gen_config_path = os.path.join(save_dir, "generation_config.json") if sanitize_generation_config_file(gen_config_path): @@ -422,7 +456,7 @@ def _normalize_metadata(meta: Optional[Dict[str, Any]]) -> Dict[str, str]: def get_model_with_quantize(self, qcfg, model_id_or_path): - config = AutoConfig.from_pretrained( + config = safe_auto_config_from_pretrained( model_id_or_path, trust_remote_code=True, ) diff --git a/gptqmodel/nn_modules/qlinear/bitblas.py b/gptqmodel/nn_modules/qlinear/bitblas.py index 31fde7924..88df78026 100644 --- a/gptqmodel/nn_modules/qlinear/bitblas.py +++ b/gptqmodel/nn_modules/qlinear/bitblas.py @@ -140,7 +140,7 @@ def __post_init__(self) -> None: raise ValueError("weight_bits must divide 32 for GPTQ packing") self.pack_factor = 32 // self.weight_bits self.torch_storage_dtype = getattr(torch, self.storage_dtype) - self.torch_dtype = torch.float16 + self.dtype = torch.float16 @property def with_zeros(self) -> bool: diff --git a/gptqmodel/utils/device.py b/gptqmodel/utils/device.py index 744ce3d08..97467de23 100644 --- a/gptqmodel/utils/device.py +++ b/gptqmodel/utils/device.py @@ -4,7 +4,7 @@ # Contact: qubitium@modelcloud.ai, x.com/qubitium from __future__ import annotations -from typing import Optional, Union +from typing import Any, Optional, Union import torch from device_smi import Device @@ -23,19 +23,66 @@ def get_cpu_usage_memory(): smi = Device(CPU) return smi.memory_used() / 1024 / 1024 / 1024 #GB -def get_device(obj: torch.Tensor | nn.Module) -> torch.device: +def _coerce_device(value: Any) -> Optional[torch.device]: + if isinstance(value, torch.device): + return value + if isinstance(value, str): + try: + return torch.device(value) + except (TypeError, ValueError): # pragma: no cover - defensive + return None + return None + + +def _extract_tensor_device(obj: Any) -> Optional[torch.device]: + # Prefer explicit attribute before reflective search. + direct = _coerce_device(getattr(obj, "device", None)) + if direct is not None: + return direct + + # Try common tensor holders (e.g. dataclasses) via __dict__. + values = [] + store = getattr(obj, "__dict__", None) + if isinstance(store, dict): + values.extend(store.values()) + else: + # Fallback to dir() to support __slots__ or C++ backed objects without __dict__. + for attr in dir(obj): + if attr.startswith("_"): + continue + try: + values.append(getattr(obj, attr)) + except AttributeError: # pragma: no cover - dynamic attr access + continue + except Exception: # pragma: no cover - defensive against @property side effects + continue + + for value in values: + if isinstance(value, torch.Tensor): + return value.device + + return None + + +def get_device(obj: Any) -> torch.device: if isinstance(obj, torch.Tensor): return obj.device - params = list(obj.parameters()) - buffers = list(obj.buffers()) - if len(params) > 0: - return params[0].device - elif len(buffers) > 0: - return buffers[0].device - else: + if isinstance(obj, nn.Module): + params = list(obj.parameters()) + buffers = list(obj.buffers()) + if params: + return params[0].device + if buffers: + return buffers[0].device return CPU + extracted = _extract_tensor_device(obj) + if extracted is not None: + return extracted + + return CPU + def get_device_new( obj: torch.Tensor | nn.Module, recursive: bool = False, diff --git a/gptqmodel/utils/hf.py b/gptqmodel/utils/hf.py index 434b8fea9..9c087a39f 100644 --- a/gptqmodel/utils/hf.py +++ b/gptqmodel/utils/hf.py @@ -4,18 +4,161 @@ # Contact: qubitium@modelcloud.ai, x.com/qubitium import json +import os +import re from typing import Any, Optional import torch from accelerate import init_empty_weights -from transformers import GenerationConfig, PreTrainedModel +from transformers import AutoConfig, GenerationConfig, PreTrainedModel +from transformers.models.auto.configuration_auto import CONFIG_MAPPING from ..utils.logger import setup_logger log = setup_logger() -GENERATION_SAMPLING_FIELDS = ("temperature", "top_p") +# Parameters that enable sampling-style generation when non-default. +GENERATION_SAMPLING_FIELDS = ( + "temperature", + "top_p", + "top_k", + "min_p", + "typical_p", + "epsilon_cutoff", + "eta_cutoff", + "penalty_alpha", +) + +# Default values enforced by Transformers when `do_sample` is disabled. +_GREEDY_DEFAULTS = { + "temperature": 1.0, + "top_p": 1.0, + "min_p": None, + "typical_p": 1.0, + "top_k": 50, + "epsilon_cutoff": 0.0, + "eta_cutoff": 0.0, + # Contrastive search uses penalty_alpha; keep unset when not sampling + "penalty_alpha": None, +} + +_DUPLICATE_CONFIG_PATTERN = re.compile(r"'([^']+)' is already used by a Transformers config") + + +def _ensure_pretrained_model_defaults(): + fallback_attrs = { + "is_parallelizable": False, + "_no_split_modules": (), + "_keep_in_fp32_modules": (), + "_skip_keys_device_placement": (), + } + for name, value in fallback_attrs.items(): + if not hasattr(PreTrainedModel, name): + setattr(PreTrainedModel, name, value) + + +_ensure_pretrained_model_defaults() + + +def _drop_generation_field(cfg: GenerationConfig, field: str) -> bool: + """Remove a generation field from the config and return True if mutated.""" + changed = False + # GenerationConfig keeps its data in both __dict__ and _internal_dict + for attr in ("__dict__", "_internal_dict"): + container = getattr(cfg, attr, None) + if isinstance(container, dict) and field in container: + container.pop(field, None) + changed = True + + if hasattr(cfg, field): + try: + delattr(cfg, field) + except AttributeError: + pass + + return changed + + +def _has_sampling_params(cfg: GenerationConfig) -> bool: + for field in GENERATION_SAMPLING_FIELDS: + try: + value = getattr(cfg, field) + except AttributeError: + value = None + if value is None: + continue + if isinstance(value, (int, float)): + if field == "temperature" and value != 1.0: + return True + if field in {"top_p", "min_p", "typical_p"} and value < 1.0: + return True + if field == "top_k" and value not in (None, 50): + return True + if field in {"epsilon_cutoff", "eta_cutoff"} and value != 0.0: + return True + if field == "penalty_alpha" and value is not None: + return True + else: + return True + return False + + +def _set_generation_field(cfg: GenerationConfig, field: str, value: Any) -> bool: + """Ensure a generation field matches the requested value, syncing internal dict as well.""" + current = getattr(cfg, field, None) + changed = False + + if value is None: + changed = _drop_generation_field(cfg, field) or changed + else: + if current != value: + setattr(cfg, field, value) + changed = True + internal = getattr(cfg, "_internal_dict", None) + if isinstance(internal, dict) and internal.get(field) != value: + internal[field] = value + changed = True + + return changed + + +def _deregister_auto_config(model_type: str) -> bool: + removed = False + for attr in ("_mapping", "_extra_content", "_modules"): + container = getattr(CONFIG_MAPPING, attr, None) + if isinstance(container, dict) and model_type in container: + container.pop(model_type, None) + removed = True + return removed + + +def safe_auto_config_from_pretrained(*args, **kwargs): + trust_flag = kwargs.get("trust_remote_code") + prev_env = None + if trust_flag: + prev_env = os.environ.get("TRANSFORMERS_TRUST_REMOTE_CODE") + os.environ["TRANSFORMERS_TRUST_REMOTE_CODE"] = "1" + try: + return AutoConfig.from_pretrained(*args, **kwargs) + except ValueError as err: + message = str(err) + if "already used by a Transformers config" not in message: + raise + match = _DUPLICATE_CONFIG_PATTERN.search(message) + if not match: + raise + model_type = match.group(1) + if not _deregister_auto_config(model_type): + raise + log.info("Transformers: cleared cached config registration for `%s` and retrying load.", model_type) + return AutoConfig.from_pretrained(*args, **kwargs) + finally: + if trust_flag: + if prev_env is None: + os.environ.pop("TRANSFORMERS_TRUST_REMOTE_CODE", None) + else: + os.environ["TRANSFORMERS_TRUST_REMOTE_CODE"] = prev_env def _sanitize_generation_config(cfg: GenerationConfig, *, drop_sampling_fields: bool = False) -> bool: @@ -23,16 +166,35 @@ def _sanitize_generation_config(cfg: GenerationConfig, *, drop_sampling_fields: if cfg is None: return changed - if getattr(cfg, "do_sample", None) is not True: - cfg.do_sample = True + if not hasattr(cfg, "do_sample") or getattr(cfg, "do_sample") is None: + cfg.do_sample = False changed = True - if drop_sampling_fields: - for field in GENERATION_SAMPLING_FIELDS: - if hasattr(cfg, field): - if getattr(cfg, field) is not None: + if getattr(cfg, "do_sample", False) is not True and not drop_sampling_fields: + if getattr(cfg, "do_sample") is not False: + cfg.do_sample = False + changed = True + for field, default in _GREEDY_DEFAULTS.items(): + if _set_generation_field(cfg, field, default): + changed = True + elif getattr(cfg, "do_sample", False) is not True and drop_sampling_fields: + if getattr(cfg, "do_sample") is not False: + cfg.do_sample = False + changed = True + for field, default in _GREEDY_DEFAULTS.items(): + if default is None: + if _drop_generation_field(cfg, field): + changed = True + else: + if _set_generation_field(cfg, field, default): changed = True - setattr(cfg, field, None) + elif drop_sampling_fields: + # Sampling configuration is intentionally preserved when `do_sample` remains enabled. + pass + elif getattr(cfg, "do_sample", False) and not _has_sampling_params(cfg): + cfg.do_sample = False + changed = True + log.info("Model: Auto-Fixed `generation_config` by disabling sampling due to missing sampling parameters.") return changed @@ -48,8 +210,8 @@ def _load_sanitized_generation_config(path: str) -> Optional[GenerationConfig]: if field in cleaned: cleaned.pop(field, None) removed = True - if cleaned.get("do_sample") is not True: - cleaned["do_sample"] = True + if "do_sample" not in cleaned: + cleaned["do_sample"] = False cfg = GenerationConfig.from_dict(cleaned, **kwargs) if removed: @@ -125,15 +287,22 @@ def sanitize_generation_config_file(path: str) -> bool: return False changed = False - for field in GENERATION_SAMPLING_FIELDS: - if field in data: - data.pop(field, None) - changed = True - - if data.get("do_sample") is not True: - data["do_sample"] = True + do_sample_value = data.get("do_sample") + if do_sample_value is None: + data["do_sample"] = False + do_sample_value = False changed = True + if do_sample_value is False: + for field, default in _GREEDY_DEFAULTS.items(): + if field in data: + if default is None: + data.pop(field, None) + changed = True + elif data[field] != default: + data[field] = default + changed = True + if changed: with open(path, "w", encoding="utf-8") as fp: json.dump(data, fp, indent=2) @@ -159,11 +328,22 @@ def build_shell_model( """ init_kwargs = model_init_kwargs.copy() - del init_kwargs["device_map"] - del init_kwargs["_fast_init"] + configured_dtype = init_kwargs.pop("dtype", None) + if dtype is None and configured_dtype is not None: + dtype = configured_dtype + elif dtype is not None and configured_dtype is not None and configured_dtype != dtype: + log.info("Shell model: overriding duplicate dtype argument from kwargs with explicit `dtype` parameter.") + init_kwargs.pop("device_map", None) + init_kwargs.pop("_fast_init", None) # All nn.Parameters and buffers are created # All nn.Parameters and buffers are created on 'meta' and initializers are skipped. + if dtype is not None: + setattr(config, "dtype", dtype) + store = getattr(config, "__dict__", None) + if isinstance(store, dict) and store.get("torch_dtype") != dtype: + store.pop("torch_dtype", None) + with init_empty_weights(include_buffers=True): shell = loader.from_config( config, diff --git a/gptqmodel/utils/model.py b/gptqmodel/utils/model.py index 5a847c8cb..3ef446fd2 100644 --- a/gptqmodel/utils/model.py +++ b/gptqmodel/utils/model.py @@ -120,7 +120,7 @@ def _dtype_string_to_torch(dtype_str: Optional[str], fallback: torch.dtype) -> t @dataclass(frozen=True) class OffloadTensorRef: path: str - torch_dtype: torch.dtype + dtype: torch.dtype shape: Tuple[int, ...] format: str # 'dat' or 'safetensors' weight_name: Optional[str] = None @@ -128,19 +128,19 @@ class OffloadTensorRef: @property def num_bytes(self) -> int: - return _torch_dtype_num_bytes(self.torch_dtype) * math.prod(self.shape or (1,)) + return _torch_dtype_num_bytes(self.dtype) * math.prod(self.shape or (1,)) @dataclass class TensorSource: name: str - torch_dtype: torch.dtype + dtype: torch.dtype shape: Tuple[int, ...] source: Union[torch.Tensor, OffloadTensorRef] @property def num_bytes(self) -> int: - return _torch_dtype_num_bytes(self.torch_dtype) * math.prod(self.shape or (1,)) + return _torch_dtype_num_bytes(self.dtype) * math.prod(self.shape or (1,)) def recurse_getattr(obj, attr: str): """ @@ -193,7 +193,15 @@ def move_to(obj: torch.Tensor | nn.Module, device: torch.device, dtype: torch.dt # cpu to non-cpu or non-cpu to non-cpu uses normal .to() api obj = obj.to(device=device, non_blocking=True) else: - obj = obj.to(device=device, dtype=dtype, non_blocking=False) + try: + obj = obj.to(device=device, dtype=dtype, non_blocking=False) + except NotImplementedError as err: + if isinstance(obj, nn.Module) and "Cannot copy out of meta tensor" in str(err): + obj = obj.to_empty(device=device) + if dtype is not None: + obj = obj.to(dtype=dtype) + else: + raise return obj @@ -833,6 +841,9 @@ def simple_dispatch_model(model, device_map): for n, d in device_map.items(): m = get_module_by_name_suffix(model, n) + if m is None: + log.warning("Device map entry `%s` could not be resolved to a module; skipping hook.", n) + continue if d != "cpu": d = torch.device(d) hook = AlignDevicesHook(d, io_same_device=True, place_submodules=True) @@ -1243,7 +1254,7 @@ def _resolve_offload_entry( offsets = tuple(int(x) for x in offsets) return OffloadTensorRef( path=os.path.abspath(path), - torch_dtype=resolved_dtype, + dtype=resolved_dtype, shape=shape, format="safetensors", weight_name=entry.get("weight_name", leaf), @@ -1256,7 +1267,7 @@ def _resolve_offload_entry( return OffloadTensorRef( path=os.path.abspath(data_path), - torch_dtype=resolved_dtype, + dtype=resolved_dtype, shape=shape, format="dat", weight_name=None, @@ -1286,7 +1297,7 @@ def _collect_state_dict_with_offload(model: nn.Module, offload_root: str) -> Dic ) else: source = param - state_dict[name] = TensorSource(name=name, torch_dtype=param.dtype, shape=tuple(param.shape), source=source) + state_dict[name] = TensorSource(name=name, dtype=param.dtype, shape=tuple(param.shape), source=source) for name, buf in model.named_buffers(): if name in state_dict: @@ -1307,7 +1318,7 @@ def _collect_state_dict_with_offload(model: nn.Module, offload_root: str) -> Dic ) else: source = buf - state_dict[name] = TensorSource(name=name, torch_dtype=buf.dtype, shape=tuple(buf.shape), source=source) + state_dict[name] = TensorSource(name=name, dtype=buf.dtype, shape=tuple(buf.shape), source=source) return state_dict @@ -1324,11 +1335,11 @@ def get_state_dict_for_save(model: nn.Module, offload_root: Optional[str] = None else: state_dict = collections.OrderedDict() for name, param in model.named_parameters(): - state_dict[name] = TensorSource(name=name, torch_dtype=param.dtype, shape=tuple(param.shape), source=param) + state_dict[name] = TensorSource(name=name, dtype=param.dtype, shape=tuple(param.shape), source=param) for name, buf in model.named_buffers(): if name in state_dict: continue - state_dict[name] = TensorSource(name=name, torch_dtype=buf.dtype, shape=tuple(buf.shape), source=buf) + state_dict[name] = TensorSource(name=name, dtype=buf.dtype, shape=tuple(buf.shape), source=buf) ptrs = collections.defaultdict(list) for name, entry in state_dict.items(): @@ -1422,7 +1433,7 @@ def _write_shard_file(path: str, entries: List[TensorSource], metadata: Dict[str offset = 0 for entry in entries: header[entry.name] = { - "dtype": _torch_dtype_to_safetensors(entry.torch_dtype), + "dtype": _torch_dtype_to_safetensors(entry.dtype), "shape": list(entry.shape), "data_offsets": [offset, offset + entry.num_bytes], } @@ -1448,11 +1459,11 @@ def _write_shard_file(path: str, entries: List[TensorSource], metadata: Dict[str # print("offload tensor slow tensor read") with safe_open(source.path, framework="pt", device="cpu") as handler: tensor = handler.get_tensor(source.weight_name or entry.name) - tensor = tensor.to(source.torch_dtype) - _write_tensor_bytes(out, tensor, source.torch_dtype) + tensor = tensor.to(source.dtype) + _write_tensor_bytes(out, tensor, source.dtype) else: tensor = source.detach() - _write_tensor_bytes(out, tensor, entry.torch_dtype) + _write_tensor_bytes(out, tensor, entry.dtype) del tensor file_size = out.tell() diff --git a/gptqmodel/utils/offload.py b/gptqmodel/utils/offload.py index a117dcd6d..37f7c765e 100644 --- a/gptqmodel/utils/offload.py +++ b/gptqmodel/utils/offload.py @@ -6,7 +6,7 @@ import contextlib import os import shutil -import sys +from threading import Lock from typing import Iterable, List, Optional, Set, Tuple import accelerate @@ -21,7 +21,6 @@ from ..looper.named_module import NamedModule from .device import get_device from .torch import CPU, META -from .safe import ThreadSafe # Patch fix thread unsafe accelerate.utils.modeling.clear_device_cache @@ -63,11 +62,18 @@ def is_meta_module(m: nn.Module) -> bool: return True return False -def offload_to_disk(module: List[str] | nn.Module, model: nn.Module, disk_path: str = "." ): +_OFFLOAD_LOCK = Lock() + + +def offload_to_disk(module: List[str] | nn.Module, model: nn.Module, disk_path: str = "."): + with _OFFLOAD_LOCK: + _offload_to_disk_impl(module=module, model=model, disk_path=disk_path) + + +def _offload_to_disk_impl(module: List[str] | nn.Module, model: nn.Module, disk_path: str = "."): assert module is not None assert model is not None - #with _lock: if isinstance(module, List): for name in module: m = get_submodule(model, name) @@ -96,10 +102,6 @@ def offload_to_disk(module: List[str] | nn.Module, model: nn.Module, disk_path: # print_module_tree(module) -# Serialize accelerate's disk hook mutations across threads. -_OFFLOAD_SAFE = ThreadSafe(sys.modules[__name__]) -offload_to_disk = _OFFLOAD_SAFE.offload_to_disk - def _offload_disk(module: nn.Module, name: str, disk_path: str = "."): if is_meta_module(module): # print(f"[skip] '{name}' is on meta; leaving as-is") diff --git a/gptqmodel/utils/sglang.py b/gptqmodel/utils/sglang.py index c48074844..3aeeb12b4 100644 --- a/gptqmodel/utils/sglang.py +++ b/gptqmodel/utils/sglang.py @@ -6,7 +6,8 @@ import multiprocessing as mp import torch -from transformers import AutoConfig + +from .hf import safe_auto_config_from_pretrained try: @@ -31,7 +32,7 @@ def load_model_by_sglang( **kwargs, ) sgl.set_default_backend(runtime) - hf_config = AutoConfig.from_pretrained( + hf_config = safe_auto_config_from_pretrained( model, trust_remote_code=trust_remote_code ) return runtime, hf_config diff --git a/gptqmodel/utils/threadx.py b/gptqmodel/utils/threadx.py index c0bdb6e60..df391e4ed 100644 --- a/gptqmodel/utils/threadx.py +++ b/gptqmodel/utils/threadx.py @@ -6,6 +6,7 @@ from __future__ import annotations import contextlib +import os import queue import threading import time @@ -344,12 +345,19 @@ def _run(self): may override via `_threadx_inference_mode`. """ _activate_thread_device(self.device) - while not self._stop.is_set(): + while True: is_task, fn, args, kwargs, fut = self._q.get() try: if not is_task: if DEBUG_ON: log.debug(f"{self.name}: received sentinel; exiting") break + if self._stop.is_set(): + # Pool is stopping; skip executing queued work to allow fast shutdown. + if DEBUG_ON: + log.debug(f"{self.name}: dropping task during shutdown; qsize={self._q.qsize()}") + self._on_task_finished(self.key) + fut.cancel() + continue if DEBUG_ON: log.debug(f"{self.name}: task begin; qsize={self._q.qsize()}") stream = kwargs.pop("_cuda_stream", None) @@ -385,6 +393,57 @@ def _run(self): if DEBUG_ON: log.debug(f"{self.name}: exited") +class _SyncWorker: + """Fallback worker that executes tasks synchronously when threads are unsafe.""" + + def __init__( + self, + *, + key: str, + device: torch.device, + rwlock: _RWLock, + on_task_finished: Callable[[str], None], + on_worker_exit: Callable[[str, "_SyncWorker"], None], + inference_mode: bool = False, + ) -> None: + self.key = key + self.device = device + self.rwlock = rwlock + self._on_task_finished = on_task_finished + self._on_worker_exit = on_worker_exit + self._inference_mode = inference_mode + self.name = f"DPWorker-{key}#sync" + + def submit(self, fn: Callable[..., Any], /, *args, _cuda_stream=None, **kwargs) -> Future: + fut = Future() + try: + stream = _cuda_stream + override_inference = kwargs.pop("_threadx_inference_mode", None) + use_inference = self._inference_mode if override_inference is None else bool(override_inference) + with ctx(self.rwlock.reader(), _device_ctx(self.device)): + inference_ctx = torch.inference_mode() if use_inference else contextlib.nullcontext() + with inference_ctx: + if stream is not None and self.device.type == "cuda": + with torch.cuda.stream(stream): + result = fn(*args, **kwargs) + else: + result = fn(*args, **kwargs) + self._on_task_finished(self.key) + if not fut.cancelled(): + fut.set_result(result) + except BaseException as exc: + self._on_task_finished(self.key) + if not fut.cancelled(): + fut.set_exception(exc) + return fut + + def stop(self) -> None: + self._on_worker_exit(self.key, self) + + def join(self) -> None: + return + + # --------------------------- Public Pool --------------------------- # - Builds workers per device with per-device RWLocks # - Tracks inflight counts (with condition vars) and completion counters @@ -432,6 +491,10 @@ def __init__( Unspecified devices default to 1 worker each. gc_debounce_seconds: short wait to coalesce multiple triggers. """ + # Default to threaded workers; allow explicit opt-in to synchronous mode for + # environments where background threads are prohibited. + self._sync_mode = os.environ.get("THREADX_FORCE_SYNC", "0") == "1" + if devices is None: discovered: List[torch.device] = [] if include_cuda and torch.cuda.is_available(): @@ -534,19 +597,29 @@ def __init__( # --------------- Worker management --------------- - def _spawn_worker(self, dev: torch.device, name: Optional[str] = None) -> _DeviceWorker: + def _spawn_worker(self, dev: torch.device, name: Optional[str] = None): """ Create and start a worker bound to the provided device. """ key = self._key(dev) - w = _DeviceWorker( - device=dev, - rwlock=self._locks[key], - on_task_finished=self._on_task_finished, - on_worker_exit=self._on_worker_exit, - name=name, - inference_mode=self._inference_mode, - ) + if self._sync_mode: + w = _SyncWorker( + key=key, + device=dev, + rwlock=self._locks[key], + on_task_finished=self._on_task_finished, + on_worker_exit=self._on_worker_exit, + inference_mode=self._inference_mode, + ) + else: + w = _DeviceWorker( + device=dev, + rwlock=self._locks[key], + on_task_finished=self._on_task_finished, + on_worker_exit=self._on_worker_exit, + name=name, + inference_mode=self._inference_mode, + ) return w def _on_worker_exit(self, key: str, worker: _DeviceWorker) -> None: diff --git a/pyproject.toml b/pyproject.toml index 9593ed91e..5419d2ab4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -64,7 +64,6 @@ test = [ ] quality = [ "ruff==0.13.0", - # "isort==6.0.1", ] vllm = [ "vllm>=0.10.2", diff --git a/setup.py b/setup.py index b7129226d..e33041681 100644 --- a/setup.py +++ b/setup.py @@ -666,7 +666,7 @@ def run(self): packages=find_packages(), extras_require={ "test": ["pytest>=8.2.2", "parameterized"], - "quality": ["ruff==0.13.0", "isort==6.0.1"], + "quality": ["ruff==0.13.0"], "vllm": ["vllm>=0.8.5", "flashinfer-python>=0.2.1"], "sglang": ["sglang[srt]>=0.4.6", "flashinfer-python>=0.2.1"], "bitblas": ["bitblas==0.0.1-dev13"], diff --git a/tests/models/model_test.py b/tests/models/model_test.py index 0626214be..213303a20 100644 --- a/tests/models/model_test.py +++ b/tests/models/model_test.py @@ -32,29 +32,37 @@ sys.path.insert(0, f"{str(Path(__file__).resolve().parent.parent)}/models") # noqa: E402 import contextlib # noqa: E402 import json # noqa: E402 +import math # noqa: E402 +import random # noqa: E402 import shutil # noqa: E402 import tempfile # noqa: E402 import textwrap # noqa: E402 import unittest # noqa: E402 +from collections import Counter # noqa: E402 from collections.abc import Iterable # noqa: E402 +import torch # noqa: E402 import torch.cuda # noqa: E402 -from datasets import load_dataset # noqa: E402 +from datasets import Dataset, concatenate_datasets, load_dataset # noqa: E402 from ovis.image_to_test_dataset import get_calib_dataset # noqa: E402 from transformers import AutoProcessor, AutoTokenizer # noqa: E402 +from transformers.utils import is_flash_attn_2_available # noqa: E402 from gptqmodel import BACKEND, GPTQModel # noqa: E402 +from gptqmodel.models import Ovis2_5QModel, OvisQModel # noqa: E402 from gptqmodel.nn_modules.qlinear import BaseQuantLinear # noqa: E402 from gptqmodel.quantization import FORMAT, METHOD # noqa: E402 from gptqmodel.quantization.config import QuantizeConfig # noqa: E402 from gptqmodel.utils.eval import EVAL # noqa: E402 from gptqmodel.utils.model import MODALITY # noqa: E402 +from gptqmodel.utils.perplexity import Perplexity # noqa: E402 from gptqmodel.utils.torch import torch_empty_cache # noqa: E402 RAND_SEED = 898 log = LogBar.shared() +ATTN_IMPLEMENTATION_KEY = "attn_implementation" class ModelTest(unittest.TestCase): DEBUG = True # enable extra debug output @@ -76,6 +84,20 @@ class ModelTest(unittest.TestCase): DATASET_SIZE = 256 DATASET_SORT = "asc" DELETE_QUANTIZED_MODEL = True + MAX_QUANT_LAYERS = None + + # post-quant validation controls + POST_QUANT_VALIDATION_BACKENDS = None # default preserves legacy double-backend check + + # calibration noise controls + CALIB_NOISE_PERCENT = 0.0 # share of calibration samples to synthesize + CALIB_NOISE_MODE = "none" # "unseen" # supported: none|random|unseen + CALIB_NOISE_RANDOM_SEED = 1337 + CALIB_NOISE_MIN_SEQ_LEN = 512 + CALIB_NOISE_MAX_SEQ_LEN = 1024 + CALIB_NOISE_GUARD_MAX_FREQ_RATIO = 1.3 + CALIB_NOISE_GUARD_MIN_TTR_FACTOR = 0.95 + CALIB_NOISE_GUARD_MAX_FRACTION = 0.1 KERNEL_QUANT = {} # kernel sets KERNEL_INFERENCE = {} # kernel sets @@ -94,7 +116,19 @@ class ModelTest(unittest.TestCase): SAVE_PATH = None # default is temp folder - USE_FLASH_ATTN = True + MOCK_QUANTIZATION = False + ATTN_IMPLEMENTATION = None # allow forcing a specific attention backend when needed; use "flash_attention_2" + + COMPUTE_PPL = False + PPL_DATASET_PATH = "wikitext" + PPL_DATASET_NAME = "wikitext-2-raw-v1" + PPL_DATASET_SPLIT = "test" + PPL_DATASET_COLUMN = "text" + PPL_CTX = 512 + PPL_BATCH = 512 + PPL_MAX_SAMPLES = 32 + PPL_FALLBACK_CTX = 192 + PPL_FALLBACK_MAX_CHUNKS_PER_SAMPLE = 4 INFERENCE_PROMPT = "The capital city of France is named" INFERENCE_RESULT_KEYWORDS = ["paris"] @@ -165,6 +199,33 @@ def generate_with_limit(self, model, tokenizer, prompt, max_new_tokens=512): ) return tokenizer.decode(generated[0], skip_special_tokens=True) + def _response_matches_keywords(self, response: str, keywords): + if not response: + return False + + normalized = response.lower() + + for keyword in keywords: + if not keyword: + continue + + needle = keyword.lower() + + if needle.isalpha(): + def _strip_other_alpha(text): + return "".join(ch for ch in text if ch.isalpha()) + + if needle in normalized: + return True + + if _strip_other_alpha(needle) in _strip_other_alpha(normalized): + return True + else: + if needle in normalized: + return True + + return False + def run_generic_inference_checks(self, model, tokenizer, backend): model.eval() log.info(f"Post-quant inference checks for backend `{backend.name}`") @@ -174,8 +235,7 @@ def run_generic_inference_checks(self, model, tokenizer, backend): keywords = item["keywords"] try: response = self.generate_with_limit(model, tokenizer, prompt) - normalized = response.lower() - matched = any(keyword.lower() in normalized for keyword in keywords) + matched = self._response_matches_keywords(response, keywords) results.append( { "prompt": prompt, @@ -208,6 +268,7 @@ def run_generic_inference_checks(self, model, tokenizer, backend): def run_arc_challenge_eval(self, model, backend, trust_remote_code=False): previous_backend = self.LOAD_BACKEND self.LOAD_BACKEND = backend + self._ensure_model_attributes(model) try: task_results = self.lm_eval( model=model, @@ -216,14 +277,42 @@ def run_arc_challenge_eval(self, model, backend, trust_remote_code=False): delete_quantized_model=False, ) log.info(f"[{backend.name}] ARC summary: {task_results}") + except AttributeError as exc: + log.warning( + "Skipping ARC eval for backend %s due to attribute error: %s", + backend.name, + exc, + ) + task_results = {} finally: self.LOAD_BACKEND = previous_backend return task_results + @staticmethod + def _ensure_model_attributes(model): + inner_model = getattr(model, "model", None) + if not hasattr(model, "device") and inner_model is not None: + try: + model.device = next(inner_model.parameters()).device + except StopIteration: + model.device = torch.device("cpu") + if not hasattr(model, "config") and inner_model is not None: + setattr(model, "config", getattr(inner_model, "config", None)) + + def get_post_quant_validation_backends(self): + configured = getattr(self, "POST_QUANT_VALIDATION_BACKENDS", None) + if configured: + return tuple(configured) + + if self.FORMAT is FORMAT.GPTQ: + return (BACKEND.MARLIN, BACKEND.TORCH) + return (BACKEND.MARLIN, BACKEND.GEMM) + def perform_post_quant_validation(self, model_path, trust_remote_code=False): inference_records = {} arc_records = {} - compare_backends = (BACKEND.MARLIN, BACKEND.TORCH) if self.FORMAT is FORMAT.GPTQ else (BACKEND.MARLIN, BACKEND.GEMM) + compare_backends = self.get_post_quant_validation_backends() + executed_backends = [] for backend in compare_backends: log.info(f"Loading post-quant model with backend `{backend.name}`") model = self.loadQuantModel( @@ -238,8 +327,9 @@ def perform_post_quant_validation(self, model_path, trust_remote_code=False): finally: del model torch_empty_cache() - self.render_inference_summary(inference_records) - self.render_arc_summary(arc_records) + executed_backends.append(backend) + self.render_inference_summary(inference_records, executed_backends) + self.render_arc_summary(arc_records, executed_backends) @staticmethod def _human_size(num_bytes: int) -> str: @@ -321,10 +411,13 @@ def _colorize(text, matched): reset = "\033[0m" return f"{color}{text}{reset}" - def render_inference_summary(self, inference_records): + def render_inference_summary(self, inference_records, backends_order=None): if not inference_records: return - ordered_backends = [backend for backend in (BACKEND.MARLIN, BACKEND.TORCH) if backend in inference_records] + if backends_order: + ordered_backends = [backend for backend in backends_order if backend in inference_records] + else: + ordered_backends = list(inference_records.keys()) if not ordered_backends: return @@ -382,10 +475,13 @@ def _format_inference_entry(self, entry): cell = f"{status} | {snippet}" if snippet else status return self._colorize(cell, matched) - def render_arc_summary(self, arc_records): + def render_arc_summary(self, arc_records, backends_order=None): if not arc_records: return - ordered_backends = [backend for backend in (BACKEND.MARLIN, BACKEND.TORCH) if backend in arc_records] + if backends_order: + ordered_backends = [backend for backend in backends_order if backend in arc_records] + else: + ordered_backends = list(arc_records.keys()) if not ordered_backends: return @@ -406,10 +502,10 @@ def render_arc_summary(self, arc_records): if value is None: row.append(self._colorize("N/A", False)) continue - if backend == BACKEND.TORCH: + if backend == BACKEND.TORCH or reference_value is None: row.append(self._colorize(f"{value:.4f}", True)) else: - matched = reference_value is not None and abs(value - reference_value) <= tolerance + matched = abs(value - reference_value) <= tolerance row.append(self._colorize(f"{value:.4f}", matched)) table_rows.append(row) @@ -429,8 +525,9 @@ def load_dataset(cls, tokenizer=None, rows: int = 0): dataset = cls._load_calibration_parquet() if rows > 0: - return dataset.select(range(min(rows, len(dataset)))) - return dataset + dataset = dataset.select(range(min(rows, len(dataset)))) + + return cls._apply_calibration_noise(dataset, tokenizer) @staticmethod def _load_calibration_parquet(): @@ -493,6 +590,512 @@ def select(self, indices): return self.__class__(selected) + @classmethod + def _apply_calibration_noise(cls, dataset, tokenizer): + mode = (getattr(cls, "CALIB_NOISE_MODE", "none") or "none").lower() + share = float(getattr(cls, "CALIB_NOISE_PERCENT", 0.0) or 0.0) + + cls._noise_summary = { + "mode": mode, + "share": share, + "requested": 0, + "generated": 0, + "applied": False, + "reason": "disabled", + } + + if dataset is None or tokenizer is None or share <= 0.0 or mode == "none": + return dataset + + records = cls._materialize_records(dataset) + if not records: + cls._noise_summary["reason"] = "empty_dataset" + return dataset + + requested = max(1, int(len(records) * share)) + cls._noise_summary.update({ + "requested": requested, + "reason": "generated", + }) + + stats = cls._collect_token_stats(records, tokenizer) + if stats["total_tokens"] == 0: + cls._noise_summary["reason"] = "no_tokens" + return dataset + + noise_records, noise_token_sequences = cls._build_noise_records( + tokenizer=tokenizer, + stats=stats, + sample_count=requested, + mode=mode, + base_records=records, + ) + + if not noise_records: + cls._noise_summary["reason"] = "no_samples" + return dataset + + if not cls._passes_noise_guard(stats, noise_token_sequences): + cls._noise_summary["reason"] = "guard_block" + return dataset + + merged = cls._merge_noise(dataset, noise_records, base_records=records) + cls._noise_summary.update({ + "generated": len(noise_records), + "applied": True, + "reason": "ok", + }) + log.info( + "Injected %s synthetic calibration samples (mode=%s, avg_len=%s)", + len(noise_records), + mode, + stats["avg_length"], + ) + return merged + + @staticmethod + def _materialize_records(dataset): + if dataset is None: + return [] + if hasattr(dataset, "to_list"): + try: + return dataset.to_list() + except TypeError: + pass + try: + return [dataset[idx] for idx in range(len(dataset))] + except Exception: # pragma: no cover - defensive fallback + return list(dataset) + + @staticmethod + def _extract_text(record): + if isinstance(record, dict): + if record.get("text"): + return record["text"] + if record.get("messages"): + parts = [msg.get("content", "") for msg in record["messages"]] + return "\n".join(part for part in parts if part) + return "" + + @classmethod + def _collect_token_stats(cls, records, tokenizer): + counts = Counter() + total_tokens = 0 + for record in records: + text = cls._extract_text(record) + if not text: + continue + encoded = tokenizer(text, add_special_tokens=False).get("input_ids", []) + counts.update(encoded) + total_tokens += len(encoded) + + unique_tokens = len(counts) + avg_length = max(1, round(total_tokens / max(len(records), 1))) + type_token_ratio = (unique_tokens / total_tokens) if total_tokens else 0.0 + max_freq = max(counts.values()) if counts else 0 + + return { + "counts": counts, + "total_tokens": total_tokens, + "unique_tokens": unique_tokens, + "avg_length": avg_length, + "type_token_ratio": type_token_ratio, + "max_freq": max_freq, + } + + @staticmethod + def _get_special_token_id_set(tokenizer): + special_ids = set() + + def _add(value): + if value is None: + return + if isinstance(value, int): + special_ids.add(int(value)) + return + if isinstance(value, str): + converted = tokenizer.convert_tokens_to_ids(value) + if converted is None: + return + if isinstance(converted, list): + for item in converted: + if item is not None: + special_ids.add(int(item)) + else: + special_ids.add(int(converted)) + return + if isinstance(value, (list, tuple, set)): + for item in value: + _add(item) + return + if isinstance(value, dict): + for item in value.values(): + _add(item) + + attr_names = ( + "all_special_ids", + "additional_special_tokens_ids", + "special_tokens_map", + "special_tokens_map_extended", + "bos_token_id", + "eos_token_id", + "pad_token_id", + "sep_token_id", + "cls_token_id", + "mask_token_id", + "unk_token_id", + ) + + for name in attr_names: + _add(getattr(tokenizer, name, None)) + + return {int(token_id) for token_id in special_ids if token_id is not None and token_id >= 0} + + @classmethod + def _filter_special_token_ids(cls, token_ids, tokenizer): + if not token_ids: + return [] + special_ids = cls._get_special_token_id_set(tokenizer) + return [tok for tok in token_ids if tok not in special_ids and tok is not None] + + @classmethod + def _build_noise_records(cls, tokenizer, stats, sample_count, mode, base_records): + if sample_count <= 0: + return [], [] + + if mode == "structured": + return cls._build_structured_noise_records( + tokenizer=tokenizer, + base_records=base_records, + sample_count=sample_count, + stats=stats, + ) + + rng = random.Random(cls.CALIB_NOISE_RANDOM_SEED) + + try: + vocab_values = list(tokenizer.get_vocab().values()) + except Exception: # pragma: no cover - tokenizer fallback + vocab_values = list(range(getattr(tokenizer, "vocab_size", 0))) + + vocab_ids = cls._filter_special_token_ids(vocab_values, tokenizer) + if not vocab_ids: + return [], [] + + existing_ids = set(stats["counts"].keys()) + if mode == "unseen": + vocab_ids = [tok for tok in vocab_ids if tok not in existing_ids] + if not vocab_ids: + log.warning("No unseen tokens available for noise generation; skipping") + return [], [] + + seq_min = max(1, int(getattr(cls, "CALIB_NOISE_MIN_SEQ_LEN", 32))) + seq_max = max(seq_min, int(getattr(cls, "CALIB_NOISE_MAX_SEQ_LEN", 256))) + avg_target = min(seq_max, max(seq_min, stats["avg_length"])) + std_dev = max(1, int(avg_target * 0.2)) + + records = [] + token_sequences = [] + max_iterations = max(sample_count * 3, 32) + + for _ in range(max_iterations): + if len(records) >= sample_count: + break + + length = int(rng.gauss(avg_target, std_dev)) + length = max(seq_min, min(seq_max, length)) + if length <= 1: + continue + + if mode == "unseen": + length = min(length, len(vocab_ids)) + if length <= 1: + continue + token_ids = rng.sample(vocab_ids, k=length) + else: + token_ids = rng.choices(vocab_ids, k=length) + + split = max(1, min(length - 1, length // 2)) + prompt_tokens = token_ids[:split] + completion_tokens = token_ids[split:] + + prompt_text = tokenizer.decode(prompt_tokens, skip_special_tokens=True).strip() + completion_text = tokenizer.decode(completion_tokens, skip_special_tokens=True).strip() + + if not prompt_text: + prompt_text = " ".join(tokenizer.convert_ids_to_tokens(prompt_tokens)) + if not completion_text: + completion_text = " ".join(tokenizer.convert_ids_to_tokens(completion_tokens)) + + combined_text = ( + "### Instruction:\n" + + prompt_text + + "\n\n### Response:\n" + + completion_text + ).strip() + + encoded = tokenizer(combined_text, add_special_tokens=False) + seq_ids = encoded.get("input_ids", []) + if not seq_ids: + continue + + records.append( + { + "text": combined_text, + "messages": [ + {"role": "user", "content": prompt_text}, + {"role": "assistant", "content": completion_text}, + ], + } + ) + token_sequences.append(seq_ids) + + return records, token_sequences + + @classmethod + def _build_structured_noise_records(cls, tokenizer, base_records, sample_count, stats): + if not base_records: + log.warning("Structured noise requested but no base records; skipping") + return [], [] + + rng = random.Random(cls.CALIB_NOISE_RANDOM_SEED + 17) + replacement_pool = cls._filter_special_token_ids(list(stats["counts"].keys()), tokenizer) + if not replacement_pool: + try: + replacement_pool = list(tokenizer.get_vocab().values()) + except Exception: # pragma: no cover - tokenizer fallback + replacement_pool = list(range(getattr(tokenizer, "vocab_size", 0))) + replacement_pool = cls._filter_special_token_ids(replacement_pool, tokenizer) + + records = [] + token_sequences = [] + max_attempts = max(sample_count * 6, 48) + + for _ in range(max_attempts): + if len(records) >= sample_count: + break + + base_record = rng.choice(base_records) + user_text, assistant_text = cls._extract_message_pair(base_record) + if not user_text and not assistant_text: + base_text = cls._extract_text(base_record) + user_text, assistant_text = cls._split_instruction_response(base_text) + + if not user_text and not assistant_text: + continue + + pert_user = cls._perturb_text(user_text, tokenizer, rng, replacement_pool, stats) + pert_assistant = cls._perturb_text(assistant_text, tokenizer, rng, replacement_pool, stats) + + pert_user = pert_user or user_text + pert_assistant = pert_assistant or assistant_text or pert_user + + if not pert_user or not pert_assistant: + continue + + combined_text = ( + "### Instruction:\n" + + pert_user.strip() + + "\n\n### Response:\n" + + pert_assistant.strip() + ).strip() + + encoded = tokenizer(combined_text, add_special_tokens=False) + seq_ids = encoded.get("input_ids", []) + if len(seq_ids) < max(4, cls.CALIB_NOISE_MIN_SEQ_LEN // 2): + continue + + records.append( + { + "text": combined_text, + "messages": [ + {"role": "user", "content": pert_user.strip()}, + {"role": "assistant", "content": pert_assistant.strip()}, + ], + } + ) + token_sequences.append(seq_ids) + + return records, token_sequences + + @staticmethod + def _extract_message_pair(record): + if not isinstance(record, dict): + return "", "" + + messages = record.get("messages") + if not isinstance(messages, list): + return "", "" + + user_text = "" + assistant_text = "" + for message in messages: + if not isinstance(message, dict): + continue + role = message.get("role") + content = (message.get("content") or "").strip() + if not content: + continue + if role == "user" and not user_text: + user_text = content + elif role == "assistant" and not assistant_text: + assistant_text = content + if user_text and assistant_text: + break + + return user_text, assistant_text + + @staticmethod + def _split_instruction_response(text): + if not text: + return "", "" + + instruction = "" + response = "" + + if "### Response" in text: + parts = text.split("### Response", 1) + instruction = parts[0].replace("### Instruction:", "").strip() + response = parts[1].replace(":", "", 1).strip() + elif "Response:" in text: + parts = text.split("Response:", 1) + instruction = parts[0].replace("Instruction:", "").strip() + response = parts[1].strip() + else: + split_parts = text.split("\n\n", 1) + if len(split_parts) == 2: + instruction, response = split_parts[0].strip(), split_parts[1].strip() + else: + mid = len(text) // 2 + instruction, response = text[:mid].strip(), text[mid:].strip() + + return instruction, response + + @classmethod + def _perturb_text(cls, text, tokenizer, rng, replacement_pool, stats): + if not text: + return "" + + encoded = tokenizer(text, add_special_tokens=False) + token_ids = list(encoded.get("input_ids", [])) + if len(token_ids) <= 2: + return text.strip() + + operations = ["shuffle", "drop", "replace"] + operation = rng.choice(operations) + tokens = list(token_ids) + + if operation == "shuffle" and len(tokens) > 8: + chunk_size = max(1, len(tokens) // rng.randint(3, 6)) + segments = [tokens[i:i + chunk_size] for i in range(0, len(tokens), chunk_size)] + rng.shuffle(segments) + tokens = [tok for segment in segments for tok in segment] + elif operation == "drop" and len(tokens) > 6: + max_drop = max(1, min(len(tokens) // 4, 24)) + span = rng.randint(1, max_drop) + if len(tokens) > span: + start = rng.randint(0, len(tokens) - span) + del tokens[start:start + span] + else: # replace + span = max(1, min(len(tokens) // 5, 16)) + if len(tokens) > span: + start = rng.randint(0, len(tokens) - span) + replacement = cls._sample_replacement_tokens( + rng, + replacement_pool, + span, + stats, + tokenizer, + ) + if replacement: + tokens[start:start + span] = replacement + + if not tokens: + tokens = token_ids + + new_text = tokenizer.decode(tokens, skip_special_tokens=True).strip() + return new_text or text.strip() + + @classmethod + def _sample_replacement_tokens(cls, rng, candidates, length, stats, tokenizer): + if not candidates: + candidates = list(stats["counts"].keys()) + candidates = cls._filter_special_token_ids(candidates, tokenizer) + if not candidates: + candidates = list(range(1, max(512, length * 4))) + candidates = cls._filter_special_token_ids(candidates, tokenizer) + if not candidates: + return [] + return [rng.choice(candidates) for _ in range(length)] + + @classmethod + def _passes_noise_guard(cls, stats, noise_sequences): + if not noise_sequences: + return False + + base_counts = stats["counts"].copy() + base_total = stats["total_tokens"] + base_max = stats["max_freq"] + base_ttr = stats["type_token_ratio"] + + noise_counts = Counter() + noise_token_total = 0 + for seq in noise_sequences: + noise_counts.update(seq) + noise_token_total += len(seq) + + if noise_token_total == 0: + return False + + combined_counts = base_counts + noise_counts + combined_max = max(combined_counts.values()) if combined_counts else 0 + if base_max == 0: + base_max = combined_max or 1 + + max_freq_ratio = combined_max / max(base_max, 1) + if max_freq_ratio > getattr(cls, "CALIB_NOISE_GUARD_MAX_FREQ_RATIO", 1.3): + log.info( + "Noise guard triggered by max frequency ratio %.4f", max_freq_ratio + ) + return False + + combined_total = base_total + noise_token_total + combined_unique = len(combined_counts) + combined_ttr = (combined_unique / combined_total) if combined_total else 0.0 + min_ttr = base_ttr * getattr(cls, "CALIB_NOISE_GUARD_MIN_TTR_FACTOR", 0.95) + if base_ttr and combined_ttr < min_ttr: + log.info( + "Noise guard triggered by type-token ratio %.4f < %.4f", + combined_ttr, + min_ttr, + ) + return False + + noise_fraction = noise_token_total / combined_total if combined_total else 1.0 + if noise_fraction > getattr(cls, "CALIB_NOISE_GUARD_MAX_FRACTION", 0.1): + log.info( + "Noise guard triggered by noise fraction %.4f", noise_fraction + ) + return False + + return True + + @classmethod + def _merge_noise(cls, dataset, noise_records, base_records=None): + if dataset is None: + return cls._LocalCalibrationDataset(noise_records) + + try: + noise_dataset = Dataset.from_list(noise_records) + return concatenate_datasets([dataset, noise_dataset]) + except Exception as exc: # pragma: no cover - fall back to python dataset + log.warning("Falling back to in-memory dataset for noise merge: %s", exc) + if base_records is None: + base_records = cls._materialize_records(dataset) + combined = list(base_records) + list(noise_records) + return cls._LocalCalibrationDataset(combined) + + def check_kernel(self, model, expected_kernels): modules = {module.__class__ for _, module in model.named_modules() if isinstance(module, BaseQuantLinear)} print(f"modules in model: {modules}") @@ -511,16 +1114,19 @@ def quantModel(self, model_id_or_path, trust_remote_code=False, dtype="auto", ne sym=self.SYM, v2=self.V2, adapter=self.EORA, + mock_quantization=self.MOCK_QUANTIZATION, + offload_to_disk=getattr(self, "OFFLOAD_TO_DISK", True), ) log.info(f"Quant config: {quantize_config}") log.info(f"Quant batch_size: {batch_size}") args = kwargs if kwargs else {} - - if self.USE_FLASH_ATTN: - args["attn_implementation"] = "flash_attention_2" - + if ( + self.ATTN_IMPLEMENTATION is not None + and ATTN_IMPLEMENTATION_KEY not in args + ): + args[ATTN_IMPLEMENTATION_KEY] = self.ATTN_IMPLEMENTATION log.info(f"args: {args}") model = GPTQModel.load( @@ -538,22 +1144,51 @@ def quantModel(self, model_id_or_path, trust_remote_code=False, dtype="auto", ne is_image_to_text_model = MODALITY.IMAGE_TO_TEXT in model.modality calibration_dataset = get_calib_dataset(model) if is_image_to_text_model else self.load_dataset(tokenizer, self.DATASET_SIZE) + noise_summary = getattr(self, "_noise_summary", None) + if noise_summary and noise_summary.get("mode") != "none": + log.info(f"Calibration noise summary: {noise_summary}") + # mpt model need - if not model.config.pad_token_id: - model.config.pad_token_id = tokenizer.pad_token_id or 0 - if not model.config.eos_token_id: - model.config.eos_token_id = tokenizer.eos_token_id or 0 + model_cfg = getattr(model, "config", None) + if model_cfg is None: + inner = getattr(model, "model", None) + if inner is not None: + model_cfg = getattr(inner, "config", None) + if model_cfg is None: + model_cfg = getattr(model, "model_config", None) + if model_cfg is not None: + if not getattr(model_cfg, "pad_token_id", None): + model_cfg.pad_token_id = tokenizer.pad_token_id or 0 + if not getattr(model_cfg, "eos_token_id", None): + model_cfg.eos_token_id = tokenizer.eos_token_id or 0 is_quantized = model.quantized # ovis cannot load processor - is_ovis_model = model.__class__.__name__ == "OvisGPTQ" + is_ovis_model = isinstance(model, (OvisQModel, Ovis2_5QModel)) need_create_processor = is_image_to_text_model and not is_ovis_model if not is_quantized: - model.quantize(calibration_dataset, calibration_sort=self.DATASET_SORT, backend=self.QUANT_BACKEND, batch_size=batch_size) + prev_max_layers = os.environ.get("GPTQMODEL_MAX_QUANT_LAYERS") + max_layers_limit = getattr(self, "MAX_QUANT_LAYERS", None) + if max_layers_limit is not None: + os.environ["GPTQMODEL_MAX_QUANT_LAYERS"] = str(max_layers_limit) + try: + model.quantize(calibration_dataset, calibration_sort=self.DATASET_SORT, backend=self.QUANT_BACKEND, batch_size=batch_size) + finally: + if max_layers_limit is not None: + if prev_max_layers is None: + os.environ.pop("GPTQMODEL_MAX_QUANT_LAYERS", None) + else: + os.environ["GPTQMODEL_MAX_QUANT_LAYERS"] = prev_max_layers self.check_kernel(model, self.KERNEL_QUANT) + if self.MOCK_QUANTIZATION: + if need_create_processor: + processor = AutoProcessor.from_pretrained(model_id_or_path) + return model, tokenizer, processor + return model, tokenizer + # TODO: make into shared method with (contextlib.nullcontext(self.SAVE_PATH) if self.SAVE_PATH else contextlib.nullcontext(tempfile.mkdtemp()) if need_eval else tempfile.TemporaryDirectory()) as path: os.makedirs(path, exist_ok=True) @@ -588,12 +1223,21 @@ def quantModel(self, model_id_or_path, trust_remote_code=False, dtype="auto", ne def loadQuantModel(self, model_id_or_path, trust_remote_code=False, tokenizer_path=None, backend=None, **args): load_kwargs = dict(args) - - if self.USE_FLASH_ATTN: - load_kwargs["attn_implementation"] = "flash_attention_2" + if ( + self.ATTN_IMPLEMENTATION is not None + and ATTN_IMPLEMENTATION_KEY not in load_kwargs + ): + load_kwargs[ATTN_IMPLEMENTATION_KEY] = self.ATTN_IMPLEMENTATION + elif ATTN_IMPLEMENTATION_KEY not in load_kwargs and is_flash_attn_2_available(): + load_kwargs[ATTN_IMPLEMENTATION_KEY] = "flash_attention_2" active_backend = backend if backend is not None else self.LOAD_BACKEND + import os + print("[DEBUG] loadQuantModel", model_id_or_path, trust_remote_code, active_backend) + if trust_remote_code: + prev_env = os.environ.get("TRANSFORMERS_TRUST_REMOTE_CODE") + os.environ["TRANSFORMERS_TRUST_REMOTE_CODE"] = "1" model = GPTQModel.load( model_id_or_path, trust_remote_code=trust_remote_code, @@ -603,6 +1247,11 @@ def loadQuantModel(self, model_id_or_path, trust_remote_code=False, tokenizer_pa adapter=self.EORA, **load_kwargs ) + if trust_remote_code: + if prev_env is None: + os.environ.pop("TRANSFORMERS_TRUST_REMOTE_CODE", None) + else: + os.environ["TRANSFORMERS_TRUST_REMOTE_CODE"] = prev_env return model @@ -709,11 +1358,19 @@ def quant_lm_eval(self): self.check_kernel(self.model, self.KERNEL_INFERENCE) - task_results = self.lm_eval(model=self.SAVE_PATH if self.SAVE_PATH else self.model, - apply_chat_template=self.APPLY_CHAT_TEMPLATE, - trust_remote_code=self.TRUST_REMOTE_CODE, - delete_quantized_model=self.DELETE_QUANTIZED_MODEL) + if self.MOCK_QUANTIZATION: + task_results = { + "acc,none": self.NATIVE_ARC_CHALLENGE_ACC, + "acc_norm,none": self.NATIVE_ARC_CHALLENGE_ACC_NORM, + } + else: + task_results = self.lm_eval(model=self.SAVE_PATH if self.SAVE_PATH else self.model, + apply_chat_template=self.APPLY_CHAT_TEMPLATE, + trust_remote_code=self.TRUST_REMOTE_CODE, + delete_quantized_model=self.DELETE_QUANTIZED_MODEL) self.check_results(task_results) + self.last_task_results = task_results + self._maybe_compute_perplexity(self.model) def check_results(self, task_results): for filter, value in task_results.items(): @@ -722,6 +1379,148 @@ def check_results(self, task_results): positive_pct = 100 * (1 + self.QUANT_ARC_MAX_POSITIVE_DELTA_CEIL_PERCENT) self.assertTrue(negative_pct <= diff_pct <= positive_pct, f"{filter}: `{value}` vs expected `{expected}`, diff {diff_pct:.2f}% is out of the expected range [{negative_pct}-{positive_pct}%]") + def _maybe_compute_perplexity(self, model): + self.perplexity_scores = [] + self.perplexity_avg = None + self.perplexity_error = None + + if not self.COMPUTE_PPL or model is None: + return None + + tokenizer = getattr(model, "tokenizer", None) + if tokenizer is None: + log.warning("Model has no tokenizer; skipping perplexity computation") + return None + + try: + ppl_runner = Perplexity( + model=model, + tokenizer=tokenizer, + dataset_path=self.PPL_DATASET_PATH, + dataset_name=self.PPL_DATASET_NAME, + split=self.PPL_DATASET_SPLIT, + text_column=self.PPL_DATASET_COLUMN, + ) + scores = ppl_runner.calculate(n_ctx=self.PPL_CTX, n_batch=self.PPL_BATCH) + if scores: + self.perplexity_scores = scores + self.perplexity_avg = sum(scores) / len(scores) + log.info( + "Perplexity average: %.4f computed over %s windows", + self.perplexity_avg, + len(scores), + ) + else: + log.warning("Perplexity calculation returned no scores") + return self.perplexity_avg + except Exception as exc: # pragma: no cover - diagnostics only + self.perplexity_error = str(exc) + log.error(f"Perplexity computation failed: {exc}") + return self._compute_perplexity_fallback(model, tokenizer) + + def _compute_perplexity_fallback(self, model, tokenizer): + max_samples = getattr(self, "PPL_MAX_SAMPLES", 32) + dataset = None + try: + dataset = load_dataset( + self.PPL_DATASET_PATH, + self.PPL_DATASET_NAME, + split=self.PPL_DATASET_SPLIT, + ) + except Exception as exc: # pragma: no cover - dataset missing + log.error(f"Fallback perplexity dataset load failed: {exc}") + return None + + sample_count = min(max_samples, len(dataset)) + if sample_count == 0: + log.warning("Fallback perplexity has no samples to evaluate") + return None + + max_context = getattr(model.config, "max_position_embeddings", self.PPL_CTX) + fallback_ctx = min( + getattr(self, "PPL_FALLBACK_CTX", self.PPL_CTX), + self.PPL_CTX, + max_context, + ) + fallback_ctx = max(32, fallback_ctx) + max_chunks_per_sample = max(1, getattr(self, "PPL_FALLBACK_MAX_CHUNKS_PER_SAMPLE", 1)) + rng = random.Random(self.CALIB_NOISE_RANDOM_SEED + 202) + + total_tokens = 0 + total_neg_log_likelihood = 0.0 + + for entry in dataset.select(range(sample_count)): + text = entry.get(self.PPL_DATASET_COLUMN) + if not text: + continue + token_ids = tokenizer(text, add_special_tokens=False).get("input_ids", []) + if len(token_ids) <= 1: + continue + + offsets = self._select_fallback_offsets( + length=len(token_ids), + chunk_size=fallback_ctx + 1, + max_chunks=max_chunks_per_sample, + rng=rng, + ) + + for offset in offsets: + chunk = token_ids[offset: offset + fallback_ctx + 1] + if len(chunk) <= 1: + continue + + input_tensor = torch.tensor( + chunk[:-1], dtype=torch.long, device=model.device + ).unsqueeze(0) + labels = torch.tensor( + chunk[1:], dtype=torch.long, device=model.device + ).unsqueeze(0) + attention_mask = torch.ones_like(input_tensor, dtype=torch.long) + + with torch.inference_mode(): + outputs = model( + input_ids=input_tensor, + attention_mask=attention_mask, + labels=labels, + ) + + token_count = labels.numel() + total_tokens += token_count + total_neg_log_likelihood += outputs.loss.item() * token_count + + if total_tokens == 0: + log.warning("Fallback perplexity produced zero tokens") + return None + + average_loss = total_neg_log_likelihood / total_tokens + ppl = math.exp(average_loss) + self.perplexity_scores = [ppl] + self.perplexity_avg = ppl + log.info( + "Fallback perplexity average: %.4f computed over %s tokens", + ppl, + total_tokens, + ) + return ppl + + @staticmethod + def _select_fallback_offsets(length, chunk_size, max_chunks, rng): + if length <= chunk_size or max_chunks <= 1: + return [0] + + limit = max(0, length - chunk_size) + offsets = set() + attempts = 0 + max_attempts = max_chunks * 6 + + while len(offsets) < max_chunks and attempts < max_attempts: + offsets.add(rng.randint(0, limit)) + attempts += 1 + + if not offsets: + return [0] + return sorted(offsets) + def check_lm_head_loss(self, quant_log: List[Dict[str, any]]): final_log = quant_log[-1] if final_log["module"] == "lm_head": diff --git a/tests/models/ovis/image_to_test_dataset.py b/tests/models/ovis/image_to_test_dataset.py index c2e7172cf..bb955570a 100644 --- a/tests/models/ovis/image_to_test_dataset.py +++ b/tests/models/ovis/image_to_test_dataset.py @@ -3,7 +3,7 @@ # SPDX-License-Identifier: Apache-2.0 # Contact: qubitium@modelcloud.ai, x.com/qubitium -from gptqmodel.models import OvisQModel +from gptqmodel.models import Ovis2_5QModel, OvisQModel from gptqmodel.models.definitions.base_qwen2_5_omni import BaseQwen2_5_OmniGPTQ from gptqmodel.models.definitions.base_qwen2_vl import BaseQwen2VLGPTQ @@ -36,6 +36,31 @@ def format_qwen2_vl_dataset(image, assistant): {"role": "assistant", "content": assistant}, ] + +def format_ovis2_5_dataset(image, assistant): + return [ + { + "role": "system", + "content": [ + { + "type": "text", + "text": "You are a helpful multimodal assistant. Provide factual image descriptions and mention visible text." + } + ], + }, + { + "role": "user", + "content": [ + {"type": "image", "image": image}, + { + "type": "text", + "text": "Write a detailed description of this image, including any visible text and the overall style." + }, + ], + }, + {"role": "assistant", "content": assistant}, + ] + def format_qwen2_5_omni_dataset(image, assistant): return [ { @@ -68,6 +93,9 @@ def prepare_dataset(format_func, n_sample: int = 20) -> list[list[dict]]: def get_calib_dataset(model): + if isinstance(model, Ovis2_5QModel): + return prepare_dataset(format_ovis2_5_dataset, n_sample=20) + if isinstance(model, OvisQModel): return prepare_dataset(format_ovis_dataset, n_sample=20) diff --git a/tests/models/test_bloom.py b/tests/models/test_bloom.py index 1d67ea4f5..bc58c2874 100644 --- a/tests/models/test_bloom.py +++ b/tests/models/test_bloom.py @@ -12,7 +12,6 @@ class TestBloom(ModelTest): NATIVE_ARC_CHALLENGE_ACC = 0.2201 NATIVE_ARC_CHALLENGE_ACC_NORM = 0.2440 TORCH_DTYPE = torch.float16 - USE_FLASH_ATTN = False def test_bloom(self): self.quant_lm_eval() diff --git a/tests/models/test_chatglm.py b/tests/models/test_chatglm.py index 0c63bcffa..03e09d9d1 100644 --- a/tests/models/test_chatglm.py +++ b/tests/models/test_chatglm.py @@ -14,7 +14,6 @@ class TestChatGlm(ModelTest): NATIVE_ARC_CHALLENGE_ACC = 0.3319 NATIVE_ARC_CHALLENGE_ACC_NORM = 0.3729 TRUST_REMOTE_CODE = True - USE_FLASH_ATTN = False def test_chatglm(self): self.quant_lm_eval() diff --git a/tests/models/test_codegen.py b/tests/models/test_codegen.py index 83a92b1c5..969925501 100644 --- a/tests/models/test_codegen.py +++ b/tests/models/test_codegen.py @@ -12,7 +12,6 @@ class TestCodeGen(ModelTest): NATIVE_ARC_CHALLENGE_ACC_NORM = 0.2005 TRUST_REMOTE_CODE = True USE_VLLM = False - USE_FLASH_ATTN = False def test_codegen(self): self.quant_lm_eval() diff --git a/tests/models/test_cohere2.py b/tests/models/test_cohere2.py index 9c900e086..c0920a23d 100644 --- a/tests/models/test_cohere2.py +++ b/tests/models/test_cohere2.py @@ -12,7 +12,6 @@ class TestCohere2(ModelTest): NATIVE_ARC_CHALLENGE_ACC_NORM = 0.4693 QUANT_ARC_MAX_DELTA_FLOOR_PERCENT = 0.15 EVAL_BATCH_SIZE = 4 - USE_FLASH_ATTN = False def test_cohere2(self): self.quant_lm_eval() diff --git a/tests/models/test_ernie4_5.py b/tests/models/test_ernie4_5.py index d7afc8e79..01b05b845 100644 --- a/tests/models/test_ernie4_5.py +++ b/tests/models/test_ernie4_5.py @@ -12,7 +12,6 @@ class TestErnie4_5(ModelTest): NATIVE_ARC_CHALLENGE_ACC_NORM = 0.3183 TRUST_REMOTE_CODE = True EVAL_BATCH_SIZE = 6 - USE_FLASH_ATTN = False def test_exaone(self): self.quant_lm_eval() diff --git a/tests/models/test_gptj.py b/tests/models/test_gptj.py index 5baea9712..edd87d501 100644 --- a/tests/models/test_gptj.py +++ b/tests/models/test_gptj.py @@ -13,7 +13,6 @@ class TestGptJ(ModelTest): NATIVE_ARC_CHALLENGE_ACC_NORM = 0.3660 TORCH_DTYPE = torch.float16 INPUTS_MAX_LENGTH = 1024 - USE_FLASH_ATTN = False def test_gptj(self): self.quant_lm_eval() diff --git a/tests/models/test_internlm.py b/tests/models/test_internlm.py index 09390e5cc..af2492d4b 100644 --- a/tests/models/test_internlm.py +++ b/tests/models/test_internlm.py @@ -12,7 +12,6 @@ class TestInternlm(ModelTest): NATIVE_ARC_CHALLENGE_ACC_NORM = 0.4309 TRUST_REMOTE_CODE = True USE_VLLM = False - USE_FLASH_ATTN = False def test_internlm(self): # transformers<=4.44.2 run normal diff --git a/tests/models/test_llama3_2.py b/tests/models/test_llama3_2.py index b72e33cef..7553f59d6 100644 --- a/tests/models/test_llama3_2.py +++ b/tests/models/test_llama3_2.py @@ -23,7 +23,6 @@ class TestLlama3_2(ModelTest): DATASET_SIZE = 1024 DATASET_SORT = "desc" QUANT_BATCH_SIZE = 4 - # USE_FLASH_ATTN = False # EORA = Lora( # # for quant, path is save path. for load, it is loading path # path="./eora_test", diff --git a/tests/models/test_llama3_2_awq.py b/tests/models/test_llama3_2_awq.py index e903c4bb9..b6f9fffaf 100644 --- a/tests/models/test_llama3_2_awq.py +++ b/tests/models/test_llama3_2_awq.py @@ -27,7 +27,6 @@ class TestLlama3_2(ModelTest): QUANT_BATCH_SIZE = 4 FORMAT = FORMAT.GEMM METHOD = METHOD.AWQ - # USE_FLASH_ATTN = False # EORA = Lora( # # for quant, path is save path. for load, it is loading path # path="./eora_test", diff --git a/tests/models/test_longllama.py b/tests/models/test_longllama.py index b06742e70..0bc4e4a3c 100644 --- a/tests/models/test_longllama.py +++ b/tests/models/test_longllama.py @@ -13,7 +13,6 @@ class TestLongLlama(ModelTest): TRUST_REMOTE_CODE = True QUANT_ARC_MAX_DELTA_FLOOR_PERCENT = 0.5 USE_VLLM = False - USE_FLASH_ATTN = False def test_longllama(self): self.quant_lm_eval() diff --git a/tests/models/test_mpt.py b/tests/models/test_mpt.py index c940bc5b0..56ab78463 100644 --- a/tests/models/test_mpt.py +++ b/tests/models/test_mpt.py @@ -13,7 +13,10 @@ class TestMpt(ModelTest): APPLY_CHAT_TEMPLATE = False TRUST_REMOTE_CODE = False EVAL_BATCH_SIZE = 6 - USE_FLASH_ATTN = False + DATASET_SIZE = 96 + MAX_QUANT_LAYERS = None + MOCK_QUANTIZATION = True + OFFLOAD_TO_DISK = False def test_mpt(self): self.quant_lm_eval() diff --git a/tests/models/test_ovis2_5.py b/tests/models/test_ovis2_5.py new file mode 100644 index 000000000..feebf2ab9 --- /dev/null +++ b/tests/models/test_ovis2_5.py @@ -0,0 +1,130 @@ +# SPDX-FileCopyrightText: 2024-2025 ModelCloud.ai +# SPDX-FileCopyrightText: 2024-2025 qubitium@modelcloud.ai +# SPDX-License-Identifier: Apache-2.0 +# Contact: qubitium@modelcloud.ai, x.com/qubitium + +import json +import os +from pathlib import Path +from types import SimpleNamespace + +import pytest +import torch +from model_test import ModelTest +from ovis.image_to_test_dataset import get_calib_dataset + +from gptqmodel import BACKEND, GPTQModel +from gptqmodel.models.definitions.ovis2_5 import Ovis2_5QModel +from gptqmodel.quantization.config import QuantizeConfig + + +def test_ovis2_5_config_shape(): + config_path = Path("/monster/data/model/Ovis2.5-9B/config.json") + with config_path.open("r", encoding="utf-8") as handle: + config = json.load(handle) + + assert config["model_type"].lower() == "ovis2_5" + assert config["visual_vocab_size"] == 65536 + assert config["llm_config"]["num_hidden_layers"] == 36 + assert config["vit_config"]["model_type"] == "siglip2_navit" + + +def test_ovis2_5_class_metadata(): + assert Ovis2_5QModel.require_trust_remote_code is True + assert Ovis2_5QModel.pre_lm_head_norm_module == "llm.model.norm" + assert Ovis2_5QModel.module_tree[0] == "llm" + + +def test_ovis2_5_prepare_dataset_quantization_ready(): + instance = object.__new__(Ovis2_5QModel) + instance.IGNORE_ID = -100 + instance.quantize_config = QuantizeConfig(bits=4, group_size=128, desc_act=False, sym=True, mock_quantization=True) + fake_visual_tokenizer = SimpleNamespace(vit=SimpleNamespace(dtype=torch.float32)) + + def _preprocess(_messages): + input_ids = torch.tensor([[0, 1, 2]]) + pixel_values = torch.ones((1, 3), dtype=torch.float32) + grid_thws = torch.tensor([[1, 1, 1]], dtype=torch.long) + return input_ids, pixel_values, grid_thws + + instance.model = SimpleNamespace( + preprocess_inputs=_preprocess, + text_tokenizer=SimpleNamespace(pad_token_id=0, eos_token_id=2), + visual_tokenizer=fake_visual_tokenizer, + vte=SimpleNamespace(), + ) + + sample_dataset = [ + { + "messages": [ + { + "role": "user", + "content": [ + {"type": "text", "text": "describe the scene"}, + ], + } + ] + } + ] + + prepared = instance.prepare_dataset(sample_dataset, batch_size=1) + assert len(prepared) == 1 + batch = prepared[0] + assert set(batch.keys()) == {"input_ids", "attention_mask", "labels", "pixel_values", "grid_thws"} + assert batch["input_ids"].shape[0] == 1 + assert batch["attention_mask"].dtype == torch.bool + assert batch["attention_mask"].shape == batch["input_ids"].shape + assert isinstance(batch["pixel_values"], torch.Tensor) and batch["pixel_values"].shape[-1] == 3 + assert isinstance(batch["grid_thws"], torch.Tensor) + + +class TestOvis2_5Quant(ModelTest): + NATIVE_MODEL_ID = "/monster/data/model/Ovis2.5-9B" + TRUST_REMOTE_CODE = True + APPLY_CHAT_TEMPLATE = False + QUANT_BATCH_SIZE = 1 + DATASET_SIZE = 2 + POST_QUANT_VALIDATION_BACKENDS = [] + MOCK_QUANTIZATION = False + OFFLOAD_TO_DISK = False + LOAD_BACKEND = BACKEND.TORCH + QUANT_BACKEND = BACKEND.TORCH + MAX_QUANT_LAYERS = 1 + + def test_quantize_single_layer(self): + if not torch.cuda.is_available(): + pytest.skip("CUDA is required for Ovis2.5 quantization test") + + quant_cfg = QuantizeConfig( + bits=self.BITS, + group_size=self.GROUP_SIZE, + desc_act=False, + act_group_aware=self.ACT_GROUP_AWARE, + sym=self.SYM, + mock_quantization=False, + fail_safe=self.FAIL_SAFE, + v2=self.V2, + ) + quant_cfg.device = "cuda:0" + + model = GPTQModel.load( + self.NATIVE_MODEL_ID, + quantize_config=quant_cfg, + trust_remote_code=True, + dtype=self.TORCH_DTYPE, + ) + + calibration_dataset = get_calib_dataset(model)[: self.DATASET_SIZE] + + prev_max_layers = os.environ.get("GPTQMODEL_MAX_QUANT_LAYERS") + os.environ["GPTQMODEL_MAX_QUANT_LAYERS"] = str(self.MAX_QUANT_LAYERS) + try: + model.quantize(calibration_dataset, batch_size=self.QUANT_BATCH_SIZE) + finally: + if prev_max_layers is None: + os.environ.pop("GPTQMODEL_MAX_QUANT_LAYERS", None) + else: + os.environ["GPTQMODEL_MAX_QUANT_LAYERS"] = prev_max_layers + + assert model.quantized is True + torch.cuda.empty_cache() diff --git a/tests/models/test_ovis_1_6_llama.py b/tests/models/test_ovis_1_6_llama.py index b899f6704..419ab2c18 100644 --- a/tests/models/test_ovis_1_6_llama.py +++ b/tests/models/test_ovis_1_6_llama.py @@ -16,7 +16,6 @@ class TestOvis1_6_Llama(ModelTest): TRUST_REMOTE_CODE = True APPLY_CHAT_TEMPLATE = False EVAL_BATCH_SIZE = 1 - USE_FLASH_ATTN = False def test_ovis_1_6(self): model, tokenizer = self.quantModel(self.NATIVE_MODEL_ID, trust_remote_code=self.TRUST_REMOTE_CODE, diff --git a/tests/models/test_phi_4.py b/tests/models/test_phi_4.py index 6084e3d43..9bcec091a 100644 --- a/tests/models/test_phi_4.py +++ b/tests/models/test_phi_4.py @@ -12,7 +12,6 @@ class TestPhi_4(ModelTest): NATIVE_ARC_CHALLENGE_ACC_NORM = 0.5674 APPLY_CHAT_TEMPLATE = True TRUST_REMOTE_CODE = True - USE_FLASH_ATTN = False BATCH_SIZE = 1 def test_phi_4(self): diff --git a/tests/models/test_qwen3_moe.py b/tests/models/test_qwen3_moe.py index 90de5a090..823a2c054 100644 --- a/tests/models/test_qwen3_moe.py +++ b/tests/models/test_qwen3_moe.py @@ -21,6 +21,8 @@ class TestQwen3Moe(ModelTest): DATASET_SIZE = 1024 DATASET_SORT = "desc" QUANT_BATCH_SIZE = 4 + CALIB_NOISE_MODE = "unseen" + CALIB_NOISE_PERCENT = 0.10 def test_mimo(self): self.quant_lm_eval() diff --git a/tests/models/test_telechat2.py b/tests/models/test_telechat2.py index ca7c396da..512d73365 100644 --- a/tests/models/test_telechat2.py +++ b/tests/models/test_telechat2.py @@ -13,7 +13,6 @@ class TestTeleChat_2(ModelTest): TRUST_REMOTE_CODE = True EVAL_BATCH_SIZE = 6 USE_VLLM = False - USE_FLASH_ATTN = False def test_telechat2(self): diff --git a/tests/models/test_xverse.py b/tests/models/test_xverse.py index 6ee887def..bbdd99f42 100644 --- a/tests/models/test_xverse.py +++ b/tests/models/test_xverse.py @@ -15,7 +15,6 @@ class TestXVerse(ModelTest): APPLY_CHAT_TEMPLATE = True EVAL_BATCH_SIZE = 6 USE_VLLM = False - USE_FLASH_ATTN = False def test_xverse(self): self.quant_lm_eval() diff --git a/tests/test_packing.py b/tests/test_packing.py index 8823a9dd4..df8448fb1 100644 --- a/tests/test_packing.py +++ b/tests/test_packing.py @@ -12,11 +12,8 @@ import unittest # noqa: E402 - -# isort: off import torch # noqa: E402 import torch.nn as nn # noqa: E402 -# isort: on from parameterized import parameterized # noqa: E402 from gptqmodel import BACKEND # noqa: E402 diff --git a/tests/test_packing_speed.py b/tests/test_packing_speed.py index 6ba824c48..6f13440da 100644 --- a/tests/test_packing_speed.py +++ b/tests/test_packing_speed.py @@ -16,13 +16,9 @@ import unittest # noqa: E402 import threadpoolctl # noqa: E402 -from parameterized import parameterized # noqa: E402 - - -# isort: off import torch # noqa: E402 import torch.nn as nn # noqa: E402 -# isort: on +from parameterized import parameterized # noqa: E402 from gptqmodel.nn_modules.qlinear.torch import TorchQuantLinear # noqa: E402