diff --git a/gptqmodel/models/auto.py b/gptqmodel/models/auto.py index 600d79451..d8d705064 100644 --- a/gptqmodel/models/auto.py +++ b/gptqmodel/models/auto.py @@ -215,6 +215,7 @@ "exaone": ExaOneQModel, "grinmoe": GrinMoeQModel, "mllama": MLlamaQModel, + "marin": Qwen3QModel, "granite": LlamaQModel, # 100% llama clone "mobilellm": MobileLLMQModel, "hymba": HymbaQModel, diff --git a/gptqmodel/models/definitions/qwen3.py b/gptqmodel/models/definitions/qwen3.py index 0fd272c9c..c071dca9b 100644 --- a/gptqmodel/models/definitions/qwen3.py +++ b/gptqmodel/models/definitions/qwen3.py @@ -7,4 +7,27 @@ class Qwen3QModel(LlamaQModel): - pass + """ + Qwen3 inherits the Llama-style layout but inserts Q/K RMS norm layers + ahead of the attention projections. We mark those helper modules as + non-quantized so the layer walker captures the complete structure. + """ + + module_tree = [ + "model", + "layers", + "#", + { + "input_layernorm": ("input_layernorm:!",), + "self_attn": ( + "q_norm:!", + "k_norm:!", + "q_proj:0", + "k_proj:0", + "v_proj:0", + "o_proj:1", + ), + "post_attention_layernorm": ("post_attention_layernorm:!",), + "mlp": ("gate_proj:0", "up_proj:0", "down_proj:1"), + }, + ] diff --git a/tests/models/test_marin.py b/tests/models/test_marin.py new file mode 100644 index 000000000..e701a4d5e --- /dev/null +++ b/tests/models/test_marin.py @@ -0,0 +1,34 @@ +# SPDX-FileCopyrightText: 2024-2025 ModelCloud.ai +# SPDX-FileCopyrightText: 2024-2025 qubitium@modelcloud.ai +# SPDX-License-Identifier: Apache-2.0 +# Contact: qubitium@modelcloud.ai, x.com/qubitium + +from accelerate import init_empty_weights +from transformers import AutoConfig, AutoModelForCausalLM + +from model_test import ModelTest + +from gptqmodel.models.definitions.qwen3 import Qwen3QModel +from gptqmodel.quantization.config import VRAMStrategy + + +class TestMarin(ModelTest): + NATIVE_MODEL_ID = "/monster/data/model/marin-32b-base" + VRAM_STRATEGY = VRAMStrategy.BALANCED + # Marin inherits Qwen3's backbone with QK-Norm attention. + + def test_marin_module_tree(self): + config = AutoConfig.from_pretrained(self.NATIVE_MODEL_ID, trust_remote_code=True) + with init_empty_weights(include_buffers=True): + shell = AutoModelForCausalLM.from_config(config, trust_remote_code=True) + + decoder_layer = shell.model.layers[0] + self.assertTrue(hasattr(decoder_layer.self_attn, "q_norm")) + self.assertTrue(hasattr(decoder_layer.self_attn, "k_norm")) + self.assertTrue(hasattr(decoder_layer.self_attn, "q_proj")) + self.assertTrue(hasattr(decoder_layer.self_attn, "o_proj")) + self.assertIn("q_norm:!", Qwen3QModel.module_tree[3]["self_attn"]) + self.assertIn("k_norm:!", Qwen3QModel.module_tree[3]["self_attn"]) + + def test_marin(self): + self.quant_lm_eval()