diff --git a/gptqmodel/looper/awq_processor.py b/gptqmodel/looper/awq_processor.py index f3f10d16a..0f856996f 100644 --- a/gptqmodel/looper/awq_processor.py +++ b/gptqmodel/looper/awq_processor.py @@ -13,11 +13,11 @@ from torch import nn from torch.nn import Module -from ..looper.loop_processor import LoopProcessor, get_max_memory +from ..looper.loop_processor import LoopProcessor from ..looper.named_module import NamedModule from ..models import BaseQModel from ..models.writer import (PROCESS_LOG_LAYER, PROCESS_LOG_MODULE, PROCESS_LOG_NAME, - PROCESS_LOG_TIME, PROCESS_MAX_MEMORY, QUANT_LOG_LOSS, QUANT_LOG_NSAMPLES) + PROCESS_LOG_TIME, PROCESS_USED_MEMORY, QUANT_LOG_LOSS, QUANT_LOG_NSAMPLES) from ..nn_modules.qlinear.awq_gemm import AwqGEMMQuantLinear from ..nn_modules.qlinear.awq_gemv import AwqGEMVQuantLinear from ..nn_modules.qlinear.awq_gemv_fast import AwqGEMVFastQuantLinear @@ -718,7 +718,7 @@ def _apply_quant(self, module, named_linears: Dict[str, NamedModule], start_time # QUANT_LOG_DAMP: f"{damp_percent:.5f}", PROCESS_LOG_TIME: f"{duration:.3f}", # PROCESS_LOG_FWD_TIME: f"{self.fwd_time:.3f}", - PROCESS_MAX_MEMORY: get_max_memory(), + PROCESS_USED_MEMORY: self.device_memory_report(), } self.module_names.append(f"layer-{named_module.layer_index}-{named_module.name}") diff --git a/gptqmodel/looper/eora_processor.py b/gptqmodel/looper/eora_processor.py index b002b1a07..5d62a5ba0 100644 --- a/gptqmodel/looper/eora_processor.py +++ b/gptqmodel/looper/eora_processor.py @@ -17,7 +17,7 @@ from ..looper.named_module import NamedModule from ..models import BaseQModel from ..models.writer import (PROCESS_LOG_FWD_TIME, PROCESS_LOG_LAYER, PROCESS_LOG_MODULE, - PROCESS_LOG_NAME, PROCESS_LOG_TIME, PROCESS_MAX_MEMORY) + PROCESS_LOG_NAME, PROCESS_LOG_TIME, PROCESS_USED_MEMORY) from ..quantization.config import QuantizeConfig from ..utils.logger import setup_logger from ..utils.model import move_to @@ -178,7 +178,7 @@ def process(self, module: NamedModule): PROCESS_LOG_MODULE: module.name, PROCESS_LOG_TIME: f"{duration:.3f}", PROCESS_LOG_FWD_TIME: self.formatted_fwd_time(), - PROCESS_MAX_MEMORY: max_memory, + PROCESS_USED_MEMORY: max_memory, } if self.qcfg.dynamic is not None: diff --git a/gptqmodel/looper/gptq_processor.py b/gptqmodel/looper/gptq_processor.py index 37a5a483d..cd5604471 100644 --- a/gptqmodel/looper/gptq_processor.py +++ b/gptqmodel/looper/gptq_processor.py @@ -11,12 +11,12 @@ import torch from torch.nn import Module -from ..looper.loop_processor import LoopProcessor, get_max_memory +from ..looper.loop_processor import LoopProcessor from ..looper.named_module import NamedModule from ..models import BaseQModel from ..models._const import CPU from ..models.writer import (PROCESS_LOG_FWD_TIME, PROCESS_LOG_LAYER, PROCESS_LOG_MODULE, PROCESS_LOG_NAME, - PROCESS_LOG_TIME, PROCESS_MAX_MEMORY, QUANT_LOG_DAMP, QUANT_LOG_LOSS, QUANT_LOG_NSAMPLES) + PROCESS_LOG_TIME, PROCESS_USED_MEMORY, QUANT_LOG_DAMP, QUANT_LOG_LOSS, QUANT_LOG_NSAMPLES) from ..quantization import GPTQ, GPTQv2 from ..quantization.config import METHOD, QuantizeConfig from ..utils.importer import select_quant_linear @@ -166,7 +166,7 @@ def process(self, module: NamedModule): QUANT_LOG_DAMP: f"{damp_percent:.5f}", PROCESS_LOG_TIME: f"{duration:.3f}", PROCESS_LOG_FWD_TIME: self.formatted_fwd_time(), - PROCESS_MAX_MEMORY: get_max_memory(), + PROCESS_USED_MEMORY: self.device_memory_report(), } if self.qcfg.dynamic is not None: diff --git a/gptqmodel/looper/loop_processor.py b/gptqmodel/looper/loop_processor.py index c0821b4bb..e177434d4 100644 --- a/gptqmodel/looper/loop_processor.py +++ b/gptqmodel/looper/loop_processor.py @@ -6,8 +6,9 @@ import queue import threading from datetime import datetime -from typing import Any, Callable, Dict, List, Optional, Tuple +from typing import Any, Callable, Dict, List, Optional, Set, Tuple +from device_smi import Device import torch from random_word import RandomWords from torch import Tensor @@ -17,7 +18,6 @@ from ..looper.named_module import NamedModule from ..models import BaseQModel from ..quantization.config import QuantizeConfig -from ..utils.device import get_cpu_usage_memory, get_gpu_usage_memory from ..utils.logger import setup_logger from ..utils.torch import DEVICE_0, DEVICE_1 @@ -93,6 +93,9 @@ def __init__( self.log_tmp_log_file_name = f"{self.name()}_log_{RandomWords().get_random_word()}_time_{current_time}.log" self.log_worker_queue = queue.Queue() self.log_worker: threading.Thread = None + self._device_smi_handles = self._init_device_smi_handles() + self._cpu_device_smi = self._init_cpu_device_handle() + self._device_metric_failures: Set[str] = set() if self.logger_board == "clearml": try: @@ -224,6 +227,95 @@ def log_new_row(self, stat): log.info(formatted_row) log.info(len(formatted_row) * "-") + def _init_device_smi_handles(self) -> Dict[str, Device]: + handles: Dict[str, Device] = {} + + for device_id in self._discover_accelerator_devices(): + try: + handles[device_id] = Device(device_id) + except Exception as exc: # pragma: no cover - defensive, external tool + log.debug(f"Device-SMI initialisation failed for `{device_id}`: {exc}") + + return handles + + def _init_cpu_device_handle(self) -> Optional[Device]: + try: + return Device("cpu") + except Exception as exc: # pragma: no cover - defensive, external tool + log.debug(f"Device-SMI CPU initialisation failed: {exc}") + return None + + def _discover_accelerator_devices(self) -> List[str]: + devices: List[str] = [] + + if hasattr(torch, "cuda"): + try: + if torch.cuda.is_available(): + device_type = "rocm" if getattr(torch.version, "hip", None) else "cuda" + for idx in range(torch.cuda.device_count()): + devices.append(f"{device_type}:{idx}") + except Exception: # pragma: no cover - defensive, CUDA runtime differences + pass + + xpu = getattr(torch, "xpu", None) + if xpu is not None: + try: + if torch.xpu.is_available(): + for idx in range(torch.xpu.device_count()): + devices.append(f"xpu:{idx}") + except Exception: # pragma: no cover - defensive, XPU runtime differences + pass + + return devices + + def _safe_query_metric(self, device_key: str, handle: Device): + try: + return handle.metrics(fast=True) + except Exception as exc: # pragma: no cover - defensive, external tool + if device_key not in self._device_metric_failures: + log.debug(f"Device-SMI metrics failed for `{device_key}`: {exc}") + self._device_metric_failures.add(device_key) + return None + + def _snapshot_device_memory_gib(self) -> Dict[str, float]: + snapshot: Dict[str, float] = {} + for device_id, handle in self._device_smi_handles.items(): + metrics = self._safe_query_metric(device_id, handle) + if metrics is None: + continue + snapshot[device_id] = metrics.memory_used / (1024 ** 3) + return snapshot + + def _snapshot_cpu_memory_gib(self) -> Optional[float]: + if self._cpu_device_smi is None: + return None + metrics = self._safe_query_metric("cpu", self._cpu_device_smi) + if metrics is None: + return None + return metrics.memory_used / (1024 ** 3) + + def device_memory_report(self) -> str: + snapshot = self._snapshot_device_memory_gib() + if not snapshot: + return "n/a" + parts = [f"{device_id}={value:.1f}GB" for device_id, value in snapshot.items()] + return ", ".join(parts) + + def _close_device_smi_handles(self) -> None: + for handle in self._device_smi_handles.values(): + try: + handle.close() + except Exception: + pass + self._device_smi_handles.clear() + + if self._cpu_device_smi is not None: + try: + self._cpu_device_smi.close() + except Exception: + pass + self._cpu_device_smi = None + # Loop Procssor level scoped state data def result_save(self, key: str, value: Any): with self._results_lock: @@ -250,22 +342,25 @@ def results(self): def collect_memory_info(self, layer_index: int): if self.logger_task is not None: - gpu_memory = get_gpu_usage_memory() - cpu_memory = get_cpu_usage_memory() + device_snapshot = self._snapshot_device_memory_gib() + total_gpu_memory = sum(device_snapshot.values()) if device_snapshot else 0.0 + self.logger_task.get_logger().report_scalar( title='GPU Memory', series='GPU Memory', - value=gpu_memory, + value=total_gpu_memory, iteration=layer_index, ) + cpu_memory = self._snapshot_cpu_memory_gib() or 0.0 + self.logger_task.get_logger().report_scalar( title='CPU Memory', series='CPU Memory', value=cpu_memory, iteration=layer_index, ) - self.gpu_memorys.append(gpu_memory) + self.gpu_memorys.append(total_gpu_memory) self.cpu_memorys.append(cpu_memory) def log_plotly(self): @@ -322,6 +417,7 @@ def submodule_finalize(self, module: NamedModule, model: BaseQModel, **kwargs): # last step, after all loop processor is called # finalize is called in reverse after all next sequential processes are called def finalize(self, model: BaseQModel, **kwargs): + self._close_device_smi_handles() del self.inputs_cache del self._results diff --git a/gptqmodel/models/writer.py b/gptqmodel/models/writer.py index 597a34d24..393103c64 100644 --- a/gptqmodel/models/writer.py +++ b/gptqmodel/models/writer.py @@ -67,7 +67,7 @@ QUANT_LOG_DAMP = "damp" PROCESS_LOG_TIME = "time" PROCESS_LOG_FWD_TIME = "fwd_time" -PROCESS_MAX_MEMORY = "max_vram" +PROCESS_USED_MEMORY = "(v)ram" EORA_DEFAULT_FILE = "eora.safetensors" diff --git a/pyproject.toml b/pyproject.toml index c12194e83..bb98f6fc8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,7 @@ [build-system] requires = [ - "setuptools >= 64", + "setuptools>=80.9.0", + "ninja>=1.13.0" ] build-backend = "setuptools.build_meta:__legacy__" @@ -37,17 +38,15 @@ dependencies = [ "transformers>=4.56.0", "threadpoolctl>=3.6.0", "packaging>=24.2", - "device-smi==0.4.1", + "device-smi>=0.5.0", "protobuf>=6.32.0", "pillow>=11.3.0", "hf_transfer>=0.1.9", "huggingface_hub>=0.34.4", - "random_word==1.0.13", + "random_word>=1.0.13", "tokenicer>=0.0.5", "logbar>=0.0.5", - # "soundfile==0.13.1", # Qwen-Omni dependent pkg - "wheel>=0.45.1", - "maturin>=1.9.3", + # "flash-attn>=2.8.3", <-- install for lower vram usage ] [project.urls]