Skip to content

Commit

Permalink
ARROW-17301: [C++] Implement compute function "binary_slice" (#14550)
Browse files Browse the repository at this point in the history
Implements `binary_slice_bytes` similar to `utf8_slice_codeunits`.

Mostly based on `utf8_slice_codeunits`.

TODO:
* [x] C++ Tests 
* [x] Python Tests

Authored-by: kshitij12345 <kshitijkalambarkar@gmail.com>
Signed-off-by: Antoine Pitrou <antoine@python.org>
  • Loading branch information
kshitij12345 committed Nov 15, 2022
1 parent 4daf945 commit 058d4f6
Show file tree
Hide file tree
Showing 6 changed files with 328 additions and 6 deletions.
167 changes: 167 additions & 0 deletions cpp/src/arrow/compute/kernels/scalar_string_ascii.cc
Expand Up @@ -2409,6 +2409,172 @@ void AddAsciiStringReplaceSlice(FunctionRegistry* registry) {
DCHECK_OK(registry->AddFunction(std::move(func)));
}

// ----------------------------------------------------------------------
// Slice

namespace {
struct SliceBytesTransform : StringSliceTransformBase {
int64_t MaxCodeunits(int64_t ninputs, int64_t input_bytes) override {
const SliceOptions& opt = *this->options;
if ((opt.start >= 0) != (opt.stop >= 0)) {
// If start and stop don't have the same sign, we can't guess an upper bound
// on the resulting slice lengths, so return a worst case estimate.
return input_bytes;
}
int64_t max_slice_bytes = (opt.stop - opt.start + opt.step - 1) / opt.step;
return std::min(input_bytes, ninputs * std::max<int64_t>(0, max_slice_bytes));
}

int64_t Transform(const uint8_t* input, int64_t input_string_bytes, uint8_t* output) {
if (options->step >= 1) {
return SliceForward(input, input_string_bytes, output);
}
return SliceBackward(input, input_string_bytes, output);
}

int64_t SliceForward(const uint8_t* input, int64_t input_string_bytes,
uint8_t* output) {
// Slice in forward order (step > 0)
const SliceOptions& opt = *this->options;
const uint8_t* begin = input;
const uint8_t* end = input + input_string_bytes;
const uint8_t* begin_sliced;
const uint8_t* end_sliced;

if (!input_string_bytes) {
return 0;
}
// First, compute begin_sliced and end_sliced
if (opt.start >= 0) {
// start counting from the left
begin_sliced = std::min(begin + opt.start, end);
if (opt.stop > opt.start) {
// continue counting from begin_sliced
const int64_t length = opt.stop - opt.start;
end_sliced = std::min(begin_sliced + length, end);
} else if (opt.stop < 0) {
// from the end
end_sliced = std::max(end + opt.stop, begin_sliced);
} else {
// zero length slice
return 0;
}
} else {
// start counting from the right
begin_sliced = std::max(end + opt.start, begin);
if (opt.stop > 0) {
// continue counting from the left, we cannot start from begin_sliced because we
// don't know how many bytes are between begin and begin_sliced
end_sliced = std::min(begin + opt.stop, end);
// and therefore we also need this
if (end_sliced <= begin_sliced) {
// zero length slice
return 0;
}
} else if ((opt.stop < 0) && (opt.stop > opt.start)) {
// stop is negative, but larger than start, so we count again from the right
// in some cases we can optimize this, depending on the shortest path (from end
// or begin_sliced), but begin_sliced and opt.start can be 'out of sync',
// for instance when start=-100, when the string length is only 10.
end_sliced = std::max(end + opt.stop, begin_sliced);
} else {
// zero length slice
return 0;
}
}

// Second, copy computed slice to output
DCHECK(begin_sliced <= end_sliced);
if (opt.step == 1) {
// fast case, where we simply can finish with a memcpy
std::copy(begin_sliced, end_sliced, output);
return end_sliced - begin_sliced;
}

uint8_t* dest = output;
const uint8_t* i = begin_sliced;

while (i < end_sliced) {
*dest = *i;
// skip step codeunits
i += opt.step;
dest++;
}
return dest - output;
}

int64_t SliceBackward(const uint8_t* input, int64_t input_string_bytes,
uint8_t* output) {
// Slice in reverse order (step < 0)
const SliceOptions& opt = *this->options;
const uint8_t* begin = input;
const uint8_t* end = input + input_string_bytes;
const uint8_t* begin_sliced = begin;
const uint8_t* end_sliced = end;

if (!input_string_bytes) {
return 0;
}

if (opt.start >= 0) {
// +1 because begin_sliced acts as as the end of a reverse iterator
begin_sliced = std::min(begin + opt.start + 1, end);
} else {
// -1 because start=-1 means the last byte, which is 0 advances
begin_sliced = std::max(end + opt.start + 1, begin);
}
begin_sliced--;

// similar to opt.start
if (opt.stop >= 0) {
end_sliced = std::min(begin + opt.stop + 1, end);
} else {
end_sliced = std::max(end + opt.stop + 1, begin);
}
end_sliced--;

// Copy computed slice to output
uint8_t* dest = output;
const uint8_t* i = begin_sliced;
while (i > end_sliced) {
// write a single codepoint
*dest = *i;
// and skip the remainder
i += opt.step;
dest++;
}

return dest - output;
}
};

template <typename Type>
using SliceBytes = StringTransformExec<Type, SliceBytesTransform>;

} // namespace

const FunctionDoc binary_slice_doc(
"Slice binary string",
("For each binary string in `strings`, emit the substring defined by\n"
"(`start`, `stop`, `step`) as given by `SliceOptions` where `start` is\n"
"inclusive and `stop` is exclusive. All three values are measured in\n"
"bytes.\n"
"If `step` is negative, the string will be advanced in reversed order.\n"
"An error is raised if `step` is zero.\n"
"Null inputs emit null."),
{"strings"}, "SliceOptions", /*options_required=*/true);

void AddAsciiStringSlice(FunctionRegistry* registry) {
auto func =
std::make_shared<ScalarFunction>("binary_slice", Arity::Unary(), binary_slice_doc);
for (const auto& ty : BinaryTypes()) {
auto exec = GenerateVarBinaryToVarBinary<SliceBytes>(ty);
DCHECK_OK(
func->AddKernel({ty}, ty, std::move(exec), SliceBytesTransform::State::Init));
}
DCHECK_OK(registry->AddFunction(std::move(func)));
}

// ----------------------------------------------------------------------
// Split by pattern

Expand Down Expand Up @@ -3206,6 +3372,7 @@ void RegisterScalarStringAscii(FunctionRegistry* registry) {
AddAsciiStringExtractRegex(registry);
#endif
AddAsciiStringReplaceSlice(registry);
AddAsciiStringSlice(registry);
AddAsciiStringSplitPattern(registry);
AddAsciiStringSplitWhitespace(registry);
#ifdef ARROW_WITH_RE2
Expand Down
132 changes: 132 additions & 0 deletions cpp/src/arrow/compute/kernels/scalar_string_test.cc
Expand Up @@ -2119,6 +2119,138 @@ TYPED_TEST(TestStringKernels, SliceCodeunitsNegPos) {

#endif // ARROW_WITH_UTF8PROC

TYPED_TEST(TestBinaryKernels, SliceBytesBasic) {
SliceOptions options{2, 4};
this->CheckUnary("binary_slice", "[\"fo\xc2\xa2\", \"fo\", null, \"fob \"]",
this->type(), "[\"\xc2\xa2\", \"\", null, \"b \"]", &options);

// end is beyond 0, but before start (hence empty)
SliceOptions options_edgecase_1{-3, 1};
this->CheckUnary("binary_slice",
"[\"f\xc2\xa2"
"ds\"]",
this->type(), R"([""])", &options_edgecase_1);

// this is a safeguard agains an optimization path possible, but actually a tricky case
SliceOptions options_edgecase_2{-6, -2};
this->CheckUnary("binary_slice",
"[\"f\xc2\xa2"
"ds\"]",
this->type(), "[\"f\xc2\xa2\"]", &options_edgecase_2);

auto input = ArrayFromJSON(this->type(), R"(["foods"])");
EXPECT_RAISES_WITH_MESSAGE_THAT(
Invalid,
testing::HasSubstr("Function 'binary_slice' cannot be called without options"),
CallFunction("binary_slice", {input}));

SliceOptions options_invalid{2, 4, 0};
EXPECT_RAISES_WITH_MESSAGE_THAT(
Invalid, testing::HasSubstr("Slice step cannot be zero"),
CallFunction("binary_slice", {input}, &options_invalid));
}

TYPED_TEST(TestBinaryKernels, SliceBytesPosPos) {
SliceOptions options{2, 4};
this->CheckUnary(
"binary_slice",
"[\"\", \"a\", \"ab\", \"a\xc2\xa2\", \"ab\xc2\xa2\", \"ab\xc2\xffZ\"]",
this->type(), "[\"\", \"\", \"\", \"\xa2\", \"\xc2\xa2\", \"\xc2\xff\"]", &options);
SliceOptions options_step{1, 5, 2};
this->CheckUnary(
"binary_slice",
"[\"\", \"a\", \"ab\", \"a\xc2\xa2\", \"ab\xc2\xa2\", \"ab\xc2\xffZ\"]",
this->type(), "[\"\", \"\", \"b\", \"\xc2\", \"b\xa2\", \"b\xff\"]", &options_step);
SliceOptions options_step_neg{5, 1, -2};
this->CheckUnary(
"binary_slice",
"[\"\", \"a\", \"ab\", \"a\xc2\xa2\", \"ab\xc2\xa2\", \"ab\xc2\xffZ\"]",
this->type(), "[\"\", \"\", \"\", \"\xa2\", \"\xa2\", \"Z\xc2\"]",
&options_step_neg);
options_step_neg.stop = 0;
this->CheckUnary(
"binary_slice",
"[\"\", \"a\", \"ab\", \"a\xc2\xa2\", \"aZ\xc2\xa2\", \"ab\xc2\xffZ\"]",
this->type(), "[\"\", \"\", \"b\", \"\xa2\", \"\xa2Z\", \"Z\xc2\"]",
&options_step_neg);
}

TYPED_TEST(TestBinaryKernels, SliceBytesPosNeg) {
SliceOptions options{2, -1};
this->CheckUnary(
"binary_slice",
"[\"\", \"a\", \"ab\", \"a\xc2\xa2\", \"aZ\xc2\xa2\", \"ab\xc2\xffZ\"]",
this->type(), "[\"\", \"\", \"\", \"\", \"\xc2\", \"\xc2\xff\"]", &options);
SliceOptions options_step{1, -1, 2};
this->CheckUnary(
"binary_slice",
"[\"\", \"a\", \"ab\", \"a\xc2\xa2\", \"aZ\xc2\xa2\", \"ab\xc2\xffZ\"]",
this->type(), "[\"\", \"\", \"\", \"\xc2\", \"Z\", \"b\xff\"]", &options_step);
SliceOptions options_step_neg{3, -4, -2};
this->CheckUnary(
"binary_slice",
"[\"\", \"a\", \"ab\", \"Z\xc2\xa2\", \"aZ\xc2\xa2\", \"ab\xc2\xffZ\"]",
this->type(), "[\"\", \"a\", \"b\", \"\xa2Z\", \"\xa2Z\", \"\xff\"]",
&options_step_neg);
options_step_neg.stop = -5;
this->CheckUnary(
"binary_slice",
"[\"\", \"a\", \"ab\", \"Z\xc2\xa2\", \"aZ\xc2\xa2\", \"aP\xc2\xffZ\"]",
this->type(), "[\"\", \"a\", \"b\", \"\xa2Z\", \"\xa2Z\", \"\xffP\"]",
&options_step_neg);
}

TYPED_TEST(TestBinaryKernels, SliceBytesNegNeg) {
SliceOptions options{-2, -1};
this->CheckUnary(
"binary_slice",
"[\"\", \"a\", \"ab\", \"Z\xc2\xa2\", \"aZ\xc2\xa2\", \"ab\xc2\xffZ\"]",
this->type(), "[\"\", \"\", \"a\", \"\xc2\", \"\xc2\", \"\xff\"]", &options);
SliceOptions options_step{-4, -1, 2};
this->CheckUnary(
"binary_slice",
"[\"\", \"a\", \"ab\", \"Z\xc2\xa2\", \"aZ\xc2\xa2\", \"aP\xc2\xffZ\"]",
this->type(), "[\"\", \"\", \"a\", \"Z\", \"a\xc2\", \"P\xff\"]", &options_step);
SliceOptions options_step_neg{-1, -3, -2};
this->CheckUnary(
"binary_slice",
"[\"\", \"a\", \"ab\", \"Z\xc2\xa2\", \"aZ\xc2\xa2\", \"aP\xc2\xffZ\"]",
this->type(), "[\"\", \"a\", \"b\", \"\xa2\", \"\xa2\", \"Z\"]", &options_step_neg);
options_step_neg.stop = -4;
this->CheckUnary(
"binary_slice",
"[\"\", \"a\", \"ab\", \"Z\xc2\xa2\", \"aZ\xc2\xa2\", \"aP\xc2\xffZ\"]",
this->type(), "[\"\", \"a\", \"b\", \"\xa2Z\", \"\xa2Z\", \"Z\xc2\"]",
&options_step_neg);
}

TYPED_TEST(TestBinaryKernels, SliceBytesNegPos) {
SliceOptions options{-2, 4};
this->CheckUnary(
"binary_slice",
"[\"\", \"a\", \"ab\", \"Z\xc2\xa2\", \"aZ\xc2\xa2\", \"aP\xc2\xffZ\"]",
this->type(), "[\"\", \"a\", \"ab\", \"\xc2\xa2\", \"\xc2\xa2\", \"\xff\"]",
&options);
SliceOptions options_step{-4, 4, 2};
this->CheckUnary(
"binary_slice",
"[\"\", \"a\", \"ab\", \"Z\xc2\xa2\", \"aZ\xc2\xa2\", \"aP\xc2\xffZ\"]",
this->type(), "[\"\", \"a\", \"a\", \"Z\xa2\", \"a\xc2\", \"P\xff\"]",
&options_step);
SliceOptions options_step_neg{-1, 1, -2};
this->CheckUnary(
"binary_slice",
"[\"\", \"a\", \"ab\", \"Z\xc2\xa2\", \"aZ\xc2\xa2\", \"aP\xc2\xffZ\"]",
this->type(), "[\"\", \"\", \"\", \"\xa2\", \"\xa2\", \"Z\xc2\"]",
&options_step_neg);
options_step_neg.stop = 0;
this->CheckUnary(
"binary_slice",
"[\"\", \"a\", \"ab\", \"Z\xc2\xa2\", \"aZ\xc2\xa2\", \"aP\xc2\xffZ\"]",
this->type(), "[\"\", \"\", \"b\", \"\xa2\", \"\xa2Z\", \"Z\xc2\"]",
&options_step_neg);
}

TYPED_TEST(TestStringKernels, PadAscii) {
PadOptions options{/*width=*/5, " "};
this->CheckUnary("ascii_center", R"([null, "a", "bb", "bar", "foobar"])", this->type(),
Expand Down
15 changes: 10 additions & 5 deletions docs/source/cpp/compute.rst
Expand Up @@ -1089,13 +1089,18 @@ semantics follow Python slicing semantics: the start index is inclusive,
the stop index exclusive; if the step is negative, the sequence is followed
in reverse order.

+--------------------------+------------+----------------+-----------------+--------------------------+---------+
| Function name | Arity | Input types | Output type | Options class | Notes |
+==========================+============+================+=================+==========================+=========+
| utf8_slice_codeunits | Unary | String-like | String-like | :struct:`SliceOptions` | \(1) |
+--------------------------+------------+----------------+-----------------+--------------------------+---------+
+--------------------------+------------+-------------------------+-------------------------+--------------------------+---------+
| Function name | Arity | Input types | Output type | Options class | Notes |
+==========================+============+=========================+=========================+==========================+=========+
| binary_slice | Unary | Binary-like | Binary-like | :struct:`SliceOptions` | \(1) |
+--------------------------+------------+-------------------------+-------------------------+--------------------------+---------+
| utf8_slice_codeunits | Unary | String-like | String-like | :struct:`SliceOptions` | \(2) |
+--------------------------+------------+-------------------------+-------------------------+--------------------------+---------+

* \(1) Slice string into a substring defined by (``start``, ``stop``, ``step``)
as given by :struct:`SliceOptions` where ``start`` and ``stop`` are measured
in bytes. Null inputs emit null.
* \(2) Slice string into a substring defined by (``start``, ``stop``, ``step``)
as given by :struct:`SliceOptions` where ``start`` and ``stop`` are measured
in codeunits. Null inputs emit null.

Expand Down
1 change: 1 addition & 0 deletions docs/source/python/api/compute.rst
Expand Up @@ -342,6 +342,7 @@ String Slicing
.. autosummary::
:toctree: ../generated/

binary_slice
utf8_slice_codeunits

Containment Tests
Expand Down
17 changes: 17 additions & 0 deletions python/pyarrow/tests/test_compute.py
Expand Up @@ -18,6 +18,7 @@
from datetime import datetime
from functools import lru_cache, partial
import inspect
import itertools
import os
import pickle
import pytest
Expand Down Expand Up @@ -536,6 +537,22 @@ def test_slice_compatibility():
start, stop, step) == result


def test_binary_slice_compatibility():
arr = pa.array([b"", b"a", b"a\xff", b"ab\x00", b"abc\xfb", b"ab\xf2de"])
for start, stop, step in itertools.product(range(-6, 6),
range(-6, 6),
range(-3, 4)):
if step == 0:
continue
expected = pa.array([k.as_py()[start:stop:step]
for k in arr])
result = pc.binary_slice(
arr, start=start, stop=stop, step=step)
assert expected.equals(result)
# Positional options
assert pc.binary_slice(arr, start, stop, step) == result


def test_split_pattern():
arr = pa.array(["-foo---bar--", "---foo---b"])
result = pc.split_pattern(arr, pattern="---")
Expand Down
2 changes: 1 addition & 1 deletion r/src/compute.cpp
Expand Up @@ -449,7 +449,7 @@ std::shared_ptr<arrow::compute::FunctionOptions> make_compute_options(
return std::make_shared<Options>(cpp11::as_cpp<std::string>(options["characters"]));
}

if (func_name == "utf8_slice_codeunits") {
if (func_name == "utf8_slice_codeunits" || func_name == "binary_slice") {
using Options = arrow::compute::SliceOptions;

int64_t step = 1;
Expand Down

0 comments on commit 058d4f6

Please sign in to comment.