Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
125 changes: 117 additions & 8 deletions cuda_core/cuda/core/experimental/_memory/_buffer.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,14 @@

from __future__ import annotations

from libc.stdint cimport uintptr_t
from libc.stdint cimport uintptr_t, int64_t, uint64_t

from cuda.bindings cimport cydriver
from cuda.core.experimental._memory._device_memory_resource cimport DeviceMemoryResource
from cuda.core.experimental._memory._ipc cimport IPCBufferDescriptor, IPCDataForBuffer
from cuda.core.experimental._memory cimport _ipc
from cuda.core.experimental._stream cimport Stream_accept, Stream
from cuda.core.experimental._utils.cuda_utils cimport (
_check_driver_error as raise_if_driver_error,
)
from cuda.core.experimental._utils.cuda_utils cimport HANDLE_RETURN

import abc
from typing import TypeVar, Union
Expand Down Expand Up @@ -137,6 +136,7 @@ cdef class Buffer:

"""
stream = Stream_accept(stream)
cdef Stream s_stream = <Stream>stream
cdef size_t src_size = self._size

if dst is None:
Expand All @@ -150,8 +150,14 @@ cdef class Buffer:
raise ValueError( "buffer sizes mismatch between src and dst (sizes "
f"are: src={src_size}, dst={dst_size})"
)
err, = driver.cuMemcpyAsync(dst._ptr, self._ptr, src_size, stream.handle)
raise_if_driver_error(err)
cdef cydriver.CUstream s = s_stream._handle
with nogil:
HANDLE_RETURN(cydriver.cuMemcpyAsync(
<cydriver.CUdeviceptr>dst._ptr,
<cydriver.CUdeviceptr>self._ptr,
src_size,
s
))
return dst

def copy_from(self, src: Buffer, *, stream: Stream | GraphBuilder):
Expand All @@ -167,15 +173,78 @@ cdef class Buffer:

"""
stream = Stream_accept(stream)
cdef Stream s_stream = <Stream>stream
cdef size_t dst_size = self._size
cdef size_t src_size = src._size

if src_size != dst_size:
raise ValueError( "buffer sizes mismatch between src and dst (sizes "
f"are: src={src_size}, dst={dst_size})"
)
err, = driver.cuMemcpyAsync(self._ptr, src._ptr, dst_size, stream.handle)
raise_if_driver_error(err)
cdef cydriver.CUstream s = s_stream._handle
with nogil:
HANDLE_RETURN(cydriver.cuMemcpyAsync(
<cydriver.CUdeviceptr>self._ptr,
<cydriver.CUdeviceptr>src._ptr,
dst_size,
s
))

def fill(self, value: int, width: int, *, stream: Stream | GraphBuilder):
"""Fill this buffer with a value pattern asynchronously on the given stream.

Parameters
----------
value : int
Integer value to fill the buffer with
Comment on lines +198 to +199
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I know the underlying driver APIs expect an integer, but I think for Buffer.fill() it would be good to support a byte-like input as well, maybe via buffer protocol?

Copy link
Contributor Author

@Andy-Jost Andy-Jost Dec 8, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To confirm my understanding, I believe the suggestion here is to accept arguments of type bytes or that provide bytes via the buffer protocol (where the number of bytes equals the buffer size) and fill our Buffer from that. Would that be better to implement as Buffer.copy_from rather than Buffer.fill? Or possibly a new function Buffer.copy_from_bytes or Buffer.copy_from_host.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The internal logic will be different from Buffer.fill, since it requires copying the bytes to a staging area, unlike cuMemset*.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was referring to allowing a user to pass a memoryview object of 1, 2, or 4 bytes in length to be used as the fill value. I.e. Buffer.fill(b'abcd').

width : int
Width in bytes for each element (must be 1, 2, or 4)
stream : :obj:`~_stream.Stream` | :obj:`~_graph.GraphBuilder`
Keyword argument specifying the stream for the asynchronous fill

Raises
------
ValueError
If width is not 1, 2, or 4, if value is out of range for the width,
or if buffer size is not divisible by width

"""
cdef Stream s_stream = Stream_accept(stream)
cdef unsigned char c_value8
cdef unsigned short c_value16
cdef unsigned int c_value32
cdef size_t N

# Validate width
if width not in (1, 2, 4):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd put the validation code closer to the top of the function so we avoid any setup work in the error case where the user passes an unsupported size to the function.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK. I fiddled with it to simplify the logic. To be honest, I don't see a big improvement, here, since most of the preceding statements just declare stack variables.

raise ValueError(f"width must be 1, 2, or 4, got {width}")

# Validate buffer size modulus.
cdef size_t buffer_size = self._size
if buffer_size % width != 0:
raise ValueError(f"buffer size ({buffer_size}) must be divisible by width ({width})")

