Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
98 changes: 87 additions & 11 deletions gptqmodel/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,12 +195,20 @@ def __init__(
):
super().__init__()

self.model = self.after_model_load(model, load_quantized_model=load_quantized_model)
self.turtle_model = turtle_model

# record configuration early so model lifecycle hooks can rely on them
self.compiled = False # set to True while compile() is triggered successfully
self.quantized = quantized
self.load_quantized_model = load_quantized_model
self.qlinear_kernel = qlinear_kernel
self.trust_remote_code = trust_remote_code
self.model_local_path = model_local_path
self.quantize_config = quantize_config

self.processor: ProcessorMixin = None

self.model = self.after_model_load(model, load_quantized_model=load_quantized_model)
self.turtle_model = turtle_model

if tokenizer is not None:
if isinstance(tokenizer, PreTrainedTokenizerBase):
self.tokenizer = Tokenicer.load(tokenizer, trust_remote_code=trust_remote_code)
Expand All @@ -216,22 +224,16 @@ def __init__(
if isinstance(self.model, PreTrainedModel):
autofix_hf_model_config(self.model, path=model_local_path)

self.quantize_config = quantize_config

self._background_pool: Optional["DeviceThreadPool"] = None
self._turtle_reload_future: Optional[Future] = None
self._turtle_reload_lock = threading.Lock()
self._turtle_ready = threading.Event()
self._turtle_ready.set()

# compat: state to assist in checkpoint_format gptq(v1) to gptq_v2 conversion
self.qlinear_kernel = qlinear_kernel
self.trust_remote_code = trust_remote_code
self.model_local_path = model_local_path
# stores all per-layer quant stats such as avg loss and processing time
self.quant_log = []

self.processor: ProcessorMixin = None
if self.require_load_processor:
self.processor = AutoProcessor.from_pretrained(model_local_path)

Expand Down Expand Up @@ -545,10 +547,70 @@ def _convert_tensor_to_list(tensor):

new_calibration_dataset = []
too_short_calibration_data_count = 0

max_positions = None
max_positions_source = None
trimmed_row_count = 0
longest_trimmed_row = 0

def _maybe_resolve_length(value, source_name):
nonlocal max_positions, max_positions_source
try:
if value is None:
return False
limit = int(value)
except Exception:
return False
if limit <= 0:
return False
if max_positions is None or limit < max_positions:
max_positions = limit
max_positions_source = source_name
return True

model_config = getattr(self.model, "config", None)
if model_config is not None:
primary_names = ("max_position_embeddings",)
fallback_names = (
"max_sequence_length",
"max_seq_len",
"n_positions",
"seq_length",
)

for attr_name in primary_names:
if _maybe_resolve_length(getattr(model_config, attr_name, None), attr_name):
break
if max_positions is None:
for attr_name in fallback_names:
if _maybe_resolve_length(getattr(model_config, attr_name, None), attr_name):
break

for example in calibration_dataset:
input_ids = _convert_tensor_to_list(example["input_ids"])
attention_mask = _convert_tensor_to_list(example["attention_mask"])

if max_positions is not None:
trimmed = False
trimmed_input_ids = []
trimmed_attention_mask = []

for row_ids, row_mask in zip(input_ids, attention_mask):
row_len = len(row_ids)
if row_len > max_positions:
trimmed = True
trimmed_row_count += 1
longest_trimmed_row = max(longest_trimmed_row, row_len)
trimmed_input_ids.append(row_ids[:max_positions])
trimmed_attention_mask.append(row_mask[:max_positions])
else:
trimmed_input_ids.append(row_ids)
trimmed_attention_mask.append(row_mask)

if trimmed:
input_ids = trimmed_input_ids
attention_mask = trimmed_attention_mask

# filter if input_ids is too short
if len(input_ids[0]) <= calibration_data_min_length:
too_short_calibration_data_count += 1
Expand All @@ -565,6 +627,15 @@ def _convert_tensor_to_list(tensor):
log.warn(f"Quantize: {too_short_calibration_data_count} input_ids with length <= {calibration_data_min_length} were removed. "
f"Use quantize(calibration_data_min_length={calibration_data_min_length}) to set a custom minimum length.")

if trimmed_row_count > 0:
log.info(
"Quantize: trimmed %s calibration rows above %s=%s (longest original length=%s)",
trimmed_row_count,
max_positions_source,
max_positions,
longest_trimmed_row,
)

if calibration_dataset_concat_size:
_require_tokenizer("`calibration_dataset_concat_size` is specified")
concatenated_data = []
Expand Down Expand Up @@ -1629,8 +1700,13 @@ def tied_word_embedding(self) -> bool:
def __getattr__(self, item):
try:
return super().__getattr__(item)
except Exception:
return getattr(self.model, item)
except Exception as exc: # torch Modules raise AttributeError here
model = self.__dict__.get("model")
if model is None:
model = self._modules.get("model") if hasattr(self, "_modules") else None
if model is not None and item != "model":
return getattr(model, item)
raise exc

__all__ = ["BaseQModel"]

Expand Down
29 changes: 29 additions & 0 deletions gptqmodel/models/definitions/base_qwen2_5_omni.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,10 @@ def pre_quantize_generate_hook_start(self):
self.shell_module_materialize(self.model.thinker.audio_tower, self.quantize_config.device)
self.shell_module_materialize(self.model.thinker.visual.rotary_pos_emb, self.quantize_config.device)
self.shell_module_materialize(self.model.thinker.model.rotary_emb, self.quantize_config.device)
if hasattr(self.model, "talker"):
self.shell_module_materialize(self.model.talker, self.quantize_config.device)
if hasattr(self.model, "token2wav"):
self.shell_module_materialize(self.model.token2wav, self.quantize_config.device)
for layer in self.model.thinker.model.layers:
self.shell_module_materialize(layer.self_attn.rotary_emb, self.quantize_config.device)

Expand Down Expand Up @@ -87,6 +91,17 @@ def pre_quantize_generate_hook_end(self):
disk_path=self.quantize_config.offload_to_disk_path,
)

