diff --git a/gptqmodel/models/base.py b/gptqmodel/models/base.py index e9d70bd70..c93f96a32 100644 --- a/gptqmodel/models/base.py +++ b/gptqmodel/models/base.py @@ -37,6 +37,7 @@ from ..adapter.adapter import Adapter from ..nn_modules.qlinear import BaseQuantLinear from ..nn_modules.qlinear.torch import TorchQuantLinear +from ..nn_modules.qlinear.lookahead import configure_default_lookahead from ..quantization import QuantizeConfig from ..quantization.config import FORMAT, METHOD, QUANTIZE_BLACK_LIST, dynamic_get from ..quantization.rotation.rotation import fuse_layer_norms, rotate_model @@ -251,6 +252,8 @@ def __init__( # print kernel info: log.info(f"Kernel: loaded -> `[{', '.join(cls.__name__ for cls in self.kernels())}]`") + self._auto_configure_lookahead() + @classmethod def extract_layers_node(cls): """ @@ -1096,6 +1099,19 @@ def kernels(self) -> List[Type[BaseQuantLinear]]: return list(loaded_kernels) + def _auto_configure_lookahead(self) -> None: + if not isinstance(self.model, nn.Module): + return + + quant_modules = [module for module in self.model.modules() if isinstance(module, TorchQuantLinear)] + if not quant_modules: + return + + if not any(getattr(module, "_lookahead_enabled", False) for module in quant_modules): + return + + configure_default_lookahead(self.model) + def compile(self, backend: str = "inductor", mode: str = None, fullgraph: bool = False): log.warn("Deprecation: `model.compile()` is deprecated. Please use `model.optimize()` instead.") return self.optimize(backend=backend, mode=mode, fullgraph=fullgraph) diff --git a/gptqmodel/nn_modules/qlinear/lookahead.py b/gptqmodel/nn_modules/qlinear/lookahead.py new file mode 100644 index 000000000..1cdb222da --- /dev/null +++ b/gptqmodel/nn_modules/qlinear/lookahead.py @@ -0,0 +1,90 @@ +# SPDX-FileCopyrightText: 2024-2025 ModelCloud.ai +# SPDX-FileCopyrightText: 2024-2025 qubitium@modelcloud.ai +# SPDX-License-Identifier: Apache-2.0 +# Contact: qubitium@modelcloud.ai, x.com/qubitium + +from collections import defaultdict +from typing import Iterable, List, Tuple + +from .torch import TorchQuantLinear + + +def configure_lookahead_chain(modules: Iterable[TorchQuantLinear]): + """Wire a sequence of TorchQuantLinear modules for one-step lookahead. + + Each module in *modules* (except the last) will prefetch the next module's + dequantized weights the moment it finishes its own forward call. The last + module's ``lookahead_next`` pointer is cleared. + """ + + last = None + for module in modules: + if last is not None: + last.enable_lookahead(True).set_lookahead_next(module) + module.enable_lookahead(True) + last = module + if last is not None: + last.set_lookahead_next(None) + + +def _clear_existing_links(modules: Iterable[TorchQuantLinear]): + for module in modules: + module.set_lookahead_next(None) + + +def configure_default_lookahead(model) -> None: + """Eagerly decode the MLP projection trio when attention ``q_proj`` runs. + + For each transformer block this disables lookahead between + ``self_attn.{q,k,v,o}_proj`` and instead wires ``q_proj`` to prefetch the + block's ``mlp.{gate,up,down}_proj`` modules concurrently. Missing modules + are skipped. + """ + + ordered_modules: List[Tuple[str, TorchQuantLinear]] = [] + for name, module in model.named_modules(): + if isinstance(module, TorchQuantLinear): + ordered_modules.append((name, module)) + + if not ordered_modules: + return + + _clear_existing_links(module for _, module in ordered_modules) + + attn_order = ("q_proj", "k_proj", "v_proj", "o_proj") + mlp_order = ("gate_proj", "up_proj", "down_proj") + + attn_blocks = defaultdict(dict) + mlp_blocks = defaultdict(dict) + + for name, module in ordered_modules: + if ".self_attn." in name: + prefix, leaf = name.split(".self_attn.", maxsplit=1) + leaf = leaf.split(".")[0] + attn_blocks[prefix][leaf] = module + continue + if ".mlp." in name: + prefix, leaf = name.split(".mlp.", maxsplit=1) + leaf = leaf.split(".")[0] + mlp_blocks[prefix][leaf] = module + + for block in set(list(attn_blocks.keys()) + list(mlp_blocks.keys())): + attn = attn_blocks.get(block, {}) + mlp = mlp_blocks.get(block, {}) + + q_module = attn.get("q_proj") + attn_modules = [attn.get(key) for key in attn_order if attn.get(key) is not None] + mlp_targets = [mlp.get(key) for key in mlp_order if mlp.get(key) is not None] + + # reset lookahead state on all participating modules within this block + for module in attn_modules: + module.set_lookahead_next(None) + module.enable_lookahead(False) + for module in mlp_targets: + module.set_lookahead_next(None) + module.enable_lookahead(False) + + if q_module is not None and mlp_targets: + q_module.enable_lookahead(True).set_lookahead_next(tuple(mlp_targets)) + for target in mlp_targets: + target.enable_lookahead(True) diff --git a/gptqmodel/nn_modules/qlinear/torch.py b/gptqmodel/nn_modules/qlinear/torch.py index bb2de2755..f07270f65 100644 --- a/gptqmodel/nn_modules/qlinear/torch.py +++ b/gptqmodel/nn_modules/qlinear/torch.py @@ -4,6 +4,10 @@ # Contact: qubitium@modelcloud.ai, x.com/qubitium +import math +import os +from collections.abc import Iterable + import torch import torch.nn as nn from transformers import PreTrainedModel @@ -15,6 +19,14 @@ from ...utils.logger import setup_logger from ...utils.torch import torch_compile +try: + from ..triton_utils.dequant import dequant as triton_dequant + + _TRITON_DEQUANT_AVAILABLE = True +except Exception: # pragma: no cover - optional dependency + triton_dequant = None + _TRITON_DEQUANT_AVAILABLE = False + log = setup_logger() @@ -70,6 +82,25 @@ def __init__( **kwargs) self.dequant_dtype = torch.int16 if self.bits == 8 else torch.int8 + self._streaming_enabled = bool(int(os.environ.get("GPTQ_TORCH_STREAMING", "0"))) + self._stream_tile_cols = int(os.environ.get("GPTQ_TORCH_STREAM_TILE", "512")) + self._stream_double_buffers = 2 + self._g_idx_long_cache = None + self._zeros_cache = None + self._cache_enabled = bool(int(os.environ.get("GPTQ_TORCH_CACHE_WEIGHTS", "0"))) + triton_flag = os.environ.get("GPTQ_TORCH_TRITON_DEQUANT") + if triton_flag is None: + self._triton_dequant_enabled = _TRITON_DEQUANT_AVAILABLE + else: + self._triton_dequant_enabled = ( + triton_flag not in {"0", "false", "False"} + ) and _TRITON_DEQUANT_AVAILABLE + self._cached_weights = {} + self._lookahead_enabled = bool(int(os.environ.get("GPTQ_TORCH_LOOKAHEAD", "0"))) + self._lookahead_next = None + self._prefetch_stream = None + self._prefetched_weights = {} + self._prefetch_events = {} # if self.group_size != self.in_features: # self.padded_infeatures = self.in_features + (-self.in_features % self.group_size) @@ -91,6 +122,19 @@ def post_init(self): # torch benefits the most from torch.compile, enable it by default self.optimize() + self._stream_reset_cache() + self.clear_weight_cache() + self._reset_prefetch_state() + + def dequantize_weight(self, num_itr: int = 1): + if ( + num_itr == 1 + and self._triton_dequant_enabled + and self._can_use_triton_dequant() + ): + return self._dequantize_weight_triton() + + return super().dequantize_weight(num_itr=num_itr) def optimize(self, backend: str = None, mode: str = None, fullgraph: bool = False): if self.optimized: @@ -151,10 +195,23 @@ def forward(self, x: torch.Tensor): return out def _forward(self, x, out_shape): - num_itr = self.g_idx.shape[0] // x.shape[-1] - # make sure dequant dtype matches input x - weights = self.dequantize_weight(num_itr=num_itr).to(x.dtype) + cached = self._maybe_get_cached_weights(x) + if cached is not None: + out = torch.matmul(x, cached).reshape(out_shape) + elif self._should_use_streaming(x): + out = self._forward_streaming(x, out_shape) + else: + out = self._forward_eager(x, out_shape) + + self._maybe_schedule_lookahead(x.dtype) + return out + def _forward_eager(self, x: torch.Tensor, out_shape): + num_itr = self.g_idx.shape[0] // x.shape[-1] + weights = self._consume_prefetched_weights(x.dtype) + if weights is None: + weights = self.dequantize_weight(num_itr=num_itr).to(x.dtype) + self._update_cached_weights(weights) out = torch.matmul(x, weights).reshape(out_shape) if self.bias is not None: @@ -165,6 +222,243 @@ def _forward(self, x, out_shape): return out + def _forward_streaming(self, x: torch.Tensor, out_shape): + tile = max(64, min(self._stream_tile_cols, self.out_features)) + total_tiles = math.ceil(self.out_features / tile) + device = x.device + + out = torch.empty((x.shape[0], self.out_features), dtype=x.dtype, device=device) + buffers = [ + torch.empty((self.in_features, tile), dtype=x.dtype, device=device) + for _ in range(self._stream_double_buffers) + ] + widths = [0 for _ in range(self._stream_double_buffers)] + + stream_dequant = torch.cuda.Stream(device=device) + zeros = self._stream_decode_qzeros() + g_idx = self._stream_g_idx_long() + + def schedule(tile_idx: int, buffer_idx: int): + start = tile_idx * tile + end = min(start + tile, self.out_features) + with torch.cuda.stream(stream_dequant): + widths[buffer_idx] = self._stream_dequantize_tile( + buffer=buffers[buffer_idx], + zeros=zeros, + g_idx=g_idx, + start=start, + end=end, + dtype=x.dtype, + ) + + schedule(0, 0) + compute_stream = torch.cuda.current_stream() + + for tile_idx in range(total_tiles): + buffer_idx = tile_idx % self._stream_double_buffers + compute_stream.wait_stream(stream_dequant) + width = widths[buffer_idx] + start = tile_idx * tile + end = start + width + + out_slice = out.narrow(1, start, width) + out_slice.zero_() + torch.addmm( + out_slice, + x, + buffers[buffer_idx].narrow(1, 0, width), + beta=0.0, + alpha=1.0, + out=out_slice, + ) + + next_tile = tile_idx + 1 + if next_tile < total_tiles: + next_buffer_idx = next_tile % self._stream_double_buffers + schedule(next_tile, next_buffer_idx) + + out = out.reshape(out_shape) + + if self.bias is not None: + out.add_(self.bias) + + if self.adapter: + out = self.adapter.apply(x=x, out=out) + + return out + + def _maybe_get_cached_weights(self, x: torch.Tensor): + if not self._cache_enabled or self.training: + return None + cached = self._cached_weights.get(x.dtype) + if cached is not None: + if cached.device != x.device: + self._cached_weights.pop(x.dtype, None) + else: + return cached + return None + + def _update_cached_weights(self, weights: torch.Tensor): + if not self._cache_enabled or self.training: + return + self._cached_weights[weights.dtype] = weights.detach() + + def _consume_prefetched_weights(self, dtype: torch.dtype): + if not self._lookahead_enabled or self.training: + return None + tensor = self._prefetched_weights.pop(dtype, None) + if tensor is None: + return None + event = self._prefetch_events.pop(dtype, None) + if event is not None and torch.cuda.is_available(): + torch.cuda.current_stream(device=tensor.device).wait_event(event) + return tensor + + def _stream_dequantize_tile( + self, + buffer: torch.Tensor, + zeros: torch.Tensor, + g_idx: torch.Tensor, + start: int, + end: int, + dtype: torch.dtype, + ) -> int: + width = end - start + qweight_tile = self.qweight.narrow(1, start, width) + weight = torch.bitwise_right_shift( + qweight_tile.unsqueeze(1).expand(-1, self.pack_factor, -1), + self.wf_unsqueeze_neg_one, + ).to(self.dequant_dtype) + weight = torch.bitwise_and(weight, self.maxq) + weight = weight.reshape(weight.shape[0] * weight.shape[1], weight.shape[2]) + + zeros_tile = zeros.narrow(1, start, width) + scales_tile = self.scales.narrow(1, start, width) + + dequant = scales_tile[g_idx] * (weight - zeros_tile[g_idx]) + buffer.narrow(1, 0, width).copy_(dequant.to(dtype)) + return width + + def _stream_decode_qzeros(self): + if self._zeros_cache is not None and self._zeros_cache.device == self.qzeros.device: + return self._zeros_cache + + zeros = torch.bitwise_right_shift( + self.qzeros.unsqueeze(2).expand(-1, -1, self.pack_factor), + self.wf_unsqueeze_zero, + ).to(self.dequant_dtype) + zeros = torch.bitwise_and(zeros, self.maxq).reshape(self.scales.shape) + self._zeros_cache = zeros + return zeros + + def _stream_g_idx_long(self): + if self._g_idx_long_cache is None or self._g_idx_long_cache.device != self.g_idx.device: + self._g_idx_long_cache = self.g_idx.long() + return self._g_idx_long_cache + + def _stream_reset_cache(self): + self._zeros_cache = None + self._g_idx_long_cache = None + + def _should_use_streaming(self, x: torch.Tensor) -> bool: + if not self._streaming_enabled: + return False + if x.device.type != "cuda": + return False + if not torch.cuda.is_available(): + return False + # Torch kernels with num_itr > 1 already tiled differently. + if self.g_idx.shape[0] // x.shape[-1] != 1: + return False + return True + + def enable_streaming(self, enabled: bool = True, tile_cols: int = None): + self._streaming_enabled = enabled + if tile_cols is not None: + self._stream_tile_cols = tile_cols + self._stream_reset_cache() + return self + + def enable_weight_cache(self, enabled: bool = True): + self._cache_enabled = enabled + if not enabled: + self.clear_weight_cache() + return self + + def clear_weight_cache(self): + self._cached_weights.clear() + + def enable_lookahead(self, enabled: bool = True): + self._lookahead_enabled = enabled + if not enabled: + self._reset_prefetch_state() + return self + + def set_lookahead_next(self, module: "TorchQuantLinear"): + if module is None: + self._lookahead_next = None + self._reset_prefetch_state() + return self + + if isinstance(module, TorchQuantLinear): + self._lookahead_next = module + return self + + if isinstance(module, Iterable): + targets = tuple(m for m in module if m is not None) + if not targets: + self._lookahead_next = None + self._reset_prefetch_state() + return self + for target in targets: + if not isinstance(target, TorchQuantLinear): + raise TypeError("lookahead targets must be TorchQuantLinear modules or None") + self._lookahead_next = targets + return self + + raise TypeError("lookahead target must be TorchQuantLinear, iterable of TorchQuantLinear, or None") + + def _reset_prefetch_state(self): + for event in self._prefetch_events.values(): + if hasattr(event, "destroy"): + event.destroy() + self._prefetch_events.clear() + self._prefetched_weights.clear() + self._prefetch_stream = None + + def _maybe_schedule_lookahead(self, dtype: torch.dtype): + if not self._lookahead_enabled or self.training: + return + next_module = self._lookahead_next + if next_module is None: + return + if self.qweight.device.type != "cuda": + return + if isinstance(next_module, tuple): + for module in next_module: + module._prefetch(dtype) + else: + next_module._prefetch(dtype) + + def _prefetch(self, dtype: torch.dtype): + if not self._lookahead_enabled or self.training: + return + if dtype in self._prefetched_weights: + return + device = self.list_buffers()[0].device + if device.type != "cuda": + return + if self._prefetch_stream is None: + self._prefetch_stream = torch.cuda.Stream(device=device) + stream = self._prefetch_stream + with torch.cuda.stream(stream): + num_itr = max(1, self.g_idx.shape[0] // self.in_features) + weights = self.dequantize_weight(num_itr=num_itr).to(dtype) + event = torch.cuda.Event(enable_timing=False) + event.record(stream) + self._prefetched_weights[dtype] = weights + self._prefetch_events[dtype] = event + # clear gptq only weights: useful in de-quantization def _empty_gptq_only_weights(self): self.qzeros = None @@ -172,6 +466,39 @@ def _empty_gptq_only_weights(self): self.g_idx = None self.scales = None + def _can_use_triton_dequant(self) -> bool: + if not _TRITON_DEQUANT_AVAILABLE: + return False + if self.training: + return False + if self.qweight is None or self.qzeros is None or self.scales is None or self.g_idx is None: + return False + if self.qweight.device.type != "cuda": + return False + if self.bits not in (2, 3, 4, 8): + return False + if not (self.qweight.is_contiguous() and self.qzeros.is_contiguous() and self.scales.is_contiguous()): + return False + # g_idx is stored as int32 tensor; ensure it resides on the same device. + if self.g_idx.device != self.qweight.device: + return False + return True + + def _dequantize_weight_triton(self) -> torch.Tensor: + # Use the Triton helper to decode weights directly on device. + dtype = self.scales.dtype + weights = triton_dequant( + dtype, + self.qweight, + self.scales, + self.qzeros, + self.g_idx, + self.bits, + self.pack_dtype_bits, + self.maxq, + ) + return weights + def dequantize_model(model: PreTrainedModel): for name, module in model.named_modules(): if isinstance(module, BaseQuantLinear) and not isinstance(module, TorchQuantLinear): diff --git a/tests/models/test_llama3_2.py b/tests/models/test_llama3_2.py index b72e33cef..838f6af81 100644 --- a/tests/models/test_llama3_2.py +++ b/tests/models/test_llama3_2.py @@ -6,20 +6,20 @@ from model_test import ModelTest -# a100:0 -# desc_act = False, act_group_aware = False 0.2500/0.2841 -# desc_act = False, act_group_aware = True 0.3063/0.3456 -# desc_act = True, 0.3089/0.3328 +# a100:7 +# desc_act = False, act_group_aware = False 0.2918/0.3422 +# desc_act = False, act_group_aware = True 0.3311/0.3549 +# desc_act = True, 0.3191/0.3567 class TestLlama3_2(ModelTest): NATIVE_MODEL_ID = "/monster/data/model/Llama-3.2-1B-Instruct" # "meta-llama/Llama-3.2-1B-Instruct" - NATIVE_ARC_CHALLENGE_ACC = 0.3234 - NATIVE_ARC_CHALLENGE_ACC_NORM = 0.3524 - QUANT_ARC_MAX_DELTA_FLOOR_PERCENT = 0.36 + NATIVE_ARC_CHALLENGE_ACC = 0.3311 + NATIVE_ARC_CHALLENGE_ACC_NORM = 0.3549 + QUANT_ARC_MAX_DELTA_FLOOR_PERCENT = 0.05 APPLY_CHAT_TEMPLATE = True V2 = False DEBUG = True - ACT_GROUP_AWARE = True - DESC_ACT = False + ACT_GROUP_AWARE = False + DESC_ACT = True DATASET_SIZE = 1024 DATASET_SORT = "desc" QUANT_BATCH_SIZE = 4 diff --git a/tests/test_attn_mask.py b/tests/test_attn_mask.py new file mode 100644 index 000000000..0b1a7fef1 --- /dev/null +++ b/tests/test_attn_mask.py @@ -0,0 +1,29 @@ +# SPDX-FileCopyrightText: 2024-2025 ModelCloud.ai +# SPDX-License-Identifier: Apache-2.0 + +import torch + +from gptqmodel.utils.attn_mask import apply_keep_mask_bt, normalize_seq_mask + + +def test_normalize_seq_mask_binary_mask(): + mask = torch.tensor([[1, 0, 1]]) + + keep = normalize_seq_mask(mask) + + assert keep.dtype is torch.bool + assert keep.tolist() == [[True, False, True]] + + +def test_normalize_seq_mask_additive_zero_keep(): + mask = torch.tensor([[[[0.0, -10000.0, 0.0]]]]) + + keep = normalize_seq_mask(mask) + + assert keep.dtype is torch.bool + assert keep.tolist() == [[True, False, True]] + + values = torch.arange(6, dtype=torch.float32).view(1, 3, 2) + filtered = apply_keep_mask_bt(values, keep) + + assert torch.equal(filtered, torch.tensor([[0.0, 1.0], [4.0, 5.0]])) diff --git a/tests/test_dataset_loading.py b/tests/test_dataset_loading.py new file mode 100644 index 000000000..4f224b031 --- /dev/null +++ b/tests/test_dataset_loading.py @@ -0,0 +1,21 @@ +# SPDX-FileCopyrightText: 2024-2025 ModelCloud.ai +# SPDX-FileCopyrightText: 2024-2025 qubitium@modelcloud.ai +# SPDX-License-Identifier: Apache-2.0 +# Contact: qubitium@modelcloud.ai, x.com/qubitium + +from datasets import load_dataset + + +def test_dataset_loader(): + # Load a small split of a dataset from Hugging Face + dataset = load_dataset("imdb", split="train[:1%]") # load only 1% to keep it small + + # Print dataset info + print(dataset) + + # Print the first row + first_row = dataset[0] + print("First row:", first_row) + +if __name__ == "__main__": + test_dataset_loader() diff --git a/tests/test_hessian_chunk.py b/tests/test_hessian_chunk.py new file mode 100644 index 000000000..4b4a29d5b --- /dev/null +++ b/tests/test_hessian_chunk.py @@ -0,0 +1,264 @@ +# SPDX-FileCopyrightText: 2024-2025 ModelCloud.ai +# SPDX-FileCopyrightText: 2024-2025 qubitium@modelcloud.ai +# SPDX-License-Identifier: Apache-2.0 +# Contact: qubitium@modelcloud.ai, x.com/qubitium + +import math +import statistics +import time +import tracemalloc +import types +from typing import Callable, Dict, Iterable, List, Tuple + +import pytest +import torch +from tabulate import tabulate + +from gptqmodel.quantization import gptq as gptq_impl +from gptqmodel.quantization.config import QuantizeConfig +from gptqmodel.quantization.gptq import GPTQ + + +@pytest.fixture(autouse=True) +def reset_workspace_caches(): + gptq_impl._WORKSPACE_CACHE.clear() + gptq_impl._WORKSPACE_LOCKS.clear() + gptq_impl._BF16_SUPPORT_CACHE.clear() + yield + gptq_impl._WORKSPACE_CACHE.clear() + gptq_impl._WORKSPACE_LOCKS.clear() + gptq_impl._BF16_SUPPORT_CACHE.clear() + + +def _clone_module(module: torch.nn.Module) -> torch.nn.Module: + replica = type(module)(module.in_features, module.out_features, bias=False) + replica.load_state_dict(module.state_dict()) + replica.eval() + return replica + + +def _instrument_chunks(gptq: GPTQ) -> None: + original = gptq._borrow_materialized_chunk_fp32 + + def wrapped(self, chunk, rows): + self._chunk_invocations += 1 + return original(chunk, rows) + + gptq._chunk_invocations = 0 + gptq._borrow_materialized_chunk_fp32 = types.MethodType(wrapped, gptq) + + +def test_hessian_chunk_consistency_matches_full_precision(): + torch.manual_seed(0) + + base = torch.nn.Linear(32, 16, bias=False).eval() + module_full = _clone_module(base) + module_chunked = _clone_module(base) + + qcfg_full = QuantizeConfig( + hessian_chunk_size=None, + hessian_chunk_bytes=1_000_000_000, + hessian_use_bfloat16_staging=False, + ) + qcfg_chunked = QuantizeConfig( + hessian_chunk_size=16, + hessian_use_bfloat16_staging=False, + ) + + gptq_full = GPTQ(module_full, qcfg_full) + gptq_chunked = GPTQ(module_chunked, qcfg_chunked) + + calib = torch.randn(128, 32, dtype=torch.float16) + + gptq_full.process_batch(calib.clone()) + gptq_chunked.process_batch(calib.clone()) + + assert torch.allclose(gptq_full.H, gptq_chunked.H, atol=1e-5, rtol=1e-5) + + +def test_hessian_chunk_invocations_and_workspace_shape(): + torch.manual_seed(1) + + base = torch.nn.Linear(64, 32, bias=False).eval() + + large_cfg = QuantizeConfig(hessian_chunk_size=256) + large_gptq = GPTQ(_clone_module(base), large_cfg) + _instrument_chunks(large_gptq) + + small_cfg = QuantizeConfig(hessian_chunk_size=16) + small_gptq = GPTQ(_clone_module(base), small_cfg) + _instrument_chunks(small_gptq) + + calib = torch.randn(120, 64, dtype=torch.float16) + + large_gptq.process_batch(calib.clone()) + assert large_gptq._chunk_invocations == 1 + + small_gptq.process_batch(calib.clone()) + expected_chunks = math.ceil(calib.shape[0] / small_cfg.hessian_chunk_size) + assert small_gptq._chunk_invocations == expected_chunks + + device = torch.device(base.weight.device) + cols = base.in_features + fp32_key = gptq_impl._workspace_cache_key(device, torch.float32, cols) + + assert fp32_key in gptq_impl._WORKSPACE_CACHE + large_workspace = gptq_impl._WORKSPACE_CACHE[fp32_key] + assert large_workspace.shape[0] >= calib.shape[0] + assert large_workspace.shape[1] == large_gptq.columns + + small_workspace = gptq_impl._WORKSPACE_CACHE[fp32_key] + assert small_workspace is large_workspace + + staging_dtype = small_gptq._preferred_staging_dtype(calib.dtype, device) + if staging_dtype == torch.bfloat16: + staging_key = gptq_impl._workspace_cache_key(device, staging_dtype, cols) + assert staging_key in gptq_impl._WORKSPACE_CACHE + + +def test_hessian_chunk_bytes_budget(): + torch.manual_seed(2) + + base = torch.nn.Linear(48, 24, bias=False).eval() + module = _clone_module(base) + + bytes_budget = 16 * 48 * 4 + qcfg = QuantizeConfig(hessian_chunk_size=None, hessian_chunk_bytes=bytes_budget) + gptq = GPTQ(module, qcfg) + _instrument_chunks(gptq) + + calib = torch.randn(64, 48, dtype=torch.float16) + gptq.process_batch(calib) + + assert gptq._chunk_invocations == math.ceil(calib.shape[0] / 16) + + device = torch.device(module.weight.device) + cols = base.in_features + fp32_key = gptq_impl._workspace_cache_key(device, torch.float32, cols) + workspace = gptq_impl._WORKSPACE_CACHE[fp32_key] + assert workspace.shape[0] == 16 + assert workspace.shape[1] == gptq.columns + + +def _benchmark_case( + base_module: torch.nn.Module, + cfg_factory: Callable[[], QuantizeConfig], + calib: torch.Tensor, + num_iterations: int = 10, +) -> Dict[str, float]: + device = base_module.weight.device + use_cuda = device.type == "cuda" and torch.cuda.is_available() + + warmup_module = _clone_module(base_module) + warmup = GPTQ(warmup_module, cfg_factory()) + warmup.process_batch(calib.clone()) + + def measure_once() -> Tuple[float, float]: + module = _clone_module(base_module) + gptq = GPTQ(module, cfg_factory()) + + if use_cuda: + torch.cuda.reset_peak_memory_stats(device) + else: + tracemalloc.start() + + start = time.perf_counter() + gptq.process_batch(calib.clone()) + if use_cuda: + torch.cuda.synchronize(device) + elapsed = (time.perf_counter() - start) * 1000.0 + + if use_cuda: + peak_mem = torch.cuda.max_memory_allocated(device) + else: + _, peak_mem = tracemalloc.get_traced_memory() + tracemalloc.stop() + + return elapsed, peak_mem / (1024 * 1024) + + timings: List[float] = [] + memories: List[float] = [] + for _ in range(num_iterations): + elapsed, mem_mb = measure_once() + timings.append(elapsed) + memories.append(mem_mb) + + mean_ms = statistics.fmean(timings) + stdev_ms = statistics.pstdev(timings) + mean_mem = statistics.fmean(memories) + + config_sample = cfg_factory() + + return { + "chunk_size": config_sample.hessian_chunk_size, + "chunk_bytes": config_sample.hessian_chunk_bytes, + "bf16": config_sample.hessian_use_bfloat16_staging, + "mean_ms": mean_ms, + "stdev_ms": stdev_ms, + "mean_mem_mb": mean_mem, + } + + +def _print_benchmark_table(rows: Iterable[Dict[str, float]]) -> None: + table_rows = [] + for row in rows: + table_rows.append( + [ + "None" if row["chunk_size"] is None else row["chunk_size"], + "None" if row["chunk_bytes"] is None else row["chunk_bytes"], + row["bf16"], + f"{row['mean_ms']:.3f}", + f"{row['stdev_ms']:.3f}", + f"{row['mean_mem_mb']:.3f}", + ] + ) + + headers = [ + "chunk_size", + "chunk_bytes", + "bf16", + "mean_ms", + "stdev_ms", + "mean_mem_mb", + ] + + print(tabulate(table_rows, headers=headers, tablefmt="github")) + + +def test_hessian_chunk_benchmark_table(): + torch.manual_seed(3) + + base = torch.nn.Linear(96, 48, bias=False).eval() + calib = torch.randn(256, 96, dtype=torch.float16) + + configs: List[Callable[[], QuantizeConfig]] = [ + lambda: QuantizeConfig( + hessian_chunk_size=None, + hessian_chunk_bytes=512 * 1024 * 1024, + hessian_use_bfloat16_staging=False, + ), + lambda: QuantizeConfig( + hessian_chunk_size=64, + hessian_use_bfloat16_staging=False, + ), + lambda: QuantizeConfig( + hessian_chunk_size=32, + hessian_use_bfloat16_staging=True, + ), + lambda: QuantizeConfig( + hessian_chunk_size=None, + hessian_chunk_bytes=64 * 1024 * 1024, + hessian_use_bfloat16_staging=True, + ), + ] + + results = [] + for cfg_factory in configs: + result = _benchmark_case(base, cfg_factory, calib, num_iterations=10) + results.append(result) + + _print_benchmark_table(results) + + assert len(results) == len(configs) + assert all(result["mean_ms"] > 0.0 for result in results) + assert all(result["mean_mem_mb"] > 0.0 for result in results) diff --git a/tests/test_hessian_inverse.py b/tests/test_hessian_inverse.py new file mode 100644 index 000000000..60962611e --- /dev/null +++ b/tests/test_hessian_inverse.py @@ -0,0 +1,40 @@ +# SPDX-FileCopyrightText: 2024-2025 ModelCloud.ai +# SPDX-FileCopyrightText: 2024-2025 qubitium@modelcloud.ai +# SPDX-License-Identifier: Apache-2.0 +# Contact: qubitium@modelcloud.ai, x.com/qubitium + +import torch +import torch.nn as nn + +from gptqmodel.quantization import QuantizeConfig +from gptqmodel.quantization.gptq import GPTQ + + +def _build_gptq(damp_percent: float, damp_auto_increment: float) -> GPTQ: + module = nn.Linear(2, 2, bias=False) + qcfg = QuantizeConfig(damp_percent=damp_percent, damp_auto_increment=damp_auto_increment) + return GPTQ(module, qcfg=qcfg) + + +def test_hessian_inverse_handles_rank_deficiency(): + gptq = _build_gptq(damp_percent=0.05, damp_auto_increment=0.05) + device = gptq.module.target_device + hessian = torch.tensor([[1.0, 1.0], [1.0, 1.0]], dtype=torch.float32, device=device) + + hessian_inv, damp = gptq.hessian_inverse(hessian) + + assert hessian_inv is not None + assert hessian_inv.shape == hessian.shape + assert 0 < damp < 1 + assert torch.allclose(hessian_inv, torch.triu(hessian_inv)) + + +def test_hessian_inverse_returns_none_for_indefinite_matrix(): + gptq = _build_gptq(damp_percent=0.05, damp_auto_increment=0.25) + device = gptq.module.target_device + hessian = torch.tensor([[0.0, 1.0], [1.0, 0.0]], dtype=torch.float32, device=device) + + hessian_inv, damp = gptq.hessian_inverse(hessian) + + assert hessian_inv is None + assert damp == 1.0 diff --git a/tests/test_pack.py b/tests/test_pack.py new file mode 100644 index 000000000..4e35a894b --- /dev/null +++ b/tests/test_pack.py @@ -0,0 +1,145 @@ +# SPDX-FileCopyrightText: 2024-2025 ModelCloud.ai +# SPDX-License-Identifier: Apache-2.0 +# Contact: qubitium@modelcloud.ai, x.com/qubitium + +import math +import unittest + +import torch +import torch.nn as nn +from parameterized import parameterized +from tabulate import tabulate + +from gptqmodel import BACKEND +from gptqmodel.nn_modules.qlinear.torch import TorchQuantLinear + + +class TestPackAccuracy(unittest.TestCase): + in_features = 1024 + out_features = 768 + + @staticmethod + def _build_inputs(bits: int, group_size: int): + torch.manual_seed(0) + + linear = nn.Linear(TestPackAccuracy.in_features, TestPackAccuracy.out_features, bias=False) + + if group_size == -1: + groups = 1 + g_idx = torch.zeros(TestPackAccuracy.in_features, dtype=torch.int64) + else: + groups = math.ceil(TestPackAccuracy.in_features / group_size) + g_idx = torch.arange(TestPackAccuracy.in_features, dtype=torch.int64) // group_size + + max_q = 2 ** bits - 1 + scales = torch.rand(groups, TestPackAccuracy.out_features, dtype=torch.float32) * 0.05 + 1e-3 + zeros = torch.randint(0, max_q + 1, (groups, TestPackAccuracy.out_features), dtype=torch.int32) + + q_int = torch.randint(0, max_q + 1, (TestPackAccuracy.in_features, TestPackAccuracy.out_features), dtype=torch.int32) + scales_expanded = scales[g_idx].to(torch.float32) + zeros_expanded = zeros[g_idx].to(torch.float32) + weight = scales_expanded * (q_int.to(torch.float32) - zeros_expanded) + linear.weight.data = weight.T.to(linear.weight.dtype) + + return linear, scales, zeros, g_idx + + def _quant_linear(self): + qlinear = TorchQuantLinear( + bits=self.current_bits, + group_size=self.current_group_size, + sym=True, + desc_act=True, + in_features=self.in_features, + out_features=self.out_features, + pack_dtype=torch.int32, + backend=BACKEND.TORCH, + bias=False, + ) + return qlinear + + def _run_impl(self, impl: str, linear, scales, zeros, g_idx): + qlinear = self._quant_linear() + scales_T = scales.t().contiguous() + zeros_T = zeros.t().contiguous() + + if impl == "original": + qlinear.pack_original(linear, scales_T, zeros_T, g_idx=g_idx) + elif impl == "pack_block": + qlinear.pack_block( + linear, + scales_T, + zeros_T, + g_idx=g_idx.to(dtype=torch.int32), + ) + elif impl == "gpu": + if not torch.cuda.is_available(): + self.skipTest("CUDA device required for GPU pack comparison") + qlinear.pack_gpu( + linear, + scales_T, + zeros_T, + g_idx=g_idx.to(dtype=torch.int32), + ) + torch.cuda.synchronize() + else: + raise ValueError(f"Unknown impl `{impl}`") + + # Move buffers to CPU for comparisons + result = { + "qweight": qlinear.qweight.detach().cpu(), + "qzeros": qlinear.qzeros.detach().cpu(), + "scales": qlinear.scales.detach().cpu(), + "g_idx": qlinear.g_idx.detach().cpu(), + } + if hasattr(qlinear, "bias") and qlinear.bias is not None: + result["bias"] = qlinear.bias.detach().cpu() + return result + + @parameterized.expand( + [ + (2, -1), (2, 32), (2, 64), (2, 128), + (3, -1), (3, 32), (3, 64), (3, 128), + (4, -1), (4, 32), (4, 64), (4, 128), + (8, -1), (8, 32), (8, 64), (8, 128), + ] + ) + def test_pack_consistency(self, bits, group_size): + self.current_bits = bits + self.current_group_size = group_size + + linear, scales, zeros, g_idx = self._build_inputs(bits, group_size) + + baseline = self._run_impl("original", linear, scales, zeros, g_idx) + pack_cpu = self._run_impl("pack_block", linear, scales, zeros, g_idx) + results = {"pack_block": pack_cpu} + + if torch.cuda.is_available(): + results["pack_gpu"] = self._run_impl("gpu", linear, scales, zeros, g_idx) + + rows = [] + rows.append([f"pack_original (bits={bits}, g={group_size})", 0.0, 0.0, 0.0, 0.0]) + for name, tensors in results.items(): + diff_qweight = (tensors["qweight"].to(dtype=baseline["qweight"].dtype) - baseline["qweight"]).abs().max().item() + diff_qzeros = (tensors["qzeros"].to(dtype=baseline["qzeros"].dtype) - baseline["qzeros"]).abs().max().item() + diff_scales = (tensors["scales"].to(dtype=baseline["scales"].dtype) - baseline["scales"]).abs().max().item() + diff_gidx = (tensors["g_idx"].to(dtype=baseline["g_idx"].dtype) - baseline["g_idx"]).abs().max().item() + rows.append([ + f"{name} (bits={bits}, g={group_size})", + diff_qweight, + diff_qzeros, + diff_scales, + diff_gidx, + ]) + + self.assertTrue(torch.equal(tensors["qweight"], baseline["qweight"])) + self.assertTrue(torch.equal(tensors["qzeros"], baseline["qzeros"])) + self.assertTrue(torch.equal(tensors["g_idx"].to(dtype=baseline["g_idx"].dtype), baseline["g_idx"])) + self.assertTrue(torch.equal(tensors["scales"], baseline["scales"])) + + print( + tabulate( + rows, + headers=["impl", "max|Δ qweight|", "max|Δ qzeros|", "max|Δ scales|", "max|Δ g_idx|"], + floatfmt=".3e", + ) + ) diff --git a/tests/test_qzero_offsets.py b/tests/test_qzero_offsets.py new file mode 100644 index 000000000..795ee9041 --- /dev/null +++ b/tests/test_qzero_offsets.py @@ -0,0 +1,100 @@ +# SPDX-FileCopyrightText: 2024-2025 ModelCloud.ai +# SPDX-License-Identifier: Apache-2.0 + +from types import SimpleNamespace + +import pytest +import torch +import torch.nn as nn + +from gptqmodel.nn_modules.qlinear import BaseQuantLinear +from gptqmodel.quantization import FORMAT, METHOD +from gptqmodel.utils.model import ( + convert_gptq_v1_to_v2_format_module, + convert_gptq_v2_to_v1_format_module, +) + + +class _TestQuantLinear(BaseQuantLinear): + REQUIRES_FORMAT_V2 = True + + def __init__(self, *, bits: int, pack_dtype: torch.dtype, columns: int) -> None: + nn.Module.__init__(self) + self.bits = bits + self.pack_dtype = pack_dtype + self.qzeros = nn.Parameter(torch.zeros((1, columns), dtype=pack_dtype), requires_grad=False) + self._qzeros_format = 1 + + def qzero_format(self, format: int | None = None) -> int: + if format is None: + return self._qzeros_format + self._qzeros_format = format + return self._qzeros_format + + +def _make_module(bits: int, pack_dtype: torch.dtype) -> _TestQuantLinear: + columns = 3 if bits == 3 else 1 + return _TestQuantLinear(bits=bits, pack_dtype=pack_dtype, columns=columns) + + +@pytest.mark.parametrize( + "bits, pack_dtype", + ( + (2, torch.int8), + (3, torch.int32), + (4, torch.int16), + (8, torch.int32), + ), +) +@torch.inference_mode() +def test_qzero_offsets_roundtrip(bits: int, pack_dtype: torch.dtype) -> None: + module = _make_module(bits=bits, pack_dtype=pack_dtype) + original = module.qzeros.data.clone() + + convert_gptq_v1_to_v2_format_module(module=module, bits=bits, pack_dtype=pack_dtype) + + convert_gptq_v2_to_v1_format_module( + module=module, + quantize_config=SimpleNamespace( + bits=bits, + pack_dtype=pack_dtype, + quant_method=METHOD.GPTQ, + format=FORMAT.GPTQ, + ), + ) + + assert torch.equal(module.qzeros.data, original) + + +@torch.inference_mode() +def test_qzero_offsets_scalar_patterns(): + cases = [ + (2, torch.int8, torch.tensor([[0x55]], dtype=torch.int8)), + (2, torch.int32, torch.tensor([[0x5555_5555]], dtype=torch.int32)), + (4, torch.int16, torch.tensor([[0x1111]], dtype=torch.int16)), + (8, torch.int32, torch.tensor([[0x0101_0101]], dtype=torch.int32)), + ] + + for bits, pack_dtype, expected in cases: + module = _make_module(bits=bits, pack_dtype=pack_dtype) + convert_gptq_v1_to_v2_format_module(module=module, bits=bits, pack_dtype=pack_dtype) + assert torch.equal(module.qzeros.data, expected) + + +@torch.inference_mode() +def test_qzero_offsets_triangular_patterns(): + cases = [ + ( + torch.int8, + torch.tensor([[0x24, 0x92, 0x49]], dtype=torch.int8), + ), + ( + torch.int32, + torch.tensor([[0x2492_4924, 0x9249_2492, 0x4924_9249]], dtype=torch.int32), + ), + ] + + for pack_dtype, expected in cases: + module = _make_module(bits=3, pack_dtype=pack_dtype) + convert_gptq_v1_to_v2_format_module(module=module, bits=3, pack_dtype=pack_dtype) + assert torch.equal(module.qzeros.data, expected) diff --git a/tests/test_torch_triton_group_sizes.py b/tests/test_torch_triton_group_sizes.py new file mode 100644 index 000000000..5de88d97d --- /dev/null +++ b/tests/test_torch_triton_group_sizes.py @@ -0,0 +1,118 @@ +# SPDX-FileCopyrightText: 2024-2025 ModelCloud.ai +# SPDX-License-Identifier: Apache-2.0 +# Contact: qubitium@modelcloud.ai, x.com/qubitium + +from __future__ import annotations + +import pytest +import torch +import torch.nn as nn + +from gptqmodel.nn_modules.qlinear.torch import TorchQuantLinear +from gptqmodel.nn_modules.qlinear.tritonv2 import TritonV2QuantLinear + + +def _mock_gptq_linear(bits: int, group_size: int, in_features: int, out_features: int) -> tuple[nn.Linear, torch.Tensor, torch.Tensor, torch.Tensor]: + maxq = (1 << (bits - 1)) - 1 + weight = torch.randn((in_features, out_features), dtype=torch.float32) + + if group_size != -1: + reshaped = weight.view(in_features // group_size, group_size, out_features) + w_g = reshaped.permute(1, 0, 2).reshape(group_size, -1) + else: + w_g = weight + + scales = torch.maximum( + w_g.abs().max(dim=0, keepdim=True).values, + torch.full((1, w_g.shape[1]), 1e-6, device=w_g.device), + ) + scales = scales / maxq + + q = torch.round(w_g / scales).clamp_(-maxq, maxq) + ref = (q * scales).to(dtype=torch.float16) + + if group_size != -1: + ref = ref.reshape(group_size, -1, out_features) + ref = ref.permute(1, 0, 2).reshape(in_features, out_features) + + q = q.reshape(group_size, -1, out_features) + q = q.permute(1, 0, 2).reshape(in_features, out_features) + + linear = nn.Linear(in_features, out_features, bias=False) + linear.weight.data = ref.t().contiguous() + + scales = scales.reshape(-1, out_features).contiguous() + zeros = torch.zeros_like(scales, dtype=torch.int32) + g_idx = torch.arange(in_features, dtype=torch.int32) // ( + group_size if group_size != -1 else in_features + ) + + return linear, scales, zeros, g_idx + + +@pytest.mark.cuda +@pytest.mark.parametrize("group_size", [256, 512, 1024]) +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) +def test_torch_triton_large_group_sizes(group_size: int, dtype: torch.dtype) -> None: + if not torch.cuda.is_available(): + pytest.skip("CUDA device required") + + if dtype is torch.bfloat16 and not torch.cuda.is_bf16_supported(): + pytest.skip("CUDA bfloat16 not supported on this device") + + torch.cuda.set_device(0) + + bits = 4 + in_features = 4096 + out_features = 4096 + + torch.manual_seed(0) + + linear, scales, zeros, g_idx = _mock_gptq_linear(bits, group_size, in_features, out_features) + + torch_module = TorchQuantLinear( + bits=bits, + group_size=group_size, + sym=True, + desc_act=False, + in_features=in_features, + out_features=out_features, + pack_dtype=torch.int32, + bias=False, + ) + torch_module.pack_block(linear, scales.T, zeros.T, g_idx=g_idx) + torch_module.post_init() + + try: + triton_module = TritonV2QuantLinear( + bits=bits, + group_size=group_size, + desc_act=False, + sym=True, + in_features=in_features, + out_features=out_features, + pack_dtype=torch.int32, + bias=False, + ) + except ValueError as err: + pytest.skip(f"Triton backend unavailable: {err}") + + triton_module.pack_block(linear, scales.T, zeros.T, g_idx=g_idx) + triton_module.post_init() + + device = torch.device("cuda:0") + torch_module = torch_module.to(device=device, dtype=dtype).eval() + triton_module = triton_module.to(device=device, dtype=dtype).eval() + + batch = 8 + x = torch.randn((batch, in_features), device=device, dtype=dtype) + + with torch.inference_mode(): + torch_out = torch_module(x) + triton_out = triton_module(x) + + torch_out = torch_out.to(torch.float32) + triton_out = triton_out.to(torch.float32) + + assert torch.allclose(triton_out, torch_out, rtol=1e-2, atol=1e-2) + assert torch_out.abs().max() > 0 diff --git a/tests/test_torch_weight_cache.py b/tests/test_torch_weight_cache.py new file mode 100644 index 000000000..86a617b0e --- /dev/null +++ b/tests/test_torch_weight_cache.py @@ -0,0 +1,129 @@ +# SPDX-FileCopyrightText: 2024-2025 ModelCloud.ai +# SPDX-License-Identifier: Apache-2.0 +# Contact: qubitium@modelcloud.ai, x.com/qubitium + +import torch +import torch.nn as nn +import pytest + +from gptqmodel.nn_modules.qlinear.lookahead import configure_default_lookahead +from gptqmodel.nn_modules.qlinear.torch import TorchQuantLinear + + +def _make_module(device: torch.device): + module = TorchQuantLinear( + bits=4, + group_size=32, + sym=True, + desc_act=True, + in_features=64, + out_features=64, + bias=True, + pack_dtype=torch.int32, + adapter=None, + register_buffers=True, + ).to(device) + + with torch.no_grad(): + module.qweight.zero_() + module.qzeros.zero_() + module.scales.fill_(1.0) + module.bias.uniform_(-0.1, 0.1) + + module.qzero_format(format=2) + module.post_init() + module.eval() + return module + + +def test_cached_forward_matches_baseline(): + device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") + module = _make_module(device) + + x = torch.randn(8, module.in_features, device=device, dtype=torch.float16) + + module.enable_weight_cache(False) + ref = module(x) + + module.enable_weight_cache(True) + module.clear_weight_cache() + cached = module(x) + + torch.testing.assert_close(ref, cached) + assert x.dtype in module._cached_weights + assert module._cached_weights[x.dtype].device.type == device.type + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda required for lookahead prefetch test") +def test_lookahead_prefetch_single_step(): + device = torch.device("cuda") + producer = _make_module(device) + consumer = _make_module(device) + + producer.enable_lookahead(True).set_lookahead_next(consumer) + consumer.enable_lookahead(True) + + x = torch.randn(4, producer.in_features, device=device, dtype=torch.float16) + + producer(x) + assert torch.float16 in consumer._prefetched_weights + + consumer(x) + assert torch.float16 not in consumer._prefetched_weights + + +def test_configure_default_lookahead_chain(): + device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") + + class DummyAttn(nn.Module): + def __init__(self): + super().__init__() + self.q_proj = _make_module(device) + self.k_proj = _make_module(device) + self.v_proj = _make_module(device) + self.o_proj = _make_module(device) + + class DummyMLP(nn.Module): + def __init__(self): + super().__init__() + self.gate_proj = _make_module(device) + self.up_proj = _make_module(device) + self.down_proj = _make_module(device) + + class DummyLayer(nn.Module): + def __init__(self): + super().__init__() + self.self_attn = DummyAttn() + self.mlp = DummyMLP() + + class DummyModel(nn.Module): + def __init__(self): + super().__init__() + self.layers = nn.ModuleList([DummyLayer()]) + + model = DummyModel() + for module in model.modules(): + if isinstance(module, TorchQuantLinear): + module.enable_lookahead(True) + + configure_default_lookahead(model) + + layer = model.layers[0] + q_proj = layer.self_attn.q_proj + k_proj = layer.self_attn.k_proj + v_proj = layer.self_attn.v_proj + o_proj = layer.self_attn.o_proj + gate_proj = layer.mlp.gate_proj + up_proj = layer.mlp.up_proj + down_proj = layer.mlp.down_proj + + assert q_proj._lookahead_next == (gate_proj, up_proj, down_proj) + assert q_proj._lookahead_enabled + + for module in (k_proj, v_proj, o_proj): + assert module._lookahead_next is None + assert not module._lookahead_enabled + + for module in (gate_proj, up_proj, down_proj): + assert module._lookahead_next is None + assert module._lookahead_enabled diff --git a/tests/test_writer_attention.py b/tests/test_writer_attention.py new file mode 100644 index 000000000..bc3110388 --- /dev/null +++ b/tests/test_writer_attention.py @@ -0,0 +1,127 @@ +# SPDX-FileCopyrightText: 2024-2025 ModelCloud.ai +# SPDX-License-Identifier: Apache-2.0 + +import copy +from types import SimpleNamespace + +import pytest + +from gptqmodel.models.writer import ModelWriter +from gptqmodel.quantization.config import FORMAT, METHOD + + +class _DummyKernel: + REQUIRES_FORMAT_V2 = False + SUPPORTS_SHARDS = True + + +class _DummyQuantizeConfig: + format = FORMAT.GPTQ + quant_method = METHOD.GPTQ + damp_percent = 0.0 + damp_auto_increment = 0.0 + static_groups = False + true_sequential = False + mse = False + v2 = False + v2_alpha = 0.0 + act_group_aware = False + adapter = None + dynamic = False + offload_to_disk = False + offload_to_disk_path = None + lm_head = False + + def __init__(self): + self._meta = {} + + def __deepcopy__(self, memo): + clone = type(self)() + memo[id(self)] = clone + clone._meta = copy.deepcopy(self._meta, memo) + return clone + + def meta_set_versionable(self, key, value): + self._meta[key] = value + + def meta_set(self, key, value): + self._meta[key] = value + + def to_dict(self): + return {"meta": dict(self._meta)} + + def save_pretrained(self, _): # pragma: no cover - not exercised in this test + return None + + def extract_adapter_rank_patterns(self): # pragma: no cover - not exercised here + return {} + + +class _DummyConfig: + def __init__(self): + self.attn_implementation = "flash_attention_2" + self._attn_implementation = "flash_attention_2" + + def __deepcopy__(self, memo): + clone = type(self)() + memo[id(self)] = clone + clone.__dict__ = copy.deepcopy(self.__dict__, memo) + return clone + + +class _DummyGenerationConfig(_DummyConfig): + pass + + +class _DummyModel: + def __init__(self, tracker): + self.config = _DummyConfig() + self.generation_config = _DummyGenerationConfig() + self._tracker = tracker + + def save_pretrained(self, *_args, **_kwargs): + self._tracker["config_snapshot"] = dict(self.config.__dict__) + self._tracker["generation_snapshot"] = dict(self.generation_config.__dict__) + raise RuntimeError("stop after checks") + + +def _build_dummy_model_writer(): + class _Base: + pass + + DummyWriter = ModelWriter(_Base) + instance = DummyWriter() + instance.quantized = True + instance.quantize_config = _DummyQuantizeConfig() + instance.quant_log = [] + instance.load_quantized_model = False + instance.qlinear_kernel = _DummyKernel() + instance.model_local_path = "/tmp/nonexistent" + instance.trust_remote_code = False + instance.tokenizer = None + instance.processor = None + instance.turtle_model = SimpleNamespace() + instance.lm_head = "lm_head" + return instance + + +def test_save_quantized_strips_attention_before_serialization(tmp_path, monkeypatch): + tracker = {} + writer = _build_dummy_model_writer() + writer.model = _DummyModel(tracker) + + monkeypatch.setattr("gptqmodel.models.writer.get_model_files_size", lambda _: 1) + + with pytest.raises(RuntimeError, match="stop after checks"): + writer.save_quantized(save_dir=str(tmp_path)) + + config_snapshot = tracker["config_snapshot"] + generation_snapshot = tracker["generation_snapshot"] + + assert "attn_implementation" not in config_snapshot + assert "_attn_implementation" not in config_snapshot + assert "attn_implementation" not in generation_snapshot + assert "_attn_implementation" not in generation_snapshot + + assert writer.model.config.attn_implementation == "flash_attention_2" + assert writer.model.config._attn_implementation == "flash_attention_2"