Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 16 additions & 14 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
</p>

## Latest News
* 11/1/2025 5.1.0-dev: ✨Brumby (attention free) model support.
* 10/31/2025 5.1.0-dev: ✨IBM Granite Nano support. New `calibration_concat_separator` config option.
* 10/30/2025 5.1.0-dev: 🎉AWQ support out of beta with full feature support in including multi-gpu quant and MoE vram saving.
* 10/30/2025 5.1.0-dev: ✨Marin model. New AWQ Torch reference kernel. Fix AWQ Marlin kernel for bf16. Fix GLM 4.5/4.6 MoE missing `mtp` layers on model save (HF bug). Modular refractor.
Expand Down Expand Up @@ -180,20 +181,21 @@ Native support support some of the most popular multi-modal models:
<img src=https://github.com/user-attachments/assets/c1b89394-f8f6-44e5-9949-bef15a124723 width="51%"> <img src=https://github.com/user-attachments/assets/23901236-10c5-4435-ac2f-06cf2e097f1e width="47%">

## Model Support
| Model | | | | | | | | | |
|-------------------|---|-------------|---|----------------|---|----------------|---|---------------------|---|
| Apertus | ✅ | EXAONE 3.0 | ✅ | InternLM 1/2.5 | ✅ | Mixtral | ✅ | Qwen 2/3 (Next/MoE) | ✅ |
| Baichuan | ✅ | Falcon (H1) | ✅ | Kimi K2 | ✅ | MobileLLM | ✅ | Qwen 2/2.5/3 VL | ✅ |
| Bloom | ✅ | FastVLM | ✅ | Klear | ✅ | MOSS | ✅ | Qwen 2.5/3 Omni | ✅ |
| ChatGLM | ✅ | Gemma 1/2/3 | ✅ | LING/RING | ✅ | MPT | ✅ | RefinedWeb | ✅ |
| CodeGen | ✅ | GPTBigCod | ✅ | Llama 1-3.3 | ✅ | Nemotron H | ✅ | StableLM | ✅ |
| Cohere 1-2 | ✅ | GPTQ-Neo(X) | ✅ | Llama 3.2 VL | ✅ | Nemotron Ultra | ✅ | StarCoder2 | ✅ |
| DBRX Converted | ✅ | GPT-2 | ✅ | Llama 4 | ✅ | OPT | ✅ | TeleChat2 | ✅ |
| Deci | ✅ | GPT-J | ✅ | LongCatFlash | ✅ | OLMo2 | ✅ | Yi | ✅ |
| DeepSeek-V2/V3/R1 | ✅ | GPT-OSS | ✅ | LongLLaMA | ✅ | Ovis 1.6/2 | ✅ | Seed-OSS | ✅ |
| DeepSeek-V2-Lite | ✅ | Granite | ✅ | Instella | ✅ | Phi 1-4 | ✅ | XVERSE | ✅ |
| Dream | ✅ | GRIN-MoE | ✅ | MiniCPM3 | ✅ | PanGu-α | ✅ | Minimax M2 | ✅ |
| ERNIE 4.5 | ✅ | Hymba | ✅ | Mistral | ✅ | Qwen 1/2/3 | ✅ | GLM 4.X | ✅ |
| Model | | | | | | | | | |
|-------------------|---|-------------|---|---------------|--|-----------|--|-----------------|--|
| Apertus | ✅ | EXAONE 3.0 | ✅ | InternLM 1/2.5 | ✅ | Mixtral | ✅ | Qwen 2/3 (Next/MoE) | ✅ |
| Baichuan | ✅ | Falcon (H1) | ✅ | Kimi K2 | ✅ | MobileLLM | ✅ | Qwen 2/2.5/3 VL | ✅ |
| Bloom | ✅ | FastVLM | ✅ | Klear | ✅ | MOSS | ✅ | Qwen 2.5/3 Omni | ✅ |
| ChatGLM | ✅ | Gemma 1/2/3 | ✅ | LING/RING | ✅ | MPT | ✅ | RefinedWeb | ✅ |
| CodeGen | ✅ | GPTBigCod | ✅ | Llama 1-3.3 | ✅ | Nemotron H | ✅ | StableLM | ✅ |
| Cohere 1-2 | ✅ | GPTQ-Neo(X) | ✅ | Llama 3.2 VL | ✅ | Nemotron Ultra | ✅ | StarCoder2 | ✅ |
| DBRX Converted | ✅ | GPT-2 | ✅ | Llama 4 | ✅ | OPT | ✅ | TeleChat2 | ✅ |
| Deci | ✅ | GPT-J | ✅ | LongCatFlash | ✅ | OLMo2 | ✅ | Yi | ✅ |
| DeepSeek-V2/V3/R1 | ✅ | GPT-OSS | ✅ | LongLLaMA | ✅ | Ovis 1.6/2 | ✅ | Seed-OSS | ✅ |
| DeepSeek-V2-Lite | ✅ | Granite | ✅ | Instella | ✅ | Phi 1-4 | ✅ | XVERSE | ✅ |
| Dream | ✅ | GRIN-MoE | ✅ | MiniCPM3 | ✅ | PanGu-α | ✅ | Minimax M2 | ✅ |
| ERNIE 4.5 | ✅ | Hymba | ✅ | Mistral | ✅ | Qwen 1/2/3 | ✅ | GLM 4.X | ✅ |
| Brumby | ✅ | | | | | | | | |


