Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 35 additions & 11 deletions gptqmodel/looper/gptq_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,24 @@
from ..utils.logger import setup_logger, log_time_block
from ..utils.device import get_device
from ..utils.model import create_quant_module, find_modules, move_to, pack_model, pack_module
from ..utils.torch import HAS_CUDA, tf32_disable_guard, torch_streamCtx, torch_sync
from ..utils.torch import tf32_disable_guard

log = setup_logger()
lock = threading.Lock()


class _PinnedHostPool:
def __init__(self) -> None:
self._lock = threading.Lock()

def acquire(self, shape: torch.Size, dtype: torch.dtype, layout: torch.layout) -> torch.Tensor:
return torch.empty(shape, dtype=dtype, layout=layout, device="cpu", pin_memory=True)

def release(self, tensor: torch.Tensor) -> None:
# No pooling to avoid cross-thread pinned storage reuse issues.
return None


class GPTQProcessor(LoopProcessor):
def __init__(self, tokenizer, qcfg: QuantizeConfig, calibration, prepare_dataset_func,
calibration_concat_size: Optional[int], calibration_sort: Optional[str], batch_size: int,
Expand All @@ -42,6 +55,7 @@ def __init__(self, tokenizer, qcfg: QuantizeConfig, calibration, prepare_dataset

self.calculate_w_wq_diff = calculate_w_wq_diff
self.avg_losses = []
self._host_pool = _PinnedHostPool()

def set_calibration_dataset(self, calibration_dataset):
raise NotImplementedError("GPTQProcessor's calibration_dataset cannot be modified")
Expand Down Expand Up @@ -162,15 +176,17 @@ def process(self, module: NamedModule):
with tf32_disable_guard():
wq, q_scales, q_zeros, q_g_idx, duration, avg_loss, damp_percent, nsamples = g.quantize()

q_scales = q_scales.to(CPU)
q_zeros = q_zeros.to(CPU)
q_g_idx = q_g_idx.to(CPU)
module.stream_state_payload_to_cpu(
{
"q_scales": q_scales,
"q_zeros": q_zeros,
"q_g_idx": q_g_idx,
},
host_pool=self._host_pool,
)
del q_scales, q_zeros, q_g_idx

with self.lock:
module.state.update({"q_scales": q_scales})
module.state.update({"q_zeros": q_zeros})
module.state.update({"q_g_idx": q_g_idx})

self.durations.append(duration)
self.avg_losses.append(avg_loss)
self.module_names.append(f"layer-{module.layer_index}-{module.name}")
Expand Down Expand Up @@ -248,6 +264,7 @@ def submodule_finalize(self, module: NamedModule, model: BaseQModel, **kwargs):
# module.weight.data = move_to(module.state.pop("wq"), device=CPU) # large weights is slow to init on cpu

# cleanup all memory or states vars persistently added by this processor
module.stream_sync()
with (self.lock):
# if calculate_w_wq_diff is enabled (eora), we need to revert our original wq
if self.calculate_w_wq_diff:
Expand All @@ -256,9 +273,10 @@ def submodule_finalize(self, module: NamedModule, model: BaseQModel, **kwargs):
module.state.pop("w", None) #
module.state.pop("w_wq_diff", None)

q_zeros = module.state.pop("q_zeros")
q_scales = module.state.pop("q_scales")
q_g_idx = module.state.pop("q_g_idx")
# need to clone to due to steamed pinned memory and access on diff thread
q_zeros = module.state.pop("q_zeros").clone()
q_scales = module.state.pop("q_scales").clone()
q_g_idx = module.state.pop("q_g_idx").clone()

assert q_zeros.device == CPU
assert q_scales.device == CPU
Expand Down Expand Up @@ -332,6 +350,7 @@ def submodule_finalize(self, module: NamedModule, model: BaseQModel, **kwargs):
with self.lock:
self.result_pop(module.full_name)

self._release_host_buffers(q_scales, q_zeros, q_g_idx)
module.unregister_parameter("weight")

def finalize(self, model: BaseQModel, **kwargs):
Expand All @@ -354,3 +373,8 @@ def name(self) -> str:
# TODO fix me..this hacks inherited base class logic, why not override name in gptqv2?
qcfg = self.qcfg_dynamic if self.qcfg_dynamic is not None else self.qcfg
return "gptq v2" if qcfg.v2 else "gptq"

def _release_host_buffers(self, *tensors: torch.Tensor) -> None:
for tensor in tensors:
if isinstance(tensor, torch.Tensor):
self._host_pool.release(tensor)
122 changes: 102 additions & 20 deletions gptqmodel/looper/named_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
# SPDX-License-Identifier: Apache-2.0
# Contact: qubitium@modelcloud.ai, x.com/qubitium
import threading
from typing import Any, Optional
from typing import Any, Dict, Optional

import torch
import transformers
Expand All @@ -16,7 +16,6 @@
log = setup_logger()

class NamedModule(torch.nn.Module):
_lock = threading.Lock()

def __init__(self, module: torch.nn.Module, name: str, full_name:str, layer_index: int) -> None:
super().__init__()
Expand All @@ -30,6 +29,7 @@ def __init__(self, module: torch.nn.Module, name: str, full_name:str, layer_inde
# persistent work state for named module (used by some LoopProcessors)
# store all `processed()` work state/data/result here
self.state = {}
self._state_lock = threading.RLock()

# print(f"NamedModule init: name: `{name}, full-name: `{full_name}`")

Expand Down Expand Up @@ -72,34 +72,26 @@ def named_buffers(self, prefix: str = "", recurse: bool = True):
def register_buffer(
self, name: str, tensor: Optional[Tensor], persistent: bool = True
) -> None:
with self._lock:
with self._state_lock:
return self.module.register_buffer(name, tensor, persistent)

def unregister_buffer(self, name: str):
with self._lock:
with self._state_lock:
if name in self.module._buffers:
del self.module._buffers[name]
if hasattr(self.module, name):
delattr(self.module, name)
# else:
# log.debug(f"{self.full_name} has no attribute: {name}")
# else:
# log.debug(f"{self.full_name} has no buffer: {name}")

def register_parameter(self, name: str, param: Optional[Parameter]) -> None:
with self._lock:
with self._state_lock:
return self.module.register_parameter(name, param)

def unregister_parameter(self, name: str) -> None:
with self._lock:
with self._state_lock:
if name in self.module._parameters:
del self.module._parameters[name]
if hasattr(self.module, name):
delattr(self.module, name)
# else:
# log.debug(f"{self.full_name} has no attribute: {name}")
# else:
# log.debug(f"{self.full_name} has no parameter: {name}")
# return stats for mo
# def stats(self) -> Dict[str, float]:
# # -1 means no stats have yet to gathered for the stat property
Expand All @@ -112,13 +104,103 @@ def unregister_parameter(self, name: str) -> None:

# getattr is only called if python cannot find attr for `self`
def __getattr__(self, name: str):
with self._lock:
with self._state_lock:
return getattr(self.module, name)

# setattr is always called by python even if attr exists in `self`
def __setattr__(self, name: str, value: Any) -> None:
with self._lock:
if name in ["module", "module_dtype", "name", "full_name", "layer_index", "state", "target_device", "register_buffer", "unregister_buffer", "register_parameter", "unregister_parameter"]:
self.__dict__[name] = value
else:
self.module.__dict__[name] = value
if name in [
"module",
"module_dtype",
"name",
"full_name",
"layer_index",
"state",
"target_device",
"register_buffer",
"unregister_buffer",
"register_parameter",
"unregister_parameter",
"_state_lock",
]:
object.__setattr__(self, name, value)
return

with self._state_lock:
setattr(self.module, name, value)

def stream_state_payload_to_cpu(
self,
tensors: Dict[str, torch.Tensor],
*,
host_pool,
) -> Dict[str, torch.Tensor]:
return self._stream_tensor_dict(
tensors,
host_pool=host_pool,
store_callback=lambda host_map: self.state.update(host_map),
)

def stream_parameters_to_cpu(self, *, host_pool) -> Dict[str, torch.Tensor]:
tensor_map = {name: param for name, param in self.module.named_parameters(recurse=False)}
return self._stream_tensor_dict(
tensor_map,
host_pool=host_pool,
store_callback=lambda host_map: self.state.setdefault("parameters_cpu", {}).update(host_map),
)

def stream_buffers_to_cpu(self, *, host_pool) -> Dict[str, torch.Tensor]:
tensor_map = {name: buf for name, buf in self.module.named_buffers(recurse=False)}
return self._stream_tensor_dict(
tensor_map,
host_pool=host_pool,
store_callback=lambda host_map: self.state.setdefault("buffers_cpu", {}).update(host_map),
)

def stream_all_to_cpu(self, *, host_pool) -> Dict[str, Dict[str, torch.Tensor]]:
params = self.stream_parameters_to_cpu(host_pool=host_pool)
buffers = self.stream_buffers_to_cpu(host_pool=host_pool)
return {"parameters": params, "buffers": buffers}

def stream_sync(self) -> None:
with self._state_lock:
pending = self.state.pop("streaming_events", [])
for entry in pending:
entry["event"].synchronize()

def _stream_tensor_dict(
self,
tensors: Dict[str, torch.Tensor],
*,
host_pool,
store_callback,
) -> Dict[str, torch.Tensor]:
filtered = {name: tensor for name, tensor in tensors.items() if isinstance(tensor, torch.Tensor)}
if not filtered:
return {}

first = next(iter(filtered.values()))

if first.device.type != "cuda" or not torch.cuda.is_available():
host_map = {name: tensor.detach().to("cpu") for name, tensor in filtered.items()}
with self._state_lock:
store_callback(host_map)
return host_map

stream = torch.cuda.Stream(device=first.device)
done_event = torch.cuda.Event(enable_timing=False)
host_map: Dict[str, torch.Tensor] = {}

with torch.cuda.stream(stream):
for name, tensor in filtered.items():
src = tensor.detach()
host = host_pool.acquire(src.shape, src.dtype, src.layout)
host.copy_(src, non_blocking=True)
host_map[name] = host
done_event.record(stream)

with self._state_lock:
events = self.state.setdefault("streaming_events", [])
events.append({"event": done_event, "stream": stream})
store_callback(host_map)
return host_map
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ dependencies = [
"huggingface_hub>=0.34.4",
"random_word>=1.0.13",
"tokenicer>=0.0.5",
"logbar>=0.1.3",
"logbar>=0.1.4",
"maturin>=1.9.4", # required by safetensors and hf_transfer
"datasets>=3.6.0",
"pyarrow>=21.0",
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ hf_transfer>=0.1.9
huggingface_hub>=0.34.4
random_word>=1.0.13
tokenicer>=0.0.5
logbar>=0.1.3
logbar>=0.1.4
maturin>=1.9.4
datasets>=3.6.0
pyarrow>=21.0
Expand Down
Loading