Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion gptqmodel/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,16 +10,23 @@

DEBUG_ON = env_flag("DEBUG")

from .utils.linalg_warmup import run_torch_linalg_warmup
from .utils.threadx import DeviceThreadPool


DEVICE_THREAD_POOL = DeviceThreadPool(
inference_mode=True,
warmups={
"cuda": run_torch_linalg_warmup,
"xpu": run_torch_linalg_warmup,
"mps": run_torch_linalg_warmup,
"cpu": run_torch_linalg_warmup,
},
workers={
"cuda:per": 4,
"xpu:per": 1,
"mps": 8,
"cpu": 8,
"cpu": min(12, max(1, (os.cpu_count() or 1) // 2)),
"model_loader:cpu": 2,
},
empty_cache_every_n=512,
Expand Down
5 changes: 2 additions & 3 deletions gptqmodel/eora/eora.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@

from ..utils.logger import setup_logger
from ..utils.rocm import IS_ROCM
from ..utils.safe import TORCH_LINALG

log = setup_logger()

Expand Down Expand Up @@ -89,7 +88,7 @@ def eora_compute_lora(
original_backend = torch.backends.cuda.preferred_linalg_library()
torch.backends.cuda.preferred_linalg_library(backend="magma")

L, Q = TORCH_LINALG.eigh(raw_scaling_diag_matrix)
L, Q = torch.linalg.eigh(raw_scaling_diag_matrix)

if (L < 0).any():
## When expanding the calibration data size for EoRA, I suggest maintaining the balance by allocating 50% to general input (C4) and the remaining 50% to downstream task data.
Expand All @@ -107,7 +106,7 @@ def eora_compute_lora(

delta_scale = torch.matmul(w_wq_delta, scaling_diag_matrix)

U, S, V = TORCH_LINALG.svd(delta_scale, full_matrices=False)
U, S, V = torch.linalg.svd(delta_scale, full_matrices=False)
lowrank_r = rank
truc_s = S[:lowrank_r]
truc_u = U[:, :lowrank_r]
Expand Down
5 changes: 2 additions & 3 deletions gptqmodel/quantization/gptq.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
from ..quantization import QuantizeConfig
from ..utils.device import get_device
from ..utils.logger import setup_logger
from ..utils.safe import TORCH_LINALG
from .gar import compose_final_perm, compute_global_perm, compute_local_perms, invert_perm
from .quantizer import HF_OPTIMUM, Quantizer

Expand Down Expand Up @@ -567,8 +566,8 @@ def hessian_inverse(self, H: torch.Tensor):
H2 = H.clone()
H2[diag, diag] += damp * mean
# TODO call to torch.linalg is not threadsafe? Porque no? Esta muy mal.
H2 = TORCH_LINALG.cholesky(H2)
Hinv = TORCH_LINALG.cholesky(torch.cholesky_inverse(H2), upper=True)
H2 = torch.linalg.cholesky(H2)
Hinv = torch.linalg.cholesky(torch.cholesky_inverse(H2), upper=True)
del H, H2
break
except torch._C._LinAlgError as e:
Expand Down
5 changes: 2 additions & 3 deletions gptqmodel/quantization/qqq.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
from ..looper.named_module import NamedModule
from ..quantization.quantizer import HF_OPTIMUM
from ..utils import setup_logger
from ..utils.safe import TORCH_LINALG
from .gptq import get_number_of_rows_and_cols


Expand Down Expand Up @@ -355,9 +354,9 @@ def quantize(
damp = percdamp * torch.mean(torch.diag(H))
diag = torch.arange(self.columns, device=self.dev)
H[diag, diag] += damp
H = TORCH_LINALG.cholesky(H)
H = torch.linalg.cholesky(H)
H = torch.cholesky_inverse(H)
H = TORCH_LINALG.cholesky(H, upper=True)
H = torch.linalg.cholesky(H, upper=True)
Hinv = H

for i1 in range(0, self.columns, blocksize):
Expand Down
3 changes: 1 addition & 2 deletions gptqmodel/quantization/rotation/rotation.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@

from ...utils.logger import setup_logger
from ...utils.model import get_module_by_name_prefix
from ...utils.safe import TORCH_LINALG
from ...utils.torch import torch_empty_cache
from .hadamard_utils import apply_exact_had_to_linear, random_hadamard_matrix

Expand Down Expand Up @@ -91,7 +90,7 @@ def random_orthogonal_matrix(size, device):
"""
torch.cuda.empty_cache()
random_matrix = torch.randn(size, size, dtype=torch.float64).to(device)
q, r = TORCH_LINALG.qr(random_matrix)
q, r = torch.linalg.qr(random_matrix)
q *= torch.sign(torch.diag(r)).unsqueeze(0)
return q

Expand Down
195 changes: 195 additions & 0 deletions gptqmodel/utils/cuda_activation_buffer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,195 @@
# SPDX-FileCopyrightText: 2025 ModelCloud.ai
# SPDX-License-Identifier: Apache-2.0

from __future__ import annotations

import dataclasses
import queue
import threading
import time
from typing import Any, Callable, List, Optional

import torch


__all__ = ["ActivationPacket", "CudaEventActivationBuffer"]


@dataclasses.dataclass(slots=True)
class ActivationPacket:
"""
Tracks a single async device->host transfer triggered from a forward hook.

The event is recorded on the dedicated copy stream so the consumer can
decide when to block. The `host_tensor` already points at pinned memory.
"""

event: torch.cuda.Event
host_tensor: torch.Tensor
meta: Optional[Any] = None
created_at: float = dataclasses.field(default_factory=time.perf_counter)


class CudaEventActivationBuffer:
"""
Schedules non-blocking GPU->CPU copies using a dedicated CUDA stream + event.

Typical usage inside a forward hook::

buffer = CudaEventActivationBuffer(device="cuda:6")

def hook(module, inputs, output):
tensor = output[0] if isinstance(output, (tuple, list)) else output
buffer.capture_async(tensor, meta=module.__class__.__name__)

# elsewhere in consumer thread
for packet in buffer.drain():
packet.event.synchronize()
process(packet.host_tensor, packet.meta)

The hook thread returns immediately after enqueuing the async copy which
allows the caller to release activation VRAM without waiting on D2H traffic.
"""

def __init__(
self,
device: torch.device | str | int,
stream: Optional[torch.cuda.Stream] = None,
pin_memory: bool = True,
host_allocator: Optional[Callable[[torch.Size, torch.dtype, torch.layout], torch.Tensor]] = None,
host_reclaimer: Optional[Callable[[torch.Tensor], None]] = None,
) -> None:
if not torch.cuda.is_available():
raise RuntimeError("CUDA is not available for CudaEventActivationBuffer.")

dev = torch.device(device)
if dev.type != "cuda":
raise ValueError(f"CudaEventActivationBuffer requires a CUDA device, got {dev}.")

if dev.index is None:
dev = torch.device("cuda", torch.cuda.current_device())

self._device = dev
self._pin_memory = pin_memory
self._host_allocator = host_allocator
self._host_reclaimer = host_reclaimer

with torch.cuda.device(self._device):
self._copy_stream = stream or torch.cuda.Stream()

self._pending: "queue.SimpleQueue[ActivationPacket]" = queue.SimpleQueue()
self._lock = threading.Lock()
self._approx_pending = 0

def capture_async(
self,
activation: torch.Tensor,
*,
meta: Any = None,
enqueue: bool = True,
) -> ActivationPacket:
"""
Enqueue an async D2H copy of ``activation`` onto the buffer stream.

Returns an ActivationPacket which is also available later via drain().
"""
if activation.device != self._device:
raise ValueError(
f"Activation tensor is on {activation.device}, expected {self._device}."
)

activation = activation.detach()
if not activation.is_contiguous():
activation = activation.contiguous()

host = self._allocate_host(activation)

event = torch.cuda.Event(blocking=False, interprocess=False)

current = torch.cuda.current_stream(self._device)
copy_stream = self._copy_stream
copy_stream.wait_stream(current)

with torch.cuda.stream(copy_stream):
host.copy_(activation, non_blocking=True)
event.record(copy_stream)

packet = ActivationPacket(event=event, host_tensor=host, meta=meta)
if enqueue:
self._pending_put(packet)
return packet

def drain(self, *, wait: bool = True, max_items: Optional[int] = None) -> List[ActivationPacket]:
"""
Collect all queued packets (or up to max_items) in FIFO order.

When ``wait`` is True we synchronize each packet's event before returning.
"""
packets: List[ActivationPacket] = []
pulled = 0

while True:
if max_items is not None and pulled >= max_items:
break

try:
packet = self._pending_get()
except queue.Empty:
break

pulled += 1
if wait:
packet.event.synchronize()
packets.append(packet)

return packets

def recycle(self, packet: ActivationPacket) -> None:
"""
Return a packet's host buffer to the allocator pool (if provided).
"""
if self._host_reclaimer is not None:
self._host_reclaimer(packet.host_tensor)

def pending_count(self) -> int:
"""
Non-blocking length check. The SimpleQueue does not expose qsize()
reliably on all platforms, so we track with a lock-protected counter.
"""
with self._lock:
count = getattr(self, "_approx_pending", 0)
return count

def __len__(self) -> int:
return self.pending_count()

def __enter__(self) -> "CudaEventActivationBuffer":
return self

def __exit__(self, exc_type, exc, tb) -> None:
self.drain(wait=True)

def _pending_put(self, packet: ActivationPacket) -> None:
with self._lock:
self._approx_pending = getattr(self, "_approx_pending", 0) + 1
self._pending.put(packet)

def _pending_get(self) -> ActivationPacket:
packet = self._pending.get_nowait()
with self._lock:
self._approx_pending = max(getattr(self, "_approx_pending", 0) - 1, 0)
return packet

def _allocate_host(self, activation: torch.Tensor) -> torch.Tensor:
if self._host_allocator is not None:
host = self._host_allocator(activation.shape, activation.dtype, activation.layout)
if not host.is_pinned():
raise ValueError("Custom host allocator must return pinned CPU tensors.")
return host
return torch.empty(
activation.shape,
dtype=activation.dtype,
layout=activation.layout,
device="cpu",
pin_memory=self._pin_memory,
)
69 changes: 69 additions & 0 deletions gptqmodel/utils/linalg_warmup.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
# SPDX-FileCopyrightText: 2024-2025 ModelCloud.ai
# SPDX-License-Identifier: Apache-2.0
# Contact: qubitium@modelcloud.ai, x.com/qubitium

from __future__ import annotations

import contextlib
import threading

import torch


_GLOBAL_WARMUP_LOCK = threading.Lock()


def _make_spd(size: int, device: torch.device, dtype: torch.dtype) -> torch.Tensor:
"""Generate a small symmetric positive definite matrix."""
base = torch.randn((size, size), device=device, dtype=dtype)
identity = torch.eye(size, device=device, dtype=dtype)
return base @ base.transpose(-1, -2) + identity * 1e-3


def _run_cholesky_and_eigh(device: torch.device, dtype: torch.dtype) -> None:
spd = _make_spd(4, device, dtype)
torch.linalg.cholesky(spd)
torch.linalg.eigh(spd)


def _run_svd(device: torch.device, dtype: torch.dtype) -> None:
mat = torch.randn((4, 3), device=device, dtype=dtype)
torch.linalg.svd(mat, full_matrices=False)


def _run_qr(device: torch.device, dtype: torch.dtype) -> None:
square = torch.randn((4, 4), device=device, dtype=dtype)
torch.linalg.qr(square)


def run_torch_linalg_warmup(device: torch.device) -> None:
"""
Execute the torch.linalg operators used across the project once on the worker thread.

Serialized under a global lock to avoid races inside PyTorch's lazy wrappers. The warmup
still runs once per physical device so backend-specific handles are initialized where needed.
"""
with _GLOBAL_WARMUP_LOCK:
dtypes = (torch.float32, torch.float64)
for dtype in dtypes:
_run_cholesky_and_eigh(device, dtype)
_run_svd(device, dtype)
_run_qr(device, dtype)

if device.type == "cuda" and hasattr(torch.backends, "cuda"):
preferred = getattr(torch.backends.cuda, "preferred_linalg_library", None)
if callable(preferred):
current = preferred()
# Core warmup already ran using the currently preferred backend above.
# Some installations fall back to MAGMA when the primary solver is unavailable,
# so we pre-initialize MAGMA as well when it differs from the preferred backend.
if current and current != "magma":
with contextlib.suppress(Exception):
torch.backends.cuda.preferred_linalg_library(backend="magma")
_run_cholesky_and_eigh(device, torch.float32)
if current:
with contextlib.suppress(Exception):
torch.backends.cuda.preferred_linalg_library(backend=current)


__all__ = ["run_torch_linalg_warmup"]
Loading