From 58467f487f3777ed4cec09c318c486426caeebf6 Mon Sep 17 00:00:00 2001 From: Qubitium Date: Sat, 4 Oct 2025 03:12:34 +0000 Subject: [PATCH 1/4] fix opt compat where is max_position_embeddings limit Signed-off-by: Qubitium --- gptqmodel/models/base.py | 69 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 69 insertions(+) diff --git a/gptqmodel/models/base.py b/gptqmodel/models/base.py index d88587d33..d8fb1d49f 100644 --- a/gptqmodel/models/base.py +++ b/gptqmodel/models/base.py @@ -545,10 +545,70 @@ def _convert_tensor_to_list(tensor): new_calibration_dataset = [] too_short_calibration_data_count = 0 + + max_positions = None + max_positions_source = None + trimmed_row_count = 0 + longest_trimmed_row = 0 + + def _maybe_resolve_length(value, source_name): + nonlocal max_positions, max_positions_source + try: + if value is None: + return False + limit = int(value) + except Exception: + return False + if limit <= 0: + return False + if max_positions is None or limit < max_positions: + max_positions = limit + max_positions_source = source_name + return True + + model_config = getattr(self.model, "config", None) + if model_config is not None: + primary_names = ("max_position_embeddings",) + fallback_names = ( + "max_sequence_length", + "max_seq_len", + "n_positions", + "seq_length", + ) + + for attr_name in primary_names: + if _maybe_resolve_length(getattr(model_config, attr_name, None), attr_name): + break + if max_positions is None: + for attr_name in fallback_names: + if _maybe_resolve_length(getattr(model_config, attr_name, None), attr_name): + break + for example in calibration_dataset: input_ids = _convert_tensor_to_list(example["input_ids"]) attention_mask = _convert_tensor_to_list(example["attention_mask"]) + if max_positions is not None: + trimmed = False + trimmed_input_ids = [] + trimmed_attention_mask = [] + + for row_ids, row_mask in zip(input_ids, attention_mask): + row_len = len(row_ids) + if row_len > max_positions: + trimmed = True + trimmed_row_count += 1 + longest_trimmed_row = max(longest_trimmed_row, row_len) + trimmed_input_ids.append(row_ids[:max_positions]) + trimmed_attention_mask.append(row_mask[:max_positions]) + else: + trimmed_input_ids.append(row_ids) + trimmed_attention_mask.append(row_mask) + + if trimmed: + input_ids = trimmed_input_ids + attention_mask = trimmed_attention_mask + # filter if input_ids is too short if len(input_ids[0]) <= calibration_data_min_length: too_short_calibration_data_count += 1 @@ -565,6 +625,15 @@ def _convert_tensor_to_list(tensor): log.warn(f"Quantize: {too_short_calibration_data_count} input_ids with length <= {calibration_data_min_length} were removed. " f"Use quantize(calibration_data_min_length={calibration_data_min_length}) to set a custom minimum length.") + if trimmed_row_count > 0: + log.info( + "Quantize: trimmed %s calibration rows above %s=%s (longest original length=%s)", + trimmed_row_count, + max_positions_source, + max_positions, + longest_trimmed_row, + ) + if calibration_dataset_concat_size: _require_tokenizer("`calibration_dataset_concat_size` is specified") concatenated_data = [] From 93a246b2e9e3d99ab30a45b718f2ea11d76cb5f0 Mon Sep 17 00:00:00 2001 From: Qubitium Date: Sat, 4 Oct 2025 03:37:14 +0000 Subject: [PATCH 2/4] fix qwen2.5_omni compat with latest transformers Signed-off-by: Qubitium --- .../models/definitions/base_qwen2_5_omni.py | 29 +++++++++++++++++++ 1 file changed, 29 insertions(+) diff --git a/gptqmodel/models/definitions/base_qwen2_5_omni.py b/gptqmodel/models/definitions/base_qwen2_5_omni.py index 23f43fc7f..4dcdf72ff 100644 --- a/gptqmodel/models/definitions/base_qwen2_5_omni.py +++ b/gptqmodel/models/definitions/base_qwen2_5_omni.py @@ -57,6 +57,10 @@ def pre_quantize_generate_hook_start(self): self.shell_module_materialize(self.model.thinker.audio_tower, self.quantize_config.device) self.shell_module_materialize(self.model.thinker.visual.rotary_pos_emb, self.quantize_config.device) self.shell_module_materialize(self.model.thinker.model.rotary_emb, self.quantize_config.device) + if hasattr(self.model, "talker"): + self.shell_module_materialize(self.model.talker, self.quantize_config.device) + if hasattr(self.model, "token2wav"): + self.shell_module_materialize(self.model.token2wav, self.quantize_config.device) for layer in self.model.thinker.model.layers: self.shell_module_materialize(layer.self_attn.rotary_emb, self.quantize_config.device) @@ -87,6 +91,17 @@ def pre_quantize_generate_hook_end(self): disk_path=self.quantize_config.offload_to_disk_path, ) + if hasattr(self.model, "talker"): + offload_to_disk(model=self.model, + module=self.model.talker, + disk_path=self.quantize_config.offload_to_disk_path, + ) + if hasattr(self.model, "token2wav"): + offload_to_disk(model=self.model, + module=self.model.token2wav, + disk_path=self.quantize_config.offload_to_disk_path, + ) + for layer in self.model.thinker.model.layers: layer.self_attn.rotary_emb = layer.self_attn.rotary_emb.to(CPU) @@ -95,6 +110,10 @@ def pre_quantize_generate_hook_end(self): self.model.thinker.model.embed_tokens = self.model.thinker.model.embed_tokens.to(CPU) self.model.thinker.visual = self.model.thinker.visual.to(CPU) self.model.thinker.audio_tower = self.model.thinker.audio_tower.to(CPU) + if hasattr(self.model, "talker"): + self.model.talker = self.model.talker.to(CPU) + if hasattr(self.model, "token2wav"): + self.model.token2wav = self.model.token2wav.to(CPU) self.model.thinker.visual.rotary_pos_emb = self.model.thinker.visual.rotary_pos_emb.to(CPU) self.model.thinker.model.rotary_emb = self.model.thinker.model.rotary_emb.to(CPU) @@ -121,6 +140,16 @@ def process_vision_info( def preprocess_dataset(self, sample: Dict) -> Dict: return sample + def forward(self, *args, **kwargs): + """Delegate textual forward passes to the thinker submodule. + + The top-level Hugging Face wrapper leaves ``forward`` unimplemented when + ``trust_remote_code`` is disabled, so we expose the thinker equivalent to + keep tooling such as lm-eval operational in quantized environments. + """ + + return self.model.thinker(*args, **kwargs) + def load_processor(self) -> ProcessorMixin: return AutoProcessor.from_pretrained(self.model_local_path) From 725c018312fadb724a10c4166ea61c4218a16885 Mon Sep 17 00:00:00 2001 From: Qubitium Date: Sat, 4 Oct 2025 04:10:33 +0000 Subject: [PATCH 3/4] fix qwen3_omni compat with latest transformers Signed-off-by: Qubitium --- .../models/definitions/qwen3_omni_moe.py | 20 ++++++++++++++++++- 1 file changed, 19 insertions(+), 1 deletion(-) diff --git a/gptqmodel/models/definitions/qwen3_omni_moe.py b/gptqmodel/models/definitions/qwen3_omni_moe.py index 30eceae24..414c2c160 100644 --- a/gptqmodel/models/definitions/qwen3_omni_moe.py +++ b/gptqmodel/models/definitions/qwen3_omni_moe.py @@ -49,6 +49,10 @@ def pre_quantize_generate_hook_start(self): self.shell_module_materialize(self.model.thinker.audio_tower, self.quantize_config.device) self.shell_module_materialize(self.model.thinker.visual.rotary_pos_emb, self.quantize_config.device) self.shell_module_materialize(self.model.thinker.model.rotary_emb, self.quantize_config.device) + if hasattr(self.model, "talker"): + self.shell_module_materialize(self.model.talker, self.quantize_config.device) + if hasattr(self.model, "code2wav"): + self.shell_module_materialize(self.model.code2wav, self.quantize_config.device) def pre_quantize_generate_hook_end(self): if self.quantize_config.offload_to_disk: @@ -76,11 +80,26 @@ def pre_quantize_generate_hook_end(self): module=self.model.thinker.model.rotary_emb, disk_path=self.quantize_config.offload_to_disk_path, ) + + if hasattr(self.model, "talker"): + offload_to_disk(model=self.model, + module=self.model.talker, + disk_path=self.quantize_config.offload_to_disk_path, + ) + if hasattr(self.model, "code2wav"): + offload_to_disk(model=self.model, + module=self.model.code2wav, + disk_path=self.quantize_config.offload_to_disk_path, + ) return self.model.thinker.model.embed_tokens = self.model.thinker.model.embed_tokens.to(CPU) self.model.thinker.visual = self.model.thinker.visual.to(CPU) self.model.thinker.audio_tower = self.model.thinker.audio_tower.to(CPU) + if hasattr(self.model, "talker"): + self.model.talker = self.model.talker.to(CPU) + if hasattr(self.model, "code2wav"): + self.model.code2wav = self.model.code2wav.to(CPU) self.model.thinker.visual.rotary_pos_emb = self.model.thinker.visual.rotary_pos_emb.to(CPU) self.model.thinker.model.rotary_emb = self.model.thinker.model.rotary_emb.to(CPU) @@ -91,4 +110,3 @@ def after_model_load(self, model, load_quantized_model=False): self.processor = AutoProcessor.from_pretrained(self.model_local_path) return model - From 522a1c10df5dc2b50ffadcf1e24cafd50302b56e Mon Sep 17 00:00:00 2001 From: Qubitium Date: Sat, 4 Oct 2025 04:49:05 +0000 Subject: [PATCH 4/4] offload is not thread safe Signed-off-by: Qubitium --- gptqmodel/models/base.py | 29 ++++++++++++------- .../models/definitions/qwen3_omni_moe.py | 6 ++++ gptqmodel/models/loader.py | 8 +++++ gptqmodel/utils/offload.py | 10 +++++-- tests/models/test_qwen3_moe.py | 2 +- 5 files changed, 40 insertions(+), 15 deletions(-) diff --git a/gptqmodel/models/base.py b/gptqmodel/models/base.py index d8fb1d49f..861ec171e 100644 --- a/gptqmodel/models/base.py +++ b/gptqmodel/models/base.py @@ -195,12 +195,20 @@ def __init__( ): super().__init__() - self.model = self.after_model_load(model, load_quantized_model=load_quantized_model) - self.turtle_model = turtle_model - + # record configuration early so model lifecycle hooks can rely on them self.compiled = False # set to True while compile() is triggered successfully self.quantized = quantized self.load_quantized_model = load_quantized_model + self.qlinear_kernel = qlinear_kernel + self.trust_remote_code = trust_remote_code + self.model_local_path = model_local_path + self.quantize_config = quantize_config + + self.processor: ProcessorMixin = None + + self.model = self.after_model_load(model, load_quantized_model=load_quantized_model) + self.turtle_model = turtle_model + if tokenizer is not None: if isinstance(tokenizer, PreTrainedTokenizerBase): self.tokenizer = Tokenicer.load(tokenizer, trust_remote_code=trust_remote_code) @@ -216,8 +224,6 @@ def __init__( if isinstance(self.model, PreTrainedModel): autofix_hf_model_config(self.model, path=model_local_path) - self.quantize_config = quantize_config - self._background_pool: Optional["DeviceThreadPool"] = None self._turtle_reload_future: Optional[Future] = None self._turtle_reload_lock = threading.Lock() @@ -225,13 +231,9 @@ def __init__( self._turtle_ready.set() # compat: state to assist in checkpoint_format gptq(v1) to gptq_v2 conversion - self.qlinear_kernel = qlinear_kernel - self.trust_remote_code = trust_remote_code - self.model_local_path = model_local_path # stores all per-layer quant stats such as avg loss and processing time self.quant_log = [] - self.processor: ProcessorMixin = None if self.require_load_processor: self.processor = AutoProcessor.from_pretrained(model_local_path) @@ -1698,8 +1700,13 @@ def tied_word_embedding(self) -> bool: def __getattr__(self, item): try: return super().__getattr__(item) - except Exception: - return getattr(self.model, item) + except Exception as exc: # torch Modules raise AttributeError here + model = self.__dict__.get("model") + if model is None: + model = self._modules.get("model") if hasattr(self, "_modules") else None + if model is not None and item != "model": + return getattr(model, item) + raise exc __all__ = ["BaseQModel"] diff --git a/gptqmodel/models/definitions/qwen3_omni_moe.py b/gptqmodel/models/definitions/qwen3_omni_moe.py index 414c2c160..a9f1ce696 100644 --- a/gptqmodel/models/definitions/qwen3_omni_moe.py +++ b/gptqmodel/models/definitions/qwen3_omni_moe.py @@ -3,6 +3,8 @@ # SPDX-License-Identifier: Apache-2.0 # Contact: qubitium@modelcloud.ai, x.com/qubitium +import os + import torch from transformers import AutoModelForTextToWaveform, AutoProcessor @@ -44,6 +46,10 @@ class Qwen3OmniMoeGPTQ(BaseQModel): ] def pre_quantize_generate_hook_start(self): + spk_path = os.path.join(self.model_local_path, "spk_dict.pt") + if os.path.isfile(spk_path): + self.model.load_speakers(spk_path) + self.shell_module_materialize(self.model.thinker.model.embed_tokens, self.quantize_config.device) self.shell_module_materialize(self.model.thinker.visual, self.quantize_config.device) self.shell_module_materialize(self.model.thinker.audio_tower, self.quantize_config.device) diff --git a/gptqmodel/models/loader.py b/gptqmodel/models/loader.py index c1a2a4f9b..802e2a009 100644 --- a/gptqmodel/models/loader.py +++ b/gptqmodel/models/loader.py @@ -181,6 +181,10 @@ def skip(*args, **kwargs): # TODO FIX ME for `dynamic`, non-quantized modules should be in native type dtype = auto_dtype(config=config, device=quantize_config.device, quant_inference=False) + if isinstance(dtype, torch.dtype) and getattr(config, "torch_dtype", None) != dtype: + # Align config metadata with the dtype we will materialize weights in. + config.torch_dtype = dtype + # enforce some values despite user specified # non-quantized models are always loaded into cpu model_init_kwargs["device_map"] = cpu_device_map @@ -320,6 +324,10 @@ def from_quantized( # TODO FIX ME for `dynamic`, non-quantized modules should be in native type dtype = auto_dtype(config=config, device=device, quant_inference=True) + if isinstance(dtype, torch.dtype) and getattr(config, "torch_dtype", None) != dtype: + # Ensure flash attention kernels see an explicit dtype instead of relying on defaults. + config.torch_dtype = dtype + qcfg = QuantizeConfig.from_pretrained(model_local_path, **cached_file_kwargs, **kwargs) if qcfg.quant_method == METHOD.AWQ and qcfg.format in [FORMAT.GEMV_FAST]: diff --git a/gptqmodel/utils/offload.py b/gptqmodel/utils/offload.py index 56864f9c2..a117dcd6d 100644 --- a/gptqmodel/utils/offload.py +++ b/gptqmodel/utils/offload.py @@ -6,7 +6,7 @@ import contextlib import os import shutil -import threading +import sys from typing import Iterable, List, Optional, Set, Tuple import accelerate @@ -21,10 +21,9 @@ from ..looper.named_module import NamedModule from .device import get_device from .torch import CPU, META +from .safe import ThreadSafe -_lock = threading.Lock() - # Patch fix thread unsafe accelerate.utils.modeling.clear_device_cache def _fake_clear_device_cache(garbage_collection=False): pass @@ -96,6 +95,11 @@ def offload_to_disk(module: List[str] | nn.Module, model: nn.Module, disk_path: # print("offload_disk: list item tree") # print_module_tree(module) + +# Serialize accelerate's disk hook mutations across threads. +_OFFLOAD_SAFE = ThreadSafe(sys.modules[__name__]) +offload_to_disk = _OFFLOAD_SAFE.offload_to_disk + def _offload_disk(module: nn.Module, name: str, disk_path: str = "."): if is_meta_module(module): # print(f"[skip] '{name}' is on meta; leaving as-is") diff --git a/tests/models/test_qwen3_moe.py b/tests/models/test_qwen3_moe.py index 1f3b00f92..90de5a090 100644 --- a/tests/models/test_qwen3_moe.py +++ b/tests/models/test_qwen3_moe.py @@ -18,7 +18,7 @@ class TestQwen3Moe(ModelTest): DEBUG = True ACT_GROUP_AWARE = True DESC_ACT = False - DATASET_SIZE = 1024 * 8 + DATASET_SIZE = 1024 DATASET_SORT = "desc" QUANT_BATCH_SIZE = 4