From 40dcaf62ef9a7695a52904308a0ca401167b3542 Mon Sep 17 00:00:00 2001 From: Qubitium Date: Fri, 3 Oct 2025 07:39:37 +0000 Subject: [PATCH 1/5] improve testing Signed-off-by: Qubitium --- tests/models/model_test.py | 196 +++++++++++++++++++++++++++++++++++-- 1 file changed, 189 insertions(+), 7 deletions(-) diff --git a/tests/models/model_test.py b/tests/models/model_test.py index a7af96da0..42e65945f 100644 --- a/tests/models/model_test.py +++ b/tests/models/model_test.py @@ -26,6 +26,7 @@ from typing import Dict, List # noqa: E402 from logbar import LogBar # noqa: E402 +from tabulate import tabulate # noqa: E402 sys.path.insert(0, f"{str(Path(__file__).resolve().parent.parent)}/models") # noqa: E402 @@ -101,6 +102,14 @@ class ModelTest(unittest.TestCase): LM_HEAD_LOSS_MAX_DELTA_PERCENT = 0.1 # ±10% EXPECT_LM_HEAD_LOSS = None + GENERIC_TEST_PROMPTS = [ + {"prompt": "Which city is the capital city of France?", "keywords": ["paris"]}, + {"prompt": "What is the smallest habitable planet in the milky way?", "keywords": ["earth"]}, + {"prompt": "Who wrote the play Romeo and Juliet?", "keywords": ["shakespeare"]}, + {"prompt": "What gas do plants primarily absorb from the atmosphere during photosynthesis?", "keywords": ["carbon dioxide"]}, + {"prompt": "Name the largest ocean on Earth.", "keywords": ["pacific"]}, + ] + def assertInference(self, model, tokenizer=None, keywords=None, prompt=INFERENCE_PROMPT): # gptqmodel can auto init tokenizer internally @@ -141,6 +150,176 @@ def generateChat(self, model, tokenizer, prompt=None): print(f"Result is: \n{output}") return output + def generate_with_limit(self, model, tokenizer, prompt, max_new_tokens=512): + inputs = tokenizer(prompt, return_tensors="pt").to(model.device) + pad_token_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else tokenizer.eos_token_id + generated = model.generate( + **inputs, + max_new_tokens=max_new_tokens, + do_sample=False, + num_beams=1, + pad_token_id=pad_token_id, + eos_token_id=tokenizer.eos_token_id, + ) + return tokenizer.decode(generated[0], skip_special_tokens=True) + + def run_generic_inference_checks(self, model, tokenizer, backend): + model.eval() + log.info(f"Post-quant inference checks for backend `{backend.name}`") + results = [] + for idx, item in enumerate(self.GENERIC_TEST_PROMPTS, start=1): + prompt = item["prompt"] + keywords = item["keywords"] + try: + response = self.generate_with_limit(model, tokenizer, prompt) + normalized = response.lower() + matched = any(keyword.lower() in normalized for keyword in keywords) + results.append( + { + "prompt": prompt, + "keywords": keywords, + "response": response, + "matched": matched, + } + ) + if matched: + log.info(f"[{backend.name}] Prompt {idx} PASS: `{prompt}`") + else: + snippet = response.replace("\n", " ")[:200] + log.error(f"[{backend.name}] Prompt {idx} MISS: `{prompt}` -> `{snippet}`") + except Exception as exc: # pragma: no cover - informative logging for test harness + log.error(f"[{backend.name}] Prompt {idx} ERROR: `{prompt}` -> {exc}") + results.append( + { + "prompt": prompt, + "keywords": keywords, + "response": str(exc), + "matched": False, + } + ) + return results + + def run_arc_challenge_eval(self, model, backend, trust_remote_code=False): + previous_backend = self.LOAD_BACKEND + self.LOAD_BACKEND = backend + try: + task_results = self.lm_eval( + model=model, + apply_chat_template=self.APPLY_CHAT_TEMPLATE, + trust_remote_code=trust_remote_code, + delete_quantized_model=False, + ) + log.info(f"[{backend.name}] ARC summary: {task_results}") + finally: + self.LOAD_BACKEND = previous_backend + return task_results + + def perform_post_quant_validation(self, model_path, trust_remote_code=False): + inference_records = {} + arc_records = {} + for backend in (BACKEND.MARLIN, BACKEND.TORCH): + log.info(f"Loading post-quant model with backend `{backend.name}`") + model = self.loadQuantModel( + model_path, + trust_remote_code=trust_remote_code, + backend=backend, + ) + tokenizer = model.tokenizer or self.load_tokenizer(model_path, trust_remote_code=trust_remote_code) + inference_records[backend] = self.run_generic_inference_checks(model, tokenizer, backend) + try: + arc_records[backend] = self.run_arc_challenge_eval(model, backend, trust_remote_code=trust_remote_code) + finally: + del model + torch_empty_cache() + self.render_inference_summary(inference_records) + self.render_arc_summary(arc_records) + + @staticmethod + def _colorize(text, matched): + color = "\033[92m" if matched else "\033[91m" + reset = "\033[0m" + return f"{color}{text}{reset}" + + def render_inference_summary(self, inference_records): + if not inference_records: + return + ordered_backends = [backend for backend in (BACKEND.MARLIN, BACKEND.TORCH) if backend in inference_records] + if not ordered_backends: + return + + prompts = [item["prompt"] for item in self.GENERIC_TEST_PROMPTS] + table_rows = [] + sanity_scores = {} + for backend in ordered_backends: + entries = {entry["prompt"]: entry for entry in inference_records[backend]} + matched_count = sum(1 for entry in entries.values() if entry.get("matched")) + total_count = len(entries) if entries else 1 + sanity_scores[backend] = (matched_count, total_count) + + for prompt in prompts: + row = [prompt, ", ".join(self._normalize_keyword_case(k) for k in self._keywords_for_prompt(prompt))] + for backend in ordered_backends: + entry = next((item for item in inference_records[backend] if item["prompt"] == prompt), None) + if entry is None: + row.append(self._colorize("N/A", False)) + continue + matched = entry["matched"] + snippet = entry["response"].replace("\n", " ")[:80] + row.append(self._colorize(f"{'PASS' if matched else 'MISS'} | {snippet}", matched)) + table_rows.append(row) + + headers = ["Prompt", "Expected Keywords"] + [backend.name for backend in ordered_backends] + log.info("Sanity prompt comparison:\n%s", tabulate(table_rows, headers=headers, tablefmt="github")) + + for backend, (matched, total) in sanity_scores.items(): + score_pct = 100.0 * matched / max(total, 1) + result_text = f"{matched}/{total} ({score_pct:.1f}%)" + log.info("Sanity score [%s]: %s", backend.name, result_text) + + @staticmethod + def _normalize_keyword_case(keyword): + return keyword.lower() + + def _keywords_for_prompt(self, prompt): + for item in self.GENERIC_TEST_PROMPTS: + if item["prompt"] == prompt: + return item["keywords"] + return [] + + def render_arc_summary(self, arc_records): + if not arc_records: + return + ordered_backends = [backend for backend in (BACKEND.MARLIN, BACKEND.TORCH) if backend in arc_records] + if not ordered_backends: + return + + metrics = set() + for results in arc_records.values(): + metrics.update(results.keys()) + metrics = sorted(metrics) + + table_rows = [] + tolerance = 0.01 + torch_reference = arc_records.get(BACKEND.TORCH, {}) + + for metric in metrics: + row = [metric] + reference_value = torch_reference.get(metric) + for backend in ordered_backends: + value = arc_records[backend].get(metric) + if value is None: + row.append(self._colorize("N/A", False)) + continue + if backend == BACKEND.TORCH: + row.append(self._colorize(f"{value:.4f}", True)) + else: + matched = reference_value is not None and abs(value - reference_value) <= tolerance + row.append(self._colorize(f"{value:.4f}", matched)) + table_rows.append(row) + + headers = ["Metric"] + [backend.name for backend in ordered_backends] + log.info("ARC challenge comparison:\n%s", tabulate(table_rows, headers=headers, tablefmt="github")) + def load_tokenizer(self, model_id_or_path, trust_remote_code=False): tokenizer = AutoTokenizer.from_pretrained(model_id_or_path, trust_remote_code=trust_remote_code) return tokenizer @@ -159,7 +338,7 @@ def load_dataset(cls, tokenizer=None, rows: int = 0): @staticmethod def _load_calibration_parquet(): - parquet_path = Path("~/nm-calibration/llm.parquet").expanduser() + parquet_path = Path("/monster/data/_ci_/nm-calibration/llm.parquet").expanduser() if not parquet_path.exists(): raise FileNotFoundError(f"Calibration parquet not found at {parquet_path}") @@ -287,6 +466,7 @@ def quantModel(self, model_id_or_path, trust_remote_code=False, dtype="auto", ne model.save(path) tokenizer.save_pretrained(path) log.info(f"Quantized Model saved to tmp dir: {path}") + self.perform_post_quant_validation(path, trust_remote_code=trust_remote_code) q_model = self.loadQuantModel(path, trust_remote_code=trust_remote_code) q_tokenizer = q_model.tokenizer if need_create_processor: @@ -308,21 +488,23 @@ def quantModel(self, model_id_or_path, trust_remote_code=False, dtype="auto", ne else: return model, tokenizer - def loadQuantModel(self, model_id_or_path, trust_remote_code=False, tokenizer_path=None, **args): + def loadQuantModel(self, model_id_or_path, trust_remote_code=False, tokenizer_path=None, backend=None, **args): - kargs = args if args else {} + load_kwargs = dict(args) if self.USE_FLASH_ATTN: - args["attn_implementation"] = "flash_attention_2" + load_kwargs["attn_implementation"] = "flash_attention_2" + + active_backend = backend if backend is not None else self.LOAD_BACKEND model = GPTQModel.load( model_id_or_path, trust_remote_code=trust_remote_code, - backend=self.LOAD_BACKEND, - device_map={"": "cpu"} if self.LOAD_BACKEND == BACKEND.TORCH_FUSED else "auto", + backend=active_backend, + device_map={"": "cpu"} if active_backend == BACKEND.TORCH_FUSED else "auto", debug=self.DEBUG, adapter=self.EORA, - **kargs + **load_kwargs ) return model From 0f8aaddf786889c063bd5af37a9bbde3a27abbc4 Mon Sep 17 00:00:00 2001 From: Qubitium Date: Fri, 3 Oct 2025 07:51:55 +0000 Subject: [PATCH 2/5] clean up output Signed-off-by: Qubitium --- tests/models/model_test.py | 47 +++++++++++++++++++++++++++----------- 1 file changed, 34 insertions(+), 13 deletions(-) diff --git a/tests/models/model_test.py b/tests/models/model_test.py index 42e65945f..68c9e15b7 100644 --- a/tests/models/model_test.py +++ b/tests/models/model_test.py @@ -34,6 +34,7 @@ import shutil # noqa: E402 import tempfile # noqa: E402 import unittest # noqa: E402 +import textwrap # noqa: E402 from collections.abc import Iterable # noqa: E402 import torch.cuda # noqa: E402 @@ -182,11 +183,15 @@ def run_generic_inference_checks(self, model, tokenizer, backend): "matched": matched, } ) + snippet = self._summarize_response(response, width=160) if matched: - log.info(f"[{backend.name}] Prompt {idx} PASS: `{prompt}`") + log.info( + f"[{backend.name}] Prompt {idx} PASS: `{prompt}` -> `{snippet}`" + ) else: - snippet = response.replace("\n", " ")[:200] - log.error(f"[{backend.name}] Prompt {idx} MISS: `{prompt}` -> `{snippet}`") + log.error( + f"[{backend.name}] Prompt {idx} MISS: `{prompt}` -> `{snippet}`" + ) except Exception as exc: # pragma: no cover - informative logging for test harness log.error(f"[{backend.name}] Prompt {idx} ERROR: `{prompt}` -> {exc}") results.append( @@ -248,7 +253,6 @@ def render_inference_summary(self, inference_records): return prompts = [item["prompt"] for item in self.GENERIC_TEST_PROMPTS] - table_rows = [] sanity_scores = {} for backend in ordered_backends: entries = {entry["prompt"]: entry for entry in inference_records[backend]} @@ -256,20 +260,21 @@ def render_inference_summary(self, inference_records): total_count = len(entries) if entries else 1 sanity_scores[backend] = (matched_count, total_count) + log.info("Sanity prompt comparison:") for prompt in prompts: - row = [prompt, ", ".join(self._normalize_keyword_case(k) for k in self._keywords_for_prompt(prompt))] + expected = ", ".join( + self._normalize_keyword_case(k) for k in self._keywords_for_prompt(prompt) + ) + lines = [f"Prompt: {prompt}", f" Expected: {expected or 'None'}"] for backend in ordered_backends: entry = next((item for item in inference_records[backend] if item["prompt"] == prompt), None) if entry is None: - row.append(self._colorize("N/A", False)) + lines.append(f" {backend.name:<6}: {self._colorize('N/A', False)}") continue - matched = entry["matched"] - snippet = entry["response"].replace("\n", " ")[:80] - row.append(self._colorize(f"{'PASS' if matched else 'MISS'} | {snippet}", matched)) - table_rows.append(row) - - headers = ["Prompt", "Expected Keywords"] + [backend.name for backend in ordered_backends] - log.info("Sanity prompt comparison:\n%s", tabulate(table_rows, headers=headers, tablefmt="github")) + lines.append( + f" {backend.name:<6}: {self._format_inference_entry(entry)}" + ) + log.info("\n".join(lines)) for backend, (matched, total) in sanity_scores.items(): score_pct = 100.0 * matched / max(total, 1) @@ -286,6 +291,21 @@ def _keywords_for_prompt(self, prompt): return item["keywords"] return [] + @staticmethod + def _summarize_response(response, width=80): + clean = " ".join(response.split()) if response else "" + if not clean: + return "" + return textwrap.shorten(clean, width=width, placeholder="…") + + def _format_inference_entry(self, entry): + matched = entry.get("matched", False) + response = entry.get("response", "") + snippet = self._summarize_response(response) + status = "PASS" if matched else "MISS" + cell = f"{status} | {snippet}" if snippet else status + return self._colorize(cell, matched) + def render_arc_summary(self, arc_records): if not arc_records: return @@ -539,6 +559,7 @@ def lm_eval(self, model, apply_chat_template=False, trust_remote_code=False, del llm_backend="vllm" if self.USE_VLLM else "gptqmodel", model_args=model_args, output_path=tmp_dir, + backend=self.LOAD_BACKEND, framework=framework, tasks=tasks, apply_chat_template=apply_chat_template, From c6c6dda51ef2585880929cdaa6224d4700218b96 Mon Sep 17 00:00:00 2001 From: Qubitium Date: Fri, 3 Oct 2025 08:01:15 +0000 Subject: [PATCH 3/5] AWQ needs v2 to v1 conversion? Signed-off-by: Qubitium --- gptqmodel/models/writer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gptqmodel/models/writer.py b/gptqmodel/models/writer.py index 393103c64..dd8168a1c 100644 --- a/gptqmodel/models/writer.py +++ b/gptqmodel/models/writer.py @@ -226,7 +226,7 @@ def save_quantized( if not self.load_quantized_model: model = self.model # # internal is always gptq v2 but allow users to pass gptq (v1) via config - if quantize_config.format == FORMAT.GPTQ or quantize_config.format == FORMAT.GEMM: + if quantize_config.format in (FORMAT.GPTQ): # or quantize_config.format == FORMAT.GEMM: # Model qzeros may be edited in place. model = convert_gptq_v2_to_v1_format( model, quantize_config=quantize_config, qlinear_kernel=self.qlinear_kernel From 73bc50ec83bbe183bb4ee454b933c65922f978f2 Mon Sep 17 00:00:00 2001 From: Qubitium Date: Fri, 3 Oct 2025 08:18:09 +0000 Subject: [PATCH 4/5] FIX v2 to v1 conversion regression in latest refractors Signed-off-by: Qubitium --- gptqmodel/looper/gptq_processor.py | 1 + gptqmodel/looper/qqq_processor.py | 1 + gptqmodel/models/loader.py | 28 ++++---- gptqmodel/models/writer.py | 11 +++- gptqmodel/utils/model.py | 36 ++++++++++- tests/test_format_conversion_flow.py | 96 ++++++++++++++++++++++++++++ 6 files changed, 154 insertions(+), 19 deletions(-) create mode 100644 tests/test_format_conversion_flow.py diff --git a/gptqmodel/looper/gptq_processor.py b/gptqmodel/looper/gptq_processor.py index cd5604471..0ca1b5713 100644 --- a/gptqmodel/looper/gptq_processor.py +++ b/gptqmodel/looper/gptq_processor.py @@ -246,6 +246,7 @@ def submodule_finalize(self, module: NamedModule, model: BaseQModel, **kwargs): layers=layers, quant_linear_cls=model.qlinear_kernel, lock=self.lock, + quantize_config=self.qcfg, ) # TODO: store module quant results in module, not global processor result diff --git a/gptqmodel/looper/qqq_processor.py b/gptqmodel/looper/qqq_processor.py index f60eef4ba..05d14d99f 100644 --- a/gptqmodel/looper/qqq_processor.py +++ b/gptqmodel/looper/qqq_processor.py @@ -230,6 +230,7 @@ def submodule_finalize(self, module: NamedModule, model: BaseQModel, **kwargs): quant_linear_cls=QQQQuantLinear, lock=self.lock, q_scales_extra=q_scales_extra, + quantize_config=self.qcfg, ) # TODO: store module quant results in module, not global processor result diff --git a/gptqmodel/models/loader.py b/gptqmodel/models/loader.py index 271048e2b..c1a2a4f9b 100644 --- a/gptqmodel/models/loader.py +++ b/gptqmodel/models/loader.py @@ -546,22 +546,24 @@ def skip(*args, **kwargs): offload_state_dict=True, offload_buffers=True, ) - # validate sym=False v1 loading needs to be protected for models produced with new v2 format codebase - if not qcfg.sym and not qcfg.is_quantized_by_v2(): - raise ValueError( - f"Format: Loading of a sym=False model with format={FORMAT.GPTQ} is only supported if produced by gptqmodel version >= {MIN_VERSION_WITH_V2}" - ) - - model = convert_gptq_v1_to_v2_format( - model, - cfg=qcfg, - qlinear_kernel=preload_qlinear_kernel, - ) load_checkpoint_in_model = False - if preload_qlinear_kernel.REQUIRES_FORMAT_V2: - qcfg.runtime_format = FORMAT.GPTQ_V2 + if qcfg.format == FORMAT.GPTQ: + # validate sym=False v1 loading needs to be protected for models produced with new v2 format codebase + if not qcfg.sym and not qcfg.is_quantized_by_v2(): + raise ValueError( + f"Format: Loading of a sym=False model with format={FORMAT.GPTQ} is only supported if produced by gptqmodel version >= {MIN_VERSION_WITH_V2}" + ) + + if preload_qlinear_kernel.REQUIRES_FORMAT_V2: + model = convert_gptq_v1_to_v2_format( + model, + cfg=qcfg, + qlinear_kernel=preload_qlinear_kernel, + ) + + qcfg.runtime_format = FORMAT.GPTQ_V2 if backend in [BACKEND.MARLIN, BACKEND.MARLIN_FP16] and ( preload_qlinear_kernel == ExllamaV2QuantLinear or qcfg.format == FORMAT.MARLIN): diff --git a/gptqmodel/models/writer.py b/gptqmodel/models/writer.py index dd8168a1c..cad838658 100644 --- a/gptqmodel/models/writer.py +++ b/gptqmodel/models/writer.py @@ -25,6 +25,7 @@ from ..adapter.peft import LoraConfig from ..quantization.config import ( FORMAT, + METHOD, META_FIELD_ACT_GROUP_AWARE, META_FIELD_DAMP_AUTO_INCREMENT, META_FIELD_DAMP_PERCENT, @@ -225,9 +226,13 @@ def save_quantized( if not self.load_quantized_model: model = self.model - # # internal is always gptq v2 but allow users to pass gptq (v1) via config - if quantize_config.format in (FORMAT.GPTQ): # or quantize_config.format == FORMAT.GEMM: - # Model qzeros may be edited in place. + # internal is always gptq v2 but allow users to pass gptq (v1) via config + if ( + quantize_config.format == FORMAT.GPTQ + and quantize_config.quant_method == METHOD.GPTQ + and self.qlinear_kernel.REQUIRES_FORMAT_V2 + ): + # Model qzeros may be edited in place for export compatibility. model = convert_gptq_v2_to_v1_format( model, quantize_config=quantize_config, qlinear_kernel=self.qlinear_kernel ) diff --git a/gptqmodel/utils/model.py b/gptqmodel/utils/model.py index 0c644e6c2..3dd7275ea 100644 --- a/gptqmodel/utils/model.py +++ b/gptqmodel/utils/model.py @@ -673,7 +673,19 @@ def convert_gptq_v2_to_v1_format( return model -def pack_module(name, qModules, q_scales, q_zeros, q_g_idx, layers, quant_linear_cls, lock: threading.Lock, q_scales_extra = None): +def pack_module( + name, + qModules, + q_scales, + q_zeros, + q_g_idx, + layers, + quant_linear_cls, + lock: threading.Lock, + q_scales_extra=None, + quantize_config: Optional[QuantizeConfig] = None, + quant_result: Optional[Dict[str, Any]] = None, +): # Limit pack() thread usage to avoid auto-parallizataion regression with tctl.threadpool_limits(limits=1): with lock: @@ -702,6 +714,17 @@ def pack_module(name, qModules, q_scales, q_zeros, q_g_idx, layers, quant_linear else: module.pack(linear=layer, scales=q_scales, zeros=q_zeros, g_idx=q_g_idx) + if ( + quantize_config is not None + and quantize_config.quant_method == METHOD.GPTQ + and quantize_config.format == FORMAT.GPTQ + and getattr(quant_linear_cls, "REQUIRES_FORMAT_V2", False) + ): + convert_gptq_v2_to_v1_format_module( + module=module, + quantize_config=quantize_config, + ) + # TODO: why move it back to gpu? # start = time.time() # qModules[name].to(layer_device) @@ -767,8 +790,15 @@ def wrapper(name): # TODO FIX, thread pool executor does not advance iterator pb.next() pb.title(f"Packing {name}").draw() - pack_module(name=name, qModules=qModules, quant_result=quant_result, layers=modules, - quant_linear_cls=quant_linear_cls, lock=lock) + pack_module( + name=name, + qModules=qModules, + quant_result=quant_result, + layers=modules, + quant_linear_cls=quant_linear_cls, + lock=lock, + quantize_config=qcfg, + ) for _ in executor.map(wrapper, names): pass diff --git a/tests/test_format_conversion_flow.py b/tests/test_format_conversion_flow.py new file mode 100644 index 000000000..561887e8d --- /dev/null +++ b/tests/test_format_conversion_flow.py @@ -0,0 +1,96 @@ +# SPDX-FileCopyrightText: 2024-2025 ModelCloud.ai +# SPDX-FileCopyrightText: 2024-2025 qubitium@modelcloud.ai +# SPDX-License-Identifier: Apache-2.0 +# Contact: qubitium@modelcloud.ai, x.com/qubitium + +import threading +from unittest import mock + +import torch + +from gptqmodel.quantization import FORMAT, METHOD, QuantizeConfig +from gptqmodel.utils.model import pack_module + + +class _DummyLayer: + def __init__(self): + self.weight = torch.nn.Parameter(torch.zeros(1, 1)) + + def to(self, *_args, **_kwargs): + return self + + +class _DummyQuantModule: + def __init__(self): + self.bits = 4 + self.pack_dtype = torch.int32 + + QUANT_TYPE = "gptq" + + def to(self, *_args, **_kwargs): + return self + + def pack(self, **_kwargs): + pass + + def qzero_format(self, format: int | None = None): + if format is not None: + self._fmt = format + return getattr(self, "_fmt", 2) + + +def _make_quant_linear_cls(requires_v2: bool): + return type( + "DummyQuantLinear", + (), + { + "QUANT_TYPE": "gptq", + "REQUIRES_FORMAT_V2": requires_v2, + }, + ) + + +def _run_pack(quant_cfg: QuantizeConfig, requires_v2: bool) -> int: + dummy_module = _DummyQuantModule() + qmodules = {"layer": dummy_module} + layers = {"layer": _DummyLayer()} + q_scales = torch.zeros(1, 1) + q_zeros = torch.zeros(1, 1, dtype=torch.int32) + q_g_idx = torch.zeros(1, dtype=torch.int32) + lock = threading.Lock() + + quant_linear_cls = _make_quant_linear_cls(requires_v2=requires_v2) + assert getattr(quant_linear_cls, "REQUIRES_FORMAT_V2") is requires_v2 + + with mock.patch("gptqmodel.utils.model.convert_gptq_v2_to_v1_format_module") as convert_mock: + pack_module( + name="layer", + qModules=qmodules, + q_scales=q_scales, + q_zeros=q_zeros, + q_g_idx=q_g_idx, + layers=layers, + quant_linear_cls=quant_linear_cls, + lock=lock, + quantize_config=quant_cfg, + ) + + return convert_mock.call_count + + +def test_pack_module_converts_for_gptq_requires_v2(): + cfg = QuantizeConfig(bits=4, quant_method=METHOD.GPTQ, format=FORMAT.GPTQ, offload_to_disk=False) + calls = _run_pack(cfg, requires_v2=True) + assert calls == 1 + + +def test_pack_module_skips_for_non_gptq_method(): + cfg = QuantizeConfig(bits=4, quant_method=METHOD.AWQ, format=FORMAT.GEMM, offload_to_disk=False) + calls = _run_pack(cfg, requires_v2=True) + assert calls == 0 + + +def test_pack_module_skips_when_kernel_uses_v1(): + cfg = QuantizeConfig(bits=4, quant_method=METHOD.GPTQ, format=FORMAT.GPTQ, offload_to_disk=False) + calls = _run_pack(cfg, requires_v2=False) + assert calls == 0 From 9bb5a06eaa5050fdd7799df68d3bc0a86bcc8386 Mon Sep 17 00:00:00 2001 From: Qubitium Date: Fri, 3 Oct 2025 08:29:19 +0000 Subject: [PATCH 5/5] clean generation arg to avoid transformer warnings Signed-off-by: Qubitium --- gptqmodel/models/writer.py | 5 +++ gptqmodel/utils/hf.py | 75 +++++++++++++++++++++++++++++++++++++- 2 files changed, 79 insertions(+), 1 deletion(-) diff --git a/gptqmodel/models/writer.py b/gptqmodel/models/writer.py index cad838658..69b6d7bdc 100644 --- a/gptqmodel/models/writer.py +++ b/gptqmodel/models/writer.py @@ -41,6 +41,7 @@ MIN_VERSION_WITH_V2, ) from ..utils.backend import BACKEND +from ..utils.hf import sanitize_generation_config_file from ..utils.logger import setup_logger from ..utils.model import ( convert_gptq_v2_to_v1_format, @@ -251,6 +252,10 @@ def save_quantized( # Use empty state_dict hack to bypass saving weights self.model.save_pretrained(save_dir, state_dict={}, is_main_process=True) + gen_config_path = os.path.join(save_dir, "generation_config.json") + if sanitize_generation_config_file(gen_config_path): + log.info("Model: Sanitized `generation_config.json` before packaging.") + # Save `quantize_config.json` quantize_config.save_pretrained(save_dir) diff --git a/gptqmodel/utils/hf.py b/gptqmodel/utils/hf.py index c28ff1ea9..434b8fea9 100644 --- a/gptqmodel/utils/hf.py +++ b/gptqmodel/utils/hf.py @@ -3,6 +3,7 @@ # SPDX-License-Identifier: Apache-2.0 # Contact: qubitium@modelcloud.ai, x.com/qubitium +import json from typing import Any, Optional import torch @@ -14,6 +15,49 @@ log = setup_logger() +GENERATION_SAMPLING_FIELDS = ("temperature", "top_p") + + +def _sanitize_generation_config(cfg: GenerationConfig, *, drop_sampling_fields: bool = False) -> bool: + changed = False + if cfg is None: + return changed + + if getattr(cfg, "do_sample", None) is not True: + cfg.do_sample = True + changed = True + + if drop_sampling_fields: + for field in GENERATION_SAMPLING_FIELDS: + if hasattr(cfg, field): + if getattr(cfg, field) is not None: + changed = True + setattr(cfg, field, None) + return changed + + +def _load_sanitized_generation_config(path: str) -> Optional[GenerationConfig]: + try: + config_dict, kwargs = GenerationConfig.get_config_dict(path) + except Exception: + return None + + cleaned = dict(config_dict) + removed = False + for field in GENERATION_SAMPLING_FIELDS: + if field in cleaned: + cleaned.pop(field, None) + removed = True + if cleaned.get("do_sample") is not True: + cleaned["do_sample"] = True + + cfg = GenerationConfig.from_dict(cleaned, **kwargs) + if removed: + log.info("Model: Removed unsupported sampling fields from `generation_config.json` during load.") + _sanitize_generation_config(cfg, drop_sampling_fields=True) + return cfg + + # TODO FIXME! Pre-quantized use AutoModelForCausalLM.from_pretrained() but post-quantized use AutoModelForCausalLM.from_config() def autofix_hf_model_config(model: PreTrainedModel, path: str = None): if model.can_generate(): @@ -21,7 +65,10 @@ def autofix_hf_model_config(model: PreTrainedModel, path: str = None): if path: log.info(f"Model: Loaded `generation_config`: {model.generation_config}") try: - cfg = GenerationConfig.from_pretrained(pretrained_model_name=path) + cfg = _load_sanitized_generation_config(path) + if cfg is None: + cfg = GenerationConfig.from_pretrained(pretrained_model_name=path, do_sample=True) + _sanitize_generation_config(cfg, drop_sampling_fields=True) if cfg != model.generation_config: # migrated pad_token_id to config if hasattr(model.generation_config, "pad_token_id"): @@ -41,7 +88,9 @@ def autofix_hf_model_config(model: PreTrainedModel, path: str = None): autofix_hf_generation_config(model.generation_config) # print(f"After autofix_hf_model_config: {model.generation_config}") + def autofix_hf_generation_config(cfg: GenerationConfig): + _sanitize_generation_config(cfg, drop_sampling_fields=True) # HF has recently started to perform very strict validation model save which results in warnings on load() # to become exceptions on save(). if cfg.do_sample is False: @@ -67,6 +116,30 @@ def autofix_hf_generation_config(cfg: GenerationConfig): cfg.do_sample = True log.info("Model: Auto-Fixed `generation_config` by setting `do_sample=True`.") + +def sanitize_generation_config_file(path: str) -> bool: + try: + with open(path, "r", encoding="utf-8") as fp: + data = json.load(fp) + except FileNotFoundError: + return False + + changed = False + for field in GENERATION_SAMPLING_FIELDS: + if field in data: + data.pop(field, None) + changed = True + + if data.get("do_sample") is not True: + data["do_sample"] = True + changed = True + + if changed: + with open(path, "w", encoding="utf-8") as fp: + json.dump(data, fp, indent=2) + + return changed + # load hf model with empty tensors on meta device (zero tensor memory usage) def build_shell_model( loader,