diff --git a/cuda_core/tests/helpers/buffers.py b/cuda_core/tests/helpers/buffers.py index fbd5428c28..d6166b2536 100644 --- a/cuda_core/tests/helpers/buffers.py +++ b/cuda_core/tests/helpers/buffers.py @@ -18,6 +18,30 @@ ] +def _is_managed_ptr(ptr) -> bool: + try: + attr = driver.CUpointer_attribute.CU_POINTER_ATTRIBUTE_IS_MANAGED + return bool(handle_return(driver.cuPointerGetAttribute(attr, ptr))) + except Exception: + return False + + +def _sync_for_host_managed_access(buffer) -> None: + if not _is_managed_ptr(buffer.handle): + return + device = getattr(buffer.memory_resource, "device", None) + if device is None: + try: + device = Device(int(buffer.device_id)) + except Exception: + return + try: + if not device.properties.concurrent_managed_access: + device.sync() + except AttributeError: + return + + class DummyUnifiedMemoryResource(MemoryResource): def __init__(self, device): self.device = device @@ -112,6 +136,7 @@ def verify_buffer(self, buffer, seed=None, value=None): ptr_expected = self._ptr(pattern_buffer) scratch_buffer.copy_from(buffer, stream=self.stream) self.sync_target.sync() + _sync_for_host_managed_access(scratch_buffer) assert libc.memcmp(ptr_test, ptr_expected, self.size) == 0 @staticmethod @@ -132,6 +157,7 @@ def _get_pattern_buffer(self, seed, value): else: pattern_buffer = DummyUnifiedMemoryResource(self.device).allocate(self.size) ptr = self._ptr(pattern_buffer) + _sync_for_host_managed_access(pattern_buffer) for i in range(self.size): ptr[i] = (seed + i) & 0xFF self.pattern_buffers[key] = pattern_buffer @@ -148,11 +174,14 @@ def make_scratch_buffer(device, value, nbytes): def set_buffer(buffer, value): assert 0 <= int(value) < 256 ptr = ctypes.cast(int(buffer.handle), ctypes.POINTER(ctypes.c_byte)) + _sync_for_host_managed_access(buffer) ctypes.memset(ptr, value & 0xFF, buffer.size) def compare_equal_buffers(buffer1, buffer2): """Compare the contents of two host-accessible buffers for bitwise equality.""" + _sync_for_host_managed_access(buffer1) + _sync_for_host_managed_access(buffer2) if buffer1.size != buffer2.size: return False ptr1 = ctypes.cast(int(buffer1.handle), ctypes.POINTER(ctypes.c_byte))