Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 12 additions & 6 deletions gptqmodel/models/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
from ..quantization.config import FORMAT, METHOD, MIN_VERSION_WITH_V2
from ..utils.backend import BACKEND
from ..utils.importer import auto_select_device, normalize_device_device_map, select_quant_linear
from ..utils.inspect import safe_kwargs_call
from ..utils.logger import setup_logger
from ..utils.machete import _validate_machete_device_support
from ..utils.marlin import _validate_marlin_device_support
Expand All @@ -56,6 +57,8 @@
log = setup_logger()

ATTN_IMPLEMENTATION = "attn_implementation"


def parse_version_string(version_str: str):
try:
return Version(version_str)
Expand Down Expand Up @@ -105,12 +108,15 @@ def get_model_local_path(pretrained_model_id_or_path, **kwargs):
is_local = os.path.isdir(pretrained_model_id_or_path)
if is_local:
return os.path.normpath(pretrained_model_id_or_path)
else:
# Clone kwargs before modifying
download_kwargs = kwargs.copy()
download_kwargs.pop("attn_implementation", None)
download_kwargs.pop("use_flash_attention_2", None)
return snapshot_download(pretrained_model_id_or_path, **download_kwargs)
def _log_removed(removed: list[str]):
log.debug("Loader: dropping unsupported snapshot_download kwargs: %s", ", ".join(removed))

return safe_kwargs_call(
snapshot_download,
pretrained_model_id_or_path,
kwargs=kwargs,
on_removed=_log_removed,
)


def ModelLoader(cls):
Expand Down
4 changes: 2 additions & 2 deletions gptqmodel/models/writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,13 +30,13 @@
META_FIELD_ACT_GROUP_AWARE,
META_FIELD_DAMP_AUTO_INCREMENT,
META_FIELD_DAMP_PERCENT,
META_FIELD_GPTAQ_ALPHA,
META_FIELD_GPTAQ_ENABLED,
META_FIELD_MSE,
META_FIELD_QUANTIZER,
META_FIELD_STATIC_GROUPS,
META_FIELD_TRUE_SEQUENTIAL,
META_FIELD_URI,
META_FIELD_GPTAQ_ALPHA,
META_FIELD_GPTAQ_ENABLED,
META_QUANTIZER_GPTQMODEL,
META_VALUE_URI,
MIN_VERSION_WITH_V2,
Expand Down
64 changes: 64 additions & 0 deletions gptqmodel/utils/inspect.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
# SPDX-FileCopyrightText: 2024-2025 ModelCloud.ai
# SPDX-FileCopyrightText: 2024-2025 qubitium@modelcloud.ai
# SPDX-License-Identifier: Apache-2.0
# Contact: qubitium@modelcloud.ai, x.com/qubitium

from __future__ import annotations

import inspect
from functools import lru_cache
from typing import Any, Callable, FrozenSet, Optional, Tuple


SupportedKwargInfo = Tuple[bool, Optional[FrozenSet[str]]]


@lru_cache(maxsize=None)
def get_supported_kwargs(callable_obj: Callable) -> SupportedKwargInfo:
"""Return (accepts_var_kwargs, allowed_kwargs) for a callable.

allowed_kwargs is None when the callable uses ``**kwargs`` or when inspection fails.
"""
try:
signature = inspect.signature(callable_obj)
except (TypeError, ValueError):
return True, None

if any(param.kind == inspect.Parameter.VAR_KEYWORD for param in signature.parameters.values()):
return True, None

allowed = frozenset(
name
for name, param in signature.parameters.items()
if param.kind in (inspect.Parameter.POSITIONAL_OR_KEYWORD, inspect.Parameter.KEYWORD_ONLY)
)
return False, allowed


def safe_kwargs_call(
callable_obj: Callable,
*args: Any,
kwargs: Optional[dict] = None,
on_removed: Optional[Callable[[list[str]], None]] = None,
):
"""Invoke ``callable_obj`` with kwargs filtered against its signature.

Many third-party helpers (e.g., hub download utilities) have a strict
keyword signature. This helper allows callers to gather keyword arguments
from multiple sources, filter out unsupported ones via inspection, and
invoke the callable safely without tripping ``TypeError``. When the
callable accepts ``**kwargs`` or inspection fails, the original kwargs are
forwarded unchanged.
"""

kwargs = dict(kwargs or {})
accepts_var_kw, allowed_kwargs = get_supported_kwargs(callable_obj)
if accepts_var_kw or allowed_kwargs is None:
return callable_obj(*args, **kwargs)

filtered = {key: value for key, value in kwargs.items() if key in allowed_kwargs}
if on_removed is not None:
removed = sorted(key for key in kwargs if key not in allowed_kwargs)
if removed:
on_removed(removed)
return callable_obj(*args, **filtered)