Skip to content
Merged

Marin #2139

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions gptqmodel/models/auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,7 @@
"exaone": ExaOneQModel,
"grinmoe": GrinMoeQModel,
"mllama": MLlamaQModel,
"marin": Qwen3QModel,
"granite": LlamaQModel, # 100% llama clone
"mobilellm": MobileLLMQModel,
"hymba": HymbaQModel,
Expand Down
25 changes: 24 additions & 1 deletion gptqmodel/models/definitions/qwen3.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,27 @@


class Qwen3QModel(LlamaQModel):
pass
"""
Qwen3 inherits the Llama-style layout but inserts Q/K RMS norm layers
ahead of the attention projections. We mark those helper modules as
non-quantized so the layer walker captures the complete structure.
"""

module_tree = [
"model",
"layers",
"#",
{
"input_layernorm": ("input_layernorm:!",),
"self_attn": (
"q_norm:!",
"k_norm:!",
"q_proj:0",
"k_proj:0",
"v_proj:0",
"o_proj:1",
),
"post_attention_layernorm": ("post_attention_layernorm:!",),
"mlp": ("gate_proj:0", "up_proj:0", "down_proj:1"),
},
]
34 changes: 34 additions & 0 deletions tests/models/test_marin.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
# SPDX-FileCopyrightText: 2024-2025 ModelCloud.ai
# SPDX-FileCopyrightText: 2024-2025 qubitium@modelcloud.ai
# SPDX-License-Identifier: Apache-2.0
# Contact: qubitium@modelcloud.ai, x.com/qubitium

from accelerate import init_empty_weights
from transformers import AutoConfig, AutoModelForCausalLM

from model_test import ModelTest

from gptqmodel.models.definitions.qwen3 import Qwen3QModel
from gptqmodel.quantization.config import VRAMStrategy


class TestMarin(ModelTest):
NATIVE_MODEL_ID = "/monster/data/model/marin-32b-base"
VRAM_STRATEGY = VRAMStrategy.BALANCED
# Marin inherits Qwen3's backbone with QK-Norm attention.

def test_marin_module_tree(self):
config = AutoConfig.from_pretrained(self.NATIVE_MODEL_ID, trust_remote_code=True)
with init_empty_weights(include_buffers=True):
shell = AutoModelForCausalLM.from_config(config, trust_remote_code=True)

decoder_layer = shell.model.layers[0]
self.assertTrue(hasattr(decoder_layer.self_attn, "q_norm"))
self.assertTrue(hasattr(decoder_layer.self_attn, "k_norm"))
self.assertTrue(hasattr(decoder_layer.self_attn, "q_proj"))
self.assertTrue(hasattr(decoder_layer.self_attn, "o_proj"))
self.assertIn("q_norm:!", Qwen3QModel.module_tree[3]["self_attn"])
self.assertIn("k_norm:!", Qwen3QModel.module_tree[3]["self_attn"])

def test_marin(self):
self.quant_lm_eval()