diff --git a/gptqmodel/utils/env.py b/gptqmodel/utils/env.py index e2839083c..4f1907a1d 100644 --- a/gptqmodel/utils/env.py +++ b/gptqmodel/utils/env.py @@ -13,13 +13,17 @@ _TRUTHY = {"1", "true", "yes", "on", "y"} -def env_flag(name: str, default: str | None = "0") -> bool: +def env_flag(name: str, default: str | bool | None = "0") -> bool: """Return ``True`` when an env var is set to a truthy value.""" - value = os.getenv(name, default) + value = os.getenv(name) if value is None: - return False - return value.strip().lower() in _TRUTHY + if default is None: + return False + if isinstance(default, bool): + return default + value = default + return str(value).strip().lower() in _TRUTHY __all__ = ["env_flag"] diff --git a/gptqmodel/utils/looper_helpers.py b/gptqmodel/utils/looper_helpers.py index 5e79876ae..b77e4e8bd 100644 --- a/gptqmodel/utils/looper_helpers.py +++ b/gptqmodel/utils/looper_helpers.py @@ -24,7 +24,7 @@ from ..utils.torch import ALL_DEVICES, CPU, torch_sync -USE_TORCH_REPLICATE = env_flag("GPTQMODEL_USE_TORCH_REPLICATE") +USE_TORCH_REPLICATE = env_flag("GPTQMODEL_USE_TORCH_REPLICATE", True) _THREAD_SAFE_PARALLEL = ThreadSafe(torch_parallel) diff --git a/tests/pytest.ini b/tests/pytest.ini index dfd34a073..f8fe22115 100644 --- a/tests/pytest.ini +++ b/tests/pytest.ini @@ -5,3 +5,6 @@ norecursedirs = tasks evalplus_results markers = ci: CPU-only CI regression coverage for DeviceThreadPool affinity behaviour cuda: Requires CUDA device + inference: Inference workloads that replicate models across devices +filterwarnings = + ignore:Warning only once for all operators.*:UserWarning