Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 14 additions & 24 deletions gptqmodel/looper/eora_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,8 @@ def process(self, module: NamedModule):
device=module.target_device,
)

del eigen_scaling_diag_matrix

# wq with A/B applied
computed_wq = wq + (B @ A)

Expand Down Expand Up @@ -194,31 +196,19 @@ def process(self, module: NamedModule):
# log.info(stat)
self.log_new_row(stat)

# logger.info(f"Quantizing module END: {name}, {gptq[name].shape()}")
with self.lock:
self.result_save(module.full_name, {
"rank": module.adapter_cfg.rank,
"lora_A.weight": move_to(A.to(dtype=module.module_dtype), device=CPU, stream=self.stream),
"lora_B.weight": move_to(B.to(dtype=module.module_dtype), device=CPU, stream=self.stream),
})

# eora = Lora(rank=module.adapter_cfg.rank, lora_A=A, lora_B=B)
#
# module.state.update({
# "adapter": eora,
# })
eora = Lora(
rank=module.adapter_cfg.rank,
lora_A=move_to(A.to(dtype=module.module_dtype), device=CPU, stream=self.stream),
lora_B=move_to(B.to(dtype=module.module_dtype), device=CPU, stream=self.stream),
)

def submodule_finalize(self, module: NamedModule):
pass
# adapter: Lora = module.state.pop("adapter")
#
# # logger.info(f"Quantizing module END: {name}, {gptq[name].shape()}")
# self.result_save(module.full_name, {
# "lora_A.weight": move_to(adapter.lora_A.to(dtype=torch.float16), device=CPU, stream=self.stream),
# # A.to(dtype=torch.float16, device=CPU),
# "lora_B.weight": move_to(adapter.lora_B.to(dtype=torch.float16), device=CPU, stream=self.stream),
# # B.to(dtype=torch.float16, device=CPU),
# })
module.state.update({
"adapter": eora
})

def submodule_finalize(self, module: NamedModule, model: BaseQModel, **kwargs):
# logger.info(f"Quantizing module END: {name}, {gptq[name].shape()}")
self.result_save(module.full_name, module.state.pop("adapter"))

def finalize(self, model: BaseQModel, **kwargs):
# block for streams
Expand Down
5 changes: 5 additions & 0 deletions gptqmodel/looper/loop_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,22 +224,27 @@ def log_new_row(self, stat):
log.info(formatted_row)
log.info(len(formatted_row) * "-")

# Loop Procssor level scoped state data
def result_save(self, key: str, value: Any):
with self._results_lock:
#assert self.result_get(key) is None, f"key: {key} already exists in `self.result`"
self._results[key] = value

# Loop Procssor level scoped state data
def result_get(self, key: str, default: Any = None) -> Any:
with self._results_lock:
return self._results.get(key, default)

# Loop Procssor level scoped state data
def result_pop(self, key: str, default: Any = None):
with self._results_lock:
return self._results.pop(key, default)

# Loop Procssor level scoped state data
def result_pop(self, key: str, default: Any = None) -> Any:
return self._results.pop(key, default)

# Loop Procssor level scoped state data
def results(self):
return self._results

Expand Down
30 changes: 18 additions & 12 deletions gptqmodel/looper/module_looper.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from ..models._const import CUDA, SUPPORTS_MODULE_TYPES
from ..nn_modules.hooked_linear import (STOP_FORWARD_EXCEPTION, HookedLinear,
StopForward, replace_module_with_hooked_legacy)
from ..utils import ASYNC_WORKER
from ..utils import ASYNC_BG_QUEUE, SERIAL_BG_QUEUE
from ..utils.attn_mask import apply_keep_mask_bt, normalize_seq_mask
from ..utils.device import get_device
from ..utils.logger import setup_logger
Expand All @@ -32,6 +32,7 @@
from ..utils.torch import (ALL_DEVICES, CPU, DEFAULT_BALANCE_STRATEGY, HAS_CUDA, META, BalanceStrategy,
device_next, device_next_reset, torch_empty_cache, torch_sync)
from .awq_processor import AWQProcessor
from .qqq_processor import QQQProcessor

log = setup_logger()

