Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 56 additions & 4 deletions gptqmodel/looper/module_looper.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from typing import Dict, List, Optional

import torch
from torch.nn.parallel import replicate as torch_replicate

from ..looper.dequantize_processor import DequantizeProcessor
from ..looper.eora_processor import EoraProcessor
Expand Down Expand Up @@ -197,14 +198,63 @@ def _clone_module_for_devices(self, module: torch.nn.Module, devices: List[torch
module_label = getattr(module, "full_name", module.__class__.__name__)
clone_timings = [] if DEBUG_ON else None

if len(devices) == 1:
target_device = devices[0]
module = module.to(target_device)
module.eval()
_rehome_module_to_device(module, target_device, move_parameters=True, move_buffers=True)
self._clear_non_picklable_state(module)
setattr(module, "_gptqmodule_device_hint", target_device)
clones[target_device] = module
return clones

self._clear_non_picklable_state(module)
primary_device = devices[0]
replicate_devices = [dev for dev in devices if dev.type == primary_device.type]
use_replicate = (
len(replicate_devices) == len(devices)
and primary_device.type == "cuda"
and torch.cuda.is_available()
)

if use_replicate:
try:
module = module.to(primary_device)
module.eval()
_rehome_module_to_device(module, primary_device, move_parameters=True, move_buffers=True)
setattr(module, "_gptqmodule_device_hint", primary_device)
replicate_start = time.perf_counter() if DEBUG_ON else None
replicas = torch_replicate(module, devices)
if clone_timings is not None and replicate_start is not None:
clone_timings.append(("replicate", time.perf_counter() - replicate_start))

for dev, replica in zip(devices, replicas):
replica.eval()
_rehome_module_to_device(replica, dev, move_parameters=True, move_buffers=True)
self._clear_non_picklable_state(replica)
setattr(replica, "_gptqmodule_device_hint", dev)
clones[dev] = replica

if clone_timings:
timing_str = ", ".join(
f"{str(dev)}={duration * 1000:.2f}ms" for dev, duration in clone_timings
)
log.debug(f"ModuleLooper: replicate {module_label} -> {timing_str}")
return clones
except Exception:
log.info("Clone: fast clone failed")
# Fall back to deepcopy path if replicate is unsupported for this module.
if clone_timings is not None:
clone_timings.append(("replicate_failed", 0.0))

for dev in devices:
start_ts = time.perf_counter() if DEBUG_ON else None
replica = copy.deepcopy(module)
replica = replica.to(dev)
replica.eval()
_rehome_module_to_device(replica, dev, move_parameters=False, move_buffers=True)
_rehome_module_to_device(replica, dev, move_parameters=True, move_buffers=True)
self._clear_non_picklable_state(replica)
setattr(replica, "_gptqmodule_device_hint", dev)
clones[dev] = replica
if clone_timings is not None and start_ts is not None:
clone_timings.append((dev, time.perf_counter() - start_ts))
Expand Down Expand Up @@ -254,8 +304,8 @@ def _forward_batch_worker(
update. The thin signature keeps the function pickleable for the worker
queue.
"""
module_device = get_device(module)
_rehome_module_to_device(module, module_device, move_parameters=False, move_buffers=True)
module_device = getattr(module, "_gptqmodule_device_hint", None) or get_device(module)
_rehome_module_to_device(module, module_device, move_parameters=True, move_buffers=True)

inputs = [move_to(inp, device=module_device, stream=False) for inp in layer_input]

Expand Down Expand Up @@ -408,7 +458,7 @@ def _run_forward_batches_single(
if reuse_kv and prev_kv is not None:
additional_inputs["kv_last_layer"] = nested_move_to(prev_kv, device=cur_layer_device, stream=False)

_rehome_module_to_device(module, cur_layer_device, move_parameters=False, move_buffers=True)
_rehome_module_to_device(module, cur_layer_device, move_parameters=True, move_buffers=True)

module_output = None
try:
Expand Down Expand Up @@ -758,6 +808,8 @@ def loop(self, fail_safe: bool = False, **kwargs):
quant_modules_pb.title(f"Quantizing layer {layer_index} of {layer_count - 1}").draw()
module = layers[layer_index]

self.gptq_model.wait_for_turtle_reload()

if module.__class__.__name__.lower() == "MllamaCrossAttentionDecoderLayer".lower():
# TODO FIXME: currently we not support quantizing cross attention layer (pixel_values)
continue
Expand Down
65 changes: 50 additions & 15 deletions gptqmodel/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,8 @@ def __init__(
self._background_pool: Optional["DeviceThreadPool"] = None
self._turtle_reload_future: Optional[Future] = None
self._turtle_reload_lock = threading.Lock()
self._turtle_ready = threading.Event()
self._turtle_ready.set()

# compat: state to assist in checkpoint_format gptq(v1) to gptq_v2 conversion
self.qlinear_kernel = qlinear_kernel
Expand Down Expand Up @@ -1086,6 +1088,24 @@ def format_nodes(nodes):
def register_background_pool(self, pool: Optional["DeviceThreadPool"]) -> None:
self._background_pool = pool

def _wait_for_turtle_ready(self, timeout: Optional[float] = None) -> None:
if self.turtle_model is None:
return

self._turtle_ready.wait(timeout=timeout)

def _ensure_turtle_ready(self) -> None:
if self.turtle_model is None:
return

self._apply_completed_turtle_reload()

if self._turtle_ready.is_set():
return

self._wait_for_turtle_ready()
self._apply_completed_turtle_reload()

def _clone_model_init_kwargs(self, source: PreTrainedModel) -> Dict[str, Any]:
kwargs = getattr(source, "_model_init_kwargs", {}) or {}
if isinstance(kwargs, dict):
Expand All @@ -1112,13 +1132,16 @@ def _schedule_turtle_reload(self) -> None:
self._apply_completed_turtle_reload()

if self.turtle_model is None or self.model_local_path is None:
self._turtle_ready.set()
return

pool = self._background_pool
if pool is None:
self._turtle_ready.clear()
new_model = self._reload_turtle_model_sync()
if new_model is not None:
self.turtle_model = new_model
self._turtle_ready.set()
return

with self._turtle_reload_lock:
Expand All @@ -1132,25 +1155,35 @@ def _schedule_turtle_reload(self) -> None:
loader = self.loader

def _reload_task():
model = loader.from_pretrained(
model_local_path,
config=config,
low_cpu_mem_usage=True,
**reload_kwargs,
)
model._model_init_kwargs = reload_kwargs
return model
try:
model = loader.from_pretrained(
model_local_path,
config=config,
low_cpu_mem_usage=True,
**reload_kwargs,
)
model._model_init_kwargs = reload_kwargs
return model
finally:
self._turtle_ready.set()

self._turtle_ready.clear()

try:
self._turtle_reload_future = pool.submit(CPU, _reload_task)
future = pool.submit(CPU, _reload_task)
except Exception as exc:
log.warning("Turtle reload scheduling failed; falling back to sync reload: %s", exc)
self._turtle_reload_future = None
else:
self._turtle_reload_future = future
return

if self._turtle_reload_future is None:
new_model = self._reload_turtle_model_sync()
if new_model is not None:
self.turtle_model = new_model
# Fallback: synchronous reload (either pool missing or submit failed)
self._turtle_ready.clear()
new_model = self._reload_turtle_model_sync()
if new_model is not None:
self.turtle_model = new_model
self._turtle_ready.set()

def _apply_completed_turtle_reload(self, *, wait: bool = False) -> None:
"""Adopt the result of any finished background turtle reload."""
Expand Down Expand Up @@ -1181,15 +1214,17 @@ def _apply_completed_turtle_reload(self, *, wait: bool = False) -> None:
self._turtle_reload_future = None
if new_model is not None:
self.turtle_model = new_model
self._turtle_ready.set()
return

with self._turtle_reload_lock:
if self._turtle_reload_future is future:
self.turtle_model = new_model
self._turtle_reload_future = None
self._turtle_ready.set()

def wait_for_turtle_reload(self) -> None:
self._apply_completed_turtle_reload(wait=True)
self._ensure_turtle_ready()

# transfer actually materizlied module from turtle (real) to shell
def shell_module_materialize(
Expand All @@ -1198,7 +1233,7 @@ def shell_module_materialize(
device: torch.device,
non_blocking: bool = False,
) -> torch.nn.Module:
self._apply_completed_turtle_reload()
self._ensure_turtle_ready()

if self.turtle_model is None:
if get_device(target_submodule) != device:
Expand Down
154 changes: 154 additions & 0 deletions tests/test_torch_replicate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
import copy
from typing import Callable

import pytest
import torch
from tabulate import tabulate
from torch.nn.parallel import replicate as torch_replicate


TIMED_TRIALS = 5
WARMUP_TRIALS = 1


def _build_template_module() -> torch.nn.Module:
torch.manual_seed(0)
return torch.nn.Sequential(
torch.nn.Linear(4096, 4096, bias=False),
torch.nn.GELU(),
torch.nn.Linear(4096, 4096, bias=False),
)


def _replicate_strategy(module: torch.nn.Module, devices: list[torch.device]) -> list[torch.nn.Module]:
return torch_replicate(module, devices)


def _deepcopy_strategy(module: torch.nn.Module, devices: list[torch.device]) -> list[torch.nn.Module]:
clones = []
for dev in devices:
replica = copy.deepcopy(module)
clones.append(replica.to(dev))
return clones


def _benchmark(
strategy: Callable[[torch.nn.Module, list[torch.device]], list[torch.nn.Module]],
devices: list[torch.device],
template: torch.nn.Module,
*,
trials: int = TIMED_TRIALS,
warmup: int = WARMUP_TRIALS,
) -> tuple[list[float], list[int]]:
times: list[float] = []
mems: list[int] = []

def _run(record: bool) -> None:
module = copy.deepcopy(template).to(devices[0]).eval()
torch.cuda.synchronize()

baselines = {}
for dev in devices:
baselines[dev] = torch.cuda.memory_allocated(dev)
torch.cuda.reset_peak_memory_stats(dev)

start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)

start_event.record()
clones = strategy(module, devices)
end_event.record()
torch.cuda.synchronize()

if record:
duration = start_event.elapsed_time(end_event) / 1000.0
extra_mem = 0
for dev in devices:
peak = torch.cuda.max_memory_allocated(dev)
extra_mem += max(0, peak - baselines[dev])

times.append(duration)
mems.append(extra_mem)

del clones
del module
torch.cuda.empty_cache()
torch.cuda.synchronize()

for _ in range(warmup):
_run(record=False)
for _ in range(trials):
_run(record=True)

return times, mems


def _summarise_metrics(times: list[float], mems: list[int]):
avg_time = sum(times) / len(times)
avg_mem = sum(mems) / len(mems)
return {
"time_avg": avg_time,
"time_min": min(times),
"time_max": max(times),
"mem_avg": avg_mem,
"mem_min": min(mems),
"mem_max": max(mems),
}


@pytest.mark.cuda
def test_torch_replicate():
if not torch.cuda.is_available() or torch.cuda.device_count() < 2:
pytest.skip("torch.nn.parallel.replicate comparison requires at least two CUDA devices")

devices = [torch.device(f"cuda:{idx}") for idx in range(2)]
template = _build_template_module()

replicate_times, replicate_mems = _benchmark(_replicate_strategy, devices, template)
deepcopy_times, deepcopy_mems = _benchmark(_deepcopy_strategy, devices, template)

replicate_summary = _summarise_metrics(replicate_times, replicate_mems)
deepcopy_summary = _summarise_metrics(deepcopy_times, deepcopy_mems)

table = [
[
"replicate",
replicate_summary["time_avg"],
replicate_summary["time_min"],
replicate_summary["time_max"],
replicate_summary["mem_avg"] / (1024 ** 2),
replicate_summary["mem_min"] / (1024 ** 2),
replicate_summary["mem_max"] / (1024 ** 2),
],
[
"deepcopy",
deepcopy_summary["time_avg"],
deepcopy_summary["time_min"],
deepcopy_summary["time_max"],
deepcopy_summary["mem_avg"] / (1024 ** 2),
deepcopy_summary["mem_min"] / (1024 ** 2),
deepcopy_summary["mem_max"] / (1024 ** 2),
],
]

headers = [
"strategy",
"time_avg_s",
"time_min_s",
"time_max_s",
"mem_avg_MB",
"mem_min_MB",
"mem_max_MB",
]

print(tabulate(table, headers=headers, floatfmt=".4f"))

assert replicate_summary["time_avg"] <= deepcopy_summary["time_avg"], (
"replicate slower than deepcopy: "
f"replicate={replicate_summary['time_avg']:.4f}s, deepcopy={deepcopy_summary['time_avg']:.4f}s"
)
assert replicate_summary["mem_avg"] <= deepcopy_summary["mem_avg"], (
"replicate used more memory: "
f"replicate={replicate_summary['mem_avg'] / (1024 ** 2):.1f}MB, "
f"deepcopy={deepcopy_summary['mem_avg'] / (1024 ** 2):.1f}MB"
)