Skip to content
Merged

Cleanup #1997

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 36 additions & 6 deletions gptqmodel/looper/module_looper.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

import threading
import time
from concurrent.futures import as_completed
from contextlib import nullcontext
from typing import Dict, List, Optional, TYPE_CHECKING

Expand Down Expand Up @@ -1241,25 +1242,54 @@ def _finalize_on_worker(process, module, idx, total, module_label, layer_idx):

finalize_futures_snapshot = list(finalize_futures)

def _drain_finalize_futures(futures, finalize_pb_local, finalize_count_local, progress_bar):
if finalize_futures_snapshot:
finalize_pb.title(
f"Submodule finalize 0/{finalize_count}"
).subtitle("Waiting for completions...").draw()

future_metadata = {
future: (module_label, process, layer_idx)
for future, _, module_label, process, layer_idx in finalize_futures_snapshot
}

def _drain_finalize_futures(
futures,
finalize_pb_local,
finalize_count_local,
future_metadata_local,
):
completed_local = 0
try:
for future, idx, module_label, process, layer_idx in futures:
for future in as_completed(futures):
module_label, process, layer_idx = future_metadata_local.get(
future, (None, None, None)
)

future.result()

layer_label = f"Layer {layer_idx}" if layer_idx is not None else "layer ?"
display_module = module_label or "<unnamed>"
subtitle = f"{process.name()}: {display_module}"
processor_name = process.name() if process is not None else "<processor>"
subtitle = f"{processor_name}: {display_module}"

completed_local += 1
finalize_pb_local.next()
finalize_pb_local.title(
f"{layer_label} Finalize {idx}/{finalize_count_local}"
).subtitle(subtitle).next().draw()
f"{layer_label} Finalize {completed_local}/{finalize_count_local}"
).subtitle(subtitle).draw()
finally:
finalize_pb_local.close()

if finalize_futures_snapshot:
# Drain finalize futures asynchronously so the main loop can continue scheduling work.
threading.Thread(
target=_drain_finalize_futures,
args=(finalize_futures_snapshot, finalize_pb, finalize_count, pb),
args=(
[future for future, *_ in finalize_futures_snapshot],
finalize_pb,
finalize_count,
future_metadata,
),
name="SubmoduleFinalizeWatcher",
daemon=True,
).start()
Expand Down
3 changes: 2 additions & 1 deletion gptqmodel/models/auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@
from .definitions.phi3 import Phi3QModel, PhiMoEGPTQForCausalLM # noqa: E402
from .definitions.phi4 import Phi4MMGPTQ # noqa: E402
from .definitions.qwen import QwenQModel # noqa: E402
from .definitions.qwen2 import Qwen2QModel # noqa: E402
from .definitions.qwen2_5_omni import Qwen2_5_OmniGPTQ
from .definitions.qwen2_5_vl import Qwen2_5_VLQModel # noqa: E402
from .definitions.qwen2_moe import Qwen2MoeQModel # noqa: E402
Expand Down Expand Up @@ -162,7 +163,7 @@
"stablelm": LlamaQModel, # 100% llama clone
"starcoder2": Starcoder2QModel,
"mixtral": MixtralQModel,
"qwen2": LlamaQModel, # 100% llama clone
"qwen2": Qwen2QModel,
"qwen3": Qwen3QModel,
"longllama": LlamaQModel, # 100% llama clone
"gemma": LlamaQModel, # 100% llama clone
Expand Down
1 change: 1 addition & 0 deletions gptqmodel/models/definitions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
from .phi import PhiQModel
from .phi3 import Phi3QModel
from .qwen import QwenQModel
from .qwen2 import Qwen2QModel
from .qwen2_5_vl import Qwen2_5_VLQModel
from .qwen2_moe import Qwen2MoeQModel
from .qwen2_vl import Qwen2VLQModel
Expand Down
12 changes: 12 additions & 0 deletions gptqmodel/models/definitions/qwen2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
# SPDX-FileCopyrightText: 2024-2025 ModelCloud.ai
# SPDX-FileCopyrightText: 2024-2025 qubitium@modelcloud.ai
# SPDX-License-Identifier: Apache-2.0
# Contact: qubitium@modelcloud.ai, x.com/qubitium

from __future__ import annotations

from .llama import LlamaQModel


class Qwen2QModel(LlamaQModel):
pass
2 changes: 1 addition & 1 deletion gptqmodel/models/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def check_versions(model_class, requirements: List[str]):
def get_model_local_path(pretrained_model_id_or_path, **kwargs):
is_local = os.path.isdir(pretrained_model_id_or_path)
if is_local:
return pretrained_model_id_or_path
return os.path.normpath(pretrained_model_id_or_path)
else:
# Clone kwargs before modifying
download_kwargs = kwargs.copy()
Expand Down
2 changes: 1 addition & 1 deletion gptqmodel/nn_modules/qlinear/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -628,7 +628,7 @@ def _process_block(i0: int, i1: int):
# ---------- schedule blocks across a thread pool ----------
starts = list(range(0, in_features, block_in))
ranges = [(i0, min(i0 + block_in, in_features)) for i0 in starts]
total_blocks = len(ranges)
len(ranges)

