From 83313909e3fa5b95ee8bfb5c0d8cb47080234696 Mon Sep 17 00:00:00 2001 From: Casper da Costa-Luis Date: Tue, 16 May 2023 19:03:23 +0530 Subject: [PATCH] swigcuvec: add retarray convenience function --- cuvec/include/cuvec.i | 3 +-- cuvec/swigcuvec.py | 17 ++++++++++++++--- 2 files changed, 15 insertions(+), 5 deletions(-) diff --git a/cuvec/include/cuvec.i b/cuvec/include/cuvec.i index 701c848..b60efce 100644 --- a/cuvec/include/cuvec.i +++ b/cuvec/include/cuvec.i @@ -3,11 +3,10 @@ * for external use via `%include "cuvec.i"`. */ %include "std_vector.i" - %{ #include "cuvec.cuh" // SwigCuVec %} - +/// expose definitions template struct SwigCuVec { CuVec vec; std::vector shape; diff --git a/cuvec/swigcuvec.py b/cuvec/swigcuvec.py index 7dd5f16..9f91ba2 100644 --- a/cuvec/swigcuvec.py +++ b/cuvec/swigcuvec.py @@ -181,9 +181,9 @@ def asarray(arr, dtype=None, order=None, ownership: str = 'warning') -> CuVec: >>> res = asarray(some_swig_api_func(..., output=getattr(out, 'cuvec', None))) `res.cuvec` and `out.cuvec` are now the same yet garbage collected separately (dangling ptr). - Instead, use: - >>> res = some_swig_api_func(..., output=getattr(out, 'cuvec', None)) - >>> res = out if hasattr(out, 'cuvec') else asarray(res) + Instead, use the `retarray` helper: + >>> raw = some_swig_api_func(..., output=getattr(out, 'cuvec', None)) + >>> res = retarray(raw, out) NB: `asarray()` is safe if the raw cuvec was created in C++/SWIG, e.g.: >>> res = asarray(some_swig_api_func(..., output=None), ownership='debug') """ @@ -198,3 +198,14 @@ def asarray(arr, dtype=None, order=None, ownership: str = 'warning') -> CuVec: if dtype is None or res.dtype == np.dtype(dtype): return CuVec(np.asanyarray(res, order=order)) return CuVec(np.asanyarray(arr, dtype=dtype, order=order)) + + +def retarray(raw, out: Optional[CuVec] = None): + """ + Returns `out if hasattr(out, 'cuvec') else asarray(raw, ownership='debug')`. + See `asarray` for explanation. + Args: + raw: a raw CuVec (returned by C++/SWIG function). + out: preallocated output array. + """ + return out if hasattr(out, 'cuvec') else asarray(raw, ownership='debug')