Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 50 additions & 31 deletions cuda_core/cuda/core/experimental/_device.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ from cuda.core.experimental._utils.cuda_utils import (
)


# TODO: I prefer to type these as "cdef object" and avoid accessing them from within Python,
# but it seems it is very convenient to expose them for testing purposes...
_tls = threading.local()
_lock = threading.Lock()
cdef bint _is_cuInit = False
Expand Down Expand Up @@ -55,7 +57,8 @@ cdef class DeviceProperties:
cdef inline _get_attribute(self, cydriver.CUdevice_attribute attr):
"""Retrieve the attribute value directly from the driver."""
cdef int val
HANDLE_RETURN(cydriver.cuDeviceGetAttribute(&val, attr, self._handle))
with nogil:
HANDLE_RETURN(cydriver.cuDeviceGetAttribute(&val, attr, self._handle))
return val

cdef _get_cached_attribute(self, attr):
Expand Down Expand Up @@ -912,7 +915,8 @@ cdef cydriver.CUcontext _get_primary_context(int dev_id) except?NULL:
primary_ctxs = _tls.primary_ctxs = [0] * total
cdef cydriver.CUcontext ctx = <cydriver.CUcontext><uintptr_t>(primary_ctxs[dev_id])
if ctx == NULL:
HANDLE_RETURN(cydriver.cuDevicePrimaryCtxRetain(&ctx, dev_id))
with nogil:
HANDLE_RETURN(cydriver.cuDevicePrimaryCtxRetain(&ctx, dev_id))
primary_ctxs[dev_id] = <uintptr_t>(ctx)
return ctx

Expand Down Expand Up @@ -948,19 +952,21 @@ class Device:
def __new__(cls, device_id: Optional[int] = None):
global _is_cuInit
if _is_cuInit is False:
with _lock:
with _lock, nogil:
HANDLE_RETURN(cydriver.cuInit(0))
_is_cuInit = True

# important: creating a Device instance does not initialize the GPU!
cdef cydriver.CUdevice dev
cdef cydriver.CUcontext ctx
if device_id is None:
err = cydriver.cuCtxGetDevice(&dev)
with nogil:
err = cydriver.cuCtxGetDevice(&dev)
if err == cydriver.CUresult.CUDA_SUCCESS:
device_id = int(dev)
elif err == cydriver.CUresult.CUDA_ERROR_INVALID_CONTEXT:
HANDLE_RETURN(cydriver.cuCtxGetCurrent(&ctx))
with nogil:
HANDLE_RETURN(cydriver.cuCtxGetCurrent(&ctx))
assert <void*>(ctx) == NULL
device_id = 0 # cudart behavior
else:
Expand All @@ -973,18 +979,20 @@ class Device:
try:
devices = _tls.devices
except AttributeError:
HANDLE_RETURN(cydriver.cuDeviceGetCount(&total))
with nogil:
HANDLE_RETURN(cydriver.cuDeviceGetCount(&total))
devices = _tls.devices = []
for dev_id in range(total):
device = super().__new__(cls)
device._id = dev_id
# If the device is in TCC mode, or does not support memory pools for some other reason,
# use the SynchronousMemoryResource which does not use memory pools.
HANDLE_RETURN(
cydriver.cuDeviceGetAttribute(
&attr, cydriver.CUdevice_attribute.CU_DEVICE_ATTRIBUTE_MEMORY_POOLS_SUPPORTED, dev_id
with nogil:
HANDLE_RETURN(
cydriver.cuDeviceGetAttribute(
&attr, cydriver.CUdevice_attribute.CU_DEVICE_ATTRIBUTE_MEMORY_POOLS_SUPPORTED, dev_id
)
)
)
if attr == 1:
device._mr = DeviceMemoryResource(dev_id)
else:
Expand All @@ -1005,16 +1013,18 @@ class Device:
f"Device {self._id} is not yet initialized, perhaps you forgot to call .set_current() first?"
)