# TODO FIX ME...threads safety issue with threaded block work
workers_eff = 1 # max(1, min(workers, total_blocks))
Expand Down
2 changes: 1 addition & 1 deletion gptqmodel/nn_modules/qlinear/torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

class TorchQuantLinear(PackableQuantLinear):
SUPPORTS_BITS = [2, 3, 4, 8]
SUPPORTS_GROUP_SIZE = [-1, 16, 32, 64, 128]
SUPPORTS_GROUP_SIZE = [-1, 16, 32, 64, 128, 256, 512, 1024]
SUPPORTS_DESC_ACT = [True, False]
SUPPORTS_SYM = [True, False]
SUPPORTS_SHARDS = True
Expand Down
3 changes: 1 addition & 2 deletions gptqmodel/nn_modules/qlinear/tritonv2.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ class TritonModuleMixin:

class TritonV2QuantLinear(TorchQuantLinear, TritonModuleMixin):
SUPPORTS_BITS = [2, 4, 8]
SUPPORTS_GROUP_SIZE = [-1, 16, 32, 64, 128]
SUPPORTS_GROUP_SIZE = [-1, 16, 32, 64, 128, 256, 512, 1024]
SUPPORTS_DESC_ACT = [True, False]
SUPPORTS_SYM = [True, False]
SUPPORTS_SHARDS = True
Expand Down Expand Up @@ -207,4 +207,3 @@ def triton_xpu_available():
except Exception:
return False


4 changes: 2 additions & 2 deletions gptqmodel/quantization/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,10 +293,10 @@ def __post_init__(self):
if key == "bits" and value not in fields_info[0].metadata["choices"]:
raise ValueError(f"QuantizeConfig: Layer `{layer}` only support quantization of `{fields_info[0].metadata['choices']}` bits.")
elif key == "group_size" and value != -1 and value <= 0:
raise ValueError("QuantizeConfig: `group_size` must in the value set of `[-1, 16, 32, 64, 128]`.")
raise ValueError("QuantizeConfig: `group_size` must be one of `[-1, 16, 32, 64, 128, 256, 512, 1024]`.")

if self.group_size != -1 and self.group_size <= 0:
raise ValueError("QuantizeConfig: `group_size` must in the value set of `[-1, 16, 32, 64, 128]`.")
raise ValueError("QuantizeConfig: `group_size` must be one of `[-1, 16, 32, 64, 128, 256, 512, 1024]`.")

