Skip to content

Commit

Permalink
ARROW-2102: [C++] Implement Take kernel
Browse files Browse the repository at this point in the history
This implements take as a `BinaryKernel`

Out of bounds indices raise an error. All integer index types should be supported.

Supported value types are numeric, boolean, null, binary, dictionary, and string (untested: fixed width binary, time/date).

In addition to `TakeKernel`, a convenience function is implemented which takes arrays as its arguments (currently only array inputs are supported).

Author: Benjamin Kietzman <bengilgit@gmail.com>

Closes #3880 from bkietz/ARROW-2102-Implement-take-kernel-functions-primitiv and squashes the following commits:

a0250e5 <Benjamin Kietzman> Remove out of bounds option- always raise an error
99792b7 <Benjamin Kietzman> address review comments
3a1ef12 <Benjamin Kietzman> renaming, remove superflous DCHECK
cc821c6 <Benjamin Kietzman> avoid conflict with macro in R_ext/RS.h
b64722d <Benjamin Kietzman> use NULLPTR in public headers
198320d <Benjamin Kietzman> incorporate Francois' suggestions
8be7df1 <Benjamin Kietzman> fix take-test
4a6932f <Benjamin Kietzman> add take kernel to api.h
c5fd669 <Benjamin Kietzman> add test for taking from DictionaryArrays
a7dd739 <Benjamin Kietzman> explain null behavior in Take doccomment
14e837e <Benjamin Kietzman> add better explanatory comment for Take
e94fe31 <Benjamin Kietzman> use Datum::make_array
5fd6a21 <Benjamin Kietzman> add default case for switch(out_of_bounds)
7b840c7 <Benjamin Kietzman> first draft of take kernel impl
  • Loading branch information
