Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions cuda_bindings/cuda/bindings/_internal/cufile_linux.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -111,12 +111,15 @@ cdef void* load_library() except* with gil:
return <void*>handle


cdef int __check_or_init_cufile() except -1 nogil:
cdef int _init_cufile() except -1 nogil:
global __py_cufile_init

cdef void* handle = NULL

with gil, __symbol_lock:
# Recheck the flag after obtaining the locks
if __py_cufile_init:
return 0
# Load function
global __cuFileHandleRegister
__cuFileHandleRegister = dlsym(RTLD_DEFAULT, 'cuFileHandleRegister')
Expand Down Expand Up @@ -427,7 +430,7 @@ cdef inline int _check_or_init_cufile() except -1 nogil:
if __py_cufile_init:
return 0

return __check_or_init_cufile()
return _init_cufile()


cdef dict func_ptrs = None
Expand Down
8 changes: 6 additions & 2 deletions cuda_bindings/cuda/bindings/_internal/nvjitlink_linux.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -80,12 +80,16 @@ cdef void* load_library() except* with gil:
return <void*>handle


cdef int __check_or_init_nvjitlink() except -1 nogil:
cdef int _init_nvjitlink() except -1 nogil:
global __py_nvjitlink_init

cdef void* handle = NULL

with gil, __symbol_lock:
# Recheck the flag after obtaining the locks
if __py_nvjitlink_init:
return 0

# Load function
global __nvJitLinkCreate
__nvJitLinkCreate = dlsym(RTLD_DEFAULT, 'nvJitLinkCreate')
Expand Down Expand Up @@ -193,7 +197,7 @@ cdef inline int _check_or_init_nvjitlink() except -1 nogil:
if __py_nvjitlink_init:
return 0

return __check_or_init_nvjitlink()
return _init_nvjitlink()

cdef dict func_ptrs = None

Expand Down
8 changes: 6 additions & 2 deletions cuda_bindings/cuda/bindings/_internal/nvjitlink_windows.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -93,10 +93,14 @@ cdef void* __nvJitLinkGetInfoLog = NULL
cdef void* __nvJitLinkVersion = NULL


cdef int __check_or_init_nvjitlink() except -1 nogil:
cdef int _init_nvjitlink() except -1 nogil:
global __py_nvjitlink_init

with gil, __symbol_lock:
# Recheck the flag after obtaining the locks
if __py_nvjitlink_init:
return 0

# Load library
handle = load_nvidia_dynamic_lib("nvJitLink")._handle_uint

Expand Down Expand Up @@ -151,7 +155,7 @@ cdef inline int _check_or_init_nvjitlink() except -1 nogil:
if __py_nvjitlink_init:
return 0

return __check_or_init_nvjitlink()
return _init_nvjitlink()


cdef dict func_ptrs = None
Expand Down
8 changes: 6 additions & 2 deletions cuda_bindings/cuda/bindings/_internal/nvvm_linux.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -79,12 +79,16 @@ cdef void* load_library() except* with gil:
return <void*>handle


cdef int __check_or_init_nvvm() except -1 nogil:
cdef int _init_nvvm() except -1 nogil:
global __py_nvvm_init

cdef void* handle = NULL

with gil, __symbol_lock:
# Recheck the flag after obtaining the locks
if __py_nvvm_init:
return 0

# Load function
global __nvvmGetErrorString
__nvvmGetErrorString = dlsym(RTLD_DEFAULT, 'nvvmGetErrorString')
Expand Down Expand Up @@ -185,7 +189,7 @@ cdef inline int _check_or_init_nvvm() except -1 nogil:
if __py_nvvm_init:
return 0

return __check_or_init_nvvm()
return _init_nvvm()


cdef dict func_ptrs = None
Expand Down
8 changes: 6 additions & 2 deletions cuda_bindings/cuda/bindings/_internal/nvvm_windows.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -92,10 +92,14 @@ cdef void* __nvvmGetProgramLogSize = NULL
cdef void* __nvvmGetProgramLog = NULL


cdef int __check_or_init_nvvm() except -1 nogil:
cdef int _init_nvvm() except -1 nogil:
global __py_nvvm_init

with gil, __symbol_lock:
# Recheck the flag after obtaining the locks
if __py_nvvm_init:
return 0

# Load library
handle = load_nvidia_dynamic_lib("nvvm")._handle_uint

Expand Down Expand Up @@ -147,7 +151,7 @@ cdef inline int _check_or_init_nvvm() except -1 nogil:
if __py_nvvm_init:
return 0

return __check_or_init_nvvm()
return _init_nvvm()


cdef dict func_ptrs = None
Expand Down
47 changes: 46 additions & 1 deletion cuda_bindings/cuda/bindings/cufile.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import numpy as _numpy
from cpython cimport buffer as _buffer
from cpython.memoryview cimport PyMemoryView_FromMemory
from enum import IntEnum as _IntEnum
cimport cpython

import cython

Expand Down Expand Up @@ -54,6 +55,10 @@ cdef class _py_anon_pod1:
"""Get the pointer address to the data as Python :class:`int`."""
return self._data.ctypes.data

cdef intptr_t _get_ptr(self):
"""Get the pointer address to the data as Python :class:`int`."""
return self._data.ctypes.data

def __int__(self):
return self._data.ctypes.data

Expand Down Expand Up @@ -157,6 +162,10 @@ cdef class _py_anon_pod3:
"""Get the pointer address to the data as Python :class:`int`."""
return self._data.ctypes.data

cdef intptr_t _get_ptr(self):
"""Get the pointer address to the data as Python :class:`int`."""
return self._data.ctypes.data

def __int__(self):
return self._data.ctypes.data

