diff --git a/cpp/src/arrow/CMakeLists.txt b/cpp/src/arrow/CMakeLists.txt index c2a853f11bf97..3af37dc5eabb2 100644 --- a/cpp/src/arrow/CMakeLists.txt +++ b/cpp/src/arrow/CMakeLists.txt @@ -333,6 +333,7 @@ if(ARROW_COMPUTE) compute/kernels/add.cc compute/kernels/take.cc compute/kernels/isin.cc + compute/kernels/match.cc compute/kernels/util_internal.cc compute/operations/cast.cc compute/operations/literal.cc) diff --git a/cpp/src/arrow/array/dict_internal.h b/cpp/src/arrow/array/dict_internal.h index f950461dbe23d..1a21ed2a86613 100644 --- a/cpp/src/arrow/array/dict_internal.h +++ b/cpp/src/arrow/array/dict_internal.h @@ -15,6 +15,8 @@ // specific language governing permissions and limitations // under the License. +#pragma once + #include "arrow/array/builder_dict.h" #include diff --git a/cpp/src/arrow/compute/kernels/CMakeLists.txt b/cpp/src/arrow/compute/kernels/CMakeLists.txt index a54fd835193e5..8ce28df39756a 100644 --- a/cpp/src/arrow/compute/kernels/CMakeLists.txt +++ b/cpp/src/arrow/compute/kernels/CMakeLists.txt @@ -21,6 +21,7 @@ add_arrow_test(boolean_test PREFIX "arrow-compute") add_arrow_test(cast_test PREFIX "arrow-compute") add_arrow_test(hash_test PREFIX "arrow-compute") add_arrow_test(isin_test PREFIX "arrow-compute") +add_arrow_test(match_test PREFIX "arrow-compute") add_arrow_test(sort_to_indices_test PREFIX "arrow-compute") add_arrow_test(util_internal_test PREFIX "arrow-compute") add_arrow_test(add-test PREFIX "arrow-compute") diff --git a/cpp/src/arrow/compute/kernels/match.cc b/cpp/src/arrow/compute/kernels/match.cc new file mode 100644 index 0000000000000..4698a18c62f36 --- /dev/null +++ b/cpp/src/arrow/compute/kernels/match.cc @@ -0,0 +1,281 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "arrow/compute/kernels/match.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "arrow/array.h" +#include "arrow/array/dict_internal.h" +#include "arrow/buffer.h" +#include "arrow/builder.h" +#include "arrow/compute/context.h" +#include "arrow/compute/kernel.h" +#include "arrow/compute/kernels/util_internal.h" +#include "arrow/memory_pool.h" +#include "arrow/type.h" +#include "arrow/type_traits.h" +#include "arrow/util/bit_util.h" +#include "arrow/util/checked_cast.h" +#include "arrow/util/hashing.h" +#include "arrow/util/logging.h" +#include "arrow/util/macros.h" +#include "arrow/util/string_view.h" +#include "arrow/visitor_inline.h" + +namespace arrow { + +using internal::checked_cast; +using internal::DictionaryTraits; +using internal::HashTraits; + +namespace compute { + +class MatchKernelImpl : public UnaryKernel { + public: + std::shared_ptr out_type() const override { return int32(); } + + virtual Status Init(const Datum& needles) = 0; +}; + +template +class MatchKernel : public MatchKernelImpl { + public: + MatchKernel(std::shared_ptr type, MemoryPool* pool) + : type_(std::move(type)), pool_(pool) {} + + Status Call(FunctionContext* ctx, const Datum& haystack, Datum* out) override { + if (!haystack.is_arraylike()) { + return Status::Invalid("Haystack input to match kernel was not array-like"); + } + + Int32Builder indices_builder; + RETURN_NOT_OK(indices_builder.Reserve(haystack.length())); + + auto lookup_value = [&](util::optional v) { + if (v.has_value()) { + // check if value in haystack array is in the needles_table_ + if (needles_table_->Get(*v) != -1) { + // matching needle; output index from needles_table_ + indices_builder.UnsafeAppend(needles_table_->Get(*v)); + } else { + // no matching needle; output null + indices_builder.UnsafeAppendNull(); + } + } else { + if (needles_table_->GetNull() != -1) { + // needles include null; output index from needles_table_ + indices_builder.UnsafeAppend(needles_table_->GetNull()); + } else { + // needles do not include null; output null + indices_builder.UnsafeAppendNull(); + } + } + }; + + if (haystack.kind() == Datum::ARRAY) { + VisitArrayDataInline(*haystack.array(), lookup_value); + } + + if (haystack.kind() == Datum::CHUNKED_ARRAY) { + for (const auto& chunk : haystack.chunked_array()->chunks()) { + VisitArrayDataInline(*chunk->data(), lookup_value); + } + } + + std::shared_ptr out_data; + RETURN_NOT_OK(indices_builder.FinishInternal(&out_data)); + out->value = std::move(out_data); + return Status::OK(); + } + + Status Init(const Datum& needles) override { + if (!needles.is_arraylike()) { + return Status::Invalid("Needles input to match kernel was not array-like"); + } + + needles_table_.reset(new MemoTable(pool_, 0)); + + auto insert_value = [&](util::optional v) { + if (v.has_value()) { + int32_t unused_memo_index; + return needles_table_->GetOrInsert(*v, &unused_memo_index); + } + needles_table_->GetOrInsertNull(); + return Status::OK(); + }; + + if (needles.kind() == Datum::ARRAY) { + return VisitArrayDataInline(*needles.array(), insert_value); + } + + for (const auto& chunk : needles.chunked_array()->chunks()) { + RETURN_NOT_OK(VisitArrayDataInline(*chunk->data(), insert_value)); + } + return Status::OK(); + } + + protected: + using MemoTable = typename HashTraits::MemoTableType; + std::unique_ptr needles_table_; + std::shared_ptr type_; + MemoryPool* pool_; +}; + +// ---------------------------------------------------------------------- +// (NullType has a separate implementation) + +class NullMatchKernel : public MatchKernelImpl { + public: + NullMatchKernel(const std::shared_ptr& type, MemoryPool* pool) {} + + Status Call(FunctionContext* ctx, const Datum& haystack, Datum* out) override { + if (!haystack.is_arraylike()) { + return Status::Invalid("Haystack input to match kernel was not array-like"); + } + + Int32Builder indices_builder; + if (haystack.length() != 0) { + if (needles_null_count_ == 0) { + RETURN_NOT_OK(indices_builder.AppendNulls(haystack.length())); + } else { + RETURN_NOT_OK(indices_builder.Reserve(haystack.length())); + + for (int64_t i = 0; i < haystack.length(); ++i) { + indices_builder.UnsafeAppend(0); + } + } + } + + std::shared_ptr out_data; + RETURN_NOT_OK(indices_builder.FinishInternal(&out_data)); + out->value = std::move(out_data); + return Status::OK(); + } + + Status Init(const Datum& needles) override { + if (!needles.is_arraylike()) { + return Status::Invalid("Needles input to match kernel was not array-like"); + } + + needles_null_count_ = needles.length(); + return Status::OK(); + } + + private: + int64_t needles_null_count_{}; +}; + +// ---------------------------------------------------------------------- +// Kernel wrapper for generic hash table kernels + +template +struct MatchKernelTraits; + +template <> +struct MatchKernelTraits { + using MatchKernelImpl = NullMatchKernel; +}; + +template +struct MatchKernelTraits> { + using MatchKernelImpl = MatchKernel; +}; + +template <> +struct MatchKernelTraits { + using MatchKernelImpl = MatchKernel; +}; + +template +struct MatchKernelTraits> { + using MatchKernelImpl = MatchKernel; +}; + +template +struct MatchKernelTraits> { + using MatchKernelImpl = MatchKernel; +}; + +Status GetMatchKernel(FunctionContext* ctx, const std::shared_ptr& type, + const Datum& needles, std::unique_ptr* out) { + std::unique_ptr kernel; + +#define MATCH_CASE(InType) \ + case InType::type_id: \ + kernel.reset(new typename MatchKernelTraits::MatchKernelImpl( \ + type, ctx->memory_pool())); \ + break + + switch (type->id()) { + MATCH_CASE(NullType); + MATCH_CASE(BooleanType); + MATCH_CASE(UInt8Type); + MATCH_CASE(Int8Type); + MATCH_CASE(UInt16Type); + MATCH_CASE(Int16Type); + MATCH_CASE(UInt32Type); + MATCH_CASE(Int32Type); + MATCH_CASE(UInt64Type); + MATCH_CASE(Int64Type); + MATCH_CASE(FloatType); + MATCH_CASE(DoubleType); + MATCH_CASE(Date32Type); + MATCH_CASE(Date64Type); + MATCH_CASE(Time32Type); + MATCH_CASE(Time64Type); + MATCH_CASE(TimestampType); + MATCH_CASE(BinaryType); + MATCH_CASE(StringType); + MATCH_CASE(FixedSizeBinaryType); + MATCH_CASE(Decimal128Type); + default: + break; + } +#undef MATCH_CASE + + if (!kernel) { + return Status::NotImplemented("Match is not implemented for ", type->ToString()); + } + RETURN_NOT_OK(kernel->Init(needles)); + *out = std::move(kernel); + return Status::OK(); +} + +Status Match(FunctionContext* ctx, const Datum& haystack, const Datum& needles, + Datum* out) { + DCHECK(haystack.type()->Equals(needles.type())); + std::vector outputs; + std::unique_ptr kernel; + + RETURN_NOT_OK(GetMatchKernel(ctx, haystack.type(), needles, &kernel)); + RETURN_NOT_OK(detail::InvokeUnaryArrayKernel(ctx, kernel.get(), haystack, &outputs)); + + *out = detail::WrapDatumsLike(haystack, outputs); + return Status::OK(); +} + +} // namespace compute +} // namespace arrow diff --git a/cpp/src/arrow/compute/kernels/match.h b/cpp/src/arrow/compute/kernels/match.h new file mode 100644 index 0000000000000..1251ca88d255b --- /dev/null +++ b/cpp/src/arrow/compute/kernels/match.h @@ -0,0 +1,57 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include + +#include "arrow/array.h" +#include "arrow/compute/context.h" +#include "arrow/compute/kernel.h" +#include "arrow/status.h" +#include "arrow/type.h" +#include "arrow/util/visibility.h" + +namespace arrow { +namespace compute { + +/// \brief Match examines each slot in the haystack against a needles array. +/// If the value is not found in needles, null will be output. +/// If found, the index of occurence within needles (ignoring duplicates) +/// will be output. +/// +/// For example given haystack = [99, 42, 3, null] and +/// needles = [3, 3, 99], the output will be = [1, null, 0, null] +/// +/// Note: Null in the haystack is considered to match +/// a null in the needles array. For example given +/// haystack = [99, 42, 3, null] and needles = [3, 99, null], +/// the output will be = [1, null, 0, 2] +/// +/// \param[in] context the FunctionContext +/// \param[in] haystack array-like input +/// \param[in] needles array-like input +/// \param[out] out resulting datum +/// +/// \since 1.0.0 +/// \note API not yet finalized +ARROW_EXPORT +Status Match(FunctionContext* context, const Datum& haystack, const Datum& needles, + Datum* out); + +} // namespace compute +} // namespace arrow diff --git a/cpp/src/arrow/compute/kernels/match_test.cc b/cpp/src/arrow/compute/kernels/match_test.cc new file mode 100644 index 0000000000000..2103eaa5ae644 --- /dev/null +++ b/cpp/src/arrow/compute/kernels/match_test.cc @@ -0,0 +1,389 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +#include "arrow/array.h" +#include "arrow/buffer.h" +#include "arrow/compute/context.h" +#include "arrow/compute/kernel.h" +#include "arrow/compute/kernels/match.h" +#include "arrow/compute/kernels/util_internal.h" +#include "arrow/compute/test_util.h" +#include "arrow/memory_pool.h" +#include "arrow/status.h" +#include "arrow/table.h" +#include "arrow/testing/gtest_common.h" +#include "arrow/testing/random.h" +#include "arrow/testing/util.h" +#include "arrow/type.h" +#include "arrow/type_traits.h" +#include "arrow/util/decimal.h" + +namespace arrow { +namespace compute { + +// ---------------------------------------------------------------------- +// Match tests + +class TestMatchKernel : public ComputeFixture, public TestBase { + public: + void CheckMatch(const std::shared_ptr& type, const std::string& haystack_json, + const std::string& needles_json, const std::string& expected_json) { + std::shared_ptr haystack = ArrayFromJSON(type, haystack_json); + std::shared_ptr needles = ArrayFromJSON(type, needles_json); + std::shared_ptr expected = ArrayFromJSON(int32(), expected_json); + + Datum actual_datum; + ASSERT_OK(Match(&this->ctx_, haystack, needles, &actual_datum)); + std::shared_ptr actual = actual_datum.make_array(); + ASSERT_ARRAYS_EQUAL(*expected, *actual); + } +}; + +template +class TestMatchKernelPrimitive : public TestMatchKernel {}; + +using PrimitiveDictionaries = + ::testing::Types; + +TYPED_TEST_CASE(TestMatchKernelPrimitive, PrimitiveDictionaries); + +TYPED_TEST(TestMatchKernelPrimitive, Match) { + auto type = TypeTraits::type_singleton(); + + // No Nulls + this->CheckMatch(type, + /* haystack= */ "[2, 1, 2, 1, 2, 3]", + /* needles= */ "[2, 1, 2, 3]", + /* expected= */ "[0, 1, 0, 1, 0, 2]"); + + // Haystack array all null + this->CheckMatch(type, + /* haystack= */ "[null, null, null, null, null, null]", + /* needles= */ "[2, 1, 3]", + /* expected= */ "[null, null, null, null, null, null]"); + + // Needles array all null + this->CheckMatch(type, + /* haystack= */ "[2, 1, 2, 1, 2, 3]", + /* needles= */ "[null, null, null, null]", + /* expected= */ "[null, null, null, null, null, null]"); + + // Both arrays all null + this->CheckMatch(type, + /* haystack= */ "[null, null, null, null]", + /* needles= */ "[null, null]", + /* expected= */ "[0, 0, 0, 0]"); + + // No Match + this->CheckMatch(type, + /* haystack= */ "[2, null, 7, 3, 8]", + /* needles= */ "[2, null, 2, null, 6, 3, 3]", + /* expected= */ "[0, 1, null, 3, null]"); + + // Empty Arrays + this->CheckMatch(type, "[]", "[]", "[]"); +} + +TYPED_TEST(TestMatchKernelPrimitive, PrimitiveResizeTable) { + using T = typename TypeParam::c_type; + + const int64_t kTotalValues = std::min(INT16_MAX, 1UL << sizeof(T) / 2); + const int64_t kRepeats = 5; + + Int32Builder expected_builder; + NumericBuilder haystack_builder; + ASSERT_OK(expected_builder.Resize(kTotalValues * kRepeats)); + ASSERT_OK(haystack_builder.Resize(kTotalValues * kRepeats)); + + for (int64_t i = 0; i < kTotalValues * kRepeats; i++) { + const auto index = i % kTotalValues; + + haystack_builder.UnsafeAppend(static_cast(index)); + expected_builder.UnsafeAppend(static_cast(index)); + } + + std::shared_ptr haystack, needles, expected; + ASSERT_OK(haystack_builder.Finish(&haystack)); + needles = haystack; + ASSERT_OK(expected_builder.Finish(&expected)); + + Datum actual_datum; + ASSERT_OK(Match(&this->ctx_, haystack, needles, &actual_datum)); + std::shared_ptr actual = actual_datum.make_array(); + ASSERT_ARRAYS_EQUAL(*expected, *actual); +} + +TEST_F(TestMatchKernel, MatchNull) { + CheckMatch(null(), "[null, null, null]", "[null, null]", "[0, 0, 0]"); + + CheckMatch(null(), "[null, null, null]", "[]", "[null, null, null]"); + + CheckMatch(null(), "[]", "[null, null]", "[]"); + + CheckMatch(null(), "[]", "[]", "[]"); +} + +TEST_F(TestMatchKernel, MatchTimeTimestamp) { + CheckMatch(time32(TimeUnit::SECOND), + /* haystack= */ "[1, null, 5, 1, 2]", + /* needles= */ "[2, 1, null, 1]", + /* expected= */ "[1, 2, null, 1, 0]"); + + // Needles array has no nulls + CheckMatch(time32(TimeUnit::SECOND), + /* haystack= */ "[2, null, 5, 1]", + /* needles= */ "[2, 1, 1]", + /* expected= */ "[0, null, null, 1]"); + + // No match + CheckMatch(time32(TimeUnit::SECOND), "[3, null, 5, 3]", "[2, 1, 2, 1, 2]", + "[null, null, null, null]"); + + // Empty arrays + CheckMatch(time32(TimeUnit::SECOND), "[]", "[]", "[]"); + + CheckMatch(time64(TimeUnit::NANO), "[2, null, 2, 1]", "[2, null, 1]", "[0, 1, 0, 2]"); + + CheckMatch(timestamp(TimeUnit::NANO), "[2, null, 2, 1]", "[2, null, 2, 1]", + "[0, 1, 0, 2]"); + + // Empty haystack array + CheckMatch(timestamp(TimeUnit::NANO), "[]", "[2, null, 2, 1]", "[]"); + + // Empty needles array + CheckMatch(timestamp(TimeUnit::NANO), "[2, null, 2, 1]", "[]", + "[null, null, null, null]"); + + // Both array are all null + CheckMatch(time32(TimeUnit::SECOND), "[null, null, null, null]", "[null, null]", + "[0, 0, 0, 0]"); +} + +TEST_F(TestMatchKernel, MatchBoolean) { + CheckMatch(boolean(), + /* haystack= */ "[false, null, false, true]", + /* needles= */ "[null, false, true]", + /* expected= */ "[1, 0, 1, 2]"); + + CheckMatch(boolean(), "[false, null, false, true]", "[false, true, null, true, null]", + "[0, 2, 0, 1]"); + + // No Nulls + CheckMatch(boolean(), "[true, true, false, true]", "[false, true]", "[1, 1, 0, 1]"); + + CheckMatch(boolean(), "[false, true, false, true]", "[true, true, true, true]", + "[null, 0, null, 0]"); + + // No match + CheckMatch(boolean(), "[true, true, true, true]", "[false, false, false]", + "[null, null, null, null]"); + + // Nulls in haystack array + CheckMatch(boolean(), "[null, null, null, null]", "[true, true]", + "[null, null, null, null]"); + + // Nulls in needles array + CheckMatch(boolean(), "[true, true, false, true]", + "[null, null, null, null, null, null]", "[null, null, null, null]"); + + // Both array have Nulls + CheckMatch(boolean(), "[null, null, null, null]", "[null, null, null, null]", + "[0, 0, 0, 0]"); +} + +template +class TestMatchKernelBinary : public TestMatchKernel {}; + +using BinaryTypes = ::testing::Types; +TYPED_TEST_CASE(TestMatchKernelBinary, BinaryTypes); + +TYPED_TEST(TestMatchKernelBinary, MatchBinary) { + auto type = TypeTraits::type_singleton(); + this->CheckMatch(type, R"(["foo", null, "bar", "foo"])", R"(["foo", null, "bar"])", + R"([0, 1, 2, 0])"); + + // No match + this->CheckMatch(type, + /* haystack= */ R"(["foo", null, "bar", "foo"])", + /* needles= */ R"(["baz", "bazzz", "baz", "bazzz"])", + /* expected= */ R"([null, null, null, null])"); + + // Nulls in haystack array + this->CheckMatch(type, + /* haystack= */ R"([null, null, null, null])", + /* needles= */ R"(["foo", "bar", "foo"])", + /* expected= */ R"([null, null, null, null])"); + + // Nulls in needles array + this->CheckMatch(type, R"(["foo", "bar", "foo"])", R"([null, null, null])", + R"([null, null, null])"); + + // Both array have Nulls + this->CheckMatch(type, + /* haystack= */ R"([null, null, null, null])", + /* needles= */ R"([null, null, null, null])", + /* expected= */ R"([0, 0, 0, 0])"); + + // Empty arrays + this->CheckMatch(type, R"([])", R"([])", R"([])"); + + // Empty haystack array + this->CheckMatch(type, R"([])", R"(["foo", null, "bar", null])", "[]"); + + // Empty needles array + this->CheckMatch(type, R"(["foo", null, "bar", "foo"])", "[]", + R"([null, null, null, null])"); +} + +TEST_F(TestMatchKernel, BinaryResizeTable) { + const int32_t kTotalValues = 10000; +#if !defined(ARROW_VALGRIND) + const int32_t kRepeats = 10; +#else + // Mitigate Valgrind's slowness + const int32_t kRepeats = 3; +#endif + + const int32_t kBufSize = 20; + + Int32Builder expected_builder; + StringBuilder haystack_builder; + ASSERT_OK(expected_builder.Resize(kTotalValues * kRepeats)); + ASSERT_OK(haystack_builder.Resize(kTotalValues * kRepeats)); + ASSERT_OK(haystack_builder.ReserveData(kBufSize * kTotalValues * kRepeats)); + + for (int32_t i = 0; i < kTotalValues * kRepeats; i++) { + int32_t index = i % kTotalValues; + + char buf[kBufSize] = "test"; + ASSERT_GE(snprintf(buf + 4, sizeof(buf) - 4, "%d", index), 0); + + haystack_builder.UnsafeAppend(util::string_view(buf)); + expected_builder.UnsafeAppend(index); + } + + std::shared_ptr haystack, needles, expected; + ASSERT_OK(haystack_builder.Finish(&haystack)); + needles = haystack; + ASSERT_OK(expected_builder.Finish(&expected)); + + Datum actual_datum; + ASSERT_OK(Match(&this->ctx_, haystack, needles, &actual_datum)); + std::shared_ptr actual = actual_datum.make_array(); + ASSERT_ARRAYS_EQUAL(*expected, *actual); +} + +TEST_F(TestMatchKernel, MatchFixedSizeBinary) { + CheckMatch(fixed_size_binary(5), + /* haystack= */ R"(["bbbbb", null, "aaaaa", "ccccc"])", + /* needles= */ R"(["bbbbb", null, "bbbbb", "aaaaa", "ccccc"])", + /* expected= */ R"([0, 1, 2, 3])"); + + // Nulls in haystack + CheckMatch(fixed_size_binary(5), + /* haystack= */ R"([null, null, null, null, null])", + /* needles= */ R"(["bbbbb", "aabbb", "bbbbb", "aaaaa", "ccccc"])", + /* expected= */ R"([null, null, null, null, null])"); + + // Nulls in needles + CheckMatch(fixed_size_binary(5), + /* haystack= */ R"(["bbbbb", null, "bbbbb", "aaaaa", "ccccc"])", + /* needles= */ R"([null, null, null])", + /* expected= */ R"([null, 0, null, null, null])"); + + // Both array have Nulls + CheckMatch(fixed_size_binary(5), + /* haystack= */ R"([null, null, null, null, null])", + /* needles= */ R"([null, null, null, null])", + /* expected= */ R"([0, 0, 0, 0, 0])"); + + // No match + CheckMatch(fixed_size_binary(5), + /* haystack= */ R"(["bbbbc", "bbbbc", "aaaad", "cccca"])", + /* needles= */ R"(["bbbbb", null, "bbbbb", "aaaaa", "ddddd"])", + /* expected= */ R"([null, null, null, null])"); + + // Empty haystack array + CheckMatch(fixed_size_binary(5), R"([])", + R"(["bbbbb", null, "bbbbb", "aaaaa", "ccccc"])", R"([])"); + + // Empty needles array + CheckMatch(fixed_size_binary(5), R"(["bbbbb", null, "bbbbb", "aaaaa", "ccccc"])", + R"([])", R"([null, null, null, null, null])"); + + // Empty arrays + CheckMatch(fixed_size_binary(0), R"([])", R"([])", R"([])"); +} + +TEST_F(TestMatchKernel, MatchDecimal) { + std::vector input{12, 12, 11, 12}; + std::vector member_set{12, 12, 11, 12}; + std::vector expected{0, 1, 2, 0}; + + CheckMatch(decimal(2, 0), + /* haystack= */ R"(["12", null, "11", "12"])", + /* needles= */ R"(["12", null, "11", "12"])", + /* expected= */ R"([0, 1, 2, 0])"); +} + +TEST_F(TestMatchKernel, MatchChunkedArrayInvoke) { + std::vector values1 = {"foo", "bar", "foo"}; + std::vector values2 = {"bar", "baz", "quuux", "foo"}; + std::vector values3 = {"foo", "bar", "foo"}; + std::vector values4 = {"bar", "baz", "barr", "foo"}; + + auto type = utf8(); + auto a1 = _MakeArray(type, values1, {}); + auto a2 = _MakeArray(type, values2, {true, true, true, false}); + auto a3 = _MakeArray(type, values3, {}); + auto a4 = _MakeArray(type, values4, {}); + + ArrayVector array1 = {a1, a2}; + auto carr = std::make_shared(array1); + ArrayVector array2 = {a3, a4}; + auto member_set = std::make_shared(array2); + + auto i1 = _MakeArray(int32(), {0, 1, 0}, {}); + auto i2 = + _MakeArray(int32(), {1, 2, 2, 2}, {true, true, false, false}); + + ArrayVector expected = {i1, i2}; + auto expected_carr = std::make_shared(expected); + + Datum encoded_out; + ASSERT_OK(Match(&this->ctx_, carr, member_set, &encoded_out)); + ASSERT_EQ(Datum::CHUNKED_ARRAY, encoded_out.kind()); + + AssertChunkedEqual(*expected_carr, *encoded_out.chunked_array()); +} + +} // namespace compute +} // namespace arrow