diff --git a/gptqmodel/looper/awq_processor.py b/gptqmodel/looper/awq_processor.py index 06ecdc64a..e3fc542d7 100644 --- a/gptqmodel/looper/awq_processor.py +++ b/gptqmodel/looper/awq_processor.py @@ -48,14 +48,34 @@ class _AWQLayerState: lock: threading.Lock = field(default_factory=threading.Lock) class AWQProcessor(LoopProcessor): - def __init__(self, tokenizer, qcfg: QuantizeConfig, calibration, prepare_dataset_func, - calibration_concat_size: Optional[int], calibration_sort: Optional[str], batch_size: int, gptq_model, model, - require_fwd: bool = True, calculate_w_wq_diff: bool = False): - - super().__init__(tokenizer=tokenizer, qcfg=qcfg, calibration=calibration, - calibration_concat_size=calibration_concat_size, calibration_sort=calibration_sort, - prepare_dataset_func=prepare_dataset_func, batch_size=batch_size, - require_fwd=require_fwd, fwd_after_process=False) + def __init__( + self, + tokenizer, + qcfg: QuantizeConfig, + calibration, + prepare_dataset_func, + calibration_concat_size: Optional[int], + calibration_sort: Optional[str], + batch_size: int, + gptq_model, + model, + require_fwd: bool = True, + calculate_w_wq_diff: bool = False, + calibration_concat_separator: Optional[str] = None, + ): + + super().__init__( + tokenizer=tokenizer, + qcfg=qcfg, + calibration=calibration, + calibration_concat_size=calibration_concat_size, + calibration_sort=calibration_sort, + calibration_concat_separator=calibration_concat_separator, + prepare_dataset_func=prepare_dataset_func, + batch_size=batch_size, + require_fwd=require_fwd, + fwd_after_process=False, + ) self.calculate_w_wq_diff = calculate_w_wq_diff self.avg_losses = [] diff --git a/gptqmodel/looper/eora_processor.py b/gptqmodel/looper/eora_processor.py index b31737cdf..46aafe616 100644 --- a/gptqmodel/looper/eora_processor.py +++ b/gptqmodel/looper/eora_processor.py @@ -28,15 +28,29 @@ class EoraProcessor(LoopProcessor): - def __init__(self, tokenizer, qcfg: QuantizeConfig, calibration, prepare_dataset_func, - calibration_concat_size: Optional[int], calibration_sort: Optional[str], batch_size: int, - require_fwd: bool = True - ): - super().__init__(tokenizer=tokenizer, qcfg=qcfg, calibration=calibration, - calibration_concat_size=calibration_concat_size, - calibration_sort=calibration_sort, - prepare_dataset_func=prepare_dataset_func, batch_size=batch_size, - require_fwd=require_fwd) + def __init__( + self, + tokenizer, + qcfg: QuantizeConfig, + calibration, + prepare_dataset_func, + calibration_concat_size: Optional[int], + calibration_sort: Optional[str], + batch_size: int, + require_fwd: bool = True, + calibration_concat_separator: Optional[str] = None, + ): + super().__init__( + tokenizer=tokenizer, + qcfg=qcfg, + calibration=calibration, + calibration_concat_size=calibration_concat_size, + calibration_sort=calibration_sort, + calibration_concat_separator=calibration_concat_separator, + prepare_dataset_func=prepare_dataset_func, + batch_size=batch_size, + require_fwd=require_fwd, + ) # Track per-module segment accumulators keyed by device so we can merge # contributions without repeatedly moving data through the CPU. diff --git a/gptqmodel/looper/gptq_processor.py b/gptqmodel/looper/gptq_processor.py index ca843dc96..d0427288c 100644 --- a/gptqmodel/looper/gptq_processor.py +++ b/gptqmodel/looper/gptq_processor.py @@ -31,15 +31,31 @@ lock = threading.Lock() class GPTQProcessor(LoopProcessor): - def __init__(self, tokenizer, qcfg: QuantizeConfig, calibration, prepare_dataset_func, - calibration_concat_size: Optional[int], calibration_sort: Optional[str], batch_size: int, - require_fwd: bool = True, calculate_w_wq_diff: bool = False): - - super().__init__(tokenizer=tokenizer, qcfg=qcfg, calibration=calibration, - calibration_concat_size=calibration_concat_size, - calibration_sort=calibration_sort, - prepare_dataset_func=prepare_dataset_func, batch_size=batch_size, - require_fwd=require_fwd) + def __init__( + self, + tokenizer, + qcfg: QuantizeConfig, + calibration, + prepare_dataset_func, + calibration_concat_size: Optional[int], + calibration_sort: Optional[str], + batch_size: int, + require_fwd: bool = True, + calculate_w_wq_diff: bool = False, + calibration_concat_separator: Optional[str] = None, + ): + + super().__init__( + tokenizer=tokenizer, + qcfg=qcfg, + calibration=calibration, + calibration_concat_size=calibration_concat_size, + calibration_sort=calibration_sort, + calibration_concat_separator=calibration_concat_separator, + prepare_dataset_func=prepare_dataset_func, + batch_size=batch_size, + require_fwd=require_fwd, + ) self.calculate_w_wq_diff = calculate_w_wq_diff self.avg_losses = [] diff --git a/gptqmodel/looper/loop_processor.py b/gptqmodel/looper/loop_processor.py index c995208ca..87a552488 100644 --- a/gptqmodel/looper/loop_processor.py +++ b/gptqmodel/looper/loop_processor.py @@ -57,6 +57,7 @@ def __init__( prepare_dataset_func: Optional[Callable] = None, calibration_concat_size: Optional[int] = None, calibration_sort: Optional[str] = None, + calibration_concat_separator: Optional[str] = None, batch_size: int = 1, require_fwd: bool = True, fwd_after_process: bool = True, @@ -128,10 +129,13 @@ def __init__( if prepare_dataset_func is None: raise ValueError("prepare_dataset_func must be provided when calibration data is supplied.") - calibration = prepare_dataset_func(calibration_dataset=calibration, - calibration_dataset_concat_size=calibration_concat_size, - calibration_dataset_sort=calibration_sort, - batch_size=batch_size) + calibration = prepare_dataset_func( + calibration_dataset=calibration, + calibration_dataset_concat_size=calibration_concat_size, + calibration_dataset_sort=calibration_sort, + batch_size=batch_size, + calibration_concat_separator=calibration_concat_separator, + ) # Calculate the average length of the average input_ids total_input_ids_length = 0 diff --git a/gptqmodel/looper/native_processor.py b/gptqmodel/looper/native_processor.py index 9d5ad255c..44fc81f6e 100644 --- a/gptqmodel/looper/native_processor.py +++ b/gptqmodel/looper/native_processor.py @@ -21,16 +21,32 @@ # v2 requires that we also need to capture/store non-quantized inputs class NativeProcessor(LoopProcessor): - def __init__(self, tokenizer, qcfg: QuantizeConfig, calibration, prepare_dataset_func, - calibration_concat_size: Optional[int], calibration_sort: Optional[str], batch_size: int, - require_fwd: bool = True): - - super().__init__(tokenizer=tokenizer, qcfg=qcfg, calibration=calibration, - calibration_concat_size=calibration_concat_size, - calibration_sort=calibration_sort, - prepare_dataset_func=prepare_dataset_func, batch_size=batch_size, - require_fwd=require_fwd, fwd_after_process=False, - fwd_all_modules_in_single_pass=True) + def __init__( + self, + tokenizer, + qcfg: QuantizeConfig, + calibration, + prepare_dataset_func, + calibration_concat_size: Optional[int], + calibration_sort: Optional[str], + batch_size: int, + require_fwd: bool = True, + calibration_concat_separator: Optional[str] = None, + ): + + super().__init__( + tokenizer=tokenizer, + qcfg=qcfg, + calibration=calibration, + calibration_concat_size=calibration_concat_size, + calibration_sort=calibration_sort, + calibration_concat_separator=calibration_concat_separator, + prepare_dataset_func=prepare_dataset_func, + batch_size=batch_size, + require_fwd=require_fwd, + fwd_after_process=False, + fwd_all_modules_in_single_pass=True, + ) self.native_inp_caches = {} diff --git a/gptqmodel/looper/qqq_processor.py b/gptqmodel/looper/qqq_processor.py index 319e390b8..272f360be 100644 --- a/gptqmodel/looper/qqq_processor.py +++ b/gptqmodel/looper/qqq_processor.py @@ -26,14 +26,31 @@ log = setup_logger() class QQQProcessor(LoopProcessor): - def __init__(self, tokenizer, qcfg: QuantizeConfig, calibration, prepare_dataset_func, - calibration_concat_size: Optional[int], calibration_sort: Optional[str], batch_size: int, - require_fwd: bool = True, calculate_w_wq_diff: bool = False): - - super().__init__(tokenizer=tokenizer, qcfg=qcfg, calibration=calibration, - calibration_concat_size=calibration_concat_size, calibration_sort=calibration_sort, - prepare_dataset_func=prepare_dataset_func, batch_size=batch_size, - require_fwd=require_fwd) + def __init__( + self, + tokenizer, + qcfg: QuantizeConfig, + calibration, + prepare_dataset_func, + calibration_concat_size: Optional[int], + calibration_sort: Optional[str], + batch_size: int, + require_fwd: bool = True, + calculate_w_wq_diff: bool = False, + calibration_concat_separator: Optional[str] = None, + ): + + super().__init__( + tokenizer=tokenizer, + qcfg=qcfg, + calibration=calibration, + calibration_concat_size=calibration_concat_size, + calibration_sort=calibration_sort, + calibration_concat_separator=calibration_concat_separator, + prepare_dataset_func=prepare_dataset_func, + batch_size=batch_size, + require_fwd=require_fwd, + ) self.calculate_w_wq_diff = calculate_w_wq_diff self.avg_losses = [] diff --git a/gptqmodel/models/_const.py b/gptqmodel/models/_const.py index cdaad86e4..2085d44bb 100644 --- a/gptqmodel/models/_const.py +++ b/gptqmodel/models/_const.py @@ -131,5 +131,3 @@ def get_best_device(backend: BACKEND = BACKEND.AUTO) -> torch.device: EXLLAMA_DEFAULT_MAX_INPUT_LENGTH = 2048 EXPERT_INDEX_PLACEHOLDER = "{expert_index}" - -CALIBRATION_DATASET_CONCAT_CHAR = " " diff --git a/gptqmodel/models/auto.py b/gptqmodel/models/auto.py index 569005336..e4cd0af16 100644 --- a/gptqmodel/models/auto.py +++ b/gptqmodel/models/auto.py @@ -96,6 +96,7 @@ from .definitions.gpt_neox import GPTNeoXQModel # noqa: E402 from .definitions.gpt_oss import GPTOSSGPTQ # noqa: E402 from .definitions.gptj import GptJQModel # noqa: E402 +from .definitions.granitemoehybrid import GraniteMoeHybridQModel from .definitions.grinmoe import GrinMoeQModel # noqa: E402 from .definitions.hymba import HymbaQModel # noqa: E402 from .definitions.instella import InstellaQModel # noqa: E402 @@ -139,7 +140,6 @@ from .definitions.starcoder2 import Starcoder2QModel # noqa: E402 from .definitions.telechat2 import TeleChat2QModel from .definitions.xverse import XverseQModel # noqa: E402 -from .definitions.granitemoehybrid import GraniteMoeHybridQModel # make quants and inference more determinisitc @@ -692,6 +692,7 @@ def generate( calibration_dataset_sort: Optional[str] = None, batch_size: Optional[int] = 1, tokenizer: Optional[PreTrainedTokenizerBase] = None, + calibration_concat_separator: Optional[str] = None, # pass-through vars for load() trust_remote_code: bool = False, dtype: Optional[Union[str, torch.dtype]] = None, @@ -736,5 +737,6 @@ def generate( calibration_dataset_sort=calibration_dataset_sort, batch_size=batch_size, tokenizer=tokenizer, + calibration_concat_separator=calibration_concat_separator, ) return diff --git a/gptqmodel/models/base.py b/gptqmodel/models/base.py index 27508d769..870d8cae7 100644 --- a/gptqmodel/models/base.py +++ b/gptqmodel/models/base.py @@ -7,12 +7,10 @@ import copy import json import os -import random import re import threading import time from collections import defaultdict -from collections.abc import Sequence from contextlib import nullcontext from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Type, Union @@ -46,7 +44,7 @@ from ..quantization.config import FORMAT, METHOD, QUANTIZE_BLACK_LIST, VRAMStrategy, dynamic_get from ..quantization.rotation.rotation import fuse_layer_norms, rotate_model from ..utils.backend import BACKEND -from ..utils.data import collate_data +from ..utils.calibration import prepare_calibration_dataset from ..utils.device import get_device from ..utils.hf import autofix_hf_model_config from ..utils.importer import select_quant_linear @@ -56,7 +54,6 @@ from ..utils.structure import alias_from_turtle_for_submodule from ..utils.torch import TORCH_HAS_COMPILE, torch_compile from ._const import ( - CALIBRATION_DATASET_CONCAT_CHAR, CPU, DEFAULT_MAX_SHARD_SIZE, DEVICE, @@ -413,393 +410,22 @@ def prepare_dataset( "HFDatasetType", "HFIterableDatasetType", ], - # Setting a fixed calibration_dataset_concat_size may improve the performance of the quantized model. calibration_dataset_concat_size: Optional[int] = None, calibration_dataset_sort: Optional[str] = None, batch_size: int = 1, calibration_data_min_length: int = 10, + calibration_concat_separator: Optional[str] = None, ): - hf_dataset_types: tuple = () - if HFDataset is not None: - hf_dataset_types += (HFDataset,) - if HFIterableDataset is not None: - hf_dataset_types += (HFIterableDataset,) - - if isinstance(calibration_dataset, str): - raise ValueError("Quantize: calibration dataset must be iterable, not a single string.") - - if hf_dataset_types and isinstance(calibration_dataset, hf_dataset_types): - raw_examples = list(calibration_dataset) - elif isinstance(calibration_dataset, list): - raw_examples = calibration_dataset - elif isinstance(calibration_dataset, Sequence) and not isinstance(calibration_dataset, (bytes, bytearray)): - raw_examples = list(calibration_dataset) - else: - raw_examples = list(calibration_dataset) - - if len(raw_examples) == 0: - raise ValueError("Quantize: calibration dataset is empty.") - - def _require_tokenizer(reason: str) -> None: - if self.tokenizer is None: - raise ValueError(f"tokenizer must be provided when {reason}.") - - def _to_2d_long_tensor(value: Any, name: str, idx: int) -> torch.Tensor: - try: - tensor = torch.as_tensor(value, dtype=torch.long) - except Exception as exc: - raise ValueError(f"Quantize: failed to convert `{name}` to tensor for calibration item {idx}.") from exc - - if tensor.ndim == 0: - raise ValueError(f"Quantize: `{name}` for calibration item {idx} must be 1D or 2D, got scalar.") - if tensor.ndim == 1: - tensor = tensor.unsqueeze(0) - elif tensor.ndim != 2: - raise ValueError( - f"Quantize: `{name}` for calibration item {idx} must be rank 1 or 2, got rank {tensor.ndim}." - ) - return tensor - - def _pack_ids(ids_value: Any, mask_value: Any, idx: int) -> Dict[str, torch.Tensor]: - ids_tensor = _to_2d_long_tensor(ids_value, "input_ids", idx) - - if mask_value is None: - mask_tensor = torch.ones_like(ids_tensor, dtype=torch.long) - else: - mask_tensor = _to_2d_long_tensor(mask_value, "attention_mask", idx) - if mask_tensor.shape != ids_tensor.shape: - if mask_tensor.numel() == ids_tensor.numel(): - mask_tensor = mask_tensor.reshape(ids_tensor.shape) - else: - raise ValueError( - f"Quantize: attention_mask shape {tuple(mask_tensor.shape)} does not match input_ids shape " - f"{tuple(ids_tensor.shape)} for calibration item {idx}." - ) - - return { - "input_ids": ids_tensor.detach(), - "attention_mask": mask_tensor.detach(), - } - - def _tokenize_text_value(text_value: Any, idx: int) -> Dict[str, torch.Tensor]: - _require_tokenizer("calibration data contains raw text") - tokenized = self.tokenizer( - text_value, - add_special_tokens=True, - return_tensors="pt", - ) - input_ids = tokenized["input_ids"] - attention_mask = tokenized.get("attention_mask") - return _pack_ids(input_ids, attention_mask, idx) - - def _tokenize_messages_value(messages_value: Any, idx: int) -> Dict[str, torch.Tensor]: - _require_tokenizer("calibration data uses the `messages` feature") - apply_fn = getattr(self.tokenizer, "apply_template", None) - if apply_fn is None: - raise ValueError("tokenizer must expose `apply_template` to handle `messages` calibration data.") - try: - templated = apply_fn(messages_value, tokenize=False) - except TypeError: - templated = apply_fn(messages_value) - - if templated is None: - raise ValueError(f"tokenizer.apply_template returned None for calibration item {idx}.") - - if hasattr(templated, "get"): - ids_value = templated.get("input_ids") - mask_value = templated.get("attention_mask") - text_value = templated.get("text") - if ids_value is not None: - return _pack_ids(ids_value, mask_value, idx) - if text_value is not None: - return _tokenize_text_value(text_value, idx) - - if isinstance(templated, (list, tuple)): - if len(templated) > 0 and isinstance(templated[0], int): - return _pack_ids(list(templated), None, idx) - raise ValueError( - f"tokenizer.apply_template returned an unsupported sequence type for calibration item {idx}." - ) - - if torch.is_tensor(templated): - return _pack_ids(templated, None, idx) - - if isinstance(templated, str): - return _tokenize_text_value(templated, idx) - - raise ValueError( - f"tokenizer.apply_template returned unsupported type {type(templated)} for calibration item {idx}." - ) - - processed_examples: List[Dict[str, torch.Tensor]] = [] - for idx, example in enumerate(raw_examples): - if isinstance(example, dict): - if "messages" in example: - apply_fn = getattr(self.tokenizer, "apply_template", None) if self.tokenizer else None - if apply_fn is None: - if "text" in example: - processed_examples.append(_tokenize_text_value(example["text"], idx)) - continue - raise ValueError( - "tokenizer must expose `apply_template` or calibration data must provide `text` when using `messages`." - ) - processed_examples.append(_tokenize_messages_value(example["messages"], idx)) - continue - if "text" in example: - processed_examples.append(_tokenize_text_value(example["text"], idx)) - continue - if "input_ids" in example: - processed_examples.append(_pack_ids(example["input_ids"], example.get("attention_mask"), idx)) - continue - raise ValueError( - f"Quantize: unsupported calibration example structure at index {idx}: keys={list(example.keys())}" - ) - - if isinstance(example, str): - processed_examples.append(_tokenize_text_value(example, idx)) - continue - - if isinstance(example, (list, tuple)): - if all(isinstance(x, int) for x in example): - processed_examples.append(_pack_ids(list(example), None, idx)) - continue - raise ValueError( - f"Quantize: list-based calibration example at index {idx} must contain only integers." - ) - - if torch.is_tensor(example): - processed_examples.append(_pack_ids(example, None, idx)) - continue - - try: - processed_examples.append(_pack_ids(example, None, idx)) - except Exception as exc: - raise ValueError( - f"Quantize: unsupported calibration example type {type(example)} at index {idx}." - ) from exc - - calibration_dataset = processed_examples - - def _convert_tensor_to_list(tensor): - if isinstance(tensor, torch.Tensor): - if len(tensor.shape) == 1: - tensor = tensor.unsqueeze(0) - tensor = tensor.long() - return tensor.cpu().numpy().tolist() - return [tensor] - - new_calibration_dataset = [] - too_short_calibration_data_count = 0 - - max_positions = None - max_positions_source = None - trimmed_row_count = 0 - longest_trimmed_row = 0 - - def _maybe_resolve_length(value, source_name): - nonlocal max_positions, max_positions_source - try: - if value is None: - return False - limit = int(value) - except Exception: - return False - if limit <= 0: - return False - if max_positions is None or limit < max_positions: - max_positions = limit - max_positions_source = source_name - return True - - model_config = getattr(self.model, "config", None) - if model_config is not None: - primary_names = ("max_position_embeddings",) - fallback_names = ( - "max_sequence_length", - "max_seq_len", - "n_positions", - "seq_length", - ) - - for attr_name in primary_names: - if _maybe_resolve_length(getattr(model_config, attr_name, None), attr_name): - break - if max_positions is None: - for attr_name in fallback_names: - if _maybe_resolve_length(getattr(model_config, attr_name, None), attr_name): - break - - for example in calibration_dataset: - input_ids = _convert_tensor_to_list(example["input_ids"]) - attention_mask = _convert_tensor_to_list(example["attention_mask"]) - - if max_positions is not None: - trimmed = False - trimmed_input_ids = [] - trimmed_attention_mask = [] - - for row_ids, row_mask in zip(input_ids, attention_mask): - row_len = len(row_ids) - if row_len > max_positions: - trimmed = True - trimmed_row_count += 1 - longest_trimmed_row = max(longest_trimmed_row, row_len) - trimmed_input_ids.append(row_ids[:max_positions]) - trimmed_attention_mask.append(row_mask[:max_positions]) - else: - trimmed_input_ids.append(row_ids) - trimmed_attention_mask.append(row_mask) - - if trimmed: - input_ids = trimmed_input_ids - attention_mask = trimmed_attention_mask - - # filter if input_ids is too short - if len(input_ids[0]) <= calibration_data_min_length: - too_short_calibration_data_count += 1 - continue - - new_calibration_dataset.append( - { - "input_ids": input_ids, - "attention_mask": attention_mask, - } - ) - - if too_short_calibration_data_count > 0: - log.warn(f"Quantize: {too_short_calibration_data_count} input_ids with length <= {calibration_data_min_length} were removed. " - f"Use quantize(calibration_data_min_length={calibration_data_min_length}) to set a custom minimum length.") - - if trimmed_row_count > 0: - log.info( - "Quantize: trimmed %s calibration rows above %s=%s (longest original length=%s)", - trimmed_row_count, - max_positions_source, - max_positions, - longest_trimmed_row, - ) - - if calibration_dataset_concat_size: - _require_tokenizer("`calibration_dataset_concat_size` is specified") - concatenated_data = [] - input_ids_buff = [] - attention_mask_buff = [] - current_length = 0 - - new_line = self.tokenizer(CALIBRATION_DATASET_CONCAT_CHAR, return_tensors="pt") - new_line_input_ids = _convert_tensor_to_list(new_line["input_ids"])[0] - new_line_attention_mask = _convert_tensor_to_list(new_line["attention_mask"])[0] - new_line_input_ids_len = len(new_line_input_ids) - - for example in new_calibration_dataset: - input_ids = example["input_ids"][0] - attention_mask = example["attention_mask"][0] - - if current_length + len(input_ids) + new_line_input_ids_len >= calibration_dataset_concat_size: - if len(input_ids_buff) > 0: - remaining_space = calibration_dataset_concat_size - current_length - # if there is remaining space, add the remaining input to the current block - if remaining_space > 0: - input_ids_buff.extend(new_line_input_ids) - input_ids_buff.extend(input_ids[:remaining_space - new_line_input_ids_len]) - attention_mask_buff.extend(new_line_attention_mask) - attention_mask_buff.extend(attention_mask[:remaining_space - new_line_input_ids_len]) - - concatenated_data.append({ - "input_ids": [input_ids_buff], - "attention_mask": [attention_mask_buff] - }) - else: - # if there is no remaining space, add the current block to the concatenated data - concatenated_data.append({ - "input_ids": [input_ids_buff], - "attention_mask": [attention_mask_buff] - }) - - input_ids_buff = input_ids[:calibration_dataset_concat_size] - attention_mask_buff = attention_mask[:calibration_dataset_concat_size] - current_length = len(input_ids_buff) - else: - input_ids_buff = input_ids[:calibration_dataset_concat_size] - attention_mask_buff = attention_mask[:calibration_dataset_concat_size] - current_length = len(input_ids_buff) - else: - if len(input_ids_buff) > 0: - input_ids_buff.extend(new_line_input_ids) - attention_mask_buff.extend(new_line_attention_mask) - current_length += new_line_input_ids_len - - input_ids_buff.extend(input_ids) - attention_mask_buff.extend(attention_mask) - current_length += len(input_ids) - - - if input_ids_buff: - padding_length = calibration_dataset_concat_size - len(input_ids_buff) - if padding_length > 0: - input_ids_buff.extend([self.tokenizer.pad_token_id] * padding_length) - attention_mask_buff.extend([0] * padding_length) - concatenated_data.append({ - "input_ids": [input_ids_buff], - "attention_mask": [attention_mask_buff] - }) - - new_calibration_dataset = concatenated_data - - # Sort or shuffle calibration dataset - if calibration_dataset_sort == "asc": - log.info("Calibration: Sort in ascending order by length") - sorted_dataset = sorted( - new_calibration_dataset, - key=lambda item: len(item["input_ids"][0]) - ) - elif calibration_dataset_sort == "desc": - log.info("Calibration: Sort in descending order by length") - sorted_dataset = sorted( - new_calibration_dataset, - key=lambda item: len(item["input_ids"][0]), - reverse=True - ) - elif calibration_dataset_sort == "shuffle": - log.info("Calibration: Sort by random shuffle") - sorted_dataset = new_calibration_dataset[:] # shallow copy - random.shuffle(sorted_dataset) - else: - log.info("Calibration: Native order") - sorted_dataset = new_calibration_dataset # fallback: no sort - - if self.support_batch_quantize: - new_calibration_dataset_batched = [ - collate_data(sorted_dataset[start: start + batch_size], self.tokenizer.pad_token_id) - for start in range(0, len(sorted_dataset), batch_size) - ] - - # total tokens counters - total_padded = 0 - total_non_padded = 0 - - for batch in new_calibration_dataset_batched: - # attention_mask is shape [batch_size, seq_len] - mask = batch["attention_mask"] - - # count where mask == 0 (padded tokens) - total_padded += (mask == 0).sum().item() - - # count where mask == 1 (non-padded tokens) - total_non_padded += (mask == 1).sum().item() - - log.info(f"Calibration: Total padded tokens: {total_padded}") - log.info(f"Calibration: Total non-padded tokens: {total_non_padded}") - log.info(f"Calibration: Total tokens: {total_non_padded + total_padded}") - else: - new_calibration_dataset_batched = [ - { - "input_ids": torch.tensor(block["input_ids"], dtype=torch.long), - } - for block in sorted_dataset - ] - - return new_calibration_dataset_batched + return prepare_calibration_dataset( + self, + calibration_dataset=calibration_dataset, + calibration_dataset_concat_size=calibration_dataset_concat_size, + calibration_dataset_sort=calibration_dataset_sort, + batch_size=batch_size, + calibration_data_min_length=calibration_data_min_length, + calibration_concat_separator=calibration_concat_separator, + logger=log, + ) def quantize( self, @@ -815,6 +441,7 @@ def quantize( adapter_calibration_dataset: Union[List[Dict[str, Union[List[int], torch.LongTensor]]], List[str], List[int]] = None, # minimum length of calibration data, default is 10 calibration_data_min_length: int = 10, + calibration_concat_separator: Optional[str] = None, ) -> Dict[str, List[Dict[str, str]]]: if self.quantized: raise EnvironmentError("quantize() is called a model that is already quantized") @@ -909,6 +536,7 @@ def quantize( "prepare_dataset_func": self.prepare_dataset, "calibration_concat_size": calibration_concat_size, "calibration_sort": calibration_sort, + "calibration_concat_separator": calibration_concat_separator, "batch_size": batch_size, "calculate_w_wq_diff": needs_lora, # lora needs original w - wq delta } @@ -1013,6 +641,7 @@ def quantize( prepare_dataset_func=self.prepare_dataset, calibration_concat_size=calibration_concat_size, calibration_sort=calibration_sort, + calibration_concat_separator=calibration_concat_separator, batch_size=batch_size, ) ) @@ -1041,6 +670,7 @@ def _eora_generate( calibration_dataset_sort: Optional[str] = None, batch_size: int = 1, tokenizer: Optional[PreTrainedTokenizerBase] = None, + calibration_concat_separator: Optional[str] = None, ): if self.quantized: raise EnvironmentError("eora_generate() is called a model that is already quantized") @@ -1073,6 +703,7 @@ def _eora_generate( prepare_dataset_func=self.prepare_dataset, calibration_concat_size=calibration_dataset_concat_size, calibration_sort=calibration_dataset_sort, + calibration_concat_separator=calibration_concat_separator, batch_size=batch_size, ), DequantizeProcessor( @@ -1085,6 +716,7 @@ def _eora_generate( prepare_dataset_func=self.prepare_dataset, calibration_concat_size=calibration_dataset_concat_size, calibration_sort=calibration_dataset_sort, + calibration_concat_separator=calibration_concat_separator, batch_size=batch_size, ), ] diff --git a/gptqmodel/models/loader.py b/gptqmodel/models/loader.py index e1dd85b98..949489f68 100644 --- a/gptqmodel/models/loader.py +++ b/gptqmodel/models/loader.py @@ -8,11 +8,11 @@ import os import time from importlib.metadata import PackageNotFoundError, version +from itertools import chain from typing import Dict, List, Optional, Union import torch import transformers -from itertools import chain from ..utils.structure import print_module_tree diff --git a/gptqmodel/quantization/gptq.py b/gptqmodel/quantization/gptq.py index 59b341a15..7321d5c98 100644 --- a/gptqmodel/quantization/gptq.py +++ b/gptqmodel/quantization/gptq.py @@ -559,35 +559,69 @@ def hf_quantize( @torch.inference_mode() def hessian_inverse(self, H: torch.Tensor): - damp = self.qcfg.damp_percent - mean = torch.mean(torch.diag(H)) - - orig_diag = H.diag().clone() - while 0 < damp < 1: - try: - H.diagonal().add_(damp * mean) - H2 = torch.linalg.cholesky(H) - Hinv = torch.linalg.cholesky(torch.cholesky_inverse(H2), upper=True) - H.diagonal().copy_(orig_diag) - del H2 - break - except torch._C._LinAlgError as e: - H.diagonal().copy_(orig_diag) - if self.qcfg.damp_auto_increment != 0: + # Capture a writable view of the Hessian diagonal so we can restore it between attempts. + diag_view = H.diagonal() + orig_diag = diag_view.clone() + + # When a block is numerically singular, pure damping can stall at 1.0. + # Prepare a tiny diagonal floor (relative to the largest entry) that we + # only inject if the normal damping loop fails. Keeping the scale near 1e-6 + # of the dominant entry keeps the bias negligible for healthy layers while + # still rescuing pathological Hessian blocks. + base_abs_max = torch.max(orig_diag.abs()).item() + if not math.isfinite(base_abs_max) or base_abs_max == 0.0: + base_abs_max = 1.0 + floor_base = base_abs_max * 1e-6 + max_floor_attempts = 6 + used_damp = self.qcfg.damp_percent + last_error = None + + attempt = 0 + while attempt <= max_floor_attempts: + if attempt == 0: + current_diag = orig_diag + else: + floor_increment = floor_base * math.pow(10.0, attempt - 1) + current_diag = torch.clamp(orig_diag + floor_increment, min=floor_increment) + if attempt == 1: log.warn( - f"Quantization: Module `{self.name}` -> Current `damp_percent = {damp:.5f}` is too low, auto-incrementing by `{self.qcfg.damp_auto_increment:.5f}`") - damp += self.qcfg.damp_auto_increment + f"Quantization: Module `{self.name}` -> Applying Hessian diagonal floor (+{floor_increment:.2e}) to recover positive definiteness.") else: log.warn( - "Quantization: Module `{self.name}` -> Please increase damp or nsamples for calibration data to avoid the following quant error: current damp_percent=`{damp:.5f}`") - raise e + f"Quantization: Module `{self.name}` -> Increasing Hessian diagonal floor to +{floor_increment:.2e}.") + + diag_view.copy_(current_diag) + mean = torch.mean(current_diag) + damp = self.qcfg.damp_percent + + while 0 < damp < 1: + try: + diag_view.add_(damp * mean) + H2 = torch.linalg.cholesky(H) + Hinv_result = torch.linalg.cholesky(torch.cholesky_inverse(H2), upper=True) + diag_view.copy_(current_diag) + del H2 + used_damp = damp + return Hinv_result, used_damp + except torch._C._LinAlgError as e: + last_error = e + diag_view.copy_(current_diag) + if self.qcfg.damp_auto_increment != 0: + log.warn( + f"Quantization: Module `{self.name}` -> Current `damp_percent = {damp:.5f}` is too low, auto-incrementing by `{self.qcfg.damp_auto_increment:.5f}`") + damp += self.qcfg.damp_auto_increment + else: + log.warn( + f"Quantization: Module `{self.name}` -> Hessian Cholesky failed with `damp_percent={damp:.5f}` and no auto increment configured.") + break - if not (0 < damp < 1): - log.error( - f"Quantization: Module `{self.name}` -> `damp_percent` must between 0 and 1. current is {damp}. Module cannot be correctly processed.") - return None, 1.0 + attempt += 1 - return Hinv, damp + log.error( + f"Quantization: Module `{self.name}` -> Hessian remained non positive-definite after diagonal floor attempts. Last `damp_percent` tried = {damp:.5f}.") + if last_error is not None: + log.debug(f"Hessian failure detail: {last_error}") + return None, 1.0 @torch.inference_mode() def quantize( diff --git a/gptqmodel/utils/calibration.py b/gptqmodel/utils/calibration.py index ba2e4ad76..de7fcb4c7 100644 --- a/gptqmodel/utils/calibration.py +++ b/gptqmodel/utils/calibration.py @@ -3,15 +3,462 @@ # SPDX-License-Identifier: Apache-2.0 # Contact: qubitium@modelcloud.ai, x.com/qubitium -def batched(iterable, n: int, process_func): - # batched('ABCDEFG', 3) → ABC DEF G - assert n >= 1, "batch size must be at least one" - from itertools import islice +"""Utilities for preparing calibration datasets used during quantization.""" - iterator = iter(iterable) +from __future__ import annotations - while batch := tuple(islice(iterator, n)): - if process_func is None: +import random +from typing import Any, Dict, List, Optional, Sequence, Union + +import torch + +from .data import collate_data +from .logger import setup_logger + + +try: # pragma: no cover - optional dependency + from datasets import Dataset as HFDataset + from datasets import IterableDataset as HFIterableDataset +except Exception: # pragma: no cover - handled dynamically + HFDataset = HFIterableDataset = None + + +CalibrationInputType = Union[ + List[Dict[str, Union[List[int], torch.LongTensor]]], + List[str], + List[List[int]], + "HFDataset", # type: ignore[type-arg] + "HFIterableDataset", # type: ignore[type-arg] +] + + +def batched(iterable, batch_size: int, process_func=None): + """Yield fixed-size batches from ``iterable`` after optional processing.""" + + if batch_size <= 0: + raise ValueError("batch_size must be positive") + + batch = [] + for item in iterable: + processed = process_func(item) if process_func is not None else item + batch.append(processed) + if len(batch) == batch_size: yield batch + batch = [] + + if batch: + yield batch + + +def prepare_calibration_dataset( + qmodel, + calibration_dataset: CalibrationInputType, + calibration_dataset_concat_size: Optional[int] = None, + calibration_dataset_sort: Optional[str] = None, + batch_size: int = 1, + calibration_data_min_length: int = 10, + calibration_concat_separator: Optional[str] = None, + logger=None, +): + """Normalize, validate, and batch calibration samples for quantization. + + Parameters mirror ``BaseQModel.prepare_dataset`` so existing code paths can + delegate directly to this helper. + """ + + log = logger or setup_logger() + + tokenizer = getattr(qmodel, "tokenizer", None) + support_batch_quantize = getattr(qmodel, "support_batch_quantize", True) + + hf_dataset_types: tuple = () + if HFDataset is not None: + hf_dataset_types += (HFDataset,) + if HFIterableDataset is not None: + hf_dataset_types += (HFIterableDataset,) + + if isinstance(calibration_dataset, str): + raise ValueError("Quantize: calibration dataset must be iterable, not a single string.") + + if hf_dataset_types and isinstance(calibration_dataset, hf_dataset_types): + raw_examples = list(calibration_dataset) + elif isinstance(calibration_dataset, list): + raw_examples = calibration_dataset + elif isinstance(calibration_dataset, Sequence) and not isinstance(calibration_dataset, (bytes, bytearray)): + raw_examples = list(calibration_dataset) + else: + raw_examples = list(calibration_dataset) + + if len(raw_examples) == 0: + raise ValueError("Quantize: calibration dataset is empty.") + + def _require_tokenizer(reason: str) -> None: + if tokenizer is None: + raise ValueError(f"tokenizer must be provided when {reason}.") + + def _to_2d_long_tensor(value: Any, name: str, idx: int) -> torch.Tensor: + try: + tensor = torch.as_tensor(value, dtype=torch.long) + except Exception as exc: # pragma: no cover - defensive + raise ValueError(f"Quantize: failed to convert `{name}` to tensor for calibration item {idx}.") from exc + + if tensor.ndim == 0: + raise ValueError(f"Quantize: `{name}` for calibration item {idx} must be 1D or 2D, got scalar.") + if tensor.ndim == 1: + tensor = tensor.unsqueeze(0) + elif tensor.ndim != 2: + raise ValueError( + f"Quantize: `{name}` for calibration item {idx} must be rank 1 or 2, got rank {tensor.ndim}." + ) + return tensor + + def _pack_ids(ids_value: Any, mask_value: Any, idx: int) -> Dict[str, torch.Tensor]: + ids_tensor = _to_2d_long_tensor(ids_value, "input_ids", idx) + + if mask_value is None: + mask_tensor = torch.ones_like(ids_tensor, dtype=torch.long) + else: + mask_tensor = _to_2d_long_tensor(mask_value, "attention_mask", idx) + if mask_tensor.shape != ids_tensor.shape: + if mask_tensor.numel() == ids_tensor.numel(): + mask_tensor = mask_tensor.reshape(ids_tensor.shape) + else: + raise ValueError( + f"Quantize: attention_mask shape {tuple(mask_tensor.shape)} does not match input_ids shape " + f"{tuple(ids_tensor.shape)} for calibration item {idx}." + ) + + return { + "input_ids": ids_tensor.detach(), + "attention_mask": mask_tensor.detach(), + } + + def _tokenize_text_value(text_value: Any, idx: int) -> Dict[str, torch.Tensor]: + _require_tokenizer("calibration data contains raw text") + tokenized = tokenizer( # type: ignore[call-arg] + text_value, + add_special_tokens=True, + return_tensors="pt", + ) + input_ids = tokenized["input_ids"] + attention_mask = tokenized.get("attention_mask") + return _pack_ids(input_ids, attention_mask, idx) + + def _tokenize_messages_value(messages_value: Any, idx: int) -> Dict[str, torch.Tensor]: + _require_tokenizer("calibration data uses the `messages` feature") + apply_fn = getattr(tokenizer, "apply_template", None) + if apply_fn is None: + raise ValueError("tokenizer must expose `apply_template` to handle `messages` calibration data.") + try: + templated = apply_fn(messages_value, tokenize=False) + except TypeError: + templated = apply_fn(messages_value) + + if templated is None: + raise ValueError(f"tokenizer.apply_template returned None for calibration item {idx}.") + + if hasattr(templated, "get"): + ids_value = templated.get("input_ids") + mask_value = templated.get("attention_mask") + text_value = templated.get("text") + if ids_value is not None: + return _pack_ids(ids_value, mask_value, idx) + if text_value is not None: + return _tokenize_text_value(text_value, idx) + + if isinstance(templated, (list, tuple)): + if len(templated) > 0 and isinstance(templated[0], int): + return _pack_ids(list(templated), None, idx) + raise ValueError( + "tokenizer.apply_template returned an unsupported sequence type for calibration item {idx}." + ) + + if torch.is_tensor(templated): + return _pack_ids(templated, None, idx) + + if isinstance(templated, str): + return _tokenize_text_value(templated, idx) + + raise ValueError( + f"tokenizer.apply_template returned unsupported type {type(templated)} for calibration item {idx}." + ) + + processed_examples: List[Dict[str, torch.Tensor]] = [] + for idx, example in enumerate(raw_examples): + if isinstance(example, dict): + if "messages" in example: + apply_fn = getattr(tokenizer, "apply_template", None) if tokenizer else None + if apply_fn is None: + if "text" in example: + processed_examples.append(_tokenize_text_value(example["text"], idx)) + continue + raise ValueError( + "tokenizer must expose `apply_template` or calibration data must provide `text` when using `messages`." + ) + processed_examples.append(_tokenize_messages_value(example["messages"], idx)) + continue + if "text" in example: + processed_examples.append(_tokenize_text_value(example["text"], idx)) + continue + if "input_ids" in example: + processed_examples.append(_pack_ids(example["input_ids"], example.get("attention_mask"), idx)) + continue + raise ValueError( + f"Quantize: unsupported calibration example structure at index {idx}: keys={list(example.keys())}" + ) + + if isinstance(example, str): + processed_examples.append(_tokenize_text_value(example, idx)) + continue + + if isinstance(example, (list, tuple)): + if all(isinstance(x, int) for x in example): + processed_examples.append(_pack_ids(list(example), None, idx)) + continue + raise ValueError( + f"Quantize: list-based calibration example at index {idx} must contain only integers." + ) + + if torch.is_tensor(example): + processed_examples.append(_pack_ids(example, None, idx)) + continue + + try: + processed_examples.append(_pack_ids(example, None, idx)) + except Exception as exc: # pragma: no cover - defensive + raise ValueError( + f"Quantize: unsupported calibration example type {type(example)} at index {idx}." + ) from exc + + calibration_dataset = processed_examples + + def _convert_tensor_to_list(tensor): + if isinstance(tensor, torch.Tensor): + if len(tensor.shape) == 1: + tensor = tensor.unsqueeze(0) + tensor = tensor.long() + return tensor.cpu().numpy().tolist() + return [tensor] + + new_calibration_dataset = [] + too_short_calibration_data_count = 0 + + max_positions = None + max_positions_source = None + trimmed_row_count = 0 + longest_trimmed_row = 0 + + def _maybe_resolve_length(value, source_name): + nonlocal max_positions, max_positions_source + try: + if value is None: + return False + limit = int(value) + except Exception: + return False + if limit <= 0: + return False + if max_positions is None or limit < max_positions: + max_positions = limit + max_positions_source = source_name + return True + + model_config = getattr(getattr(qmodel, "model", None), "config", None) + if model_config is not None: + primary_names = ("max_position_embeddings",) + fallback_names = ( + "max_sequence_length", + "max_seq_len", + "n_positions", + "seq_length", + ) + + for attr_name in primary_names: + if _maybe_resolve_length(getattr(model_config, attr_name, None), attr_name): + break + if max_positions is None: + for attr_name in fallback_names: + if _maybe_resolve_length(getattr(model_config, attr_name, None), attr_name): + break + + for example in calibration_dataset: + input_ids = _convert_tensor_to_list(example["input_ids"]) + attention_mask = _convert_tensor_to_list(example["attention_mask"]) + + if max_positions is not None: + trimmed = False + trimmed_input_ids = [] + trimmed_attention_mask = [] + + for row_ids, row_mask in zip(input_ids, attention_mask): + row_len = len(row_ids) + if row_len > max_positions: + trimmed = True + trimmed_row_count += 1 + longest_trimmed_row = max(longest_trimmed_row, row_len) + trimmed_input_ids.append(row_ids[:max_positions]) + trimmed_attention_mask.append(row_mask[:max_positions]) + else: + trimmed_input_ids.append(row_ids) + trimmed_attention_mask.append(row_mask) + + if trimmed: + input_ids = trimmed_input_ids + attention_mask = trimmed_attention_mask + + if len(input_ids[0]) <= calibration_data_min_length: + too_short_calibration_data_count += 1 + continue + + new_calibration_dataset.append( + { + "input_ids": input_ids, + "attention_mask": attention_mask, + } + ) + + if too_short_calibration_data_count > 0: + log.warn( + f"Quantize: {too_short_calibration_data_count} input_ids with length <= {calibration_data_min_length} were removed. " + f"Use quantize(calibration_data_min_length={calibration_data_min_length}) to set a custom minimum length." + ) + + if trimmed_row_count > 0: + log.info( + "Quantize: trimmed %s calibration rows above %s=%s (longest original length=%s)", + trimmed_row_count, + max_positions_source, + max_positions, + longest_trimmed_row, + ) + + if calibration_dataset_concat_size: + _require_tokenizer("`calibration_dataset_concat_size` is specified") + concatenated_data = [] + input_ids_buff = [] + attention_mask_buff = [] + current_length = 0 + + separator = calibration_concat_separator if calibration_concat_separator is not None else "" + if separator: + new_line = tokenizer(separator, return_tensors="pt") # type: ignore[call-arg] + new_line_input_ids = _convert_tensor_to_list(new_line["input_ids"])[0] + new_line_attention_mask = _convert_tensor_to_list(new_line["attention_mask"])[0] else: - yield [process_func(item) for item in batch] + new_line_input_ids = [] + new_line_attention_mask = [] + new_line_input_ids_len = len(new_line_input_ids) + + def flush_buffer(): + nonlocal input_ids_buff, attention_mask_buff, current_length + concatenated_data.append( + { + "input_ids": [input_ids_buff], + "attention_mask": [attention_mask_buff], + } + ) + input_ids_buff = [] + attention_mask_buff = [] + current_length = 0 + + for example in new_calibration_dataset: + row_ids = example["input_ids"][0] + row_mask = example["attention_mask"][0] + position = 0 + row_length = len(row_ids) + + while position < row_length: + if input_ids_buff: + if new_line_input_ids_len: + if current_length + new_line_input_ids_len > calibration_dataset_concat_size: + flush_buffer() + continue + input_ids_buff.extend(new_line_input_ids) + attention_mask_buff.extend(new_line_attention_mask) + current_length += new_line_input_ids_len + + available = calibration_dataset_concat_size - current_length + if available == 0: + flush_buffer() + continue + + chunk_len = min(available, row_length - position) + if chunk_len == 0: + flush_buffer() + continue + + end = position + chunk_len + input_ids_buff.extend(row_ids[position:end]) + attention_mask_buff.extend(row_mask[position:end]) + current_length += chunk_len + position = end + + if current_length == calibration_dataset_concat_size: + flush_buffer() + + if input_ids_buff: + padding_length = calibration_dataset_concat_size - len(input_ids_buff) + if padding_length > 0: + pad_id = getattr(tokenizer, "pad_token_id", 0) + input_ids_buff.extend([pad_id] * padding_length) + attention_mask_buff.extend([0] * padding_length) + concatenated_data.append( + { + "input_ids": [input_ids_buff], + "attention_mask": [attention_mask_buff], + } + ) + + new_calibration_dataset = concatenated_data + + if calibration_dataset_sort == "asc": + log.info("Calibration: Sort in ascending order by length") + sorted_dataset = sorted( + new_calibration_dataset, + key=lambda item: len(item["input_ids"][0]), + ) + elif calibration_dataset_sort == "desc": + log.info("Calibration: Sort in descending order by length") + sorted_dataset = sorted( + new_calibration_dataset, + key=lambda item: len(item["input_ids"][0]), + reverse=True, + ) + elif calibration_dataset_sort == "shuffle": + log.info("Calibration: Sort by random shuffle") + sorted_dataset = new_calibration_dataset[:] + random.shuffle(sorted_dataset) + else: + log.info("Calibration: Native order") + sorted_dataset = new_calibration_dataset + + if support_batch_quantize: + pad_token_id = getattr(tokenizer, "pad_token_id", 0) if tokenizer is not None else 0 + new_calibration_dataset_batched = [ + collate_data(sorted_dataset[start : start + batch_size], pad_token_id) + for start in range(0, len(sorted_dataset), batch_size) + ] + + total_padded = 0 + total_non_padded = 0 + + for batch in new_calibration_dataset_batched: + mask = batch["attention_mask"] + total_padded += (mask == 0).sum().item() + total_non_padded += (mask == 1).sum().item() + + log.info(f"Calibration: Total padded tokens: {total_padded}") + log.info(f"Calibration: Total non-padded tokens: {total_non_padded}") + log.info(f"Calibration: Total tokens: {total_non_padded + total_padded}") + else: + new_calibration_dataset_batched = [ + { + "input_ids": torch.tensor(block["input_ids"], dtype=torch.long), + } + for block in sorted_dataset + ] + + return new_calibration_dataset_batched + + +__all__ = ["batched", "prepare_calibration_dataset"] diff --git a/tests/models/model_test.py b/tests/models/model_test.py index db50e7c38..d4ac7c655 100644 --- a/tests/models/model_test.py +++ b/tests/models/model_test.py @@ -90,6 +90,8 @@ class ModelTest(unittest.TestCase): INPUTS_MAX_LENGTH = 2048 MODEL_MAX_LEN = 4096 DATASET_SIZE = 512 + DATASET_CONCAT_SIZE = None + DATASET_CONCAT_SEPARATOR = None DATASET_SORT = "desc" DELETE_QUANTIZED_MODEL = True EVAL_TASKS = None @@ -109,6 +111,8 @@ class ModelTest(unittest.TestCase): FAIL_SAFE = True EORA = None DAMP_PERCENT = 0.05 + MSE = 0.0 + DYNAMIC = None SAVE_PATH = None # default is temp folder @@ -783,6 +787,8 @@ def quantModel(self, model_id_or_path, trust_remote_code=False, dtype="auto", ne pack_impl="cpu", vram_strategy=self.VRAM_STRATEGY, damp_percent=self.DAMP_PERCENT, + mse=self.MSE, + dynamic=self.DYNAMIC, ) log.info(f"Quant config: {quantize_config}") @@ -832,7 +838,14 @@ def quantModel(self, model_id_or_path, trust_remote_code=False, dtype="auto", ne try: save_context, planned_save_path, cleanup_callback = self._prepare_quant_save_destination(need_eval) log.info(f"Quantized model artifacts will be saved to: {planned_save_path}") - model.quantize(calibration_dataset, calibration_sort=self.DATASET_SORT, backend=self.QUANT_BACKEND, batch_size=batch_size) + model.quantize( + calibration_dataset, + calibration_concat_size=self.DATASET_CONCAT_SIZE, + calibration_concat_separator=self.DATASET_CONCAT_SEPARATOR, + calibration_sort=self.DATASET_SORT, + backend=self.QUANT_BACKEND, + batch_size=batch_size, + ) self.check_kernel(model, self.KERNEL_QUANT) diff --git a/tests/models/test_granite_4_0_h_1b.py b/tests/models/test_granite_4_0_h_1b.py index b01d4924c..6fac47574 100644 --- a/tests/models/test_granite_4_0_h_1b.py +++ b/tests/models/test_granite_4_0_h_1b.py @@ -2,9 +2,9 @@ # SPDX-FileCopyrightText: 2024-2025 qubitium@modelcloud.ai # SPDX-License-Identifier: Apache-2.0 # Contact: qubitium@modelcloud.ai, x.com/qubitium -from gptqmodel import BACKEND from model_test import ModelTest +from gptqmodel import BACKEND from gptqmodel.utils.eval import EVAL diff --git a/tests/models/test_granite_4_0_h_350m.py b/tests/models/test_granite_4_0_h_350m.py index 75fb282fd..a77f905a7 100644 --- a/tests/models/test_granite_4_0_h_350m.py +++ b/tests/models/test_granite_4_0_h_350m.py @@ -2,9 +2,9 @@ # SPDX-FileCopyrightText: 2024-2025 qubitium@modelcloud.ai # SPDX-License-Identifier: Apache-2.0 # Contact: qubitium@modelcloud.ai, x.com/qubitium -from gptqmodel import BACKEND from model_test import ModelTest +from gptqmodel import BACKEND from gptqmodel.utils.eval import EVAL diff --git a/tests/models/test_llama3_2.py b/tests/models/test_llama3_2.py index c4a315c42..06870ed27 100644 --- a/tests/models/test_llama3_2.py +++ b/tests/models/test_llama3_2.py @@ -8,40 +8,40 @@ from gptqmodel.utils.eval import EVAL -# a100:7, MARLIN kernel -# desc_act = False, act_group_aware = False 0.3200/0.3447 -# desc_act = False, act_group_aware = True 0.3181/0.3481 -# desc_act = True, REGRESSION 0.3191/0.3601 -# a100:6+7: MARLIN kernel -# desc_act = False, act_group_aware = True 0.3217/0.3643 # | Metric | MARLIN | # |--------------------------------|----------| -# | arc_challenge :: acc,none | 0.3174 | -# | arc_challenge :: acc_norm,none | 0.3601 | -# | mmlu_stem :: acc,none | 0.3186 | +# | arc_challenge :: acc,none | 0.3046 | +# | arc_challenge :: acc_norm,none | 0.3345 | +# | mmlu_stem :: acc,none | 0.3768 | +# | gsm8k_plat :: exact,flexible | 0.1944 | class TestLlama3_2(ModelTest): NATIVE_MODEL_ID = "/monster/data/model/Llama-3.2-1B-Instruct" # "meta-llama/Llama-3.2-1B-Instruct" EVAL_BATCH_SIZE = 64 + DATASET_CONCAT_SIZE = 2048 EVAL_TASKS = { - EVAL.LM_EVAL.ARC_CHALLENGE: { + EVAL.LM_EVAL.GSM8K_PLATINUM_COT: { "chat_template": True, - "acc": { - "value": 0.3191, + "exact_match,flexible-extract": { + "value": 0.1944, "floor_pct": 0.04, - "ceil_pct": 0.10, - }, - "acc_norm": { - "value": 0.3507, - "floor_pct": 0.04, - "ceil_pct": 0.10, }, }, EVAL.LM_EVAL.MMLU_STEM: { "chat_template": False, "acc": { - "value": 0.2978, + "value": 0.3768, # 0.3099 4096, 0.3270 2048 + "floor_pct": 0.04, + }, + }, + EVAL.LM_EVAL.ARC_CHALLENGE: { + "chat_template": True, + "acc": { + "value": 0.3046, # 0.3294 4096, 0.3242 2048 + "floor_pct": 0.04, + }, + "acc_norm": { + "value": 0.3345, # 0.3558 4096, 0.3635 2048 "floor_pct": 0.04, - "ceil_pct": 0.10, }, }, } diff --git a/tests/models/test_llama3_2_awq.py b/tests/models/test_llama3_2_awq.py index 48c4a8c45..855736ddd 100644 --- a/tests/models/test_llama3_2_awq.py +++ b/tests/models/test_llama3_2_awq.py @@ -9,33 +9,40 @@ from gptqmodel.utils.eval import EVAL -# a100:0 -# desc_act = False, act_group_aware = False 0.2500/0.2841 -# desc_act = False, act_group_aware = True 0.3063/0.3456 -# desc_act = True, 0.3089/0.3328 +# | Metric | MARLIN | +# |--------------------------------|----------| +# | arc_challenge :: acc,none | 0.3106 | +# | arc_challenge :: acc_norm,none | 0.3532 | +# | mmlu_stem :: acc,none | 0.3527 | +# | gsm8k_plat :: exact,flexible | 0.2192 | class TestLlama3_2_awq(ModelTest): NATIVE_MODEL_ID = "/monster/data/model/Llama-3.2-1B-Instruct" # "meta-llama/Llama-3.2-1B-Instruct" EVAL_BATCH_SIZE = 64 + DATASET_CONCAT_SIZE = 2048 # new EVAL_TASKS = { + EVAL.LM_EVAL.GSM8K_PLATINUM_COT: { + "chat_template": True, + "exact_match,flexible-extract": { + "value": 0.2192, + "floor_pct": 0.04, + }, + }, EVAL.LM_EVAL.ARC_CHALLENGE: { "chat_template": True, "acc": { - "value": 0.3200, + "value": 0.3106, "floor_pct": 0.04, - "ceil_pct": 0.10, }, "acc_norm": { - "value": 0.3362, + "value": 0.3532, "floor_pct": 0.04, - "ceil_pct": 0.10, }, }, EVAL.LM_EVAL.MMLU_STEM: { "chat_template": False, "acc": { - "value": 0.3657, + "value": 0.3527, "floor_pct": 0.04, - "ceil_pct": 0.10, }, }, } diff --git a/tests/models/test_qwen2_5.py b/tests/models/test_qwen2_5.py index ae194898c..4c5e2c072 100644 --- a/tests/models/test_qwen2_5.py +++ b/tests/models/test_qwen2_5.py @@ -10,20 +10,29 @@ # | Metric | MARLIN | # |--------------------------------|----------| -# | arc_challenge :: acc,none | 0.2892 | -# | arc_challenge :: acc_norm,none | 0.3302 | -# | mmlu_stem :: acc,none | 0.4351 | +# | arc_challenge :: acc,none | 0.2961 | +# | arc_challenge :: acc_norm,none | 0.3285 | +# | mmlu_stem :: acc,none | 0.3942 | +# | gsm8k_plat :: exact,flexible | 0.2963 | class TestQwen2_5(ModelTest): GROUP_SIZE = 32 NATIVE_MODEL_ID = "/monster/data/model/Qwen2.5-0.5B-Instruct" EVAL_BATCH_SIZE = 64 + DATASET_CONCAT_SIZE = 2048 EVAL_TASKS = { + EVAL.LM_EVAL.GSM8K_PLATINUM_COT: { + "chat_template": True, + "exact_match,flexible-extract": { + "value": 0.2963, + "floor_pct": 0.04, + }, + }, EVAL.LM_EVAL.ARC_CHALLENGE: { - "acc": {"value": 0.2910, "floor_pct": 0.04}, - "acc_norm": {"value": 0.3268, "floor_pct": 0.04}, + "acc": {"value": 0.2961, "floor_pct": 0.04}, + "acc_norm": {"value": 0.3285, "floor_pct": 0.04}, }, EVAL.LM_EVAL.MMLU_STEM: { - "acc": {"value": 0.3819, "floor_pct": 0.04}, + "acc": {"value": 0.3942, "floor_pct": 0.04}, }, } diff --git a/tests/test_benchmark_submodule_finalize.py b/tests/test_benchmark_submodule_finalize.py index b67048953..54a217eca 100644 --- a/tests/test_benchmark_submodule_finalize.py +++ b/tests/test_benchmark_submodule_finalize.py @@ -20,7 +20,14 @@ from gptqmodel.utils.threadx import DeviceThreadPool -def _dummy_prepare_dataset(*, calibration_dataset, calibration_dataset_concat_size, calibration_dataset_sort, batch_size): +def _dummy_prepare_dataset( + *, + calibration_dataset, + calibration_dataset_concat_size, + calibration_dataset_sort, + batch_size, + calibration_concat_separator=None, +): return calibration_dataset diff --git a/tests/test_gptq_device_ctx.py b/tests/test_gptq_device_ctx.py index 0ee915c47..8ce43be79 100644 --- a/tests/test_gptq_device_ctx.py +++ b/tests/test_gptq_device_ctx.py @@ -16,7 +16,14 @@ from gptqmodel.quantization.config import QuantizeConfig -def _dummy_prepare_dataset(*, calibration_dataset, calibration_dataset_concat_size, calibration_dataset_sort, batch_size): +def _dummy_prepare_dataset( + *, + calibration_dataset, + calibration_dataset_concat_size, + calibration_dataset_sort, + batch_size, + calibration_concat_separator=None, +): return calibration_dataset diff --git a/tests/test_hessian_inverse.py b/tests/test_hessian_inverse.py index 60962611e..d1c01ffaf 100644 --- a/tests/test_hessian_inverse.py +++ b/tests/test_hessian_inverse.py @@ -3,6 +3,7 @@ # SPDX-License-Identifier: Apache-2.0 # Contact: qubitium@modelcloud.ai, x.com/qubitium +import pytest import torch import torch.nn as nn @@ -16,6 +17,15 @@ def _build_gptq(damp_percent: float, damp_auto_increment: float) -> GPTQ: return GPTQ(module, qcfg=qcfg) +def _damped_hessian(base: torch.Tensor, used_damp: float) -> torch.Tensor: + """Reconstruct the damped matrix the solver actually inverted.""" + damped = base.clone() + diag_view = damped.diagonal() + mean = torch.mean(diag_view) + diag_view.add_(used_damp * mean) + return damped + + def test_hessian_inverse_handles_rank_deficiency(): gptq = _build_gptq(damp_percent=0.05, damp_auto_increment=0.05) device = gptq.module.target_device @@ -27,14 +37,83 @@ def test_hessian_inverse_handles_rank_deficiency(): assert hessian_inv.shape == hessian.shape assert 0 < damp < 1 assert torch.allclose(hessian_inv, torch.triu(hessian_inv)) + # Accuracy sanity check: recovered triangular factor should match the inverse of the damped matrix. + reconstructed = hessian_inv.transpose(-1, -2) @ hessian_inv + expected_inverse = torch.linalg.inv(_damped_hessian(hessian, damp)) + assert torch.allclose(reconstructed, expected_inverse, atol=1e-5, rtol=1e-4) def test_hessian_inverse_returns_none_for_indefinite_matrix(): gptq = _build_gptq(damp_percent=0.05, damp_auto_increment=0.25) device = gptq.module.target_device hessian = torch.tensor([[0.0, 1.0], [1.0, 0.0]], dtype=torch.float32, device=device) + original = hessian.clone() hessian_inv, damp = gptq.hessian_inverse(hessian) assert hessian_inv is None assert damp == 1.0 + # The diagonal should reflect the final floor attempt. + assert torch.allclose(hessian.diagonal(), torch.full((2,), 0.1, device=device)) + # Off-diagonals must remain untouched. + assert torch.allclose(hessian - torch.diag(hessian.diagonal()), original - torch.diag(original.diagonal())) + + +def test_hessian_inverse_matches_reference_for_positive_definite_matrix(): + gptq = _build_gptq(damp_percent=0.05, damp_auto_increment=0.05) + device = gptq.module.target_device + original = torch.tensor( + [[4.0, 1.0, 0.5], [1.0, 3.0, 0.2], [0.5, 0.2, 2.5]], + dtype=torch.float32, + device=device, + ) + hessian = original.clone() + + hessian_inv, used_damp = gptq.hessian_inverse(hessian) + + assert hessian_inv is not None + # Ensure the solver does not mutate a healthy block. + assert torch.allclose(hessian, original) + + damped = _damped_hessian(hessian, used_damp) + reconstructed = hessian_inv.transpose(-1, -2) @ hessian_inv + expected_inverse = torch.linalg.inv(damped) + assert torch.allclose(reconstructed, expected_inverse, atol=1e-6, rtol=1e-5) + + +def test_hessian_inverse_applies_diagonal_floor_for_semi_definite_input(): + gptq = _build_gptq(damp_percent=0.05, damp_auto_increment=0.0) + device = gptq.module.target_device + hessian = torch.tensor([[0.0, 0.01], [0.01, 0.0]], dtype=torch.float32, device=device) + + hessian_inv, used_damp = gptq.hessian_inverse(hessian) + + assert hessian_inv is not None + assert used_damp == pytest.approx(gptq.qcfg.damp_percent) + # Diagonal should be floored to a positive value so later steps see a PD matrix. + assert torch.all(hessian.diagonal() > 0) + assert torch.allclose(hessian.diagonal(), torch.full((2,), 0.01, device=device), atol=1e-7, rtol=0.0) + + damped = _damped_hessian(hessian, used_damp) + # Should be positive definite after flooring, so Cholesky succeeds. + torch.linalg.cholesky(damped) + reconstructed = hessian_inv.transpose(-1, -2) @ hessian_inv + expected_inverse = torch.linalg.inv(damped) + assert torch.allclose(reconstructed, expected_inverse, atol=1e-5, rtol=1e-4) + + +def test_hessian_inverse_handles_singleton_flooring(): + gptq = _build_gptq(damp_percent=0.05, damp_auto_increment=0.0) + device = gptq.module.target_device + hessian = torch.tensor([[0.0]], dtype=torch.float32, device=device) + + hessian_inv, used_damp = gptq.hessian_inverse(hessian) + + assert hessian_inv is not None + assert hessian_inv.shape == hessian.shape + assert torch.allclose(hessian.diagonal(), torch.tensor([1e-6], dtype=torch.float32, device=device)) + + damped = _damped_hessian(hessian, used_damp) + reconstructed = hessian_inv.transpose(-1, -2) @ hessian_inv + expected_inverse = torch.linalg.inv(damped) + assert torch.allclose(reconstructed, expected_inverse, atol=1e-6, rtol=1e-4) diff --git a/tests/test_prepare_dataset.py b/tests/test_prepare_dataset.py new file mode 100644 index 000000000..2522495cd --- /dev/null +++ b/tests/test_prepare_dataset.py @@ -0,0 +1,113 @@ +# SPDX-FileCopyrightText: 2024-2025 ModelCloud.ai +# SPDX-FileCopyrightText: 2024-2025 qubitium@modelcloud.ai +# SPDX-License-Identifier: Apache-2.0 +# Contact: qubitium@modelcloud.ai, x.com/qubitium + +import copy + +import torch + +from gptqmodel.models.base import BaseQModel + + +class _StubTokenizer: + pad_token_id = 0 + + def __call__(self, text, return_tensors="pt", add_special_tokens=True): + if isinstance(text, list): + raise ValueError("_StubTokenizer only supports string inputs") + token_ids = [self._encode_char(ch) for ch in str(text)] + attention = [1] * len(token_ids) + input_tensor = torch.tensor([token_ids], dtype=torch.long) + mask_tensor = torch.tensor([attention], dtype=torch.long) + return {"input_ids": input_tensor, "attention_mask": mask_tensor} + + @staticmethod + def _encode_char(ch: str) -> int: + value = ord(ch) + return value if value > 0 else 1 + + +def _make_qmodel() -> BaseQModel: + model = BaseQModel.__new__(BaseQModel) + model.tokenizer = _StubTokenizer() + model.support_batch_quantize = True + dummy_config = type("_Cfg", (), {"max_position_embeddings": 128})() + dummy_model = type("_DummyModel", (), {"config": dummy_config})() + model.model = dummy_model + return model + + +def _sample_dataset(): + return [ + {"input_ids": [[1, 2]], "attention_mask": [[1, 1]]}, + {"input_ids": [[3]], "attention_mask": [[1]]}, + ] + + +def test_prepare_dataset_concat_without_separator(): + qmodel = _make_qmodel() + dataset = copy.deepcopy(_sample_dataset()) + + batches = qmodel.prepare_dataset( + calibration_dataset=dataset, + calibration_dataset_concat_size=5, + calibration_dataset_sort=None, + batch_size=1, + calibration_data_min_length=0, + calibration_concat_separator=None, + ) + + assert len(batches) == 1 + input_ids = batches[0]["input_ids"].tolist() + attention_mask = batches[0]["attention_mask"].int().tolist() + + assert input_ids == [[1, 2, 3, 0, 0]] + assert attention_mask == [[1, 1, 1, 0, 0]] + + +def test_prepare_dataset_concat_with_separator(): + qmodel = _make_qmodel() + dataset = copy.deepcopy(_sample_dataset()) + + batches = qmodel.prepare_dataset( + calibration_dataset=dataset, + calibration_dataset_concat_size=5, + calibration_dataset_sort=None, + batch_size=1, + calibration_data_min_length=0, + calibration_concat_separator="##", + ) + + assert len(batches) == 1 + input_ids = batches[0]["input_ids"].tolist() + attention_mask = batches[0]["attention_mask"].int().tolist() + + sep_tokens = [_StubTokenizer._encode_char("#"), _StubTokenizer._encode_char("#")] + assert input_ids == [[1, 2, *sep_tokens, 3]] + assert attention_mask == [[1, 1, 1, 1, 1]] + + +def test_prepare_dataset_splits_long_row_across_blocks(): + qmodel = _make_qmodel() + long_row = {"input_ids": [[1, 2, 3, 4, 5, 6]], "attention_mask": [[1, 1, 1, 1, 1, 1]]} + dataset = [long_row] + + batches = qmodel.prepare_dataset( + calibration_dataset=dataset, + calibration_dataset_concat_size=5, + calibration_dataset_sort=None, + batch_size=1, + calibration_data_min_length=0, + calibration_concat_separator=None, + ) + + assert len(batches) == 2 + first_ids = batches[0]["input_ids"].tolist() + second_ids = batches[1]["input_ids"].tolist() + first_mask = batches[0]["attention_mask"].int().tolist() + second_mask = batches[1]["attention_mask"].int().tolist() + assert first_ids == [[1, 2, 3, 4, 5]] + assert first_mask == [[1, 1, 1, 1, 1]] + assert second_ids == [[6, 0, 0, 0, 0]] + assert second_mask == [[1, 0, 0, 0, 0]]