def _get_current_context(self, check_consistency=False) -> driver.CUcontext:
def _get_current_context(self, bint check_consistency=False) -> driver.CUcontext:
cdef cydriver.CUcontext ctx
HANDLE_RETURN(cydriver.cuCtxGetCurrent(&ctx))
if ctx == NULL:
raise CUDAError("No context is bound to the calling CPU thread.")
cdef cydriver.CUdevice dev
if check_consistency:
HANDLE_RETURN(cydriver.cuCtxGetDevice(&dev))
if <int>(dev) != self._id:
raise CUDAError("Internal error (current device is not equal to Device.device_id)")
cdef cydriver.CUdevice this_dev = self._id
with nogil:
HANDLE_RETURN(cydriver.cuCtxGetCurrent(&ctx))
if ctx == NULL:
raise CUDAError("No context is bound to the calling CPU thread.")
if check_consistency:
HANDLE_RETURN(cydriver.cuCtxGetDevice(&dev))
if dev != this_dev:
raise CUDAError("Internal error (current device is not equal to Device.device_id)")
return driver.CUcontext(<uintptr_t>ctx)

@property
Expand Down Expand Up @@ -1043,10 +1053,12 @@ class Device:

"""
cdef cydriver.CUuuid uuid
IF CUDA_CORE_BUILD_MAJOR == "12":
HANDLE_RETURN(cydriver.cuDeviceGetUuid_v2(&uuid, self._id))
ELSE: # 13.0+
HANDLE_RETURN(cydriver.cuDeviceGetUuid(&uuid, self._id))
cdef cydriver.CUdevice this_dev = self._id
with nogil:
IF CUDA_CORE_BUILD_MAJOR == "12":
HANDLE_RETURN(cydriver.cuDeviceGetUuid_v2(&uuid, this_dev))
ELSE: # 13.0+
HANDLE_RETURN(cydriver.cuDeviceGetUuid(&uuid, this_dev))
cdef bytes uuid_b = cpython.PyBytes_FromStringAndSize(uuid.bytes, sizeof(uuid.bytes))
cdef str uuid_hex = uuid_b.hex()
# 8-4-4-4-12
Expand All @@ -1058,7 +1070,10 @@ class Device:
# Use 256 characters to be consistent with CUDA Runtime
cdef int LENGTH = 256
cdef bytes name = bytes(LENGTH)
HANDLE_RETURN(cydriver.cuDeviceGetName(<char*>name, LENGTH, self._id))
cdef char* name_ptr = name
cdef cydriver.CUdevice this_dev = self._id
with nogil:
HANDLE_RETURN(cydriver.cuDeviceGetName(name_ptr, LENGTH, this_dev))
name = name.split(b"\0")[0]
return name.decode()

Expand Down Expand Up @@ -1161,7 +1176,8 @@ class Device:
>>> # ... do work on device 0 ...

"""
cdef cydriver.CUcontext _ctx
cdef cydriver.CUcontext prev_ctx
cdef cydriver.CUcontext curr_ctx
if ctx is not None:
# TODO: revisit once Context is cythonized
assert_type(ctx, Context)
Expand All @@ -1170,16 +1186,19 @@ class Device:
"the provided context was created on the device with"
f" id={ctx._id}, which is different from the target id={self._id}"
)
# _ctx is the previous context
HANDLE_RETURN(cydriver.cuCtxPopCurrent(&_ctx))
HANDLE_RETURN(cydriver.cuCtxPushCurrent(<cydriver.CUcontext>(ctx._handle)))
# prev_ctx is the previous context
curr_ctx = <cydriver.CUcontext>(ctx._handle)
with nogil:
HANDLE_RETURN(cydriver.cuCtxPopCurrent(&prev_ctx))
HANDLE_RETURN(cydriver.cuCtxPushCurrent(curr_ctx))
self._has_inited = True
if _ctx != NULL:
return Context._from_ctx(<uintptr_t>(_ctx), self._id)
if prev_ctx != NULL:
return Context._from_ctx(<uintptr_t>(prev_ctx), self._id)
else:
# use primary ctx
_ctx = _get_primary_context(self._id)
HANDLE_RETURN(cydriver.cuCtxSetCurrent(_ctx))
curr_ctx = _get_primary_context(self._id)
with nogil:
HANDLE_RETURN(cydriver.cuCtxSetCurrent(curr_ctx))
self._has_inited = True

