From 034c7ed6b6e89e317b739a888749002a9a028b90 Mon Sep 17 00:00:00 2001 From: Qubitium Date: Fri, 31 Oct 2025 15:53:54 +0000 Subject: [PATCH 1/6] add brumby --- gptqmodel/models/auto.py | 2 ++ gptqmodel/models/definitions/__init__.py | 1 + gptqmodel/models/definitions/brumby.py | 37 +++++++++++++++++++ tests/models/test_brumby.py | 46 ++++++++++++++++++++++++ 4 files changed, 86 insertions(+) create mode 100644 gptqmodel/models/definitions/brumby.py create mode 100644 tests/models/test_brumby.py diff --git a/gptqmodel/models/auto.py b/gptqmodel/models/auto.py index e4cd0af16..aeb2cefe3 100644 --- a/gptqmodel/models/auto.py +++ b/gptqmodel/models/auto.py @@ -74,6 +74,7 @@ from .definitions.baichuan import BaiChuanQModel # noqa: E402 from .definitions.bailing_moe import BailingMoeQModel # noqa: E402 from .definitions.bloom import BloomQModel # noqa: E402 +from .definitions.brumby import BrumbyQModel # noqa: E402 from .definitions.chatglm import ChatGLMQModel # noqa: E402 from .definitions.codegen import CodeGenQModel # noqa: E402 from .definitions.dbrx import DbrxQModel # noqa: E402 @@ -151,6 +152,7 @@ "apertus": ApertusQModel, "dream": DreamQModel, "bloom": BloomQModel, + "brumby": BrumbyQModel, "gpt_neo": GptNeoQModel, "kimi_k2": DeepSeekV3QModel, # 100% DeepSeekV3QModel clone "klear": KlearQModel, diff --git a/gptqmodel/models/definitions/__init__.py b/gptqmodel/models/definitions/__init__.py index 5a60ef88f..ba7422a52 100644 --- a/gptqmodel/models/definitions/__init__.py +++ b/gptqmodel/models/definitions/__init__.py @@ -10,6 +10,7 @@ # other model from .baichuan import BaiChuanQModel from .bloom import BloomQModel +from .brumby import BrumbyQModel from .chatglm import ChatGLMQModel from .codegen import CodeGenQModel from .dbrx import DbrxQModel diff --git a/gptqmodel/models/definitions/brumby.py b/gptqmodel/models/definitions/brumby.py new file mode 100644 index 000000000..e399c0a46 --- /dev/null +++ b/gptqmodel/models/definitions/brumby.py @@ -0,0 +1,37 @@ +# SPDX-FileCopyrightText: 2024-2025 ModelCloud.ai +# SPDX-FileCopyrightText: 2024-2025 qubitium@modelcloud.ai +# SPDX-License-Identifier: Apache-2.0 +# Contact: qubitium@modelcloud.ai, x.com/qubitium + +from ..base import BaseQModel + + +class BrumbyQModel(BaseQModel): + require_trust_remote_code = True + require_pkgs_version = ["retention>=1.0.7"] + + pre_lm_head_norm_module = "model.norm" + + module_tree = [ + "model", + "layers", + "#", + { + "input_layernorm": ("input_layernorm:0:!",), + "self_attn": ( + "q_proj:0", + "k_proj:0", + "v_proj:0", + "g_proj:0:!", + "o_proj:1", + "q_norm:0:!", + "k_norm:0:!", + ), + "post_attention_layernorm": ("post_attention_layernorm:1:!",), + "mlp": ( + "gate_proj:0", + "up_proj:0", + "down_proj:1", + ), + }, + ] diff --git a/tests/models/test_brumby.py b/tests/models/test_brumby.py new file mode 100644 index 000000000..60608c3f2 --- /dev/null +++ b/tests/models/test_brumby.py @@ -0,0 +1,46 @@ +# SPDX-FileCopyrightText: 2024-2025 ModelCloud.ai +# SPDX-FileCopyrightText: 2024-2025 qubitium@modelcloud.ai +# SPDX-License-Identifier: Apache-2.0 +# Contact: qubitium@modelcloud.ai, x.com/qubitium + +from model_test import ModelTest + +from gptqmodel.utils.eval import EVAL + + +# | Metric | MARLIN | +# |---------------------------|----------| +# | arc_challenge :: acc | 0.8900 | +# | gsm8k_cot :: exact | 0.8800 | +# | gsm8k_platinum_cot :: exact_match,flexible-extract | 0.8700 | +# | mmlu :: acc | 0.7100 | +class TestBrumby(ModelTest): + GROUP_SIZE = 32 + NATIVE_MODEL_ID = "/monster/data/model/Brumby-14B-Base" + # EVAL_BATCH_SIZE = 32 + TRUST_REMOTE_CODE = True + DATASET_CONCAT_SIZE = 2048 + EVAL_TASKS = { + EVAL.LM_EVAL.GSM8K_PLATINUM_COT: { + "chat_template": True, + "exact_match,flexible-extract": { + "value": 0.87, + "floor_pct": 6.05, + }, + }, + EVAL.LM_EVAL.ARC_CHALLENGE: { + "acc": { + "value": 0.89, + "floor_pct": 4.05, + }, + }, + EVAL.LM_EVAL.MMLU: { + "acc": { + "value": 0.71, + "floor_pct": 4.05, + }, + }, + } + + def test_brumby(self): + self.quant_lm_eval() From 42f9e847b8fe91d870d05418e8159f3c44a983ed Mon Sep 17 00:00:00 2001 From: Qubitium Date: Fri, 31 Oct 2025 16:30:44 +0000 Subject: [PATCH 2/6] update --- gptqmodel/models/definitions/brumby.py | 8 +++++++ tests/models/model_test.py | 6 +++++- tests/models/test_brumby.py | 29 +++++++++++++++++--------- 3 files changed, 32 insertions(+), 11 deletions(-) diff --git a/gptqmodel/models/definitions/brumby.py b/gptqmodel/models/definitions/brumby.py index e399c0a46..24067685c 100644 --- a/gptqmodel/models/definitions/brumby.py +++ b/gptqmodel/models/definitions/brumby.py @@ -35,3 +35,11 @@ class BrumbyQModel(BaseQModel): ), }, ] + + def after_model_load(self, model, load_quantized_model=False): + if hasattr(model, "config") and hasattr(model.config, "use_cache"): + model.config.use_cache = False + generation_config = getattr(model, "generation_config", None) + if generation_config is not None and hasattr(generation_config, "use_cache"): + generation_config.use_cache = False + return model diff --git a/tests/models/model_test.py b/tests/models/model_test.py index d4ac7c655..cae636155 100644 --- a/tests/models/model_test.py +++ b/tests/models/model_test.py @@ -24,7 +24,7 @@ from enum import Enum # noqa: E402 from pathlib import Path # noqa: E402 -from typing import Dict, List # noqa: E402 +from typing import Any, Dict, List # noqa: E402 from logbar import LogBar # noqa: E402 from tabulate import tabulate # noqa: E402 @@ -95,6 +95,7 @@ class ModelTest(unittest.TestCase): DATASET_SORT = "desc" DELETE_QUANTIZED_MODEL = True EVAL_TASKS = None + LOAD_MODEL_EXTRA_ARGS: Dict[str, Any] = {} KERNEL_QUANT = {} # kernel sets KERNEL_INFERENCE = {} # kernel sets @@ -909,6 +910,9 @@ def quantModel(self, model_id_or_path, trust_remote_code=False, dtype="auto", ne def loadQuantModel(self, model_id_or_path, trust_remote_code=False, tokenizer_path=None, backend=None, **args): load_kwargs = dict(args) + if self.LOAD_MODEL_EXTRA_ARGS: + for key, value in self.LOAD_MODEL_EXTRA_ARGS.items(): + load_kwargs.setdefault(key, value) if self.USE_FLASH_ATTN: if is_flash_attn_2_available(): diff --git a/tests/models/test_brumby.py b/tests/models/test_brumby.py index 60608c3f2..388df7ecd 100644 --- a/tests/models/test_brumby.py +++ b/tests/models/test_brumby.py @@ -6,6 +6,8 @@ from model_test import ModelTest from gptqmodel.utils.eval import EVAL +from gptqmodel.quantization import FORMAT, METHOD + # | Metric | MARLIN | @@ -15,30 +17,37 @@ # | gsm8k_platinum_cot :: exact_match,flexible-extract | 0.8700 | # | mmlu :: acc | 0.7100 | class TestBrumby(ModelTest): + # FORMAT = FORMAT.GEMM + # METHOD = METHOD.AWQ + GROUP_SIZE = 32 + DATASET_SIZE = 1024 NATIVE_MODEL_ID = "/monster/data/model/Brumby-14B-Base" - # EVAL_BATCH_SIZE = 32 TRUST_REMOTE_CODE = True + LOAD_MODEL_EXTRA_ARGS = {"use_cache": False} DATASET_CONCAT_SIZE = 2048 EVAL_TASKS = { EVAL.LM_EVAL.GSM8K_PLATINUM_COT: { "chat_template": True, "exact_match,flexible-extract": { "value": 0.87, - "floor_pct": 6.05, + "floor_pct": 0.05, + "ceil_pct": 0.10, }, }, - EVAL.LM_EVAL.ARC_CHALLENGE: { - "acc": { - "value": 0.89, - "floor_pct": 4.05, + EVAL.LM_EVAL.GSM8K_COT: { + "chat_template": True, + "exact_match,flexible-extract": { + "value": 0.88, + "floor_pct": 0.05, + "ceil_pct": 0.10, }, }, + EVAL.LM_EVAL.ARC_CHALLENGE: { + "acc": {"value": 0.89, "floor_pct": 0.05, "ceil_pct": 0.10}, + }, EVAL.LM_EVAL.MMLU: { - "acc": { - "value": 0.71, - "floor_pct": 4.05, - }, + "acc": {"value": 0.71, "floor_pct": 0.05, "ceil_pct": 0.10}, }, } From fd5cd14853b97de49dbddd28eb98f271527b1b49 Mon Sep 17 00:00:00 2001 From: Qubitium Date: Fri, 31 Oct 2025 16:41:57 +0000 Subject: [PATCH 3/6] fix eval --- scripts/eval_model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/eval_model.py b/scripts/eval_model.py index f6c3d3a3f..45a574815 100644 --- a/scripts/eval_model.py +++ b/scripts/eval_model.py @@ -13,7 +13,7 @@ import gptqmodel from tabulate import tabulate - +from gptqmodel import GPTQModel from gptqmodel.models.base import BaseQModel from gptqmodel.utils.eval import EVAL From 743559c1d8214e2aa4e86034802da4b776c1b474 Mon Sep 17 00:00:00 2001 From: Qubitium Date: Fri, 31 Oct 2025 21:53:25 +0000 Subject: [PATCH 4/6] default drop samp to false --- gptqmodel/utils/hf.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/gptqmodel/utils/hf.py b/gptqmodel/utils/hf.py index b195106f2..2d949842e 100644 --- a/gptqmodel/utils/hf.py +++ b/gptqmodel/utils/hf.py @@ -54,7 +54,7 @@ def _load_sanitized_generation_config(path: str) -> Optional[GenerationConfig]: cfg = GenerationConfig.from_dict(cleaned, **kwargs) if removed: log.info("Model: Removed unsupported sampling fields from `generation_config.json` during load.") - _sanitize_generation_config(cfg, drop_sampling_fields=True) + _sanitize_generation_config(cfg, drop_sampling_fields=False) return cfg @@ -68,7 +68,7 @@ def autofix_hf_model_config(model: PreTrainedModel, path: str = None): cfg = _load_sanitized_generation_config(path) if cfg is None: cfg = GenerationConfig.from_pretrained(pretrained_model_name=path, do_sample=True) - _sanitize_generation_config(cfg, drop_sampling_fields=True) + _sanitize_generation_config(cfg, drop_sampling_fields=False) if cfg != model.generation_config: # migrated pad_token_id to config if hasattr(model.generation_config, "pad_token_id"): @@ -90,7 +90,7 @@ def autofix_hf_model_config(model: PreTrainedModel, path: str = None): def autofix_hf_generation_config(cfg: GenerationConfig): - _sanitize_generation_config(cfg, drop_sampling_fields=True) + _sanitize_generation_config(cfg, drop_sampling_fields=False) # HF has recently started to perform very strict validation model save which results in warnings on load() # to become exceptions on save(). if cfg.do_sample is False: From ecbe0bdc6d7abbe94f5ba00577155066a1769371 Mon Sep 17 00:00:00 2001 From: Qubitium Date: Fri, 31 Oct 2025 22:00:48 +0000 Subject: [PATCH 5/6] default drop samp to false --- gptqmodel/utils/hf.py | 20 -------------------- tests/models/test_brumby.py | 10 ---------- 2 files changed, 30 deletions(-) diff --git a/gptqmodel/utils/hf.py b/gptqmodel/utils/hf.py index 2d949842e..d757e4986 100644 --- a/gptqmodel/utils/hf.py +++ b/gptqmodel/utils/hf.py @@ -15,9 +15,6 @@ log = setup_logger() -GENERATION_SAMPLING_FIELDS = ("temperature", "top_p") - - def _sanitize_generation_config(cfg: GenerationConfig, *, drop_sampling_fields: bool = False) -> bool: changed = False if cfg is None: @@ -27,12 +24,6 @@ def _sanitize_generation_config(cfg: GenerationConfig, *, drop_sampling_fields: cfg.do_sample = True changed = True - if drop_sampling_fields: - for field in GENERATION_SAMPLING_FIELDS: - if hasattr(cfg, field): - if getattr(cfg, field) is not None: - changed = True - setattr(cfg, field, None) return changed @@ -43,17 +34,10 @@ def _load_sanitized_generation_config(path: str) -> Optional[GenerationConfig]: return None cleaned = dict(config_dict) - removed = False - for field in GENERATION_SAMPLING_FIELDS: - if field in cleaned: - cleaned.pop(field, None) - removed = True if cleaned.get("do_sample") is not True: cleaned["do_sample"] = True cfg = GenerationConfig.from_dict(cleaned, **kwargs) - if removed: - log.info("Model: Removed unsupported sampling fields from `generation_config.json` during load.") _sanitize_generation_config(cfg, drop_sampling_fields=False) return cfg @@ -125,10 +109,6 @@ def sanitize_generation_config_file(path: str) -> bool: return False changed = False - for field in GENERATION_SAMPLING_FIELDS: - if field in data: - data.pop(field, None) - changed = True if data.get("do_sample") is not True: data["do_sample"] = True diff --git a/tests/models/test_brumby.py b/tests/models/test_brumby.py index 388df7ecd..2ceae22c1 100644 --- a/tests/models/test_brumby.py +++ b/tests/models/test_brumby.py @@ -9,17 +9,7 @@ from gptqmodel.quantization import FORMAT, METHOD - -# | Metric | MARLIN | -# |---------------------------|----------| -# | arc_challenge :: acc | 0.8900 | -# | gsm8k_cot :: exact | 0.8800 | -# | gsm8k_platinum_cot :: exact_match,flexible-extract | 0.8700 | -# | mmlu :: acc | 0.7100 | class TestBrumby(ModelTest): - # FORMAT = FORMAT.GEMM - # METHOD = METHOD.AWQ - GROUP_SIZE = 32 DATASET_SIZE = 1024 NATIVE_MODEL_ID = "/monster/data/model/Brumby-14B-Base" From f90cc858fa5975597c6a1105adea95925d98080d Mon Sep 17 00:00:00 2001 From: Qubitium Date: Fri, 31 Oct 2025 22:03:42 +0000 Subject: [PATCH 6/6] notes --- README.md | 30 ++++++++++++++++-------------- 1 file changed, 16 insertions(+), 14 deletions(-) diff --git a/README.md b/README.md index f86ccffbd..4b5f1b83f 100644 --- a/README.md +++ b/README.md @@ -17,6 +17,7 @@

