diff --git a/gptqmodel/__init__.py b/gptqmodel/__init__.py index b54284928..07c28adcb 100644 --- a/gptqmodel/__init__.py +++ b/gptqmodel/__init__.py @@ -10,16 +10,23 @@ DEBUG_ON = env_flag("DEBUG") +from .utils.linalg_warmup import run_torch_linalg_warmup from .utils.threadx import DeviceThreadPool DEVICE_THREAD_POOL = DeviceThreadPool( inference_mode=True, + warmups={ + "cuda": run_torch_linalg_warmup, + "xpu": run_torch_linalg_warmup, + "mps": run_torch_linalg_warmup, + "cpu": run_torch_linalg_warmup, + }, workers={ "cuda:per": 4, "xpu:per": 1, "mps": 8, - "cpu": 8, + "cpu": min(12, max(1, (os.cpu_count() or 1) // 2)), "model_loader:cpu": 2, }, empty_cache_every_n=512, diff --git a/gptqmodel/eora/eora.py b/gptqmodel/eora/eora.py index 6c037471f..eeff9df13 100644 --- a/gptqmodel/eora/eora.py +++ b/gptqmodel/eora/eora.py @@ -12,7 +12,6 @@ from ..utils.logger import setup_logger from ..utils.rocm import IS_ROCM -from ..utils.safe import TORCH_LINALG log = setup_logger() @@ -89,7 +88,7 @@ def eora_compute_lora( original_backend = torch.backends.cuda.preferred_linalg_library() torch.backends.cuda.preferred_linalg_library(backend="magma") - L, Q = TORCH_LINALG.eigh(raw_scaling_diag_matrix) + L, Q = torch.linalg.eigh(raw_scaling_diag_matrix) if (L < 0).any(): ## When expanding the calibration data size for EoRA, I suggest maintaining the balance by allocating 50% to general input (C4) and the remaining 50% to downstream task data. @@ -107,7 +106,7 @@ def eora_compute_lora( delta_scale = torch.matmul(w_wq_delta, scaling_diag_matrix) - U, S, V = TORCH_LINALG.svd(delta_scale, full_matrices=False) + U, S, V = torch.linalg.svd(delta_scale, full_matrices=False) lowrank_r = rank truc_s = S[:lowrank_r] truc_u = U[:, :lowrank_r] diff --git a/gptqmodel/quantization/gptq.py b/gptqmodel/quantization/gptq.py index 4ff916554..21b5c3fb6 100644 --- a/gptqmodel/quantization/gptq.py +++ b/gptqmodel/quantization/gptq.py @@ -23,7 +23,6 @@ from ..quantization import QuantizeConfig from ..utils.device import get_device from ..utils.logger import setup_logger -from ..utils.safe import TORCH_LINALG from .gar import compose_final_perm, compute_global_perm, compute_local_perms, invert_perm from .quantizer import HF_OPTIMUM, Quantizer @@ -567,8 +566,8 @@ def hessian_inverse(self, H: torch.Tensor): H2 = H.clone() H2[diag, diag] += damp * mean # TODO call to torch.linalg is not threadsafe? Porque no? Esta muy mal. - H2 = TORCH_LINALG.cholesky(H2) - Hinv = TORCH_LINALG.cholesky(torch.cholesky_inverse(H2), upper=True) + H2 = torch.linalg.cholesky(H2) + Hinv = torch.linalg.cholesky(torch.cholesky_inverse(H2), upper=True) del H, H2 break except torch._C._LinAlgError as e: diff --git a/gptqmodel/quantization/qqq.py b/gptqmodel/quantization/qqq.py index 4d495a93f..92d8426aa 100644 --- a/gptqmodel/quantization/qqq.py +++ b/gptqmodel/quantization/qqq.py @@ -15,7 +15,6 @@ from ..looper.named_module import NamedModule from ..quantization.quantizer import HF_OPTIMUM from ..utils import setup_logger -from ..utils.safe import TORCH_LINALG from .gptq import get_number_of_rows_and_cols @@ -355,9 +354,9 @@ def quantize( damp = percdamp * torch.mean(torch.diag(H)) diag = torch.arange(self.columns, device=self.dev) H[diag, diag] += damp - H = TORCH_LINALG.cholesky(H) + H = torch.linalg.cholesky(H) H = torch.cholesky_inverse(H) - H = TORCH_LINALG.cholesky(H, upper=True) + H = torch.linalg.cholesky(H, upper=True) Hinv = H for i1 in range(0, self.columns, blocksize): diff --git a/gptqmodel/quantization/rotation/rotation.py b/gptqmodel/quantization/rotation/rotation.py index 7b74b6f38..110e45110 100644 --- a/gptqmodel/quantization/rotation/rotation.py +++ b/gptqmodel/quantization/rotation/rotation.py @@ -11,7 +11,6 @@ from ...utils.logger import setup_logger from ...utils.model import get_module_by_name_prefix -from ...utils.safe import TORCH_LINALG from ...utils.torch import torch_empty_cache from .hadamard_utils import apply_exact_had_to_linear, random_hadamard_matrix @@ -91,7 +90,7 @@ def random_orthogonal_matrix(size, device): """ torch.cuda.empty_cache() random_matrix = torch.randn(size, size, dtype=torch.float64).to(device) - q, r = TORCH_LINALG.qr(random_matrix) + q, r = torch.linalg.qr(random_matrix) q *= torch.sign(torch.diag(r)).unsqueeze(0) return q diff --git a/gptqmodel/utils/cuda_activation_buffer.py b/gptqmodel/utils/cuda_activation_buffer.py new file mode 100644 index 000000000..0344f2b6b --- /dev/null +++ b/gptqmodel/utils/cuda_activation_buffer.py @@ -0,0 +1,195 @@ +# SPDX-FileCopyrightText: 2025 ModelCloud.ai +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +import dataclasses +import queue +import threading +import time +from typing import Any, Callable, List, Optional + +import torch + + +__all__ = ["ActivationPacket", "CudaEventActivationBuffer"] + + +@dataclasses.dataclass(slots=True) +class ActivationPacket: + """ + Tracks a single async device->host transfer triggered from a forward hook. + + The event is recorded on the dedicated copy stream so the consumer can + decide when to block. The `host_tensor` already points at pinned memory. + """ + + event: torch.cuda.Event + host_tensor: torch.Tensor + meta: Optional[Any] = None + created_at: float = dataclasses.field(default_factory=time.perf_counter) + + +class CudaEventActivationBuffer: + """ + Schedules non-blocking GPU->CPU copies using a dedicated CUDA stream + event. + + Typical usage inside a forward hook:: + + buffer = CudaEventActivationBuffer(device="cuda:6") + + def hook(module, inputs, output): + tensor = output[0] if isinstance(output, (tuple, list)) else output + buffer.capture_async(tensor, meta=module.__class__.__name__) + + # elsewhere in consumer thread + for packet in buffer.drain(): + packet.event.synchronize() + process(packet.host_tensor, packet.meta) + + The hook thread returns immediately after enqueuing the async copy which + allows the caller to release activation VRAM without waiting on D2H traffic. + """ + + def __init__( + self, + device: torch.device | str | int, + stream: Optional[torch.cuda.Stream] = None, + pin_memory: bool = True, + host_allocator: Optional[Callable[[torch.Size, torch.dtype, torch.layout], torch.Tensor]] = None, + host_reclaimer: Optional[Callable[[torch.Tensor], None]] = None, + ) -> None: + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is not available for CudaEventActivationBuffer.") + + dev = torch.device(device) + if dev.type != "cuda": + raise ValueError(f"CudaEventActivationBuffer requires a CUDA device, got {dev}.") + + if dev.index is None: + dev = torch.device("cuda", torch.cuda.current_device()) + + self._device = dev + self._pin_memory = pin_memory + self._host_allocator = host_allocator + self._host_reclaimer = host_reclaimer + + with torch.cuda.device(self._device): + self._copy_stream = stream or torch.cuda.Stream() + + self._pending: "queue.SimpleQueue[ActivationPacket]" = queue.SimpleQueue() + self._lock = threading.Lock() + self._approx_pending = 0 + + def capture_async( + self, + activation: torch.Tensor, + *, + meta: Any = None, + enqueue: bool = True, + ) -> ActivationPacket: + """ + Enqueue an async D2H copy of ``activation`` onto the buffer stream. + + Returns an ActivationPacket which is also available later via drain(). + """ + if activation.device != self._device: + raise ValueError( + f"Activation tensor is on {activation.device}, expected {self._device}." + ) + + activation = activation.detach() + if not activation.is_contiguous(): + activation = activation.contiguous() + + host = self._allocate_host(activation) + + event = torch.cuda.Event(blocking=False, interprocess=False) + + current = torch.cuda.current_stream(self._device) + copy_stream = self._copy_stream + copy_stream.wait_stream(current) + + with torch.cuda.stream(copy_stream): + host.copy_(activation, non_blocking=True) + event.record(copy_stream) + + packet = ActivationPacket(event=event, host_tensor=host, meta=meta) + if enqueue: + self._pending_put(packet) + return packet + + def drain(self, *, wait: bool = True, max_items: Optional[int] = None) -> List[ActivationPacket]: + """ + Collect all queued packets (or up to max_items) in FIFO order. + + When ``wait`` is True we synchronize each packet's event before returning. + """ + packets: List[ActivationPacket] = [] + pulled = 0 + + while True: + if max_items is not None and pulled >= max_items: + break + + try: + packet = self._pending_get() + except queue.Empty: + break + + pulled += 1 + if wait: + packet.event.synchronize() + packets.append(packet) + + return packets + + def recycle(self, packet: ActivationPacket) -> None: + """ + Return a packet's host buffer to the allocator pool (if provided). + """ + if self._host_reclaimer is not None: + self._host_reclaimer(packet.host_tensor) + + def pending_count(self) -> int: + """ + Non-blocking length check. The SimpleQueue does not expose qsize() + reliably on all platforms, so we track with a lock-protected counter. + """ + with self._lock: + count = getattr(self, "_approx_pending", 0) + return count + + def __len__(self) -> int: + return self.pending_count() + + def __enter__(self) -> "CudaEventActivationBuffer": + return self + + def __exit__(self, exc_type, exc, tb) -> None: + self.drain(wait=True) + + def _pending_put(self, packet: ActivationPacket) -> None: + with self._lock: + self._approx_pending = getattr(self, "_approx_pending", 0) + 1 + self._pending.put(packet) + + def _pending_get(self) -> ActivationPacket: + packet = self._pending.get_nowait() + with self._lock: + self._approx_pending = max(getattr(self, "_approx_pending", 0) - 1, 0) + return packet + + def _allocate_host(self, activation: torch.Tensor) -> torch.Tensor: + if self._host_allocator is not None: + host = self._host_allocator(activation.shape, activation.dtype, activation.layout) + if not host.is_pinned(): + raise ValueError("Custom host allocator must return pinned CPU tensors.") + return host + return torch.empty( + activation.shape, + dtype=activation.dtype, + layout=activation.layout, + device="cpu", + pin_memory=self._pin_memory, + ) diff --git a/gptqmodel/utils/linalg_warmup.py b/gptqmodel/utils/linalg_warmup.py new file mode 100644 index 000000000..cace622f8 --- /dev/null +++ b/gptqmodel/utils/linalg_warmup.py @@ -0,0 +1,69 @@ +# SPDX-FileCopyrightText: 2024-2025 ModelCloud.ai +# SPDX-License-Identifier: Apache-2.0 +# Contact: qubitium@modelcloud.ai, x.com/qubitium + +from __future__ import annotations + +import contextlib +import threading + +import torch + + +_GLOBAL_WARMUP_LOCK = threading.Lock() + + +def _make_spd(size: int, device: torch.device, dtype: torch.dtype) -> torch.Tensor: + """Generate a small symmetric positive definite matrix.""" + base = torch.randn((size, size), device=device, dtype=dtype) + identity = torch.eye(size, device=device, dtype=dtype) + return base @ base.transpose(-1, -2) + identity * 1e-3 + + +def _run_cholesky_and_eigh(device: torch.device, dtype: torch.dtype) -> None: + spd = _make_spd(4, device, dtype) + torch.linalg.cholesky(spd) + torch.linalg.eigh(spd) + + +def _run_svd(device: torch.device, dtype: torch.dtype) -> None: + mat = torch.randn((4, 3), device=device, dtype=dtype) + torch.linalg.svd(mat, full_matrices=False) + + +def _run_qr(device: torch.device, dtype: torch.dtype) -> None: + square = torch.randn((4, 4), device=device, dtype=dtype) + torch.linalg.qr(square) + + +def run_torch_linalg_warmup(device: torch.device) -> None: + """ + Execute the torch.linalg operators used across the project once on the worker thread. + + Serialized under a global lock to avoid races inside PyTorch's lazy wrappers. The warmup + still runs once per physical device so backend-specific handles are initialized where needed. + """ + with _GLOBAL_WARMUP_LOCK: + dtypes = (torch.float32, torch.float64) + for dtype in dtypes: + _run_cholesky_and_eigh(device, dtype) + _run_svd(device, dtype) + _run_qr(device, dtype) + + if device.type == "cuda" and hasattr(torch.backends, "cuda"): + preferred = getattr(torch.backends.cuda, "preferred_linalg_library", None) + if callable(preferred): + current = preferred() + # Core warmup already ran using the currently preferred backend above. + # Some installations fall back to MAGMA when the primary solver is unavailable, + # so we pre-initialize MAGMA as well when it differs from the preferred backend. + if current and current != "magma": + with contextlib.suppress(Exception): + torch.backends.cuda.preferred_linalg_library(backend="magma") + _run_cholesky_and_eigh(device, torch.float32) + if current: + with contextlib.suppress(Exception): + torch.backends.cuda.preferred_linalg_library(backend=current) + + +__all__ = ["run_torch_linalg_warmup"] diff --git a/gptqmodel/utils/model.py b/gptqmodel/utils/model.py index e2bbbd090..ab590ba41 100644 --- a/gptqmodel/utils/model.py +++ b/gptqmodel/utils/model.py @@ -705,72 +705,76 @@ def pack_module( packer_label = None - with lock: + if lock is not None: + with lock: + layers[name] = layer + qModules[name] = module + else: layers[name] = layer qModules[name] = module - # TODO FIX ME..remove hard coded qqq pack - if quant_linear_cls.QUANT_TYPE == "qqq": - if q_scales_extra is not None: - q_scales_extra = q_scales_extra.to(CPU) - packer_label = "module.pack" - with log_time_block( - packer_label, - logger=log, - module_name=name, - ): - module.pack(linear=layer, scales=q_scales, s_extra=q_scales_extra) - else: - effective_impl = (pack_impl or "original").lower() - - if effective_impl in {"cpu", "block", "pack_block"}: - effective_impl = "block" - elif effective_impl in {"original", "pack_original"}: + # TODO FIX ME..remove hard coded qqq pack + if quant_linear_cls.QUANT_TYPE == "qqq": + if q_scales_extra is not None: + q_scales_extra = q_scales_extra.to(CPU) + packer_label = "module.pack" + with log_time_block( + packer_label, + logger=log, + module_name=name, + ): + module.pack(linear=layer, scales=q_scales, s_extra=q_scales_extra) + else: + effective_impl = (pack_impl or "original").lower() + + if effective_impl in {"cpu", "block", "pack_block"}: + effective_impl = "block" + elif effective_impl in {"original", "pack_original"}: + effective_impl = "original" + elif effective_impl == "gpu": + if not HAS_CUDA: + log.warning("pack_module: GPU packing requested but CUDA is unavailable; falling back to original pack.") effective_impl = "original" - elif effective_impl == "gpu": - if not HAS_CUDA: - log.warning("pack_module: GPU packing requested but CUDA is unavailable; falling back to original pack.") - effective_impl = "original" - elif not hasattr(module, "pack_gpu"): - log.warning("pack_module: GPU packing requested but module lacks pack_gpu; falling back to original pack.") - effective_impl = "original" - elif effective_impl != "original": - log.warning( - "pack_module: Unknown pack_impl `%s`; defaulting to original pack.", - pack_impl, - ) + elif not hasattr(module, "pack_gpu"): + log.warning("pack_module: GPU packing requested but module lacks pack_gpu; falling back to original pack.") effective_impl = "original" + elif effective_impl != "original": + log.warning( + "pack_module: Unknown pack_impl `%s`; defaulting to original pack.", + pack_impl, + ) + effective_impl = "original" - label_map = { - "gpu": "module.pack_gpu", - "block": "module.pack_block", - "original": "module.pack_original", - } + label_map = { + "gpu": "module.pack_gpu", + "block": "module.pack_block", + "original": "module.pack_original", + } - packer_label = label_map[effective_impl] + packer_label = label_map[effective_impl] - with log_time_block( - packer_label, - logger=log, - module_name=name, - ): - if effective_impl == "gpu": - module.pack_gpu( - linear=layer, - scales=q_scales, - zeros=q_zeros, - g_idx=q_g_idx, - device=target_device, - ) - elif effective_impl == "block": - module.pack_block( - linear=layer, - scales=q_scales, - zeros=q_zeros, - g_idx=q_g_idx, - ) - else: - module.pack_original(linear=layer, scales=q_scales, zeros=q_zeros, g_idx=q_g_idx) + with log_time_block( + packer_label, + logger=log, + module_name=name, + ): + if effective_impl == "gpu": + module.pack_gpu( + linear=layer, + scales=q_scales, + zeros=q_zeros, + g_idx=q_g_idx, + device=target_device, + ) + elif effective_impl == "block": + module.pack_block( + linear=layer, + scales=q_scales, + zeros=q_zeros, + g_idx=q_g_idx, + ) + else: + module.pack_original(linear=layer, scales=q_scales, zeros=q_zeros, g_idx=q_g_idx) if ( quantize_config is not None diff --git a/gptqmodel/utils/safe.py b/gptqmodel/utils/safe.py index dc448b297..c92d422c3 100644 --- a/gptqmodel/utils/safe.py +++ b/gptqmodel/utils/safe.py @@ -14,7 +14,6 @@ from types import ModuleType import threadpoolctl as _threadpoolctl -import torch class ThreadSafe(ModuleType): @@ -96,14 +95,10 @@ def __dir__(self): def __repr__(self): return repr(self._value) - - -TORCH_LINALG = ThreadSafe(torch.linalg) THREADPOOLCTL = ThreadSafe(_threadpoolctl) GC = ThreadSafe(gc) __all__ = [ "ThreadSafe", - "TORCH_LINALG", "THREADPOOLCTL", "GC", ] diff --git a/gptqmodel/utils/threadx.py b/gptqmodel/utils/threadx.py index e0bb73069..7bef93341 100644 --- a/gptqmodel/utils/threadx.py +++ b/gptqmodel/utils/threadx.py @@ -13,7 +13,7 @@ import time import traceback from concurrent.futures import Future -from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union +from typing import Any, Callable, Dict, Iterable, List, Optional, Set, Tuple, Union import torch @@ -309,6 +309,7 @@ def __init__( name: Optional[str] = None, inference_mode: bool = False, cpu_core: Optional[int] = None, + warmup_fn: Optional[Callable[[torch.device], None]] = None, *, key_override: Optional[str] = None, ): @@ -316,6 +317,7 @@ def __init__( self.rwlock = rwlock self._on_task_finished = on_task_finished self._on_worker_exit = on_worker_exit + self._warmup_fn = warmup_fn if key_override is not None: self.key = key_override @@ -389,6 +391,16 @@ def _apply_cpu_affinity(self) -> None: ) self._affinity_applied = True + def _run_warmup(self) -> None: + warmup_fn = self._warmup_fn + if warmup_fn is None: + return + try: + with ctx(self.rwlock.reader(), _device_ctx(self.device)): + warmup_fn(self.device) + finally: + self._warmup_fn = None + def _run(self): """ Main loop: pull tasks, set device context, execute, mark completion, and @@ -400,6 +412,11 @@ def _run(self): """ self._apply_cpu_affinity() _activate_thread_device(self.device) + try: + self._run_warmup() + except BaseException as exc: + self._abort_process(exc) + return while not self._stop.is_set(): is_task, fn, args, kwargs, fut = self._q.get() try: @@ -537,6 +554,7 @@ def __init__( include_mps: bool = True, include_cpu: bool = True, inference_mode: bool = False, + warmups: Optional[Dict[str, Callable[[torch.device], None]]] = None, empty_cache_every_n: int = 50, # <=0 disables janitor workers: Optional[Dict[str, int]] = None, # e.g. {'cpu':4, 'cuda:per':1, 'cuda:0':3} gc_debounce_seconds: float = 0.02, # absorb bursty triggers before GC @@ -557,6 +575,9 @@ def __init__( (CUDA parents must include an explicit index, e.g. 'alias:cuda:0') Unspecified devices default to 1 worker each. gc_debounce_seconds: short wait to coalesce multiple triggers. + warmups: optional mapping from device family (e.g. 'cuda') to a callable + run once after the worker activates its device. A special key + 'default' applies when no family-specific warmup is found. pin_cpu_workers: bind CPU device workers to individual CPU cores when affinity APIs are available. Defaults to False so CPU tasks may float across cores unless explicitly opt-in. @@ -614,6 +635,11 @@ def __init__( self._last_gc_done_per_device: Dict[str, int] = {} self._inference_mode = bool(inference_mode) + self._worker_warmups = ( + {str(k).lower(): fn for k, fn in warmups.items()} if warmups else None + ) + self._warmup_lock = threading.Lock() + self._warmup_ran_keys: Set[str] = set() workers_cfg = workers or {} base_workers: Dict[str, int] = {} @@ -857,6 +883,28 @@ def _priority(dev_type: str) -> int: return plan + def _resolve_worker_warmup(self, dev: torch.device, key: str) -> Optional[Callable[[torch.device], None]]: + mapping = self._worker_warmups + if not mapping: + return None + family = dev.type.lower() + warmup = mapping.get(family) + primary_key = key.split(":", 1)[0].lower() + if warmup is None and primary_key in mapping: + warmup = mapping[primary_key] + if warmup is None: + warmup = mapping.get("default") + if warmup is None: + return None + + # Map virtual workers back to their parent key so warmup runs once per physical device. + physical_key = self._virtual_to_parent.get(key, key) + with self._warmup_lock: + if physical_key in self._warmup_ran_keys: + return None + self._warmup_ran_keys.add(physical_key) + return warmup + def _spawn_worker( self, dev: torch.device, @@ -867,6 +915,7 @@ def _spawn_worker( """ Create and start a worker bound to the provided device. """ + warmup_fn = self._resolve_worker_warmup(dev, key) w = _DeviceWorker( device=dev, rwlock=self._locks[key], @@ -875,6 +924,7 @@ def _spawn_worker( name=name, inference_mode=self._inference_mode, cpu_core=cpu_core, + warmup_fn=warmup_fn, key_override=key, ) return w diff --git a/tests/models/test_qwen3_moe.py b/tests/models/test_qwen3_moe.py index a36f4e480..749235c1e 100644 --- a/tests/models/test_qwen3_moe.py +++ b/tests/models/test_qwen3_moe.py @@ -6,7 +6,6 @@ from model_test import ModelTest - class TestQwen3Moe(ModelTest): NATIVE_MODEL_ID = "/monster/data/model/Qwen3-30B-A3B" QUANT_ARC_MAX_DELTA_FLOOR_PERCENT = 0.04 diff --git a/tests/models/test_qwen3_next.py b/tests/models/test_qwen3_next.py index dcb960d14..d25b8b5bd 100644 --- a/tests/models/test_qwen3_next.py +++ b/tests/models/test_qwen3_next.py @@ -23,6 +23,7 @@ class TestQwen3Next(ModelTest): QUANT_BATCH_SIZE = 4 CALIB_NOISE_MODE = "unseen" CALIB_NOISE_PERCENT = 0.025 + USE_FLASH_ATTN = True def test_mimo(self): self.quant_lm_eval() diff --git a/tests/test_benchmark_submodule_finalize.py b/tests/test_benchmark_submodule_finalize.py new file mode 100644 index 000000000..b67048953 --- /dev/null +++ b/tests/test_benchmark_submodule_finalize.py @@ -0,0 +1,402 @@ +import os + + +# force pytest runs to target GPU 7 (becomes cuda:0 after masking) +os.environ.setdefault("CUDA_VISIBLE_DEVICES", "7") + +import threading +import time +from types import SimpleNamespace +from unittest import mock + +import pytest +import torch + +from gptqmodel.looper import gptq_processor as gptq_processor_module +from gptqmodel.looper.gptq_processor import GPTQProcessor +from gptqmodel.looper.named_module import NamedModule +from gptqmodel.nn_modules.qlinear.torch import TorchQuantLinear +from gptqmodel.quantization.config import QuantizeConfig +from gptqmodel.utils.threadx import DeviceThreadPool + + +def _dummy_prepare_dataset(*, calibration_dataset, calibration_dataset_concat_size, calibration_dataset_sort, batch_size): + return calibration_dataset + + +class _DummyProgressBar: + def title(self, _): + return self + + def draw(self): + return None + + +class _DummyTimer: + def __init__(self): + self.events = [] + + def record(self, name, duration, source=None): + self.events.append({"name": name, "duration": duration, "source": source}) + + +class _TinyLinearModel(torch.nn.Module): + def __init__(self, features: int = 8): + super().__init__() + self.linear = torch.nn.Linear(features, features, bias=False) + + +class _LockTracker: + def __init__(self): + self._state_lock = threading.Lock() + self._entries = [] + self._active = 0 + self.max_active = 0 + + def enter(self): + start = time.perf_counter() + with self._state_lock: + self._active += 1 + self.max_active = max(self.max_active, self._active) + return start + + def exit(self, start_time: float): + end = time.perf_counter() + with self._state_lock: + self._active -= 1 + self._entries.append((start_time, end - start_time)) + + def total_duration(self) -> float: + with self._state_lock: + return sum(duration for _, duration in self._entries) + + def window_duration(self) -> float: + with self._state_lock: + if not self._entries: + return 0.0 + start_min = min(start for start, _ in self._entries) + end_max = max(start + duration for start, duration in self._entries) + return end_max - start_min + + def entry_count(self) -> int: + with self._state_lock: + return len(self._entries) + + +class _CountingLock: + def __init__(self, base_lock, tracker: _LockTracker): + self._base_lock = base_lock + self._tracker = tracker + self._start_time = None + + def __enter__(self): + self._base_lock.acquire() + self._start_time = self._tracker.enter() + return self + + def __exit__(self, exc_type, exc, tb): + self._tracker.exit(self._start_time) + self._start_time = None + self._base_lock.release() + return False + + +@pytest.mark.cuda +def test_submodule_finalize_timing(): + if not torch.cuda.is_available(): + pytest.skip("CUDA required for GPTQ finalize benchmark") + + if torch.cuda.device_count() == 0: + pytest.skip("No CUDA devices visible after masking; cannot benchmark finalize") + + torch.cuda.set_device(0) + device = torch.device("cuda", 0) + + base_model = _TinyLinearModel().to(device=device, dtype=torch.float16) + named_module = NamedModule( + base_model.linear, + name="linear", + full_name="linear", + layer_index=0, + ) + named_module.target_device = device + named_module.module.target_device = device + + qcfg = QuantizeConfig( + group_size=8, + desc_act=False, + sym=True, + mock_quantization=True, + pack_impl="original", + offload_to_disk=False, + ) + + processor = GPTQProcessor( + tokenizer=None, + qcfg=qcfg, + calibration=None, + prepare_dataset_func=_dummy_prepare_dataset, + calibration_concat_size=None, + calibration_sort=None, + batch_size=1, + require_fwd=False, + calculate_w_wq_diff=False, + ) + processor.pb = _DummyProgressBar() + + processor.preprocess(named_module, fail_safe=False) + processor.process(named_module) + + # move weights to CPU to satisfy create_quant_module invariants + base_model.to("cpu") + named_module.module.to("cpu") + named_module.target_device = torch.device("cpu") + named_module.module.target_device = torch.device("cpu") + + timer = _DummyTimer() + quant_model = SimpleNamespace( + model=base_model, + quantize_config=qcfg, + qlinear_kernel=TorchQuantLinear, + lm_head="lm_head", + quant_region_timer=timer, + quantized=False, + ) + + events = [] + last_checkpoint = None + + def record_gap(label: str, t_now: float): + nonlocal last_checkpoint + if last_checkpoint is None: + last_checkpoint = t_now + return + gap = t_now - last_checkpoint + if gap > 0: + events.append((label, gap)) + last_checkpoint = t_now + + original_create = gptq_processor_module.create_quant_module + original_pack = gptq_processor_module.pack_module + original_result_pop = GPTQProcessor.result_pop + original_unregister = NamedModule.unregister_parameter + + def wrapped_create_quant_module(*args, **kwargs): + nonlocal last_checkpoint + start = time.perf_counter() + record_gap("state_cleanup", start) + try: + return original_create(*args, **kwargs) + finally: + end = time.perf_counter() + events.append(("create_quant_module", end - start)) + last_checkpoint = end + + def wrapped_pack_module(*args, **kwargs): + nonlocal last_checkpoint + start = time.perf_counter() + record_gap("between_create_and_pack", start) + packer = None + try: + packer = original_pack(*args, **kwargs) + return packer + finally: + end = time.perf_counter() + label = packer or "None" + events.append((f"pack_module[{label}]", end - start)) + last_checkpoint = end + + def wrapped_result_pop(self, *args, **kwargs): + nonlocal last_checkpoint + start = time.perf_counter() + record_gap("between_pack_and_result_pop", start) + try: + return original_result_pop(self, *args, **kwargs) + finally: + end = time.perf_counter() + events.append(("result_pop", end - start)) + last_checkpoint = end + + def wrapped_unregister_parameter(self, *args, **kwargs): + nonlocal last_checkpoint + start = time.perf_counter() + record_gap("between_result_pop_and_unregister", start) + try: + return original_unregister(self, *args, **kwargs) + finally: + end = time.perf_counter() + events.append(("unregister_parameter", end - start)) + last_checkpoint = end + + start_time = time.perf_counter() + last_checkpoint = start_time + + with ( + mock.patch("gptqmodel.looper.gptq_processor.create_quant_module", new=wrapped_create_quant_module), + mock.patch("gptqmodel.looper.gptq_processor.pack_module", new=wrapped_pack_module), + mock.patch.object(GPTQProcessor, "result_pop", new=wrapped_result_pop), + mock.patch.object(NamedModule, "unregister_parameter", new=wrapped_unregister_parameter), + ): + processor.submodule_finalize(named_module, quant_model) + + end_time = time.perf_counter() + record_gap("post_unregister_tail", end_time) + total_elapsed = end_time - start_time + events.append(("total_elapsed", total_elapsed)) + + # sanity checks to ensure finalize replaced module and cleared state + assert "q_scales" not in named_module.state + assert "q_zeros" not in named_module.state + assert "q_g_idx" not in named_module.state + assert isinstance(quant_model.model.linear, TorchQuantLinear) + + create_time = sum(duration for label, duration in events if label == "create_quant_module") + pack_time = sum(duration for label, duration in events if label.startswith("pack_module")) + cleanup_labels = { + "state_cleanup", + "between_create_and_pack", + "between_pack_and_result_pop", + "between_result_pop_and_unregister", + "post_unregister_tail", + } + cleanup_time = sum(duration for label, duration in events if label in cleanup_labels) + other_time = total_elapsed - (create_time + pack_time + cleanup_time) + + print("\nsubmodule_finalize timing breakdown (ms):") + for label, duration in events: + print(f" {label:<32} {duration * 1000:.3f}") + + print("\nSummary:") + print(f" total_elapsed_ms = {total_elapsed * 1000:.3f}") + print(f" create_quant_module_ms = {create_time * 1000:.3f}") + print(f" pack_module_ms = {pack_time * 1000:.3f}") + print(f" cleanup_gaps_ms = {cleanup_time * 1000:.3f}") + print(f" other_ms = {other_time * 1000:.3f}") + + if timer.events: + print("\nquant_region_timer records:") + for entry in timer.events: + name = entry["name"] + duration_ms = entry["duration"] * 1000 + source = entry.get("source") + print(f" {name:<32} {duration_ms:.3f} (source={source})") + + +def _prepare_modules(processor, qcfg, device, module_count): + modules = [] + for idx in range(module_count): + base_model = _TinyLinearModel().to(device=device, dtype=torch.float16) + named_module = NamedModule( + base_model.linear, + name=f"linear_{idx}", + full_name="linear", + layer_index=idx, + ) + named_module.target_device = device + named_module.module.target_device = device + + processor.preprocess(named_module, fail_safe=False) + processor.process(named_module) + + base_model.to("cpu") + named_module.module.to("cpu") + named_module.target_device = torch.device("cpu") + named_module.module.target_device = torch.device("cpu") + + quant_model = SimpleNamespace( + model=base_model, + quantize_config=qcfg, + qlinear_kernel=TorchQuantLinear, + lm_head="lm_head", + quant_region_timer=_DummyTimer(), + quantized=False, + ) + modules.append((named_module, quant_model)) + return modules + + +def _finalize_worker(processor, module, quant_model): + start = time.perf_counter() + processor.submodule_finalize(module, quant_model) + end = time.perf_counter() + thread_name = threading.current_thread().name + return thread_name, start, end + + +@pytest.mark.cuda +@pytest.mark.parametrize("cpu_workers", [8, 32]) +def test_submodule_finalize_threadpool_serialization(cpu_workers): + if not torch.cuda.is_available(): + pytest.skip("CUDA required for GPTQ finalize pool benchmark") + + if torch.cuda.device_count() == 0: + pytest.skip("No CUDA devices visible after masking; cannot benchmark finalize") + + torch.cuda.set_device(0) + device = torch.device("cuda", 0) + + qcfg = QuantizeConfig( + group_size=8, + desc_act=False, + sym=True, + mock_quantization=True, + pack_impl="original", + offload_to_disk=False, + ) + + processor = GPTQProcessor( + tokenizer=None, + qcfg=qcfg, + calibration=None, + prepare_dataset_func=_dummy_prepare_dataset, + calibration_concat_size=None, + calibration_sort=None, + batch_size=1, + require_fwd=False, + calculate_w_wq_diff=False, + ) + processor.pb = _DummyProgressBar() + + module_count = min(cpu_workers * 2, 32) + modules = _prepare_modules(processor, qcfg, device, module_count) + + lock_tracker = _LockTracker() + original_pack_module = gptq_processor_module.pack_module + + def instrumented_pack_module(*args, **kwargs): + args_list = list(args) + + if "lock" in kwargs and kwargs["lock"] is not None: + kwargs = dict(kwargs) + kwargs["lock"] = _CountingLock(kwargs["lock"], lock_tracker) + elif len(args_list) > 7: + args_list[7] = _CountingLock(args_list[7], lock_tracker) + else: + raise AssertionError("Expected lock argument in pack_module") + + return original_pack_module(*tuple(args_list), **kwargs) + + pool = DeviceThreadPool(include_cuda=False, include_cpu=True, workers={"cpu": cpu_workers}, inference_mode=True) + + try: + with mock.patch.object(gptq_processor_module, "pack_module", new=instrumented_pack_module): + futures = [ + pool.submit(torch.device("cpu"), _finalize_worker, processor, module, quant_model) + for module, quant_model in modules + ] + results = [future.result() for future in futures] + finally: + pool.shutdown() + + thread_names = {name for name, _, _ in results} + lock_total = lock_tracker.total_duration() + lock_window = lock_tracker.window_duration() + + assert lock_tracker.entry_count() == len(modules) + assert lock_tracker.max_active == 1, f"Expected serialized pack_module, saw max_active={lock_tracker.max_active}" + + if lock_total > 0 and lock_window > 0: + ratio = lock_total / lock_window + assert ratio <= 1.05, f"Expected serialized execution; total/window ratio={ratio:.3f}" + + assert len(thread_names) >= min(cpu_workers, len(modules), 2), "Thread pool failed to utilize multiple workers" diff --git a/tests/test_cuda_event_stream_activation_buffer.py b/tests/test_cuda_event_stream_activation_buffer.py new file mode 100644 index 000000000..52749c20a --- /dev/null +++ b/tests/test_cuda_event_stream_activation_buffer.py @@ -0,0 +1,254 @@ +# SPDX-FileCopyrightText: 2025 ModelCloud.ai +# SPDX-License-Identifier: Apache-2.0 + +import statistics +import time + +import pytest +import torch +import torch.nn as nn + +from gptqmodel.utils.cuda_activation_buffer import CudaEventActivationBuffer + + +class PinnedHostPool: + def __init__(self) -> None: + self._store = {} + self.hits = 0 + self.misses = 0 + + def acquire(self, shape: torch.Size, dtype: torch.dtype, layout: torch.layout) -> torch.Tensor: + key = (tuple(shape), dtype, layout) + bucket = self._store.get(key) + if bucket: + self.hits += 1 + return bucket.pop() + self.misses += 1 + return torch.empty(shape, dtype=dtype, layout=layout, device="cpu", pin_memory=True) + + def release(self, tensor: torch.Tensor) -> None: + key = (tuple(tensor.shape), tensor.dtype, tensor.layout) + self._store.setdefault(key, []).append(tensor) + + +pytestmark = pytest.mark.skipif( + (not torch.cuda.is_available()) or torch.cuda.device_count() <= 6, + reason="CUDA device 6 is required for this test", +) + + +class ActivationEmitter(nn.Module): + """ + Minimal module that mimics a transformer block stage emitting large activations. + + We keep computation intentionally light so that host transfer dominates timing, + highlighting the benefit of async CUDA stream copies. + """ + + def __init__(self, hidden_dim: int): + super().__init__() + self.bias = nn.Parameter(torch.zeros(1, 1, hidden_dim)) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return x + self.bias + + +class ForwardBlock(nn.Module): + def __init__(self, hidden_dim: int): + super().__init__() + self.emitter = ActivationEmitter(hidden_dim) + self.norm = nn.LayerNorm(hidden_dim) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.norm(self.emitter(x)) + + +def _run_variant( + mode: str, + model: nn.Module, + batch_input: torch.Tensor, + *, + warmup: int = 2, + steps: int = 6, + buffer_kwargs: dict | None = None, + recycle_packets: bool = False, +): + assert mode in {"gpu", "sync", "async"} + + device = batch_input.device + forward_latencies = [] + drain_latencies = [] + captured_outputs = [] + tmp_store = [] + + buffer_kwargs = buffer_kwargs or {} + buffer = CudaEventActivationBuffer(device=device, **buffer_kwargs) if mode == "async" else None + + def _hook(_module, _inputs, output): + tensor = output[0] if isinstance(output, (tuple, list)) else output + tensor = tensor.detach() + if mode == "async": + assert buffer is not None + buffer.capture_async(tensor) + elif mode == "sync": + tmp_store.append(tensor.to("cpu")) + else: + tmp_store.append(tensor) + + handle = model.emitter.register_forward_hook(_hook) + + try: + current_stream = torch.cuda.current_stream(device) + with torch.inference_mode(): + total_steps = warmup + steps + for idx in range(total_steps): + tmp_store.clear() + current_stream.synchronize() + t0 = time.perf_counter() + _ = model(batch_input) + current_stream.synchronize() + elapsed = time.perf_counter() - t0 + if idx >= warmup: + forward_latencies.append(elapsed) + + if mode == "async": + assert buffer is not None + t1 = time.perf_counter() + drained = buffer.drain(wait=True) + drain_elapsed = time.perf_counter() - t1 + if idx >= warmup: + drain_latencies.append(drain_elapsed) + for pkt in drained: + captured_outputs.append(pkt.host_tensor.clone()) + if recycle_packets: + buffer.recycle(pkt) + else: + t1 = time.perf_counter() + drained = list(tmp_store) + drain_elapsed = time.perf_counter() - t1 + if idx >= warmup: + drain_latencies.append(drain_elapsed) + captured_outputs.extend(drained) + finally: + handle.remove() + if buffer is not None: + leftover = buffer.drain(wait=True) + for pkt in leftover: + captured_outputs.append(pkt.host_tensor.clone()) + if recycle_packets: + buffer.recycle(pkt) + + return forward_latencies, drain_latencies, captured_outputs + + +def test_cuda_event_stream_activation_buffer_benchmarks(): + """ + Benchmarks three capture strategies that mirror GPTQ forward hooks: + + - gpu: baseline, keep activations resident on the device. + - sync: copy to CPU immediately via blocking `.cpu()`. + - async: enqueue D2H on a dedicated CUDA stream and wait only once drained. + """ + device = torch.device("cuda", 6) + torch.cuda.set_device(device) + + batch = 4 + seq = 2048 + hidden_dim = 4096 + dtype = torch.float16 + + torch.manual_seed(0) + torch.cuda.manual_seed_all(0) + + template_model = ForwardBlock(hidden_dim).to(device=device, dtype=dtype).eval() + state = template_model.state_dict() + del template_model + + model_gpu = ForwardBlock(hidden_dim).to(device=device, dtype=dtype).eval() + model_gpu.load_state_dict(state) + + model_sync = ForwardBlock(hidden_dim).to(device=device, dtype=dtype).eval() + model_sync.load_state_dict(state) + + model_async = ForwardBlock(hidden_dim).to(device=device, dtype=dtype).eval() + model_async.load_state_dict(state) + + batch_input = torch.randn(batch, seq, hidden_dim, device=device, dtype=dtype) + + # Warm everything so subsequent measurements reflect steady-state timings. + for _ in range(3): + _ = model_gpu(batch_input) + _ = model_sync(batch_input) + _ = model_async(batch_input) + + gpu_forward, gpu_drain, gpu_outputs = _run_variant("gpu", model_gpu, batch_input) + sync_forward, sync_drain, sync_outputs = _run_variant("sync", model_sync, batch_input) + async_forward, async_drain, async_outputs = _run_variant("async", model_async, batch_input) + + pool = PinnedHostPool() + pool_kwargs = { + "host_allocator": pool.acquire, + "host_reclaimer": pool.release, + } + async_pool_forward, async_pool_drain, async_pool_outputs = _run_variant( + "async", + model_async, + batch_input, + warmup=0, + steps=5, + buffer_kwargs=pool_kwargs, + recycle_packets=True, + ) + + gpu_outputs_cpu = [t.detach().cpu() for t in gpu_outputs] + + assert len(gpu_outputs) == len(sync_outputs) == len(async_outputs) > 0 + for baseline, candidate in zip(sync_outputs, async_outputs): + assert torch.allclose(baseline, candidate, atol=0, rtol=0) + for baseline, candidate in zip(sync_outputs, gpu_outputs_cpu): + assert torch.allclose(baseline, candidate, atol=0, rtol=0) + reference_cpu = sync_outputs[0] + for candidate in async_pool_outputs: + assert torch.allclose(reference_cpu, candidate, atol=0, rtol=0) + + mean_gpu_forward = statistics.mean(gpu_forward) + mean_sync_forward = statistics.mean(sync_forward) + mean_async_forward = statistics.mean(async_forward) + + mean_gpu_drain = statistics.mean(gpu_drain) + mean_sync_drain = statistics.mean(sync_drain) + mean_async_drain = statistics.mean(async_drain) + + async_combined = [f + d for f, d in zip(async_forward, async_drain)] + combined_mean = statistics.mean(async_combined) + + # Async capture should avoid additional forward blocking relative to sync copies. + assert mean_async_forward <= mean_sync_forward + + # Async totals should be bounded by the synchronous copy baseline. + assert combined_mean <= mean_sync_forward * 1.1 + + miss_forward = async_pool_forward[0] + hit_forward_mean = statistics.mean(async_pool_forward[1:]) if len(async_pool_forward) > 1 else miss_forward + miss_drain = async_pool_drain[0] + hit_drain_mean = statistics.mean(async_pool_drain[1:]) if len(async_pool_drain) > 1 else miss_drain + + assert pool.misses >= 1 + assert pool.hits >= 1 + assert hit_forward_mean <= miss_forward * 0.75 + + print( + "[CUDA6 Activation Copy Benchmark]\n" + f" gpu forward mean: {mean_gpu_forward * 1e3:.2f} ms\n" + f" gpu drain mean: {mean_gpu_drain * 1e3:.2f} ms\n" + f" sync forward mean: {mean_sync_forward * 1e3:.2f} ms\n" + f" sync drain mean: {mean_sync_drain * 1e3:.2f} ms\n" + f" async forward mean:{mean_async_forward * 1e3:.2f} ms\n" + f" async drain mean: {mean_async_drain * 1e3:.2f} ms\n" + f" async combined: {combined_mean * 1e3:.2f} ms\n" + f" pool miss forward: {miss_forward * 1e3:.2f} ms\n" + f" pool hit forward: {hit_forward_mean * 1e3:.2f} ms\n" + f" pool miss drain: {miss_drain * 1e3:.2f} ms\n" + f" pool hit drain: {hit_drain_mean * 1e3:.2f} ms\n" + f" pool stats (hits/misses): {pool.hits}/{pool.misses}" + ) diff --git a/tests/test_gptq_add_batch_cpu.py b/tests/test_gptq_add_batch_cpu.py new file mode 100644 index 000000000..a52b8b7c1 --- /dev/null +++ b/tests/test_gptq_add_batch_cpu.py @@ -0,0 +1,137 @@ +# SPDX-FileCopyrightText: 2025 ModelCloud.ai +# SPDX-License-Identifier: Apache-2.0 + +import time +from dataclasses import dataclass +from typing import List, Tuple + +import pytest +import torch +import torch.nn as nn + +from gptqmodel.quantization.gptq import GPTQ + + +pytestmark = pytest.mark.skipif( + (not torch.cuda.is_available()) or torch.cuda.device_count() <= 6, + reason="CUDA device 6 is required for this benchmark test", +) + + +@dataclass +class PathStats: + per_batch_seconds: float + total_seconds: float + peak_bytes: int + batches_measured: int + + +def _make_module(hidden_dim: int, device: torch.device) -> nn.Linear: + layer = nn.Linear(hidden_dim, hidden_dim, bias=False, dtype=torch.float16) + return layer.to(device).eval() + + +def _generate_input( + batch_size: int, + seq_len: int, + hidden_dim: int, + device: torch.device, +) -> torch.Tensor: + return torch.randn(batch_size, seq_len, hidden_dim, device=device, dtype=torch.float16) + + +def _benchmark_add_batch( + module: nn.Module, + device: torch.device, + hidden_dim: int, + *, + total_batches: int, + warmup_batches: int, + batch_size: int, + seq_len: int, + use_cpu_queue: bool, +) -> PathStats: + gptq = GPTQ(module) + dummy_outputs = torch.empty(0, device=device) + + def _run_batch(idx: int) -> None: + activations = _generate_input(batch_size, seq_len, hidden_dim, device=device) + if use_cpu_queue: + cpu_activations = activations.detach().to(device="cpu") + del activations + gptq.add_batch(cpu_activations, dummy_outputs, batch_index=idx) + else: + gptq.add_batch(activations, dummy_outputs, batch_index=idx) + + for idx in range(warmup_batches): + _run_batch(idx) + + torch.cuda.synchronize(device) + baseline_alloc = torch.cuda.memory_allocated(device) + torch.cuda.reset_peak_memory_stats(device) + + measured = 0 + start = time.perf_counter() + + for idx in range(warmup_batches, total_batches): + _run_batch(idx) + measured += 1 + + torch.cuda.synchronize(device) + total = time.perf_counter() - start + peak_alloc = torch.cuda.max_memory_allocated(device) + peak_bytes = max(0, peak_alloc - baseline_alloc) + per_batch = total / measured if measured else 0.0 + return PathStats(per_batch_seconds=per_batch, total_seconds=total, peak_bytes=peak_bytes, batches_measured=measured) + + +def test_gptq_add_batch_cpu_vs_gpu_queue(): + device = torch.device("cuda", 6) + torch.cuda.set_device(device) + + configs: List[Tuple[str, int]] = [ + ("llama3", 4096), + ("qwen3", 3584), + ] + + total_batches = 8 + warmup_batches = 2 + batch_size = 4 + seq_len = 512 + + for name, hidden_dim in configs: + module_gpu = _make_module(hidden_dim, device=device) + gpu_stats = _benchmark_add_batch( + module_gpu, + device, + hidden_dim, + total_batches=total_batches, + warmup_batches=warmup_batches, + batch_size=batch_size, + seq_len=seq_len, + use_cpu_queue=False, + ) + + module_cpu_queue = _make_module(hidden_dim, device=device) + cpu_stats = _benchmark_add_batch( + module_cpu_queue, + device, + hidden_dim, + total_batches=total_batches, + warmup_batches=warmup_batches, + batch_size=batch_size, + seq_len=seq_len, + use_cpu_queue=True, + ) + + assert gpu_stats.batches_measured == cpu_stats.batches_measured == total_batches - warmup_batches + + print( + f"[{name.upper()}] GPU queue: {gpu_stats.per_batch_seconds*1e3:.3f} ms/batch " + f"(total {gpu_stats.total_seconds:.3f} s, peak GPU alloc {gpu_stats.peak_bytes/1024/1024:.2f} MiB) | " + f"CPU queue: {cpu_stats.per_batch_seconds*1e3:.3f} ms/batch " + f"(total {cpu_stats.total_seconds:.3f} s, peak GPU alloc {cpu_stats.peak_bytes/1024/1024:.2f} MiB)" + ) + + assert cpu_stats.per_batch_seconds >= gpu_stats.per_batch_seconds + assert cpu_stats.peak_bytes <= gpu_stats.peak_bytes diff --git a/tests/test_gptq_hessian_chunking.py b/tests/test_gptq_hessian_chunking.py new file mode 100644 index 000000000..11ba9b409 --- /dev/null +++ b/tests/test_gptq_hessian_chunking.py @@ -0,0 +1,126 @@ +# SPDX-FileCopyrightText: 2025 ModelCloud.ai +# SPDX-License-Identifier: Apache-2.0 + +import math +import time +from typing import Dict, List, Optional, Tuple + +import pytest +import torch +import torch.nn as nn + +from gptqmodel.quantization import gptq as gptq_mod +from gptqmodel.quantization.config import QuantizeConfig +from gptqmodel.quantization.gptq import GPTQ + + +pytestmark = pytest.mark.skipif( + (not torch.cuda.is_available()) or torch.cuda.device_count() <= 6, + reason="CUDA device 6 is required for this benchmark test", +) + + +def _make_module(hidden_dim: int, device: torch.device) -> nn.Linear: + layer = nn.Linear(hidden_dim, hidden_dim, bias=False, dtype=torch.float16) + return layer.to(device).eval() + + +def _run_add_batch( + hidden_dim: int, + *, + device: torch.device, + batch_size: int, + seq_len: int, + total_batches: int, + warmup_batches: int, + chunk_bytes: Optional[int], +) -> Dict[str, float]: + qcfg = QuantizeConfig() + qcfg.hessian_chunk_bytes = chunk_bytes + + module = _make_module(hidden_dim, device) + gptq = GPTQ(module, qcfg=qcfg) + dummy_outputs = torch.empty(0, device=device) + + def _one_batch(idx: int): + activations = torch.randn(batch_size, seq_len, hidden_dim, device=device, dtype=torch.float16) + gptq.add_batch(activations, dummy_outputs, batch_index=idx) + + for idx in range(warmup_batches): + _one_batch(idx) + + torch.cuda.synchronize(device) + baseline_alloc = torch.cuda.memory_allocated(device) + torch.cuda.reset_peak_memory_stats(device) + + measured = 0 + start = time.perf_counter() + for idx in range(warmup_batches, total_batches): + _one_batch(idx) + measured += 1 + torch.cuda.synchronize(device) + elapsed = time.perf_counter() - start + peak_alloc = torch.cuda.max_memory_allocated(device) + gptq_mod._WORKSPACE_CACHE.clear() + + per_batch = elapsed / measured if measured else 0.0 + activation_mb = (batch_size * seq_len * hidden_dim * 2) / (1024**2) + peak_delta_mb = max(0.0, (peak_alloc - baseline_alloc) / (1024**2)) + + chunk_rows = gptq._resolve_hessian_chunk_size(batch_size * seq_len, torch.float32) + + return { + "chunk_bytes": chunk_bytes, + "per_batch_sec": per_batch, + "total_sec": elapsed, + "peak_delta_mb": peak_delta_mb, + "activation_mb": activation_mb, + "chunk_rows": chunk_rows, + } + + +def test_hessian_chunking_vram_vs_latency(): + device = torch.device("cuda", 6) + torch.cuda.set_device(device) + + configs: List[Tuple[str, int]] = [ + ("llama3", 4096), + ("qwen3", 3584), + ] + chunk_options = [None, 64 << 20, 32 << 20, 16 << 20, 8 << 20, 4 << 20] + + total_batches = 6 + warmup_batches = 2 + batch_size = 4 + seq_len = 512 + + for name, hidden_dim in configs: + results: List[Dict[str, float]] = [] + for chunk_bytes in chunk_options: + stats = _run_add_batch( + hidden_dim, + device=device, + batch_size=batch_size, + seq_len=seq_len, + total_batches=total_batches, + warmup_batches=warmup_batches, + chunk_bytes=chunk_bytes, + ) + results.append(stats) + + baseline = results[0] + best = min(results, key=lambda x: x["peak_delta_mb"]) + + print(f"\n[{name.upper()}] activation ~{baseline['activation_mb']:.2f} MiB") + for stats in results: + chunk_label = "none" if stats["chunk_bytes"] is None else f"{stats['chunk_bytes'] // (1<<20)} MiB" + print( + f" chunk={chunk_label:<5} | chunk_rows={stats['chunk_rows']} | " + f"peak ΔVRAM {stats['peak_delta_mb']:.2f} MiB | per-batch {stats['per_batch_sec'] * 1e3:.2f} ms" + ) + + assert math.isclose(baseline["activation_mb"], best["activation_mb"], rel_tol=1e-6) + + smallest_chunk = results[-1] + assert smallest_chunk["peak_delta_mb"] >= baseline["peak_delta_mb"] + assert smallest_chunk["per_batch_sec"] <= baseline["per_batch_sec"] * 4.0 diff --git a/tests/test_hessian_accumulation_cpu.py b/tests/test_hessian_accumulation_cpu.py new file mode 100644 index 000000000..a605936ac --- /dev/null +++ b/tests/test_hessian_accumulation_cpu.py @@ -0,0 +1,124 @@ +# SPDX-FileCopyrightText: 2025 ModelCloud.ai +# SPDX-License-Identifier: Apache-2.0 + +import time +from dataclasses import dataclass +from typing import List, Tuple + +import pytest +import torch +import torch.nn as nn + +from gptqmodel.quantization.gptq import GPTQ +from gptqmodel.utils.safe import THREADPOOLCTL + + +pytestmark = pytest.mark.skipif( + (not torch.cuda.is_available()) or torch.cuda.device_count() <= 6, + reason="CUDA device 6 is required for this benchmark test", +) + + +@dataclass +class BenchmarkResult: + per_batch_seconds: float + total_seconds: float + batches_measured: int + + +def _make_module(hidden_dim: int, device: torch.device) -> nn.Linear: + module = nn.Linear(hidden_dim, hidden_dim, bias=False, dtype=torch.float16).to(device) + module.eval() + return module + + +def _generate_samples( + batches: int, + batch_size: int, + seq_len: int, + hidden_dim: int, + device: torch.device, +) -> List[torch.Tensor]: + torch.manual_seed(0) + torch.cuda.manual_seed_all(0) + samples = [] + for _ in range(batches): + sample = torch.randn(batch_size, seq_len, hidden_dim, device=device, dtype=torch.float16) + samples.append(sample.contiguous()) + return samples + + +def _to_pinned_cpu(samples: List[torch.Tensor]) -> List[torch.Tensor]: + pinned = [] + for tensor in samples: + host = tensor.to(device="cpu", non_blocking=False).contiguous() + pinned.append(host.pin_memory()) + return pinned + + +def _benchmark_add_batch( + module: nn.Module, + samples: List[torch.Tensor], + warmup_batches: int, + device: torch.device, +) -> BenchmarkResult: + gptq = GPTQ(module) + dummy_outputs = torch.empty(0, device=samples[0].device) + + # Warmup to populate internal workspaces and caches + for idx in range(warmup_batches): + gptq.add_batch(samples[idx], dummy_outputs, batch_index=idx) + + if samples[0].device.type == "cuda": + torch.cuda.synchronize(device) + + measured = 0 + start = time.perf_counter() + + for idx in range(warmup_batches, len(samples)): + gptq.add_batch(samples[idx], dummy_outputs, batch_index=idx) + measured += 1 + + if samples[0].device.type == "cuda": + torch.cuda.synchronize(device) + + total = time.perf_counter() - start + per_batch = total / measured if measured else 0.0 + return BenchmarkResult(per_batch_seconds=per_batch, total_seconds=total, batches_measured=measured) + + +def test_hessian_accumulation_cpu_vs_gpu(): + device = torch.device("cuda", 6) + torch.cuda.set_device(device) + + configs: List[Tuple[str, int]] = [ + ("llama3", 4096), + ("qwen3", 3584), + ] + + total_batches = 6 + warmup_batches = 2 + batch_size = 4 + seq_len = 256 + + for name, hidden_dim in configs: + module_gpu = _make_module(hidden_dim, device=device) + gpu_samples = _generate_samples(total_batches, batch_size, seq_len, hidden_dim, device) + cpu_samples = _to_pinned_cpu(gpu_samples) + + gpu_result = _benchmark_add_batch(module_gpu, gpu_samples, warmup_batches, device) + + module_cpu = _make_module(hidden_dim, device=device) + with THREADPOOLCTL.threadpool_limits(limits=16): + cpu_result = _benchmark_add_batch(module_cpu, cpu_samples, warmup_batches, device) + + assert gpu_result.batches_measured == cpu_result.batches_measured == total_batches - warmup_batches + + print( + f"[{name.upper()}] GPU add_batch: {gpu_result.per_batch_seconds*1e3:.3f} ms/batch " + f"(total {gpu_result.total_seconds:.3f} s) | " + f"CPU add_batch (threads=16): {cpu_result.per_batch_seconds*1e3:.3f} ms/batch " + f"(total {cpu_result.total_seconds:.3f} s)" + ) + + assert cpu_result.per_batch_seconds >= gpu_result.per_batch_seconds diff --git a/tests/test_linalg.py b/tests/test_linalg.py index 10908c076..31fec6585 100644 --- a/tests/test_linalg.py +++ b/tests/test_linalg.py @@ -15,8 +15,6 @@ from logbar import LogBar # noqa: E402 from parameterized import parameterized # noqa: E402 -from gptqmodel.utils.safe import TORCH_LINALG # noqa: E402 - log = LogBar.shared() @@ -35,7 +33,7 @@ class Test(unittest.TestCase): ) def test_linalg_eigh(self, dtype: torch.dtype, size: int): matrix = torch.randn([size, size], device=ROCM, dtype=dtype) - TORCH_LINALG.eigh(matrix) + torch.linalg.eigh(matrix) @parameterized.expand( [ @@ -51,6 +49,6 @@ def test_linalg_eigh_magma(self, dtype: torch.dtype, size: int): torch.backends.cuda.preferred_linalg_library(backend="magma") matrix = torch.randn([size, size], device=ROCM, dtype=dtype) - TORCH_LINALG.eigh(matrix) + torch.linalg.eigh(matrix) torch.backends.cuda.preferred_linalg_library(backend=original_backend) diff --git a/tests/test_threadpoolctl.py b/tests/test_threadpoolctl.py new file mode 100644 index 000000000..da9aae4af --- /dev/null +++ b/tests/test_threadpoolctl.py @@ -0,0 +1,78 @@ +# SPDX-FileCopyrightText: 2025 ModelCloud.ai +# SPDX-License-Identifier: Apache-2.0 + +import time +from typing import Dict, List + +import torch + +from gptqmodel.utils.safe import THREADPOOLCTL +from gptqmodel.utils.threadx import DeviceThreadPool + + +def _run_thread_limit(pool: DeviceThreadPool, limit: int) -> Dict[str, float]: + d_cpu = torch.device("cpu") + futures = [] + + def worker(): + with THREADPOOLCTL.threadpool_limits(limits=limit): + start = time.perf_counter() + info = THREADPOOLCTL.threadpool_info() + counts = [entry.get("num_threads", 0) for entry in info if entry.get("num_threads", 0) > 0] + # Exercise BLAS path + a = torch.randn(512, 256, device=d_cpu) + b = torch.randn(256, 512, device=d_cpu) + _ = a @ b + elapsed = time.perf_counter() - start + max_threads = max(counts) if counts else 0 + return elapsed, max_threads + + for _ in range(8): + futures.append(pool.submit(d_cpu, worker)) + + pool.wait(d_cpu) + + timings = [] + thread_counts = [] + for fut in futures: + elapsed, max_threads = fut.result(timeout=5) + timings.append(elapsed) + thread_counts.append(max_threads) + + mean_time = sum(timings) / len(timings) + return { + "mean_time": mean_time, + "thread_counts": thread_counts, + } + + +def test_threadpool_limits_inside_device_threadpool(): + d_cpu = torch.device("cpu") + pool = DeviceThreadPool( + devices=[d_cpu], + include_cuda=False, + include_xpu=False, + include_mps=False, + include_cpu=True, + workers={"cpu": 8}, + inference_mode=True, + ) + + try: + limits = [1, 2, 4, 8, 16, 32] + results: List[Dict[str, float]] = [] + + for limit in limits: + result = _run_thread_limit(pool, limit) + results.append(result) + for count in result["thread_counts"]: + if count: + assert count <= limit + for limit, result in zip(limits, results): + print( + f"[thread limit={limit}] mean worker time: {result['mean_time'] * 1e3:.3f} ms " + f"| thread counts: {result['thread_counts']}" + ) + finally: + pool.shutdown(wait=True) +