diff --git a/examples/benchmark/generation_speed.py b/examples/benchmark/generation_speed.py index b1788bfd3..a2b224c6f 100644 --- a/examples/benchmark/generation_speed.py +++ b/examples/benchmark/generation_speed.py @@ -12,11 +12,13 @@ import torch from datasets import Dataset, load_dataset -from gptqmodel import BACKEND, GPTQModel, QuantizeConfig from logbar import LogBar from transformers import AutoTokenizer, GenerationConfig from transformers.generation.logits_process import LogitsProcessor +from gptqmodel import BACKEND, GPTQModel, QuantizeConfig + + logger = LogBar.shared() random.seed(0) diff --git a/examples/benchmark/ipex.py b/examples/benchmark/ipex.py index c9cfee7c8..fc73436ed 100644 --- a/examples/benchmark/ipex.py +++ b/examples/benchmark/ipex.py @@ -9,6 +9,7 @@ import torch from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer + try: from optimum.intel.utils.modeling_utils import bind_cores_for_best_perf bind_cores_for_best_perf() @@ -18,6 +19,7 @@ import argparse + parser = argparse.ArgumentParser(description="Benchmark IPEX vs HF on a pre-trained model.") parser.add_argument("--model", type=str, required=True, help="Path or name of the pre-trained model.") parser.add_argument("--cores", type=int, default=8, help="Number of CPU cores to use.") diff --git a/examples/benchmark/perplexity.py b/examples/benchmark/perplexity.py index 42651174e..36b4eb812 100644 --- a/examples/benchmark/perplexity.py +++ b/examples/benchmark/perplexity.py @@ -6,9 +6,11 @@ import argparse import os -from gptqmodel.utils.perplexity import Perplexity from transformers import AutoTokenizer +from gptqmodel.utils.perplexity import Perplexity + + os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" diff --git a/examples/eora/eora_generation.py b/examples/eora/eora_generation.py index 636113806..310a81cee 100644 --- a/examples/eora/eora_generation.py +++ b/examples/eora/eora_generation.py @@ -16,6 +16,7 @@ # -- do not touch import os + os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" # -- end do not touch @@ -23,10 +24,12 @@ # from models.model_test import ModelTest # noqa: E402 from eora_calibration_data_construction import construct_c4, construct_mmlu + from gptqmodel import GPTQModel, QuantizeConfig # noqa: E402 from gptqmodel.adapter.adapter import Lora from gptqmodel.utils.torch import torch_empty_cache # noqa: E402 + ## meta-llama/Llama-3.2-1B ## meta-llama/Llama-3.2-3B ## meta-llama/Meta-Llama-3-8B diff --git a/examples/eora/eora_load_and_inference.py b/examples/eora/eora_load_and_inference.py index 2c4ad9e46..7a3ffa107 100644 --- a/examples/eora/eora_load_and_inference.py +++ b/examples/eora/eora_load_and_inference.py @@ -16,6 +16,7 @@ # -- do not touch import os + os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" # -- end do not touch @@ -23,6 +24,7 @@ from gptqmodel import BACKEND, GPTQModel # noqa: E402 from gptqmodel.adapter.adapter import Lora # noqa: E402 + if __name__ == '__main__': import argparse diff --git a/examples/eora/evaluation.py b/examples/eora/evaluation.py index 3ec272a76..08713f089 100644 --- a/examples/eora/evaluation.py +++ b/examples/eora/evaluation.py @@ -16,17 +16,19 @@ # -- do not touch import os + os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" # -- end do not touch from typing import Optional # noqa: E402 +from lm_eval.utils import make_table # noqa: E402 + from gptqmodel import BACKEND, GPTQModel # noqa: E402 from gptqmodel.adapter.adapter import Lora # noqa: E402 from gptqmodel.utils.eval import EVAL # noqa: E402 from gptqmodel.utils.torch import torch_empty_cache # noqa: E402 -from lm_eval.utils import make_table # noqa: E402 def bench(path: str, backend: BACKEND, adapter: Optional[Lora], task): diff --git a/examples/eora/post_quant_eora_generation.py b/examples/eora/post_quant_eora_generation.py index 6689ab5d9..25a1ce2b7 100644 --- a/examples/eora/post_quant_eora_generation.py +++ b/examples/eora/post_quant_eora_generation.py @@ -16,15 +16,18 @@ # -- do not touch import os + os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" # -- end do not touch from eora_calibration_data_construction import construct_ARC, construct_c4, construct_mmlu + from gptqmodel import GPTQModel # noqa: E402 from gptqmodel.adapter.adapter import Lora # noqa: E402 + if __name__ == '__main__': import argparse diff --git a/examples/evaluation/run_language_modeling_task.py b/examples/evaluation/run_language_modeling_task.py index 60790dd0f..32371f4a8 100644 --- a/examples/evaluation/run_language_modeling_task.py +++ b/examples/evaluation/run_language_modeling_task.py @@ -7,10 +7,12 @@ import datasets import torch +from transformers import AutoTokenizer + from gptqmodel import BACKEND, GPTQModel, QuantizeConfig from gptqmodel.eval_tasks import LanguageModelingTask from gptqmodel.utils.torch import torch_empty_cache -from transformers import AutoTokenizer + DATASET = "tatsu-lab/alpaca" WITH_INPUT_TEMPLATE = "Instruction:\n{instruction}\n\nInput:\n{input}\n\nOutput:\n" diff --git a/examples/evaluation/run_sequence_classification_task.py b/examples/evaluation/run_sequence_classification_task.py index 8c8f589c1..1c1a6840b 100644 --- a/examples/evaluation/run_sequence_classification_task.py +++ b/examples/evaluation/run_sequence_classification_task.py @@ -8,10 +8,12 @@ import datasets import torch +from transformers import AutoTokenizer + from gptqmodel import BACKEND, GPTQModel, QuantizeConfig from gptqmodel.eval_tasks import SequenceClassificationTask from gptqmodel.utils.torch import torch_empty_cache -from transformers import AutoTokenizer + DATASET = "cardiffnlp/tweet_sentiment_multilingual" TEMPLATE = "Question:What's the sentiment of the given text? Choices are {labels}.\nText: {text}\nAnswer:" diff --git a/examples/evaluation/run_text_summarization_task.py b/examples/evaluation/run_text_summarization_task.py index e7d268fc0..5c56f9fb6 100644 --- a/examples/evaluation/run_text_summarization_task.py +++ b/examples/evaluation/run_text_summarization_task.py @@ -8,10 +8,12 @@ import datasets import torch +from transformers import AutoTokenizer, GenerationConfig + from gptqmodel import BACKEND, GPTQModel, QuantizeConfig from gptqmodel.eval_tasks import TextSummarizationTask from gptqmodel.utils.torch import torch_empty_cache -from transformers import AutoTokenizer, GenerationConfig + os.system("pip install py7zr") diff --git a/examples/inference/run_transformers.py b/examples/inference/run_transformers.py index 863ab68f1..eb46afa13 100644 --- a/examples/inference/run_transformers.py +++ b/examples/inference/run_transformers.py @@ -5,6 +5,7 @@ from transformers import AutoModelForCausalLM, AutoTokenizer + tokenizer = AutoTokenizer.from_pretrained("TheBloke/TinyLlama-1.1B-Chat-v0.3-GPTQ") quantized_model = AutoModelForCausalLM.from_pretrained("TheBloke/TinyLlama-1.1B-Chat-v0.3-GPTQ") print(tokenizer.decode(quantized_model.generate(**tokenizer("gptqmodel is", return_tensors="pt").to(quantized_model.device))[0])) diff --git a/examples/inference/run_with_different_backends.py b/examples/inference/run_with_different_backends.py index 10ac9830b..84be1b9a0 100644 --- a/examples/inference/run_with_different_backends.py +++ b/examples/inference/run_with_different_backends.py @@ -8,9 +8,11 @@ import sys from argparse import ArgumentParser -from gptqmodel import BACKEND, GPTQModel, QuantizeConfig, get_best_device from transformers import AutoTokenizer +from gptqmodel import BACKEND, GPTQModel, QuantizeConfig, get_best_device + + os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" diff --git a/examples/quantization/basic_usage.py b/examples/quantization/basic_usage.py index af1bdb234..5f9f1b387 100644 --- a/examples/quantization/basic_usage.py +++ b/examples/quantization/basic_usage.py @@ -5,9 +5,11 @@ import os -from gptqmodel import GPTQModel, QuantizeConfig, get_best_device from transformers import AutoTokenizer +from gptqmodel import GPTQModel, QuantizeConfig, get_best_device + + os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" diff --git a/examples/quantization/basic_usage_wikitext2.py b/examples/quantization/basic_usage_wikitext2.py index bf94e6dc0..02f600465 100644 --- a/examples/quantization/basic_usage_wikitext2.py +++ b/examples/quantization/basic_usage_wikitext2.py @@ -5,9 +5,11 @@ import torch from datasets import load_dataset -from gptqmodel import GPTQModel, QuantizeConfig from transformers import AutoTokenizer +from gptqmodel import GPTQModel, QuantizeConfig + + pretrained_model_id = "/monster/data/model/TinyLlama-1.1B-Chat-v1.0" # "TinyLlama/TinyLlama-1.1B-Chat-v1.0" quantized_model_id = "TinyLlama-1.1B-Chat-v1.0-4bit-128g" diff --git a/examples/quantization/transformers_usage.py b/examples/quantization/transformers_usage.py index 6846df950..671ba030f 100755 --- a/examples/quantization/transformers_usage.py +++ b/examples/quantization/transformers_usage.py @@ -5,6 +5,7 @@ from transformers import AutoModelForCausalLM, AutoTokenizer, GPTQConfig + model_id = "facebook/opt-125m" tokenizer = AutoTokenizer.from_pretrained(model_id) dataset = ["gptqmodel is an easy-to-use model quantization library with user-friendly apis, based on GPTQ algorithm."] diff --git a/format/format.sh b/format/format.sh index 93f7cfae4..9e4d630ed 100755 --- a/format/format.sh +++ b/format/format.sh @@ -8,7 +8,8 @@ pip install -U ruff==0.13.0 isort==6.0.1 ruff check ../gptqmodel/models ../gptqmodel/nn_modules ../gptqmodel/quantization ../gptqmodel/utils ../gptqmodel/__init__.py ../examples ../tests ../setup.py --fix --unsafe-fixes ruff_status=$? -isort -l 119 -e ../ +# isort is too slow +# isort -l 119 -e ../ # Exit with the status code of ruff check exit $ruff_status \ No newline at end of file diff --git a/gptqmodel/__init__.py b/gptqmodel/__init__.py index 615ae00c1..258080baa 100644 --- a/gptqmodel/__init__.py +++ b/gptqmodel/__init__.py @@ -11,6 +11,7 @@ from .utils.exllama import exllama_set_max_input_length from .version import __version__ + if os.getenv('GPTQMODEL_USE_MODELSCOPE', 'False').lower() in ['true', '1']: try: from modelscope.utils.hf_util.patcher import patch_hub diff --git a/gptqmodel/looper/gptq_processor.py b/gptqmodel/looper/gptq_processor.py index 58fc5b40f..c45f3e1c6 100644 --- a/gptqmodel/looper/gptq_processor.py +++ b/gptqmodel/looper/gptq_processor.py @@ -13,13 +13,13 @@ from ..looper.loop_processor import LoopProcessor, get_max_memory from ..looper.named_module import NamedModule from ..models import BaseQModel +from ..models._const import CPU from ..models.writer import (PROCESS_LOG_FWD_TIME, PROCESS_LOG_LAYER, PROCESS_LOG_MODULE, PROCESS_LOG_NAME, PROCESS_LOG_TIME, PROCESS_MAX_MEMORY, QUANT_LOG_DAMP, QUANT_LOG_LOSS, QUANT_LOG_NSAMPLES) from ..quantization import GPTQ, GPTQv2 from ..quantization.config import METHOD, QuantizeConfig from ..utils.importer import select_quant_linear from ..utils.logger import setup_logger -from ..utils.memory import MEM_LORD from ..utils.model import create_quant_module, find_modules, move_to, pack_model, pack_module from ..utils.offload import undo_offload_to_disk from ..utils.torch import HAS_CUDA, torch_streamCtx, torch_sync @@ -127,7 +127,10 @@ def process(self, module: NamedModule): g = self.tasks[module.name] wq, q_scales, q_zeros, q_g_idx, duration, avg_loss, damp_percent, nsamples = g.quantize() - MEM_LORD.free((q_scales, q_zeros, q_g_idx)) + + q_scales = q_scales.to(CPU) + q_zeros = q_zeros.to(CPU) + q_g_idx = q_g_idx.to(CPU) with self.lock: module.state.update({"q_scales": q_scales}) @@ -198,7 +201,7 @@ def process(self, module: NamedModule): "wq": wq, # fp16, quantized weight but not int4 (packed qweight) }) - MEM_LORD.free(module.weight) + # single largest deallocation of vram happens here module.weight.data = wq # submodule_finalized is called in reverse after all next sequential processes are called @@ -215,6 +218,10 @@ def submodule_finalize(self, module: NamedModule, model: BaseQModel, **kwargs): q_scales = module.state.pop("q_scales") q_g_idx = module.state.pop("q_g_idx") + assert q_zeros.device == CPU + assert q_scales.device == CPU + assert q_g_idx.device == CPU + layers = find_modules(model.model) # replace module with quantized module @@ -251,7 +258,6 @@ def submodule_finalize(self, module: NamedModule, model: BaseQModel, **kwargs): with self.lock: self.result_pop(module.full_name) - # MEM_LORD.free(module.weight) module.unregister_parameter("weight") def finalize(self, model: BaseQModel, **kwargs): @@ -260,14 +266,12 @@ def finalize(self, model: BaseQModel, **kwargs): torch_sync() model.model = undo_offload_to_disk(module=model.model, include_buffers=True, delete_offload_folders=True) - MEM_LORD.free(model.model) # print("finalize") # print_module_tree(model.model) # set quantized state model.quantized = True - model.quantize_config.quant_method = METHOD.GPTQ super().finalize(model=model, **kwargs) diff --git a/gptqmodel/looper/module_looper.py b/gptqmodel/looper/module_looper.py index 81f9c6ca7..fee50bdb1 100644 --- a/gptqmodel/looper/module_looper.py +++ b/gptqmodel/looper/module_looper.py @@ -2,13 +2,16 @@ # SPDX-FileCopyrightText: 2024-2025 qubitium@modelcloud.ai # SPDX-License-Identifier: Apache-2.0 # Contact: qubitium@modelcloud.ai, x.com/qubitium + +from __future__ import annotations + import copy import gc import threading import time -from concurrent.futures import ThreadPoolExecutor +from contextlib import contextmanager from functools import partial -from typing import Dict, List +from typing import Dict, List, Optional import torch @@ -22,13 +25,13 @@ from ..models._const import CUDA, SUPPORTS_MODULE_TYPES from ..nn_modules.hooked_linear import (STOP_FORWARD_EXCEPTION, HookedLinear, StopForward, replace_module_with_hooked_legacy) -from ..utils import ASYNC_BG_QUEUE, SERIAL_BG_QUEUE from ..utils.attn_mask import apply_keep_mask_bt, normalize_seq_mask -from ..utils.device import get_device +from ..utils.device import get_device, get_device_new from ..utils.logger import setup_logger from ..utils.model import find_modules, get_module, get_module_by_name_prefix, move_to, nested_move_to from ..utils.offload import offload_to_disk from ..utils.structure import print_module_tree +from ..utils.threadx import DeviceThreadPool from ..utils.torch import (ALL_DEVICES, CPU, DEFAULT_BALANCE_STRATEGY, HAS_CUDA, META, BalanceStrategy, device_next, device_next_reset, torch_empty_cache, torch_sync) from .awq_processor import AWQProcessor @@ -37,6 +40,89 @@ log = setup_logger() +# -------------------- Device helpers (local) -------------------- + +@contextmanager +def _device_ctx(dev: Optional[torch.device]): + """ + Ensure the caller thread’s current device matches `dev` for the duration of the + context (CUDA/XPU). Prevents cuBLAS/cuDNN handle/device mismatches in multi-GPU. + """ + if dev is None: + yield + else: + dtyp = getattr(dev, "type", None) + if dtyp == "cuda": + with torch.cuda.device(dev.index): + yield + elif dtyp == "xpu" and hasattr(torch, "xpu"): + with torch.xpu.device(dev.index): # type: ignore[attr-defined] + yield + else: + # cpu/mps/meta -> nothing special needed + yield + + +@torch.inference_mode() +def _rehome_module_to_device( + module: torch.nn.Module, + device: torch.device, + *, + move_parameters: bool = False, + move_buffers: bool = True, + include_non_persistent_buffers: bool = True, + only_mismatched: bool = True, +): + """ + Move a module's **registered** tensors to `device`. + Defaults to buffers-only (fast; fixes RoPE cos/sin caches and other internal state). + Parameters can be moved too, but that risks breaking weight tying and increases VRAM churn. + """ + for sub in module.modules(): + # Buffers (covers most cached internal tensors if properly registered) + if move_buffers: + np_set = getattr(sub, "_non_persistent_buffers_set", set()) + for name, buf in list(getattr(sub, "_buffers", {}).items()): + if buf is None or not isinstance(buf, torch.Tensor): + continue + if not include_non_persistent_buffers and name in np_set: + continue + if only_mismatched and buf.device == device: + continue + try: + sub._buffers[name] = buf.to(device, non_blocking=True) + except Exception: + try: + sub._buffers[name] = buf.to(device) + except Exception: + pass + + # Parameters (rarely needed; default False) + if move_parameters: + for pname, p in list(getattr(sub, "_parameters", {}).items()): + if p is None or not isinstance(p, torch.nn.Parameter): + continue + if only_mismatched and p.device == device: + continue + try: + with torch.no_grad(): + new_p = torch.nn.Parameter( + p.data.to(device, non_blocking=True), + requires_grad=p.requires_grad + ) + sub._parameters[pname] = new_p + except Exception: + try: + with torch.no_grad(): + new_p = torch.nn.Parameter( + p.data.to(device), + requires_grad=p.requires_grad + ) + sub._parameters[pname] = new_p + except Exception: + pass + + class ModuleLooper(): def __init__(self, model: BaseQModel, processors: List[LoopProcessor]): self.processors = processors @@ -44,6 +130,20 @@ def __init__(self, model: BaseQModel, processors: List[LoopProcessor]): self.support_batch_quantize = model.support_batch_quantize self.lock = threading.Lock() + # Create a single pool for the entire looper lifecycle. + # Eagerly discovers devices and pins worker threads per device. + # Tune worker counts here if desired (example policy shown). + self.pool = DeviceThreadPool( + inference_mode=True, + workers={ + "cuda:per": 4, # unique memory per instance + "xpu:per": 1, # unique memory per instance + "mps": 8, # unified memory + "cpu": 8, # unified memory + }, + empty_cache_every_n=14, # disable auto GC during quant loops; enable if you want + ) + # NEW: Wrap an existing hook so its inputs/outputs are pre-masked for GPTQ stats. # We *do not* alter the module's actual computation; only what the hook # passes down to the processor capture path is masked. @@ -76,7 +176,7 @@ def hook(module, inputs, output): if isinstance(output, tuple): new_output = (yk,) + tuple(output[1:]) else: - new_output = [yk] + list(output[1:]) + new_output = [yk] + list(output[1:] ) elif torch.is_tensor(output) and keep is not None and output.dim() >= 3: new_output = apply_keep_mask_bt(output, keep) except Exception: @@ -177,10 +277,13 @@ def store_input_hook(module, args, kwargs): if self.gptq_model.ATTENTION_MASKS_DTYPE is torch.long: example["attention_mask"] = example["attention_mask"].long() - if self.gptq_model.INPUT_EMBEDDING_EXTRA_ARGS: - self.gptq_model.model.generate(**example, **self.gptq_model.INPUT_EMBEDDING_EXTRA_ARGS) - else: - self.gptq_model.model(**example, use_cache=use_cache) + # Ensure initial caches (like RoPE) are created on the quant device + with self.pool.read_lock(self.gptq_model.quantize_config.device): + with _device_ctx(self.gptq_model.quantize_config.device): + if self.gptq_model.INPUT_EMBEDDING_EXTRA_ARGS: + self.gptq_model.model.generate(**example, **self.gptq_model.INPUT_EMBEDDING_EXTRA_ARGS) + else: + self.gptq_model.model(**example, use_cache=use_cache) except StopForward: pass @@ -191,7 +294,7 @@ def store_input_hook(module, args, kwargs): return InputCache(layer_inputs=layer_inputs, layer_input_kwargs=layer_input_kwargs, position_ids=position_ids, attention_masks=attention_masks) - @torch.inference_mode + @torch.inference_mode() def loop(self, fail_safe: bool = False, **kwargs): if self.gptq_model.quantize_config.lm_head: if self.gptq_model.model.config.tie_word_embeddings and hasattr(self.gptq_model.model.model, "_tied_weights_keys"): @@ -300,9 +403,7 @@ def loop(self, fail_safe: bool = False, **kwargs): # merge all subsets into one modules = [sum(modules, [])] - # TODO: integrated AWQ module forward/hooks within module lopper so everything is unified - # AWQ does it's own per layer module hooks and calculations. Logic has not been fully integrated into - # the module_looper so we wil let awq handle per layer operations for now + # AWQ does per-layer itself; skip here if isinstance(processor, AWQProcessor): named_childs = dict() for index, names in enumerate(modules): @@ -313,10 +414,7 @@ def loop(self, fail_safe: bool = False, **kwargs): processor=processor, fail_safe=fail_safe) named_childs.update(named_modules) - - # awq uses model.layers[0] for quantization instead of model.layers.0.self_attn.q_proj processor.layer_quantize(module, cur_layer_device, named_childs) - # skip module_looper processing for awq continue layer_inputs = processor.inputs_cache.layer_inputs @@ -372,12 +470,10 @@ def loop(self, fail_safe: bool = False, **kwargs): raw_mask = attention_masks[j] layer_attention_mask = raw_mask if raw_mask is None else move_to(raw_mask, device=cur_layer_device, stream=False) - # Compute and set keep-mask for this batch, for hook wrappers to consume + # Compute and set keep-mask for this batch if raw_mask is not None: - # Assume hidden_states is first arg with shape [B, S, H] seq_len = layer_input[0].shape[1] if (len(layer_input) > 0 and layer_input[0].dim() >= 2) else None keep_mask_bs = normalize_seq_mask(layer_attention_mask, seq_len=seq_len) - # We don't require LoopProcessor to declare this attribute; set dynamically. setattr(processor, "current_attention_mask", keep_mask_bs) else: setattr(processor, "current_attention_mask", None) @@ -392,18 +488,23 @@ def loop(self, fail_safe: bool = False, **kwargs): additional_layer_inputs[k] = nested_move_to(v, device=cur_layer_device, stream=False) try: - # reuse_kv special-case - if hasattr(module, "reuse_kv") and module.reuse_kv: - additional_layer_inputs["kv_last_layer"] = shared_kv_cache_dict.get(layer_index - 1) - layer_output = module(*layer_input) if is_lm_head_module else module(*layer_input, **additional_layer_inputs) - if shared_kv_cache_dict.get(layer_index) is None: - shared_kv_cache_dict[layer_index] = layer_output[-1] - else: - layer_output = module(*layer_input) if is_lm_head_module else module(*layer_input, **additional_layer_inputs) + # Ensure internal buffers (e.g., RoPE caches) are on the layer's device + # _rehome_module_to_device(module, cur_layer_device, move_parameters=False, move_buffers=True) + + # Acquire read lock so auto-GC cannot run while we forward + with self.pool.read_lock(cur_layer_device): + with _device_ctx(cur_layer_device): + # reuse_kv special-case + if hasattr(module, "reuse_kv") and module.reuse_kv: + additional_layer_inputs["kv_last_layer"] = shared_kv_cache_dict.get(layer_index - 1) + layer_output = module(*layer_input) if is_lm_head_module else module(*layer_input, **additional_layer_inputs) + if shared_kv_cache_dict.get(layer_index) is None: + shared_kv_cache_dict[layer_index] = layer_output[-1] + else: + layer_output = module(*layer_input) if is_lm_head_module else module(*layer_input, **additional_layer_inputs) except StopForward: pass finally: - # Clear the per-batch mask no matter what setattr(processor, "current_attention_mask", None) del layer_input del additional_layer_inputs @@ -442,36 +543,25 @@ def loop(self, fail_safe: bool = False, **kwargs): for name in moe_skip_modules: subset.pop(name) - # ---- Start Process Hook ---- - if len(ALL_DEVICES) <= 1: - for name_index, name in enumerate(subset): - m = subset[name] - processor.process(module=m) - processed_subset[name] = m - else: - max_workers = len(ALL_DEVICES) if DEFAULT_BALANCE_STRATEGY == BalanceStrategy.GPU else len(ALL_DEVICES) - 1 - with ThreadPoolExecutor(max_workers=max_workers) as executor: - futures = [] - - @torch.inference_mode() - def process_module(name, m): - # prevent cuda sync memory ctx bugs - m_device = get_device(m) - if HAS_CUDA and m_device is not None and m_device.type == "cuda": - torch.cuda.set_device(m_device) - - processor.process(module=m) - return name, m - - for name in subset: - m = subset[name] - futures.append(executor.submit(process_module, name, m)) + # ---- Start Process Hook (via DeviceThreadPool) ---- + futures = [] - for future in futures: - name, m = future.result() - processed_subset[name] = m + @torch.inference_mode() + def _process_on_worker(proc: LoopProcessor, nm: NamedModule): + # Run processor.process for this NamedModule + proc.process(module=nm) + return nm.name, nm - torch_sync() + for name in subset: + m = subset[name] + # Prefer the planned target_device; fallback to module's own device + tgt_dev = getattr(m.module, "target_device", None) or get_device(m) or CPU + futures.append(self.pool.submit(tgt_dev, _process_on_worker, processor, m)) + + for fut in futures: + name, m = fut.result() + processed_subset[name] = m + torch_sync() # ---- End Process Hook ---- is_last_module = layer_index == len(quant_modules_pb) - 1 @@ -486,7 +576,6 @@ def process_module(name, m): raw_mask = attention_masks[j] layer_attention_mask = raw_mask if raw_mask is None else move_to(raw_mask, device=cur_layer_device) - # Keep-mask for this replay, for completeness (in case hooks capture again) if raw_mask is not None: seq_len = layer_input[0].shape[1] if (len(layer_input) > 0 and layer_input[0].dim() >= 2) else None keep_mask_bs = normalize_seq_mask(layer_attention_mask, seq_len=seq_len) @@ -501,14 +590,16 @@ def process_module(name, m): for k, v in layer_input_kwargs[j].items(): additional_layer_inputs[k] = nested_move_to(v, device=cur_layer_device) - if hasattr(module, "reuse_kv") and module.reuse_kv: - additional_layer_inputs["kv_last_layer"] = shared_kv_cache_dict.get(layer_index - 1) + # Rehome buffers again in case module ran on a different device previously + _rehome_module_to_device(module, cur_layer_device, move_parameters=False, move_buffers=True) - module_output = None - if is_lm_head_module: - module_output = module(*layer_input) - else: - module_output = module(*layer_input, **additional_layer_inputs) + # Guard forward with read lock to block auto-GC + with self.pool.read_lock(cur_layer_device): + with _device_ctx(cur_layer_device): + if is_lm_head_module: + module_output = module(*layer_input) + else: + module_output = module(*layer_input, **additional_layer_inputs) if isinstance(module_output, tuple): layer_output = module_output[0] @@ -522,7 +613,6 @@ def process_module(name, m): layer_outputs.append([layer_output]) - # Clear per-batch mask setattr(processor, "current_attention_mask", None) del layer_input @@ -543,20 +633,18 @@ def process_module(name, m): if p_index == len(self.processors) - 1: torch_sync() + + # Gather finalize tasks (can offload to disk); run them via the pool + finalize_futures = [] + for reverse_p in reversed(self.processors): for name in processed_subset: @torch.inference_mode() def finalize_module(process, module): - # prevent cuda sync memory ctx bugs - m_device = get_device(module) - if HAS_CUDA and m_device is not None and m_device.type == "cuda": - torch.cuda.set_device(m_device) - process.submodule_finalize(module, self.gptq_model) - # TODO FIX ME offloading to LoopProcessor lifecycle + # Disk offload (lifecycle TODO note preserved) if isinstance(process, (GPTQProcessor, QQQProcessor, AWQProcessor)): - # checking for disk offloading offload_to_disk( model=self.gptq_model.model, module=self.gptq_model.model.get_submodule(module.full_name), @@ -565,21 +653,24 @@ def finalize_module(process, module): module = processed_subset[name] - if self.gptq_model.quantize_config.offload_to_disk: - SERIAL_BG_QUEUE.submit(partial( - finalize_module, - process=reverse_p, - module=module, - )) - else: - reverse_p.submodule_finalize(module, self.gptq_model) + target_dev = get_device_new(module, recursive=True, assert_mode=True, expected="cpu") + + # Submit on the module's device thread (safe & deterministic) + finalize_futures.append( + self.pool.submit(target_dev, finalize_module, reverse_p, module) + ) + + # If any finalize tasks were queued, wait for them + for fut in finalize_futures: + fut.result() # LifeCycle: All sub-modules have finalized meaning quantization work is complete - SERIAL_BG_QUEUE.join() + # Ensure ANY remaining tasks the looper submitted have drained + self.pool.wait() # same as wait('all') # paranoid safety check - torch_sync() - torch_sync(device=CPU) + # torch_sync() + # torch_sync(device=CPU) total_log = {} diff --git a/gptqmodel/models/_const.py b/gptqmodel/models/_const.py index ed7e67931..c56342cf2 100644 --- a/gptqmodel/models/_const.py +++ b/gptqmodel/models/_const.py @@ -14,6 +14,7 @@ from ..utils.rocm import IS_ROCM from ..utils.torch import HAS_CUDA, HAS_MPS, HAS_XPU + CPU = device("cpu") META = device("meta") CUDA = device("cuda") diff --git a/gptqmodel/models/auto.py b/gptqmodel/models/auto.py index 397f67caa..9669c7427 100644 --- a/gptqmodel/models/auto.py +++ b/gptqmodel/models/auto.py @@ -11,6 +11,7 @@ from ..utils.logger import setup_logger + log = setup_logger() # if not os.environ.get("PYTHON_GIL", None): @@ -31,6 +32,7 @@ import sys # noqa: E402 + # TODO: waiting for pytorch implementgation of aten ops for MPS if sys.platform == "darwin": os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" @@ -112,11 +114,12 @@ from .definitions.qwen3 import Qwen3QModel # noqa: E402 from .definitions.qwen3_moe import Qwen3MoeQModel # noqa: E402 from .definitions.qwen3_next import Qwen3NextGPTQ # noqa: E402 +from .definitions.qwen3_omni_moe import Qwen3OmniMoeGPTQ from .definitions.rw import RwgQModel # noqa: E402 from .definitions.starcoder2 import Starcoder2QModel # noqa: E402 from .definitions.telechat2 import TeleChat2QModel from .definitions.xverse import XverseQModel # noqa: E402 -from .definitions.qwen3_omni_moe import Qwen3OmniMoeGPTQ + # make quants and inference more determinisitc torch.manual_seed(787) diff --git a/gptqmodel/models/base.py b/gptqmodel/models/base.py index cebd26e98..1e503b253 100644 --- a/gptqmodel/models/base.py +++ b/gptqmodel/models/base.py @@ -15,8 +15,14 @@ import torch._dynamo import torch.nn as nn from tokenicer import Tokenicer -from transformers import (AutoModelForCausalLM, AutoProcessor, PreTrainedModel, - PreTrainedTokenizerBase, ProcessorMixin, modeling_utils) +from transformers import ( + AutoModelForCausalLM, + AutoProcessor, + PreTrainedModel, + PreTrainedTokenizerBase, + ProcessorMixin, + modeling_utils, +) from ..adapter.adapter import Adapter from ..nn_modules.qlinear import BaseQuantLinear @@ -34,8 +40,14 @@ from ..utils.offload import offload_to_disk from ..utils.structure import alias_from_turtle_for_submodule from ..utils.torch import TORCH_HAS_COMPILE, torch_compile -from ._const import (CALIBRATION_DATASET_CONCAT_CHAR, CPU, DEFAULT_MAX_SHARD_SIZE, - DEVICE, EXPERT_INDEX_PLACEHOLDER, META) +from ._const import ( + CALIBRATION_DATASET_CONCAT_CHAR, + CPU, + DEFAULT_MAX_SHARD_SIZE, + DEVICE, + EXPERT_INDEX_PLACEHOLDER, + META, +) from .loader import ModelLoader from .writer import ModelWriter @@ -310,8 +322,8 @@ def simple_layer_modules(cls, model_config, quantize_config, is_awq_quantize: bo layer_modules = cls.build_moe_modules_if_need(model_config, layer_modules, is_awq_quantize) layer_modules = cls.filter_not_quantize_module(layer_modules, quantize_config) - - print(f"simple_layer_modules layer_modules: {layer_modules}") + + # print(f"simple_layer_modules layer_modules: {layer_modules}") return layer_modules @classmethod @@ -1071,9 +1083,9 @@ def shell_module_materialize( if self.turtle_model is None: if get_device(target_submodule) != device: target_submodule.to(device) - + return target_submodule - + module = alias_from_turtle_for_submodule( target_model=self.model, turtle_model=self.turtle_model, diff --git a/gptqmodel/models/definitions/__init__.py b/gptqmodel/models/definitions/__init__.py index 2a9459aa4..7072e254d 100644 --- a/gptqmodel/models/definitions/__init__.py +++ b/gptqmodel/models/definitions/__init__.py @@ -63,4 +63,4 @@ from .klear import KlearQModel from .llava_qwen2 import LlavaQwen2QModel from .nemotron_h import NemotronHQModel -from .qwen3_omni_moe import Qwen3OmniMoeGPTQ \ No newline at end of file +from .qwen3_omni_moe import Qwen3OmniMoeGPTQ diff --git a/gptqmodel/models/definitions/base_qwen2_5_omni.py b/gptqmodel/models/definitions/base_qwen2_5_omni.py index 4fca6b67f..227a7bd56 100644 --- a/gptqmodel/models/definitions/base_qwen2_5_omni.py +++ b/gptqmodel/models/definitions/base_qwen2_5_omni.py @@ -5,6 +5,7 @@ from typing import Dict, Optional +import torch from PIL import Image from transformers import AutoModelForTextToWaveform, AutoProcessor, ProcessorMixin @@ -13,7 +14,6 @@ from ...utils.model import MODALITY from .._const import CPU from ..base import BaseQModel -import torch class BaseQwen2_5_OmniGPTQ(BaseQModel): @@ -23,7 +23,7 @@ class BaseQwen2_5_OmniGPTQ(BaseQModel): INPUT_EMBEDDING_EXTRA_ARGS = { "return_audio": False, } - + loader = AutoModelForTextToWaveform pre_lm_head_norm_module = "thinker.model.norm" diff --git a/gptqmodel/models/definitions/gemma2.py b/gptqmodel/models/definitions/gemma2.py index ea7b4883d..519a9296b 100644 --- a/gptqmodel/models/definitions/gemma2.py +++ b/gptqmodel/models/definitions/gemma2.py @@ -7,6 +7,7 @@ from ...utils.logger import setup_logger from . import LlamaQModel + log = setup_logger() SUPPORT_ERR = "Currently, only vLLM/SGLang with flashinfer enabled can correctly inference a quantized Gemma2-27B model. Pre-quantized model with sample vLLM code: https://huggingface.co/ModelCloud/gemma-2-27b-it-gptq-4bit ." diff --git a/gptqmodel/models/definitions/nemotron_h.py b/gptqmodel/models/definitions/nemotron_h.py index 4944fe004..40535e348 100644 --- a/gptqmodel/models/definitions/nemotron_h.py +++ b/gptqmodel/models/definitions/nemotron_h.py @@ -26,8 +26,11 @@ def monkey_patch(self): if not self.load_quantized_model: return - from transformers.utils.import_utils import (is_causal_conv1d_available, - is_flash_attn_2_available, is_mamba_2_ssm_available) + from transformers.utils.import_utils import ( + is_causal_conv1d_available, + is_flash_attn_2_available, + is_mamba_2_ssm_available, + ) if is_mamba_2_ssm_available(): from mamba_ssm.ops.triton.selective_state_update import selective_state_update diff --git a/gptqmodel/models/definitions/qwen3_omni_moe.py b/gptqmodel/models/definitions/qwen3_omni_moe.py index 4487136cc..3b071bcd0 100644 --- a/gptqmodel/models/definitions/qwen3_omni_moe.py +++ b/gptqmodel/models/definitions/qwen3_omni_moe.py @@ -3,11 +3,13 @@ # SPDX-License-Identifier: Apache-2.0 # Contact: qubitium@modelcloud.ai, x.com/qubitium +import torch from transformers import AutoModelForTextToWaveform -from ..base import BaseQModel -from .._const import CPU + from ...utils.offload import offload_to_disk -import torch +from .._const import CPU +from ..base import BaseQModel + class Qwen3OmniMoeGPTQ(BaseQModel): ATTENTION_MASKS_REQUIRED_FOR_INPUT = True @@ -16,7 +18,7 @@ class Qwen3OmniMoeGPTQ(BaseQModel): INPUT_EMBEDDING_EXTRA_ARGS = { "return_audio": False, } - + loader = AutoModelForTextToWaveform dynamic_expert_index = "num_experts" @@ -59,7 +61,7 @@ def pre_quantize_generate_hook_end(self): module=self.model.thinker.visual, disk_path=self.quantize_config.offload_to_disk_path, ) - + offload_to_disk(model=self.model.thinker, module=self.model.thinker.audio_tower, disk_path=self.quantize_config.offload_to_disk_path, @@ -69,7 +71,7 @@ def pre_quantize_generate_hook_end(self): module=self.model.thinker.visual.rotary_pos_emb, disk_path=self.quantize_config.offload_to_disk_path, ) - + offload_to_disk(model=self.model.thinker.model, module=self.model.thinker.model.rotary_emb, disk_path=self.quantize_config.offload_to_disk_path, diff --git a/gptqmodel/models/loader.py b/gptqmodel/models/loader.py index 2576583bb..271048e2b 100644 --- a/gptqmodel/models/loader.py +++ b/gptqmodel/models/loader.py @@ -15,6 +15,7 @@ from ..utils.structure import print_module_tree + if os.getenv('GPTQMODEL_USE_MODELSCOPE', 'False').lower() in ['true', '1']: try: from modelscope import snapshot_download @@ -37,11 +38,21 @@ from ..utils.importer import auto_select_device, normalize_device_device_map, select_quant_linear from ..utils.logger import setup_logger from ..utils.marlin import _validate_marlin_device_support -from ..utils.model import (auto_dtype, convert_gptq_v1_to_v2_format, find_config_seq_len, find_modules, - get_checkpoints, get_module_by_name_prefix, gptqmodel_post_init, - load_checkpoint_in_model_then_tie_weights, make_quant, simple_dispatch_model) +from ..utils.model import ( + auto_dtype, + convert_gptq_v1_to_v2_format, + find_config_seq_len, + find_modules, + get_checkpoints, + get_module_by_name_prefix, + gptqmodel_post_init, + load_checkpoint_in_model_then_tie_weights, + make_quant, + simple_dispatch_model, +) from ._const import DEVICE, normalize_device + log = setup_logger() ATTN_IMPLEMENTATION = "attn_implementation" @@ -185,7 +196,7 @@ def skip(*args, **kwargs): model = build_shell_model(cls.loader, config=config, **model_init_kwargs) model._model_init_kwargs = model_init_kwargs print_module_tree(model=model) - + # enable mmap with low_cpu_mem_usage turtle_model = cls.loader.from_pretrained(model_local_path, config=config, low_cpu_mem_usage=True, **model_init_kwargs) @@ -198,7 +209,7 @@ def skip(*args, **kwargs): model = cls.loader.from_pretrained(model_local_path, config=config, **model_init_kwargs) model._model_init_kwargs = model_init_kwargs print_module_tree(model=model) - + turtle_model = None model_config = model.config.to_dict() diff --git a/gptqmodel/models/writer.py b/gptqmodel/models/writer.py index 51831a172..d7ccb9e7e 100644 --- a/gptqmodel/models/writer.py +++ b/gptqmodel/models/writer.py @@ -26,20 +26,39 @@ from ..adapter.adapter import HF_ADAPTER_FILE_NAME, HF_ADAPTER_WEIGHT_KEY_PREFIX, Lora from ..adapter.peft import LoraConfig -from ..quantization.config import (FORMAT, META_FIELD_ACT_GROUP_AWARE, META_FIELD_DAMP_AUTO_INCREMENT, - META_FIELD_DAMP_PERCENT, META_FIELD_MSE, META_FIELD_QUANTIZER, - META_FIELD_STATIC_GROUPS, META_FIELD_TRUE_SEQUENTIAL, META_FIELD_URI, - META_FIELD_V2_ALPHA, META_FIELD_V2_ENABLED, META_QUANTIZER_GPTQMODEL, - META_VALUE_URI, MIN_VERSION_WITH_V2) +from ..quantization.config import ( + FORMAT, + META_FIELD_ACT_GROUP_AWARE, + META_FIELD_DAMP_AUTO_INCREMENT, + META_FIELD_DAMP_PERCENT, + META_FIELD_MSE, + META_FIELD_QUANTIZER, + META_FIELD_STATIC_GROUPS, + META_FIELD_TRUE_SEQUENTIAL, + META_FIELD_URI, + META_FIELD_V2_ALPHA, + META_FIELD_V2_ENABLED, + META_QUANTIZER_GPTQMODEL, + META_VALUE_URI, + MIN_VERSION_WITH_V2, +) from ..utils.backend import BACKEND from ..utils.logger import setup_logger -from ..utils.model import (convert_gptq_v2_to_v1_format, copy_py_files, find_modules, get_model_files_size, - get_state_dict_for_save, load_checkpoint_in_model_then_tie_weights, make_quant) +from ..utils.model import ( + convert_gptq_v2_to_v1_format, + copy_py_files, + find_modules, + get_model_files_size, + get_state_dict_for_save, + load_checkpoint_in_model_then_tie_weights, + make_quant, +) from ..utils.structure import alias_all_from_turtle_if_meta from ..utils.torch import torch_empty_cache from ..version import __version__ from ._const import DEFAULT_MAX_SHARD_SIZE + log = setup_logger() PROCESS_LOG_NAME = "process" diff --git a/gptqmodel/nn_modules/hooked_linear.py b/gptqmodel/nn_modules/hooked_linear.py index 38ec21283..a7a3cb0ef 100644 --- a/gptqmodel/nn_modules/hooked_linear.py +++ b/gptqmodel/nn_modules/hooked_linear.py @@ -11,6 +11,7 @@ from ..utils.logger import setup_logger + log = setup_logger() class StopForward(Exception): diff --git a/gptqmodel/nn_modules/qlinear/__init__.py b/gptqmodel/nn_modules/qlinear/__init__.py index 6a7dfa975..69a5dbb8f 100644 --- a/gptqmodel/nn_modules/qlinear/__init__.py +++ b/gptqmodel/nn_modules/qlinear/__init__.py @@ -19,6 +19,7 @@ from ...utils.backend import BACKEND from ...utils.logger import setup_logger + log = setup_logger() class BaseQuantLinear(nn.Module): diff --git a/gptqmodel/nn_modules/qlinear/awq_exllama.py b/gptqmodel/nn_modules/qlinear/awq_exllama.py index ad8aaf6bc..202cf3de1 100644 --- a/gptqmodel/nn_modules/qlinear/awq_exllama.py +++ b/gptqmodel/nn_modules/qlinear/awq_exllama.py @@ -13,6 +13,7 @@ from ...utils.backend import BACKEND from ...utils.logger import setup_logger + log = setup_logger() exl_ext, msg = try_import("gptqmodel_exllama_kernels") diff --git a/gptqmodel/nn_modules/qlinear/awq_exllamav2.py b/gptqmodel/nn_modules/qlinear/awq_exllamav2.py index 71f65d222..5233d7e79 100644 --- a/gptqmodel/nn_modules/qlinear/awq_exllamav2.py +++ b/gptqmodel/nn_modules/qlinear/awq_exllamav2.py @@ -14,6 +14,7 @@ from ...utils.exllamav2 import ScratchSpace from ...utils.logger import setup_logger + log = setup_logger() exlv2_ext, msg = try_import("gptqmodel_exllamav2_kernels") diff --git a/gptqmodel/nn_modules/qlinear/awq_gemm.py b/gptqmodel/nn_modules/qlinear/awq_gemm.py index 032cfae5e..58f7943d4 100644 --- a/gptqmodel/nn_modules/qlinear/awq_gemm.py +++ b/gptqmodel/nn_modules/qlinear/awq_gemm.py @@ -12,6 +12,7 @@ from ...utils.backend import BACKEND from ...utils.logger import setup_logger + log = setup_logger() class AwqGEMMQuantLinear(AWQuantLinear, PackableQuantLinear): diff --git a/gptqmodel/nn_modules/qlinear/awq_gemm_ipex.py b/gptqmodel/nn_modules/qlinear/awq_gemm_ipex.py index d3a5c75ea..0138ffd46 100644 --- a/gptqmodel/nn_modules/qlinear/awq_gemm_ipex.py +++ b/gptqmodel/nn_modules/qlinear/awq_gemm_ipex.py @@ -12,6 +12,7 @@ from ...utils.logger import setup_logger from .awq_gemm import AwqGEMMQuantLinear + log = setup_logger() try: diff --git a/gptqmodel/nn_modules/qlinear/awq_gemv.py b/gptqmodel/nn_modules/qlinear/awq_gemv.py index ebd428df8..46fbbe41f 100644 --- a/gptqmodel/nn_modules/qlinear/awq_gemv.py +++ b/gptqmodel/nn_modules/qlinear/awq_gemv.py @@ -13,6 +13,7 @@ from ...utils.gemv import calculate_zeros_width from ...utils.logger import setup_logger + log = setup_logger() awq_ext, msg = try_import("gptqmodel_awq_kernels") diff --git a/gptqmodel/nn_modules/qlinear/awq_gemv_fast.py b/gptqmodel/nn_modules/qlinear/awq_gemv_fast.py index 354aafab0..3649b6900 100644 --- a/gptqmodel/nn_modules/qlinear/awq_gemv_fast.py +++ b/gptqmodel/nn_modules/qlinear/awq_gemv_fast.py @@ -13,6 +13,7 @@ from ...utils.gemv import calculate_zeros_width from ...utils.logger import setup_logger + log = setup_logger() awq_v2_ext, msg = try_import("gptqmodel_awq_v2_kernels") diff --git a/gptqmodel/nn_modules/qlinear/awq_marlin.py b/gptqmodel/nn_modules/qlinear/awq_marlin.py index e7f3dc5e8..586f98436 100644 --- a/gptqmodel/nn_modules/qlinear/awq_marlin.py +++ b/gptqmodel/nn_modules/qlinear/awq_marlin.py @@ -15,11 +15,19 @@ from ...nn_modules.qlinear import AWQuantLinear from ...utils.backend import BACKEND from ...utils.logger import setup_logger -from ...utils.marlin import (apply_awq_marlin_linear, awq_to_marlin_zero_points, marlin_make_empty_g_idx, - marlin_make_workspace_new, marlin_permute_bias, marlin_permute_scales, replace_parameter) +from ...utils.marlin import ( + apply_awq_marlin_linear, + awq_to_marlin_zero_points, + marlin_make_empty_g_idx, + marlin_make_workspace_new, + marlin_permute_bias, + marlin_permute_scales, + replace_parameter, +) from ...utils.marlin_scalar_type import scalar_types from ...utils.rocm import IS_ROCM + marlin_import_exception = None try: import gptqmodel_marlin_kernels diff --git a/gptqmodel/nn_modules/qlinear/bitblas.py b/gptqmodel/nn_modules/qlinear/bitblas.py index 6092b6ddf..c8307e078 100644 --- a/gptqmodel/nn_modules/qlinear/bitblas.py +++ b/gptqmodel/nn_modules/qlinear/bitblas.py @@ -19,6 +19,7 @@ from ...utils import BACKEND from ...utils.logger import setup_logger + log = setup_logger() BITBLAS_TARGET = None diff --git a/gptqmodel/nn_modules/qlinear/bitblas_target_detector.py b/gptqmodel/nn_modules/qlinear/bitblas_target_detector.py index eb295e5ca..04a8b6c83 100644 --- a/gptqmodel/nn_modules/qlinear/bitblas_target_detector.py +++ b/gptqmodel/nn_modules/qlinear/bitblas_target_detector.py @@ -12,6 +12,7 @@ from ...utils.logger import setup_logger + log = setup_logger() TARGET_MISSING_ERROR = ( diff --git a/gptqmodel/nn_modules/qlinear/exllama.py b/gptqmodel/nn_modules/qlinear/exllama.py index 408278284..5d7f2e49e 100644 --- a/gptqmodel/nn_modules/qlinear/exllama.py +++ b/gptqmodel/nn_modules/qlinear/exllama.py @@ -15,6 +15,7 @@ from ...utils.logger import setup_logger from . import BaseQuantLinear + exllama_import_exception = None try: from gptqmodel_exllama_kernels import make_q4, q4_matmul diff --git a/gptqmodel/nn_modules/qlinear/exllama_eora.py b/gptqmodel/nn_modules/qlinear/exllama_eora.py index 792d1da3d..1e1c12608 100644 --- a/gptqmodel/nn_modules/qlinear/exllama_eora.py +++ b/gptqmodel/nn_modules/qlinear/exllama_eora.py @@ -25,6 +25,7 @@ from ...nn_modules.qlinear import BaseQuantLinear from ...utils.logger import setup_logger + exllama_eora_import_exception = None try: diff --git a/gptqmodel/nn_modules/qlinear/exllamav2.py b/gptqmodel/nn_modules/qlinear/exllamav2.py index 90d2bfed6..60e4ad0c1 100644 --- a/gptqmodel/nn_modules/qlinear/exllamav2.py +++ b/gptqmodel/nn_modules/qlinear/exllamav2.py @@ -16,6 +16,7 @@ from ...utils.exllamav2 import ScratchSpace from ...utils.logger import setup_logger + exllama_v2_import_exception = None try: from gptqmodel_exllamav2_kernels import gemm_half_q_half, make_q_matrix diff --git a/gptqmodel/nn_modules/qlinear/marlin.py b/gptqmodel/nn_modules/qlinear/marlin.py index ced343537..fce5aa02d 100644 --- a/gptqmodel/nn_modules/qlinear/marlin.py +++ b/gptqmodel/nn_modules/qlinear/marlin.py @@ -16,7 +16,6 @@ # Adapted from vllm at https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/quantization/gptq_marlin.py -import os from typing import List, Optional, Tuple import numpy as np @@ -27,12 +26,23 @@ from ...nn_modules.qlinear import BaseQuantLinear from ...utils.backend import BACKEND from ...utils.logger import setup_logger -from ...utils.marlin import (_transform_param, apply_gptq_marlin_linear, gptq_marlin_repack, marlin_import_exception, - marlin_is_k_full, marlin_make_empty_g_idx, marlin_make_workspace_new, marlin_permute_bias, - marlin_permute_scales, marlin_repeat_scales_on_all_ranks, marlin_sort_g_idx) +from ...utils.marlin import ( + _transform_param, + apply_gptq_marlin_linear, + gptq_marlin_repack, + marlin_import_exception, + marlin_is_k_full, + marlin_make_empty_g_idx, + marlin_make_workspace_new, + marlin_permute_bias, + marlin_permute_scales, + marlin_repeat_scales_on_all_ranks, + marlin_sort_g_idx, +) from ...utils.marlin_scalar_type import scalar_types from ...utils.rocm import IS_ROCM + log = setup_logger() diff --git a/gptqmodel/nn_modules/qlinear/qqq.py b/gptqmodel/nn_modules/qlinear/qqq.py index a565e9ced..673a1b3d9 100644 --- a/gptqmodel/nn_modules/qlinear/qqq.py +++ b/gptqmodel/nn_modules/qlinear/qqq.py @@ -18,6 +18,7 @@ from ...utils.logger import setup_logger from ...utils.rocm import IS_ROCM + qqq_import_exception = None try: import gptqmodel_qqq_kernels diff --git a/gptqmodel/nn_modules/qlinear/torch.py b/gptqmodel/nn_modules/qlinear/torch.py index 88ab62846..3e9f2af77 100644 --- a/gptqmodel/nn_modules/qlinear/torch.py +++ b/gptqmodel/nn_modules/qlinear/torch.py @@ -15,6 +15,7 @@ from ...utils.logger import setup_logger from ...utils.torch import torch_compile + log = setup_logger() class TorchQuantLinear(PackableQuantLinear): diff --git a/gptqmodel/nn_modules/qlinear/torch_fused.py b/gptqmodel/nn_modules/qlinear/torch_fused.py index db7017bac..b2932fcfd 100644 --- a/gptqmodel/nn_modules/qlinear/torch_fused.py +++ b/gptqmodel/nn_modules/qlinear/torch_fused.py @@ -15,6 +15,7 @@ from ...utils.logger import setup_logger from ...utils.torch import TORCH_HAS_XPU_FUSED_OPS + log = setup_logger() # TODO: not yet working for cuda/cpu fused int4 ops diff --git a/gptqmodel/nn_modules/qlinear/tritonv2.py b/gptqmodel/nn_modules/qlinear/tritonv2.py index 7ee1957fa..0e2dbdb76 100644 --- a/gptqmodel/nn_modules/qlinear/tritonv2.py +++ b/gptqmodel/nn_modules/qlinear/tritonv2.py @@ -15,6 +15,7 @@ from ...utils.python import has_gil_disabled from .torch import TorchQuantLinear + try: import triton import triton.language as tl diff --git a/gptqmodel/nn_modules/triton_utils/custom_autotune.py b/gptqmodel/nn_modules/triton_utils/custom_autotune.py index c4c1b11fa..3642a3660 100644 --- a/gptqmodel/nn_modules/triton_utils/custom_autotune.py +++ b/gptqmodel/nn_modules/triton_utils/custom_autotune.py @@ -10,6 +10,7 @@ import triton + # code based https://github.com/fpgaminer/GPTQ-triton """ Mostly the same as the autotuner in Triton, but with a few changes like using 40 runs instead of 100. diff --git a/gptqmodel/nn_modules/triton_utils/kernels.py b/gptqmodel/nn_modules/triton_utils/kernels.py index 6cc51d3c2..d5e7b4d66 100644 --- a/gptqmodel/nn_modules/triton_utils/kernels.py +++ b/gptqmodel/nn_modules/triton_utils/kernels.py @@ -12,6 +12,7 @@ from ...utils.torch import HAS_XPU from . import custom_autotune + log = setup_logger() # code based https://github.com/fpgaminer/GPTQ-triton diff --git a/gptqmodel/quantization/__init__.py b/gptqmodel/quantization/__init__.py index 945d7d6c3..b2a617dc2 100644 --- a/gptqmodel/quantization/__init__.py +++ b/gptqmodel/quantization/__init__.py @@ -3,8 +3,16 @@ # SPDX-License-Identifier: Apache-2.0 # Contact: qubitium@modelcloud.ai, x.com/qubitium -from .config import (FORMAT, FORMAT_FIELD_CHECKPOINT, FORMAT_FIELD_CODE, METHOD, - QUANT_CONFIG_FILENAME, QUANT_METHOD_FIELD, BaseQuantizeConfig, QuantizeConfig) +from .config import ( + FORMAT, + FORMAT_FIELD_CHECKPOINT, + FORMAT_FIELD_CODE, + METHOD, + QUANT_CONFIG_FILENAME, + QUANT_METHOD_FIELD, + BaseQuantizeConfig, + QuantizeConfig, +) from .gptq import GPTQ from .gptqv2 import GPTQv2 from .quantizer import Quantizer, quantize diff --git a/gptqmodel/quantization/awq/modules/linear/exllama.py b/gptqmodel/quantization/awq/modules/linear/exllama.py index 328f3e07a..4d50d18a8 100644 --- a/gptqmodel/quantization/awq/modules/linear/exllama.py +++ b/gptqmodel/quantization/awq/modules/linear/exllama.py @@ -1,8 +1,10 @@ import torch import torch.nn as nn + from gptqmodel.quantization.awq.utils.module import try_import from gptqmodel.quantization.awq.utils.packing_utils import unpack_reorder_pack + exl_ext, msg = try_import("gptqmodel_exl_kernels") # Dummy tensor to pass instead of g_idx since there is no way to pass "None" to a C++ extension diff --git a/gptqmodel/quantization/awq/modules/linear/exllamav2.py b/gptqmodel/quantization/awq/modules/linear/exllamav2.py index 21fdf37dc..804d43132 100644 --- a/gptqmodel/quantization/awq/modules/linear/exllamav2.py +++ b/gptqmodel/quantization/awq/modules/linear/exllamav2.py @@ -2,9 +2,11 @@ import torch import torch.nn as nn + from gptqmodel.quantization.awq.utils.module import try_import from gptqmodel.quantization.awq.utils.packing_utils import unpack_reorder_pack + exlv2_ext, msg = try_import("gptqmodel_exlv2_kernels") # Dummy tensor to pass instead of g_idx since there is no way to pass "None" to a C++ extension diff --git a/gptqmodel/quantization/awq/modules/linear/gemm.py b/gptqmodel/quantization/awq/modules/linear/gemm.py index bd28d5553..d3f12c4f9 100644 --- a/gptqmodel/quantization/awq/modules/linear/gemm.py +++ b/gptqmodel/quantization/awq/modules/linear/gemm.py @@ -2,10 +2,12 @@ import torch import torch.nn as nn +from torch.autograd import Function + from gptqmodel.quantization.awq.utils.module import try_import from gptqmodel.quantization.awq.utils.packing_utils import dequantize_gemm from gptqmodel.quantization.awq.utils.utils import get_best_device -from torch.autograd import Function + # NOTE: We check if awq_ext or triton is available. awq_ext will be preferred if both are installed. diff --git a/gptqmodel/quantization/awq/modules/linear/gemm_ipex.py b/gptqmodel/quantization/awq/modules/linear/gemm_ipex.py index a36a1ce76..144b1ecda 100644 --- a/gptqmodel/quantization/awq/modules/linear/gemm_ipex.py +++ b/gptqmodel/quantization/awq/modules/linear/gemm_ipex.py @@ -1,9 +1,11 @@ import torch import torch.nn as nn + from gptqmodel.quantization.awq.utils.packing_utils import dequantize_gemm from .gemm import WQLinear_GEMM + try: from intel_extension_for_pytorch.llm.quantization import IPEXWeightOnlyQuantizedLinear assert hasattr(IPEXWeightOnlyQuantizedLinear, "from_weight"), "The minimum version for ipex is at least 2.4" diff --git a/gptqmodel/quantization/awq/modules/linear/gemv.py b/gptqmodel/quantization/awq/modules/linear/gemv.py index 68f0dcc82..ccef8b916 100644 --- a/gptqmodel/quantization/awq/modules/linear/gemv.py +++ b/gptqmodel/quantization/awq/modules/linear/gemv.py @@ -1,8 +1,10 @@ import torch import torch.nn as nn + from gptqmodel.quantization.awq.utils.module import try_import + awq_ext, msg = try_import("gptqmodel_awq_kernels") def make_divisible(c, divisor): diff --git a/gptqmodel/quantization/awq/modules/linear/gemv_fast.py b/gptqmodel/quantization/awq/modules/linear/gemv_fast.py index 8a5bca338..470d29584 100644 --- a/gptqmodel/quantization/awq/modules/linear/gemv_fast.py +++ b/gptqmodel/quantization/awq/modules/linear/gemv_fast.py @@ -1,7 +1,9 @@ import torch + from gptqmodel.quantization.awq.utils.module import try_import + awq_v2_ext, msg = try_import("gptqmodel_awq_v2_kernels") def make_divisible(c, divisor): diff --git a/gptqmodel/quantization/awq/modules/linear/marlin.py b/gptqmodel/quantization/awq/modules/linear/marlin.py index 9cc921a13..f2704f602 100644 --- a/gptqmodel/quantization/awq/modules/linear/marlin.py +++ b/gptqmodel/quantization/awq/modules/linear/marlin.py @@ -1,8 +1,10 @@ import numpy as np import torch import torch.nn as nn + from gptqmodel.quantization.awq.utils.module import try_import + marlin_cuda, msg = try_import("marlin_cuda") def _get_perms(): diff --git a/gptqmodel/quantization/awq/quantize/scale.py b/gptqmodel/quantization/awq/quantize/scale.py index b34b85070..a66e8d235 100644 --- a/gptqmodel/quantization/awq/quantize/scale.py +++ b/gptqmodel/quantization/awq/quantize/scale.py @@ -2,16 +2,18 @@ import torch import torch.nn as nn -from gptqmodel.quantization.awq.modules.act import ScaledActivation -from gptqmodel.quantization.awq.utils.module import get_op_by_name, set_op_by_name -from gptqmodel.quantization.awq.utils.utils import get_best_device from transformers.activations import GELUActivation, NewGELUActivation, PytorchGELUTanh from transformers.models.bloom.modeling_bloom import BloomGelu from transformers.models.cohere.modeling_cohere import CohereLayerNorm -from transformers.models.gemma2.modeling_gemma2 import Gemma2RMSNorm from transformers.models.gemma.modeling_gemma import GemmaRMSNorm +from transformers.models.gemma2.modeling_gemma2 import Gemma2RMSNorm from transformers.models.llama.modeling_llama import LlamaRMSNorm +from gptqmodel.quantization.awq.modules.act import ScaledActivation +from gptqmodel.quantization.awq.utils.module import get_op_by_name, set_op_by_name +from gptqmodel.quantization.awq.utils.utils import get_best_device + + allowed_norms = [nn.LayerNorm, LlamaRMSNorm, GemmaRMSNorm, Gemma2RMSNorm, CohereLayerNorm] allowed_act_fns = [ nn.GELU, diff --git a/gptqmodel/quantization/awq/utils/packing_utils.py b/gptqmodel/quantization/awq/utils/packing_utils.py index e2c72896b..e108a781f 100644 --- a/gptqmodel/quantization/awq/utils/packing_utils.py +++ b/gptqmodel/quantization/awq/utils/packing_utils.py @@ -1,5 +1,6 @@ import torch + AWQ_ORDER = [0, 2, 4, 6, 1, 3, 5, 7] AWQ_REVERSE_ORDER = [0, 4, 1, 5, 2, 6, 3, 7] diff --git a/gptqmodel/quantization/awq/utils/utils.py b/gptqmodel/quantization/awq/utils/utils.py index 90d3b6f96..3bb3608ea 100644 --- a/gptqmodel/quantization/awq/utils/utils.py +++ b/gptqmodel/quantization/awq/utils/utils.py @@ -3,6 +3,7 @@ import accelerate import torch + ipex_available = importlib.util.find_spec("intel_extension_for_pytorch") is not None try: triton_available = True diff --git a/gptqmodel/quantization/config.py b/gptqmodel/quantization/config.py index 90c40f4ea..b3a63f629 100644 --- a/gptqmodel/quantization/config.py +++ b/gptqmodel/quantization/config.py @@ -18,6 +18,7 @@ from ..adapter.adapter import Lora, normalize_adapter from ..utils.logger import setup_logger + log = setup_logger() BITS_FIELD_CODE = "bits" diff --git a/gptqmodel/quantization/gptq.py b/gptqmodel/quantization/gptq.py index 61cda7533..84fdb76b1 100644 --- a/gptqmodel/quantization/gptq.py +++ b/gptqmodel/quantization/gptq.py @@ -25,7 +25,7 @@ from ..utils.torch import HAS_CUDA, HAS_XPU, device_next from .gar import compose_final_perm, compute_global_perm, compute_local_perms, invert_perm from .quantizer import HF_OPTIMUM, Quantizer -from ..utils.memory import MEM_LORD + log = setup_logger() @@ -56,7 +56,7 @@ def get_number_of_rows_and_cols(layer: nn.Module): class GPTQ: def __init__(self, module: nn.Module, qcfg: Optional[QuantizeConfig] = None): - # self.lock = threading.Lock() + self.lock = threading.Lock() # self.num_tied_handles = 0 # if qcfg.tied_gptq_handle is not None: @@ -136,11 +136,12 @@ def _clone_module(self, copy=True, device: torch.device = None): return clone.float() def add_batch(self, inp: torch.Tensor, out: torch.Tensor): - self.fwd_counter += 1 + with self.lock: + self.fwd_counter += 1 - # print(f"self.module.target_device = {self.module.target_device}") + # print(f"self.module.target_device = {self.module.target_device}") - self.process_batch(inp) + self.process_batch(inp) def process_batch(self, inp: torch.Tensor): # print(f"inp = {inp}") @@ -523,7 +524,6 @@ def quantize( avg_loss = 999999999 del Losses - MEM_LORD.free(self.H) del self.H group_size = self.qcfg.group_size if self.qcfg.group_size != -1 else self.columns diff --git a/gptqmodel/quantization/qqq.py b/gptqmodel/quantization/qqq.py index 6e6bb8cad..6d56fadd7 100644 --- a/gptqmodel/quantization/qqq.py +++ b/gptqmodel/quantization/qqq.py @@ -6,12 +6,13 @@ import transformers from torch import nn -from .gptq import get_number_of_rows_and_cols from .. import QuantizeConfig from ..looper.named_module import NamedModule from ..quantization.quantizer import HF_OPTIMUM from ..utils import setup_logger from ..utils.torch import device_next +from .gptq import get_number_of_rows_and_cols + DEBUG = False diff --git a/gptqmodel/quantization/quantizer.py b/gptqmodel/quantization/quantizer.py index 7b7462790..7614fcc6d 100644 --- a/gptqmodel/quantization/quantizer.py +++ b/gptqmodel/quantization/quantizer.py @@ -11,6 +11,7 @@ from ..quantization import QuantizeConfig from ..utils.logger import setup_logger + log = setup_logger() HF_OPTIMUM = "hf_optimum" diff --git a/gptqmodel/quantization/rotation/hadamard_utils.py b/gptqmodel/quantization/rotation/hadamard_utils.py index fe7a7e881..0abdbab54 100644 --- a/gptqmodel/quantization/rotation/hadamard_utils.py +++ b/gptqmodel/quantization/rotation/hadamard_utils.py @@ -1,8 +1,10 @@ import math import torch + from gptqmodel.utils.logger import setup_logger + # Adapted from https://github.com/Cornell-RelaxML/quip-sharp/blob/main/lib/utils/matmul_had.py log = setup_logger() diff --git a/gptqmodel/quantization/rotation/rotation.py b/gptqmodel/quantization/rotation/rotation.py index da839f598..10bf6dcbf 100644 --- a/gptqmodel/quantization/rotation/rotation.py +++ b/gptqmodel/quantization/rotation/rotation.py @@ -9,6 +9,7 @@ from ...utils.torch import torch_empty_cache from .hadamard_utils import apply_exact_had_to_linear, random_hadamard_matrix + log = setup_logger() diff --git a/gptqmodel/utils/__init__.py b/gptqmodel/utils/__init__.py index 11973f0b5..5f35a7ba6 100644 --- a/gptqmodel/utils/__init__.py +++ b/gptqmodel/utils/__init__.py @@ -9,6 +9,7 @@ from .threads import AsyncManager, SerialWorker from .vram import get_vram + log = setup_logger() ASYNC_BG_QUEUE = AsyncManager(threads=4) diff --git a/gptqmodel/utils/bitblas.py b/gptqmodel/utils/bitblas.py index a37a67146..55588e627 100644 --- a/gptqmodel/utils/bitblas.py +++ b/gptqmodel/utils/bitblas.py @@ -14,6 +14,7 @@ from .model import load_checkpoint_in_model_then_tie_weights from .torch import torch_empty_cache + log = setup_logger() def prepare_model_for_bitblas_load( diff --git a/gptqmodel/utils/device.py b/gptqmodel/utils/device.py index 953af281b..744ce3d08 100644 --- a/gptqmodel/utils/device.py +++ b/gptqmodel/utils/device.py @@ -4,6 +4,8 @@ # Contact: qubitium@modelcloud.ai, x.com/qubitium from __future__ import annotations +from typing import Optional, Union + import torch from device_smi import Device from torch import nn as nn @@ -21,7 +23,6 @@ def get_cpu_usage_memory(): smi = Device(CPU) return smi.memory_used() / 1024 / 1024 / 1024 #GB - def get_device(obj: torch.Tensor | nn.Module) -> torch.device: if isinstance(obj, torch.Tensor): return obj.device @@ -34,3 +35,95 @@ def get_device(obj: torch.Tensor | nn.Module) -> torch.device: return buffers[0].device else: return CPU + +def get_device_new( + obj: torch.Tensor | nn.Module, + recursive: bool = False, + assert_mode: bool = False, + expected: Optional[Union[str, torch.device]] = None, + check_index: bool = False, +) -> torch.device: + """ + Return a representative device for a Tensor/Module and optionally assert uniformity. + + Args: + obj: Tensor or nn.Module. + recursive: If obj is an nn.Module, traverse submodules (parameters/buffers) + recursively (like module.parameters(recurse=True)). + assert_mode: If True, perform assertions about device placement: + - If `expected` is provided: assert that ALL params/buffers live on a device + whose .type matches `expected`'s .type (and, if check_index, the same index). + - If `expected` is None: assert that ALL params/buffers share a single uniform + device type (and, if check_index, the same index). + expected: A target device or device string (e.g., "cpu", "cuda", "cuda:1"). + check_index: If True, also require the same device index (e.g., all on cuda:0). + + Returns: + torch.device: A representative device. Priority order: + - Tensor: its own device + - Module: the first parameter device, else first buffer device, else CPU + """ + # --- Helper to normalize an "expected" device to (type, index) --- + def _normalize_expected(exp: Optional[Union[str, torch.device]]): + if exp is None: + return None, None + dev = torch.device(exp) if isinstance(exp, str) else exp + return dev.type, dev.index + + # --- Collect devices present on the object --- + if isinstance(obj, torch.Tensor): + devices = [obj.device] + elif isinstance(obj, nn.Module): + # Pull parameters/buffers; recurse if requested + params = list(obj.parameters(recurse=recursive)) + buffs = list(obj.buffers(recurse=recursive)) + devices = [] + if params: + devices.extend(p.device for p in params) + if buffs: + devices.extend(b.device for b in buffs) + if not devices: + devices = [CPU] + else: + raise TypeError(f"get_device() expects Tensor or nn.Module, got {type(obj)}") + + # Representative device (keep legacy behavior) + rep_device = devices[0] + + # --- Assertions (if requested) --- + if assert_mode: + exp_type, exp_index = _normalize_expected(expected) + + def _key(d: torch.device): + return (d.type, d.index if check_index else None) + + if exp_type is not None: + # Check against expected device TYPE (and optionally INDEX) + mismatches = [ + d for d in devices + if d.type != exp_type or (check_index and d.index != exp_index) + ] + if mismatches: + # Build a concise error message with a few examples + sample = ", ".join({f"{d.type}:{d.index}" for d in mismatches[:5]}) + target = f"{exp_type}" + (f":{exp_index}" if check_index else "") + raise AssertionError( + f"Device assertion failed: expected all tensors on {target}, " + f"but found mismatches (e.g., {sample}). Total tensors checked: {len(devices)}." + ) + else: + # Ensure uniformity across all devices (by type, and optionally index) + unique = { _key(d) for d in devices } + if len(unique) > 1: + # Summarize what we actually found + summary = ", ".join(sorted(f"{t}:{i}" for (t, i) in unique)) + detail = ", ".join({f"{d.type}:{d.index}" for d in devices[:8]}) + msg = ( + "Device assertion failed: tensors are on multiple devices. " + f"Found {{{summary}}}. Examples: {detail}." + ) + if not check_index: + msg += " (Tip: set check_index=True to also require same device index.)" + raise AssertionError(msg) + + return rep_device diff --git a/gptqmodel/utils/eval.py b/gptqmodel/utils/eval.py index 2fa44f405..2a6fa3325 100644 --- a/gptqmodel/utils/eval.py +++ b/gptqmodel/utils/eval.py @@ -7,6 +7,7 @@ import os from enum import Enum + try: from enum import EnumType except ImportError: diff --git a/gptqmodel/utils/hf.py b/gptqmodel/utils/hf.py index 04eb1c3c4..de7285c01 100644 --- a/gptqmodel/utils/hf.py +++ b/gptqmodel/utils/hf.py @@ -6,6 +6,7 @@ from ..utils.logger import setup_logger + log = setup_logger() # TODO FIXME! Pre-quantized use AutoModelForCausalLM.from_pretrained() but post-quantized use AutoModelForCausalLM.from_config() diff --git a/gptqmodel/utils/importer.py b/gptqmodel/utils/importer.py index 5bd1e5a57..1b278c06b 100644 --- a/gptqmodel/utils/importer.py +++ b/gptqmodel/utils/importer.py @@ -8,6 +8,7 @@ from typing import Dict, List, Optional, Type, Union import torch + from gptqmodel.adapter.adapter import Adapter from ..models._const import DEVICE, normalize_device @@ -33,6 +34,7 @@ from .rocm import IS_ROCM from .torch import HAS_CUDA, HAS_MPS, HAS_XPU + message_logged = False log = setup_logger() diff --git a/gptqmodel/utils/marlin.py b/gptqmodel/utils/marlin.py index 2f7b82707..d069c8b87 100644 --- a/gptqmodel/utils/marlin.py +++ b/gptqmodel/utils/marlin.py @@ -11,6 +11,7 @@ from .marlin_scalar_type import ScalarType from .rocm import IS_ROCM + log = setup_logger() marlin_import_exception = None diff --git a/gptqmodel/utils/marlin_scalar_type.py b/gptqmodel/utils/marlin_scalar_type.py index 055f28914..8dda14fd2 100644 --- a/gptqmodel/utils/marlin_scalar_type.py +++ b/gptqmodel/utils/marlin_scalar_type.py @@ -7,6 +7,7 @@ from enum import Enum from typing import Optional, Union + _SCALAR_TYPES_ID_MAP = {} diff --git a/gptqmodel/utils/memory.py b/gptqmodel/utils/memory.py index 67ad93ec3..185704d8d 100644 --- a/gptqmodel/utils/memory.py +++ b/gptqmodel/utils/memory.py @@ -2,13 +2,15 @@ # SPDX-License-Identifier: Apache-2.0 from __future__ import annotations + import os import threading -from typing import Dict, Iterable, Tuple, Generator +from typing import Dict, Generator, Iterable, Tuple import torch import torch.nn as nn + # ---------- ANSI COLORS ---------- RESET = "\033[0m" RED = "\033[91m" @@ -325,5 +327,3 @@ def format_bytes(n: int) -> str: return f"{x:.2f} {u}" x /= 1024.0 -# default to auto gc interval for every 8GB of freed memory -MEM_LORD = MemTracker(auto_gc_bytes="auto") \ No newline at end of file diff --git a/gptqmodel/utils/mlx.py b/gptqmodel/utils/mlx.py index 45b6679d3..3a056ff4e 100644 --- a/gptqmodel/utils/mlx.py +++ b/gptqmodel/utils/mlx.py @@ -9,6 +9,7 @@ from .log import setup_logger from .torch import torch_empty_cache + try: import mlx.core as mx from mlx_lm import generate diff --git a/gptqmodel/utils/mmlupro.py b/gptqmodel/utils/mmlupro.py index 7f939b754..7bf4044a8 100644 --- a/gptqmodel/utils/mmlupro.py +++ b/gptqmodel/utils/mmlupro.py @@ -12,6 +12,7 @@ from ..utils.logger import setup_logger + choices = ["A", "B", "C", "D", "E", "F", "G", "H", "I", "J", "K", "L", "M", "N", "O", "P"] max_model_length = 4096 max_new_tokens = 2048 diff --git a/gptqmodel/utils/model.py b/gptqmodel/utils/model.py index 1fe3c3747..a2fb302a8 100644 --- a/gptqmodel/utils/model.py +++ b/gptqmodel/utils/model.py @@ -23,8 +23,6 @@ import torch import torch.nn as nn import transformers -from gptqmodel.nn_modules.qlinear.exllama_eora import ExllamaEoraQuantLinear -from gptqmodel.nn_modules.qlinear.marlin import MarlinQuantLinear from huggingface_hub import HfApi, hf_hub_download from packaging import version from torch.nn.modules.conv import _ConvNd @@ -32,10 +30,18 @@ from transformers.pytorch_utils import id_tensor_storage from transformers.utils.hub import cached_file +from gptqmodel.nn_modules.qlinear.exllama_eora import ExllamaEoraQuantLinear +from gptqmodel.nn_modules.qlinear.marlin import MarlinQuantLinear + from ..adapter.adapter import Adapter from ..looper.named_module import NamedModule -from ..models._const import (CPU, DEVICE, EXLLAMA_DEFAULT_MAX_INPUT_LENGTH, - EXPERT_INDEX_PLACEHOLDER, SUPPORTS_MODULE_TYPES) +from ..models._const import ( + CPU, + DEVICE, + EXLLAMA_DEFAULT_MAX_INPUT_LENGTH, + EXPERT_INDEX_PLACEHOLDER, + SUPPORTS_MODULE_TYPES, +) from ..nn_modules.qlinear import BaseQuantLinear from ..nn_modules.qlinear.awq_exllamav2 import AwqExllamaV2QuantLinear from ..nn_modules.qlinear.exllama import ExllamaQuantLinear @@ -48,7 +54,7 @@ from .importer import select_quant_linear from .logger import setup_logger from .torch import torch_empty_cache, torch_new_stream_ctx -from ..utils.memory import MEM_LORD + log = setup_logger() @@ -80,7 +86,6 @@ def recurse_setattr(module, name, value): def move_to(obj: torch.Tensor | nn.Module, device: torch.device, dtype: torch.dtype = None, stream: bool = False): if get_device(obj) != device: - MEM_LORD.free(obj) if stream: # we cannot support changing dtype and stream at the same time assert dtype is None, f"streaming does not support changing dtype: actual = `{dtype}" diff --git a/gptqmodel/utils/offload.py b/gptqmodel/utils/offload.py index c52eaebda..85a1f045f 100644 --- a/gptqmodel/utils/offload.py +++ b/gptqmodel/utils/offload.py @@ -10,17 +10,18 @@ from typing import Iterable, List, Optional, Set, Tuple import torch + # move base_module tensors to disk from accelerate import disk_offload from accelerate.hooks import remove_hook_from_module, remove_hook_from_submodules from accelerate.utils import align_module_device, has_offloaded_params from torch import nn -from .memory import MEM_LORD from ..looper.named_module import NamedModule from .device import get_device from .torch import CPU, HAS_CUDA, META + _lock = threading.Lock() def get_module_fullname(model: torch.nn.Module, module: torch.nn.Module) -> str: @@ -58,30 +59,30 @@ def offload_to_disk(module: List[str] | nn.Module, model: nn.Module, disk_path: assert module is not None assert model is not None - with _lock: - if isinstance(module, List): - for name in module: - m = get_submodule(model, name) - # unwrap named module - if isinstance(m, NamedModule): - # print(f"offloading named module: {module.full_name}") - m = m.module - - full_name = get_module_fullname(model=model, module=m) - _offload_disk(module=m, name=full_name, disk_path=disk_path) - else: + #with _lock: + if isinstance(module, List): + for name in module: + m = get_submodule(model, name) # unwrap named module - if isinstance(module, NamedModule): + if isinstance(m, NamedModule): # print(f"offloading named module: {module.full_name}") - module = module.module + m = m.module + + full_name = get_module_fullname(model=model, module=m) + _offload_disk(module=m, name=full_name, disk_path=disk_path) + else: + # unwrap named module + if isinstance(module, NamedModule): + # print(f"offloading named module: {module.full_name}") + module = module.module - full_name = get_module_fullname(model=model, module=module) + full_name = get_module_fullname(model=model, module=module) - _offload_disk(module=module, name=full_name, disk_path=disk_path) + _offload_disk(module=module, name=full_name, disk_path=disk_path) - if hasattr(module, "config") and getattr(module.config, - "tie_word_embeddings", False): - module.tie_weights() # makes lm_head.weight point to embed_tokens.weight again after offload + if hasattr(module, "config") and getattr(module.config, + "tie_word_embeddings", False): + module.tie_weights() # makes lm_head.weight point to embed_tokens.weight again after offload # print("offload_disk: list item tree") # print_module_tree(module) @@ -91,7 +92,6 @@ def _offload_disk(module: nn.Module, name: str, disk_path: str = "."): # print(f"[skip] '{name}' is on meta; leaving as-is") return - MEM_LORD.free(module) m_device = get_device(module) if m_device.type == "cuda": torch.cuda.set_device(m_device) @@ -262,43 +262,43 @@ def undo_offload_to_disk( Returns: The same `module`, now “de-offloaded”. """ - with _lock: - # Track candidate offload dirs if user asks to delete them later. - offload_dirs: Set[str] = set() - - # 1) Materialize all offloaded leaves as real tensors on the target device/dtype. - with torch.inference_mode(): - for sub in module.modules(): - if not has_offloaded_params(sub): - continue - - # Discover offload folders opportunistically (optional cleanup) - offload_dirs |= _possible_offload_dirs_from_hook(sub) - - # Prefer a fast path reading directly from the weights_map if exposed by this Accelerate version. - handled = _restore_leaves_from_weights_map(sub, device=device, dtype=dtype) - if handled: - continue - - # Fallback path: ask Accelerate to align this submodule to the execution device, - # then clone+rebind leaves so they become regular, hook-free tensors. - with _maybe_align(sub, device=device): - for name, tensor, is_param in list(_iter_leaf_tensors(sub, include_buffers=include_buffers)): - is_meta = (getattr(tensor, "is_meta", False) or tensor.device is META) - if not is_meta: - # Still clone if the hook attached a tensor view that would be re-offloaded later. - # Safer to always break links to hook-managed storages. - src = tensor - else: - # After align, meta leaves should be backed by real memory on `device`. - src = tensor - - if is_param: - new_p = _clone_into_parameter(src, device=device, dtype=dtype, requires_grad=tensor.requires_grad) - setattr(sub, name, new_p) - else: - new_b = _clone_into_buffer(src, device=device, dtype=dtype) - setattr(sub, name, new_b) + #with _lock: + # Track candidate offload dirs if user asks to delete them later. + offload_dirs: Set[str] = set() + + # 1) Materialize all offloaded leaves as real tensors on the target device/dtype. + with torch.inference_mode(): + for sub in module.modules(): + if not has_offloaded_params(sub): + continue + + # Discover offload folders opportunistically (optional cleanup) + offload_dirs |= _possible_offload_dirs_from_hook(sub) + + # Prefer a fast path reading directly from the weights_map if exposed by this Accelerate version. + handled = _restore_leaves_from_weights_map(sub, device=device, dtype=dtype) + if handled: + continue + + # Fallback path: ask Accelerate to align this submodule to the execution device, + # then clone+rebind leaves so they become regular, hook-free tensors. + with _maybe_align(sub, device=device): + for name, tensor, is_param in list(_iter_leaf_tensors(sub, include_buffers=include_buffers)): + is_meta = (getattr(tensor, "is_meta", False) or tensor.device is META) + if not is_meta: + # Still clone if the hook attached a tensor view that would be re-offloaded later. + # Safer to always break links to hook-managed storages. + src = tensor + else: + # After align, meta leaves should be backed by real memory on `device`. + src = tensor + + if is_param: + new_p = _clone_into_parameter(src, device=device, dtype=dtype, requires_grad=tensor.requires_grad) + setattr(sub, name, new_p) + else: + new_b = _clone_into_buffer(src, device=device, dtype=dtype) + setattr(sub, name, new_b) # 2) Remove all Accelerate hooks so future forwards won't offload again. remove_hook_from_submodules(module) # public API diff --git a/gptqmodel/utils/openai_server.py b/gptqmodel/utils/openai_server.py index e84653ecb..f2a52e4ee 100644 --- a/gptqmodel/utils/openai_server.py +++ b/gptqmodel/utils/openai_server.py @@ -9,6 +9,7 @@ import torch + try: import uvicorn from fastapi import FastAPI, HTTPException diff --git a/gptqmodel/utils/perplexity.py b/gptqmodel/utils/perplexity.py index 51b31456a..806ce66ac 100644 --- a/gptqmodel/utils/perplexity.py +++ b/gptqmodel/utils/perplexity.py @@ -10,6 +10,7 @@ from datasets import load_dataset, load_from_disk from logbar import LogBar + logger = LogBar.shared() class Perplexity: diff --git a/gptqmodel/utils/python.py b/gptqmodel/utils/python.py index 2dfac7e22..3f3f5cd0c 100644 --- a/gptqmodel/utils/python.py +++ b/gptqmodel/utils/python.py @@ -1,9 +1,11 @@ import platform import sys -from gptqmodel.utils.logger import setup_logger from packaging.version import Version +from gptqmodel.utils.logger import setup_logger + + log = setup_logger() # Check if GIL (global interpreter lock) is controllable in this Python build. diff --git a/gptqmodel/utils/rocm.py b/gptqmodel/utils/rocm.py index abf4f6cb8..c8623fdd6 100644 --- a/gptqmodel/utils/rocm.py +++ b/gptqmodel/utils/rocm.py @@ -5,4 +5,5 @@ import torch + IS_ROCM = torch.version.hip is not None diff --git a/gptqmodel/utils/sglang.py b/gptqmodel/utils/sglang.py index 764006473..c48074844 100644 --- a/gptqmodel/utils/sglang.py +++ b/gptqmodel/utils/sglang.py @@ -8,6 +8,7 @@ import torch from transformers import AutoConfig + try: import sglang as sgl SGLANG_AVAILABLE = True diff --git a/gptqmodel/utils/structure.py b/gptqmodel/utils/structure.py index 44b57f1fa..ecfd943f9 100644 --- a/gptqmodel/utils/structure.py +++ b/gptqmodel/utils/structure.py @@ -32,6 +32,7 @@ from ..utils.logger import setup_logger + # ========================= # ANSI color helpers # ========================= @@ -610,7 +611,7 @@ def alias_all_from_turtle_if_meta( """ if turtle_model is None: return 0 - + turtle_map = dict(turtle_model.named_modules()) swapped = 0 diff --git a/gptqmodel/utils/threadx.py b/gptqmodel/utils/threadx.py new file mode 100644 index 000000000..a827ad48c --- /dev/null +++ b/gptqmodel/utils/threadx.py @@ -0,0 +1,1010 @@ +# SPDX-FileCopyrightText: 2024-2025 ModelCloud.ai +# SPDX-FileCopyrightText: 2024-2025 qubitium@modelcloud.ai +# SPDX-License-Identifier: Apache-2.0 +# Contact: qubitium@modelcloud.ai, x.com/qubitium + +from __future__ import annotations + +import contextlib +import queue +import threading +import time +from concurrent.futures import Future +from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union + +import torch + +from ..utils.logger import setup_logger + + +log = setup_logger() + + +DeviceLike = Union[str, int, torch.device] + + +# --------------------------- Backend availability helpers --------------------------- + +def _mps_available() -> bool: + return ( + hasattr(torch, "backends") + and hasattr(torch.backends, "mps") + and torch.backends.mps.is_available() + ) + + +# --- HARD COPIES of original empty_cache callables (never auto-switched) --- +TORCH_CUDA_EMPTY_CACHE: Optional[Callable[[], None]] = None +TORCH_XPU_EMPTY_CACHE: Optional[Callable[[], None]] = None +TORCH_MPS_EMPTY_CACHE: Optional[Callable[[], None]] = None + +try: + TORCH_CUDA_EMPTY_CACHE = getattr(torch.cuda, "empty_cache", None) if hasattr(torch, "cuda") else None + if TORCH_CUDA_EMPTY_CACHE is not None and not callable(TORCH_CUDA_EMPTY_CACHE): + TORCH_CUDA_EMPTY_CACHE = None +except Exception: + TORCH_CUDA_EMPTY_CACHE = None + +try: + TORCH_XPU_EMPTY_CACHE = getattr(torch.xpu, "empty_cache", None) if hasattr(torch, "xpu") else None + if TORCH_XPU_EMPTY_CACHE is not None and not callable(TORCH_XPU_EMPTY_CACHE): + TORCH_XPU_EMPTY_CACHE = None +except Exception: + TORCH_XPU_EMPTY_CACHE = None + +try: + TORCH_MPS_EMPTY_CACHE = getattr(torch.mps, "empty_cache", None) if hasattr(torch, "mps") else None + if TORCH_MPS_EMPTY_CACHE is not None and not callable(TORCH_MPS_EMPTY_CACHE): + TORCH_MPS_EMPTY_CACHE = None +except Exception: + TORCH_MPS_EMPTY_CACHE = None +# ------------------------------------------------------------------------------- + + +def _coerce_device(d: DeviceLike) -> torch.device: + if isinstance(d, torch.device): + return d + if isinstance(d, int): + if torch.cuda.is_available(): + return torch.device("cuda", d) + if hasattr(torch, "xpu") and torch.xpu.is_available(): + return torch.device("xpu", d) + if _mps_available(): + return torch.device("mps") + return torch.device("cpu") + return torch.device(d) + + +@contextlib.contextmanager +def _device_ctx(dev: torch.device): + """Set the caller thread’s current device for CUDA/XPU so library handles match.""" + if dev.type == "cuda": + with torch.cuda.device(dev.index): + yield + elif dev.type == "xpu" and hasattr(torch, "xpu"): + with torch.xpu.device(dev.index): + yield + else: + yield + + +def _activate_thread_device(dev: torch.device): + """Pin the worker thread to the device.""" + if dev.type == "cuda": + torch.cuda.set_device(dev.index) + elif dev.type == "xpu" and hasattr(torch, "xpu"): + torch.xpu.set_device(dev.index) + # mps/cpu: nothing to pin + + +# --------------------------- Read-Write Lock (writer-preference) --------------------------- + +class _RWLock: + """ + Reader-writer lock with writer preference. + + - Multiple readers may hold the lock simultaneously. + - A single writer holds exclusivity. + - When a writer is waiting, new readers will block. + - Writer is re-entrant for its owning thread. + """ + def __init__(self): + self._cond = threading.Condition() + self._readers = 0 + self._writer: Optional[int] = None # thread id that owns write + self._writer_depth = 0 + self._writers_waiting = 0 + + # --- Write (exclusive) --- + def acquire_write(self): + me = threading.get_ident() + with self._cond: + if self._writer == me: # re-entrant + self._writer_depth += 1 + return + self._writers_waiting += 1 + try: + while self._writer is not None or self._readers > 0: + self._cond.wait() + self._writer = me + self._writer_depth = 1 + finally: + self._writers_waiting -= 1 + + def release_write(self): + me = threading.get_ident() + with self._cond: + if self._writer != me: + raise RuntimeError("release_write called by non-owner") + self._writer_depth -= 1 + if self._writer_depth == 0: + self._writer = None + self._cond.notify_all() + + @contextlib.contextmanager + def writer(self): + self.acquire_write() + try: + yield + finally: + self.release_write() + + # --- Read (shared) --- + def acquire_read(self): + me = threading.get_ident() + with self._cond: + # writer can re-enter as reader + if self._writer == me: + self._readers += 1 + return + while self._writer is not None or self._writers_waiting > 0: + self._cond.wait() + self._readers += 1 + + def release_read(self): + with self._cond: + if self._readers <= 0: + raise RuntimeError("release_read without acquire_read") + self._readers -= 1 + if self._readers == 0: + self._cond.notify_all() + + @contextlib.contextmanager + def reader(self): + self.acquire_read() + try: + yield + finally: + self.release_read() + + +class _LockGroup(contextlib.AbstractContextManager): + """Acquire multiple device write locks in deterministic order to avoid deadlocks.""" + def __init__(self, ordered_pairs: List[tuple[str, _RWLock]]): + self._pairs = ordered_pairs + + def __enter__(self): + for _, lk in self._pairs: + lk.acquire_write() + return self + + def __exit__(self, exc_type, exc, tb): + for _, lk in reversed(self._pairs): + lk.release_write() + return False + + +class _ReadLockGroup(contextlib.AbstractContextManager): + """Acquire multiple device read locks in deterministic order.""" + def __init__(self, ordered_pairs: List[tuple[str, _RWLock]]): + self._pairs = ordered_pairs + + def __enter__(self): + for _, lk in self._pairs: + lk.acquire_read() + return self + + def __exit__(self, exc_type, exc, tb): + for _, lk in reversed(self._pairs): + lk.release_read() + return False + + +class _WaitAndLock(contextlib.AbstractContextManager): + """ + Context manager returned by pool.wait(scope, lock=True). + On enter: acquires writer locks over the scope in canonical order, + which inherently waits for in-flight readers (tasks) to drain. + On exit: releases locks. + """ + def __init__(self, pairs: List[tuple[str, _RWLock]]): + self._group = _LockGroup(pairs) + + def __enter__(self): + return self._group.__enter__() + + def __exit__(self, exc_type, exc, tb): + return self._group.__exit__(exc_type, exc, tb) + + +# --------------------------- Worker Thread --------------------------- + +class _DeviceWorker: + """ + Single worker thread bound to one device. + Queue entries: (is_task: bool, fn, args, kwargs, future) + Supports configurable lifecycle: after N tasks, stop accepting new work, + drain its queue, and exit; the pool will spawn a replacement. + """ + def __init__( + self, + device: torch.device, + rwlock: _RWLock, + on_task_finished: Callable[[str], None], + on_retire_request: Callable[[str, _DeviceWorker], None], + on_worker_exit: Callable[[str, _DeviceWorker], None], + name: Optional[str] = None, + inference_mode: bool = False, + lifecycle_calls: int = 50, + ): + self.device = device + self.rwlock = rwlock + self._on_task_finished = on_task_finished + self._on_retire_request = on_retire_request + self._on_worker_exit = on_worker_exit + self._lifecycle_limit = max(0, int(lifecycle_calls)) # 0 disables rotation + self._tasks_since_spawn = 0 + + self.key = f"{device.type}:{device.index}" if device.index is not None else device.type + self.name = name or f"DPWorker-{self.key}" + self._q: "queue.Queue[Tuple[bool, Callable[..., Any], tuple, dict, Future]]" = queue.Queue() + self._stop = threading.Event() + self._retire_requested = False + self._accepting = True + + self._inference_mode = inference_mode + self._thread = threading.Thread(target=self._run, name=self.name, daemon=True) + self._thread.start() + + # --- lifecycle / accept state --- + def is_accepting(self) -> bool: + return self._accepting and not self._stop.is_set() + + def request_stop(self): + self._stop.set() + self._q.put((False, lambda: None, (), {}, Future())) + + # --- public API for pool --- + def submit(self, fn: Callable[..., Any], /, *args, **kwargs) -> Future: + fut = Future() + self._q.put((True, fn, args, kwargs, fut)) + return fut + + def stop(self): + self.request_stop() + + def join(self): + self._thread.join() + + # --- internal main loop --- + def _run(self): + _activate_thread_device(self.device) + maybe_inference = torch.inference_mode() if self._inference_mode else contextlib.nullcontext() + with maybe_inference: + while not self._stop.is_set(): + # If we're retiring and nothing is queued, exit gracefully + if self._retire_requested and self._q.empty(): + break + try: + is_task, fn, args, kwargs, fut = self._q.get(timeout=0.05) + except queue.Empty: + continue + try: + if not is_task: + break # sentinel -> exit + # Tasks take a read lock so GC's writer lock can't interleave + with self.rwlock.reader(): + stream = kwargs.pop("_cuda_stream", None) + with _device_ctx(self.device): + if stream is not None and self.device.type == "cuda": + with torch.cuda.stream(stream): + result = fn(*args, **kwargs) + else: + result = fn(*args, **kwargs) + if not fut.cancelled(): + fut.set_result(result) + except BaseException as exc: + if not fut.cancelled(): + fut.set_exception(exc) + finally: + if is_task: + self._tasks_since_spawn += 1 + self._on_task_finished(self.key) + # Lifecycle check: once we hit the limit, mark retiring (stop accepting) + if self._lifecycle_limit > 0 and self._tasks_since_spawn >= self._lifecycle_limit: + if not self._retire_requested: + self._retire_requested = True + self._accepting = False + # Notify pool to spawn a replacement now + self._on_retire_request(self.key, self) + self._q.task_done() + + # Thread is exiting; notify pool for cleanup + try: + self._on_worker_exit(self.key, self) + finally: + pass + + +# --------------------------- Public Pool --------------------------- + +class DeviceThreadPool: + """ + Multi-device thread pool with: + - Eager discovery/creation of workers and locks for CUDA/XPU/MPS/CPU. + - Configurable worker counts per device (default 1). + - Correct per-thread device context. + - submit()/do() for async/sync, with optional `_cuda_stream` (CUDA only). + - Per-device RWLocks + global lock and family/all read-locks. + - wait(scope, lock=False/True) to drain tasks (optionally with exclusive locks). + - Per-device/global completed counters and in-flight counters. + - Janitor: triggers empty-cache after N completions on accelerator devices, under a global lock. + - GC diagnostics helpers. + - Worker lifecycle rotation: after N tasks (default 50), workers retire and are replaced. + """ + + def __init__( + self, + devices: Optional[Iterable[DeviceLike]] = None, + *, + include_cuda: bool = True, + include_xpu: bool = True, + include_mps: bool = True, + include_cpu: bool = True, + inference_mode: bool = False, + empty_cache_every_n: int = 50, # <=0 disables janitor + workers: Optional[Dict[str, int]] = None, # e.g. {'cpu':4, 'cuda:per':1, 'cuda:0':3} + gc_debounce_seconds: float = 0.02, # absorb bursty triggers before GC + worker_lifecycle_calls: int = 50, # <=0 disables lifecycle rotation + ): + """ + Args: + devices: explicit list of devices. If None, auto-discover per include_* flags. + workers: dict mapping worker-count policy: + - 'cpu': N -> N workers total for CPU + - 'mps': N -> N workers for MPS (single device) + - 'cuda:per': N -> N workers per CUDA index + - 'xpu:per': N -> N workers per XPU index + - 'cuda:': N -> override for specific CUDA index + - 'xpu:': N -> override for specific XPU index + Unspecified devices default to 1 worker each. + gc_debounce_seconds: short wait to coalesce multiple triggers. + worker_lifecycle_calls: number of tasks a worker handles before retiring (0=disabled). + """ + if devices is None: + discovered: List[torch.device] = [] + if include_cuda and torch.cuda.is_available(): + for i in range(torch.cuda.device_count()): + discovered.append(torch.device("cuda", i)) + if include_xpu and hasattr(torch, "xpu") and torch.xpu.is_available(): + for i in range(torch.xpu.device_count()): + discovered.append(torch.device("xpu", i)) + if include_mps and _mps_available(): + discovered.append(torch.device("mps")) + if include_cpu: + discovered.append(torch.device("cpu")) + devices = discovered + + self._locks: Dict[str, _RWLock] = {} + self._devices_by_key: Dict[str, torch.device] = {} + + # Worker groups: key -> List[_DeviceWorker] + self._worker_groups: Dict[str, List[_DeviceWorker]] = {} + self._dispatch_rr: Dict[str, int] = {} # round-robin index per key + self._dispatch_lock = threading.Lock() + + # Stats / GC / inflight control + self._stats_lock = threading.Lock() + self._per_device_done: Dict[str, int] = {} + self._total_done: int = 0 + + self._empty_cache_every_n = int(empty_cache_every_n) + self._gc_event = threading.Event() + self._stop_event = threading.Event() + self._janitor: Optional[threading.Thread] = None + + # in-flight (scheduled but not finished) counters + per-device CVs + self._inflight: Dict[str, int] = {} + self._inflight_cv: Dict[str, threading.Condition] = {} + + # GC dedupe/coalesce + self._gc_debounce_s = float(gc_debounce_seconds) + # per-device watermark of "done" as of last GC that actually ran + self._last_gc_done_per_device: Dict[str, int] = {} + + # Worker lifecycle rotation + self._worker_lifecycle_calls = int(worker_lifecycle_calls) + # Store inference mode for worker spawns + self._inference_mode = bool(inference_mode) + + workers = workers or {} + + # Build locks, inflight structs, and workers eagerly + for d in devices: + dev = _coerce_device(d) + if dev.type not in ("cuda", "xpu", "mps", "cpu"): + continue + key = self._key(dev) + if key in self._devices_by_key: + continue + + self._devices_by_key[key] = dev + self._locks[key] = _RWLock() + self._per_device_done[key] = 0 + self._inflight[key] = 0 + self._inflight_cv[key] = threading.Condition() + self._last_gc_done_per_device[key] = 0 + + n_workers = self._resolve_workers_for_device(dev, workers) + group: List[_DeviceWorker] = [] + for wid in range(int(max(1, n_workers))): + worker = self._spawn_worker(dev, name=f"DPWorker-{key}#{wid}") + group.append(worker) + self._worker_groups[key] = group + self._dispatch_rr[key] = 0 + + # Canonical lock order + self._ordered_keys = sorted(self._locks.keys()) + + # GC diagnostics counters + self._gc_passes = 0 + self._last_gc_ts: Optional[float] = None + + # Start janitor if enabled and accelerators exist + if self._empty_cache_every_n > 0 and any( + self._devices_by_key[k].type in ("cuda", "xpu", "mps") for k in self._ordered_keys + ): + self._janitor = threading.Thread( + target=self._janitor_loop, name="DP-Janitor", daemon=True + ) + self._janitor.start() + + # --------------- Worker management (spawn/retire/cleanup) --------------- + + def _spawn_worker(self, dev: torch.device, name: Optional[str] = None) -> _DeviceWorker: + key = self._key(dev) + return _DeviceWorker( + device=dev, + rwlock=self._locks[key], + on_task_finished=self._on_task_finished, + on_retire_request=self._on_worker_retire_request, + on_worker_exit=self._on_worker_exit, + name=name, + inference_mode=self._inference_mode, + lifecycle_calls=self._worker_lifecycle_calls, + ) + + def _on_worker_retire_request(self, key: str, worker: _DeviceWorker) -> None: + """ + A worker hit its lifecycle limit. Mark it non-accepting (it already is), + and immediately spawn a replacement to maintain capacity. + """ + dev = self._devices_by_key[key] + with self._dispatch_lock: + group = self._worker_groups.get(key, []) + if worker in group: + replacement = self._spawn_worker(dev, name=f"{worker.name}.r{int(time.time()*1000)}") + group.append(replacement) + self._worker_groups[key] = group + + def _on_worker_exit(self, key: str, worker: _DeviceWorker) -> None: + """Cleanup finished workers from the group.""" + with self._dispatch_lock: + group = self._worker_groups.get(key, []) + if worker in group: + group.remove(worker) + self._worker_groups[key] = group + if group: + self._dispatch_rr[key] %= len(group) + else: + self._dispatch_rr[key] = 0 + + # --------------- Public Work API --------------- + + def submit( + self, + device: DeviceLike, + fn: Callable[..., Any], + /, + *args, + _cuda_stream: Optional[torch.cuda.Stream] = None, + **kwargs, + ) -> Future: + """ + Asynchronously schedule work on the given device; returns a Future. + Optional (CUDA): pass `_cuda_stream=` to launch into a specific stream. + """ + dev = _coerce_device(device) + key = self._key(dev) + worker = self._pick_worker(key) + if _cuda_stream is not None and dev.type != "cuda": + raise ValueError("_cuda_stream is only valid for CUDA devices") + + # mark in-flight before enqueue to avoid races with wait() + self._mark_scheduled(key) + try: + return worker.submit(fn, *args, _cuda_stream=_cuda_stream, **kwargs) + except BaseException: + # roll back inflight if enqueue fails (rare) + self._mark_finished(key) + raise + + def do( + self, + device: DeviceLike, + fn: Callable[..., Any], + /, + *args, + _cuda_stream: Optional[torch.cuda.Stream] = None, + **kwargs, + ) -> Any: + """Synchronously schedule work and block for the result.""" + fut = self.submit(device, fn, *args, _cuda_stream=_cuda_stream, **kwargs) + return fut.result() + + def shutdown(self, wait: bool = True): + """Gracefully stop all workers and janitor.""" + self._stop_event.set() + self._gc_event.set() # wake janitor + if self._janitor is not None and wait: + self._janitor.join() + + for group in self._worker_groups.values(): + for w in group: + w.stop() + if wait: + for group in self._worker_groups.values(): + for w in group: + w.join() + + # --------------- Public Lock API --------------- + + def device_lock(self, device: DeviceLike): + """Exclusive lock for a single device (blocks all its workers).""" + dev = _coerce_device(device) + key = self._key(dev) + lk = self._locks.get(key) + if lk is None: + raise ValueError(f"Unknown device for pool: {dev}") + return lk.writer() + + def read_lock(self, device: DeviceLike | str): + """ + Shared/read lock. Accepts: + - concrete device: torch.device('cuda:0'), 'cuda:1' + - family device: torch.device('cuda'), 'cuda', 'xpu', 'mps', 'cpu' + - 'all' for every device in the pool + Returns a context manager. + """ + if isinstance(device, str): + if device == "all": + pairs = [(k, self._locks[k]) for k in self._ordered_keys] + return _ReadLockGroup(pairs) + if device in ("cuda", "xpu", "mps", "cpu"): + keys = [k for k in self._ordered_keys if k.startswith(device)] + if not keys: + raise ValueError(f"No devices of type '{device}' in pool") + pairs = [(k, self._locks[k]) for k in keys] + return _ReadLockGroup(pairs) + + dev = _coerce_device(device) + key = self._key(dev) + + if dev.index is None: + fam = dev.type + keys = [k for k in self._ordered_keys if k.startswith(fam)] + if not keys: + raise ValueError(f"No devices of type '{fam}' in pool") + pairs = [(k, self._locks[k]) for k in keys] + return _ReadLockGroup(pairs) + + lk = self._locks.get(key) + if lk is None: + raise ValueError(f"Unknown device for pool: {dev}") + return lk.reader() + + def lock(self, devices: Optional[Iterable[DeviceLike]] = None): + """ + Exclusive lock across multiple devices (default: all pool devices). + Acquires each device's write lock in canonical order to avoid deadlocks. + """ + if devices is None: + pairs = [(k, self._locks[k]) for k in self._ordered_keys] + else: + keys = sorted(self._normalize_scope_to_keys(devices)) + pairs = [(k, self._locks[k]) for k in keys] + return _LockGroup(pairs) + + # --------------- Public Wait API --------------- + + def wait(self, scope: Optional[Union[str, DeviceLike, Iterable[DeviceLike]]]) -> None | _WaitAndLock: + """ + Wait until in-flight tasks for `scope` drain to zero. + """ + keys = self._resolve_scope_to_keys(scope) + for k in keys: + cv = self._inflight_cv[k] + with cv: + while self._inflight[k] > 0: + cv.wait() + return None + + # --------------- Public Stats API --------------- + + def stats(self) -> Dict[str, Any]: + """Return counters snapshot: per-device and global.""" + with self._stats_lock: + return { + "per_device": dict(self._per_device_done), + "total": int(self._total_done), + "threshold": int(self._empty_cache_every_n), + } + + def device_completed(self, device: DeviceLike) -> int: + key = self._key(_coerce_device(device)) + with self._stats_lock: + return int(self._per_device_done.get(key, 0)) + + def total_completed(self) -> int: + with self._stats_lock: + return int(self._total_done) + + # --------------- Internals --------------- + + def _key(self, dev: torch.device) -> str: + idx = "" if dev.index is None else f":{dev.index}" + return f"{dev.type}{idx}" + + def _pick_worker(self, key: str) -> _DeviceWorker: + group = self._worker_groups.get(key) + if not group: + raise ValueError(f"Device {key} not part of this pool.") + + with self._dispatch_lock: + n = len(group) + if n == 0: + raise ValueError(f"No workers available for device {key}") + start = self._dispatch_rr[key] % n + idx = start + # Find the next accepting worker + for _ in range(n): + w = group[idx] + if w.is_accepting(): + self._dispatch_rr[key] = (idx + 1) % n + return w + idx = (idx + 1) % n + # If none are accepting, spawn a fresh one and use it + dev = self._devices_by_key[key] + neww = self._spawn_worker(dev, name=f"DPWorker-{key}#hot") + group.append(neww) + self._worker_groups[key] = group + self._dispatch_rr[key] = (len(group) - 1 + 1) % len(group) + return neww + + def _resolve_workers_for_device(self, dev: torch.device, table: Dict[str, int]) -> int: + key = self._key(dev) + if key in table: + return int(table[key]) + fam_key = f"{dev.type}:per" + if fam_key in table: + return int(table[fam_key]) + if dev.type in ("cpu", "mps") and dev.type in table: + return int(table[dev.type]) + return 1 + + def _normalize_scope_to_keys(self, scope: Iterable[DeviceLike]) -> List[str]: + keys: List[str] = [] + for s in scope: + if isinstance(s, str): + if s in ("all",): + keys.extend(self._ordered_keys) + elif ":" in s: + if s not in self._locks: + raise ValueError(f"Unknown device key in scope: {s}") + keys.append(s) + else: + fam = s + fam_keys = [k for k in self._ordered_keys if k.startswith(fam)] + if not fam_keys: + raise ValueError(f"No devices of type '{fam}' in pool") + keys.extend(fam_keys) + else: + dev = _coerce_device(s) + k = self._key(dev) + if k not in self._locks: + raise ValueError(f"Device not in pool: {dev}") + keys.append(k) + return keys + + def _resolve_scope_to_keys(self, scope: Optional[Union[str, DeviceLike, Iterable[DeviceLike]]]) -> List[str]: + if scope is None or (isinstance(scope, str) and scope == "all"): + return list(self._ordered_keys) + if isinstance(scope, (str, torch.device, int)): + return self._normalize_scope_to_keys([scope]) + return self._normalize_scope_to_keys(scope) + + # ---- inflight & completion accounting ---- + + def _mark_scheduled(self, key: str) -> None: + cv = self._inflight_cv[key] + with cv: + self._inflight[key] += 1 + + def _mark_finished(self, key: str) -> None: + cv = self._inflight_cv[key] + with cv: + self._inflight[key] -= 1 + if self._inflight[key] == 0: + cv.notify_all() + + def _on_task_finished(self, key: str) -> None: + self._mark_finished(key) + + trigger_gc = False + with self._stats_lock: + self._per_device_done[key] += 1 + self._total_done += 1 + dev_type = self._devices_by_key[key].type + if self._empty_cache_every_n > 0 and dev_type in ("cuda", "xpu", "mps"): + n = self._per_device_done[key] + if n % self._empty_cache_every_n == 0: + trigger_gc = True + if trigger_gc: + self._gc_event.set() + + # ---- ANSI table rendering for GC diagnostics ---- + + def _ansi_table(self, headers: List[str], rows: List[List[str]]) -> str: + widths = [len(h) for h in headers] + for r in rows: + for i, cell in enumerate(r): + widths[i] = max(widths[i], len(cell)) + + def hrule(sep_left="+", sep_mid="+", sep_right="+", h="-"): + parts = [sep_left] + for i, w in enumerate(widths): + parts.append(h * (w + 2)) + parts.append(sep_mid if i < len(widths) - 1 else sep_right) + return "".join(parts) + + def format_row(cols: List[str]): + out = ["|"] + for i, cell in enumerate(cols): + out.append(" " + cell.ljust(widths[i]) + " ") + out.append("|") + return "".join(out) + + BOLD = "\x1b[1m" + RESET = "\x1b[0m" + + top = hrule() + mid = hrule(h="=") + bot = hrule() + + lines = [top, format_row([BOLD + h + RESET for h in headers]), mid] + for r in rows: + lines.append(format_row(r)) + lines.append(bot) + return "\n".join(lines) + + def _collect_state_snapshot(self) -> Dict[str, Any]: + with self._stats_lock: + per_done = dict(self._per_device_done) + total_done = int(self._total_done) + threshold = int(self._empty_cache_every_n) + + inflight: Dict[str, int] = {} + for k, cv in self._inflight_cv.items(): + with cv: + inflight[k] = int(self._inflight[k]) + + workers = {k: len(self._worker_groups.get(k, [])) for k in self._devices_by_key.keys()} + + meta: Dict[str, Dict[str, str]] = {} + for k, dev in self._devices_by_key.items(): + idx = "" if dev.index is None else str(dev.index) + meta[k] = {"type": dev.type, "index": idx} + + snap: Dict[str, Any] = { + "devices": sorted(self._devices_by_key.keys()), + "per_done": per_done, + "total_done": total_done, + "threshold": threshold, + "inflight": inflight, + "workers": workers, + "meta": meta, + "total_inflight": sum(inflight.values()), + "total_workers": sum(workers.values()), + "gc_passes": int(self._gc_passes), + "last_gc_ts": self._last_gc_ts, + "now": time.time(), + } + return snap + + def _render_gc_table(self, snap: Dict[str, Any]) -> str: + headers = [ + "Device", "Type", "Index", "Workers", "Inflight", + "Done", "Threshold", "NextGC", "Accel" + ] + rows: List[List[str]] = [] + thr = snap["threshold"] + for k in snap["devices"]: + t = snap["meta"][k]["type"] + idx = snap["meta"][k]["index"] + w = snap["workers"].get(k, 0) + infl = snap["inflight"].get(k, 0) + done = snap["per_done"].get(k, 0) + accel = "Y" if t in ("cuda", "xpu", "mps") else "N" + if thr > 0 and t in ("cuda", "xpu", "mps"): + rem = thr - (done % thr) if (done % thr) != 0 else 0 + nextgc = "now" if rem == 0 and done > 0 else str(rem) + else: + nextgc = "-" + rows.append([k, t, idx, str(w), str(infl), str(done), str(thr) if thr > 0 else "-", nextgc, accel]) + + table_main = self._ansi_table(headers, rows) + + totals_headers = ["Total Workers", "Total Inflight", "Total Done", "GC Passes", "Since Last GC (s)"] + since = "-" if self._last_gc_ts is None else f"{time.time() - self._last_gc_ts:.3f}" + totals_rows = [[ + str(sum(len(v) for v in self._worker_groups.values())), + str(snap["total_inflight"]), + str(snap["total_done"]), + str(snap["gc_passes"]), + since, + ]] + table_totals = self._ansi_table(totals_headers, totals_rows) + return table_main + "\n" + table_totals + + # ---- janitor (global empty-cache under lock) ---- + + def _synchronize_all(self): + """ + Ensure devices are idle before empty_cache() to avoid races with outstanding kernels. + Iterate discovered devices and only guard on attribute presence for backend sync. + """ + # CUDA + for key in self._ordered_keys: + dev = self._devices_by_key[key] + if dev.type != "cuda": + continue + with torch.cuda.device(dev.index): + torch.cuda.synchronize() + + # XPU + for key in self._ordered_keys: + dev = self._devices_by_key[key] + if dev.type != "xpu": + continue + if hasattr(torch, "xpu") and hasattr(torch.xpu, "synchronize"): + with torch.xpu.device(dev.index): + torch.xpu.synchronize() + + # MPS + has_mps_device = any(self._devices_by_key[k].type == "mps" for k in self._ordered_keys) + if has_mps_device and hasattr(torch, "mps") and hasattr(torch.mps, "synchronize"): + torch.mps.synchronize() + + def _should_run_gc_from_snapshot(self, snap: Dict[str, Any]) -> bool: + """ + Decide whether to run GC by comparing per-device progress since last GC. + This deduplicates bursty triggers that occurred before the previous GC ran. + """ + thr = snap["threshold"] + if thr <= 0: + return False + for k in snap["devices"]: + dev_type = snap["meta"][k]["type"] + if dev_type not in ("cuda", "xpu", "mps"): + continue + done_now = snap["per_done"].get(k, 0) + done_prev = self._last_gc_done_per_device.get(k, 0) + if done_now - done_prev >= thr: + return True + return False + + def _update_gc_watermarks(self, snap_after: Dict[str, Any]) -> None: + """Record 'done' counters as of a GC pass.""" + for k in snap_after["devices"]: + self._last_gc_done_per_device[k] = snap_after["per_done"].get(k, 0) + + def _janitor_loop(self): + while True: + self._gc_event.wait() + if self._stop_event.is_set(): + break + + if self._gc_debounce_s > 0: + t_end = time.time() + self._gc_debounce_s + while time.time() < t_end: + self._gc_event.clear() + self._gc_event.wait(timeout=max(0.0, t_end - time.time())) + self._gc_event.clear() + else: + self._gc_event.clear() + + try: + pre = self._collect_state_snapshot() + log.debug("GC trigger received; acquiring global exclusive lock…") + except Exception as e: + try: + log.warn(f"Failed to render GC pre-snapshot: {e!r}") + except Exception: + pass + pre = { + "devices": list(self._devices_by_key.keys()), + "per_done": {k: self._per_device_done.get(k, 0) for k in self._devices_by_key.keys()}, + "threshold": self._empty_cache_every_n, + "meta": {k: {"type": self._devices_by_key[k].type} for k in self._devices_by_key.keys()}, + "inflight": dict.fromkeys(self._devices_by_key.keys(), 0), + "workers": {k: len(self._worker_groups.get(k, [])) for k in self._devices_by_key.keys()}, + "total_inflight": 0, + "total_workers": sum(len(v) for v in self._worker_groups.values()), + "gc_passes": self._gc_passes, + "last_gc_ts": self._last_gc_ts, + "now": time.time(), + "total_done": self._total_done, + } + + if not self._should_run_gc_from_snapshot(pre): + continue + + with self.lock(): # writer lock across ALL devices + t0 = time.time() + # Optional but often expensive: + # self._synchronize_all() + self._empty_all_caches() + t1 = time.time() + + self._gc_passes += 1 + self._last_gc_ts = t1 + + try: + post = self._collect_state_snapshot() + self._update_gc_watermarks(post) + log.info(f"GC completed in {t1 - t0:.3f}s (pass #{self._gc_passes}).") + except Exception as e: + try: + log.warn(f"Failed to render GC post-snapshot: {e!r}") + except Exception: + pass + + def _empty_all_caches(self): + """ + Call the captured originals if available; no redundant availability checks + and no try/except around empty_cache (fail loud if backend misbehaves). + """ + # CUDA + if TORCH_CUDA_EMPTY_CACHE is not None: + for key in self._ordered_keys: + dev = self._devices_by_key[key] + if dev.type != "cuda": + continue + with torch.cuda.device(dev.index): + TORCH_CUDA_EMPTY_CACHE() + log.debug(f"cuda empty cache called on {dev.index}") + + # XPU + if TORCH_XPU_EMPTY_CACHE is not None: + for key in self._ordered_keys: + dev = self._devices_by_key[key] + if dev.type != "xpu": + continue + with torch.xpu.device(dev.index): + TORCH_XPU_EMPTY_CACHE() + + # MPS (only if this pool actually has an MPS device) + if TORCH_MPS_EMPTY_CACHE is not None: + has_mps_device = any(self._devices_by_key[k].type == "mps" for k in self._ordered_keys) + if has_mps_device: + TORCH_MPS_EMPTY_CACHE() diff --git a/gptqmodel/utils/torch.py b/gptqmodel/utils/torch.py index 9dbd46761..580ff10f0 100644 --- a/gptqmodel/utils/torch.py +++ b/gptqmodel/utils/torch.py @@ -15,6 +15,7 @@ from ..utils.logger import setup_logger from . import gte_python_3_13_3, has_gil_disabled, log_gil_requirements_for + # pytorch 2.6.0 fixes many compilation errors TORCH_HAS_COMPILE = version.parse(torch.__version__).release >= version.Version('2.6').release TORCH_GTE_28 = version.parse(torch.__version__).release >= version.Version('2.8').release diff --git a/gptqmodel/utils/vllm.py b/gptqmodel/utils/vllm.py index 707db6a12..e7f6c49a2 100644 --- a/gptqmodel/utils/vllm.py +++ b/gptqmodel/utils/vllm.py @@ -7,6 +7,7 @@ import torch + try: from vllm import LLM, SamplingParams, TokensPrompt diff --git a/setup.py b/setup.py index ff77dd8b9..b2d2af1b5 100644 --- a/setup.py +++ b/setup.py @@ -12,6 +12,7 @@ from setuptools import find_packages, setup from setuptools.command.bdist_wheel import bdist_wheel as _bdist_wheel + # --------------------------- # Helpers (no torch required) # --------------------------- diff --git a/tests/benchmark/benchmark.py b/tests/benchmark/benchmark.py index 648e1723a..ff33989c5 100644 --- a/tests/benchmark/benchmark.py +++ b/tests/benchmark/benchmark.py @@ -4,9 +4,10 @@ # Contact: qubitium@modelcloud.ai, x.com/qubitium from benchmark_test import BenchmarkTest -from gptqmodel import BACKEND from parameterized import parameterized # noqa: E402 +from gptqmodel import BACKEND + class TestInference(BenchmarkTest): @parameterized.expand( diff --git a/tests/benchmark/benchmark_test.py b/tests/benchmark/benchmark_test.py index 20d4a5eb6..59e687bf6 100644 --- a/tests/benchmark/benchmark_test.py +++ b/tests/benchmark/benchmark_test.py @@ -6,14 +6,17 @@ import os import time + os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" import unittest # noqa: E402 -from gptqmodel import GPTQModel # noqa: E402 from logbar import LogBar +from gptqmodel import GPTQModel # noqa: E402 + + logger = LogBar.shared() class BenchmarkTest(unittest.TestCase): diff --git a/tests/inference_speed.py b/tests/inference_speed.py index 0d37a6789..c971fba7c 100644 --- a/tests/inference_speed.py +++ b/tests/inference_speed.py @@ -8,15 +8,18 @@ from gptqmodel.utils.torch import torch_empty_cache + os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" import unittest -from gptqmodel import GPTQModel from logbar import LogBar from transformers import AutoTokenizer +from gptqmodel import GPTQModel + + logger = LogBar.shared() class InferenceSpeed(unittest.TestCase): diff --git a/tests/models/model_test.py b/tests/models/model_test.py index 585252afe..62f5602d7 100644 --- a/tests/models/model_test.py +++ b/tests/models/model_test.py @@ -6,6 +6,7 @@ import os import sys + if sys.platform == "darwin": os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" @@ -19,6 +20,7 @@ from logbar import LogBar # noqa: E402 + sys.path.insert(0, f"{str(Path(__file__).resolve().parent.parent)}/models") # noqa: E402 import contextlib # noqa: E402 import shutil # noqa: E402 @@ -27,6 +29,9 @@ import torch.cuda # noqa: E402 from datasets import load_dataset # noqa: E402 +from ovis.image_to_test_dataset import get_calib_dataset # noqa: E402 +from transformers import AutoProcessor, AutoTokenizer # noqa: E402 + from gptqmodel import BACKEND, GPTQModel # noqa: E402 from gptqmodel.nn_modules.qlinear import BaseQuantLinear # noqa: E402 from gptqmodel.quantization import FORMAT, METHOD # noqa: E402 @@ -34,8 +39,7 @@ from gptqmodel.utils.eval import EVAL # noqa: E402 from gptqmodel.utils.model import MODALITY # noqa: E402 from gptqmodel.utils.torch import torch_empty_cache # noqa: E402 -from ovis.image_to_test_dataset import get_calib_dataset # noqa: E402 -from transformers import AutoProcessor, AutoTokenizer # noqa: E402 + RAND_SEED = 898 diff --git a/tests/models/test_apertus.py b/tests/models/test_apertus.py index 19f13dd0a..fa675fb77 100644 --- a/tests/models/test_apertus.py +++ b/tests/models/test_apertus.py @@ -3,9 +3,10 @@ # SPDX-License-Identifier: Apache-2.0 # Contact: qubitium@modelcloud.ai, x.com/qubitium -from gptqmodel import BACKEND from model_test import ModelTest +from gptqmodel import BACKEND + class TestApertus(ModelTest): NATIVE_MODEL_ID = "/monster/data/model/Apertus-8B-Instruct-2509/" diff --git a/tests/models/test_gptbigcode.py b/tests/models/test_gptbigcode.py index acd4cd00f..d59f0f03a 100644 --- a/tests/models/test_gptbigcode.py +++ b/tests/models/test_gptbigcode.py @@ -6,6 +6,7 @@ import importlib.util import os + # TODO: find how ipex registered it jit interpreter # if intel_extension_for_pytorch was installed, @torch.jit.script in transformers/models/gpt_bigcode/modeling_gpt_bigcode.py will try to use ipex as torchScript interpreter. # However, in quantization, tensor were on gpu, which will throw RuntimeError: itensor_view_from_dense expects CPU tensor input diff --git a/tests/models/test_llama3_2.py b/tests/models/test_llama3_2.py index 2b3adcd8b..b72e33cef 100644 --- a/tests/models/test_llama3_2.py +++ b/tests/models/test_llama3_2.py @@ -4,7 +4,6 @@ # Contact: qubitium@modelcloud.ai, x.com/qubitium from model_test import ModelTest -from gptqmodel.adapter.adapter import Lora # a100:0 @@ -24,6 +23,7 @@ class TestLlama3_2(ModelTest): DATASET_SIZE = 1024 DATASET_SORT = "desc" QUANT_BATCH_SIZE = 4 + # USE_FLASH_ATTN = False # EORA = Lora( # # for quant, path is save path. for load, it is loading path # path="./eora_test", diff --git a/tests/models/test_qwen2_5_omni.py b/tests/models/test_qwen2_5_omni.py index 44c5f3cf8..e3d3a2844 100644 --- a/tests/models/test_qwen2_5_omni.py +++ b/tests/models/test_qwen2_5_omni.py @@ -5,9 +5,10 @@ import os import soundfile as sf -from gptqmodel.models.definitions.qwen2_5_omni import Qwen2_5_OmniGPTQ from model_test import ModelTest +from gptqmodel.models.definitions.qwen2_5_omni import Qwen2_5_OmniGPTQ + class TestQwen2_5_Omni(ModelTest): NATIVE_MODEL_ID = "/monster/data/model/Qwen2.5-Omni-3B" diff --git a/tests/models/test_qwen2_5_vl.py b/tests/models/test_qwen2_5_vl.py index 0415f25c4..035a3f8ac 100644 --- a/tests/models/test_qwen2_5_vl.py +++ b/tests/models/test_qwen2_5_vl.py @@ -3,9 +3,10 @@ # SPDX-License-Identifier: Apache-2.0 # Contact: qubitium@modelcloud.ai, x.com/qubitium -from gptqmodel.models.definitions.qwen2_vl import Qwen2VLQModel from model_test import ModelTest +from gptqmodel.models.definitions.qwen2_vl import Qwen2VLQModel + class TestQwen2_VL(ModelTest): NATIVE_MODEL_ID = "/monster/data/model/Qwen2.5-VL-3B-Instruct" diff --git a/tests/models/test_qwen2_moe.py b/tests/models/test_qwen2_moe.py index 0ff90edff..cf9903ab3 100644 --- a/tests/models/test_qwen2_moe.py +++ b/tests/models/test_qwen2_moe.py @@ -1,6 +1,7 @@ import unittest import torch + from gptqmodel import BACKEND, GPTQModel diff --git a/tests/models/test_qwen2_vl.py b/tests/models/test_qwen2_vl.py index 9afb5384e..119955df9 100644 --- a/tests/models/test_qwen2_vl.py +++ b/tests/models/test_qwen2_vl.py @@ -3,9 +3,10 @@ # SPDX-License-Identifier: Apache-2.0 # Contact: qubitium@modelcloud.ai, x.com/qubitium -from gptqmodel.models.definitions.qwen2_vl import Qwen2VLQModel from model_test import ModelTest +from gptqmodel.models.definitions.qwen2_vl import Qwen2VLQModel + class TestQwen2_VL(ModelTest): NATIVE_MODEL_ID = "/monster/data/model/Qwen2-VL-2B-Instruct" diff --git a/tests/models/test_qwen3_moe.py b/tests/models/test_qwen3_moe.py index 071635414..90de5a090 100644 --- a/tests/models/test_qwen3_moe.py +++ b/tests/models/test_qwen3_moe.py @@ -11,9 +11,16 @@ class TestQwen3Moe(ModelTest): QUANT_ARC_MAX_DELTA_FLOOR_PERCENT = 0.2 NATIVE_ARC_CHALLENGE_ACC = 0.2739 NATIVE_ARC_CHALLENGE_ACC_NORM = 0.3055 - TRUST_REMOTE_CODE = True + # TRUST_REMOTE_CODE = False APPLY_CHAT_TEMPLATE = True - EVAL_BATCH_SIZE = 6 + # EVAL_BATCH_SIZE = 6 + V2 = False + DEBUG = True + ACT_GROUP_AWARE = True + DESC_ACT = False + DATASET_SIZE = 1024 + DATASET_SORT = "desc" + QUANT_BATCH_SIZE = 4 def test_mimo(self): self.quant_lm_eval() diff --git a/tests/tasks/mmlu/_generate_configs.py b/tests/tasks/mmlu/_generate_configs.py index f613f7cd4..28b94616d 100644 --- a/tests/tasks/mmlu/_generate_configs.py +++ b/tests/tasks/mmlu/_generate_configs.py @@ -9,6 +9,7 @@ import yaml from tqdm import tqdm + eval_logger = logging.getLogger("lm-eval") diff --git a/tests/test_adapter_config.py b/tests/test_adapter_config.py index 6c09017e4..dc635087a 100644 --- a/tests/test_adapter_config.py +++ b/tests/test_adapter_config.py @@ -19,11 +19,13 @@ from gptqmodel import QuantizeConfig from gptqmodel.adapter.adapter import Lora, normalize_adapter + os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" # -- end do not touch import unittest # noqa: E402 + lora = "lora" class TestExtensionConfig(unittest.TestCase): diff --git a/tests/test_asym_gptq_v1.py b/tests/test_asym_gptq_v1.py index fe80ac6c5..f58a045fd 100644 --- a/tests/test_asym_gptq_v1.py +++ b/tests/test_asym_gptq_v1.py @@ -6,11 +6,13 @@ # -- do not touch import os + os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" -from gptqmodel.quantization import FORMAT # noqa: E402 # -- end do not touch from models.model_test import ModelTest # noqa: E402 +from gptqmodel.quantization import FORMAT # noqa: E402 + class Test(ModelTest): NATIVE_MODEL_ID = "/monster/data/model/Llama-3.2-1B-Instruct" # "meta-llama/Llama-3.2-1B-Instruct" diff --git a/tests/test_awq.py b/tests/test_awq.py index 9e200ab1b..be0e1d1b1 100644 --- a/tests/test_awq.py +++ b/tests/test_awq.py @@ -6,17 +6,21 @@ import unittest from datasets import load_dataset +from parameterized import parameterized +from transformers import AutoTokenizer + from gptqmodel.nn_modules.qlinear.awq_gemm import AwqGEMMQuantLinear from gptqmodel.quantization import FORMAT, METHOD, QUANT_CONFIG_FILENAME from gptqmodel.utils.torch import torch_empty_cache -from parameterized import parameterized -from transformers import AutoTokenizer + os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" -from gptqmodel import BACKEND, GPTQModel, QuantizeConfig # noqa: E402 # -- end do not touch from logbar import LogBar +from gptqmodel import BACKEND, GPTQModel, QuantizeConfig # noqa: E402 + + log = LogBar.shared() diff --git a/tests/test_awq_moe.py b/tests/test_awq_moe.py index 015dbca57..2d5a316a7 100644 --- a/tests/test_awq_moe.py +++ b/tests/test_awq_moe.py @@ -6,17 +6,21 @@ import unittest from datasets import load_dataset +from parameterized import parameterized +from transformers import AutoTokenizer + from gptqmodel.nn_modules.qlinear.awq_gemm import AwqGEMMQuantLinear from gptqmodel.quantization import FORMAT, METHOD, QUANT_CONFIG_FILENAME from gptqmodel.utils.torch import torch_empty_cache -from parameterized import parameterized -from transformers import AutoTokenizer + os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" -from gptqmodel import BACKEND, GPTQModel, QuantizeConfig # noqa: E402 # -- end do not touch from logbar import LogBar +from gptqmodel import BACKEND, GPTQModel, QuantizeConfig # noqa: E402 + + log = LogBar.shared() class TestGroupSize(unittest.TestCase): diff --git a/tests/test_bits.py b/tests/test_bits.py index e2aaea6d5..5297fa637 100644 --- a/tests/test_bits.py +++ b/tests/test_bits.py @@ -8,12 +8,16 @@ from gptqmodel.nn_modules.qlinear.exllama_eora import ExllamaEoraQuantLinear + os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" # -- end do not touch import logging # noqa: E402 import tempfile # noqa: E402 import unittest # noqa: E402 +from lm_eval.utils import make_table # noqa: E402 +from transformers import AutoTokenizer # noqa: E402 + from gptqmodel import BACKEND, GPTQModel, QuantizeConfig # noqa: E402 from gptqmodel.nn_modules.qlinear.bitblas import BitBLASQuantLinear # noqa: E402 from gptqmodel.nn_modules.qlinear.exllama import ExllamaQuantLinear # noqa: E402 @@ -22,8 +26,7 @@ from gptqmodel.nn_modules.qlinear.torch import TorchQuantLinear # noqa: E402 from gptqmodel.nn_modules.qlinear.tritonv2 import TritonV2QuantLinear # noqa: E402 from gptqmodel.utils.eval import EVAL # noqa: E402 -from lm_eval.utils import make_table # noqa: E402 -from transformers import AutoTokenizer # noqa: E402 + logger = logging.getLogger(__name__) diff --git a/tests/test_bits_new.py b/tests/test_bits_new.py index 3f9e34b5b..12b530660 100644 --- a/tests/test_bits_new.py +++ b/tests/test_bits_new.py @@ -16,6 +16,7 @@ # -- do not touch import os + os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" # -- end do not touch @@ -23,13 +24,14 @@ from typing import Optional # noqa: E402 from datasets import load_dataset # noqa: E402 +from lm_eval.utils import make_table # noqa: E402 +from models.model_test import ModelTest # noqa: E402 +from tabulate import tabulate # noqa: E402 + from gptqmodel import BACKEND, GPTQModel, QuantizeConfig # noqa: E402 from gptqmodel.adapter.adapter import Lora # noqa: E402 from gptqmodel.utils.eval import EVAL # noqa: E402 from gptqmodel.utils.torch import torch_empty_cache # noqa: E402 -from lm_eval.utils import make_table # noqa: E402 -from models.model_test import ModelTest # noqa: E402 -from tabulate import tabulate # noqa: E402 def bench(path: str, backend: BACKEND, adapter: Optional[Lora]): diff --git a/tests/test_cpu_gpu_memory_copy.py b/tests/test_cpu_gpu_memory_copy.py index d79168896..acb58f535 100644 --- a/tests/test_cpu_gpu_memory_copy.py +++ b/tests/test_cpu_gpu_memory_copy.py @@ -21,8 +21,10 @@ import argparse import math import time + import torch + def gib_to_elems_fp16(gib: float) -> int: # 1 GiB = 1024**3 bytes; fp16 = 2 bytes/elem return int((gib * (1024**3)) // 2) diff --git a/tests/test_cuda_stream.py b/tests/test_cuda_stream.py new file mode 100644 index 000000000..ebffa8b9a --- /dev/null +++ b/tests/test_cuda_stream.py @@ -0,0 +1,138 @@ +# test_d2h_concurrency.py +# pytest -q -s test_d2h_concurrency.py +import math +import time + +import pytest +import torch + + +pytestmark = pytest.mark.skipif( + not torch.cuda.is_available(), reason="CUDA required for this test" +) + +def _mb(nbytes): return nbytes / (1024**2) + +def _banner(msg): + print("\n" + "=" * 80) + print(msg) + print("=" * 80) + +def test_three_d2h_transfers_concurrency_vs_serial(): + dev = torch.device("cuda", 0) + props = torch.cuda.get_device_properties(dev) + _banner( + f"GPU: {props.name} | asyncEngineCount={getattr(props, 'asyncEngineCount', 'n/a')} | " + f"PCIe/Link: unknown (PyTorch doesn't expose)\n" + "Expectation: multiple D2H on a single GPU serialize onto one D2H engine." + ) + + torch.cuda.set_device(dev) + + # Use a size large enough to dominate overhead but not stress CI. + # ~256 MiB each => 3 * 256 MiB = 768 MiB total device RAM + pinned host buffers. + elements = (256 * 1024 * 1024) // 2 # fp16 => 2 bytes/elt + dtype = torch.float16 + + # Device tensors + d0 = torch.empty(elements, dtype=dtype, device=dev) + d1 = torch.empty_like(d0) + d2 = torch.empty_like(d0) + + # Pinned host buffers (required for async copies) + h0 = torch.empty_like(d0, device="cpu", pin_memory=True) + h1 = torch.empty_like(d1, device="cpu", pin_memory=True) + h2 = torch.empty_like(d2, device="cpu", pin_memory=True) + + # Warmup: one D2H copy to touch paths + h0.copy_(d0, non_blocking=True) + torch.cuda.synchronize() + + # --- Serialized on a single stream --- + s_serial = torch.cuda.Stream() + torch.cuda.synchronize() + t0 = time.perf_counter() + with torch.cuda.stream(s_serial): + h0.copy_(d0, non_blocking=True) + h1.copy_(d1, non_blocking=True) + h2.copy_(d2, non_blocking=True) + torch.cuda.synchronize() + serial_time = time.perf_counter() - t0 + + # --- Launched concurrently on three streams --- + s0, s1, s2 = torch.cuda.Stream(), torch.cuda.Stream(), torch.cuda.Stream() + torch.cuda.synchronize() + t1 = time.perf_counter() + with torch.cuda.stream(s0): + h0.copy_(d0, non_blocking=True) + with torch.cuda.stream(s1): + h1.copy_(d1, non_blocking=True) + with torch.cuda.stream(s2): + h2.copy_(d2, non_blocking=True) + torch.cuda.synchronize() + concurrent_time = time.perf_counter() - t1 + + total_mb = 3 * _mb(d0.numel() * d0.element_size()) + print(f"\nTransferred total ~{total_mb:.1f} MiB (3 x ~{total_mb/3:.1f} MiB) D2H") + print(f"[SERIAL] {serial_time:.4f} s | ~{total_mb/serial_time:.1f} MiB/s effective") + print(f"[CONCURRENT] {concurrent_time:.4f} s | ~{total_mb/concurrent_time:.1f} MiB/s effective") + + # We expect little to no speedup when "concurrent" (same-direction copies share the D2H engine). + # Allow some tolerance either way depending on driver/runtime details. + assert concurrent_time >= 0.8 * serial_time, ( + "Unexpected large speedup from concurrent D2H; " + "this would contradict single-engine D2H behavior." + ) + assert concurrent_time <= 1.3 * serial_time, ( + "Concurrent D2H took much longer than serialized; " + "this suggests overheads far above expectation." + ) + +def test_h2d_d2h_bidirectional_overlap_possible(): + """Optional: demonstrate one H2D can overlap one D2H if GPU has ≥2 copy engines.""" + dev = torch.device("cuda", 0) + props = torch.cuda.get_device_properties(dev) + if getattr(props, "asyncEngineCount", 0) < 2: + pytest.skip("GPU reports <2 copy engines; bidirectional overlap unlikely.") + + torch.cuda.set_device(dev) + + elements = (128 * 1024 * 1024) // 1 # 128 MiB in bytes (uint8) + dtype = torch.uint8 + + # Host buffers (pinned) and device tensors + h_src = torch.empty(elements, dtype=dtype, device="cpu", pin_memory=True) + h_dst = torch.empty(elements, dtype=dtype, device="cpu", pin_memory=True) + d_buf = torch.empty(elements, dtype=dtype, device=dev) + + # Warmup + d_buf.copy_(h_src, non_blocking=True) + h_dst.copy_(d_buf, non_blocking=True) + torch.cuda.synchronize() + + # Baseline: serialize H2D then D2H on one stream + s = torch.cuda.Stream() + torch.cuda.synchronize() + t0 = time.perf_counter() + with torch.cuda.stream(s): + d_buf.copy_(h_src, non_blocking=True) # H2D + h_dst.copy_(d_buf, non_blocking=True) # D2H + torch.cuda.synchronize() + serial = time.perf_counter() - t0 + + # Overlap: H2D on one stream, D2H on another (should overlap on separate engines) + sh2d, sd2h = torch.cuda.Stream(), torch.cuda.Stream() + torch.cuda.synchronize() + t1 = time.perf_counter() + with torch.cuda.stream(sh2d): + d_buf.copy_(h_src, non_blocking=True) # H2D + with torch.cuda.stream(sd2h): + h_dst.copy_(d_buf, non_blocking=True) # D2H + torch.cuda.synchronize() + overlapped = time.perf_counter() - t1 + + print(f"\n[H2D->D2H] SERIAL {serial:.4f} s") + print(f"[H2D||D2H] OVERLAP {overlapped:.4f} s (expect <= ~serial)") + + # Expect some overlap benefit (not necessarily 2x). + assert overlapped <= 0.9 * serial or math.isclose(overlapped, serial, rel_tol=0.05) diff --git a/tests/test_dynamic.py b/tests/test_dynamic.py index 0bc638908..0ea3acabe 100644 --- a/tests/test_dynamic.py +++ b/tests/test_dynamic.py @@ -8,11 +8,15 @@ from gptqmodel.nn_modules.qlinear.torch import TorchQuantLinear + os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" # -- end do not touch import json # noqa: E402 import tempfile # noqa: E402 +from models.model_test import ModelTest # noqa: E402 +from parameterized import parameterized # noqa: E402 + from gptqmodel import BACKEND, GPTQModel # noqa: E402 from gptqmodel.nn_modules.qlinear import BaseQuantLinear # noqa: E402 from gptqmodel.nn_modules.qlinear.marlin import MarlinQuantLinear # noqa: E402 @@ -20,8 +24,6 @@ from gptqmodel.quantization import QuantizeConfig # noqa: E402 from gptqmodel.utils import safetensor # noqa: E402 from gptqmodel.utils.perplexity import Perplexity # noqa: E402 -from models.model_test import ModelTest # noqa: E402 -from parameterized import parameterized # noqa: E402 class TestDynamic(ModelTest): diff --git a/tests/test_estimate_vram.py b/tests/test_estimate_vram.py index 8b6153b6e..918d5a544 100644 --- a/tests/test_estimate_vram.py +++ b/tests/test_estimate_vram.py @@ -6,6 +6,7 @@ # -- do not touch import os + os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" # -- end do not touch import unittest # noqa: E402 diff --git a/tests/test_eval.py b/tests/test_eval.py index 34a93c201..478aa36ae 100644 --- a/tests/test_eval.py +++ b/tests/test_eval.py @@ -5,18 +5,22 @@ import os + os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" import tempfile # noqa: E402 -from typing import Type # noqa: E402 -from typing import Union # noqa: E402 +from typing import ( + Type, # noqa: E402 + Union, # noqa: E402 +) -from gptqmodel import GPTQModel # noqa: E402 -from gptqmodel.utils.eval import EVAL # noqa: E402 from lm_eval.tasks import TaskManager # noqa: E402 from models.model_test import ModelTest # noqa: E402 from parameterized import parameterized # noqa: E402 +from gptqmodel import GPTQModel # noqa: E402 +from gptqmodel.utils.eval import EVAL # noqa: E402 + class TestEval(ModelTest): @classmethod diff --git a/tests/test_evalplus.py b/tests/test_evalplus.py index cc1c53da4..62181595b 100644 --- a/tests/test_evalplus.py +++ b/tests/test_evalplus.py @@ -6,6 +6,7 @@ # -- do not touch import os + os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" # -- end do not touch diff --git a/tests/test_flash_attention.py b/tests/test_flash_attention.py index 19ada8aab..69cb39078 100644 --- a/tests/test_flash_attention.py +++ b/tests/test_flash_attention.py @@ -6,14 +6,16 @@ # -- do not touch import os + os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" # -- end do not touch -from gptqmodel import GPTQModel # noqa: E402 from models.model_test import ModelTest # noqa: E402 from transformers import AutoTokenizer # noqa: E402 +from gptqmodel import GPTQModel # noqa: E402 + class Test(ModelTest): diff --git a/tests/test_gpu_gpu_memory_copy.py b/tests/test_gpu_gpu_memory_copy.py index 73df8cb7f..1901b2a9c 100644 --- a/tests/test_gpu_gpu_memory_copy.py +++ b/tests/test_gpu_gpu_memory_copy.py @@ -18,10 +18,12 @@ # - For accurate timing we synchronize before/after and use perf_counter. import argparse -import time import math +import time + import torch + def gib_to_elems_fp16(gib: float) -> int: # 1 GiB = 1024**3 bytes; fp16 = 2 bytes/elem return int((gib * (1024**3)) // 2) diff --git a/tests/test_group_size.py b/tests/test_group_size.py index 23e76bb92..32380ece0 100644 --- a/tests/test_group_size.py +++ b/tests/test_group_size.py @@ -6,6 +6,7 @@ # -- do not touch import os + os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" # -- end do not touch import logging # noqa: E402 @@ -13,6 +14,9 @@ import traceback # noqa: E402 import unittest # noqa: E402 +from lm_eval.utils import make_table # noqa: E402 +from transformers import AutoTokenizer # noqa: E402 + from gptqmodel import BACKEND, GPTQModel, QuantizeConfig # noqa: E402 from gptqmodel.nn_modules.qlinear.bitblas import BitBLASQuantLinear # noqa: E402 from gptqmodel.nn_modules.qlinear.exllama import ExllamaQuantLinear # noqa: E402 @@ -21,8 +25,7 @@ from gptqmodel.nn_modules.qlinear.torch import TorchQuantLinear # noqa: E402 from gptqmodel.nn_modules.qlinear.tritonv2 import TritonV2QuantLinear # noqa: E402 from gptqmodel.utils.eval import EVAL # noqa: E402 -from lm_eval.utils import make_table # noqa: E402 -from transformers import AutoTokenizer # noqa: E402 + logger = logging.getLogger(__name__) diff --git a/tests/test_inference_result_xpu.py b/tests/test_inference_result_xpu.py index 870897023..3498ea45c 100644 --- a/tests/test_inference_result_xpu.py +++ b/tests/test_inference_result_xpu.py @@ -5,14 +5,16 @@ import os + os.environ["CUDA_VISIBLE_DEVICES"] = "" import tempfile -from gptqmodel import BACKEND, GPTQModel, QuantizeConfig -from gptqmodel.models._const import DEVICE from models.model_test import ModelTest from parameterized import parameterized +from gptqmodel import BACKEND, GPTQModel, QuantizeConfig +from gptqmodel.models._const import DEVICE + class TestInferenceResultXPU(ModelTest): NATIVE_MODEL_ID = "/monster/data/model/Llama-3.2-1B-Instruct" diff --git a/tests/test_inference_speed.py b/tests/test_inference_speed.py index 64bca6f6e..fe76c2f47 100644 --- a/tests/test_inference_speed.py +++ b/tests/test_inference_speed.py @@ -6,12 +6,15 @@ # -- do not touch import os + os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" -from gptqmodel.utils import BACKEND # noqa: E402 # -- end do not touch from inference_speed import InferenceSpeed # noqa: E402 from parameterized import parameterized # noqa: E402 +from gptqmodel.utils import BACKEND # noqa: E402 + + ''' NATIVE_MODEL_ID = /monster/data/model/Llama-3.2-1B-Instruct-gptqmodel-4bit-vortext-v1 BITBLAS_NATIVE_MODEL_ID = /monster/data/model/opt-125M-autoround-lm_head-false-symTrue diff --git a/tests/test_inference_speed_torch_fused.py b/tests/test_inference_speed_torch_fused.py index 609c612ff..a038553c6 100644 --- a/tests/test_inference_speed_torch_fused.py +++ b/tests/test_inference_speed_torch_fused.py @@ -6,13 +6,15 @@ # -- do not touch import os + os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" # -- end do not touch -from gptqmodel.utils import BACKEND # noqa: E402 from inference_speed import InferenceSpeed # noqa: E402 from parameterized import parameterized # noqa: E402 +from gptqmodel.utils import BACKEND # noqa: E402 + class TestInferenceSpeedTorchFused(InferenceSpeed): @parameterized.expand( diff --git a/tests/test_integration.py b/tests/test_integration.py index f17036b3f..03fba31ac 100644 --- a/tests/test_integration.py +++ b/tests/test_integration.py @@ -4,13 +4,15 @@ # Contact: qubitium@modelcloud.ai, x.com/qubitium import os + os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" import tempfile # noqa: E402 import unittest # noqa: E402 -from gptqmodel.utils.torch import torch_empty_cache # noqa: E402 from transformers import AutoModelForCausalLM, AutoTokenizer, GPTQConfig # noqa: E402 +from gptqmodel.utils.torch import torch_empty_cache # noqa: E402 + class TestIntegration(unittest.TestCase): INFERENCE_PROMPT = "Which city is the capital of France? The city name is " diff --git a/tests/test_ipex_xpu.py b/tests/test_ipex_xpu.py index e21c56175..10993de49 100644 --- a/tests/test_ipex_xpu.py +++ b/tests/test_ipex_xpu.py @@ -6,14 +6,16 @@ # -- do not touch import os + os.environ["CUDA_VISIBLE_DEVICES"] = "" # -- end do not touch import tempfile # noqa: E402 +from models.model_test import ModelTest # noqa: E402 + from gptqmodel import BACKEND, GPTQModel, QuantizeConfig # noqa: E402 from gptqmodel.models._const import DEVICE # noqa: E402 -from models.model_test import ModelTest # noqa: E402 class TestsTorchFused(ModelTest): diff --git a/tests/test_kernel_output.py b/tests/test_kernel_output.py index e4e4ceb9c..c4b4652c1 100644 --- a/tests/test_kernel_output.py +++ b/tests/test_kernel_output.py @@ -1,6 +1,10 @@ import unittest import torch +from logbar import LogBar +from parameterized import parameterized +from torch import Tensor + from gptqmodel import BACKEND, GPTQModel from gptqmodel.adapter.adapter import Adapter, AdapterCache, Lora from gptqmodel.nn_modules.qlinear.exllama_eora import ExllamaEoraQuantLinear @@ -9,9 +13,7 @@ from gptqmodel.nn_modules.qlinear.torch import TorchQuantLinear from gptqmodel.nn_modules.qlinear.tritonv2 import TritonV2QuantLinear from gptqmodel.utils.model import find_modules -from logbar import LogBar -from parameterized import parameterized -from torch import Tensor + log = LogBar.shared() diff --git a/tests/test_kernel_output_torch_fused.py b/tests/test_kernel_output_torch_fused.py index c4a0e460b..bd38e4c37 100644 --- a/tests/test_kernel_output_torch_fused.py +++ b/tests/test_kernel_output_torch_fused.py @@ -1,13 +1,15 @@ import unittest import torch +from logbar import LogBar +from parameterized import parameterized +from torch import Tensor + from gptqmodel import BACKEND, GPTQModel from gptqmodel.nn_modules.qlinear.torch import TorchQuantLinear from gptqmodel.nn_modules.qlinear.torch_fused import TorchFusedQuantLinear from gptqmodel.utils.model import find_modules -from logbar import LogBar -from parameterized import parameterized -from torch import Tensor + log = LogBar.shared() diff --git a/tests/test_linalg.py b/tests/test_linalg.py index 2a639c772..31fec6585 100644 --- a/tests/test_linalg.py +++ b/tests/test_linalg.py @@ -6,6 +6,7 @@ # -- do not touch import os + os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" # -- end do not touch import unittest # noqa: E402 @@ -14,6 +15,7 @@ from logbar import LogBar # noqa: E402 from parameterized import parameterized # noqa: E402 + log = LogBar.shared() ROCM = torch.device("cuda:0") # fake cuda diff --git a/tests/test_lm_eval.py b/tests/test_lm_eval.py index c0b810ced..2e0653621 100644 --- a/tests/test_lm_eval.py +++ b/tests/test_lm_eval.py @@ -6,14 +6,16 @@ # -- do not touch import os + os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" # -- end do not touch import tempfile # noqa: E402 import unittest # noqa: E402 +from lm_eval.utils import make_table # noqa: E402 + from gptqmodel import BACKEND, GPTQModel from gptqmodel.utils.eval import EVAL # noqa: E402 -from lm_eval.utils import make_table # noqa: E402 class TestLmEval(unittest.TestCase): diff --git a/tests/test_lm_head.py b/tests/test_lm_head.py index db60b3c51..613a9f4ac 100644 --- a/tests/test_lm_head.py +++ b/tests/test_lm_head.py @@ -9,12 +9,14 @@ from datasets import load_dataset + os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" -from gptqmodel import GPTQModel, QuantizeConfig # noqa: E402 -from gptqmodel.nn_modules.qlinear import BaseQuantLinear # noqa: E402 # -- end do not touch from models.model_test import ModelTest # noqa: E402 +from gptqmodel import GPTQModel, QuantizeConfig # noqa: E402 +from gptqmodel.nn_modules.qlinear import BaseQuantLinear # noqa: E402 + class TestLmHeadLoad(ModelTest): NATIVE_MODEL_ID = "/monster/data/model/TinyLlama-1.1B-intermediate-step-1341k-3T-autoround-lm_head-symFalse" # "LnL-AI/TinyLlama-1.1B-intermediate-step-1341k-3T-autoround-lm_head-symFalse" diff --git a/tests/test_lora.py b/tests/test_lora.py index 6811fcfba..36dd8c07a 100644 --- a/tests/test_lora.py +++ b/tests/test_lora.py @@ -16,14 +16,16 @@ # -- do not touch import os + os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" # -- end do not touch -from gptqmodel import BACKEND, GPTQModel # noqa: E402 -from gptqmodel.adapter.adapter import Lora # noqa: E402 from models.model_test import ModelTest # noqa: E402 from parameterized import parameterized # noqa: E402 +from gptqmodel import BACKEND, GPTQModel # noqa: E402 +from gptqmodel.adapter.adapter import Lora # noqa: E402 + class Test(ModelTest): NATIVE_MODEL_ID = "/monster/data/model/sliuau-llama3.2-1b-4bit-group128" diff --git a/tests/test_mlx.py b/tests/test_mlx.py index f854aa67d..5dfe1deb2 100644 --- a/tests/test_mlx.py +++ b/tests/test_mlx.py @@ -1,6 +1,7 @@ import os import sys + os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" if sys.platform == "darwin": @@ -8,11 +9,12 @@ import tempfile # noqa: E402 -from gptqmodel import GPTQModel # noqa: E402 from mlx_lm import generate, load # noqa: E402 from models.model_test import ModelTest # noqa: E402 from transformers import AutoTokenizer # noqa: E402 +from gptqmodel import GPTQModel # noqa: E402 + class TestExport(ModelTest): NATIVE_MODEL_ID = "/monster/data/model/Qwen2.5-0.5B-Instruct/gptq_4bits_01-07_14-18-11_maxlen1024_ns1024_descFalse_damp0.1/" diff --git a/tests/test_mlx_generate.py b/tests/test_mlx_generate.py index f3484bfe1..f8581101b 100644 --- a/tests/test_mlx_generate.py +++ b/tests/test_mlx_generate.py @@ -1,14 +1,17 @@ import os + os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" import sys # noqa: E402 + if sys.platform == "darwin": os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" -from gptqmodel import BACKEND, GPTQModel # noqa: E402 from models.model_test import ModelTest # noqa: E402 +from gptqmodel import BACKEND, GPTQModel # noqa: E402 + class TestMlxGenerate(ModelTest): @classmethod diff --git a/tests/test_mmlupro.py b/tests/test_mmlupro.py index adacf66d3..9c0aaaf7d 100644 --- a/tests/test_mmlupro.py +++ b/tests/test_mmlupro.py @@ -9,6 +9,7 @@ from gptqmodel import GPTQModel from gptqmodel.utils.eval import EVAL + # os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" # os.environ["CUDA_VISIBLE_DEVICES"] = "7" diff --git a/tests/test_modelscope.py b/tests/test_modelscope.py index 22fcf2663..7214a86b4 100644 --- a/tests/test_modelscope.py +++ b/tests/test_modelscope.py @@ -1,9 +1,11 @@ import os + os.environ["GPTQMODEL_USE_MODELSCOPE"] = "True" -from gptqmodel import GPTQModel # noqa: E402 from models.model_test import ModelTest # noqa: E402 +from gptqmodel import GPTQModel # noqa: E402 + class TestLoadModelscope(ModelTest): diff --git a/tests/test_multi_gpu_inference.py b/tests/test_multi_gpu_inference.py index a779ceae7..c29cef520 100644 --- a/tests/test_multi_gpu_inference.py +++ b/tests/test_multi_gpu_inference.py @@ -9,6 +9,7 @@ import torch from transformers import AutoTokenizer + os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" # -- end do not touch diff --git a/tests/test_olora_finetuning_xpu.py b/tests/test_olora_finetuning_xpu.py index 6423c2db5..755fdebf5 100644 --- a/tests/test_olora_finetuning_xpu.py +++ b/tests/test_olora_finetuning_xpu.py @@ -17,6 +17,7 @@ from logbar import LogBar + os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" import tempfile # noqa: E402 @@ -26,11 +27,19 @@ import torch # noqa: E402 import transformers # noqa: E402 from datasets import load_dataset # noqa: E402 -from gptqmodel import BACKEND # noqa: E402 from peft import AdaLoraConfig, get_peft_model # noqa: E402 from tokenicer import Tokenicer # noqa: E402 -from transformers import (AutoModelForCausalLM, GPTQConfig, TrainerCallback, # noqa: E402 - TrainerControl, TrainerState, TrainingArguments) +from transformers import ( # noqa: E402 + AutoModelForCausalLM, + GPTQConfig, + TrainerCallback, + TrainerControl, + TrainerState, + TrainingArguments, +) + +from gptqmodel import BACKEND # noqa: E402 + DEVICE = torch.device("cuda:0") diff --git a/tests/test_openai_server.py b/tests/test_openai_server.py index a7ada7d77..03a731c3f 100644 --- a/tests/test_openai_server.py +++ b/tests/test_openai_server.py @@ -7,8 +7,10 @@ import unittest import openai + from gptqmodel import GPTQModel + os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" class TestOpeniServer(unittest.TestCase): diff --git a/tests/test_p2p.py b/tests/test_p2p.py index 1c1f318ac..83d402090 100644 --- a/tests/test_p2p.py +++ b/tests/test_p2p.py @@ -5,6 +5,7 @@ import torch + def main(): if not torch.cuda.is_available(): print("CUDA not available") diff --git a/tests/test_packable.py b/tests/test_packable.py index ef68a8849..c70aac53a 100644 --- a/tests/test_packable.py +++ b/tests/test_packable.py @@ -3,6 +3,10 @@ from typing import Dict import torch +from logbar import LogBar +from parameterized import parameterized +from safetensors.torch import load_file + from gptqmodel import BACKEND, GPTQModel from gptqmodel.nn_modules.qlinear.exllama import ExllamaQuantLinear # noqa: E402 from gptqmodel.nn_modules.qlinear.exllama_eora import ExllamaEoraQuantLinear @@ -12,9 +16,7 @@ from gptqmodel.nn_modules.qlinear.torch_fused import TorchFusedQuantLinear from gptqmodel.nn_modules.qlinear.tritonv2 import TritonV2QuantLinear # noqa: E402 from gptqmodel.utils.model import convert_gptq_v2_to_v1_format, find_modules -from logbar import LogBar -from parameterized import parameterized -from safetensors.torch import load_file + log = LogBar.shared() diff --git a/tests/test_packing.py b/tests/test_packing.py index 913e4b94b..8823a9dd4 100644 --- a/tests/test_packing.py +++ b/tests/test_packing.py @@ -6,18 +6,21 @@ # -- do not touch import os + os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" # -- end do not touch import unittest # noqa: E402 + # isort: off import torch # noqa: E402 import torch.nn as nn # noqa: E402 # isort: on +from parameterized import parameterized # noqa: E402 + from gptqmodel import BACKEND # noqa: E402 from gptqmodel.nn_modules.qlinear.torch import TorchQuantLinear # noqa: E402 -from parameterized import parameterized # noqa: E402 def gen_quant(k: int, n: int, groupsize: int, bits: int): diff --git a/tests/test_packing_speed.py b/tests/test_packing_speed.py index c4cb061b6..6ba824c48 100644 --- a/tests/test_packing_speed.py +++ b/tests/test_packing_speed.py @@ -8,6 +8,7 @@ from gptqmodel import BACKEND + os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" # -- end do not touch @@ -17,6 +18,7 @@ import threadpoolctl # noqa: E402 from parameterized import parameterized # noqa: E402 + # isort: off import torch # noqa: E402 import torch.nn as nn # noqa: E402 diff --git a/tests/test_parameter_count.py b/tests/test_parameter_count.py index f74d1996f..2e9bbe742 100644 --- a/tests/test_parameter_count.py +++ b/tests/test_parameter_count.py @@ -2,11 +2,12 @@ import tempfile import torch.cuda -from gptqmodel import GPTQModel, QuantizeConfig -from gptqmodel.utils.tensor import tensor_parameters from models.model_test import ModelTest from safetensors.torch import load_file +from gptqmodel import GPTQModel, QuantizeConfig +from gptqmodel.utils.tensor import tensor_parameters + class TestsParameterCount(ModelTest): LLAMA_3_2_1B_PARAMETER_COUNT = 1235814400 @@ -19,11 +20,12 @@ class TestsParameterCount(ModelTest): def test_parameter_count(self): import os.path - from gptqmodel import QuantizeConfig - from gptqmodel.utils.tensor import tensor_parameters from huggingface_hub import hf_hub_download from safetensors.torch import load_file + from gptqmodel import QuantizeConfig + from gptqmodel.utils.tensor import tensor_parameters + model_id = "/monster/data/model/Llama-3.2-1B-Instruct-gptqmodel-4bit-vortex-v1" if os.path.isdir(model_id): file_path = os.path.join(model_id, "model.safetensors") diff --git a/tests/test_perplexity.py b/tests/test_perplexity.py index 11ff0f7b9..66acd7968 100644 --- a/tests/test_perplexity.py +++ b/tests/test_perplexity.py @@ -7,6 +7,7 @@ import os import time + os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" # -- end do not touch @@ -14,13 +15,14 @@ import unittest # noqa: E402 from datasets import load_dataset # noqa: E402 +from parameterized import parameterized # noqa: E402 +from transformers import AutoModelForCausalLM, AutoTokenizer # noqa: E402 + from gptqmodel import BACKEND, GPTQModel # noqa: E402 from gptqmodel.quantization.config import FORMAT, METHOD, QuantizeConfig # noqa: E402 from gptqmodel.utils.perplexity import Perplexity # noqa: E402 from gptqmodel.utils.rocm import IS_ROCM # noqa: E402 from gptqmodel.utils.torch import torch_empty_cache # noqa: E402 -from parameterized import parameterized # noqa: E402 -from transformers import AutoModelForCausalLM, AutoTokenizer # noqa: E402 class TestPerplexity(unittest.TestCase): diff --git a/tests/test_post_quant_eora.py b/tests/test_post_quant_eora.py index 422f19d95..6c18300ab 100644 --- a/tests/test_post_quant_eora.py +++ b/tests/test_post_quant_eora.py @@ -16,6 +16,7 @@ # -- do not touch import os + os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" # -- end do not touch @@ -23,11 +24,12 @@ from typing import Optional # noqa: E402 from datasets import load_dataset +from models.model_test import ModelTest # noqa: E402 + from gptqmodel import BACKEND, GPTQModel # noqa: E402 from gptqmodel.adapter.adapter import Lora # noqa: E402 from gptqmodel.utils.eval import EVAL # noqa: E402 from gptqmodel.utils.torch import torch_empty_cache # noqa: E402 -from models.model_test import ModelTest # noqa: E402 def bench(path: str, backend: BACKEND, adapter: Optional[Lora]): diff --git a/tests/test_q4_bitblas.py b/tests/test_q4_bitblas.py index d75d02408..329ebcc81 100644 --- a/tests/test_q4_bitblas.py +++ b/tests/test_q4_bitblas.py @@ -6,15 +6,17 @@ # -- do not touch import os + os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" # -- end do not touch import unittest # noqa: E402 import torch # noqa: E402 +from transformers import AutoTokenizer # noqa: E402 + from gptqmodel import BACKEND, GPTQModel # noqa: E402 from gptqmodel.nn_modules.qlinear.bitblas import BitBLASQuantLinear # noqa: E402 -from transformers import AutoTokenizer # noqa: E402 class TestQ4BitBLAS(unittest.TestCase): diff --git a/tests/test_q4_exllama_v1.py b/tests/test_q4_exllama_v1.py index f61866d50..de265b67f 100644 --- a/tests/test_q4_exllama_v1.py +++ b/tests/test_q4_exllama_v1.py @@ -6,20 +6,23 @@ # -- do not touch import os + os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" # -- end do not touch import torch # noqa: E402 +from gptqmodel_exllama_kernels import prepare_buffers, set_tuning_params # noqa: E402 +from models.model_test import ModelTest # noqa: E402 +from transformers import AutoTokenizer # noqa: E402 + from gptqmodel import BACKEND, GPTQModel, exllama_set_max_input_length # noqa: E402 from gptqmodel.models._const import EXLLAMA_DEFAULT_MAX_INPUT_LENGTH # noqa: E402 from gptqmodel.nn_modules.qlinear.exllama import ExllamaQuantLinear # noqa: E402 from gptqmodel.quantization import FORMAT, METHOD # noqa: E402 from gptqmodel.utils.importer import select_quant_linear # noqa: E402 from gptqmodel.utils.model import gptqmodel_post_init # noqa: E402 -from gptqmodel_exllama_kernels import prepare_buffers, set_tuning_params # noqa: E402 -from models.model_test import ModelTest # noqa: E402 -from transformers import AutoTokenizer # noqa: E402 + REFERENCE = torch.Tensor( [ diff --git a/tests/test_q4_exllama_v2.py b/tests/test_q4_exllama_v2.py index 458b7419d..4218e787b 100644 --- a/tests/test_q4_exllama_v2.py +++ b/tests/test_q4_exllama_v2.py @@ -6,19 +6,22 @@ # -- do not touch import os + os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" # -- end do not touch import unittest # noqa: E402 import torch # noqa: E402 +from test_q4_exllama_v1 import REFERENCE, get_diff # noqa: E402 +from transformers import AutoTokenizer # noqa: E402 + from gptqmodel import BACKEND, GPTQModel # noqa: E402 from gptqmodel.nn_modules.qlinear.exllamav2 import ExllamaV2QuantLinear # noqa: E402 from gptqmodel.quantization import FORMAT, METHOD # noqa: E402 from gptqmodel.utils.importer import select_quant_linear # noqa: E402 from gptqmodel.utils.model import gptqmodel_post_init # noqa: E402 -from test_q4_exllama_v1 import REFERENCE, get_diff # noqa: E402 -from transformers import AutoTokenizer # noqa: E402 + GENERATE_EVAL_SIZE = 100 diff --git a/tests/test_q4_ipex.py b/tests/test_q4_ipex.py index 01eb6b27e..e8b0422c5 100644 --- a/tests/test_q4_ipex.py +++ b/tests/test_q4_ipex.py @@ -6,13 +6,15 @@ # -- do not touch import os + os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" # -- end do not touch import torch # noqa: E402 -from gptqmodel import BACKEND # noqa: E402 from models.model_test import ModelTest # noqa: E402 +from gptqmodel import BACKEND # noqa: E402 + class TestsTorchFused(ModelTest): NATIVE_MODEL_ID = "/monster/data/model/TinyLlama-1.1B-Chat-v1.0-GPTQ-4bit" # "bigscience/bloom-560m" diff --git a/tests/test_q4_marlin.py b/tests/test_q4_marlin.py index 69950ae77..f35efd7c8 100644 --- a/tests/test_q4_marlin.py +++ b/tests/test_q4_marlin.py @@ -6,16 +6,18 @@ # -- do not touch import os + os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" # -- end do not touch import torch # noqa: E402 -from gptqmodel import BACKEND, GPTQModel # noqa: E402 -from gptqmodel.nn_modules.qlinear.marlin import MarlinQuantLinear # noqa: E402 from models.model_test import ModelTest # noqa: E402 from parameterized import parameterized # noqa: E402 from transformers import AutoTokenizer # noqa: E402 +from gptqmodel import BACKEND, GPTQModel # noqa: E402 +from gptqmodel.nn_modules.qlinear.marlin import MarlinQuantLinear # noqa: E402 + class TestQ4Marlin(ModelTest): diff --git a/tests/test_q4_torch.py b/tests/test_q4_torch.py index d2a59c3cc..f67e51a9f 100644 --- a/tests/test_q4_torch.py +++ b/tests/test_q4_torch.py @@ -6,14 +6,16 @@ # -- do not touch import os + os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" # -- end do not touch import torch # noqa: E402 -from gptqmodel import BACKEND, GPTQModel # noqa: E402 from models.model_test import ModelTest # noqa: E402 from parameterized import parameterized # noqa: E402 +from gptqmodel import BACKEND, GPTQModel # noqa: E402 + class TestsQ4Torch(ModelTest): GENERATE_EVAL_SIZE_MIN = 20 diff --git a/tests/test_q4_torch_apple.py b/tests/test_q4_torch_apple.py index 621a8d310..70798b715 100644 --- a/tests/test_q4_torch_apple.py +++ b/tests/test_q4_torch_apple.py @@ -6,11 +6,12 @@ import sys # noqa: E402 import torch # noqa: E402 -from gptqmodel import BACKEND, GPTQModel # noqa: E402 from models.model_test import ModelTest # noqa: E402 from parameterized import parameterized # noqa: E402 from transformers import AutoTokenizer # noqa: E402 +from gptqmodel import BACKEND, GPTQModel # noqa: E402 + class TestsQ4Torch(ModelTest): GENERATE_EVAL_SIZE_MIN = 5 diff --git a/tests/test_q4_triton.py b/tests/test_q4_triton.py index b0da451dd..d19e141ab 100644 --- a/tests/test_q4_triton.py +++ b/tests/test_q4_triton.py @@ -6,15 +6,17 @@ # -- do not touch import os + os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" # -- end do not touch import torch # noqa: E402 -from gptqmodel import BACKEND, GPTQModel # noqa: E402 -from gptqmodel.nn_modules.qlinear.tritonv2 import TritonV2QuantLinear # noqa: E402 from models.model_test import ModelTest # noqa: E402 from transformers import AutoTokenizer # noqa: E402 +from gptqmodel import BACKEND, GPTQModel # noqa: E402 +from gptqmodel.nn_modules.qlinear.tritonv2 import TritonV2QuantLinear # noqa: E402 + class TestsQ4Triton(ModelTest): model_id = "/monster/data/model/TinyLlama-1.1B-Chat-v1.0-GPTQ-4bit" diff --git a/tests/test_qqq.py b/tests/test_qqq.py index 660b794e8..e5ce63448 100644 --- a/tests/test_qqq.py +++ b/tests/test_qqq.py @@ -6,17 +6,21 @@ import unittest from datasets import load_dataset +from parameterized import parameterized +from transformers import AutoTokenizer + from gptqmodel.nn_modules.qlinear.qqq import QQQQuantLinear from gptqmodel.quantization import FORMAT, METHOD, QUANT_CONFIG_FILENAME from gptqmodel.utils.torch import torch_empty_cache -from parameterized import parameterized -from transformers import AutoTokenizer + os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" -from gptqmodel import BACKEND, GPTQModel, QuantizeConfig # noqa: E402 # -- end do not touch from logbar import LogBar +from gptqmodel import BACKEND, GPTQModel, QuantizeConfig # noqa: E402 + + log = LogBar.shared() class TestGroupSize(unittest.TestCase): diff --git a/tests/test_qqq_inference.py b/tests/test_qqq_inference.py index 4bf5e1870..013aa055b 100644 --- a/tests/test_qqq_inference.py +++ b/tests/test_qqq_inference.py @@ -1,6 +1,7 @@ from gptqmodel import GPTQModel from gptqmodel.utils.eval import EVAL + eval_results = GPTQModel.eval("HandH1998/QQQ-Llama-3-8b-g128", framework=EVAL.LM_EVAL, tasks=[EVAL.LM_EVAL.ARC_CHALLENGE]) diff --git a/tests/test_quant_and_eora.py b/tests/test_quant_and_eora.py index a779880bd..5f79e74ff 100644 --- a/tests/test_quant_and_eora.py +++ b/tests/test_quant_and_eora.py @@ -16,6 +16,7 @@ # -- do not touch import os + os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" # -- end do not touch @@ -23,14 +24,15 @@ from typing import Optional # noqa: E402 from datasets import load_dataset # noqa: E402 +from lm_eval.utils import make_table # noqa: E402 +from models.model_test import ModelTest # noqa: E402 +from parameterized import parameterized # noqa: E402 + from gptqmodel import BACKEND, GPTQModel, QuantizeConfig # noqa: E402 from gptqmodel.adapter.adapter import Lora # noqa: E402 from gptqmodel.quantization import FORMAT, METHOD # noqa: E402 from gptqmodel.utils.eval import EVAL # noqa: E402 from gptqmodel.utils.torch import torch_empty_cache # noqa: E402 -from lm_eval.utils import make_table # noqa: E402 -from models.model_test import ModelTest # noqa: E402 -from parameterized import parameterized # noqa: E402 class Test(ModelTest): diff --git a/tests/test_quant_and_eora_transformers.py b/tests/test_quant_and_eora_transformers.py index a21ec57ba..a35a341b9 100644 --- a/tests/test_quant_and_eora_transformers.py +++ b/tests/test_quant_and_eora_transformers.py @@ -21,6 +21,7 @@ from safetensors.torch import load_file from transformers import AutoModelForCausalLM, AutoTokenizer + os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" # -- end do not touch @@ -28,15 +29,17 @@ from typing import Optional # noqa: E402 from datasets import load_dataset # noqa: E402 -from gptqmodel import BACKEND, GPTQModel, QuantizeConfig # noqa: E402 -from gptqmodel.adapter.adapter import HF_ADAPTER_FILE_NAME, HF_ADAPTER_WEIGHT_KEY_PREFIX, Lora # noqa: E402 -from gptqmodel.utils.eval import EVAL # noqa: E402 -from gptqmodel.utils.torch import torch_empty_cache # noqa: E402 from lm_eval.utils import make_table # noqa: E402 from logbar import LogBar from models.model_test import ModelTest # noqa: E402 from tabulate import tabulate # noqa: E402 +from gptqmodel import BACKEND, GPTQModel, QuantizeConfig # noqa: E402 +from gptqmodel.adapter.adapter import HF_ADAPTER_FILE_NAME, HF_ADAPTER_WEIGHT_KEY_PREFIX, Lora # noqa: E402 +from gptqmodel.utils.eval import EVAL # noqa: E402 +from gptqmodel.utils.torch import torch_empty_cache # noqa: E402 + + log = LogBar.shared() diff --git a/tests/test_quant_batch.py b/tests/test_quant_batch.py index 92440c1b1..699b3fb17 100644 --- a/tests/test_quant_batch.py +++ b/tests/test_quant_batch.py @@ -6,16 +6,18 @@ # -- do not touch import os + os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" # -- end do not touch import tempfile # noqa: E402 +from models.model_test import ModelTest # noqa: E402 +from transformers import AutoTokenizer # noqa: E402 + from gptqmodel import GPTQModel # noqa: E402 from gptqmodel.quantization import QuantizeConfig # noqa: E402 from gptqmodel.utils.perplexity import Perplexity # noqa: E402 -from models.model_test import ModelTest # noqa: E402 -from transformers import AutoTokenizer # noqa: E402 class TestQuantBatch(ModelTest): diff --git a/tests/test_quant_formats.py b/tests/test_quant_formats.py index 82526f84a..2cf181571 100644 --- a/tests/test_quant_formats.py +++ b/tests/test_quant_formats.py @@ -6,6 +6,7 @@ # -- do not touch import os + os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" @@ -16,13 +17,14 @@ import tempfile # noqa: E402 from datasets import load_dataset # noqa: E402 +from models.model_test import ModelTest # noqa: E402 +from parameterized import parameterized # noqa: E402 +from transformers import AutoTokenizer # noqa: E402 + from gptqmodel import BACKEND, GPTQModel, __version__, get_best_device # noqa: E402 from gptqmodel.quantization import FORMAT, METHOD, QUANT_CONFIG_FILENAME # noqa: E402 from gptqmodel.quantization.config import META_FIELD_QUANTIZER, META_QUANTIZER_GPTQMODEL, QuantizeConfig # noqa: E402 from gptqmodel.utils.torch import torch_empty_cache # noqa: E402 -from models.model_test import ModelTest # noqa: E402 -from parameterized import parameterized # noqa: E402 -from transformers import AutoTokenizer # noqa: E402 class TestQuantization(ModelTest): diff --git a/tests/test_quant_time.py b/tests/test_quant_time.py index 9ec0316c5..e484e27f5 100644 --- a/tests/test_quant_time.py +++ b/tests/test_quant_time.py @@ -5,13 +5,15 @@ import os + os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" import time # noqa: E402 +from models.model_test import ModelTest # noqa: E402 + from gptqmodel import GPTQModel # noqa: E402 from gptqmodel.quantization.config import QuantizeConfig # noqa: E402 -from models.model_test import ModelTest # noqa: E402 class TestQuantTime(ModelTest): diff --git a/tests/test_quant_trust_remote.py b/tests/test_quant_trust_remote.py index f6e65d76f..899eeba79 100644 --- a/tests/test_quant_trust_remote.py +++ b/tests/test_quant_trust_remote.py @@ -6,16 +6,18 @@ # -- do not touch import os + os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" # -- end do not touch import tempfile # noqa: E402 -from gptqmodel import GPTQModel # noqa: E402 -from gptqmodel.quantization import FORMAT, QuantizeConfig # noqa: E402 from models.model_test import ModelTest # noqa: E402 from transformers import AutoTokenizer # noqa: E402 +from gptqmodel import GPTQModel # noqa: E402 +from gptqmodel.quantization import FORMAT, QuantizeConfig # noqa: E402 + class TestQuantWithTrustRemoteTrue(ModelTest): @classmethod diff --git a/tests/test_save_loaded_quantized_model.py b/tests/test_save_loaded_quantized_model.py index ff2811399..7a6075df9 100644 --- a/tests/test_save_loaded_quantized_model.py +++ b/tests/test_save_loaded_quantized_model.py @@ -6,15 +6,18 @@ # -- do not touch import os + os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" # -- end do not touch import tempfile # noqa: E402 import unittest # noqa: E402 -from gptqmodel import BACKEND, GPTQModel, get_best_device # noqa: E402 from parameterized import parameterized # noqa: E402 from transformers import AutoTokenizer # noqa: E402 +from gptqmodel import BACKEND, GPTQModel, get_best_device # noqa: E402 + + MODEL_ID = "/monster/data/model/TinyLlama-1.1B-Chat-v1.0-GPTQ-4bit" class TestSave(unittest.TestCase): diff --git a/tests/test_save_loaded_quantized_model_ipex.py b/tests/test_save_loaded_quantized_model_ipex.py index 4c24ed620..ea4871af2 100644 --- a/tests/test_save_loaded_quantized_model_ipex.py +++ b/tests/test_save_loaded_quantized_model_ipex.py @@ -6,15 +6,18 @@ # -- do not touch import os + os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" # -- end do not touch import tempfile # noqa: E402 import unittest # noqa: E402 -from gptqmodel import BACKEND, GPTQModel, get_best_device # noqa: E402 from parameterized import parameterized # noqa: E402 from transformers import AutoTokenizer # noqa: E402 +from gptqmodel import BACKEND, GPTQModel, get_best_device # noqa: E402 + + MODEL_ID = "/monster/data/model/TinyLlama-1.1B-Chat-v1.0-GPTQ-4bit" class TestSave(unittest.TestCase): diff --git a/tests/test_serialization.py b/tests/test_serialization.py index 77ffc52bd..f7a8d25e9 100644 --- a/tests/test_serialization.py +++ b/tests/test_serialization.py @@ -6,6 +6,7 @@ # -- do not touch import os + os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" # -- end do not touch diff --git a/tests/test_sglang.py b/tests/test_sglang.py index b36acba84..169972468 100644 --- a/tests/test_sglang.py +++ b/tests/test_sglang.py @@ -5,14 +5,16 @@ import os + os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" # -- end do not touch import importlib.util # noqa: E402 -from gptqmodel import BACKEND, GPTQModel # noqa: E402 from models.model_test import ModelTest # noqa: E402 +from gptqmodel import BACKEND, GPTQModel # noqa: E402 + class TestLoadSglang(ModelTest): diff --git a/tests/test_sharded.py b/tests/test_sharded.py index f6c19185b..9b2a00bd3 100644 --- a/tests/test_sharded.py +++ b/tests/test_sharded.py @@ -6,6 +6,7 @@ # -- do not touch import os + os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" # -- end do not touch @@ -14,9 +15,10 @@ import unittest # noqa: E402 import torch # noqa: E402 -from gptqmodel import GPTQModel # noqa: E402 from transformers import AutoTokenizer # noqa: E402 +from gptqmodel import GPTQModel # noqa: E402 + class TestSharded(unittest.TestCase): MODEL_ID = "/monster/data/model/TinyLlama-1.1B-Chat-v1.0-GPTQ-4bit" diff --git a/tests/test_simple_quant.py b/tests/test_simple_quant.py index e783ebcd6..202e9048b 100644 --- a/tests/test_simple_quant.py +++ b/tests/test_simple_quant.py @@ -1,10 +1,12 @@ import tempfile from datasets import load_dataset +from logbar import LogBar + from gptqmodel import GPTQModel, QuantizeConfig from gptqmodel.quantization import FORMAT from gptqmodel.utils.eval import EVAL -from logbar import LogBar + log = LogBar.shared() @@ -70,6 +72,7 @@ def get_calib_data(tokenizer, rows: int): # eval from lm_eval.utils import make_table + with tempfile.TemporaryDirectory() as tmp_dir: results = GPTQModel.eval( QUANT_SAVE_PATH, diff --git a/tests/test_tgi.py b/tests/test_tgi.py index 0e5099325..c20a23a38 100644 --- a/tests/test_tgi.py +++ b/tests/test_tgi.py @@ -6,6 +6,7 @@ # -- do not touch import os + os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" # -- end do not touch import json # noqa: E402 diff --git a/tests/test_threadx.py b/tests/test_threadx.py new file mode 100644 index 000000000..c2bfa2b27 --- /dev/null +++ b/tests/test_threadx.py @@ -0,0 +1,296 @@ +# SPDX-FileCopyrightText: 2024-2025 ModelCloud.ai +# SPDX-FileCopyrightText: 2024-2025 qubitium@modelcloud.ai +# SPDX-License-Identifier: Apache-2.0 +# Contact: qubitium@modelcloud.ai, x.com/qubitium + +import threading +import time +from concurrent.futures import ThreadPoolExecutor + +import pytest +import torch +import torch.nn as nn + +from gptqmodel.utils.threadx import DeviceThreadPool + + +pytestmark = [ + pytest.mark.cuda, + pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA required"), +] + + +def require_n_gpus(n): + if torch.cuda.device_count() < n: + pytest.skip(f"requires >= {n} CUDA devices") + + +# ----------------------------- Fixtures ----------------------------- + +@pytest.fixture(scope="module") +def devices_two(): + require_n_gpus(2) + return [torch.device("cuda", 0), torch.device("cuda", 1)] + + +@pytest.fixture() +def pool(devices_two): + # small threshold for fast tests + p = DeviceThreadPool(devices=devices_two, inference_mode=True, empty_cache_every_n=3) + yield p + p.shutdown(wait=True) + + +# ----------------------------- Helpers ------------------------------ + +def _sleep_kernel_ms(ms: int): + """Lightweight delay on device via cuda._sleep; fall back to host sleep.""" + cycles = int(ms * 1_000_000) + if hasattr(torch.cuda, "_sleep"): + torch.cuda._sleep(cycles) + else: + time.sleep(ms / 1000.0) + + +# ----------------------------- Core Tests --------------------------- + +def test_basic_submit_and_do(pool, devices_two): + d0 = devices_two[0] + + def add(a, b): + return a + b + + a = torch.randn(512, 512, device=d0) + b = torch.randn(512, 512, device=d0) + + out = pool.do(d0, add, a, b) + assert out.device == d0 + torch.testing.assert_close(out, a + b) + + fut = pool.submit(d0, add, a, b) + out2 = fut.result(timeout=5) + torch.testing.assert_close(out2, a + b) + + +def test_linear_forward(pool, devices_two): + d0 = devices_two[0] + m = nn.Linear(128, 64).to(d0) + x = torch.randn(32, 128, device=d0) + + def forward(module, inp): + return module(inp) + + y = pool.do(d0, forward, m, x) + assert y.shape == (32, 64) and y.device == d0 + + +def test_tensor_manipulation_and_minmax(pool, devices_two): + d0 = devices_two[0] + t = torch.randn(2048, device=d0) + + def stats(u): + return u.min().item(), u.max().item(), u.abs().mean().item() + + mn, mx, am = pool.do(d0, stats, t) + assert isinstance(mn, float) and isinstance(mx, float) and isinstance(am, float) + assert mn <= mx + + +def test_d2h_and_h2d_with_pinned_memory(pool, devices_two): + d0 = devices_two[0] + n = 1 << 19 + src = torch.randn(n, device=d0, dtype=torch.float32) + host = torch.empty(n, dtype=torch.float32, pin_memory=True) + + def d2h(a, h): + h.copy_(a, non_blocking=True) + torch.cuda.current_stream().synchronize() + return float(h[:10].sum().item()) + + s = pool.do(d0, d2h, src, host) + assert isinstance(s, float) + + def h2d(h): + b = torch.empty_like(src, device=d0) + b.copy_(h, non_blocking=True) + torch.cuda.current_stream().synchronize() + return b + + dst = pool.do(d0, h2d, host) + torch.testing.assert_close(dst, host.to(device=d0)) + + +def test_p2p_copy_between_devices(pool, devices_two): + d0, d1 = devices_two + n = 1 << 18 + a = torch.randn(n, device=d0, dtype=torch.float16) + + def p2p(x, target_dev): + return x.to(target_dev, non_blocking=True) + + b = pool.do(d0, p2p, a, d1) + assert b.device == d1 + torch.testing.assert_close(b.to(d0), a, atol=1e-3, rtol=1e-3) + + +def test_stream_context_is_honored(pool, devices_two): + d0 = devices_two[0] + stream = torch.cuda.Stream(device=d0) + expected_ptr = stream.cuda_stream + + def check_stream(expected): + cur = torch.cuda.current_stream() + return int(cur.cuda_stream == expected) + + ok = pool.do(d0, check_stream, expected_ptr, _cuda_stream=stream) + assert ok == 1 + + +def test_parallel_submissions_from_many_threads(pool, devices_two): + d0, d1 = devices_two + + def work_scale_add(scale, x, y): + return (x * scale) + y + + xs = [torch.randn(1024, device=d0) for _ in range(8)] + ys = [torch.randn(1024, device=d0) for _ in range(8)] + zs = [torch.randn(1024, device=d1) for _ in range(8)] + + def submit_task(i): + if i % 2 == 0: + return pool.do(d0, work_scale_add, 1.5, xs[i // 2], ys[i // 2]) + else: + return pool.do(d1, work_scale_add, 0.5, zs[i // 2], zs[i // 2]) + + with ThreadPoolExecutor(max_workers=8) as ex: + futs = [ex.submit(submit_task, i) for i in range(8)] + outs = [f.result(timeout=10) for f in futs] + + cnt0 = sum(1 for i in range(8) if i % 2 == 0) + cnt1 = 8 - cnt0 + assert len(outs) == 8 + assert sum(o.device == d0 for o in outs) == cnt0 + assert sum(o.device == d1 for o in outs) == cnt1 + + +def test_device_lock_blocks_only_that_device(pool, devices_two): + d0, d1 = devices_two + + started0 = threading.Event() + finished0 = threading.Event() + started1 = threading.Event() + finished1 = threading.Event() + + def long_op(mark_start: threading.Event, mark_done: threading.Event, ms=150): + mark_start.set() + _sleep_kernel_ms(ms) + mark_done.set() + return ms + + with pool.device_lock(d0): + fut0 = pool.submit(d0, long_op, started0, finished0, 100) + fut1 = pool.submit(d1, long_op, started1, finished1, 50) + + started1.wait(timeout=2) + assert started1.is_set() + finished1.wait(timeout=2) + assert finished1.is_set() + assert not started0.is_set() + + assert fut0.result(timeout=2) == 100 + assert started0.is_set() and finished0.is_set() + assert fut1.result(timeout=0.5) == 50 + + +def test_global_lock_blocks_all_devices(pool, devices_two): + d0, d1 = devices_two + + started = [threading.Event(), threading.Event()] + + def long_op(i, ms=100): + started[i].set() + _sleep_kernel_ms(ms) + return ms + + with pool.lock(): + f0 = pool.submit(d0, long_op, 0, 50) + f1 = pool.submit(d1, long_op, 1, 50) + time.sleep(0.05) + assert not started[0].is_set() + assert not started[1].is_set() + + assert f0.result(timeout=2) == 50 + assert f1.result(timeout=2) == 50 + assert started[0].is_set() and started[1].is_set() + + +# ---------------------- Counters + Janitor Tests --------------------- + +def test_counters_increment_and_global_sum(pool, devices_two): + d0, d1 = devices_two + + def noop(): + return 1 + + # run several tasks across both devices + for _ in range(4): + pool.do(d0, noop) + for _ in range(5): + pool.do(d1, noop) + + stats = pool.stats() + per = stats["per_device"] + total = stats["total"] + assert per[f"cuda:{d0.index}"] == 4 + assert per[f"cuda:{d1.index}"] == 5 + assert total == 9 + + +def test_janitor_triggers_empty_cache_every_n(pool, devices_two, monkeypatch): + """ + Set threshold small (3). After each device completes 3 tasks, we expect the janitor + to acquire the global lock and call empty_cache once per device. + We monkeypatch torch.cuda.empty_cache to record invocations and delay slightly so + we can observe the janitor pass. + """ + d0, d1 = devices_two + + calls = [] + in_gc = threading.Event() + + orig_empty = torch.cuda.empty_cache + + def spy_empty_cache(): + in_gc.set() + # record current device id + cur = torch.cuda.current_device() + calls.append(cur) + # small delay so we can see blocking effects if any + time.sleep(0.03) + return + + monkeypatch.setattr(torch.cuda, "empty_cache", spy_empty_cache) + + def noop(): + return 1 + + # Threshold is 3 => after 3 completions on d0, GC pass runs (both devices visited) + for _ in range(3): + pool.do(d0, noop) + + # Wait for janitor to start + in_gc.wait(timeout=2) + assert in_gc.is_set() + + # Give janitor time to run both device empties + t0 = time.time() + while len(calls) < 2 and time.time() - t0 < 2.0: + time.sleep(0.01) + + # We should see two calls, one per device (order not guaranteed) + assert len(calls) >= 2 + assert sorted(set(calls)) == sorted({d0.index, d1.index}) + + # Restore + monkeypatch.setattr(torch.cuda, "empty_cache", orig_empty) + diff --git a/tests/test_threadx_cpu.py b/tests/test_threadx_cpu.py new file mode 100644 index 000000000..95d242df5 --- /dev/null +++ b/tests/test_threadx_cpu.py @@ -0,0 +1,160 @@ +# SPDX-FileCopyrightText: 2024-2025 ModelCloud.ai +# SPDX-FileCopyrightText: 2024-2025 qubitium@modelcloud.ai +# SPDX-License-Identifier: Apache-2.0 +# Contact: qubitium@modelcloud.ai, x.com/qubitium + +import threading +import time + +import pytest +import torch +import torch.nn as nn + +from gptqmodel.utils.threadx import DeviceThreadPool + + +pytestmark = [ + pytest.mark.cuda, + pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA required"), +] + +def require_n_gpus(n): + if torch.cuda.device_count() < n: + pytest.skip(f"requires >= {n} CUDA devices") + +# ----------------------------- Existing GPU fixtures/tests omitted for brevity ----------------------------- +# (Keep your previously provided tests as-is.) + +# ----------------------------- New CPU-inclusive fixtures ----------------------------- + +@pytest.fixture(scope="module") +def devices_two_plus_cpu(): + require_n_gpus(2) + return [torch.device("cuda", 0), torch.device("cuda", 1), torch.device("cpu")] + +@pytest.fixture() +def pool_cpu(devices_two_plus_cpu): + # small threshold so we can easily trip GC on GPUs; CPU won't trigger GC itself + p = DeviceThreadPool(devices=devices_two_plus_cpu, inference_mode=True, empty_cache_every_n=3) + yield p + p.shutdown(wait=True) + +# ----------------------------- CPU tests ----------------------------- + +def test_cpu_worker_basic(pool_cpu): + d_cpu = torch.device("cpu") + + def add(a, b): + return a + b + + a = torch.randn(256, 256, device=d_cpu) + b = torch.randn(256, 256, device=d_cpu) + out = pool_cpu.do(d_cpu, add, a, b) + assert out.device.type == "cpu" + torch.testing.assert_close(out, a + b) + +def test_cpu_linear_forward(pool_cpu): + d_cpu = torch.device("cpu") + m = nn.Linear(128, 64) + x = torch.randn(32, 128, device=d_cpu) + + def forward(module, inp): + return module(inp) + + y = pool_cpu.do(d_cpu, forward, m, x) + assert y.shape == (32, 64) and y.device.type == "cpu" + +def test_cpu_device_lock_blocks_only_cpu(pool_cpu): + d_cpu = torch.device("cpu") + d0 = torch.device("cuda", 0) + + started_cpu = threading.Event() + finished_cpu = threading.Event() + started_gpu = threading.Event() + finished_gpu = threading.Event() + + def long_cpu(mark_start, mark_done, ms=150): + mark_start.set() + time.sleep(ms / 1000.0) + mark_done.set() + return ms + + def long_gpu(mark_start, mark_done, ms=100): + mark_start.set() + # use a tiny CUDA sleep to do “real” device work + if hasattr(torch.cuda, "_sleep"): + torch.cuda._sleep(ms * 1000_000) + else: + time.sleep(ms / 1000.0) + mark_done.set() + return ms + + with pool_cpu.device_lock(d_cpu): + f_cpu = pool_cpu.submit(d_cpu, long_cpu, started_cpu, finished_cpu, 100) + f_gpu = pool_cpu.submit(d0, long_gpu, started_gpu, finished_gpu, 50) + + # GPU should run while CPU is blocked + started_gpu.wait(timeout=2) + assert started_gpu.is_set() + finished_gpu.wait(timeout=2) + assert finished_gpu.is_set() + + # CPU must not have started + assert not started_cpu.is_set() + + # After release, CPU proceeds + assert f_cpu.result(timeout=2) == 100 + assert started_cpu.is_set() and finished_cpu.is_set() + assert f_gpu.result(timeout=0.5) == 50 + +def test_global_lock_includes_cpu(pool_cpu): + d_cpu = torch.device("cpu") + d0 = torch.device("cuda", 0) + + started = [threading.Event(), threading.Event()] # [cpu, gpu] + + def long_cpu(): + started[0].set() + time.sleep(0.05) + return 1 + + def long_gpu(): + started[1].set() + if hasattr(torch.cuda, "_sleep"): + torch.cuda._sleep(50 * 1000_000) + else: + time.sleep(0.05) + return 1 + + with pool_cpu.lock(): + f_cpu = pool_cpu.submit(d_cpu, long_cpu) + f_gpu = pool_cpu.submit(d0, long_gpu) + time.sleep(0.02) + # Neither should start under global lock + assert not started[0].is_set() + assert not started[1].is_set() + + assert f_cpu.result(timeout=2) == 1 + assert f_gpu.result(timeout=2) == 1 + assert started[0].is_set() and started[1].is_set() + +def test_counters_include_cpu(pool_cpu): + d_cpu = torch.device("cpu") + d0 = torch.device("cuda", 0) + + def noop(): + return 1 + + # CPU tasks + for _ in range(4): + pool_cpu.do(d_cpu, noop) + + # GPU tasks + for _ in range(2): + pool_cpu.do(d0, noop) + + stats = pool_cpu.stats() + per = stats["per_device"] + assert per["cpu"] == 4 + assert per[f"cuda:{d0.index}"] >= 2 # at least the two we just ran + assert stats["total"] >= 6 diff --git a/tests/test_threadx_mps.py b/tests/test_threadx_mps.py new file mode 100644 index 000000000..a5d92aaa3 --- /dev/null +++ b/tests/test_threadx_mps.py @@ -0,0 +1,71 @@ +# SPDX-FileCopyrightText: 2024-2025 ModelCloud.ai +# SPDX-FileCopyrightText: 2024-2025 qubitium@modelcloud.ai +# SPDX-License-Identifier: Apache-2.0 +# Contact: qubitium@modelcloud.ai, x.com/qubitium + +import time + +import pytest +import torch +import torch.nn as nn + +from gptqmodel.utils.threadx import DeviceThreadPool + + +mps_available = hasattr(torch, "backends") and hasattr(torch.backends, "mps") and torch.backends.mps.is_available() + +pytestmark = [ + pytest.mark.mps, + pytest.mark.skipif(not mps_available, reason="MPS not available"), +] + +def test_mps_worker_basic(): + d_mps = torch.device("mps") + p = DeviceThreadPool(devices=[d_mps], inference_mode=True, empty_cache_every_n=3) + try: + def add(a, b): return a + b + a = torch.randn(256, 256, device=d_mps) + b = torch.randn(256, 256, device=d_mps) + out = p.do(d_mps, add, a, b) + assert out.device.type == "mps" + torch.testing.assert_close(out, a + b) + finally: + p.shutdown() + +def test_mps_linear_forward_and_counters(monkeypatch): + d_mps = torch.device("mps") + calls = [] + + # Spy on torch.mps.empty_cache to confirm janitor invocation + if hasattr(torch, "mps"): + orig = torch.mps.empty_cache + def spy(): + calls.append("ec") + # tiny delay to ensure pass runs + time.sleep(0.01) + monkeypatch.setattr(torch.mps, "empty_cache", spy) + + p = DeviceThreadPool(devices=[d_mps], inference_mode=True, empty_cache_every_n=2) + try: + m = nn.Linear(64, 32).to(d_mps) + x = torch.randn(16, 64, device=d_mps) + + def fwd(mod, inp): return mod(inp) + + # two tasks -> threshold 2 -> janitor should run once + p.do(d_mps, fwd, m, x) + y = p.do(d_mps, fwd, m, x) + assert y.shape == (16, 32) + + # allow janitor to run + time.sleep(0.1) + + st = p.stats() + assert st["per_device"]["mps"] >= 2 + assert st["total"] >= 2 + if hasattr(torch, "mps"): + assert len(calls) >= 1 + finally: + p.shutdown() + if hasattr(torch, "mps"): + monkeypatch.setattr(torch.mps, "empty_cache", orig) diff --git a/tests/test_threadx_wait.py b/tests/test_threadx_wait.py new file mode 100644 index 000000000..92b01877c --- /dev/null +++ b/tests/test_threadx_wait.py @@ -0,0 +1,261 @@ +# SPDX-License-Identifier: Apache-2.0 +# Author: ModelCloud.ai / qubitium +import threading +import time + +import pytest +import torch + +from gptqmodel.utils.threadx import DeviceThreadPool + + +pytestmark = [ + pytest.mark.cuda, + pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA required"), +] + + +def require_n_gpus(n: int): + if torch.cuda.device_count() < n: + pytest.skip(f"requires >= {n} CUDA devices") + + +# ---------- Helpers ---------- + +def _host_long(ms=150): + # Host-side delay (keeps worker thread busy, independent of GPU concurrency) + time.sleep(ms / 1000.0) + return ms + + +def _start_then_sleep(start_evt: threading.Event, ms=200): + start_evt.set() + time.sleep(ms / 1000.0) + return ms + + +# ---------- Fixtures ---------- + +@pytest.fixture() +def pool_default_two_cuda(): + """Default pool with 2 CUDA devices, 1 worker per device.""" + require_n_gpus(2) + devices = [torch.device("cuda", 0), torch.device("cuda", 1)] + p = DeviceThreadPool(devices=devices, inference_mode=True, empty_cache_every_n=0) + try: + yield p + finally: + p.shutdown(wait=True) + + +@pytest.fixture() +def pool_workers_override(): + """ + Pool validating worker-count overrides: + - cuda:0 -> 3 workers (override) + - cuda:1 -> 1 worker (via 'cuda:per') + - cpu -> 4 workers + """ + require_n_gpus(2) + devices = [torch.device("cuda", 0), torch.device("cuda", 1), torch.device("cpu")] + p = DeviceThreadPool( + devices=devices, + inference_mode=True, + empty_cache_every_n=0, + workers={"cuda:per": 1, "cuda:0": 3, "cpu": 4}, + ) + try: + yield p + finally: + p.shutdown(wait=True) + + +# ---------- wait() API tests ---------- + +def test_wait_cuda_without_lock(pool_default_two_cuda): + d0 = torch.device("cuda", 0) + + # Submit several long host tasks to d0 + futs = [pool_default_two_cuda.submit(d0, _host_long, 150) for _ in range(3)] + + # Wait for CUDA scope to drain + pool_default_two_cuda.wait("cuda") + + # All should be done now + for f in futs: + assert f.done() + assert f.result() == 150 + + +def test_wait_cuda_with_lock_blocks_new_tasks(pool_default_two_cuda): + d0 = torch.device("cuda", 0) + + started_after_lock = threading.Event() + + # Submit one long task so there's something to drain + f0 = pool_default_two_cuda.submit(d0, _host_long, 120) + + # Acquire wait+lock over all CUDA devices + with pool_default_two_cuda.wait("cuda", lock=True): + # At this point, the initial task finished and we hold the writer lock. + # Submit a new task: it should enqueue but **not** start until we exit. + fut = pool_default_two_cuda.submit(d0, _start_then_sleep, started_after_lock, 200) + + # Give the worker a moment; it must NOT start while lock is held. + time.sleep(0.05) + assert not started_after_lock.is_set() + assert not fut.done() + + # After releasing the lock, the task can start and complete + assert fut.result(timeout=2) == 200 + assert started_after_lock.is_set() + assert f0.result(timeout=0.5) == 120 + + +def test_wait_specific_device_vs_family(pool_default_two_cuda): + d0, d1 = torch.device("cuda", 0), torch.device("cuda", 1) + + started_d0 = threading.Event() + started_d1 = threading.Event() + + fut0 = pool_default_two_cuda.submit(d0, _start_then_sleep, started_d0, 200) + fut1 = pool_default_two_cuda.submit(d1, _start_then_sleep, started_d1, 200) + + # Wait only for cuda:0; cuda:1 may still be running. + pool_default_two_cuda.wait("cuda:0") + + # d0 must be done; d1 may or may not be done depending on timing, but no deadlocks. + assert fut0.done() + assert started_d0.is_set() + assert fut0.result(timeout=0.5) == 200 + + # To be deterministic, drain all CUDA before leaving the test + pool_default_two_cuda.wait("cuda") + assert fut1.result(timeout=0.5) == 200 + assert started_d1.is_set() + + +def test_wait_returns_context_manager_with_lock(pool_default_two_cuda): + d0 = torch.device("cuda", 0) + + with pool_default_two_cuda.wait("cuda", lock=True): + # While holding the exclusive lock, submitting a task will park at the lock + started = threading.Event() + fut = pool_default_two_cuda.submit(d0, _start_then_sleep, started, 120) + time.sleep(0.05) + assert not started.is_set() + assert not fut.done() + + # After lock release, it proceeds + assert fut.result(timeout=2) == 120 + assert started.is_set() + + +# ---------- Worker-count policy tests ---------- + +def test_worker_count_override_cuda0_vs_cuda1(pool_workers_override): + """ + With workers={"cuda:per":1, "cuda:0":3}: + - cuda:0 has 3 workers -> up to 3 tasks can start almost immediately. + - cuda:1 has 1 worker -> tasks start one-by-one. + We measure host-level concurrency via start events. + """ + d0, d1 = torch.device("cuda", 0), torch.device("cuda", 1) + + # cuda:0 — expect ~3 tasks to start quickly + start_events_0 = [threading.Event() for _ in range(6)] + futs0 = [pool_workers_override.submit(d0, _start_then_sleep, ev, 200) for ev in start_events_0] + time.sleep(0.10) # give workers time to start tasks + started0 = sum(ev.is_set() for ev in start_events_0) + assert started0 >= 3 # at least the configured worker count + # Drain + for f in futs0: + assert f.result(timeout=3) == 200 + + # cuda:1 — only 1 worker; after a short delay, only ~1 task should have started + start_events_1 = [threading.Event() for _ in range(4)] + futs1 = [pool_workers_override.submit(d1, _start_then_sleep, ev, 150) for ev in start_events_1] + time.sleep(0.08) + started1 = sum(ev.is_set() for ev in start_events_1) + assert started1 <= 2 # typically 1; allow 2 to avoid flakiness + # Drain + for f in futs1: + assert f.result(timeout=3) == 150 + + +def test_worker_count_override_cpu(pool_workers_override): + """ + With workers={"cpu":4}, we expect ~4 CPU tasks to start quickly in parallel. + """ + d_cpu = torch.device("cpu") + + starts = [threading.Event() for _ in range(8)] + futs = [pool_workers_override.submit(d_cpu, _start_then_sleep, ev, 200) for ev in starts] + + time.sleep(0.10) + started = sum(ev.is_set() for ev in starts) + assert started >= 4 # at least configured worker count should be active + + for f in futs: + assert f.result(timeout=3) == 200 + + +def test_wait_all_scope_with_mixed_devices(pool_workers_override): + """ + Ensure wait(None)/wait('all') drains all devices (CPU + both CUDA). + """ + d_cpu = torch.device("cpu") + d0, d1 = torch.device("cuda", 0), torch.device("cuda", 1) + + futs = [] + for _ in range(3): + futs.append(pool_workers_override.submit(d_cpu, _host_long, 120)) + futs.append(pool_workers_override.submit(d0, _host_long, 120)) + futs.append(pool_workers_override.submit(d1, _host_long, 120)) + + # Wait for everything + pool_workers_override.wait() # same as wait('all') + pool_workers_override.wait('all') # idempotent + + for f in futs: + assert f.done() + assert f.result(timeout=0.5) == 120 + + +def test_wait_cuda_lock_allows_other_families(pool_workers_override): + """ + Holding wait('cuda', lock=True) should not block CPU tasks. + """ + d_cpu = torch.device("cpu") + d0 = torch.device("cuda", 0) + + # Fill CUDA with a couple tasks + fut0 = pool_workers_override.submit(d0, _host_long, 120) + fut1 = pool_workers_override.submit(d0, _host_long, 120) + + with pool_workers_override.wait("cuda", lock=True): + # CUDA is drained & locked exclusively here + # CPU task should still start+finish + cpu_done = threading.Event() + + def cpu_task(): + cpu_done.set() + return _host_long(50) + + f_cpu = pool_workers_override.submit(d_cpu, cpu_task) + # Give it a moment; must run even while CUDA is locked + time.sleep(0.02) + assert cpu_done.is_set() + assert f_cpu.result(timeout=2) == 50 + + # Submitting a CUDA task now will queue but not start + started = threading.Event() + f_blocked = pool_workers_override.submit(d0, _start_then_sleep, started, 80) + time.sleep(0.03) + assert not started.is_set() + assert not f_blocked.done() + + # After release, the blocked CUDA task proceeds + assert fut0.result(timeout=2) == 120 + assert fut1.result(timeout=2) == 120 + assert f_blocked.result(timeout=2) == 80 diff --git a/tests/test_tokenicer.py b/tests/test_tokenicer.py index 2cb39b118..26b7940dd 100644 --- a/tests/test_tokenicer.py +++ b/tests/test_tokenicer.py @@ -5,13 +5,15 @@ import os + os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" import unittest # noqa: E402 -from gptqmodel import GPTQModel, QuantizeConfig # noqa: E402 from parameterized import parameterized # noqa: E402 +from gptqmodel import GPTQModel, QuantizeConfig # noqa: E402 + class TestTokenicer(unittest.TestCase): diff --git a/tests/test_triton.py b/tests/test_triton.py index 790b6847e..1475d6fcf 100644 --- a/tests/test_triton.py +++ b/tests/test_triton.py @@ -6,6 +6,7 @@ # -- do not touch import os + os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" # -- end do not touch @@ -14,11 +15,13 @@ import torch # noqa: E402 import torch.utils.benchmark as benchmark # noqa: E402 -from gptqmodel import BACKEND, GPTQModel # noqa: E402 from logbar import LogBar # noqa: E402 from parameterized import parameterized # noqa: E402 from transformers import AutoTokenizer # noqa: E402 +from gptqmodel import BACKEND, GPTQModel # noqa: E402 + + log = LogBar.shared() MODEL_ID = "/monster/data/model/Llama-7B-GPTQ" diff --git a/tests/test_triton_xpu.py b/tests/test_triton_xpu.py index 625c367a6..bb761ab62 100644 --- a/tests/test_triton_xpu.py +++ b/tests/test_triton_xpu.py @@ -6,14 +6,16 @@ # -- do not touch import os + os.environ["CUDA_VISIBLE_DEVICES"] = "" # -- end do not touch import tempfile # noqa: E402 +from models.model_test import ModelTest # noqa: E402 + from gptqmodel import BACKEND, GPTQModel, QuantizeConfig # noqa: E402 from gptqmodel.models._const import DEVICE # noqa: E402 -from models.model_test import ModelTest # noqa: E402 class TestTritonXPU(ModelTest): diff --git a/tests/test_vllm.py b/tests/test_vllm.py index 4e7f85528..2dca623c6 100644 --- a/tests/test_vllm.py +++ b/tests/test_vllm.py @@ -6,17 +6,19 @@ # -- do not touch import os + os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" # -- end do not touch import importlib.util # noqa: E402 import tempfile # noqa: E402 +from models.model_test import ModelTest # noqa: E402 +from transformers import AutoTokenizer # noqa: E402 + from gptqmodel import BACKEND, GPTQModel, QuantizeConfig # noqa: E402 from gptqmodel.nn_modules.qlinear import BaseQuantLinear # noqa: E402 from gptqmodel.utils.torch import torch_empty_cache # noqa: E402 -from models.model_test import ModelTest # noqa: E402 -from transformers import AutoTokenizer # noqa: E402 class TestLoadVLLM(ModelTest):