From d53e549143e30081d595a5a35c1bc0242b5cf237 Mon Sep 17 00:00:00 2001 From: Qubitium Date: Tue, 14 Oct 2025 06:34:26 +0000 Subject: [PATCH] dedup Signed-off-by: Qubitium --- gptqmodel/__init__.py | 4 +++- gptqmodel/nn_modules/qlinear/__init__.py | 5 +++-- gptqmodel/utils/looper_helpers.py | 6 +++++- 3 files changed, 11 insertions(+), 4 deletions(-) diff --git a/gptqmodel/__init__.py b/gptqmodel/__init__.py index f4034a874..5ae0666d9 100644 --- a/gptqmodel/__init__.py +++ b/gptqmodel/__init__.py @@ -5,8 +5,10 @@ import os +from .utils.env import env_flag -DEBUG_ON = str(os.environ.get("DEBUG", "")).lower() in ("1", "true", "yes", "on") + +DEBUG_ON = env_flag("DEBUG") from .utils.threadx import DeviceThreadPool diff --git a/gptqmodel/nn_modules/qlinear/__init__.py b/gptqmodel/nn_modules/qlinear/__init__.py index aa971d500..89b0625bc 100644 --- a/gptqmodel/nn_modules/qlinear/__init__.py +++ b/gptqmodel/nn_modules/qlinear/__init__.py @@ -19,6 +19,7 @@ from ...models._const import DEVICE, PLATFORM from ...utils.backend import BACKEND from ...utils.logger import setup_logger +from ...utils.env import env_flag from ...utils.safe import THREADPOOLCTL @@ -541,8 +542,8 @@ def pack_block( if (in_features % word_bits) != 0: raise ValueError("in_features must be divisible by 32") - disable_ext = os.getenv("GPTQMODEL_DISABLE_PACK_EXT", "").lower() in {"1", "true", "yes"} - force_ext = os.getenv("GPTQMODEL_FORCE_PACK_EXT", "").lower() in {"1", "true", "yes"} + disable_ext = env_flag("GPTQMODEL_DISABLE_PACK_EXT") + force_ext = env_flag("GPTQMODEL_FORCE_PACK_EXT") pack_block_threads = workers if workers and workers > 0 else 1 env_threads = os.getenv("GPTQMODEL_PACK_THREADS") if env_threads: diff --git a/gptqmodel/utils/looper_helpers.py b/gptqmodel/utils/looper_helpers.py index a0d665463..80b93ce12 100644 --- a/gptqmodel/utils/looper_helpers.py +++ b/gptqmodel/utils/looper_helpers.py @@ -15,12 +15,15 @@ from .. import DEBUG_ON, DEVICE_THREAD_POOL from ..nn_modules.hooked_linear import StopForward from ..utils.attn_mask import normalize_seq_mask +from ..utils.env import env_flag from ..utils.device import get_device from ..utils.logger import setup_logger from ..utils.model import move_to, nested_move_to from ..utils.safe import ThreadSafe from ..utils.torch import ALL_DEVICES, CPU, torch_sync +USE_TORCH_REPLICATE = env_flag("GPTQMODEL_USE_TORCH_REPLICATE") + _THREAD_SAFE_PARALLEL = ThreadSafe(torch_parallel) @@ -261,7 +264,8 @@ def _prepare_module(target_device: torch.device, step_name: str) -> None: _record(step_name, start_ts) use_replicate = ( - homogeneous_type + USE_TORCH_REPLICATE + and homogeneous_type and backend_available(device_type) and device_type != "cpu" )