## Platform and HW Support
Expand Down
2 changes: 2 additions & 0 deletions gptqmodel/models/auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@
from .definitions.baichuan import BaiChuanQModel # noqa: E402
from .definitions.bailing_moe import BailingMoeQModel # noqa: E402
from .definitions.bloom import BloomQModel # noqa: E402
from .definitions.brumby import BrumbyQModel # noqa: E402
from .definitions.chatglm import ChatGLMQModel # noqa: E402
from .definitions.codegen import CodeGenQModel # noqa: E402
from .definitions.dbrx import DbrxQModel # noqa: E402
Expand Down Expand Up @@ -151,6 +152,7 @@
"apertus": ApertusQModel,
"dream": DreamQModel,
"bloom": BloomQModel,
"brumby": BrumbyQModel,
"gpt_neo": GptNeoQModel,
"kimi_k2": DeepSeekV3QModel, # 100% DeepSeekV3QModel clone
"klear": KlearQModel,
Expand Down
1 change: 1 addition & 0 deletions gptqmodel/models/definitions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
# other model
from .baichuan import BaiChuanQModel
from .bloom import BloomQModel
from .brumby import BrumbyQModel
from .chatglm import ChatGLMQModel
from .codegen import CodeGenQModel
from .dbrx import DbrxQModel
Expand Down
45 changes: 45 additions & 0 deletions gptqmodel/models/definitions/brumby.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
# SPDX-FileCopyrightText: 2024-2025 ModelCloud.ai
# SPDX-FileCopyrightText: 2024-2025 qubitium@modelcloud.ai
# SPDX-License-Identifier: Apache-2.0
# Contact: qubitium@modelcloud.ai, x.com/qubitium

from ..base import BaseQModel


class BrumbyQModel(BaseQModel):
require_trust_remote_code = True
require_pkgs_version = ["retention>=1.0.7"]

pre_lm_head_norm_module = "model.norm"

module_tree = [
"model",
"layers",
"#",
{
"input_layernorm": ("input_layernorm:0:!",),
"self_attn": (
"q_proj:0",
"k_proj:0",
"v_proj:0",
"g_proj:0:!",
"o_proj:1",
"q_norm:0:!",
"k_norm:0:!",
),
"post_attention_layernorm": ("post_attention_layernorm:1:!",),
"mlp": (
"gate_proj:0",
"up_proj:0",
"down_proj:1",
),
},
]

def after_model_load(self, model, load_quantized_model=False):
if hasattr(model, "config") and hasattr(model.config, "use_cache"):
model.config.use_cache = False
generation_config = getattr(model, "generation_config", None)
if generation_config is not None and hasattr(generation_config, "use_cache"):
generation_config.use_cache = False
return model
26 changes: 3 additions & 23 deletions gptqmodel/utils/hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,6 @@

log = setup_logger()

GENERATION_SAMPLING_FIELDS = ("temperature", "top_p")


