diff --git a/gptqmodel/looper/loop_processor.py b/gptqmodel/looper/loop_processor.py index e177434d4..7c6fc4ceb 100644 --- a/gptqmodel/looper/loop_processor.py +++ b/gptqmodel/looper/loop_processor.py @@ -4,6 +4,7 @@ # Contact: qubitium@modelcloud.ai, x.com/qubitium import json import queue +import re import threading from datetime import datetime from typing import Any, Callable, Dict, List, Optional, Set, Tuple @@ -26,6 +27,8 @@ # global level lock PROCESSOR_GLOBAL_LOCK = threading.Lock() +ANSI_ESCAPE_RE = re.compile(r"\x1B\[[0-?]*[ -/]*[@-~]") + # LoopProcessor is a singleton(), not per module instance class LoopProcessor: def __init__( @@ -191,41 +194,58 @@ def loss_color(self, loss_value): else: return "\033[91m" # Red + def _strip_ansi(self, text: Any) -> str: + return ANSI_ESCAPE_RE.sub("", str(text)) + + def _visible_length(self, text: Any) -> int: + return len(self._strip_ansi(text)) + + def _ljust_visible(self, text: str, width: int) -> str: + padding = max(width - self._visible_length(text), 0) + if padding: + return f"{text}{' ' * padding}" + return text + def log_new_row(self, stat): self.log_call_count += 1 self.log_save_async(stat) # write to temp log file # Update max_widths with the new row's column widths for key, value in stat.items(): - current_width = max(len(str(key)), len(str(value))) + 4 # 4 is for padding + key_str = str(key) + value_str = str(value) + current_width = max(self._visible_length(key_str), self._visible_length(value_str)) + 4 # 4 is for padding if key not in self.log_max_widths or current_width > self.log_max_widths[key]: self.log_max_widths[key] = current_width if self.log_call_count % 20 == 1: # Format the header row - header_row = "| " + " | ".join( - str(key).ljust(self.log_max_widths[key]) for key in self.log_max_widths.keys()) + " |" + header_cells = [ + self._ljust_visible(str(key), self.log_max_widths[key]) for key in self.log_max_widths.keys() + ] + header_row = "| " + " | ".join(header_cells) + " |" + header_separator = "-" * self._visible_length(header_row) if self.log_call_count == 1: - log.info(len(header_row) * "-") + log.info(header_separator) log.info(header_row) - log.info(len(header_row) * "-") + log.info(header_separator) - formatted_row = "| " + row_cells = [] for key in self.log_max_widths.keys(): value = stat.get(key, "") - if key == "loss": - color_code = self.loss_color(float(value)) - formatted_value = f"{color_code}{value}\033[0m" + value_str = str(value) + if key == "loss" and value_str: + color_code = self.loss_color(float(value_str)) + formatted_value = f"{color_code}{value_str}\033[0m" else: - formatted_value = str(value) - formatted_row += formatted_value.ljust(self.log_max_widths[key]) + " | " - - # formatted_row = "| " + " | ".join( - # str(stat.get(key, "")).ljust(self.log_max_widths[key]) for key in self.log_max_widths.keys()) + " |" + formatted_value = value_str + row_cells.append(self._ljust_visible(formatted_value, self.log_max_widths[key])) + formatted_row = "| " + " | ".join(row_cells) + " |" + row_separator = "-" * self._visible_length(formatted_row) log.info(formatted_row) - log.info(len(formatted_row) * "-") + log.info(row_separator) def _init_device_smi_handles(self) -> Dict[str, Device]: handles: Dict[str, Device] = {} diff --git a/gptqmodel/looper/module_looper.py b/gptqmodel/looper/module_looper.py index 06f1bd0a2..2f733b322 100644 --- a/gptqmodel/looper/module_looper.py +++ b/gptqmodel/looper/module_looper.py @@ -925,37 +925,57 @@ def _process_on_worker(proc: LoopProcessor, nm: NamedModule): torch_sync() # Gather finalize tasks (can offload to disk); run them via the pool - finalize_futures = [] + finalize_tasks = [] for reverse_p in reversed(self.processors): - for name in processed_subset: - @torch.inference_mode() - def finalize_module(process, module): - process.submodule_finalize(module, self.gptq_model) + for module in processed_subset.values(): + target_dev = get_device_new(module, recursive=True, assert_mode=True, expected="cpu") + module_label = getattr(module, "full_name", getattr(module, "name", "")) + finalize_tasks.append((reverse_p, module, module_label, target_dev)) - # Disk offload (lifecycle TODO note preserved) - if isinstance(process, (GPTQProcessor, QQQProcessor, AWQProcessor)): - offload_to_disk( - model=self.gptq_model.model, - module=self.gptq_model.model.get_submodule(module.full_name), - disk_path=self.gptq_model.quantize_config.offload_to_disk_path, - ) + finalize_count = len(finalize_tasks) + if finalize_count: + quant_modules_pb.subtitle( + f"Finalizing submodules ({finalize_count})" + ).draw() - module = processed_subset[name] + finalize_futures = [] - target_dev = get_device_new(module, recursive=True, assert_mode=True, expected="cpu") + @torch.inference_mode() + def _finalize_on_worker(process, module, idx, total, module_label): + quant_modules_pb.subtitle( + f"{process.name()}: finalizing {idx}/{total} ({module_label})" + ).draw() + + process.submodule_finalize(module, self.gptq_model) + + # Disk offload (lifecycle TODO note preserved) + if isinstance(process, (GPTQProcessor, QQQProcessor, AWQProcessor)): + offload_to_disk( + model=self.gptq_model.model, + module=self.gptq_model.model.get_submodule(module.full_name), + disk_path=self.gptq_model.quantize_config.offload_to_disk_path, + ) - # Submit on the module's device thread (safe & deterministic) - finalize_futures.append( - DEVICE_THREAD_POOL.submit_serial( - target_dev, finalize_module, reverse_p, module - ) + for index, (process, module, module_label, target_dev) in enumerate(finalize_tasks, start=1): + finalize_futures.append( + DEVICE_THREAD_POOL.submit_serial( + target_dev, + _finalize_on_worker, + process, + module, + index, + finalize_count, + module_label, ) + ) - # If any finalize tasks were queued, wait for them for fut in finalize_futures: fut.result() + if finalize_count: + quant_modules_pb.subtitle("").draw() + # LifeCycle: All sub-modules have finalized meaning quantization work is complete # Ensure ANY remaining tasks the looper submitted have drained DEVICE_THREAD_POOL.wait() # same as wait('all') diff --git a/gptqmodel/utils/model.py b/gptqmodel/utils/model.py index 3ef446fd2..f928c0605 100644 --- a/gptqmodel/utils/model.py +++ b/gptqmodel/utils/model.py @@ -1261,6 +1261,20 @@ def _resolve_offload_entry( data_offsets=offsets, ) + filename = entry.get("filename") + if filename: + path = filename if os.path.isabs(filename) else os.path.join(module_dir, filename) + start = int(entry.get("offset", 0)) + end = start + (_torch_dtype_num_bytes(resolved_dtype) * math.prod(shape or (1,))) + return OffloadTensorRef( + path=os.path.abspath(path), + dtype=resolved_dtype, + shape=shape, + format="dat", + weight_name=None, + data_offsets=(start, end), + ) + data_path = os.path.join(module_dir, f"{leaf}.dat") if not os.path.isfile(data_path): return None @@ -1450,7 +1464,10 @@ def _write_shard_file(path: str, entries: List[TensorSource], metadata: Dict[str if isinstance(source, OffloadTensorRef): if source.format == "dat": # print("offload tensor io buffered transfer DAT") - _copy_file_stream(source.path, out, entry.num_bytes) + start = 0 + if source.data_offsets is not None: + start = source.data_offsets[0] + _copy_file_stream(source.path, out, entry.num_bytes, offset=start) elif source.format == "safetensors" and source.data_offsets is not None: # print("offload tensor io buffered transfer SAFETENSOR stream") start, end = source.data_offsets diff --git a/gptqmodel/utils/offload.py b/gptqmodel/utils/offload.py index 37f7c765e..c84285d75 100644 --- a/gptqmodel/utils/offload.py +++ b/gptqmodel/utils/offload.py @@ -4,8 +4,10 @@ # Contact: qubitium@modelcloud.ai, x.com/qubitium import contextlib +import json import os import shutil +import struct from threading import Lock from typing import Iterable, List, Optional, Set, Tuple @@ -16,6 +18,7 @@ from accelerate import disk_offload from accelerate.hooks import remove_hook_from_module, remove_hook_from_submodules from accelerate.utils import align_module_device, has_offloaded_params +from safetensors.torch import save_file as safetensors_save_file from torch import nn from ..looper.named_module import NamedModule @@ -65,6 +68,54 @@ def is_meta_module(m: nn.Module) -> bool: _OFFLOAD_LOCK = Lock() +def _prepare_offload_directory(target_dir: str) -> None: + if os.path.isdir(target_dir): + shutil.rmtree(target_dir) + os.makedirs(target_dir, exist_ok=True) + + +def _bundle_module_state_dict(module: nn.Module, offload_dir: str) -> dict: + bundle_path = os.path.join(offload_dir, "module.safetensors") + index: dict[str, dict] = {} + tensors: dict[str, torch.Tensor] = {} + + with torch.inference_mode(): + for key, tensor in module.state_dict().items(): + cpu_tensor = tensor.detach().to("cpu") + tensors[key] = cpu_tensor.contiguous() + entry = { + "dtype": str(cpu_tensor.dtype).replace("torch.", ""), + "shape": list(cpu_tensor.shape), + "safetensors_file": os.path.abspath(bundle_path), + "weight_name": key, + } + index[key] = entry + + safetensors_save_file(tensors, bundle_path) + + with open(bundle_path, "rb") as fh: + header_len = struct.unpack(" torch.Tensor: + return self.linear(x * self.mask_buffer.float()) * self.scale_buffer + + +def _clone_state_dict(module: nn.Module) -> dict[str, torch.Tensor]: + return {k: v.detach().clone() for k, v in module.state_dict().items()} + + +def test_offload_to_disk_writes_single_dat_file(tmp_path): + model = _LinearWithBuffers(in_features=128, out_features=96) + original_state = _clone_state_dict(model.linear) + + offload_root = tmp_path / "offload_root" + offload_to_disk(module=model.linear, model=model, disk_path=str(offload_root)) + + module_dir = offload_root / "linear" + assert module_dir.is_dir(), "Expected per-module directory to exist" + + files = sorted(module_dir.iterdir(), key=lambda p: p.name) + rows = [(path.name, path.stat().st_size) for path in files] + print(tabulate(rows, headers=["file", "bytes"], tablefmt="github")) + + safetensor_files = [path for path in files if path.suffix == ".safetensors"] + assert len(safetensor_files) == 1, "offload_to_disk should produce exactly one safetensors file per module" + assert safetensor_files[0].name == "module.safetensors" + + with open(module_dir / "index.json", encoding="utf-8") as fp: + index = json.load(fp) + + expected_keys = set(model.linear.state_dict().keys()) + assert set(index.keys()) == expected_keys + assert all(Path(entry.get("safetensors_file")).name == "module.safetensors" for entry in index.values()) + assert all(entry.get("data_offsets") is not None for entry in index.values()) + + save_dir = tmp_path / "saved" + save_dir.mkdir() + state_dict = get_state_dict_for_save(model, offload_root=str(offload_root)) + expected_files, tensor_to_filename, _ = streaming_state_dict_to_shards( + state_dict, + save_dir=str(save_dir), + model_base_name="model", + single_file_name="model.safetensors", + metadata={}, + max_shard_size=None, + ) + + assert len(expected_files) == 1 + shard_path = save_dir / expected_files[0] + with safe_open(str(shard_path), framework="pt", device="cpu") as handler: + for name, tensor in original_state.items(): + saved = handler.get_tensor(f"linear.{name}") + torch.testing.assert_close(saved, tensor) + + # Materialize the module back and ensure values match the snapshot captured before offload. + undo_offload_to_disk(model.linear, delete_offload_folders=False) + for name, tensor in model.linear.state_dict().items(): + torch.testing.assert_close(tensor, original_state[name])