diff --git a/gptqmodel/models/definitions/base_qwen2_5_omni.py b/gptqmodel/models/definitions/base_qwen2_5_omni.py index 227a7bd56..8827b69dc 100644 --- a/gptqmodel/models/definitions/base_qwen2_5_omni.py +++ b/gptqmodel/models/definitions/base_qwen2_5_omni.py @@ -4,14 +4,14 @@ # Contact: qubitium@modelcloud.ai, x.com/qubitium from typing import Dict, Optional - +import os import torch from PIL import Image from transformers import AutoModelForTextToWaveform, AutoProcessor, ProcessorMixin - from ...utils.calibration import batched from ...utils.image import extract_vision_info, fetch_image from ...utils.model import MODALITY +from ...utils.offload import offload_to_disk from .._const import CPU from ..base import BaseQModel @@ -46,18 +46,50 @@ class BaseQwen2_5_OmniGPTQ(BaseQModel): require_load_processor = True def pre_quantize_generate_hook_start(self): - - self.model.thinker.model.embed_tokens = self.model.thinker.model.embed_tokens.to(self.quantize_config.device) - self.model.thinker.visual = self.model.thinker.visual.to(self.quantize_config.device) - self.model.thinker.audio_tower = self.model.thinker.audio_tower.to(self.quantize_config.device) - - self.model.thinker.visual.rotary_pos_emb = self.model.thinker.visual.rotary_pos_emb.to(self.quantize_config.device) - self.model.thinker.model.rotary_emb = self.model.thinker.model.rotary_emb.to(self.quantize_config.device) - + # load speaker + spk_path = os.path.join(self.model_local_path, "spk_dict.pt") + self.model.load_speakers(spk_path) + + self.shell_module_materialize(self.model.thinker.model.embed_tokens, self.quantize_config.device) + self.shell_module_materialize(self.model.thinker.visual, self.quantize_config.device) + self.shell_module_materialize(self.model.thinker.audio_tower, self.quantize_config.device) + self.shell_module_materialize(self.model.thinker.visual.rotary_pos_emb, self.quantize_config.device) + self.shell_module_materialize(self.model.thinker.model.rotary_emb, self.quantize_config.device) for layer in self.model.thinker.model.layers: - layer.self_attn.rotary_emb = layer.self_attn.rotary_emb.to(self.quantize_config.device) + self.shell_module_materialize(layer.self_attn.rotary_emb, self.quantize_config.device) def pre_quantize_generate_hook_end(self): + if self.quantize_config.offload_to_disk: + offload_to_disk(model=self.model.thinker.model, + module=self.model.thinker.model.embed_tokens, + disk_path=self.quantize_config.offload_to_disk_path, + ) + + offload_to_disk(model=self.model.thinker, + module=self.model.thinker.visual, + disk_path=self.quantize_config.offload_to_disk_path, + ) + + offload_to_disk(model=self.model.thinker, + module=self.model.thinker.audio_tower, + disk_path=self.quantize_config.offload_to_disk_path, + ) + + offload_to_disk(model=self.model.thinker.visual, + module=self.model.thinker.visual.rotary_pos_emb, + disk_path=self.quantize_config.offload_to_disk_path, + ) + + offload_to_disk(model=self.model.thinker.model, + module=self.model.thinker.model.rotary_emb, + disk_path=self.quantize_config.offload_to_disk_path, + ) + + for layer in self.model.thinker.model.layers: + layer.self_attn.rotary_emb = layer.self_attn.rotary_emb.to(CPU) + + return + self.model.thinker.model.embed_tokens = self.model.thinker.model.embed_tokens.to(CPU) self.model.thinker.visual = self.model.thinker.visual.to(CPU) self.model.thinker.audio_tower = self.model.thinker.audio_tower.to(CPU) @@ -67,6 +99,7 @@ def pre_quantize_generate_hook_end(self): for layer in self.model.thinker.model.layers: layer.self_attn.rotary_emb = layer.self_attn.rotary_emb.to(CPU) + @staticmethod def process_vision_info( conversations: list[dict] | list[list[dict]], @@ -89,7 +122,7 @@ def preprocess_dataset(self, sample: Dict) -> Dict: def load_processor(self) -> ProcessorMixin: return AutoProcessor.from_pretrained(self.model_local_path) - def prepare_dataset(self, calibration_dataset, calibration_dataset_concat_size=None, batch_size: int = 1): + def prepare_dataset(self, calibration_dataset, calibration_dataset_concat_size=None, batch_size: int = 1, **kwargs): processor = self.load_processor() calib_data = [] for batch in batched(calibration_dataset, batch_size, process_func=self.preprocess_dataset): diff --git a/gptqmodel/models/definitions/ovis.py b/gptqmodel/models/definitions/ovis.py index 0675a2ee6..8485975c1 100644 --- a/gptqmodel/models/definitions/ovis.py +++ b/gptqmodel/models/definitions/ovis.py @@ -93,7 +93,8 @@ def prepare_dataset( calibration_dataset, calibration_dataset_concat_size, batch_size: int = 1, - tokenizer=None, ): + tokenizer=None, + **kwargs): calib_data = [] for batch in batched(calibration_dataset, batch_size, self.preprocess_dataset): pixel_values, input_ids, labels = tuple([instance[key] for instance in batch] diff --git a/gptqmodel/models/definitions/qwen3_omni_moe.py b/gptqmodel/models/definitions/qwen3_omni_moe.py index 3b071bcd0..9d7da24c0 100644 --- a/gptqmodel/models/definitions/qwen3_omni_moe.py +++ b/gptqmodel/models/definitions/qwen3_omni_moe.py @@ -4,7 +4,7 @@ # Contact: qubitium@modelcloud.ai, x.com/qubitium import torch -from transformers import AutoModelForTextToWaveform +from transformers import AutoModelForTextToWaveform, AutoProcessor from ...utils.offload import offload_to_disk from .._const import CPU @@ -84,3 +84,11 @@ def pre_quantize_generate_hook_end(self): self.model.thinker.visual.rotary_pos_emb = self.model.thinker.visual.rotary_pos_emb.to(CPU) self.model.thinker.model.rotary_emb = self.model.thinker.model.rotary_emb.to(CPU) + + def after_model_load(self, model, load_quantized_model=False): + # need to load processor for save processor_config and chat_template + if not load_quantized_model: + self.processor = AutoProcessor.from_pretrained(self.model_local_path) + + return model + diff --git a/tests/models/test_qwen3_omni.py b/tests/models/test_qwen3_omni.py new file mode 100644 index 000000000..d386a424a --- /dev/null +++ b/tests/models/test_qwen3_omni.py @@ -0,0 +1,26 @@ +# SPDX-FileCopyrightText: 2024-2025 ModelCloud.ai +# SPDX-FileCopyrightText: 2024-2025 qubitium@modelcloud.ai +# SPDX-License-Identifier: Apache-2.0 +# Contact: qubitium@modelcloud.ai, x.com/qubitium + +from model_test import ModelTest + + +class TestQwen3Omni(ModelTest): + NATIVE_MODEL_ID = "/monster/data/model/Qwen3-Omni-30B-A3B-Instruct/" + QUANT_ARC_MAX_DELTA_FLOOR_PERCENT = 0.2 + NATIVE_ARC_CHALLENGE_ACC = 0.2739 + NATIVE_ARC_CHALLENGE_ACC_NORM = 0.3055 + # TRUST_REMOTE_CODE = False + APPLY_CHAT_TEMPLATE = True + # EVAL_BATCH_SIZE = 6 + V2 = False + DEBUG = True + ACT_GROUP_AWARE = True + DESC_ACT = False + DATASET_SIZE = 1024 + DATASET_SORT = "desc" + QUANT_BATCH_SIZE = 1 + + def test_omni(self): + self.quant_lm_eval()