Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion docs/source/python/api/compute.rst
Original file line number Diff line number Diff line change
Expand Up @@ -532,8 +532,8 @@ Selections
drop_null
filter
inverse_permutation
take
scatter
take

Sorts and Partitions
--------------------
Expand Down Expand Up @@ -606,6 +606,7 @@ Compute Options
ExtractRegexSpanOptions
FilterOptions
IndexOptions
InversePermutationOptions
JoinOptions
ListFlattenOptions
ListSliceOptions
Expand Down Expand Up @@ -635,6 +636,7 @@ Compute Options
SkewOptions
SliceOptions
SortOptions
ScatterOptions
SplitOptions
SplitPatternOptions
StrftimeOptions
Expand Down
56 changes: 56 additions & 0 deletions python/pyarrow/_compute.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -1444,6 +1444,62 @@ class RunEndEncodeOptions(_RunEndEncodeOptions):
self._set_options(run_end_type)


cdef class _InversePermutationOptions(FunctionOptions):
def _set_options(self, max_index, output_type):
if output_type is None:
self.wrapped.reset(new CInversePermutationOptions(max_index))
else:
output_ty = ensure_type(output_type)
self.wrapped.reset(
new CInversePermutationOptions(max_index,
pyarrow_unwrap_data_type(output_ty)))


class InversePermutationOptions(_InversePermutationOptions):
"""
Options for `inverse_permutation` function.

Parameters
----------
max_index : int64, default -1
The max value in the input indices to allow.
The length of the function’s output will be this value plus 1.
If negative, this value will be set to the length of the input indices
minus 1 and the length of the function’s output will be the length
of the input indices.
output_type : DataType, default None
The type of the output inverse permutation.
If None, the output will be of the same type as the input indices, otherwise
must be signed integer type. An invalid error will be reported if this type
is not able to store the length of the input indices.
"""

def __init__(self, max_index=-1, output_type=None):
self._set_options(max_index, output_type)


cdef class _ScatterOptions(FunctionOptions):
def _set_options(self, max_index):
self.wrapped.reset(new CScatterOptions(max_index))


class ScatterOptions(_ScatterOptions):
"""
Options for `scatter` function.

Parameters
----------
max_index : int64, default -1
The max value in the input indices to allow.
The length of the function’s output will be this value plus 1.
If negative, this value will be set to the length of the input indices minus 1
and the length of the function’s output will be the length of the input indices.
"""

def __init__(self, max_index=-1):
self._set_options(max_index)


cdef class _TakeOptions(FunctionOptions):
def _set_options(self, boundscheck):
self.wrapped.reset(new CTakeOptions(boundscheck))
Expand Down
2 changes: 2 additions & 0 deletions python/pyarrow/compute.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
ExtractRegexSpanOptions,
FilterOptions,
IndexOptions,
InversePermutationOptions,
JoinOptions,
ListSliceOptions,
ListFlattenOptions,
Expand All @@ -66,6 +67,7 @@
RoundTemporalOptions,
RoundToMultipleOptions,
ScalarAggregateOptions,
ScatterOptions,
SelectKOptions,
SetLookupOptions,
SkewOptions,
Expand Down
12 changes: 12 additions & 0 deletions python/pyarrow/includes/libarrow.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -2588,6 +2588,18 @@ cdef extern from "arrow/compute/api.h" namespace "arrow::compute" nogil:
CTakeOptions(c_bool boundscheck)
c_bool boundscheck

cdef cppclass CInversePermutationOptions \
"arrow::compute::InversePermutationOptions"(CFunctionOptions):
CInversePermutationOptions(int64_t max_index)
CInversePermutationOptions(int64_t max_index, shared_ptr[CDataType] output_type)
int64_t max_index
shared_ptr[CDataType] output_type

cdef cppclass CScatterOptions \
"arrow::compute::ScatterOptions"(CFunctionOptions):
CScatterOptions(int64_t max_index)
int64_t max_index

cdef cppclass CStrptimeOptions \
"arrow::compute::StrptimeOptions"(CFunctionOptions):
CStrptimeOptions(c_string format, TimeUnit unit, c_bool raise_error)
Expand Down
40 changes: 40 additions & 0 deletions python/pyarrow/tests/test_compute.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,7 @@ def test_option_class_equality(request):
pc.WeekOptions(week_starts_monday=True, count_from_zero=False,
first_week_is_fully_in_year=False),
pc.ZeroFillOptions(4, "0"),
pc.InversePermutationOptions(-1, output_type=pa.int32()),
]
# Timezone database might not be installed on Windows or Emscripten
if request.config.pyarrow.is_enabled["timezone_data"]:
Expand Down Expand Up @@ -1590,6 +1591,45 @@ def test_filter_null_type():
assert len(table.filter(mask).column(0)) == 5


def test_inverse_permutation():
arr0 = pa.array([], type=pa.int32())
arr = pa.chunked_array([
arr0, [9, 7, 5, 3, 1], [0], [2, 4, 6], [8], arr0,
])
result = pc.inverse_permutation(arr)
print(result)
expected = pa.chunked_array([[5, 4, 6, 3, 7, 2, 8, 1, 9, 0]], type=pa.int32())
assert result.equals(expected)

# `inverse_permutation` kernel currently does not accept options
options = pc.InversePermutationOptions(max_index=4, output_type=pa.int64())
print(options)
with pytest.raises(TypeError, match="an unexpected keyword argument \'options\'"):
pc.inverse_permutation(arr, options=options)

# `inverse_permutation` kernel currently won't accept max_index
with pytest.raises(TypeError, match="an unexpected keyword argument \'max_index\'"):
pc.inverse_permutation(arr, max_index=4)


def test_scatter():
values = pa.array([True, False, True, True, False, False, True, True, True, False])
indices = pa.array([9, 8, 7, 6, 5, 4, 3, 2, 1, 0])
expected = pa.array([False, True, True, True, False,
False, True, True, False, True])
result = pc.scatter(values, indices)
assert result.equals(expected)

# `scatter` kernel currently does not accept options
options = pc.ScatterOptions(max_index=4)
with pytest.raises(TypeError, match="unexpected keyword argument \'options\'"):
pc.scatter(values, indices, options=options)

# `scatter` kernel currently won't accept max_index
with pytest.raises(TypeError, match="unexpected keyword argument \'max_index\'"):
pc.scatter(values, indices, max_index=4)


@pytest.mark.parametrize("typ", ["array", "chunked_array"])
def test_compare_array(typ):
if typ == "array":
Expand Down
Loading