Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
9aef890
add Qwen3OmniMoe
LRL2-ModelCloud Sep 27, 2025
dd1d400
Fixing now: if offload_to_disk=False, we should not load the model in…
LRL2-ModelCloud Sep 27, 2025
964a92f
cleanup
LRL2-ModelCloud Sep 27, 2025
f60ae0c
fix nonetype
LRL2-ModelCloud Sep 27, 2025
5864c49
fix none
LRL2-ModelCloud Sep 27, 2025
7d54a92
add qwen3 omine moe support
LRL2-ModelCloud Sep 27, 2025
e42a21d
require attention_mask to be int type if model decoder layer type is …
LRL2-ModelCloud Sep 27, 2025
37165ba
Revert "require attention_mask to be int type if model decoder layer …
LRL2-ModelCloud Sep 27, 2025
f20f37d
for compatibility, attention_mask should be of type long.
LRL2-ModelCloud Sep 27, 2025
f7cd5ae
add support_offload_to_disk
LRL2-ModelCloud Sep 27, 2025
f3ff5c7
cleanup
LRL2-ModelCloud Sep 27, 2025
0765e28
cleanup
LRL2-ModelCloud Sep 27, 2025
7fda54a
typo
LRL2-ModelCloud Sep 27, 2025
0560d81
check none
LRL2-ModelCloud Sep 27, 2025
1c9c91a
cleanup
LRL2-ModelCloud Sep 27, 2025
9ef8679
offload to disk
LRL2-ModelCloud Sep 27, 2025
2f5b277
cleanup
LRL2-ModelCloud Sep 27, 2025
78f6758
update
LRL2-ModelCloud Sep 28, 2025
4ade4d0
mod filter_not_quantize_module
LRL2-ModelCloud Sep 28, 2025
59cf8c2
if offload_to_disk=False, we should not load the model into meta first.
LRL2-ModelCloud Sep 28, 2025
41fb9f4
mod base.py
LRL2-ModelCloud Sep 28, 2025
e3fc3ca
mod pre_quantize_generate_hook_end
LRL2-ModelCloud Sep 28, 2025
58e6d79
fix has no attr quantize_config
LRL2-ModelCloud Sep 28, 2025
acb5f80
fix filter
LRL2-ModelCloud Sep 28, 2025
5be44b7
check none
LRL2-ModelCloud Sep 28, 2025
c20d9d2
need config
LRL2-ModelCloud Sep 28, 2025
65eb981
fix filter_not_quantize_module
LRL2-ModelCloud Sep 28, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 8 additions & 6 deletions gptqmodel/looper/module_looper.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,8 +108,7 @@ def store_input_hook(module, args, kwargs):
layer_inputs.append(layer_input)

# Keyword arguments.
# TODO FIX ME..why is Qwen2_5OmniDecoderLayer harded here?
if kwargs.get("attention_mask") is not None and str(type(module)) != "<class 'transformers.models.qwen2_5_omni.modeling_qwen2_5_omni.Qwen2_5OmniDecoderLayer'>":
if kwargs.get("attention_mask") is not None and self.gptq_model.ATTENTION_MASKS_REQUIRED_FOR_INPUT:
attention_masks.append(kwargs["attention_mask"].to(device=data_device))
else:
attention_masks.append(None)
Expand Down Expand Up @@ -160,7 +159,7 @@ def store_input_hook(module, args, kwargs):

for example in calibration_data:
for k, v in example.items():
if str(type(layers[0])) == "<class 'transformers.models.qwen2_5_omni.modeling_qwen2_5_omni.Qwen2_5OmniDecoderLayer'>":
if self.gptq_model.ATTENTION_MASKS_REQUIRED_FOR_INPUT:
data_device = self.gptq_model.quantize_config.device
else:
data_device = self.gptq_model.quantize_config.device if k == "pixel_values" else cur_layer_device
Expand All @@ -175,8 +174,11 @@ def store_input_hook(module, args, kwargs):
v = v.unsqueeze(0)
example[k] = move_to(v, device=data_device)
try:
if str(type(layers[0])) == "<class 'transformers.models.qwen2_5_omni.modeling_qwen2_5_omni.Qwen2_5OmniDecoderLayer'>":
self.gptq_model.model.generate(**example, return_audio=False)
if self.gptq_model.ATTENTION_MASKS_DTYPE is torch.long:
example["attention_mask"] = example["attention_mask"].long()

