Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions gptqmodel/looper/awq_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,11 @@
from torch import nn
from torch.nn import Module

from ..looper.loop_processor import LoopProcessor, get_max_memory
from ..looper.loop_processor import LoopProcessor
from ..looper.named_module import NamedModule
from ..models import BaseQModel
from ..models.writer import (PROCESS_LOG_LAYER, PROCESS_LOG_MODULE, PROCESS_LOG_NAME,
PROCESS_LOG_TIME, PROCESS_MAX_MEMORY, QUANT_LOG_LOSS, QUANT_LOG_NSAMPLES)
PROCESS_LOG_TIME, PROCESS_USED_MEMORY, QUANT_LOG_LOSS, QUANT_LOG_NSAMPLES)
from ..nn_modules.qlinear.awq_gemm import AwqGEMMQuantLinear
from ..nn_modules.qlinear.awq_gemv import AwqGEMVQuantLinear
from ..nn_modules.qlinear.awq_gemv_fast import AwqGEMVFastQuantLinear
Expand Down Expand Up @@ -718,7 +718,7 @@ def _apply_quant(self, module, named_linears: Dict[str, NamedModule], start_time
# QUANT_LOG_DAMP: f"{damp_percent:.5f}",
PROCESS_LOG_TIME: f"{duration:.3f}",
# PROCESS_LOG_FWD_TIME: f"{self.fwd_time:.3f}",
PROCESS_MAX_MEMORY: get_max_memory(),
PROCESS_USED_MEMORY: self.device_memory_report(),
}
self.module_names.append(f"layer-{named_module.layer_index}-{named_module.name}")

Expand Down
4 changes: 2 additions & 2 deletions gptqmodel/looper/eora_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from ..looper.named_module import NamedModule
from ..models import BaseQModel
from ..models.writer import (PROCESS_LOG_FWD_TIME, PROCESS_LOG_LAYER, PROCESS_LOG_MODULE,
PROCESS_LOG_NAME, PROCESS_LOG_TIME, PROCESS_MAX_MEMORY)
PROCESS_LOG_NAME, PROCESS_LOG_TIME, PROCESS_USED_MEMORY)
from ..quantization.config import QuantizeConfig
from ..utils.logger import setup_logger
from ..utils.model import move_to
Expand Down Expand Up @@ -178,7 +178,7 @@ def process(self, module: NamedModule):
PROCESS_LOG_MODULE: module.name,
PROCESS_LOG_TIME: f"{duration:.3f}",
PROCESS_LOG_FWD_TIME: self.formatted_fwd_time(),
PROCESS_MAX_MEMORY: max_memory,
PROCESS_USED_MEMORY: max_memory,
}

if self.qcfg.dynamic is not None:
Expand Down
6 changes: 3 additions & 3 deletions gptqmodel/looper/gptq_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,12 @@
import torch
from torch.nn import Module

from ..looper.loop_processor import LoopProcessor, get_max_memory
from ..looper.loop_processor import LoopProcessor
from ..looper.named_module import NamedModule
from ..models import BaseQModel
from ..models._const import CPU
from ..models.writer import (PROCESS_LOG_FWD_TIME, PROCESS_LOG_LAYER, PROCESS_LOG_MODULE, PROCESS_LOG_NAME,
PROCESS_LOG_TIME, PROCESS_MAX_MEMORY, QUANT_LOG_DAMP, QUANT_LOG_LOSS, QUANT_LOG_NSAMPLES)
PROCESS_LOG_TIME, PROCESS_USED_MEMORY, QUANT_LOG_DAMP, QUANT_LOG_LOSS, QUANT_LOG_NSAMPLES)
from ..quantization import GPTQ, GPTQv2
from ..quantization.config import METHOD, QuantizeConfig
from ..utils.importer import select_quant_linear
Expand Down Expand Up @@ -166,7 +166,7 @@ def process(self, module: NamedModule):
QUANT_LOG_DAMP: f"{damp_percent:.5f}",
PROCESS_LOG_TIME: f"{duration:.3f}",
PROCESS_LOG_FWD_TIME: self.formatted_fwd_time(),
PROCESS_MAX_MEMORY: get_max_memory(),
PROCESS_USED_MEMORY: self.device_memory_report(),
}

if self.qcfg.dynamic is not None:
Expand Down
108 changes: 102 additions & 6 deletions gptqmodel/looper/loop_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,9 @@
import queue
import threading
from datetime import datetime
from typing import Any, Callable, Dict, List, Optional, Tuple
from typing import Any, Callable, Dict, List, Optional, Set, Tuple

from device_smi import Device
import torch
from random_word import RandomWords
from torch import Tensor
Expand All @@ -17,7 +18,6 @@
from ..looper.named_module import NamedModule
from ..models import BaseQModel
from ..quantization.config import QuantizeConfig
from ..utils.device import get_cpu_usage_memory, get_gpu_usage_memory
from ..utils.logger import setup_logger
from ..utils.torch import DEVICE_0, DEVICE_1

