Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 25 additions & 2 deletions gptqmodel/looper/named_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,16 @@ def stream_sync(self) -> None:
pending = self.state.pop("streaming_events", [])
for entry in pending:
entry["event"].synchronize()
keys = entry.get("keys")
if keys:
with self._state_lock:
event_map = self.state.get("streaming_event_map")
if event_map is not None:
for key in keys:
event_map.pop(key, None)
sources = entry.get("sources")
if sources is not None:
sources.clear()

def _stream_tensor_dict(
self,
Expand Down Expand Up @@ -194,17 +204,30 @@ def _stream_tensor_dict(
copy_stream = torch.cuda.Stream(device=copy_device)
done_event = torch.cuda.Event(enable_timing=False, blocking=False)

pending_sources = []
with torch.cuda.stream(copy_stream):
copy_stream.wait_stream(compute_stream)
for name, tensor in filtered.items():
src = tensor.detach()
src.record_stream(copy_stream)
pending_sources.append(src)
host = host_pool.acquire(src.shape, src.dtype, src.layout)
host.copy_(src, non_blocking=True)
host_map[name] = host
done_event.record(copy_stream)

with self._state_lock:
events = self.state.setdefault("streaming_events", [])
events.append({"event": done_event, "stream": copy_stream})
store_callback(host_map)
event_map = self.state.setdefault("streaming_event_map", {})
for key in host_map.keys():
event_map[key] = done_event
events = self.state.setdefault("streaming_events", [])
events.append(
{
"event": done_event,
"stream": copy_stream,
"sources": pending_sources,
"keys": tuple(host_map.keys()),
}
)
return host_map
291 changes: 291 additions & 0 deletions tests/test_named_module.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,17 @@
# SPDX-FileCopyrightText: 2025 ModelCloud.ai
# SPDX-License-Identifier: Apache-2.0

import gc
import math
import os
import queue
import random
import subprocess
import sys
import textwrap
import threading
import time
from dataclasses import dataclass

import pytest
import torch
Expand Down Expand Up @@ -126,3 +133,287 @@ def release(self, tensor):

if result.returncode != 0:
pytest.skip(f"Subprocess streaming test unavailable: {result.stderr.strip()}")


@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA required for multi-thread streaming stress test")
@pytest.mark.skipif(torch.cuda.device_count() < 4, reason="At least 4 CUDA devices required (0-3)")
def test_named_module_multithreaded_streaming_free_thread_stress():
if not hasattr(sys, "_is_gil_enabled"):
pytest.skip("Python runtime does not expose _is_gil_enabled; free-threading build required")
if sys._is_gil_enabled():
pytest.skip("GIL is enabled - run with PYTHON_GIL=0 to exercise free-threaded streaming stress")

thread_count = 12
duration_s = 30.0
devices = [torch.device("cuda", idx) for idx in range(4)]
bf16_bytes = torch.tensor([], dtype=torch.bfloat16).element_size()

class _PinnedHostPool:
def __init__(self):
self._lock = threading.Lock()
self._active = 0
self._max_active = 0

def acquire(self, shape, dtype, layout):
with self._lock:
self._active += 1
if self._active > self._max_active:
self._max_active = self._active
return torch.empty(shape, dtype=dtype, layout=layout, device="cpu", pin_memory=True)

def release(self, tensor):
del tensor
with self._lock:
self._active -= 1

@property
def max_active(self):
with self._lock:
return self._max_active

@dataclass(frozen=True)
class _ModuleContext:
named: NamedModule
device: torch.device
host_pool: _PinnedHostPool

@dataclass(frozen=True)
class _ExpectedTensor:
fingerprint: float
checksum: float

module_contexts: list[_ModuleContext] = []
for idx, device in enumerate(devices):
layer = _make_linear(2048).to(device=device, dtype=torch.bfloat16)
named = NamedModule(layer, name=f"stress_proj_{idx}", full_name=f"stress.layers.{idx}.proj", layer_index=idx)
module_contexts.append(_ModuleContext(named=named, device=device, host_pool=_PinnedHostPool()))

pending_jobs: queue.Queue = queue.Queue()
stop_event = threading.Event()
error_queue: queue.Queue = queue.Queue()
stats_lock = threading.Lock()
stats = {
"payloads_issued": 0,
"pending_enqueues": 0,
"verified_same_thread": 0,
"verified_cross_thread": 0,
"empty_cache_calls": 0,
"gc_collect_calls": 0,
"largest_tensor_mb": 0,
}

def _record_stat(name: str, delta: int = 1) -> None:
with stats_lock:
stats[name] = stats.get(name, 0) + delta

def _update_largest_tensor(val_mb: int) -> None:
with stats_lock:
stats["largest_tensor_mb"] = max(stats["largest_tensor_mb"], val_mb)

def _fingerprint_last_value(tensor: torch.Tensor) -> float:
flat = tensor.reshape(-1)
last = flat[-1]
if last.device.type != "cpu":
last = last.to(dtype=torch.float32, device="cpu")
else:
last = last.to(dtype=torch.float32)
return float(last.item())

def _verify_expected(ctx: _ModuleContext, expected_items: tuple[tuple[str, _ExpectedTensor], ...]) -> bool:
named = ctx.named
for key, expected in expected_items:
with named._state_lock:
host_tensor = named.state.get(key)
event_map = named.state.get("streaming_event_map", {})
pending_event = event_map.get(key)
if host_tensor is None:
ctx.named.stream_sync()
with named._state_lock:
host_tensor = named.state.get(key)
if host_tensor is None:
with named._state_lock:
available = sorted(str(k) for k in named.state.keys())
error_queue.put(f"Missing host tensor for key {key}; available={available}")
stop_event.set()
return False
if pending_event is not None:
pending_event.synchronize()
actual_val = _fingerprint_last_value(host_tensor)
actual_sum = float(host_tensor.to(dtype=torch.float32, device="cpu").sum().item())
if (
not math.isclose(actual_val, expected.fingerprint, rel_tol=0.0, abs_tol=1e-3)
or not math.isclose(actual_sum, expected.checksum, rel_tol=0.0, abs_tol=1e-2)
):
ctx.named.stream_sync()
with named._state_lock:
retry_tensor = named.state.get(key)
if retry_tensor is not None:
retry_val = _fingerprint_last_value(retry_tensor)
retry_sum = float(retry_tensor.to(dtype=torch.float32, device="cpu").sum().item())
else:
retry_val = None
retry_sum = None
ctx.host_pool.release(host_tensor)
del host_tensor
with named._state_lock:
named.state.pop(key, None)
error_queue.put(
"Mismatch for "
f"{key}: expected(last={expected.fingerprint}, sum={expected.checksum}), "
f"got(last={actual_val}, sum={actual_sum}), "
f"retry(last={retry_val}, sum={retry_sum})"
)
stop_event.set()
return False
ctx.host_pool.release(host_tensor)
del host_tensor
with named._state_lock:
named.state.pop(key, None)
named.state.get("streaming_event_map", {}).pop(key, None)
return True

def _maybe_empty_cache(device: torch.device, rng: random.Random, probability: float = 0.25) -> None:
if rng.random() < probability:
with torch.cuda.device(device):
torch.cuda.empty_cache()
_record_stat("empty_cache_calls")

def _try_consume(thread_id: int, rng: random.Random) -> bool:
try:
ctx, expected_items = pending_jobs.get_nowait()
except queue.Empty:
return False
try:
device = ctx.device
torch.cuda.set_device(device)
_maybe_empty_cache(device, rng, probability=0.3)
if rng.random() < 0.3:
gc.collect()
_record_stat("gc_collect_calls")
ctx.named.stream_sync()
if _verify_expected(ctx, expected_items):
_record_stat("verified_cross_thread")
return True
finally:
pending_jobs.task_done()

def _issue_payload(thread_id: int, rng: random.Random, seq_id: int) -> int:
ctx = rng.choice(module_contexts)
device = ctx.device
torch.cuda.set_device(device)
prefix = f"thread{thread_id}-seq{seq_id}"
next_seq = seq_id + 1
tensor_count = rng.randint(1, 3)
tensor_sizes: list[int] = []
payload: dict[str, torch.Tensor] = {}
expected_pairs: list[tuple[str, _ExpectedTensor]] = []
for idx in range(tensor_count):
size_mb = rng.randint(3, 32)
tensor_sizes.append(size_mb)
numel = max(1, (size_mb * 1024 * 1024) // bf16_bytes)
if numel >= 1024:
cols = 256
rows = max(1, numel // cols)
shape = (rows, cols)
else:
shape = (numel,)
tensor = torch.randn(shape, device=device, dtype=torch.bfloat16)
key = f"{prefix}/tensor{idx}"
payload[key] = tensor
expected_pairs.append(
(
key,
_ExpectedTensor(
fingerprint=_fingerprint_last_value(tensor),
checksum=float(tensor.to(dtype=torch.float32).sum().item()),
),
)
)
_update_largest_tensor(max(tensor_sizes))
_record_stat("payloads_issued")
_maybe_empty_cache(device, rng, probability=0.35)
ctx.named.stream_state_payload_to_cpu(payload, host_pool=ctx.host_pool)
if rng.random() < 0.35:
gc.collect()
_record_stat("gc_collect_calls")
if rng.random() < 0.5:
ctx.named.stream_sync()
if _verify_expected(ctx, tuple(expected_pairs)):
_record_stat("verified_same_thread")
else:
pending_jobs.put((ctx, tuple(expected_pairs)))
_record_stat("pending_enqueues")
time.sleep(rng.uniform(0.0, 0.003))
return next_seq

barrier = threading.Barrier(parties=thread_count + 1)
deadline = time.monotonic() + duration_s

def _worker(thread_id: int) -> None:
rng = random.Random(0x9E3779B97F4A7C15 ^ thread_id)
seq = 0
try:
barrier.wait()
while time.monotonic() < deadline and not stop_event.is_set():
processed = False
if rng.random() < 0.6:
processed = _try_consume(thread_id, rng)
if not processed:
seq = _issue_payload(thread_id, rng, seq)
except Exception as exc:
stop_event.set()
error_queue.put(f"Thread {thread_id} error: {exc}")

threads = [threading.Thread(target=_worker, args=(idx,), name=f"named-module-stress-{idx}") for idx in range(thread_count)]
for thread in threads:
thread.start()
barrier.wait()

for thread in threads:
thread.join()

while not pending_jobs.empty():
ctx, expected_items = pending_jobs.get()
torch.cuda.set_device(ctx.device)
ctx.named.stream_sync()
if not _verify_expected(ctx, expected_items):
break
pending_jobs.task_done()

for device in devices:
torch.cuda.synchronize(device=device)
for ctx in module_contexts:
with ctx.named._state_lock:
keys_to_remove = [key for key in ctx.named.state.keys() if key.startswith("thread")]
for key in keys_to_remove:
ctx.named.state.pop(key, None)

if not error_queue.empty():
errors = []
while not error_queue.empty():
errors.append(error_queue.get())
pytest.fail(" ; ".join(errors))

with stats_lock:
summary = {
"payloads_issued": stats["payloads_issued"],
"pending_enqueues": stats["pending_enqueues"],
"verified_same_thread": stats["verified_same_thread"],
"verified_cross_thread": stats["verified_cross_thread"],
"empty_cache_calls": stats["empty_cache_calls"],
"gc_collect_calls": stats["gc_collect_calls"],
"largest_tensor_mb": stats["largest_tensor_mb"],
}

pool_usage = ", ".join(
f"gpu{ctx.device.index}:max_pinned={ctx.host_pool.max_active}" for ctx in module_contexts
)
print(
f"NamedModule multi-thread stress summary: "
f"payloads={summary['payloads_issued']}, pending={summary['pending_enqueues']}, "
f"verified_same_thread={summary['verified_same_thread']}, "
f"verified_cross_thread={summary['verified_cross_thread']}, "
f"empty_cache_calls={summary['empty_cache_calls']}, "
f"gc_collect_calls={summary['gc_collect_calls']}, "
f"largest_tensor_mb={summary['largest_tensor_mb']}; pool_usage={pool_usage}"
)