Skip to content

Commit

Permalink
ARROW-8277: [Python] implemented __eq__, __repr__, and provided a wra…
Browse files Browse the repository at this point in the history
…pper of Take() for RecordBatch

Closes #6768 from brills/record_batch

Authored-by: Zhuo Peng <1835738+brills@users.noreply.github.com>
Signed-off-by: Wes McKinney <wesm+git@apache.org>
  • Loading branch information
brills authored and wesm committed Mar 31, 2020
1 parent 6ce0948 commit 58c0941
Show file tree
Hide file tree
Showing 3 changed files with 73 additions and 0 deletions.
3 changes: 3 additions & 0 deletions python/pyarrow/includes/libarrow.pxd
Expand Up @@ -1497,6 +1497,9 @@ cdef extern from "arrow/compute/api.h" namespace "arrow::compute" nogil:
CStatus Take(CFunctionContext* context, const CDatum& values,
const CDatum& indices, const CTakeOptions& options,
CDatum* out)
CStatus Take(CFunctionContext* context, const CRecordBatch& batch,
const CArray& indices, const CTakeOptions& options,
shared_ptr[CRecordBatch]* out)

# Filter clashes with gandiva.pyx::Filter
CStatus FilterKernel" arrow::compute::Filter"(
Expand Down
37 changes: 37 additions & 0 deletions python/pyarrow/table.pxi
Expand Up @@ -536,6 +536,15 @@ cdef class RecordBatch(_PandasConvertible):
def __len__(self):
return self.batch.num_rows()

def __eq__(self, other):
try:
return self.equals(other)
except TypeError:
return NotImplemented

def __repr__(self):
return 'pyarrow.{}\n{}'.format(type(self).__name__, str(self.schema))

def validate(self, *, full=False):
"""
Perform validation checks. An exception is raised if validation fails.
Expand Down Expand Up @@ -730,6 +739,34 @@ cdef class RecordBatch(_PandasConvertible):

return result

def take(self, Array indices):
"""
Take rows from a RecordBatch.
The resulting batch contains rows taken from the input batch at the
given indices. If an index is null then all the cells in that row
will be null.
Parameters
----------
indices : Array
The indices of the values to extract. Array needs to be of
integer type.
Returns
-------
RecordBatch
"""
cdef:
CTakeOptions options
shared_ptr[CRecordBatch] out
CRecordBatch* this_batch = self.batch

with nogil:
check_status(Take(_context(), deref(this_batch),
deref(indices.sp_array), options, &out))

return pyarrow_wrap_batch(out)

def to_pydict(self):
"""
Convert the RecordBatch to a dict or OrderedDict.
Expand Down
33 changes: 33 additions & 0 deletions python/pyarrow/tests/test_table.py
Expand Up @@ -324,6 +324,39 @@ def test_recordbatch_basics():
# schema as first positional argument
batch = pa.record_batch(data, schema)
assert batch.schema == schema
assert (str(batch) == """pyarrow.RecordBatch
c0: int16
c1: int32""")


def test_recordbatch_equals():
data1 = [
pa.array(range(5), type='int16'),
pa.array([-10, -5, 0, None, 10], type='int32')
]
data2 = [
pa.array(['a', 'b', 'c']),
pa.array([['d'], ['e'], ['f']]),
]
column_names = ['c0', 'c1']

batch = pa.record_batch(data1, column_names)
assert batch == pa.record_batch(data1, column_names)
assert batch.equals(pa.record_batch(data1, column_names))

assert batch != pa.record_batch(data2, column_names)
assert not batch.equals(pa.record_batch(data2, column_names))


def test_recordbatch_take():
batch = pa.record_batch(
[pa.array([1, 2, 3, None, 5]),
pa.array(['a', 'b', 'c', 'd', 'e'])],
['f1', 'f2'])
assert batch.take(pa.array([2, 3])).equals(batch.slice(2, 2))
assert batch.take(pa.array([2, None])).equals(
pa.record_batch([pa.array([3, None]), pa.array(['c', None])],
['f1', 'f2']))


def test_recordbatch_column_sets_private_name():
Expand Down

0 comments on commit 58c0941

Please sign in to comment.