def create_context(self, options: ContextOptions = None) -> Context:
Expand Down
17 changes: 17 additions & 0 deletions cuda_core/cuda/core/experimental/_event.pxd
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# SPDX-License-Identifier: Apache-2.0

from cuda.bindings cimport cydriver


cdef class Event:

cdef:
cydriver.CUevent _handle
bint _timing_disabled
bint _busy_waited
int _device_id
object _ctx_handle

cpdef close(self)
37 changes: 14 additions & 23 deletions cuda_core/cuda/core/experimental/_event.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,7 @@ from cuda.core.experimental._context import Context
from cuda.core.experimental._utils.cuda_utils import (
CUDAError,
driver,
handle_return,
)
import sys
if TYPE_CHECKING:
import cuda.bindings
from cuda.core.experimental._device import Device
Expand Down Expand Up @@ -81,13 +79,6 @@ cdef class Event:
and they should instead be created through a :obj:`~_stream.Stream` object.

"""
cdef:
cydriver.CUevent _handle
bint _timing_disabled
bint _busy_waited
int _device_id
object _ctx_handle

def __cinit__(self):
self._handle = <cydriver.CUevent>(NULL)

Expand All @@ -109,24 +100,21 @@ cdef class Event:
self._busy_waited = True
if opts.support_ipc:
raise NotImplementedError("WIP: https://github.com/NVIDIA/cuda-python/issues/103")
HANDLE_RETURN(cydriver.cuEventCreate(&self._handle, flags))
with nogil:
HANDLE_RETURN(cydriver.cuEventCreate(&self._handle, flags))
self._device_id = device_id
self._ctx_handle = ctx_handle
return self

cdef _shutdown_safe_close(self, is_shutting_down=sys.is_finalizing):
if is_shutting_down and is_shutting_down():
return
if self._handle != NULL:
HANDLE_RETURN(cydriver.cuEventDestroy(self._handle))
self._handle = <cydriver.CUevent>(NULL)

cpdef close(self):
"""Destroy the event."""
self._shutdown_safe_close(is_shutting_down=None)
if self._handle != NULL:
with nogil:
HANDLE_RETURN(cydriver.cuEventDestroy(self._handle))
self._handle = <cydriver.CUevent>(NULL)

def __del__(self):
self._shutdown_safe_close()
def __dealloc__(self):
self.close()

def __isub__(self, other):
return NotImplemented
Expand All @@ -137,7 +125,8 @@ cdef class Event:
def __sub__(self, other: Event):
# return self - other (in milliseconds)
cdef float timing
err = cydriver.cuEventElapsedTime(&timing, other._handle, self._handle)
with nogil:
err = cydriver.cuEventElapsedTime(&timing, other._handle, self._handle)
if err == 0:
return timing
else:
Expand Down Expand Up @@ -187,12 +176,14 @@ cdef class Event:
has been completed.

"""
HANDLE_RETURN(cydriver.cuEventSynchronize(self._handle))
with nogil:
HANDLE_RETURN(cydriver.cuEventSynchronize(self._handle))

@property
def is_done(self) -> bool:
"""Return True if all captured works have been completed, otherwise False."""
result = cydriver.cuEventQuery(self._handle)
with nogil:
result = cydriver.cuEventQuery(self._handle)
if result == cydriver.CUresult.CUDA_SUCCESS:
return True
if result == cydriver.CUresult.CUDA_ERROR_NOT_READY:
Expand Down
Loading
Loading