Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion cuda_core/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,8 @@ cu12 = ["cuda-bindings[all]==12.*"]
cu13 = ["cuda-bindings[all]==13.*"]

[dependency-groups]
test = ["cython>=3.2,<3.3", "setuptools", "pytest>=6.2.4"]
# use cffi for VMM tests on Windows
test = ["cython>=3.2,<3.3", "setuptools", "pytest>=6.2.4", "cffi; platform_system == \"Windows\""]
test-cu12 = ["cuda-core[test]", "cupy-cuda12x; python_version < '3.14'", "cuda-toolkit[cudart]==12.*"] # runtime headers needed by CuPy
test-cu13 = ["cuda-core[test]", "cupy-cuda13x; python_version < '3.14'", "cuda-toolkit[cudart]==13.*"] # runtime headers needed by CuPy
# free threaded build, cupy doesn't support free-threaded builds yet, so avoid installing it for now
Expand Down
77 changes: 56 additions & 21 deletions cuda_core/tests/test_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@
# SPDX-License-Identifier: Apache-2.0

import ctypes
import functools
import sys
from ctypes import wintypes

try:
from cuda.bindings import driver
Expand Down Expand Up @@ -314,30 +314,65 @@ def test_device_memory_resource_initialization(mempool_device, use_device_object
buffer.close()


def get_handle_type():
def get_sa():
class SECURITY_ATTRIBUTES(ctypes.Structure):
_fields_ = [
("nLength", wintypes.DWORD),
("lpSecurityDescriptor", wintypes.LPVOID),
("bInheritHandle", wintypes.BOOL),
]
@functools.cache
def get_win32_object_attributes():
"""If success, the returned pointer address to the OBJECT_ATTRIBUTES
object is valid within the test process lifetime. This helper function
is needed because OBJECT_ATTRIBUTES must be initialized by a macro
function, and the usual FFI/ctypes tricks do not work.
"""
import tempfile

from cffi import FFI

prog = FFI()
prog.cdef("uintptr_t get_ptr_to_oa();")
source = r"""
#define _AMD64_
#include <ntdef.h>

static OBJECT_ATTRIBUTES objAttributes;

extern "C"
uintptr_t get_ptr_to_oa() {
static bool objAttributesConfigured = false;
if (!objAttributesConfigured) {
InitializeObjectAttributes(&objAttributes, NULL, 0, NULL, NULL);
objAttributesConfigured = true;
}
return reinterpret_cast<uintptr_t>(&objAttributes);
}
"""
prog.set_source(
"mod_oa",
source,
source_extension=".cpp",
extra_compile_args=("/std:c++17",),
)
temp_dir = tempfile.mkdtemp()
prog.compile(tmpdir=temp_dir)

sa = SECURITY_ATTRIBUTES()
sa.nLength = ctypes.sizeof(sa)
sa.lpSecurityDescriptor = None
sa.bInheritHandle = False # TODO: why?
sys.path.append(temp_dir)
from mod_oa.lib import get_ptr_to_oa

return sa
return get_ptr_to_oa()


def get_handle_types():
if IS_WINDOWS:
return (("win32", get_sa()), ("win32_kmt", None))
tests = [("win32_kmt", None)]
try:
ptr_to_oa = get_win32_object_attributes()
tests.append(("win32", ptr_to_oa))
except Exception: # noqa: S110
pass
return tests
else:
return (("posix_fd", None),)


@pytest.mark.parametrize("use_device_object", [True, False])
@pytest.mark.parametrize("handle_type", get_handle_type())
@pytest.mark.parametrize("handle_type", get_handle_types())
def test_vmm_allocator_basic_allocation(use_device_object, handle_type):
"""Test basic VMM allocation functionality.

Expand All @@ -351,8 +386,8 @@ def test_vmm_allocator_basic_allocation(use_device_object, handle_type):
if not device.properties.virtual_memory_management_supported:
pytest.skip("Virtual memory management is not supported on this device")

handle_type, security_attribute = handle_type # unpack
win32_handle_metadata = ctypes.addressof(security_attribute) if security_attribute else 0
handle_type, object_attributes = handle_type # unpack
win32_handle_metadata = object_attributes if handle_type == "win32" else 0
options = VirtualMemoryResourceOptions(
handle_type=handle_type,
win32_handle_metadata=win32_handle_metadata,
Expand Down Expand Up @@ -446,7 +481,7 @@ def test_vmm_allocator_policy_configuration():
modified_buffer.close()


@pytest.mark.parametrize("handle_type", get_handle_type())
@pytest.mark.parametrize("handle_type", get_handle_types())
def test_vmm_allocator_grow_allocation(handle_type):
"""Test VMM allocator's ability to grow existing allocations.

Expand All @@ -460,8 +495,8 @@ def test_vmm_allocator_grow_allocation(handle_type):
if not device.properties.virtual_memory_management_supported:
pytest.skip("Virtual memory management is not supported on this device")

handle_type, security_attribute = handle_type # unpack
win32_handle_metadata = ctypes.addressof(security_attribute) if security_attribute else 0
handle_type, object_attributes = handle_type # unpack
win32_handle_metadata = object_attributes if handle_type == "win32" else 0
options = VirtualMemoryResourceOptions(
handle_type=handle_type,
win32_handle_metadata=win32_handle_metadata,
Expand Down
Loading