diff --git a/cpp/src/arrow/CMakeLists.txt b/cpp/src/arrow/CMakeLists.txt index d1f4852523377..d4fb19fbb31de 100644 --- a/cpp/src/arrow/CMakeLists.txt +++ b/cpp/src/arrow/CMakeLists.txt @@ -158,6 +158,7 @@ if(ARROW_COMPUTE) compute/kernels/hash.cc compute/kernels/mean.cc compute/kernels/sum.cc + compute/kernels/take.cc compute/kernels/util-internal.cc compute/operations/cast.cc compute/operations/literal.cc) diff --git a/cpp/src/arrow/array/builder_binary.cc b/cpp/src/arrow/array/builder_binary.cc index 4fef135b20348..26c6cb4d372ef 100644 --- a/cpp/src/arrow/array/builder_binary.cc +++ b/cpp/src/arrow/array/builder_binary.cc @@ -232,8 +232,8 @@ Status FixedSizeBinaryBuilder::AppendValues(const uint8_t* data, int64_t length, Status FixedSizeBinaryBuilder::AppendNull() { RETURN_NOT_OK(Reserve(1)); - UnsafeAppendToBitmap(false); - return byte_builder_.Advance(byte_width_); + UnsafeAppendNull(); + return Status::OK(); } void FixedSizeBinaryBuilder::Reset() { diff --git a/cpp/src/arrow/array/builder_binary.h b/cpp/src/arrow/array/builder_binary.h index c3a459b39fccd..954f58e7cecfc 100644 --- a/cpp/src/arrow/array/builder_binary.h +++ b/cpp/src/arrow/array/builder_binary.h @@ -185,8 +185,8 @@ class ARROW_EXPORT FixedSizeBinaryBuilder : public ArrayBuilder { Status Append(const uint8_t* value) { ARROW_RETURN_NOT_OK(Reserve(1)); - UnsafeAppendToBitmap(true); - return byte_builder_.Append(value, byte_width_); + UnsafeAppend(value); + return Status::OK(); } Status Append(const char* value) { @@ -194,30 +194,46 @@ class ARROW_EXPORT FixedSizeBinaryBuilder : public ArrayBuilder { } Status Append(const util::string_view& view) { -#ifndef NDEBUG - CheckValueSize(static_cast(view.size())); -#endif - return Append(reinterpret_cast(view.data())); + ARROW_RETURN_NOT_OK(Reserve(1)); + UnsafeAppend(view); + return Status::OK(); } Status Append(const std::string& s) { -#ifndef NDEBUG - CheckValueSize(static_cast(s.size())); -#endif - return Append(reinterpret_cast(s.data())); + ARROW_RETURN_NOT_OK(Reserve(1)); + UnsafeAppend(s); + return Status::OK(); } template Status Append(const std::array& value) { ARROW_RETURN_NOT_OK(Reserve(1)); - UnsafeAppendToBitmap(true); - return byte_builder_.Append(value); + UnsafeAppend( + util::string_view(reinterpret_cast(value.data()), value.size())); + return Status::OK(); } Status AppendValues(const uint8_t* data, int64_t length, const uint8_t* valid_bytes = NULLPTR); Status AppendNull(); + void UnsafeAppend(const uint8_t* value) { + UnsafeAppendToBitmap(true); + byte_builder_.UnsafeAppend(value, byte_width_); + } + + void UnsafeAppend(util::string_view value) { +#ifndef NDEBUG + CheckValueSize(static_cast(value.size())); +#endif + UnsafeAppend(reinterpret_cast(value.data())); + } + + void UnsafeAppendNull() { + UnsafeAppendToBitmap(false); + byte_builder_.UnsafeAdvance(byte_width_); + } + void Reset() override; Status Resize(int64_t capacity) override; Status FinishInternal(std::shared_ptr* out) override; diff --git a/cpp/src/arrow/compute/api.h b/cpp/src/arrow/compute/api.h index b6e609a71e1b2..eb0e7897e2873 100644 --- a/cpp/src/arrow/compute/api.h +++ b/cpp/src/arrow/compute/api.h @@ -27,5 +27,6 @@ #include "arrow/compute/kernels/hash.h" // IWYU pragma: export #include "arrow/compute/kernels/mean.h" // IWYU pragma: export #include "arrow/compute/kernels/sum.h" // IWYU pragma: export +#include "arrow/compute/kernels/take.h" // IWYU pragma: export #endif // ARROW_COMPUTE_API_H diff --git a/cpp/src/arrow/compute/kernels/CMakeLists.txt b/cpp/src/arrow/compute/kernels/CMakeLists.txt index 5d78747bf93b9..abdc092a59037 100644 --- a/cpp/src/arrow/compute/kernels/CMakeLists.txt +++ b/cpp/src/arrow/compute/kernels/CMakeLists.txt @@ -20,6 +20,7 @@ arrow_install_all_headers("arrow/compute/kernels") add_arrow_test(boolean-test PREFIX "arrow-compute") add_arrow_test(cast-test PREFIX "arrow-compute") add_arrow_test(hash-test PREFIX "arrow-compute") +add_arrow_test(take-test PREFIX "arrow-compute") add_arrow_test(util-internal-test PREFIX "arrow-compute") # Aggregates diff --git a/cpp/src/arrow/compute/kernels/take-test.cc b/cpp/src/arrow/compute/kernels/take-test.cc new file mode 100644 index 0000000000000..110e07375a98b --- /dev/null +++ b/cpp/src/arrow/compute/kernels/take-test.cc @@ -0,0 +1,166 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// returnGegarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include +#include + +#include "arrow/compute/context.h" +#include "arrow/compute/kernels/take.h" +#include "arrow/compute/test-util.h" +#include "arrow/testing/gtest_common.h" +#include "arrow/testing/gtest_util.h" +#include "arrow/testing/random.h" +#include "arrow/testing/util.h" + +namespace arrow { +namespace compute { + +using util::string_view; + +template +class TestTakeKernel : public ComputeFixture, public TestBase { + protected: + void AssertTakeArrays(const std::shared_ptr& values, + const std::shared_ptr& indices, TakeOptions options, + const std::shared_ptr& expected) { + std::shared_ptr actual; + ASSERT_OK(arrow::compute::Take(&this->ctx_, *values, *indices, options, &actual)); + AssertArraysEqual(*expected, *actual); + } + void AssertTake(const std::shared_ptr& type, const std::string& values, + const std::string& indices, TakeOptions options, + const std::string& expected) { + std::shared_ptr actual; + + for (auto index_type : {int8(), uint32()}) { + ASSERT_OK(this->Take(type, values, index_type, indices, options, &actual)); + AssertArraysEqual(*ArrayFromJSON(type, expected), *actual); + } + } + Status Take(const std::shared_ptr& type, const std::string& values, + const std::shared_ptr& index_type, const std::string& indices, + TakeOptions options, std::shared_ptr* out) { + return arrow::compute::Take(&this->ctx_, *ArrayFromJSON(type, values), + *ArrayFromJSON(index_type, indices), options, out); + } +}; + +class TestTakeKernelWithNull : public TestTakeKernel { + protected: + void AssertTake(const std::string& values, const std::string& indices, + TakeOptions options, const std::string& expected) { + TestTakeKernel::AssertTake(utf8(), values, indices, options, expected); + } +}; + +TEST_F(TestTakeKernelWithNull, TakeNull) { + TakeOptions options; + this->AssertTake("[null, null, null]", "[0, 1, 0]", options, "[null, null, null]"); + + std::shared_ptr arr; + ASSERT_RAISES(Invalid, this->Take(null(), "[null, null, null]", int8(), "[0, 9, 0]", + options, &arr)); +} + +class TestTakeKernelWithBoolean : public TestTakeKernel { + protected: + void AssertTake(const std::string& values, const std::string& indices, + TakeOptions options, const std::string& expected) { + TestTakeKernel::AssertTake(boolean(), values, indices, options, + expected); + } +}; + +TEST_F(TestTakeKernelWithBoolean, TakeBoolean) { + TakeOptions options; + this->AssertTake("[true, false, true]", "[0, 1, 0]", options, "[true, false, true]"); + this->AssertTake("[null, false, true]", "[0, 1, 0]", options, "[null, false, null]"); + this->AssertTake("[true, false, true]", "[null, 1, 0]", options, "[null, false, true]"); + + std::shared_ptr arr; + ASSERT_RAISES(Invalid, this->Take(boolean(), "[true, false, true]", int8(), "[0, 9, 0]", + options, &arr)); +} + +template +class TestTakeKernelWithNumeric : public TestTakeKernel { + protected: + void AssertTake(const std::string& values, const std::string& indices, + TakeOptions options, const std::string& expected) { + TestTakeKernel::AssertTake(type_singleton(), values, indices, options, + expected); + } + std::shared_ptr type_singleton() { + return TypeTraits::type_singleton(); + } +}; + +TYPED_TEST_CASE(TestTakeKernelWithNumeric, NumericArrowTypes); +TYPED_TEST(TestTakeKernelWithNumeric, TakeNumeric) { + TakeOptions options; + this->AssertTake("[7, 8, 9]", "[0, 1, 0]", options, "[7, 8, 7]"); + this->AssertTake("[null, 8, 9]", "[0, 1, 0]", options, "[null, 8, null]"); + this->AssertTake("[7, 8, 9]", "[null, 1, 0]", options, "[null, 8, 7]"); + + std::shared_ptr arr; + ASSERT_RAISES(Invalid, this->Take(this->type_singleton(), "[7, 8, 9]", int8(), + "[0, 9, 0]", options, &arr)); +} + +class TestTakeKernelWithString : public TestTakeKernel { + protected: + void AssertTake(const std::string& values, const std::string& indices, + TakeOptions options, const std::string& expected) { + TestTakeKernel::AssertTake(utf8(), values, indices, options, expected); + } + void AssertTakeDictionary(const std::string& dictionary_values, + const std::string& dictionary_indices, + const std::string& indices, TakeOptions options, + const std::string& expected_indices) { + auto type = dictionary(int8(), ArrayFromJSON(utf8(), dictionary_values)); + std::shared_ptr values, actual, expected; + ASSERT_OK(DictionaryArray::FromArrays(type, ArrayFromJSON(int8(), dictionary_indices), + &values)); + ASSERT_OK(DictionaryArray::FromArrays(type, ArrayFromJSON(int8(), expected_indices), + &expected)); + auto take_indices = ArrayFromJSON(int8(), indices); + this->AssertTakeArrays(values, take_indices, options, expected); + } +}; + +TEST_F(TestTakeKernelWithString, TakeString) { + TakeOptions options; + this->AssertTake(R"(["a", "b", "c"])", "[0, 1, 0]", options, R"(["a", "b", "a"])"); + this->AssertTake(R"([null, "b", "c"])", "[0, 1, 0]", options, "[null, \"b\", null]"); + this->AssertTake(R"(["a", "b", "c"])", "[null, 1, 0]", options, R"([null, "b", "a"])"); + + std::shared_ptr arr; + ASSERT_RAISES(Invalid, this->Take(utf8(), R"(["a", "b", "c"])", int8(), "[0, 9, 0]", + options, &arr)); +} + +TEST_F(TestTakeKernelWithString, TakeDictionary) { + TakeOptions options; + auto dict = R"(["a", "b", "c", "d", "e"])"; + this->AssertTakeDictionary(dict, "[3, 4, 2]", "[0, 1, 0]", options, "[3, 4, 3]"); + this->AssertTakeDictionary(dict, "[null, 4, 2]", "[0, 1, 0]", options, + "[null, 4, null]"); + this->AssertTakeDictionary(dict, "[3, 4, 2]", "[null, 1, 0]", options, "[null, 4, 3]"); +} + +} // namespace compute +} // namespace arrow diff --git a/cpp/src/arrow/compute/kernels/take.cc b/cpp/src/arrow/compute/kernels/take.cc new file mode 100644 index 0000000000000..1dd34a92449c3 --- /dev/null +++ b/cpp/src/arrow/compute/kernels/take.cc @@ -0,0 +1,216 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// returnGegarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include +#include + +#include "arrow/builder.h" +#include "arrow/compute/context.h" +#include "arrow/compute/kernels/take.h" +#include "arrow/util/logging.h" +#include "arrow/visitor_inline.h" + +namespace arrow { +namespace compute { + +Status Take(FunctionContext* context, const Array& values, const Array& indices, + const TakeOptions& options, std::shared_ptr* out) { + Datum out_datum; + RETURN_NOT_OK( + Take(context, Datum(values.data()), Datum(indices.data()), options, &out_datum)); + *out = out_datum.make_array(); + return Status::OK(); +} + +Status Take(FunctionContext* context, const Datum& values, const Datum& indices, + const TakeOptions& options, Datum* out) { + TakeKernel kernel(values.type(), options); + RETURN_NOT_OK(kernel.Call(context, values, indices, out)); + return Status::OK(); +} + +struct TakeParameters { + FunctionContext* context; + std::shared_ptr values, indices; + TakeOptions options; + std::shared_ptr* out; +}; + +template +Status UnsafeAppend(Builder* builder, Scalar&& value) { + builder->UnsafeAppend(std::forward(value)); + return Status::OK(); +} + +Status UnsafeAppend(BinaryBuilder* builder, util::string_view value) { + RETURN_NOT_OK(builder->ReserveData(static_cast(value.size()))); + builder->UnsafeAppend(value); + return Status::OK(); +} + +Status UnsafeAppend(StringBuilder* builder, util::string_view value) { + RETURN_NOT_OK(builder->ReserveData(static_cast(value.size()))); + builder->UnsafeAppend(value); + return Status::OK(); +} + +template +Status TakeImpl(FunctionContext*, const ValueArray& values, const IndexArray& indices, + OutBuilder* builder) { + auto raw_indices = indices.raw_values(); + for (int64_t i = 0; i < indices.length(); ++i) { + if (!AllIndicesValid && indices.IsNull(i)) { + builder->UnsafeAppendNull(); + continue; + } + auto index = static_cast(raw_indices[i]); + if (index < 0 || index >= values.length()) { + return Status::Invalid("take index out of bounds"); + } + if (!AllValuesValid && values.IsNull(index)) { + builder->UnsafeAppendNull(); + continue; + } + RETURN_NOT_OK(UnsafeAppend(builder, values.GetView(index))); + } + return Status::OK(); +} + +template +Status UnpackIndicesNullCount(FunctionContext* context, const ValueArray& values, + const IndexArray& indices, OutBuilder* builder) { + if (indices.null_count() == 0) { + return TakeImpl(context, values, indices, builder); + } + return TakeImpl(context, values, indices, builder); +} + +template +Status UnpackValuesNullCount(FunctionContext* context, const ValueArray& values, + const IndexArray& indices, OutBuilder* builder) { + if (values.null_count() == 0) { + return UnpackIndicesNullCount(context, values, indices, builder); + } + return UnpackIndicesNullCount(context, values, indices, builder); +} + +template +struct UnpackValues { + using IndexArrayRef = const typename TypeTraits::ArrayType&; + + template + Status Visit(const ValueType&) { + using ValueArrayRef = const typename TypeTraits::ArrayType&; + using OutBuilder = typename TypeTraits::BuilderType; + IndexArrayRef indices = static_cast(*params_.indices); + ValueArrayRef values = static_cast(*params_.values); + std::unique_ptr builder; + RETURN_NOT_OK(MakeBuilder(params_.context->memory_pool(), values.type(), &builder)); + RETURN_NOT_OK(builder->Reserve(indices.length())); + RETURN_NOT_OK(UnpackValuesNullCount(params_.context, values, indices, + static_cast(builder.get()))); + return builder->Finish(params_.out); + } + + Status Visit(const NullType& t) { + auto indices_length = params_.indices->length(); + if (indices_length != 0) { + auto indices = static_cast(*params_.indices).raw_values(); + auto minmax = std::minmax_element(indices, indices + indices_length); + auto min = static_cast(*minmax.first); + auto max = static_cast(*minmax.second); + if (min < 0 || max >= params_.values->length()) { + return Status::Invalid("out of bounds index"); + } + } + params_.out->reset(new NullArray(indices_length)); + return Status::OK(); + } + + Status Visit(const DictionaryType& t) { + std::shared_ptr taken_indices; + { + // To take from a dictionary, apply the current kernel to the dictionary's + // indices. (Use UnpackValues since IndexType is already unpacked) + auto indices = static_cast(params_.values.get())->indices(); + TakeParameters params = params_; + params.values = indices; + params.out = &taken_indices; + UnpackValues unpack = {params}; + RETURN_NOT_OK(VisitTypeInline(*t.index_type(), &unpack)); + } + // create output dictionary from taken indices + return DictionaryArray::FromArrays(dictionary(t.index_type(), t.dictionary()), + taken_indices, params_.out); + } + + Status Visit(const ExtensionType& t) { + // XXX can we just take from its storage? + return Status::NotImplemented("gathering values of type ", t); + } + + Status Visit(const UnionType& t) { + return Status::NotImplemented("gathering values of type ", t); + } + + Status Visit(const ListType& t) { + return Status::NotImplemented("gathering values of type ", t); + } + + Status Visit(const StructType& t) { + return Status::NotImplemented("gathering values of type ", t); + } + + const TakeParameters& params_; +}; + +struct UnpackIndices { + template + enable_if_integer Visit(const IndexType&) { + UnpackValues unpack = {params_}; + return VisitTypeInline(*params_.values->type(), &unpack); + } + + Status Visit(const DataType& other) { + return Status::Invalid("index type not supported: ", other); + } + + const TakeParameters& params_; +}; + +Status TakeKernel::Call(FunctionContext* ctx, const Datum& values, const Datum& indices, + Datum* out) { + if (!values.is_array() || !indices.is_array()) { + return Status::Invalid("TakeKernel expects array values and indices"); + } + std::shared_ptr out_array; + TakeParameters params; + params.context = ctx; + params.values = values.make_array(); + params.indices = indices.make_array(); + params.options = options_; + params.out = &out_array; + UnpackIndices unpack = {params}; + RETURN_NOT_OK(VisitTypeInline(*indices.type(), &unpack)); + *out = Datum(out_array); + return Status::OK(); +} + +} // namespace compute +} // namespace arrow diff --git a/cpp/src/arrow/compute/kernels/take.h b/cpp/src/arrow/compute/kernels/take.h new file mode 100644 index 0000000000000..bfd69112786aa --- /dev/null +++ b/cpp/src/arrow/compute/kernels/take.h @@ -0,0 +1,83 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include + +#include "arrow/compute/kernel.h" +#include "arrow/status.h" +#include "arrow/util/visibility.h" + +namespace arrow { + +class Array; + +namespace compute { + +class FunctionContext; + +struct ARROW_EXPORT TakeOptions {}; + +/// \brief Take from an array of values at indices in another array +/// +/// The output array will be of the same type as the input values +/// array, with elements taken from the values array at the given +/// indices. If an index is null then the taken element will null. +/// +/// For example given values = ["a", "b", "c", null, "e", "f"] and +/// indices = [2, 1, null, 3], the output will be +/// = [values[2], values[1], null, values[3]] +/// = ["c", "b", null, null] +/// +/// \param[in] context the FunctionContext +/// \param[in] values array from which to take +/// \param[in] indices which values to take +/// \param[in] options options +/// \param[out] out resulting array +ARROW_EXPORT +Status Take(FunctionContext* context, const Array& values, const Array& indices, + const TakeOptions& options, std::shared_ptr* out); + +/// \brief Take from an array of values at indices in another array +/// +/// \param[in] context the FunctionContext +/// \param[in] values datum from which to take +/// \param[in] indices which values to take +/// \param[in] options options +/// \param[out] out resulting datum +ARROW_EXPORT +Status Take(FunctionContext* context, const Datum& values, const Datum& indices, + const TakeOptions& options, Datum* out); + +/// \brief BinaryKernel implementing Take operation +class ARROW_EXPORT TakeKernel : public BinaryKernel { + public: + explicit TakeKernel(const std::shared_ptr& type, TakeOptions options = {}) + : type_(type), options_(options) {} + + Status Call(FunctionContext* ctx, const Datum& values, const Datum& indices, + Datum* out) override; + + std::shared_ptr out_type() const override { return type_; } + + private: + std::shared_ptr type_; + TakeOptions options_; +}; +} // namespace compute +} // namespace arrow diff --git a/cpp/src/arrow/type_traits.h b/cpp/src/arrow/type_traits.h index 0cea58483f4ab..263916fd639b6 100644 --- a/cpp/src/arrow/type_traits.h +++ b/cpp/src/arrow/type_traits.h @@ -304,9 +304,9 @@ using enable_if_primitive_ctype = template using enable_if_date = typename std::enable_if::value>::type; -template +template using enable_if_integer = - typename std::enable_if::value>::type; + typename std::enable_if::value, U>::type; template using enable_if_signed_integer =