Expand Down Expand Up @@ -286,6 +295,10 @@ cdef class IOEvents:
"""Get the pointer address to the data as Python :class:`int`."""
return self._data.ctypes.data

cdef intptr_t _get_ptr(self):
"""Get the pointer address to the data as Python :class:`int`."""
return self._data.ctypes.data

def __int__(self):
if self._data.size > 1:
raise TypeError("int() argument must be a bytes-like object of size 1. "
Expand Down Expand Up @@ -422,6 +435,10 @@ cdef class OpCounter:
"""Get the pointer address to the data as Python :class:`int`."""
return self._data.ctypes.data

cdef intptr_t _get_ptr(self):
"""Get the pointer address to the data as Python :class:`int`."""
return self._data.ctypes.data

def __int__(self):
return self._data.ctypes.data

Expand Down Expand Up @@ -551,6 +568,10 @@ cdef class PerGpuStats:
"""Get the pointer address to the data as Python :class:`int`."""
return self._data.ctypes.data

cdef intptr_t _get_ptr(self):
"""Get the pointer address to the data as Python :class:`int`."""
return self._data.ctypes.data

def __int__(self):
return self._data.ctypes.data

Expand Down Expand Up @@ -914,6 +935,10 @@ cdef class Descr:
"""Get the pointer address to the data as Python :class:`int`."""
return self._data.ctypes.data

cdef intptr_t _get_ptr(self):
"""Get the pointer address to the data as Python :class:`int`."""
return self._data.ctypes.data

def __int__(self):
if self._data.size > 1:
raise TypeError("int() argument must be a bytes-like object of size 1. "
Expand Down Expand Up @@ -1052,6 +1077,10 @@ cdef class _py_anon_pod2:
"""Get the pointer address to the data as Python :class:`int`."""
return self._data.ctypes.data

cdef intptr_t _get_ptr(self):
"""Get the pointer address to the data as Python :class:`int`."""
return self._data.ctypes.data

def __int__(self):
return self._data.ctypes.data

Expand Down Expand Up @@ -1185,6 +1214,10 @@ cdef class StatsLevel1:
"""Get the pointer address to the data as Python :class:`int`."""
return self._data.ctypes.data

cdef intptr_t _get_ptr(self):
"""Get the pointer address to the data as Python :class:`int`."""
return self._data.ctypes.data

def __int__(self):
return self._data.ctypes.data

Expand Down Expand Up @@ -1667,6 +1700,10 @@ cdef class IOParams:
"""Get the pointer address to the data as Python :class:`int`."""
return self._data.ctypes.data

cdef intptr_t _get_ptr(self):
"""Get the pointer address to the data as Python :class:`int`."""
return self._data.ctypes.data

def __int__(self):
if self._data.size > 1:
raise TypeError("int() argument must be a bytes-like object of size 1. "
Expand Down Expand Up @@ -1824,6 +1861,10 @@ cdef class StatsLevel2:
"""Get the pointer address to the data as Python :class:`int`."""
return self._data.ctypes.data

cdef intptr_t _get_ptr(self):
"""Get the pointer address to the data as Python :class:`int`."""
return self._data.ctypes.data

def __int__(self):
return self._data.ctypes.data

Expand Down Expand Up @@ -1935,6 +1976,10 @@ cdef class StatsLevel3:
"""Get the pointer address to the data as Python :class:`int`."""
return self._data.ctypes.data

cdef intptr_t _get_ptr(self):
"""Get the pointer address to the data as Python :class:`int`."""
return self._data.ctypes.data

def __int__(self):
return self._data.ctypes.data

Expand Down Expand Up @@ -2458,7 +2503,7 @@ cpdef str get_parameter_string(int param, int len):
with nogil:
__status__ = cuFileGetParameterString(<_StringConfigParameter>param, desc_str, len)
check_status(__status__)
return _desc_str_.decode()
return cpython.PyUnicode_FromString(desc_str)


cpdef set_parameter_size_t(int param, size_t value):
Expand Down
17 changes: 0 additions & 17 deletions cuda_bindings/tests/test_cufile.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,20 +121,6 @@ def isSupportedFilesystem():
pytestmark = pytest.mark.skipif(not cufileLibraryAvailable(), reason="cuFile library not available on this system")


def safe_decode_string(raw_value):
"""Safely decode a string value from ctypes buffer."""
# Find null terminator if present
null_pos = raw_value.find(b"\x00")
if null_pos != -1:
raw_value = raw_value[:null_pos]
# Decode with error handling
try:
return raw_value.decode("utf-8", errors="ignore")
except UnicodeDecodeError:
# If UTF-8 fails, try to decode as bytes
return str(raw_value)


def test_cufile_success_defined():
"""Check if CUFILE_SUCCESS is defined in OpError enum."""
assert hasattr(cufile.OpError, "SUCCESS")
Expand Down Expand Up @@ -1774,8 +1760,6 @@ def test_set_get_parameter_string(tmp_path):

def test_param(param, val, default_val):
orig_val = cufile.get_parameter_string(param, 256)
# Use safe_decode_string to handle null terminators and padding
orig_val = safe_decode_string(orig_val.encode("utf-8"))

val_b = val.encode("utf-8")
val_buf = ctypes.create_string_buffer(val_b)
Expand All @@ -1787,7 +1771,6 @@ def test_param(param, val, default_val):
# Round-trip test
cufile.set_parameter_string(param, int(ctypes.addressof(val_buf)))
retrieved_val = cufile.get_parameter_string(param, 256)
retrieved_val = safe_decode_string(retrieved_val.encode("utf-8"))
assert retrieved_val == val

# Restore
Expand Down
Loading