Skip to content

Commit

Permalink
swigcuvec: add retarray convenience function
Browse files Browse the repository at this point in the history
  • Loading branch information
casperdcl committed May 16, 2023
1 parent ba7b6ae commit 8331390
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 5 deletions.
3 changes: 1 addition & 2 deletions cuvec/include/cuvec.i
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,10 @@
* for external use via `%include "cuvec.i"`.
*/
%include "std_vector.i"

%{
#include "cuvec.cuh" // SwigCuVec<T>
%}

/// expose definitions
template <class T> struct SwigCuVec {
CuVec<T> vec;
std::vector<size_t> shape;
Expand Down
17 changes: 14 additions & 3 deletions cuvec/swigcuvec.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,9 +181,9 @@ def asarray(arr, dtype=None, order=None, ownership: str = 'warning') -> CuVec:
>>> res = asarray(some_swig_api_func(..., output=getattr(out, 'cuvec', None)))
`res.cuvec` and `out.cuvec` are now the same
yet garbage collected separately (dangling ptr).
Instead, use:
>>> res = some_swig_api_func(..., output=getattr(out, 'cuvec', None))
>>> res = out if hasattr(out, 'cuvec') else asarray(res)
Instead, use the `retarray` helper:
>>> raw = some_swig_api_func(..., output=getattr(out, 'cuvec', None))
>>> res = retarray(raw, out)
NB: `asarray()` is safe if the raw cuvec was created in C++/SWIG, e.g.:
>>> res = asarray(some_swig_api_func(..., output=None), ownership='debug')
"""
Expand All @@ -198,3 +198,14 @@ def asarray(arr, dtype=None, order=None, ownership: str = 'warning') -> CuVec:
if dtype is None or res.dtype == np.dtype(dtype):
return CuVec(np.asanyarray(res, order=order))
return CuVec(np.asanyarray(arr, dtype=dtype, order=order))


def retarray(raw, out: Optional[CuVec] = None):
"""
Returns `out if hasattr(out, 'cuvec') else asarray(raw, ownership='debug')`.
See `asarray` for explanation.
Args:
raw: a raw CuVec (returned by C++/SWIG function).
out: preallocated output array.
"""
return out if hasattr(out, 'cuvec') else asarray(raw, ownership='debug')

Check warning on line 211 in cuvec/swigcuvec.py

View check run for this annotation

Codecov / codecov/patch

cuvec/swigcuvec.py#L211

Added line #L211 was not covered by tests

0 comments on commit 8331390

Please sign in to comment.