if hasattr(self.model, "talker"):
offload_to_disk(model=self.model,
module=self.model.talker,
disk_path=self.quantize_config.offload_to_disk_path,
)
if hasattr(self.model, "token2wav"):
offload_to_disk(model=self.model,
module=self.model.token2wav,
disk_path=self.quantize_config.offload_to_disk_path,
)

for layer in self.model.thinker.model.layers:
layer.self_attn.rotary_emb = layer.self_attn.rotary_emb.to(CPU)

Expand All @@ -95,6 +110,10 @@ def pre_quantize_generate_hook_end(self):
self.model.thinker.model.embed_tokens = self.model.thinker.model.embed_tokens.to(CPU)
self.model.thinker.visual = self.model.thinker.visual.to(CPU)
self.model.thinker.audio_tower = self.model.thinker.audio_tower.to(CPU)
if hasattr(self.model, "talker"):
self.model.talker = self.model.talker.to(CPU)
if hasattr(self.model, "token2wav"):
self.model.token2wav = self.model.token2wav.to(CPU)

self.model.thinker.visual.rotary_pos_emb = self.model.thinker.visual.rotary_pos_emb.to(CPU)
self.model.thinker.model.rotary_emb = self.model.thinker.model.rotary_emb.to(CPU)
Expand All @@ -121,6 +140,16 @@ def process_vision_info(
def preprocess_dataset(self, sample: Dict) -> Dict:
return sample

def forward(self, *args, **kwargs):
"""Delegate textual forward passes to the thinker submodule.

The top-level Hugging Face wrapper leaves ``forward`` unimplemented when
``trust_remote_code`` is disabled, so we expose the thinker equivalent to
keep tooling such as lm-eval operational in quantized environments.
"""

return self.model.thinker(*args, **kwargs)

def load_processor(self) -> ProcessorMixin:
return AutoProcessor.from_pretrained(self.model_local_path)

Expand Down
26 changes: 25 additions & 1 deletion gptqmodel/models/definitions/qwen3_omni_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
# SPDX-License-Identifier: Apache-2.0
# Contact: qubitium@modelcloud.ai, x.com/qubitium

import os

import torch
from transformers import AutoModelForTextToWaveform, AutoProcessor

Expand Down Expand Up @@ -44,11 +46,19 @@ class Qwen3OmniMoeGPTQ(BaseQModel):
]

def pre_quantize_generate_hook_start(self):
spk_path = os.path.join(self.model_local_path, "spk_dict.pt")
if os.path.isfile(spk_path):
self.model.load_speakers(spk_path)

self.shell_module_materialize(self.model.thinker.model.embed_tokens, self.quantize_config.device)
self.shell_module_materialize(self.model.thinker.visual, self.quantize_config.device)
self.shell_module_materialize(self.model.thinker.audio_tower, self.quantize_config.device)
self.shell_module_materialize(self.model.thinker.visual.rotary_pos_emb, self.quantize_config.device)
self.shell_module_materialize(self.model.thinker.model.rotary_emb, self.quantize_config.device)
if hasattr(self.model, "talker"):
self.shell_module_materialize(self.model.talker, self.quantize_config.device)
if hasattr(self.model, "code2wav"):
self.shell_module_materialize(self.model.code2wav, self.quantize_config.device)