if self.gptq_model.INPUT_EMBEDDING_EXTRA_ARGS:
self.gptq_model.model.generate(**example, **self.gptq_model.INPUT_EMBEDDING_EXTRA_ARGS)
else:
self.gptq_model.model(**example, use_cache=use_cache)
except StopForward:
Expand Down Expand Up @@ -240,7 +242,7 @@ def loop(self, fail_safe: bool = False, **kwargs):
for processor in self.processors:
processor.release_calibration_dataset()

layer_modules = self.gptq_model.simple_layer_modules(model_config=self.gptq_model.model.config)
layer_modules = self.gptq_model.simple_layer_modules(model_config=self.gptq_model.model.config, quantize_config=self.gptq_model.quantize_config)

if not self.gptq_model.quantize_config.true_sequential:
layer_modules = [sum(layer_modules, [])]
Expand Down
2 changes: 2 additions & 0 deletions gptqmodel/models/auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@
from .definitions.starcoder2 import Starcoder2QModel # noqa: E402
from .definitions.telechat2 import TeleChat2QModel
from .definitions.xverse import XverseQModel # noqa: E402
from .definitions.qwen3_omni_moe import Qwen3OmniMoeGPTQ

# make quants and inference more determinisitc
torch.manual_seed(787)
Expand Down Expand Up @@ -180,6 +181,7 @@
"qwen2_vl": Qwen2VLQModel,
"qwen2_5_vl": Qwen2_5_VLQModel,
"qwen2_5_omni": Qwen2_5_OmniGPTQ,
"qwen3_omni_moe": Qwen3OmniMoeGPTQ,
"dbrx": DbrxQModel,
"dbrx_converted": DbrxConvertedQModel,
"deepseek_v2": DeepSeekV2QModel,
Expand Down
52 changes: 40 additions & 12 deletions gptqmodel/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from ..nn_modules.qlinear import BaseQuantLinear
from ..nn_modules.qlinear.torch import TorchQuantLinear
from ..quantization import QuantizeConfig
from ..quantization.config import FORMAT, METHOD, QUANTIZE_BLACK_LIST
from ..quantization.config import FORMAT, METHOD, QUANTIZE_BLACK_LIST, dynamic_get
from ..quantization.rotation.rotation import fuse_layer_norms, rotate_model
from ..utils.backend import BACKEND
from ..utils.data import collate_data
Expand Down Expand Up @@ -56,14 +56,6 @@ def classproperty(func):
return _ClassPropertyDescriptor(func)


def filter_not_quantize_module(layer_modules):
return [
[name for name in block if NOT_QUANTIZE_FLAG not in name]
for block in layer_modules
if any(NOT_QUANTIZE_FLAG not in name for name in block)
]


