Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs_input/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ set(DOXYGEN_EXCLUDE_DIR2 ${PROJECT_BINARY_DIR}/*)
set(DOXYGEN_EXCLUDE_DIR3 ${PROJECT_SOURCE_DIR}/*build*)
set(DOXYGEN_EXCLUDE_DIR4 ${PROJECT_SOURCE_DIR}/examples/cmake_sample_project/build*)
set(DOXYGEN_EXCLUDE_DIR5 ${PROJECT_SOURCE_DIR}/libmathdx/*)
set(DOXYGEN_EXCLUDE_DIR6 ${PROJECT_SOURCE_DIR}/examples/*)
set(DOXYXML_DIR ${PROJECT_BINARY_DIR}/doxygen/xml/)
set(DOXYFILE_IN ${CMAKE_CURRENT_SOURCE_DIR}/Doxyfile.in)
set(DOXYFILE_OUT ${CMAKE_CURRENT_BINARY_DIR}/Doxyfile)
Expand Down
1 change: 1 addition & 0 deletions docs_input/Doxyfile.in
Original file line number Diff line number Diff line change
Expand Up @@ -887,6 +887,7 @@ EXCLUDE_PATTERNS += "@DOXYGEN_EXCLUDE_DIR2@"
EXCLUDE_PATTERNS += "@DOXYGEN_EXCLUDE_DIR3@"
EXCLUDE_PATTERNS += "@DOXYGEN_EXCLUDE_DIR4@"
EXCLUDE_PATTERNS += "@DOXYGEN_EXCLUDE_DIR5@"
EXCLUDE_PATTERNS += "@DOXYGEN_EXCLUDE_DIR6@"

# The EXCLUDE_SYMBOLS tag can be used to specify one or more symbol names
# (namespaces, classes, functions, etc.) that should be excluded from the
Expand Down
7 changes: 7 additions & 0 deletions docs_input/api/creation/tensors/make.rst
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,13 @@ Custom Allocator Support
.. doxygenfunction:: make_tensor( TensorType &tensor, const index_t (&shape)[TensorType::Rank()], Allocator&& alloc)
.. doxygenfunction:: make_tensor( TensorType &tensor, ShapeType &&shape, Allocator&& alloc)

DLPack Support
~~~~~~~~~~~~~~
.. versionadded:: 1.1.0
Comment thread
simonbyrne marked this conversation as resolved.
.. doxygenfunction:: make_tensor( TensorType &tensor, DLManagedTensorVersioned *dlp_tensor)
.. versionadded:: 1.1.0
.. doxygenfunction:: make_tensor( TensorType &tensor, DLManagedTensor *dlp_tensor)

Return by Pointer
~~~~~~~~~~~~~~~~~
.. doxygenfunction:: make_tensor_p( const index_t (&shape)[RANK], matxMemorySpace_t space = MATX_MANAGED_MEMORY, cudaStream_t stream = 0)
Expand Down
58 changes: 56 additions & 2 deletions docs_input/external.rst
Original file line number Diff line number Diff line change
Expand Up @@ -81,8 +81,62 @@ Care must be taken when passing either operators or pointers to existing code to
* The *kind* of the pointer must be known to the external code. For example, if the tensor was created in device memory, the external
code must access it only where device memory is accessible.

If the external code supports the *dlpack* standard, the tensor's `ToDLPack()` method can be used instead to get a `DLManagedTensor` object.
This method is much safer since all shape and ownership can be transferred.
DLPack Interoperability
=======================

If the external code supports the `DLPack exchange API <https://dmlc.github.io/dlpack/latest/>`_, MatX can exchange tensors
with full metadata (dtype, shape, strides, device) and explicit ownership.

DLPack operates on a producer-consumer model where the producer is the library
that creates the tensor and the consumer is the library that uses the tensor.
The producer is responsible for creating a pointer to a
`DLManagedTensorVersioned` or `DLManagedTensor` object, which contains a
reference to the tensor and a deleter function. The consumer is responsible for
calling the `deleter` function when it is done with the tensor.

Exporting MatX tensors via DLPack
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

MatX supports exporting both legacy and versioned DLPack objects:

.. code-block:: cpp

auto t = matx::make_tensor<float>({10, 10});

// Versioned DLPack (v1.x style)
DLManagedTensorVersioned *versioned = t.ToDlPackVersioned();
// Legacy DLPack (v0.x style)
DLManagedTensor *legacy = t.ToDlPack();

Both calls increment internal ownership so the underlying storage stays valid
until the matching DLPack `deleter` is called.

Importing external DLPack tensors into MatX
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

When importing into MatX, use the `make_tensor` overloads that consume a
pointer to a `DLManagedTensorVersioned` or `DLManagedTensor` object. For
example, to convert a libtorch tensor to a MatX tensor:

.. code-block:: cpp

#include <torch/torch.h>
#include <ATen/DLConvertor.h>
#include <matx.h>

// Create a libtorch tensor
auto torch_tensor = torch::randn({10, 10});

// Convert the libtorch tensor to a MatX tensor
matx::tensor_t<float, 2> t;
matx::make_tensor(t, at::toDLPackVersioned(torch_tensor));

MatX will invoke the producer-provided DLPack deleter when the last MatX
reference to the imported storage is released.

.. important::

A `DLManagedTensorVersioned` or `DLManagedTensor` should only be consumed once.


Passing By Object
Expand Down
97 changes: 31 additions & 66 deletions examples/python_integration_sample/example_matxutil.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,69 +9,34 @@

import matxutil

# Demonstrate dlpack consumption invalidates it for future use
def dlp_usage_error():
a = cp.empty((3,3), dtype=cp.float32)
dlp = a.__dlpack__()
assert(matxutil.check_dlpack_status(dlp) == 0)
a2 = cp.from_dlpack(dlp) # causes dlp to become unused
assert(matxutil.check_dlpack_status(dlp) != 0)
return dlp

# Demonstrate cupy array stays in scope when returning valid dlp
def scope_okay():
a = cp.empty((3,3), dtype=cp.float32)
a[1,1] = 2
dlp = a.__dlpack__()
assert(matxutil.check_dlpack_status(dlp) == 0)
return dlp

#Do all cupy work using the "with stream" context manager
stream = cp.cuda.stream.Stream(non_blocking=True)
with stream:
print("Demonstrate dlpack consumption invalidates it for future use:")
dlp = dlp_usage_error()
assert(matxutil.check_dlpack_status(dlp) != 0)
print(f" dlp capsule name is: {matxutil.get_capsule_name(dlp)}")
print()

print("Demonstrate cupy array stays in scope when returning valid dlpack:")
dlp = scope_okay()
assert(matxutil.check_dlpack_status(dlp) == 0)
print(f" dlp capsule name is: {matxutil.get_capsule_name(dlp)}")
print()

print("Print info about the dlpack:")
matxutil.print_dlpack_info(dlp)
print()

print("Use MatX to print the tensor:")
matxutil.print_float_2D(dlp)
print()

print("Print current memory usage info:")
gpu_mempool = cp.get_default_memory_pool()
pinned_mempool = cp.get_default_pinned_memory_pool()
print(f" GPU mempool used bytes {gpu_mempool.used_bytes()}")
print(f" Pinned mempool n_free_blocks {pinned_mempool.n_free_blocks()}")
print()

print("Demonstrate python to C++ to python to C++ calling chain (uses mypythonlib.py):")
# This function calls back into python and executes a from_dlpack, consuming the dlp
matxutil.call_python_example(dlp)
assert(matxutil.check_dlpack_status(dlp) != 0)
del dlp

print("Demonstrate adding two tensors together using MatX:")
a = cp.array([[1,2,3],[4,5,6],[7,8,9]], dtype=cp.float32)
b = cp.array([[1,2,3],[4,5,6],[7,8,9]], dtype=cp.float32)
c = cp.empty(b.shape, dtype=b.dtype)

c_dlp = c.__dlpack__(stream=stream.ptr)
a_dlp = a.__dlpack__(stream=stream.ptr)
b_dlp = b.__dlpack__(stream=stream.ptr)
matxutil.add_float_2D(c_dlp, a_dlp, b_dlp, stream.ptr)
stream.synchronize()
print(f"Tensor a {a}")
print(f"Tensor b {b}")
print(f"Tensor c=a+b {c}")
a = cp.arange(9, dtype=cp.float32).reshape(3, 3)

# Convert the cupy array to a DLPack capsule
print("Printing tensor using MatX:")
a_dlp = a.__dlpack__()
# Print the tensor using MatX
matxutil.print_float_2D(a_dlp)

# calling again will throw an error, as the DLPack capsule has been consumed
try:
matxutil.print_float_2D(a_dlp)
assert False, "Expected print_float_2D to throw"
except Exception:
pass

# passing an incompatible tensor type will throw an error
try:
matxutil.print_float_2D(cp.arange(9, dtype=cp.float64).__dlpack__())
assert False, "Expected print_float_2D to throw"
except Exception:
pass

print("Printing tensor using Python called from MatX:")
# valid as we create a new DLPack capsule
matxutil.python_print_float_2D(a.__dlpack__())

print("Adding two tensors together using MatX on the current stream:")
b = cp.ones((3, 3), dtype=cp.float32)
c = cp.empty((3, 3), dtype=cp.float32)
matxutil.add_float_2D(c.__dlpack__(), a.__dlpack__(), b.__dlpack__(), cp.cuda.get_current_stream().ptr)
print(c) # implicit stream synchronization
Loading