# Map width (bytes) to bitwidth and validate value
cdef int bitwidth = width * 8
_validate_value_against_bitwidth(bitwidth, value, is_signed=False)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is a bit of unexpected behavior that this will throw if someone passes a negative integer


# Validate value fits in width and perform fill
cdef cydriver.CUstream s = s_stream._handle
if width == 1:
c_value8 = <unsigned char>value
N = buffer_size
with nogil:
HANDLE_RETURN(cydriver.cuMemsetD8Async(<cydriver.CUdeviceptr>self._ptr, c_value8, N, s))
elif width == 2:
c_value16 = <unsigned short>value
N = buffer_size // 2
with nogil:
HANDLE_RETURN(cydriver.cuMemsetD16Async(<cydriver.CUdeviceptr>self._ptr, c_value16, N, s))
else: # width == 4
c_value32 = <unsigned int>value
N = buffer_size // 4
with nogil:
HANDLE_RETURN(cydriver.cuMemsetD32Async(<cydriver.CUdeviceptr>self._ptr, c_value32, N, s))

def __dlpack__(
self,
Expand Down Expand Up @@ -340,3 +409,43 @@ cdef class MemoryResource:
and document the behavior.
"""
...


# Helper Functions
# ----------------
cdef void _validate_value_against_bitwidth(int bitwidth, int64_t value, bint is_signed=False) except *:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nitpick: we should probably make this an inline function for performance reasons

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

and avoid except*, so it should return an int instead, something like

Suggested change
cdef void _validate_value_against_bitwidth(int bitwidth, int64_t value, bint is_signed=False) except *:
cdef int _validate_value_against_bitwidth(int bitwidth, int64_t value, bint is_signed=False) except?-1:

"""Validate that a value fits within the representable range for a given bitwidth.

Parameters
----------
bitwidth : int
Number of bits (e.g., 8, 16, 32)
value : int64_t
Value to validate
is_signed : bool, optional
Whether the value is signed (default: False)

Raises
------
ValueError
If value is outside the representable range for the bitwidth
"""
cdef int max_bits = bitwidth
assert max_bits < 64, f"bitwidth ({max_bits}) must be less than 64"

cdef int64_t min_value
cdef uint64_t max_value_unsigned
cdef int64_t max_value

if is_signed:
min_value = -(<int64_t>1 << (max_bits - 1))
max_value = (<int64_t>1 << (max_bits - 1)) - 1
else:
min_value = 0
max_value_unsigned = (<uint64_t>1 << max_bits) - 1
max_value = <int64_t>max_value_unsigned
Comment on lines +440 to +446
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nitpick: getting the min and max values here can be cimported from the built in libc module: https://github.com/cython/cython/blob/master/Cython/Includes/libc/limits.pxd

i.e. from libc.limits cimport INT_MAX


if not min_value <= value <= max_value:
raise ValueError(
f"value must be in range [{min_value}, {max_value}]"
)
43 changes: 29 additions & 14 deletions cuda_core/tests/test_graph_mem.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,8 +75,14 @@ def free(self, buffers):


@pytest.mark.parametrize("mode", ["no_graph", "global", "thread_local", "relaxed"])
def test_graph_alloc(mempool_device, mode):
"""Test basic graph capture with memory allocated and deallocated by GraphMemoryResource."""
@pytest.mark.parametrize("action", ["incr", "fill"])
def test_graph_alloc(mempool_device, mode, action):
"""Test basic graph capture with memory allocated and deallocated by
GraphMemoryResource.

This test verifies graph capture for Buffer operations including copy_from,
copy_to, fill, and kernel launch operations.
"""
NBYTES = 64
device = mempool_device
stream = device.create_stream()
Expand All @@ -93,14 +99,22 @@ def test_graph_alloc(mempool_device, mode):
config = LaunchConfig(grid=1, block=1)
launch(stream, config, set_zero, out, NBYTES)

# Increments out by 3
def apply_kernels(mr, stream, out):
buffer = mr.allocate(NBYTES, stream=stream)
buffer.copy_from(out, stream=stream)
for kernel in [add_one, add_one, add_one]:
launch(stream, config, kernel, buffer, NBYTES)
out.copy_from(buffer, stream=stream)
buffer.close()
if action == "incr":
# Increments out by 3
def apply_kernels(mr, stream, out):
buffer = mr.allocate(NBYTES, stream=stream)
buffer.copy_from(out, stream=stream)
for kernel in [add_one, add_one, add_one]:
launch(stream, config, kernel, buffer, NBYTES)
out.copy_from(buffer, stream=stream)
buffer.close()
elif action == "fill":
# Fills out with 3
def apply_kernels(mr, stream, out):
buffer = mr.allocate(NBYTES, stream=stream)
buffer.fill(3, width=1, stream=stream)
out.copy_from(buffer, stream=stream)
buffer.close()

# Apply kernels, with or without graph capture.
if mode == "no_graph":
Expand All @@ -121,10 +135,11 @@ def apply_kernels(mr, stream, out):
assert compare_buffer_to_constant(out, 3)

# Second launch.
graph.upload(stream)
graph.launch(stream)
stream.sync()
assert compare_buffer_to_constant(out, 6)
if action == "incr":
graph.upload(stream)
graph.launch(stream)
stream.sync()
assert compare_buffer_to_constant(out, 6)


@pytest.mark.skipif(IS_WINDOWS or IS_WSL, reason="auto_free_on_launch not supported on Windows")
Expand Down
82 changes: 82 additions & 0 deletions cuda_core/tests/test_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,88 @@ def test_buffer_copy_from():
buffer_copy_from(DummyPinnedMemoryResource(device), device, check=True)


def buffer_fill(dummy_mr: MemoryResource, device: Device, check=False):
stream = device.create_stream()

# Test width=1 (byte fill)
buffer1 = dummy_mr.allocate(size=1024)
buffer1.fill(0x42, width=1, stream=stream)
device.sync()

if check:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why are parametrizing these value checks in a test suite? The memory sizes don't strike me as so large that these operations would be slow.

Copy link
Contributor Author

@Andy-Jost Andy-Jost Dec 5, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The values are only checked when the memory allocation is pinned. This follows the existing pattern.

ptr = ctypes.cast(buffer1.handle, ctypes.POINTER(ctypes.c_byte))
for i in range(10):
assert ptr[i] == 0x42

# Test error: invalid width
for bad_width in [w for w in range(-10, 10) if w not in (1, 2, 4)]:
with pytest.raises(ValueError, match="width must be 1, 2, or 4"):
buffer1.fill(0x42, width=bad_width, stream=stream)

# Test error: value out of range for width=1
for bad_value in [-42, -1, 256]:
with pytest.raises(ValueError, match="value must be in range \\[0, 255\\]"):
buffer1.fill(bad_value, width=1, stream=stream)

# Test error: buffer size not divisible by width
for bad_size in [1025, 1027, 1029, 1031]: # Not divisible by 2
buffer_err = dummy_mr.allocate(size=1025)
with pytest.raises(ValueError, match="must be divisible"):
buffer_err.fill(0x1234, width=2, stream=stream)
buffer_err.close()

buffer1.close()

# Test width=2 (16-bit fill)
buffer2 = dummy_mr.allocate(size=1024) # Divisible by 2
buffer2.fill(0x1234, width=2, stream=stream)
device.sync()

if check:
ptr = ctypes.cast(buffer2.handle, ctypes.POINTER(ctypes.c_uint16))
for i in range(5):
assert ptr[i] == 0x1234

# Test error: value out of range for width=2
for bad_value in [-42, -1, 65536, 65537, 100000]:
with pytest.raises(ValueError, match="value must be in range \\[0, 65535\\]"):
buffer2.fill(bad_value, width=2, stream=stream)

buffer2.close()

# Test width=4 (32-bit fill)
buffer4 = dummy_mr.allocate(size=1024) # Divisible by 4
buffer4.fill(0xDEADBEEF, width=4, stream=stream)
device.sync()

if check:
ptr = ctypes.cast(buffer4.handle, ctypes.POINTER(ctypes.c_uint32))
for i in range(5):
assert ptr[i] == 0xDEADBEEF

# Test error: value out of range for width=4
for bad_value in [-42, -1, 4294967296, 4294967297, 5000000000]:
with pytest.raises(ValueError, match="value must be in range \\[0, 4294967295\\]"):
buffer4.fill(bad_value, width=4, stream=stream)

# Test error: buffer size not divisible by width
for bad_size in [1025, 1026, 1027, 1029, 1030, 1031]: # Not divisible by 4
buffer_err2 = dummy_mr.allocate(size=bad_size)
with pytest.raises(ValueError, match="must be divisible"):
buffer_err2.fill(0xDEADBEEF, width=4, stream=stream)
buffer_err2.close()

buffer4.close()


def test_buffer_fill():
device = Device()
device.set_current()
buffer_fill(DummyDeviceMemoryResource(device), device)
buffer_fill(DummyUnifiedMemoryResource(device), device)
buffer_fill(DummyPinnedMemoryResource(device), device, check=True)


def buffer_close(dummy_mr: MemoryResource):
buffer = dummy_mr.allocate(size=1024)
buffer.close()
Expand Down
Loading