diff --git a/examples/benchmark/ipex.py b/examples/benchmark/ipex.py index cb0e70cd7..fc73436ed 100644 --- a/examples/benchmark/ipex.py +++ b/examples/benchmark/ipex.py @@ -7,9 +7,7 @@ import time import torch -from transformers import AutoModelForCausalLM, AutoTokenizer - -from gptqmodel.utils.hf import safe_auto_config_from_pretrained +from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer try: @@ -55,7 +53,7 @@ def prepare_dataset_for_bench(tokenizer, batch_size=8): # load model, check model backend start_load = time.time() -config = safe_auto_config_from_pretrained(ars.model) +config = AutoConfig.from_pretrained(ars.model) is_quantized_model = hasattr(config, "quantization_config") if is_quantized_model: from gptqmodel import BACKEND, GPTQModel diff --git a/format/format.sh b/format/format.sh index a087073c6..9e4d630ed 100755 --- a/format/format.sh +++ b/format/format.sh @@ -2,11 +2,14 @@ cd "$(dirname "$0")" || exit -# force ruff to be same version as setup.py -pip install -U ruff==0.13.0 +# force ruff/isort to be same version as setup.py +pip install -U ruff==0.13.0 isort==6.0.1 ruff check ../gptqmodel/models ../gptqmodel/nn_modules ../gptqmodel/quantization ../gptqmodel/utils ../gptqmodel/__init__.py ../examples ../tests ../setup.py --fix --unsafe-fixes ruff_status=$? +# isort is too slow +# isort -l 119 -e ../ + # Exit with the status code of ruff check -exit $ruff_status +exit $ruff_status \ No newline at end of file diff --git a/gptqmodel/looper/module_looper.py b/gptqmodel/looper/module_looper.py index 6c6f87e1e..91cc6119c 100644 --- a/gptqmodel/looper/module_looper.py +++ b/gptqmodel/looper/module_looper.py @@ -869,8 +869,7 @@ def _process_on_worker(proc: LoopProcessor, nm: NamedModule): for fut in futures: name, m = fut.result() processed_subset[name] = m - - #torch_sync() + torch_sync() # ---- End Process Hook ---- is_last_module = layer_index == len(pb) - 1 diff --git a/gptqmodel/models/auto.py b/gptqmodel/models/auto.py index 1355dcd5e..25b114647 100644 --- a/gptqmodel/models/auto.py +++ b/gptqmodel/models/auto.py @@ -44,14 +44,13 @@ import torch # noqa: E402 from huggingface_hub import list_repo_files # noqa: E402 from tokenicer import Tokenicer # noqa: E402 -from transformers import GenerationConfig, PreTrainedModel, PreTrainedTokenizerBase # noqa: E402 +from transformers import AutoConfig, GenerationConfig, PreTrainedModel, PreTrainedTokenizerBase # noqa: E402 from ..adapter.adapter import Adapter, Lora, normalize_adapter # noqa: E402 from ..nn_modules.qlinear.torch import TorchQuantLinear # noqa: E402 from ..quantization import METHOD, QUANT_CONFIG_FILENAME # noqa: E402 from ..utils import BACKEND # noqa: E402 from ..utils.eval import EVAL # noqa: E402 -from ..utils.hf import safe_auto_config_from_pretrained # noqa: E402 from ..utils.model import find_modules # noqa: E402 from ..utils.torch import CPU, torch_empty_cache # noqa: E402 from .base import BaseQModel, QuantizeConfig # noqa: E402 @@ -101,7 +100,6 @@ from .definitions.nemotron_h import NemotronHQModel # noqa: E402 from .definitions.opt import OptQModel # noqa: E402 from .definitions.ovis import OvisQModel # noqa: E402 -from .definitions.ovis2_5 import Ovis2_5QModel # noqa: E402 from .definitions.pangu_alpha import PanguAlphaQModel # noqa: E402 from .definitions.phi import PhiQModel # noqa: E402 from .definitions.phi3 import Phi3QModel, PhiMoEGPTQForCausalLM # noqa: E402 @@ -197,7 +195,6 @@ "hymba": HymbaQModel, "olmo2": LlamaQModel, # 100% llama clone "ovis": OvisQModel, - "ovis2_5": Ovis2_5QModel, "telechat": TeleChat2QModel, "instella": InstellaQModel, "mimo": MimoQModel, @@ -216,7 +213,7 @@ def check_and_get_model_type(model_dir, trust_remote_code=False): - config = safe_auto_config_from_pretrained(model_dir, trust_remote_code=trust_remote_code) + config = AutoConfig.from_pretrained(model_dir, trust_remote_code=trust_remote_code) if config.model_type.lower() not in SUPPORTED_MODELS: raise TypeError(f"{config.model_type} isn't supported yet.") model_type = config.model_type @@ -253,7 +250,7 @@ def load( backend = BACKEND(backend) is_gptqmodel_quantized = False - model_cfg = safe_auto_config_from_pretrained(model_id_or_path, trust_remote_code=trust_remote_code) + model_cfg = AutoConfig.from_pretrained(model_id_or_path, trust_remote_code=trust_remote_code) if hasattr(model_cfg, "quantization_config") and "quant_format" in model_cfg.quantization_config: # only if the model is quantized or compatible with gptqmodel should we set is_quantized to true if model_cfg.quantization_config["quant_format"].lower() in (METHOD.GPTQ, METHOD.AWQ, METHOD.QQQ): @@ -274,7 +271,6 @@ def load( break if is_gptqmodel_quantized: - log.info("GPTQModel.load: loading quantized model `%s` with trust_remote_code=%s", model_id_or_path, trust_remote_code) m = cls.from_quantized( model_id_or_path=model_id_or_path, device_map=device_map, @@ -308,8 +304,7 @@ def from_pretrained( trust_remote_code: bool = False, **model_init_kwargs, ) -> BaseQModel: - config = safe_auto_config_from_pretrained(model_id_or_path, trust_remote_code=trust_remote_code) - if hasattr(config, + if hasattr(AutoConfig.from_pretrained(model_id_or_path, trust_remote_code=trust_remote_code), "quantization_config"): log.warn("Model is already quantized, will use `from_quantized` to load quantized model.\n" "If you want to quantize the model, please pass un_quantized model path or id, and use " @@ -400,11 +395,7 @@ def eval( if isinstance(model_or_id_or_path, str): log.info(f"Eval: loading using backend = `{backend}`") - model = GPTQModel.load( - model_id_or_path=model_or_id_or_path, - backend=backend, - trust_remote_code=trust_remote_code, - ) + model = GPTQModel.load(model_id_or_path=model_or_id_or_path, backend=backend) model_id_or_path = model_or_id_or_path elif isinstance(model_or_id_or_path, BaseQModel) or isinstance(model_or_id_or_path, (PreTrainedModel, PeftModel)): model = model_or_id_or_path @@ -416,7 +407,7 @@ def eval( if isinstance(model, BaseQModel): tokenizer = model.tokenizer elif isinstance(model, PreTrainedModel) or model_id_or_path.strip(): - tokenizer = Tokenicer.load(model_id_or_path, trust_remote_code=trust_remote_code) + tokenizer = Tokenicer.load(model_id_or_path) if tokenizer is None: raise ValueError("Tokenizer: Auto-loading of tokenizer failed with `model_or_id_or_path`. Please pass in `tokenizer` as argument.") @@ -455,46 +446,19 @@ def eval( # use model.generation_config whenever possible if gen_kwargs is None: + # TODO: move to utils if hasattr(model, "generation_config") and isinstance(model.generation_config, GenerationConfig): - cfg = model.generation_config - kv_pairs = [] - if getattr(cfg, "do_sample", False): - kv_pairs.append("do_sample=True") - temperature = getattr(cfg, "temperature", None) - if temperature is not None and temperature != 1.0: - kv_pairs.append(f"temperature={temperature}") - top_k = getattr(cfg, "top_k", None) - if top_k is not None: - kv_pairs.append(f"top_k={top_k}") - top_p = getattr(cfg, "top_p", None) - if top_p is not None and top_p != 1.0: - kv_pairs.append(f"top_p={top_p}") - min_p = getattr(cfg, "min_p", None) - if min_p is not None and min_p > 0.0: - kv_pairs.append(f"min_p={min_p}") - typical_p = getattr(cfg, "typical_p", None) - if typical_p is not None and typical_p != 1.0: - kv_pairs.append(f"typical_p={typical_p}") - epsilon_cutoff = getattr(cfg, "epsilon_cutoff", None) - if epsilon_cutoff is not None and epsilon_cutoff != 0.0: - kv_pairs.append(f"epsilon_cutoff={epsilon_cutoff}") - eta_cutoff = getattr(cfg, "eta_cutoff", None) - if eta_cutoff is not None and eta_cutoff != 0.0: - kv_pairs.append(f"eta_cutoff={eta_cutoff}") - penalty_alpha = getattr(cfg, "penalty_alpha", None) - if penalty_alpha is not None: - kv_pairs.append(f"penalty_alpha={penalty_alpha}") - else: - kv_pairs.append("do_sample=False") - temperature = getattr(cfg, "temperature", None) - if temperature is None: - temperature = 0.0 - if temperature != 1.0: - kv_pairs.append(f"temperature={temperature}") - - gen_kwargs = ",".join(kv_pairs) + gen_dict = { + "do_sample": model.generation_config.do_sample, + "temperature": model.generation_config.temperature, + "top_k": model.generation_config.top_k, + "top_p": model.generation_config.top_p, + "min_p": model.generation_config.min_p, + + } + gen_kwargs = ','.join(f"{key}={value}" for key, value in gen_dict.items() if value not in ["", {}, None, []]) else: - gen_kwargs = "do_sample=False,temperature=0.0" # default + gen_kwargs = "temperature=0.0,top_k=50" # default log.info(f"LM-EVAL: `gen_kwargs` = `{gen_kwargs}`") @@ -571,7 +535,7 @@ def eval( @staticmethod def export(model_id_or_path: str, target_path: str, format: str, trust_remote_code: bool = False): # load config - config = safe_auto_config_from_pretrained(model_id_or_path, trust_remote_code=trust_remote_code) + config = AutoConfig.from_pretrained(model_id_or_path, trust_remote_code=trust_remote_code) if not config.quantization_config: raise ValueError("Model is not quantized") @@ -579,11 +543,7 @@ def export(model_id_or_path: str, target_path: str, format: str, trust_remote_co gptq_config = config.quantization_config # load gptq model - gptq_model = GPTQModel.load( - model_id_or_path, - backend=BACKEND.TORCH, - trust_remote_code=trust_remote_code, - ) + gptq_model = GPTQModel.load(model_id_or_path, backend=BACKEND.TORCH) if format == "mlx": try: diff --git a/gptqmodel/models/definitions/__init__.py b/gptqmodel/models/definitions/__init__.py index e10838ed4..7072e254d 100644 --- a/gptqmodel/models/definitions/__init__.py +++ b/gptqmodel/models/definitions/__init__.py @@ -44,7 +44,6 @@ from .mpt import MptQModel from .opt import OptQModel from .ovis import OvisQModel -from .ovis2_5 import Ovis2_5QModel from .phi import PhiQModel from .phi3 import Phi3QModel from .qwen import QwenQModel diff --git a/gptqmodel/models/definitions/ovis.py b/gptqmodel/models/definitions/ovis.py index 5c21b1aac..234ffa81b 100644 --- a/gptqmodel/models/definitions/ovis.py +++ b/gptqmodel/models/definitions/ovis.py @@ -17,7 +17,6 @@ class OvisQModel(BaseQModel): - require_trust_remote_code = True pre_lm_head_norm_module = "llm.model.norm" module_tree = [ @@ -50,23 +49,6 @@ def monkey_patch(self): self.model.vte = self.model.vte.to(dtype=self.model.llm.dtype) def pre_quantize_generate_hook_start(self): - visual_tokenizer_meta = any(param.device.type == "meta" for param in self.model.visual_tokenizer.parameters()) or \ - any(buffer.device.type == "meta" for buffer in self.model.visual_tokenizer.buffers()) - if visual_tokenizer_meta: - try: - self.shell_module_materialize(self.model.visual_tokenizer, self.quantize_config.device) - except Exception: - logging.warning("OVIS visual_tokenizer shell materialization failed; continuing with fallback move.", - exc_info=True) - - vte_meta = any(param.device.type == "meta" for param in self.model.vte.parameters()) or \ - any(buffer.device.type == "meta" for buffer in self.model.vte.buffers()) - if vte_meta: - try: - self.shell_module_materialize(self.model.vte, self.quantize_config.device) - except Exception: - logging.warning("OVIS VTE shell materialization failed; continuing with fallback move.", exc_info=True) - self.model.visual_tokenizer = move_to(self.model.visual_tokenizer, device=self.quantize_config.device) self.model.vte = move_to(self.model.vte, device=self.quantize_config.device) @@ -94,14 +76,8 @@ def preprocess_dataset(self, sample: Dict) -> Dict: propagate_exception=False ) - target_dtype = self.model.visual_tokenizer.dtype if pixel_values is None: pixel_values, _ = self.visual_tokenizer.mock_input() - pixel_values = [pv.to(dtype=target_dtype) for pv in pixel_values] - elif isinstance(pixel_values, (list, tuple)): - pixel_values = [pv.to(dtype=target_dtype) for pv in pixel_values] - else: - pixel_values = pixel_values.to(dtype=target_dtype) input_ids = input_ids[:text_max_length] labels = labels[:text_max_length] @@ -146,128 +122,7 @@ def prepare_dataset( return calib_data - def generate(self, inputs=None, **kwargs): + def generate(self, inputs, **kwargs): """shortcut for model.generate""" - model_device = getattr(self.model, "device", None) - if model_device is None: - quant_device = getattr(self.quantize_config, "device", None) - model_device = torch.device(quant_device if quant_device is not None else "cpu") - - llm = getattr(self.model, "llm", None) - - pixel_values = None - if isinstance(inputs, dict): - pixel_values = inputs.get("pixel_values") - if pixel_values is None: - pixel_values = kwargs.get("pixel_values") - - has_real_pixels = False - if pixel_values is not None: - if isinstance(pixel_values, (list, tuple)): - has_real_pixels = any(p is not None for p in pixel_values) - else: - has_real_pixels = True - - text_only = (llm is not None) and not has_real_pixels - - if text_only: - kwargs = dict(kwargs) - kwargs.pop("pixel_values", None) - - if isinstance(inputs, dict): - if inputs.get("pixel_values") is not None: - inputs = dict(inputs) - inputs.pop("pixel_values", None) - - llm_device = next(self.model.llm.parameters()).device - - def ensure_attention_mask(payload): - if payload.get("attention_mask") is not None or "input_ids" not in payload: - return payload - mask = torch.ones_like(payload["input_ids"], dtype=torch.bool, device=payload["input_ids"].device) - payload["attention_mask"] = mask - return payload - - if isinstance(inputs, (str, list)): - if isinstance(inputs, str) or (isinstance(inputs, list) and all(isinstance(x, str) for x in inputs)): - if self.tokenizer is None: - raise ValueError( - "You passed in text to OvisQModel.generate() but tokenizer is missing." - ) - tokenized = self.tokenizer( - inputs, - return_tensors="pt", - padding=True, - padding_side="left" - ).to(llm_device) - with torch.amp.autocast(device_type=llm_device.type): - return self.model.llm.generate(**tokenized, **kwargs) - - if isinstance(inputs, dict): - payload = {k: v for k, v in inputs.items() if k != "pixel_values"} - if "input_ids" in payload and isinstance(payload["input_ids"], torch.Tensor): - payload["input_ids"] = payload["input_ids"].to(llm_device) - if "attention_mask" in payload and isinstance(payload["attention_mask"], torch.Tensor): - payload["attention_mask"] = payload["attention_mask"].to(llm_device) - payload = ensure_attention_mask(payload) - payload.update(kwargs) - with torch.amp.autocast(device_type=llm_device.type): - return self.model.llm.generate(**payload) - - if isinstance(inputs, torch.Tensor): - inputs = inputs.to(llm_device) - attention_mask = kwargs.pop("attention_mask", None) - if attention_mask is None: - attention_mask = torch.ones_like(inputs, dtype=torch.bool, device=inputs.device) - else: - attention_mask = attention_mask.to(inputs.device) - - with torch.amp.autocast(device_type=llm_device.type): - return self.model.llm.generate( - inputs=inputs, - attention_mask=attention_mask, - **kwargs - ) - - if inputs is None: - payload = ensure_attention_mask({k: v for k, v in kwargs.items() if k in {"input_ids", "attention_mask"}}) - remaining = {k: v for k, v in kwargs.items() if k not in payload} - payload = {k: (v.to(llm_device) if isinstance(v, torch.Tensor) else v) for k, v in payload.items()} - with torch.amp.autocast(device_type=llm_device.type): - return self.model.llm.generate(**payload, **remaining) - - with torch.amp.autocast(device_type=model_device.type): - return super().generate(inputs=inputs, **kwargs) - - def forward(self, *args, **kwargs): - """Allow text-only invocations to bypass vision branch for evaluator compatibility.""" - if args: - # most callers pass input_ids positionally; forward them as keyword for clarity - if "input_ids" in kwargs: - raise TypeError("OvisQModel.forward() received positional and keyword input_ids") - kwargs = dict(kwargs) # shallow copy to avoid mutating caller state - kwargs["input_ids"] = args[0] - args = args[1:] - - input_ids = kwargs.get("input_ids") - pixel_values = kwargs.get("pixel_values") - - if input_ids is not None and pixel_values is None and hasattr(self.model, "llm"): - attention_mask = kwargs.get("attention_mask") - if attention_mask is None: - attention_mask = torch.ones_like(input_ids, dtype=torch.bool, device=input_ids.device) - - # Hugging Face text evaluators expect logits only; labels are optional here. - llm_kwargs = { - k: v - for k, v in kwargs.items() - if k not in {"pixel_values"} - } - llm_kwargs["attention_mask"] = attention_mask - - if llm_kwargs.get("labels") is None: - llm_kwargs.pop("labels", None) - - return self.model.llm(**llm_kwargs) - - return super().forward(*args, **kwargs) + with torch.inference_mode(), torch.amp.autocast(device_type=self.device.type): + return self.model.generate(inputs, **kwargs) diff --git a/gptqmodel/models/definitions/ovis2_5.py b/gptqmodel/models/definitions/ovis2_5.py deleted file mode 100644 index 80a00beac..000000000 --- a/gptqmodel/models/definitions/ovis2_5.py +++ /dev/null @@ -1,551 +0,0 @@ -# SPDX-FileCopyrightText: 2024-2025 ModelCloud.ai -# SPDX-FileCopyrightText: 2024-2025 qubitium@modelcloud.ai -# SPDX-License-Identifier: Apache-2.0 -# Contact: qubitium@modelcloud.ai, x.com/qubitium - -import copy -import logging -from typing import Dict, List, Sequence - -import torch -from transformers import AutoProcessor - -from ...utils.calibration import batched -from ...utils.image import fetch_image -from ...utils.model import MODALITY, move_to, nested_move_to -from .._const import CPU -from ..base import BaseQModel - - -log = logging.getLogger(__name__) - - -class Ovis2_5QModel(BaseQModel): - require_trust_remote_code = True - pre_lm_head_norm_module = "llm.model.norm" - - module_tree = [ - "llm", - "model", - "layers", - "#", - { - "input_layernorm": ("input_layernorm:!",), - "self_attn": ("q_proj:0", "k_proj:0", "v_proj:0", "o_proj:1"), - "post_attention_layernorm": ("post_attention_layernorm:!",), - "mlp": ("gate_proj:0", "up_proj:0", "down_proj:1"), - }, - ] - - layer_modules_strict = False - - require_monkeypatch = True - require_load_processor = True - - modality = [MODALITY.IMAGE_TO_TEXT] - - IGNORE_ID = -100 - - def monkey_patch(self): - # keep the vision tower dtype aligned with the text tower for stable quantization/inference - dtype = getattr(self.model.llm, "dtype", None) - if dtype is None or dtype == torch.float32: - dtype = torch.bfloat16 - try: - self.model.llm = self.model.llm.to(dtype=dtype) - if hasattr(self.model.llm, "config"): - self.model.llm.config.dtype = dtype - except Exception: - log.warning("Failed to cast llm to %s", dtype, exc_info=True) - - try: - self.model.visual_tokenizer = self.model.visual_tokenizer.to(dtype=dtype) - except Exception: - log.warning("Failed to cast visual_tokenizer to %s", dtype, exc_info=True) - try: - self.model.vte = self.model.vte.to(dtype=dtype) - except Exception: - log.warning("Failed to cast visual embedding to %s", dtype, exc_info=True) - - attn_impl = getattr(self.model.llm.config, "_attn_implementation", None) - if attn_impl == "flash_attention_2": - log.info("Ovis2.5 monkey_patch: downgrading attention implementation to eager for compatibility.") - self.model.llm.config._attn_implementation = "eager" - - @property - def text_tokenizer(self): - return getattr(self.model, "text_tokenizer", self.tokenizer) - - @property - def visual_tokenizer(self): - return getattr(self.model, "visual_tokenizer", None) - - def load_processor(self): - return AutoProcessor.from_pretrained(self.model_local_path, trust_remote_code=True) - - def get_text_tokenizer(self): - return getattr(self.model, "text_tokenizer", None) - - def get_visual_tokenizer(self): - return getattr(self.model, "visual_tokenizer", None) - - def _ensure_image_payload(self, messages: List[Dict]) -> List[Dict]: - normalized_messages = copy.deepcopy(messages) - for message in normalized_messages: - content = message.get("content") - if not isinstance(content, list): - continue - - updated_content = [] - for item in content: - if not isinstance(item, dict): - updated_content.append(item) - continue - - if item.get("type") != "image": - updated_content.append(item) - continue - - if item.get("image") is None and item.get("image_url") is None: - log.warning("Skipping image item without image or image_url: message=%s", message) - continue - - if hasattr(item.get("image"), "mode"): - updated_content.append(item) - continue - - source = {"image": item.get("image")} if item.get("image") is not None else {"image_url": item.get("image_url")} - try: - image = fetch_image(source) - except Exception as exc: - log.warning("Failed to load image for message; skipping image token", exc_info=exc) - continue - - new_item = dict(item) - new_item["image"] = image - new_item.pop("image_url", None) - updated_content.append(new_item) - - message["content"] = updated_content - - return normalized_messages - - def _build_messages_from_conversations( - self, - conversations: Sequence[Dict], - sample: Dict, - ) -> List[Dict]: - images = sample.get("image") - if images is None: - images_sequence: List = [] - elif isinstance(images, (list, tuple)): - images_sequence = list(images) - else: - images_sequence = [images] - - image_objects: List = [] - for image_entry in images_sequence: - try: - image_objects.append(fetch_image({"image": image_entry})) - except Exception as exc: - log.warning("Failed to load image `%s` referenced by conversations", image_entry, exc_info=exc) - - image_iter = iter(image_objects) - normalized_messages: List[Dict] = [] - for turn in conversations: - speaker = turn.get("from") or turn.get("role") or "user" - value = turn.get("value") or turn.get("content") or "" - if speaker == "human": - role = "user" - elif speaker == "gpt": - role = "assistant" - else: - role = speaker - - if role == "assistant": - normalized_messages.append({"role": "assistant", "content": value}) - continue - - segments = value.split("") - content: List[Dict] = [] - for index, segment in enumerate(segments): - if segment: - content.append({"type": "text", "text": segment}) - if index < len(segments) - 1: - image_obj = next(image_iter, None) - if image_obj is None: - log.warning("Conversation refers to more tokens than provided images.") - break - content.append({"type": "image", "image": image_obj}) - - if not content: - content.append({"type": "text", "text": ""}) - - normalized_messages.append({"role": role, "content": content}) - - return normalized_messages - - def _normalize_messages(self, sample: Dict) -> List[Dict]: - if isinstance(sample, list): - return self._ensure_image_payload(sample) - - if sample.get("messages"): - return self._ensure_image_payload(sample["messages"]) - - conversations = sample.get("conversations") - if not conversations: - raise ValueError("Ovis2_5 calibration sample must provide `messages` or `conversations`.") - - return self._ensure_image_payload(self._build_messages_from_conversations(conversations, sample)) - - def _coerce_pixel_values(self, pixel_values): - if pixel_values is None: - return None - - target_dtype = getattr(getattr(self.visual_tokenizer, "vit", self.visual_tokenizer), "dtype", None) - target_dtype = target_dtype or getattr(self.visual_tokenizer, "dtype", None) - if isinstance(pixel_values, torch.Tensor): - if target_dtype is not None: - pixel_values = pixel_values.to(dtype=target_dtype) - return pixel_values - - if isinstance(pixel_values, (list, tuple)): - coerced = [] - for pv in pixel_values: - tensor = pv if isinstance(pv, torch.Tensor) else torch.as_tensor(pv) - if target_dtype is not None: - tensor = tensor.to(dtype=target_dtype) - coerced.append(tensor) - return coerced - - tensor = torch.as_tensor(pixel_values) - if target_dtype is not None: - tensor = tensor.to(dtype=target_dtype) - return tensor - - def preprocess_dataset(self, sample: Dict) -> Dict: - messages = self._normalize_messages(sample) - input_ids, pixel_values, grid_thws = self.model.preprocess_inputs(messages) - - pixel_values = self._coerce_pixel_values(pixel_values) - if pixel_values is None: - pixel_values = [] - - if isinstance(grid_thws, torch.Tensor): - grid_thws = grid_thws.to(dtype=torch.long) - elif grid_thws is None: - grid_thws = [] - - input_ids = input_ids.squeeze(0) - pad_token_id = self.text_tokenizer.pad_token_id - if pad_token_id is None: - pad_token_id = self.text_tokenizer.eos_token_id - - attention_mask = torch.ne(input_ids, pad_token_id) - - labels = input_ids.clone() - labels.masked_fill_(labels < 0, self.IGNORE_ID) - labels.masked_fill_(~attention_mask, self.IGNORE_ID) - - return { - "input_ids": input_ids, - "attention_mask": attention_mask, - "labels": labels, - "pixel_values": pixel_values, - "grid_thws": grid_thws, - } - - def prepare_dataset( - self, - calibration_dataset, - calibration_dataset_concat_size=None, - batch_size: int = 1, - tokenizer=None, - **kwargs, - ): - pad_token_id = self.text_tokenizer.pad_token_id - if pad_token_id is None: - pad_token_id = self.text_tokenizer.eos_token_id - - calib_data = [] - for batch in batched(calibration_dataset, batch_size, self.preprocess_dataset): - input_ids_list = [instance["input_ids"] for instance in batch] - attention_masks_list = [instance["attention_mask"].to(torch.bool) for instance in batch] - labels_list = [instance["labels"] for instance in batch] - pixel_values_list = [instance["pixel_values"] for instance in batch] - grid_thws_list = [instance["grid_thws"] for instance in batch] - - def _collate_vision(items): - if not items: - return None - if len(items) == 1: - value = items[0] - if value in (None, []): - return None - return value - return items - - pixel_values = _collate_vision(pixel_values_list) - grid_thws = _collate_vision(grid_thws_list) - - input_ids = torch.nn.utils.rnn.pad_sequence( - input_ids_list, - batch_first=True, - padding_value=pad_token_id, - ) - attention_mask = torch.nn.utils.rnn.pad_sequence( - attention_masks_list, - batch_first=True, - padding_value=0, - ).to(torch.bool) - labels = torch.nn.utils.rnn.pad_sequence( - labels_list, - batch_first=True, - padding_value=self.IGNORE_ID, - ) - - calib_data.append( - { - "input_ids": input_ids, - "attention_mask": attention_mask, - "labels": labels, - "pixel_values": pixel_values, - "grid_thws": grid_thws, - } - ) - - return calib_data - - def pre_quantize_generate_hook_start(self): - visual_tokenizer = getattr(self.model, "visual_tokenizer", None) - vision_modules = [] - if visual_tokenizer is not None: - vision_modules.append(visual_tokenizer) - vit = getattr(visual_tokenizer, "vit", None) - if vit is not None: - vision_modules.append(vit) - vte = getattr(self.model, "vte", None) - if vte is not None: - vision_modules.append(vte) - - for module in vision_modules: - if module is None: - continue - try: - has_meta_params = any(param.device.type == "meta" for param in module.parameters()) - except Exception: - has_meta_params = False - try: - has_meta_buffers = any(buffer.device.type == "meta" for buffer in module.buffers()) - except Exception: - has_meta_buffers = False - - if has_meta_params or has_meta_buffers: - try: - self.shell_module_materialize(module, self.quantize_config.device) - except Exception: # pragma: no cover - defensive downgrade - log.warning("OVIS2.5 module shell materialization failed; continuing with fallback move.", exc_info=True) - - if visual_tokenizer is not None: - self.model.visual_tokenizer = move_to(visual_tokenizer, device=self.quantize_config.device) - vit = getattr(self.model.visual_tokenizer, "vit", None) - if vit is not None: - self.model.visual_tokenizer.vit = move_to(vit, device=self.quantize_config.device) - - if vte is not None: - self.model.vte = move_to(vte, device=self.quantize_config.device) - - indicator_buffer = getattr(self.model, "indicator_token_indices", None) - if isinstance(indicator_buffer, torch.Tensor) and indicator_buffer.device.type == "meta": - count = indicator_buffer.shape[0] - vocab_size = getattr(getattr(self.model, "config", None), "visual_vocab_size", count) - start_index = vocab_size - count - device = self.quantize_config.device - if not isinstance(device, torch.device): - device = torch.device(device) - materialized = torch.arange(start_index, vocab_size, dtype=torch.long, device=device) - self.model.register_buffer("indicator_token_indices", materialized, persistent=False) - - def pre_quantize_generate_hook_end(self): - if self.quantize_config.offload_to_disk: - from ...utils.offload import offload_to_disk - - visual_tokenizer = getattr(self.model, "visual_tokenizer", None) - if visual_tokenizer is not None: - offload_to_disk(model=self.model, module=visual_tokenizer, disk_path=self.quantize_config.offload_to_disk_path) - vit = getattr(visual_tokenizer, "vit", None) - if vit is not None: - offload_to_disk(model=visual_tokenizer, module=vit, disk_path=self.quantize_config.offload_to_disk_path) - - vte = getattr(self.model, "vte", None) - if vte is not None: - offload_to_disk(model=self.model, module=vte, disk_path=self.quantize_config.offload_to_disk_path) - return - - visual_tokenizer = getattr(self.model, "visual_tokenizer", None) - if visual_tokenizer is not None: - self.model.visual_tokenizer = move_to(visual_tokenizer, device=CPU) - vit = getattr(self.model.visual_tokenizer, "vit", None) - if vit is not None: - self.model.visual_tokenizer.vit = move_to(vit, device=CPU) - - vte = getattr(self.model, "vte", None) - if vte is not None: - self.model.vte = move_to(vte, device=CPU) - - indicator_buffer = getattr(self.model, "indicator_token_indices", None) - if isinstance(indicator_buffer, torch.Tensor): - self.model.register_buffer("indicator_token_indices", indicator_buffer.to(CPU), persistent=False) - - def generate(self, inputs=None, **kwargs): - model_device = getattr(self.model, "device", None) - if model_device is None: - quant_device = getattr(self.quantize_config, "device", None) - model_device = torch.device(quant_device if quant_device is not None else "cpu") - - llm = getattr(self.model, "llm", None) - - pixel_values = None - grid_thws = None - if isinstance(inputs, dict): - pixel_values = inputs.get("pixel_values") - grid_thws = inputs.get("grid_thws") - if pixel_values is None: - pixel_values = kwargs.get("pixel_values") - if grid_thws is None: - grid_thws = kwargs.get("grid_thws") - - def _has_pixels(payload): - if payload is None: - return False - if isinstance(payload, torch.Tensor): - return payload.numel() > 0 - if isinstance(payload, (str, bytes)): - return False - if isinstance(payload, Sequence): - return any(_has_pixels(item) for item in payload) - return True - - has_real_pixels = _has_pixels(pixel_values) - - if llm is not None and not has_real_pixels: - kwargs = dict(kwargs) - kwargs.pop("pixel_values", None) - kwargs.pop("grid_thws", None) - - if isinstance(inputs, dict): - payload = {k: v for k, v in inputs.items() if k not in {"pixel_values", "grid_thws"}} - else: - payload = inputs - - def ensure_attention_mask(payload_dict): - if payload_dict is None: - return None - if payload_dict.get("attention_mask") is not None or "input_ids" not in payload_dict: - return payload_dict - mask = torch.ones_like(payload_dict["input_ids"], dtype=torch.bool, device=payload_dict["input_ids"].device) - payload_dict["attention_mask"] = mask - return payload_dict - - llm_device = next(self.model.llm.parameters()).device - - if isinstance(payload, (str, list)): - if isinstance(payload, str) or (isinstance(payload, list) and all(isinstance(x, str) for x in payload)): - if self.tokenizer is None: - raise ValueError("You passed in text to Ovis2_5QModel.generate() but tokenizer is missing.") - tokenized = self.tokenizer( - payload, - return_tensors="pt", - padding=True, - padding_side="left", - ).to(llm_device) - with torch.amp.autocast(device_type=llm_device.type): - return self.model.llm.generate(**tokenized, **kwargs) - - if isinstance(payload, dict): - payload = ensure_attention_mask({k: (v.to(llm_device) if isinstance(v, torch.Tensor) else v) for k, v in payload.items()}) - with torch.amp.autocast(device_type=llm_device.type): - return self.model.llm.generate(**payload, **kwargs) - - if isinstance(payload, torch.Tensor): - payload = payload.to(llm_device) - attention_mask = kwargs.pop("attention_mask", None) - if attention_mask is None: - attention_mask = torch.ones_like(payload, dtype=torch.bool, device=payload.device) - else: - attention_mask = attention_mask.to(payload.device) - with torch.amp.autocast(device_type=llm_device.type): - return self.model.llm.generate(inputs=payload, attention_mask=attention_mask, **kwargs) - - if payload is None: - payload_inputs = {k: v for k, v in kwargs.items() if k in {"input_ids", "attention_mask"}} - payload_inputs = ensure_attention_mask(payload_inputs) - if payload_inputs: - payload_inputs = {k: (v.to(llm_device) if isinstance(v, torch.Tensor) else v) for k, v in payload_inputs.items()} - with torch.amp.autocast(device_type=llm_device.type): - return self.model.llm.generate(**payload_inputs, **{k: v for k, v in kwargs.items() if k not in payload_inputs}) - - kwargs = dict(kwargs) - if isinstance(inputs, dict): - inputs_dict = dict(inputs) - base_inputs = inputs_dict.pop("input_ids", None) - if base_inputs is None and "inputs_embeds" in inputs_dict: - kwargs.setdefault("inputs_embeds", inputs_dict.pop("inputs_embeds")) - for key in list(inputs_dict.keys()): - if key in {"pixel_values", "grid_thws"}: - inputs_dict.pop(key) - for key, value in inputs_dict.items(): - kwargs.setdefault(key, value) - inputs = base_inputs - - if inputs is None and isinstance(kwargs.get("input_ids"), torch.Tensor): - inputs = kwargs.get("input_ids") - - target_device = next(self.model.llm.parameters()).device - model_device = target_device - if isinstance(inputs, torch.Tensor) and inputs.device != target_device: - inputs = inputs.to(target_device) - if isinstance(kwargs.get("input_ids"), torch.Tensor) and kwargs["input_ids"].device != target_device: - kwargs["input_ids"] = kwargs["input_ids"].to(target_device) - if "input_ids" in kwargs: - kwargs.pop("input_ids") - if isinstance(kwargs.get("attention_mask"), torch.Tensor) and kwargs["attention_mask"].device != target_device: - kwargs["attention_mask"] = kwargs["attention_mask"].to(target_device) - - if pixel_values is not None: - kwargs["pixel_values"] = nested_move_to(pixel_values, device=model_device) - if grid_thws is not None: - kwargs["grid_thws"] = nested_move_to(grid_thws, device=model_device) - - with torch.amp.autocast(device_type=model_device.type): - return super().generate(inputs=inputs, **kwargs) - - def forward(self, *args, **kwargs): - if args: - if "input_ids" in kwargs: - raise TypeError("Ovis2_5QModel.forward() received positional and keyword input_ids") - kwargs = dict(kwargs) - kwargs["input_ids"] = args[0] - args = args[1:] - - input_ids = kwargs.get("input_ids") - pixel_values = kwargs.get("pixel_values") - - if input_ids is not None and pixel_values is None and hasattr(self.model, "llm"): - attention_mask = kwargs.get("attention_mask") - if attention_mask is None: - attention_mask = torch.ones_like(input_ids, dtype=torch.bool, device=input_ids.device) - - llm_kwargs = { - k: v - for k, v in kwargs.items() - if k not in {"pixel_values", "grid_thws"} - } - llm_kwargs["attention_mask"] = attention_mask - - if llm_kwargs.get("labels") is None: - llm_kwargs.pop("labels", None) - - return self.model.llm(**llm_kwargs) - - return super().forward(*args, **kwargs) diff --git a/gptqmodel/models/loader.py b/gptqmodel/models/loader.py index bdc8589b4..aae472190 100644 --- a/gptqmodel/models/loader.py +++ b/gptqmodel/models/loader.py @@ -25,7 +25,7 @@ from huggingface_hub import snapshot_download from packaging.version import InvalidVersion, Version -from transformers import AutoTokenizer, PretrainedConfig +from transformers import AutoConfig, AutoTokenizer, PretrainedConfig from transformers.modeling_utils import no_init_weights from transformers.utils import is_flash_attn_2_available from transformers.utils.generic import ContextManagers @@ -35,7 +35,6 @@ from ..quantization import QuantizeConfig from ..quantization.config import FORMAT, METHOD, MIN_VERSION_WITH_V2 from ..utils.backend import BACKEND -from ..utils.hf import safe_auto_config_from_pretrained from ..utils.importer import auto_select_device, normalize_device_device_map, select_quant_linear from ..utils.logger import setup_logger from ..utils.marlin import _validate_marlin_device_support @@ -57,45 +56,6 @@ log = setup_logger() ATTN_IMPLEMENTATION = "attn_implementation" - - -def _coerce_torch_dtype(value): - if value is None: - return None - if isinstance(value, torch.dtype): - return value - attr = None - if isinstance(value, str): - attr = getattr(torch, value, None) - if isinstance(attr, torch.dtype): - return attr - raise TypeError(f"Unsupported torch dtype value: {value!r}") - - -def _clear_config_attr(config, name: str) -> None: - store = getattr(config, "__dict__", None) - if isinstance(store, dict) and name in store: - del store[name] - - -def _normalize_config_dtype(config) -> Optional[torch.dtype]: - legacy = getattr(config, "torch_dtype", None) - if legacy is not None: - try: - coerced = _coerce_torch_dtype(legacy) - except TypeError: - coerced = None - if coerced is not None: - setattr(config, "dtype", coerced) - _clear_config_attr(config, "torch_dtype") - current = getattr(config, "dtype", None) - if isinstance(current, str): - try: - current = _coerce_torch_dtype(current) - setattr(config, "dtype", current) - except TypeError: - current = None - return current if isinstance(current, torch.dtype) else None def parse_version_string(version_str: str): try: return Version(version_str) @@ -202,12 +162,9 @@ def skip(*args, **kwargs): torch.nn.init.uniform_ = skip torch.nn.init.normal_ = skip - torch_dtype_arg = model_init_kwargs.pop("torch_dtype", None) - model_init_kwargs["trust_remote_code"] = trust_remote_code - config = safe_auto_config_from_pretrained(model_local_path, **model_init_kwargs) - normalized_config_dtype = _normalize_config_dtype(config) + config = AutoConfig.from_pretrained(model_local_path, **model_init_kwargs) atten_impl = model_init_kwargs.get("attn_implementation", None) @@ -224,32 +181,13 @@ def skip(*args, **kwargs): if cls.require_dtype: dtype = cls.require_dtype - if torch_dtype_arg is not None: - coerced = _coerce_torch_dtype(torch_dtype_arg) - if dtype is None or dtype == "auto" or not isinstance(dtype, torch.dtype): - dtype = coerced - else: - log.info("Loader: overriding legacy torch_dtype argument with `dtype` and removing duplication.") - model_init_kwargs["dtype"] = dtype - if dtype is None or dtype == "auto" or not isinstance(dtype, torch.dtype): - if normalized_config_dtype is not None: - dtype = normalized_config_dtype - else: - # TODO FIX ME for `dynamic`, non-quantized modules should be in native type - dtype = auto_dtype(config=config, device=quantize_config.device, quant_inference=False) - - if isinstance(dtype, torch.dtype): - current = getattr(config, "dtype", None) - if isinstance(current, str): - try: - current = _coerce_torch_dtype(current) - except TypeError: - current = None - if current != dtype: - # Align config metadata with the dtype we will materialize weights in. - config.dtype = dtype - _clear_config_attr(config, "torch_dtype") + # TODO FIX ME for `dynamic`, non-quantized modules should be in native type + dtype = auto_dtype(config=config, device=quantize_config.device, quant_inference=False) + + if isinstance(dtype, torch.dtype) and getattr(config, "torch_dtype", None) != dtype: + # Align config metadata with the dtype we will materialize weights in. + config.torch_dtype = dtype # enforce some values despite user specified # non-quantized models are always loaded into cpu @@ -263,21 +201,12 @@ def skip(*args, **kwargs): if quantize_config.offload_to_disk: print("shell model-----------") - shell_kwargs = model_init_kwargs.copy() - shell_dtype = shell_kwargs.pop("dtype", dtype) - model = build_shell_model(cls.loader, config=config, dtype=shell_dtype, **shell_kwargs) + model = build_shell_model(cls.loader, config=config, **model_init_kwargs) model._model_init_kwargs = model_init_kwargs print_module_tree(model=model) # enable mmap with low_cpu_mem_usage - turtle_kwargs = model_init_kwargs.copy() - turtle_kwargs.setdefault("dtype", dtype) - turtle_model = cls.loader.from_pretrained( - model_local_path, - config=config, - low_cpu_mem_usage=True, - **turtle_kwargs, - ) + turtle_model = cls.loader.from_pretrained(model_local_path, config=config, low_cpu_mem_usage=True, **model_init_kwargs) # TODO FIX ME...temp store model_init args turtle_model._model_init_kwargs = model_init_kwargs @@ -285,26 +214,12 @@ def skip(*args, **kwargs): # print_module_tree(model=turtle_model) else: print("loading model directly to CPU (not using meta device or turtle_model)-----------") - direct_kwargs = model_init_kwargs.copy() - direct_kwargs.setdefault("dtype", dtype) - model = cls.loader.from_pretrained( - model_local_path, - config=config, - **direct_kwargs, - ) + model = cls.loader.from_pretrained(model_local_path, config=config, **model_init_kwargs) model._model_init_kwargs = model_init_kwargs print_module_tree(model=model) turtle_model = None - if isinstance(dtype, torch.dtype): - if getattr(model, "config", None) is not None: - model.config.dtype = dtype - _clear_config_attr(model.config, "torch_dtype") - if turtle_model is not None and getattr(turtle_model, "config", None) is not None: - turtle_model.config.dtype = dtype - _clear_config_attr(turtle_model.config, "torch_dtype") - model_config = model.config.to_dict() seq_len_keys = ["max_position_embeddings", "seq_length", "n_positions", "multimodal_max_length"] config_seq_len = find_config_seq_len(model_config, seq_len_keys) @@ -389,8 +304,7 @@ def from_quantized( revision = kwargs.pop("revision", None) subfolder = kwargs.pop("subfolder", "") commit_hash = kwargs.pop("_commit_hash", None) - attn_arg = kwargs.pop("attn_implementation", None) - torch_dtype_arg = kwargs.pop("torch_dtype", None) + attn_implementation = kwargs.pop("attn_implementation", None) cached_file_kwargs = { "cache_dir": cache_dir, @@ -403,57 +317,26 @@ def from_quantized( "subfolder": subfolder, "_raise_exceptions_for_missing_entries": False, "_commit_hash": commit_hash, + "attn_implementation": attn_implementation, } # == step1: prepare configs and file names == # - print("[DEBUG] safe_auto_config call", trust_remote_code, model_local_path) - config: PretrainedConfig = safe_auto_config_from_pretrained( + config: PretrainedConfig = AutoConfig.from_pretrained( model_local_path, trust_remote_code=trust_remote_code, **cached_file_kwargs, ) - log.info("Loader: safe_auto_config_from_pretrained called with trust_remote_code=%s for %s", - trust_remote_code, model_local_path) - print("[DEBUG] loaded config model_type", getattr(config, "model_type", None)) - attn_override = attn_arg - if getattr(config, "model_type", "").lower() == "ovis": - for key in ("attn_implementation", "_attn_implementation"): - value = getattr(config, key, None) - if value == "flash_attention_2" or value is None: - setattr(config, key, "eager") - attn_override = "eager" - cached_file_kwargs.pop("attn_implementation", None) - normalized_config_dtype = _normalize_config_dtype(config) if cls.require_dtype: dtype = cls.require_dtype - if torch_dtype_arg is not None: - coerced = _coerce_torch_dtype(torch_dtype_arg) - if dtype is None or dtype == "auto" or not isinstance(dtype, torch.dtype): - dtype = coerced - else: - log.info("Loader: overriding legacy torch_dtype argument with `dtype` and removing duplication.") - kwargs["dtype"] = dtype + if dtype is None or dtype == "auto" or not isinstance(dtype, torch.dtype) : + # TODO FIX ME for `dynamic`, non-quantized modules should be in native type + dtype = auto_dtype(config=config, device=device, quant_inference=True) - if dtype is None or dtype == "auto" or not isinstance(dtype, torch.dtype): - if normalized_config_dtype is not None: - dtype = normalized_config_dtype - else: - # TODO FIX ME for `dynamic`, non-quantized modules should be in native type - dtype = auto_dtype(config=config, device=device, quant_inference=True) - - if isinstance(dtype, torch.dtype): - current = getattr(config, "dtype", None) - if isinstance(current, str): - try: - current = _coerce_torch_dtype(current) - except TypeError: - current = None - if current != dtype: - # Ensure flash attention kernels see an explicit dtype instead of relying on defaults. - config.dtype = dtype - _clear_config_attr(config, "torch_dtype") + if isinstance(dtype, torch.dtype) and getattr(config, "torch_dtype", None) != dtype: + # Ensure flash attention kernels see an explicit dtype instead of relying on defaults. + config.torch_dtype = dtype qcfg = QuantizeConfig.from_pretrained(model_local_path, **cached_file_kwargs, **kwargs) @@ -578,40 +461,28 @@ def skip(*args, **kwargs): with (ContextManagers(init_contexts)): cls.before_model_load(cls, load_quantized_model=True) - supports_flash_attn = bool(getattr(config, "_supports_flash_attn_2", False)) if config.architectures: model_class = getattr(transformers, config.architectures[0], None) - if model_class is not None: - model_supports_flash = getattr(model_class, "_supports_flash_attn_2", None) - if model_supports_flash is not None: - supports_flash_attn = bool(model_supports_flash) + if model_class is not None and hasattr(model_class, "_supports_flash_attn_2"): + supports_flash_attn = model_class._supports_flash_attn_2 + else: + supports_flash_attn = None + else: + supports_flash_attn = None args = {} - if attn_override is not None: - args[ATTN_IMPLEMENTATION] = attn_override - elif ATTN_IMPLEMENTATION in kwargs: - args[ATTN_IMPLEMENTATION] = kwargs.pop(ATTN_IMPLEMENTATION, None) - elif device in [DEVICE.CUDA, DEVICE.ROCM]: - if supports_flash_attn and is_flash_attn_2_available(): - args[ATTN_IMPLEMENTATION] = "flash_attention_2" + if supports_flash_attn and device in [DEVICE.CUDA, DEVICE.ROCM]: + if ATTN_IMPLEMENTATION in kwargs: + args[ATTN_IMPLEMENTATION] = kwargs.pop(ATTN_IMPLEMENTATION, None) + elif is_flash_attn_2_available(): + args = {ATTN_IMPLEMENTATION: "flash_attention_2"} log.info("Loader: Auto enabling flash attention2") - flash_attn_requested = args.get(ATTN_IMPLEMENTATION) == "flash_attention_2" - if flash_attn_requested and isinstance(dtype, torch.dtype): - config.dtype = dtype - _clear_config_attr(config, "torch_dtype") - model = cls.loader.from_config( config, trust_remote_code=trust_remote_code, dtype=dtype, **args ) model.checkpoint_file_name = model_save_name - if flash_attn_requested and isinstance(dtype, torch.dtype) and getattr(model, "config", None) is not None: - model.config.dtype = dtype - _clear_config_attr(model.config, "torch_dtype") - if ATTN_IMPLEMENTATION in args and getattr(model, "config", None) is not None: - _clear_config_attr(model.config, "_attn_implementation") - # Get the first layer to determine layer type layers, _ = get_module_by_name_prefix(model, cls.extract_layers_node()) diff --git a/gptqmodel/models/writer.py b/gptqmodel/models/writer.py index 4c1777c72..f0ac0dbba 100644 --- a/gptqmodel/models/writer.py +++ b/gptqmodel/models/writer.py @@ -16,7 +16,7 @@ import torch import transformers from safetensors.torch import save_file -from transformers import PreTrainedTokenizerFast, ProcessorMixin +from transformers import AutoConfig, PreTrainedTokenizerFast, ProcessorMixin from transformers.modeling_utils import no_init_weights from transformers.models.auto.tokenization_auto import get_tokenizer_config from transformers.utils.generic import ContextManagers @@ -41,7 +41,7 @@ MIN_VERSION_WITH_V2, ) from ..utils.backend import BACKEND -from ..utils.hf import safe_auto_config_from_pretrained, sanitize_generation_config_file +from ..utils.hf import sanitize_generation_config_file from ..utils.logger import setup_logger from ..utils.model import ( convert_gptq_v2_to_v1_format, @@ -61,32 +61,6 @@ log = setup_logger() -_ATTN_IMPLEMENTATION_KEYS = ("attn_implementation", "_attn_implementation") - - -def _stash_and_clear_attention_attr(config_like): - """Temporarily remove transient attention implementation flags before serialization.""" - if config_like is None: - return {} - store = getattr(config_like, "__dict__", None) - if not isinstance(store, dict): - return {} - removed = {} - for key in _ATTN_IMPLEMENTATION_KEYS: - if key in store: - removed[key] = store.pop(key) - return removed - - -def _restore_attention_attr(config_like, removed_attrs): - if not removed_attrs: - return - store = getattr(config_like, "__dict__", None) - if not isinstance(store, dict): - return - for key, value in removed_attrs.items(): - store[key] = value - PROCESS_LOG_NAME = "process" PROCESS_LOG_LAYER = "layer" PROCESS_LOG_MODULE = "module" @@ -274,17 +248,9 @@ def save_quantized( config.quantization_config = quantize_config.to_dict() self.model.config = config - removed_config_attn = _stash_and_clear_attention_attr(config) - generation_config = getattr(self.model, "generation_config", None) - removed_generation_attn = _stash_and_clear_attention_attr(generation_config) - # Save model config, including generation_config # Use empty state_dict hack to bypass saving weights - try: - self.model.save_pretrained(save_dir, state_dict={}, is_main_process=True) - finally: - _restore_attention_attr(config, removed_config_attn) - _restore_attention_attr(generation_config, removed_generation_attn) + self.model.save_pretrained(save_dir, state_dict={}, is_main_process=True) gen_config_path = os.path.join(save_dir, "generation_config.json") if sanitize_generation_config_file(gen_config_path): @@ -456,7 +422,7 @@ def _normalize_metadata(meta: Optional[Dict[str, Any]]) -> Dict[str, str]: def get_model_with_quantize(self, qcfg, model_id_or_path): - config = safe_auto_config_from_pretrained( + config = AutoConfig.from_pretrained( model_id_or_path, trust_remote_code=True, ) diff --git a/gptqmodel/nn_modules/qlinear/bitblas.py b/gptqmodel/nn_modules/qlinear/bitblas.py index 88df78026..31fde7924 100644 --- a/gptqmodel/nn_modules/qlinear/bitblas.py +++ b/gptqmodel/nn_modules/qlinear/bitblas.py @@ -140,7 +140,7 @@ def __post_init__(self) -> None: raise ValueError("weight_bits must divide 32 for GPTQ packing") self.pack_factor = 32 // self.weight_bits self.torch_storage_dtype = getattr(torch, self.storage_dtype) - self.dtype = torch.float16 + self.torch_dtype = torch.float16 @property def with_zeros(self) -> bool: diff --git a/gptqmodel/utils/device.py b/gptqmodel/utils/device.py index 97467de23..744ce3d08 100644 --- a/gptqmodel/utils/device.py +++ b/gptqmodel/utils/device.py @@ -4,7 +4,7 @@ # Contact: qubitium@modelcloud.ai, x.com/qubitium from __future__ import annotations -from typing import Any, Optional, Union +from typing import Optional, Union import torch from device_smi import Device @@ -23,66 +23,19 @@ def get_cpu_usage_memory(): smi = Device(CPU) return smi.memory_used() / 1024 / 1024 / 1024 #GB -def _coerce_device(value: Any) -> Optional[torch.device]: - if isinstance(value, torch.device): - return value - if isinstance(value, str): - try: - return torch.device(value) - except (TypeError, ValueError): # pragma: no cover - defensive - return None - return None - - -def _extract_tensor_device(obj: Any) -> Optional[torch.device]: - # Prefer explicit attribute before reflective search. - direct = _coerce_device(getattr(obj, "device", None)) - if direct is not None: - return direct - - # Try common tensor holders (e.g. dataclasses) via __dict__. - values = [] - store = getattr(obj, "__dict__", None) - if isinstance(store, dict): - values.extend(store.values()) - else: - # Fallback to dir() to support __slots__ or C++ backed objects without __dict__. - for attr in dir(obj): - if attr.startswith("_"): - continue - try: - values.append(getattr(obj, attr)) - except AttributeError: # pragma: no cover - dynamic attr access - continue - except Exception: # pragma: no cover - defensive against @property side effects - continue - - for value in values: - if isinstance(value, torch.Tensor): - return value.device - - return None - - -def get_device(obj: Any) -> torch.device: +def get_device(obj: torch.Tensor | nn.Module) -> torch.device: if isinstance(obj, torch.Tensor): return obj.device - if isinstance(obj, nn.Module): - params = list(obj.parameters()) - buffers = list(obj.buffers()) - if params: - return params[0].device - if buffers: - return buffers[0].device + params = list(obj.parameters()) + buffers = list(obj.buffers()) + if len(params) > 0: + return params[0].device + elif len(buffers) > 0: + return buffers[0].device + else: return CPU - extracted = _extract_tensor_device(obj) - if extracted is not None: - return extracted - - return CPU - def get_device_new( obj: torch.Tensor | nn.Module, recursive: bool = False, diff --git a/gptqmodel/utils/hf.py b/gptqmodel/utils/hf.py index 9c087a39f..434b8fea9 100644 --- a/gptqmodel/utils/hf.py +++ b/gptqmodel/utils/hf.py @@ -4,161 +4,18 @@ # Contact: qubitium@modelcloud.ai, x.com/qubitium import json -import os -import re from typing import Any, Optional import torch from accelerate import init_empty_weights -from transformers import AutoConfig, GenerationConfig, PreTrainedModel -from transformers.models.auto.configuration_auto import CONFIG_MAPPING +from transformers import GenerationConfig, PreTrainedModel from ..utils.logger import setup_logger log = setup_logger() -# Parameters that enable sampling-style generation when non-default. -GENERATION_SAMPLING_FIELDS = ( - "temperature", - "top_p", - "top_k", - "min_p", - "typical_p", - "epsilon_cutoff", - "eta_cutoff", - "penalty_alpha", -) - -# Default values enforced by Transformers when `do_sample` is disabled. -_GREEDY_DEFAULTS = { - "temperature": 1.0, - "top_p": 1.0, - "min_p": None, - "typical_p": 1.0, - "top_k": 50, - "epsilon_cutoff": 0.0, - "eta_cutoff": 0.0, - # Contrastive search uses penalty_alpha; keep unset when not sampling - "penalty_alpha": None, -} - -_DUPLICATE_CONFIG_PATTERN = re.compile(r"'([^']+)' is already used by a Transformers config") - - -def _ensure_pretrained_model_defaults(): - fallback_attrs = { - "is_parallelizable": False, - "_no_split_modules": (), - "_keep_in_fp32_modules": (), - "_skip_keys_device_placement": (), - } - for name, value in fallback_attrs.items(): - if not hasattr(PreTrainedModel, name): - setattr(PreTrainedModel, name, value) - - -_ensure_pretrained_model_defaults() - - -def _drop_generation_field(cfg: GenerationConfig, field: str) -> bool: - """Remove a generation field from the config and return True if mutated.""" - changed = False - # GenerationConfig keeps its data in both __dict__ and _internal_dict - for attr in ("__dict__", "_internal_dict"): - container = getattr(cfg, attr, None) - if isinstance(container, dict) and field in container: - container.pop(field, None) - changed = True - - if hasattr(cfg, field): - try: - delattr(cfg, field) - except AttributeError: - pass - - return changed - - -def _has_sampling_params(cfg: GenerationConfig) -> bool: - for field in GENERATION_SAMPLING_FIELDS: - try: - value = getattr(cfg, field) - except AttributeError: - value = None - if value is None: - continue - if isinstance(value, (int, float)): - if field == "temperature" and value != 1.0: - return True - if field in {"top_p", "min_p", "typical_p"} and value < 1.0: - return True - if field == "top_k" and value not in (None, 50): - return True - if field in {"epsilon_cutoff", "eta_cutoff"} and value != 0.0: - return True - if field == "penalty_alpha" and value is not None: - return True - else: - return True - return False - - -def _set_generation_field(cfg: GenerationConfig, field: str, value: Any) -> bool: - """Ensure a generation field matches the requested value, syncing internal dict as well.""" - current = getattr(cfg, field, None) - changed = False - - if value is None: - changed = _drop_generation_field(cfg, field) or changed - else: - if current != value: - setattr(cfg, field, value) - changed = True - internal = getattr(cfg, "_internal_dict", None) - if isinstance(internal, dict) and internal.get(field) != value: - internal[field] = value - changed = True - - return changed - - -def _deregister_auto_config(model_type: str) -> bool: - removed = False - for attr in ("_mapping", "_extra_content", "_modules"): - container = getattr(CONFIG_MAPPING, attr, None) - if isinstance(container, dict) and model_type in container: - container.pop(model_type, None) - removed = True - return removed - - -def safe_auto_config_from_pretrained(*args, **kwargs): - trust_flag = kwargs.get("trust_remote_code") - prev_env = None - if trust_flag: - prev_env = os.environ.get("TRANSFORMERS_TRUST_REMOTE_CODE") - os.environ["TRANSFORMERS_TRUST_REMOTE_CODE"] = "1" - try: - return AutoConfig.from_pretrained(*args, **kwargs) - except ValueError as err: - message = str(err) - if "already used by a Transformers config" not in message: - raise - match = _DUPLICATE_CONFIG_PATTERN.search(message) - if not match: - raise - model_type = match.group(1) - if not _deregister_auto_config(model_type): - raise - log.info("Transformers: cleared cached config registration for `%s` and retrying load.", model_type) - return AutoConfig.from_pretrained(*args, **kwargs) - finally: - if trust_flag: - if prev_env is None: - os.environ.pop("TRANSFORMERS_TRUST_REMOTE_CODE", None) - else: - os.environ["TRANSFORMERS_TRUST_REMOTE_CODE"] = prev_env +GENERATION_SAMPLING_FIELDS = ("temperature", "top_p") def _sanitize_generation_config(cfg: GenerationConfig, *, drop_sampling_fields: bool = False) -> bool: @@ -166,35 +23,16 @@ def _sanitize_generation_config(cfg: GenerationConfig, *, drop_sampling_fields: if cfg is None: return changed - if not hasattr(cfg, "do_sample") or getattr(cfg, "do_sample") is None: - cfg.do_sample = False + if getattr(cfg, "do_sample", None) is not True: + cfg.do_sample = True changed = True - if getattr(cfg, "do_sample", False) is not True and not drop_sampling_fields: - if getattr(cfg, "do_sample") is not False: - cfg.do_sample = False - changed = True - for field, default in _GREEDY_DEFAULTS.items(): - if _set_generation_field(cfg, field, default): - changed = True - elif getattr(cfg, "do_sample", False) is not True and drop_sampling_fields: - if getattr(cfg, "do_sample") is not False: - cfg.do_sample = False - changed = True - for field, default in _GREEDY_DEFAULTS.items(): - if default is None: - if _drop_generation_field(cfg, field): - changed = True - else: - if _set_generation_field(cfg, field, default): + if drop_sampling_fields: + for field in GENERATION_SAMPLING_FIELDS: + if hasattr(cfg, field): + if getattr(cfg, field) is not None: changed = True - elif drop_sampling_fields: - # Sampling configuration is intentionally preserved when `do_sample` remains enabled. - pass - elif getattr(cfg, "do_sample", False) and not _has_sampling_params(cfg): - cfg.do_sample = False - changed = True - log.info("Model: Auto-Fixed `generation_config` by disabling sampling due to missing sampling parameters.") + setattr(cfg, field, None) return changed @@ -210,8 +48,8 @@ def _load_sanitized_generation_config(path: str) -> Optional[GenerationConfig]: if field in cleaned: cleaned.pop(field, None) removed = True - if "do_sample" not in cleaned: - cleaned["do_sample"] = False + if cleaned.get("do_sample") is not True: + cleaned["do_sample"] = True cfg = GenerationConfig.from_dict(cleaned, **kwargs) if removed: @@ -287,21 +125,14 @@ def sanitize_generation_config_file(path: str) -> bool: return False changed = False - do_sample_value = data.get("do_sample") - if do_sample_value is None: - data["do_sample"] = False - do_sample_value = False - changed = True + for field in GENERATION_SAMPLING_FIELDS: + if field in data: + data.pop(field, None) + changed = True - if do_sample_value is False: - for field, default in _GREEDY_DEFAULTS.items(): - if field in data: - if default is None: - data.pop(field, None) - changed = True - elif data[field] != default: - data[field] = default - changed = True + if data.get("do_sample") is not True: + data["do_sample"] = True + changed = True if changed: with open(path, "w", encoding="utf-8") as fp: @@ -328,22 +159,11 @@ def build_shell_model( """ init_kwargs = model_init_kwargs.copy() - configured_dtype = init_kwargs.pop("dtype", None) - if dtype is None and configured_dtype is not None: - dtype = configured_dtype - elif dtype is not None and configured_dtype is not None and configured_dtype != dtype: - log.info("Shell model: overriding duplicate dtype argument from kwargs with explicit `dtype` parameter.") - init_kwargs.pop("device_map", None) - init_kwargs.pop("_fast_init", None) + del init_kwargs["device_map"] + del init_kwargs["_fast_init"] # All nn.Parameters and buffers are created # All nn.Parameters and buffers are created on 'meta' and initializers are skipped. - if dtype is not None: - setattr(config, "dtype", dtype) - store = getattr(config, "__dict__", None) - if isinstance(store, dict) and store.get("torch_dtype") != dtype: - store.pop("torch_dtype", None) - with init_empty_weights(include_buffers=True): shell = loader.from_config( config, diff --git a/gptqmodel/utils/model.py b/gptqmodel/utils/model.py index ec88025b5..9e58da9a2 100644 --- a/gptqmodel/utils/model.py +++ b/gptqmodel/utils/model.py @@ -120,7 +120,7 @@ def _dtype_string_to_torch(dtype_str: Optional[str], fallback: torch.dtype) -> t @dataclass(frozen=True) class OffloadTensorRef: path: str - dtype: torch.dtype + torch_dtype: torch.dtype shape: Tuple[int, ...] format: str # 'dat' or 'safetensors' weight_name: Optional[str] = None @@ -128,19 +128,19 @@ class OffloadTensorRef: @property def num_bytes(self) -> int: - return _torch_dtype_num_bytes(self.dtype) * math.prod(self.shape or (1,)) + return _torch_dtype_num_bytes(self.torch_dtype) * math.prod(self.shape or (1,)) @dataclass class TensorSource: name: str - dtype: torch.dtype + torch_dtype: torch.dtype shape: Tuple[int, ...] source: Union[torch.Tensor, OffloadTensorRef] @property def num_bytes(self) -> int: - return _torch_dtype_num_bytes(self.dtype) * math.prod(self.shape or (1,)) + return _torch_dtype_num_bytes(self.torch_dtype) * math.prod(self.shape or (1,)) def recurse_getattr(obj, attr: str): """ @@ -193,15 +193,7 @@ def move_to(obj: torch.Tensor | nn.Module, device: torch.device, dtype: torch.dt # cpu to non-cpu or non-cpu to non-cpu uses normal .to() api obj = obj.to(device=device, non_blocking=True) else: - try: - obj = obj.to(device=device, dtype=dtype, non_blocking=False) - except NotImplementedError as err: - if isinstance(obj, nn.Module) and "Cannot copy out of meta tensor" in str(err): - obj = obj.to_empty(device=device) - if dtype is not None: - obj = obj.to(dtype=dtype) - else: - raise + obj = obj.to(device=device, dtype=dtype, non_blocking=False) return obj @@ -598,20 +590,20 @@ def convert_gptq_v1_to_v2_format( return model # Limit thread usage to avoid auto-parallizataion regression - with tctl.threadpool_limits(limits=1): - t = time.time() - log.info( - f"Format: Converting `{FORMAT_FIELD_CHECKPOINT}` from `{FORMAT.GPTQ}` to internal `{FORMAT.GPTQ_V2}`.") + # with tctl.threadpool_limits(limits=1): + t = time.time() + log.info( + f"Format: Converting `{FORMAT_FIELD_CHECKPOINT}` from `{FORMAT.GPTQ}` to internal `{FORMAT.GPTQ_V2}`.") - for _, submodule in model.named_modules(): - # v1 checkpoint format used to do `qzeros = qzeros -= 1` before serialization, thus the - # additions here do not overflow. - # v1 checkpoint format with sym=False saved via convert_gptq_v2_to_v1_format() will - # overflow ~<=13% based on testing - if isinstance(submodule, qlinear_kernel): - convert_gptq_v1_to_v2_format_module(module=submodule, bits=cfg.bits, pack_dtype=cfg.pack_dtype) + for _, submodule in model.named_modules(): + # v1 checkpoint format used to do `qzeros = qzeros -= 1` before serialization, thus the + # additions here do not overflow. + # v1 checkpoint format with sym=False saved via convert_gptq_v2_to_v1_format() will + # overflow ~<=13% based on testing + if isinstance(submodule, qlinear_kernel): + convert_gptq_v1_to_v2_format_module(module=submodule, bits=cfg.bits, pack_dtype=cfg.pack_dtype) - log.info(f"Format: Conversion complete: {time.time() - t}s") + #log.info(f"Format: Conversion complete: {time.time() - t}s") return model @@ -673,11 +665,11 @@ def convert_gptq_v2_to_v1_format( return model # Limit thread usage to avoid auto-parallizataion regression - with tctl.threadpool_limits(limits=1): - for _, submodule in model.named_modules(): - # sym=False has underflow probability of ~<=13% during testing. No underflow possible for sym=True. - if isinstance(submodule, qlinear_kernel): - convert_gptq_v2_to_v1_format_module(module=submodule, quantize_config=quantize_config) + # with tctl.threadpool_limits(limits=1): + for _, submodule in model.named_modules(): + # sym=False has underflow probability of ~<=13% during testing. No underflow possible for sym=True. + if isinstance(submodule, qlinear_kernel): + convert_gptq_v2_to_v1_format_module(module=submodule, quantize_config=quantize_config) return model @@ -696,9 +688,9 @@ def pack_module( quant_result: Optional[Dict[str, Any]] = None, ): # Limit pack() thread usage to avoid auto-parallizataion regression - with ctx(tctl.threadpool_limits(limits=1), lock): - layer = layers[name] - module = qModules[name] + # with ctx(tctl.threadpool_limits(limits=1), lock): + layer = layers[name] + module = qModules[name] assert get_device(module) == CPU assert get_device(layer) == CPU @@ -846,9 +838,6 @@ def simple_dispatch_model(model, device_map): for n, d in device_map.items(): m = get_module_by_name_suffix(model, n) - if m is None: - log.warning("Device map entry `%s` could not be resolved to a module; skipping hook.", n) - continue if d != "cpu": d = torch.device(d) hook = AlignDevicesHook(d, io_same_device=True, place_submodules=True) @@ -1259,7 +1248,7 @@ def _resolve_offload_entry( offsets = tuple(int(x) for x in offsets) return OffloadTensorRef( path=os.path.abspath(path), - dtype=resolved_dtype, + torch_dtype=resolved_dtype, shape=shape, format="safetensors", weight_name=entry.get("weight_name", leaf), @@ -1286,7 +1275,7 @@ def _resolve_offload_entry( return OffloadTensorRef( path=os.path.abspath(data_path), - dtype=resolved_dtype, + torch_dtype=resolved_dtype, shape=shape, format="dat", weight_name=None, @@ -1316,7 +1305,7 @@ def _collect_state_dict_with_offload(model: nn.Module, offload_root: str) -> Dic ) else: source = param - state_dict[name] = TensorSource(name=name, dtype=param.dtype, shape=tuple(param.shape), source=source) + state_dict[name] = TensorSource(name=name, torch_dtype=param.dtype, shape=tuple(param.shape), source=source) for name, buf in model.named_buffers(): if name in state_dict: @@ -1337,7 +1326,7 @@ def _collect_state_dict_with_offload(model: nn.Module, offload_root: str) -> Dic ) else: source = buf - state_dict[name] = TensorSource(name=name, dtype=buf.dtype, shape=tuple(buf.shape), source=source) + state_dict[name] = TensorSource(name=name, torch_dtype=buf.dtype, shape=tuple(buf.shape), source=source) return state_dict @@ -1354,11 +1343,11 @@ def get_state_dict_for_save(model: nn.Module, offload_root: Optional[str] = None else: state_dict = collections.OrderedDict() for name, param in model.named_parameters(): - state_dict[name] = TensorSource(name=name, dtype=param.dtype, shape=tuple(param.shape), source=param) + state_dict[name] = TensorSource(name=name, torch_dtype=param.dtype, shape=tuple(param.shape), source=param) for name, buf in model.named_buffers(): if name in state_dict: continue - state_dict[name] = TensorSource(name=name, dtype=buf.dtype, shape=tuple(buf.shape), source=buf) + state_dict[name] = TensorSource(name=name, torch_dtype=buf.dtype, shape=tuple(buf.shape), source=buf) ptrs = collections.defaultdict(list) for name, entry in state_dict.items(): @@ -1452,7 +1441,7 @@ def _write_shard_file(path: str, entries: List[TensorSource], metadata: Dict[str offset = 0 for entry in entries: header[entry.name] = { - "dtype": _torch_dtype_to_safetensors(entry.dtype), + "dtype": _torch_dtype_to_safetensors(entry.torch_dtype), "shape": list(entry.shape), "data_offsets": [offset, offset + entry.num_bytes], } @@ -1481,11 +1470,11 @@ def _write_shard_file(path: str, entries: List[TensorSource], metadata: Dict[str # print("offload tensor slow tensor read") with safe_open(source.path, framework="pt", device="cpu") as handler: tensor = handler.get_tensor(source.weight_name or entry.name) - tensor = tensor.to(source.dtype) - _write_tensor_bytes(out, tensor, source.dtype) + tensor = tensor.to(source.torch_dtype) + _write_tensor_bytes(out, tensor, source.torch_dtype) else: tensor = source.detach() - _write_tensor_bytes(out, tensor, entry.dtype) + _write_tensor_bytes(out, tensor, entry.torch_dtype) del tensor file_size = out.tell() diff --git a/gptqmodel/utils/offload.py b/gptqmodel/utils/offload.py index 26d69c485..a3c544bd5 100644 --- a/gptqmodel/utils/offload.py +++ b/gptqmodel/utils/offload.py @@ -125,6 +125,7 @@ def _offload_to_disk_impl(module: List[str] | nn.Module, model: nn.Module, disk_ assert module is not None assert model is not None + #with _lock: if isinstance(module, List): for name in module: m = get_submodule(model, name) @@ -153,6 +154,10 @@ def _offload_to_disk_impl(module: List[str] | nn.Module, model: nn.Module, disk_ # print_module_tree(module) +# Serialize accelerate's disk hook mutations across threads. +#_OFFLOAD_SAFE = ThreadSafe(sys.modules[__name__]) +#offload_to_disk = _OFFLOAD_SAFE.offload_to_disk + def _offload_disk(module: nn.Module, name: str, disk_path: str = "."): if is_meta_module(module): # print(f"[skip] '{name}' is on meta; leaving as-is") diff --git a/gptqmodel/utils/sglang.py b/gptqmodel/utils/sglang.py index 3aeeb12b4..c48074844 100644 --- a/gptqmodel/utils/sglang.py +++ b/gptqmodel/utils/sglang.py @@ -6,8 +6,7 @@ import multiprocessing as mp import torch - -from .hf import safe_auto_config_from_pretrained +from transformers import AutoConfig try: @@ -32,7 +31,7 @@ def load_model_by_sglang( **kwargs, ) sgl.set_default_backend(runtime) - hf_config = safe_auto_config_from_pretrained( + hf_config = AutoConfig.from_pretrained( model, trust_remote_code=trust_remote_code ) return runtime, hf_config diff --git a/gptqmodel/utils/threadx.py b/gptqmodel/utils/threadx.py index a83dbb065..a59acde5e 100644 --- a/gptqmodel/utils/threadx.py +++ b/gptqmodel/utils/threadx.py @@ -6,14 +6,12 @@ from __future__ import annotations import contextlib -import os import queue import threading import time from concurrent.futures import Future from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union -import threadpoolctl as tctl import torch from .. import DEBUG_ON @@ -354,19 +352,12 @@ def _run(self): may override via `inference_mode`. """ _activate_thread_device(self.device) - while True: + while not self._stop.is_set(): is_task, fn, args, kwargs, fut = self._q.get() try: if not is_task: if DEBUG_ON: log.debug(f"{self.name}: received sentinel; exiting") break - if self._stop.is_set(): - # Pool is stopping; skip executing queued work to allow fast shutdown. - if DEBUG_ON: - log.debug(f"{self.name}: dropping task during shutdown; qsize={self._q.qsize()}") - self._on_task_finished(self.key) - fut.cancel() - continue if DEBUG_ON: log.debug(f"{self.name}: task begin; qsize={self._q.qsize()}") stream = kwargs.pop("cuda_stream", None) @@ -434,14 +425,14 @@ def submit(self, fn: Callable[..., Any], /, *args, cuda_stream=None, **kwargs) - ) use_inference = self._inference_mode if override_inference is None else bool(override_inference) with ctx(self.rwlock.reader(), _device_ctx(self.device)): - with tctl.threadpool_limits(limits=1): - inference_ctx = torch.inference_mode() if use_inference else contextlib.nullcontext() - with inference_ctx: - if stream is not None and self.device.type == "cuda": - with torch.cuda.stream(stream): - result = fn(*args, **kwargs) - else: + # with tctl.threadpool_limits(limits=1): + inference_ctx = torch.inference_mode() if use_inference else contextlib.nullcontext() + with inference_ctx: + if stream is not None and self.device.type == "cuda": + with torch.cuda.stream(stream): result = fn(*args, **kwargs) + else: + result = fn(*args, **kwargs) self._on_task_finished(self.key) if not fut.cancelled(): fut.set_result(result) @@ -505,10 +496,6 @@ def __init__( Unspecified devices default to 1 worker each. gc_debounce_seconds: short wait to coalesce multiple triggers. """ - # Default to threaded workers; allow explicit opt-in to synchronous mode for - # environments where background threads are prohibited. - self._sync_mode = os.environ.get("THREADX_FORCE_SYNC", "0") == "1" - if devices is None: discovered: List[torch.device] = [] if include_cuda and torch.cuda.is_available(): @@ -611,29 +598,19 @@ def __init__( # --------------- Worker management --------------- - def _spawn_worker(self, dev: torch.device, name: Optional[str] = None): + def _spawn_worker(self, dev: torch.device, name: Optional[str] = None) -> _DeviceWorker: """ Create and start a worker bound to the provided device. """ key = self._key(dev) - if self._sync_mode: - w = _SyncWorker( - key=key, - device=dev, - rwlock=self._locks[key], - on_task_finished=self._on_task_finished, - on_worker_exit=self._on_worker_exit, - inference_mode=self._inference_mode, - ) - else: - w = _DeviceWorker( - device=dev, - rwlock=self._locks[key], - on_task_finished=self._on_task_finished, - on_worker_exit=self._on_worker_exit, - name=name, - inference_mode=self._inference_mode, - ) + w = _DeviceWorker( + device=dev, + rwlock=self._locks[key], + on_task_finished=self._on_task_finished, + on_worker_exit=self._on_worker_exit, + name=name, + inference_mode=self._inference_mode, + ) return w def _on_worker_exit(self, key: str, worker: _DeviceWorker) -> None: diff --git a/pyproject.toml b/pyproject.toml index d91dff05e..cdc2de4cf 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -64,6 +64,7 @@ test = [ ] quality = [ "ruff==0.13.0", + # "isort==6.0.1", ] vllm = [ "vllm>=0.10.2", diff --git a/setup.py b/setup.py index e33041681..b7129226d 100644 --- a/setup.py +++ b/setup.py @@ -666,7 +666,7 @@ def run(self): packages=find_packages(), extras_require={ "test": ["pytest>=8.2.2", "parameterized"], - "quality": ["ruff==0.13.0"], + "quality": ["ruff==0.13.0", "isort==6.0.1"], "vllm": ["vllm>=0.8.5", "flashinfer-python>=0.2.1"], "sglang": ["sglang[srt]>=0.4.6", "flashinfer-python>=0.2.1"], "bitblas": ["bitblas==0.0.1-dev13"], diff --git a/tests/models/model_test.py b/tests/models/model_test.py index 213303a20..0626214be 100644 --- a/tests/models/model_test.py +++ b/tests/models/model_test.py @@ -32,37 +32,29 @@ sys.path.insert(0, f"{str(Path(__file__).resolve().parent.parent)}/models") # noqa: E402 import contextlib # noqa: E402 import json # noqa: E402 -import math # noqa: E402 -import random # noqa: E402 import shutil # noqa: E402 import tempfile # noqa: E402 import textwrap # noqa: E402 import unittest # noqa: E402 -from collections import Counter # noqa: E402 from collections.abc import Iterable # noqa: E402 -import torch # noqa: E402 import torch.cuda # noqa: E402 -from datasets import Dataset, concatenate_datasets, load_dataset # noqa: E402 +from datasets import load_dataset # noqa: E402 from ovis.image_to_test_dataset import get_calib_dataset # noqa: E402 from transformers import AutoProcessor, AutoTokenizer # noqa: E402 -from transformers.utils import is_flash_attn_2_available # noqa: E402 from gptqmodel import BACKEND, GPTQModel # noqa: E402 -from gptqmodel.models import Ovis2_5QModel, OvisQModel # noqa: E402 from gptqmodel.nn_modules.qlinear import BaseQuantLinear # noqa: E402 from gptqmodel.quantization import FORMAT, METHOD # noqa: E402 from gptqmodel.quantization.config import QuantizeConfig # noqa: E402 from gptqmodel.utils.eval import EVAL # noqa: E402 from gptqmodel.utils.model import MODALITY # noqa: E402 -from gptqmodel.utils.perplexity import Perplexity # noqa: E402 from gptqmodel.utils.torch import torch_empty_cache # noqa: E402 RAND_SEED = 898 log = LogBar.shared() -ATTN_IMPLEMENTATION_KEY = "attn_implementation" class ModelTest(unittest.TestCase): DEBUG = True # enable extra debug output @@ -84,20 +76,6 @@ class ModelTest(unittest.TestCase): DATASET_SIZE = 256 DATASET_SORT = "asc" DELETE_QUANTIZED_MODEL = True - MAX_QUANT_LAYERS = None - - # post-quant validation controls - POST_QUANT_VALIDATION_BACKENDS = None # default preserves legacy double-backend check - - # calibration noise controls - CALIB_NOISE_PERCENT = 0.0 # share of calibration samples to synthesize - CALIB_NOISE_MODE = "none" # "unseen" # supported: none|random|unseen - CALIB_NOISE_RANDOM_SEED = 1337 - CALIB_NOISE_MIN_SEQ_LEN = 512 - CALIB_NOISE_MAX_SEQ_LEN = 1024 - CALIB_NOISE_GUARD_MAX_FREQ_RATIO = 1.3 - CALIB_NOISE_GUARD_MIN_TTR_FACTOR = 0.95 - CALIB_NOISE_GUARD_MAX_FRACTION = 0.1 KERNEL_QUANT = {} # kernel sets KERNEL_INFERENCE = {} # kernel sets @@ -116,19 +94,7 @@ class ModelTest(unittest.TestCase): SAVE_PATH = None # default is temp folder - MOCK_QUANTIZATION = False - ATTN_IMPLEMENTATION = None # allow forcing a specific attention backend when needed; use "flash_attention_2" - - COMPUTE_PPL = False - PPL_DATASET_PATH = "wikitext" - PPL_DATASET_NAME = "wikitext-2-raw-v1" - PPL_DATASET_SPLIT = "test" - PPL_DATASET_COLUMN = "text" - PPL_CTX = 512 - PPL_BATCH = 512 - PPL_MAX_SAMPLES = 32 - PPL_FALLBACK_CTX = 192 - PPL_FALLBACK_MAX_CHUNKS_PER_SAMPLE = 4 + USE_FLASH_ATTN = True INFERENCE_PROMPT = "The capital city of France is named" INFERENCE_RESULT_KEYWORDS = ["paris"] @@ -199,33 +165,6 @@ def generate_with_limit(self, model, tokenizer, prompt, max_new_tokens=512): ) return tokenizer.decode(generated[0], skip_special_tokens=True) - def _response_matches_keywords(self, response: str, keywords): - if not response: - return False - - normalized = response.lower() - - for keyword in keywords: - if not keyword: - continue - - needle = keyword.lower() - - if needle.isalpha(): - def _strip_other_alpha(text): - return "".join(ch for ch in text if ch.isalpha()) - - if needle in normalized: - return True - - if _strip_other_alpha(needle) in _strip_other_alpha(normalized): - return True - else: - if needle in normalized: - return True - - return False - def run_generic_inference_checks(self, model, tokenizer, backend): model.eval() log.info(f"Post-quant inference checks for backend `{backend.name}`") @@ -235,7 +174,8 @@ def run_generic_inference_checks(self, model, tokenizer, backend): keywords = item["keywords"] try: response = self.generate_with_limit(model, tokenizer, prompt) - matched = self._response_matches_keywords(response, keywords) + normalized = response.lower() + matched = any(keyword.lower() in normalized for keyword in keywords) results.append( { "prompt": prompt, @@ -268,7 +208,6 @@ def run_generic_inference_checks(self, model, tokenizer, backend): def run_arc_challenge_eval(self, model, backend, trust_remote_code=False): previous_backend = self.LOAD_BACKEND self.LOAD_BACKEND = backend - self._ensure_model_attributes(model) try: task_results = self.lm_eval( model=model, @@ -277,42 +216,14 @@ def run_arc_challenge_eval(self, model, backend, trust_remote_code=False): delete_quantized_model=False, ) log.info(f"[{backend.name}] ARC summary: {task_results}") - except AttributeError as exc: - log.warning( - "Skipping ARC eval for backend %s due to attribute error: %s", - backend.name, - exc, - ) - task_results = {} finally: self.LOAD_BACKEND = previous_backend return task_results - @staticmethod - def _ensure_model_attributes(model): - inner_model = getattr(model, "model", None) - if not hasattr(model, "device") and inner_model is not None: - try: - model.device = next(inner_model.parameters()).device - except StopIteration: - model.device = torch.device("cpu") - if not hasattr(model, "config") and inner_model is not None: - setattr(model, "config", getattr(inner_model, "config", None)) - - def get_post_quant_validation_backends(self): - configured = getattr(self, "POST_QUANT_VALIDATION_BACKENDS", None) - if configured: - return tuple(configured) - - if self.FORMAT is FORMAT.GPTQ: - return (BACKEND.MARLIN, BACKEND.TORCH) - return (BACKEND.MARLIN, BACKEND.GEMM) - def perform_post_quant_validation(self, model_path, trust_remote_code=False): inference_records = {} arc_records = {} - compare_backends = self.get_post_quant_validation_backends() - executed_backends = [] + compare_backends = (BACKEND.MARLIN, BACKEND.TORCH) if self.FORMAT is FORMAT.GPTQ else (BACKEND.MARLIN, BACKEND.GEMM) for backend in compare_backends: log.info(f"Loading post-quant model with backend `{backend.name}`") model = self.loadQuantModel( @@ -327,9 +238,8 @@ def perform_post_quant_validation(self, model_path, trust_remote_code=False): finally: del model torch_empty_cache() - executed_backends.append(backend) - self.render_inference_summary(inference_records, executed_backends) - self.render_arc_summary(arc_records, executed_backends) + self.render_inference_summary(inference_records) + self.render_arc_summary(arc_records) @staticmethod def _human_size(num_bytes: int) -> str: @@ -411,13 +321,10 @@ def _colorize(text, matched): reset = "\033[0m" return f"{color}{text}{reset}" - def render_inference_summary(self, inference_records, backends_order=None): + def render_inference_summary(self, inference_records): if not inference_records: return - if backends_order: - ordered_backends = [backend for backend in backends_order if backend in inference_records] - else: - ordered_backends = list(inference_records.keys()) + ordered_backends = [backend for backend in (BACKEND.MARLIN, BACKEND.TORCH) if backend in inference_records] if not ordered_backends: return @@ -475,13 +382,10 @@ def _format_inference_entry(self, entry): cell = f"{status} | {snippet}" if snippet else status return self._colorize(cell, matched) - def render_arc_summary(self, arc_records, backends_order=None): + def render_arc_summary(self, arc_records): if not arc_records: return - if backends_order: - ordered_backends = [backend for backend in backends_order if backend in arc_records] - else: - ordered_backends = list(arc_records.keys()) + ordered_backends = [backend for backend in (BACKEND.MARLIN, BACKEND.TORCH) if backend in arc_records] if not ordered_backends: return @@ -502,10 +406,10 @@ def render_arc_summary(self, arc_records, backends_order=None): if value is None: row.append(self._colorize("N/A", False)) continue - if backend == BACKEND.TORCH or reference_value is None: + if backend == BACKEND.TORCH: row.append(self._colorize(f"{value:.4f}", True)) else: - matched = abs(value - reference_value) <= tolerance + matched = reference_value is not None and abs(value - reference_value) <= tolerance row.append(self._colorize(f"{value:.4f}", matched)) table_rows.append(row) @@ -525,9 +429,8 @@ def load_dataset(cls, tokenizer=None, rows: int = 0): dataset = cls._load_calibration_parquet() if rows > 0: - dataset = dataset.select(range(min(rows, len(dataset)))) - - return cls._apply_calibration_noise(dataset, tokenizer) + return dataset.select(range(min(rows, len(dataset)))) + return dataset @staticmethod def _load_calibration_parquet(): @@ -590,512 +493,6 @@ def select(self, indices): return self.__class__(selected) - @classmethod - def _apply_calibration_noise(cls, dataset, tokenizer): - mode = (getattr(cls, "CALIB_NOISE_MODE", "none") or "none").lower() - share = float(getattr(cls, "CALIB_NOISE_PERCENT", 0.0) or 0.0) - - cls._noise_summary = { - "mode": mode, - "share": share, - "requested": 0, - "generated": 0, - "applied": False, - "reason": "disabled", - } - - if dataset is None or tokenizer is None or share <= 0.0 or mode == "none": - return dataset - - records = cls._materialize_records(dataset) - if not records: - cls._noise_summary["reason"] = "empty_dataset" - return dataset - - requested = max(1, int(len(records) * share)) - cls._noise_summary.update({ - "requested": requested, - "reason": "generated", - }) - - stats = cls._collect_token_stats(records, tokenizer) - if stats["total_tokens"] == 0: - cls._noise_summary["reason"] = "no_tokens" - return dataset - - noise_records, noise_token_sequences = cls._build_noise_records( - tokenizer=tokenizer, - stats=stats, - sample_count=requested, - mode=mode, - base_records=records, - ) - - if not noise_records: - cls._noise_summary["reason"] = "no_samples" - return dataset - - if not cls._passes_noise_guard(stats, noise_token_sequences): - cls._noise_summary["reason"] = "guard_block" - return dataset - - merged = cls._merge_noise(dataset, noise_records, base_records=records) - cls._noise_summary.update({ - "generated": len(noise_records), - "applied": True, - "reason": "ok", - }) - log.info( - "Injected %s synthetic calibration samples (mode=%s, avg_len=%s)", - len(noise_records), - mode, - stats["avg_length"], - ) - return merged - - @staticmethod - def _materialize_records(dataset): - if dataset is None: - return [] - if hasattr(dataset, "to_list"): - try: - return dataset.to_list() - except TypeError: - pass - try: - return [dataset[idx] for idx in range(len(dataset))] - except Exception: # pragma: no cover - defensive fallback - return list(dataset) - - @staticmethod - def _extract_text(record): - if isinstance(record, dict): - if record.get("text"): - return record["text"] - if record.get("messages"): - parts = [msg.get("content", "") for msg in record["messages"]] - return "\n".join(part for part in parts if part) - return "" - - @classmethod - def _collect_token_stats(cls, records, tokenizer): - counts = Counter() - total_tokens = 0 - for record in records: - text = cls._extract_text(record) - if not text: - continue - encoded = tokenizer(text, add_special_tokens=False).get("input_ids", []) - counts.update(encoded) - total_tokens += len(encoded) - - unique_tokens = len(counts) - avg_length = max(1, round(total_tokens / max(len(records), 1))) - type_token_ratio = (unique_tokens / total_tokens) if total_tokens else 0.0 - max_freq = max(counts.values()) if counts else 0 - - return { - "counts": counts, - "total_tokens": total_tokens, - "unique_tokens": unique_tokens, - "avg_length": avg_length, - "type_token_ratio": type_token_ratio, - "max_freq": max_freq, - } - - @staticmethod - def _get_special_token_id_set(tokenizer): - special_ids = set() - - def _add(value): - if value is None: - return - if isinstance(value, int): - special_ids.add(int(value)) - return - if isinstance(value, str): - converted = tokenizer.convert_tokens_to_ids(value) - if converted is None: - return - if isinstance(converted, list): - for item in converted: - if item is not None: - special_ids.add(int(item)) - else: - special_ids.add(int(converted)) - return - if isinstance(value, (list, tuple, set)): - for item in value: - _add(item) - return - if isinstance(value, dict): - for item in value.values(): - _add(item) - - attr_names = ( - "all_special_ids", - "additional_special_tokens_ids", - "special_tokens_map", - "special_tokens_map_extended", - "bos_token_id", - "eos_token_id", - "pad_token_id", - "sep_token_id", - "cls_token_id", - "mask_token_id", - "unk_token_id", - ) - - for name in attr_names: - _add(getattr(tokenizer, name, None)) - - return {int(token_id) for token_id in special_ids if token_id is not None and token_id >= 0} - - @classmethod - def _filter_special_token_ids(cls, token_ids, tokenizer): - if not token_ids: - return [] - special_ids = cls._get_special_token_id_set(tokenizer) - return [tok for tok in token_ids if tok not in special_ids and tok is not None] - - @classmethod - def _build_noise_records(cls, tokenizer, stats, sample_count, mode, base_records): - if sample_count <= 0: - return [], [] - - if mode == "structured": - return cls._build_structured_noise_records( - tokenizer=tokenizer, - base_records=base_records, - sample_count=sample_count, - stats=stats, - ) - - rng = random.Random(cls.CALIB_NOISE_RANDOM_SEED) - - try: - vocab_values = list(tokenizer.get_vocab().values()) - except Exception: # pragma: no cover - tokenizer fallback - vocab_values = list(range(getattr(tokenizer, "vocab_size", 0))) - - vocab_ids = cls._filter_special_token_ids(vocab_values, tokenizer) - if not vocab_ids: - return [], [] - - existing_ids = set(stats["counts"].keys()) - if mode == "unseen": - vocab_ids = [tok for tok in vocab_ids if tok not in existing_ids] - if not vocab_ids: - log.warning("No unseen tokens available for noise generation; skipping") - return [], [] - - seq_min = max(1, int(getattr(cls, "CALIB_NOISE_MIN_SEQ_LEN", 32))) - seq_max = max(seq_min, int(getattr(cls, "CALIB_NOISE_MAX_SEQ_LEN", 256))) - avg_target = min(seq_max, max(seq_min, stats["avg_length"])) - std_dev = max(1, int(avg_target * 0.2)) - - records = [] - token_sequences = [] - max_iterations = max(sample_count * 3, 32) - - for _ in range(max_iterations): - if len(records) >= sample_count: - break - - length = int(rng.gauss(avg_target, std_dev)) - length = max(seq_min, min(seq_max, length)) - if length <= 1: - continue - - if mode == "unseen": - length = min(length, len(vocab_ids)) - if length <= 1: - continue - token_ids = rng.sample(vocab_ids, k=length) - else: - token_ids = rng.choices(vocab_ids, k=length) - - split = max(1, min(length - 1, length // 2)) - prompt_tokens = token_ids[:split] - completion_tokens = token_ids[split:] - - prompt_text = tokenizer.decode(prompt_tokens, skip_special_tokens=True).strip() - completion_text = tokenizer.decode(completion_tokens, skip_special_tokens=True).strip() - - if not prompt_text: - prompt_text = " ".join(tokenizer.convert_ids_to_tokens(prompt_tokens)) - if not completion_text: - completion_text = " ".join(tokenizer.convert_ids_to_tokens(completion_tokens)) - - combined_text = ( - "### Instruction:\n" - + prompt_text - + "\n\n### Response:\n" - + completion_text - ).strip() - - encoded = tokenizer(combined_text, add_special_tokens=False) - seq_ids = encoded.get("input_ids", []) - if not seq_ids: - continue - - records.append( - { - "text": combined_text, - "messages": [ - {"role": "user", "content": prompt_text}, - {"role": "assistant", "content": completion_text}, - ], - } - ) - token_sequences.append(seq_ids) - - return records, token_sequences - - @classmethod - def _build_structured_noise_records(cls, tokenizer, base_records, sample_count, stats): - if not base_records: - log.warning("Structured noise requested but no base records; skipping") - return [], [] - - rng = random.Random(cls.CALIB_NOISE_RANDOM_SEED + 17) - replacement_pool = cls._filter_special_token_ids(list(stats["counts"].keys()), tokenizer) - if not replacement_pool: - try: - replacement_pool = list(tokenizer.get_vocab().values()) - except Exception: # pragma: no cover - tokenizer fallback - replacement_pool = list(range(getattr(tokenizer, "vocab_size", 0))) - replacement_pool = cls._filter_special_token_ids(replacement_pool, tokenizer) - - records = [] - token_sequences = [] - max_attempts = max(sample_count * 6, 48) - - for _ in range(max_attempts): - if len(records) >= sample_count: - break - - base_record = rng.choice(base_records) - user_text, assistant_text = cls._extract_message_pair(base_record) - if not user_text and not assistant_text: - base_text = cls._extract_text(base_record) - user_text, assistant_text = cls._split_instruction_response(base_text) - - if not user_text and not assistant_text: - continue - - pert_user = cls._perturb_text(user_text, tokenizer, rng, replacement_pool, stats) - pert_assistant = cls._perturb_text(assistant_text, tokenizer, rng, replacement_pool, stats) - - pert_user = pert_user or user_text - pert_assistant = pert_assistant or assistant_text or pert_user - - if not pert_user or not pert_assistant: - continue - - combined_text = ( - "### Instruction:\n" - + pert_user.strip() - + "\n\n### Response:\n" - + pert_assistant.strip() - ).strip() - - encoded = tokenizer(combined_text, add_special_tokens=False) - seq_ids = encoded.get("input_ids", []) - if len(seq_ids) < max(4, cls.CALIB_NOISE_MIN_SEQ_LEN // 2): - continue - - records.append( - { - "text": combined_text, - "messages": [ - {"role": "user", "content": pert_user.strip()}, - {"role": "assistant", "content": pert_assistant.strip()}, - ], - } - ) - token_sequences.append(seq_ids) - - return records, token_sequences - - @staticmethod - def _extract_message_pair(record): - if not isinstance(record, dict): - return "", "" - - messages = record.get("messages") - if not isinstance(messages, list): - return "", "" - - user_text = "" - assistant_text = "" - for message in messages: - if not isinstance(message, dict): - continue - role = message.get("role") - content = (message.get("content") or "").strip() - if not content: - continue - if role == "user" and not user_text: - user_text = content - elif role == "assistant" and not assistant_text: - assistant_text = content - if user_text and assistant_text: - break - - return user_text, assistant_text - - @staticmethod - def _split_instruction_response(text): - if not text: - return "", "" - - instruction = "" - response = "" - - if "### Response" in text: - parts = text.split("### Response", 1) - instruction = parts[0].replace("### Instruction:", "").strip() - response = parts[1].replace(":", "", 1).strip() - elif "Response:" in text: - parts = text.split("Response:", 1) - instruction = parts[0].replace("Instruction:", "").strip() - response = parts[1].strip() - else: - split_parts = text.split("\n\n", 1) - if len(split_parts) == 2: - instruction, response = split_parts[0].strip(), split_parts[1].strip() - else: - mid = len(text) // 2 - instruction, response = text[:mid].strip(), text[mid:].strip() - - return instruction, response - - @classmethod - def _perturb_text(cls, text, tokenizer, rng, replacement_pool, stats): - if not text: - return "" - - encoded = tokenizer(text, add_special_tokens=False) - token_ids = list(encoded.get("input_ids", [])) - if len(token_ids) <= 2: - return text.strip() - - operations = ["shuffle", "drop", "replace"] - operation = rng.choice(operations) - tokens = list(token_ids) - - if operation == "shuffle" and len(tokens) > 8: - chunk_size = max(1, len(tokens) // rng.randint(3, 6)) - segments = [tokens[i:i + chunk_size] for i in range(0, len(tokens), chunk_size)] - rng.shuffle(segments) - tokens = [tok for segment in segments for tok in segment] - elif operation == "drop" and len(tokens) > 6: - max_drop = max(1, min(len(tokens) // 4, 24)) - span = rng.randint(1, max_drop) - if len(tokens) > span: - start = rng.randint(0, len(tokens) - span) - del tokens[start:start + span] - else: # replace - span = max(1, min(len(tokens) // 5, 16)) - if len(tokens) > span: - start = rng.randint(0, len(tokens) - span) - replacement = cls._sample_replacement_tokens( - rng, - replacement_pool, - span, - stats, - tokenizer, - ) - if replacement: - tokens[start:start + span] = replacement - - if not tokens: - tokens = token_ids - - new_text = tokenizer.decode(tokens, skip_special_tokens=True).strip() - return new_text or text.strip() - - @classmethod - def _sample_replacement_tokens(cls, rng, candidates, length, stats, tokenizer): - if not candidates: - candidates = list(stats["counts"].keys()) - candidates = cls._filter_special_token_ids(candidates, tokenizer) - if not candidates: - candidates = list(range(1, max(512, length * 4))) - candidates = cls._filter_special_token_ids(candidates, tokenizer) - if not candidates: - return [] - return [rng.choice(candidates) for _ in range(length)] - - @classmethod - def _passes_noise_guard(cls, stats, noise_sequences): - if not noise_sequences: - return False - - base_counts = stats["counts"].copy() - base_total = stats["total_tokens"] - base_max = stats["max_freq"] - base_ttr = stats["type_token_ratio"] - - noise_counts = Counter() - noise_token_total = 0 - for seq in noise_sequences: - noise_counts.update(seq) - noise_token_total += len(seq) - - if noise_token_total == 0: - return False - - combined_counts = base_counts + noise_counts - combined_max = max(combined_counts.values()) if combined_counts else 0 - if base_max == 0: - base_max = combined_max or 1 - - max_freq_ratio = combined_max / max(base_max, 1) - if max_freq_ratio > getattr(cls, "CALIB_NOISE_GUARD_MAX_FREQ_RATIO", 1.3): - log.info( - "Noise guard triggered by max frequency ratio %.4f", max_freq_ratio - ) - return False - - combined_total = base_total + noise_token_total - combined_unique = len(combined_counts) - combined_ttr = (combined_unique / combined_total) if combined_total else 0.0 - min_ttr = base_ttr * getattr(cls, "CALIB_NOISE_GUARD_MIN_TTR_FACTOR", 0.95) - if base_ttr and combined_ttr < min_ttr: - log.info( - "Noise guard triggered by type-token ratio %.4f < %.4f", - combined_ttr, - min_ttr, - ) - return False - - noise_fraction = noise_token_total / combined_total if combined_total else 1.0 - if noise_fraction > getattr(cls, "CALIB_NOISE_GUARD_MAX_FRACTION", 0.1): - log.info( - "Noise guard triggered by noise fraction %.4f", noise_fraction - ) - return False - - return True - - @classmethod - def _merge_noise(cls, dataset, noise_records, base_records=None): - if dataset is None: - return cls._LocalCalibrationDataset(noise_records) - - try: - noise_dataset = Dataset.from_list(noise_records) - return concatenate_datasets([dataset, noise_dataset]) - except Exception as exc: # pragma: no cover - fall back to python dataset - log.warning("Falling back to in-memory dataset for noise merge: %s", exc) - if base_records is None: - base_records = cls._materialize_records(dataset) - combined = list(base_records) + list(noise_records) - return cls._LocalCalibrationDataset(combined) - - def check_kernel(self, model, expected_kernels): modules = {module.__class__ for _, module in model.named_modules() if isinstance(module, BaseQuantLinear)} print(f"modules in model: {modules}") @@ -1114,19 +511,16 @@ def quantModel(self, model_id_or_path, trust_remote_code=False, dtype="auto", ne sym=self.SYM, v2=self.V2, adapter=self.EORA, - mock_quantization=self.MOCK_QUANTIZATION, - offload_to_disk=getattr(self, "OFFLOAD_TO_DISK", True), ) log.info(f"Quant config: {quantize_config}") log.info(f"Quant batch_size: {batch_size}") args = kwargs if kwargs else {} - if ( - self.ATTN_IMPLEMENTATION is not None - and ATTN_IMPLEMENTATION_KEY not in args - ): - args[ATTN_IMPLEMENTATION_KEY] = self.ATTN_IMPLEMENTATION + + if self.USE_FLASH_ATTN: + args["attn_implementation"] = "flash_attention_2" + log.info(f"args: {args}") model = GPTQModel.load( @@ -1144,51 +538,22 @@ def quantModel(self, model_id_or_path, trust_remote_code=False, dtype="auto", ne is_image_to_text_model = MODALITY.IMAGE_TO_TEXT in model.modality calibration_dataset = get_calib_dataset(model) if is_image_to_text_model else self.load_dataset(tokenizer, self.DATASET_SIZE) - noise_summary = getattr(self, "_noise_summary", None) - if noise_summary and noise_summary.get("mode") != "none": - log.info(f"Calibration noise summary: {noise_summary}") - # mpt model need - model_cfg = getattr(model, "config", None) - if model_cfg is None: - inner = getattr(model, "model", None) - if inner is not None: - model_cfg = getattr(inner, "config", None) - if model_cfg is None: - model_cfg = getattr(model, "model_config", None) - if model_cfg is not None: - if not getattr(model_cfg, "pad_token_id", None): - model_cfg.pad_token_id = tokenizer.pad_token_id or 0 - if not getattr(model_cfg, "eos_token_id", None): - model_cfg.eos_token_id = tokenizer.eos_token_id or 0 + if not model.config.pad_token_id: + model.config.pad_token_id = tokenizer.pad_token_id or 0 + if not model.config.eos_token_id: + model.config.eos_token_id = tokenizer.eos_token_id or 0 is_quantized = model.quantized # ovis cannot load processor - is_ovis_model = isinstance(model, (OvisQModel, Ovis2_5QModel)) + is_ovis_model = model.__class__.__name__ == "OvisGPTQ" need_create_processor = is_image_to_text_model and not is_ovis_model if not is_quantized: - prev_max_layers = os.environ.get("GPTQMODEL_MAX_QUANT_LAYERS") - max_layers_limit = getattr(self, "MAX_QUANT_LAYERS", None) - if max_layers_limit is not None: - os.environ["GPTQMODEL_MAX_QUANT_LAYERS"] = str(max_layers_limit) - try: - model.quantize(calibration_dataset, calibration_sort=self.DATASET_SORT, backend=self.QUANT_BACKEND, batch_size=batch_size) - finally: - if max_layers_limit is not None: - if prev_max_layers is None: - os.environ.pop("GPTQMODEL_MAX_QUANT_LAYERS", None) - else: - os.environ["GPTQMODEL_MAX_QUANT_LAYERS"] = prev_max_layers + model.quantize(calibration_dataset, calibration_sort=self.DATASET_SORT, backend=self.QUANT_BACKEND, batch_size=batch_size) self.check_kernel(model, self.KERNEL_QUANT) - if self.MOCK_QUANTIZATION: - if need_create_processor: - processor = AutoProcessor.from_pretrained(model_id_or_path) - return model, tokenizer, processor - return model, tokenizer - # TODO: make into shared method with (contextlib.nullcontext(self.SAVE_PATH) if self.SAVE_PATH else contextlib.nullcontext(tempfile.mkdtemp()) if need_eval else tempfile.TemporaryDirectory()) as path: os.makedirs(path, exist_ok=True) @@ -1223,21 +588,12 @@ def quantModel(self, model_id_or_path, trust_remote_code=False, dtype="auto", ne def loadQuantModel(self, model_id_or_path, trust_remote_code=False, tokenizer_path=None, backend=None, **args): load_kwargs = dict(args) - if ( - self.ATTN_IMPLEMENTATION is not None - and ATTN_IMPLEMENTATION_KEY not in load_kwargs - ): - load_kwargs[ATTN_IMPLEMENTATION_KEY] = self.ATTN_IMPLEMENTATION - elif ATTN_IMPLEMENTATION_KEY not in load_kwargs and is_flash_attn_2_available(): - load_kwargs[ATTN_IMPLEMENTATION_KEY] = "flash_attention_2" + + if self.USE_FLASH_ATTN: + load_kwargs["attn_implementation"] = "flash_attention_2" active_backend = backend if backend is not None else self.LOAD_BACKEND - import os - print("[DEBUG] loadQuantModel", model_id_or_path, trust_remote_code, active_backend) - if trust_remote_code: - prev_env = os.environ.get("TRANSFORMERS_TRUST_REMOTE_CODE") - os.environ["TRANSFORMERS_TRUST_REMOTE_CODE"] = "1" model = GPTQModel.load( model_id_or_path, trust_remote_code=trust_remote_code, @@ -1247,11 +603,6 @@ def loadQuantModel(self, model_id_or_path, trust_remote_code=False, tokenizer_pa adapter=self.EORA, **load_kwargs ) - if trust_remote_code: - if prev_env is None: - os.environ.pop("TRANSFORMERS_TRUST_REMOTE_CODE", None) - else: - os.environ["TRANSFORMERS_TRUST_REMOTE_CODE"] = prev_env return model @@ -1358,19 +709,11 @@ def quant_lm_eval(self): self.check_kernel(self.model, self.KERNEL_INFERENCE) - if self.MOCK_QUANTIZATION: - task_results = { - "acc,none": self.NATIVE_ARC_CHALLENGE_ACC, - "acc_norm,none": self.NATIVE_ARC_CHALLENGE_ACC_NORM, - } - else: - task_results = self.lm_eval(model=self.SAVE_PATH if self.SAVE_PATH else self.model, - apply_chat_template=self.APPLY_CHAT_TEMPLATE, - trust_remote_code=self.TRUST_REMOTE_CODE, - delete_quantized_model=self.DELETE_QUANTIZED_MODEL) + task_results = self.lm_eval(model=self.SAVE_PATH if self.SAVE_PATH else self.model, + apply_chat_template=self.APPLY_CHAT_TEMPLATE, + trust_remote_code=self.TRUST_REMOTE_CODE, + delete_quantized_model=self.DELETE_QUANTIZED_MODEL) self.check_results(task_results) - self.last_task_results = task_results - self._maybe_compute_perplexity(self.model) def check_results(self, task_results): for filter, value in task_results.items(): @@ -1379,148 +722,6 @@ def check_results(self, task_results): positive_pct = 100 * (1 + self.QUANT_ARC_MAX_POSITIVE_DELTA_CEIL_PERCENT) self.assertTrue(negative_pct <= diff_pct <= positive_pct, f"{filter}: `{value}` vs expected `{expected}`, diff {diff_pct:.2f}% is out of the expected range [{negative_pct}-{positive_pct}%]") - def _maybe_compute_perplexity(self, model): - self.perplexity_scores = [] - self.perplexity_avg = None - self.perplexity_error = None - - if not self.COMPUTE_PPL or model is None: - return None - - tokenizer = getattr(model, "tokenizer", None) - if tokenizer is None: - log.warning("Model has no tokenizer; skipping perplexity computation") - return None - - try: - ppl_runner = Perplexity( - model=model, - tokenizer=tokenizer, - dataset_path=self.PPL_DATASET_PATH, - dataset_name=self.PPL_DATASET_NAME, - split=self.PPL_DATASET_SPLIT, - text_column=self.PPL_DATASET_COLUMN, - ) - scores = ppl_runner.calculate(n_ctx=self.PPL_CTX, n_batch=self.PPL_BATCH) - if scores: - self.perplexity_scores = scores - self.perplexity_avg = sum(scores) / len(scores) - log.info( - "Perplexity average: %.4f computed over %s windows", - self.perplexity_avg, - len(scores), - ) - else: - log.warning("Perplexity calculation returned no scores") - return self.perplexity_avg - except Exception as exc: # pragma: no cover - diagnostics only - self.perplexity_error = str(exc) - log.error(f"Perplexity computation failed: {exc}") - return self._compute_perplexity_fallback(model, tokenizer) - - def _compute_perplexity_fallback(self, model, tokenizer): - max_samples = getattr(self, "PPL_MAX_SAMPLES", 32) - dataset = None - try: - dataset = load_dataset( - self.PPL_DATASET_PATH, - self.PPL_DATASET_NAME, - split=self.PPL_DATASET_SPLIT, - ) - except Exception as exc: # pragma: no cover - dataset missing - log.error(f"Fallback perplexity dataset load failed: {exc}") - return None - - sample_count = min(max_samples, len(dataset)) - if sample_count == 0: - log.warning("Fallback perplexity has no samples to evaluate") - return None - - max_context = getattr(model.config, "max_position_embeddings", self.PPL_CTX) - fallback_ctx = min( - getattr(self, "PPL_FALLBACK_CTX", self.PPL_CTX), - self.PPL_CTX, - max_context, - ) - fallback_ctx = max(32, fallback_ctx) - max_chunks_per_sample = max(1, getattr(self, "PPL_FALLBACK_MAX_CHUNKS_PER_SAMPLE", 1)) - rng = random.Random(self.CALIB_NOISE_RANDOM_SEED + 202) - - total_tokens = 0 - total_neg_log_likelihood = 0.0 - - for entry in dataset.select(range(sample_count)): - text = entry.get(self.PPL_DATASET_COLUMN) - if not text: - continue - token_ids = tokenizer(text, add_special_tokens=False).get("input_ids", []) - if len(token_ids) <= 1: - continue - - offsets = self._select_fallback_offsets( - length=len(token_ids), - chunk_size=fallback_ctx + 1, - max_chunks=max_chunks_per_sample, - rng=rng, - ) - - for offset in offsets: - chunk = token_ids[offset: offset + fallback_ctx + 1] - if len(chunk) <= 1: - continue - - input_tensor = torch.tensor( - chunk[:-1], dtype=torch.long, device=model.device - ).unsqueeze(0) - labels = torch.tensor( - chunk[1:], dtype=torch.long, device=model.device - ).unsqueeze(0) - attention_mask = torch.ones_like(input_tensor, dtype=torch.long) - - with torch.inference_mode(): - outputs = model( - input_ids=input_tensor, - attention_mask=attention_mask, - labels=labels, - ) - - token_count = labels.numel() - total_tokens += token_count - total_neg_log_likelihood += outputs.loss.item() * token_count - - if total_tokens == 0: - log.warning("Fallback perplexity produced zero tokens") - return None - - average_loss = total_neg_log_likelihood / total_tokens - ppl = math.exp(average_loss) - self.perplexity_scores = [ppl] - self.perplexity_avg = ppl - log.info( - "Fallback perplexity average: %.4f computed over %s tokens", - ppl, - total_tokens, - ) - return ppl - - @staticmethod - def _select_fallback_offsets(length, chunk_size, max_chunks, rng): - if length <= chunk_size or max_chunks <= 1: - return [0] - - limit = max(0, length - chunk_size) - offsets = set() - attempts = 0 - max_attempts = max_chunks * 6 - - while len(offsets) < max_chunks and attempts < max_attempts: - offsets.add(rng.randint(0, limit)) - attempts += 1 - - if not offsets: - return [0] - return sorted(offsets) - def check_lm_head_loss(self, quant_log: List[Dict[str, any]]): final_log = quant_log[-1] if final_log["module"] == "lm_head": diff --git a/tests/models/ovis/image_to_test_dataset.py b/tests/models/ovis/image_to_test_dataset.py index bb955570a..c2e7172cf 100644 --- a/tests/models/ovis/image_to_test_dataset.py +++ b/tests/models/ovis/image_to_test_dataset.py @@ -3,7 +3,7 @@ # SPDX-License-Identifier: Apache-2.0 # Contact: qubitium@modelcloud.ai, x.com/qubitium -from gptqmodel.models import Ovis2_5QModel, OvisQModel +from gptqmodel.models import OvisQModel from gptqmodel.models.definitions.base_qwen2_5_omni import BaseQwen2_5_OmniGPTQ from gptqmodel.models.definitions.base_qwen2_vl import BaseQwen2VLGPTQ @@ -36,31 +36,6 @@ def format_qwen2_vl_dataset(image, assistant): {"role": "assistant", "content": assistant}, ] - -def format_ovis2_5_dataset(image, assistant): - return [ - { - "role": "system", - "content": [ - { - "type": "text", - "text": "You are a helpful multimodal assistant. Provide factual image descriptions and mention visible text." - } - ], - }, - { - "role": "user", - "content": [ - {"type": "image", "image": image}, - { - "type": "text", - "text": "Write a detailed description of this image, including any visible text and the overall style." - }, - ], - }, - {"role": "assistant", "content": assistant}, - ] - def format_qwen2_5_omni_dataset(image, assistant): return [ { @@ -93,9 +68,6 @@ def prepare_dataset(format_func, n_sample: int = 20) -> list[list[dict]]: def get_calib_dataset(model): - if isinstance(model, Ovis2_5QModel): - return prepare_dataset(format_ovis2_5_dataset, n_sample=20) - if isinstance(model, OvisQModel): return prepare_dataset(format_ovis_dataset, n_sample=20) diff --git a/tests/models/test_bloom.py b/tests/models/test_bloom.py index bc58c2874..1d67ea4f5 100644 --- a/tests/models/test_bloom.py +++ b/tests/models/test_bloom.py @@ -12,6 +12,7 @@ class TestBloom(ModelTest): NATIVE_ARC_CHALLENGE_ACC = 0.2201 NATIVE_ARC_CHALLENGE_ACC_NORM = 0.2440 TORCH_DTYPE = torch.float16 + USE_FLASH_ATTN = False def test_bloom(self): self.quant_lm_eval() diff --git a/tests/models/test_chatglm.py b/tests/models/test_chatglm.py index 03e09d9d1..0c63bcffa 100644 --- a/tests/models/test_chatglm.py +++ b/tests/models/test_chatglm.py @@ -14,6 +14,7 @@ class TestChatGlm(ModelTest): NATIVE_ARC_CHALLENGE_ACC = 0.3319 NATIVE_ARC_CHALLENGE_ACC_NORM = 0.3729 TRUST_REMOTE_CODE = True + USE_FLASH_ATTN = False def test_chatglm(self): self.quant_lm_eval() diff --git a/tests/models/test_codegen.py b/tests/models/test_codegen.py index 969925501..83a92b1c5 100644 --- a/tests/models/test_codegen.py +++ b/tests/models/test_codegen.py @@ -12,6 +12,7 @@ class TestCodeGen(ModelTest): NATIVE_ARC_CHALLENGE_ACC_NORM = 0.2005 TRUST_REMOTE_CODE = True USE_VLLM = False + USE_FLASH_ATTN = False def test_codegen(self): self.quant_lm_eval() diff --git a/tests/models/test_cohere2.py b/tests/models/test_cohere2.py index c0920a23d..9c900e086 100644 --- a/tests/models/test_cohere2.py +++ b/tests/models/test_cohere2.py @@ -12,6 +12,7 @@ class TestCohere2(ModelTest): NATIVE_ARC_CHALLENGE_ACC_NORM = 0.4693 QUANT_ARC_MAX_DELTA_FLOOR_PERCENT = 0.15 EVAL_BATCH_SIZE = 4 + USE_FLASH_ATTN = False def test_cohere2(self): self.quant_lm_eval() diff --git a/tests/models/test_ernie4_5.py b/tests/models/test_ernie4_5.py index 01b05b845..d7afc8e79 100644 --- a/tests/models/test_ernie4_5.py +++ b/tests/models/test_ernie4_5.py @@ -12,6 +12,7 @@ class TestErnie4_5(ModelTest): NATIVE_ARC_CHALLENGE_ACC_NORM = 0.3183 TRUST_REMOTE_CODE = True EVAL_BATCH_SIZE = 6 + USE_FLASH_ATTN = False def test_exaone(self): self.quant_lm_eval() diff --git a/tests/models/test_gptj.py b/tests/models/test_gptj.py index edd87d501..5baea9712 100644 --- a/tests/models/test_gptj.py +++ b/tests/models/test_gptj.py @@ -13,6 +13,7 @@ class TestGptJ(ModelTest): NATIVE_ARC_CHALLENGE_ACC_NORM = 0.3660 TORCH_DTYPE = torch.float16 INPUTS_MAX_LENGTH = 1024 + USE_FLASH_ATTN = False def test_gptj(self): self.quant_lm_eval() diff --git a/tests/models/test_internlm.py b/tests/models/test_internlm.py index af2492d4b..09390e5cc 100644 --- a/tests/models/test_internlm.py +++ b/tests/models/test_internlm.py @@ -12,6 +12,7 @@ class TestInternlm(ModelTest): NATIVE_ARC_CHALLENGE_ACC_NORM = 0.4309 TRUST_REMOTE_CODE = True USE_VLLM = False + USE_FLASH_ATTN = False def test_internlm(self): # transformers<=4.44.2 run normal diff --git a/tests/models/test_llama3_2.py b/tests/models/test_llama3_2.py index 7553f59d6..b72e33cef 100644 --- a/tests/models/test_llama3_2.py +++ b/tests/models/test_llama3_2.py @@ -23,6 +23,7 @@ class TestLlama3_2(ModelTest): DATASET_SIZE = 1024 DATASET_SORT = "desc" QUANT_BATCH_SIZE = 4 + # USE_FLASH_ATTN = False # EORA = Lora( # # for quant, path is save path. for load, it is loading path # path="./eora_test", diff --git a/tests/models/test_llama3_2_awq.py b/tests/models/test_llama3_2_awq.py index b6f9fffaf..e903c4bb9 100644 --- a/tests/models/test_llama3_2_awq.py +++ b/tests/models/test_llama3_2_awq.py @@ -27,6 +27,7 @@ class TestLlama3_2(ModelTest): QUANT_BATCH_SIZE = 4 FORMAT = FORMAT.GEMM METHOD = METHOD.AWQ + # USE_FLASH_ATTN = False # EORA = Lora( # # for quant, path is save path. for load, it is loading path # path="./eora_test", diff --git a/tests/models/test_longllama.py b/tests/models/test_longllama.py index 0bc4e4a3c..b06742e70 100644 --- a/tests/models/test_longllama.py +++ b/tests/models/test_longllama.py @@ -13,6 +13,7 @@ class TestLongLlama(ModelTest): TRUST_REMOTE_CODE = True QUANT_ARC_MAX_DELTA_FLOOR_PERCENT = 0.5 USE_VLLM = False + USE_FLASH_ATTN = False def test_longllama(self): self.quant_lm_eval() diff --git a/tests/models/test_mpt.py b/tests/models/test_mpt.py index 56ab78463..c940bc5b0 100644 --- a/tests/models/test_mpt.py +++ b/tests/models/test_mpt.py @@ -13,10 +13,7 @@ class TestMpt(ModelTest): APPLY_CHAT_TEMPLATE = False TRUST_REMOTE_CODE = False EVAL_BATCH_SIZE = 6 - DATASET_SIZE = 96 - MAX_QUANT_LAYERS = None - MOCK_QUANTIZATION = True - OFFLOAD_TO_DISK = False + USE_FLASH_ATTN = False def test_mpt(self): self.quant_lm_eval() diff --git a/tests/models/test_ovis2_5.py b/tests/models/test_ovis2_5.py deleted file mode 100644 index feebf2ab9..000000000 --- a/tests/models/test_ovis2_5.py +++ /dev/null @@ -1,130 +0,0 @@ -# SPDX-FileCopyrightText: 2024-2025 ModelCloud.ai -# SPDX-FileCopyrightText: 2024-2025 qubitium@modelcloud.ai -# SPDX-License-Identifier: Apache-2.0 -# Contact: qubitium@modelcloud.ai, x.com/qubitium - -import json -import os -from pathlib import Path -from types import SimpleNamespace - -import pytest -import torch -from model_test import ModelTest -from ovis.image_to_test_dataset import get_calib_dataset - -from gptqmodel import BACKEND, GPTQModel -from gptqmodel.models.definitions.ovis2_5 import Ovis2_5QModel -from gptqmodel.quantization.config import QuantizeConfig - - -def test_ovis2_5_config_shape(): - config_path = Path("/monster/data/model/Ovis2.5-9B/config.json") - with config_path.open("r", encoding="utf-8") as handle: - config = json.load(handle) - - assert config["model_type"].lower() == "ovis2_5" - assert config["visual_vocab_size"] == 65536 - assert config["llm_config"]["num_hidden_layers"] == 36 - assert config["vit_config"]["model_type"] == "siglip2_navit" - - -def test_ovis2_5_class_metadata(): - assert Ovis2_5QModel.require_trust_remote_code is True - assert Ovis2_5QModel.pre_lm_head_norm_module == "llm.model.norm" - assert Ovis2_5QModel.module_tree[0] == "llm" - - -def test_ovis2_5_prepare_dataset_quantization_ready(): - instance = object.__new__(Ovis2_5QModel) - instance.IGNORE_ID = -100 - instance.quantize_config = QuantizeConfig(bits=4, group_size=128, desc_act=False, sym=True, mock_quantization=True) - fake_visual_tokenizer = SimpleNamespace(vit=SimpleNamespace(dtype=torch.float32)) - - def _preprocess(_messages): - input_ids = torch.tensor([[0, 1, 2]]) - pixel_values = torch.ones((1, 3), dtype=torch.float32) - grid_thws = torch.tensor([[1, 1, 1]], dtype=torch.long) - return input_ids, pixel_values, grid_thws - - instance.model = SimpleNamespace( - preprocess_inputs=_preprocess, - text_tokenizer=SimpleNamespace(pad_token_id=0, eos_token_id=2), - visual_tokenizer=fake_visual_tokenizer, - vte=SimpleNamespace(), - ) - - sample_dataset = [ - { - "messages": [ - { - "role": "user", - "content": [ - {"type": "text", "text": "describe the scene"}, - ], - } - ] - } - ] - - prepared = instance.prepare_dataset(sample_dataset, batch_size=1) - assert len(prepared) == 1 - batch = prepared[0] - assert set(batch.keys()) == {"input_ids", "attention_mask", "labels", "pixel_values", "grid_thws"} - assert batch["input_ids"].shape[0] == 1 - assert batch["attention_mask"].dtype == torch.bool - assert batch["attention_mask"].shape == batch["input_ids"].shape - assert isinstance(batch["pixel_values"], torch.Tensor) and batch["pixel_values"].shape[-1] == 3 - assert isinstance(batch["grid_thws"], torch.Tensor) - - -class TestOvis2_5Quant(ModelTest): - NATIVE_MODEL_ID = "/monster/data/model/Ovis2.5-9B" - TRUST_REMOTE_CODE = True - APPLY_CHAT_TEMPLATE = False - QUANT_BATCH_SIZE = 1 - DATASET_SIZE = 2 - POST_QUANT_VALIDATION_BACKENDS = [] - MOCK_QUANTIZATION = False - OFFLOAD_TO_DISK = False - LOAD_BACKEND = BACKEND.TORCH - QUANT_BACKEND = BACKEND.TORCH - MAX_QUANT_LAYERS = 1 - - def test_quantize_single_layer(self): - if not torch.cuda.is_available(): - pytest.skip("CUDA is required for Ovis2.5 quantization test") - - quant_cfg = QuantizeConfig( - bits=self.BITS, - group_size=self.GROUP_SIZE, - desc_act=False, - act_group_aware=self.ACT_GROUP_AWARE, - sym=self.SYM, - mock_quantization=False, - fail_safe=self.FAIL_SAFE, - v2=self.V2, - ) - quant_cfg.device = "cuda:0" - - model = GPTQModel.load( - self.NATIVE_MODEL_ID, - quantize_config=quant_cfg, - trust_remote_code=True, - dtype=self.TORCH_DTYPE, - ) - - calibration_dataset = get_calib_dataset(model)[: self.DATASET_SIZE] - - prev_max_layers = os.environ.get("GPTQMODEL_MAX_QUANT_LAYERS") - os.environ["GPTQMODEL_MAX_QUANT_LAYERS"] = str(self.MAX_QUANT_LAYERS) - try: - model.quantize(calibration_dataset, batch_size=self.QUANT_BATCH_SIZE) - finally: - if prev_max_layers is None: - os.environ.pop("GPTQMODEL_MAX_QUANT_LAYERS", None) - else: - os.environ["GPTQMODEL_MAX_QUANT_LAYERS"] = prev_max_layers - - assert model.quantized is True - torch.cuda.empty_cache() diff --git a/tests/models/test_ovis_1_6_llama.py b/tests/models/test_ovis_1_6_llama.py index 419ab2c18..b899f6704 100644 --- a/tests/models/test_ovis_1_6_llama.py +++ b/tests/models/test_ovis_1_6_llama.py @@ -16,6 +16,7 @@ class TestOvis1_6_Llama(ModelTest): TRUST_REMOTE_CODE = True APPLY_CHAT_TEMPLATE = False EVAL_BATCH_SIZE = 1 + USE_FLASH_ATTN = False def test_ovis_1_6(self): model, tokenizer = self.quantModel(self.NATIVE_MODEL_ID, trust_remote_code=self.TRUST_REMOTE_CODE, diff --git a/tests/models/test_phi_4.py b/tests/models/test_phi_4.py index 9bcec091a..6084e3d43 100644 --- a/tests/models/test_phi_4.py +++ b/tests/models/test_phi_4.py @@ -12,6 +12,7 @@ class TestPhi_4(ModelTest): NATIVE_ARC_CHALLENGE_ACC_NORM = 0.5674 APPLY_CHAT_TEMPLATE = True TRUST_REMOTE_CODE = True + USE_FLASH_ATTN = False BATCH_SIZE = 1 def test_phi_4(self): diff --git a/tests/models/test_qwen3_moe.py b/tests/models/test_qwen3_moe.py index b1f101b79..86411ec02 100644 --- a/tests/models/test_qwen3_moe.py +++ b/tests/models/test_qwen3_moe.py @@ -9,8 +9,8 @@ class TestQwen3Moe(ModelTest): NATIVE_MODEL_ID = "/monster/data/model/Qwen3-30B-A3B" QUANT_ARC_MAX_DELTA_FLOOR_PERCENT = 0.2 - NATIVE_ARC_CHALLENGE_ACC = 0.39 - NATIVE_ARC_CHALLENGE_ACC_NORM = 0.39 + NATIVE_ARC_CHALLENGE_ACC = 0.2739 + NATIVE_ARC_CHALLENGE_ACC_NORM = 0.3055 # TRUST_REMOTE_CODE = False APPLY_CHAT_TEMPLATE = True # EVAL_BATCH_SIZE = 6 diff --git a/tests/models/test_qwen3_omni.py b/tests/models/test_qwen3_omni.py index 495b96663..d386a424a 100644 --- a/tests/models/test_qwen3_omni.py +++ b/tests/models/test_qwen3_omni.py @@ -20,9 +20,7 @@ class TestQwen3Omni(ModelTest): DESC_ACT = False DATASET_SIZE = 1024 DATASET_SORT = "desc" - QUANT_BATCH_SIZE = 4 - CALIB_NOISE_MODE = "unseen" - CALIB_NOISE_PERCENT = 0.10 + QUANT_BATCH_SIZE = 1 def test_omni(self): self.quant_lm_eval() diff --git a/tests/models/test_telechat2.py b/tests/models/test_telechat2.py index 512d73365..ca7c396da 100644 --- a/tests/models/test_telechat2.py +++ b/tests/models/test_telechat2.py @@ -13,6 +13,7 @@ class TestTeleChat_2(ModelTest): TRUST_REMOTE_CODE = True EVAL_BATCH_SIZE = 6 USE_VLLM = False + USE_FLASH_ATTN = False def test_telechat2(self): diff --git a/tests/models/test_xverse.py b/tests/models/test_xverse.py index bbdd99f42..6ee887def 100644 --- a/tests/models/test_xverse.py +++ b/tests/models/test_xverse.py @@ -15,6 +15,7 @@ class TestXVerse(ModelTest): APPLY_CHAT_TEMPLATE = True EVAL_BATCH_SIZE = 6 USE_VLLM = False + USE_FLASH_ATTN = False def test_xverse(self): self.quant_lm_eval() diff --git a/tests/test_packing.py b/tests/test_packing.py index df8448fb1..8823a9dd4 100644 --- a/tests/test_packing.py +++ b/tests/test_packing.py @@ -12,8 +12,11 @@ import unittest # noqa: E402 + +# isort: off import torch # noqa: E402 import torch.nn as nn # noqa: E402 +# isort: on from parameterized import parameterized # noqa: E402 from gptqmodel import BACKEND # noqa: E402 diff --git a/tests/test_packing_speed.py b/tests/test_packing_speed.py index 6f13440da..6ba824c48 100644 --- a/tests/test_packing_speed.py +++ b/tests/test_packing_speed.py @@ -16,9 +16,13 @@ import unittest # noqa: E402 import threadpoolctl # noqa: E402 +from parameterized import parameterized # noqa: E402 + + +# isort: off import torch # noqa: E402 import torch.nn as nn # noqa: E402 -from parameterized import parameterized # noqa: E402 +# isort: on from gptqmodel.nn_modules.qlinear.torch import TorchQuantLinear # noqa: E402