Skip to content

Commit

Permalink
[DLPACK] Support the new python array api with DLPack (#7993)
Browse files Browse the repository at this point in the history
* [DLPACK] Support the new python array api with dlpack

* Fix lint
  • Loading branch information
YuchenJin committed May 7, 2021
1 parent d789875 commit 8d9a1df
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 5 deletions.
35 changes: 31 additions & 4 deletions python/tvm/runtime/ndarray.py
Expand Up @@ -62,6 +62,27 @@ def device(self):
"""Device of this array"""
return self.handle.contents.device

def __dlpack__(self, stream=None): # pylint: disable=unused-argument
"""Export the array for consumption by from_dlpack() as a DLPack capsule.
Parameters
----------
stream : int, optional
A Python integer representing a pointer to a stream.
Stream is provided by the consumer to the producer to instruct the producer
to ensure that operations can safely be performed on the array.
Returns
-------
capsule : PyCapsule
A DLPack capsule for the array, containing a DLPackManagedTensor.
"""
return self.to_dlpack()

def __dlpack_device__(self):
"""Return a tuple of device_type, device_id in DLPack convention"""
return (self.handle.contents.device.device_type, self.handle.contents.device.device_id)

def __hash__(self):
return ctypes.cast(self.handle, ctypes.c_void_p).value

Expand Down Expand Up @@ -301,22 +322,28 @@ def empty(shape, dtype="float32", device=device(1, 0), mem_scope=None):


def from_dlpack(dltensor):
"""Produce an array from a DLPack tensor without memory copy.
"""Produces an array from an object with __dlpack__ method or a DLPack tensor w/o memory copy.
Retreives the underlying DLPack tensor's pointer to create an array from the
data. Removes the original DLPack tensor's destructor as now the array is
responsible for destruction.
Parameters
----------
dltensor : DLPack tensor
Input DLManagedTensor, can only be consumed once.
dltensor : object with __dlpack__ attribute or a DLPack capsule
Returns
-------
arr: tvm.nd.NDArray
The array view of the tensor data.
"""
return _from_dlpack(dltensor)
t = type(dltensor)
if t.__module__ == "builtins" and t.__name__ == "PyCapsule":
return _from_dlpack(dltensor)

if hasattr(dltensor, "__dlpack__"):
dlpack_caps = dltensor.__dlpack__()
return _from_dlpack(dlpack_caps)
raise AttributeError("Required attribute __dlpack__ not found")


def cpu(dev_id=0):
Expand Down
3 changes: 2 additions & 1 deletion tests/python/contrib/test_dlpack.py
Expand Up @@ -15,6 +15,7 @@
# specific language governing permissions and limitations
# under the License.
import tvm
import tvm.testing
from tvm import te
import numpy as np
from tvm.contrib.dlpack import to_pytorch_func
Expand All @@ -32,7 +33,7 @@ def test():
x = torch.rand(56, 56)
tvm_x = tvm.nd.from_dlpack(torch.utils.dlpack.to_dlpack(x))
np.testing.assert_equal(x.numpy(), tvm_x.asnumpy())
y = tvm.nd.from_dlpack(tvm_x.to_dlpack())
y = tvm.nd.from_dlpack(tvm_x)
np.testing.assert_equal(y.asnumpy(), tvm_x.asnumpy())
np.testing.assert_equal(
torch.utils.dlpack.from_dlpack(y.to_dlpack()).numpy(), tvm_x.asnumpy()
Expand Down

0 comments on commit 8d9a1df

Please sign in to comment.