def pre_quantize_generate_hook_end(self):
if self.quantize_config.offload_to_disk:
Expand Down Expand Up @@ -76,11 +86,26 @@ def pre_quantize_generate_hook_end(self):
module=self.model.thinker.model.rotary_emb,
disk_path=self.quantize_config.offload_to_disk_path,
)

if hasattr(self.model, "talker"):
offload_to_disk(model=self.model,
module=self.model.talker,
disk_path=self.quantize_config.offload_to_disk_path,
)
if hasattr(self.model, "code2wav"):
offload_to_disk(model=self.model,
module=self.model.code2wav,
disk_path=self.quantize_config.offload_to_disk_path,
)
return

self.model.thinker.model.embed_tokens = self.model.thinker.model.embed_tokens.to(CPU)
self.model.thinker.visual = self.model.thinker.visual.to(CPU)
self.model.thinker.audio_tower = self.model.thinker.audio_tower.to(CPU)
if hasattr(self.model, "talker"):
self.model.talker = self.model.talker.to(CPU)
if hasattr(self.model, "code2wav"):
self.model.code2wav = self.model.code2wav.to(CPU)

self.model.thinker.visual.rotary_pos_emb = self.model.thinker.visual.rotary_pos_emb.to(CPU)
self.model.thinker.model.rotary_emb = self.model.thinker.model.rotary_emb.to(CPU)
Expand All @@ -91,4 +116,3 @@ def after_model_load(self, model, load_quantized_model=False):
self.processor = AutoProcessor.from_pretrained(self.model_local_path)

return model

8 changes: 8 additions & 0 deletions gptqmodel/models/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,10 @@ def skip(*args, **kwargs):
# TODO FIX ME for `dynamic`, non-quantized modules should be in native type
dtype = auto_dtype(config=config, device=quantize_config.device, quant_inference=False)

if isinstance(dtype, torch.dtype) and getattr(config, "torch_dtype", None) != dtype:
# Align config metadata with the dtype we will materialize weights in.
config.torch_dtype = dtype

# enforce some values despite user specified
# non-quantized models are always loaded into cpu
model_init_kwargs["device_map"] = cpu_device_map
Expand Down Expand Up @@ -320,6 +324,10 @@ def from_quantized(
# TODO FIX ME for `dynamic`, non-quantized modules should be in native type
dtype = auto_dtype(config=config, device=device, quant_inference=True)

if isinstance(dtype, torch.dtype) and getattr(config, "torch_dtype", None) != dtype:
# Ensure flash attention kernels see an explicit dtype instead of relying on defaults.
config.torch_dtype = dtype

qcfg = QuantizeConfig.from_pretrained(model_local_path, **cached_file_kwargs, **kwargs)

if qcfg.quant_method == METHOD.AWQ and qcfg.format in [FORMAT.GEMV_FAST]:
Expand Down
10 changes: 7 additions & 3 deletions gptqmodel/utils/offload.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import contextlib
import os
import shutil
import threading
import sys
from typing import Iterable, List, Optional, Set, Tuple

import accelerate
Expand All @@ -21,10 +21,9 @@
from ..looper.named_module import NamedModule
from .device import get_device
from .torch import CPU, META
from .safe import ThreadSafe


_lock = threading.Lock()

# Patch fix thread unsafe accelerate.utils.modeling.clear_device_cache
def _fake_clear_device_cache(garbage_collection=False):
pass
Expand Down Expand Up @@ -96,6 +95,11 @@ def offload_to_disk(module: List[str] | nn.Module, model: nn.Module, disk_path:
# print("offload_disk: list item tree")
# print_module_tree(module)


# Serialize accelerate's disk hook mutations across threads.
_OFFLOAD_SAFE = ThreadSafe(sys.modules[__name__])
offload_to_disk = _OFFLOAD_SAFE.offload_to_disk

def _offload_disk(module: nn.Module, name: str, disk_path: str = "."):
if is_meta_module(module):
# print(f"[skip] '{name}' is on meta; leaving as-is")
Expand Down
2 changes: 1 addition & 1 deletion tests/models/test_qwen3_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ class TestQwen3Moe(ModelTest):
DEBUG = True
ACT_GROUP_AWARE = True
DESC_ACT = False
DATASET_SIZE = 1024 * 8
DATASET_SIZE = 1024
DATASET_SORT = "desc"
QUANT_BATCH_SIZE = 4

Expand Down
Loading