def _sanitize_generation_config(cfg: GenerationConfig, *, drop_sampling_fields: bool = False) -> bool:
changed = False
if cfg is None:
Expand All @@ -27,12 +24,6 @@ def _sanitize_generation_config(cfg: GenerationConfig, *, drop_sampling_fields:
cfg.do_sample = True
changed = True

if drop_sampling_fields:
for field in GENERATION_SAMPLING_FIELDS:
if hasattr(cfg, field):
if getattr(cfg, field) is not None:
changed = True
setattr(cfg, field, None)
return changed


Expand All @@ -43,18 +34,11 @@ def _load_sanitized_generation_config(path: str) -> Optional[GenerationConfig]:
return None

cleaned = dict(config_dict)
removed = False
for field in GENERATION_SAMPLING_FIELDS:
if field in cleaned:
cleaned.pop(field, None)
removed = True
if cleaned.get("do_sample") is not True:
cleaned["do_sample"] = True

cfg = GenerationConfig.from_dict(cleaned, **kwargs)
if removed:
log.info("Model: Removed unsupported sampling fields from `generation_config.json` during load.")
_sanitize_generation_config(cfg, drop_sampling_fields=True)
_sanitize_generation_config(cfg, drop_sampling_fields=False)
return cfg


Expand All @@ -68,7 +52,7 @@ def autofix_hf_model_config(model: PreTrainedModel, path: str = None):
cfg = _load_sanitized_generation_config(path)
if cfg is None:
cfg = GenerationConfig.from_pretrained(pretrained_model_name=path, do_sample=True)
_sanitize_generation_config(cfg, drop_sampling_fields=True)
_sanitize_generation_config(cfg, drop_sampling_fields=False)
if cfg != model.generation_config:
# migrated pad_token_id to config
if hasattr(model.generation_config, "pad_token_id"):
Expand All @@ -90,7 +74,7 @@ def autofix_hf_model_config(model: PreTrainedModel, path: str = None):


def autofix_hf_generation_config(cfg: GenerationConfig):
_sanitize_generation_config(cfg, drop_sampling_fields=True)
_sanitize_generation_config(cfg, drop_sampling_fields=False)
# HF has recently started to perform very strict validation model save which results in warnings on load()
# to become exceptions on save().
if cfg.do_sample is False:
Expand Down Expand Up @@ -125,10 +109,6 @@ def sanitize_generation_config_file(path: str) -> bool:
return False

changed = False
for field in GENERATION_SAMPLING_FIELDS:
if field in data:
data.pop(field, None)
changed = True

if data.get("do_sample") is not True:
data["do_sample"] = True
Expand Down
2 changes: 1 addition & 1 deletion scripts/eval_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

import gptqmodel
from tabulate import tabulate

from gptqmodel import GPTQModel
from gptqmodel.models.base import BaseQModel
from gptqmodel.utils.eval import EVAL

Expand Down
6 changes: 5 additions & 1 deletion tests/models/model_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@

from enum import Enum # noqa: E402
from pathlib import Path # noqa: E402
from typing import Dict, List # noqa: E402
from typing import Any, Dict, List # noqa: E402

from logbar import LogBar # noqa: E402
from tabulate import tabulate # noqa: E402
Expand Down Expand Up @@ -95,6 +95,7 @@ class ModelTest(unittest.TestCase):
DATASET_SORT = "desc"
DELETE_QUANTIZED_MODEL = True
EVAL_TASKS = None
LOAD_MODEL_EXTRA_ARGS: Dict[str, Any] = {}

KERNEL_QUANT = {} # kernel sets
KERNEL_INFERENCE = {} # kernel sets
Expand Down Expand Up @@ -909,6 +910,9 @@ def quantModel(self, model_id_or_path, trust_remote_code=False, dtype="auto", ne
def loadQuantModel(self, model_id_or_path, trust_remote_code=False, tokenizer_path=None, backend=None, **args):

load_kwargs = dict(args)
if self.LOAD_MODEL_EXTRA_ARGS:
for key, value in self.LOAD_MODEL_EXTRA_ARGS.items():
load_kwargs.setdefault(key, value)

if self.USE_FLASH_ATTN:
if is_flash_attn_2_available():
Expand Down
45 changes: 45 additions & 0 deletions tests/models/test_brumby.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
# SPDX-FileCopyrightText: 2024-2025 ModelCloud.ai
# SPDX-FileCopyrightText: 2024-2025 qubitium@modelcloud.ai
# SPDX-License-Identifier: Apache-2.0
# Contact: qubitium@modelcloud.ai, x.com/qubitium

from model_test import ModelTest

from gptqmodel.utils.eval import EVAL
from gptqmodel.quantization import FORMAT, METHOD


class TestBrumby(ModelTest):
GROUP_SIZE = 32
DATASET_SIZE = 1024
NATIVE_MODEL_ID = "/monster/data/model/Brumby-14B-Base"
TRUST_REMOTE_CODE = True
LOAD_MODEL_EXTRA_ARGS = {"use_cache": False}
DATASET_CONCAT_SIZE = 2048
EVAL_TASKS = {
EVAL.LM_EVAL.GSM8K_PLATINUM_COT: {
"chat_template": True,
"exact_match,flexible-extract": {
"value": 0.87,
"floor_pct": 0.05,
"ceil_pct": 0.10,
},
},
EVAL.LM_EVAL.GSM8K_COT: {
"chat_template": True,
"exact_match,flexible-extract": {
"value": 0.88,
"floor_pct": 0.05,
"ceil_pct": 0.10,
},
},
EVAL.LM_EVAL.ARC_CHALLENGE: {
"acc": {"value": 0.89, "floor_pct": 0.05, "ceil_pct": 0.10},
},
EVAL.LM_EVAL.MMLU: {
"acc": {"value": 0.71, "floor_pct": 0.05, "ceil_pct": 0.10},
},
}

def test_brumby(self):
self.quant_lm_eval()