From 034c7ed6b6e89e317b739a888749002a9a028b90 Mon Sep 17 00:00:00 2001
From: Qubitium
Date: Fri, 31 Oct 2025 15:53:54 +0000
Subject: [PATCH 1/6] add brumby
---
gptqmodel/models/auto.py | 2 ++
gptqmodel/models/definitions/__init__.py | 1 +
gptqmodel/models/definitions/brumby.py | 37 +++++++++++++++++++
tests/models/test_brumby.py | 46 ++++++++++++++++++++++++
4 files changed, 86 insertions(+)
create mode 100644 gptqmodel/models/definitions/brumby.py
create mode 100644 tests/models/test_brumby.py
diff --git a/gptqmodel/models/auto.py b/gptqmodel/models/auto.py
index e4cd0af16..aeb2cefe3 100644
--- a/gptqmodel/models/auto.py
+++ b/gptqmodel/models/auto.py
@@ -74,6 +74,7 @@
from .definitions.baichuan import BaiChuanQModel # noqa: E402
from .definitions.bailing_moe import BailingMoeQModel # noqa: E402
from .definitions.bloom import BloomQModel # noqa: E402
+from .definitions.brumby import BrumbyQModel # noqa: E402
from .definitions.chatglm import ChatGLMQModel # noqa: E402
from .definitions.codegen import CodeGenQModel # noqa: E402
from .definitions.dbrx import DbrxQModel # noqa: E402
@@ -151,6 +152,7 @@
"apertus": ApertusQModel,
"dream": DreamQModel,
"bloom": BloomQModel,
+ "brumby": BrumbyQModel,
"gpt_neo": GptNeoQModel,
"kimi_k2": DeepSeekV3QModel, # 100% DeepSeekV3QModel clone
"klear": KlearQModel,
diff --git a/gptqmodel/models/definitions/__init__.py b/gptqmodel/models/definitions/__init__.py
index 5a60ef88f..ba7422a52 100644
--- a/gptqmodel/models/definitions/__init__.py
+++ b/gptqmodel/models/definitions/__init__.py
@@ -10,6 +10,7 @@
# other model
from .baichuan import BaiChuanQModel
from .bloom import BloomQModel
+from .brumby import BrumbyQModel
from .chatglm import ChatGLMQModel
from .codegen import CodeGenQModel
from .dbrx import DbrxQModel
diff --git a/gptqmodel/models/definitions/brumby.py b/gptqmodel/models/definitions/brumby.py
new file mode 100644
index 000000000..e399c0a46
--- /dev/null
+++ b/gptqmodel/models/definitions/brumby.py
@@ -0,0 +1,37 @@
+# SPDX-FileCopyrightText: 2024-2025 ModelCloud.ai
+# SPDX-FileCopyrightText: 2024-2025 qubitium@modelcloud.ai
+# SPDX-License-Identifier: Apache-2.0
+# Contact: qubitium@modelcloud.ai, x.com/qubitium
+
+from ..base import BaseQModel
+
+
+class BrumbyQModel(BaseQModel):
+ require_trust_remote_code = True
+ require_pkgs_version = ["retention>=1.0.7"]
+
+ pre_lm_head_norm_module = "model.norm"
+
+ module_tree = [
+ "model",
+ "layers",
+ "#",
+ {
+ "input_layernorm": ("input_layernorm:0:!",),
+ "self_attn": (
+ "q_proj:0",
+ "k_proj:0",
+ "v_proj:0",
+ "g_proj:0:!",
+ "o_proj:1",
+ "q_norm:0:!",
+ "k_norm:0:!",
+ ),
+ "post_attention_layernorm": ("post_attention_layernorm:1:!",),
+ "mlp": (
+ "gate_proj:0",
+ "up_proj:0",
+ "down_proj:1",
+ ),
+ },
+ ]
diff --git a/tests/models/test_brumby.py b/tests/models/test_brumby.py
new file mode 100644
index 000000000..60608c3f2
--- /dev/null
+++ b/tests/models/test_brumby.py
@@ -0,0 +1,46 @@
+# SPDX-FileCopyrightText: 2024-2025 ModelCloud.ai
+# SPDX-FileCopyrightText: 2024-2025 qubitium@modelcloud.ai
+# SPDX-License-Identifier: Apache-2.0
+# Contact: qubitium@modelcloud.ai, x.com/qubitium
+
+from model_test import ModelTest
+
+from gptqmodel.utils.eval import EVAL
+
+
+# | Metric | MARLIN |
+# |---------------------------|----------|
+# | arc_challenge :: acc | 0.8900 |
+# | gsm8k_cot :: exact | 0.8800 |
+# | gsm8k_platinum_cot :: exact_match,flexible-extract | 0.8700 |
+# | mmlu :: acc | 0.7100 |
+class TestBrumby(ModelTest):
+ GROUP_SIZE = 32
+ NATIVE_MODEL_ID = "/monster/data/model/Brumby-14B-Base"
+ # EVAL_BATCH_SIZE = 32
+ TRUST_REMOTE_CODE = True
+ DATASET_CONCAT_SIZE = 2048
+ EVAL_TASKS = {
+ EVAL.LM_EVAL.GSM8K_PLATINUM_COT: {
+ "chat_template": True,
+ "exact_match,flexible-extract": {
+ "value": 0.87,
+ "floor_pct": 6.05,
+ },
+ },
+ EVAL.LM_EVAL.ARC_CHALLENGE: {
+ "acc": {
+ "value": 0.89,
+ "floor_pct": 4.05,
+ },
+ },
+ EVAL.LM_EVAL.MMLU: {
+ "acc": {
+ "value": 0.71,
+ "floor_pct": 4.05,
+ },
+ },
+ }
+
+ def test_brumby(self):
+ self.quant_lm_eval()
From 42f9e847b8fe91d870d05418e8159f3c44a983ed Mon Sep 17 00:00:00 2001
From: Qubitium
Date: Fri, 31 Oct 2025 16:30:44 +0000
Subject: [PATCH 2/6] update
---
gptqmodel/models/definitions/brumby.py | 8 +++++++
tests/models/model_test.py | 6 +++++-
tests/models/test_brumby.py | 29 +++++++++++++++++---------
3 files changed, 32 insertions(+), 11 deletions(-)
diff --git a/gptqmodel/models/definitions/brumby.py b/gptqmodel/models/definitions/brumby.py
index e399c0a46..24067685c 100644
--- a/gptqmodel/models/definitions/brumby.py
+++ b/gptqmodel/models/definitions/brumby.py
@@ -35,3 +35,11 @@ class BrumbyQModel(BaseQModel):
),
},
]
+
+ def after_model_load(self, model, load_quantized_model=False):
+ if hasattr(model, "config") and hasattr(model.config, "use_cache"):
+ model.config.use_cache = False
+ generation_config = getattr(model, "generation_config", None)
+ if generation_config is not None and hasattr(generation_config, "use_cache"):
+ generation_config.use_cache = False
+ return model
diff --git a/tests/models/model_test.py b/tests/models/model_test.py
index d4ac7c655..cae636155 100644
--- a/tests/models/model_test.py
+++ b/tests/models/model_test.py
@@ -24,7 +24,7 @@
from enum import Enum # noqa: E402
from pathlib import Path # noqa: E402
-from typing import Dict, List # noqa: E402
+from typing import Any, Dict, List # noqa: E402
from logbar import LogBar # noqa: E402
from tabulate import tabulate # noqa: E402
@@ -95,6 +95,7 @@ class ModelTest(unittest.TestCase):
DATASET_SORT = "desc"
DELETE_QUANTIZED_MODEL = True
EVAL_TASKS = None
+ LOAD_MODEL_EXTRA_ARGS: Dict[str, Any] = {}
KERNEL_QUANT = {} # kernel sets
KERNEL_INFERENCE = {} # kernel sets
@@ -909,6 +910,9 @@ def quantModel(self, model_id_or_path, trust_remote_code=False, dtype="auto", ne
def loadQuantModel(self, model_id_or_path, trust_remote_code=False, tokenizer_path=None, backend=None, **args):
load_kwargs = dict(args)
+ if self.LOAD_MODEL_EXTRA_ARGS:
+ for key, value in self.LOAD_MODEL_EXTRA_ARGS.items():
+ load_kwargs.setdefault(key, value)
if self.USE_FLASH_ATTN:
if is_flash_attn_2_available():
diff --git a/tests/models/test_brumby.py b/tests/models/test_brumby.py
index 60608c3f2..388df7ecd 100644
--- a/tests/models/test_brumby.py
+++ b/tests/models/test_brumby.py
@@ -6,6 +6,8 @@
from model_test import ModelTest
from gptqmodel.utils.eval import EVAL
+from gptqmodel.quantization import FORMAT, METHOD
+
# | Metric | MARLIN |
@@ -15,30 +17,37 @@
# | gsm8k_platinum_cot :: exact_match,flexible-extract | 0.8700 |
# | mmlu :: acc | 0.7100 |
class TestBrumby(ModelTest):
+ # FORMAT = FORMAT.GEMM
+ # METHOD = METHOD.AWQ
+
GROUP_SIZE = 32
+ DATASET_SIZE = 1024
NATIVE_MODEL_ID = "/monster/data/model/Brumby-14B-Base"
- # EVAL_BATCH_SIZE = 32
TRUST_REMOTE_CODE = True
+ LOAD_MODEL_EXTRA_ARGS = {"use_cache": False}
DATASET_CONCAT_SIZE = 2048
EVAL_TASKS = {
EVAL.LM_EVAL.GSM8K_PLATINUM_COT: {
"chat_template": True,
"exact_match,flexible-extract": {
"value": 0.87,
- "floor_pct": 6.05,
+ "floor_pct": 0.05,
+ "ceil_pct": 0.10,
},
},
- EVAL.LM_EVAL.ARC_CHALLENGE: {
- "acc": {
- "value": 0.89,
- "floor_pct": 4.05,
+ EVAL.LM_EVAL.GSM8K_COT: {
+ "chat_template": True,
+ "exact_match,flexible-extract": {
+ "value": 0.88,
+ "floor_pct": 0.05,
+ "ceil_pct": 0.10,
},
},
+ EVAL.LM_EVAL.ARC_CHALLENGE: {
+ "acc": {"value": 0.89, "floor_pct": 0.05, "ceil_pct": 0.10},
+ },
EVAL.LM_EVAL.MMLU: {
- "acc": {
- "value": 0.71,
- "floor_pct": 4.05,
- },
+ "acc": {"value": 0.71, "floor_pct": 0.05, "ceil_pct": 0.10},
},
}
From fd5cd14853b97de49dbddd28eb98f271527b1b49 Mon Sep 17 00:00:00 2001
From: Qubitium
Date: Fri, 31 Oct 2025 16:41:57 +0000
Subject: [PATCH 3/6] fix eval
---
scripts/eval_model.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/scripts/eval_model.py b/scripts/eval_model.py
index f6c3d3a3f..45a574815 100644
--- a/scripts/eval_model.py
+++ b/scripts/eval_model.py
@@ -13,7 +13,7 @@
import gptqmodel
from tabulate import tabulate
-
+from gptqmodel import GPTQModel
from gptqmodel.models.base import BaseQModel
from gptqmodel.utils.eval import EVAL
From 743559c1d8214e2aa4e86034802da4b776c1b474 Mon Sep 17 00:00:00 2001
From: Qubitium
Date: Fri, 31 Oct 2025 21:53:25 +0000
Subject: [PATCH 4/6] default drop samp to false
---
gptqmodel/utils/hf.py | 6 +++---
1 file changed, 3 insertions(+), 3 deletions(-)
diff --git a/gptqmodel/utils/hf.py b/gptqmodel/utils/hf.py
index b195106f2..2d949842e 100644
--- a/gptqmodel/utils/hf.py
+++ b/gptqmodel/utils/hf.py
@@ -54,7 +54,7 @@ def _load_sanitized_generation_config(path: str) -> Optional[GenerationConfig]:
cfg = GenerationConfig.from_dict(cleaned, **kwargs)
if removed:
log.info("Model: Removed unsupported sampling fields from `generation_config.json` during load.")
- _sanitize_generation_config(cfg, drop_sampling_fields=True)
+ _sanitize_generation_config(cfg, drop_sampling_fields=False)
return cfg
@@ -68,7 +68,7 @@ def autofix_hf_model_config(model: PreTrainedModel, path: str = None):
cfg = _load_sanitized_generation_config(path)
if cfg is None:
cfg = GenerationConfig.from_pretrained(pretrained_model_name=path, do_sample=True)
- _sanitize_generation_config(cfg, drop_sampling_fields=True)
+ _sanitize_generation_config(cfg, drop_sampling_fields=False)
if cfg != model.generation_config:
# migrated pad_token_id to config
if hasattr(model.generation_config, "pad_token_id"):
@@ -90,7 +90,7 @@ def autofix_hf_model_config(model: PreTrainedModel, path: str = None):
def autofix_hf_generation_config(cfg: GenerationConfig):
- _sanitize_generation_config(cfg, drop_sampling_fields=True)
+ _sanitize_generation_config(cfg, drop_sampling_fields=False)
# HF has recently started to perform very strict validation model save which results in warnings on load()
# to become exceptions on save().
if cfg.do_sample is False:
From ecbe0bdc6d7abbe94f5ba00577155066a1769371 Mon Sep 17 00:00:00 2001
From: Qubitium
Date: Fri, 31 Oct 2025 22:00:48 +0000
Subject: [PATCH 5/6] default drop samp to false
---
gptqmodel/utils/hf.py | 20 --------------------
tests/models/test_brumby.py | 10 ----------
2 files changed, 30 deletions(-)
diff --git a/gptqmodel/utils/hf.py b/gptqmodel/utils/hf.py
index 2d949842e..d757e4986 100644
--- a/gptqmodel/utils/hf.py
+++ b/gptqmodel/utils/hf.py
@@ -15,9 +15,6 @@
log = setup_logger()
-GENERATION_SAMPLING_FIELDS = ("temperature", "top_p")
-
-
def _sanitize_generation_config(cfg: GenerationConfig, *, drop_sampling_fields: bool = False) -> bool:
changed = False
if cfg is None:
@@ -27,12 +24,6 @@ def _sanitize_generation_config(cfg: GenerationConfig, *, drop_sampling_fields:
cfg.do_sample = True
changed = True
- if drop_sampling_fields:
- for field in GENERATION_SAMPLING_FIELDS:
- if hasattr(cfg, field):
- if getattr(cfg, field) is not None:
- changed = True
- setattr(cfg, field, None)
return changed
@@ -43,17 +34,10 @@ def _load_sanitized_generation_config(path: str) -> Optional[GenerationConfig]:
return None
cleaned = dict(config_dict)
- removed = False
- for field in GENERATION_SAMPLING_FIELDS:
- if field in cleaned:
- cleaned.pop(field, None)
- removed = True
if cleaned.get("do_sample") is not True:
cleaned["do_sample"] = True
cfg = GenerationConfig.from_dict(cleaned, **kwargs)
- if removed:
- log.info("Model: Removed unsupported sampling fields from `generation_config.json` during load.")
_sanitize_generation_config(cfg, drop_sampling_fields=False)
return cfg
@@ -125,10 +109,6 @@ def sanitize_generation_config_file(path: str) -> bool:
return False
changed = False
- for field in GENERATION_SAMPLING_FIELDS:
- if field in data:
- data.pop(field, None)
- changed = True
if data.get("do_sample") is not True:
data["do_sample"] = True
diff --git a/tests/models/test_brumby.py b/tests/models/test_brumby.py
index 388df7ecd..2ceae22c1 100644
--- a/tests/models/test_brumby.py
+++ b/tests/models/test_brumby.py
@@ -9,17 +9,7 @@
from gptqmodel.quantization import FORMAT, METHOD
-
-# | Metric | MARLIN |
-# |---------------------------|----------|
-# | arc_challenge :: acc | 0.8900 |
-# | gsm8k_cot :: exact | 0.8800 |
-# | gsm8k_platinum_cot :: exact_match,flexible-extract | 0.8700 |
-# | mmlu :: acc | 0.7100 |
class TestBrumby(ModelTest):
- # FORMAT = FORMAT.GEMM
- # METHOD = METHOD.AWQ
-
GROUP_SIZE = 32
DATASET_SIZE = 1024
NATIVE_MODEL_ID = "/monster/data/model/Brumby-14B-Base"
From f90cc858fa5975597c6a1105adea95925d98080d Mon Sep 17 00:00:00 2001
From: Qubitium
Date: Fri, 31 Oct 2025 22:03:42 +0000
Subject: [PATCH 6/6] notes
---
README.md | 30 ++++++++++++++++--------------
1 file changed, 16 insertions(+), 14 deletions(-)
diff --git a/README.md b/README.md
index f86ccffbd..4b5f1b83f 100644
--- a/README.md
+++ b/README.md
@@ -17,6 +17,7 @@
## Latest News
+* 11/1/2025 5.1.0-dev: β¨Brumby (attention free) model support.
* 10/31/2025 5.1.0-dev: β¨IBM Granite Nano support. New `calibration_concat_separator` config option.
* 10/30/2025 5.1.0-dev: πAWQ support out of beta with full feature support in including multi-gpu quant and MoE vram saving.
* 10/30/2025 5.1.0-dev: β¨Marin model. New AWQ Torch reference kernel. Fix AWQ Marlin kernel for bf16. Fix GLM 4.5/4.6 MoE missing `mtp` layers on model save (HF bug). Modular refractor.
@@ -180,20 +181,21 @@ Native support support some of the most popular multi-modal models:
## Model Support
-| Model | | | | | | | | | |
-|-------------------|---|-------------|---|----------------|---|----------------|---|---------------------|---|
-| Apertus | β
| EXAONE 3.0 | β
| InternLM 1/2.5 | β
| Mixtral | β
| Qwen 2/3 (Next/MoE) | β
|
-| Baichuan | β
| Falcon (H1) | β
| Kimi K2 | β
| MobileLLM | β
| Qwen 2/2.5/3 VL | β
|
-| Bloom | β
| FastVLM | β
| Klear | β
| MOSS | β
| Qwen 2.5/3 Omni | β
|
-| ChatGLM | β
| Gemma 1/2/3 | β
| LING/RING | β
| MPT | β
| RefinedWeb | β
|
-| CodeGen | β
| GPTBigCod | β
| Llama 1-3.3 | β
| Nemotron H | β
| StableLM | β
|
-| Cohere 1-2 | β
| GPTQ-Neo(X) | β
| Llama 3.2 VL | β
| Nemotron Ultra | β
| StarCoder2 | β
|
-| DBRX Converted | β
| GPT-2 | β
| Llama 4 | β
| OPT | β
| TeleChat2 | β
|
-| Deci | β
| GPT-J | β
| LongCatFlash | β
| OLMo2 | β
| Yi | β
|
-| DeepSeek-V2/V3/R1 | β
| GPT-OSS | β
| LongLLaMA | β
| Ovis 1.6/2 | β
| Seed-OSS | β
|
-| DeepSeek-V2-Lite | β
| Granite | β
| Instella | β
| Phi 1-4 | β
| XVERSE | β
|
-| Dream | β
| GRIN-MoE | β
| MiniCPM3 | β
| PanGu-Ξ± | β
| Minimax M2 | β
|
-| ERNIE 4.5 | β
| Hymba | β
| Mistral | β
| Qwen 1/2/3 | β
| GLM 4.X | β
|
+| Model | | | | | | | | | |
+|-------------------|---|-------------|---|---------------|--|-----------|--|-----------------|--|
+| Apertus | β
| EXAONE 3.0 | β
| InternLM 1/2.5 | β
| Mixtral | β
| Qwen 2/3 (Next/MoE) | β
|
+| Baichuan | β
| Falcon (H1) | β
| Kimi K2 | β
| MobileLLM | β
| Qwen 2/2.5/3 VL | β
|
+| Bloom | β
| FastVLM | β
| Klear | β
| MOSS | β
| Qwen 2.5/3 Omni | β
|
+| ChatGLM | β
| Gemma 1/2/3 | β
| LING/RING | β
| MPT | β
| RefinedWeb | β
|
+| CodeGen | β
| GPTBigCod | β
| Llama 1-3.3 | β
| Nemotron H | β
| StableLM | β
|
+| Cohere 1-2 | β
| GPTQ-Neo(X) | β
| Llama 3.2 VL | β
| Nemotron Ultra | β
| StarCoder2 | β
|
+| DBRX Converted | β
| GPT-2 | β
| Llama 4 | β
| OPT | β
| TeleChat2 | β
|
+| Deci | β
| GPT-J | β
| LongCatFlash | β
| OLMo2 | β
| Yi | β
|
+| DeepSeek-V2/V3/R1 | β
| GPT-OSS | β
| LongLLaMA | β
| Ovis 1.6/2 | β
| Seed-OSS | β
|
+| DeepSeek-V2-Lite | β
| Granite | β
| Instella | β
| Phi 1-4 | β
| XVERSE | β
|
+| Dream | β
| GRIN-MoE | β
| MiniCPM3 | β
| PanGu-Ξ± | β
| Minimax M2 | β
|
+| ERNIE 4.5 | β
| Hymba | β
| Mistral | β
| Qwen 1/2/3 | β
| GLM 4.X | β
|
+| Brumby | β
| | | | | | | | |
## Platform and HW Support