diff --git a/gptqmodel/looper/module_looper.py b/gptqmodel/looper/module_looper.py index 517d065fb..05a2ca9d2 100644 --- a/gptqmodel/looper/module_looper.py +++ b/gptqmodel/looper/module_looper.py @@ -108,8 +108,7 @@ def store_input_hook(module, args, kwargs): layer_inputs.append(layer_input) # Keyword arguments. - # TODO FIX ME..why is Qwen2_5OmniDecoderLayer harded here? - if kwargs.get("attention_mask") is not None and str(type(module)) != "": + if kwargs.get("attention_mask") is not None and self.gptq_model.ATTENTION_MASKS_REQUIRED_FOR_INPUT: attention_masks.append(kwargs["attention_mask"].to(device=data_device)) else: attention_masks.append(None) @@ -160,7 +159,7 @@ def store_input_hook(module, args, kwargs): for example in calibration_data: for k, v in example.items(): - if str(type(layers[0])) == "": + if self.gptq_model.ATTENTION_MASKS_REQUIRED_FOR_INPUT: data_device = self.gptq_model.quantize_config.device else: data_device = self.gptq_model.quantize_config.device if k == "pixel_values" else cur_layer_device @@ -175,8 +174,11 @@ def store_input_hook(module, args, kwargs): v = v.unsqueeze(0) example[k] = move_to(v, device=data_device) try: - if str(type(layers[0])) == "": - self.gptq_model.model.generate(**example, return_audio=False) + if self.gptq_model.ATTENTION_MASKS_DTYPE is torch.long: + example["attention_mask"] = example["attention_mask"].long() + + if self.gptq_model.INPUT_EMBEDDING_EXTRA_ARGS: + self.gptq_model.model.generate(**example, **self.gptq_model.INPUT_EMBEDDING_EXTRA_ARGS) else: self.gptq_model.model(**example, use_cache=use_cache) except StopForward: @@ -240,7 +242,7 @@ def loop(self, fail_safe: bool = False, **kwargs): for processor in self.processors: processor.release_calibration_dataset() - layer_modules = self.gptq_model.simple_layer_modules(model_config=self.gptq_model.model.config) + layer_modules = self.gptq_model.simple_layer_modules(model_config=self.gptq_model.model.config, quantize_config=self.gptq_model.quantize_config) if not self.gptq_model.quantize_config.true_sequential: layer_modules = [sum(layer_modules, [])] diff --git a/gptqmodel/models/auto.py b/gptqmodel/models/auto.py index da8c0729a..397f67caa 100644 --- a/gptqmodel/models/auto.py +++ b/gptqmodel/models/auto.py @@ -116,6 +116,7 @@ from .definitions.starcoder2 import Starcoder2QModel # noqa: E402 from .definitions.telechat2 import TeleChat2QModel from .definitions.xverse import XverseQModel # noqa: E402 +from .definitions.qwen3_omni_moe import Qwen3OmniMoeGPTQ # make quants and inference more determinisitc torch.manual_seed(787) @@ -180,6 +181,7 @@ "qwen2_vl": Qwen2VLQModel, "qwen2_5_vl": Qwen2_5_VLQModel, "qwen2_5_omni": Qwen2_5_OmniGPTQ, + "qwen3_omni_moe": Qwen3OmniMoeGPTQ, "dbrx": DbrxQModel, "dbrx_converted": DbrxConvertedQModel, "deepseek_v2": DeepSeekV2QModel, diff --git a/gptqmodel/models/base.py b/gptqmodel/models/base.py index 4ee513de5..654feb303 100644 --- a/gptqmodel/models/base.py +++ b/gptqmodel/models/base.py @@ -22,7 +22,7 @@ from ..nn_modules.qlinear import BaseQuantLinear from ..nn_modules.qlinear.torch import TorchQuantLinear from ..quantization import QuantizeConfig -from ..quantization.config import FORMAT, METHOD, QUANTIZE_BLACK_LIST +from ..quantization.config import FORMAT, METHOD, QUANTIZE_BLACK_LIST, dynamic_get from ..quantization.rotation.rotation import fuse_layer_norms, rotate_model from ..utils.backend import BACKEND from ..utils.data import collate_data @@ -56,14 +56,6 @@ def classproperty(func): return _ClassPropertyDescriptor(func) -def filter_not_quantize_module(layer_modules): - return [ - [name for name in block if NOT_QUANTIZE_FLAG not in name] - for block in layer_modules - if any(NOT_QUANTIZE_FLAG not in name for name in block) - ] - - def generate_node_for_awq_scaling(inp, prev_op, module_kwargs, nodes_size, subset, module2inspect): n = { "prev_op": prev_op, @@ -149,6 +141,12 @@ class BaseQModel(nn.Module): support_batch_quantize = True + ATTENTION_MASKS_DTYPE = torch.bool # default to bool + + ATTENTION_MASKS_REQUIRED_FOR_INPUT: bool = False + + INPUT_EMBEDDING_EXTRA_ARGS = None + def __init__( self, model: PreTrainedModel, @@ -275,21 +273,45 @@ def build_moe_modules_if_need(cls, model_config, layer_modules, is_awq_quantize: def get_num_experts(cls, model_config): if hasattr(model_config, "text_config"): num_experts = getattr(model_config.text_config, cls.dynamic_expert_index) + elif hasattr(model_config, "thinker_config"): + num_experts = getattr(model_config.thinker_config.text_config, cls.dynamic_expert_index) else: num_experts = getattr(model_config, cls.dynamic_expert_index) return num_experts + @classmethod + def filter_not_quantize_module(cls, layer_modules, quantize_config): + layer_modules = [ + [name for name in block if NOT_QUANTIZE_FLAG not in name] + for block in layer_modules + ] + layer_modules = [block for block in layer_modules if block] # 去掉空 block + + if getattr(quantize_config, "dynamic", None): + new_layer_modules = [] + for modules in layer_modules: + filtered = [ + m for m in modules + if dynamic_get(quantize_config.dynamic, module_name=m) is not False + ] + if filtered: + new_layer_modules.append(filtered) + layer_modules = new_layer_modules + + return layer_modules + # Inside each `LlamaDecoderLayer` layer are many internal modules # List them in the order executed in model forward() code # Many models have same execution order of: attention (q_k_v) projection, attention (output) projection, mlp (n) projections @classmethod - def simple_layer_modules(cls, model_config, is_awq_quantize: bool = False): + def simple_layer_modules(cls, model_config, quantize_config, is_awq_quantize: bool = False): layer_modules = cls.build_layer_modules(cls.module_tree) layer_modules = cls.build_moe_modules_if_need(model_config, layer_modules, is_awq_quantize) - layer_modules = filter_not_quantize_module(layer_modules) - # print(f"simple_layer_modules layer_modules: {layer_modules}") + layer_modules = cls.filter_not_quantize_module(layer_modules, quantize_config) + + print(f"simple_layer_modules layer_modules: {layer_modules}") return layer_modules @classmethod @@ -1046,6 +1068,12 @@ def shell_module_materialize( device: torch.device, non_blocking: bool = False, ) -> torch.nn.Module: + if self.turtle_model is None: + if get_device(target_submodule) != device: + target_submodule.to(device) + + return target_submodule + module = alias_from_turtle_for_submodule( target_model=self.model, turtle_model=self.turtle_model, diff --git a/gptqmodel/models/definitions/__init__.py b/gptqmodel/models/definitions/__init__.py index 64d608f67..2a9459aa4 100644 --- a/gptqmodel/models/definitions/__init__.py +++ b/gptqmodel/models/definitions/__init__.py @@ -63,3 +63,4 @@ from .klear import KlearQModel from .llava_qwen2 import LlavaQwen2QModel from .nemotron_h import NemotronHQModel +from .qwen3_omni_moe import Qwen3OmniMoeGPTQ \ No newline at end of file diff --git a/gptqmodel/models/definitions/base_qwen2_5_omni.py b/gptqmodel/models/definitions/base_qwen2_5_omni.py index 16347813c..4fca6b67f 100644 --- a/gptqmodel/models/definitions/base_qwen2_5_omni.py +++ b/gptqmodel/models/definitions/base_qwen2_5_omni.py @@ -13,9 +13,17 @@ from ...utils.model import MODALITY from .._const import CPU from ..base import BaseQModel +import torch class BaseQwen2_5_OmniGPTQ(BaseQModel): + ATTENTION_MASKS_REQUIRED_FOR_INPUT = True + ATTENTION_MASKS_DTYPE = torch.long + + INPUT_EMBEDDING_EXTRA_ARGS = { + "return_audio": False, + } + loader = AutoModelForTextToWaveform pre_lm_head_norm_module = "thinker.model.norm" diff --git a/gptqmodel/models/definitions/qwen3_omni_moe.py b/gptqmodel/models/definitions/qwen3_omni_moe.py new file mode 100644 index 000000000..4487136cc --- /dev/null +++ b/gptqmodel/models/definitions/qwen3_omni_moe.py @@ -0,0 +1,84 @@ +# SPDX-FileCopyrightText: 2024-2025 ModelCloud.ai +# SPDX-FileCopyrightText: 2024-2025 qubitium@modelcloud.ai +# SPDX-License-Identifier: Apache-2.0 +# Contact: qubitium@modelcloud.ai, x.com/qubitium + +from transformers import AutoModelForTextToWaveform +from ..base import BaseQModel +from .._const import CPU +from ...utils.offload import offload_to_disk +import torch + +class Qwen3OmniMoeGPTQ(BaseQModel): + ATTENTION_MASKS_REQUIRED_FOR_INPUT = True + ATTENTION_MASKS_DTYPE = torch.long + + INPUT_EMBEDDING_EXTRA_ARGS = { + "return_audio": False, + } + + loader = AutoModelForTextToWaveform + + dynamic_expert_index = "num_experts" + + pre_lm_head_norm_module = "thinker.model.norm" + + module_tree = [ + "thinker", + "model", + "layers", + "#", + { + "input_layernorm": ("input_layernorm:!",), + "self_attn": ("q_proj:0", "k_proj:0", "v_proj:0", "o_proj:1"), + "post_attention_layernorm": ("post_attention_layernorm:!",), + "mlp": { + "gate": ("gate",), + "experts": { + "#": ("gate_proj:0", "up_proj:0", "down_proj:1"), + }, + }, + } + ] + + def pre_quantize_generate_hook_start(self): + self.shell_module_materialize(self.model.thinker.model.embed_tokens, self.quantize_config.device) + self.shell_module_materialize(self.model.thinker.visual, self.quantize_config.device) + self.shell_module_materialize(self.model.thinker.audio_tower, self.quantize_config.device) + self.shell_module_materialize(self.model.thinker.visual.rotary_pos_emb, self.quantize_config.device) + self.shell_module_materialize(self.model.thinker.model.rotary_emb, self.quantize_config.device) + + def pre_quantize_generate_hook_end(self): + if self.quantize_config.offload_to_disk: + offload_to_disk(model=self.model.thinker.model, + module=self.model.thinker.model.embed_tokens, + disk_path=self.quantize_config.offload_to_disk_path, + ) + + offload_to_disk(model=self.model.thinker, + module=self.model.thinker.visual, + disk_path=self.quantize_config.offload_to_disk_path, + ) + + offload_to_disk(model=self.model.thinker, + module=self.model.thinker.audio_tower, + disk_path=self.quantize_config.offload_to_disk_path, + ) + + offload_to_disk(model=self.model.thinker.visual, + module=self.model.thinker.visual.rotary_pos_emb, + disk_path=self.quantize_config.offload_to_disk_path, + ) + + offload_to_disk(model=self.model.thinker.model, + module=self.model.thinker.model.rotary_emb, + disk_path=self.quantize_config.offload_to_disk_path, + ) + return + + self.model.thinker.model.embed_tokens = self.model.thinker.model.embed_tokens.to(CPU) + self.model.thinker.visual = self.model.thinker.visual.to(CPU) + self.model.thinker.audio_tower = self.model.thinker.audio_tower.to(CPU) + + self.model.thinker.visual.rotary_pos_emb = self.model.thinker.visual.rotary_pos_emb.to(CPU) + self.model.thinker.model.rotary_emb = self.model.thinker.model.rotary_emb.to(CPU) diff --git a/gptqmodel/models/loader.py b/gptqmodel/models/loader.py index 90993d7c1..2576583bb 100644 --- a/gptqmodel/models/loader.py +++ b/gptqmodel/models/loader.py @@ -180,19 +180,26 @@ def skip(*args, **kwargs): cls.before_model_load(cls, load_quantized_model=False) from ..utils.hf import build_shell_model - #model = cls.loader.from_pretrained(model_local_path, config=config, **model_init_kwargs) - print("shell model-----------") - model = build_shell_model(cls.loader, config=config, **model_init_kwargs) - model._model_init_kwargs = model_init_kwargs - - print_module_tree(model=model) - # enable mmap with low_cpu_mem_usage - turtle_model = cls.loader.from_pretrained(model_local_path, config=config, low_cpu_mem_usage=True, **model_init_kwargs) - - # TODO FIX ME...temp store model_init args - turtle_model._model_init_kwargs = model_init_kwargs - # print("actual turtle model-----------") - # print_module_tree(model=turtle_model) + if quantize_config.offload_to_disk: + print("shell model-----------") + model = build_shell_model(cls.loader, config=config, **model_init_kwargs) + model._model_init_kwargs = model_init_kwargs + print_module_tree(model=model) + + # enable mmap with low_cpu_mem_usage + turtle_model = cls.loader.from_pretrained(model_local_path, config=config, low_cpu_mem_usage=True, **model_init_kwargs) + + # TODO FIX ME...temp store model_init args + turtle_model._model_init_kwargs = model_init_kwargs + # print("actual turtle model-----------") + # print_module_tree(model=turtle_model) + else: + print("loading model directly to CPU (not using meta device or turtle_model)-----------") + model = cls.loader.from_pretrained(model_local_path, config=config, **model_init_kwargs) + model._model_init_kwargs = model_init_kwargs + print_module_tree(model=model) + + turtle_model = None model_config = model.config.to_dict() seq_len_keys = ["max_position_embeddings", "seq_length", "n_positions", "multimodal_max_length"] @@ -204,7 +211,7 @@ def skip(*args, **kwargs): model.seqlen = 4096 model.eval() - turtle_model.eval() + turtle_model.eval() if turtle_model is not None else None tokenizer = AutoTokenizer.from_pretrained(pretrained_model_id_or_path, trust_remote_code=trust_remote_code) @@ -462,7 +469,7 @@ def skip(*args, **kwargs): continue if not any(name.startswith(prefix) for prefix in cls.extract_layers_node()) or any(name.startswith(ignore_module) for ignore_module in ignore_modules) or all( - not name.endswith(ignore_module) for sublist in cls.simple_layer_modules(config) for ignore_module in sublist + not name.endswith(ignore_module) for sublist in cls.simple_layer_modules(config, qcfg) for ignore_module in sublist ): # log non-lm-head quantized modules only if name is not cls.lm_head: diff --git a/gptqmodel/models/writer.py b/gptqmodel/models/writer.py index 9aae554f9..eccdbb264 100644 --- a/gptqmodel/models/writer.py +++ b/gptqmodel/models/writer.py @@ -443,7 +443,7 @@ def skip(*args, **kwargs): continue if any(name.startswith(ignore_module) for ignore_module in ignore_modules) or all( - not name.endswith(ignore_module) for sublist in self.simple_layer_modules(config) for ignore_module in sublist + not name.endswith(ignore_module) for sublist in self.simple_layer_modules(config, qcfg) for ignore_module in sublist ): # log non-lm-head quantizerd modules only if name is not self.lm_head: diff --git a/gptqmodel/utils/structure.py b/gptqmodel/utils/structure.py index f187ee096..44b57f1fa 100644 --- a/gptqmodel/utils/structure.py +++ b/gptqmodel/utils/structure.py @@ -608,6 +608,9 @@ def alias_all_from_turtle_if_meta( Logs each swap via log.info(). """ + if turtle_model is None: + return 0 + turtle_map = dict(turtle_model.named_modules()) swapped = 0