Skip to content

Commit

Permalink
python - add vec.array_write
Browse files Browse the repository at this point in the history
  • Loading branch information
jeremylt committed Dec 15, 2021
1 parent 6bb2089 commit 3e8c176
Show file tree
Hide file tree
Showing 2 changed files with 94 additions and 0 deletions.
74 changes: 74 additions & 0 deletions python/ceed_vector.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,54 @@ def get_array_read(self, memtype=MEM_HOST):
# return read only Numba array
return nbcuda.from_cuda_array_interface(desc)

# Get Vector's data array in write-only mode
def get_array_write(self, memtype=MEM_HOST):
"""Get write-only access to a Vector via the specified memory type.
All old values should be considered invalid.
Args:
**memtype: memory type of the array being passed, default CEED_MEM_HOST
Returns:
*array: Numpy or Numba array"""

# Retrieve the length of the array
length_pointer = ffi.new("CeedInt *")
err_code = lib.CeedVectorGetLength(self._pointer[0], length_pointer)
self._ceed._check_error(err_code)

# Setup the pointer's pointer
array_pointer = ffi.new("CeedScalar **")

# libCEED call
err_code = lib.CeedVectorGetArrayWrite(
self._pointer[0], memtype, array_pointer)
self._ceed._check_error(err_code)

# Return array created from buffer
if memtype == MEM_HOST:
# Create buffer object from returned pointer
buff = ffi.buffer(
array_pointer[0],
ffi.sizeof("CeedScalar") *
length_pointer[0])
# return read only Numpy array
ret = np.frombuffer(buff, dtype=scalar_types[lib.CEED_SCALAR_TYPE])
ret.flags['WRITEABLE'] = False
return ret
else:
# CUDA array interface
# https://numba.pydata.org/numba-doc/latest/cuda/cuda_array_interface.html
import numba.cuda as nbcuda
desc = {
'shape': (length_pointer[0]),
'typestr': '>f8',
'data': (int(ffi.cast("intptr_t", array_pointer[0])), False),
'version': 2
}
# return read only Numba array
return nbcuda.from_cuda_array_interface(desc)

# Restore the Vector's data array
def restore_array(self):
"""Restore an array obtained using get_array()."""
Expand Down Expand Up @@ -264,6 +312,32 @@ def array_read(self, *shape, memtype=MEM_HOST):
yield x
self.restore_array_read()

@contextlib.contextmanager
def array_write(self, *shape, memtype=MEM_HOST):
"""Context manager for write-only array access.
All old values should be considered invalid.
Args:
shape (tuple): shape of returned numpy.array
**memtype: memory type of the array being passed, default CEED_MEM_HOST
Returns:
np.array: write-only view of vector
Examples:
Viewing contents of a reshaped libceed.Vector view:
>>> vec = ceed.Vector(6)
>>> vec.set_value(1.3)
>>> with vec.array_read(2, 3) as x:
>>> print(x)
"""
x = self.get_array_write(memtype=memtype)
if shape:
x = x.reshape(shape)
yield x
self.restore_array()

# Get the length of a Vector
def get_length(self):
"""Get the length of a Vector.
Expand Down
20 changes: 20 additions & 0 deletions python/tests/test-1-vector.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,6 +312,26 @@ def test_123(ceed_resource, capsys):
with x.array() as b:
assert np.allclose(-.5 * a, b)

# -------------------------------------------------------------------------------
# Test getArrayWrite to modify array
# -------------------------------------------------------------------------------


def test_124(ceed_resource):
ceed = libceed.Ceed(ceed_resource)

n = 10

x = ceed.Vector(n)

with x.array_write() as a:
for i in range(x.length):
a[i] = 3 * i

with x.array_read() as a:
for i in range(len(a)):
assert a[i] == 3 * i

# -------------------------------------------------------------------------------
# Test modification of reshaped array
# -------------------------------------------------------------------------------
Expand Down

0 comments on commit 3e8c176

Please sign in to comment.