Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 2 additions & 10 deletions gptqmodel/looper/eora_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
# SPDX-License-Identifier: Apache-2.0
# Contact: qubitium@modelcloud.ai, x.com/qubitium

import contextlib
import copy
import time
from typing import Callable, Dict, Optional, Tuple
Expand Down Expand Up @@ -106,15 +107,6 @@ def tmp(module, input: Tuple[torch.Tensor, ...], output: torch.Tensor):
)
return tmp

def pre_process_streaming(self, module: NamedModule):
eigen_matrix = self.eigen_scaling_diag_matrix[module.name]
with torch_streamCtx(module.target_device_stream):
if eigen_matrix is not None:
self.eigen_scaling_diag_matrix[module.name] = eigen_matrix.to(device=module.target_device, non_blocking=True)

module.state["w_wq_diff"] = module.state["w_wq_diff"].to(device=module.target_device, non_blocking=True)
module.state["wq"] = module.state["wq"].to(device=module.target_device, non_blocking=True)

def process(self, module: NamedModule):
assert isinstance(module.adapter_cfg, Lora)

Expand Down Expand Up @@ -185,7 +177,7 @@ def process(self, module: NamedModule):
PROCESS_LOG_LAYER: module.layer_index,
PROCESS_LOG_MODULE: module.name,
PROCESS_LOG_TIME: f"{duration:.3f}",
PROCESS_LOG_FWD_TIME: f"{self.fwd_time:.3f}",
PROCESS_LOG_FWD_TIME: self.formatted_fwd_time(),
PROCESS_MAX_MEMORY: max_memory,
}

Expand Down
11 changes: 2 additions & 9 deletions gptqmodel/looper/gptq_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
# SPDX-License-Identifier: Apache-2.0
# Contact: qubitium@modelcloud.ai, x.com/qubitium

import contextlib
import copy
import threading
from typing import Callable, Optional, Tuple
Expand Down Expand Up @@ -108,14 +109,6 @@ def tmp(module, inp: Tuple[torch.Tensor, ...], out: torch.Tensor):
del inp, out
return tmp

def pre_process_streaming(self, module: NamedModule):
g = self.tasks[module.name]
with torch_streamCtx(module.target_device_stream):
# log.debug(f"streaming module `{g.name}` to device = `{module.target_device}`")
if g.H is not None:
g.H = g.H.to(device=module.target_device, non_blocking=True)
g.module.weight.data = g.module.weight.data.to(device=module.target_device, non_blocking=True)

def process(self, module: NamedModule):
# Reset peak memory stats
#torch.cuda.reset_peak_memory_stats()
Expand Down Expand Up @@ -172,7 +165,7 @@ def process(self, module: NamedModule):
QUANT_LOG_NSAMPLES: f"{nsamples}",
QUANT_LOG_DAMP: f"{damp_percent:.5f}",
PROCESS_LOG_TIME: f"{duration:.3f}",
PROCESS_LOG_FWD_TIME: f"{self.fwd_time:.3f}",
PROCESS_LOG_FWD_TIME: self.formatted_fwd_time(),
PROCESS_MAX_MEMORY: get_max_memory(),
}

Expand Down
6 changes: 5 additions & 1 deletion gptqmodel/looper/loop_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,6 +277,10 @@ def set_calibration_dataset(self, calibration_dataset):
def set_fwd_time(self, fwd_time: float):
self.fwd_time = fwd_time

def formatted_fwd_time(self) -> str:
fwd_time = self.fwd_time if self.fwd_time is not None else 0.0
return f"{fwd_time:.3f}"

