Skip to content

Commit

Permalink
ARROW-2276: [Python] Expose buffer protocol on Tensor
Browse files Browse the repository at this point in the history
Also add a bit_width property to the DataType class.

Author: Antoine Pitrou <antoine@python.org>

Closes #1741 from pitrou/ARROW-2276-tensor-buffer-protocol and squashes the following commits:

104388a <Antoine Pitrou> ARROW-2276:  Expose buffer protocol on Tensor
  • Loading branch information
pitrou committed Apr 4, 2018
1 parent 26bc4ab commit 640fc83
Show file tree
Hide file tree
Showing 5 changed files with 113 additions and 0 deletions.
24 changes: 24 additions & 0 deletions python/pyarrow/array.pxi
Expand Up @@ -651,6 +651,30 @@ strides: {0.strides}""".format(self)
self._validate()
return tuple(self.tp.strides())

def __getbuffer__(self, cp.Py_buffer* buffer, int flags):
self._validate()

buffer.buf = <char *> self.tp.data().get().data()
pep3118_format = self.type.pep3118_format
if pep3118_format is None:
raise NotImplementedError("type %s not supported for buffer "
"protocol" % (self.type,))
buffer.format = pep3118_format
buffer.itemsize = self.type.bit_width // 8
buffer.internal = NULL
buffer.len = self.tp.size() * buffer.itemsize
buffer.ndim = self.tp.ndim()
buffer.obj = self
if self.tp.is_mutable():
buffer.readonly = 0
else:
buffer.readonly = 1
# NOTE: This assumes Py_ssize_t == int64_t, and that the shape
# and strides arrays lifetime is tied to the tensor's
buffer.shape = <Py_ssize_t *> &self.tp.shape()[0]
buffer.strides = <Py_ssize_t *> &self.tp.strides()[0]
buffer.suboffsets = NULL


cdef wrap_array_output(PyObject* output):
cdef object obj = PyObject_to_object(output)
Expand Down
3 changes: 3 additions & 0 deletions python/pyarrow/lib.pxd
Expand Up @@ -20,6 +20,8 @@ from pyarrow.includes.libarrow cimport *
from pyarrow.includes.libarrow cimport CStatus
from cpython cimport PyObject
from libcpp cimport nullptr
from libcpp.cast cimport dynamic_cast


cdef extern from "Python.h":
int PySlice_Check(object)
Expand All @@ -42,6 +44,7 @@ cdef class DataType:
cdef:
shared_ptr[CDataType] sp_type
CDataType* type
bytes pep3118_format

cdef void init(self, const shared_ptr[CDataType]& type)

Expand Down
27 changes: 27 additions & 0 deletions python/pyarrow/tests/test_tensor.py
Expand Up @@ -165,3 +165,30 @@ def test_read_tensor(tmpdir):
read_mmap = pa.memory_map(path, mode='r')
array = pa.read_tensor(read_mmap).to_numpy()
np.testing.assert_equal(data, array)


@pytest.mark.skipif(sys.version_info < (3,),
reason="requires Python 3+")
def test_tensor_memoryview():
# Tensors support the PEP 3118 buffer protocol
for dtype, expected_format in [(np.int8, '=b'),
(np.int64, '=q'),
(np.uint64, '=Q'),
(np.float16, 'e'),
(np.float64, 'd'),
]:
data = np.arange(10, dtype=dtype)
dtype = data.dtype
lst = data.tolist()
tensor = pa.Tensor.from_numpy(data)
m = memoryview(tensor)
assert m.format == expected_format
assert m.shape == data.shape
assert m.strides == data.strides
assert m.ndim == 1
assert m.nbytes == data.nbytes
assert m.itemsize == data.itemsize
assert m.itemsize * 8 == tensor.type.bit_width
assert np.frombuffer(m, dtype).tolist() == lst
del tensor, data
assert np.frombuffer(m, dtype).tolist() == lst
13 changes: 13 additions & 0 deletions python/pyarrow/tests/test_types.py
Expand Up @@ -230,6 +230,19 @@ def test_exact_primitive_types(t, check_func):
assert check_func(t)


def test_bit_width():
for ty, expected in [(pa.bool_(), 1),
(pa.int8(), 8),
(pa.uint32(), 32),
(pa.float16(), 16),
(pa.decimal128(19, 4), 128),
(pa.binary(42), 42 * 8)]:
assert ty.bit_width == expected
for ty in [pa.binary(), pa.string(), pa.list_(pa.int16())]:
with pytest.raises(ValueError, match="fixed width"):
ty.bit_width


def test_fixed_size_binary_byte_width():
ty = pa.binary(5)
assert ty.byte_width == 5
Expand Down
46 changes: 46 additions & 0 deletions python/pyarrow/types.pxi
Expand Up @@ -43,6 +43,42 @@ cdef dict _pandas_type_map = {
_Type_DECIMAL: np.object_,
}

cdef dict _pep3118_type_map = {
_Type_INT8: b'b',
_Type_INT16: b'h',
_Type_INT32: b'i',
_Type_INT64: b'q',
_Type_UINT8: b'B',
_Type_UINT16: b'H',
_Type_UINT32: b'I',
_Type_UINT64: b'Q',
_Type_HALF_FLOAT: b'e',
_Type_FLOAT: b'f',
_Type_DOUBLE: b'd',
}


cdef bytes _datatype_to_pep3118(CDataType* type):
"""
Construct a PEP 3118 format string describing the given datatype.
None is returned for unsupported types.
"""
try:
char = _pep3118_type_map[type.id()]
except KeyError:
return None
else:
if char in b'bBhHiIqQ':
# Use "standard" int widths, not native
return b'=' + char
else:
return char


# Workaround for Cython parsing bug
# https://github.com/cython/cython/issues/2143
ctypedef CFixedWidthType* _CFixedWidthTypePtr


cdef class DataType:
"""
Expand All @@ -54,12 +90,22 @@ cdef class DataType:
cdef void init(self, const shared_ptr[CDataType]& type):
self.sp_type = type
self.type = type.get()
self.pep3118_format = _datatype_to_pep3118(self.type)

property id:

def __get__(self):
return self.type.id()

property bit_width:

def __get__(self):
cdef _CFixedWidthTypePtr ty
ty = dynamic_cast[_CFixedWidthTypePtr](self.type)
if ty == nullptr:
raise ValueError("Non-fixed width type")
return ty.bit_width()

def __str__(self):
if self.type is NULL:
raise TypeError(
Expand Down

0 comments on commit 640fc83

Please sign in to comment.