Skip to content

Commit

Permalink
Nits and cleanups
Browse files Browse the repository at this point in the history
  • Loading branch information
pitrou committed Dec 19, 2023
1 parent 811d2b5 commit 49a978f
Show file tree
Hide file tree
Showing 4 changed files with 34 additions and 41 deletions.
38 changes: 16 additions & 22 deletions cpp/src/arrow/c/dlpack.cc
Original file line number Diff line number Diff line change
Expand Up @@ -58,51 +58,45 @@ Result<DLDataType> GetDLDataType(const DataType& type) {
}
}

} // namespace
struct ManagerCtx {
std::shared_ptr<ArrayData> ref;
std::shared_ptr<ArrayData> array;
DLManagedTensor tensor;
};

} // namespace

Result<DLManagedTensor*> ExportArray(const std::shared_ptr<Array>& arr) {
// Define DLDevice struct nad check if array type is supported
// by the DLPack protocol at the same time. Raise TypeError if not.
// Supported data types: int, uint, float with no validity buffer.
ARROW_ASSIGN_OR_RAISE(auto device, ExportDevice(arr))

// Define the DLDataType struct
const DataType* arrow_type = arr->type().get();
ARROW_ASSIGN_OR_RAISE(auto dlpack_type, GetDLDataType(*arrow_type));
const DataType& type = *arr->type();
std::shared_ptr<ArrayData> data = arr->data();
ARROW_ASSIGN_OR_RAISE(auto dlpack_type, GetDLDataType(type));

// Create ManagerCtx with the reference to
// the data of the array
std::shared_ptr<ArrayData> array_ref = arr->data();
// Create ManagerCtx that will serve as the owner of the DLManagedTensor
std::unique_ptr<ManagerCtx> ctx(new ManagerCtx);
ctx->ref = array_ref;

// Define the data pointer to the DLTensor
// If array is of length 0, data pointer should be NULL
if (arr->length() == 0) {
ctx->tensor.dl_tensor.data = NULL;
} else if (arr->offset() > 0) {
const auto byte_width = arr->type()->byte_width();
const auto start = arr->offset() * byte_width;
ARROW_ASSIGN_OR_RAISE(auto sliced_buffer,
SliceBufferSafe(array_ref->buffers[1], start));
ctx->tensor.dl_tensor.data =
const_cast<void*>(reinterpret_cast<const void*>(sliced_buffer->address()));
} else {
ctx->tensor.dl_tensor.data = const_cast<void*>(
reinterpret_cast<const void*>(array_ref->buffers[1]->address()));
const auto data_offset = data->offset * type.byte_width();
ctx->tensor.dl_tensor.data =
const_cast<uint8_t*>(data->buffers[1]->data() + data_offset);
}

ctx->tensor.dl_tensor.device = device;
ctx->tensor.dl_tensor.ndim = 1;
ctx->tensor.dl_tensor.dtype = dlpack_type;
ctx->tensor.dl_tensor.shape = const_cast<int64_t*>(&array_ref->length);
ctx->tensor.dl_tensor.shape = const_cast<int64_t*>(&data->length);
ctx->tensor.dl_tensor.strides = NULL;
ctx->tensor.dl_tensor.byte_offset = 0;

ctx->array = std::move(data);
ctx->tensor.manager_ctx = ctx.get();
ctx->tensor.deleter = [](struct DLManagedTensor* self) {
delete reinterpret_cast<ManagerCtx*>(self->manager_ctx);
Expand All @@ -115,13 +109,13 @@ Result<DLDevice> ExportDevice(const std::shared_ptr<Array>& arr) {
if (arr->null_count() > 0) {
return Status::TypeError("Can only use DLPack on arrays with no nulls.");
}
const DataType* arrow_type = arr->type().get();
if (arrow_type->id() == Type::BOOL) {
const DataType& type = *arr->type();
if (type.id() == Type::BOOL) {
return Status::TypeError("Bit-packed boolean data type not supported by DLPack.");
}
if (!is_integer(arrow_type->id()) && !is_floating(arrow_type->id())) {
if (!is_integer(type.id()) && !is_floating(type.id())) {
return Status::TypeError("DataType is not compatible with DLPack spec: ",
arrow_type->ToString());
type.ToString());
}

// Define DLDevice struct
Expand Down
9 changes: 4 additions & 5 deletions cpp/src/arrow/c/dlpack.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,13 @@

namespace arrow::dlpack {

/// \brief DLPack protocol for producing DLManagedTensor
/// \brief Export Arrow array as DLPack tensor.
///
/// DLMangedTensor is produced from an array as defined by
/// the DLPack protocol, see https://dmlc.github.io/dlpack/latest/.
/// DLMangedTensor is produced as defined by the DLPack protocol,
/// see https://dmlc.github.io/dlpack/latest/.
///
/// Data types for which the protocol is supported are
/// primitive data types without NullType, BooleanType and
/// Decimal types.
/// integer and floating-point data types.
///
/// DLPack protocol only supports arrays with one contiguous
/// memory region which means Arrow Arrays with validity buffers
Expand Down
20 changes: 10 additions & 10 deletions cpp/src/arrow/c/dlpack_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,9 @@ class TestExportArray : public ::testing::Test {
void SetUp() {}
};

auto check_dlptensor = [](const std::shared_ptr<Array>& arr,
std::shared_ptr<DataType> arrow_type,
DLDataTypeCode dlpack_type, int64_t length) {
void CheckDLTensor(const std::shared_ptr<Array>& arr,
const std::shared_ptr<DataType>& arrow_type,
DLDataTypeCode dlpack_type, int64_t length) {
ASSERT_OK_AND_ASSIGN(auto dlmtensor, arrow::dlpack::ExportArray(arr));
auto dltensor = dlmtensor->dl_tensor;

Expand All @@ -59,10 +59,10 @@ auto check_dlptensor = [](const std::shared_ptr<Array>& arr,
ASSERT_EQ(0, device.device_id);

dlmtensor->deleter(dlmtensor);
};
}

TEST_F(TestExportArray, TestSupportedArray) {
std::vector<std::pair<std::shared_ptr<DataType>, DLDataTypeCode>> cases = {
const std::vector<std::pair<std::shared_ptr<DataType>, DLDataTypeCode>> cases = {
{int8(), DLDataTypeCode::kDLInt},
{uint8(), DLDataTypeCode::kDLUInt},
{
Expand All @@ -89,19 +89,19 @@ TEST_F(TestExportArray, TestSupportedArray) {
for (auto [arrow_type, dlpack_type] : cases) {
const std::shared_ptr<Array> array =
ArrayFromJSON(arrow_type, "[1, 0, 10, 0, 2, 1, 3, 5, 1, 0]");
check_dlptensor(array, arrow_type, dlpack_type, 10);
CheckDLTensor(array, arrow_type, dlpack_type, 10);
ASSERT_OK_AND_ASSIGN(auto sliced_1, array->SliceSafe(1, 5));
check_dlptensor(sliced_1, arrow_type, dlpack_type, 5);
CheckDLTensor(sliced_1, arrow_type, dlpack_type, 5);
ASSERT_OK_AND_ASSIGN(auto sliced_2, array->SliceSafe(0, 5));
check_dlptensor(sliced_2, arrow_type, dlpack_type, 5);
CheckDLTensor(sliced_2, arrow_type, dlpack_type, 5);
ASSERT_OK_AND_ASSIGN(auto sliced_3, array->SliceSafe(3));
check_dlptensor(sliced_3, arrow_type, dlpack_type, 7);
CheckDLTensor(sliced_3, arrow_type, dlpack_type, 7);
}

ASSERT_EQ(allocated_bytes, arrow::default_memory_pool()->bytes_allocated());
}

TEST_F(TestExportArray, TestUnSupportedArray) {
TEST_F(TestExportArray, TestErrors) {
const std::shared_ptr<Array> array_null = ArrayFromJSON(null(), "[]");
ASSERT_RAISES_WITH_MESSAGE(TypeError,
"Type error: DataType is not compatible with DLPack spec: " +
Expand Down
8 changes: 4 additions & 4 deletions python/pyarrow/array.pxi
Original file line number Diff line number Diff line change
Expand Up @@ -1794,7 +1794,7 @@ cdef class Array(_PandasConvertible):
A DLPack capsule for the array, pointing to a DLManagedTensor.
"""
if stream is None:
dlm_tensor = GetResultValue(ExportToDLPack(pyarrow_unwrap_array(self)))
dlm_tensor = GetResultValue(ExportToDLPack(self.sp_array))

return PyCapsule_New(dlm_tensor, 'dltensor', dlpack_pycapsule_deleter)
else:
Expand All @@ -1804,7 +1804,7 @@ cdef class Array(_PandasConvertible):

def __dlpack_device__(self):
"""
Returns the DLPack device tuple this arrays resides on.
Return the DLPack device tuple this arrays resides on.
Returns
-------
Expand All @@ -1813,8 +1813,8 @@ cdef class Array(_PandasConvertible):
CPU = 1, see cpp/src/arrow/c/dpack_abi.h) and index of the
device which is 0 by default for CPU.
"""
device = GetResultValue(ExportDevice(pyarrow_unwrap_array(self)))
return (device.device_type, device.device_id)
device = GetResultValue(ExportDevice(self.sp_array))
return device.device_type, device.device_id


cdef _array_like_to_pandas(obj, options, types_mapper):
Expand Down

0 comments on commit 49a978f

Please sign in to comment.