Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 35 additions & 15 deletions gptqmodel/looper/loop_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# Contact: qubitium@modelcloud.ai, x.com/qubitium
import json
import queue
import re
import threading
from datetime import datetime
from typing import Any, Callable, Dict, List, Optional, Set, Tuple
Expand All @@ -26,6 +27,8 @@
# global level lock
PROCESSOR_GLOBAL_LOCK = threading.Lock()

ANSI_ESCAPE_RE = re.compile(r"\x1B\[[0-?]*[ -/]*[@-~]")

# LoopProcessor is a singleton(), not per module instance
class LoopProcessor:
def __init__(
Expand Down Expand Up @@ -191,41 +194,58 @@ def loss_color(self, loss_value):
else:
return "\033[91m" # Red

def _strip_ansi(self, text: Any) -> str:
return ANSI_ESCAPE_RE.sub("", str(text))

def _visible_length(self, text: Any) -> int:
return len(self._strip_ansi(text))

def _ljust_visible(self, text: str, width: int) -> str:
padding = max(width - self._visible_length(text), 0)
if padding:
return f"{text}{' ' * padding}"
return text

def log_new_row(self, stat):
self.log_call_count += 1
self.log_save_async(stat) # write to temp log file

# Update max_widths with the new row's column widths
for key, value in stat.items():
current_width = max(len(str(key)), len(str(value))) + 4 # 4 is for padding
key_str = str(key)
value_str = str(value)
current_width = max(self._visible_length(key_str), self._visible_length(value_str)) + 4 # 4 is for padding
if key not in self.log_max_widths or current_width > self.log_max_widths[key]:
self.log_max_widths[key] = current_width

if self.log_call_count % 20 == 1:
# Format the header row
header_row = "| " + " | ".join(
str(key).ljust(self.log_max_widths[key]) for key in self.log_max_widths.keys()) + " |"
header_cells = [
self._ljust_visible(str(key), self.log_max_widths[key]) for key in self.log_max_widths.keys()
]
header_row = "| " + " | ".join(header_cells) + " |"
header_separator = "-" * self._visible_length(header_row)

if self.log_call_count == 1:
log.info(len(header_row) * "-")
log.info(header_separator)
log.info(header_row)
log.info(len(header_row) * "-")
log.info(header_separator)

formatted_row = "| "
row_cells = []
for key in self.log_max_widths.keys():
value = stat.get(key, "")
if key == "loss":
color_code = self.loss_color(float(value))
formatted_value = f"{color_code}{value}\033[0m"
value_str = str(value)
if key == "loss" and value_str:
color_code = self.loss_color(float(value_str))
formatted_value = f"{color_code}{value_str}\033[0m"
else:
formatted_value = str(value)
formatted_row += formatted_value.ljust(self.log_max_widths[key]) + " | "

# formatted_row = "| " + " | ".join(
# str(stat.get(key, "")).ljust(self.log_max_widths[key]) for key in self.log_max_widths.keys()) + " |"
formatted_value = value_str
row_cells.append(self._ljust_visible(formatted_value, self.log_max_widths[key]))

formatted_row = "| " + " | ".join(row_cells) + " |"
row_separator = "-" * self._visible_length(formatted_row)
log.info(formatted_row)
log.info(len(formatted_row) * "-")
log.info(row_separator)

def _init_device_smi_handles(self) -> Dict[str, Device]:
handles: Dict[str, Device] = {}
Expand Down
60 changes: 40 additions & 20 deletions gptqmodel/looper/module_looper.py
Original file line number Diff line number Diff line change
Expand Up @@ -925,37 +925,57 @@ def _process_on_worker(proc: LoopProcessor, nm: NamedModule):
torch_sync()

# Gather finalize tasks (can offload to disk); run them via the pool
finalize_futures = []
finalize_tasks = []

for reverse_p in reversed(self.processors):
for name in processed_subset:
@torch.inference_mode()
def finalize_module(process, module):
process.submodule_finalize(module, self.gptq_model)
for module in processed_subset.values():
target_dev = get_device_new(module, recursive=True, assert_mode=True, expected="cpu")
module_label = getattr(module, "full_name", getattr(module, "name", ""))
finalize_tasks.append((reverse_p, module, module_label, target_dev))

# Disk offload (lifecycle TODO note preserved)
if isinstance(process, (GPTQProcessor, QQQProcessor, AWQProcessor)):
offload_to_disk(
model=self.gptq_model.model,
module=self.gptq_model.model.get_submodule(module.full_name),
disk_path=self.gptq_model.quantize_config.offload_to_disk_path,
)
finalize_count = len(finalize_tasks)
if finalize_count:
quant_modules_pb.subtitle(
f"Finalizing submodules ({finalize_count})"
).draw()

module = processed_subset[name]
finalize_futures = []

target_dev = get_device_new(module, recursive=True, assert_mode=True, expected="cpu")
@torch.inference_mode()
def _finalize_on_worker(process, module, idx, total, module_label):
quant_modules_pb.subtitle(
f"{process.name()}: finalizing {idx}/{total} ({module_label})"
).draw()

process.submodule_finalize(module, self.gptq_model)

# Disk offload (lifecycle TODO note preserved)
if isinstance(process, (GPTQProcessor, QQQProcessor, AWQProcessor)):
offload_to_disk(
model=self.gptq_model.model,
module=self.gptq_model.model.get_submodule(module.full_name),
disk_path=self.gptq_model.quantize_config.offload_to_disk_path,
)

# Submit on the module's device thread (safe & deterministic)
finalize_futures.append(
DEVICE_THREAD_POOL.submit_serial(
target_dev, finalize_module, reverse_p, module
)
for index, (process, module, module_label, target_dev) in enumerate(finalize_tasks, start=1):
finalize_futures.append(
DEVICE_THREAD_POOL.submit_serial(
target_dev,
_finalize_on_worker,
process,
module,
index,
finalize_count,
module_label,
)
)

# If any finalize tasks were queued, wait for them
for fut in finalize_futures:
fut.result()

if finalize_count:
quant_modules_pb.subtitle("").draw()

# LifeCycle: All sub-modules have finalized meaning quantization work is complete
# Ensure ANY remaining tasks the looper submitted have drained
DEVICE_THREAD_POOL.wait() # same as wait('all')
Expand Down
19 changes: 18 additions & 1 deletion gptqmodel/utils/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -1261,6 +1261,20 @@ def _resolve_offload_entry(
data_offsets=offsets,
)

filename = entry.get("filename")
if filename:
path = filename if os.path.isabs(filename) else os.path.join(module_dir, filename)
start = int(entry.get("offset", 0))
end = start + (_torch_dtype_num_bytes(resolved_dtype) * math.prod(shape or (1,)))
return OffloadTensorRef(
path=os.path.abspath(path),
dtype=resolved_dtype,
shape=shape,
format="dat",
weight_name=None,
data_offsets=(start, end),
)

data_path = os.path.join(module_dir, f"{leaf}.dat")
if not os.path.isfile(data_path):
return None
Expand Down Expand Up @@ -1450,7 +1464,10 @@ def _write_shard_file(path: str, entries: List[TensorSource], metadata: Dict[str
if isinstance(source, OffloadTensorRef):
if source.format == "dat":
# print("offload tensor io buffered transfer DAT")
_copy_file_stream(source.path, out, entry.num_bytes)
start = 0
if source.data_offsets is not None:
start = source.data_offsets[0]
_copy_file_stream(source.path, out, entry.num_bytes, offset=start)
elif source.format == "safetensors" and source.data_offsets is not None:
# print("offload tensor io buffered transfer SAFETENSOR stream")
start, end = source.data_offsets
Expand Down
61 changes: 58 additions & 3 deletions gptqmodel/utils/offload.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,10 @@
# Contact: qubitium@modelcloud.ai, x.com/qubitium

import contextlib
import json
import os
import shutil
import struct
from threading import Lock
from typing import Iterable, List, Optional, Set, Tuple

Expand All @@ -16,6 +18,7 @@
from accelerate import disk_offload
from accelerate.hooks import remove_hook_from_module, remove_hook_from_submodules
from accelerate.utils import align_module_device, has_offloaded_params
from safetensors.torch import save_file as safetensors_save_file
from torch import nn

from ..looper.named_module import NamedModule
Expand Down Expand Up @@ -65,6 +68,54 @@ def is_meta_module(m: nn.Module) -> bool:
_OFFLOAD_LOCK = Lock()


def _prepare_offload_directory(target_dir: str) -> None:
if os.path.isdir(target_dir):
shutil.rmtree(target_dir)
os.makedirs(target_dir, exist_ok=True)


def _bundle_module_state_dict(module: nn.Module, offload_dir: str) -> dict:
bundle_path = os.path.join(offload_dir, "module.safetensors")
index: dict[str, dict] = {}
tensors: dict[str, torch.Tensor] = {}

with torch.inference_mode():
for key, tensor in module.state_dict().items():
cpu_tensor = tensor.detach().to("cpu")
tensors[key] = cpu_tensor.contiguous()
entry = {
"dtype": str(cpu_tensor.dtype).replace("torch.", ""),
"shape": list(cpu_tensor.shape),
"safetensors_file": os.path.abspath(bundle_path),
"weight_name": key,
}
index[key] = entry

safetensors_save_file(tensors, bundle_path)

with open(bundle_path, "rb") as fh:
header_len = struct.unpack("<Q", fh.read(8))[0]
header = json.loads(fh.read(header_len).decode("utf-8"))
data_offset_base = fh.tell()

for key, tensor_meta in header.items():
if key == "__metadata__":
continue
entry = index.get(key)
if entry is None:
continue
offsets = tensor_meta.get("data_offsets")
if offsets is not None:
start, end = (int(offsets[0]), int(offsets[1]))
entry["data_offsets"] = [data_offset_base + start, data_offset_base + end]

index_path = os.path.join(offload_dir, "index.json")
with open(index_path, "w", encoding="utf-8") as fp:
json.dump(index, fp, indent=2)

return index


def offload_to_disk(module: List[str] | nn.Module, model: nn.Module, disk_path: str = "."):
with _OFFLOAD_LOCK:
_offload_to_disk_impl(module=module, model=model, disk_path=disk_path)
Expand Down Expand Up @@ -119,11 +170,15 @@ def _offload_disk(module: nn.Module, name: str, disk_path: str = "."):
if not has_params and not has_buffers:
return

module_offload_dir = os.path.join(disk_path, name)

_prepare_offload_directory(module_offload_dir)
_bundle_module_state_dict(module, module_offload_dir)

_ = disk_offload(
module,
# device_map={ "" : "disk" }, # only touch this subtree
offload_dir=f"{disk_path}/{name}",
offload_buffers=True, # needed for buffers
offload_dir=module_offload_dir,
offload_buffers=True,
execution_device=m_device,
)

Expand Down
81 changes: 81 additions & 0 deletions tests/test_offload_files.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
# SPDX-FileCopyrightText: 2024-2025 ModelCloud.ai
# SPDX-FileCopyrightText: 2024-2025 qubitium@modelcloud.ai
# SPDX-License-Identifier: Apache-2.0
# Contact: qubitium@modelcloud.ai, x.com/qubitium

import json
from pathlib import Path

import torch
from safetensors import safe_open
from tabulate import tabulate
from torch import nn

from gptqmodel.utils.model import get_state_dict_for_save, streaming_state_dict_to_shards
from gptqmodel.utils.offload import offload_to_disk, undo_offload_to_disk


class _LinearWithBuffers(nn.Module):
def __init__(self, in_features: int, out_features: int):
super().__init__()
self.linear = nn.Linear(in_features, out_features, bias=True)
self.register_buffer("scale_buffer", torch.linspace(0.0, 1.0, out_features))
self.register_buffer("mask_buffer", torch.randint(0, 2, (out_features, in_features)).bool())

def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.linear(x * self.mask_buffer.float()) * self.scale_buffer


def _clone_state_dict(module: nn.Module) -> dict[str, torch.Tensor]:
return {k: v.detach().clone() for k, v in module.state_dict().items()}


def test_offload_to_disk_writes_single_dat_file(tmp_path):
model = _LinearWithBuffers(in_features=128, out_features=96)
original_state = _clone_state_dict(model.linear)

offload_root = tmp_path / "offload_root"
offload_to_disk(module=model.linear, model=model, disk_path=str(offload_root))

module_dir = offload_root / "linear"
assert module_dir.is_dir(), "Expected per-module directory to exist"

files = sorted(module_dir.iterdir(), key=lambda p: p.name)
rows = [(path.name, path.stat().st_size) for path in files]
print(tabulate(rows, headers=["file", "bytes"], tablefmt="github"))

safetensor_files = [path for path in files if path.suffix == ".safetensors"]
assert len(safetensor_files) == 1, "offload_to_disk should produce exactly one safetensors file per module"
assert safetensor_files[0].name == "module.safetensors"

with open(module_dir / "index.json", encoding="utf-8") as fp:
index = json.load(fp)

expected_keys = set(model.linear.state_dict().keys())
assert set(index.keys()) == expected_keys
assert all(Path(entry.get("safetensors_file")).name == "module.safetensors" for entry in index.values())
assert all(entry.get("data_offsets") is not None for entry in index.values())

save_dir = tmp_path / "saved"
save_dir.mkdir()
state_dict = get_state_dict_for_save(model, offload_root=str(offload_root))
expected_files, tensor_to_filename, _ = streaming_state_dict_to_shards(
state_dict,
save_dir=str(save_dir),
model_base_name="model",
single_file_name="model.safetensors",
metadata={},
max_shard_size=None,
)

assert len(expected_files) == 1
shard_path = save_dir / expected_files[0]
with safe_open(str(shard_path), framework="pt", device="cpu") as handler:
for name, tensor in original_state.items():
saved = handler.get_tensor(f"linear.{name}")
torch.testing.assert_close(saved, tensor)

# Materialize the module back and ensure values match the snapshot captured before offload.
undo_offload_to_disk(model.linear, delete_offload_folders=False)
for name, tensor in model.linear.state_dict().items():
torch.testing.assert_close(tensor, original_state[name])