Skip to content

Commit

Permalink
feat(core/libnvml): add compatibility layers for process info
Browse files Browse the repository at this point in the history
Signed-off-by: Xuehai Pan <XuehaiPan@pku.edu.cn>
  • Loading branch information
XuehaiPan committed Jul 25, 2022
1 parent 7cb90e5 commit 70dc9a7
Showing 1 changed file with 139 additions and 3 deletions.
142 changes: 139 additions & 3 deletions nvitop/core/libnvml.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,11 @@

# pylint: disable=invalid-name

import ctypes as _ctypes
import functools as _functools
import inspect as _inspect
import logging as _logging
import os as _os
import re as _re
import sys as _sys
import threading as _threading
Expand Down Expand Up @@ -49,7 +52,6 @@

# Load members from module `pynvml` and register them in `__all__` and globals.
_vars_pynvml = vars(_pynvml)
_vars = _OrderedDict()
_name = _attr = None
_errcode_to_name = {}
_const_names = []
Expand Down Expand Up @@ -158,6 +160,13 @@
__lock = _threading.Lock()

LOGGER = _logging.getLogger(__name__)
try:
LOGGER.setLevel(_os.getenv('LOGLEVEL', default='WARNING').upper())
except (ValueError, TypeError):
pass
if not LOGGER.hasHandlers() and LOGGER.isEnabledFor(_logging.INFO):
LOGGER.addHandler(_logging.StreamHandler())
LOGGER.addHandler(_logging.FileHandler('debug-{}.log'.format(__name__)))
UNKNOWN_FUNCTIONS = {}
UNKNOWN_FUNCTIONS_CACHE_SIZE = 1024
VERSIONED_PATTERN = _re.compile(r'^(?P<name>\w+)(?P<suffix>_v(\d)+)$')
Expand Down Expand Up @@ -368,7 +377,133 @@ def nvmlCheckReturn(
return retval != NA and isinstance(retval, types)


# Add support for lookup fallback and context manager.
# Patch layers for backward compatibility ##########################################################
def __patch_backward_compatibility_layers() -> None:
function_name_mapping_lock = _threading.Lock()
function_name_mapping = {}

def function_mapping_update(mapping):
with function_name_mapping_lock:
mapping = dict(mapping)
for name, mapped_name in function_name_mapping.items():
if mapped_name in mapping:
mapping[name] = mapping[mapped_name]
function_name_mapping.update(mapping)
return mapping

def with_mapped_function_name():
def wrapper(nvmlGetFunctionPointer):
@_functools.wraps(nvmlGetFunctionPointer)
def wrapped(name):
mapped_name = function_name_mapping.get(name, name)
return nvmlGetFunctionPointer(mapped_name)

return wrapped

_pynvml.__dict__.update( # need to use module.__dict__.__setitem__ because module.__setattr__ will not work
_nvmlGetFunctionPointer=wrapper(
_pynvml._nvmlGetFunctionPointer # pylint: disable=protected-access
)
)

def patch_function_pointers_when_fail(names, callback):
"""Patches the function pointers of the NVML library."""

def wrapper(nvmlGetFunctionPointer):
@_functools.wraps(nvmlGetFunctionPointer)
def wrapped(name):
try:
return nvmlGetFunctionPointer(name)
except NVMLError_FunctionNotFound as ex:
if name in names:
new_name = callback(name, names, ex, _pynvml, __modself)
return nvmlGetFunctionPointer(new_name)
raise

return wrapped

return wrapper

def patch_process_info():
PrintableStructure = _pynvml._PrintableStructure # pylint: disable=protected-access

# pylint: disable-next=missing-class-docstring,too-few-public-methods
class c_nvmlProcessInfo_v1_t(PrintableStructure):
_fields_ = [
('pid', _ctypes.c_uint),
('usedGpuMemory', _ctypes.c_ulonglong),
]
_fmt_ = {
'usedGpuMemory': '%d B',
}

# pylint: disable-next=missing-class-docstring,too-few-public-methods
class c_nvmlProcessInfo_v2_t(PrintableStructure):
_fields_ = [
('pid', _ctypes.c_uint),
('usedGpuMemory', _ctypes.c_ulonglong),
('gpuInstanceId', _ctypes.c_uint),
('computeInstanceId', _ctypes.c_uint),
]
_fmt_ = {
'usedGpuMemory': '%d B',
}

nvmlDeviceGetRunningProcesses_v3_v2 = {
'nvmlDeviceGetComputeRunningProcesses_v3': 'nvmlDeviceGetComputeRunningProcesses_v2',
'nvmlDeviceGetGraphicsRunningProcesses_v3': 'nvmlDeviceGetGraphicsRunningProcesses_v2',
'nvmlDeviceGetMPSComputeRunningProcesses_v3': 'nvmlDeviceGetMPSComputeRunningProcesses_v2',
}
nvmlDeviceGetRunningProcesses_v2_v1 = {
'nvmlDeviceGetComputeRunningProcesses_v2': 'nvmlDeviceGetComputeRunningProcesses',
'nvmlDeviceGetGraphicsRunningProcesses_v2': 'nvmlDeviceGetGraphicsRunningProcesses',
'nvmlDeviceGetMPSComputeRunningProcesses_v2': 'nvmlDeviceGetMPSComputeRunningProcesses',
}

def callback(name, names, exception, pynvml, modself): # pylint: disable=unused-argument
if name in nvmlDeviceGetRunningProcesses_v3_v2:
mapping = nvmlDeviceGetRunningProcesses_v3_v2
struct_type = c_nvmlProcessInfo_v2_t
elif name in nvmlDeviceGetRunningProcesses_v2_v1:
mapping = nvmlDeviceGetRunningProcesses_v2_v1
struct_type = c_nvmlProcessInfo_v1_t
else:
raise exception # no fallbacks for v1 APIs

LOGGER.debug('Patching NVML function pointer `%s`', name)
mapping = function_mapping_update(mapping)
pynvml.__dict__.update(c_nvmlProcessInfo_t=struct_type)
modself.__dict__.update(c_nvmlProcessInfo_t=struct_type)

for old_name, mapped_name in mapping.items():
LOGGER.debug(' Map NVML function `%s` to `%s`', old_name, mapped_name)
LOGGER.debug(
' Patch NVML struct `c_nvmlProcessInfo_t` to `%s`', struct_type.__name__
)
return mapping[name]

_pynvml.__dict__.update( # need to use module.__dict__.__setitem__ because module.__setattr__ will not work
# The patching ordering is important
_nvmlGetFunctionPointer=patch_function_pointers_when_fail(
names=set(nvmlDeviceGetRunningProcesses_v3_v2), callback=callback
)(
patch_function_pointers_when_fail(
names=set(nvmlDeviceGetRunningProcesses_v2_v1), callback=callback
)(
_pynvml._nvmlGetFunctionPointer # pylint: disable=protected-access
)
)
)

with_mapped_function_name() # patch first and only for once
patch_process_info()


__patch_backward_compatibility_layers()
del __patch_backward_compatibility_layers


# Add support for lookup fallback and context manager ##############################################
class _CustomModule(_ModuleType):
"""Modified module type to support lookup fallback and context manager.
Expand Down Expand Up @@ -416,6 +551,7 @@ def __del__(self) -> None:
__modself.__class__ = _CustomModule
del _CustomModule

del _inspect, _logging, _re, _sys, _threading
# Delete imported references
del _inspect, _logging, _os, _re, _sys, _threading
del _OrderedDict, _FunctionType, _ModuleType
del _Tuple, _Callable, _Type, _Union, _Optional, _Any

0 comments on commit 70dc9a7

Please sign in to comment.