diff --git a/.github/workflows/unit_tests.yml b/.github/workflows/unit_tests.yml index 3899bceda..c5f56eed5 100644 --- a/.github/workflows/unit_tests.yml +++ b/.github/workflows/unit_tests.yml @@ -58,7 +58,7 @@ env: CUDA_DEVICE_ORDER: PCI_BUS_ID CUDA_VISIBLE_DEVICES: 0 TORCH_CUDA_ARCH_LIST: '8.6 8.9 9.0 12.0' - PYTORCH_CUDA_ALLOC_CONF: 'expandable_segments:True' + PYTORCH_ALLOC_CONF: 'expandable_segments:True' MAX_JOBS: 4 RUNNER: 10.0.13.31 XEON5: 10.0.14.248 diff --git a/MANIFEST.in b/MANIFEST.in index 4ef798768..b3715a358 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -1,7 +1,10 @@ -global-include gptqmodel_ext/**/*.cuh -global-include gptqmodel_ext/**/*.h -global-include gptqmodel_ext/**/*.cpp -global-include gptqmodel_ext/**/*.cu -global-include gptqmodel_ext/**/*.py +recursive-include gptqmodel_ext/awq *.h *.cuh *.cu *.cpp +recursive-include gptqmodel_ext/exllama *.h *.cuh *.cu *.cpp +recursive-include gptqmodel_ext/exllamav2 *.h *.cuh *.cu *.cpp +recursive-include gptqmodel_ext/exllama_eora/eora *.h *.cuh *.cu *.cpp *.py +recursive-include gptqmodel_ext/marlin *.h *.cuh *.cu *.cpp +recursive-include gptqmodel_ext/qqq *.h *.cuh *.cu *.cpp +include gptqmodel_ext/marlin/generate_kernels.py +recursive-exclude gptqmodel_ext __pycache__ *.pyc prune tests/ prune format/ diff --git a/examples/benchmark/perplexity.py b/examples/benchmark/perplexity.py index 36b4eb812..4c0978702 100644 --- a/examples/benchmark/perplexity.py +++ b/examples/benchmark/perplexity.py @@ -12,7 +12,7 @@ os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" -os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" +os.environ["PYTORCH_ALLOC_CONF"] = "expandable_segments:True" if __name__ == "__main__": """ diff --git a/examples/eora/eora_generation.py b/examples/eora/eora_generation.py index 310a81cee..eecf2df9f 100644 --- a/examples/eora/eora_generation.py +++ b/examples/eora/eora_generation.py @@ -18,7 +18,7 @@ os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" -os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" +os.environ["PYTORCH_ALLOC_CONF"] = "expandable_segments:True" # -- end do not touch diff --git a/examples/eora/eora_load_and_inference.py b/examples/eora/eora_load_and_inference.py index 7a3ffa107..f3c2d4e24 100644 --- a/examples/eora/eora_load_and_inference.py +++ b/examples/eora/eora_load_and_inference.py @@ -18,7 +18,7 @@ os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" -os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" +os.environ["PYTORCH_ALLOC_CONF"] = "expandable_segments:True" # -- end do not touch from gptqmodel import BACKEND, GPTQModel # noqa: E402 diff --git a/examples/eora/evaluation.py b/examples/eora/evaluation.py index 08713f089..ba90ac158 100644 --- a/examples/eora/evaluation.py +++ b/examples/eora/evaluation.py @@ -18,7 +18,7 @@ os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" -os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" +os.environ["PYTORCH_ALLOC_CONF"] = "expandable_segments:True" # -- end do not touch from typing import Optional # noqa: E402 diff --git a/examples/eora/post_quant_eora_generation.py b/examples/eora/post_quant_eora_generation.py index 25a1ce2b7..eaad35219 100644 --- a/examples/eora/post_quant_eora_generation.py +++ b/examples/eora/post_quant_eora_generation.py @@ -18,7 +18,7 @@ os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" -os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" +os.environ["PYTORCH_ALLOC_CONF"] = "expandable_segments:True" # -- end do not touch diff --git a/examples/inference/run_with_different_backends.py b/examples/inference/run_with_different_backends.py index 84be1b9a0..98cd1d14b 100644 --- a/examples/inference/run_with_different_backends.py +++ b/examples/inference/run_with_different_backends.py @@ -14,7 +14,7 @@ os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" -os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" +os.environ["PYTORCH_ALLOC_CONF"] = "expandable_segments:True" pretrained_model_id = "TinyLlama/TinyLlama-1.1B-Chat-v1.0" quantized_model_id = "./TinyLlama/TinyLlama-1.1B-Chat-v1.0-4bit-128g" diff --git a/examples/quantization/basic_usage.py b/examples/quantization/basic_usage.py index 5f9f1b387..a8c7e832c 100644 --- a/examples/quantization/basic_usage.py +++ b/examples/quantization/basic_usage.py @@ -11,7 +11,7 @@ os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" -os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" +os.environ["PYTORCH_ALLOC_CONF"] = "expandable_segments:True" pretrained_model_id = "/monster/data/model/TinyLlama-1.1B-Chat-v1.0" # "TinyLlama/TinyLlama-1.1B-Chat-v1.0" quantized_model_id = "TinyLlama-1.1B-Chat-v1.0-4bit-128g" diff --git a/gptqmodel/looper/eora_processor.py b/gptqmodel/looper/eora_processor.py index 5d62a5ba0..4d8c5237b 100644 --- a/gptqmodel/looper/eora_processor.py +++ b/gptqmodel/looper/eora_processor.py @@ -22,6 +22,7 @@ from ..utils.logger import setup_logger from ..utils.model import move_to from ..utils.torch import CPU, DEVICE_0, DEVICE_1, torch_streamCtx, torch_sync +from ..utils.torch import HAS_CUDA, tf32_disable_guard, torch_streamCtx, torch_sync log = setup_logger() diff --git a/gptqmodel/models/auto.py b/gptqmodel/models/auto.py index 9669c7427..924a5cdf0 100644 --- a/gptqmodel/models/auto.py +++ b/gptqmodel/models/auto.py @@ -18,9 +18,9 @@ # os.environ["PYTHON_GIL"] = '0' # log.info("ENV: Auto disable GIL and use free-threading mode when applicable: Python 3.13t+. You must install the -t edition of Python.") -if not os.environ.get("PYTORCH_CUDA_ALLOC_CONF", None): - os.environ["PYTORCH_CUDA_ALLOC_CONF"] = 'expandable_segments:True' - log.info("ENV: Auto setting PYTORCH_CUDA_ALLOC_CONF='expandable_segments:True' for memory saving.") +if not os.environ.get("PYTORCH_ALLOC_CONF", None): + os.environ["PYTORCH_ALLOC_CONF"] = 'expandable_segments:True,max_split_size_mb:256,garbage_collection_threshold:0.7' + log.info("ENV: Auto setting PYTORCH_ALLOC_CONF='expandable_segments:True,max_split_size_mb:256,garbage_collection_threshold:0.7' for memory saving.") if not os.environ.get("CUDA_DEVICE_ORDER", None): os.environ["CUDA_DEVICE_ORDER"] = 'PCI_BUS_ID' diff --git a/gptqmodel/models/base.py b/gptqmodel/models/base.py index 3525134e1..bed6f57c2 100644 --- a/gptqmodel/models/base.py +++ b/gptqmodel/models/base.py @@ -10,6 +10,7 @@ import random import threading from collections import defaultdict +from collections.abc import Sequence from concurrent.futures import Future from typing import TYPE_CHECKING, Any, Dict, List, Optional, Type, Union @@ -26,6 +27,13 @@ modeling_utils, ) +try: # Optional dependency for huggingface datasets support + from datasets import Dataset as HFDataset + from datasets import IterableDataset as HFIterableDataset +except Exception: # pragma: no cover - datasets may not be installed + HFDataset = None + HFIterableDataset = None + from ..adapter.adapter import Adapter from ..nn_modules.qlinear import BaseQuantLinear from ..nn_modules.qlinear.torch import TorchQuantLinear @@ -56,6 +64,11 @@ if TYPE_CHECKING: from ..utils.threadx import DeviceThreadPool + try: + from datasets import Dataset as HFDatasetType + from datasets import IterableDataset as HFIterableDatasetType + except Exception: # pragma: no cover - optional dependency + HFDatasetType = HFIterableDatasetType = object class _ClassPropertyDescriptor: @@ -347,36 +360,179 @@ def full_layer_modules(cls, model_config=None, is_awq_quantize: bool = False): def prepare_dataset( self, - calibration_dataset: Union[List[Dict[str, Union[List[int], torch.LongTensor]]], List[str], List[List[int]]], + calibration_dataset: Union[ + List[Dict[str, Union[List[int], torch.LongTensor]]], + List[str], + List[List[int]], + "HFDatasetType", + "HFIterableDatasetType", + ], # Setting a fixed calibration_dataset_concat_size may improve the performance of the quantized model. calibration_dataset_concat_size: Optional[int] = None, calibration_dataset_sort: Optional[str] = None, batch_size: int = 1, calibration_data_min_length: int = 10, ): - if isinstance(calibration_dataset[0], (str, list)) or (isinstance(calibration_dataset[0], list) and all(isinstance(x, int) for x in calibration_dataset[0])): + hf_dataset_types: tuple = () + if HFDataset is not None: + hf_dataset_types += (HFDataset,) + if HFIterableDataset is not None: + hf_dataset_types += (HFIterableDataset,) + + if isinstance(calibration_dataset, str): + raise ValueError("Quantize: calibration dataset must be iterable, not a single string.") + + if hf_dataset_types and isinstance(calibration_dataset, hf_dataset_types): + raw_examples = list(calibration_dataset) + elif isinstance(calibration_dataset, list): + raw_examples = calibration_dataset + elif isinstance(calibration_dataset, Sequence) and not isinstance(calibration_dataset, (bytes, bytearray)): + raw_examples = list(calibration_dataset) + else: + raw_examples = list(calibration_dataset) + + if len(raw_examples) == 0: + raise ValueError("Quantize: calibration dataset is empty.") + + def _require_tokenizer(reason: str) -> None: if self.tokenizer is None: - raise ValueError(f"tokenizer must be provided when calibration_dataset is List[str] or List[int], type: {type(calibration_dataset[0])}") - - # Convert strings/ints to tokenized format - new_calibration_dataset = [] - for data in calibration_dataset: - # convert to tensor directly if already in token ids format (ints) - if isinstance(data, list) and all(isinstance(x, int) for x in data): - input_ids = torch.tensor([data], dtype=torch.long) - attention_mask = torch.ones_like(input_ids) - new_calibration_dataset.append({ - "input_ids": input_ids, - "attention_mask": attention_mask - }) - # call tokenizer if dataset still string format (str) - else: - tokenized = self.tokenizer(data, return_tensors="pt") - new_calibration_dataset.append({ - "input_ids": tokenized["input_ids"], - "attention_mask": tokenized["attention_mask"] - }) - calibration_dataset = new_calibration_dataset + raise ValueError(f"tokenizer must be provided when {reason}.") + + def _to_2d_long_tensor(value: Any, name: str, idx: int) -> torch.Tensor: + try: + tensor = torch.as_tensor(value, dtype=torch.long) + except Exception as exc: + raise ValueError(f"Quantize: failed to convert `{name}` to tensor for calibration item {idx}.") from exc + + if tensor.ndim == 0: + raise ValueError(f"Quantize: `{name}` for calibration item {idx} must be 1D or 2D, got scalar.") + if tensor.ndim == 1: + tensor = tensor.unsqueeze(0) + elif tensor.ndim != 2: + raise ValueError( + f"Quantize: `{name}` for calibration item {idx} must be rank 1 or 2, got rank {tensor.ndim}." + ) + return tensor + + def _pack_ids(ids_value: Any, mask_value: Any, idx: int) -> Dict[str, torch.Tensor]: + ids_tensor = _to_2d_long_tensor(ids_value, "input_ids", idx) + + if mask_value is None: + mask_tensor = torch.ones_like(ids_tensor, dtype=torch.long) + else: + mask_tensor = _to_2d_long_tensor(mask_value, "attention_mask", idx) + if mask_tensor.shape != ids_tensor.shape: + if mask_tensor.numel() == ids_tensor.numel(): + mask_tensor = mask_tensor.reshape(ids_tensor.shape) + else: + raise ValueError( + f"Quantize: attention_mask shape {tuple(mask_tensor.shape)} does not match input_ids shape " + f"{tuple(ids_tensor.shape)} for calibration item {idx}." + ) + + return { + "input_ids": ids_tensor.detach(), + "attention_mask": mask_tensor.detach(), + } + + def _tokenize_text_value(text_value: Any, idx: int) -> Dict[str, torch.Tensor]: + _require_tokenizer("calibration data contains raw text") + tokenized = self.tokenizer( + text_value, + add_special_tokens=True, + return_tensors="pt", + ) + input_ids = tokenized["input_ids"] + attention_mask = tokenized.get("attention_mask") + return _pack_ids(input_ids, attention_mask, idx) + + def _tokenize_messages_value(messages_value: Any, idx: int) -> Dict[str, torch.Tensor]: + _require_tokenizer("calibration data uses the `messages` feature") + apply_fn = getattr(self.tokenizer, "apply_template", None) + if apply_fn is None: + raise ValueError("tokenizer must expose `apply_template` to handle `messages` calibration data.") + try: + templated = apply_fn(messages_value, tokenize=False) + except TypeError: + templated = apply_fn(messages_value) + + if templated is None: + raise ValueError(f"tokenizer.apply_template returned None for calibration item {idx}.") + + if hasattr(templated, "get"): + ids_value = templated.get("input_ids") + mask_value = templated.get("attention_mask") + text_value = templated.get("text") + if ids_value is not None: + return _pack_ids(ids_value, mask_value, idx) + if text_value is not None: + return _tokenize_text_value(text_value, idx) + + if isinstance(templated, (list, tuple)): + if len(templated) > 0 and isinstance(templated[0], int): + return _pack_ids(list(templated), None, idx) + raise ValueError( + f"tokenizer.apply_template returned an unsupported sequence type for calibration item {idx}." + ) + + if torch.is_tensor(templated): + return _pack_ids(templated, None, idx) + + if isinstance(templated, str): + return _tokenize_text_value(templated, idx) + + raise ValueError( + f"tokenizer.apply_template returned unsupported type {type(templated)} for calibration item {idx}." + ) + + processed_examples: List[Dict[str, torch.Tensor]] = [] + for idx, example in enumerate(raw_examples): + if isinstance(example, dict): + if "messages" in example: + apply_fn = getattr(self.tokenizer, "apply_template", None) if self.tokenizer else None + if apply_fn is None: + if "text" in example: + processed_examples.append(_tokenize_text_value(example["text"], idx)) + continue + raise ValueError( + "tokenizer must expose `apply_template` or calibration data must provide `text` when using `messages`." + ) + processed_examples.append(_tokenize_messages_value(example["messages"], idx)) + continue + if "text" in example: + processed_examples.append(_tokenize_text_value(example["text"], idx)) + continue + if "input_ids" in example: + processed_examples.append(_pack_ids(example["input_ids"], example.get("attention_mask"), idx)) + continue + raise ValueError( + f"Quantize: unsupported calibration example structure at index {idx}: keys={list(example.keys())}" + ) + + if isinstance(example, str): + processed_examples.append(_tokenize_text_value(example, idx)) + continue + + if isinstance(example, (list, tuple)): + if all(isinstance(x, int) for x in example): + processed_examples.append(_pack_ids(list(example), None, idx)) + continue + raise ValueError( + f"Quantize: list-based calibration example at index {idx} must contain only integers." + ) + + if torch.is_tensor(example): + processed_examples.append(_pack_ids(example, None, idx)) + continue + + try: + processed_examples.append(_pack_ids(example, None, idx)) + except Exception as exc: + raise ValueError( + f"Quantize: unsupported calibration example type {type(example)} at index {idx}." + ) from exc + + calibration_dataset = processed_examples def _convert_tensor_to_list(tensor): if isinstance(tensor, torch.Tensor): @@ -409,6 +565,7 @@ def _convert_tensor_to_list(tensor): f"Use quantize(calibration_data_min_length={calibration_data_min_length}) to set a custom minimum length.") if calibration_dataset_concat_size: + _require_tokenizer("`calibration_dataset_concat_size` is specified") concatenated_data = [] input_ids_buff = [] attention_mask_buff = [] @@ -532,7 +689,7 @@ def quantize( calibration: Union[List[Dict[str, Union[List[int], torch.LongTensor]]], List[str], List[int]], # Setting a fixed calibration_dataset_concat_size may improve the performance of the quantized model. calibration_concat_size: Optional[int] = None, - calibration_sort: Optional[str] = None, # valid values are asc, desc, shuffle + calibration_sort: Optional[str] = "desc", # valid values are asc, desc, shuffle batch_size: int = 1, tokenizer: Optional[PreTrainedTokenizerBase] = None, logger_board: Optional[str] = None, @@ -1126,6 +1283,7 @@ def _reload_turtle_model_sync(self) -> Optional[PreTrainedModel]: **reload_kwargs, ) new_model._model_init_kwargs = reload_kwargs + new_model.eval() return new_model def _schedule_turtle_reload(self) -> None: @@ -1163,6 +1321,7 @@ def _reload_task(): **reload_kwargs, ) model._model_init_kwargs = reload_kwargs + model.eval() return model finally: self._turtle_ready.set() @@ -1170,7 +1329,9 @@ def _reload_task(): self._turtle_ready.clear() try: - future = pool.submit(CPU, _reload_task) + # Re-loading constructs new Parameter objects; run outside inference mode + # so autograd metadata stays intact even on inference-optimised workers. + future = pool.submit(CPU, _reload_task, _threadx_inference_mode=False) except Exception as exc: log.warning("Turtle reload scheduling failed; falling back to sync reload: %s", exc) self._turtle_reload_future = None @@ -1182,6 +1343,7 @@ def _reload_task(): self._turtle_ready.clear() new_model = self._reload_turtle_model_sync() if new_model is not None: + new_model.eval() self.turtle_model = new_model self._turtle_ready.set() @@ -1209,6 +1371,8 @@ def _apply_completed_turtle_reload(self, *, wait: bool = False) -> None: except Exception as exc: log.warning("Background turtle reload failed; retrying synchronously: %s", exc) new_model = self._reload_turtle_model_sync() + if new_model is not None: + new_model.eval() with self._turtle_reload_lock: if self._turtle_reload_future is future: self._turtle_reload_future = None @@ -1219,7 +1383,9 @@ def _apply_completed_turtle_reload(self, *, wait: bool = False) -> None: with self._turtle_reload_lock: if self._turtle_reload_future is future: - self.turtle_model = new_model + if new_model is not None: + new_model.eval() + self.turtle_model = new_model self._turtle_reload_future = None self._turtle_ready.set() diff --git a/gptqmodel/utils/_extension_loader.py b/gptqmodel/utils/_extension_loader.py new file mode 100644 index 000000000..998f46eb2 --- /dev/null +++ b/gptqmodel/utils/_extension_loader.py @@ -0,0 +1,146 @@ +from __future__ import annotations + +import ctypes +import importlib +import importlib.machinery +import importlib.util +import sys +from pathlib import Path +from types import ModuleType +from typing import Iterable, Optional + + +_TORCH_SHARED_LIBS_PRELOADED = False + + +def _ensure_torch_shared_libraries_loaded() -> None: + """Load torch's shared libraries with RTLD_GLOBAL so extensions can resolve symbols.""" + + global _TORCH_SHARED_LIBS_PRELOADED + if _TORCH_SHARED_LIBS_PRELOADED: + return + + try: + import torch # local import to avoid hard dependency if torch is absent + except Exception: + return + + torch_lib_dir = Path(torch.__file__).resolve().parent / "lib" + if not torch_lib_dir.is_dir(): + return + + # Load core torch libraries first so subsequent extension loads can resolve symbols. + load_order = ( + "libtorch_python.so", + "libtorch_cuda.so", + "libtorch_cpu.so", + "libtorch.so", + "libc10_cuda.so", + "libc10.so", + "libtorch_global_deps.so", + ) + + for name in load_order: + candidate = torch_lib_dir / name + if not candidate.is_file(): + continue + try: + mode = getattr(ctypes, "RTLD_GLOBAL", None) + if mode is None: + ctypes.CDLL(str(candidate)) + else: + ctypes.CDLL(str(candidate), mode=mode) + except OSError: + # Silently ignore individual load failures; later loads may still succeed + continue + + _TORCH_SHARED_LIBS_PRELOADED = True + + +def load_extension_module(module_name: str, + package: Optional[str] = "gptqmodel") -> ModuleType: + """Import a compiled extension, with fallbacks for editable installs. + + Args: + module_name: The qualified module name to import. + package: Package hint used to derive search paths. + + Returns: + The loaded module. + + Raises: + ImportError: If the module cannot be located or loaded. + """ + if module_name in sys.modules: + return sys.modules[module_name] + + try: + return importlib.import_module(module_name) + except ImportError as primary_error: + ext_path = _resolve_extension_path(module_name, package) + if ext_path is None: + raise primary_error + + loader = importlib.machinery.ExtensionFileLoader(module_name, + str(ext_path)) + spec = importlib.util.spec_from_loader(module_name, loader) + if spec is None: + raise primary_error + + module = importlib.util.module_from_spec(spec) + try: + _ensure_torch_shared_libraries_loaded() + loader.exec_module(module) + except Exception as load_error: # pragma: no cover - surface exact failure + raise ImportError( + f"Failed to load extension module {module_name} from {ext_path}: {load_error}" + ) from load_error + + sys.modules[module_name] = module + return module + + +def _resolve_extension_path(module_name: str, + package: Optional[str]) -> Optional[Path]: + for directory in _candidate_directories(package): + for suffix in importlib.machinery.EXTENSION_SUFFIXES: + candidate = directory / f"{module_name}{suffix}" + if candidate.is_file(): + return candidate + return None + + +def _candidate_directories(package: Optional[str]) -> Iterable[Path]: + seen = set() + + def _add(path: Path): + try: + resolved = path.resolve() + except (FileNotFoundError, RuntimeError): + resolved = path + if resolved not in seen: + seen.add(resolved) + yield resolved + + if package: + spec = importlib.util.find_spec(package) + if spec: + locations = spec.submodule_search_locations or [] + if not locations and spec.origin: + locations = [Path(spec.origin).parent] + for location in locations: + location_path = Path(location) + yield from _add(location_path) + yield from _add(location_path.parent) + + # Fallbacks cover source checkout and editable installs. + current = Path(__file__).resolve() + base = current.parent.parent # gptqmodel/ + for candidate in ( + base, + base.parent, + base / "lib", + base.parent / "lib", + base.parent / "build", + ): + yield from _add(candidate) diff --git a/gptqmodel/utils/marlin.py b/gptqmodel/utils/marlin.py index 841e7e325..97111ec22 100644 --- a/gptqmodel/utils/marlin.py +++ b/gptqmodel/utils/marlin.py @@ -8,6 +8,7 @@ import torch from ..utils.logger import setup_logger +from ._extension_loader import load_extension_module from .marlin_scalar_type import ScalarType from .rocm import IS_ROCM @@ -16,7 +17,7 @@ marlin_import_exception = None try: - import gptqmodel_marlin_kernels + gptqmodel_marlin_kernels = load_extension_module("gptqmodel_marlin_kernels") except ImportError as e: marlin_import_exception = str(e) diff --git a/gptqmodel/utils/threadx.py b/gptqmodel/utils/threadx.py index be0f1af08..a1b20440e 100644 --- a/gptqmodel/utils/threadx.py +++ b/gptqmodel/utils/threadx.py @@ -338,40 +338,47 @@ def _run(self): Main loop: pull tasks, set device context, execute, mark completion, and fulfill or fail the future. Completion is accounted BEFORE resolving the future to make stats() deterministic even under test interleavings. + + Workers default to inference mode for throughput but individual tasks + may override via `_threadx_inference_mode`. """ _activate_thread_device(self.device) - maybe_inference = torch.inference_mode() if self._inference_mode else contextlib.nullcontext() - with maybe_inference: - while not self._stop.is_set(): - is_task, fn, args, kwargs, fut = self._q.get() - try: - if not is_task: - if DEBUG_ON: log.debug(f"{self.name}: received sentinel; exiting") - break - if DEBUG_ON: log.debug(f"{self.name}: task begin; qsize={self._q.qsize()}") - # Tasks take a **read** lock so janitor's write lock can't interleave - with self.rwlock.reader(): - stream = kwargs.pop("_cuda_stream", None) - with _device_ctx(self.device): + while not self._stop.is_set(): + is_task, fn, args, kwargs, fut = self._q.get() + try: + if not is_task: + if DEBUG_ON: log.debug(f"{self.name}: received sentinel; exiting") + break + if DEBUG_ON: log.debug(f"{self.name}: task begin; qsize={self._q.qsize()}") + + stream = kwargs.pop("_cuda_stream", None) + override_inference = kwargs.pop("_threadx_inference_mode", None) + use_inference = self._inference_mode if override_inference is None else bool(override_inference) + + # Tasks take a **read** lock so janitor's write lock can't interleave + with self.rwlock.reader(): + with _device_ctx(self.device): + inference_ctx = torch.inference_mode() if use_inference else contextlib.nullcontext() + with inference_ctx: if stream is not None and self.device.type == "cuda": with torch.cuda.stream(stream): result = fn(*args, **kwargs) else: result = fn(*args, **kwargs) - # Counters must be updated before resolving futures to prevent - # tests reading stats mid-transition and seeing stale totals. - self._on_task_finished(self.key) - if not fut.cancelled(): - fut.set_result(result) - if DEBUG_ON: log.debug(f"{self.name}: task done") - except BaseException as exc: - # Even on exception we must decrement inflight and update totals. - self._on_task_finished(self.key) - if not fut.cancelled(): - fut.set_exception(exc) - if DEBUG_ON: log.debug(f"{self.name}: task exception: {exc!r}") - finally: - self._q.task_done() + # Counters must be updated before resolving futures to prevent + # tests reading stats mid-transition and seeing stale totals. + self._on_task_finished(self.key) + if not fut.cancelled(): + fut.set_result(result) + if DEBUG_ON: log.debug(f"{self.name}: task done") + except BaseException as exc: + # Even on exception we must decrement inflight and update totals. + self._on_task_finished(self.key) + if not fut.cancelled(): + fut.set_exception(exc) + if DEBUG_ON: log.debug(f"{self.name}: task exception: {exc!r}") + finally: + self._q.task_done() try: self._on_worker_exit(self.key, self) finally: diff --git a/gptqmodel/utils/torch.py b/gptqmodel/utils/torch.py index 7ca315d36..2ae844ea3 100644 --- a/gptqmodel/utils/torch.py +++ b/gptqmodel/utils/torch.py @@ -244,16 +244,24 @@ def tf32_enable_guard(): yield return - prev_matmul = torch.backends.cuda.matmul.allow_tf32 - prev_cudnn = torch.backends.cudnn.allow_tf32 + if torch.backends.fp32_precision == "tf32": + yield + return + + torch.backends.fp32_precision = "tf32" + torch.backends.cuda.matmul.fp32_precision = "tf32" + torch.backends.cudnn.fp32_precision = "tf32" + torch.backends.cudnn.conv.fp32_precision = "tf32" + torch.backends.cudnn.rnn.fp32_precision = "tf32" - torch.backends.cuda.matmul.allow_tf32 = True - torch.backends.cudnn.allow_tf32 = True try: yield finally: - torch.backends.cuda.matmul.allow_tf32 = prev_matmul - torch.backends.cudnn.allow_tf32 = prev_cudnn + torch.backends.fp32_precision = "ieee" + torch.backends.cuda.matmul.fp32_precision = "ieee" + torch.backends.cudnn.fp32_precision = "ieee" + torch.backends.cudnn.conv.fp32_precision = "ieee" + torch.backends.cudnn.rnn.fp32_precision = "ieee" @contextmanager @@ -262,13 +270,21 @@ def tf32_disable_guard(): yield return - prev_matmul = torch.backends.cuda.matmul.allow_tf32 - prev_cudnn = torch.backends.cudnn.allow_tf32 + if torch.backends.fp32_precision == "ieee": + yield + return + + torch.backends.fp32_precision = "ieee" + torch.backends.cuda.matmul.fp32_precision = "ieee" + torch.backends.cudnn.fp32_precision = "ieee" + torch.backends.cudnn.conv.fp32_precision = "ieee" + torch.backends.cudnn.rnn.fp32_precision = "ieee" - torch.backends.cuda.matmul.allow_tf32 = False - torch.backends.cudnn.allow_tf32 = False try: yield finally: - torch.backends.cuda.matmul.allow_tf32 = prev_matmul - torch.backends.cudnn.allow_tf32 = prev_cudnn + torch.backends.fp32_precision = "tf32" + torch.backends.cuda.matmul.fp32_precision = "tf32" + torch.backends.cudnn.fp32_precision = "tf32" + torch.backends.cudnn.conv.fp32_precision = "tf32" + torch.backends.cudnn.rnn.fp32_precision = "tf32" \ No newline at end of file diff --git a/gptqmodel_ext/marlin/generate_kernels.py b/gptqmodel_ext/marlin/generate_kernels.py index 143b095ae..f69b38db9 100644 --- a/gptqmodel_ext/marlin/generate_kernels.py +++ b/gptqmodel_ext/marlin/generate_kernels.py @@ -1,14 +1,12 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import glob import itertools -import os -import subprocess +from pathlib import Path import jinja2 FILE_HEAD = """ -// auto generated by generate.py +// auto generated by generate_kernels.py // clang-format off #include "kernel.h" @@ -17,6 +15,8 @@ namespace MARLIN_NAMESPACE_NAME { """.strip() +FILE_TAIL = "}\n" + TEMPLATE = ("template __global__ void Marlin<" "{{scalar_t}}, " "{{w_type_id}}, " @@ -49,93 +49,114 @@ DTYPES = ["fp16", "bf16"] -def remove_old_kernels(): - for filename in glob.glob(os.path.dirname(__file__) + "/kernel_*.cu"): - subprocess.call(["rm", "-f", filename]) +def remove_old_kernels() -> None: + root = Path(__file__).parent + for path in root.glob("kernel_*.cu"): + path.unlink(missing_ok=True) +def _write_kernel_file(scalar_type: str, dtype: str, templates: list[str]) -> Path: + root = Path(__file__).parent + scalar_suffix = scalar_type.split("::", 1)[1].lower() if "::" in scalar_type else scalar_type.lower() + output_path = root / f"kernel_{dtype}_{scalar_suffix}.cu" -def generate_new_kernels(): - for scalar_type, dtype in itertools.product(SCALAR_TYPES, DTYPES): - all_template_str_list = [] + lines = [FILE_HEAD, "", f"// Instantiations for dtype={dtype}, weight={scalar_type}", ""] + lines.append("\n".join(templates)) + lines.append("") + lines.append(FILE_TAIL) - for group_blocks, m_blocks, thread_configs in itertools.product( - GROUP_BLOCKS, THREAD_M_BLOCKS, THREAD_CONFIGS): + output_path.write_text("\n".join(lines), encoding="utf-8") + return output_path - # act order case only support gptq-int4 and gptq-int8 - if group_blocks == 0 and scalar_type not in [ - "vllm::kU4B8", "vllm::kU8B128" - ]: - continue - if thread_configs[2] == 256: - # for small batch (m_blocks == 1), we only need (128, 128, 256) - # for large batch (m_blocks > 1), we only need (64, 256, 256) - if m_blocks <= 1 and thread_configs[0] != 128: - continue - if m_blocks > 1 and thread_configs[0] != 64: - continue - - # we only support channelwise quantization and group_size == 128 - # for fp8 - if scalar_type == "vllm::kFE4M3fn" and group_blocks not in [-1, 8]: + +def render_templates_for_combo(scalar_type: str, dtype: str) -> list[str]: + results: list[str] = [] + for group_blocks, m_blocks, thread_configs in itertools.product( + GROUP_BLOCKS, THREAD_M_BLOCKS, THREAD_CONFIGS): + + # act order case only support gptq-int4 and gptq-int8 + if group_blocks == 0 and scalar_type not in [ + "vllm::kU4B8", "vllm::kU8B128" + ]: + continue + if thread_configs[2] == 256: + # for small batch (m_blocks == 1), we only need (128, 128, 256) + # for large batch (m_blocks > 1), we only need (64, 256, 256) + if m_blocks <= 1 and thread_configs[0] != 128: continue - # nvfp4 only supports group_size == 16 - # mxfp4 only supports group_size == 32 - if scalar_type == "vllm::kFE2M1f" and group_blocks not in [1, 2]: + if m_blocks > 1 and thread_configs[0] != 64: continue - # other quantization methods don't support group_size = 16 - if scalar_type != "vllm::kFE2M1f" and group_blocks == 1: + + # we only support channelwise quantization and group_size == 128 + # for fp8 + if scalar_type == "vllm::kFE4M3fn" and group_blocks not in [-1, 8]: + continue + # nvfp4 only supports group_size == 16 + # mxfp4 only supports group_size == 32 + if scalar_type == "vllm::kFE2M1f" and group_blocks not in [1, 2]: + continue + # other quantization methods don't support group_size = 16 + if scalar_type != "vllm::kFE2M1f" and group_blocks == 1: + continue + + k_blocks = thread_configs[0] // 16 + n_blocks = thread_configs[1] // 16 + threads = thread_configs[2] + + c_dtype = "half" if dtype == "fp16" else "nv_bfloat16" + + is_zp_float_list = [False] + if dtype == "fp16" and scalar_type == "vllm::kU4" and \ + group_blocks == 4: + # HQQ (is_zp_float = true) only supports + # 4bit quantization and fp16 + is_zp_float_list.append(True) + + if scalar_type == "vllm::kFE2M1f" and group_blocks == 1: + s_type = "vllm::kFE4M3fn" + elif scalar_type == "vllm::kFE2M1f" and group_blocks == 2: + s_type = "vllm::kFE8M0fnu" + if dtype == "fp16": + # we cannot safely dequantize e8m0 to fp16, so skip this continue + elif dtype == "fp16": + s_type = "vllm::kFloat16" + elif dtype == "bf16": + s_type = "vllm::kBFloat16" + + for is_zp_float in is_zp_float_list: + template_str = jinja2.Template(TEMPLATE).render( + scalar_t=c_dtype, + w_type_id=scalar_type + ".id()", + s_type_id=s_type + ".id()", + threads=threads, + thread_m_blocks=max(m_blocks, 1), + thread_n_blocks=n_blocks, + thread_k_blocks=k_blocks, + m_block_size_8=m_blocks == 0.5, + stages="pipe_stages", + group_blocks=group_blocks, + is_zp_float=is_zp_float, + ) + + results.append(template_str) + + return results + + +def generate_new_kernels() -> None: + emitted = False + for scalar_type, dtype in itertools.product(SCALAR_TYPES, DTYPES): + templates = render_templates_for_combo(scalar_type, dtype) + if not templates: + continue + + _write_kernel_file(scalar_type, dtype, templates) + emitted = True - k_blocks = thread_configs[0] // 16 - n_blocks = thread_configs[1] // 16 - threads = thread_configs[2] - - c_dtype = "half" if dtype == "fp16" else "nv_bfloat16" - - is_zp_float_list = [False] - if dtype == "fp16" and scalar_type == "vllm::kU4" and \ - group_blocks == 4: - # HQQ (is_zp_float = true) only supports - # 4bit quantization and fp16 - is_zp_float_list.append(True) - - if scalar_type == "vllm::kFE2M1f" and group_blocks == 1: - s_type = "vllm::kFE4M3fn" - elif scalar_type == "vllm::kFE2M1f" and group_blocks == 2: - s_type = "vllm::kFE8M0fnu" - if dtype == "fp16": - # we cannot safely dequantize e8m0 to fp16, so skip this - continue - elif dtype == "fp16": - s_type = "vllm::kFloat16" - elif dtype == "bf16": - s_type = "vllm::kBFloat16" - - for is_zp_float in is_zp_float_list: - template_str = jinja2.Template(TEMPLATE).render( - scalar_t=c_dtype, - w_type_id=scalar_type + ".id()", - s_type_id=s_type + ".id()", - threads=threads, - thread_m_blocks=max(m_blocks, 1), - thread_n_blocks=n_blocks, - thread_k_blocks=k_blocks, - m_block_size_8=m_blocks == 0.5, - stages="pipe_stages", - group_blocks=group_blocks, - is_zp_float=is_zp_float, - ) - - all_template_str_list.append(template_str) - - file_content = FILE_HEAD + "\n\n" - file_content += "\n\n".join(all_template_str_list) + "\n\n}\n" - filename = f"kernel_{dtype}_{scalar_type[6:].lower()}.cu" - - with open(os.path.join(os.path.dirname(__file__), filename), "w") as f: - f.write(file_content) + if not emitted: + raise RuntimeError("No marlin kernels were generated; check template configuration.") if __name__ == "__main__": remove_old_kernels() - generate_new_kernels() \ No newline at end of file + generate_new_kernels() diff --git a/gptqmodel_ext/marlin/kernel_bf16_kfe2m1f.cu b/gptqmodel_ext/marlin/kernel_bf16_kfe2m1f.cu deleted file mode 100644 index cdd8472c4..000000000 --- a/gptqmodel_ext/marlin/kernel_bf16_kfe2m1f.cu +++ /dev/null @@ -1,69 +0,0 @@ -// auto generated by generate.py -// clang-format off - -#include "kernel.h" -#include "marlin_template.h" - -namespace MARLIN_NAMESPACE_NAME { - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -} diff --git a/gptqmodel_ext/marlin/kernel_bf16_kfe4m3fn.cu b/gptqmodel_ext/marlin/kernel_bf16_kfe4m3fn.cu deleted file mode 100644 index 8128dbb57..000000000 --- a/gptqmodel_ext/marlin/kernel_bf16_kfe4m3fn.cu +++ /dev/null @@ -1,69 +0,0 @@ -// auto generated by generate.py -// clang-format off - -#include "kernel.h" -#include "marlin_template.h" - -namespace MARLIN_NAMESPACE_NAME { - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -} diff --git a/gptqmodel_ext/marlin/kernel_bf16_ku4.cu b/gptqmodel_ext/marlin/kernel_bf16_ku4.cu deleted file mode 100644 index 87ca11766..000000000 --- a/gptqmodel_ext/marlin/kernel_bf16_ku4.cu +++ /dev/null @@ -1,129 +0,0 @@ -// auto generated by generate.py -// clang-format off - -#include "kernel.h" -#include "marlin_template.h" - -namespace MARLIN_NAMESPACE_NAME { - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -} diff --git a/gptqmodel_ext/marlin/kernel_bf16_ku4b8.cu b/gptqmodel_ext/marlin/kernel_bf16_ku4b8.cu deleted file mode 100644 index 9c0e8dacb..000000000 --- a/gptqmodel_ext/marlin/kernel_bf16_ku4b8.cu +++ /dev/null @@ -1,159 +0,0 @@ -// auto generated by generate.py -// clang-format off - -#include "kernel.h" -#include "marlin_template.h" - -namespace MARLIN_NAMESPACE_NAME { - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -} diff --git a/gptqmodel_ext/marlin/kernel_bf16_ku8b128.cu b/gptqmodel_ext/marlin/kernel_bf16_ku8b128.cu deleted file mode 100644 index ac13dd97a..000000000 --- a/gptqmodel_ext/marlin/kernel_bf16_ku8b128.cu +++ /dev/null @@ -1,159 +0,0 @@ -// auto generated by generate.py -// clang-format off - -#include "kernel.h" -#include "marlin_template.h" - -namespace MARLIN_NAMESPACE_NAME { - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -} diff --git a/gptqmodel_ext/marlin/kernel_fp16_kfe2m1f.cu b/gptqmodel_ext/marlin/kernel_fp16_kfe2m1f.cu deleted file mode 100644 index aba349be6..000000000 --- a/gptqmodel_ext/marlin/kernel_fp16_kfe2m1f.cu +++ /dev/null @@ -1,39 +0,0 @@ -// auto generated by generate.py -// clang-format off - -#include "kernel.h" -#include "marlin_template.h" - -namespace MARLIN_NAMESPACE_NAME { - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -} diff --git a/gptqmodel_ext/marlin/kernel_fp16_kfe4m3fn.cu b/gptqmodel_ext/marlin/kernel_fp16_kfe4m3fn.cu deleted file mode 100644 index 25934db9d..000000000 --- a/gptqmodel_ext/marlin/kernel_fp16_kfe4m3fn.cu +++ /dev/null @@ -1,69 +0,0 @@ -// auto generated by generate.py -// clang-format off - -#include "kernel.h" -#include "marlin_template.h" - -namespace MARLIN_NAMESPACE_NAME { - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -} diff --git a/gptqmodel_ext/marlin/kernel_fp16_ku4.cu b/gptqmodel_ext/marlin/kernel_fp16_ku4.cu deleted file mode 100644 index 3d81ae767..000000000 --- a/gptqmodel_ext/marlin/kernel_fp16_ku4.cu +++ /dev/null @@ -1,159 +0,0 @@ -// auto generated by generate.py -// clang-format off - -#include "kernel.h" -#include "marlin_template.h" - -namespace MARLIN_NAMESPACE_NAME { - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -} diff --git a/gptqmodel_ext/marlin/kernel_fp16_ku4b8.cu b/gptqmodel_ext/marlin/kernel_fp16_ku4b8.cu deleted file mode 100644 index 87b9b4a4c..000000000 --- a/gptqmodel_ext/marlin/kernel_fp16_ku4b8.cu +++ /dev/null @@ -1,159 +0,0 @@ -// auto generated by generate.py -// clang-format off - -#include "kernel.h" -#include "marlin_template.h" - -namespace MARLIN_NAMESPACE_NAME { - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -} diff --git a/gptqmodel_ext/marlin/kernel_fp16_ku8b128.cu b/gptqmodel_ext/marlin/kernel_fp16_ku8b128.cu deleted file mode 100644 index 7f1e3fab3..000000000 --- a/gptqmodel_ext/marlin/kernel_fp16_ku8b128.cu +++ /dev/null @@ -1,159 +0,0 @@ -// auto generated by generate.py -// clang-format off - -#include "kernel.h" -#include "marlin_template.h" - -namespace MARLIN_NAMESPACE_NAME { - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -template __global__ void Marlin( MARLIN_KERNEL_PARAMS ); - -} diff --git a/pyproject.toml b/pyproject.toml index a1b9ed236..dda50e769 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,12 +1,12 @@ [build-system] requires = [ - "setuptools>=80.9.0", - "ninja>=1.13.0" + "setuptools>=80.9", + "ninja>=1.13.0", # required for faster compilataion ] build-backend = "setuptools.build_meta:__legacy__" [project] -name = "GPT-QModel" +name = "GPTQModel" dynamic = ["version"] description = "Production ready LLM model compression/quantization toolkit with hw accelerated inference support for both cpu/gpu via HF, vLLM, and SGLang." readme = "README.md" @@ -32,7 +32,6 @@ classifiers = [ ] dependencies = [ "accelerate>=1.10.1", - # "datasets>=3.5.0", "numpy==2.2.6", "torch>=2.8.0", "safetensors>=0.6.2", @@ -47,6 +46,11 @@ dependencies = [ "random_word>=1.0.13", "tokenicer>=0.0.5", "logbar>=0.0.5", + "maturin>=1.9.4", # required by safetensors and hf_transfer + "datasets>=3.6.0", + "pyarrow>=21.0", + "dill>=0.3.8", # datasets requirements + # "cython>=3.1.4", # required by hf-xet/hf-transfer # "flash-attn>=2.8.3", <-- install for lower vram usage ] @@ -60,7 +64,7 @@ test = [ ] quality = [ "ruff==0.13.0", - "isort==6.0.1", + # "isort==6.0.1", ] vllm = [ "vllm>=0.10.2", diff --git a/setup.py b/setup.py index b2d2af1b5..b7129226d 100644 --- a/setup.py +++ b/setup.py @@ -2,7 +2,6 @@ # SPDX-FileCopyrightText: 2024-2025 qubitium@modelcloud.ai # SPDX-License-Identifier: Apache-2.0 # Contact: qubitium@modelcloud.ai, x.com/qubitium -import glob import os import re import subprocess @@ -191,27 +190,44 @@ def _major_minor(v: str) -> str: return v +def _version_geq(version: str | None, major: int, minor: int = 0) -> bool: + if not version: + return False + try: + parts = re.split(r"[._-]", version) + ver_major = int(parts[0]) if parts else 0 + ver_minor = int(parts[1]) if len(parts) > 1 else 0 + return (ver_major, ver_minor) >= (major, minor) + except Exception: + return False + + +def _nvcc_release_version() -> str | None: + out = _probe_cmd(["nvcc", "--version"]) + if not out: + print( + "NVCC not found: For Ubuntu, run `sudo update-alternatives --config cuda` to fix path for already installed Cuda." + ) + return None + + match = re.search(r"release\s+(\d+)\.(\d+)", out) + if match: + return f"{match.group(1)}.{match.group(2)}" + return None + + def _detect_cuda_version() -> str | None: # Priority: env → nvidia-smi → nvcc v = os.environ.get("CUDA_VERSION") if v and v.strip(): return v.strip() - # nvidia-smi (modern drivers expose cuda_version) - out = _probe_cmd(["nvidia-smi", "--query-gpu=cuda_version", "--format=csv,noheader"]) - if out: - line = _first_token_line(out) - if line and re.match(r"^\d+\.\d+(\.\d+)?$", line): - return line - # nvcc --version (parse 'release X.Y') - out = _probe_cmd(["nvcc", "--version"]) - if out: - m = re.search(r"release\s+(\d+)\.(\d+)", out) - if m: - return f"{m.group(1)}.{m.group(2)}" + return _nvcc_release_version() - return None + +def _detect_nvcc_version() -> str | None: + return _nvcc_release_version() def get_version_tag() -> str: @@ -242,6 +258,7 @@ def get_version_tag() -> str: CUDA_VERSION = _read_env("CUDA_VERSION") ROCM_VERSION = _read_env("ROCM_VERSION") TORCH_CUDA_ARCH_LIST = _read_env("TORCH_CUDA_ARCH_LIST") +NVCC_VERSION = _read_env("NVCC_VERSION") # respect user env then detect if not TORCH_VERSION: @@ -250,6 +267,8 @@ def get_version_tag() -> str: CUDA_VERSION = _detect_cuda_version() if not ROCM_VERSION: ROCM_VERSION = _detect_rocm_version() +if not NVCC_VERSION: + NVCC_VERSION = _detect_nvcc_version() SKIP_ROCM_VERSION_CHECK = _read_env("SKIP_ROCM_VERSION_CHECK") FORCE_BUILD = _bool_env("GPTQMODEL_FORCE_BUILD", False) @@ -432,7 +451,12 @@ def _env_enabled_any(names, default="1") -> bool: extra_compile_args["nvcc"] += [f"-D_GLIBCXX_USE_CXX11_ABI={CXX11_ABI}"] if not ROCM_VERSION: + # if _version_geq(NVCC_VERSION, 13, 0): + # extra_compile_args["nvcc"].append("--device-entity-has-hidden-visibility=false") extra_compile_args["nvcc"] += [ + # Allow instantiations of __global__ templates to live in different + # translation units (we split marlin kernels for Ninja parallelism). + "-static-global-template-stub=false", "--threads", "8", # NVCC parallelism "--optimize=3", # alias for -O3 # "-rdc=true", # enable relocatable device code, required for future cuda > 13.x <-- TODO FIX ME broken loading @@ -472,7 +496,22 @@ def _hipify_compile_flags(flags): if sys.platform != "win32": if not ROCM_VERSION and HAS_CUDA_V8: if BUILD_MARLIN: - marlin_template_kernel_srcs = glob.glob("gptqmodel_ext/marlin/kernel_*.cu") + marlin_kernel_dir = Path("gptqmodel_ext/marlin") + marlin_kernel_files = sorted(marlin_kernel_dir.glob("kernel_*.cu")) + + if not marlin_kernel_files: + generator_script = marlin_kernel_dir / "generate_kernels.py" + if generator_script.exists(): + print("Regenerating marlin template instantiations for parallel compilation...") + subprocess.check_call([sys.executable, str(generator_script)]) + marlin_kernel_files = sorted(marlin_kernel_dir.glob("kernel_*.cu")) + + if not marlin_kernel_files: + raise RuntimeError( + "No generated marlin kernel templates detected. Run generate_kernels.py before building." + ) + + marlin_template_kernel_srcs = [str(path) for path in marlin_kernel_files] extensions += [ cpp_ext.CUDAExtension( "gptqmodel_marlin_kernels", diff --git a/tests/allocator_bench.py b/tests/allocator_bench.py new file mode 100644 index 000000000..703bc00e7 --- /dev/null +++ b/tests/allocator_bench.py @@ -0,0 +1,167 @@ +#!/usr/bin/env python3 +"""Stress the CUDA caching allocator to compare PYTORCH_ALLOC_CONF tunings.""" + +from __future__ import annotations + +import argparse +import json +import os +import time +from typing import Dict, Iterable, List + +import torch + + +PROFILES: Dict[str, List[Dict[str, Iterable[int]]]] = { + "24gb": [ + {"allocate_mb": [2048, 2048, 2048, 1536, 1536, 1280], "release": [1, 4], "sleep_cycles": 80_000}, + {"allocate_mb": [2560, 1792, 1024, 1600], "release": [0, 3], "sleep_cycles": 80_000}, + {"allocate_mb": [1280, 2048, 896, 2304], "release": [], "sleep_cycles": 80_000}, + ], + "80gb": [ + {"allocate_mb": [4096, 4096, 4096, 3584, 3584, 3072, 3072], "release": [1, 3, 5], "sleep_cycles": 120_000}, + {"allocate_mb": [5120, 4096, 3584, 2560, 2304], "release": [0, 2], "sleep_cycles": 120_000}, + {"allocate_mb": [3584, 4096, 4608, 5120], "release": [], "sleep_cycles": 120_000}, + ], +} + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument("--profile", choices=PROFILES.keys(), default="24gb") + parser.add_argument("--iterations", type=int, default=5) + parser.add_argument("--warmup", type=int, default=1) + parser.add_argument("--device", type=int, default=0) + parser.add_argument("--dtype", type=str, default="float16") + parser.add_argument("--final-sleep-cycles", type=int, default=500_000, help="extra sleep cycles after each iteration") + parser.add_argument("--phase-sleep-scale", type=float, default=1.0, help="scale factor applied to per-phase sleep settings") + parser.add_argument("--fill", choices=("none", "uniform", "normal"), default="none") + parser.add_argument("--json", action="store_true", help="emit metrics as JSON") + return parser.parse_args() + + +def _dtype_from_string(dtype_name: str) -> torch.dtype: + dtype = getattr(torch, dtype_name) + if not isinstance(dtype, torch.dtype): + raise TypeError(f"Unsupported dtype {dtype_name!r}") + return dtype + + +def _allocate_tensor( + size_mb: int, + dtype: torch.dtype, + device: torch.device, + element_size: int, + fill_mode: str, +) -> torch.Tensor: + numel = (size_mb * 1024 * 1024) // element_size + if numel == 0: + raise ValueError(f"Requested allocation too small: {size_mb} MB") + tensor = torch.empty((numel,), dtype=dtype, device=device) + if fill_mode == "uniform": + tensor.uniform_(-1.0, 1.0) + elif fill_mode == "normal": + tensor.normal_(mean=0.0, std=1.0) + return tensor + + +def _run_iteration( + phases: List[Dict[str, Iterable[int]]], + dtype: torch.dtype, + device: torch.device, + element_size: int, + fill_mode: str, + phase_sleep_scale: float, + final_sleep_cycles: int, +) -> None: + allocations: List[torch.Tensor] = [] + for phase in phases: + for size_mb in phase["allocate_mb"]: + allocations.append(_allocate_tensor(size_mb, dtype, device, element_size, fill_mode)) + sleep_cycles = int(phase.get("sleep_cycles", 0) * phase_sleep_scale) + if sleep_cycles: + torch.cuda._sleep(sleep_cycles) + release_indices = phase.get("release", []) + if release_indices: + for idx in sorted(release_indices, reverse=True): + if 0 <= idx < len(allocations): + del allocations[idx] + post_sleep_cycles = int(phase.get("post_sleep_cycles", 0) * phase_sleep_scale) + if post_sleep_cycles: + torch.cuda._sleep(post_sleep_cycles) + if final_sleep_cycles: + torch.cuda._sleep(final_sleep_cycles) + allocations.clear() + + +def main() -> None: + args = parse_args() + dtype = _dtype_from_string(args.dtype) + device = torch.device(f"cuda:{args.device}") + torch.cuda.set_device(device) + + element_size = torch.tensor([], dtype=dtype).element_size() + + phases = PROFILES[args.profile] + + torch.cuda.empty_cache() + torch.cuda.reset_peak_memory_stats(device) + for _ in range(args.warmup): + _run_iteration( + phases, + dtype, + device, + element_size, + args.fill, + args.phase_sleep_scale, + args.final_sleep_cycles, + ) + torch.cuda.synchronize(device) + + torch.cuda.reset_peak_memory_stats(device) + + durations: List[float] = [] + for _ in range(args.iterations): + torch.cuda.synchronize(device) + start = time.perf_counter() + _run_iteration( + phases, + dtype, + device, + element_size, + args.fill, + args.phase_sleep_scale, + args.final_sleep_cycles, + ) + torch.cuda.synchronize(device) + durations.append(time.perf_counter() - start) + + peak_reserved = torch.cuda.max_memory_reserved(device) + peak_allocated = torch.cuda.max_memory_allocated(device) + + metrics = { + "device": str(device), + "profile": args.profile, + "iterations": args.iterations, + "dtype": args.dtype, + "fill": args.fill, + "peak_reserved_bytes": peak_reserved, + "peak_reserved_gib": peak_reserved / (1024 ** 3), + "peak_allocated_bytes": peak_allocated, + "peak_allocated_gib": peak_allocated / (1024 ** 3), + "per_iter_seconds": durations, + "mean_iter_seconds": sum(durations) / len(durations), + "stdev_iter_seconds": float(torch.tensor(durations).std(unbiased=False)) if len(durations) > 1 else 0.0, + "pytorch_alloc_conf": os.environ.get("PYTORCH_ALLOC_CONF", ""), + } + + if args.json: + print(json.dumps(metrics)) + else: + print("Benchmark metrics:") + for key, value in metrics.items(): + print(f" {key}: {value}") + + +if __name__ == "__main__": + main() diff --git a/tests/benchmark/benchmark_test.py b/tests/benchmark/benchmark_test.py index 59e687bf6..0004310b9 100644 --- a/tests/benchmark/benchmark_test.py +++ b/tests/benchmark/benchmark_test.py @@ -8,7 +8,7 @@ os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" -os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" +os.environ["PYTORCH_ALLOC_CONF"] = "expandable_segments:True" import unittest # noqa: E402 diff --git a/tests/models/model_test.py b/tests/models/model_test.py index d85ce2ad2..01973ebeb 100644 --- a/tests/models/model_test.py +++ b/tests/models/model_test.py @@ -11,7 +11,7 @@ os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" -os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" +os.environ["PYTORCH_ALLOC_CONF"] = "expandable_segments:True,max_split_size_mb:256,garbage_collection_threshold:0.7" #"expandable_segments:True" # Following makes test results more deterministic but much slower # # the CUBLAS env is required for use_deterministic_algorithms @@ -33,6 +33,7 @@ import shutil # noqa: E402 import tempfile # noqa: E402 import unittest # noqa: E402 +from collections.abc import Iterable # noqa: E402 import torch.cuda # noqa: E402 from datasets import load_dataset # noqa: E402 @@ -145,36 +146,77 @@ def load_tokenizer(self, model_id_or_path, trust_remote_code=False): return tokenizer @classmethod - def load_dataset(self, tokenizer=None, rows: int = 0): - - #traindata = load_dataset("json", data_files="/monster/data/model/dataset/c4-train.00000-of-01024.json.gz", split="train") - traindata = load_dataset("neuralmagic/calibration", "LLM", split="train") - #neuralmagic / calibration - # Load data directly from gzipped JSON file - # with gzip.open("/monster/data/model/dataset/c4-train.00000-of-01024.json.gz", 'rt', encoding='utf-8') as f: - # traindata = [json.loads(line) for line in f] - # - if not tokenizer: - return traindata.select(range(rows)) - # return traindata[:rows] - - # Count total rows - # print("Total rows:", len(traindata), "wanted rows=", rows) - - # Select the first N rows (e.g., N=10) - subset = traindata.select(range(min(rows, len(traindata)))) - - return subset["text"] - - datas = [] - for index, sample in enumerate(traindata): - tokenized = tokenizer(sample['text']) - if len(tokenized.data['input_ids']) < self.INPUTS_MAX_LENGTH: - datas.append(tokenized) - if len(datas) >= rows: - break - - return datas + def load_dataset(cls, tokenizer=None, rows: int = 0): + try: + dataset = load_dataset(path="~/nm-calibration", name="LLM", split="train") + except Exception as exc: # pragma: no cover - exercised in fallbacks + log.warning("load_dataset failed; falling back to local parquet: %s", exc) + dataset = cls._load_calibration_parquet() + + if rows > 0: + return dataset.select(range(min(rows, len(dataset)))) + return dataset + + @staticmethod + def _load_calibration_parquet(): + parquet_path = Path("~/nm-calibration/llm.parquet").expanduser() + if not parquet_path.exists(): + raise FileNotFoundError(f"Calibration parquet not found at {parquet_path}") + + try: + import pandas as pd + except ImportError: # pragma: no cover - depends on test environment + pd = None + + if pd is not None: + records = pd.read_parquet(parquet_path).to_dict(orient="records") + return ModelTest._LocalCalibrationDataset(records) + + try: + import pyarrow.parquet as pq + except ImportError as err: + raise RuntimeError( + "Neither pandas nor pyarrow is available to load calibration parquet" + ) from err + + table = pq.read_table(parquet_path) + records = table.to_pylist() + return ModelTest._LocalCalibrationDataset(records) + + class _LocalCalibrationDataset: + __slots__ = ("_records",) + + def __init__(self, records): + normalized = [] + for record in records: + item = {} + for key, value in dict(record).items(): + if hasattr(value, "tolist") and not isinstance(value, (str, bytes)): + value = value.tolist() + item[key] = value + normalized.append(item) + self._records = normalized + + def __len__(self): + return len(self._records) + + def __iter__(self): + return iter(self._records) + + def __getitem__(self, index): + return self._records[index] + + def select(self, indices): + if isinstance(indices, slice): + selected = self._records[indices] + else: + if isinstance(indices, range): + indices = list(indices) + elif not isinstance(indices, Iterable): + raise TypeError("select `indices` must be a slice or iterable of integers") + selected = [self._records[i] for i in indices] + return self.__class__(selected) + def check_kernel(self, model, expected_kernels): modules = {module.__class__ for _, module in model.named_modules() if isinstance(module, BaseQuantLinear)} diff --git a/tests/test_quant_and_eora.py b/tests/test_quant_and_eora.py index 5f79e74ff..6e2b163ee 100644 --- a/tests/test_quant_and_eora.py +++ b/tests/test_quant_and_eora.py @@ -36,9 +36,7 @@ class Test(ModelTest): - #NATIVE_MODEL_ID = "/monster/data/model/Qwen2.5-0.5B-Instruct/" - #NATIVE_MODEL_ID = "/monster/data/model/tinyllama-15M-stories" - NATIVE_MODEL_ID = "/monster/data/model/Llama-3.2-1B" + NATIVE_MODEL_ID = "/monster/data/model/Qwen2.5-0.5B-Instruct/" NATIVE_ARC_CHALLENGE_ACC = 0.3567 NATIVE_ARC_CHALLENGE_ACC_NORM = 0.3805 @@ -50,9 +48,7 @@ def setUpClass(cls): @parameterized.expand( [ - # (QUANT_METHOD.GPTQ, FORMAT.GPTQ), # gptq v2 (METHOD.GPTQ, FORMAT.GPTQ), # gptq v1 - #(QUANT_METHOD.QQQ, FORMAT.QQQ), ] ) def test_quant_and_eora(self, quant_method: METHOD, format: FORMAT): @@ -61,8 +57,8 @@ def test_quant_and_eora(self, quant_method: METHOD, format: FORMAT): desc_act = False act_group_aware = True rank = 128 - batch_size = 1 - calibration_dataset_rows = 512 + batch_size = 4 + calibration_dataset_rows = 1024 calibration_dataset_concat_size = 0 # disable adapter_path = "eora" dataset_id = "allenai/c4" @@ -99,10 +95,12 @@ def test_quant_and_eora(self, quant_method: METHOD, format: FORMAT): model = GPTQModel.load( model_id_or_path=self.NATIVE_MODEL_ID, quantize_config=quant_config, + # apply_chat_template=True, ) model.quantize( calibration=calibration_dataset, + calibration_sort="desc", batch_size=batch_size, calibration_concat_size=calibration_dataset_concat_size, ) # @@ -153,6 +151,7 @@ def bench(self, path: str, backend: BACKEND, adapter: Optional[Lora]): model_or_id_or_path=model, framework=EVAL.LM_EVAL, tasks=[EVAL.LM_EVAL.ARC_CHALLENGE], + apply_chat_template=True, # MMLU is too slow for ci test # EVAL.LM_EVAL.MMLU ) diff --git a/tests/test_quant_formats.py b/tests/test_quant_formats.py index 2cf181571..5e7dadc2f 100644 --- a/tests/test_quant_formats.py +++ b/tests/test_quant_formats.py @@ -8,7 +8,7 @@ os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" -os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" +os.environ["PYTORCH_ALLOC_CONF"] = "expandable_segments:True" # -- end do not touch