From 267b268337fb9de1e878e1d96a2842754579dc0d Mon Sep 17 00:00:00 2001 From: Andy Jost Date: Fri, 14 Nov 2025 14:33:49 -0800 Subject: [PATCH] Update Device constructor to accept Device or device ordinal. Update public APIs taking a device to accept either. --- cuda_core/cuda/core/experimental/_device.pyx | 53 +++++++++++-------- cuda_core/cuda/core/experimental/_event.pyx | 1 - .../_memory/_device_memory_resource.pyx | 11 ++-- .../cuda/core/experimental/_memory/_ipc.pyx | 3 +- .../cuda/core/experimental/_memory/_legacy.py | 4 +- .../_memory/_virtual_memory_resource.py | 9 ++-- cuda_core/cuda/core/experimental/_module.py | 36 +++++++------ cuda_core/tests/test_memory.py | 6 ++- cuda_core/tests/test_module.py | 3 +- 9 files changed, 73 insertions(+), 53 deletions(-) diff --git a/cuda_core/cuda/core/experimental/_device.pyx b/cuda_core/cuda/core/experimental/_device.pyx index bc6167793..6c3c0c6db 100644 --- a/cuda_core/cuda/core/experimental/_device.pyx +++ b/cuda_core/cuda/core/experimental/_device.pyx @@ -948,9 +948,16 @@ class Device: Default value of `None` return the currently used device. """ - __slots__ = ("_id", "_mr", "_has_inited", "_properties", "_uuid") + __slots__ = ("_id", "_memory_resource", "_has_inited", "_properties", "_uuid") - def __new__(cls, device_id: int | None = None): + def __new__(cls, device_id: Device | int | None = None): + # Handle device_id argument. + if isinstance(device_id, Device): + return device_id + else: + device_id = getattr(device_id, 'device_id', device_id) + + # Initialize CUDA. global _is_cuInit if _is_cuInit is False: with _lock, nogil: @@ -976,7 +983,7 @@ class Device: raise ValueError(f"device_id must be >= 0, got {device_id}") # ensure Device is singleton - cdef int total, attr + cdef int total try: devices = _tls.devices except AttributeError: @@ -986,21 +993,7 @@ class Device: for dev_id in range(total): device = super().__new__(cls) device._id = dev_id - # If the device is in TCC mode, or does not support memory pools for some other reason, - # use the SynchronousMemoryResource which does not use memory pools. - with nogil: - HANDLE_RETURN( - cydriver.cuDeviceGetAttribute( - &attr, cydriver.CUdevice_attribute.CU_DEVICE_ATTRIBUTE_MEMORY_POOLS_SUPPORTED, dev_id - ) - ) - if attr == 1: - from cuda.core.experimental._memory import DeviceMemoryResource - device._mr = DeviceMemoryResource(dev_id) - else: - from cuda.core.experimental._memory import _SynchronousMemoryResource - device._mr = _SynchronousMemoryResource(dev_id) - + device._memory_resource = None device._has_inited = False device._properties = None device._uuid = None @@ -1128,13 +1121,31 @@ class Device: @property def memory_resource(self) -> MemoryResource: """Return :obj:`~_memory.MemoryResource` associated with this device.""" - return self._mr + cdef int attr, device_id + if self._memory_resource is None: + # If the device is in TCC mode, or does not support memory pools for some other reason, + # use the SynchronousMemoryResource which does not use memory pools. + device_id = self._id + with nogil: + HANDLE_RETURN( + cydriver.cuDeviceGetAttribute( + &attr, cydriver.CUdevice_attribute.CU_DEVICE_ATTRIBUTE_MEMORY_POOLS_SUPPORTED, device_id + ) + ) + if attr == 1: + from cuda.core.experimental._memory import DeviceMemoryResource + self._memory_resource = DeviceMemoryResource(self._id) + else: + from cuda.core.experimental._memory import _SynchronousMemoryResource + self._memory_resource = _SynchronousMemoryResource(self._id) + + return self._memory_resource @memory_resource.setter def memory_resource(self, mr): from cuda.core.experimental._memory import MemoryResource assert_type(mr, MemoryResource) - self._mr = mr + self._memory_resource = mr @property def default_stream(self) -> Stream: @@ -1324,7 +1335,7 @@ class Device: self._check_context_initialized() if stream is None: stream = default_stream() - return self._mr.allocate(size, stream) + return self.memory_resource.allocate(size, stream) def sync(self): """Synchronize the device. diff --git a/cuda_core/cuda/core/experimental/_event.pyx b/cuda_core/cuda/core/experimental/_event.pyx index 92899f967..98a45d004 100644 --- a/cuda_core/cuda/core/experimental/_event.pyx +++ b/cuda_core/cuda/core/experimental/_event.pyx @@ -25,7 +25,6 @@ from cuda.core.experimental._utils.cuda_utils import ( ) if TYPE_CHECKING: import cuda.bindings - from cuda.core.experimental._device import Device @dataclass diff --git a/cuda_core/cuda/core/experimental/_memory/_device_memory_resource.pyx b/cuda_core/cuda/core/experimental/_memory/_device_memory_resource.pyx index 47b6fd114..03e9941fb 100644 --- a/cuda_core/cuda/core/experimental/_memory/_device_memory_resource.pyx +++ b/cuda_core/cuda/core/experimental/_memory/_device_memory_resource.pyx @@ -18,9 +18,9 @@ from cuda.core.experimental._utils.cuda_utils cimport ( HANDLE_RETURN, ) -import cython from dataclasses import dataclass from typing import Optional, TYPE_CHECKING +import cython import platform # no-cython-lint import uuid import weakref @@ -131,7 +131,7 @@ cdef class DeviceMemoryResource(MemoryResource): Parameters ---------- - device_id : int | Device + device_id : Device | int Device or Device ordinal for which a memory resource is constructed. options : DeviceMemoryResourceOptions @@ -211,8 +211,9 @@ cdef class DeviceMemoryResource(MemoryResource): self._ipc_data = None self._attributes = None - def __init__(self, device_id: int | Device, options=None): - cdef int dev_id = getattr(device_id, 'device_id', device_id) + def __init__(self, device_id: Device | int, options=None): + from .._device import Device + cdef int dev_id = Device(device_id).device_id opts = check_or_create_options( DeviceMemoryResourceOptions, options, "DeviceMemoryResource options", keep_none=True @@ -261,7 +262,7 @@ cdef class DeviceMemoryResource(MemoryResource): @classmethod def from_allocation_handle( - cls, device_id: int | Device, alloc_handle: int | IPCAllocationHandle + cls, device_id: Device | int, alloc_handle: int | IPCAllocationHandle ) -> DeviceMemoryResource: """Create a device memory resource from an allocation handle. diff --git a/cuda_core/cuda/core/experimental/_memory/_ipc.pyx b/cuda_core/cuda/core/experimental/_memory/_ipc.pyx index 5aa13af8f..d9384bf2b 100644 --- a/cuda_core/cuda/core/experimental/_memory/_ipc.pyx +++ b/cuda_core/cuda/core/experimental/_memory/_ipc.pyx @@ -197,7 +197,8 @@ cdef DeviceMemoryResource DMR_from_allocation_handle(cls, device_id, alloc_handl # Construct a new DMR. cdef DeviceMemoryResource self = DeviceMemoryResource.__new__(cls) - self._dev_id = getattr(device_id, 'device_id', device_id) + from .._device import Device + self._dev_id = Device(device_id).device_id self._mempool_owned = True self._ipc_data = IPCData(alloc_handle, mapped=True) diff --git a/cuda_core/cuda/core/experimental/_memory/_legacy.py b/cuda_core/cuda/core/experimental/_memory/_legacy.py index 523835a79..9bddf697a 100644 --- a/cuda_core/cuda/core/experimental/_memory/_legacy.py +++ b/cuda_core/cuda/core/experimental/_memory/_legacy.py @@ -86,7 +86,9 @@ class _SynchronousMemoryResource(MemoryResource): __slots__ = ("_dev_id",) def __init__(self, device_id): - self._dev_id = getattr(device_id, "device_id", device_id) + from .._device import Device + + self._dev_id = Device(device_id).device_id def allocate(self, size, stream=None) -> Buffer: if stream is None: diff --git a/cuda_core/cuda/core/experimental/_memory/_virtual_memory_resource.py b/cuda_core/cuda/core/experimental/_memory/_virtual_memory_resource.py index 5379f0b8f..c17c30bc9 100644 --- a/cuda_core/cuda/core/experimental/_memory/_virtual_memory_resource.py +++ b/cuda_core/cuda/core/experimental/_memory/_virtual_memory_resource.py @@ -5,6 +5,7 @@ from dataclasses import dataclass, field from typing import Iterable, Literal, Union +from cuda.core.experimental._device import Device from cuda.core.experimental._memory._buffer import Buffer, MemoryResource from cuda.core.experimental._stream import Stream from cuda.core.experimental._utils.cuda_utils import ( @@ -140,15 +141,15 @@ class VirtualMemoryResource(MemoryResource): Parameters ---------- - device_id : int - Device ordinal for which a memory resource is constructed. + device_id : Device | int + Device for which a memory resource is constructed. config : VirtualMemoryResourceOptions A configuration object for the VirtualMemoryResource """ - def __init__(self, device, config: VirtualMemoryResourceOptions = None): - self.device = device + def __init__(self, device_id: Device | int, config: VirtualMemoryResourceOptions = None): + self.device = Device(device_id) self.config = check_or_create_options( VirtualMemoryResourceOptions, config, "VirtualMemoryResource options", keep_none=False ) diff --git a/cuda_core/cuda/core/experimental/_module.py b/cuda_core/cuda/core/experimental/_module.py index f8ce8f95d..9af722465 100644 --- a/cuda_core/cuda/core/experimental/_module.py +++ b/cuda_core/cuda/core/experimental/_module.py @@ -7,6 +7,7 @@ from typing import Union from warnings import warn +from cuda.core.experimental._device import Device from cuda.core.experimental._launch_config import LaunchConfig, _to_native_launch_config from cuda.core.experimental._stream import Stream from cuda.core.experimental._utils.clear_error_support import ( @@ -73,8 +74,9 @@ def _init(cls, kernel): self._loader = _backend[self._backend_version] return self - def _get_cached_attribute(self, device_id: int, attribute: driver.CUfunction_attribute) -> int: + def _get_cached_attribute(self, device_id: Device | int, attribute: driver.CUfunction_attribute) -> int: """Helper function to get a cached attribute or fetch and cache it if not present.""" + device_id = Device(device_id).device_id cache_key = device_id, attribute result = self._cache.get(cache_key, cache_key) if result is not cache_key: @@ -94,62 +96,62 @@ def _get_cached_attribute(self, device_id: int, attribute: driver.CUfunction_att self._cache[cache_key] = result return result - def max_threads_per_block(self, device_id: int = None) -> int: + def max_threads_per_block(self, device_id: Device | int = None) -> int: """int : The maximum number of threads per block. This attribute is read-only.""" return self._get_cached_attribute( device_id, driver.CUfunction_attribute.CU_FUNC_ATTRIBUTE_MAX_THREADS_PER_BLOCK ) - def shared_size_bytes(self, device_id: int = None) -> int: + def shared_size_bytes(self, device_id: Device | int = None) -> int: """int : The size in bytes of statically-allocated shared memory required by this function. This attribute is read-only.""" return self._get_cached_attribute(device_id, driver.CUfunction_attribute.CU_FUNC_ATTRIBUTE_SHARED_SIZE_BYTES) - def const_size_bytes(self, device_id: int = None) -> int: + def const_size_bytes(self, device_id: Device | int = None) -> int: """int : The size in bytes of user-allocated constant memory required by this function. This attribute is read-only.""" return self._get_cached_attribute(device_id, driver.CUfunction_attribute.CU_FUNC_ATTRIBUTE_CONST_SIZE_BYTES) - def local_size_bytes(self, device_id: int = None) -> int: + def local_size_bytes(self, device_id: Device | int = None) -> int: """int : The size in bytes of local memory used by each thread of this function. This attribute is read-only.""" return self._get_cached_attribute(device_id, driver.CUfunction_attribute.CU_FUNC_ATTRIBUTE_LOCAL_SIZE_BYTES) - def num_regs(self, device_id: int = None) -> int: + def num_regs(self, device_id: Device | int = None) -> int: """int : The number of registers used by each thread of this function. This attribute is read-only.""" return self._get_cached_attribute(device_id, driver.CUfunction_attribute.CU_FUNC_ATTRIBUTE_NUM_REGS) - def ptx_version(self, device_id: int = None) -> int: + def ptx_version(self, device_id: Device | int = None) -> int: """int : The PTX virtual architecture version for which the function was compiled. This attribute is read-only.""" return self._get_cached_attribute(device_id, driver.CUfunction_attribute.CU_FUNC_ATTRIBUTE_PTX_VERSION) - def binary_version(self, device_id: int = None) -> int: + def binary_version(self, device_id: Device | int = None) -> int: """int : The binary architecture version for which the function was compiled. This attribute is read-only.""" return self._get_cached_attribute(device_id, driver.CUfunction_attribute.CU_FUNC_ATTRIBUTE_BINARY_VERSION) - def cache_mode_ca(self, device_id: int = None) -> bool: + def cache_mode_ca(self, device_id: Device | int = None) -> bool: """bool : Whether the function has been compiled with user specified option "-Xptxas --dlcm=ca" set. This attribute is read-only.""" return bool(self._get_cached_attribute(device_id, driver.CUfunction_attribute.CU_FUNC_ATTRIBUTE_CACHE_MODE_CA)) - def max_dynamic_shared_size_bytes(self, device_id: int = None) -> int: + def max_dynamic_shared_size_bytes(self, device_id: Device | int = None) -> int: """int : The maximum size in bytes of dynamically-allocated shared memory that can be used by this function.""" return self._get_cached_attribute( device_id, driver.CUfunction_attribute.CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES ) - def preferred_shared_memory_carveout(self, device_id: int = None) -> int: + def preferred_shared_memory_carveout(self, device_id: Device | int = None) -> int: """int : The shared memory carveout preference, in percent of the total shared memory.""" return self._get_cached_attribute( device_id, driver.CUfunction_attribute.CU_FUNC_ATTRIBUTE_PREFERRED_SHARED_MEMORY_CARVEOUT ) - def cluster_size_must_be_set(self, device_id: int = None) -> bool: + def cluster_size_must_be_set(self, device_id: Device | int = None) -> bool: """bool : The kernel must launch with a valid cluster size specified. This attribute is read-only.""" return bool( @@ -158,25 +160,25 @@ def cluster_size_must_be_set(self, device_id: int = None) -> bool: ) ) - def required_cluster_width(self, device_id: int = None) -> int: + def required_cluster_width(self, device_id: Device | int = None) -> int: """int : The required cluster width in blocks.""" return self._get_cached_attribute( device_id, driver.CUfunction_attribute.CU_FUNC_ATTRIBUTE_REQUIRED_CLUSTER_WIDTH ) - def required_cluster_height(self, device_id: int = None) -> int: + def required_cluster_height(self, device_id: Device | int = None) -> int: """int : The required cluster height in blocks.""" return self._get_cached_attribute( device_id, driver.CUfunction_attribute.CU_FUNC_ATTRIBUTE_REQUIRED_CLUSTER_HEIGHT ) - def required_cluster_depth(self, device_id: int = None) -> int: + def required_cluster_depth(self, device_id: Device | int = None) -> int: """int : The required cluster depth in blocks.""" return self._get_cached_attribute( device_id, driver.CUfunction_attribute.CU_FUNC_ATTRIBUTE_REQUIRED_CLUSTER_DEPTH ) - def non_portable_cluster_size_allowed(self, device_id: int = None) -> bool: + def non_portable_cluster_size_allowed(self, device_id: Device | int = None) -> bool: """bool : Whether the function can be launched with non-portable cluster size.""" return bool( self._get_cached_attribute( @@ -184,7 +186,7 @@ def non_portable_cluster_size_allowed(self, device_id: int = None) -> bool: ) ) - def cluster_scheduling_policy_preference(self, device_id: int = None) -> int: + def cluster_scheduling_policy_preference(self, device_id: Device | int = None) -> int: """int : The block scheduling policy of a function.""" return self._get_cached_attribute( device_id, driver.CUfunction_attribute.CU_FUNC_ATTRIBUTE_CLUSTER_SCHEDULING_POLICY_PREFERENCE diff --git a/cuda_core/tests/test_memory.py b/cuda_core/tests/test_memory.py index a261ec7a3..d960e6ee1 100644 --- a/cuda_core/tests/test_memory.py +++ b/cuda_core/tests/test_memory.py @@ -312,7 +312,8 @@ def test_device_memory_resource_initialization(mempool_device, use_device_object buffer.close() -def test_vmm_allocator_basic_allocation(): +@pytest.mark.parametrize("use_device_object", [True, False]) +def test_vmm_allocator_basic_allocation(use_device_object): """Test basic VMM allocation functionality. This test verifies that VirtualMemoryResource can allocate memory @@ -327,7 +328,8 @@ def test_vmm_allocator_basic_allocation(): options = VirtualMemoryResourceOptions() # Create VMM allocator with default config - vmm_mr = VirtualMemoryResource(device, config=options) + device_arg = device if use_device_object else device.device_id + vmm_mr = VirtualMemoryResource(device_arg, config=options) # Test basic allocation buffer = vmm_mr.allocate(4096) diff --git a/cuda_core/tests/test_module.py b/cuda_core/tests/test_module.py index 49df966c0..901a57f7a 100644 --- a/cuda_core/tests/test_module.py +++ b/cuda_core/tests/test_module.py @@ -134,8 +134,9 @@ def test_read_only_kernel_attributes(get_saxpy_kernel_cubin, attr, expected_type value = method() assert value is not None - # get the value for each device on the system + # get the value for each device on the system, using either the device object or ordinal for device in system.devices: + value = method(device) value = method(device.device_id) assert isinstance(value, expected_type), f"Expected {attr} to be of type {expected_type}, but got {type(value)}"