Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 45 additions & 12 deletions gptqmodel/models/definitions/base_qwen2_5_omni.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,14 @@
# Contact: qubitium@modelcloud.ai, x.com/qubitium

from typing import Dict, Optional

import os
import torch
from PIL import Image
from transformers import AutoModelForTextToWaveform, AutoProcessor, ProcessorMixin

from ...utils.calibration import batched
from ...utils.image import extract_vision_info, fetch_image
from ...utils.model import MODALITY
from ...utils.offload import offload_to_disk
from .._const import CPU
from ..base import BaseQModel

Expand Down Expand Up @@ -46,18 +46,50 @@ class BaseQwen2_5_OmniGPTQ(BaseQModel):
require_load_processor = True

def pre_quantize_generate_hook_start(self):

self.model.thinker.model.embed_tokens = self.model.thinker.model.embed_tokens.to(self.quantize_config.device)
self.model.thinker.visual = self.model.thinker.visual.to(self.quantize_config.device)
self.model.thinker.audio_tower = self.model.thinker.audio_tower.to(self.quantize_config.device)

self.model.thinker.visual.rotary_pos_emb = self.model.thinker.visual.rotary_pos_emb.to(self.quantize_config.device)
self.model.thinker.model.rotary_emb = self.model.thinker.model.rotary_emb.to(self.quantize_config.device)

# load speaker
spk_path = os.path.join(self.model_local_path, "spk_dict.pt")
self.model.load_speakers(spk_path)

self.shell_module_materialize(self.model.thinker.model.embed_tokens, self.quantize_config.device)
self.shell_module_materialize(self.model.thinker.visual, self.quantize_config.device)
self.shell_module_materialize(self.model.thinker.audio_tower, self.quantize_config.device)
self.shell_module_materialize(self.model.thinker.visual.rotary_pos_emb, self.quantize_config.device)
self.shell_module_materialize(self.model.thinker.model.rotary_emb, self.quantize_config.device)
for layer in self.model.thinker.model.layers:
layer.self_attn.rotary_emb = layer.self_attn.rotary_emb.to(self.quantize_config.device)
self.shell_module_materialize(layer.self_attn.rotary_emb, self.quantize_config.device)

def pre_quantize_generate_hook_end(self):
if self.quantize_config.offload_to_disk:
offload_to_disk(model=self.model.thinker.model,
module=self.model.thinker.model.embed_tokens,
disk_path=self.quantize_config.offload_to_disk_path,
)

offload_to_disk(model=self.model.thinker,
module=self.model.thinker.visual,
disk_path=self.quantize_config.offload_to_disk_path,
)

offload_to_disk(model=self.model.thinker,
module=self.model.thinker.audio_tower,
disk_path=self.quantize_config.offload_to_disk_path,
)

offload_to_disk(model=self.model.thinker.visual,
module=self.model.thinker.visual.rotary_pos_emb,
disk_path=self.quantize_config.offload_to_disk_path,
)

offload_to_disk(model=self.model.thinker.model,
module=self.model.thinker.model.rotary_emb,
disk_path=self.quantize_config.offload_to_disk_path,
)

for layer in self.model.thinker.model.layers:
layer.self_attn.rotary_emb = layer.self_attn.rotary_emb.to(CPU)

return

self.model.thinker.model.embed_tokens = self.model.thinker.model.embed_tokens.to(CPU)
self.model.thinker.visual = self.model.thinker.visual.to(CPU)
self.model.thinker.audio_tower = self.model.thinker.audio_tower.to(CPU)
Expand All @@ -67,6 +99,7 @@ def pre_quantize_generate_hook_end(self):

for layer in self.model.thinker.model.layers:
layer.self_attn.rotary_emb = layer.self_attn.rotary_emb.to(CPU)

@staticmethod
def process_vision_info(
conversations: list[dict] | list[list[dict]],
Expand All @@ -89,7 +122,7 @@ def preprocess_dataset(self, sample: Dict) -> Dict:
def load_processor(self) -> ProcessorMixin:
return AutoProcessor.from_pretrained(self.model_local_path)

def prepare_dataset(self, calibration_dataset, calibration_dataset_concat_size=None, batch_size: int = 1):
def prepare_dataset(self, calibration_dataset, calibration_dataset_concat_size=None, batch_size: int = 1, **kwargs):
processor = self.load_processor()
calib_data = []
for batch in batched(calibration_dataset, batch_size, process_func=self.preprocess_dataset):
Expand Down
3 changes: 2 additions & 1 deletion gptqmodel/models/definitions/ovis.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,8 @@ def prepare_dataset(
calibration_dataset,
calibration_dataset_concat_size,
batch_size: int = 1,
tokenizer=None, ):
tokenizer=None,
**kwargs):
calib_data = []
for batch in batched(calibration_dataset, batch_size, self.preprocess_dataset):
pixel_values, input_ids, labels = tuple([instance[key] for instance in batch]
Expand Down
10 changes: 9 additions & 1 deletion gptqmodel/models/definitions/qwen3_omni_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
# Contact: qubitium@modelcloud.ai, x.com/qubitium

import torch
from transformers import AutoModelForTextToWaveform
from transformers import AutoModelForTextToWaveform, AutoProcessor

from ...utils.offload import offload_to_disk
from .._const import CPU
Expand Down Expand Up @@ -84,3 +84,11 @@ def pre_quantize_generate_hook_end(self):

self.model.thinker.visual.rotary_pos_emb = self.model.thinker.visual.rotary_pos_emb.to(CPU)
self.model.thinker.model.rotary_emb = self.model.thinker.model.rotary_emb.to(CPU)

def after_model_load(self, model, load_quantized_model=False):
# need to load processor for save processor_config and chat_template
if not load_quantized_model:
self.processor = AutoProcessor.from_pretrained(self.model_local_path)

return model

26 changes: 26 additions & 0 deletions tests/models/test_qwen3_omni.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
# SPDX-FileCopyrightText: 2024-2025 ModelCloud.ai
# SPDX-FileCopyrightText: 2024-2025 qubitium@modelcloud.ai
# SPDX-License-Identifier: Apache-2.0
# Contact: qubitium@modelcloud.ai, x.com/qubitium

from model_test import ModelTest


class TestQwen3Omni(ModelTest):
NATIVE_MODEL_ID = "/monster/data/model/Qwen3-Omni-30B-A3B-Instruct/"
QUANT_ARC_MAX_DELTA_FLOOR_PERCENT = 0.2
NATIVE_ARC_CHALLENGE_ACC = 0.2739
NATIVE_ARC_CHALLENGE_ACC_NORM = 0.3055
# TRUST_REMOTE_CODE = False
APPLY_CHAT_TEMPLATE = True
# EVAL_BATCH_SIZE = 6
V2 = False
DEBUG = True
ACT_GROUP_AWARE = True
DESC_ACT = False
DATASET_SIZE = 1024
DATASET_SORT = "desc"
QUANT_BATCH_SIZE = 1

def test_omni(self):
self.quant_lm_eval()