Expand Down Expand Up @@ -222,7 +223,7 @@ def loop(self, fail_safe: bool = False, **kwargs):
prev_processor = self.processors[p_index - 1]
processor.set_calibration_dataset(prev_processor.calibration_dataset)
# If calibration_dataset is None or Empty, the input_cache of the previous processor is used.
processor.receive_input_cache(copy.copy(prev_processor.inputs_cache))
processor.receive_input_cache(prev_processor.inputs_cache)
elif isinstance(processor, DequantizeProcessor):
# DequantizeProcessor does not perform any operations on dataset.
processor.set_calibration_dataset([])
Expand Down Expand Up @@ -543,31 +544,36 @@ def process_module(name, m):
for reverse_p in reversed(self.processors):
for name in processed_subset:
@torch.inference_mode()
def finalize_module(module):
def finalize_module(process, module):
# prevent cuda sync memory ctx bugs
m_device = get_device(module)
if HAS_CUDA and m_device is not None and m_device.type == "cuda":
torch.cuda.set_device(m_device)

reverse_p.submodule_finalize(module, self.gptq_model)
process.submodule_finalize(module, self.gptq_model)

# checking for disk offloading
offload_to_disk(
model=self.gptq_model.model,
module=self.gptq_model.model.get_submodule(module.full_name),
disk_path=self.gptq_model.quantize_config.offload_to_disk_path,
)
# TODO FIX ME offloading to LoopProcessor lifecycle
if isinstance(process, (GPTQProcessor, QQQProcessor, AWQProcessor)):
# checking for disk offloading
offload_to_disk(
model=self.gptq_model.model,
module=self.gptq_model.model.get_submodule(module.full_name),
disk_path=self.gptq_model.quantize_config.offload_to_disk_path,
)

module = processed_subset[name]