# called first
def preprocess(self, module: NamedModule, **kwargs):
pass
Expand Down Expand Up @@ -353,4 +357,4 @@ def get_max_memory() -> str:
max_memory = f"{active_0:.2f}MB, {active_1:.2f}MB"
else:
max_memory = f"{active_0:.2f}MB"
return max_memory
return max_memory
36 changes: 13 additions & 23 deletions gptqmodel/looper/module_looper.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,20 +197,16 @@ def _clone_module_for_devices(self, module: torch.nn.Module, devices: List[torch
module_label = getattr(module, "full_name", module.__class__.__name__)
clone_timings = [] if DEBUG_ON else None

cleared_attrs = self._clear_non_picklable_state(module)
try:
for dev in devices:
start_ts = time.perf_counter() if DEBUG_ON else None
replica = copy.deepcopy(module)
replica = replica.to(dev)
replica.eval()
_rehome_module_to_device(replica, dev, move_parameters=False, move_buffers=True)
self._clear_non_picklable_state(replica)
clones[dev] = replica
if clone_timings is not None and start_ts is not None:
clone_timings.append((dev, time.perf_counter() - start_ts))
finally:
self._restore_non_picklable_state(cleared_attrs)
for dev in devices:
start_ts = time.perf_counter() if DEBUG_ON else None
replica = copy.deepcopy(module)
replica = replica.to(dev)
replica.eval()
_rehome_module_to_device(replica, dev, move_parameters=False, move_buffers=True)
clones[dev] = replica
if clone_timings is not None and start_ts is not None:
clone_timings.append((dev, time.perf_counter() - start_ts))

if clone_timings:
timing_str = ", ".join(f"{str(dev)}={duration * 1000:.2f}ms" for dev, duration in clone_timings)
log.debug(f"ModuleLooper: deepcopy {module_label} -> {timing_str}")
Expand All @@ -224,9 +220,6 @@ def maybe_clear(obj: torch.nn.Module):
if id(obj) in seen:
return
seen.add(id(obj))
if hasattr(obj, "target_device_stream"):
cleared.append((obj, getattr(obj, "target_device_stream")))
setattr(obj, "target_device_stream", None)

if isinstance(module, torch.nn.Module):
for sub in module.modules():
Expand All @@ -236,11 +229,6 @@ def maybe_clear(obj: torch.nn.Module):

return cleared

@staticmethod
def _restore_non_picklable_state(cleared):
for obj, value in cleared:
setattr(obj, "target_device_stream", value)

@staticmethod
def _forward_batch_worker(
module: torch.nn.Module,
Expand Down Expand Up @@ -828,7 +816,9 @@ def loop(self, fail_safe: bool = False, **kwargs):
for idx, (name, m) in enumerate(subset.items()):
is_last = (idx == subset_size - 1)

m.module.target_device, m.module.target_device_stream = device_next()
target_device = device_next()
m.target_device = target_device
m.module.target_device = target_device

# Wrap the processor hook with masking
if hasattr(subset[name], 'forward_hook'):
Expand Down
4 changes: 2 additions & 2 deletions gptqmodel/looper/named_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@ def __init__(self, module: torch.nn.Module, name: str, full_name:str, layer_inde
self.layer_index = layer_index # layerid in a repeating layer, if in outside layer, this info may be fake

# some processing will move this module to target_device gptq, eora, etc
# self.target_device, self.target_device_stream = device_next()
self.target_device, self.target_device_stream = None, None
# self.target_device = device_next()
self.target_device = None


# persistent work state for named module (used by some LoopProcessors)
Expand Down
11 changes: 2 additions & 9 deletions gptqmodel/looper/qqq_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
# SPDX-License-Identifier: Apache-2.0
# Contact: qubitium@modelcloud.ai, x.com/qubitium

import contextlib
import copy
from typing import Callable, Optional, Tuple

Expand Down Expand Up @@ -107,14 +108,6 @@ def tmp(_, inp: Tuple[torch.Tensor, ...], out: torch.Tensor):
q.add_batch(inp[0].data, out.data) # noqa: F821
return tmp

def pre_process_streaming(self, module: NamedModule):
q = self.tasks[module.name]
with torch_streamCtx(module.target_device_stream):
if q.H is not None:
q.H = q.H.to(device=module.target_device, non_blocking=True)
module.weight.data = module.weight.data.to(device=module.target_device, non_blocking=True)


def process(self, module: NamedModule):
self.pb.title(f"Quantizing {module.name} in layer ").draw()
qqq = self.tasks
Expand Down Expand Up @@ -156,7 +149,7 @@ def process(self, module: NamedModule):
QUANT_LOG_NSAMPLES: f"{nsamples}",
QUANT_LOG_DAMP: f"{damp_percent:.5f}",
PROCESS_LOG_TIME: f"{duration:.3f}",
PROCESS_LOG_FWD_TIME: f"{self.fwd_time:.3f}",
PROCESS_LOG_FWD_TIME: self.formatted_fwd_time(),
}

if self.qcfg.dynamic is not None:
Expand Down
2 changes: 1 addition & 1 deletion gptqmodel/quantization/gptq.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def __init__(self, module: nn.Module, qcfg: Optional[QuantizeConfig] = None):
self.name = HF_OPTIMUM
self.module = module
# emulate NamedModule properties
self.module.target_device, self.module.target_device_stream = device_next()
self.module.target_device = device_next()

self._validate_module(self.module)

Expand Down
2 changes: 1 addition & 1 deletion gptqmodel/quantization/qqq.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,7 @@ def __init__(self, module: nn.Module, qcfg: Optional[QuantizeConfig] = None):
self.name = HF_OPTIMUM
self.layer = module
# emulate NamedModule properties
self.layer.target_device, self.layer.target_device_stream = device_next()
self.layer.target_device = device_next()
self.dev = self.layer.weight.device

self._validate_module(self.layer)
Expand Down
7 changes: 3 additions & 4 deletions gptqmodel/utils/torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,14 +217,13 @@ def device_next_reset():
global NEXT_DEVICE_INDEX
NEXT_DEVICE_INDEX = 0

def device_next(balance_strategy: BalanceStrategy = DEFAULT_BALANCE_STRATEGY) -> (torch.device, Union[torch.cuda.Stream, torch.xpu.Stream]):
def device_next(balance_strategy: BalanceStrategy = DEFAULT_BALANCE_STRATEGY) -> torch.device:
global NEXT_DEVICE_INDEX

if len(ALL_DEVICES) <= 1:
return ALL_DEVICES[0], ALL_STREAMS[0]
return ALL_DEVICES[0]

device = ALL_DEVICES[NEXT_DEVICE_INDEX]
device_stream = ALL_STREAMS[NEXT_DEVICE_INDEX]
if NEXT_DEVICE_INDEX < len(ALL_DEVICES) - 1:
NEXT_DEVICE_INDEX += 1
else:
Expand All @@ -233,7 +232,7 @@ def device_next(balance_strategy: BalanceStrategy = DEFAULT_BALANCE_STRATEGY) ->
else:
NEXT_DEVICE_INDEX = 0

return (device, device_stream)
return device

def torch_streamCtx(stream: Union[torch.cuda.Stream, torch.xpu.Stream]) -> StreamContext:
return torch.cuda.stream(stream) if HAS_CUDA else torch.xpu.stream(stream)
Expand Down
84 changes: 84 additions & 0 deletions tests/test_gptq_device_ctx.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
import concurrent.futures
import os
import sys
from typing import Dict, List

import pytest
import torch

from gptqmodel.looper.gptq_processor import GPTQProcessor
from gptqmodel.looper.named_module import NamedModule
from gptqmodel.quantization.config import QuantizeConfig


def _dummy_prepare_dataset(*, calibration_dataset, calibration_dataset_concat_size, calibration_dataset_sort, batch_size):
return calibration_dataset


class _DummyProgressBar:
def title(self, _):
return self

def draw(self):
return None


def _is_free_threaded() -> bool:
gil_check = getattr(sys, "_is_gil_enabled", None)
if callable(gil_check):
return not gil_check()
env_value = os.environ.get("PYTHON_GIL", "1").lower()
return env_value in {"0", "false", "off"}


def _run_quant_on_device(device_index: int) -> torch.device:
torch.cuda.set_device(device_index)
target = torch.device(f"cuda:{device_index}")
module = torch.nn.Linear(8, 8, bias=False).to(target)
named = NamedModule(module, name=f"linear_{device_index}", full_name=f"model.layers.{device_index}.linear", layer_index=device_index)

qcfg = QuantizeConfig(mock_quantization=True, group_size=-1, desc_act=False)
processor = GPTQProcessor(
tokenizer=None,
qcfg=qcfg,
calibration=None,
prepare_dataset_func=_dummy_prepare_dataset,
calibration_concat_size=None,
calibration_sort=None,
batch_size=1,
logger_board="",
require_fwd=False,
calculate_w_wq_diff=False,
)
processor.pb = _DummyProgressBar()

processor.preprocess(named, fail_safe=False)
named.module.target_device = target

processor.process(named)

return named.weight.data.device


#@pytest.mark.cuda
def test_gptq_quantize_keeps_weight_on_assigned_device_multigpu_free_thread():
if not torch.cuda.is_available():
pytest.skip("CUDA required for multi-GPU device context test")

if torch.cuda.device_count() < 8:
pytest.skip("Requires at least 8 CUDA devices")

if sys.version_info < (3, 13):
pytest.skip("Requires Python 3.13 free-threading build")

if not _is_free_threaded():
pytest.skip("Requires PYTHON_GIL=0 (free-threading)")

device_indices: List[int] = list(range(8))

with concurrent.futures.ThreadPoolExecutor(max_workers=len(device_indices)) as pool:
futures = [pool.submit(_run_quant_on_device, idx) for idx in device_indices]
results: Dict[int, torch.device] = {idx: future.result() for idx, future in zip(device_indices, futures)}

for idx, device in results.items():
assert device.type == "cuda" and device.index == idx
1 change: 1 addition & 0 deletions tests/test_tf32_performance.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

from gptqmodel.utils.torch import tf32_disable_guard, tf32_enable_guard


try:
from tabulate import tabulate
except ImportError: # pragma: no cover
Expand Down