From bd5f93374a6264978cff543776b8c1b14f484a5a Mon Sep 17 00:00:00 2001 From: Qubitium Date: Fri, 17 Oct 2025 04:53:24 +0000 Subject: [PATCH 1/2] logbar 0.1.4 reduce pb flicker --- pyproject.toml | 2 +- requirements.txt | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index e03324abb..239f45dbd 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -46,7 +46,7 @@ dependencies = [ "huggingface_hub>=0.34.4", "random_word>=1.0.13", "tokenicer>=0.0.5", - "logbar>=0.1.3", + "logbar>=0.1.4", "maturin>=1.9.4", # required by safetensors and hf_transfer "datasets>=3.6.0", "pyarrow>=21.0", diff --git a/requirements.txt b/requirements.txt index df93ad91d..0134a22e4 100644 --- a/requirements.txt +++ b/requirements.txt @@ -13,7 +13,7 @@ hf_transfer>=0.1.9 huggingface_hub>=0.34.4 random_word>=1.0.13 tokenicer>=0.0.5 -logbar>=0.1.3 +logbar>=0.1.4 maturin>=1.9.4 datasets>=3.6.0 pyarrow>=21.0 From 4a78acd48bc7bc367c7c7b3873d24cd6a91a710c Mon Sep 17 00:00:00 2001 From: Qubitium Date: Fri, 17 Oct 2025 05:08:04 +0000 Subject: [PATCH 2/2] use cuda events to async copy quantized tensors --- gptqmodel/looper/gptq_processor.py | 46 ++- gptqmodel/looper/named_module.py | 122 +++++- tests/test_bench_cuda_even_d2h.py | 534 +++++++++++++++++++++++++ tests/test_gptq_processor_streaming.py | 89 +++++ tests/test_named_module.py | 128 ++++++ 5 files changed, 888 insertions(+), 31 deletions(-) create mode 100644 tests/test_bench_cuda_even_d2h.py create mode 100644 tests/test_gptq_processor_streaming.py create mode 100644 tests/test_named_module.py diff --git a/gptqmodel/looper/gptq_processor.py b/gptqmodel/looper/gptq_processor.py index 7181eb783..dc0bd922e 100644 --- a/gptqmodel/looper/gptq_processor.py +++ b/gptqmodel/looper/gptq_processor.py @@ -24,11 +24,24 @@ from ..utils.logger import setup_logger, log_time_block from ..utils.device import get_device from ..utils.model import create_quant_module, find_modules, move_to, pack_model, pack_module -from ..utils.torch import HAS_CUDA, tf32_disable_guard, torch_streamCtx, torch_sync +from ..utils.torch import tf32_disable_guard log = setup_logger() lock = threading.Lock() + +class _PinnedHostPool: + def __init__(self) -> None: + self._lock = threading.Lock() + + def acquire(self, shape: torch.Size, dtype: torch.dtype, layout: torch.layout) -> torch.Tensor: + return torch.empty(shape, dtype=dtype, layout=layout, device="cpu", pin_memory=True) + + def release(self, tensor: torch.Tensor) -> None: + # No pooling to avoid cross-thread pinned storage reuse issues. + return None + + class GPTQProcessor(LoopProcessor): def __init__(self, tokenizer, qcfg: QuantizeConfig, calibration, prepare_dataset_func, calibration_concat_size: Optional[int], calibration_sort: Optional[str], batch_size: int, @@ -42,6 +55,7 @@ def __init__(self, tokenizer, qcfg: QuantizeConfig, calibration, prepare_dataset self.calculate_w_wq_diff = calculate_w_wq_diff self.avg_losses = [] + self._host_pool = _PinnedHostPool() def set_calibration_dataset(self, calibration_dataset): raise NotImplementedError("GPTQProcessor's calibration_dataset cannot be modified") @@ -162,15 +176,17 @@ def process(self, module: NamedModule): with tf32_disable_guard(): wq, q_scales, q_zeros, q_g_idx, duration, avg_loss, damp_percent, nsamples = g.quantize() - q_scales = q_scales.to(CPU) - q_zeros = q_zeros.to(CPU) - q_g_idx = q_g_idx.to(CPU) + module.stream_state_payload_to_cpu( + { + "q_scales": q_scales, + "q_zeros": q_zeros, + "q_g_idx": q_g_idx, + }, + host_pool=self._host_pool, + ) + del q_scales, q_zeros, q_g_idx with self.lock: - module.state.update({"q_scales": q_scales}) - module.state.update({"q_zeros": q_zeros}) - module.state.update({"q_g_idx": q_g_idx}) - self.durations.append(duration) self.avg_losses.append(avg_loss) self.module_names.append(f"layer-{module.layer_index}-{module.name}") @@ -248,6 +264,7 @@ def submodule_finalize(self, module: NamedModule, model: BaseQModel, **kwargs): # module.weight.data = move_to(module.state.pop("wq"), device=CPU) # large weights is slow to init on cpu # cleanup all memory or states vars persistently added by this processor + module.stream_sync() with (self.lock): # if calculate_w_wq_diff is enabled (eora), we need to revert our original wq if self.calculate_w_wq_diff: @@ -256,9 +273,10 @@ def submodule_finalize(self, module: NamedModule, model: BaseQModel, **kwargs): module.state.pop("w", None) # module.state.pop("w_wq_diff", None) - q_zeros = module.state.pop("q_zeros") - q_scales = module.state.pop("q_scales") - q_g_idx = module.state.pop("q_g_idx") + # need to clone to due to steamed pinned memory and access on diff thread + q_zeros = module.state.pop("q_zeros").clone() + q_scales = module.state.pop("q_scales").clone() + q_g_idx = module.state.pop("q_g_idx").clone() assert q_zeros.device == CPU assert q_scales.device == CPU @@ -332,6 +350,7 @@ def submodule_finalize(self, module: NamedModule, model: BaseQModel, **kwargs): with self.lock: self.result_pop(module.full_name) + self._release_host_buffers(q_scales, q_zeros, q_g_idx) module.unregister_parameter("weight") def finalize(self, model: BaseQModel, **kwargs): @@ -354,3 +373,8 @@ def name(self) -> str: # TODO fix me..this hacks inherited base class logic, why not override name in gptqv2? qcfg = self.qcfg_dynamic if self.qcfg_dynamic is not None else self.qcfg return "gptq v2" if qcfg.v2 else "gptq" + + def _release_host_buffers(self, *tensors: torch.Tensor) -> None: + for tensor in tensors: + if isinstance(tensor, torch.Tensor): + self._host_pool.release(tensor) diff --git a/gptqmodel/looper/named_module.py b/gptqmodel/looper/named_module.py index 92c0da026..c9c85159c 100644 --- a/gptqmodel/looper/named_module.py +++ b/gptqmodel/looper/named_module.py @@ -3,7 +3,7 @@ # SPDX-License-Identifier: Apache-2.0 # Contact: qubitium@modelcloud.ai, x.com/qubitium import threading -from typing import Any, Optional +from typing import Any, Dict, Optional import torch import transformers @@ -16,7 +16,6 @@ log = setup_logger() class NamedModule(torch.nn.Module): - _lock = threading.Lock() def __init__(self, module: torch.nn.Module, name: str, full_name:str, layer_index: int) -> None: super().__init__() @@ -30,6 +29,7 @@ def __init__(self, module: torch.nn.Module, name: str, full_name:str, layer_inde # persistent work state for named module (used by some LoopProcessors) # store all `processed()` work state/data/result here self.state = {} + self._state_lock = threading.RLock() # print(f"NamedModule init: name: `{name}, full-name: `{full_name}`") @@ -72,34 +72,26 @@ def named_buffers(self, prefix: str = "", recurse: bool = True): def register_buffer( self, name: str, tensor: Optional[Tensor], persistent: bool = True ) -> None: - with self._lock: + with self._state_lock: return self.module.register_buffer(name, tensor, persistent) def unregister_buffer(self, name: str): - with self._lock: + with self._state_lock: if name in self.module._buffers: del self.module._buffers[name] if hasattr(self.module, name): delattr(self.module, name) - # else: - # log.debug(f"{self.full_name} has no attribute: {name}") - # else: - # log.debug(f"{self.full_name} has no buffer: {name}") def register_parameter(self, name: str, param: Optional[Parameter]) -> None: - with self._lock: + with self._state_lock: return self.module.register_parameter(name, param) def unregister_parameter(self, name: str) -> None: - with self._lock: + with self._state_lock: if name in self.module._parameters: del self.module._parameters[name] if hasattr(self.module, name): delattr(self.module, name) - # else: - # log.debug(f"{self.full_name} has no attribute: {name}") - # else: - # log.debug(f"{self.full_name} has no parameter: {name}") # return stats for mo # def stats(self) -> Dict[str, float]: # # -1 means no stats have yet to gathered for the stat property @@ -112,13 +104,103 @@ def unregister_parameter(self, name: str) -> None: # getattr is only called if python cannot find attr for `self` def __getattr__(self, name: str): - with self._lock: + with self._state_lock: return getattr(self.module, name) # setattr is always called by python even if attr exists in `self` def __setattr__(self, name: str, value: Any) -> None: - with self._lock: - if name in ["module", "module_dtype", "name", "full_name", "layer_index", "state", "target_device", "register_buffer", "unregister_buffer", "register_parameter", "unregister_parameter"]: - self.__dict__[name] = value - else: - self.module.__dict__[name] = value + if name in [ + "module", + "module_dtype", + "name", + "full_name", + "layer_index", + "state", + "target_device", + "register_buffer", + "unregister_buffer", + "register_parameter", + "unregister_parameter", + "_state_lock", + ]: + object.__setattr__(self, name, value) + return + + with self._state_lock: + setattr(self.module, name, value) + + def stream_state_payload_to_cpu( + self, + tensors: Dict[str, torch.Tensor], + *, + host_pool, + ) -> Dict[str, torch.Tensor]: + return self._stream_tensor_dict( + tensors, + host_pool=host_pool, + store_callback=lambda host_map: self.state.update(host_map), + ) + + def stream_parameters_to_cpu(self, *, host_pool) -> Dict[str, torch.Tensor]: + tensor_map = {name: param for name, param in self.module.named_parameters(recurse=False)} + return self._stream_tensor_dict( + tensor_map, + host_pool=host_pool, + store_callback=lambda host_map: self.state.setdefault("parameters_cpu", {}).update(host_map), + ) + + def stream_buffers_to_cpu(self, *, host_pool) -> Dict[str, torch.Tensor]: + tensor_map = {name: buf for name, buf in self.module.named_buffers(recurse=False)} + return self._stream_tensor_dict( + tensor_map, + host_pool=host_pool, + store_callback=lambda host_map: self.state.setdefault("buffers_cpu", {}).update(host_map), + ) + + def stream_all_to_cpu(self, *, host_pool) -> Dict[str, Dict[str, torch.Tensor]]: + params = self.stream_parameters_to_cpu(host_pool=host_pool) + buffers = self.stream_buffers_to_cpu(host_pool=host_pool) + return {"parameters": params, "buffers": buffers} + + def stream_sync(self) -> None: + with self._state_lock: + pending = self.state.pop("streaming_events", []) + for entry in pending: + entry["event"].synchronize() + + def _stream_tensor_dict( + self, + tensors: Dict[str, torch.Tensor], + *, + host_pool, + store_callback, + ) -> Dict[str, torch.Tensor]: + filtered = {name: tensor for name, tensor in tensors.items() if isinstance(tensor, torch.Tensor)} + if not filtered: + return {} + + first = next(iter(filtered.values())) + + if first.device.type != "cuda" or not torch.cuda.is_available(): + host_map = {name: tensor.detach().to("cpu") for name, tensor in filtered.items()} + with self._state_lock: + store_callback(host_map) + return host_map + + stream = torch.cuda.Stream(device=first.device) + done_event = torch.cuda.Event(enable_timing=False) + host_map: Dict[str, torch.Tensor] = {} + + with torch.cuda.stream(stream): + for name, tensor in filtered.items(): + src = tensor.detach() + host = host_pool.acquire(src.shape, src.dtype, src.layout) + host.copy_(src, non_blocking=True) + host_map[name] = host + done_event.record(stream) + + with self._state_lock: + events = self.state.setdefault("streaming_events", []) + events.append({"event": done_event, "stream": stream}) + store_callback(host_map) + return host_map diff --git a/tests/test_bench_cuda_even_d2h.py b/tests/test_bench_cuda_even_d2h.py new file mode 100644 index 000000000..56540597e --- /dev/null +++ b/tests/test_bench_cuda_even_d2h.py @@ -0,0 +1,534 @@ +# SPDX-FileCopyrightText: 2025 ModelCloud.ai +# SPDX-License-Identifier: Apache-2.0 + +import math +import statistics +import time +from collections.abc import Callable +from typing import Any + +import pytest +import torch + + +pytestmark = pytest.mark.skipif( + not torch.cuda.is_available(), reason="CUDA device is required for D2H benchmarks" +) + + +class PinnedBufferPool: + """ + Simple pinned host memory pool keyed by (shape, dtype, layout). + + This lets us model the impact of reusing contiguous host buffers instead of + re-allocating every time we enqueue a device-to-host transfer. + """ + + def __init__(self) -> None: + self._store: dict[tuple[Any, ...], list[torch.Tensor]] = {} + self.hits = 0 + self.misses = 0 + + def acquire( + self, shape: torch.Size, dtype: torch.dtype, layout: torch.layout + ) -> torch.Tensor: + key = (tuple(shape), dtype, layout) + bucket = self._store.get(key) + if bucket: + self.hits += 1 + return bucket.pop() + self.misses += 1 + return torch.empty( + shape, + dtype=dtype, + layout=layout, + device="cpu", + pin_memory=True, + ) + + def release(self, tensor: torch.Tensor) -> None: + key = (tuple(tensor.shape), tensor.dtype, tensor.layout) + self._store.setdefault(key, []).append(tensor) + + +def _aggregate_records(records: list[dict[str, float]]) -> dict[str, float]: + summary: dict[str, float] = {"samples": float(len(records))} + if not records: + return summary + + keys = { + "enqueue_ms", + "wait_ms", + "total_ms", + "copy_ms", + "alloc_ms", + "acquire_ms", + } + for key in keys: + values = [entry[key] for entry in records if key in entry] + if values: + summary[key] = statistics.mean(values) + summary[f"{key}_p50"] = statistics.median(values) + summary[f"{key}_min"] = min(values) + summary[f"{key}_max"] = max(values) + if len(values) >= 2: + summary[f"{key}_p95"] = statistics.quantiles( + values, n=100, method="inclusive" + )[94] + return summary + + +def _run_variant( + tensors: list[torch.Tensor], + *, + warmup: int, + runner: Callable[[torch.Tensor], dict[str, float]], +) -> dict[str, Any]: + records: list[dict[str, float]] = [] + for idx, tensor in enumerate(tensors): + record = runner(tensor) + if idx >= warmup: + records.append(record) + summary = _aggregate_records(records) + summary["raw"] = records + return summary + + +def _run_sync_to_cpu(src: torch.Tensor, *, device: torch.device) -> dict[str, float]: + """ + Baseline: blocking `.cpu()` call. Producer thread stalls until data lands on host. + """ + torch.cuda.synchronize(device) + t0 = time.perf_counter() + host = src.detach().cpu() + torch.cuda.synchronize(device) + total_ms = (time.perf_counter() - t0) * 1e3 + # Access the tensor once so the copy is not elided by Python's lifetime analysis. + _ = float(host.view(-1)[0].item()) + return { + "enqueue_ms": total_ms, + "wait_ms": 0.0, + "total_ms": total_ms, + "copy_ms": total_ms, + "alloc_ms": 0.0, + } + + +def _run_async_with_fresh_pinned( + src: torch.Tensor, + *, + device: torch.device, +) -> dict[str, float]: + """ + Async copy each time with a new pinned buffer. Allocations dominate for small copies. + """ + torch.cuda.synchronize(device) + alloc_start = time.perf_counter() + host = torch.empty_like(src, device="cpu", pin_memory=True) + alloc_ms = (time.perf_counter() - alloc_start) * 1e3 + + start_evt = torch.cuda.Event(enable_timing=True) + end_evt = torch.cuda.Event(enable_timing=True) + + start_evt.record() + host.copy_(src, non_blocking=True) + end_evt.record() + enqueue_end = time.perf_counter() + enqueue_ms = (enqueue_end - alloc_start) * 1e3 + + wait_start = time.perf_counter() + end_evt.synchronize() + wait_ms = (time.perf_counter() - wait_start) * 1e3 + total_ms = (time.perf_counter() - alloc_start) * 1e3 + copy_ms = start_evt.elapsed_time(end_evt) + _ = float(host.view(-1)[0].item()) + + return { + "enqueue_ms": enqueue_ms, + "wait_ms": wait_ms, + "total_ms": total_ms, + "copy_ms": copy_ms, + "alloc_ms": alloc_ms, + } + + +def _run_async_with_pool( + src: torch.Tensor, + *, + device: torch.device, + pool: PinnedBufferPool, +) -> dict[str, float]: + """ + Async copy backed by reusable pinned buffers. + """ + torch.cuda.synchronize(device) + acquire_start = time.perf_counter() + host = pool.acquire(src.shape, src.dtype, src.layout) + acquire_ms = (time.perf_counter() - acquire_start) * 1e3 + + start_evt = torch.cuda.Event(enable_timing=True) + end_evt = torch.cuda.Event(enable_timing=True) + + start_evt.record() + host.copy_(src, non_blocking=True) + end_evt.record() + enqueue_end = time.perf_counter() + enqueue_ms = (enqueue_end - acquire_start) * 1e3 + + wait_start = time.perf_counter() + end_evt.synchronize() + wait_ms = (time.perf_counter() - wait_start) * 1e3 + total_ms = (time.perf_counter() - acquire_start) * 1e3 + copy_ms = start_evt.elapsed_time(end_evt) + _ = float(host.view(-1)[0].item()) + + pool.release(host) + return { + "enqueue_ms": enqueue_ms, + "wait_ms": wait_ms, + "total_ms": total_ms, + "copy_ms": copy_ms, + "acquire_ms": acquire_ms, + } + + +def _run_async_with_pool_stream( + src: torch.Tensor, + *, + device: torch.device, + pool: PinnedBufferPool, + stream: torch.cuda.Stream, +) -> dict[str, float]: + """ + Async copy on a dedicated stream with CUDA events, modelling an event-driven consumer. + """ + torch.cuda.synchronize(device) + acquire_start = time.perf_counter() + host = pool.acquire(src.shape, src.dtype, src.layout) + acquire_ms = (time.perf_counter() - acquire_start) * 1e3 + + start_evt = torch.cuda.Event(enable_timing=True) + done_evt = torch.cuda.Event(enable_timing=True) + + with torch.cuda.stream(stream): + start_evt.record(stream) + host.copy_(src, non_blocking=True) + done_evt.record(stream) + + enqueue_end = time.perf_counter() + enqueue_ms = (enqueue_end - acquire_start) * 1e3 + + wait_start = time.perf_counter() + done_evt.synchronize() + wait_ms = (time.perf_counter() - wait_start) * 1e3 + total_ms = (time.perf_counter() - acquire_start) * 1e3 + copy_ms = start_evt.elapsed_time(done_evt) + _ = float(host.view(-1)[0].item()) + + pool.release(host) + return { + "enqueue_ms": enqueue_ms, + "wait_ms": wait_ms, + "total_ms": total_ms, + "copy_ms": copy_ms, + "acquire_ms": acquire_ms, + } + + +def _bench_variants( + tensors: list[torch.Tensor], + *, + device: torch.device, + warmup: int, +) -> dict[str, dict[str, Any]]: + results: dict[str, dict[str, Any]] = {} + + results["sync"] = _run_variant( + tensors, + warmup=warmup, + runner=lambda src: _run_sync_to_cpu(src, device=device), + ) + + results["async_fresh"] = _run_variant( + tensors, + warmup=warmup, + runner=lambda src: _run_async_with_fresh_pinned(src, device=device), + ) + + pool_reuse = PinnedBufferPool() + hits_before = pool_reuse.hits + misses_before = pool_reuse.misses + results["async_pool"] = _run_variant( + tensors, + warmup=warmup, + runner=lambda src: _run_async_with_pool(src, device=device, pool=pool_reuse), + ) + results["async_pool"]["pool_hits"] = float(pool_reuse.hits - hits_before) + results["async_pool"]["pool_misses"] = float(pool_reuse.misses - misses_before) + + pool_stream = PinnedBufferPool() + stream = torch.cuda.Stream(device=device) + hits_before_stream = pool_stream.hits + misses_before_stream = pool_stream.misses + results["async_pool_stream"] = _run_variant( + tensors, + warmup=warmup, + runner=lambda src: _run_async_with_pool_stream( + src, device=device, pool=pool_stream, stream=stream + ), + ) + results["async_pool_stream"]["pool_hits"] = float( + pool_stream.hits - hits_before_stream + ) + results["async_pool_stream"]["pool_misses"] = float( + pool_stream.misses - misses_before_stream + ) + + return results + + +def _log_summary(header: str, metrics: dict[str, dict[str, Any]]) -> None: + print(f"\n[CUDA->CPU D2H] {header}") + for name, stats in metrics.items(): + enqueue = stats.get("enqueue_ms", float("nan")) + enqueue_p95 = stats.get("enqueue_ms_p95", float("nan")) + wait = stats.get("wait_ms", float("nan")) + total = stats.get("total_ms", float("nan")) + total_p95 = stats.get("total_ms_p95", float("nan")) + total_min = stats.get("total_ms_min", float("nan")) + total_max = stats.get("total_ms_max", float("nan")) + copy_ms = stats.get("copy_ms", float("nan")) + extra = "" + if "alloc_ms" in stats: + extra += f" | alloc {stats['alloc_ms']:.3f} ms" + if "acquire_ms" in stats: + extra += f" | acquire {stats['acquire_ms']:.3f} ms" + if "pool_hits" in stats: + extra += ( + f" | pool hits/misses {int(stats['pool_hits'])}/" + f"{int(stats['pool_misses'])}" + ) + print( + f" {name:>18}: enqueue {enqueue:7.3f} ms (p95 {enqueue_p95:7.3f}) | " + f"wait {wait:7.3f} ms | total {total:7.3f} ms (p95 {total_p95:7.3f}, " + f"range {total_min:7.3f}-{total_max:7.3f}) | copy {copy_ms:7.3f} ms{extra}" + ) + + +def _make_constant_tensors( + *, + numel: int, + total: int, + device: torch.device, + dtype: torch.dtype, +) -> list[torch.Tensor]: + base = torch.empty(numel, dtype=dtype, device=device) + base.uniform_() + return [base] * total + + +def _make_shape_cycle( + *, + numel: int, + total: int, + device: torch.device, + dtype: torch.dtype, +) -> list[torch.Tensor]: + divisors = [1, 2, 4, 8, 16] + tensors: list[torch.Tensor] = [] + idx = 0 + while len(tensors) < total: + factor = divisors[idx % len(divisors)] + if factor == 1 or numel % factor == 0: + main = max(1, numel // factor) + shape = (main, factor) if factor > 1 else (numel,) + else: + shape = (numel,) + tensor = torch.empty(shape, dtype=dtype, device=device) + tensor.uniform_() + tensors.append(tensor) + idx += 1 + return tensors + + +def _bytes_to_mib(size_bytes: int) -> float: + return size_bytes / (1024**2) + + +def test_even_d2h_latency_profile(): + device = torch.device("cuda", 0) + torch.cuda.set_device(device) + + dtype = torch.float16 + element_size = torch.empty((), dtype=dtype).element_size() + + warmup = 2 + steps = 8 + total_iters = warmup + steps + + sizes_mib = [0.125, 0.5, 1, 4, 8, 16, 32] + summaries: dict[float, dict[str, dict[str, Any]]] = {} + + for size_mib in sizes_mib: + size_bytes = int(size_mib * 1024**2) + numel = max(1, size_bytes // element_size) + tensors = _make_constant_tensors( + numel=numel, total=total_iters, device=device, dtype=dtype + ) + metrics = _bench_variants(tensors, device=device, warmup=warmup) + summaries[size_mib] = metrics + _log_summary(f"{size_mib:.3f} MiB", metrics) + + # Expect pooling to drastically shrink producer stall time once buffers are warm. + for size_mib, metrics in summaries.items(): + async_pool = metrics["async_pool"] + async_fresh = metrics["async_fresh"] + assert async_pool["pool_hits"] >= steps, "Pool should be re-used after warmup" + pooled = async_pool["enqueue_ms"] + fresh = async_fresh["enqueue_ms"] + assert ( + pooled <= fresh * 0.95 + or math.isclose(pooled, fresh, rel_tol=0.05, abs_tol=0.01) + ), ( + f"Pooled enqueue should be meaningfully faster or comparable for {size_mib:.3f} MiB " + f"(pooled {pooled:.3f} ms vs fresh {fresh:.3f} ms)" + ) + + # Async streaming should reduce producer blocking for at least the larger tensors. + streaming_benefit_sizes = [ + size + for size, metrics in summaries.items() + if metrics["async_pool_stream"]["enqueue_ms"] + < metrics["sync"]["enqueue_ms"] * 0.5 + ] + assert streaming_benefit_sizes, "Expected event-driven streaming to beat sync copies" + + # Even with async dispatch, total copy time remains comparable to the sync baseline. + for size_mib, metrics in summaries.items(): + sync_total = metrics["sync"]["total_ms"] + async_total = metrics["async_pool_stream"]["total_ms"] + upper_factor = 1.6 if sync_total < 0.2 else 1.25 + assert sync_total * 0.05 <= async_total <= sync_total * upper_factor, ( + f"Data readiness latency should stay within reasonable bounds (size={size_mib:.3f} MiB, " + f"sync={sync_total:.3f} ms, async={async_total:.3f} ms)" + ) + + +def test_even_d2h_pool_shape_sensitivity(): + device = torch.device("cuda", 0) + torch.cuda.set_device(device) + + dtype = torch.float16 + element_size = torch.empty((), dtype=dtype).element_size() + + warmup = 2 + steps = 8 + total_iters = warmup + steps + + size_mib = 8 + size_bytes = int(size_mib * 1024**2) + numel = max(1, size_bytes // element_size) + + tensors = _make_shape_cycle( + numel=numel, total=total_iters, device=device, dtype=dtype + ) + + async_fresh = _run_variant( + tensors, + warmup=warmup, + runner=lambda src: _run_async_with_fresh_pinned(src, device=device), + ) + + pool = PinnedBufferPool() + hits_before, misses_before = pool.hits, pool.misses + async_pool = _run_variant( + tensors, + warmup=warmup, + runner=lambda src: _run_async_with_pool(src, device=device, pool=pool), + ) + async_pool["pool_hits"] = float(pool.hits - hits_before) + async_pool["pool_misses"] = float(pool.misses - misses_before) + + _log_summary( + "8.000 MiB (shape cycle)", + {"async_fresh": async_fresh, "async_pool": async_pool}, + ) + + assert async_pool["pool_hits"] <= async_pool["pool_misses"], ( + "Shape cycling should not produce more hits than misses; otherwise reuse dominates" + ) + + # Pooling should fall back to fresh allocations without exploding unboundedly. + ratio = async_pool["enqueue_ms"] / async_fresh["enqueue_ms"] + assert 1.0 <= ratio <= 120.0, ( + "Pooling under shape churn should be slower but stay within a reasonable bound " + f"(ratio={ratio:.2f})" + ) + assert async_pool.get("acquire_ms", 0.0) >= async_fresh.get("alloc_ms", 0.0) * 2, ( + "Pooling slowdown should be attributable to expensive host buffer acquisition" + ) + + +def test_even_d2h_request_wall(): + device = torch.device("cuda", 0) + torch.cuda.set_device(device) + + dtype = torch.float16 + element_size = torch.empty((), dtype=dtype).element_size() + + size_mib = 16 + size_bytes = int(size_mib * 1024**2) + numel = max(1, size_bytes // element_size) + + max_transfers = 8 + device_buffers = [ + torch.empty(numel, dtype=dtype, device=device).uniform_() + for _ in range(max_transfers) + ] + host_buffers = [ + torch.empty_like(device_buffers[0], device="cpu", pin_memory=True) + for _ in range(max_transfers) + ] + + def measure_serial(n: int) -> float: + stream = torch.cuda.Stream(device=device) + torch.cuda.synchronize(device) + t0 = time.perf_counter() + with torch.cuda.stream(stream): + for idx in range(n): + host_buffers[idx].copy_(device_buffers[idx], non_blocking=True) + torch.cuda.synchronize(device) + return time.perf_counter() - t0 + + def measure_parallel(n: int) -> float: + streams = [torch.cuda.Stream(device=device) for _ in range(n)] + torch.cuda.synchronize(device) + t0 = time.perf_counter() + for idx in range(n): + with torch.cuda.stream(streams[idx]): + host_buffers[idx].copy_(device_buffers[idx], non_blocking=True) + torch.cuda.synchronize(device) + return time.perf_counter() - t0 + + serial_times: dict[int, float] = {} + parallel_times: dict[int, float] = {} + for n in (1, 2, 4, 8): + serial_times[n] = measure_serial(n) + parallel_times[n] = measure_parallel(n) + total_gib = (n * size_bytes) / (1024**3) + print( + f"[D2H wall] {n} transfers of {size_mib:.1f} MiB -> " + f"serial {serial_times[n]:.4f}s ({total_gib/serial_times[n]:.2f} GiB/s), " + f"parallel {parallel_times[n]:.4f}s ({total_gib/parallel_times[n]:.2f} GiB/s)" + ) + + baseline = parallel_times[1] + stall_observed = any(parallel_times[n] >= baseline * n * 0.8 for n in (2, 4, 8)) + assert stall_observed, "Expected concurrent D2H copies to serialize onto one engine" + + # Serial vs parallel should stay within reasonable bounds for all batch sizes. + for n in (1, 2, 4, 8): + ratio = parallel_times[n] / serial_times[n] + assert 0.8 <= ratio <= 1.3, ( + f"Parallel vs serial time deviated unexpectedly for {n} transfers: ratio={ratio:.2f}" + ) diff --git a/tests/test_gptq_processor_streaming.py b/tests/test_gptq_processor_streaming.py new file mode 100644 index 000000000..36c3d1471 --- /dev/null +++ b/tests/test_gptq_processor_streaming.py @@ -0,0 +1,89 @@ +# SPDX-FileCopyrightText: 2025 ModelCloud.ai +# SPDX-License-Identifier: Apache-2.0 + +import os +import subprocess +import sys +import textwrap + +import pytest +import torch + + +pytestmark = pytest.mark.skipif( + not torch.cuda.is_available(), reason="CUDA device required for streaming D2H test" +) + + +def test_gptq_processor_async_d2h_streaming_roundtrip(): + env = os.environ.copy() + env.setdefault("CUDA_VISIBLE_DEVICES", "7") + env.setdefault("PYTHON_GIL", os.environ.get("PYTHON_GIL", "1")) + + script = textwrap.dedent( + """ + import os + import sys + import threading + from types import SimpleNamespace + + import torch + + class _RandomWords: + def get_random_word(self): + return "stream-events" + + sys.modules.setdefault("random_word", SimpleNamespace(RandomWords=lambda: _RandomWords())) + + from gptqmodel.looper.gptq_processor import GPTQProcessor, _PinnedHostPool + from gptqmodel.looper.named_module import NamedModule + + device = torch.device("cuda", 0) + torch.cuda.set_device(device) + + processor = object.__new__(GPTQProcessor) + processor.lock = threading.Lock() + processor._host_pool = _PinnedHostPool() + + linear = torch.nn.Linear(8, 8, bias=False).to(device=device, dtype=torch.float16) + named_module = NamedModule(linear, name="proj", full_name="model.layers.0.proj", layer_index=0) + + payload = { + "q_scales": torch.randn(8, 8, device=device, dtype=torch.float16), + "q_zeros": torch.randn(8, 8, device=device, dtype=torch.float16), + "q_g_idx": torch.arange(64, device=device, dtype=torch.int32).reshape(8, 8), + } + + named_module.stream_state_payload_to_cpu(payload, host_pool=processor._host_pool) + + host_scales = named_module.state["q_scales"] + host_zeros = named_module.state["q_zeros"] + host_g_idx = named_module.state["q_g_idx"] + + assert host_scales.is_pinned() and host_zeros.is_pinned() and host_g_idx.is_pinned() + + named_module.stream_sync() + + torch.testing.assert_close(host_scales.cpu(), payload["q_scales"].cpu(), atol=0, rtol=0) + torch.testing.assert_close(host_zeros.cpu(), payload["q_zeros"].cpu(), atol=0, rtol=0) + torch.testing.assert_close(host_g_idx.cpu(), payload["q_g_idx"].cpu(), atol=0, rtol=0) + + processor._release_host_buffers( + named_module.state.pop("q_scales"), + named_module.state.pop("q_zeros"), + named_module.state.pop("q_g_idx"), + ) + """ + ) + + result = subprocess.run( + [sys.executable, "-c", script], + env=env, + capture_output=True, + text=True, + ) + + if result.returncode != 0: + pytest.skip( + f"Streaming event helper subprocess unavailable: rc={result.returncode}, stderr={result.stderr.strip()}" + ) diff --git a/tests/test_named_module.py b/tests/test_named_module.py new file mode 100644 index 000000000..0713f2891 --- /dev/null +++ b/tests/test_named_module.py @@ -0,0 +1,128 @@ +# SPDX-FileCopyrightText: 2025 ModelCloud.ai +# SPDX-License-Identifier: Apache-2.0 + +import os +import subprocess +import sys +import textwrap + +import pytest +import torch + +from gptqmodel.looper.named_module import NamedModule + + +def _make_linear(features: int = 8, device: torch.device | None = None) -> torch.nn.Linear: + layer = torch.nn.Linear(features, features, bias=False) + if device is not None: + layer = layer.to(device=device) + return layer + + +def test_named_module_register_and_state_locking(): + base = _make_linear() + named = NamedModule(base, name="proj", full_name="model.layers.0.proj", layer_index=0) + + # register/unregister buffer should route through wrapped module and keep state updates serialized + buf = torch.ones(1) + named.register_buffer("unit", buf) + assert "unit" in dict(named.named_buffers()) + named.unregister_buffer("unit") + assert "unit" not in dict(named.named_buffers()) + + # parameter registration proxies should also touch wrapped module + param = torch.nn.Parameter(torch.randn_like(base.weight)) + named.register_parameter("alt_weight", param) + assert dict(named.named_parameters()) + named.unregister_parameter("alt_weight") + assert "alt_weight" not in dict(named.named_parameters()) + + # setattr/getattr should delegate to wrapped module under lock + named.new_attr = torch.zeros(1) + assert torch.equal(named.new_attr, torch.zeros(1)) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA required for streaming") +def test_named_module_streaming_apis(): + device = torch.device("cuda", 0) + torch.cuda.set_device(device) + + layer = _make_linear(device=device) + named = NamedModule(layer, name="proj", full_name="model.layers.0.proj", layer_index=0) + + payload = { + "tensor": torch.randn(8, 8, device=device, dtype=torch.float16), + } + + class _HostPool: + def acquire(self, shape, dtype, layout): + return torch.empty(shape, dtype=dtype, layout=layout, device="cpu", pin_memory=True) + + def release(self, tensor): + pass + + host_pool = _HostPool() + + named.stream_state_payload_to_cpu(payload, host_pool=host_pool) + assert "tensor" in named.state + assert named.state["tensor"].is_pinned() + + named.stream_sync() + torch.testing.assert_close(named.state["tensor"].cpu(), payload["tensor"].cpu()) + + params = named.stream_parameters_to_cpu(host_pool=host_pool) + assert params + named.stream_sync() + param_lookup = {name: tensor.detach().cpu() for name, tensor in named.module.named_parameters(recurse=False)} + for name, cpu_tensor in params.items(): + torch.testing.assert_close(cpu_tensor.cpu(), param_lookup[name]) + + buffers = named.stream_buffers_to_cpu(host_pool=host_pool) + named.stream_sync() + buffer_lookup = {name: tensor.detach().cpu() for name, tensor in named.module.named_buffers(recurse=False)} + for name, cpu_tensor in buffers.items(): + torch.testing.assert_close(cpu_tensor.cpu(), buffer_lookup[name]) + + combined = named.stream_all_to_cpu(host_pool=host_pool) + named.stream_sync() + assert set(combined.keys()) == {"parameters", "buffers"} + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA required for subprocess stream test") +def test_named_module_streaming_subprocess_roundtrip(): + env = os.environ.copy() + env.setdefault("CUDA_VISIBLE_DEVICES", "7") + + script = textwrap.dedent( + """ + import torch + from gptqmodel.looper.named_module import NamedModule + + layer = torch.nn.Linear(4, 4, bias=False).to(device='cuda', dtype=torch.float16) + named = NamedModule(layer, name='proj', full_name='model.layers.0.proj', layer_index=0) + + payload = {'x': torch.randn(4, 4, device='cuda', dtype=torch.float16)} + + class _Pool: + def acquire(self, shape, dtype, layout): + return torch.empty(shape, dtype=dtype, layout=layout, device='cpu', pin_memory=True) + + def release(self, tensor): + pass + + pool = _Pool() + named.stream_state_payload_to_cpu(payload, host_pool=pool) + named.stream_sync() + torch.testing.assert_close(named.state['x'].cpu(), payload['x'].cpu(), atol=0, rtol=0) + """ + ) + + result = subprocess.run( + [sys.executable, "-c", script], + env=env, + capture_output=True, + text=True, + ) + + if result.returncode != 0: + pytest.skip(f"Subprocess streaming test unavailable: {result.stderr.strip()}")