diff --git a/gptqmodel/looper/eora_processor.py b/gptqmodel/looper/eora_processor.py index 4c9ddeac2..7ee752a30 100644 --- a/gptqmodel/looper/eora_processor.py +++ b/gptqmodel/looper/eora_processor.py @@ -140,6 +140,8 @@ def process(self, module: NamedModule): device=module.target_device, ) + del eigen_scaling_diag_matrix + # wq with A/B applied computed_wq = wq + (B @ A) @@ -194,31 +196,19 @@ def process(self, module: NamedModule): # log.info(stat) self.log_new_row(stat) - # logger.info(f"Quantizing module END: {name}, {gptq[name].shape()}") - with self.lock: - self.result_save(module.full_name, { - "rank": module.adapter_cfg.rank, - "lora_A.weight": move_to(A.to(dtype=module.module_dtype), device=CPU, stream=self.stream), - "lora_B.weight": move_to(B.to(dtype=module.module_dtype), device=CPU, stream=self.stream), - }) - - # eora = Lora(rank=module.adapter_cfg.rank, lora_A=A, lora_B=B) - # - # module.state.update({ - # "adapter": eora, - # }) + eora = Lora( + rank=module.adapter_cfg.rank, + lora_A=move_to(A.to(dtype=module.module_dtype), device=CPU, stream=self.stream), + lora_B=move_to(B.to(dtype=module.module_dtype), device=CPU, stream=self.stream), + ) - def submodule_finalize(self, module: NamedModule): - pass - # adapter: Lora = module.state.pop("adapter") - # - # # logger.info(f"Quantizing module END: {name}, {gptq[name].shape()}") - # self.result_save(module.full_name, { - # "lora_A.weight": move_to(adapter.lora_A.to(dtype=torch.float16), device=CPU, stream=self.stream), - # # A.to(dtype=torch.float16, device=CPU), - # "lora_B.weight": move_to(adapter.lora_B.to(dtype=torch.float16), device=CPU, stream=self.stream), - # # B.to(dtype=torch.float16, device=CPU), - # }) + module.state.update({ + "adapter": eora + }) + + def submodule_finalize(self, module: NamedModule, model: BaseQModel, **kwargs): + # logger.info(f"Quantizing module END: {name}, {gptq[name].shape()}") + self.result_save(module.full_name, module.state.pop("adapter")) def finalize(self, model: BaseQModel, **kwargs): # block for streams diff --git a/gptqmodel/looper/loop_processor.py b/gptqmodel/looper/loop_processor.py index 776be1aa5..ff654a425 100644 --- a/gptqmodel/looper/loop_processor.py +++ b/gptqmodel/looper/loop_processor.py @@ -224,22 +224,27 @@ def log_new_row(self, stat): log.info(formatted_row) log.info(len(formatted_row) * "-") + # Loop Procssor level scoped state data def result_save(self, key: str, value: Any): with self._results_lock: #assert self.result_get(key) is None, f"key: {key} already exists in `self.result`" self._results[key] = value + # Loop Procssor level scoped state data def result_get(self, key: str, default: Any = None) -> Any: with self._results_lock: return self._results.get(key, default) + # Loop Procssor level scoped state data def result_pop(self, key: str, default: Any = None): with self._results_lock: return self._results.pop(key, default) + # Loop Procssor level scoped state data def result_pop(self, key: str, default: Any = None) -> Any: return self._results.pop(key, default) + # Loop Procssor level scoped state data def results(self): return self._results diff --git a/gptqmodel/looper/module_looper.py b/gptqmodel/looper/module_looper.py index 891e4512b..517d065fb 100644 --- a/gptqmodel/looper/module_looper.py +++ b/gptqmodel/looper/module_looper.py @@ -22,7 +22,7 @@ from ..models._const import CUDA, SUPPORTS_MODULE_TYPES from ..nn_modules.hooked_linear import (STOP_FORWARD_EXCEPTION, HookedLinear, StopForward, replace_module_with_hooked_legacy) -from ..utils import ASYNC_WORKER +from ..utils import ASYNC_BG_QUEUE, SERIAL_BG_QUEUE from ..utils.attn_mask import apply_keep_mask_bt, normalize_seq_mask from ..utils.device import get_device from ..utils.logger import setup_logger @@ -32,6 +32,7 @@ from ..utils.torch import (ALL_DEVICES, CPU, DEFAULT_BALANCE_STRATEGY, HAS_CUDA, META, BalanceStrategy, device_next, device_next_reset, torch_empty_cache, torch_sync) from .awq_processor import AWQProcessor +from .qqq_processor import QQQProcessor log = setup_logger() @@ -222,7 +223,7 @@ def loop(self, fail_safe: bool = False, **kwargs): prev_processor = self.processors[p_index - 1] processor.set_calibration_dataset(prev_processor.calibration_dataset) # If calibration_dataset is None or Empty, the input_cache of the previous processor is used. - processor.receive_input_cache(copy.copy(prev_processor.inputs_cache)) + processor.receive_input_cache(prev_processor.inputs_cache) elif isinstance(processor, DequantizeProcessor): # DequantizeProcessor does not perform any operations on dataset. processor.set_calibration_dataset([]) @@ -543,31 +544,36 @@ def process_module(name, m): for reverse_p in reversed(self.processors): for name in processed_subset: @torch.inference_mode() - def finalize_module(module): + def finalize_module(process, module): # prevent cuda sync memory ctx bugs m_device = get_device(module) if HAS_CUDA and m_device is not None and m_device.type == "cuda": torch.cuda.set_device(m_device) - reverse_p.submodule_finalize(module, self.gptq_model) + process.submodule_finalize(module, self.gptq_model) - # checking for disk offloading - offload_to_disk( - model=self.gptq_model.model, - module=self.gptq_model.model.get_submodule(module.full_name), - disk_path=self.gptq_model.quantize_config.offload_to_disk_path, - ) + # TODO FIX ME offloading to LoopProcessor lifecycle + if isinstance(process, (GPTQProcessor, QQQProcessor, AWQProcessor)): + # checking for disk offloading + offload_to_disk( + model=self.gptq_model.model, + module=self.gptq_model.model.get_submodule(module.full_name), + disk_path=self.gptq_model.quantize_config.offload_to_disk_path, + ) module = processed_subset[name] if self.gptq_model.quantize_config.offload_to_disk: - ASYNC_WORKER.submit(partial( + SERIAL_BG_QUEUE.submit(partial( finalize_module, + process=reverse_p, module=module, )) + else: + reverse_p.submodule_finalize(module, self.gptq_model) # LifeCycle: All sub-modules have finalized meaning quantization work is complete - ASYNC_WORKER.join() + SERIAL_BG_QUEUE.join() # paranoid safety check torch_sync() diff --git a/gptqmodel/looper/named_module.py b/gptqmodel/looper/named_module.py index 195927ceb..e10120d30 100644 --- a/gptqmodel/looper/named_module.py +++ b/gptqmodel/looper/named_module.py @@ -31,7 +31,8 @@ def __init__(self, module: torch.nn.Module, name: str, full_name:str, layer_inde # self.target_device, self.target_device_stream = device_next() self.target_device, self.target_device_stream = None, None - # persistent work state forLoopProcessors + + # persistent work state for named module (used by some LoopProcessors) # store all `processed()` work state/data/result here self.state = {} diff --git a/gptqmodel/models/writer.py b/gptqmodel/models/writer.py index 5261d0b81..9aae554f9 100644 --- a/gptqmodel/models/writer.py +++ b/gptqmodel/models/writer.py @@ -75,7 +75,8 @@ def _eora_save(self, save_dir: str, model_save_dir: str = None): weights = {} target_modules = set() # convert the dict into safetensors compatible dict - for key, d in self.lora_results.items(): + for key, adapter in self.lora_results.items(): + assert isinstance(adapter, Lora) key = key.lower() simple_module_name = key.split(".")[-1] # mlp.gate_proj => gate_proj target_modules.add(simple_module_name) @@ -84,12 +85,11 @@ def _eora_save(self, save_dir: str, model_save_dir: str = None): # key = key.removeprefix('model.') # some HF models use model. or model.model. # must normalize key since HF can load weights as `model.` or not based on what AutoModel is used - key = f"{HF_ADAPTER_WEIGHT_KEY_PREFIX}{key}" - lora_rank = d.pop("rank") - for lora_key, lora_weight in d.items(): - assert isinstance(lora_weight, torch.Tensor) - weights[f"{key}.{lora_key}"] = lora_weight - log.info(f"Adapter: EoRA weights found -> `{key}.{lora_key}`, rank = `{lora_rank}`") + weight_key = f"{HF_ADAPTER_WEIGHT_KEY_PREFIX}{key}" + + weights[f"{weight_key}.lora_A.weight"] = adapter.lora_A + weights[f"{weight_key}.lora_B.weight"] = adapter.lora_B + log.info(f"Adapter: EoRA weights found -> `{weight_key}.lora_A/Lora_B.weight`, rank = `{adapter.rank}`") weight_file_path = f"{save_dir.removesuffix('/')}/{HF_ADAPTER_FILE_NAME}" diff --git a/gptqmodel/utils/__init__.py b/gptqmodel/utils/__init__.py index 9f998d8cd..11973f0b5 100644 --- a/gptqmodel/utils/__init__.py +++ b/gptqmodel/utils/__init__.py @@ -6,12 +6,13 @@ from .backend import BACKEND from .logger import setup_logger from .python import gte_python_3_13_3, has_gil_control, has_gil_disabled, log_gil_requirements_for -from .threads import AsyncManager +from .threads import AsyncManager, SerialWorker from .vram import get_vram log = setup_logger() -ASYNC_WORKER = AsyncManager(threads=4) +ASYNC_BG_QUEUE = AsyncManager(threads=4) +SERIAL_BG_QUEUE = SerialWorker() # TODO: datasets is not compatible with free threading if has_gil_disabled(): diff --git a/gptqmodel/utils/threads.py b/gptqmodel/utils/threads.py index b447a2ee4..cb5060bf5 100644 --- a/gptqmodel/utils/threads.py +++ b/gptqmodel/utils/threads.py @@ -4,6 +4,7 @@ # Contact: qubitium@modelcloud.ai, x.com/qubitium import concurrent.futures as cf +import queue import threading import time import traceback @@ -84,3 +85,34 @@ def shutdown(self, wait=True, cancel_pending=False): if wait: self.join() # wait for remaining tasks self._exec.shutdown(wait=wait) + + +class SerialWorker: + def __init__(self, name="serial-worker"): + self._q = queue.Queue() + self._t = threading.Thread(target=self._loop, name=name, daemon=True) + self._t.start() + + def _loop(self): + while True: + fn = self._q.get() + if fn is None: + break + try: + fn() + except Exception: + traceback.print_exc() + finally: + self._q.task_done() + + def submit(self, fn): + if not callable(fn): + raise TypeError("submit expects a callable") + self._q.put(fn) + + def join(self, timeout=None): + self._q.join() + + def shutdown(self): + self._q.put(None) + self._t.join() diff --git a/tests/test_cpu_gpu_memory_copy.py b/tests/test_cpu_gpu_memory_copy.py new file mode 100644 index 000000000..d79168896 --- /dev/null +++ b/tests/test_cpu_gpu_memory_copy.py @@ -0,0 +1,131 @@ +# SPDX-FileCopyrightText: 2024-2025 ModelCloud.ai +# SPDX-FileCopyrightText: 2024-2025 qubitium@modelcloud.ai +# SPDX-License-Identifier: Apache-2.0 +# Contact: qubitium@modelcloud.ai, x.com/qubitium + + +# cpu_gpu_bandwidth_test.py +# Measure HtoD and DtoH bandwidth with pageable vs pinned CPU memory. +# +# Usage examples: +# python cpu_gpu_bandwidth_test.py # 40 GiB total, 1 GiB chunks, GPU:0 +# python cpu_gpu_bandwidth_test.py --total-gib 80 --chunk-gib 2 --gpu 1 +# +# Notes: +# - We stream in chunks to avoid allocating a single massive tensor. +# - For HtoD, pinned CPU memory + non_blocking=True is required for true async copies. +# - For DtoH, pinned CPU memory also enables non_blocking transfers. +# - We synchronize before/after timing to get accurate results. +# - dtype is fp16 to match your earlier test (2 bytes/elem). + +import argparse +import math +import time +import torch + +def gib_to_elems_fp16(gib: float) -> int: + # 1 GiB = 1024**3 bytes; fp16 = 2 bytes/elem + return int((gib * (1024**3)) // 2) + +def gibs_per_s(bytes_moved: int, seconds: float) -> float: + return (bytes_moved / (1024**3)) / seconds + +def make_cpu_tensor(num_elems: int, pin: bool) -> torch.Tensor: + # Pageable vs pinned CPU tensor + return torch.empty(num_elems, dtype=torch.float16, device="cpu", pin_memory=pin) + +def make_gpu_tensor(num_elems: int, gpu: int) -> torch.Tensor: + with torch.cuda.device(gpu): + return torch.empty(num_elems, dtype=torch.float16, device=f"cuda:{gpu}") + +def run_htod(gpu: int, total_gib: float, chunk_gib: float, pinned: bool) -> float: + n_chunks = math.ceil(total_gib / chunk_gib) + chunk_elems = gib_to_elems_fp16(chunk_gib) + bytes_per_chunk = chunk_elems * 2 + total_bytes = n_chunks * bytes_per_chunk + + # Buffers + src_cpu = make_cpu_tensor(chunk_elems, pin=pinned) + # Touch once to ensure physical allocation before timing + src_cpu.uniform_() + + dst_gpu = make_gpu_tensor(chunk_elems, gpu) + + # Warmup (not timed) + dst_gpu.copy_(src_cpu, non_blocking=True) + torch.cuda.synchronize(gpu) + + # Timed loop + t0 = time.perf_counter() + for _ in range(n_chunks): + dst_gpu.copy_(src_cpu, non_blocking=True) # non_blocking is effective only if pinned=True + torch.cuda.synchronize(gpu) + t1 = time.perf_counter() + + secs = t1 - t0 + bw = gibs_per_s(total_bytes, secs) + label = "Pinned" if pinned else "Pageable" + print(f"[CPU to GPU {label}] {total_bytes/(1024**3):.2f} GiB in {secs:.3f} s -> {bw:.2f} GiB/s") + return bw + +def run_dtoh(gpu: int, total_gib: float, chunk_gib: float, pinned: bool) -> float: + n_chunks = math.ceil(total_gib / chunk_gib) + chunk_elems = gib_to_elems_fp16(chunk_gib) + bytes_per_chunk = chunk_elems * 2 + total_bytes = n_chunks * bytes_per_chunk + + # Buffers + src_gpu = make_gpu_tensor(chunk_elems, gpu) + src_gpu.uniform_() + + dst_cpu = make_cpu_tensor(chunk_elems, pin=pinned) + + # Warmup (not timed) + dst_cpu.copy_(src_gpu, non_blocking=True) + torch.cuda.synchronize(gpu) + + # Timed loop + t0 = time.perf_counter() + for _ in range(n_chunks): + dst_cpu.copy_(src_gpu, non_blocking=True) # effective non_blocking only if pinned=True + torch.cuda.synchronize(gpu) + t1 = time.perf_counter() + + secs = t1 - t0 + bw = gibs_per_s(total_bytes, secs) + label = "Pinned" if pinned else "Pageable" + print(f"[GPU to CPU {label}] {total_bytes/(1024**3):.2f} GiB in {secs:.3f} s -> {bw:.2f} GiB/s") + return bw + +def main(): + parser = argparse.ArgumentParser(description="CPU<->GPU bandwidth test with pinned vs pageable CPU memory") + parser.add_argument("--gpu", type=int, default=0, help="GPU id to test against") + parser.add_argument("--total-gib", type=float, default=40.0, help="Total GiB to stream per direction per mode") + parser.add_argument("--chunk-gib", type=float, default=1.0, help="Chunk size GiB per copy") + args = parser.parse_args() + + if not torch.cuda.is_available(): + raise SystemExit("CUDA not available.") + if args.chunk_gib <= 0 or args.total_gib <= 0 or args.total_gib < args.chunk_gib: + raise SystemExit("Invalid sizes: ensure total-gib >= chunk-gib > 0.") + + print(f"CUDA devices: {torch.cuda.device_count()}, testing GPU {args.gpu}") + print(f"Total per run: {args.total_gib:.2f} GiB in {args.chunk_gib:.2f} GiB chunks (fp16).") + print("non_blocking=True is only truly async when CPU memory is pinned.\n") + + # HtoD: pageable vs pinned + bw_htod_pageable = run_htod(args.gpu, args.total_gib, args.chunk_gib, pinned=False) + bw_htod_pinned = run_htod(args.gpu, args.total_gib, args.chunk_gib, pinned=True) + + # DtoH: pageable vs pinned + bw_dtoh_pageable = run_dtoh(args.gpu, args.total_gib, args.chunk_gib, pinned=False) + bw_dtoh_pinned = run_dtoh(args.gpu, args.total_gib, args.chunk_gib, pinned=True) + + print("\nSummary (GiB/s):") + print(f" CPU to GPU Pageable: {bw_htod_pageable:.2f}") + print(f" CPU to GPU Pinned : {bw_htod_pinned:.2f}") + print(f" GPU to CPU Pageable: {bw_dtoh_pageable:.2f}") + print(f" GPU to CPU Pinned : {bw_dtoh_pinned:.2f}") + +if __name__ == "__main__": + main() diff --git a/tests/test_gpu_gpu_memory_copy.py b/tests/test_gpu_gpu_memory_copy.py new file mode 100644 index 000000000..73df8cb7f --- /dev/null +++ b/tests/test_gpu_gpu_memory_copy.py @@ -0,0 +1,97 @@ +# SPDX-FileCopyrightText: 2024-2025 ModelCloud.ai +# SPDX-FileCopyrightText: 2024-2025 qubitium@modelcloud.ai +# SPDX-License-Identifier: Apache-2.0 +# Contact: qubitium@modelcloud.ai, x.com/qubitium + + +# p2p_bandwidth_test.py +# Measure inter-GPU copy bandwidth using chunked fp16 tensors. +# Default: stream 40 GiB total (in 1 GiB chunks) from 0->1 and 1->0. +# +# Usage: +# python p2p_bandwidth_test.py # 40 GiB total, 1 GiB chunks +# python p2p_bandwidth_test.py --total-gib 80 --chunk-gib 2 +# +# Notes: +# - This avoids allocating a single 40 GiB tensor (which would OOM or be risky). +# - If P2P is available, CUDA will use it; otherwise it falls back to host staging. +# - For accurate timing we synchronize before/after and use perf_counter. + +import argparse +import time +import math +import torch + +def gib_to_elems_fp16(gib: float) -> int: + # 1 GiB = 1024**3 bytes; fp16 = 2 bytes/elem + return int((gib * (1024**3)) // 2) + +def format_gibs_per_s(bytes_moved, seconds): + return (bytes_moved / (1024**3)) / seconds + +def run_direction(src_dev: int, dst_dev: int, total_gib: float, chunk_gib: float) -> float: + assert total_gib > 0 and chunk_gib > 0 and total_gib >= chunk_gib + n_chunks = math.ceil(total_gib / chunk_gib) + # Round chunk so that n_chunks * chunk_gib >= total_gib + chunk_elems = gib_to_elems_fp16(chunk_gib) + + # Pre-allocate reusable src/dst chunk buffers + with torch.cuda.device(src_dev): + src = torch.empty(chunk_elems, dtype=torch.float16, device=f"cuda:{src_dev}") + # Fill once to avoid lazy allocations later + src.uniform_() + + with torch.cuda.device(dst_dev): + dst = torch.empty(chunk_elems, dtype=torch.float16, device=f"cuda:{dst_dev}") + + # Warmup: single copy (not timed) + dst.copy_(src, non_blocking=True) + torch.cuda.synchronize(src_dev) + torch.cuda.synchronize(dst_dev) + + # Timed streaming of N chunks + bytes_per_chunk = chunk_elems * 2 # fp16 = 2 bytes + total_bytes = n_chunks * bytes_per_chunk + + t0 = time.perf_counter() + for _ in range(n_chunks): + # reuse the same buffers; content doesn't matter for bandwidth + dst.copy_(src, non_blocking=True) + # Ensure all queued copies are complete + torch.cuda.synchronize(src_dev) + torch.cuda.synchronize(dst_dev) + t1 = time.perf_counter() + + seconds = t1 - t0 + gibs = total_bytes / (1024**3) + bw = format_gibs_per_s(total_bytes, seconds) + print(f"[cuda:{src_dev} -> cuda:{dst_dev}] Transferred {gibs:.2f} GiB in {seconds:.3f} s -> {bw:.2f} GiB/s") + return bw + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--src", type=int, default=0, help="source GPU id") + parser.add_argument("--dst", type=int, default=1, help="destination GPU id") + parser.add_argument("--total-gib", type=float, default=40.0, help="total GiB to stream per direction") + parser.add_argument("--chunk-gib", type=float, default=1.0, help="chunk size GiB per copy") + args = parser.parse_args() + + if not torch.cuda.is_available() or torch.cuda.device_count() < 2: + raise SystemExit("Need at least 2 CUDA devices.") + + # Basic info + print(f"Detected {torch.cuda.device_count()} CUDA devices.") + print(f"Testing {args.total_gib:.2f} GiB total per direction in {args.chunk_gib:.2f} GiB chunks.") + print(f"CUDA P2P (device_can_access_peer): " + f"{args.src}->{args.dst}={torch.cuda.can_device_access_peer(args.src, args.dst)}, " + f"{args.dst}->{args.src}={torch.cuda.can_device_access_peer(args.dst, args.src)}") + + # Run both directions + bw_fwd = run_direction(args.src, args.dst, args.total_gib, args.chunk_gib) + bw_bwd = run_direction(args.dst, args.src, args.total_gib, args.chunk_gib) + + # Summary + print(f"Average bandwidth: {(bw_fwd + bw_bwd)/2:.2f} GiB/s") + +if __name__ == "__main__": + main() diff --git a/tests/test_p2p.py b/tests/test_p2p.py new file mode 100644 index 000000000..1c1f318ac --- /dev/null +++ b/tests/test_p2p.py @@ -0,0 +1,34 @@ +# SPDX-FileCopyrightText: 2024-2025 ModelCloud.ai +# SPDX-FileCopyrightText: 2024-2025 qubitium@modelcloud.ai +# SPDX-License-Identifier: Apache-2.0 +# Contact: qubitium@modelcloud.ai, x.com/qubitium + +import torch + +def main(): + if not torch.cuda.is_available(): + print("CUDA not available") + return + + ndev = torch.cuda.device_count() + if ndev < 2: + print(f"Only {ndev} CUDA device(s) visible, need >= 2") + return + + print("Devices:") + for d in range(ndev): + props = torch.cuda.get_device_properties(d) + print(f" cuda:{d} {props.name} {props.total_memory/1024**3:.1f} GiB CC {props.major}.{props.minor}") + + print("\nP2P capability (rows=src, cols=dst):") + for i in range(ndev): + row = [] + for j in range(ndev): + if i == j: + row.append(" - ") + continue + row.append("yes" if torch.cuda.can_device_access_peer(i, j) else " no") + print(f"{i:>2}: " + " ".join(f"{r:>3}" for r in row)) + +if __name__ == "__main__": + main() diff --git a/tests/test_quant_and_eora.py b/tests/test_quant_and_eora.py index eb77d7217..a779880bd 100644 --- a/tests/test_quant_and_eora.py +++ b/tests/test_quant_and_eora.py @@ -48,12 +48,12 @@ def setUpClass(cls): @parameterized.expand( [ - # (QUANT_METHOD.GPTQ, FORMAT.GPTQ, True), # gptq v2 - (METHOD.GPTQ, FORMAT.GPTQ, False), # gptq v1 + # (QUANT_METHOD.GPTQ, FORMAT.GPTQ), # gptq v2 + (METHOD.GPTQ, FORMAT.GPTQ), # gptq v1 #(QUANT_METHOD.QQQ, FORMAT.QQQ), ] ) - def test_quant_and_eora(self, quant_method: METHOD, format: FORMAT, v2: bool): + def test_quant_and_eora(self, quant_method: METHOD, format: FORMAT): bits = 4 group_size = 128 desc_act = False @@ -92,7 +92,6 @@ def test_quant_and_eora(self, quant_method: METHOD, format: FORMAT, v2: bool): adapter=eora, format=format, quant_method=quant_method, - v2=v2, ) model = GPTQModel.load(