diff --git a/gptqmodel/models/loader.py b/gptqmodel/models/loader.py index 5edef870d..4dbbd82b0 100644 --- a/gptqmodel/models/loader.py +++ b/gptqmodel/models/loader.py @@ -35,6 +35,7 @@ from ..quantization.config import FORMAT, METHOD, MIN_VERSION_WITH_V2 from ..utils.backend import BACKEND from ..utils.importer import auto_select_device, normalize_device_device_map, select_quant_linear +from ..utils.inspect import safe_kwargs_call from ..utils.logger import setup_logger from ..utils.machete import _validate_machete_device_support from ..utils.marlin import _validate_marlin_device_support @@ -56,6 +57,8 @@ log = setup_logger() ATTN_IMPLEMENTATION = "attn_implementation" + + def parse_version_string(version_str: str): try: return Version(version_str) @@ -105,12 +108,15 @@ def get_model_local_path(pretrained_model_id_or_path, **kwargs): is_local = os.path.isdir(pretrained_model_id_or_path) if is_local: return os.path.normpath(pretrained_model_id_or_path) - else: - # Clone kwargs before modifying - download_kwargs = kwargs.copy() - download_kwargs.pop("attn_implementation", None) - download_kwargs.pop("use_flash_attention_2", None) - return snapshot_download(pretrained_model_id_or_path, **download_kwargs) + def _log_removed(removed: list[str]): + log.debug("Loader: dropping unsupported snapshot_download kwargs: %s", ", ".join(removed)) + + return safe_kwargs_call( + snapshot_download, + pretrained_model_id_or_path, + kwargs=kwargs, + on_removed=_log_removed, + ) def ModelLoader(cls): diff --git a/gptqmodel/models/writer.py b/gptqmodel/models/writer.py index e9460a7b8..4f7d2f4ae 100644 --- a/gptqmodel/models/writer.py +++ b/gptqmodel/models/writer.py @@ -30,13 +30,13 @@ META_FIELD_ACT_GROUP_AWARE, META_FIELD_DAMP_AUTO_INCREMENT, META_FIELD_DAMP_PERCENT, + META_FIELD_GPTAQ_ALPHA, + META_FIELD_GPTAQ_ENABLED, META_FIELD_MSE, META_FIELD_QUANTIZER, META_FIELD_STATIC_GROUPS, META_FIELD_TRUE_SEQUENTIAL, META_FIELD_URI, - META_FIELD_GPTAQ_ALPHA, - META_FIELD_GPTAQ_ENABLED, META_QUANTIZER_GPTQMODEL, META_VALUE_URI, MIN_VERSION_WITH_V2, diff --git a/gptqmodel/utils/inspect.py b/gptqmodel/utils/inspect.py new file mode 100644 index 000000000..6c3bb3245 --- /dev/null +++ b/gptqmodel/utils/inspect.py @@ -0,0 +1,64 @@ +# SPDX-FileCopyrightText: 2024-2025 ModelCloud.ai +# SPDX-FileCopyrightText: 2024-2025 qubitium@modelcloud.ai +# SPDX-License-Identifier: Apache-2.0 +# Contact: qubitium@modelcloud.ai, x.com/qubitium + +from __future__ import annotations + +import inspect +from functools import lru_cache +from typing import Any, Callable, FrozenSet, Optional, Tuple + + +SupportedKwargInfo = Tuple[bool, Optional[FrozenSet[str]]] + + +@lru_cache(maxsize=None) +def get_supported_kwargs(callable_obj: Callable) -> SupportedKwargInfo: + """Return (accepts_var_kwargs, allowed_kwargs) for a callable. + + allowed_kwargs is None when the callable uses ``**kwargs`` or when inspection fails. + """ + try: + signature = inspect.signature(callable_obj) + except (TypeError, ValueError): + return True, None + + if any(param.kind == inspect.Parameter.VAR_KEYWORD for param in signature.parameters.values()): + return True, None + + allowed = frozenset( + name + for name, param in signature.parameters.items() + if param.kind in (inspect.Parameter.POSITIONAL_OR_KEYWORD, inspect.Parameter.KEYWORD_ONLY) + ) + return False, allowed + + +def safe_kwargs_call( + callable_obj: Callable, + *args: Any, + kwargs: Optional[dict] = None, + on_removed: Optional[Callable[[list[str]], None]] = None, +): + """Invoke ``callable_obj`` with kwargs filtered against its signature. + + Many third-party helpers (e.g., hub download utilities) have a strict + keyword signature. This helper allows callers to gather keyword arguments + from multiple sources, filter out unsupported ones via inspection, and + invoke the callable safely without tripping ``TypeError``. When the + callable accepts ``**kwargs`` or inspection fails, the original kwargs are + forwarded unchanged. + """ + + kwargs = dict(kwargs or {}) + accepts_var_kw, allowed_kwargs = get_supported_kwargs(callable_obj) + if accepts_var_kw or allowed_kwargs is None: + return callable_obj(*args, **kwargs) + + filtered = {key: value for key, value in kwargs.items() if key in allowed_kwargs} + if on_removed is not None: + removed = sorted(key for key in kwargs if key not in allowed_kwargs) + if removed: + on_removed(removed) + return callable_obj(*args, **filtered)