From 8d9a1dfe77cba7220d9313c32070f190b7eb30d8 Mon Sep 17 00:00:00 2001 From: Yuchen Jin Date: Thu, 6 May 2021 22:51:49 -0700 Subject: [PATCH] [DLPACK] Support the new python array api with DLPack (#7993) * [DLPACK] Support the new python array api with dlpack * Fix lint --- python/tvm/runtime/ndarray.py | 35 +++++++++++++++++++++++++---- tests/python/contrib/test_dlpack.py | 3 ++- 2 files changed, 33 insertions(+), 5 deletions(-) diff --git a/python/tvm/runtime/ndarray.py b/python/tvm/runtime/ndarray.py index 68735e0e115d..befe077a4dcd 100644 --- a/python/tvm/runtime/ndarray.py +++ b/python/tvm/runtime/ndarray.py @@ -62,6 +62,27 @@ def device(self): """Device of this array""" return self.handle.contents.device + def __dlpack__(self, stream=None): # pylint: disable=unused-argument + """Export the array for consumption by from_dlpack() as a DLPack capsule. + + Parameters + ---------- + stream : int, optional + A Python integer representing a pointer to a stream. + Stream is provided by the consumer to the producer to instruct the producer + to ensure that operations can safely be performed on the array. + + Returns + ------- + capsule : PyCapsule + A DLPack capsule for the array, containing a DLPackManagedTensor. + """ + return self.to_dlpack() + + def __dlpack_device__(self): + """Return a tuple of device_type, device_id in DLPack convention""" + return (self.handle.contents.device.device_type, self.handle.contents.device.device_id) + def __hash__(self): return ctypes.cast(self.handle, ctypes.c_void_p).value @@ -301,22 +322,28 @@ def empty(shape, dtype="float32", device=device(1, 0), mem_scope=None): def from_dlpack(dltensor): - """Produce an array from a DLPack tensor without memory copy. + """Produces an array from an object with __dlpack__ method or a DLPack tensor w/o memory copy. Retreives the underlying DLPack tensor's pointer to create an array from the data. Removes the original DLPack tensor's destructor as now the array is responsible for destruction. Parameters ---------- - dltensor : DLPack tensor - Input DLManagedTensor, can only be consumed once. + dltensor : object with __dlpack__ attribute or a DLPack capsule Returns ------- arr: tvm.nd.NDArray The array view of the tensor data. """ - return _from_dlpack(dltensor) + t = type(dltensor) + if t.__module__ == "builtins" and t.__name__ == "PyCapsule": + return _from_dlpack(dltensor) + + if hasattr(dltensor, "__dlpack__"): + dlpack_caps = dltensor.__dlpack__() + return _from_dlpack(dlpack_caps) + raise AttributeError("Required attribute __dlpack__ not found") def cpu(dev_id=0): diff --git a/tests/python/contrib/test_dlpack.py b/tests/python/contrib/test_dlpack.py index 8bf9069b78cf..ca6592b3e61e 100644 --- a/tests/python/contrib/test_dlpack.py +++ b/tests/python/contrib/test_dlpack.py @@ -15,6 +15,7 @@ # specific language governing permissions and limitations # under the License. import tvm +import tvm.testing from tvm import te import numpy as np from tvm.contrib.dlpack import to_pytorch_func @@ -32,7 +33,7 @@ def test(): x = torch.rand(56, 56) tvm_x = tvm.nd.from_dlpack(torch.utils.dlpack.to_dlpack(x)) np.testing.assert_equal(x.numpy(), tvm_x.asnumpy()) - y = tvm.nd.from_dlpack(tvm_x.to_dlpack()) + y = tvm.nd.from_dlpack(tvm_x) np.testing.assert_equal(y.asnumpy(), tvm_x.asnumpy()) np.testing.assert_equal( torch.utils.dlpack.from_dlpack(y.to_dlpack()).numpy(), tvm_x.asnumpy()