From de4a0e91d30f5a7aab8056f6c94019fd77370659 Mon Sep 17 00:00:00 2001 From: LRL2-ModelCloud Date: Thu, 30 Oct 2025 10:36:38 +0800 Subject: [PATCH 1/3] fix qwen2 omni --- gptqmodel/models/definitions/base_qwen2_5_omni.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/gptqmodel/models/definitions/base_qwen2_5_omni.py b/gptqmodel/models/definitions/base_qwen2_5_omni.py index 4dcdf72ff..6fb7cd0bf 100644 --- a/gptqmodel/models/definitions/base_qwen2_5_omni.py +++ b/gptqmodel/models/definitions/base_qwen2_5_omni.py @@ -62,7 +62,8 @@ def pre_quantize_generate_hook_start(self): if hasattr(self.model, "token2wav"): self.shell_module_materialize(self.model.token2wav, self.quantize_config.device) for layer in self.model.thinker.model.layers: - self.shell_module_materialize(layer.self_attn.rotary_emb, self.quantize_config.device) + if hasattr(layer.self_attn, "rotary_emb"): + self.shell_module_materialize(layer.self_attn.rotary_emb, self.quantize_config.device) def pre_quantize_generate_hook_end(self): if self.quantize_config.offload_to_disk: @@ -103,7 +104,11 @@ def pre_quantize_generate_hook_end(self): ) for layer in self.model.thinker.model.layers: - layer.self_attn.rotary_emb = layer.self_attn.rotary_emb.to(CPU) + if hasattr(layer.self_attn, "rotary_emb"): + offload_to_disk(model=self.model.thinker.model, + module=layer.self_attn.rotary_emb, + disk_path=self.quantize_config.offload_to_disk_path, + ) return From c81b95bc744ae84f2846b8ecc5a5a1a0b5d346aa Mon Sep 17 00:00:00 2001 From: LRL2-ModelCloud Date: Thu, 30 Oct 2025 11:26:03 +0800 Subject: [PATCH 2/3] llama4 not support flash_attn_2 --- tests/models/test_llama4.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/models/test_llama4.py b/tests/models/test_llama4.py index 8a7a8b1e5..940112dcd 100644 --- a/tests/models/test_llama4.py +++ b/tests/models/test_llama4.py @@ -18,6 +18,7 @@ class TestLlama4(ModelTest): }, } TRUST_REMOTE_CODE = False + USE_FLASH_ATTN = False def test_llama4(self): self.quant_lm_eval() From 81e57b2ba3ea0d2304a8a1cc5a6d51ca4358542e Mon Sep 17 00:00:00 2001 From: LRL2-ModelCloud Date: Thu, 30 Oct 2025 11:27:46 +0800 Subject: [PATCH 3/3] cache NotImplementedError --- gptqmodel/models/loader.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/gptqmodel/models/loader.py b/gptqmodel/models/loader.py index 28baf7fe2..017d0719f 100644 --- a/gptqmodel/models/loader.py +++ b/gptqmodel/models/loader.py @@ -617,7 +617,11 @@ def assign(mod, device_id): # 1–3. Assign input embeddings, layers, and ignored modules # ------------------------------------------------------------- # Input embeddings → GPU 0 - in_emb = model.get_input_embeddings() if hasattr(model, "get_input_embeddings") else None + try: + in_emb = model.get_input_embeddings() + except NotImplementedError: + log.warning("Model does not implement get_input_embeddings. Skipping input embeddings assignment.") + in_emb = None assign(in_emb, device_ids[0]) # Alternating layers