diff --git a/cuda_bindings/cuda/bindings/_internal/cufile_linux.pyx b/cuda_bindings/cuda/bindings/_internal/cufile_linux.pyx index 33b638464..2f4580d79 100644 --- a/cuda_bindings/cuda/bindings/_internal/cufile_linux.pyx +++ b/cuda_bindings/cuda/bindings/_internal/cufile_linux.pyx @@ -111,12 +111,15 @@ cdef void* load_library() except* with gil: return handle -cdef int __check_or_init_cufile() except -1 nogil: +cdef int _init_cufile() except -1 nogil: global __py_cufile_init cdef void* handle = NULL with gil, __symbol_lock: + # Recheck the flag after obtaining the locks + if __py_cufile_init: + return 0 # Load function global __cuFileHandleRegister __cuFileHandleRegister = dlsym(RTLD_DEFAULT, 'cuFileHandleRegister') @@ -427,7 +430,7 @@ cdef inline int _check_or_init_cufile() except -1 nogil: if __py_cufile_init: return 0 - return __check_or_init_cufile() + return _init_cufile() cdef dict func_ptrs = None diff --git a/cuda_bindings/cuda/bindings/_internal/nvjitlink_linux.pyx b/cuda_bindings/cuda/bindings/_internal/nvjitlink_linux.pyx index a8e2b4e56..ccc412b0f 100644 --- a/cuda_bindings/cuda/bindings/_internal/nvjitlink_linux.pyx +++ b/cuda_bindings/cuda/bindings/_internal/nvjitlink_linux.pyx @@ -80,12 +80,16 @@ cdef void* load_library() except* with gil: return handle -cdef int __check_or_init_nvjitlink() except -1 nogil: +cdef int _init_nvjitlink() except -1 nogil: global __py_nvjitlink_init cdef void* handle = NULL with gil, __symbol_lock: + # Recheck the flag after obtaining the locks + if __py_nvjitlink_init: + return 0 + # Load function global __nvJitLinkCreate __nvJitLinkCreate = dlsym(RTLD_DEFAULT, 'nvJitLinkCreate') @@ -193,7 +197,7 @@ cdef inline int _check_or_init_nvjitlink() except -1 nogil: if __py_nvjitlink_init: return 0 - return __check_or_init_nvjitlink() + return _init_nvjitlink() cdef dict func_ptrs = None diff --git a/cuda_bindings/cuda/bindings/_internal/nvjitlink_windows.pyx b/cuda_bindings/cuda/bindings/_internal/nvjitlink_windows.pyx index 5b26cba4a..1b88b9989 100644 --- a/cuda_bindings/cuda/bindings/_internal/nvjitlink_windows.pyx +++ b/cuda_bindings/cuda/bindings/_internal/nvjitlink_windows.pyx @@ -93,10 +93,14 @@ cdef void* __nvJitLinkGetInfoLog = NULL cdef void* __nvJitLinkVersion = NULL -cdef int __check_or_init_nvjitlink() except -1 nogil: +cdef int _init_nvjitlink() except -1 nogil: global __py_nvjitlink_init with gil, __symbol_lock: + # Recheck the flag after obtaining the locks + if __py_nvjitlink_init: + return 0 + # Load library handle = load_nvidia_dynamic_lib("nvJitLink")._handle_uint @@ -151,7 +155,7 @@ cdef inline int _check_or_init_nvjitlink() except -1 nogil: if __py_nvjitlink_init: return 0 - return __check_or_init_nvjitlink() + return _init_nvjitlink() cdef dict func_ptrs = None diff --git a/cuda_bindings/cuda/bindings/_internal/nvvm_linux.pyx b/cuda_bindings/cuda/bindings/_internal/nvvm_linux.pyx index 64d97334b..e1addcc9e 100644 --- a/cuda_bindings/cuda/bindings/_internal/nvvm_linux.pyx +++ b/cuda_bindings/cuda/bindings/_internal/nvvm_linux.pyx @@ -79,12 +79,16 @@ cdef void* load_library() except* with gil: return handle -cdef int __check_or_init_nvvm() except -1 nogil: +cdef int _init_nvvm() except -1 nogil: global __py_nvvm_init cdef void* handle = NULL with gil, __symbol_lock: + # Recheck the flag after obtaining the locks + if __py_nvvm_init: + return 0 + # Load function global __nvvmGetErrorString __nvvmGetErrorString = dlsym(RTLD_DEFAULT, 'nvvmGetErrorString') @@ -185,7 +189,7 @@ cdef inline int _check_or_init_nvvm() except -1 nogil: if __py_nvvm_init: return 0 - return __check_or_init_nvvm() + return _init_nvvm() cdef dict func_ptrs = None diff --git a/cuda_bindings/cuda/bindings/_internal/nvvm_windows.pyx b/cuda_bindings/cuda/bindings/_internal/nvvm_windows.pyx index a1d7dfbd1..de3e789a4 100644 --- a/cuda_bindings/cuda/bindings/_internal/nvvm_windows.pyx +++ b/cuda_bindings/cuda/bindings/_internal/nvvm_windows.pyx @@ -92,10 +92,14 @@ cdef void* __nvvmGetProgramLogSize = NULL cdef void* __nvvmGetProgramLog = NULL -cdef int __check_or_init_nvvm() except -1 nogil: +cdef int _init_nvvm() except -1 nogil: global __py_nvvm_init with gil, __symbol_lock: + # Recheck the flag after obtaining the locks + if __py_nvvm_init: + return 0 + # Load library handle = load_nvidia_dynamic_lib("nvvm")._handle_uint @@ -147,7 +151,7 @@ cdef inline int _check_or_init_nvvm() except -1 nogil: if __py_nvvm_init: return 0 - return __check_or_init_nvvm() + return _init_nvvm() cdef dict func_ptrs = None diff --git a/cuda_bindings/cuda/bindings/cufile.pyx b/cuda_bindings/cuda/bindings/cufile.pyx index f8b3e360a..a4e5c2399 100644 --- a/cuda_bindings/cuda/bindings/cufile.pyx +++ b/cuda_bindings/cuda/bindings/cufile.pyx @@ -12,6 +12,7 @@ import numpy as _numpy from cpython cimport buffer as _buffer from cpython.memoryview cimport PyMemoryView_FromMemory from enum import IntEnum as _IntEnum +cimport cpython import cython @@ -54,6 +55,10 @@ cdef class _py_anon_pod1: """Get the pointer address to the data as Python :class:`int`.""" return self._data.ctypes.data + cdef intptr_t _get_ptr(self): + """Get the pointer address to the data as Python :class:`int`.""" + return self._data.ctypes.data + def __int__(self): return self._data.ctypes.data @@ -157,6 +162,10 @@ cdef class _py_anon_pod3: """Get the pointer address to the data as Python :class:`int`.""" return self._data.ctypes.data + cdef intptr_t _get_ptr(self): + """Get the pointer address to the data as Python :class:`int`.""" + return self._data.ctypes.data + def __int__(self): return self._data.ctypes.data @@ -286,6 +295,10 @@ cdef class IOEvents: """Get the pointer address to the data as Python :class:`int`.""" return self._data.ctypes.data + cdef intptr_t _get_ptr(self): + """Get the pointer address to the data as Python :class:`int`.""" + return self._data.ctypes.data + def __int__(self): if self._data.size > 1: raise TypeError("int() argument must be a bytes-like object of size 1. " @@ -422,6 +435,10 @@ cdef class OpCounter: """Get the pointer address to the data as Python :class:`int`.""" return self._data.ctypes.data + cdef intptr_t _get_ptr(self): + """Get the pointer address to the data as Python :class:`int`.""" + return self._data.ctypes.data + def __int__(self): return self._data.ctypes.data @@ -551,6 +568,10 @@ cdef class PerGpuStats: """Get the pointer address to the data as Python :class:`int`.""" return self._data.ctypes.data + cdef intptr_t _get_ptr(self): + """Get the pointer address to the data as Python :class:`int`.""" + return self._data.ctypes.data + def __int__(self): return self._data.ctypes.data @@ -914,6 +935,10 @@ cdef class Descr: """Get the pointer address to the data as Python :class:`int`.""" return self._data.ctypes.data + cdef intptr_t _get_ptr(self): + """Get the pointer address to the data as Python :class:`int`.""" + return self._data.ctypes.data + def __int__(self): if self._data.size > 1: raise TypeError("int() argument must be a bytes-like object of size 1. " @@ -1052,6 +1077,10 @@ cdef class _py_anon_pod2: """Get the pointer address to the data as Python :class:`int`.""" return self._data.ctypes.data + cdef intptr_t _get_ptr(self): + """Get the pointer address to the data as Python :class:`int`.""" + return self._data.ctypes.data + def __int__(self): return self._data.ctypes.data @@ -1185,6 +1214,10 @@ cdef class StatsLevel1: """Get the pointer address to the data as Python :class:`int`.""" return self._data.ctypes.data + cdef intptr_t _get_ptr(self): + """Get the pointer address to the data as Python :class:`int`.""" + return self._data.ctypes.data + def __int__(self): return self._data.ctypes.data @@ -1667,6 +1700,10 @@ cdef class IOParams: """Get the pointer address to the data as Python :class:`int`.""" return self._data.ctypes.data + cdef intptr_t _get_ptr(self): + """Get the pointer address to the data as Python :class:`int`.""" + return self._data.ctypes.data + def __int__(self): if self._data.size > 1: raise TypeError("int() argument must be a bytes-like object of size 1. " @@ -1824,6 +1861,10 @@ cdef class StatsLevel2: """Get the pointer address to the data as Python :class:`int`.""" return self._data.ctypes.data + cdef intptr_t _get_ptr(self): + """Get the pointer address to the data as Python :class:`int`.""" + return self._data.ctypes.data + def __int__(self): return self._data.ctypes.data @@ -1935,6 +1976,10 @@ cdef class StatsLevel3: """Get the pointer address to the data as Python :class:`int`.""" return self._data.ctypes.data + cdef intptr_t _get_ptr(self): + """Get the pointer address to the data as Python :class:`int`.""" + return self._data.ctypes.data + def __int__(self): return self._data.ctypes.data @@ -2458,7 +2503,7 @@ cpdef str get_parameter_string(int param, int len): with nogil: __status__ = cuFileGetParameterString(<_StringConfigParameter>param, desc_str, len) check_status(__status__) - return _desc_str_.decode() + return cpython.PyUnicode_FromString(desc_str) cpdef set_parameter_size_t(int param, size_t value): diff --git a/cuda_bindings/tests/test_cufile.py b/cuda_bindings/tests/test_cufile.py index 3716e2bec..8ac12dfc7 100644 --- a/cuda_bindings/tests/test_cufile.py +++ b/cuda_bindings/tests/test_cufile.py @@ -121,20 +121,6 @@ def isSupportedFilesystem(): pytestmark = pytest.mark.skipif(not cufileLibraryAvailable(), reason="cuFile library not available on this system") -def safe_decode_string(raw_value): - """Safely decode a string value from ctypes buffer.""" - # Find null terminator if present - null_pos = raw_value.find(b"\x00") - if null_pos != -1: - raw_value = raw_value[:null_pos] - # Decode with error handling - try: - return raw_value.decode("utf-8", errors="ignore") - except UnicodeDecodeError: - # If UTF-8 fails, try to decode as bytes - return str(raw_value) - - def test_cufile_success_defined(): """Check if CUFILE_SUCCESS is defined in OpError enum.""" assert hasattr(cufile.OpError, "SUCCESS") @@ -1774,8 +1760,6 @@ def test_set_get_parameter_string(tmp_path): def test_param(param, val, default_val): orig_val = cufile.get_parameter_string(param, 256) - # Use safe_decode_string to handle null terminators and padding - orig_val = safe_decode_string(orig_val.encode("utf-8")) val_b = val.encode("utf-8") val_buf = ctypes.create_string_buffer(val_b) @@ -1787,7 +1771,6 @@ def test_param(param, val, default_val): # Round-trip test cufile.set_parameter_string(param, int(ctypes.addressof(val_buf))) retrieved_val = cufile.get_parameter_string(param, 256) - retrieved_val = safe_decode_string(retrieved_val.encode("utf-8")) assert retrieved_val == val # Restore