diff --git a/cuda_core/cuda/core/experimental/__init__.py b/cuda_core/cuda/core/experimental/__init__.py index a06119321..a01134373 100644 --- a/cuda_core/cuda/core/experimental/__init__.py +++ b/cuda_core/cuda/core/experimental/__init__.py @@ -17,7 +17,7 @@ from cuda.core.experimental._memory import ( Buffer, DeviceMemoryResource, - IPCChannel, + DeviceMemoryResourceOptions, LegacyPinnedMemoryResource, MemoryResource, ) diff --git a/cuda_core/cuda/core/experimental/_device.py b/cuda_core/cuda/core/experimental/_device.py index 0499baa58..be8c5170a 100644 --- a/cuda_core/cuda/core/experimental/_device.py +++ b/cuda_core/cuda/core/experimental/_device.py @@ -1160,6 +1160,9 @@ def __int__(self): def __repr__(self): return f"" + def __reduce__(self): + return Device, (self.device_id,) + def set_current(self, ctx: Context = None) -> Union[Context, None]: """Set device to be used for GPU executions. diff --git a/cuda_core/cuda/core/experimental/_memory.pyx b/cuda_core/cuda/core/experimental/_memory.pyx index ace146bdf..3fdc1410f 100644 --- a/cuda_core/cuda/core/experimental/_memory.pyx +++ b/cuda_core/cuda/core/experimental/_memory.pyx @@ -12,12 +12,15 @@ from cuda.core.experimental._utils.cuda_utils cimport ( import sys from dataclasses import dataclass -from typing import TypeVar, Union, TYPE_CHECKING +from typing import Optional, TypeVar, Union, TYPE_CHECKING import abc import array +import contextlib import cython +import multiprocessing import os import platform +import sys import weakref from cuda.core.experimental._dlpack import DLDeviceType, make_py_capsule from cuda.core.experimental._stream import Stream, default_stream @@ -27,8 +30,8 @@ if platform.system() == "Linux": import socket if TYPE_CHECKING: - import cuda.bindings.driver - from cuda.core.experimental._device import Device + from ._device import Device + import uuid # TODO: define a memory property mixin class and make Buffer and # MemoryResource both inherit from it @@ -81,6 +84,9 @@ cdef class Buffer: self._mr = None self._ptr_obj = None + def __reduce__(self): + return Buffer.from_ipc_descriptor, (self.memory_resource, self.get_ipc_descriptor()) + cpdef close(self, stream: Stream = None): """Deallocate this buffer asynchronously on the given stream. @@ -137,7 +143,7 @@ cdef class Buffer: return self._mr.device_id raise NotImplementedError("WIP: Currently this property only supports buffers with associated MemoryResource") - def export(self) -> IPCBufferDescriptor: + def get_ipc_descriptor(self) -> IPCBufferDescriptor: """Export a buffer allocated for sharing between processes.""" if not self._mr.is_ipc_enabled: raise RuntimeError("Memory resource is not IPC-enabled") @@ -146,7 +152,7 @@ cdef class Buffer: return IPCBufferDescriptor._init(ptr.reserved, self.size) @classmethod - def import_(cls, mr: MemoryResource, ipc_buffer: IPCBufferDescriptor) -> Buffer: + def from_ipc_descriptor(cls, mr: MemoryResource, ipc_buffer: IPCBufferDescriptor) -> Buffer: """Import a buffer that was exported from another process.""" if not mr.is_ipc_enabled: raise RuntimeError("Memory resource is not IPC-enabled") @@ -384,33 +390,29 @@ cdef class IPCBufferDescriptor: return self def __reduce__(self): - # This is subject to change if the CUmemPoolPtrExportData struct/object changes. - return (self._reconstruct, (self._reserved, self._size)) + return self._init, (self._reserved, self._size) @property def size(self): return self._size - @classmethod - def _reconstruct(cls, reserved, size): - instance = cls._init(reserved, size) - return instance - cdef class IPCAllocationHandle: """Shareable handle to an IPC-enabled device memory pool.""" cdef: int _handle + object _uuid def __init__(self, *arg, **kwargs): raise RuntimeError("IPCAllocationHandle objects cannot be instantiated directly. Please use MemoryResource APIs.") @classmethod - def _init(cls, handle: int): + def _init(cls, handle: int, uuid: uuid.UUID): cdef IPCAllocationHandle self = IPCAllocationHandle.__new__(cls) assert handle >= 0 self._handle = handle + self._uuid = uuid return self cpdef close(self): @@ -420,6 +422,7 @@ cdef class IPCAllocationHandle: os.close(self._handle) finally: self._handle = -1 + self._uuid = None def __del__(self): """Close the handle.""" @@ -436,54 +439,20 @@ cdef class IPCAllocationHandle: def handle(self) -> int: return self._handle + @property + def uuid(self) -> uuid.UUID: + return self._uuid -cdef class IPCChannel: - """Communication channel for sharing IPC-enabled memory pools.""" - - cdef: - object _proxy - - def __init__(self): - if platform.system() == "Linux": - self._proxy = IPCChannelUnixSocket._init() - else: - raise RuntimeError("IPC is not available on {platform.system()}") - - -cdef class IPCChannelUnixSocket: - """Unix-specific channel for sharing memory pools over sockets.""" - - cdef: - object _sock_out - object _sock_in - def __init__(self, *arg, **kwargs): - raise RuntimeError("IPCChannelUnixSocket objects cannot be instantiated directly. Please use MemoryResource APIs.") +def _reduce_allocation_handle(alloc_handle): + df = multiprocessing.reduction.DupFd(alloc_handle.handle) + return _reconstruct_allocation_handle, (type(alloc_handle), df, alloc_handle.uuid) - @classmethod - def _init(cls): - cdef IPCChannelUnixSocket self = IPCChannelUnixSocket.__new__(cls) - self._sock_out, self._sock_in = socket.socketpair(socket.AF_UNIX, socket.SOCK_SEQPACKET) - return self +def _reconstruct_allocation_handle(cls, df, uuid): + return cls._init(df.detach(), uuid) - cpdef _send_allocation_handle(self, alloc_handle: IPCAllocationHandle): - """Sends over this channel an allocation handle for exporting a - shared memory pool.""" - self._sock_out.sendmsg( - [], - [(socket.SOL_SOCKET, socket.SCM_RIGHTS, array.array("i", [int(alloc_handle)]))] - ) - cpdef IPCAllocationHandle _receive_allocation_handle(self): - """Receives over this channel an allocation handle for importing a - shared memory pool.""" - fds = array.array("i") - _, ancillary_data, _, _ = self._sock_in.recvmsg(0, socket.CMSG_LEN(fds.itemsize)) - assert len(ancillary_data) == 1 - cmsg_level, cmsg_type, cmsg_data = ancillary_data[0] - assert cmsg_level == socket.SOL_SOCKET and cmsg_type == socket.SCM_RIGHTS - fds.frombytes(cmsg_data[: len(cmsg_data) - (len(cmsg_data) % fds.itemsize)]) - return IPCAllocationHandle._init(int(fds[0])) +multiprocessing.reduction.register(IPCAllocationHandle, _reduce_allocation_handle) @dataclass @@ -564,8 +533,15 @@ class DeviceMemoryResourceAttributes: del mempool_property +# Holds DeviceMemoryResource objects imported by this process. +# This enables buffer serialization, as buffers can reduce to a pair +# of comprising the memory resource UUID (the key into this registry) +# and the serialized buffer descriptor. +_ipc_registry = {} + class DeviceMemoryResource(MemoryResource): - """Create a device memory resource managing a stream-ordered memory pool. + """ + Create a device memory resource managing a stream-ordered memory pool. Parameters ---------- @@ -586,8 +562,63 @@ class DeviceMemoryResource(MemoryResource): When using an existing (current or default) memory pool, the returned device memory resource does not own the pool (`is_handle_owned` is `False`), and closing the resource has no effect. + + Notes + ----- + To create an IPC-Enabled memory resource (MR) that is capable of sharing + allocations between processes, specify ``ipc_enabled=True`` in the initializer + option. Sharing an allocation is a two-step procedure that involves + mapping a memory resource and then mapping buffers owned by that resource. + These steps can be accomplished in several ways. + + An IPC-enabled memory resource can allocate memory buffers but cannot + receive shared buffers. Mapping an MR to another process creates a "mapped + memory resource" (MMR). An MMR cannot allocate memory buffers and can only + receive shared buffers. MRs and MMRs are both of type + :class:`DeviceMemoryResource` and can be distinguished via + :attr:`DeviceMemoryResource.is_mapped`. + + An MR is shared via an allocation handle obtained by calling + :meth:`DeviceMemoryResource.get_allocation_handle`. The allocation handle + has a platform-specific interpretation; however, memory IPC is currently + only supported for Linux, and in that case allocation handles are file + descriptors. After sending an allocation handle to another process, it can + be used to create an MMR by invoking + :meth:`DeviceMemoryResource.from_allocation_handle`. + + Buffers can be shared as serializable descriptors obtained by calling + :meth:`Buffer.get_ipc_descriptor`. In a receiving process, a shared buffer is + created by invoking :meth:`Buffer.from_ipc_descriptor` with an MMR and + buffer descriptor, where the MMR corresponds to the MR that created the + described buffer. + + To help manage the association between memory resources and buffers, a + registry is provided. Every MR has a unique identifier (UUID). MMRs can be + registered by calling :meth:`DeviceMemoryResource.register` with the UUID + of the corresponding MR. Registered MMRs can be looked up via + :meth:`DeviceMemoryResource.from_registry`. When registering MMRs in this + way, the use of buffer descriptors can be avoided. Instead, buffer objects + can themselves be serialized and transferred directly. Serialization embeds + the UUID, which is used to locate the correct MMR during reconstruction. + + IPC-enabled memory resources interoperate with the :mod:`multiprocessing` + module to provide a simplified interface. This approach can avoid direct + use of allocation handles, buffer descriptors, MMRs, and the registry. When + using :mod:`multiprocessing` to spawn processes or send objects through + communication channels such as :class:`multiprocessing.Queue`, + :class:`multiprocessing.Pipe`, or :class:`multiprocessing.Connection`, + :class:`Buffer` objects may be sent directly, and in such cases the process + for creating MMRs and mapping buffers will be handled automatically. + + For greater efficiency when transferring many buffers, one may also send + MRs and buffers separately. When an MR is sent via :mod:`multiprocessing`, + an MMR is created and registered in the receiving process. Subsequently, + buffers may be serialized and transferred using ordinary :mod:`pickle` + methods. The reconstruction procedure uses the registry to find the + associated MMR. """ - __slots__ = "_dev_id", "_mempool_handle", "_attributes", "_ipc_handle_type", "_mempool_owned", "_is_imported" + __slots__ = ("_dev_id", "_mempool_handle", "_attributes", "_ipc_handle_type", + "_mempool_owned", "_is_mapped", "_uuid", "_alloc_handle") def __init__(self, device_id: int | Device, options=None): device_id = getattr(device_id, 'device_id', device_id) @@ -602,7 +633,9 @@ class DeviceMemoryResource(MemoryResource): self._attributes = None self._ipc_handle_type = _NOIPC_HANDLE_TYPE self._mempool_owned = False - self._is_imported = False + self._is_mapped = False + self._uuid = None + self._alloc_handle = None err, self._mempool_handle = driver.cuDeviceGetMemPool(self.device_id) raise_if_driver_error(err) @@ -643,36 +676,92 @@ class DeviceMemoryResource(MemoryResource): self._attributes = None self._ipc_handle_type = properties.handleTypes self._mempool_owned = True - self._is_imported = False + self._is_mapped = False + self._uuid = None + self._alloc_handle = None err, self._mempool_handle = driver.cuMemPoolCreate(properties) raise_if_driver_error(err) + if opts.ipc_enabled: + self.get_allocation_handle() # enables Buffer.get_ipc_descriptor, sets uuid + def __del__(self): self.close() def close(self): """Close the device memory resource and destroy the associated memory pool if owned.""" - if self._mempool_handle is not None and self._mempool_owned: - err, = driver.cuMemPoolDestroy(self._mempool_handle) - raise_if_driver_error(err) + if self._mempool_handle is not None: + try: + if self._mempool_owned: + err, = driver.cuMemPoolDestroy(self._mempool_handle) + raise_if_driver_error(err) + finally: + if self.is_mapped: + self.unregister() + self._dev_id = None + self._mempool_handle = None + self._attributes = None + self._ipc_handle_type = _NOIPC_HANDLE_TYPE + self._mempool_owned = False + self._is_mapped = False + self._uuid = None + self._alloc_handle = None - self._dev_id = None - self._mempool_handle = None - self._attributes = None - self._ipc_handle_type = _NOIPC_HANDLE_TYPE - self._mempool_owned = False - self._is_imported = False - @classmethod - def from_shared_channel(cls, device_id: int | Device, channel: IPCChannel) -> DeviceMemoryResource: - """Create a device memory resource from a memory pool shared over an IPC channel.""" - device_id = getattr(device_id, 'device_id', device_id) - alloc_handle = channel._proxy._receive_allocation_handle() - return cls._from_allocation_handle(device_id, alloc_handle) + def __reduce__(self): + return DeviceMemoryResource.from_registry, (self.uuid,) + + @staticmethod + def from_registry(uuid: uuid.UUID) -> DeviceMemoryResource: + """ + Obtain a registered mapped memory resource. + + Raises + ------ + RuntimeError + If no mapped memory resource is found in the registry. + """ + + try: + return _ipc_registry[uuid] + except KeyError: + raise RuntimeError(f"Memory resource {uuid} was not found") from None + + def register(self, uuid: uuid.UUID) -> DeviceMemoryResource: + """ + Register a mapped memory resource. + + Returns + ------- + The registered mapped memory resource. If one was previously registered + with the given key, it is returned. + """ + existing = _ipc_registry.get(uuid) + if existing is not None: + return existing + assert self._uuid is None or self._uuid == uuid + _ipc_registry[uuid] = self + self._uuid = uuid + return self + + def unregister(self): + """Unregister this mapped memory resource.""" + assert self.is_mapped + if _ipc_registry is not None: # can occur during shutdown catastrophe + with contextlib.suppress(KeyError): + del _ipc_registry[self.uuid] + + @property + def uuid(self) -> Optional[uuid.UUID]: + """ + A universally unique identifier for this memory resource. Meaningful + only for IPC-enabled memory resources. + """ + return self._uuid @classmethod - def _from_allocation_handle(cls, device_id: int | Device, alloc_handle: IPCAllocationHandle) -> DeviceMemoryResource: + def from_allocation_handle(cls, device_id: int | Device, alloc_handle: int | IPCAllocationHandle) -> DeviceMemoryResource: """Create a device memory resource from an allocation handle. Construct a new `DeviceMemoryResource` instance that imports a memory @@ -685,13 +774,19 @@ class DeviceMemoryResource(MemoryResource): The ID of the device or a Device object for which the memory resource is created. - alloc_handle : int + alloc_handle : int | IPCAllocationHandle The shareable handle of the device memory resource to import. Returns ------- A new device memory resource instance with the imported handle. """ + # Quick exit for registry hits. + uuid = getattr(alloc_handle, 'uuid', None) + self = _ipc_registry.get(uuid) + if self is not None: + return self + device_id = getattr(device_id, 'device_id', device_id) self = cls.__new__(cls) @@ -700,19 +795,18 @@ class DeviceMemoryResource(MemoryResource): self._attributes = None self._ipc_handle_type = _IPC_HANDLE_TYPE self._mempool_owned = True - self._is_imported = True + self._is_mapped = True + self._uuid = None + self._alloc_handle = None # only used for non-imported err, self._mempool_handle = driver.cuMemPoolImportFromShareableHandle(int(alloc_handle), _IPC_HANDLE_TYPE, 0) raise_if_driver_error(err) - + if uuid is not None: + registered = self.register(uuid) + assert registered is self return self - def share_to_channel(self, channel : IPCChannel): - if not self.is_ipc_enabled: - raise RuntimeError("Memory resource is not IPC-enabled") - channel._proxy._send_allocation_handle(self._get_allocation_handle()) - - def _get_allocation_handle(self) -> IPCAllocationHandle: + def get_allocation_handle(self) -> IPCAllocationHandle: """Export the memory pool handle to be shared (requires IPC). The handle can be used to share the memory pool with other processes. @@ -722,11 +816,22 @@ class DeviceMemoryResource(MemoryResource): ------- The shareable handle for the memory pool. """ - if not self.is_ipc_enabled: - raise RuntimeError("Memory resource is not IPC-enabled") - err, alloc_handle = driver.cuMemPoolExportToShareableHandle(self._mempool_handle, _IPC_HANDLE_TYPE, 0) - raise_if_driver_error(err) - return IPCAllocationHandle._init(alloc_handle) + if self._alloc_handle is None: + if not self.is_ipc_enabled: + raise RuntimeError("Memory resource is not IPC-enabled") + if self._is_mapped: + raise RuntimeError("Imported memory resource cannot be exported") + err, alloc_handle = driver.cuMemPoolExportToShareableHandle(self._mempool_handle, _IPC_HANDLE_TYPE, 0) + raise_if_driver_error(err) + try: + assert self._uuid is None + import uuid + self._uuid = uuid.uuid4() + self._alloc_handle = IPCAllocationHandle._init(alloc_handle, self._uuid) + except: + os.close(alloc_handle) + raise + return self._alloc_handle def allocate(self, size_t size, stream: Stream = None) -> Buffer: """Allocate a buffer of the requested size. @@ -745,8 +850,8 @@ class DeviceMemoryResource(MemoryResource): The allocated buffer object, which is accessible on the device that this memory resource was created for. """ - if self._is_imported: - raise TypeError("Cannot allocate from shared memory pool imported via IPC") + if self._is_mapped: + raise TypeError("Cannot allocate from a mapped IPC-enabled memory resource") if stream is None: stream = default_stream() err, ptr = driver.cuMemAllocFromPoolAsync(size, self._mempool_handle, stream.handle) @@ -794,9 +899,12 @@ class DeviceMemoryResource(MemoryResource): return self._mempool_owned @property - def is_imported(self) -> bool: - """Whether the memory resource was imported from another process. If True, allocation is not permitted.""" - return self._is_imported + def is_mapped(self) -> bool: + """ + Whether this is a mapping of an IPC-enabled memory resource from + another process. If True, allocation is not permitted. + """ + return self._is_mapped @property def is_device_accessible(self) -> bool: @@ -814,6 +922,16 @@ class DeviceMemoryResource(MemoryResource): return self._ipc_handle_type != _NOIPC_HANDLE_TYPE +def _deep_reduce_device_memory_resource(mr): + from . import Device + device = Device(mr.device_id) + alloc_handle = mr.get_allocation_handle() + return mr.from_allocation_handle, (device, alloc_handle) + + +multiprocessing.reduction.register(DeviceMemoryResource, _deep_reduce_device_memory_resource) + + class LegacyPinnedMemoryResource(MemoryResource): """Create a pinned memory resource that uses legacy cuMemAllocHost/cudaMallocHost APIs. diff --git a/cuda_core/docs/source/api.rst b/cuda_core/docs/source/api.rst index 9c93d0f75..f239c69cd 100644 --- a/cuda_core/docs/source/api.rst +++ b/cuda_core/docs/source/api.rst @@ -30,6 +30,7 @@ CUDA runtime :template: dataclass.rst + DeviceMemoryResourceOptions EventOptions GraphCompleteOptions GraphDebugPrintOptions diff --git a/cuda_core/docs/source/api_private.rst b/cuda_core/docs/source/api_private.rst index fb36e0a30..917b7101d 100644 --- a/cuda_core/docs/source/api_private.rst +++ b/cuda_core/docs/source/api_private.rst @@ -4,9 +4,9 @@ :orphan: .. This page is to generate documentation for private classes exposed to users, - i.e., users cannot instantiate it by themselves but may use it's properties - or methods via returned values from public APIs. These classes must be referred - in public APIs returning their instances. + i.e., users cannot instantiate them but may use their properties or methods + via returned values from public APIs. These classes must be referred in + public APIs returning their instances. .. currentmodule:: cuda.core.experimental @@ -18,8 +18,9 @@ CUDA runtime _memory.PyCapsule _memory.DevicePointerT - _memory.IPCBufferDescriptor _device.DeviceProperties + _memory.IPCAllocationHandle + _memory.IPCBufferDescriptor _module.KernelAttributes _module.KernelOccupancy _module.ParamInfo diff --git a/cuda_core/tests/conftest.py b/cuda_core/tests/conftest.py index c56c0a972..db9761a3c 100644 --- a/cuda_core/tests/conftest.py +++ b/cuda_core/tests/conftest.py @@ -10,7 +10,7 @@ import multiprocessing import pytest -from cuda.core.experimental import Device, _device +from cuda.core.experimental import Device, DeviceMemoryResource, DeviceMemoryResourceOptions, _device from cuda.core.experimental._utils.cuda_utils import handle_return @@ -70,4 +70,31 @@ def pop_all_contexts(): return pop_all_contexts +@pytest.fixture +def ipc_device(): + """Obtains a device suitable for IPC-enabled mempool tests, or skips.""" + # Check if IPC is supported on this platform/device + device = Device() + device.set_current() + + if not device.properties.memory_pools_supported: + pytest.skip("Device does not support mempool operations") + + # Note: Linux specific. Once Windows support for IPC is implemented, this + # test should be updated. + if not device.properties.handle_type_posix_file_descriptor_supported: + pytest.skip("Device does not support IPC") + + return device + + +@pytest.fixture +def ipc_memory_resource(ipc_device): + POOL_SIZE = 2097152 + options = DeviceMemoryResourceOptions(max_size=POOL_SIZE, ipc_enabled=True) + mr = DeviceMemoryResource(ipc_device, options=options) + assert mr.is_ipc_enabled + return mr + + skipif_need_cuda_headers = pytest.mark.skipif(helpers.CUDA_INCLUDE_PATH is None, reason="need CUDA header") diff --git a/cuda_core/tests/memory_ipc/test_errors.py b/cuda_core/tests/memory_ipc/test_errors.py new file mode 100644 index 000000000..3e8265b39 --- /dev/null +++ b/cuda_core/tests/memory_ipc/test_errors.py @@ -0,0 +1,143 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +import multiprocessing +import pickle +import re + +from cuda.core.experimental import Buffer, Device, DeviceMemoryResource, DeviceMemoryResourceOptions +from cuda.core.experimental._utils.cuda_utils import CUDAError + +CHILD_TIMEOUT_SEC = 20 +NBYTES = 64 +POOL_SIZE = 2097152 + + +class ChildErrorHarness: + """Test harness for checking errors in child processes. Subclasses override + PARENT_ACTION, CHILD_ACTION, and ASSERT (see below for examples).""" + + def test_main(self, ipc_device, ipc_memory_resource): + """Parent process that checks child errors.""" + # Attach fixtures to this object for convenience. These can be accessed + # from PARENT_ACTION. + self.device = ipc_device + self.mr = ipc_memory_resource + + # Start a child process to generate error info. + pipe = [multiprocessing.Queue() for _ in range(2)] + process = multiprocessing.Process(target=self.child_main, args=(pipe, self.device, self.mr)) + process.start() + + # Interact. + self.PARENT_ACTION(pipe[0]) + + # Check the error. + exc_type, exc_msg = pipe[1].get(timeout=CHILD_TIMEOUT_SEC) + self.ASSERT(exc_type, exc_msg) + + # Wait for the child process. + process.join(timeout=CHILD_TIMEOUT_SEC) + assert process.exitcode == 0 + + def child_main(self, pipe, device, mr): + """Child process that pushes IPC errors to a shared pipe for testing.""" + self.device = device + self.device.set_current() + self.mr = mr + try: + self.CHILD_ACTION(pipe[0]) + except Exception as e: + exc_info = type(e), str(e) + else: + exc_info = None, None + pipe[1].put(exc_info) + + +class TestAllocFromImportedMr(ChildErrorHarness): + """Error when attempting to allocate from an import memory resource.""" + + def PARENT_ACTION(self, queue): + queue.put(self.mr) + + def CHILD_ACTION(self, queue): + mr = queue.get(timeout=CHILD_TIMEOUT_SEC) + mr.allocate(NBYTES) + + def ASSERT(self, exc_type, exc_msg): + assert exc_type is TypeError + assert exc_msg == "Cannot allocate from a mapped IPC-enabled memory resource" + + +class TestImportWrongMR(ChildErrorHarness): + """Error when importing a buffer from the wrong memory resource.""" + + def PARENT_ACTION(self, queue): + options = DeviceMemoryResourceOptions(max_size=POOL_SIZE, ipc_enabled=True) + mr2 = DeviceMemoryResource(self.device, options=options) + buffer = mr2.allocate(NBYTES) + queue.put([self.mr, buffer.get_ipc_descriptor()]) # Note: mr does not own this buffer + + def CHILD_ACTION(self, queue): + mr, buffer_desc = queue.get(timeout=CHILD_TIMEOUT_SEC) + Buffer.from_ipc_descriptor(mr, buffer_desc) + + def ASSERT(self, exc_type, exc_msg): + assert exc_type is CUDAError + assert "CUDA_ERROR_INVALID_VALUE" in exc_msg + + +class TestExportImportedMR(ChildErrorHarness): + """Error when exporting a memory resource that was imported.""" + + def PARENT_ACTION(self, queue): + queue.put(self.mr) + + def CHILD_ACTION(self, queue): + mr = queue.get(timeout=CHILD_TIMEOUT_SEC) + mr.get_allocation_handle() + + def ASSERT(self, exc_type, exc_msg): + assert exc_type is RuntimeError + assert exc_msg == "Imported memory resource cannot be exported" + + +class TestImportBuffer(ChildErrorHarness): + """Error when using a buffer as a buffer descriptor.""" + + def PARENT_ACTION(self, queue): + # Note: if the buffer is not attached to something to prolong its life, + # CUDA_ERROR_INVALID_CONTEXT is raised from Buffer.__del__ + self.buffer = self.mr.allocate(NBYTES) + queue.put(self.buffer) + + def CHILD_ACTION(self, queue): + buffer = queue.get(timeout=CHILD_TIMEOUT_SEC) + Buffer.from_ipc_descriptor(self.mr, buffer) + + def ASSERT(self, exc_type, exc_msg): + assert exc_type is TypeError + assert exc_msg.startswith("Argument 'ipc_buffer' has incorrect type") + + +class TestDanglingBuffer(ChildErrorHarness): + """ + Error when importing a buffer object without registering its memory + resource. + """ + + def PARENT_ACTION(self, queue): + options = DeviceMemoryResourceOptions(max_size=POOL_SIZE, ipc_enabled=True) + mr2 = DeviceMemoryResource(self.device, options=options) + self.buffer = mr2.allocate(NBYTES) + buffer_s = pickle.dumps(self.buffer) # noqa: S301 + queue.put(buffer_s) # Note: mr2 not sent + + def CHILD_ACTION(self, queue): + Device().set_current() + buffer_s = queue.get(timeout=CHILD_TIMEOUT_SEC) + pickle.loads(buffer_s) # noqa: S301 + + def ASSERT(self, exc_type, exc_msg): + assert exc_type is RuntimeError + assert re.match(r"Memory resource [a-z0-9-]+ was not found", exc_msg) diff --git a/cuda_core/tests/memory_ipc/test_leaks.py b/cuda_core/tests/memory_ipc/test_leaks.py new file mode 100644 index 000000000..bfead7dd3 --- /dev/null +++ b/cuda_core/tests/memory_ipc/test_leaks.py @@ -0,0 +1,142 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +import contextlib +import gc +import multiprocessing as mp + +try: + import psutil +except ImportError: + HAVE_PSUTIL = False +else: + HAVE_PSUTIL = True + +import pytest +from cuda.core.experimental import _memory +from cuda.core.experimental._utils.cuda_utils import driver + +CHILD_TIMEOUT_SEC = 20 +NBYTES = 64 + +USING_FDS = _memory._IPC_HANDLE_TYPE == driver.CUmemAllocationHandleType.CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR +skip_if_unrunnable = pytest.mark.skipif( + not USING_FDS or not HAVE_PSUTIL, reason="mempool allocation handle is not using fds or psutil is unavailable" +) + + +@skip_if_unrunnable +def test_alloc_handle(ipc_memory_resource): + """Check for fd leaks in get_allocation_handle.""" + mr = ipc_memory_resource + with CheckFDLeaks(): + [mr.get_allocation_handle() for _ in range(10)] + + +def exec_success(obj, number=1): + """Succesfully run a child process.""" + for _ in range(number): + process = mp.Process(target=child_main, args=(obj,)) + process.start() + process.join() + assert process.exitcode == 0 + + +def child_main(obj, *args): + pass + + +def exec_launch_failure(obj, number=1): + """ + Unsuccesfully try to launch a child process. This fails when + after the child starts. + """ + for _ in range(number): + process = mp.Process(target=child_main_bad, args=(obj,)) + process.start() + process.join() + assert process.exitcode != 0 + + +def child_main_bad(): + """Fails when passed arguments.""" + pass + + +def exec_reduce_failure(obj, number=1): + """ + Unsuccesfully try to launch a child process. This fails before + the child starts but after the resource-owning object is serialized. + """ + for _ in range(number): + fails_to_reduce = Irreducible() + with contextlib.suppress(RuntimeError): + mp.Process(target=child_main, args=(obj, fails_to_reduce)).start() + + +class Irreducible: + """A class that cannot be serialized.""" + + def __reduce__(self): + raise RuntimeError("Irreducible") + + +@skip_if_unrunnable +@pytest.mark.parametrize( + "getobject", + [ + lambda mr: mr.get_allocation_handle(), + lambda mr: mr, + lambda mr: mr.allocate(NBYTES), + lambda mr: mr.allocate(NBYTES).get_ipc_descriptor(), + ], + ids=["alloc_handle", "mr", "buffer", "buffer_desc"], +) +@pytest.mark.parametrize("launcher", [exec_success, exec_launch_failure, exec_reduce_failure]) +def test_pass_object(ipc_memory_resource, launcher, getobject): + """Check for fd leaks when an object is sent as a subprocess argument.""" + mr = ipc_memory_resource + with CheckFDLeaks(): + obj = getobject(mr) + try: + launcher(obj, number=2) + finally: + del obj + + +class CheckFDLeaks: + """ + Context manager to check for file descriptor leaks. + Ensures the number of open file descriptors is the same before and after the block. + """ + + def __init__(self): + self.process = psutil.Process() + + def __enter__(self): + prime() + gc.collect() + self.initial_fds = self.process.num_fds() + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + if exc_type is not None: + gc.collect() + final_fds = self.process.num_fds() + assert final_fds == self.initial_fds + return False + + +prime_was_run = False + + +def prime(): + """Multiprocessing consumes a file descriptor on first launch.""" + assert mp.get_start_method() == "spawn" + global prime_was_run + if not prime_was_run: + process = mp.Process() + process.start() + process.join() + assert process.exitcode == 0 + prime_was_run = True diff --git a/cuda_core/tests/memory_ipc/test_memory_ipc.py b/cuda_core/tests/memory_ipc/test_memory_ipc.py new file mode 100644 index 000000000..9ed24792b --- /dev/null +++ b/cuda_core/tests/memory_ipc/test_memory_ipc.py @@ -0,0 +1,181 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +import multiprocessing as mp + +from cuda.core.experimental import Buffer, DeviceMemoryResource +from utility import IPCBufferTestHelper + +CHILD_TIMEOUT_SEC = 20 +NBYTES = 64 +NWORKERS = 2 +NTASKS = 2 + + +class TestIpcMempool: + def test_main(self, ipc_device, ipc_memory_resource): + """Test IPC with memory pools.""" + # Set up the IPC-enabled memory pool and share it. + device = ipc_device + mr = ipc_memory_resource + + # Start the child process. + queue = mp.Queue() + process = mp.Process(target=self.child_main, args=(device, mr, queue)) + process.start() + + # Allocate and fill memory. + buffer = mr.allocate(NBYTES) + helper = IPCBufferTestHelper(device, buffer) + helper.fill_buffer(flipped=False) + + # Export the buffer via IPC. + queue.put(buffer) + + # Wait for the child process. + process.join(timeout=CHILD_TIMEOUT_SEC) + assert process.exitcode == 0 + + # Verify that the buffer was modified. + helper.verify_buffer(flipped=True) + + def child_main(self, device, mr, queue): + device.set_current() + buffer = queue.get(timeout=CHILD_TIMEOUT_SEC) + helper = IPCBufferTestHelper(device, buffer) + helper.verify_buffer(flipped=False) + helper.fill_buffer(flipped=True) + + +class TestIPCMempoolMultiple: + def test_main(self, ipc_device, ipc_memory_resource): + """Test IPC with memory pools using multiple processes.""" + # Construct an IPC-enabled memory resource and share it with two children. + device = ipc_device + mr = ipc_memory_resource + q1, q2 = (mp.Queue() for _ in range(2)) + + # Allocate memory buffers and export them to each child. + buffer1 = mr.allocate(NBYTES) + q1.put(buffer1) + q2.put(buffer1) + buffer2 = mr.allocate(NBYTES) + q1.put(buffer2) + q2.put(buffer2) + + # Start the child processes. + p1 = mp.Process(target=self.child_main, args=(device, mr, 1, q1)) + p2 = mp.Process(target=self.child_main, args=(device, mr, 2, q2)) + p1.start() + p2.start() + + # Wait for the child processes. + p1.join(timeout=CHILD_TIMEOUT_SEC) + p2.join(timeout=CHILD_TIMEOUT_SEC) + assert p1.exitcode == 0 + assert p2.exitcode == 0 + + # Verify that the buffers were modified. + IPCBufferTestHelper(device, buffer1).verify_buffer(flipped=False) + IPCBufferTestHelper(device, buffer2).verify_buffer(flipped=True) + + def child_main(self, device, mr, idx, queue): + # Note: passing the mr registers it so that buffers can be passed + # directly. + device.set_current() + buffer1 = queue.get(timeout=CHILD_TIMEOUT_SEC) + buffer2 = queue.get(timeout=CHILD_TIMEOUT_SEC) + if idx == 1: + IPCBufferTestHelper(device, buffer1).fill_buffer(flipped=False) + elif idx == 2: + IPCBufferTestHelper(device, buffer2).fill_buffer(flipped=True) + + +class TestIPCSharedAllocationHandleAndBufferDescriptors: + def test_main(self, ipc_device, ipc_memory_resource): + """ + Demonstrate that a memory pool allocation handle can be reused for IPC + with multiple processes. Uses buffer descriptors. + """ + # Set up the IPC-enabled memory pool and share it using one handle. + device = ipc_device + mr = ipc_memory_resource + alloc_handle = mr.get_allocation_handle() + + # Start children. + q1, q2 = (mp.Queue() for _ in range(2)) + p1 = mp.Process(target=self.child_main, args=(device, alloc_handle, 1, q1)) + p2 = mp.Process(target=self.child_main, args=(device, alloc_handle, 2, q2)) + p1.start() + p2.start() + + # Allocate and share memory. + buf1 = mr.allocate(NBYTES) + buf2 = mr.allocate(NBYTES) + q1.put(buf1.get_ipc_descriptor()) + q2.put(buf2.get_ipc_descriptor()) + + # Wait for children. + p1.join(timeout=CHILD_TIMEOUT_SEC) + p2.join(timeout=CHILD_TIMEOUT_SEC) + assert p1.exitcode == 0 + assert p2.exitcode == 0 + + # Verify results. + IPCBufferTestHelper(device, buf1).verify_buffer(starting_from=1) + IPCBufferTestHelper(device, buf2).verify_buffer(starting_from=2) + + def child_main(self, device, alloc_handle, idx, queue): + """Fills a shared memory buffer.""" + # In this case, the device needs to be set up (passing the mr does it + # implicitly in other tests). + device.set_current() + mr = DeviceMemoryResource.from_allocation_handle(device, alloc_handle) + buffer_descriptor = queue.get(timeout=CHILD_TIMEOUT_SEC) + buffer = Buffer.from_ipc_descriptor(mr, buffer_descriptor) + IPCBufferTestHelper(device, buffer).fill_buffer(starting_from=idx) + + +class TestIPCSharedAllocationHandleAndBufferObjects: + def test_main(self, ipc_device, ipc_memory_resource): + """ + Demonstrate that a memory pool allocation handle can be reused for IPC + with multiple processes. Uses buffer objects (not descriptors). + """ + device = ipc_device + mr = ipc_memory_resource + alloc_handle = mr.get_allocation_handle() + + # Start children. + q1, q2 = (mp.Queue() for _ in range(2)) + p1 = mp.Process(target=self.child_main, args=(device, alloc_handle, 1, q1)) + p2 = mp.Process(target=self.child_main, args=(device, alloc_handle, 2, q2)) + p1.start() + p2.start() + + # Allocate and share memory. + buf1 = mr.allocate(NBYTES) + buf2 = mr.allocate(NBYTES) + q1.put(buf1) + q2.put(buf2) + + # Wait for children. + p1.join(timeout=CHILD_TIMEOUT_SEC) + p2.join(timeout=CHILD_TIMEOUT_SEC) + assert p1.exitcode == 0 + assert p2.exitcode == 0 + + # Verify results. + IPCBufferTestHelper(device, buf1).verify_buffer(starting_from=1) + IPCBufferTestHelper(device, buf2).verify_buffer(starting_from=2) + + def child_main(self, device, alloc_handle, idx, queue): + """Fills a shared memory buffer.""" + device.set_current() + + # Register the memory resource. + DeviceMemoryResource.from_allocation_handle(device, alloc_handle) + + # Now get buffers. + buffer = queue.get(timeout=CHILD_TIMEOUT_SEC) + IPCBufferTestHelper(device, buffer).fill_buffer(starting_from=idx) diff --git a/cuda_core/tests/memory_ipc/test_send_buffers.py b/cuda_core/tests/memory_ipc/test_send_buffers.py new file mode 100644 index 000000000..966b6eafc --- /dev/null +++ b/cuda_core/tests/memory_ipc/test_send_buffers.py @@ -0,0 +1,57 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +import multiprocessing as mp +from itertools import cycle + +import pytest +from cuda.core.experimental import DeviceMemoryResource, DeviceMemoryResourceOptions +from utility import IPCBufferTestHelper + +CHILD_TIMEOUT_SEC = 20 +NBYTES = 64 +NMRS = 3 +NTASKS = 7 +POOL_SIZE = 2097152 + + +@pytest.mark.parametrize("nmrs", (1, NMRS)) +def test_ipc_send_buffers(ipc_device, nmrs): + """Test passing buffers sourced from multiple memory resources.""" + # Set up several IPC-enabled memory pools. + device = ipc_device + options = DeviceMemoryResourceOptions(max_size=POOL_SIZE, ipc_enabled=True) + mrs = [DeviceMemoryResource(device, options=options) for _ in range(NMRS)] + + # Allocate and fill memory. + buffers = [mr.allocate(NBYTES) for mr, _ in zip(cycle(mrs), range(NTASKS))] + for buffer in buffers: + helper = IPCBufferTestHelper(device, buffer) + helper.fill_buffer(flipped=False) + + # Start the child process. + process = mp.Process( + target=child_main, + args=( + device, + buffers, + ), + ) + process.start() + + # Wait for the child process. + process.join(timeout=CHILD_TIMEOUT_SEC) + assert process.exitcode == 0 + + # Verify that the buffers were modified. + for buffer in buffers: + helper = IPCBufferTestHelper(device, buffer) + helper.verify_buffer(flipped=True) + + +def child_main(device, buffers): + device.set_current() + for buffer in buffers: + helper = IPCBufferTestHelper(device, buffer) + helper.verify_buffer(flipped=False) + helper.fill_buffer(flipped=True) diff --git a/cuda_core/tests/memory_ipc/test_serialize.py b/cuda_core/tests/memory_ipc/test_serialize.py new file mode 100644 index 000000000..2d88bcd03 --- /dev/null +++ b/cuda_core/tests/memory_ipc/test_serialize.py @@ -0,0 +1,180 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +import multiprocessing as mp +import multiprocessing.reduction +import os + +from cuda.core.experimental import Buffer, Device, DeviceMemoryResource +from utility import IPCBufferTestHelper + +CHILD_TIMEOUT_SEC = 20 +NBYTES = 64 +POOL_SIZE = 2097152 + + +class TestObjectSerializationDirect: + """ + Test the low-level interface for sharing memory resources. + + Send a memory resource over a connection via Python's `send_handle`. Reconstruct + it on the other end and demonstrate buffer sharing. + """ + + def test_main(self, ipc_device, ipc_memory_resource): + device = ipc_device + mr = ipc_memory_resource + + # Start the child process. + parent_conn, child_conn = mp.Pipe() + process = mp.Process(target=self.child_main, args=(child_conn,)) + process.start() + + # Send a memory resource by allocation handle. + alloc_handle = mr.get_allocation_handle() + mp.reduction.send_handle(parent_conn, alloc_handle.handle, process.pid) + + # Send a buffer. + buffer1 = mr.allocate(NBYTES) + parent_conn.send(buffer1) # directly + + buffer2 = mr.allocate(NBYTES) + parent_conn.send(buffer2.get_ipc_descriptor()) # by descriptor + + # Wait for the child process. + process.join(timeout=CHILD_TIMEOUT_SEC) + assert process.exitcode == 0 + + # Confirm buffers were modified. + IPCBufferTestHelper(device, buffer1).verify_buffer(flipped=True) + IPCBufferTestHelper(device, buffer2).verify_buffer(flipped=True) + + def child_main(self, conn): + # Set up the device. + device = Device() + device.set_current() + + # Receive the memory resource. + handle = mp.reduction.recv_handle(conn) + mr = DeviceMemoryResource.from_allocation_handle(device, handle) + os.close(handle) + + # Receive the buffers. + buffer1 = conn.recv() # directly + buffer_desc = conn.recv() + buffer2 = Buffer.from_ipc_descriptor(mr, buffer_desc) # by descriptor + + # Modify the buffers. + IPCBufferTestHelper(device, buffer1).fill_buffer(flipped=True) + IPCBufferTestHelper(device, buffer2).fill_buffer(flipped=True) + + +class TestObjectSerializationWithMR: + def test_main(self, ipc_device, ipc_memory_resource): + """Test sending IPC memory objects to a child through a queue.""" + device = ipc_device + mr = ipc_memory_resource + + # Start the child process. Sending the memory resource registers it so + # that buffers can be handled automatically. + pipe = [mp.Queue() for _ in range(2)] + process = mp.Process(target=self.child_main, args=(pipe, mr)) + process.start() + + # Send a memory resource directly. This relies on the mr already + # being passed when spawning the child. + pipe[0].put(mr) + uuid = pipe[1].get(timeout=CHILD_TIMEOUT_SEC) + assert uuid == mr.uuid + + # Send a buffer. + buffer = mr.allocate(NBYTES) + pipe[0].put(buffer) + + # Wait for the child process. + process.join(timeout=CHILD_TIMEOUT_SEC) + assert process.exitcode == 0 + + # Confirm buffer was modified. + IPCBufferTestHelper(device, buffer).verify_buffer(flipped=True) + + def child_main(self, pipe, _): + device = Device() + device.set_current() + + # Memory resource. + mr = pipe[0].get(timeout=CHILD_TIMEOUT_SEC) + pipe[1].put(mr.uuid) + + # Buffer. + buffer = pipe[0].get(timeout=CHILD_TIMEOUT_SEC) + assert buffer.memory_resource.handle == mr.handle + IPCBufferTestHelper(device, buffer).fill_buffer(flipped=True) + + +def test_object_passing(ipc_device, ipc_memory_resource): + """ + Test sending objects as arguments when starting a process. + + True pickling of allocation handles and memory resources is enabled only + when spawning a process. This is similar to the way sockets and various objects + in multiprocessing (e.g., Queue) work. + """ + + # Define the objects. + device = ipc_device + mr = ipc_memory_resource + alloc_handle = mr.get_allocation_handle() + buffer = mr.allocate(NBYTES) + buffer_desc = buffer.get_ipc_descriptor() + + helper = IPCBufferTestHelper(device, buffer) + helper.fill_buffer(flipped=False) + + # Start the child process. + process = mp.Process(target=child_main, args=(alloc_handle, mr, buffer_desc, buffer)) + process.start() + process.join(timeout=CHILD_TIMEOUT_SEC) + assert process.exitcode == 0 + + helper.verify_buffer(flipped=True) + + +def child_main(alloc_handle, mr1, buffer_desc, buffer1): + device = Device() + device.set_current() + mr2 = DeviceMemoryResource.from_allocation_handle(device, alloc_handle) + + # OK to build the buffer from either mr and the descriptor. + # All buffer* objects point to the same memory. + buffer2 = Buffer.from_ipc_descriptor(mr1, buffer_desc) + buffer3 = Buffer.from_ipc_descriptor(mr2, buffer_desc) + + helper1 = IPCBufferTestHelper(device, buffer1) + helper2 = IPCBufferTestHelper(device, buffer2) + helper3 = IPCBufferTestHelper(device, buffer3) + + helper1.verify_buffer(flipped=False) + helper2.verify_buffer(flipped=False) + helper3.verify_buffer(flipped=False) + + # Modify 1. + helper1.fill_buffer(flipped=True) + + helper1.verify_buffer(flipped=True) + helper2.verify_buffer(flipped=True) + helper3.verify_buffer(flipped=True) + + # Modify 2. + helper2.fill_buffer(flipped=False) + + helper1.verify_buffer(flipped=False) + helper2.verify_buffer(flipped=False) + helper3.verify_buffer(flipped=False) + + # Modify 3. + helper3.fill_buffer(flipped=True) + + helper1.verify_buffer(flipped=True) + helper2.verify_buffer(flipped=True) + helper3.verify_buffer(flipped=True) diff --git a/cuda_core/tests/memory_ipc/test_workerpool.py b/cuda_core/tests/memory_ipc/test_workerpool.py new file mode 100644 index 000000000..401324e05 --- /dev/null +++ b/cuda_core/tests/memory_ipc/test_workerpool.py @@ -0,0 +1,117 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +import multiprocessing as mp +import pickle +from itertools import cycle + +import pytest +from cuda.core.experimental import Buffer, Device, DeviceMemoryResource, DeviceMemoryResourceOptions +from utility import IPCBufferTestHelper + +CHILD_TIMEOUT_SEC = 20 +NBYTES = 64 +NWORKERS = 2 +NMRS = 3 +NTASKS = 20 +POOL_SIZE = 2097152 + + +class TestIpcWorkerPool: + """ + Map a function over shared buffers using a worker pool to distribute work. + + This demonstrates the simplest interface, though not the most efficient + one. Each buffer transfer involes a deep transfer of the associated memory + resource (duplicates are ignored on the receiving end). + """ + + @pytest.mark.parametrize("nmrs", (1, NMRS)) + def test_main(self, ipc_device, nmrs): + device = ipc_device + options = DeviceMemoryResourceOptions(max_size=POOL_SIZE, ipc_enabled=True) + mrs = [DeviceMemoryResource(device, options=options) for _ in range(nmrs)] + buffers = [mr.allocate(NBYTES) for mr, _ in zip(cycle(mrs), range(NTASKS))] + + with mp.Pool(NWORKERS) as pool: + pool.map(self.process_buffer, buffers) + + for buffer in buffers: + IPCBufferTestHelper(device, buffer).verify_buffer(flipped=True) + + def process_buffer(self, buffer): + device = Device(buffer.memory_resource.device_id) + device.set_current() + IPCBufferTestHelper(device, buffer).fill_buffer(flipped=True) + + +class TestIpcWorkerPoolUsingIPCDescriptors: + """ + Test buffer sharing using IPC descriptors. + + The memory resources are passed to subprocesses at startup. Buffers are + passed by their handles and reconstructed using the corresponding resource. + """ + + @staticmethod + def init_worker(mrs): + """Called during child process initialization to store received memory resources.""" + TestIpcWorkerPoolUsingIPCDescriptors.mrs = mrs + + @pytest.mark.parametrize("nmrs", (1, NMRS)) + def test_main(self, ipc_device, nmrs): + device = ipc_device + options = DeviceMemoryResourceOptions(max_size=POOL_SIZE, ipc_enabled=True) + mrs = [DeviceMemoryResource(device, options=options) for _ in range(nmrs)] + buffers = [mr.allocate(NBYTES) for mr, _ in zip(cycle(mrs), range(NTASKS))] + + with mp.Pool(NWORKERS, initializer=self.init_worker, initargs=(mrs,)) as pool: + pool.starmap( + self.process_buffer, + [(mrs.index(buffer.memory_resource), buffer.get_ipc_descriptor()) for buffer in buffers], + ) + + for buffer in buffers: + IPCBufferTestHelper(device, buffer).verify_buffer(flipped=True) + + def process_buffer(self, mr_idx, buffer_desc): + mr = self.mrs[mr_idx] + device = Device(mr.device_id) + device.set_current() + buffer = Buffer.from_ipc_descriptor(mr, buffer_desc) + IPCBufferTestHelper(device, buffer).fill_buffer(flipped=True) + + +class TestIpcWorkerPoolUsingRegistry: + """ + Test buffer sharing using the memory resource registry. + + The memory resources are passed to subprocesses at startup, which + implicitly registers them. Buffers are passed via serialization and matched + to the corresponding memory resource through the registry. This is more + complicated than the simple example (first, above) but passes buffers more + efficiently. + """ + + @staticmethod + def init_worker(mrs): + # Passing mrs implicitly registers them. + pass + + @pytest.mark.parametrize("nmrs", (1, NMRS)) + def test_main(self, ipc_device, nmrs): + device = ipc_device + options = DeviceMemoryResourceOptions(max_size=POOL_SIZE, ipc_enabled=True) + mrs = [DeviceMemoryResource(device, options=options) for _ in range(nmrs)] + buffers = [mr.allocate(NBYTES) for mr, _ in zip(cycle(mrs), range(NTASKS))] + + with mp.Pool(NWORKERS, initializer=self.init_worker, initargs=(mrs,)) as pool: + pool.starmap(self.process_buffer, [(device, pickle.dumps(buffer)) for buffer in buffers]) + + for buffer in buffers: + IPCBufferTestHelper(device, buffer).verify_buffer(flipped=True) + + def process_buffer(self, device, buffer_s): + device.set_current() + buffer = pickle.loads(buffer_s) # noqa: S301 + IPCBufferTestHelper(device, buffer).fill_buffer(flipped=True) diff --git a/cuda_core/tests/memory_ipc/utility.py b/cuda_core/tests/memory_ipc/utility.py new file mode 100644 index 000000000..7ce7752b6 --- /dev/null +++ b/cuda_core/tests/memory_ipc/utility.py @@ -0,0 +1,65 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +import ctypes + +from cuda.core.experimental import Buffer, MemoryResource +from cuda.core.experimental._utils.cuda_utils import driver, handle_return + + +class DummyUnifiedMemoryResource(MemoryResource): + def __init__(self, device): + self.device = device + + def allocate(self, size, stream=None) -> Buffer: + ptr = handle_return(driver.cuMemAllocManaged(size, driver.CUmemAttach_flags.CU_MEM_ATTACH_GLOBAL.value)) + return Buffer.from_handle(ptr=ptr, size=size, mr=self) + + def deallocate(self, ptr, size, stream=None): + handle_return(driver.cuMemFree(ptr)) + + @property + def is_device_accessible(self) -> bool: + return True + + @property + def is_host_accessible(self) -> bool: + return True + + @property + def device_id(self) -> int: + return self.device + + +class IPCBufferTestHelper: + """A helper for manipulating memory buffers in IPC tests. + + Provides methods to fill a buffer with one of two test patterns and verify + the expected values. + """ + + def __init__(self, device, buffer): + self.device = device + self.buffer = buffer + self.scratch_buffer = DummyUnifiedMemoryResource(self.device).allocate(self.buffer.size) + self.stream = device.create_stream() + + def fill_buffer(self, flipped=False, starting_from=0): + """Fill a device buffer with test pattern using unified memory.""" + ptr = ctypes.cast(int(self.scratch_buffer.handle), ctypes.POINTER(ctypes.c_byte)) + op = (lambda i: 255 - i) if flipped else (lambda i: i) + for i in range(self.buffer.size): + ptr[i] = ctypes.c_byte(op(starting_from + i)) + self.buffer.copy_from(self.scratch_buffer, stream=self.stream) + self.device.sync() + + def verify_buffer(self, flipped=False, starting_from=0): + """Verify the buffer contents.""" + self.scratch_buffer.copy_from(self.buffer, stream=self.stream) + self.device.sync() + ptr = ctypes.cast(int(self.scratch_buffer.handle), ctypes.POINTER(ctypes.c_byte)) + op = (lambda i: 255 - i) if flipped else (lambda i: i) + for i in range(self.buffer.size): + assert ctypes.c_byte(ptr[i]).value == ctypes.c_byte(op(starting_from + i)).value, ( + f"Buffer contains incorrect data at index {i}" + ) diff --git a/cuda_core/tests/test_ipc_mempool.py b/cuda_core/tests/test_ipc_mempool.py deleted file mode 100644 index de436fd48..000000000 --- a/cuda_core/tests/test_ipc_mempool.py +++ /dev/null @@ -1,178 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: Apache-2.0 - -try: - from cuda.bindings import driver -except ImportError: - from cuda import cuda as driver - -import ctypes -import multiprocessing - -import pytest -from cuda.core.experimental import Buffer, Device, DeviceMemoryResource, IPCChannel, MemoryResource -from cuda.core.experimental._utils.cuda_utils import handle_return - -CHILD_TIMEOUT_SEC = 10 -NBYTES = 64 -POOL_SIZE = 2097152 - - -@pytest.fixture(scope="function") -def ipc_device(): - """Obtains a device suitable for IPC-enabled mempool tests, or skips.""" - # Check if IPC is supported on this platform/device - device = Device() - device.set_current() - - if not device.properties.memory_pools_supported: - pytest.skip("Device does not support mempool operations") - - # Note: Linux specific. Once Windows support for IPC is implemented, this - # test should be updated. - if not device.properties.handle_type_posix_file_descriptor_supported: - pytest.skip("Device does not support IPC") - - return device - - -def test_ipc_mempool(ipc_device): - """Test IPC with memory pools.""" - # Set up the IPC-enabled memory pool and share it. - stream = ipc_device.create_stream() - mr = DeviceMemoryResource(ipc_device, dict(max_size=POOL_SIZE, ipc_enabled=True)) - assert mr.is_ipc_enabled - channel = IPCChannel() - mr.share_to_channel(channel) - - # Start the child process. - queue = multiprocessing.Queue() - process = multiprocessing.Process(target=child_main1, args=(channel, queue)) - process.start() - - # Allocate and fill memory. - buffer = mr.allocate(NBYTES, stream=stream) - protocol = IPCBufferTestProtocol(ipc_device, buffer, stream=stream) - protocol.fill_buffer(flipped=False) - stream.sync() - - # Export the buffer via IPC. - handle = buffer.export() - queue.put(handle) - - # Wait for the child process. - process.join(timeout=CHILD_TIMEOUT_SEC) - assert process.exitcode == 0 - - # Verify that the buffer was modified. - protocol.verify_buffer(flipped=True) - - -def child_main1(channel, queue): - device = Device() - device.set_current() - stream = device.create_stream() - - mr = DeviceMemoryResource.from_shared_channel(device, channel) - handle = queue.get() # Get exported buffer data - buffer = Buffer.import_(mr, handle) - - protocol = IPCBufferTestProtocol(device, buffer, stream=stream) - protocol.verify_buffer(flipped=False) - protocol.fill_buffer(flipped=True) - stream.sync() - - -def test_shared_pool_errors(ipc_device): - """Test expected errors with allocating from a shared IPC memory pool.""" - # Set up the IPC-enabled memory pool and share it. - mr = DeviceMemoryResource(ipc_device, dict(max_size=POOL_SIZE, ipc_enabled=True)) - channel = IPCChannel() - mr.share_to_channel(channel) - - # Start a child process to generate error info. - queue = multiprocessing.Queue() - process = multiprocessing.Process(target=child_main2, args=(channel, queue)) - process.start() - - # Check the errors. - exc_type, exc_msg = queue.get(timeout=CHILD_TIMEOUT_SEC) - assert exc_type is TypeError - assert exc_msg == "Cannot allocate from shared memory pool imported via IPC" - - # Wait for the child process. - process.join(timeout=CHILD_TIMEOUT_SEC) - assert process.exitcode == 0 - - -def child_main2(channel, queue): - """Child process that pushes IPC errors to a shared queue for testing.""" - device = Device() - device.set_current() - - mr = DeviceMemoryResource.from_shared_channel(device, channel) - - # Allocating from an imported pool. - try: - mr.allocate(NBYTES) - except Exception as e: - exc_info = type(e), str(e) - queue.put(exc_info) - - -class DummyUnifiedMemoryResource(MemoryResource): - def __init__(self, device): - self.device = device - - def allocate(self, size, stream=None) -> Buffer: - ptr = handle_return(driver.cuMemAllocManaged(size, driver.CUmemAttach_flags.CU_MEM_ATTACH_GLOBAL.value)) - return Buffer.from_handle(ptr=ptr, size=size, mr=self) - - def deallocate(self, ptr, size, stream=None): - handle_return(driver.cuMemFree(ptr)) - - @property - def is_device_accessible(self) -> bool: - return True - - @property - def is_host_accessible(self) -> bool: - return True - - @property - def device_id(self) -> int: - return self.device - - -class IPCBufferTestProtocol: - """The protocol for verifying IPC. - - Provides methods to fill a buffer with one of two test patterns and verify - the expected values. - """ - - def __init__(self, device, buffer, nbytes=NBYTES, stream=None): - self.device = device - self.buffer = buffer - self.nbytes = nbytes - self.stream = stream if stream is not None else device.create_stream() - self.scratch_buffer = DummyUnifiedMemoryResource(self.device).allocate(self.nbytes, stream=self.stream) - - def fill_buffer(self, flipped=False): - """Fill a device buffer with test pattern using unified memory.""" - ptr = ctypes.cast(int(self.scratch_buffer.handle), ctypes.POINTER(ctypes.c_byte)) - op = (lambda i: 255 - i) if flipped else (lambda i: i) - for i in range(self.nbytes): - ptr[i] = ctypes.c_byte(op(i)) - self.buffer.copy_from(self.scratch_buffer, stream=self.stream) - - def verify_buffer(self, flipped=False): - """Verify the buffer contents.""" - self.scratch_buffer.copy_from(self.buffer, stream=self.stream) - self.stream.sync() - ptr = ctypes.cast(int(self.scratch_buffer.handle), ctypes.POINTER(ctypes.c_byte)) - op = (lambda i: 255 - i) if flipped else (lambda i: i) - for i in range(self.nbytes): - assert ctypes.c_byte(ptr[i]).value == ctypes.c_byte(op(i)).value, ( - f"Buffer contains incorrect data at index {i}" - ) diff --git a/cuda_core/tests/test_memory.py b/cuda_core/tests/test_memory.py index 4ffa813d6..5886433b2 100644 --- a/cuda_core/tests/test_memory.py +++ b/cuda_core/tests/test_memory.py @@ -15,7 +15,7 @@ import platform import pytest -from cuda.core.experimental import Buffer, Device, DeviceMemoryResource, MemoryResource +from cuda.core.experimental import Buffer, Device, DeviceMemoryResource, DeviceMemoryResourceOptions, MemoryResource from cuda.core.experimental._memory import DLDeviceType, IPCBufferDescriptor from cuda.core.experimental._utils.cuda_utils import handle_return from cuda.core.experimental.utils import StridedMemoryView @@ -310,7 +310,8 @@ def test_mempool(mempool_device): device = mempool_device # Test basic pool creation - mr = DeviceMemoryResource(device, dict(max_size=POOL_SIZE, ipc_enabled=False)) + options = DeviceMemoryResourceOptions(max_size=POOL_SIZE, ipc_enabled=False) + mr = DeviceMemoryResource(device, options=options) assert mr.device_id == device.device_id assert mr.is_device_accessible assert not mr.is_host_accessible @@ -353,14 +354,14 @@ def test_mempool(mempool_device): ipc_error_msg = "Memory resource is not IPC-enabled" with pytest.raises(RuntimeError, match=ipc_error_msg): - mr._get_allocation_handle() + mr.get_allocation_handle() with pytest.raises(RuntimeError, match=ipc_error_msg): - buffer.export() + buffer.get_ipc_descriptor() with pytest.raises(RuntimeError, match=ipc_error_msg): handle = IPCBufferDescriptor._init(b"", 0) - Buffer.import_(mr, handle) + Buffer.from_ipc_descriptor(mr, handle) buffer.close() @@ -385,7 +386,8 @@ def test_mempool_attributes(ipc_enabled, mempool_device, property_name, expected if platform.system() == "Windows": return # IPC not implemented for Windows - mr = DeviceMemoryResource(device, dict(max_size=POOL_SIZE, ipc_enabled=ipc_enabled)) + options = DeviceMemoryResourceOptions(max_size=POOL_SIZE, ipc_enabled=ipc_enabled) + mr = DeviceMemoryResource(device, options=options) assert mr.is_ipc_enabled == ipc_enabled # Get the property value