Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 6 additions & 25 deletions python/pyarrow/array.pxi
Original file line number Diff line number Diff line change
Expand Up @@ -17,27 +17,6 @@

import warnings

from cpython.object cimport Py_LT, Py_EQ, Py_GT, Py_LE, Py_NE, Py_GE


cdef str _op_to_function_name(int op):
cdef str function_name

if op == Py_EQ:
function_name = "equal"
elif op == Py_NE:
function_name = "not_equal"
elif op == Py_GT:
function_name = "greater"
elif op == Py_GE:
function_name = "greater_equal"
elif op == Py_LT:
function_name = "less"
elif op == Py_LE:
function_name = "less_equal"

return function_name


cdef _sequence_to_array(object sequence, object mask, object size,
DataType type, CMemoryPool* pool, c_bool from_pandas):
Expand Down Expand Up @@ -773,10 +752,6 @@ cdef class Array(_PandasConvertible):
with nogil:
check_status(DebugPrint(deref(self.ap), 0))

def __richcmp__(self, other, int op):
function_name = _op_to_function_name(op)
return _pc().call_function(function_name, [self, other])

def diff(self, Array other):
"""
Compare contents of this array against another one.
Expand Down Expand Up @@ -999,6 +974,12 @@ cdef class Array(_PandasConvertible):
def __str__(self):
return self.to_string()

def __eq__(self, other):
try:
return self.equals(other)
except TypeError:
return NotImplemented

def equals(Array self, Array other):
return self.ap.Equals(deref(other.ap))

Expand Down
7 changes: 7 additions & 0 deletions python/pyarrow/compute.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,13 @@ def func(left, right):
subtract = _simple_binary_function('subtract')
multiply = _simple_binary_function('multiply')

equal = _simple_binary_function('equal')
not_equal = _simple_binary_function('not_equal')
greater = _simple_binary_function('greater')
greater_equal = _simple_binary_function('greater_equal')
less = _simple_binary_function('less')
less_equal = _simple_binary_function('less_equal')


def binary_contains_exact(array, pattern):
"""
Expand Down
10 changes: 6 additions & 4 deletions python/pyarrow/table.pxi
Original file line number Diff line number Diff line change
Expand Up @@ -41,10 +41,6 @@ cdef class ChunkedArray(_PandasConvertible):
def __reduce__(self):
return chunked_array, (self.chunks, self.type)

def __richcmp__(self, other, int op):
function_name = _op_to_function_name(op)
return _pc().call_function(function_name, [self, other])

@property
def data(self):
import warnings
Expand Down Expand Up @@ -189,6 +185,12 @@ cdef class ChunkedArray(_PandasConvertible):
"""
return _pc().is_valid(self)

def __eq__(self, other):
try:
return self.equals(other)
except TypeError:
return NotImplemented

def equals(self, ChunkedArray other):
"""
Return whether the contents of two chunked arrays are equal.
Expand Down
13 changes: 13 additions & 0 deletions python/pyarrow/tests/test_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -484,6 +484,19 @@ def test_array_ref_to_ndarray_base():
assert sys.getrefcount(arr) == (refcount + 1)


def test_array_eq():
# ARROW-2150 / ARROW-9445: we define the __eq__ behavior to be
# data equality (not element-wise equality)
arr1 = pa.array([1, 2, 3], type=pa.int32())
arr2 = pa.array([1, 2, 3], type=pa.int32())
arr3 = pa.array([1, 2, 3], type=pa.int64())

assert (arr1 == arr2) is True
assert (arr1 != arr2) is False
assert (arr1 == arr3) is False
assert (arr1 != arr3) is True


def test_array_from_buffers():
values_buf = pa.py_buffer(np.int16([4, 5, 6, 7]))
nulls_buf = pa.py_buffer(np.uint8([0b00001101]))
Expand Down
33 changes: 17 additions & 16 deletions python/pyarrow/tests/test_compute.py
Original file line number Diff line number Diff line change
Expand Up @@ -376,22 +376,22 @@ def con(values): return pa.chunked_array([values])
arr1 = con([1, 2, 3, 4, None])
arr2 = con([1, 1, 4, None, 4])

result = arr1 == arr2
result = pc.equal(arr1, arr2)
assert result.equals(con([True, False, False, None, None]))

result = arr1 != arr2
result = pc.not_equal(arr1, arr2)
assert result.equals(con([False, True, True, None, None]))

result = arr1 < arr2
result = pc.less(arr1, arr2)
assert result.equals(con([False, False, True, None, None]))

result = arr1 <= arr2
result = pc.less_equal(arr1, arr2)
assert result.equals(con([True, False, True, None, None]))

result = arr1 > arr2
result = pc.greater(arr1, arr2)
assert result.equals(con([False, True, False, None, None]))

result = arr1 >= arr2
result = pc.greater_equal(arr1, arr2)
assert result.equals(con([True, True, False, None, None]))


Expand All @@ -406,22 +406,22 @@ def con(values): return pa.chunked_array([values])
# TODO this is a hacky way to construct a scalar ..
scalar = pa.array([2]).sum()

result = arr == scalar
result = pc.equal(arr, scalar)
assert result.equals(con([False, True, False, None]))

result = arr != scalar
result = pc.not_equal(arr, scalar)
assert result.equals(con([True, False, True, None]))

result = arr < scalar
result = pc.less(arr, scalar)
assert result.equals(con([True, False, False, None]))

result = arr <= scalar
result = pc.less_equal(arr, scalar)
assert result.equals(con([True, True, False, None]))

result = arr > scalar
result = pc.greater(arr, scalar)
assert result.equals(con([False, False, True, None]))

result = arr >= scalar
result = pc.greater_equal(arr, scalar)
assert result.equals(con([False, True, True, None]))


Expand All @@ -432,11 +432,12 @@ def test_compare_chunked_array_mixed():

expected = pa.chunked_array([[True, True, True, True, None]])

for result in [
arr == arr_chunked,
arr_chunked == arr,
arr_chunked == arr_chunked2,
for left, right in [
(arr, arr_chunked),
(arr_chunked, arr),
(arr_chunked, arr_chunked2),
]:
result = pc.equal(left, right)
assert result.equals(expected)


Expand Down
4 changes: 2 additions & 2 deletions python/pyarrow/tests/test_scalars.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,7 +346,7 @@ def test_list(ty, klass):
assert s.type == ty
assert len(s) == 2
assert isinstance(s.values, pa.Array)
assert s.values == v
assert s.values.to_pylist() == v
assert isinstance(s, klass)
assert repr(v) in repr(s)
assert s.as_py() == v
Expand Down Expand Up @@ -496,7 +496,7 @@ def test_dictionary():
assert s.as_py() == v
assert s.value.as_py() == v
assert s.index.as_py() == i
assert s.dictionary == dictionary
assert s.dictionary.to_pylist() == dictionary

with pytest.warns(FutureWarning):
assert s.index_value.as_py() == i
Expand Down
3 changes: 3 additions & 0 deletions python/pyarrow/tests/test_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,8 @@ def eq(xarrs, yarrs):
y = pa.chunked_array(yarrs)
assert x.equals(y)
assert y.equals(x)
assert x == y
assert x != str(y)

def ne(xarrs, yarrs):
if isinstance(xarrs, pa.ChunkedArray):
Expand All @@ -140,6 +142,7 @@ def ne(xarrs, yarrs):
y = pa.chunked_array(yarrs)
assert not x.equals(y)
assert not y.equals(x)
assert x != y

eq(pa.chunked_array([], type=pa.int32()),
pa.chunked_array([], type=pa.int32()))
Expand Down