## Latest News +* 11/1/2025 5.1.0-dev: ✨Brumby (attention free) model support. * 10/31/2025 5.1.0-dev: ✨IBM Granite Nano support. New `calibration_concat_separator` config option. * 10/30/2025 5.1.0-dev: πŸŽ‰AWQ support out of beta with full feature support in including multi-gpu quant and MoE vram saving. * 10/30/2025 5.1.0-dev: ✨Marin model. New AWQ Torch reference kernel. Fix AWQ Marlin kernel for bf16. Fix GLM 4.5/4.6 MoE missing `mtp` layers on model save (HF bug). Modular refractor. @@ -180,20 +181,21 @@ Native support support some of the most popular multi-modal models: ## Model Support -| Model | | | | | | | | | | -|-------------------|---|-------------|---|----------------|---|----------------|---|---------------------|---| -| Apertus | βœ… | EXAONE 3.0 | βœ… | InternLM 1/2.5 | βœ… | Mixtral | βœ… | Qwen 2/3 (Next/MoE) | βœ… | -| Baichuan | βœ… | Falcon (H1) | βœ… | Kimi K2 | βœ… | MobileLLM | βœ… | Qwen 2/2.5/3 VL | βœ… | -| Bloom | βœ… | FastVLM | βœ… | Klear | βœ… | MOSS | βœ… | Qwen 2.5/3 Omni | βœ… | -| ChatGLM | βœ… | Gemma 1/2/3 | βœ… | LING/RING | βœ… | MPT | βœ… | RefinedWeb | βœ… | -| CodeGen | βœ… | GPTBigCod | βœ… | Llama 1-3.3 | βœ… | Nemotron H | βœ… | StableLM | βœ… | -| Cohere 1-2 | βœ… | GPTQ-Neo(X) | βœ… | Llama 3.2 VL | βœ… | Nemotron Ultra | βœ… | StarCoder2 | βœ… | -| DBRX Converted | βœ… | GPT-2 | βœ… | Llama 4 | βœ… | OPT | βœ… | TeleChat2 | βœ… | -| Deci | βœ… | GPT-J | βœ… | LongCatFlash | βœ… | OLMo2 | βœ… | Yi | βœ… | -| DeepSeek-V2/V3/R1 | βœ… | GPT-OSS | βœ… | LongLLaMA | βœ… | Ovis 1.6/2 | βœ… | Seed-OSS | βœ… | -| DeepSeek-V2-Lite | βœ… | Granite | βœ… | Instella | βœ… | Phi 1-4 | βœ… | XVERSE | βœ… | -| Dream | βœ… | GRIN-MoE | βœ… | MiniCPM3 | βœ… | PanGu-Ξ± | βœ… | Minimax M2 | βœ… | -| ERNIE 4.5 | βœ… | Hymba | βœ… | Mistral | βœ… | Qwen 1/2/3 | βœ… | GLM 4.X | βœ… | +| Model | | | | | | | | | | +|-------------------|---|-------------|---|---------------|--|-----------|--|-----------------|--| +| Apertus | βœ… | EXAONE 3.0 | βœ… | InternLM 1/2.5 | βœ… | Mixtral | βœ… | Qwen 2/3 (Next/MoE) | βœ… | +| Baichuan | βœ… | Falcon (H1) | βœ… | Kimi K2 | βœ… | MobileLLM | βœ… | Qwen 2/2.5/3 VL | βœ… | +| Bloom | βœ… | FastVLM | βœ… | Klear | βœ… | MOSS | βœ… | Qwen 2.5/3 Omni | βœ… | +| ChatGLM | βœ… | Gemma 1/2/3 | βœ… | LING/RING | βœ… | MPT | βœ… | RefinedWeb | βœ… | +| CodeGen | βœ… | GPTBigCod | βœ… | Llama 1-3.3 | βœ… | Nemotron H | βœ… | StableLM | βœ… | +| Cohere 1-2 | βœ… | GPTQ-Neo(X) | βœ… | Llama 3.2 VL | βœ… | Nemotron Ultra | βœ… | StarCoder2 | βœ… | +| DBRX Converted | βœ… | GPT-2 | βœ… | Llama 4 | βœ… | OPT | βœ… | TeleChat2 | βœ… | +| Deci | βœ… | GPT-J | βœ… | LongCatFlash | βœ… | OLMo2 | βœ… | Yi | βœ… | +| DeepSeek-V2/V3/R1 | βœ… | GPT-OSS | βœ… | LongLLaMA | βœ… | Ovis 1.6/2 | βœ… | Seed-OSS | βœ… | +| DeepSeek-V2-Lite | βœ… | Granite | βœ… | Instella | βœ… | Phi 1-4 | βœ… | XVERSE | βœ… | +| Dream | βœ… | GRIN-MoE | βœ… | MiniCPM3 | βœ… | PanGu-Ξ± | βœ… | Minimax M2 | βœ… | +| ERNIE 4.5 | βœ… | Hymba | βœ… | Mistral | βœ… | Qwen 1/2/3 | βœ… | GLM 4.X | βœ… | +| Brumby | βœ… | | | | | | | | | ## Platform and HW Support