bkietz authored and pitrou committed Apr 9, 2019
1 parent d8e4763 commit b2adf33
Show file tree
Hide file tree
Showing 9 changed files with 500 additions and 16 deletions.
1 change: 1 addition & 0 deletions cpp/src/arrow/CMakeLists.txt
Expand Up @@ -158,6 +158,7 @@ if(ARROW_COMPUTE)
compute/kernels/hash.cc
compute/kernels/mean.cc
compute/kernels/sum.cc
compute/kernels/take.cc
compute/kernels/util-internal.cc
compute/operations/cast.cc
compute/operations/literal.cc)
Expand Down
4 changes: 2 additions & 2 deletions cpp/src/arrow/array/builder_binary.cc
Expand Up @@ -232,8 +232,8 @@ Status FixedSizeBinaryBuilder::AppendValues(const uint8_t* data, int64_t length,

Status FixedSizeBinaryBuilder::AppendNull() {
RETURN_NOT_OK(Reserve(1));
UnsafeAppendToBitmap(false);
return byte_builder_.Advance(byte_width_);
UnsafeAppendNull();
return Status::OK();
}

void FixedSizeBinaryBuilder::Reset() {
Expand Down
40 changes: 28 additions & 12 deletions cpp/src/arrow/array/builder_binary.h
Expand Up @@ -185,39 +185,55 @@ class ARROW_EXPORT FixedSizeBinaryBuilder : public ArrayBuilder {

Status Append(const uint8_t* value) {
ARROW_RETURN_NOT_OK(Reserve(1));
UnsafeAppendToBitmap(true);
return byte_builder_.Append(value, byte_width_);
UnsafeAppend(value);
return Status::OK();
}

Status Append(const char* value) {
return Append(reinterpret_cast<const uint8_t*>(value));
}

Status Append(const util::string_view& view) {
#ifndef NDEBUG
CheckValueSize(static_cast<int64_t>(view.size()));
#endif
return Append(reinterpret_cast<const uint8_t*>(view.data()));
ARROW_RETURN_NOT_OK(Reserve(1));
UnsafeAppend(view);
return Status::OK();
}

Status Append(const std::string& s) {
#ifndef NDEBUG
CheckValueSize(static_cast<int64_t>(s.size()));
#endif
return Append(reinterpret_cast<const uint8_t*>(s.data()));
ARROW_RETURN_NOT_OK(Reserve(1));
UnsafeAppend(s);
return Status::OK();
}

template <size_t NBYTES>
Status Append(const std::array<uint8_t, NBYTES>& value) {
ARROW_RETURN_NOT_OK(Reserve(1));
UnsafeAppendToBitmap(true);
return byte_builder_.Append(value);
UnsafeAppend(
util::string_view(reinterpret_cast<const char*>(value.data()), value.size()));
return Status::OK();
}

Status AppendValues(const uint8_t* data, int64_t length,
const uint8_t* valid_bytes = NULLPTR);
Status AppendNull();

void UnsafeAppend(const uint8_t* value) {
UnsafeAppendToBitmap(true);
byte_builder_.UnsafeAppend(value, byte_width_);
}

void UnsafeAppend(util::string_view value) {
#ifndef NDEBUG
CheckValueSize(static_cast<size_t>(value.size()));
#endif
UnsafeAppend(reinterpret_cast<const uint8_t*>(value.data()));
}

void UnsafeAppendNull() {
UnsafeAppendToBitmap(false);
byte_builder_.UnsafeAdvance(byte_width_);
}

void Reset() override;
Status Resize(int64_t capacity) override;
Status FinishInternal(std::shared_ptr<ArrayData>* out) override;
Expand Down
1 change: 1 addition & 0 deletions cpp/src/arrow/compute/api.h
Expand Up @@ -27,5 +27,6 @@
#include "arrow/compute/kernels/hash.h" // IWYU pragma: export
#include "arrow/compute/kernels/mean.h" // IWYU pragma: export
#include "arrow/compute/kernels/sum.h" // IWYU pragma: export
#include "arrow/compute/kernels/take.h" // IWYU pragma: export

#endif // ARROW_COMPUTE_API_H
1 change: 1 addition & 0 deletions cpp/src/arrow/compute/kernels/CMakeLists.txt
Expand Up @@ -20,6 +20,7 @@ arrow_install_all_headers("arrow/compute/kernels")
add_arrow_test(boolean-test PREFIX "arrow-compute")
add_arrow_test(cast-test PREFIX "arrow-compute")
add_arrow_test(hash-test PREFIX "arrow-compute")
add_arrow_test(take-test PREFIX "arrow-compute")
add_arrow_test(util-internal-test PREFIX "arrow-compute")

# Aggregates
Expand Down
166 changes: 166 additions & 0 deletions cpp/src/arrow/compute/kernels/take-test.cc
@@ -0,0 +1,166 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// returnGegarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

#include <memory>
#include <vector>

#include "arrow/compute/context.h"
#include "arrow/compute/kernels/take.h"
#include "arrow/compute/test-util.h"
#include "arrow/testing/gtest_common.h"
#include "arrow/testing/gtest_util.h"
#include "arrow/testing/random.h"
#include "arrow/testing/util.h"

namespace arrow {
namespace compute {

using util::string_view;

template <typename ArrowType>
class TestTakeKernel : public ComputeFixture, public TestBase {
protected:
void AssertTakeArrays(const std::shared_ptr<Array>& values,
const std::shared_ptr<Array>& indices, TakeOptions options,
const std::shared_ptr<Array>& expected) {
std::shared_ptr<Array> actual;
ASSERT_OK(arrow::compute::Take(&this->ctx_, *values, *indices, options, &actual));
AssertArraysEqual(*expected, *actual);
}
void AssertTake(const std::shared_ptr<DataType>& type, const std::string& values,
const std::string& indices, TakeOptions options,
const std::string& expected) {
std::shared_ptr<Array> actual;

for (auto index_type : {int8(), uint32()}) {
ASSERT_OK(this->Take(type, values, index_type, indices, options, &actual));
AssertArraysEqual(*ArrayFromJSON(type, expected), *actual);
}
}
Status Take(const std::shared_ptr<DataType>& type, const std::string& values,
const std::shared_ptr<DataType>& index_type, const std::string& indices,
TakeOptions options, std::shared_ptr<Array>* out) {
return arrow::compute::Take(&this->ctx_, *ArrayFromJSON(type, values),
*ArrayFromJSON(index_type, indices), options, out);
}
};

class TestTakeKernelWithNull : public TestTakeKernel<NullType> {
protected:
void AssertTake(const std::string& values, const std::string& indices,
TakeOptions options, const std::string& expected) {
TestTakeKernel<NullType>::AssertTake(utf8(), values, indices, options, expected);
}
};

TEST_F(TestTakeKernelWithNull, TakeNull) {
TakeOptions options;
this->AssertTake("[null, null, null]", "[0, 1, 0]", options, "[null, null, null]");

std::shared_ptr<Array> arr;
ASSERT_RAISES(Invalid, this->Take(null(), "[null, null, null]", int8(), "[0, 9, 0]",
options, &arr));
}

class TestTakeKernelWithBoolean : public TestTakeKernel<BooleanType> {
protected:
void AssertTake(const std::string& values, const std::string& indices,
TakeOptions options, const std::string& expected) {
TestTakeKernel<BooleanType>::AssertTake(boolean(), values, indices, options,
expected);
}
};

TEST_F(TestTakeKernelWithBoolean, TakeBoolean) {
TakeOptions options;
this->AssertTake("[true, false, true]", "[0, 1, 0]", options, "[true, false, true]");
this->AssertTake("[null, false, true]", "[0, 1, 0]", options, "[null, false, null]");
this->AssertTake("[true, false, true]", "[null, 1, 0]", options, "[null, false, true]");

std::shared_ptr<Array> arr;
ASSERT_RAISES(Invalid, this->Take(boolean(), "[true, false, true]", int8(), "[0, 9, 0]",
options, &arr));
}

template <typename ArrowType>
class TestTakeKernelWithNumeric : public TestTakeKernel<ArrowType> {
protected:
void AssertTake(const std::string& values, const std::string& indices,
TakeOptions options, const std::string& expected) {
TestTakeKernel<ArrowType>::AssertTake(type_singleton(), values, indices, options,
expected);
}
std::shared_ptr<DataType> type_singleton() {
return TypeTraits<ArrowType>::type_singleton();
}
};

TYPED_TEST_CASE(TestTakeKernelWithNumeric, NumericArrowTypes);
TYPED_TEST(TestTakeKernelWithNumeric, TakeNumeric) {
TakeOptions options;
this->AssertTake("[7, 8, 9]", "[0, 1, 0]", options, "[7, 8, 7]");
this->AssertTake("[null, 8, 9]", "[0, 1, 0]", options, "[null, 8, null]");
this->AssertTake("[7, 8, 9]", "[null, 1, 0]", options, "[null, 8, 7]");

std::shared_ptr<Array> arr;
ASSERT_RAISES(Invalid, this->Take(this->type_singleton(), "[7, 8, 9]", int8(),
"[0, 9, 0]", options, &arr));
}

class TestTakeKernelWithString : public TestTakeKernel<StringType> {
protected:
void AssertTake(const std::string& values, const std::string& indices,
TakeOptions options, const std::string& expected) {
TestTakeKernel<StringType>::AssertTake(utf8(), values, indices, options, expected);
}
void AssertTakeDictionary(const std::string& dictionary_values,
const std::string& dictionary_indices,
const std::string& indices, TakeOptions options,
const std::string& expected_indices) {
auto type = dictionary(int8(), ArrayFromJSON(utf8(), dictionary_values));
std::shared_ptr<Array> values, actual, expected;
ASSERT_OK(DictionaryArray::FromArrays(type, ArrayFromJSON(int8(), dictionary_indices),
&values));
ASSERT_OK(DictionaryArray::FromArrays(type, ArrayFromJSON(int8(), expected_indices),
&expected));
auto take_indices = ArrayFromJSON(int8(), indices);
this->AssertTakeArrays(values, take_indices, options, expected);
}
};

TEST_F(TestTakeKernelWithString, TakeString) {
TakeOptions options;
this->AssertTake(R"(["a", "b", "c"])", "[0, 1, 0]", options, R"(["a", "b", "a"])");
this->AssertTake(R"([null, "b", "c"])", "[0, 1, 0]", options, "[null, \"b\", null]");
this->AssertTake(R"(["a", "b", "c"])", "[null, 1, 0]", options, R"([null, "b", "a"])");

std::shared_ptr<Array> arr;
ASSERT_RAISES(Invalid, this->Take(utf8(), R"(["a", "b", "c"])", int8(), "[0, 9, 0]",
options, &arr));
}

TEST_F(TestTakeKernelWithString, TakeDictionary) {
TakeOptions options;
auto dict = R"(["a", "b", "c", "d", "e"])";
this->AssertTakeDictionary(dict, "[3, 4, 2]", "[0, 1, 0]", options, "[3, 4, 3]");
this->AssertTakeDictionary(dict, "[null, 4, 2]", "[0, 1, 0]", options,
"[null, 4, null]");
this->AssertTakeDictionary(dict, "[3, 4, 2]", "[null, 1, 0]", options, "[null, 4, 3]");
}

} // namespace compute
} // namespace arrow

0 comments on commit b2adf33

Please sign in to comment.