if self.gptq_model.quantize_config.offload_to_disk:
ASYNC_WORKER.submit(partial(
SERIAL_BG_QUEUE.submit(partial(
finalize_module,
process=reverse_p,
module=module,
))
else:
reverse_p.submodule_finalize(module, self.gptq_model)

# LifeCycle: All sub-modules have finalized meaning quantization work is complete
ASYNC_WORKER.join()
SERIAL_BG_QUEUE.join()

# paranoid safety check
torch_sync()
Expand Down
3 changes: 2 additions & 1 deletion gptqmodel/looper/named_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,8 @@ def __init__(self, module: torch.nn.Module, name: str, full_name:str, layer_inde
# self.target_device, self.target_device_stream = device_next()
self.target_device, self.target_device_stream = None, None

# persistent work state forLoopProcessors

# persistent work state for named module (used by some LoopProcessors)
# store all `processed()` work state/data/result here
self.state = {}

Expand Down
14 changes: 7 additions & 7 deletions gptqmodel/models/writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,8 @@ def _eora_save(self, save_dir: str, model_save_dir: str = None):
weights = {}
target_modules = set()
# convert the dict into safetensors compatible dict
for key, d in self.lora_results.items():
for key, adapter in self.lora_results.items():
assert isinstance(adapter, Lora)
key = key.lower()
simple_module_name = key.split(".")[-1] # mlp.gate_proj => gate_proj
target_modules.add(simple_module_name)
Expand All @@ -84,12 +85,11 @@ def _eora_save(self, save_dir: str, model_save_dir: str = None):
# key = key.removeprefix('model.') # some HF models use model. or model.model.

# must normalize key since HF can load weights as `model.` or not based on what AutoModel is used
key = f"{HF_ADAPTER_WEIGHT_KEY_PREFIX}{key}"
lora_rank = d.pop("rank")
for lora_key, lora_weight in d.items():
assert isinstance(lora_weight, torch.Tensor)
weights[f"{key}.{lora_key}"] = lora_weight
log.info(f"Adapter: EoRA weights found -> `{key}.{lora_key}`, rank = `{lora_rank}`")
weight_key = f"{HF_ADAPTER_WEIGHT_KEY_PREFIX}{key}"

weights[f"{weight_key}.lora_A.weight"] = adapter.lora_A
weights[f"{weight_key}.lora_B.weight"] = adapter.lora_B
log.info(f"Adapter: EoRA weights found -> `{weight_key}.lora_A/Lora_B.weight`, rank = `{adapter.rank}`")

weight_file_path = f"{save_dir.removesuffix('/')}/{HF_ADAPTER_FILE_NAME}"

Expand Down
5 changes: 3 additions & 2 deletions gptqmodel/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,13 @@
from .backend import BACKEND
from .logger import setup_logger
from .python import gte_python_3_13_3, has_gil_control, has_gil_disabled, log_gil_requirements_for
from .threads import AsyncManager
from .threads import AsyncManager, SerialWorker
from .vram import get_vram

log = setup_logger()

ASYNC_WORKER = AsyncManager(threads=4)
ASYNC_BG_QUEUE = AsyncManager(threads=4)
SERIAL_BG_QUEUE = SerialWorker()

# TODO: datasets is not compatible with free threading
if has_gil_disabled():
Expand Down
32 changes: 32 additions & 0 deletions gptqmodel/utils/threads.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# Contact: qubitium@modelcloud.ai, x.com/qubitium

import concurrent.futures as cf
import queue
import threading
import time
import traceback
Expand Down Expand Up @@ -84,3 +85,34 @@ def shutdown(self, wait=True, cancel_pending=False):
if wait:
self.join() # wait for remaining tasks
self._exec.shutdown(wait=wait)


class SerialWorker:
def __init__(self, name="serial-worker"):
self._q = queue.Queue()
self._t = threading.Thread(target=self._loop, name=name, daemon=True)
self._t.start()

def _loop(self):
while True:
fn = self._q.get()
if fn is None:
break
try:
fn()
except Exception:
traceback.print_exc()
finally:
self._q.task_done()

def submit(self, fn):
if not callable(fn):
raise TypeError("submit expects a callable")
self._q.put(fn)

def join(self, timeout=None):
self._q.join()

def shutdown(self):
self._q.put(None)
self._t.join()
131 changes: 131 additions & 0 deletions tests/test_cpu_gpu_memory_copy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
# SPDX-FileCopyrightText: 2024-2025 ModelCloud.ai
# SPDX-FileCopyrightText: 2024-2025 qubitium@modelcloud.ai
# SPDX-License-Identifier: Apache-2.0
# Contact: qubitium@modelcloud.ai, x.com/qubitium


# cpu_gpu_bandwidth_test.py
# Measure HtoD and DtoH bandwidth with pageable vs pinned CPU memory.
#
# Usage examples:
# python cpu_gpu_bandwidth_test.py # 40 GiB total, 1 GiB chunks, GPU:0
# python cpu_gpu_bandwidth_test.py --total-gib 80 --chunk-gib 2 --gpu 1
#
# Notes:
# - We stream in chunks to avoid allocating a single massive tensor.
# - For HtoD, pinned CPU memory + non_blocking=True is required for true async copies.
# - For DtoH, pinned CPU memory also enables non_blocking transfers.
# - We synchronize before/after timing to get accurate results.
# - dtype is fp16 to match your earlier test (2 bytes/elem).

import argparse
import math
import time
import torch

def gib_to_elems_fp16(gib: float) -> int:
# 1 GiB = 1024**3 bytes; fp16 = 2 bytes/elem
return int((gib * (1024**3)) // 2)

def gibs_per_s(bytes_moved: int, seconds: float) -> float:
return (bytes_moved / (1024**3)) / seconds

def make_cpu_tensor(num_elems: int, pin: bool) -> torch.Tensor:
# Pageable vs pinned CPU tensor
return torch.empty(num_elems, dtype=torch.float16, device="cpu", pin_memory=pin)

def make_gpu_tensor(num_elems: int, gpu: int) -> torch.Tensor:
with torch.cuda.device(gpu):
return torch.empty(num_elems, dtype=torch.float16, device=f"cuda:{gpu}")

def run_htod(gpu: int, total_gib: float, chunk_gib: float, pinned: bool) -> float:
n_chunks = math.ceil(total_gib / chunk_gib)
chunk_elems = gib_to_elems_fp16(chunk_gib)
bytes_per_chunk = chunk_elems * 2
total_bytes = n_chunks * bytes_per_chunk

# Buffers
src_cpu = make_cpu_tensor(chunk_elems, pin=pinned)
# Touch once to ensure physical allocation before timing
src_cpu.uniform_()

dst_gpu = make_gpu_tensor(chunk_elems, gpu)

# Warmup (not timed)
dst_gpu.copy_(src_cpu, non_blocking=True)
torch.cuda.synchronize(gpu)

# Timed loop
t0 = time.perf_counter()
for _ in range(n_chunks):
dst_gpu.copy_(src_cpu, non_blocking=True) # non_blocking is effective only if pinned=True
torch.cuda.synchronize(gpu)
t1 = time.perf_counter()

secs = t1 - t0
bw = gibs_per_s(total_bytes, secs)
label = "Pinned" if pinned else "Pageable"
print(f"[CPU to GPU {label}] {total_bytes/(1024**3):.2f} GiB in {secs:.3f} s -> {bw:.2f} GiB/s")
return bw

def run_dtoh(gpu: int, total_gib: float, chunk_gib: float, pinned: bool) -> float:
n_chunks = math.ceil(total_gib / chunk_gib)
chunk_elems = gib_to_elems_fp16(chunk_gib)
bytes_per_chunk = chunk_elems * 2
total_bytes = n_chunks * bytes_per_chunk

# Buffers
src_gpu = make_gpu_tensor(chunk_elems, gpu)
src_gpu.uniform_()

dst_cpu = make_cpu_tensor(chunk_elems, pin=pinned)

# Warmup (not timed)
dst_cpu.copy_(src_gpu, non_blocking=True)
torch.cuda.synchronize(gpu)

# Timed loop
t0 = time.perf_counter()
for _ in range(n_chunks):
dst_cpu.copy_(src_gpu, non_blocking=True) # effective non_blocking only if pinned=True
torch.cuda.synchronize(gpu)
t1 = time.perf_counter()

secs = t1 - t0
bw = gibs_per_s(total_bytes, secs)
label = "Pinned" if pinned else "Pageable"
print(f"[GPU to CPU {label}] {total_bytes/(1024**3):.2f} GiB in {secs:.3f} s -> {bw:.2f} GiB/s")
return bw

def main():
parser = argparse.ArgumentParser(description="CPU<->GPU bandwidth test with pinned vs pageable CPU memory")
parser.add_argument("--gpu", type=int, default=0, help="GPU id to test against")
parser.add_argument("--total-gib", type=float, default=40.0, help="Total GiB to stream per direction per mode")
parser.add_argument("--chunk-gib", type=float, default=1.0, help="Chunk size GiB per copy")
args = parser.parse_args()

if not torch.cuda.is_available():
raise SystemExit("CUDA not available.")
if args.chunk_gib <= 0 or args.total_gib <= 0 or args.total_gib < args.chunk_gib:
raise SystemExit("Invalid sizes: ensure total-gib >= chunk-gib > 0.")

print(f"CUDA devices: {torch.cuda.device_count()}, testing GPU {args.gpu}")
print(f"Total per run: {args.total_gib:.2f} GiB in {args.chunk_gib:.2f} GiB chunks (fp16).")
print("non_blocking=True is only truly async when CPU memory is pinned.\n")

# HtoD: pageable vs pinned
bw_htod_pageable = run_htod(args.gpu, args.total_gib, args.chunk_gib, pinned=False)
bw_htod_pinned = run_htod(args.gpu, args.total_gib, args.chunk_gib, pinned=True)

# DtoH: pageable vs pinned
bw_dtoh_pageable = run_dtoh(args.gpu, args.total_gib, args.chunk_gib, pinned=False)
bw_dtoh_pinned = run_dtoh(args.gpu, args.total_gib, args.chunk_gib, pinned=True)

print("\nSummary (GiB/s):")
print(f" CPU to GPU Pageable: {bw_htod_pageable:.2f}")
print(f" CPU to GPU Pinned : {bw_htod_pinned:.2f}")
print(f" GPU to CPU Pageable: {bw_dtoh_pageable:.2f}")
print(f" GPU to CPU Pinned : {bw_dtoh_pinned:.2f}")

if __name__ == "__main__":
main()
Loading