def generate_node_for_awq_scaling(inp, prev_op, module_kwargs, nodes_size, subset, module2inspect):
n = {
"prev_op": prev_op,
Expand Down Expand Up @@ -149,6 +141,12 @@ class BaseQModel(nn.Module):

support_batch_quantize = True

ATTENTION_MASKS_DTYPE = torch.bool # default to bool

ATTENTION_MASKS_REQUIRED_FOR_INPUT: bool = False

INPUT_EMBEDDING_EXTRA_ARGS = None

def __init__(
self,
model: PreTrainedModel,
Expand Down Expand Up @@ -275,21 +273,45 @@ def build_moe_modules_if_need(cls, model_config, layer_modules, is_awq_quantize:
def get_num_experts(cls, model_config):
if hasattr(model_config, "text_config"):
num_experts = getattr(model_config.text_config, cls.dynamic_expert_index)
elif hasattr(model_config, "thinker_config"):
num_experts = getattr(model_config.thinker_config.text_config, cls.dynamic_expert_index)
else:
num_experts = getattr(model_config, cls.dynamic_expert_index)
return num_experts

@classmethod
def filter_not_quantize_module(cls, layer_modules, quantize_config):
layer_modules = [
[name for name in block if NOT_QUANTIZE_FLAG not in name]
for block in layer_modules
]
layer_modules = [block for block in layer_modules if block] # 去掉空 block

if getattr(quantize_config, "dynamic", None):
new_layer_modules = []
for modules in layer_modules:
filtered = [
m for m in modules
if dynamic_get(quantize_config.dynamic, module_name=m) is not False
]
if filtered:
new_layer_modules.append(filtered)
layer_modules = new_layer_modules

return layer_modules

# Inside each `LlamaDecoderLayer` layer are many internal modules
# List them in the order executed in model forward() code
# Many models have same execution order of: attention (q_k_v) projection, attention (output) projection, mlp (n) projections
@classmethod
def simple_layer_modules(cls, model_config, is_awq_quantize: bool = False):
def simple_layer_modules(cls, model_config, quantize_config, is_awq_quantize: bool = False):
layer_modules = cls.build_layer_modules(cls.module_tree)

layer_modules = cls.build_moe_modules_if_need(model_config, layer_modules, is_awq_quantize)

layer_modules = filter_not_quantize_module(layer_modules)
# print(f"simple_layer_modules layer_modules: {layer_modules}")
layer_modules = cls.filter_not_quantize_module(layer_modules, quantize_config)

print(f"simple_layer_modules layer_modules: {layer_modules}")
return layer_modules

@classmethod
Expand Down Expand Up @@ -1046,6 +1068,12 @@ def shell_module_materialize(
device: torch.device,
non_blocking: bool = False,
) -> torch.nn.Module:
if self.turtle_model is None:
if get_device(target_submodule) != device:
target_submodule.to(device)

return target_submodule

module = alias_from_turtle_for_submodule(
target_model=self.model,
turtle_model=self.turtle_model,
Expand Down
1 change: 1 addition & 0 deletions gptqmodel/models/definitions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,3 +63,4 @@
from .klear import KlearQModel
from .llava_qwen2 import LlavaQwen2QModel
from .nemotron_h import NemotronHQModel
from .qwen3_omni_moe import Qwen3OmniMoeGPTQ
8 changes: 8 additions & 0 deletions gptqmodel/models/definitions/base_qwen2_5_omni.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,17 @@
from ...utils.model import MODALITY
from .._const import CPU
from ..base import BaseQModel
import torch


class BaseQwen2_5_OmniGPTQ(BaseQModel):
ATTENTION_MASKS_REQUIRED_FOR_INPUT = True
ATTENTION_MASKS_DTYPE = torch.long

INPUT_EMBEDDING_EXTRA_ARGS = {
"return_audio": False,
}

loader = AutoModelForTextToWaveform

pre_lm_head_norm_module = "thinker.model.norm"
Expand Down
84 changes: 84 additions & 0 deletions gptqmodel/models/definitions/qwen3_omni_moe.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
# SPDX-FileCopyrightText: 2024-2025 ModelCloud.ai
# SPDX-FileCopyrightText: 2024-2025 qubitium@modelcloud.ai
# SPDX-License-Identifier: Apache-2.0
# Contact: qubitium@modelcloud.ai, x.com/qubitium

from transformers import AutoModelForTextToWaveform
from ..base import BaseQModel
from .._const import CPU
from ...utils.offload import offload_to_disk
import torch

class Qwen3OmniMoeGPTQ(BaseQModel):
ATTENTION_MASKS_REQUIRED_FOR_INPUT = True
ATTENTION_MASKS_DTYPE = torch.long

INPUT_EMBEDDING_EXTRA_ARGS = {
"return_audio": False,
}

loader = AutoModelForTextToWaveform

dynamic_expert_index = "num_experts"

pre_lm_head_norm_module = "thinker.model.norm"

module_tree = [
"thinker",
"model",
"layers",
"#",
{
"input_layernorm": ("input_layernorm:!",),
"self_attn": ("q_proj:0", "k_proj:0", "v_proj:0", "o_proj:1"),
"post_attention_layernorm": ("post_attention_layernorm:!",),
"mlp": {
"gate": ("gate",),
"experts": {
"#": ("gate_proj:0", "up_proj:0", "down_proj:1"),
},
},
}
]

def pre_quantize_generate_hook_start(self):
self.shell_module_materialize(self.model.thinker.model.embed_tokens, self.quantize_config.device)
self.shell_module_materialize(self.model.thinker.visual, self.quantize_config.device)
self.shell_module_materialize(self.model.thinker.audio_tower, self.quantize_config.device)
self.shell_module_materialize(self.model.thinker.visual.rotary_pos_emb, self.quantize_config.device)
self.shell_module_materialize(self.model.thinker.model.rotary_emb, self.quantize_config.device)

def pre_quantize_generate_hook_end(self):
if self.quantize_config.offload_to_disk:
offload_to_disk(model=self.model.thinker.model,
module=self.model.thinker.model.embed_tokens,
disk_path=self.quantize_config.offload_to_disk_path,
)

offload_to_disk(model=self.model.thinker,
module=self.model.thinker.visual,
disk_path=self.quantize_config.offload_to_disk_path,
)

offload_to_disk(model=self.model.thinker,
module=self.model.thinker.audio_tower,
disk_path=self.quantize_config.offload_to_disk_path,
)

offload_to_disk(model=self.model.thinker.visual,
module=self.model.thinker.visual.rotary_pos_emb,
disk_path=self.quantize_config.offload_to_disk_path,
)

offload_to_disk(model=self.model.thinker.model,
module=self.model.thinker.model.rotary_emb,
disk_path=self.quantize_config.offload_to_disk_path,
)
return

self.model.thinker.model.embed_tokens = self.model.thinker.model.embed_tokens.to(CPU)
self.model.thinker.visual = self.model.thinker.visual.to(CPU)
self.model.thinker.audio_tower = self.model.thinker.audio_tower.to(CPU)

self.model.thinker.visual.rotary_pos_emb = self.model.thinker.visual.rotary_pos_emb.to(CPU)
self.model.thinker.model.rotary_emb = self.model.thinker.model.rotary_emb.to(CPU)
37 changes: 22 additions & 15 deletions gptqmodel/models/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,19 +180,26 @@ def skip(*args, **kwargs):
cls.before_model_load(cls, load_quantized_model=False)
from ..utils.hf import build_shell_model

#model = cls.loader.from_pretrained(model_local_path, config=config, **model_init_kwargs)
print("shell model-----------")
model = build_shell_model(cls.loader, config=config, **model_init_kwargs)
model._model_init_kwargs = model_init_kwargs

print_module_tree(model=model)
# enable mmap with low_cpu_mem_usage
turtle_model = cls.loader.from_pretrained(model_local_path, config=config, low_cpu_mem_usage=True, **model_init_kwargs)

# TODO FIX ME...temp store model_init args
turtle_model._model_init_kwargs = model_init_kwargs
# print("actual turtle model-----------")
# print_module_tree(model=turtle_model)
if quantize_config.offload_to_disk:
print("shell model-----------")
model = build_shell_model(cls.loader, config=config, **model_init_kwargs)
model._model_init_kwargs = model_init_kwargs
print_module_tree(model=model)

# enable mmap with low_cpu_mem_usage
turtle_model = cls.loader.from_pretrained(model_local_path, config=config, low_cpu_mem_usage=True, **model_init_kwargs)

# TODO FIX ME...temp store model_init args
turtle_model._model_init_kwargs = model_init_kwargs
# print("actual turtle model-----------")
# print_module_tree(model=turtle_model)
else:
print("loading model directly to CPU (not using meta device or turtle_model)-----------")
model = cls.loader.from_pretrained(model_local_path, config=config, **model_init_kwargs)
model._model_init_kwargs = model_init_kwargs
print_module_tree(model=model)

turtle_model = None

model_config = model.config.to_dict()
seq_len_keys = ["max_position_embeddings", "seq_length", "n_positions", "multimodal_max_length"]
Expand All @@ -204,7 +211,7 @@ def skip(*args, **kwargs):
model.seqlen = 4096

model.eval()
turtle_model.eval()
turtle_model.eval() if turtle_model is not None else None

tokenizer = AutoTokenizer.from_pretrained(pretrained_model_id_or_path, trust_remote_code=trust_remote_code)

Expand Down Expand Up @@ -462,7 +469,7 @@ def skip(*args, **kwargs):
continue

if not any(name.startswith(prefix) for prefix in cls.extract_layers_node()) or any(name.startswith(ignore_module) for ignore_module in ignore_modules) or all(
not name.endswith(ignore_module) for sublist in cls.simple_layer_modules(config) for ignore_module in sublist
not name.endswith(ignore_module) for sublist in cls.simple_layer_modules(config, qcfg) for ignore_module in sublist
):
# log non-lm-head quantized modules only
if name is not cls.lm_head:
Expand Down
2 changes: 1 addition & 1 deletion gptqmodel/models/writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -443,7 +443,7 @@ def skip(*args, **kwargs):
continue

if any(name.startswith(ignore_module) for ignore_module in ignore_modules) or all(
not name.endswith(ignore_module) for sublist in self.simple_layer_modules(config) for ignore_module in sublist
not name.endswith(ignore_module) for sublist in self.simple_layer_modules(config, qcfg) for ignore_module in sublist
):
# log non-lm-head quantizerd modules only
if name is not self.lm_head:
Expand Down
3 changes: 3 additions & 0 deletions gptqmodel/utils/structure.py
Original file line number Diff line number Diff line change
Expand Up @@ -608,6 +608,9 @@ def alias_all_from_turtle_if_meta(

Logs each swap via log.info().
"""
if turtle_model is None:
return 0

turtle_map = dict(turtle_model.named_modules())
swapped = 0

Expand Down