diff --git a/nvitop/core/libnvml.py b/nvitop/core/libnvml.py index f1895142..5d4c3fc3 100644 --- a/nvitop/core/libnvml.py +++ b/nvitop/core/libnvml.py @@ -5,8 +5,11 @@ # pylint: disable=invalid-name +import ctypes as _ctypes +import functools as _functools import inspect as _inspect import logging as _logging +import os as _os import re as _re import sys as _sys import threading as _threading @@ -49,7 +52,6 @@ # Load members from module `pynvml` and register them in `__all__` and globals. _vars_pynvml = vars(_pynvml) -_vars = _OrderedDict() _name = _attr = None _errcode_to_name = {} _const_names = [] @@ -158,6 +160,13 @@ __lock = _threading.Lock() LOGGER = _logging.getLogger(__name__) +try: + LOGGER.setLevel(_os.getenv('LOGLEVEL', default='WARNING').upper()) +except (ValueError, TypeError): + pass +if not LOGGER.hasHandlers() and LOGGER.isEnabledFor(_logging.INFO): + LOGGER.addHandler(_logging.StreamHandler()) + LOGGER.addHandler(_logging.FileHandler('debug-{}.log'.format(__name__))) UNKNOWN_FUNCTIONS = {} UNKNOWN_FUNCTIONS_CACHE_SIZE = 1024 VERSIONED_PATTERN = _re.compile(r'^(?P\w+)(?P_v(\d)+)$') @@ -368,7 +377,133 @@ def nvmlCheckReturn( return retval != NA and isinstance(retval, types) -# Add support for lookup fallback and context manager. +# Patch layers for backward compatibility ########################################################## +def __patch_backward_compatibility_layers() -> None: + function_name_mapping_lock = _threading.Lock() + function_name_mapping = {} + + def function_mapping_update(mapping): + with function_name_mapping_lock: + mapping = dict(mapping) + for name, mapped_name in function_name_mapping.items(): + if mapped_name in mapping: + mapping[name] = mapping[mapped_name] + function_name_mapping.update(mapping) + return mapping + + def with_mapped_function_name(): + def wrapper(nvmlGetFunctionPointer): + @_functools.wraps(nvmlGetFunctionPointer) + def wrapped(name): + mapped_name = function_name_mapping.get(name, name) + return nvmlGetFunctionPointer(mapped_name) + + return wrapped + + _pynvml.__dict__.update( # need to use module.__dict__.__setitem__ because module.__setattr__ will not work + _nvmlGetFunctionPointer=wrapper( + _pynvml._nvmlGetFunctionPointer # pylint: disable=protected-access + ) + ) + + def patch_function_pointers_when_fail(names, callback): + """Patches the function pointers of the NVML library.""" + + def wrapper(nvmlGetFunctionPointer): + @_functools.wraps(nvmlGetFunctionPointer) + def wrapped(name): + try: + return nvmlGetFunctionPointer(name) + except NVMLError_FunctionNotFound as ex: + if name in names: + new_name = callback(name, names, ex, _pynvml, __modself) + return nvmlGetFunctionPointer(new_name) + raise + + return wrapped + + return wrapper + + def patch_process_info(): + PrintableStructure = _pynvml._PrintableStructure # pylint: disable=protected-access + + # pylint: disable-next=missing-class-docstring,too-few-public-methods + class c_nvmlProcessInfo_v1_t(PrintableStructure): + _fields_ = [ + ('pid', _ctypes.c_uint), + ('usedGpuMemory', _ctypes.c_ulonglong), + ] + _fmt_ = { + 'usedGpuMemory': '%d B', + } + + # pylint: disable-next=missing-class-docstring,too-few-public-methods + class c_nvmlProcessInfo_v2_t(PrintableStructure): + _fields_ = [ + ('pid', _ctypes.c_uint), + ('usedGpuMemory', _ctypes.c_ulonglong), + ('gpuInstanceId', _ctypes.c_uint), + ('computeInstanceId', _ctypes.c_uint), + ] + _fmt_ = { + 'usedGpuMemory': '%d B', + } + + nvmlDeviceGetRunningProcesses_v3_v2 = { + 'nvmlDeviceGetComputeRunningProcesses_v3': 'nvmlDeviceGetComputeRunningProcesses_v2', + 'nvmlDeviceGetGraphicsRunningProcesses_v3': 'nvmlDeviceGetGraphicsRunningProcesses_v2', + 'nvmlDeviceGetMPSComputeRunningProcesses_v3': 'nvmlDeviceGetMPSComputeRunningProcesses_v2', + } + nvmlDeviceGetRunningProcesses_v2_v1 = { + 'nvmlDeviceGetComputeRunningProcesses_v2': 'nvmlDeviceGetComputeRunningProcesses', + 'nvmlDeviceGetGraphicsRunningProcesses_v2': 'nvmlDeviceGetGraphicsRunningProcesses', + 'nvmlDeviceGetMPSComputeRunningProcesses_v2': 'nvmlDeviceGetMPSComputeRunningProcesses', + } + + def callback(name, names, exception, pynvml, modself): # pylint: disable=unused-argument + if name in nvmlDeviceGetRunningProcesses_v3_v2: + mapping = nvmlDeviceGetRunningProcesses_v3_v2 + struct_type = c_nvmlProcessInfo_v2_t + elif name in nvmlDeviceGetRunningProcesses_v2_v1: + mapping = nvmlDeviceGetRunningProcesses_v2_v1 + struct_type = c_nvmlProcessInfo_v1_t + else: + raise exception # no fallbacks for v1 APIs + + LOGGER.debug('Patching NVML function pointer `%s`', name) + mapping = function_mapping_update(mapping) + pynvml.__dict__.update(c_nvmlProcessInfo_t=struct_type) + modself.__dict__.update(c_nvmlProcessInfo_t=struct_type) + + for old_name, mapped_name in mapping.items(): + LOGGER.debug(' Map NVML function `%s` to `%s`', old_name, mapped_name) + LOGGER.debug( + ' Patch NVML struct `c_nvmlProcessInfo_t` to `%s`', struct_type.__name__ + ) + return mapping[name] + + _pynvml.__dict__.update( # need to use module.__dict__.__setitem__ because module.__setattr__ will not work + # The patching ordering is important + _nvmlGetFunctionPointer=patch_function_pointers_when_fail( + names=set(nvmlDeviceGetRunningProcesses_v3_v2), callback=callback + )( + patch_function_pointers_when_fail( + names=set(nvmlDeviceGetRunningProcesses_v2_v1), callback=callback + )( + _pynvml._nvmlGetFunctionPointer # pylint: disable=protected-access + ) + ) + ) + + with_mapped_function_name() # patch first and only for once + patch_process_info() + + +__patch_backward_compatibility_layers() +del __patch_backward_compatibility_layers + + +# Add support for lookup fallback and context manager ############################################## class _CustomModule(_ModuleType): """Modified module type to support lookup fallback and context manager. @@ -416,6 +551,7 @@ def __del__(self) -> None: __modself.__class__ = _CustomModule del _CustomModule -del _inspect, _logging, _re, _sys, _threading +# Delete imported references +del _inspect, _logging, _os, _re, _sys, _threading del _OrderedDict, _FunctionType, _ModuleType del _Tuple, _Callable, _Type, _Union, _Optional, _Any