-
Notifications
You must be signed in to change notification settings - Fork 228
Add Buffer.fill() method for cuMemsetAsync support (#1314) #1318
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
c0d4284
8294553
a8ee5ec
7d9747d
07c65d2
35763df
8e2cddf
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -4,15 +4,14 @@ | |||||
|
|
||||||
| from __future__ import annotations | ||||||
|
|
||||||
| from libc.stdint cimport uintptr_t | ||||||
| from libc.stdint cimport uintptr_t, int64_t, uint64_t | ||||||
|
|
||||||
| from cuda.bindings cimport cydriver | ||||||
| from cuda.core.experimental._memory._device_memory_resource cimport DeviceMemoryResource | ||||||
| from cuda.core.experimental._memory._ipc cimport IPCBufferDescriptor, IPCDataForBuffer | ||||||
| from cuda.core.experimental._memory cimport _ipc | ||||||
| from cuda.core.experimental._stream cimport Stream_accept, Stream | ||||||
| from cuda.core.experimental._utils.cuda_utils cimport ( | ||||||
| _check_driver_error as raise_if_driver_error, | ||||||
| ) | ||||||
| from cuda.core.experimental._utils.cuda_utils cimport HANDLE_RETURN | ||||||
|
|
||||||
| import abc | ||||||
| from typing import TypeVar, Union | ||||||
|
|
@@ -137,6 +136,7 @@ cdef class Buffer: | |||||
|
|
||||||
| """ | ||||||
| stream = Stream_accept(stream) | ||||||
| cdef Stream s_stream = <Stream>stream | ||||||
| cdef size_t src_size = self._size | ||||||
|
|
||||||
| if dst is None: | ||||||
|
|
@@ -150,8 +150,14 @@ cdef class Buffer: | |||||
| raise ValueError( "buffer sizes mismatch between src and dst (sizes " | ||||||
| f"are: src={src_size}, dst={dst_size})" | ||||||
| ) | ||||||
| err, = driver.cuMemcpyAsync(dst._ptr, self._ptr, src_size, stream.handle) | ||||||
| raise_if_driver_error(err) | ||||||
| cdef cydriver.CUstream s = s_stream._handle | ||||||
| with nogil: | ||||||
| HANDLE_RETURN(cydriver.cuMemcpyAsync( | ||||||
| <cydriver.CUdeviceptr>dst._ptr, | ||||||
| <cydriver.CUdeviceptr>self._ptr, | ||||||
| src_size, | ||||||
| s | ||||||
| )) | ||||||
| return dst | ||||||
|
|
||||||
| def copy_from(self, src: Buffer, *, stream: Stream | GraphBuilder): | ||||||
|
|
@@ -167,15 +173,78 @@ cdef class Buffer: | |||||
|
|
||||||
| """ | ||||||
| stream = Stream_accept(stream) | ||||||
| cdef Stream s_stream = <Stream>stream | ||||||
| cdef size_t dst_size = self._size | ||||||
| cdef size_t src_size = src._size | ||||||
|
|
||||||
| if src_size != dst_size: | ||||||
| raise ValueError( "buffer sizes mismatch between src and dst (sizes " | ||||||
| f"are: src={src_size}, dst={dst_size})" | ||||||
| ) | ||||||
| err, = driver.cuMemcpyAsync(self._ptr, src._ptr, dst_size, stream.handle) | ||||||
| raise_if_driver_error(err) | ||||||
| cdef cydriver.CUstream s = s_stream._handle | ||||||
| with nogil: | ||||||
| HANDLE_RETURN(cydriver.cuMemcpyAsync( | ||||||
| <cydriver.CUdeviceptr>self._ptr, | ||||||
| <cydriver.CUdeviceptr>src._ptr, | ||||||
| dst_size, | ||||||
| s | ||||||
| )) | ||||||
|
|
||||||
| def fill(self, value: int, width: int, *, stream: Stream | GraphBuilder): | ||||||
| """Fill this buffer with a value pattern asynchronously on the given stream. | ||||||
|
|
||||||
| Parameters | ||||||
| ---------- | ||||||
| value : int | ||||||
| Integer value to fill the buffer with | ||||||
| width : int | ||||||
| Width in bytes for each element (must be 1, 2, or 4) | ||||||
| stream : :obj:`~_stream.Stream` | :obj:`~_graph.GraphBuilder` | ||||||
| Keyword argument specifying the stream for the asynchronous fill | ||||||
|
|
||||||
| Raises | ||||||
| ------ | ||||||
| ValueError | ||||||
| If width is not 1, 2, or 4, if value is out of range for the width, | ||||||
| or if buffer size is not divisible by width | ||||||
|
|
||||||
| """ | ||||||
| cdef Stream s_stream = Stream_accept(stream) | ||||||
| cdef unsigned char c_value8 | ||||||
| cdef unsigned short c_value16 | ||||||
| cdef unsigned int c_value32 | ||||||
| cdef size_t N | ||||||
|
|
||||||
| # Validate width | ||||||
| if width not in (1, 2, 4): | ||||||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'd put the validation code closer to the top of the function so we avoid any setup work in the error case where the user passes an unsupported size to the function.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. OK. I fiddled with it to simplify the logic. To be honest, I don't see a big improvement, here, since most of the preceding statements just declare stack variables. |
||||||
| raise ValueError(f"width must be 1, 2, or 4, got {width}") | ||||||
|
|
||||||
| # Validate buffer size modulus. | ||||||
| cdef size_t buffer_size = self._size | ||||||
| if buffer_size % width != 0: | ||||||
| raise ValueError(f"buffer size ({buffer_size}) must be divisible by width ({width})") | ||||||
|
|
||||||
| # Map width (bytes) to bitwidth and validate value | ||||||
| cdef int bitwidth = width * 8 | ||||||
| _validate_value_against_bitwidth(bitwidth, value, is_signed=False) | ||||||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think this is a bit of unexpected behavior that this will throw if someone passes a negative integer |
||||||
|
|
||||||
| # Validate value fits in width and perform fill | ||||||
| cdef cydriver.CUstream s = s_stream._handle | ||||||
| if width == 1: | ||||||
leofang marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
| c_value8 = <unsigned char>value | ||||||
| N = buffer_size | ||||||
| with nogil: | ||||||
| HANDLE_RETURN(cydriver.cuMemsetD8Async(<cydriver.CUdeviceptr>self._ptr, c_value8, N, s)) | ||||||
| elif width == 2: | ||||||
| c_value16 = <unsigned short>value | ||||||
| N = buffer_size // 2 | ||||||
| with nogil: | ||||||
| HANDLE_RETURN(cydriver.cuMemsetD16Async(<cydriver.CUdeviceptr>self._ptr, c_value16, N, s)) | ||||||
| else: # width == 4 | ||||||
| c_value32 = <unsigned int>value | ||||||
| N = buffer_size // 4 | ||||||
| with nogil: | ||||||
| HANDLE_RETURN(cydriver.cuMemsetD32Async(<cydriver.CUdeviceptr>self._ptr, c_value32, N, s)) | ||||||
|
|
||||||
| def __dlpack__( | ||||||
| self, | ||||||
|
|
@@ -340,3 +409,43 @@ cdef class MemoryResource: | |||||
| and document the behavior. | ||||||
| """ | ||||||
| ... | ||||||
|
|
||||||
|
|
||||||
| # Helper Functions | ||||||
| # ---------------- | ||||||
| cdef void _validate_value_against_bitwidth(int bitwidth, int64_t value, bint is_signed=False) except *: | ||||||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nitpick: we should probably make this an inline function for performance reasons
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. and avoid
Suggested change
|
||||||
| """Validate that a value fits within the representable range for a given bitwidth. | ||||||
|
|
||||||
| Parameters | ||||||
| ---------- | ||||||
| bitwidth : int | ||||||
| Number of bits (e.g., 8, 16, 32) | ||||||
| value : int64_t | ||||||
| Value to validate | ||||||
| is_signed : bool, optional | ||||||
| Whether the value is signed (default: False) | ||||||
|
|
||||||
| Raises | ||||||
| ------ | ||||||
| ValueError | ||||||
| If value is outside the representable range for the bitwidth | ||||||
| """ | ||||||
| cdef int max_bits = bitwidth | ||||||
| assert max_bits < 64, f"bitwidth ({max_bits}) must be less than 64" | ||||||
|
|
||||||
| cdef int64_t min_value | ||||||
| cdef uint64_t max_value_unsigned | ||||||
| cdef int64_t max_value | ||||||
|
|
||||||
| if is_signed: | ||||||
| min_value = -(<int64_t>1 << (max_bits - 1)) | ||||||
| max_value = (<int64_t>1 << (max_bits - 1)) - 1 | ||||||
| else: | ||||||
| min_value = 0 | ||||||
| max_value_unsigned = (<uint64_t>1 << max_bits) - 1 | ||||||
| max_value = <int64_t>max_value_unsigned | ||||||
|
Comment on lines
+440
to
+446
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nitpick: getting the min and max values here can be cimported from the built in libc module: https://github.com/cython/cython/blob/master/Cython/Includes/libc/limits.pxd i.e. |
||||||
|
|
||||||
| if not min_value <= value <= max_value: | ||||||
| raise ValueError( | ||||||
| f"value must be in range [{min_value}, {max_value}]" | ||||||
| ) | ||||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -219,6 +219,88 @@ def test_buffer_copy_from(): | |
| buffer_copy_from(DummyPinnedMemoryResource(device), device, check=True) | ||
|
|
||
|
|
||
| def buffer_fill(dummy_mr: MemoryResource, device: Device, check=False): | ||
| stream = device.create_stream() | ||
|
|
||
| # Test width=1 (byte fill) | ||
| buffer1 = dummy_mr.allocate(size=1024) | ||
| buffer1.fill(0x42, width=1, stream=stream) | ||
| device.sync() | ||
|
|
||
| if check: | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why are parametrizing these value checks in a test suite? The memory sizes don't strike me as so large that these operations would be slow.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The values are only checked when the memory allocation is pinned. This follows the existing pattern. |
||
| ptr = ctypes.cast(buffer1.handle, ctypes.POINTER(ctypes.c_byte)) | ||
| for i in range(10): | ||
| assert ptr[i] == 0x42 | ||
|
|
||
| # Test error: invalid width | ||
| for bad_width in [w for w in range(-10, 10) if w not in (1, 2, 4)]: | ||
| with pytest.raises(ValueError, match="width must be 1, 2, or 4"): | ||
| buffer1.fill(0x42, width=bad_width, stream=stream) | ||
|
|
||
| # Test error: value out of range for width=1 | ||
| for bad_value in [-42, -1, 256]: | ||
| with pytest.raises(ValueError, match="value must be in range \\[0, 255\\]"): | ||
| buffer1.fill(bad_value, width=1, stream=stream) | ||
|
|
||
| # Test error: buffer size not divisible by width | ||
| for bad_size in [1025, 1027, 1029, 1031]: # Not divisible by 2 | ||
| buffer_err = dummy_mr.allocate(size=1025) | ||
| with pytest.raises(ValueError, match="must be divisible"): | ||
| buffer_err.fill(0x1234, width=2, stream=stream) | ||
| buffer_err.close() | ||
|
|
||
| buffer1.close() | ||
|
|
||
| # Test width=2 (16-bit fill) | ||
| buffer2 = dummy_mr.allocate(size=1024) # Divisible by 2 | ||
| buffer2.fill(0x1234, width=2, stream=stream) | ||
| device.sync() | ||
|
|
||
| if check: | ||
| ptr = ctypes.cast(buffer2.handle, ctypes.POINTER(ctypes.c_uint16)) | ||
| for i in range(5): | ||
| assert ptr[i] == 0x1234 | ||
|
|
||
| # Test error: value out of range for width=2 | ||
| for bad_value in [-42, -1, 65536, 65537, 100000]: | ||
| with pytest.raises(ValueError, match="value must be in range \\[0, 65535\\]"): | ||
| buffer2.fill(bad_value, width=2, stream=stream) | ||
|
|
||
| buffer2.close() | ||
|
|
||
| # Test width=4 (32-bit fill) | ||
| buffer4 = dummy_mr.allocate(size=1024) # Divisible by 4 | ||
| buffer4.fill(0xDEADBEEF, width=4, stream=stream) | ||
| device.sync() | ||
|
|
||
| if check: | ||
| ptr = ctypes.cast(buffer4.handle, ctypes.POINTER(ctypes.c_uint32)) | ||
| for i in range(5): | ||
| assert ptr[i] == 0xDEADBEEF | ||
|
|
||
| # Test error: value out of range for width=4 | ||
| for bad_value in [-42, -1, 4294967296, 4294967297, 5000000000]: | ||
| with pytest.raises(ValueError, match="value must be in range \\[0, 4294967295\\]"): | ||
| buffer4.fill(bad_value, width=4, stream=stream) | ||
|
|
||
| # Test error: buffer size not divisible by width | ||
| for bad_size in [1025, 1026, 1027, 1029, 1030, 1031]: # Not divisible by 4 | ||
| buffer_err2 = dummy_mr.allocate(size=bad_size) | ||
| with pytest.raises(ValueError, match="must be divisible"): | ||
| buffer_err2.fill(0xDEADBEEF, width=4, stream=stream) | ||
| buffer_err2.close() | ||
|
|
||
| buffer4.close() | ||
|
|
||
|
|
||
| def test_buffer_fill(): | ||
| device = Device() | ||
| device.set_current() | ||
| buffer_fill(DummyDeviceMemoryResource(device), device) | ||
| buffer_fill(DummyUnifiedMemoryResource(device), device) | ||
| buffer_fill(DummyPinnedMemoryResource(device), device, check=True) | ||
|
|
||
|
|
||
| def buffer_close(dummy_mr: MemoryResource): | ||
| buffer = dummy_mr.allocate(size=1024) | ||
| buffer.close() | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I know the underlying driver APIs expect an integer, but I think for
Buffer.fill()it would be good to support a byte-like input as well, maybe via buffer protocol?Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To confirm my understanding, I believe the suggestion here is to accept arguments of type
bytesor that provide bytes via the buffer protocol (where the number of bytes equals the buffer size) and fill ourBufferfrom that. Would that be better to implement asBuffer.copy_fromrather thanBuffer.fill? Or possibly a new functionBuffer.copy_from_bytesorBuffer.copy_from_host.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The internal logic will be different from
Buffer.fill, since it requires copying the bytes to a staging area, unlikecuMemset*.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I was referring to allowing a user to pass a memoryview object of 1, 2, or 4 bytes in length to be used as the fill value. I.e.
Buffer.fill(b'abcd').