diff --git a/gptqmodel/utils/nogil_patcher.py b/gptqmodel/utils/nogil_patcher.py index 88a3118e1..e2779f603 100644 --- a/gptqmodel/utils/nogil_patcher.py +++ b/gptqmodel/utils/nogil_patcher.py @@ -9,10 +9,14 @@ import time from .safe import ThreadSafe +from importlib.metadata import version +from packaging.version import Version, InvalidVersion _PATCHED_ATTR = "_gptqmodel_locked_save_file" +TRITON_MIN_VERSION_STR = "3.5.0" + def patch_safetensors_save_file() -> None: try: @@ -38,9 +42,16 @@ def patch_triton_autotuner() -> None: except ImportError: return - version = getattr(triton, "__version__", None) - if version is None or tuple(int(part) for part in version.split(".")[:3]) < (3, 5, 0): - return + triton_version_str = version("triton") + + try: + triton_ver = Version(triton_version_str) + triton_min_version = Version(TRITON_MIN_VERSION_STR) + if triton_ver < triton_min_version: + return + except InvalidVersion: + if triton_version_str < TRITON_MIN_VERSION_STR: + return autotuner_cls = module.Autotuner if getattr(autotuner_cls, "_gptqmodel_threadsafe", False):