Expand Down Expand Up @@ -93,6 +93,9 @@ def __init__(
self.log_tmp_log_file_name = f"{self.name()}_log_{RandomWords().get_random_word()}_time_{current_time}.log"
self.log_worker_queue = queue.Queue()
self.log_worker: threading.Thread = None
self._device_smi_handles = self._init_device_smi_handles()
self._cpu_device_smi = self._init_cpu_device_handle()
self._device_metric_failures: Set[str] = set()

if self.logger_board == "clearml":
try:
Expand Down Expand Up @@ -224,6 +227,95 @@ def log_new_row(self, stat):
log.info(formatted_row)
log.info(len(formatted_row) * "-")

def _init_device_smi_handles(self) -> Dict[str, Device]:
handles: Dict[str, Device] = {}

for device_id in self._discover_accelerator_devices():
try:
handles[device_id] = Device(device_id)
except Exception as exc: # pragma: no cover - defensive, external tool
log.debug(f"Device-SMI initialisation failed for `{device_id}`: {exc}")

return handles

def _init_cpu_device_handle(self) -> Optional[Device]:
try:
return Device("cpu")
except Exception as exc: # pragma: no cover - defensive, external tool
log.debug(f"Device-SMI CPU initialisation failed: {exc}")
return None

def _discover_accelerator_devices(self) -> List[str]:
devices: List[str] = []

if hasattr(torch, "cuda"):
try:
if torch.cuda.is_available():
device_type = "rocm" if getattr(torch.version, "hip", None) else "cuda"
for idx in range(torch.cuda.device_count()):
devices.append(f"{device_type}:{idx}")
except Exception: # pragma: no cover - defensive, CUDA runtime differences
pass

xpu = getattr(torch, "xpu", None)
if xpu is not None:
try:
if torch.xpu.is_available():
for idx in range(torch.xpu.device_count()):
devices.append(f"xpu:{idx}")
except Exception: # pragma: no cover - defensive, XPU runtime differences
pass

return devices

def _safe_query_metric(self, device_key: str, handle: Device):
try:
return handle.metrics(fast=True)
except Exception as exc: # pragma: no cover - defensive, external tool
if device_key not in self._device_metric_failures:
log.debug(f"Device-SMI metrics failed for `{device_key}`: {exc}")
self._device_metric_failures.add(device_key)
return None

def _snapshot_device_memory_gib(self) -> Dict[str, float]:
snapshot: Dict[str, float] = {}
for device_id, handle in self._device_smi_handles.items():
metrics = self._safe_query_metric(device_id, handle)
if metrics is None:
continue
snapshot[device_id] = metrics.memory_used / (1024 ** 3)
return snapshot

def _snapshot_cpu_memory_gib(self) -> Optional[float]:
if self._cpu_device_smi is None:
return None
metrics = self._safe_query_metric("cpu", self._cpu_device_smi)
if metrics is None:
return None
return metrics.memory_used / (1024 ** 3)

def device_memory_report(self) -> str:
snapshot = self._snapshot_device_memory_gib()
if not snapshot:
return "n/a"
parts = [f"{device_id}={value:.1f}GB" for device_id, value in snapshot.items()]
return ", ".join(parts)

def _close_device_smi_handles(self) -> None:
for handle in self._device_smi_handles.values():
try:
handle.close()
except Exception:
pass
self._device_smi_handles.clear()

if self._cpu_device_smi is not None:
try:
self._cpu_device_smi.close()
except Exception:
pass
self._cpu_device_smi = None

# Loop Procssor level scoped state data
def result_save(self, key: str, value: Any):
with self._results_lock:
Expand All @@ -250,22 +342,25 @@ def results(self):

def collect_memory_info(self, layer_index: int):
if self.logger_task is not None:
gpu_memory = get_gpu_usage_memory()
cpu_memory = get_cpu_usage_memory()
device_snapshot = self._snapshot_device_memory_gib()
total_gpu_memory = sum(device_snapshot.values()) if device_snapshot else 0.0

self.logger_task.get_logger().report_scalar(
title='GPU Memory',
series='GPU Memory',
value=gpu_memory,
value=total_gpu_memory,
iteration=layer_index,
)

cpu_memory = self._snapshot_cpu_memory_gib() or 0.0

self.logger_task.get_logger().report_scalar(
title='CPU Memory',
series='CPU Memory',
value=cpu_memory,
iteration=layer_index,
)
self.gpu_memorys.append(gpu_memory)
self.gpu_memorys.append(total_gpu_memory)
self.cpu_memorys.append(cpu_memory)

def log_plotly(self):
Expand Down Expand Up @@ -322,6 +417,7 @@ def submodule_finalize(self, module: NamedModule, model: BaseQModel, **kwargs):
# last step, after all loop processor is called
# finalize is called in reverse after all next sequential processes are called
def finalize(self, model: BaseQModel, **kwargs):
self._close_device_smi_handles()
del self.inputs_cache
del self._results

Expand Down
2 changes: 1 addition & 1 deletion gptqmodel/models/writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@
QUANT_LOG_DAMP = "damp"
PROCESS_LOG_TIME = "time"
PROCESS_LOG_FWD_TIME = "fwd_time"
PROCESS_MAX_MEMORY = "max_vram"
PROCESS_USED_MEMORY = "(v)ram"

EORA_DEFAULT_FILE = "eora.safetensors"

Expand Down
11 changes: 5 additions & 6 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
[build-system]
requires = [
"setuptools >= 64",
"setuptools>=80.9.0",
"ninja>=1.13.0"
]
build-backend = "setuptools.build_meta:__legacy__"

Expand Down Expand Up @@ -37,17 +38,15 @@ dependencies = [
"transformers>=4.56.0",
"threadpoolctl>=3.6.0",
"packaging>=24.2",
"device-smi==0.4.1",
"device-smi>=0.5.0",
"protobuf>=6.32.0",
"pillow>=11.3.0",
"hf_transfer>=0.1.9",
"huggingface_hub>=0.34.4",
"random_word==1.0.13",
"random_word>=1.0.13",
"tokenicer>=0.0.5",
"logbar>=0.0.5",
# "soundfile==0.13.1", # Qwen-Omni dependent pkg
"wheel>=0.45.1",
"maturin>=1.9.3",
# "flash-attn>=2.8.3", <-- install for lower vram usage
]

[project.urls]
Expand Down