Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion gptqmodel/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,10 @@

import os

from .utils.env import env_flag

DEBUG_ON = str(os.environ.get("DEBUG", "")).lower() in ("1", "true", "yes", "on")

DEBUG_ON = env_flag("DEBUG")

from .utils.threadx import DeviceThreadPool

Expand Down
5 changes: 3 additions & 2 deletions gptqmodel/nn_modules/qlinear/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from ...models._const import DEVICE, PLATFORM
from ...utils.backend import BACKEND
from ...utils.logger import setup_logger
from ...utils.env import env_flag
from ...utils.safe import THREADPOOLCTL


Expand Down Expand Up @@ -541,8 +542,8 @@ def pack_block(
if (in_features % word_bits) != 0:
raise ValueError("in_features must be divisible by 32")

disable_ext = os.getenv("GPTQMODEL_DISABLE_PACK_EXT", "").lower() in {"1", "true", "yes"}
force_ext = os.getenv("GPTQMODEL_FORCE_PACK_EXT", "").lower() in {"1", "true", "yes"}
disable_ext = env_flag("GPTQMODEL_DISABLE_PACK_EXT")
force_ext = env_flag("GPTQMODEL_FORCE_PACK_EXT")
pack_block_threads = workers if workers and workers > 0 else 1
env_threads = os.getenv("GPTQMODEL_PACK_THREADS")
if env_threads:
Expand Down
6 changes: 5 additions & 1 deletion gptqmodel/utils/looper_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,15 @@
from .. import DEBUG_ON, DEVICE_THREAD_POOL
from ..nn_modules.hooked_linear import StopForward
from ..utils.attn_mask import normalize_seq_mask
from ..utils.env import env_flag
from ..utils.device import get_device
from ..utils.logger import setup_logger
from ..utils.model import move_to, nested_move_to
from ..utils.safe import ThreadSafe
from ..utils.torch import ALL_DEVICES, CPU, torch_sync

USE_TORCH_REPLICATE = env_flag("GPTQMODEL_USE_TORCH_REPLICATE")


_THREAD_SAFE_PARALLEL = ThreadSafe(torch_parallel)

Expand Down Expand Up @@ -261,7 +264,8 @@ def _prepare_module(target_device: torch.device, step_name: str) -> None:
_record(step_name, start_ts)

use_replicate = (
homogeneous_type
USE_TORCH_REPLICATE
and homogeneous_type
and backend_available(device_type)
and device_type != "cpu"
)
Expand Down