diff --git a/cuda_core/cuda/core/experimental/_memory.pyx b/cuda_core/cuda/core/experimental/_memory.pyx index 024ffa2ae..7c79b775f 100644 --- a/cuda_core/cuda/core/experimental/_memory.pyx +++ b/cuda_core/cuda/core/experimental/_memory.pyx @@ -1119,7 +1119,7 @@ class VirtualMemoryResourceOptions: location_type: VirtualMemoryLocationTypeT = "device" handle_type: VirtualMemoryHandleTypeT = "posix_fd" granularity: VirtualMemoryGranularityT = "recommended" - gpu_direct_rdma: bool = True + gpu_direct_rdma: bool = False addr_hint: Optional[int] = 0 addr_align: Optional[int] = None peers: Iterable[int] = field(default_factory=tuple) @@ -1198,6 +1198,11 @@ class VirtualMemoryResource(MemoryResource): if platform.system() == "Windows": raise NotImplementedError("VirtualMemoryResource is not supported on Windows") + # Validate RDMA support if requested + if self.config.gpu_direct_rdma and self.device is not None: + if not self.device.properties.gpu_direct_rdma_supported: + raise RuntimeError("GPU Direct RDMA is not supported on this device") + @staticmethod def _align_up(size: int, gran: int) -> int: """ diff --git a/cuda_core/tests/test_memory.py b/cuda_core/tests/test_memory.py index 904997f11..3e69bab42 100644 --- a/cuda_core/tests/test_memory.py +++ b/cuda_core/tests/test_memory.py @@ -28,7 +28,7 @@ from cuda.core.experimental._utils.cuda_utils import handle_return from cuda.core.experimental.utils import StridedMemoryView -from cuda_python_test_helpers import IS_WSL, supports_ipc_mempool +from cuda_python_test_helpers import supports_ipc_mempool POOL_SIZE = 2097152 # 2MB size @@ -322,13 +322,13 @@ def test_vmm_allocator_basic_allocation(): This test verifies that VirtualMemoryResource can allocate memory using CUDA VMM APIs with default configuration. """ - if platform.system() == "Windows": - pytest.skip("VirtualMemoryResource is not supported on Windows TCC") - if IS_WSL: - pytest.skip("VirtualMemoryResource is not supported on WSL") - device = Device() device.set_current() + + # Skip if virtual memory management is not supported + if not device.properties.virtual_memory_management_supported: + pytest.skip("Virtual memory management is not supported on this device") + options = VirtualMemoryResourceOptions() # Create VMM allocator with default config vmm_mr = VirtualMemoryResource(device, config=options) @@ -361,13 +361,17 @@ def test_vmm_allocator_policy_configuration(): with different allocation policies and that the configuration affects the allocation behavior. """ - if platform.system() == "Windows": - pytest.skip("VirtualMemoryResource is not supported on Windows TCC") - if IS_WSL: - pytest.skip("VirtualMemoryResource is not supported on WSL") device = Device() device.set_current() + # Skip if virtual memory management is not supported + if not device.properties.virtual_memory_management_supported: + pytest.skip("Virtual memory management is not supported on this device") + + # Skip if GPU Direct RDMA is supported (we want to test the unsupported case) + if not device.properties.gpu_direct_rdma_supported: + pytest.skip("This test requires a device that doesn't support GPU Direct RDMA") + # Test with custom VMM config custom_config = VirtualMemoryResourceOptions( allocation_type="pinned", @@ -420,13 +424,13 @@ def test_vmm_allocator_grow_allocation(): This test verifies that VirtualMemoryResource can grow existing allocations while preserving the base pointer when possible. """ - if platform.system() == "Windows": - pytest.skip("VirtualMemoryResource is not supported on Windows TCC") - if IS_WSL: - pytest.skip("VirtualMemoryResource is not supported on WSL") device = Device() device.set_current() + # Skip if virtual memory management is not supported (we need it for VMM) + if not device.properties.virtual_memory_management_supported: + pytest.skip("Virtual memory management is not supported on this device") + options = VirtualMemoryResourceOptions() vmm_mr = VirtualMemoryResource(device, config=options) @@ -458,6 +462,29 @@ def test_vmm_allocator_grow_allocation(): grown_buffer.close() +def test_vmm_allocator_rdma_unsupported_exception(): + """Test that VirtualMemoryResource throws an exception when RDMA is requested but device doesn't support it. + + This test verifies that the VirtualMemoryResource constructor throws a RuntimeError + when gpu_direct_rdma=True is requested but the device doesn't support virtual memory management. + """ + device = Device() + device.set_current() + + # Skip if virtual memory management is not supported (we need it for VMM) + if not device.properties.virtual_memory_management_supported: + pytest.skip("Virtual memory management is not supported on this device") + + # Skip if GPU Direct RDMA is supported (we want to test the unsupported case) + if device.properties.gpu_direct_rdma_supported: + pytest.skip("This test requires a device that doesn't support GPU Direct RDMA") + + # Test that requesting RDMA on an unsupported device throws an exception + options = VirtualMemoryResourceOptions(gpu_direct_rdma=True) + with pytest.raises(RuntimeError, match="GPU Direct RDMA is not supported on this device"): + VirtualMemoryResource(device, config=options) + + def test_mempool(mempool_device): device = mempool_device