if not (0 < self.damp_percent < 1):
raise ValueError("QuantizeConfig: `damp_percent` must between 0 and 1.")
Expand Down
3 changes: 1 addition & 2 deletions gptqmodel/utils/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,9 @@ def log_time_block(
if logger is None:
logger = setup_logger()

label = block_name if not module_name else f"{module_name}: {block_name}"
start = time.perf_counter()
try:
yield
finally:
duration = time.perf_counter() - start
time.perf_counter() - start
#logger.info(f"[time] {label} took {duration:.3f}s")
40 changes: 31 additions & 9 deletions gptqmodel/utils/looper_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,27 +164,49 @@ def maybe_clear(obj: torch.nn.Module):
return cleared


def _canonical_device(device: torch.device) -> torch.device:
"""Return a canonical form so indexless accelerators collapse to device:0."""
if device.type in {"cuda", "xpu", "npu"}:
index = device.index if device.index is not None else 0
return torch.device(f"{device.type}:{index}")
return device


def select_forward_devices(base_device: Optional[torch.device]) -> List[torch.device]:
if base_device is None:
return [CPU]

devices = [base_device]
base_type = base_device.type
if base_type in ("cuda", "xpu", "mps"):
devices: List[torch.device] = []
seen: set[tuple[str, int | None]] = set()

def _add(device: torch.device) -> None:
canonical = _canonical_device(device)
key = (canonical.type, canonical.index)
if key in seen:
return
seen.add(key)
devices.append(canonical)

_add(base_device)
base_type = devices[0].type
if base_type in {"cuda", "xpu", "mps", "npu"}:
for dev in ALL_DEVICES:
if dev.type == base_type and dev not in devices:
devices.append(dev)
if dev.type == base_type:
_add(dev)
return devices


def normalize_device_like(device_like) -> Optional[torch.device]:
if device_like is None:
return None
if isinstance(device_like, torch.device):
return device_like
if hasattr(device_like, "to_torch_device"):
return device_like.to_torch_device()
return torch.device(str(device_like))
device = device_like
elif hasattr(device_like, "to_torch_device"):
device = device_like.to_torch_device()
else:
device = torch.device(str(device_like))

return _canonical_device(device)


def clone_module_for_devices(
Expand Down
5 changes: 2 additions & 3 deletions gptqmodel/utils/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
from typing import Any, Dict, List, Optional, Tuple, Type, Union

import accelerate
import threadpoolctl as tctl
import torch
import torch.nn as nn
import transformers
Expand Down Expand Up @@ -57,7 +56,7 @@
from .ctx import ctx
from .device import get_device
from .importer import select_quant_linear
from .logger import setup_logger, log_time_block
from .logger import log_time_block, setup_logger
from .torch import HAS_CUDA, torch_empty_cache, torch_new_stream_ctx


Expand Down Expand Up @@ -597,7 +596,7 @@ def convert_gptq_v1_to_v2_format(

# Limit thread usage to avoid auto-parallizataion regression
# with tctl.threadpool_limits(limits=1):
t = time.time()
time.time()
log.info(
f"Format: Converting `{FORMAT_FIELD_CHECKPOINT}` from `{FORMAT.GPTQ}` to internal `{FORMAT.GPTQ_V2}`.")

Expand Down
23 changes: 23 additions & 0 deletions gptqmodel/utils/offload.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,9 @@
from .torch import CPU, META


_SMALL_MODULE_OFFLOAD_BYTES = 4 * 1024 # Skip disk writes for <4KB payloads


# Patch fix thread unsafe accelerate.utils.modeling.clear_device_cache
def _fake_clear_device_cache(garbage_collection=False):
pass
Expand Down Expand Up @@ -74,6 +77,14 @@ def _prepare_offload_directory(target_dir: str) -> None:
os.makedirs(target_dir, exist_ok=True)


def _tensor_nbytes(tensor: torch.Tensor) -> int:
try:
itemsize = tensor.element_size()
except RuntimeError:
itemsize = torch.empty((), dtype=tensor.dtype).element_size()
return tensor.numel() * itemsize


def _bundle_module_state_dict(module: nn.Module, offload_dir: str) -> dict:
bundle_path = os.path.join(offload_dir, "module.safetensors")
index: dict[str, dict] = {}
Expand Down Expand Up @@ -177,6 +188,18 @@ def _offload_disk(module: nn.Module, name: str, disk_path: str = "."):

module_offload_dir = os.path.join(disk_path, name)

total_bytes = 0
try:
state_items = module.state_dict().values()
except Exception:
state_items = []

for tensor in state_items:
total_bytes += _tensor_nbytes(tensor)

if total_bytes <= _SMALL_MODULE_OFFLOAD_BYTES:
return

_prepare_offload_directory(module_offload_dir)
_bundle_module_state_dict(module, module_offload_dir)

Expand Down
8 changes: 4 additions & 4 deletions tests/models/test_qwen2_5.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,12 @@

class TestQwen2_5(ModelTest):
NATIVE_MODEL_ID = "/monster/data/model/Qwen2.5-0.5B-Instruct"
QUANT_ARC_MAX_DELTA_FLOOR_PERCENT = 0.2
NATIVE_ARC_CHALLENGE_ACC = 0.2739
NATIVE_ARC_CHALLENGE_ACC_NORM = 0.3055
QUANT_ARC_MAX_DELTA_FLOOR_PERCENT = 0.05
NATIVE_ARC_CHALLENGE_ACC = 0.2705
NATIVE_ARC_CHALLENGE_ACC_NORM = 0.3063
TRUST_REMOTE_CODE = False
APPLY_CHAT_TEMPLATE = True
EVAL_BATCH_SIZE = 6
#EVAL_BATCH_SIZE = 6

def test_qwen2_5(self):
self.quant_lm_eval()
4 changes: 2 additions & 2 deletions tests/models/test_qwen3_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,9 @@ class TestQwen3Moe(ModelTest):
DEBUG = True
ACT_GROUP_AWARE = True
DESC_ACT = False
DATASET_SIZE = 1024
DATASET_SIZE = 2048
DATASET_SORT = "desc"
QUANT_BATCH_SIZE = 1
QUANT_BATCH_SIZE = 8
CALIB_NOISE_MODE = "unseen"
CALIB_NOISE_PERCENT = 0.025

Expand Down
1 change: 1 addition & 0 deletions tests/pytest.ini
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,4 @@ log_cli=true
norecursedirs = tasks evalplus_results
markers =
ci: CPU-only CI regression coverage for DeviceThreadPool affinity behaviour
cuda: Requires CUDA device