From f698890eb0f1812e47f59385f582f11bf5780205 Mon Sep 17 00:00:00 2001 From: ZX-ModelCloud Date: Thu, 30 Oct 2025 15:13:53 +0800 Subject: [PATCH 1/5] supports granitemoehybrid Signed-off-by: ZX-ModelCloud --- gptqmodel/models/auto.py | 2 + .../models/definitions/granitemoehybrid.py | 27 ++++++++ gptqmodel/utils/model.py | 4 +- tests/models/model_test.py | 8 ++- tests/models/test_granite_4_0_h_350m.py | 65 +++++++++++++++++++ 5 files changed, 104 insertions(+), 2 deletions(-) create mode 100644 gptqmodel/models/definitions/granitemoehybrid.py create mode 100644 tests/models/test_granite_4_0_h_350m.py diff --git a/gptqmodel/models/auto.py b/gptqmodel/models/auto.py index d8d705064..988ecdf8f 100644 --- a/gptqmodel/models/auto.py +++ b/gptqmodel/models/auto.py @@ -138,6 +138,7 @@ from .definitions.starcoder2 import Starcoder2QModel # noqa: E402 from .definitions.telechat2 import TeleChat2QModel from .definitions.xverse import XverseQModel # noqa: E402 +from .definitions.granitemoehybrid import GraniteMoeHybridQModel # make quants and inference more determinisitc @@ -217,6 +218,7 @@ "mllama": MLlamaQModel, "marin": Qwen3QModel, "granite": LlamaQModel, # 100% llama clone + "granitemoehybrid": GraniteMoeHybridQModel, "mobilellm": MobileLLMQModel, "hymba": HymbaQModel, "olmo2": LlamaQModel, # 100% llama clone diff --git a/gptqmodel/models/definitions/granitemoehybrid.py b/gptqmodel/models/definitions/granitemoehybrid.py new file mode 100644 index 000000000..e4bcd5ead --- /dev/null +++ b/gptqmodel/models/definitions/granitemoehybrid.py @@ -0,0 +1,27 @@ +# SPDX-FileCopyrightText: 2024-2025 ModelCloud.ai +# SPDX-FileCopyrightText: 2024-2025 qubitium@modelcloud.ai +# SPDX-License-Identifier: Apache-2.0 +# Contact: qubitium@modelcloud.ai, x.com/qubitium + +from ..base import BaseQModel + + +class GraniteMoeHybridQModel(BaseQModel): + dynamic_expert_index = "num_local_experts" + + pre_lm_head_norm_module = "model.norm" + + layer_modules_strict = False + + module_tree = [ + "model", + "layers", + "#", + { + "input_layernorm": ("input_layernorm:!",), + "mamba": ("in_proj:0", "out_proj:1"), + "self_attn": ("q_proj:0", "k_proj:0", "v_proj:0", "o_proj:1"), + "shared_mlp": ("input_linear:0", "output_linear:1"), + "post_attention_layernorm": ("post_attention_layernorm:!",), + } + ] diff --git a/gptqmodel/utils/model.py b/gptqmodel/utils/model.py index 9e51e6fc3..9e7be273d 100644 --- a/gptqmodel/utils/model.py +++ b/gptqmodel/utils/model.py @@ -292,7 +292,7 @@ def make_quant( log.info(f"Kernel: selected -> `{linear_cls.__name__}`.") return linear_cls except NotImplementedError as e: - log.info(f"Kernel: skipped -> `{cls}`.") + log.info(f"Kernel: skipped -> `{cls}`. {str(e)}") # only fallback to other quant linears when backend is auto. if backend not in [BACKEND.AUTO, BACKEND.AUTO_TRAINABLE]: @@ -751,6 +751,8 @@ def pack_module( "original": "module.pack_original", } + effective_impl = "original" + packer_label = label_map[effective_impl] with log_time_block( diff --git a/tests/models/model_test.py b/tests/models/model_test.py index 6ee99e2ad..db50e7c38 100644 --- a/tests/models/model_test.py +++ b/tests/models/model_test.py @@ -382,7 +382,13 @@ def perform_post_quant_validation(self, model_path, trust_remote_code=False): eval_records = {} reuse_candidates = {} - compare_backends = (BACKEND.MARLIN,) if self.FORMAT is FORMAT.GPTQ else (BACKEND.MARLIN, BACKEND.GEMM) + if self.FORMAT is FORMAT.GPTQ: + if self.LOAD_BACKEND == BACKEND.MARLIN: + compare_backends = (BACKEND.MARLIN,) + else: + compare_backends = (self.LOAD_BACKEND,) + else: + compare_backends = (BACKEND.MARLIN, BACKEND.GEMM) fallback_backend = None if BACKEND.MARLIN in compare_backends: try: diff --git a/tests/models/test_granite_4_0_h_350m.py b/tests/models/test_granite_4_0_h_350m.py new file mode 100644 index 000000000..30e915f9e --- /dev/null +++ b/tests/models/test_granite_4_0_h_350m.py @@ -0,0 +1,65 @@ +# SPDX-FileCopyrightText: 2024-2025 ModelCloud.ai +# SPDX-FileCopyrightText: 2024-2025 qubitium@modelcloud.ai +# SPDX-License-Identifier: Apache-2.0 +# Contact: qubitium@modelcloud.ai, x.com/qubitium +from gptqmodel import BACKEND +from model_test import ModelTest + +from gptqmodel.utils.eval import EVAL + + +# a100:7, MARLIN kernel +# desc_act = False, act_group_aware = False 0.3200/0.3447 +# desc_act = False, act_group_aware = True 0.3181/0.3481 +# desc_act = True, REGRESSION 0.3191/0.3601 +# a100:6+7: MARLIN kernel +# desc_act = False, act_group_aware = True 0.3217/0.3643 +# | Metric | MARLIN | +# |--------------------------------|----------| +# | arc_challenge :: acc,none | 0.3174 | +# | arc_challenge :: acc_norm,none | 0.3601 | +# | mmlu_stem :: acc,none | 0.3186 | +class Test_Granite_4_0_H_1B(ModelTest): + NATIVE_MODEL_ID = "/monster/data/model/granite-4.0-h-350m" # "ibm-granite/granite-4.0-h-350m" + GROUP_SIZE = 32 + EVAL_BATCH_SIZE = 64 + LOAD_BACKEND = BACKEND.TORCH + SAVE_PATH = "granite-4-h-350m-g128" + EVAL_TASKS = { + EVAL.LM_EVAL.ARC_CHALLENGE: { + "chat_template": True, + "acc": { + "value": 0.3191, + "floor_pct": 0.04, + "ceil_pct": 0.10, + }, + "acc_norm": { + "value": 0.3507, + "floor_pct": 0.04, + "ceil_pct": 0.10, + }, + }, + # EVAL.LM_EVAL.MMLU_STEM: { + # "chat_template": False, + # "acc": { + # "value": 0.3054, + # "floor_pct": 0.04, + # "ceil_pct": 0.10, + # }, + # }, + } + + # llama 3.2 Instruct requires chat = true to have normal ARC scores + # mmlu requires chat = false + # APPLY_CHAT_TEMPLATE = True + # QUANT_BATCH_SIZE = 4 + + # EORA = Lora( + # # for quant, path is save path. for load, it is loading path + # path="./eora_test", + # rank=128, + # ) + # b1 = 0.315, b4 = 0.3106, b8 = 0.3148, b32 = 0.3148, b16 = 0.3234 + + def test_granite(self): + self.quant_lm_eval() From 5f00148e92c13d5979ea5fd6062c2456897169e0 Mon Sep 17 00:00:00 2001 From: ZX-ModelCloud Date: Thu, 30 Oct 2025 18:57:49 +0800 Subject: [PATCH 2/5] add test_granite_4_0_h_1b.py Signed-off-by: ZX-ModelCloud --- tests/models/test_granite_4_0_h_1b.py | 52 +++++++++++++++++++++++++ tests/models/test_granite_4_0_h_350m.py | 37 ++++++------------ 2 files changed, 64 insertions(+), 25 deletions(-) create mode 100644 tests/models/test_granite_4_0_h_1b.py diff --git a/tests/models/test_granite_4_0_h_1b.py b/tests/models/test_granite_4_0_h_1b.py new file mode 100644 index 000000000..254408649 --- /dev/null +++ b/tests/models/test_granite_4_0_h_1b.py @@ -0,0 +1,52 @@ +# SPDX-FileCopyrightText: 2024-2025 ModelCloud.ai +# SPDX-FileCopyrightText: 2024-2025 qubitium@modelcloud.ai +# SPDX-License-Identifier: Apache-2.0 +# Contact: qubitium@modelcloud.ai, x.com/qubitium +from gptqmodel import BACKEND +from model_test import ModelTest + +from gptqmodel.utils.eval import EVAL + + +# a100:7, MARLIN kernel +# desc_act = False, act_group_aware = False 0.3200/0.3447 +# desc_act = False, act_group_aware = True 0.3181/0.3481 +# desc_act = True, REGRESSION 0.3191/0.3601 +# a100:6+7: MARLIN kernel +# desc_act = False, act_group_aware = True 0.3217/0.3643 +# | Metric | MARLIN | +# |--------------------------------|----------| +# | arc_challenge :: acc,none | 0.3174 | +# | arc_challenge :: acc_norm,none | 0.3601 | +# | mmlu_stem :: acc,none | 0.3186 | +class Test_Granite_4_0_H_1B(ModelTest): + NATIVE_MODEL_ID = "/monster/data/model/granite-4.0-h-1b" # "ibm-granite/granite-4.0-h-1b" + GROUP_SIZE = 32 + EVAL_BATCH_SIZE = 1 + LOAD_BACKEND = BACKEND.TORCH + EVAL_TASKS = { + EVAL.LM_EVAL.ARC_CHALLENGE: { + "chat_template": True, + "acc": { + "value": 0.3968, + "floor_pct": 0.04, + "ceil_pct": 0.10, + }, + "acc_norm": { + "value": 0.4138, + "floor_pct": 0.04, + "ceil_pct": 0.10, + }, + }, + EVAL.LM_EVAL.MMLU_STEM: { + "chat_template": False, + "acc": { + "value": 0.4015, + "floor_pct": 0.1, + "ceil_pct": 0.20, + }, + }, + } + + def test_granite(self): + self.quant_lm_eval() diff --git a/tests/models/test_granite_4_0_h_350m.py b/tests/models/test_granite_4_0_h_350m.py index 30e915f9e..9637d2260 100644 --- a/tests/models/test_granite_4_0_h_350m.py +++ b/tests/models/test_granite_4_0_h_350m.py @@ -19,47 +19,34 @@ # | arc_challenge :: acc,none | 0.3174 | # | arc_challenge :: acc_norm,none | 0.3601 | # | mmlu_stem :: acc,none | 0.3186 | -class Test_Granite_4_0_H_1B(ModelTest): +class Test_Granite_4_0_H_350M(ModelTest): NATIVE_MODEL_ID = "/monster/data/model/granite-4.0-h-350m" # "ibm-granite/granite-4.0-h-350m" GROUP_SIZE = 32 - EVAL_BATCH_SIZE = 64 + EVAL_BATCH_SIZE = 16 LOAD_BACKEND = BACKEND.TORCH - SAVE_PATH = "granite-4-h-350m-g128" EVAL_TASKS = { EVAL.LM_EVAL.ARC_CHALLENGE: { "chat_template": True, "acc": { - "value": 0.3191, + "value": 0.3046, "floor_pct": 0.04, "ceil_pct": 0.10, }, "acc_norm": { - "value": 0.3507, + "value": 0.3157, "floor_pct": 0.04, "ceil_pct": 0.10, }, }, - # EVAL.LM_EVAL.MMLU_STEM: { - # "chat_template": False, - # "acc": { - # "value": 0.3054, - # "floor_pct": 0.04, - # "ceil_pct": 0.10, - # }, - # }, + EVAL.LM_EVAL.MMLU_STEM: { + "chat_template": False, + "acc": { + "value": 0.2915, + "floor_pct": 0.1, + "ceil_pct": 0.20, + }, + }, } - # llama 3.2 Instruct requires chat = true to have normal ARC scores - # mmlu requires chat = false - # APPLY_CHAT_TEMPLATE = True - # QUANT_BATCH_SIZE = 4 - - # EORA = Lora( - # # for quant, path is save path. for load, it is loading path - # path="./eora_test", - # rank=128, - # ) - # b1 = 0.315, b4 = 0.3106, b8 = 0.3148, b32 = 0.3148, b16 = 0.3234 - def test_granite(self): self.quant_lm_eval() From d764efab47ceaddcd3af8283995a2407d3cb55a5 Mon Sep 17 00:00:00 2001 From: ZX-ModelCloud Date: Thu, 30 Oct 2025 18:58:14 +0800 Subject: [PATCH 3/5] cleanup Signed-off-by: ZX-ModelCloud --- gptqmodel/models/definitions/granitemoehybrid.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/gptqmodel/models/definitions/granitemoehybrid.py b/gptqmodel/models/definitions/granitemoehybrid.py index e4bcd5ead..94440c4b4 100644 --- a/gptqmodel/models/definitions/granitemoehybrid.py +++ b/gptqmodel/models/definitions/granitemoehybrid.py @@ -7,8 +7,6 @@ class GraniteMoeHybridQModel(BaseQModel): - dynamic_expert_index = "num_local_experts" - pre_lm_head_norm_module = "model.norm" layer_modules_strict = False From 9ecb6a5b71f154b3eb67ebbfd258e7e280345ffa Mon Sep 17 00:00:00 2001 From: ZX-ModelCloud Date: Thu, 30 Oct 2025 19:07:54 +0800 Subject: [PATCH 4/5] If a ValueError occurs in pack_gpu() or pack_block(), fallback to pack_original() Signed-off-by: ZX-ModelCloud --- gptqmodel/utils/model.py | 36 ++++++++++++++++++++---------------- 1 file changed, 20 insertions(+), 16 deletions(-) diff --git a/gptqmodel/utils/model.py b/gptqmodel/utils/model.py index 9e7be273d..dd4d25a80 100644 --- a/gptqmodel/utils/model.py +++ b/gptqmodel/utils/model.py @@ -292,7 +292,7 @@ def make_quant( log.info(f"Kernel: selected -> `{linear_cls.__name__}`.") return linear_cls except NotImplementedError as e: - log.info(f"Kernel: skipped -> `{cls}`. {str(e)}") + log.info(f"Kernel: skipped -> `{cls}`.") # only fallback to other quant linears when backend is auto. if backend not in [BACKEND.AUTO, BACKEND.AUTO_TRAINABLE]: @@ -751,8 +751,6 @@ def pack_module( "original": "module.pack_original", } - effective_impl = "original" - packer_label = label_map[effective_impl] with log_time_block( @@ -761,20 +759,26 @@ def pack_module( module_name=name, ): if effective_impl == "gpu": - module.pack_gpu( - linear=layer, - scales=q_scales, - zeros=q_zeros, - g_idx=q_g_idx, - device=target_device, - ) + try: + module.pack_gpu( + linear=layer, + scales=q_scales, + zeros=q_zeros, + g_idx=q_g_idx, + device=target_device, + ) + except ValueError: + module.pack_original(linear=layer, scales=q_scales, zeros=q_zeros, g_idx=q_g_idx) elif effective_impl == "block": - module.pack_block( - linear=layer, - scales=q_scales, - zeros=q_zeros, - g_idx=q_g_idx, - ) + try: + module.pack_block( + linear=layer, + scales=q_scales, + zeros=q_zeros, + g_idx=q_g_idx, + ) + except ValueError: + module.pack_original(linear=layer, scales=q_scales, zeros=q_zeros, g_idx=q_g_idx) else: module.pack_original(linear=layer, scales=q_scales, zeros=q_zeros, g_idx=q_g_idx) From 3b57c0d4dda1d803529957f91f35faa062586e20 Mon Sep 17 00:00:00 2001 From: ZX-ModelCloud Date: Thu, 30 Oct 2025 19:21:23 +0800 Subject: [PATCH 5/5] update comments Signed-off-by: ZX-ModelCloud --- tests/models/test_granite_4_0_h_1b.py | 14 +++++--------- tests/models/test_granite_4_0_h_350m.py | 14 +++++--------- 2 files changed, 10 insertions(+), 18 deletions(-) diff --git a/tests/models/test_granite_4_0_h_1b.py b/tests/models/test_granite_4_0_h_1b.py index 254408649..b01d4924c 100644 --- a/tests/models/test_granite_4_0_h_1b.py +++ b/tests/models/test_granite_4_0_h_1b.py @@ -8,17 +8,13 @@ from gptqmodel.utils.eval import EVAL -# a100:7, MARLIN kernel -# desc_act = False, act_group_aware = False 0.3200/0.3447 -# desc_act = False, act_group_aware = True 0.3181/0.3481 -# desc_act = True, REGRESSION 0.3191/0.3601 -# a100:6+7: MARLIN kernel -# desc_act = False, act_group_aware = True 0.3217/0.3643 +# a100:0, TORCH kernel +# desc_act = False, act_group_aware = True # | Metric | MARLIN | # |--------------------------------|----------| -# | arc_challenge :: acc,none | 0.3174 | -# | arc_challenge :: acc_norm,none | 0.3601 | -# | mmlu_stem :: acc,none | 0.3186 | +# | arc_challenge :: acc,none | 0.3968 | +# | arc_challenge :: acc_norm,none | 0.4138 | +# | mmlu_stem :: acc,none | 0.4015 | class Test_Granite_4_0_H_1B(ModelTest): NATIVE_MODEL_ID = "/monster/data/model/granite-4.0-h-1b" # "ibm-granite/granite-4.0-h-1b" GROUP_SIZE = 32 diff --git a/tests/models/test_granite_4_0_h_350m.py b/tests/models/test_granite_4_0_h_350m.py index 9637d2260..75fb282fd 100644 --- a/tests/models/test_granite_4_0_h_350m.py +++ b/tests/models/test_granite_4_0_h_350m.py @@ -8,17 +8,13 @@ from gptqmodel.utils.eval import EVAL -# a100:7, MARLIN kernel -# desc_act = False, act_group_aware = False 0.3200/0.3447 -# desc_act = False, act_group_aware = True 0.3181/0.3481 -# desc_act = True, REGRESSION 0.3191/0.3601 -# a100:6+7: MARLIN kernel -# desc_act = False, act_group_aware = True 0.3217/0.3643 +# a100:0, TORCH kernel +# desc_act = False, act_group_aware = True # | Metric | MARLIN | # |--------------------------------|----------| -# | arc_challenge :: acc,none | 0.3174 | -# | arc_challenge :: acc_norm,none | 0.3601 | -# | mmlu_stem :: acc,none | 0.3186 | +# | arc_challenge :: acc,none | 0.3046 | +# | arc_challenge :: acc_norm,none | 0.3157 | +# | mmlu_stem :: acc,none | 0.2915 | class Test_Granite_4_0_H_350M(ModelTest): NATIVE_MODEL_ID = "/monster/data/model/granite-4.0-h-350m" # "ibm-granite/granite-4.0-h-350m" GROUP_SIZE = 32