From b010b2f085abb584aece03275c374ebfd42cb572 Mon Sep 17 00:00:00 2001 From: niranda perera Date: Thu, 20 May 2021 14:22:48 -0400 Subject: [PATCH 01/39] adding basic structure --- cpp/src/arrow/CMakeLists.txt | 1 + cpp/src/arrow/compute/api_scalar.cc | 5 + cpp/src/arrow/compute/api_scalar.h | 16 ++ cpp/src/arrow/compute/kernels/CMakeLists.txt | 1 + .../arrow/compute/kernels/scalar_if_else.cc | 137 ++++++++++++++++++ .../compute/kernels/scalar_if_else_test.cc | 17 +++ cpp/src/arrow/compute/registry.cc | 1 + cpp/src/arrow/compute/registry_internal.h | 1 + 8 files changed, 179 insertions(+) create mode 100644 cpp/src/arrow/compute/kernels/scalar_if_else.cc create mode 100644 cpp/src/arrow/compute/kernels/scalar_if_else_test.cc diff --git a/cpp/src/arrow/CMakeLists.txt b/cpp/src/arrow/CMakeLists.txt index 1d832cc25a21c..f6d5a540c9891 100644 --- a/cpp/src/arrow/CMakeLists.txt +++ b/cpp/src/arrow/CMakeLists.txt @@ -399,6 +399,7 @@ if(ARROW_COMPUTE) compute/kernels/scalar_string.cc compute/kernels/scalar_validity.cc compute/kernels/scalar_fill_null.cc + compute/kernels/scalar_if_else.cc compute/kernels/util_internal.cc compute/kernels/vector_hash.cc compute/kernels/vector_nested.cc diff --git a/cpp/src/arrow/compute/api_scalar.cc b/cpp/src/arrow/compute/api_scalar.cc index 9f4ad42fecbd2..105ba7a0589af 100644 --- a/cpp/src/arrow/compute/api_scalar.cc +++ b/cpp/src/arrow/compute/api_scalar.cc @@ -157,5 +157,10 @@ Result FillNull(const Datum& values, const Datum& fill_value, ExecContext return CallFunction("fill_null", {values, fill_value}, ctx); } +Result IfElse(const Datum& cond, const Datum& if_true, const Datum& if_false, + ExecContext* ctx) { + return CallFunction("if_else", {cond, if_true, if_false}, ctx); +} + } // namespace compute } // namespace arrow diff --git a/cpp/src/arrow/compute/api_scalar.h b/cpp/src/arrow/compute/api_scalar.h index dce420b32b24d..c4b70b3998998 100644 --- a/cpp/src/arrow/compute/api_scalar.h +++ b/cpp/src/arrow/compute/api_scalar.h @@ -465,5 +465,21 @@ ARROW_EXPORT Result FillNull(const Datum& values, const Datum& fill_value, ExecContext* ctx = NULLPTR); +/// \brief IfElse returns elements chosen from `left` or `right` +/// depending on `cond`. `Null` values would be promoted to the result +/// +/// \param[in] cond `BooleanArray` condition array +/// \param[in] left scalar/ Array +/// \param[in] right scalar/ Array +/// \param[in] ctx the function execution context, optional +/// +/// \return the resulting datum +/// +/// \since x.x.x +/// \note API not yet finalized +ARROW_EXPORT +Result IfElse(const Datum& cond, const Datum& left, const Datum& right, + ExecContext* ctx = NULLPTR); + } // namespace compute } // namespace arrow diff --git a/cpp/src/arrow/compute/kernels/CMakeLists.txt b/cpp/src/arrow/compute/kernels/CMakeLists.txt index 5e223a1f906f0..fc11d14410524 100644 --- a/cpp/src/arrow/compute/kernels/CMakeLists.txt +++ b/cpp/src/arrow/compute/kernels/CMakeLists.txt @@ -29,6 +29,7 @@ add_arrow_compute_test(scalar_test scalar_string_test.cc scalar_validity_test.cc scalar_fill_null_test.cc + scalar_if_else_test.cc test_util.cc) add_arrow_benchmark(scalar_arithmetic_benchmark PREFIX "arrow-compute") diff --git a/cpp/src/arrow/compute/kernels/scalar_if_else.cc b/cpp/src/arrow/compute/kernels/scalar_if_else.cc new file mode 100644 index 0000000000000..dc461b6a3d4be --- /dev/null +++ b/cpp/src/arrow/compute/kernels/scalar_if_else.cc @@ -0,0 +1,137 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include +#include + +#include "codegen_internal.h" + +namespace arrow { +namespace compute { + +namespace { + +template +struct IfElseFunctor { + static Status Call(KernelContext* ctx, const ArrayData& cond, const ArrayData& left, + const ArrayData& right, ArrayData* out) { + return Status::OK(); + } + + static Status Call(KernelContext* ctx, const ArrayData& cond, const ArrayData& left, + const Scalar& right, ArrayData* out) { + return Status::OK(); + } + + static Status Call(KernelContext* ctx, const Scalar& cond, const Scalar& left, + const Scalar& right, Scalar* out) { + return Status::OK(); + } +}; + +template +struct ResolveExec { + static Status Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) { + if (batch.length == 0) return Status::OK(); + + if (batch[0].kind() == Datum::ARRAY) { + if (batch[1].kind() == Datum::ARRAY) { + if (batch[2].kind() == Datum::ARRAY) { // AAA + return IfElseFunctor::Call(ctx, *batch[0].array(), *batch[1].array(), + *batch[2].array(), out->mutable_array()); + } else { // AAS + return IfElseFunctor::Call(ctx, *batch[0].array(), *batch[1].array(), + *batch[2].scalar(), out->mutable_array()); + } + } else { + return Status::Invalid(""); + // if (batch[2].kind() == Datum::ARRAY) { // ASA + // return IfElseFunctor::Call(ctx, *batch[0].array(), + // *batch[2].array(), + // *batch[1].scalar(), + // out->mutable_array()); + // } else { // ASS + // return IfElseFunctor::Call(ctx, *batch[0].array(), + // *batch[1].scalar(), + // *batch[2].scalar(), + // out->mutable_array()); + // } + } + } else { // when cond is scalar, output will also be scalar + if (batch[1].kind() == Datum::ARRAY) { + return Status::Invalid(""); + // if (batch[2].kind() == Datum::ARRAY) { // SAA + // return IfElseFunctor::Call(ctx, *batch[0].scalar(), + // *batch[1].array(), + // *batch[2].array(), + // out->scalar().get()); + // } else { // SAS + // return IfElseFunctor::Call(ctx, *batch[0].scalar(), + // *batch[1].array(), + // *batch[2].scalar(), + // out->scalar().get()); + // } + } else { + if (batch[2].kind() == Datum::ARRAY) { // SSA + return Status::Invalid(""); + // return IfElseFunctor::Call(ctx, *batch[0].scalar(), + // *batch[1].scalar(), + // *batch[2].array(), + // out->scalar().get()); + } else { // SSS + return IfElseFunctor::Call(ctx, *batch[0].scalar(), *batch[1].scalar(), + *batch[2].scalar(), out->scalar().get()); + } + } + } + } +}; + +void AddPrimitiveKernels(const std::shared_ptr& scalar_function, + const std::vector>& types) { + for (auto&& type : types) { + auto exec = internal::GenerateTypeAgnosticPrimitive(*type); + ScalarKernel kernel({boolean(), type, type}, type, exec); + kernel.null_handling = NullHandling::COMPUTED_NO_PREALLOCATE; + kernel.mem_allocation = MemAllocation::NO_PREALLOCATE; + + DCHECK_OK(scalar_function->AddKernel(std::move(kernel))); + } +} + +} // namespace + +const FunctionDoc if_else_doc{"", ("`"), {"cond", "left", "right"}}; + +namespace internal { + +void RegisterScalarIfElse(FunctionRegistry* registry) { + ScalarKernel scalar_kernel; + scalar_kernel.null_handling = NullHandling::COMPUTED_NO_PREALLOCATE; + scalar_kernel.mem_allocation = MemAllocation::NO_PREALLOCATE; + + auto func = std::make_shared("if_else", Arity::Ternary(), &if_else_doc); + + AddPrimitiveKernels(func, NumericTypes()); + // todo add temporal, boolean, null and binary kernels + + DCHECK_OK(registry->AddFunction(std::move(func))); +} + +} // namespace internal +} // namespace compute +} // namespace arrow \ No newline at end of file diff --git a/cpp/src/arrow/compute/kernels/scalar_if_else_test.cc b/cpp/src/arrow/compute/kernels/scalar_if_else_test.cc new file mode 100644 index 0000000000000..5cd17fb5a64e8 --- /dev/null +++ b/cpp/src/arrow/compute/kernels/scalar_if_else_test.cc @@ -0,0 +1,17 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + diff --git a/cpp/src/arrow/compute/registry.cc b/cpp/src/arrow/compute/registry.cc index 3a8a3a0eb8530..1d713b96e1ead 100644 --- a/cpp/src/arrow/compute/registry.cc +++ b/cpp/src/arrow/compute/registry.cc @@ -125,6 +125,7 @@ static std::unique_ptr CreateBuiltInRegistry() { RegisterScalarStringAscii(registry.get()); RegisterScalarValidity(registry.get()); RegisterScalarFillNull(registry.get()); + RegisterScalarIfElse(registry.get()); // Vector functions RegisterVectorHash(registry.get()); diff --git a/cpp/src/arrow/compute/registry_internal.h b/cpp/src/arrow/compute/registry_internal.h index e4008cf3f270f..f97553af4b104 100644 --- a/cpp/src/arrow/compute/registry_internal.h +++ b/cpp/src/arrow/compute/registry_internal.h @@ -34,6 +34,7 @@ void RegisterScalarSetLookup(FunctionRegistry* registry); void RegisterScalarStringAscii(FunctionRegistry* registry); void RegisterScalarValidity(FunctionRegistry* registry); void RegisterScalarFillNull(FunctionRegistry* registry); +void RegisterScalarIfElse(FunctionRegistry* registry); // Vector functions void RegisterVectorHash(FunctionRegistry* registry); From c909293a211eaa1d4361d565f0eacf5451506d2d Mon Sep 17 00:00:00 2001 From: niranda perera Date: Fri, 21 May 2021 16:45:47 -0400 Subject: [PATCH 02/39] working primitive types --- .../arrow/compute/kernels/scalar_if_else.cc | 195 +++++++++++++++++- .../compute/kernels/scalar_if_else_test.cc | 92 +++++++++ 2 files changed, 281 insertions(+), 6 deletions(-) diff --git a/cpp/src/arrow/compute/kernels/scalar_if_else.cc b/cpp/src/arrow/compute/kernels/scalar_if_else.cc index dc461b6a3d4be..0d1b68da5e54c 100644 --- a/cpp/src/arrow/compute/kernels/scalar_if_else.cc +++ b/cpp/src/arrow/compute/kernels/scalar_if_else.cc @@ -16,17 +16,199 @@ // under the License. #include -#include +#include +#include #include "codegen_internal.h" namespace arrow { +using internal::BitBlockCount; +using internal::BitBlockCounter; + namespace compute { namespace { +// nulls will be promoted as follows +// cond.val && (cond.data && left.val || ~cond.data && right.val) +Status promote_nulls(KernelContext* ctx, const ArrayData& cond, const ArrayData& left, + const ArrayData& right, ArrayData* output) { + if (!cond.MayHaveNulls() && !left.MayHaveNulls() && !right.MayHaveNulls()) { + return Status::OK(); // no nulls to handle + } + const int64_t len = cond.length; + + ARROW_ASSIGN_OR_RAISE(std::shared_ptr out_validity, ctx->AllocateBitmap(len)); + arrow::internal::InvertBitmap(out_validity->data(), 0, len, + out_validity->mutable_data(), 0); + if (right.MayHaveNulls()) { + // out_validity = right.val && ~cond.data + arrow::internal::BitmapAndNot(right.buffers[0]->data(), right.offset, + cond.buffers[1]->data(), cond.offset, len, 0, + out_validity->mutable_data()); + } + + if (left.MayHaveNulls()) { + ARROW_ASSIGN_OR_RAISE(std::shared_ptr temp_buf, ctx->AllocateBitmap(len)); + // tmp_buf = left.val && cond.data + arrow::internal::BitmapAnd(left.buffers[0]->data(), left.offset, + cond.buffers[1]->data(), cond.offset, len, 0, + temp_buf->mutable_data()); + // out_validity = cond.data && left.val || ~cond.data && right.val + arrow::internal::BitmapOr(out_validity->data(), 0, temp_buf->data(), 0, len, 0, + out_validity->mutable_data()); + } + + if (cond.MayHaveNulls()) { + // out_validity &= cond.val + ::arrow::internal::BitmapAnd(out_validity->data(), 0, cond.buffers[0]->data(), + cond.offset, len, 0, out_validity->mutable_data()); + } + + output->buffers[0] = std::move(out_validity); + output->GetNullCount(); // update null count + return Status::OK(); +} + template -struct IfElseFunctor { +struct IfElseFunctor {}; + +template +struct IfElseFunctor::value>> { + using T = typename TypeTraits::CType; + + static Status Call(KernelContext* ctx, const ArrayData& cond, const ArrayData& left, + const ArrayData& right, ArrayData* out) { + ARROW_RETURN_NOT_OK(promote_nulls(ctx, cond, left, right, out)); + + ARROW_ASSIGN_OR_RAISE(std::shared_ptr out_buf, + arrow::internal::CopyBitmap(ctx->memory_pool(), ) + + ctx->Allocate(cond.length * sizeof(T))); + T* out_values = reinterpret_cast(out_buf->mutable_data()); + + // copy right data to out_buff + const T* right_data = right.GetValues(1); + std::memcpy(out_values, right_data, right.length * sizeof(T)); + + const auto* cond_data = cond.buffers[1]->data(); // this is a BoolArray + BitBlockCounter bit_counter(cond_data, cond.offset, cond.length); + + // selectively copy values from left data + const T* left_data = left.GetValues(1); + int64_t offset = cond.offset; + + // todo this can be improved by intrinsics. ex: _mm*_mask_store_e* (vmovdqa*) + while (offset < cond.offset + cond.length) { + const BitBlockCount& block = bit_counter.NextWord(); + if (block.AllSet()) { // all from left + std::memcpy(out_values, left_data, block.length * sizeof(T)); + } else if (block.popcount) { // selectively copy from left + for (int64_t i = 0; i < block.length; ++i) { + if (BitUtil::GetBit(cond_data, offset + i)) { + out_values[i] = left_data[i]; + } + } + } + + offset += block.length; + out_values += block.length; + left_data += block.length; + } + + out->buffers[1] = std::move(out_buf); + return Status::OK(); + } + + static Status Call(KernelContext* ctx, const ArrayData& cond, const ArrayData& left, + const Scalar& right, ArrayData* out) { + // todo impl + return Status::OK(); + } + + static Status Call(KernelContext* ctx, const Scalar& cond, const Scalar& left, + const Scalar& right, Scalar* out) { + // todo impl + return Status::OK(); + } +}; + +template +struct IfElseFunctor::value>> { + static Status Call(KernelContext* ctx, const ArrayData& cond, const ArrayData& left, + const ArrayData& right, ArrayData* out) { + ARROW_RETURN_NOT_OK(promote_nulls(ctx, cond, left, right, out)); + + ARROW_ASSIGN_OR_RAISE(std::shared_ptr out_buf, + ctx->AllocateBitmap(cond.length)); + uint8_t* out_values = out_buf->mutable_data(); + + // copy right data to out_buff + const T* right_data = right.GetValues(1); + std::memcpy(out_values, right_data, right.length * sizeof(T)); + + const auto* cond_data = cond.buffers[1]->data(); // this is a BoolArray + BitBlockCounter bit_counter(cond_data, cond.offset, cond.length); + + // selectively copy values from left data + const T* left_data = left.GetValues(1); + int64_t offset = cond.offset; + + // todo this can be improved by intrinsics. ex: _mm*_mask_store_e* (vmovdqa*) + while (offset < cond.offset + cond.length) { + const BitBlockCount& block = bit_counter.NextWord(); + if (block.AllSet()) { // all from left + std::memcpy(out_values, left_data, block.length * sizeof(T)); + } else if (block.popcount) { // selectively copy from left + for (int64_t i = 0; i < block.length; ++i) { + if (BitUtil::GetBit(cond_data, offset + i)) { + out_values[i] = left_data[i]; + } + } + } + + offset += block.length; + out_values += block.length; + left_data += block.length; + } + + out->buffers[1] = std::move(out_buf); + return Status::OK(); + } + + static Status Call(KernelContext* ctx, const ArrayData& cond, const ArrayData& left, + const Scalar& right, ArrayData* out) { + // todo impl + return Status::OK(); + } + + static Status Call(KernelContext* ctx, const Scalar& cond, const Scalar& left, + const Scalar& right, Scalar* out) { + // todo impl + return Status::OK(); + } +}; + +template +struct IfElseFunctor::value>> { + static Status Call(KernelContext* ctx, const ArrayData& cond, const ArrayData& left, + const ArrayData& right, ArrayData* out) { + return Status::OK(); + } + + static Status Call(KernelContext* ctx, const ArrayData& cond, const ArrayData& left, + const Scalar& right, ArrayData* out) { + return Status::OK(); + } + + static Status Call(KernelContext* ctx, const Scalar& cond, const Scalar& left, + const Scalar& right, Scalar* out) { + return Status::OK(); + } +}; + +template +struct IfElseFunctor::value>> { static Status Call(KernelContext* ctx, const ArrayData& cond, const ArrayData& left, const ArrayData& right, ArrayData* out) { return Status::OK(); @@ -71,19 +253,19 @@ struct ResolveExec { // out->mutable_array()); // } } - } else { // when cond is scalar, output will also be scalar + } else { if (batch[1].kind() == Datum::ARRAY) { return Status::Invalid(""); // if (batch[2].kind() == Datum::ARRAY) { // SAA // return IfElseFunctor::Call(ctx, *batch[0].scalar(), // *batch[1].array(), // *batch[2].array(), - // out->scalar().get()); + // out->mutable_array()); // } else { // SAS // return IfElseFunctor::Call(ctx, *batch[0].scalar(), // *batch[1].array(), // *batch[2].scalar(), - // out->scalar().get()); + // out->mutable_array()); // } } else { if (batch[2].kind() == Datum::ARRAY) { // SSA @@ -91,7 +273,7 @@ struct ResolveExec { // return IfElseFunctor::Call(ctx, *batch[0].scalar(), // *batch[1].scalar(), // *batch[2].array(), - // out->scalar().get()); + // out->mutable_array()); } else { // SSS return IfElseFunctor::Call(ctx, *batch[0].scalar(), *batch[1].scalar(), *batch[2].scalar(), out->scalar().get()); @@ -127,6 +309,7 @@ void RegisterScalarIfElse(FunctionRegistry* registry) { auto func = std::make_shared("if_else", Arity::Ternary(), &if_else_doc); AddPrimitiveKernels(func, NumericTypes()); + AddPrimitiveKernels(func, TemporalTypes()); // todo add temporal, boolean, null and binary kernels DCHECK_OK(registry->AddFunction(std::move(func))); diff --git a/cpp/src/arrow/compute/kernels/scalar_if_else_test.cc b/cpp/src/arrow/compute/kernels/scalar_if_else_test.cc index 5cd17fb5a64e8..9b5dac71ff473 100644 --- a/cpp/src/arrow/compute/kernels/scalar_if_else_test.cc +++ b/cpp/src/arrow/compute/kernels/scalar_if_else_test.cc @@ -15,3 +15,95 @@ // specific language governing permissions and limitations // under the License. +#include +#include +#include +#include +#include + +namespace arrow { +namespace compute { + +void CheckIfElseOutputArray(const Datum& cond, const Datum& left, const Datum& right, + const Datum& expected, bool all_valid = true) { + ASSERT_OK_AND_ASSIGN(Datum datum_out, IfElse(cond, left, right)); + std::shared_ptr result = datum_out.make_array(); + ASSERT_OK(result->ValidateFull()); + AssertArraysEqual(*expected.make_array(), *result, /*verbose=*/true); + if (all_valid) { + // Check null count of ArrayData is set, not the computed Array.null_count + ASSERT_EQ(result->data()->null_count, 0); + } +} + +void CheckIfElseOutputArray(const std::shared_ptr& type, + const std::string& cond, const std::string& left, + const std::string& right, const std::string& expected, + bool all_valid = true) { + const std::shared_ptr& cond_ = ArrayFromJSON(boolean(), cond); + const std::shared_ptr& left_ = ArrayFromJSON(type, left); + const std::shared_ptr& right_ = ArrayFromJSON(type, right); + const std::shared_ptr& expected_ = ArrayFromJSON(type, expected); + CheckIfElseOutputArray(cond_, left_, right_, expected_, all_valid); +} + +class TestIfElseNullKernel : public ::testing::Test {}; + +template +class TestIfElsePrimitive : public ::testing::Test {}; + +using PrimitiveTypes = ::testing::Types; + + +TYPED_TEST_SUITE(TestIfElsePrimitive, PrimitiveTypes); + +TYPED_TEST(TestIfElsePrimitive, IfElseFixedSize) { + // using ScalarType = typename TypeTraits::ScalarType; + auto type = TypeTraits::type_singleton(); + // auto scalar = std::make_shared(static_cast(5)); + // No Nulls + CheckIfElseOutputArray(type, "[]", "[]", "[]", "[]"); + + CheckIfElseOutputArray(type, "[true, true, true, false]", "[1, 2, 3, 4]", + "[5, 6, 7, 8]", "[1, 2, 3, 8]"); + + CheckIfElseOutputArray(type, "[true, true, null, false]", "[1, 2, 3, 4]", + "[5, 6, 7, 8]", "[1, 2, null, 8]", false); + + CheckIfElseOutputArray(type, "[true, true, true, false]", "[1, 2, null, null]", + "[null, 6, 7, null]", "[1, 2, null, null]", false); + + using ArrayType = typename TypeTraits::ArrayType; + random::RandomArrayGenerator rand(/*seed=*/0); + int64_t len = 1000; + auto cond = std::static_pointer_cast( + rand.ArrayOf(boolean(), len, /*null_probability=*/0.01)); + auto left = std::static_pointer_cast( + rand.ArrayOf(type, len, /*null_probability=*/0.01)); + auto right = std::static_pointer_cast( + rand.ArrayOf(type, len, /*null_probability=*/0.01)); + + typename TypeTraits::BuilderType builder; + + for (int64_t i = 0; i < len; ++i) { + if (!cond->IsValid(i) || (cond->Value(i) && !left->IsValid(i)) || + (!cond->Value(i) && !right->IsValid(i))) { + ASSERT_OK(builder.AppendNull()); + continue; + } + + if (cond->Value(i)) { + ASSERT_OK(builder.Append(left->Value(i))); + } else { + ASSERT_OK(builder.Append(right->Value(i))); + } + } + ASSERT_OK_AND_ASSIGN(auto expected_data, builder.Finish()); + + CheckIfElseOutputArray(cond, left, right, expected_data, false); +} + +} // namespace compute +} // namespace arrow \ No newline at end of file From d2aa90efd4ba506d5ab9d3c3faf45a8cd3cb8cff Mon Sep 17 00:00:00 2001 From: niranda perera Date: Fri, 21 May 2021 17:18:19 -0400 Subject: [PATCH 03/39] adding bool type --- .../arrow/compute/kernels/scalar_if_else.cc | 76 +++++-------------- .../compute/kernels/scalar_if_else_test.cc | 58 ++++++++++++-- 2 files changed, 71 insertions(+), 63 deletions(-) diff --git a/cpp/src/arrow/compute/kernels/scalar_if_else.cc b/cpp/src/arrow/compute/kernels/scalar_if_else.cc index 0d1b68da5e54c..fe3e31268acc3 100644 --- a/cpp/src/arrow/compute/kernels/scalar_if_else.cc +++ b/cpp/src/arrow/compute/kernels/scalar_if_else.cc @@ -49,11 +49,11 @@ Status promote_nulls(KernelContext* ctx, const ArrayData& cond, const ArrayData& } if (left.MayHaveNulls()) { - ARROW_ASSIGN_OR_RAISE(std::shared_ptr temp_buf, ctx->AllocateBitmap(len)); // tmp_buf = left.val && cond.data - arrow::internal::BitmapAnd(left.buffers[0]->data(), left.offset, - cond.buffers[1]->data(), cond.offset, len, 0, - temp_buf->mutable_data()); + ARROW_ASSIGN_OR_RAISE(std::shared_ptr temp_buf, + arrow::internal::BitmapAnd( + ctx->memory_pool(), left.buffers[0]->data(), left.offset, + cond.buffers[1]->data(), cond.offset, len, 0)); // out_validity = cond.data && left.val || ~cond.data && right.val arrow::internal::BitmapOr(out_validity->data(), 0, temp_buf->data(), 0, len, 0, out_validity->mutable_data()); @@ -82,9 +82,7 @@ struct IfElseFunctor::value>> { ARROW_RETURN_NOT_OK(promote_nulls(ctx, cond, left, right, out)); ARROW_ASSIGN_OR_RAISE(std::shared_ptr out_buf, - arrow::internal::CopyBitmap(ctx->memory_pool(), ) - - ctx->Allocate(cond.length * sizeof(T))); + ctx->Allocate(cond.length * sizeof(T))); T* out_values = reinterpret_cast(out_buf->mutable_data()); // copy right data to out_buff @@ -139,39 +137,20 @@ struct IfElseFunctor::value>> { const ArrayData& right, ArrayData* out) { ARROW_RETURN_NOT_OK(promote_nulls(ctx, cond, left, right, out)); + // out_buff = right & ~cond ARROW_ASSIGN_OR_RAISE(std::shared_ptr out_buf, - ctx->AllocateBitmap(cond.length)); - uint8_t* out_values = out_buf->mutable_data(); - - // copy right data to out_buff - const T* right_data = right.GetValues(1); - std::memcpy(out_values, right_data, right.length * sizeof(T)); - - const auto* cond_data = cond.buffers[1]->data(); // this is a BoolArray - BitBlockCounter bit_counter(cond_data, cond.offset, cond.length); - - // selectively copy values from left data - const T* left_data = left.GetValues(1); - int64_t offset = cond.offset; - - // todo this can be improved by intrinsics. ex: _mm*_mask_store_e* (vmovdqa*) - while (offset < cond.offset + cond.length) { - const BitBlockCount& block = bit_counter.NextWord(); - if (block.AllSet()) { // all from left - std::memcpy(out_values, left_data, block.length * sizeof(T)); - } else if (block.popcount) { // selectively copy from left - for (int64_t i = 0; i < block.length; ++i) { - if (BitUtil::GetBit(cond_data, offset + i)) { - out_values[i] = left_data[i]; - } - } - } - - offset += block.length; - out_values += block.length; - left_data += block.length; - } - + arrow::internal::BitmapAndNot( + ctx->memory_pool(), right.buffers[1]->data(), right.offset, + cond.buffers[1]->data(), cond.offset, cond.length, 0)); + + // out_buff = left & cond + ARROW_ASSIGN_OR_RAISE(std::shared_ptr temp_buf, + arrow::internal::BitmapAnd( + ctx->memory_pool(), left.buffers[1]->data(), left.offset, + cond.buffers[1]->data(), cond.offset, cond.length, 0)); + + arrow::internal::BitmapOr(out_buf->data(), 0, temp_buf->data(), 0, cond.length, 0, + out_buf->mutable_data()); out->buffers[1] = std::move(out_buf); return Status::OK(); } @@ -189,24 +168,6 @@ struct IfElseFunctor::value>> { } }; -template -struct IfElseFunctor::value>> { - static Status Call(KernelContext* ctx, const ArrayData& cond, const ArrayData& left, - const ArrayData& right, ArrayData* out) { - return Status::OK(); - } - - static Status Call(KernelContext* ctx, const ArrayData& cond, const ArrayData& left, - const Scalar& right, ArrayData* out) { - return Status::OK(); - } - - static Status Call(KernelContext* ctx, const Scalar& cond, const Scalar& left, - const Scalar& right, Scalar* out) { - return Status::OK(); - } -}; - template struct IfElseFunctor::value>> { static Status Call(KernelContext* ctx, const ArrayData& cond, const ArrayData& left, @@ -310,6 +271,7 @@ void RegisterScalarIfElse(FunctionRegistry* registry) { AddPrimitiveKernels(func, NumericTypes()); AddPrimitiveKernels(func, TemporalTypes()); + AddPrimitiveKernels(func, {boolean()}); // todo add temporal, boolean, null and binary kernels DCHECK_OK(registry->AddFunction(std::move(func))); diff --git a/cpp/src/arrow/compute/kernels/scalar_if_else_test.cc b/cpp/src/arrow/compute/kernels/scalar_if_else_test.cc index 9b5dac71ff473..f970b99456e45 100644 --- a/cpp/src/arrow/compute/kernels/scalar_if_else_test.cc +++ b/cpp/src/arrow/compute/kernels/scalar_if_else_test.cc @@ -47,7 +47,7 @@ void CheckIfElseOutputArray(const std::shared_ptr& type, CheckIfElseOutputArray(cond_, left_, right_, expected_, all_valid); } -class TestIfElseNullKernel : public ::testing::Test {}; +class TestIfElseKernel : public ::testing::Test {}; template class TestIfElsePrimitive : public ::testing::Test {}; @@ -75,18 +75,16 @@ TYPED_TEST(TestIfElsePrimitive, IfElseFixedSize) { CheckIfElseOutputArray(type, "[true, true, true, false]", "[1, 2, null, null]", "[null, 6, 7, null]", "[1, 2, null, null]", false); - using ArrayType = typename TypeTraits::ArrayType; random::RandomArrayGenerator rand(/*seed=*/0); int64_t len = 1000; auto cond = std::static_pointer_cast( rand.ArrayOf(boolean(), len, /*null_probability=*/0.01)); - auto left = std::static_pointer_cast( + auto left = std::static_pointer_cast( rand.ArrayOf(type, len, /*null_probability=*/0.01)); - auto right = std::static_pointer_cast( + auto right = std::static_pointer_cast( rand.ArrayOf(type, len, /*null_probability=*/0.01)); - typename TypeTraits::BuilderType builder; - + BooleanBuilder builder; for (int64_t i = 0; i < len; ++i) { if (!cond->IsValid(i) || (cond->Value(i) && !left->IsValid(i)) || (!cond->Value(i) && !right->IsValid(i))) { @@ -105,5 +103,53 @@ TYPED_TEST(TestIfElsePrimitive, IfElseFixedSize) { CheckIfElseOutputArray(cond, left, right, expected_data, false); } +TEST_F(TestIfElseKernel, IfElseBoolean) { + // using ScalarType = typename TypeTraits::ScalarType; + // auto scalar = std::make_shared(static_cast(5)); + auto type = boolean(); + // No Nulls + CheckIfElseOutputArray(type, "[]", "[]", "[]", "[]"); + + CheckIfElseOutputArray(type, "[true, true, true, false]", + "[false, false, false, false]", "[true, true, true, true]", + "[false, false, false, true]"); + + CheckIfElseOutputArray(type, "[true, true, null, false]", + "[false, false, false, false]", "[true, true, true, true]", + "[false, false, null, true]", false); + + CheckIfElseOutputArray(type, "[true, true, true, false]", "[true, false, null, null]", + "[null, false, true, null]", "[true, false, null, null]", false); + +// using ArrayType = typename TypeTraits::ArrayType; +// random::RandomArrayGenerator rand(/*seed=*/0); +// int64_t len = 1000; +// auto cond = std::static_pointer_cast( +// rand.ArrayOf(boolean(), len, /*null_probability=*/0.01)); +// auto left = std::static_pointer_cast( +// rand.ArrayOf(type, len, /*null_probability=*/0.01)); +// auto right = std::static_pointer_cast( +// rand.ArrayOf(type, len, /*null_probability=*/0.01)); +// +// typename TypeTraits::BuilderType builder; +// +// for (int64_t i = 0; i < len; ++i) { +// if (!cond->IsValid(i) || (cond->Value(i) && !left->IsValid(i)) || +// (!cond->Value(i) && !right->IsValid(i))) { +// ASSERT_OK(builder.AppendNull()); +// continue; +// } +// +// if (cond->Value(i)) { +// ASSERT_OK(builder.Append(left->Value(i))); +// } else { +// ASSERT_OK(builder.Append(right->Value(i))); +// } +// } +// ASSERT_OK_AND_ASSIGN(auto expected_data, builder.Finish()); +// +// CheckIfElseOutputArray(cond, left, right, expected_data, false); +} + } // namespace compute } // namespace arrow \ No newline at end of file From 672a99f181757bd29c58a85e1faf75b1cf43cf04 Mon Sep 17 00:00:00 2001 From: niranda perera Date: Fri, 21 May 2021 17:36:06 -0400 Subject: [PATCH 04/39] adding null kernel --- .../arrow/compute/kernels/scalar_if_else.cc | 7 +- .../compute/kernels/scalar_if_else_test.cc | 73 ++++++++++--------- 2 files changed, 43 insertions(+), 37 deletions(-) diff --git a/cpp/src/arrow/compute/kernels/scalar_if_else.cc b/cpp/src/arrow/compute/kernels/scalar_if_else.cc index fe3e31268acc3..e00ee7bb2c60c 100644 --- a/cpp/src/arrow/compute/kernels/scalar_if_else.cc +++ b/cpp/src/arrow/compute/kernels/scalar_if_else.cc @@ -172,6 +172,8 @@ template struct IfElseFunctor::value>> { static Status Call(KernelContext* ctx, const ArrayData& cond, const ArrayData& left, const ArrayData& right, ArrayData* out) { + // Nothing preallocated, so we assign left into the output + *out = left; return Status::OK(); } @@ -248,6 +250,7 @@ void AddPrimitiveKernels(const std::shared_ptr& scalar_function, const std::vector>& types) { for (auto&& type : types) { auto exec = internal::GenerateTypeAgnosticPrimitive(*type); + // cond array needs to be boolean always ScalarKernel kernel({boolean(), type, type}, type, exec); kernel.null_handling = NullHandling::COMPUTED_NO_PREALLOCATE; kernel.mem_allocation = MemAllocation::NO_PREALLOCATE; @@ -271,8 +274,8 @@ void RegisterScalarIfElse(FunctionRegistry* registry) { AddPrimitiveKernels(func, NumericTypes()); AddPrimitiveKernels(func, TemporalTypes()); - AddPrimitiveKernels(func, {boolean()}); - // todo add temporal, boolean, null and binary kernels + AddPrimitiveKernels(func, {boolean(), null()}); + // todo add binary kernels DCHECK_OK(registry->AddFunction(std::move(func))); } diff --git a/cpp/src/arrow/compute/kernels/scalar_if_else_test.cc b/cpp/src/arrow/compute/kernels/scalar_if_else_test.cc index f970b99456e45..f1574bd3d28b9 100644 --- a/cpp/src/arrow/compute/kernels/scalar_if_else_test.cc +++ b/cpp/src/arrow/compute/kernels/scalar_if_else_test.cc @@ -60,9 +60,8 @@ using PrimitiveTypes = ::testing::Types::ScalarType; auto type = TypeTraits::type_singleton(); - // auto scalar = std::make_shared(static_cast(5)); + // No Nulls CheckIfElseOutputArray(type, "[]", "[]", "[]", "[]"); @@ -75,16 +74,18 @@ TYPED_TEST(TestIfElsePrimitive, IfElseFixedSize) { CheckIfElseOutputArray(type, "[true, true, true, false]", "[1, 2, null, null]", "[null, 6, 7, null]", "[1, 2, null, null]", false); + using ArrayType = typename TypeTraits::ArrayType; random::RandomArrayGenerator rand(/*seed=*/0); int64_t len = 1000; auto cond = std::static_pointer_cast( rand.ArrayOf(boolean(), len, /*null_probability=*/0.01)); - auto left = std::static_pointer_cast( + auto left = std::static_pointer_cast( rand.ArrayOf(type, len, /*null_probability=*/0.01)); - auto right = std::static_pointer_cast( + auto right = std::static_pointer_cast( rand.ArrayOf(type, len, /*null_probability=*/0.01)); - BooleanBuilder builder; + typename TypeTraits::BuilderType builder; + for (int64_t i = 0; i < len; ++i) { if (!cond->IsValid(i) || (cond->Value(i) && !left->IsValid(i)) || (!cond->Value(i) && !right->IsValid(i))) { @@ -104,8 +105,6 @@ TYPED_TEST(TestIfElsePrimitive, IfElseFixedSize) { } TEST_F(TestIfElseKernel, IfElseBoolean) { - // using ScalarType = typename TypeTraits::ScalarType; - // auto scalar = std::make_shared(static_cast(5)); auto type = boolean(); // No Nulls CheckIfElseOutputArray(type, "[]", "[]", "[]", "[]"); @@ -121,34 +120,38 @@ TEST_F(TestIfElseKernel, IfElseBoolean) { CheckIfElseOutputArray(type, "[true, true, true, false]", "[true, false, null, null]", "[null, false, true, null]", "[true, false, null, null]", false); -// using ArrayType = typename TypeTraits::ArrayType; -// random::RandomArrayGenerator rand(/*seed=*/0); -// int64_t len = 1000; -// auto cond = std::static_pointer_cast( -// rand.ArrayOf(boolean(), len, /*null_probability=*/0.01)); -// auto left = std::static_pointer_cast( -// rand.ArrayOf(type, len, /*null_probability=*/0.01)); -// auto right = std::static_pointer_cast( -// rand.ArrayOf(type, len, /*null_probability=*/0.01)); -// -// typename TypeTraits::BuilderType builder; -// -// for (int64_t i = 0; i < len; ++i) { -// if (!cond->IsValid(i) || (cond->Value(i) && !left->IsValid(i)) || -// (!cond->Value(i) && !right->IsValid(i))) { -// ASSERT_OK(builder.AppendNull()); -// continue; -// } -// -// if (cond->Value(i)) { -// ASSERT_OK(builder.Append(left->Value(i))); -// } else { -// ASSERT_OK(builder.Append(right->Value(i))); -// } -// } -// ASSERT_OK_AND_ASSIGN(auto expected_data, builder.Finish()); -// -// CheckIfElseOutputArray(cond, left, right, expected_data, false); + random::RandomArrayGenerator rand(/*seed=*/0); + int64_t len = 1000; + auto cond = std::static_pointer_cast( + rand.ArrayOf(boolean(), len, /*null_probability=*/0.01)); + auto left = std::static_pointer_cast( + rand.ArrayOf(type, len, /*null_probability=*/0.01)); + auto right = std::static_pointer_cast( + rand.ArrayOf(type, len, /*null_probability=*/0.01)); + + BooleanBuilder builder; + for (int64_t i = 0; i < len; ++i) { + if (!cond->IsValid(i) || (cond->Value(i) && !left->IsValid(i)) || + (!cond->Value(i) && !right->IsValid(i))) { + ASSERT_OK(builder.AppendNull()); + continue; + } + + if (cond->Value(i)) { + ASSERT_OK(builder.Append(left->Value(i))); + } else { + ASSERT_OK(builder.Append(right->Value(i))); + } + } + ASSERT_OK_AND_ASSIGN(auto expected_data, builder.Finish()); + + CheckIfElseOutputArray(cond, left, right, expected_data, false); +} + +TEST_F(TestIfElseKernel, IfElseNull) { + CheckIfElseOutputArray(null(), "[null, null, null, null]", "[null, null, null, null]", + "[null, null, null, null]", "[null, null, null, null]", + /*all_valid=*/false); } } // namespace compute From 0300b8e5770db3a6ad3ff72084a776a9eb0a8053 Mon Sep 17 00:00:00 2001 From: niranda perera Date: Mon, 24 May 2021 10:08:18 -0400 Subject: [PATCH 05/39] adding PR comments --- cpp/src/arrow/compute/api_scalar.h | 2 +- .../arrow/compute/kernels/scalar_if_else.cc | 28 +++++++++---------- .../compute/kernels/scalar_if_else_test.cc | 1 - 3 files changed, 15 insertions(+), 16 deletions(-) diff --git a/cpp/src/arrow/compute/api_scalar.h b/cpp/src/arrow/compute/api_scalar.h index c4b70b3998998..ed8a55e0a0757 100644 --- a/cpp/src/arrow/compute/api_scalar.h +++ b/cpp/src/arrow/compute/api_scalar.h @@ -479,7 +479,7 @@ Result FillNull(const Datum& values, const Datum& fill_value, /// \note API not yet finalized ARROW_EXPORT Result IfElse(const Datum& cond, const Datum& left, const Datum& right, - ExecContext* ctx = NULLPTR); + ExecContext* ctx = NULLPTR); } // namespace compute } // namespace arrow diff --git a/cpp/src/arrow/compute/kernels/scalar_if_else.cc b/cpp/src/arrow/compute/kernels/scalar_if_else.cc index e00ee7bb2c60c..2b971cf323618 100644 --- a/cpp/src/arrow/compute/kernels/scalar_if_else.cc +++ b/cpp/src/arrow/compute/kernels/scalar_if_else.cc @@ -31,8 +31,8 @@ namespace { // nulls will be promoted as follows // cond.val && (cond.data && left.val || ~cond.data && right.val) -Status promote_nulls(KernelContext* ctx, const ArrayData& cond, const ArrayData& left, - const ArrayData& right, ArrayData* output) { +Status PromoteNulls(KernelContext* ctx, const ArrayData& cond, const ArrayData& left, + const ArrayData& right, ArrayData* output) { if (!cond.MayHaveNulls() && !left.MayHaveNulls() && !right.MayHaveNulls()) { return Status::OK(); // no nulls to handle } @@ -74,12 +74,14 @@ template struct IfElseFunctor {}; template -struct IfElseFunctor::value>> { +struct IfElseFunctor< + Type, swap, + enable_if_t::value | is_temporal_type::value>> { using T = typename TypeTraits::CType; static Status Call(KernelContext* ctx, const ArrayData& cond, const ArrayData& left, const ArrayData& right, ArrayData* out) { - ARROW_RETURN_NOT_OK(promote_nulls(ctx, cond, left, right, out)); + ARROW_RETURN_NOT_OK(PromoteNulls(ctx, cond, left, right, out)); ARROW_ASSIGN_OR_RAISE(std::shared_ptr out_buf, ctx->Allocate(cond.length * sizeof(T))); @@ -132,10 +134,10 @@ struct IfElseFunctor::value>> { }; template -struct IfElseFunctor::value>> { +struct IfElseFunctor> { static Status Call(KernelContext* ctx, const ArrayData& cond, const ArrayData& left, const ArrayData& right, ArrayData* out) { - ARROW_RETURN_NOT_OK(promote_nulls(ctx, cond, left, right, out)); + ARROW_RETURN_NOT_OK(PromoteNulls(ctx, cond, left, right, out)); // out_buff = right & ~cond ARROW_ASSIGN_OR_RAISE(std::shared_ptr out_buf, @@ -169,7 +171,7 @@ struct IfElseFunctor::value>> { }; template -struct IfElseFunctor::value>> { +struct IfElseFunctor> { static Status Call(KernelContext* ctx, const ArrayData& cond, const ArrayData& left, const ArrayData& right, ArrayData* out) { // Nothing preallocated, so we assign left into the output @@ -191,8 +193,6 @@ struct IfElseFunctor::value>> { template struct ResolveExec { static Status Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) { - if (batch.length == 0) return Status::OK(); - if (batch[0].kind() == Datum::ARRAY) { if (batch[1].kind() == Datum::ARRAY) { if (batch[2].kind() == Datum::ARRAY) { // AAA @@ -246,8 +246,8 @@ struct ResolveExec { } }; -void AddPrimitiveKernels(const std::shared_ptr& scalar_function, - const std::vector>& types) { +void AddPrimitiveIfElseKernels(const std::shared_ptr& scalar_function, + const std::vector>& types) { for (auto&& type : types) { auto exec = internal::GenerateTypeAgnosticPrimitive(*type); // cond array needs to be boolean always @@ -272,9 +272,9 @@ void RegisterScalarIfElse(FunctionRegistry* registry) { auto func = std::make_shared("if_else", Arity::Ternary(), &if_else_doc); - AddPrimitiveKernels(func, NumericTypes()); - AddPrimitiveKernels(func, TemporalTypes()); - AddPrimitiveKernels(func, {boolean(), null()}); + AddPrimitiveIfElseKernels(func, NumericTypes()); + AddPrimitiveIfElseKernels(func, TemporalTypes()); + AddPrimitiveIfElseKernels(func, {boolean(), null()}); // todo add binary kernels DCHECK_OK(registry->AddFunction(std::move(func))); diff --git a/cpp/src/arrow/compute/kernels/scalar_if_else_test.cc b/cpp/src/arrow/compute/kernels/scalar_if_else_test.cc index f1574bd3d28b9..cf81ebf944192 100644 --- a/cpp/src/arrow/compute/kernels/scalar_if_else_test.cc +++ b/cpp/src/arrow/compute/kernels/scalar_if_else_test.cc @@ -56,7 +56,6 @@ using PrimitiveTypes = ::testing::Types; - TYPED_TEST_SUITE(TestIfElsePrimitive, PrimitiveTypes); TYPED_TEST(TestIfElsePrimitive, IfElseFixedSize) { From b1ec65d6a41b891852ff37a21c1c299c70593941 Mon Sep 17 00:00:00 2001 From: niranda perera Date: Mon, 24 May 2021 20:58:18 -0400 Subject: [PATCH 06/39] adding more impl --- .../arrow/compute/kernels/scalar_if_else.cc | 315 ++++++++++++++---- 1 file changed, 248 insertions(+), 67 deletions(-) diff --git a/cpp/src/arrow/compute/kernels/scalar_if_else.cc b/cpp/src/arrow/compute/kernels/scalar_if_else.cc index 2b971cf323618..252408fbd8af8 100644 --- a/cpp/src/arrow/compute/kernels/scalar_if_else.cc +++ b/cpp/src/arrow/compute/kernels/scalar_if_else.cc @@ -29,8 +29,11 @@ namespace compute { namespace { -// nulls will be promoted as follows -// cond.val && (cond.data && left.val || ~cond.data && right.val) +/* + * nulls will be promoted as follows + * + * cond.val && (cond.data && left.val || ~cond.data && right.val) + */ Status PromoteNulls(KernelContext* ctx, const ArrayData& cond, const ArrayData& left, const ArrayData& right, ArrayData* output) { if (!cond.MayHaveNulls() && !left.MayHaveNulls() && !right.MayHaveNulls()) { @@ -38,29 +41,72 @@ Status PromoteNulls(KernelContext* ctx, const ArrayData& cond, const ArrayData& } const int64_t len = cond.length; - ARROW_ASSIGN_OR_RAISE(std::shared_ptr out_validity, ctx->AllocateBitmap(len)); - arrow::internal::InvertBitmap(out_validity->data(), 0, len, - out_validity->mutable_data(), 0); - if (right.MayHaveNulls()) { - // out_validity = right.val && ~cond.data - arrow::internal::BitmapAndNot(right.buffers[0]->data(), right.offset, - cond.buffers[1]->data(), cond.offset, len, 0, - out_validity->mutable_data()); + // out_validity = ~cond.data + ARROW_ASSIGN_OR_RAISE( + std::shared_ptr out_validity, + arrow::internal::InvertBitmap(ctx->memory_pool(), cond.buffers[1]->data(), + cond.offset, len)); + + if (right.MayHaveNulls()) { // out_validity = right.val && ~cond.data + arrow::internal::BitmapAnd(right.buffers[0]->data(), right.offset, + out_validity->data(), 0, len, 0, + out_validity->mutable_data()); } + std::shared_ptr tmp_buf; if (left.MayHaveNulls()) { // tmp_buf = left.val && cond.data - ARROW_ASSIGN_OR_RAISE(std::shared_ptr temp_buf, - arrow::internal::BitmapAnd( - ctx->memory_pool(), left.buffers[0]->data(), left.offset, - cond.buffers[1]->data(), cond.offset, len, 0)); - // out_validity = cond.data && left.val || ~cond.data && right.val - arrow::internal::BitmapOr(out_validity->data(), 0, temp_buf->data(), 0, len, 0, - out_validity->mutable_data()); + ARROW_ASSIGN_OR_RAISE( + tmp_buf, arrow::internal::BitmapAnd(ctx->memory_pool(), left.buffers[0]->data(), + left.offset, cond.buffers[1]->data(), + cond.offset, len, 0)); + } else { // if left all valid --> tmp_buf = cond.data (zero copy slice) + tmp_buf = SliceBuffer(cond.buffers[1], cond.offset, cond.length); + } + + // out_validity = cond.data && left.val || ~cond.data && right.val + arrow::internal::BitmapOr(out_validity->data(), 0, tmp_buf->data(), 0, len, 0, + out_validity->mutable_data()); + + if (cond.MayHaveNulls()) { + // out_validity = cond.val && (cond.data && left.val || ~cond.data && right.val) + ::arrow::internal::BitmapAnd(out_validity->data(), 0, cond.buffers[0]->data(), + cond.offset, len, 0, out_validity->mutable_data()); + } + + output->buffers[0] = std::move(out_validity); + output->GetNullCount(); // update null count + return Status::OK(); +} + +// cond.val && (cond.data && left.val || ~cond.data && right.val) +Status PromoteNulls(KernelContext* ctx, const ArrayData& cond, const Scalar& left, + const ArrayData& right, ArrayData* output) { + if (!cond.MayHaveNulls() && left.is_valid && !right.MayHaveNulls()) { + return Status::OK(); // no nulls to handle + } + const int64_t len = cond.length; + + // out_validity = ~cond.data + ARROW_ASSIGN_OR_RAISE( + std::shared_ptr out_validity, + arrow::internal::InvertBitmap(ctx->memory_pool(), cond.buffers[1]->data(), + cond.offset, len)); + + if (right.MayHaveNulls()) { // out_validity = right.val && ~cond.data + arrow::internal::BitmapAnd(right.buffers[0]->data(), right.offset, + out_validity->data(), 0, len, 0, + out_validity->mutable_data()); + } + + // out_validity = cond.data || ~cond.data && right.val + if (left.is_valid) { + arrow::internal::BitmapOr(out_validity->data(), 0, cond.buffers[1]->data(), + cond.offset, len, 0, out_validity->mutable_data()); } + // out_validity = cond.val && (cond.data || ~cond.data && right.val) if (cond.MayHaveNulls()) { - // out_validity &= cond.val ::arrow::internal::BitmapAnd(out_validity->data(), 0, cond.buffers[0]->data(), cond.offset, len, 0, out_validity->mutable_data()); } @@ -70,15 +116,34 @@ Status PromoteNulls(KernelContext* ctx, const ArrayData& cond, const ArrayData& return Status::OK(); } -template +// todo: this could be dangerous because the inverted arraydata buffer[1] may not be +// available outside Exec's scope +Status InvertBoolArrayData(KernelContext* ctx, const ArrayData& input, + ArrayData* output) { + if (input.MayHaveNulls()) { + output->buffers.emplace_back( + SliceBuffer(input.buffers[0], input.offset, input.length)); + } else { + output->buffers.push_back(NULLPTR); + } + + ARROW_ASSIGN_OR_RAISE( + std::shared_ptr inv_data, + arrow::internal::InvertBitmap(ctx->memory_pool(), input.buffers[1]->data(), + input.offset, input.length)); + output->buffers.emplace_back(std::move(inv_data)); + return Status::OK(); +} + +template struct IfElseFunctor {}; -template +template struct IfElseFunctor< - Type, swap, - enable_if_t::value | is_temporal_type::value>> { + Type, enable_if_t::value | is_temporal_type::value>> { using T = typename TypeTraits::CType; + // AAA static Status Call(KernelContext* ctx, const ArrayData& cond, const ArrayData& left, const ArrayData& right, ArrayData* out) { ARROW_RETURN_NOT_OK(PromoteNulls(ctx, cond, left, right, out)); @@ -120,12 +185,70 @@ struct IfElseFunctor< return Status::OK(); } - static Status Call(KernelContext* ctx, const ArrayData& cond, const ArrayData& left, + // ASA and AAS + static Status Call(KernelContext* ctx, const ArrayData& cond, const Scalar& left, + const ArrayData& right, ArrayData* out) { + ARROW_RETURN_NOT_OK(PromoteNulls(ctx, cond, left, right, out)); + + ARROW_ASSIGN_OR_RAISE(std::shared_ptr out_buf, + ctx->Allocate(cond.length * sizeof(T))); + T* out_values = reinterpret_cast(out_buf->mutable_data()); + + // copy right data to out_buff + const T* right_data = right.GetValues(1); + std::memcpy(out_values, right_data, right.length * sizeof(T)); + + const auto* cond_data = cond.buffers[1]->data(); // this is a BoolArray + BitBlockCounter bit_counter(cond_data, cond.offset, cond.length); + + // selectively copy values from left data + T left_data = internal::UnboxScalar::Unbox(left); + int64_t offset = cond.offset; + + // todo this can be improved by intrinsics. ex: _mm*_mask_store_e* (vmovdqa*) + while (offset < cond.offset + cond.length) { + const BitBlockCount& block = bit_counter.NextWord(); + if (block.AllSet()) { // all from left + std::fill(out_values, out_values + block.length, left_data); + } else if (block.popcount) { // selectively copy from left + for (int64_t i = 0; i < block.length; ++i) { + if (BitUtil::GetBit(cond_data, offset + i)) { + out_values[i] = left_data; + } + } + } + + offset += block.length; + out_values += block.length; + } + + out->buffers[1] = std::move(out_buf); + return Status::OK(); + } + + // ASS + static Status Call(KernelContext* ctx, const ArrayData& cond, const Scalar& left, const Scalar& right, ArrayData* out) { // todo impl return Status::OK(); } + // SAA + static Status Call(KernelContext* ctx, const Scalar& cond, const ArrayData& left, + const ArrayData& right, ArrayData* out) { + *out = dynamic_cast(cond).value ? left : right; + return Status::OK(); + } + + // SSA and SAS + template + static Status Call(KernelContext* ctx, const Scalar& cond, const Scalar& left, + const ArrayData& right, ArrayData* out) { + // todo impl + return Status::OK(); + } + + // SSS static Status Call(KernelContext* ctx, const Scalar& cond, const Scalar& left, const Scalar& right, Scalar* out) { // todo impl @@ -133,8 +256,9 @@ struct IfElseFunctor< } }; -template -struct IfElseFunctor> { +template +struct IfElseFunctor> { + // AAA static Status Call(KernelContext* ctx, const ArrayData& cond, const ArrayData& left, const ArrayData& right, ArrayData* out) { ARROW_RETURN_NOT_OK(PromoteNulls(ctx, cond, left, right, out)); @@ -157,12 +281,50 @@ struct IfElseFunctor> { return Status::OK(); } - static Status Call(KernelContext* ctx, const ArrayData& cond, const ArrayData& left, + // ASA and AAS + static Status Call(KernelContext* ctx, const ArrayData& cond, const Scalar& left, + const ArrayData& right, ArrayData* out) { + ARROW_RETURN_NOT_OK(PromoteNulls(ctx, cond, left, right, out)); + + // out_buff = right & ~cond + ARROW_ASSIGN_OR_RAISE(std::shared_ptr out_buf, + arrow::internal::BitmapAndNot( + ctx->memory_pool(), right.buffers[1]->data(), right.offset, + cond.buffers[1]->data(), cond.offset, cond.length, 0)); + + // out_buff = left & cond + bool left_data = internal::UnboxScalar::Unbox(left); + if (left_data) { + arrow::internal::BitmapOr(out_buf->data(), 0, cond.buffers[1]->data(), cond.offset, + cond.length, 0, out_buf->mutable_data()); + } + + out->buffers[1] = std::move(out_buf); + return Status::OK(); + } + + // ASS + static Status Call(KernelContext* ctx, const ArrayData& cond, const Scalar& left, const Scalar& right, ArrayData* out) { // todo impl return Status::OK(); } + // SAS and SSA + static Status Call(KernelContext* ctx, const Scalar& cond, const ArrayData& left, + const ArrayData& right, ArrayData* out) { + // todo impl + return Status::OK(); + } + + // SAS and SSA + static Status Call(KernelContext* ctx, const Scalar& cond, const Scalar& left, + const ArrayData& right, ArrayData* out) { + // todo impl + return Status::OK(); + } + + // SSS static Status Call(KernelContext* ctx, const Scalar& cond, const Scalar& left, const Scalar& right, Scalar* out) { // todo impl @@ -170,28 +332,56 @@ struct IfElseFunctor> { } }; -template -struct IfElseFunctor> { - static Status Call(KernelContext* ctx, const ArrayData& cond, const ArrayData& left, - const ArrayData& right, ArrayData* out) { - // Nothing preallocated, so we assign left into the output - *out = left; +template +struct IfElseFunctor> { + template + static inline Status ReturnCopy(const T& in, T* out) { + // Nothing preallocated, so we assign in into the output + *out = in; return Status::OK(); } + // AAA static Status Call(KernelContext* ctx, const ArrayData& cond, const ArrayData& left, + const ArrayData& right, ArrayData* out) { + return ReturnCopy(left, out); + } + + // ASA and AAS + static Status Call(KernelContext* ctx, const ArrayData& cond, const Scalar& left, + const ArrayData& right, ArrayData* out) { + return ReturnCopy(right, out); + } + + // ASS + static Status Call(KernelContext* ctx, const ArrayData& cond, const Scalar& left, const Scalar& right, ArrayData* out) { - return Status::OK(); + // todo: this could be dangerous if cond is created by calling InvertBoolArrayData + // because the inverted array may not be available outside Exec's scope + return ReturnCopy(cond, out); } + // SAA + static Status Call(KernelContext* ctx, const Scalar& cond, const ArrayData& left, + const ArrayData& right, ArrayData* out) { + return ReturnCopy(left, out); + } + + // SSA and SAS + static Status Call(KernelContext* ctx, const Scalar& cond, const Scalar& left, + const ArrayData& right, ArrayData* out) { + return ReturnCopy(right, out); + } + + // SSS static Status Call(KernelContext* ctx, const Scalar& cond, const Scalar& left, const Scalar& right, Scalar* out) { - return Status::OK(); + return ReturnCopy(left, out); } }; template -struct ResolveExec { +struct ResolveIfElseExec { static Status Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) { if (batch[0].kind() == Datum::ARRAY) { if (batch[1].kind() == Datum::ARRAY) { @@ -199,44 +389,35 @@ struct ResolveExec { return IfElseFunctor::Call(ctx, *batch[0].array(), *batch[1].array(), *batch[2].array(), out->mutable_array()); } else { // AAS - return IfElseFunctor::Call(ctx, *batch[0].array(), *batch[1].array(), - *batch[2].scalar(), out->mutable_array()); + ArrayData inv_cond; + RETURN_NOT_OK(InvertBoolArrayData(ctx, *batch[0].array(), &inv_cond)); + return IfElseFunctor::Call(ctx, inv_cond, *batch[2].scalar(), + *batch[1].array(), out->mutable_array()); } } else { - return Status::Invalid(""); - // if (batch[2].kind() == Datum::ARRAY) { // ASA - // return IfElseFunctor::Call(ctx, *batch[0].array(), - // *batch[2].array(), - // *batch[1].scalar(), - // out->mutable_array()); - // } else { // ASS - // return IfElseFunctor::Call(ctx, *batch[0].array(), - // *batch[1].scalar(), - // *batch[2].scalar(), - // out->mutable_array()); - // } + if (batch[2].kind() == Datum::ARRAY) { // ASA + return IfElseFunctor::Call(ctx, *batch[0].array(), *batch[1].scalar(), + *batch[2].array(), out->mutable_array()); + } else { // ASS + return IfElseFunctor::Call(ctx, *batch[0].array(), *batch[1].scalar(), + *batch[2].scalar(), out->mutable_array()); + } } } else { if (batch[1].kind() == Datum::ARRAY) { - return Status::Invalid(""); - // if (batch[2].kind() == Datum::ARRAY) { // SAA - // return IfElseFunctor::Call(ctx, *batch[0].scalar(), - // *batch[1].array(), - // *batch[2].array(), - // out->mutable_array()); - // } else { // SAS - // return IfElseFunctor::Call(ctx, *batch[0].scalar(), - // *batch[1].array(), - // *batch[2].scalar(), - // out->mutable_array()); - // } + if (batch[2].kind() == Datum::ARRAY) { // SAA + return IfElseFunctor::Call(ctx, *batch[0].scalar(), *batch[1].array(), + *batch[2].array(), out->mutable_array()); + } else { // SAS + ArrayData inv_cond; + RETURN_NOT_OK(InvertBoolArrayData(ctx, *batch[0].array(), &inv_cond)); + return IfElseFunctor::Call(ctx, inv_cond, *batch[2].scalar(), + *batch[1].array(), out->mutable_array()); + } } else { if (batch[2].kind() == Datum::ARRAY) { // SSA - return Status::Invalid(""); - // return IfElseFunctor::Call(ctx, *batch[0].scalar(), - // *batch[1].scalar(), - // *batch[2].array(), - // out->mutable_array()); + return IfElseFunctor::Call(ctx, *batch[0].scalar(), *batch[1].scalar(), + *batch[2].array(), out->mutable_array()); } else { // SSS return IfElseFunctor::Call(ctx, *batch[0].scalar(), *batch[1].scalar(), *batch[2].scalar(), out->scalar().get()); @@ -249,7 +430,7 @@ struct ResolveExec { void AddPrimitiveIfElseKernels(const std::shared_ptr& scalar_function, const std::vector>& types) { for (auto&& type : types) { - auto exec = internal::GenerateTypeAgnosticPrimitive(*type); + auto exec = internal::GenerateTypeAgnosticPrimitive(*type); // cond array needs to be boolean always ScalarKernel kernel({boolean(), type, type}, type, exec); kernel.null_handling = NullHandling::COMPUTED_NO_PREALLOCATE; From 52679f35482e5799c69d6d44a8e6308506288b10 Mon Sep 17 00:00:00 2001 From: niranda perera Date: Tue, 25 May 2021 11:28:22 -0400 Subject: [PATCH 07/39] simplifying exec --- .../arrow/compute/kernels/scalar_if_else.cc | 228 ++++++++++-------- 1 file changed, 123 insertions(+), 105 deletions(-) diff --git a/cpp/src/arrow/compute/kernels/scalar_if_else.cc b/cpp/src/arrow/compute/kernels/scalar_if_else.cc index 252408fbd8af8..8a14411aea918 100644 --- a/cpp/src/arrow/compute/kernels/scalar_if_else.cc +++ b/cpp/src/arrow/compute/kernels/scalar_if_else.cc @@ -29,11 +29,10 @@ namespace compute { namespace { -/* - * nulls will be promoted as follows - * - * cond.val && (cond.data && left.val || ~cond.data && right.val) - */ +// nulls will be promoted as follows: +// cond.val && (cond.data && left.val || ~cond.data && right.val) +// Note: we have to work on ArrayData. Otherwise we won't be able to handle array offsets +// AAA Status PromoteNulls(KernelContext* ctx, const ArrayData& cond, const ArrayData& left, const ArrayData& right, ArrayData* output) { if (!cond.MayHaveNulls() && !left.MayHaveNulls() && !right.MayHaveNulls()) { @@ -41,7 +40,7 @@ Status PromoteNulls(KernelContext* ctx, const ArrayData& cond, const ArrayData& } const int64_t len = cond.length; - // out_validity = ~cond.data + // out_validity = ~cond.data --> mask right values ARROW_ASSIGN_OR_RAISE( std::shared_ptr out_validity, arrow::internal::InvertBitmap(ctx->memory_pool(), cond.buffers[1]->data(), @@ -80,6 +79,7 @@ Status PromoteNulls(KernelContext* ctx, const ArrayData& cond, const ArrayData& } // cond.val && (cond.data && left.val || ~cond.data && right.val) +// ASA and AAS Status PromoteNulls(KernelContext* ctx, const ArrayData& cond, const Scalar& left, const ArrayData& right, ArrayData* output) { if (!cond.MayHaveNulls() && left.is_valid && !right.MayHaveNulls()) { @@ -92,20 +92,57 @@ Status PromoteNulls(KernelContext* ctx, const ArrayData& cond, const Scalar& lef std::shared_ptr out_validity, arrow::internal::InvertBitmap(ctx->memory_pool(), cond.buffers[1]->data(), cond.offset, len)); - + // out_validity = ~cond.data && right.val if (right.MayHaveNulls()) { // out_validity = right.val && ~cond.data arrow::internal::BitmapAnd(right.buffers[0]->data(), right.offset, out_validity->data(), 0, len, 0, out_validity->mutable_data()); } - // out_validity = cond.data || ~cond.data && right.val + // out_validity = cond.data && left.val || ~cond.data && right.val if (left.is_valid) { arrow::internal::BitmapOr(out_validity->data(), 0, cond.buffers[1]->data(), cond.offset, len, 0, out_validity->mutable_data()); } - // out_validity = cond.val && (cond.data || ~cond.data && right.val) + // out_validity = cond.val && (cond.data && left.val || ~cond.data && right.val) + if (cond.MayHaveNulls()) { + ::arrow::internal::BitmapAnd(out_validity->data(), 0, cond.buffers[0]->data(), + cond.offset, len, 0, out_validity->mutable_data()); + } + + output->buffers[0] = std::move(out_validity); + output->GetNullCount(); // update null count + return Status::OK(); +} + +// cond.val && (cond.data && left.val || ~cond.data && right.val) +// ASS +Status PromoteNulls(KernelContext* ctx, const ArrayData& cond, const Scalar& left, + const Scalar& right, ArrayData* output) { + if (!cond.MayHaveNulls() && left.is_valid && right.is_valid) { + return Status::OK(); // no nulls to handle + } + const int64_t len = cond.length; + + std::shared_ptr out_validity; + if (right.is_valid) { + // out_validity = ~cond.data + ARROW_ASSIGN_OR_RAISE( + out_validity, arrow::internal::InvertBitmap( + ctx->memory_pool(), cond.buffers[1]->data(), cond.offset, len)); + } else { + // out_validity = [0...] + ARROW_ASSIGN_OR_RAISE(out_validity, ctx->AllocateBitmap(len)); + } + + // out_validity = cond.data && left.val || ~cond.data && right.val + if (left.is_valid) { + arrow::internal::BitmapOr(out_validity->data(), 0, cond.buffers[1]->data(), + cond.offset, len, 0, out_validity->mutable_data()); + } + + // out_validity = cond.val && (cond.data && left.val || ~cond.data && right.val) if (cond.MayHaveNulls()) { ::arrow::internal::BitmapAnd(out_validity->data(), 0, cond.buffers[0]->data(), cond.offset, len, 0, out_validity->mutable_data()); @@ -120,6 +157,7 @@ Status PromoteNulls(KernelContext* ctx, const ArrayData& cond, const Scalar& lef // available outside Exec's scope Status InvertBoolArrayData(KernelContext* ctx, const ArrayData& input, ArrayData* output) { + // null buffer if (input.MayHaveNulls()) { output->buffers.emplace_back( SliceBuffer(input.buffers[0], input.offset, input.length)); @@ -127,6 +165,7 @@ Status InvertBoolArrayData(KernelContext* ctx, const ArrayData& input, output->buffers.push_back(NULLPTR); } + // data buffer ARROW_ASSIGN_OR_RAISE( std::shared_ptr inv_data, arrow::internal::InvertBitmap(ctx->memory_pool(), input.buffers[1]->data(), @@ -229,29 +268,41 @@ struct IfElseFunctor< // ASS static Status Call(KernelContext* ctx, const ArrayData& cond, const Scalar& left, const Scalar& right, ArrayData* out) { - // todo impl - return Status::OK(); - } +/* ARROW_RETURN_NOT_OK(PromoteNulls(ctx, cond, left, right, out)); - // SAA - static Status Call(KernelContext* ctx, const Scalar& cond, const ArrayData& left, - const ArrayData& right, ArrayData* out) { - *out = dynamic_cast(cond).value ? left : right; - return Status::OK(); - } + ARROW_ASSIGN_OR_RAISE(std::shared_ptr out_buf, + ctx->Allocate(cond.length * sizeof(T))); + T* out_values = reinterpret_cast(out_buf->mutable_data()); - // SSA and SAS - template - static Status Call(KernelContext* ctx, const Scalar& cond, const Scalar& left, - const ArrayData& right, ArrayData* out) { - // todo impl - return Status::OK(); - } + // copy right data to out_buff + const T* right_data = right.GetValues(1); + std::memcpy(out_values, right_data, right.length * sizeof(T)); - // SSS - static Status Call(KernelContext* ctx, const Scalar& cond, const Scalar& left, - const Scalar& right, Scalar* out) { - // todo impl + const auto* cond_data = cond.buffers[1]->data(); // this is a BoolArray + BitBlockCounter bit_counter(cond_data, cond.offset, cond.length); + + // selectively copy values from left data + T left_data = internal::UnboxScalar::Unbox(left); + int64_t offset = cond.offset; + + // todo this can be improved by intrinsics. ex: _mm*_mask_store_e* (vmovdqa*) + while (offset < cond.offset + cond.length) { + const BitBlockCount& block = bit_counter.NextWord(); + if (block.AllSet()) { // all from left + std::fill(out_values, out_values + block.length, left_data); + } else if (block.popcount) { // selectively copy from left + for (int64_t i = 0; i < block.length; ++i) { + if (BitUtil::GetBit(cond_data, offset + i)) { + out_values[i] = left_data; + } + } + } + + offset += block.length; + out_values += block.length; + } + + out->buffers[1] = std::move(out_buf);*/ return Status::OK(); } }; @@ -309,27 +360,6 @@ struct IfElseFunctor> { // todo impl return Status::OK(); } - - // SAS and SSA - static Status Call(KernelContext* ctx, const Scalar& cond, const ArrayData& left, - const ArrayData& right, ArrayData* out) { - // todo impl - return Status::OK(); - } - - // SAS and SSA - static Status Call(KernelContext* ctx, const Scalar& cond, const Scalar& left, - const ArrayData& right, ArrayData* out) { - // todo impl - return Status::OK(); - } - - // SSS - static Status Call(KernelContext* ctx, const Scalar& cond, const Scalar& left, - const Scalar& right, Scalar* out) { - // todo impl - return Status::OK(); - } }; template @@ -356,72 +386,60 @@ struct IfElseFunctor> { // ASS static Status Call(KernelContext* ctx, const ArrayData& cond, const Scalar& left, const Scalar& right, ArrayData* out) { - // todo: this could be dangerous if cond is created by calling InvertBoolArrayData - // because the inverted array may not be available outside Exec's scope return ReturnCopy(cond, out); } - - // SAA - static Status Call(KernelContext* ctx, const Scalar& cond, const ArrayData& left, - const ArrayData& right, ArrayData* out) { - return ReturnCopy(left, out); - } - - // SSA and SAS - static Status Call(KernelContext* ctx, const Scalar& cond, const Scalar& left, - const ArrayData& right, ArrayData* out) { - return ReturnCopy(right, out); - } - - // SSS - static Status Call(KernelContext* ctx, const Scalar& cond, const Scalar& left, - const Scalar& right, Scalar* out) { - return ReturnCopy(left, out); - } }; template struct ResolveIfElseExec { static Status Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) { - if (batch[0].kind() == Datum::ARRAY) { - if (batch[1].kind() == Datum::ARRAY) { - if (batch[2].kind() == Datum::ARRAY) { // AAA - return IfElseFunctor::Call(ctx, *batch[0].array(), *batch[1].array(), - *batch[2].array(), out->mutable_array()); - } else { // AAS - ArrayData inv_cond; - RETURN_NOT_OK(InvertBoolArrayData(ctx, *batch[0].array(), &inv_cond)); - return IfElseFunctor::Call(ctx, inv_cond, *batch[2].scalar(), - *batch[1].array(), out->mutable_array()); + // cond is scalar + if (batch[0].is_scalar()) { + const auto& cond = batch[0].scalar_as(); + + if (batch[1].is_scalar() && batch[2].is_scalar()) { + if (cond.is_valid) { + *out = cond.value ? batch[1].scalar() : batch[2].scalar(); + } else { + *out = MakeNullScalar(batch[1].type()); } - } else { - if (batch[2].kind() == Datum::ARRAY) { // ASA - return IfElseFunctor::Call(ctx, *batch[0].array(), *batch[1].scalar(), - *batch[2].array(), out->mutable_array()); - } else { // ASS - return IfElseFunctor::Call(ctx, *batch[0].array(), *batch[1].scalar(), - *batch[2].scalar(), out->mutable_array()); + } else { // either left or right is an array. output is always an array + int64_t bcast_size = std::max(batch[1].length(), batch[2].length()); + if (cond.is_valid) { + const auto& valid_data = cond.value ? batch[1] : batch[2]; + if (valid_data.is_array()) { + *out = valid_data; + } else { // valid data is a scalar that needs to be broadcasted + ARROW_ASSIGN_OR_RAISE(*out, + MakeArrayFromScalar(*valid_data.scalar(), bcast_size, + ctx->memory_pool())); + } + } else { // cond is null. create null array + ARROW_ASSIGN_OR_RAISE( + *out, MakeArrayOfNull(batch[1].type(), bcast_size, ctx->memory_pool())) } } + return Status::OK(); + } + + // cond is array. Use functors to sort things out + if (batch[1].kind() == Datum::ARRAY) { + if (batch[2].kind() == Datum::ARRAY) { // AAA + return IfElseFunctor::Call(ctx, *batch[0].array(), *batch[1].array(), + *batch[2].array(), out->mutable_array()); + } else { // AAS + ArrayData inv_cond; + RETURN_NOT_OK(InvertBoolArrayData(ctx, *batch[0].array(), &inv_cond)); + return IfElseFunctor::Call(ctx, inv_cond, *batch[2].scalar(), + *batch[1].array(), out->mutable_array()); + } } else { - if (batch[1].kind() == Datum::ARRAY) { - if (batch[2].kind() == Datum::ARRAY) { // SAA - return IfElseFunctor::Call(ctx, *batch[0].scalar(), *batch[1].array(), - *batch[2].array(), out->mutable_array()); - } else { // SAS - ArrayData inv_cond; - RETURN_NOT_OK(InvertBoolArrayData(ctx, *batch[0].array(), &inv_cond)); - return IfElseFunctor::Call(ctx, inv_cond, *batch[2].scalar(), - *batch[1].array(), out->mutable_array()); - } - } else { - if (batch[2].kind() == Datum::ARRAY) { // SSA - return IfElseFunctor::Call(ctx, *batch[0].scalar(), *batch[1].scalar(), - *batch[2].array(), out->mutable_array()); - } else { // SSS - return IfElseFunctor::Call(ctx, *batch[0].scalar(), *batch[1].scalar(), - *batch[2].scalar(), out->scalar().get()); - } + if (batch[2].kind() == Datum::ARRAY) { // ASA + return IfElseFunctor::Call(ctx, *batch[0].array(), *batch[1].scalar(), + *batch[2].array(), out->mutable_array()); + } else { // ASS + return IfElseFunctor::Call(ctx, *batch[0].array(), *batch[1].scalar(), + *batch[2].scalar(), out->mutable_array()); } } } From 2f21e9e3f336c62c30d9008e446624312310b4ec Mon Sep 17 00:00:00 2001 From: niranda perera Date: Tue, 25 May 2021 17:21:04 -0400 Subject: [PATCH 08/39] promote nulls with visitor --- .../arrow/compute/kernels/scalar_if_else.cc | 227 +++++++++++++++--- 1 file changed, 187 insertions(+), 40 deletions(-) diff --git a/cpp/src/arrow/compute/kernels/scalar_if_else.cc b/cpp/src/arrow/compute/kernels/scalar_if_else.cc index 8a14411aea918..f90d6146ad141 100644 --- a/cpp/src/arrow/compute/kernels/scalar_if_else.cc +++ b/cpp/src/arrow/compute/kernels/scalar_if_else.cc @@ -19,21 +19,168 @@ #include #include +#include "arrow/util/bitmap.h" #include "codegen_internal.h" namespace arrow { using internal::BitBlockCount; using internal::BitBlockCounter; +using internal::Bitmap; namespace compute { namespace { +// cond.val && (cond.data && left.val || ~cond.data && right.val) +enum IEBitmapIndex { C_VALID, C_DATA, L_VALID, R_VALID }; + +Status PromoteNullsNew1(KernelContext* ctx, const ArrayData& cond, const ArrayData& left, + const ArrayData& right, ArrayData* output) { + uint8_t flag = + !right.MayHaveNulls() * 4 + !left.MayHaveNulls() * 2 + !cond.MayHaveNulls(); + + Bitmap bitmaps[4]; + bitmaps[C_VALID] = {cond.buffers[0], cond.offset, cond.length}; + bitmaps[C_DATA] = {cond.buffers[1], cond.offset, cond.length}; + bitmaps[L_VALID] = {left.buffers[0], left.offset, left.length}; + bitmaps[R_VALID] = {right.buffers[0], right.offset, right.length}; + + uint64_t* out_validity = nullptr; + if (flag < 6) { + // there will be a validity buffer in the output + ARROW_ASSIGN_OR_RAISE(output->buffers[0], ctx->AllocateBitmap(cond.length)); + out_validity = output->GetMutableValues(0); + } + + // cond.val && (cond.data && left.val || ~cond.data && right.val) + int64_t i = 0; + switch (flag) { + case 7: // RLC = 111 + break; + case 6: // RLC = 110 + output->buffers[0] = cond.buffers[0]; + break; + case 5: // RLC = 101 + Bitmap::VisitWords({bitmaps[C_DATA], bitmaps[L_VALID]}, + [&](std::array words) { + out_validity[i] = (words[0] & words[1]) | ~words[0]; + i++; + }); + break; + case 4: // RLC = 100 + Bitmap::VisitWords({bitmaps[C_VALID], bitmaps[C_DATA], bitmaps[L_VALID]}, + [&](std::array words) { + out_validity[i] = + words[0] & ((words[1] & words[2]) | ~words[1]); + i++; + }); + break; + case 3: // RLC = 011 + Bitmap::VisitWords({bitmaps[C_DATA], bitmaps[R_VALID]}, + [&](std::array words) { + out_validity[i] = words[0] | (~words[0] & words[1]); + i++; + }); + break; + case 2: // RLC = 010 + Bitmap::VisitWords({bitmaps[C_VALID], bitmaps[C_DATA], bitmaps[R_VALID]}, + [&](std::array words) { + out_validity[i] = + words[0] & (words[1] | (~words[1] & words[2])); + i++; + }); + break; + case 1: // RLC = 001 + Bitmap::VisitWords({bitmaps[C_DATA], bitmaps[L_VALID], bitmaps[R_VALID]}, + [&](std::array words) { + out_validity[i] = + (words[0] & words[1]) | (~words[0] & words[2]); + i++; + }); + break; + case 0: // RLC = 000 + Bitmap::VisitWords(bitmaps, [&](std::array words) { + out_validity[i] = words[0] & ((words[1] & words[2]) | (~words[1] & words[3])); + i++; + }); + break; + } + return Status::OK(); +} + +/*Status PromoteNullsNew(KernelContext* ctx, const ArrayData& cond, const ArrayData& left, + const ArrayData& right, ArrayData* output) { + if (!cond.MayHaveNulls() && !left.MayHaveNulls() && !right.MayHaveNulls()) { + return Status::OK(); // no nulls to handle + } + + // there will be a validity buffer in the output + ARROW_ASSIGN_OR_RAISE(output->buffers[0], ctx->AllocateBitmap(cond.length)); + auto out_validity = output->GetMutableValues(0); + + Bitmap bitmaps[4]; + bitmaps[C_VALID] = {cond.buffers[0], cond.offset, cond.length}; + bitmaps[C_DATA] = {cond.buffers[1], cond.offset, cond.length}; + bitmaps[L_VALID] = {left.buffers[0], left.offset, left.length}; + bitmaps[R_VALID] = {right.buffers[0], right.offset, right.length}; + + uint8_t flag = + (cond.null_count == 0) * 4 + (left.null_count == 0) * 2 + (right.null_count == 0); + + int64_t i = 0; + switch (flag) { + case 0: // all have nulls + Bitmap::VisitWords(bitmaps, [&](std::array words) { + out_validity[i] = words[0] & ((words[1] & words[2]) | (~words[1] & words[3])); + i++; + }); + break; + case 1: // right all valid + Bitmap::VisitWords({bitmaps[C_VALID], bitmaps[C_DATA], bitmaps[L_VALID]}, + [&](std::array words) { + out_validity[i] = + words[0] & ((words[1] & words[2]) | ~words[1]); + i++; + }); + break; + case 2: // left all valid + Bitmap::VisitWords({bitmaps[C_VALID], bitmaps[C_DATA], bitmaps[R_VALID]}, + [&](std::array words) { + out_validity[i] = + words[0] & (words[1] | (~words[1] & words[2])); + i++; + }); + break; + case 3: // left, right all valid + *ou break; + + case 7: // all valid. nothing to do + return Status::OK(); + } + + if (cond.null_count == 0) { + } + + if (right.null_count == 0) { + Bitmap::VisitWords(bitmaps, [&](std::array words) { + apply(words[C_VALID], words[C_DATA], words[L_VALID], ~uint64_t(0)); + }); + return Status::OK(); + } + + DCHECK(left.null_count != 0 && right.null_count != 0); + Bitmap::VisitWords(bitmaps, [&](std::array words) { + apply(words[C_VALID], words[C_DATA], words[L_VALID], words[R_VALID]); + }); + + return Status::OK(); +}*/ + // nulls will be promoted as follows: // cond.val && (cond.data && left.val || ~cond.data && right.val) -// Note: we have to work on ArrayData. Otherwise we won't be able to handle array offsets -// AAA -Status PromoteNulls(KernelContext* ctx, const ArrayData& cond, const ArrayData& left, +// Note: we have to work on ArrayData. Otherwise we won't be able to handle array +// offsets AAA +/*Status PromoteNulls(KernelContext* ctx, const ArrayData& cond, const ArrayData& left, const ArrayData& right, ArrayData* output) { if (!cond.MayHaveNulls() && !left.MayHaveNulls() && !right.MayHaveNulls()) { return Status::OK(); // no nulls to handle @@ -76,7 +223,7 @@ Status PromoteNulls(KernelContext* ctx, const ArrayData& cond, const ArrayData& output->buffers[0] = std::move(out_validity); output->GetNullCount(); // update null count return Status::OK(); -} +}*/ // cond.val && (cond.data && left.val || ~cond.data && right.val) // ASA and AAS @@ -115,7 +262,7 @@ Status PromoteNulls(KernelContext* ctx, const ArrayData& cond, const Scalar& lef output->GetNullCount(); // update null count return Status::OK(); } - +/* // cond.val && (cond.data && left.val || ~cond.data && right.val) // ASS Status PromoteNulls(KernelContext* ctx, const ArrayData& cond, const Scalar& left, @@ -151,7 +298,7 @@ Status PromoteNulls(KernelContext* ctx, const ArrayData& cond, const Scalar& lef output->buffers[0] = std::move(out_validity); output->GetNullCount(); // update null count return Status::OK(); -} +}*/ // todo: this could be dangerous because the inverted arraydata buffer[1] may not be // available outside Exec's scope @@ -185,7 +332,7 @@ struct IfElseFunctor< // AAA static Status Call(KernelContext* ctx, const ArrayData& cond, const ArrayData& left, const ArrayData& right, ArrayData* out) { - ARROW_RETURN_NOT_OK(PromoteNulls(ctx, cond, left, right, out)); + ARROW_RETURN_NOT_OK(PromoteNullsNew1(ctx, cond, left, right, out)); ARROW_ASSIGN_OR_RAISE(std::shared_ptr out_buf, ctx->Allocate(cond.length * sizeof(T))); @@ -268,41 +415,41 @@ struct IfElseFunctor< // ASS static Status Call(KernelContext* ctx, const ArrayData& cond, const Scalar& left, const Scalar& right, ArrayData* out) { -/* ARROW_RETURN_NOT_OK(PromoteNulls(ctx, cond, left, right, out)); - - ARROW_ASSIGN_OR_RAISE(std::shared_ptr out_buf, - ctx->Allocate(cond.length * sizeof(T))); - T* out_values = reinterpret_cast(out_buf->mutable_data()); - - // copy right data to out_buff - const T* right_data = right.GetValues(1); - std::memcpy(out_values, right_data, right.length * sizeof(T)); - - const auto* cond_data = cond.buffers[1]->data(); // this is a BoolArray - BitBlockCounter bit_counter(cond_data, cond.offset, cond.length); - - // selectively copy values from left data - T left_data = internal::UnboxScalar::Unbox(left); - int64_t offset = cond.offset; - - // todo this can be improved by intrinsics. ex: _mm*_mask_store_e* (vmovdqa*) - while (offset < cond.offset + cond.length) { - const BitBlockCount& block = bit_counter.NextWord(); - if (block.AllSet()) { // all from left - std::fill(out_values, out_values + block.length, left_data); - } else if (block.popcount) { // selectively copy from left - for (int64_t i = 0; i < block.length; ++i) { - if (BitUtil::GetBit(cond_data, offset + i)) { - out_values[i] = left_data; + /* ARROW_RETURN_NOT_OK(PromoteNulls(ctx, cond, left, right, out)); + + ARROW_ASSIGN_OR_RAISE(std::shared_ptr out_buf, + ctx->Allocate(cond.length * sizeof(T))); + T* out_values = reinterpret_cast(out_buf->mutable_data()); + + // copy right data to out_buff + const T* right_data = right.GetValues(1); + std::memcpy(out_values, right_data, right.length * sizeof(T)); + + const auto* cond_data = cond.buffers[1]->data(); // this is a BoolArray + BitBlockCounter bit_counter(cond_data, cond.offset, cond.length); + + // selectively copy values from left data + T left_data = internal::UnboxScalar::Unbox(left); + int64_t offset = cond.offset; + + // todo this can be improved by intrinsics. ex: _mm*_mask_store_e* (vmovdqa*) + while (offset < cond.offset + cond.length) { + const BitBlockCount& block = bit_counter.NextWord(); + if (block.AllSet()) { // all from left + std::fill(out_values, out_values + block.length, left_data); + } else if (block.popcount) { // selectively copy from left + for (int64_t i = 0; i < block.length; ++i) { + if (BitUtil::GetBit(cond_data, offset + i)) { + out_values[i] = left_data; + } + } } - } - } - offset += block.length; - out_values += block.length; - } + offset += block.length; + out_values += block.length; + } - out->buffers[1] = std::move(out_buf);*/ + out->buffers[1] = std::move(out_buf);*/ return Status::OK(); } }; @@ -312,7 +459,7 @@ struct IfElseFunctor> { // AAA static Status Call(KernelContext* ctx, const ArrayData& cond, const ArrayData& left, const ArrayData& right, ArrayData* out) { - ARROW_RETURN_NOT_OK(PromoteNulls(ctx, cond, left, right, out)); + ARROW_RETURN_NOT_OK(PromoteNullsNew1(ctx, cond, left, right, out)); // out_buff = right & ~cond ARROW_ASSIGN_OR_RAISE(std::shared_ptr out_buf, From 6aa1f37629f70dfdc10a63cf0201b75671d4e673 Mon Sep 17 00:00:00 2001 From: niranda perera Date: Tue, 25 May 2021 17:48:59 -0400 Subject: [PATCH 09/39] extending test cases --- .../arrow/compute/kernels/scalar_if_else.cc | 30 +++++++++---------- .../compute/kernels/scalar_if_else_test.cc | 17 +++++++++-- 2 files changed, 30 insertions(+), 17 deletions(-) diff --git a/cpp/src/arrow/compute/kernels/scalar_if_else.cc b/cpp/src/arrow/compute/kernels/scalar_if_else.cc index f90d6146ad141..60e2618939d0a 100644 --- a/cpp/src/arrow/compute/kernels/scalar_if_else.cc +++ b/cpp/src/arrow/compute/kernels/scalar_if_else.cc @@ -31,7 +31,7 @@ namespace compute { namespace { -// cond.val && (cond.data && left.val || ~cond.data && right.val) +// cond.valid && (cond.data && left.valid || ~cond.data && right.valid) enum IEBitmapIndex { C_VALID, C_DATA, L_VALID, R_VALID }; Status PromoteNullsNew1(KernelContext* ctx, const ArrayData& cond, const ArrayData& left, @@ -52,7 +52,7 @@ Status PromoteNullsNew1(KernelContext* ctx, const ArrayData& cond, const ArrayDa out_validity = output->GetMutableValues(0); } - // cond.val && (cond.data && left.val || ~cond.data && right.val) + // cond.valid && (cond.data && left.valid || ~cond.data && right.valid) int64_t i = 0; switch (flag) { case 7: // RLC = 111 @@ -177,7 +177,7 @@ Status PromoteNullsNew1(KernelContext* ctx, const ArrayData& cond, const ArrayDa }*/ // nulls will be promoted as follows: -// cond.val && (cond.data && left.val || ~cond.data && right.val) +// cond.valid && (cond.data && left.valid || ~cond.data && right.valid) // Note: we have to work on ArrayData. Otherwise we won't be able to handle array // offsets AAA /*Status PromoteNulls(KernelContext* ctx, const ArrayData& cond, const ArrayData& left, @@ -193,7 +193,7 @@ Status PromoteNullsNew1(KernelContext* ctx, const ArrayData& cond, const ArrayDa arrow::internal::InvertBitmap(ctx->memory_pool(), cond.buffers[1]->data(), cond.offset, len)); - if (right.MayHaveNulls()) { // out_validity = right.val && ~cond.data + if (right.MayHaveNulls()) { // out_validity = right.valid && ~cond.data arrow::internal::BitmapAnd(right.buffers[0]->data(), right.offset, out_validity->data(), 0, len, 0, out_validity->mutable_data()); @@ -201,7 +201,7 @@ Status PromoteNullsNew1(KernelContext* ctx, const ArrayData& cond, const ArrayDa std::shared_ptr tmp_buf; if (left.MayHaveNulls()) { - // tmp_buf = left.val && cond.data + // tmp_buf = left.valid && cond.data ARROW_ASSIGN_OR_RAISE( tmp_buf, arrow::internal::BitmapAnd(ctx->memory_pool(), left.buffers[0]->data(), left.offset, cond.buffers[1]->data(), @@ -210,12 +210,12 @@ Status PromoteNullsNew1(KernelContext* ctx, const ArrayData& cond, const ArrayDa tmp_buf = SliceBuffer(cond.buffers[1], cond.offset, cond.length); } - // out_validity = cond.data && left.val || ~cond.data && right.val + // out_validity = cond.data && left.valid || ~cond.data && right.valid arrow::internal::BitmapOr(out_validity->data(), 0, tmp_buf->data(), 0, len, 0, out_validity->mutable_data()); if (cond.MayHaveNulls()) { - // out_validity = cond.val && (cond.data && left.val || ~cond.data && right.val) + // out_validity = cond.valid && (cond.data && left.valid || ~cond.data && right.valid) ::arrow::internal::BitmapAnd(out_validity->data(), 0, cond.buffers[0]->data(), cond.offset, len, 0, out_validity->mutable_data()); } @@ -225,7 +225,7 @@ Status PromoteNullsNew1(KernelContext* ctx, const ArrayData& cond, const ArrayDa return Status::OK(); }*/ -// cond.val && (cond.data && left.val || ~cond.data && right.val) +// cond.valid && (cond.data && left.valid || ~cond.data && right.valid) // ASA and AAS Status PromoteNulls(KernelContext* ctx, const ArrayData& cond, const Scalar& left, const ArrayData& right, ArrayData* output) { @@ -239,20 +239,20 @@ Status PromoteNulls(KernelContext* ctx, const ArrayData& cond, const Scalar& lef std::shared_ptr out_validity, arrow::internal::InvertBitmap(ctx->memory_pool(), cond.buffers[1]->data(), cond.offset, len)); - // out_validity = ~cond.data && right.val - if (right.MayHaveNulls()) { // out_validity = right.val && ~cond.data + // out_validity = ~cond.data && right.valid + if (right.MayHaveNulls()) { // out_validity = right.valid && ~cond.data arrow::internal::BitmapAnd(right.buffers[0]->data(), right.offset, out_validity->data(), 0, len, 0, out_validity->mutable_data()); } - // out_validity = cond.data && left.val || ~cond.data && right.val + // out_validity = cond.data && left.valid || ~cond.data && right.valid if (left.is_valid) { arrow::internal::BitmapOr(out_validity->data(), 0, cond.buffers[1]->data(), cond.offset, len, 0, out_validity->mutable_data()); } - // out_validity = cond.val && (cond.data && left.val || ~cond.data && right.val) + // out_validity = cond.valid && (cond.data && left.valid || ~cond.data && right.valid) if (cond.MayHaveNulls()) { ::arrow::internal::BitmapAnd(out_validity->data(), 0, cond.buffers[0]->data(), cond.offset, len, 0, out_validity->mutable_data()); @@ -263,7 +263,7 @@ Status PromoteNulls(KernelContext* ctx, const ArrayData& cond, const Scalar& lef return Status::OK(); } /* -// cond.val && (cond.data && left.val || ~cond.data && right.val) +// cond.valid && (cond.data && left.valid || ~cond.data && right.valid) // ASS Status PromoteNulls(KernelContext* ctx, const ArrayData& cond, const Scalar& left, const Scalar& right, ArrayData* output) { @@ -283,13 +283,13 @@ Status PromoteNulls(KernelContext* ctx, const ArrayData& cond, const Scalar& lef ARROW_ASSIGN_OR_RAISE(out_validity, ctx->AllocateBitmap(len)); } - // out_validity = cond.data && left.val || ~cond.data && right.val + // out_validity = cond.data && left.valid || ~cond.data && right.valid if (left.is_valid) { arrow::internal::BitmapOr(out_validity->data(), 0, cond.buffers[1]->data(), cond.offset, len, 0, out_validity->mutable_data()); } - // out_validity = cond.val && (cond.data && left.val || ~cond.data && right.val) + // out_validity = cond.valid && (cond.data && left.valid || ~cond.data && right.valid) if (cond.MayHaveNulls()) { ::arrow::internal::BitmapAnd(out_validity->data(), 0, cond.buffers[0]->data(), cond.offset, len, 0, out_validity->mutable_data()); diff --git a/cpp/src/arrow/compute/kernels/scalar_if_else_test.cc b/cpp/src/arrow/compute/kernels/scalar_if_else_test.cc index cf81ebf944192..3e55006f8835d 100644 --- a/cpp/src/arrow/compute/kernels/scalar_if_else_test.cc +++ b/cpp/src/arrow/compute/kernels/scalar_if_else_test.cc @@ -64,14 +64,27 @@ TYPED_TEST(TestIfElsePrimitive, IfElseFixedSize) { // No Nulls CheckIfElseOutputArray(type, "[]", "[]", "[]", "[]"); + // RLC = 111 CheckIfElseOutputArray(type, "[true, true, true, false]", "[1, 2, 3, 4]", "[5, 6, 7, 8]", "[1, 2, 3, 8]"); - + // RLC = 110 CheckIfElseOutputArray(type, "[true, true, null, false]", "[1, 2, 3, 4]", "[5, 6, 7, 8]", "[1, 2, null, 8]", false); - + // RLC = 100 + CheckIfElseOutputArray(type, "[true, true, null, false]", "[1, null, 3, 4]", + "[5, 6, 7, 8]", "[1, null, null, 8]", false); + // RLC = 011 + CheckIfElseOutputArray(type, "[true, true, true, false]", "[1, 2, 3, 4]", + "[5, 6, 7, null]", "[1, 2, 3, null]", false); + // RLC = 010 + CheckIfElseOutputArray(type, "[null, true, true, false]", "[1, 2, 3, 4]", + "[5, 6, 7, null]", "[null, 2, 3, null]", false); + // RLC = 001 CheckIfElseOutputArray(type, "[true, true, true, false]", "[1, 2, null, null]", "[null, 6, 7, null]", "[1, 2, null, null]", false); + // RLC = 000 + CheckIfElseOutputArray(type, "[null, true, true, false]", "[1, 2, null, null]", + "[null, 6, 7, null]", "[null, 2, null, null]", false); using ArrayType = typename TypeTraits::ArrayType; random::RandomArrayGenerator rand(/*seed=*/0); From e41152f5c282806e95547f6f283b39655fbed200 Mon Sep 17 00:00:00 2001 From: niranda perera Date: Wed, 26 May 2021 09:14:06 -0400 Subject: [PATCH 10/39] adding array-scalar null promotion --- .../arrow/compute/kernels/scalar_if_else.cc | 98 ++++++++++++++++--- 1 file changed, 87 insertions(+), 11 deletions(-) diff --git a/cpp/src/arrow/compute/kernels/scalar_if_else.cc b/cpp/src/arrow/compute/kernels/scalar_if_else.cc index 60e2618939d0a..9f92f9d2918d4 100644 --- a/cpp/src/arrow/compute/kernels/scalar_if_else.cc +++ b/cpp/src/arrow/compute/kernels/scalar_if_else.cc @@ -34,6 +34,77 @@ namespace { // cond.valid && (cond.data && left.valid || ~cond.data && right.valid) enum IEBitmapIndex { C_VALID, C_DATA, L_VALID, R_VALID }; +Status PromoteNullsNew1(KernelContext* ctx, const ArrayData& cond, const ArrayData& left, + const Scalar& right, ArrayData* output) { + uint8_t flag = right.is_valid * 4 + !left.MayHaveNulls() * 2 + !cond.MayHaveNulls(); + + Bitmap bitmaps[3]; + bitmaps[C_VALID] = {cond.buffers[0], cond.offset, cond.length}; + bitmaps[C_DATA] = {cond.buffers[1], cond.offset, cond.length}; + bitmaps[L_VALID] = {left.buffers[0], left.offset, left.length}; + + uint64_t* out_validity = nullptr; + if (flag < 6) { + // there will be a validity buffer in the output + ARROW_ASSIGN_OR_RAISE(output->buffers[0], ctx->AllocateBitmap(cond.length)); + out_validity = output->GetMutableValues(0); + } + + // cond.valid && (cond.data && left.valid || ~cond.data && right.valid) + int64_t i = 0; + switch (flag) { + case 7: // RLC = 111 + break; + case 6: // RLC = 110 + output->buffers[0] = SliceBuffer(cond.buffers[0], cond.offset, cond.length); + break; + case 5: // RLC = 101 + Bitmap::VisitWords({bitmaps[C_DATA], bitmaps[L_VALID]}, + [&](std::array words) { + auto c_data = words[0], l_valid = words[1]; + out_validity[i] = (c_data & l_valid) | ~c_data; + i++; + }); + break; + case 4: // RLC = 100 + Bitmap::VisitWords({bitmaps[C_VALID], bitmaps[C_DATA], bitmaps[L_VALID]}, + [&](std::array words) { + auto c_valid = words[0], c_data = words[1], l_valid = words[2]; + out_validity[i] = c_valid & ((c_data & l_valid) | ~c_data); + i++; + }); + break; + case 3: // RLC = 011 + // only cond.data is passed + output->buffers[0] = SliceBuffer(cond.buffers[1], cond.offset, cond.length); + break; + case 2: // RLC = 010 + Bitmap::VisitWords({bitmaps[C_VALID], bitmaps[C_DATA]}, + [&](std::array words) { + auto c_valid = words[0], c_data = words[1]; + out_validity[i] = c_valid & c_data; + i++; + }); + break; + case 1: // RLC = 001 + Bitmap::VisitWords({bitmaps[C_DATA], bitmaps[L_VALID]}, + [&](std::array words) { + auto c_data = words[0], l_valid = words[1]; + out_validity[i] = (c_data & l_valid); + i++; + }); + break; + case 0: // RLC = 000 + Bitmap::VisitWords(bitmaps, [&](std::array words) { + auto c_valid = words[0], c_data = words[1], l_valid = words[2]; + out_validity[i] = c_valid & c_data & l_valid; + i++; + }); + break; + } + return Status::OK(); +} + Status PromoteNullsNew1(KernelContext* ctx, const ArrayData& cond, const ArrayData& left, const ArrayData& right, ArrayData* output) { uint8_t flag = @@ -58,49 +129,53 @@ Status PromoteNullsNew1(KernelContext* ctx, const ArrayData& cond, const ArrayDa case 7: // RLC = 111 break; case 6: // RLC = 110 - output->buffers[0] = cond.buffers[0]; + output->buffers[0] = SliceBuffer(cond.buffers[0], cond.offset, cond.length); break; case 5: // RLC = 101 Bitmap::VisitWords({bitmaps[C_DATA], bitmaps[L_VALID]}, [&](std::array words) { - out_validity[i] = (words[0] & words[1]) | ~words[0]; + auto c_data = words[0], l_valid = words[1]; + out_validity[i] = (c_data & l_valid) | ~c_data; i++; }); break; case 4: // RLC = 100 Bitmap::VisitWords({bitmaps[C_VALID], bitmaps[C_DATA], bitmaps[L_VALID]}, [&](std::array words) { - out_validity[i] = - words[0] & ((words[1] & words[2]) | ~words[1]); + auto c_valid = words[0], c_data = words[1], l_valid = words[2]; + out_validity[i] = c_valid & ((c_data & l_valid) | ~c_data); i++; }); break; case 3: // RLC = 011 Bitmap::VisitWords({bitmaps[C_DATA], bitmaps[R_VALID]}, [&](std::array words) { - out_validity[i] = words[0] | (~words[0] & words[1]); + auto c_data = words[0], r_valid = words[1]; + out_validity[i] = c_data | (~c_data & r_valid); i++; }); break; case 2: // RLC = 010 Bitmap::VisitWords({bitmaps[C_VALID], bitmaps[C_DATA], bitmaps[R_VALID]}, [&](std::array words) { - out_validity[i] = - words[0] & (words[1] | (~words[1] & words[2])); + auto c_valid = words[0], c_data = words[1], r_valid = words[2]; + out_validity[i] = c_valid & (c_data | (~c_data & r_valid)); i++; }); break; case 1: // RLC = 001 Bitmap::VisitWords({bitmaps[C_DATA], bitmaps[L_VALID], bitmaps[R_VALID]}, [&](std::array words) { - out_validity[i] = - (words[0] & words[1]) | (~words[0] & words[2]); + auto c_data = words[0], l_valid = words[1], r_valid = words[2]; + out_validity[i] = (c_data & l_valid) | (~c_data & r_valid); i++; }); break; case 0: // RLC = 000 Bitmap::VisitWords(bitmaps, [&](std::array words) { - out_validity[i] = words[0] & ((words[1] & words[2]) | (~words[1] & words[3])); + auto c_valid = words[0], c_data = words[1], l_valid = words[2], + r_valid = words[3]; + out_validity[i] = c_valid & ((c_data & l_valid) | (~c_data & r_valid)); i++; }); break; @@ -374,7 +449,8 @@ struct IfElseFunctor< // ASA and AAS static Status Call(KernelContext* ctx, const ArrayData& cond, const Scalar& left, const ArrayData& right, ArrayData* out) { - ARROW_RETURN_NOT_OK(PromoteNulls(ctx, cond, left, right, out)); + // todo change this! scalar and array is swapped just for compilation + ARROW_RETURN_NOT_OK(PromoteNullsNew1(ctx, cond, right, left, out)); ARROW_ASSIGN_OR_RAISE(std::shared_ptr out_buf, ctx->Allocate(cond.length * sizeof(T))); From 172a824fb0ddbfa5da9c3ca3b1c547ebd1c1a204 Mon Sep 17 00:00:00 2001 From: niranda perera Date: Wed, 26 May 2021 09:56:18 -0400 Subject: [PATCH 11/39] adding scalar-scalar null promotion --- .../arrow/compute/kernels/scalar_if_else.cc | 125 ++++++++++++------ 1 file changed, 86 insertions(+), 39 deletions(-) diff --git a/cpp/src/arrow/compute/kernels/scalar_if_else.cc b/cpp/src/arrow/compute/kernels/scalar_if_else.cc index 9f92f9d2918d4..16b3696dafe45 100644 --- a/cpp/src/arrow/compute/kernels/scalar_if_else.cc +++ b/cpp/src/arrow/compute/kernels/scalar_if_else.cc @@ -34,6 +34,55 @@ namespace { // cond.valid && (cond.data && left.valid || ~cond.data && right.valid) enum IEBitmapIndex { C_VALID, C_DATA, L_VALID, R_VALID }; +Status PromoteNullsNew1(KernelContext* ctx, const ArrayData& cond, const Scalar& left, + const Scalar& right, ArrayData* output) { + uint8_t flag = right.is_valid * 4 + left.is_valid * 2 + !cond.MayHaveNulls(); + + if (flag < 6 && flag != 3) { + // there will be a validity buffer in the output + ARROW_ASSIGN_OR_RAISE(output->buffers[0], ctx->AllocateBitmap(cond.length)); + } + + // cond.valid && (cond.data && left.valid || ~cond.data && right.valid) + int64_t i = 0; + switch (flag) { + case 7: // RLC = 111 + break; + case 6: // RLC = 110 + // out_valid = c_valid + output->buffers[0] = SliceBuffer(cond.buffers[0], cond.offset, cond.length); + break; + case 5: // RLC = 101 + // out_valid = ~cond.data + arrow::internal::InvertBitmap(cond.buffers[1]->data(), cond.offset, cond.length, + output->buffers[0]->mutable_data(), 0); + break; + case 4: // RLC = 100 + // out_valid = c_valid & ~cond.data + arrow::internal::BitmapAndNot(cond.buffers[0]->data(), cond.offset, + cond.buffers[1]->data(), cond.offset, cond.length, 0, + output->buffers[0]->mutable_data()); + break; + case 3: // RLC = 011 + // out_valid = cond.data + output->buffers[0] = SliceBuffer(cond.buffers[1], cond.offset, cond.length); + break; + case 2: // RLC = 010 + // out_valid = cond.valid & cond.data + arrow::internal::BitmapAnd(cond.buffers[0]->data(), cond.offset, + cond.buffers[1]->data(), cond.offset, cond.length, 0, + output->buffers[0]->mutable_data()); + break; + case 1: // RLC = 001 + // out_valid = 0 --> nothing to do; but requires out_valid to be a all-zero buffer + break; + case 0: // RLC = 000 + // out_valid = 0 --> nothing to do; but requires out_valid to be a all-zero buffer + break; + } + return Status::OK(); +} + Status PromoteNullsNew1(KernelContext* ctx, const ArrayData& cond, const ArrayData& left, const Scalar& right, ArrayData* output) { uint8_t flag = right.is_valid * 4 + !left.MayHaveNulls() * 2 + !cond.MayHaveNulls(); @@ -44,7 +93,7 @@ Status PromoteNullsNew1(KernelContext* ctx, const ArrayData& cond, const ArrayDa bitmaps[L_VALID] = {left.buffers[0], left.offset, left.length}; uint64_t* out_validity = nullptr; - if (flag < 6) { + if (flag < 6 && flag != 3) { // there will be a validity buffer in the output ARROW_ASSIGN_OR_RAISE(output->buffers[0], ctx->AllocateBitmap(cond.length)); out_validity = output->GetMutableValues(0); @@ -79,12 +128,10 @@ Status PromoteNullsNew1(KernelContext* ctx, const ArrayData& cond, const ArrayDa output->buffers[0] = SliceBuffer(cond.buffers[1], cond.offset, cond.length); break; case 2: // RLC = 010 - Bitmap::VisitWords({bitmaps[C_VALID], bitmaps[C_DATA]}, - [&](std::array words) { - auto c_valid = words[0], c_data = words[1]; - out_validity[i] = c_valid & c_data; - i++; - }); + // out_valid = cond.valid & cond.data + arrow::internal::BitmapAnd(cond.buffers[0]->data(), cond.offset, + cond.buffers[1]->data(), cond.offset, cond.length, 0, + output->buffers[0]->mutable_data()); break; case 1: // RLC = 001 Bitmap::VisitWords({bitmaps[C_DATA], bitmaps[L_VALID]}, @@ -491,41 +538,41 @@ struct IfElseFunctor< // ASS static Status Call(KernelContext* ctx, const ArrayData& cond, const Scalar& left, const Scalar& right, ArrayData* out) { - /* ARROW_RETURN_NOT_OK(PromoteNulls(ctx, cond, left, right, out)); - - ARROW_ASSIGN_OR_RAISE(std::shared_ptr out_buf, - ctx->Allocate(cond.length * sizeof(T))); - T* out_values = reinterpret_cast(out_buf->mutable_data()); - - // copy right data to out_buff - const T* right_data = right.GetValues(1); - std::memcpy(out_values, right_data, right.length * sizeof(T)); - - const auto* cond_data = cond.buffers[1]->data(); // this is a BoolArray - BitBlockCounter bit_counter(cond_data, cond.offset, cond.length); - - // selectively copy values from left data - T left_data = internal::UnboxScalar::Unbox(left); - int64_t offset = cond.offset; - - // todo this can be improved by intrinsics. ex: _mm*_mask_store_e* (vmovdqa*) - while (offset < cond.offset + cond.length) { - const BitBlockCount& block = bit_counter.NextWord(); - if (block.AllSet()) { // all from left - std::fill(out_values, out_values + block.length, left_data); - } else if (block.popcount) { // selectively copy from left - for (int64_t i = 0; i < block.length; ++i) { - if (BitUtil::GetBit(cond_data, offset + i)) { - out_values[i] = left_data; + ARROW_RETURN_NOT_OK(PromoteNullsNew1(ctx, cond, left, right, out)); + /* + ARROW_ASSIGN_OR_RAISE(std::shared_ptr out_buf, + ctx->Allocate(cond.length * sizeof(T))); + T* out_values = reinterpret_cast(out_buf->mutable_data()); + + // copy right data to out_buff + const T* right_data = right.GetValues(1); + std::memcpy(out_values, right_data, right.length * sizeof(T)); + + const auto* cond_data = cond.buffers[1]->data(); // this is a BoolArray + BitBlockCounter bit_counter(cond_data, cond.offset, cond.length); + + // selectively copy values from left data + T left_data = internal::UnboxScalar::Unbox(left); + int64_t offset = cond.offset; + + // todo this can be improved by intrinsics. ex: _mm*_mask_store_e* (vmovdqa*) + while (offset < cond.offset + cond.length) { + const BitBlockCount& block = bit_counter.NextWord(); + if (block.AllSet()) { // all from left + std::fill(out_values, out_values + block.length, left_data); + } else if (block.popcount) { // selectively copy from left + for (int64_t i = 0; i < block.length; ++i) { + if (BitUtil::GetBit(cond_data, offset + i)) { + out_values[i] = left_data; + } + } } - } - } - offset += block.length; - out_values += block.length; - } + offset += block.length; + out_values += block.length; + } - out->buffers[1] = std::move(out_buf);*/ + out->buffers[1] = std::move(out_buf);*/ return Status::OK(); } }; From 768d630b627644baadd38a086b4bb7c38bbf34c3 Mon Sep 17 00:00:00 2001 From: niranda perera Date: Wed, 26 May 2021 14:08:47 -0400 Subject: [PATCH 12/39] refactor --- .../arrow/compute/kernels/scalar_if_else.cc | 22 ++++++++++--------- 1 file changed, 12 insertions(+), 10 deletions(-) diff --git a/cpp/src/arrow/compute/kernels/scalar_if_else.cc b/cpp/src/arrow/compute/kernels/scalar_if_else.cc index 16b3696dafe45..af23c5c935efa 100644 --- a/cpp/src/arrow/compute/kernels/scalar_if_else.cc +++ b/cpp/src/arrow/compute/kernels/scalar_if_else.cc @@ -34,8 +34,8 @@ namespace { // cond.valid && (cond.data && left.valid || ~cond.data && right.valid) enum IEBitmapIndex { C_VALID, C_DATA, L_VALID, R_VALID }; -Status PromoteNullsNew1(KernelContext* ctx, const ArrayData& cond, const Scalar& left, - const Scalar& right, ArrayData* output) { +Status PromoteNullsVisitor(KernelContext* ctx, const ArrayData& cond, const Scalar& left, + const Scalar& right, ArrayData* output) { uint8_t flag = right.is_valid * 4 + left.is_valid * 2 + !cond.MayHaveNulls(); if (flag < 6 && flag != 3) { @@ -83,8 +83,9 @@ Status PromoteNullsNew1(KernelContext* ctx, const ArrayData& cond, const Scalar& return Status::OK(); } -Status PromoteNullsNew1(KernelContext* ctx, const ArrayData& cond, const ArrayData& left, - const Scalar& right, ArrayData* output) { +Status PromoteNullsVisitor(KernelContext* ctx, const ArrayData& cond, + const ArrayData& left, const Scalar& right, + ArrayData* output) { uint8_t flag = right.is_valid * 4 + !left.MayHaveNulls() * 2 + !cond.MayHaveNulls(); Bitmap bitmaps[3]; @@ -152,8 +153,9 @@ Status PromoteNullsNew1(KernelContext* ctx, const ArrayData& cond, const ArrayDa return Status::OK(); } -Status PromoteNullsNew1(KernelContext* ctx, const ArrayData& cond, const ArrayData& left, - const ArrayData& right, ArrayData* output) { +Status PromoteNullsVisitor(KernelContext* ctx, const ArrayData& cond, + const ArrayData& left, const ArrayData& right, + ArrayData* output) { uint8_t flag = !right.MayHaveNulls() * 4 + !left.MayHaveNulls() * 2 + !cond.MayHaveNulls(); @@ -454,7 +456,7 @@ struct IfElseFunctor< // AAA static Status Call(KernelContext* ctx, const ArrayData& cond, const ArrayData& left, const ArrayData& right, ArrayData* out) { - ARROW_RETURN_NOT_OK(PromoteNullsNew1(ctx, cond, left, right, out)); + ARROW_RETURN_NOT_OK(PromoteNullsVisitor(ctx, cond, left, right, out)); ARROW_ASSIGN_OR_RAISE(std::shared_ptr out_buf, ctx->Allocate(cond.length * sizeof(T))); @@ -497,7 +499,7 @@ struct IfElseFunctor< static Status Call(KernelContext* ctx, const ArrayData& cond, const Scalar& left, const ArrayData& right, ArrayData* out) { // todo change this! scalar and array is swapped just for compilation - ARROW_RETURN_NOT_OK(PromoteNullsNew1(ctx, cond, right, left, out)); + ARROW_RETURN_NOT_OK(PromoteNullsVisitor(ctx, cond, right, left, out)); ARROW_ASSIGN_OR_RAISE(std::shared_ptr out_buf, ctx->Allocate(cond.length * sizeof(T))); @@ -538,7 +540,7 @@ struct IfElseFunctor< // ASS static Status Call(KernelContext* ctx, const ArrayData& cond, const Scalar& left, const Scalar& right, ArrayData* out) { - ARROW_RETURN_NOT_OK(PromoteNullsNew1(ctx, cond, left, right, out)); + ARROW_RETURN_NOT_OK(PromoteNullsVisitor(ctx, cond, left, right, out)); /* ARROW_ASSIGN_OR_RAISE(std::shared_ptr out_buf, ctx->Allocate(cond.length * sizeof(T))); @@ -582,7 +584,7 @@ struct IfElseFunctor> { // AAA static Status Call(KernelContext* ctx, const ArrayData& cond, const ArrayData& left, const ArrayData& right, ArrayData* out) { - ARROW_RETURN_NOT_OK(PromoteNullsNew1(ctx, cond, left, right, out)); + ARROW_RETURN_NOT_OK(PromoteNullsVisitor(ctx, cond, left, right, out)); // out_buff = right & ~cond ARROW_ASSIGN_OR_RAISE(std::shared_ptr out_buf, From 26097349230aa0953a2833b9f5d1c701f3fbb1eb Mon Sep 17 00:00:00 2001 From: niranda perera Date: Wed, 26 May 2021 20:48:39 -0400 Subject: [PATCH 13/39] readability improvements --- .../arrow/compute/kernels/scalar_if_else.cc | 474 ++++++++++-------- 1 file changed, 278 insertions(+), 196 deletions(-) diff --git a/cpp/src/arrow/compute/kernels/scalar_if_else.cc b/cpp/src/arrow/compute/kernels/scalar_if_else.cc index af23c5c935efa..87c5f585c2ea2 100644 --- a/cpp/src/arrow/compute/kernels/scalar_if_else.cc +++ b/cpp/src/arrow/compute/kernels/scalar_if_else.cc @@ -31,9 +31,11 @@ namespace compute { namespace { -// cond.valid && (cond.data && left.valid || ~cond.data && right.valid) -enum IEBitmapIndex { C_VALID, C_DATA, L_VALID, R_VALID }; +enum { COND_ALL_VALID = 1, LEFT_ALL_VALID = 2, RIGHT_ALL_VALID = 4 }; +// if the condition is null then output is null otherwise we take validity from the +// selected argument +// ie. cond.valid & (cond.data & left.valid | ~cond.data & right.valid) Status PromoteNullsVisitor(KernelContext* ctx, const ArrayData& cond, const Scalar& left, const Scalar& right, ArrayData* output) { uint8_t flag = right.is_valid * 4 + left.is_valid * 2 + !cond.MayHaveNulls(); @@ -43,37 +45,38 @@ Status PromoteNullsVisitor(KernelContext* ctx, const ArrayData& cond, const Scal ARROW_ASSIGN_OR_RAISE(output->buffers[0], ctx->AllocateBitmap(cond.length)); } - // cond.valid && (cond.data && left.valid || ~cond.data && right.valid) - int64_t i = 0; + // if the condition is null then output is null otherwise we take validity from the + // selected argument + // ie. cond.valid & (cond.data & left.valid | ~cond.data & right.valid) switch (flag) { - case 7: // RLC = 111 + case COND_ALL_VALID | LEFT_ALL_VALID | RIGHT_ALL_VALID: // = 7 break; - case 6: // RLC = 110 + case LEFT_ALL_VALID | RIGHT_ALL_VALID: // = 6 // out_valid = c_valid output->buffers[0] = SliceBuffer(cond.buffers[0], cond.offset, cond.length); break; - case 5: // RLC = 101 + case COND_ALL_VALID | RIGHT_ALL_VALID: // = 5 // out_valid = ~cond.data arrow::internal::InvertBitmap(cond.buffers[1]->data(), cond.offset, cond.length, output->buffers[0]->mutable_data(), 0); break; - case 4: // RLC = 100 + case RIGHT_ALL_VALID: // = 4 // out_valid = c_valid & ~cond.data arrow::internal::BitmapAndNot(cond.buffers[0]->data(), cond.offset, cond.buffers[1]->data(), cond.offset, cond.length, 0, output->buffers[0]->mutable_data()); break; - case 3: // RLC = 011 + case COND_ALL_VALID | LEFT_ALL_VALID: // = 3 // out_valid = cond.data output->buffers[0] = SliceBuffer(cond.buffers[1], cond.offset, cond.length); break; - case 2: // RLC = 010 + case LEFT_ALL_VALID: // = 2 // out_valid = cond.valid & cond.data arrow::internal::BitmapAnd(cond.buffers[0]->data(), cond.offset, cond.buffers[1]->data(), cond.offset, cond.length, 0, output->buffers[0]->mutable_data()); break; - case 1: // RLC = 001 + case COND_ALL_VALID: // = 1 // out_valid = 0 --> nothing to do; but requires out_valid to be a all-zero buffer break; case 0: // RLC = 000 @@ -88,6 +91,8 @@ Status PromoteNullsVisitor(KernelContext* ctx, const ArrayData& cond, ArrayData* output) { uint8_t flag = right.is_valid * 4 + !left.MayHaveNulls() * 2 + !cond.MayHaveNulls(); + enum { C_VALID, C_DATA, L_VALID }; + Bitmap bitmaps[3]; bitmaps[C_VALID] = {cond.buffers[0], cond.offset, cond.length}; bitmaps[C_DATA] = {cond.buffers[1], cond.offset, cond.length}; @@ -100,69 +105,72 @@ Status PromoteNullsVisitor(KernelContext* ctx, const ArrayData& cond, out_validity = output->GetMutableValues(0); } - // cond.valid && (cond.data && left.valid || ~cond.data && right.valid) + // lambda function that will be used inside the visitor int64_t i = 0; + auto apply = [&](uint64_t c_valid, uint64_t c_data, uint64_t l_valid, + uint64_t r_valid) { + out_validity[i] = c_valid & ((c_data & l_valid) | (~c_data & r_valid)); + i++; + }; + + // if the condition is null then output is null otherwise we take validity from the + // selected argument + // ie. cond.valid & (cond.data & left.valid | ~cond.data & right.valid) switch (flag) { - case 7: // RLC = 111 + case COND_ALL_VALID | LEFT_ALL_VALID | RIGHT_ALL_VALID: // RLC = 111 break; - case 6: // RLC = 110 + case LEFT_ALL_VALID | RIGHT_ALL_VALID: // RLC = 110 output->buffers[0] = SliceBuffer(cond.buffers[0], cond.offset, cond.length); break; - case 5: // RLC = 101 - Bitmap::VisitWords({bitmaps[C_DATA], bitmaps[L_VALID]}, - [&](std::array words) { - auto c_data = words[0], l_valid = words[1]; - out_validity[i] = (c_data & l_valid) | ~c_data; - i++; - }); - break; - case 4: // RLC = 100 - Bitmap::VisitWords({bitmaps[C_VALID], bitmaps[C_DATA], bitmaps[L_VALID]}, - [&](std::array words) { - auto c_valid = words[0], c_data = words[1], l_valid = words[2]; - out_validity[i] = c_valid & ((c_data & l_valid) | ~c_data); - i++; - }); - break; - case 3: // RLC = 011 + case COND_ALL_VALID | RIGHT_ALL_VALID: // RLC = 101 + // bitmaps[C_VALID] might be null; override to make it safe for Visit() + bitmaps[C_VALID] = bitmaps[C_DATA]; + Bitmap::VisitWords(bitmaps, [&](std::array words) { + apply(UINT64_MAX, words[C_DATA], words[L_VALID], UINT64_MAX); + }); + break; + case RIGHT_ALL_VALID: // RLC = 100 + Bitmap::VisitWords(bitmaps, [&](std::array words) { + apply(words[C_VALID], words[C_DATA], words[L_VALID], UINT64_MAX); + }); + break; + case COND_ALL_VALID | LEFT_ALL_VALID: // RLC = 011 // only cond.data is passed output->buffers[0] = SliceBuffer(cond.buffers[1], cond.offset, cond.length); break; - case 2: // RLC = 010 + case LEFT_ALL_VALID: // RLC = 010 // out_valid = cond.valid & cond.data arrow::internal::BitmapAnd(cond.buffers[0]->data(), cond.offset, cond.buffers[1]->data(), cond.offset, cond.length, 0, output->buffers[0]->mutable_data()); break; - case 1: // RLC = 001 - Bitmap::VisitWords({bitmaps[C_DATA], bitmaps[L_VALID]}, - [&](std::array words) { - auto c_data = words[0], l_valid = words[1]; - out_validity[i] = (c_data & l_valid); - i++; - }); + case COND_ALL_VALID: // RLC = 001 + // out_valid = cond.data & left.valid + arrow::internal::BitmapAnd(cond.buffers[1]->data(), cond.offset, + left.buffers[0]->data(), left.offset, cond.length, 0, + output->buffers[0]->mutable_data()); break; case 0: // RLC = 000 Bitmap::VisitWords(bitmaps, [&](std::array words) { - auto c_valid = words[0], c_data = words[1], l_valid = words[2]; - out_validity[i] = c_valid & c_data & l_valid; - i++; + apply(words[C_VALID], words[C_DATA], words[L_VALID], 0); }); break; } return Status::OK(); } -Status PromoteNullsVisitor(KernelContext* ctx, const ArrayData& cond, - const ArrayData& left, const ArrayData& right, - ArrayData* output) { - uint8_t flag = - !right.MayHaveNulls() * 4 + !left.MayHaveNulls() * 2 + !cond.MayHaveNulls(); +// if the condition is null then output is null otherwise we take validity from the +// selected argument +// ie. cond.valid & (cond.data & left.valid | ~cond.data & right.valid) +Status PromoteNullsVisitor(KernelContext* ctx, const ArrayData& cond, const Scalar& left, + const ArrayData& right, ArrayData* output) { + uint8_t flag = !right.MayHaveNulls() * 4 + left.is_valid * 2 + !cond.MayHaveNulls(); - Bitmap bitmaps[4]; + enum { C_VALID, C_DATA, R_VALID }; + + Bitmap bitmaps[3]; bitmaps[C_VALID] = {cond.buffers[0], cond.offset, cond.length}; bitmaps[C_DATA] = {cond.buffers[1], cond.offset, cond.length}; - bitmaps[L_VALID] = {left.buffers[0], left.offset, left.length}; bitmaps[R_VALID] = {right.buffers[0], right.offset, right.length}; uint64_t* out_validity = nullptr; @@ -172,75 +180,71 @@ Status PromoteNullsVisitor(KernelContext* ctx, const ArrayData& cond, out_validity = output->GetMutableValues(0); } - // cond.valid && (cond.data && left.valid || ~cond.data && right.valid) + // lambda function that will be used inside the visitor int64_t i = 0; + auto apply = [&](uint64_t c_valid, uint64_t c_data, uint64_t l_valid, + uint64_t r_valid) { + out_validity[i] = c_valid & ((c_data & l_valid) | (~c_data & r_valid)); + i++; + }; + + // if the condition is null then output is null otherwise we take validity from the + // selected argument + // ie. cond.valid & (cond.data & left.valid | ~cond.data & right.valid) switch (flag) { - case 7: // RLC = 111 + case COND_ALL_VALID | LEFT_ALL_VALID | RIGHT_ALL_VALID: // RLC = 111 break; - case 6: // RLC = 110 + case LEFT_ALL_VALID | RIGHT_ALL_VALID: // RLC = 110 output->buffers[0] = SliceBuffer(cond.buffers[0], cond.offset, cond.length); break; - case 5: // RLC = 101 - Bitmap::VisitWords({bitmaps[C_DATA], bitmaps[L_VALID]}, - [&](std::array words) { - auto c_data = words[0], l_valid = words[1]; - out_validity[i] = (c_data & l_valid) | ~c_data; - i++; - }); - break; - case 4: // RLC = 100 - Bitmap::VisitWords({bitmaps[C_VALID], bitmaps[C_DATA], bitmaps[L_VALID]}, - [&](std::array words) { - auto c_valid = words[0], c_data = words[1], l_valid = words[2]; - out_validity[i] = c_valid & ((c_data & l_valid) | ~c_data); - i++; - }); - break; - case 3: // RLC = 011 - Bitmap::VisitWords({bitmaps[C_DATA], bitmaps[R_VALID]}, - [&](std::array words) { - auto c_data = words[0], r_valid = words[1]; - out_validity[i] = c_data | (~c_data & r_valid); - i++; - }); - break; - case 2: // RLC = 010 - Bitmap::VisitWords({bitmaps[C_VALID], bitmaps[C_DATA], bitmaps[R_VALID]}, - [&](std::array words) { - auto c_valid = words[0], c_data = words[1], r_valid = words[2]; - out_validity[i] = c_valid & (c_data | (~c_data & r_valid)); - i++; - }); - break; - case 1: // RLC = 001 - Bitmap::VisitWords({bitmaps[C_DATA], bitmaps[L_VALID], bitmaps[R_VALID]}, - [&](std::array words) { - auto c_data = words[0], l_valid = words[1], r_valid = words[2]; - out_validity[i] = (c_data & l_valid) | (~c_data & r_valid); - i++; - }); + case COND_ALL_VALID | RIGHT_ALL_VALID: // RLC = 101 + // out_valid = ~cond.data + arrow::internal::InvertBitmap(cond.buffers[1]->data(), cond.offset, cond.length, + output->buffers[0]->mutable_data(), 0); + break; + case RIGHT_ALL_VALID: // RLC = 100 + // out_valid = c_valid & ~cond.data + arrow::internal::BitmapAndNot(cond.buffers[0]->data(), cond.offset, + cond.buffers[1]->data(), cond.offset, cond.length, 0, + output->buffers[0]->mutable_data()); + break; + case COND_ALL_VALID | LEFT_ALL_VALID: // RLC = 011 + // bitmaps[C_VALID] might be null; override to make it safe for Visit() + bitmaps[C_VALID] = bitmaps[C_DATA]; + Bitmap::VisitWords(bitmaps, [&](std::array words) { + apply(UINT64_MAX, words[C_DATA], UINT64_MAX, words[R_VALID]); + }); + break; + case LEFT_ALL_VALID: // RLC = 010 + Bitmap::VisitWords(bitmaps, [&](std::array words) { + apply(words[C_VALID], words[C_DATA], UINT64_MAX, words[R_VALID]); + }); + break; + case COND_ALL_VALID: // RLC = 001 + // out_valid = ~cond.data & right.valid + arrow::internal::BitmapAndNot(right.buffers[0]->data(), right.offset, + cond.buffers[1]->data(), cond.offset, cond.length, 0, + output->buffers[0]->mutable_data()); break; case 0: // RLC = 000 - Bitmap::VisitWords(bitmaps, [&](std::array words) { - auto c_valid = words[0], c_data = words[1], l_valid = words[2], - r_valid = words[3]; - out_validity[i] = c_valid & ((c_data & l_valid) | (~c_data & r_valid)); - i++; + Bitmap::VisitWords(bitmaps, [&](std::array words) { + apply(words[C_VALID], words[C_DATA], 0, words[R_VALID]); }); break; } return Status::OK(); } -/*Status PromoteNullsNew(KernelContext* ctx, const ArrayData& cond, const ArrayData& left, - const ArrayData& right, ArrayData* output) { - if (!cond.MayHaveNulls() && !left.MayHaveNulls() && !right.MayHaveNulls()) { - return Status::OK(); // no nulls to handle - } +// if the condition is null then output is null otherwise we take validity from the +// selected argument +// ie. cond.valid & (cond.data & left.valid | ~cond.data & right.valid) +Status PromoteNullsVisitor(KernelContext* ctx, const ArrayData& cond, + const ArrayData& left, const ArrayData& right, + ArrayData* output) { + uint8_t flag = + !right.MayHaveNulls() * 4 + !left.MayHaveNulls() * 2 + !cond.MayHaveNulls(); - // there will be a validity buffer in the output - ARROW_ASSIGN_OR_RAISE(output->buffers[0], ctx->AllocateBitmap(cond.length)); - auto out_validity = output->GetMutableValues(0); + enum { C_VALID, C_DATA, L_VALID, R_VALID }; Bitmap bitmaps[4]; bitmaps[C_VALID] = {cond.buffers[0], cond.offset, cond.length}; @@ -248,57 +252,77 @@ Status PromoteNullsVisitor(KernelContext* ctx, const ArrayData& cond, bitmaps[L_VALID] = {left.buffers[0], left.offset, left.length}; bitmaps[R_VALID] = {right.buffers[0], right.offset, right.length}; - uint8_t flag = - (cond.null_count == 0) * 4 + (left.null_count == 0) * 2 + (right.null_count == 0); + uint64_t* out_validity = nullptr; + if (flag < 6) { + // there will be a validity buffer in the output + ARROW_ASSIGN_OR_RAISE(output->buffers[0], ctx->AllocateBitmap(cond.length)); + out_validity = output->GetMutableValues(0); + } + // lambda function that will be used inside the visitor int64_t i = 0; + auto apply = [&](uint64_t c_valid, uint64_t c_data, uint64_t l_valid, + uint64_t r_valid) { + out_validity[i] = c_valid & ((c_data & l_valid) | (~c_data & r_valid)); + i++; + }; + + // if the condition is null then output is null otherwise we take validity from the + // selected argument + // ie. cond.valid & (cond.data & left.valid | ~cond.data & right.valid) switch (flag) { - case 0: // all have nulls + case COND_ALL_VALID | LEFT_ALL_VALID | RIGHT_ALL_VALID: // RLC = 111 + break; + case LEFT_ALL_VALID | RIGHT_ALL_VALID: // RLC = 110 + output->buffers[0] = SliceBuffer(cond.buffers[0], cond.offset, cond.length); + break; + case COND_ALL_VALID | RIGHT_ALL_VALID: // RLC = 101 + // bitmaps[C_VALID], bitmaps[R_VALID] might be null; override to make it safe for + // Visit() + bitmaps[C_VALID] = bitmaps[C_DATA]; + bitmaps[R_VALID] = bitmaps[C_DATA]; Bitmap::VisitWords(bitmaps, [&](std::array words) { - out_validity[i] = words[0] & ((words[1] & words[2]) | (~words[1] & words[3])); - i++; + apply(UINT64_MAX, words[C_DATA], words[L_VALID], UINT64_MAX); + }); + break; + case RIGHT_ALL_VALID: // RLC = 100 + // bitmaps[R_VALID] might be null; override to make it safe for Visit() + bitmaps[R_VALID] = bitmaps[C_DATA]; + Bitmap::VisitWords(bitmaps, [&](std::array words) { + apply(words[C_VALID], words[C_DATA], words[L_VALID], UINT64_MAX); + }); + break; + case COND_ALL_VALID | LEFT_ALL_VALID: // RLC = 011 + // bitmaps[C_VALID], bitmaps[L_VALID] might be null; override to make it safe for + // Visit() + bitmaps[C_VALID] = bitmaps[C_DATA]; + bitmaps[L_VALID] = bitmaps[C_DATA]; + Bitmap::VisitWords(bitmaps, [&](std::array words) { + apply(UINT64_MAX, words[C_DATA], UINT64_MAX, words[R_VALID]); + }); + break; + case LEFT_ALL_VALID: // RLC = 010 + // bitmaps[L_VALID] might be null; override to make it safe for Visit() + bitmaps[L_VALID] = bitmaps[C_DATA]; + Bitmap::VisitWords(bitmaps, [&](std::array words) { + apply(words[C_VALID], words[C_DATA], UINT64_MAX, words[R_VALID]); + }); + break; + case COND_ALL_VALID: // RLC = 001 + // bitmaps[C_VALID] might be null; override to make it safe for Visit() + bitmaps[C_VALID] = bitmaps[C_DATA]; + Bitmap::VisitWords(bitmaps, [&](std::array words) { + apply(UINT64_MAX, words[C_DATA], words[L_VALID], words[R_VALID]); + }); + break; + case 0: // RLC = 000 + Bitmap::VisitWords(bitmaps, [&](std::array words) { + apply(words[C_VALID], words[C_DATA], words[L_VALID], words[R_VALID]); }); break; - case 1: // right all valid - Bitmap::VisitWords({bitmaps[C_VALID], bitmaps[C_DATA], bitmaps[L_VALID]}, - [&](std::array words) { - out_validity[i] = - words[0] & ((words[1] & words[2]) | ~words[1]); - i++; - }); - break; - case 2: // left all valid - Bitmap::VisitWords({bitmaps[C_VALID], bitmaps[C_DATA], bitmaps[R_VALID]}, - [&](std::array words) { - out_validity[i] = - words[0] & (words[1] | (~words[1] & words[2])); - i++; - }); - break; - case 3: // left, right all valid - *ou break; - - case 7: // all valid. nothing to do - return Status::OK(); - } - - if (cond.null_count == 0) { - } - - if (right.null_count == 0) { - Bitmap::VisitWords(bitmaps, [&](std::array words) { - apply(words[C_VALID], words[C_DATA], words[L_VALID], ~uint64_t(0)); - }); - return Status::OK(); } - - DCHECK(left.null_count != 0 && right.null_count != 0); - Bitmap::VisitWords(bitmaps, [&](std::array words) { - apply(words[C_VALID], words[C_DATA], words[L_VALID], words[R_VALID]); - }); - return Status::OK(); -}*/ +} // nulls will be promoted as follows: // cond.valid && (cond.data && left.valid || ~cond.data && right.valid) @@ -347,7 +371,7 @@ Status PromoteNullsVisitor(KernelContext* ctx, const ArrayData& cond, output->buffers[0] = std::move(out_validity); output->GetNullCount(); // update null count return Status::OK(); -}*/ +} // cond.valid && (cond.data && left.valid || ~cond.data && right.valid) // ASA and AAS @@ -386,7 +410,7 @@ Status PromoteNulls(KernelContext* ctx, const ArrayData& cond, const Scalar& lef output->GetNullCount(); // update null count return Status::OK(); } -/* + // cond.valid && (cond.data && left.valid || ~cond.data && right.valid) // ASS Status PromoteNulls(KernelContext* ctx, const ArrayData& cond, const Scalar& left, @@ -422,7 +446,7 @@ Status PromoteNulls(KernelContext* ctx, const ArrayData& cond, const Scalar& lef output->buffers[0] = std::move(out_validity); output->GetNullCount(); // update null count return Status::OK(); -}*/ +} // todo: this could be dangerous because the inverted arraydata buffer[1] may not be // available outside Exec's scope @@ -444,14 +468,19 @@ Status InvertBoolArrayData(KernelContext* ctx, const ArrayData& input, output->buffers.emplace_back(std::move(inv_data)); return Status::OK(); } + */ template struct IfElseFunctor {}; +// only number types needs to be handled for Fixed sized primitive data types because, +// internal::GenerateTypeAgnosticPrimitive forwards types to the corresponding unsigned +// int type template -struct IfElseFunctor< - Type, enable_if_t::value | is_temporal_type::value>> { +struct IfElseFunctor> { using T = typename TypeTraits::CType; + // A - Array + // S - Scalar // AAA static Status Call(KernelContext* ctx, const ArrayData& cond, const ArrayData& left, @@ -495,11 +524,10 @@ struct IfElseFunctor< return Status::OK(); } - // ASA and AAS + // ASA static Status Call(KernelContext* ctx, const ArrayData& cond, const Scalar& left, const ArrayData& right, ArrayData* out) { - // todo change this! scalar and array is swapped just for compilation - ARROW_RETURN_NOT_OK(PromoteNullsVisitor(ctx, cond, right, left, out)); + ARROW_RETURN_NOT_OK(PromoteNullsVisitor(ctx, cond, left, right, out)); ARROW_ASSIGN_OR_RAISE(std::shared_ptr out_buf, ctx->Allocate(cond.length * sizeof(T))); @@ -537,44 +565,86 @@ struct IfElseFunctor< return Status::OK(); } + // AAS + static Status Call(KernelContext* ctx, const ArrayData& cond, const ArrayData& left, + const Scalar& right, ArrayData* out) { + ARROW_RETURN_NOT_OK(PromoteNullsVisitor(ctx, cond, left, right, out)); + + ARROW_ASSIGN_OR_RAISE(std::shared_ptr out_buf, + ctx->Allocate(cond.length * sizeof(T))); + T* out_values = reinterpret_cast(out_buf->mutable_data()); + + // copy left data to out_buff + const T* left_data = left.GetValues(1); + std::memcpy(out_values, left_data, left.length * sizeof(T)); + + const auto* cond_data = cond.buffers[1]->data(); // this is a BoolArray + BitBlockCounter bit_counter(cond_data, cond.offset, cond.length); + + // selectively copy values from left data + T right_data = internal::UnboxScalar::Unbox(right); + int64_t offset = cond.offset; + + // todo this can be improved by intrinsics. ex: _mm*_mask_store_e* (vmovdqa*) + // left data is already in the output buffer. Therefore, mask needs to be inverted + while (offset < cond.offset + cond.length) { + const BitBlockCount& block = bit_counter.NextWord(); + if (block.NoneSet()) { // all from right + std::fill(out_values, out_values + block.length, right_data); + } else if (block.popcount) { // selectively copy from right + for (int64_t i = 0; i < block.length; ++i) { + if (!BitUtil::GetBit(cond_data, offset + i)) { + out_values[i] = right_data; + } + } + } + + offset += block.length; + out_values += block.length; + } + + out->buffers[1] = std::move(out_buf); + return Status::OK(); + } + // ASS static Status Call(KernelContext* ctx, const ArrayData& cond, const Scalar& left, const Scalar& right, ArrayData* out) { ARROW_RETURN_NOT_OK(PromoteNullsVisitor(ctx, cond, left, right, out)); - /* - ARROW_ASSIGN_OR_RAISE(std::shared_ptr out_buf, - ctx->Allocate(cond.length * sizeof(T))); - T* out_values = reinterpret_cast(out_buf->mutable_data()); - - // copy right data to out_buff - const T* right_data = right.GetValues(1); - std::memcpy(out_values, right_data, right.length * sizeof(T)); - - const auto* cond_data = cond.buffers[1]->data(); // this is a BoolArray - BitBlockCounter bit_counter(cond_data, cond.offset, cond.length); - - // selectively copy values from left data - T left_data = internal::UnboxScalar::Unbox(left); - int64_t offset = cond.offset; - - // todo this can be improved by intrinsics. ex: _mm*_mask_store_e* (vmovdqa*) - while (offset < cond.offset + cond.length) { - const BitBlockCount& block = bit_counter.NextWord(); - if (block.AllSet()) { // all from left - std::fill(out_values, out_values + block.length, left_data); - } else if (block.popcount) { // selectively copy from left - for (int64_t i = 0; i < block.length; ++i) { - if (BitUtil::GetBit(cond_data, offset + i)) { - out_values[i] = left_data; - } - } - } - - offset += block.length; - out_values += block.length; - } - - out->buffers[1] = std::move(out_buf);*/ + + ARROW_ASSIGN_OR_RAISE(std::shared_ptr out_buf, + ctx->Allocate(cond.length * sizeof(T))); + T* out_values = reinterpret_cast(out_buf->mutable_data()); + + // copy right data to out_buff + T right_data = internal::UnboxScalar::Unbox(right); + std::fill(out_values, out_values + cond.length, right_data); + + const auto* cond_data = cond.buffers[1]->data(); // this is a BoolArray + BitBlockCounter bit_counter(cond_data, cond.offset, cond.length); + + // selectively copy values from left data + T left_data = internal::UnboxScalar::Unbox(left); + int64_t offset = cond.offset; + + // todo this can be improved by intrinsics. ex: _mm*_mask_store_e* (vmovdqa*) + while (offset < cond.offset + cond.length) { + const BitBlockCount& block = bit_counter.NextWord(); + if (block.AllSet()) { // all from left + std::fill(out_values, out_values + block.length, left_data); + } else if (block.popcount) { // selectively copy from left + for (int64_t i = 0; i < block.length; ++i) { + if (BitUtil::GetBit(cond_data, offset + i)) { + out_values[i] = left_data; + } + } + } + + offset += block.length; + out_values += block.length; + } + + out->buffers[1] = std::move(out_buf); return Status::OK(); } }; @@ -604,10 +674,10 @@ struct IfElseFunctor> { return Status::OK(); } - // ASA and AAS + // ASA static Status Call(KernelContext* ctx, const ArrayData& cond, const Scalar& left, const ArrayData& right, ArrayData* out) { - ARROW_RETURN_NOT_OK(PromoteNulls(ctx, cond, left, right, out)); + ARROW_RETURN_NOT_OK(PromoteNullsVisitor(ctx, cond, left, right, out)); // out_buff = right & ~cond ARROW_ASSIGN_OR_RAISE(std::shared_ptr out_buf, @@ -626,6 +696,13 @@ struct IfElseFunctor> { return Status::OK(); } + // AAS + static Status Call(KernelContext* ctx, const ArrayData& cond, const ArrayData& left, + const Scalar& right, ArrayData* out) { + // todo impl + return Status::OK(); + } + // ASS static Status Call(KernelContext* ctx, const ArrayData& cond, const Scalar& left, const Scalar& right, ArrayData* out) { @@ -649,12 +726,18 @@ struct IfElseFunctor> { return ReturnCopy(left, out); } - // ASA and AAS + // ASA static Status Call(KernelContext* ctx, const ArrayData& cond, const Scalar& left, const ArrayData& right, ArrayData* out) { return ReturnCopy(right, out); } + // AAS + static Status Call(KernelContext* ctx, const ArrayData& cond, const ArrayData& left, + const Scalar& right, ArrayData* out) { + return ReturnCopy(left, out); + } + // ASS static Status Call(KernelContext* ctx, const ArrayData& cond, const Scalar& left, const Scalar& right, ArrayData* out) { @@ -700,10 +783,8 @@ struct ResolveIfElseExec { return IfElseFunctor::Call(ctx, *batch[0].array(), *batch[1].array(), *batch[2].array(), out->mutable_array()); } else { // AAS - ArrayData inv_cond; - RETURN_NOT_OK(InvertBoolArrayData(ctx, *batch[0].array(), &inv_cond)); - return IfElseFunctor::Call(ctx, inv_cond, *batch[2].scalar(), - *batch[1].array(), out->mutable_array()); + return IfElseFunctor::Call(ctx, *batch[0].array(), *batch[1].array(), + *batch[2].scalar(), out->mutable_array()); } } else { if (batch[2].kind() == Datum::ARRAY) { // ASA @@ -732,6 +813,7 @@ void AddPrimitiveIfElseKernels(const std::shared_ptr& scalar_fun } // namespace +// todo fill this const FunctionDoc if_else_doc{"", ("`"), {"cond", "left", "right"}}; namespace internal { From ba36d8d72484573217b2f338014ab6c635d83884 Mon Sep 17 00:00:00 2001 From: niranda perera Date: Wed, 26 May 2021 22:05:41 -0400 Subject: [PATCH 14/39] extending tests --- .../arrow/compute/kernels/scalar_if_else.cc | 4 +- .../compute/kernels/scalar_if_else_test.cc | 185 +++++++++++++++--- 2 files changed, 155 insertions(+), 34 deletions(-) diff --git a/cpp/src/arrow/compute/kernels/scalar_if_else.cc b/cpp/src/arrow/compute/kernels/scalar_if_else.cc index 87c5f585c2ea2..e646db1241d72 100644 --- a/cpp/src/arrow/compute/kernels/scalar_if_else.cc +++ b/cpp/src/arrow/compute/kernels/scalar_if_else.cc @@ -17,9 +17,9 @@ #include #include +#include #include -#include "arrow/util/bitmap.h" #include "codegen_internal.h" namespace arrow { @@ -835,4 +835,4 @@ void RegisterScalarIfElse(FunctionRegistry* registry) { } // namespace internal } // namespace compute -} // namespace arrow \ No newline at end of file +} // namespace arrow diff --git a/cpp/src/arrow/compute/kernels/scalar_if_else_test.cc b/cpp/src/arrow/compute/kernels/scalar_if_else_test.cc index 3e55006f8835d..2b816406915a4 100644 --- a/cpp/src/arrow/compute/kernels/scalar_if_else_test.cc +++ b/cpp/src/arrow/compute/kernels/scalar_if_else_test.cc @@ -36,10 +36,9 @@ void CheckIfElseOutputArray(const Datum& cond, const Datum& left, const Datum& r } } -void CheckIfElseOutputArray(const std::shared_ptr& type, - const std::string& cond, const std::string& left, - const std::string& right, const std::string& expected, - bool all_valid = true) { +void CheckIfElseOutputAAA(const std::shared_ptr& type, const std::string& cond, + const std::string& left, const std::string& right, + const std::string& expected, bool all_valid = true) { const std::shared_ptr& cond_ = ArrayFromJSON(boolean(), cond); const std::shared_ptr& left_ = ArrayFromJSON(type, left); const std::shared_ptr& right_ = ArrayFromJSON(type, right); @@ -47,6 +46,33 @@ void CheckIfElseOutputArray(const std::shared_ptr& type, CheckIfElseOutputArray(cond_, left_, right_, expected_, all_valid); } +void CheckIfElseOutputAAS(const std::shared_ptr& type, const std::string& cond, + const std::string& left, const std::shared_ptr& right, + const std::string& expected, bool all_valid = true) { + const std::shared_ptr& cond_ = ArrayFromJSON(boolean(), cond); + const std::shared_ptr& left_ = ArrayFromJSON(type, left); + const std::shared_ptr& expected_ = ArrayFromJSON(type, expected); + CheckIfElseOutputArray(cond_, left_, right, expected_, all_valid); +} + +void CheckIfElseOutputASA(const std::shared_ptr& type, const std::string& cond, + const std::shared_ptr& left, const std::string& right, + const std::string& expected, bool all_valid = true) { + const std::shared_ptr& cond_ = ArrayFromJSON(boolean(), cond); + const std::shared_ptr& right_ = ArrayFromJSON(type, right); + const std::shared_ptr& expected_ = ArrayFromJSON(type, expected); + CheckIfElseOutputArray(cond_, left, right_, expected_, all_valid); +} + +void CheckIfElseOutputASS(const std::shared_ptr& type, const std::string& cond, + const std::shared_ptr& left, + const std::shared_ptr& right, + const std::string& expected, bool all_valid = true) { + const std::shared_ptr& cond_ = ArrayFromJSON(boolean(), cond); + const std::shared_ptr& expected_ = ArrayFromJSON(type, expected); + CheckIfElseOutputArray(cond_, left, right, expected_, all_valid); +} + class TestIfElseKernel : public ::testing::Test {}; template @@ -62,29 +88,33 @@ TYPED_TEST(TestIfElsePrimitive, IfElseFixedSize) { auto type = TypeTraits::type_singleton(); // No Nulls - CheckIfElseOutputArray(type, "[]", "[]", "[]", "[]"); + CheckIfElseOutputAAA(type, "[]", "[]", "[]", "[]"); + // -------- All arrays --------- // RLC = 111 - CheckIfElseOutputArray(type, "[true, true, true, false]", "[1, 2, 3, 4]", - "[5, 6, 7, 8]", "[1, 2, 3, 8]"); + CheckIfElseOutputAAA(type, "[true, true, true, false]", "[1, 2, 3, 4]", "[5, 6, 7, 8]", + "[1, 2, 3, 8]"); // RLC = 110 - CheckIfElseOutputArray(type, "[true, true, null, false]", "[1, 2, 3, 4]", - "[5, 6, 7, 8]", "[1, 2, null, 8]", false); + CheckIfElseOutputAAA(type, "[true, true, null, false]", "[1, 2, 3, 4]", "[5, 6, 7, 8]", + "[1, 2, null, 8]", false); + // RLC = 101 + CheckIfElseOutputAAA(type, "[true, true, true, false]", "[1, null, 3, 4]", + "[5, 6, 7, 8]", "[1, null, 3, 8]", false); // RLC = 100 - CheckIfElseOutputArray(type, "[true, true, null, false]", "[1, null, 3, 4]", - "[5, 6, 7, 8]", "[1, null, null, 8]", false); + CheckIfElseOutputAAA(type, "[true, true, null, false]", "[1, null, 3, 4]", + "[5, 6, 7, 8]", "[1, null, null, 8]", false); // RLC = 011 - CheckIfElseOutputArray(type, "[true, true, true, false]", "[1, 2, 3, 4]", - "[5, 6, 7, null]", "[1, 2, 3, null]", false); + CheckIfElseOutputAAA(type, "[true, true, true, false]", "[1, 2, 3, 4]", + "[5, 6, 7, null]", "[1, 2, 3, null]", false); // RLC = 010 - CheckIfElseOutputArray(type, "[null, true, true, false]", "[1, 2, 3, 4]", - "[5, 6, 7, null]", "[null, 2, 3, null]", false); + CheckIfElseOutputAAA(type, "[null, true, true, false]", "[1, 2, 3, 4]", + "[5, 6, 7, null]", "[null, 2, 3, null]", false); // RLC = 001 - CheckIfElseOutputArray(type, "[true, true, true, false]", "[1, 2, null, null]", - "[null, 6, 7, null]", "[1, 2, null, null]", false); + CheckIfElseOutputAAA(type, "[true, true, true, false]", "[1, 2, null, null]", + "[null, 6, 7, null]", "[1, 2, null, null]", false); // RLC = 000 - CheckIfElseOutputArray(type, "[null, true, true, false]", "[1, 2, null, null]", - "[null, 6, 7, null]", "[null, 2, null, null]", false); + CheckIfElseOutputAAA(type, "[null, true, true, false]", "[1, 2, null, null]", + "[null, 6, 7, null]", "[null, 2, null, null]", false); using ArrayType = typename TypeTraits::ArrayType; random::RandomArrayGenerator rand(/*seed=*/0); @@ -114,23 +144,114 @@ TYPED_TEST(TestIfElsePrimitive, IfElseFixedSize) { ASSERT_OK_AND_ASSIGN(auto expected_data, builder.Finish()); CheckIfElseOutputArray(cond, left, right, expected_data, false); + + // -------- Cond - Array, Left- Array, Right - Scalar --------- + + ASSERT_OK_AND_ASSIGN(std::shared_ptr valid_scalar, MakeScalar(type, 100)); + std::shared_ptr null_scalar = MakeNullScalar(type); + + // empty + CheckIfElseOutputAAS(type, "[]", "[]", valid_scalar, "[]"); + + // RLC = 111 + CheckIfElseOutputAAS(type, "[true, true, true, false]", "[1, 2, 3, 4]", valid_scalar, + "[1, 2, 3, 100]"); + // RLC = 110 + CheckIfElseOutputAAS(type, "[true, true, null, false]", "[1, 2, 3, 4]", valid_scalar, + "[1, 2, null, 100]", false); + // RLC = 101 + CheckIfElseOutputAAS(type, "[true, true, true, false]", "[1, null, 3, 4]", valid_scalar, + "[1, null, 3, 100]", false); + // RLC = 100 + CheckIfElseOutputAAS(type, "[true, true, null, false]", "[1, null, 3, 4]", valid_scalar, + "[1, null, null, 100]", false); + // RLC = 011 + CheckIfElseOutputAAS(type, "[true, true, true, false]", "[1, 2, 3, 4]", null_scalar, + "[1, 2, 3, null]", false); + // RLC = 010 + CheckIfElseOutputAAS(type, "[null, true, true, false]", "[1, 2, 3, 4]", null_scalar, + "[null, 2, 3, null]", false); + // RLC = 001 + CheckIfElseOutputAAS(type, "[true, true, true, false]", "[1, 2, null, null]", + null_scalar, "[1, 2, null, null]", false); + // RLC = 000 + CheckIfElseOutputAAS(type, "[null, true, true, false]", "[1, 2, null, null]", + null_scalar, "[null, 2, null, null]", false); + + // -------- Cond - Array, Left- Scalar, Right - Array --------- + // empty + CheckIfElseOutputASA(type, "[]", valid_scalar, "[]", "[]"); + + // LRC = 111 + CheckIfElseOutputASA(type, "[true, true, true, false]", valid_scalar, "[1, 2, 3, 4]", + "[100, 100, 100, 4]"); + // LRC = 110 + CheckIfElseOutputASA(type, "[true, true, null, false]", valid_scalar, "[1, 2, 3, 4]", + "[100, 100, null, 4]", false); + // LRC = 101 + CheckIfElseOutputASA(type, "[true, true, true, false]", valid_scalar, + "[1, null, 3, null]", "[100, 100, 100, null]", false); + // LRC = 100 + CheckIfElseOutputASA(type, "[true, true, null, false]", valid_scalar, + "[1, null, 3, null]", "[100, 100, null, null]", false); + // LRC = 011 + CheckIfElseOutputASA(type, "[true, true, true, false]", null_scalar, "[1, 2, 3, 4]", + "[null, null, null, 4]", false); + // LRC = 010 + CheckIfElseOutputASA(type, "[null, true, true, false]", null_scalar, "[1, 2, 3, 4]", + "[null, null, null, 4]", false); + // LRC = 001 + CheckIfElseOutputASA(type, "[true, true, true, false]", null_scalar, "[1, 2, null, 4]", + "[null, null, null, 4]", false); + // LRC = 000 + CheckIfElseOutputASA(type, "[true, true, null, false]", null_scalar, "[1, 2, null, 4]", + "[null, null, null, 4]", false); + + // -------- Cond - Array, Left- Scalar, Right - Scalar --------- + ASSERT_OK_AND_ASSIGN(std::shared_ptr valid_scalar1, MakeScalar(type, 111)); + + // empty + CheckIfElseOutputASS(type, "[]", valid_scalar, valid_scalar1, "[]"); + + // RLC = 111 + CheckIfElseOutputASS(type, "[true, true, true, false]", valid_scalar, valid_scalar1, + "[100, 100, 100, 111]"); + // LRC = 110 + CheckIfElseOutputASS(type, "[true, true, null, false]", valid_scalar, valid_scalar1, + "[100, 100, null, 111]", false); + // LRC = 101 + CheckIfElseOutputASS(type, "[true, true, null, false]", valid_scalar, null_scalar, + "[100, 100, null, null]", false); + // LRC = 100 + CheckIfElseOutputASS(type, "[true, true, null, false]", valid_scalar, null_scalar, + "[100, 100, null, null]", false); + // LRC = 011 + CheckIfElseOutputASS(type, "[true, true, true, false]", null_scalar, valid_scalar1, + "[null, null, null, 111]", false); + // LRC = 010 + CheckIfElseOutputASS(type, "[null, true, true, false]", null_scalar, valid_scalar1, + "[null, null, null, 111]", false); + // LRC = 001 + CheckIfElseOutputASS(type, "[true, true, true, false]", null_scalar, null_scalar, + "[null, null, null, null]", false); + // LRC = 000 + CheckIfElseOutputASS(type, "[true, true, null, false]", null_scalar, null_scalar, + "[null, null, null, null]", false); } TEST_F(TestIfElseKernel, IfElseBoolean) { auto type = boolean(); // No Nulls - CheckIfElseOutputArray(type, "[]", "[]", "[]", "[]"); + CheckIfElseOutputAAA(type, "[]", "[]", "[]", "[]"); - CheckIfElseOutputArray(type, "[true, true, true, false]", - "[false, false, false, false]", "[true, true, true, true]", - "[false, false, false, true]"); + CheckIfElseOutputAAA(type, "[true, true, true, false]", "[false, false, false, false]", + "[true, true, true, true]", "[false, false, false, true]"); - CheckIfElseOutputArray(type, "[true, true, null, false]", - "[false, false, false, false]", "[true, true, true, true]", - "[false, false, null, true]", false); + CheckIfElseOutputAAA(type, "[true, true, null, false]", "[false, false, false, false]", + "[true, true, true, true]", "[false, false, null, true]", false); - CheckIfElseOutputArray(type, "[true, true, true, false]", "[true, false, null, null]", - "[null, false, true, null]", "[true, false, null, null]", false); + CheckIfElseOutputAAA(type, "[true, true, true, false]", "[true, false, null, null]", + "[null, false, true, null]", "[true, false, null, null]", false); random::RandomArrayGenerator rand(/*seed=*/0); int64_t len = 1000; @@ -161,10 +282,10 @@ TEST_F(TestIfElseKernel, IfElseBoolean) { } TEST_F(TestIfElseKernel, IfElseNull) { - CheckIfElseOutputArray(null(), "[null, null, null, null]", "[null, null, null, null]", - "[null, null, null, null]", "[null, null, null, null]", - /*all_valid=*/false); + CheckIfElseOutputAAA(null(), "[null, null, null, null]", "[null, null, null, null]", + "[null, null, null, null]", "[null, null, null, null]", + /*all_valid=*/false); } } // namespace compute -} // namespace arrow \ No newline at end of file +} // namespace arrow From cc339f20eb009d94ecd446d93c9d4957e880c295 Mon Sep 17 00:00:00 2001 From: niranda perera Date: Thu, 27 May 2021 09:47:29 -0400 Subject: [PATCH 15/39] completing bool type impl --- .../arrow/compute/kernels/scalar_if_else.cc | 55 ++++++++++++++++++- 1 file changed, 52 insertions(+), 3 deletions(-) diff --git a/cpp/src/arrow/compute/kernels/scalar_if_else.cc b/cpp/src/arrow/compute/kernels/scalar_if_else.cc index e646db1241d72..375b1da2688a1 100644 --- a/cpp/src/arrow/compute/kernels/scalar_if_else.cc +++ b/cpp/src/arrow/compute/kernels/scalar_if_else.cc @@ -20,7 +20,7 @@ #include #include -#include "codegen_internal.h" +#include "arrow/compute/kernels/codegen_internal.h" namespace arrow { using internal::BitBlockCount; @@ -699,14 +699,63 @@ struct IfElseFunctor> { // AAS static Status Call(KernelContext* ctx, const ArrayData& cond, const ArrayData& left, const Scalar& right, ArrayData* out) { - // todo impl + ARROW_RETURN_NOT_OK(PromoteNullsVisitor(ctx, cond, left, right, out)); + + // out_buff = left & cond + ARROW_ASSIGN_OR_RAISE(std::shared_ptr out_buf, + arrow::internal::BitmapAnd( + ctx->memory_pool(), left.buffers[1]->data(), left.offset, + cond.buffers[1]->data(), cond.offset, cond.length, 0)); + + bool right_data = internal::UnboxScalar::Unbox(right); + + // out_buff = left & cond | right & ~cond + if (right_data) { + ARROW_ASSIGN_OR_RAISE( + std::shared_ptr tmp_buf, + arrow::internal::InvertBitmap(ctx->memory_pool(), cond.buffers[1]->data(), + cond.offset, cond.length)); + arrow::internal::BitmapOr(tmp_buf->data(), 0, out_buf->data(), 0, cond.length, 0, + out_buf->mutable_data()); + } + + out->buffers[1] = std::move(out_buf); return Status::OK(); } // ASS static Status Call(KernelContext* ctx, const ArrayData& cond, const Scalar& left, const Scalar& right, ArrayData* out) { - // todo impl + ARROW_RETURN_NOT_OK(PromoteNullsVisitor(ctx, cond, left, right, out)); + + bool left_data = internal::UnboxScalar::Unbox(left); + bool right_data = internal::UnboxScalar::Unbox(right); + + // out_buf = left & cond | right & ~cond + std::shared_ptr out_buf = nullptr; + if (left_data) { + if (right_data) { + // out_buf = ones + ARROW_ASSIGN_OR_RAISE(out_buf, ctx->AllocateBitmap(cond.length)); + // filling with UINT8_MAX upto the buffer's size (in bytes) + std::fill(out_buf->mutable_data(), out_buf->mutable_data() + out_buf->size(), + UINT8_MAX); + } else { + // out_buf = cond + out_buf = SliceBuffer(cond.buffers[1], cond.offset, cond.length); + } + } else { + if (right_data) { + // out_buf = ~cond + ARROW_ASSIGN_OR_RAISE(out_buf, arrow::internal::InvertBitmap( + ctx->memory_pool(), cond.buffers[1]->data(), + cond.offset, cond.length)) + } else { + // out_buf = zeros + ARROW_ASSIGN_OR_RAISE(out_buf, ctx->AllocateBitmap(cond.length)); + } + } + out->buffers[1] = std::move(out_buf); return Status::OK(); } }; From 625f36e3137293b697e73d20692ff133be9489ee Mon Sep 17 00:00:00 2001 From: niranda perera Date: Thu, 27 May 2021 10:18:32 -0400 Subject: [PATCH 16/39] adding PR review suggestions --- cpp/src/arrow/compute/api_scalar.h | 4 +- .../arrow/compute/kernels/scalar_if_else.cc | 50 ++--- .../compute/kernels/scalar_if_else_test.cc | 198 +++++++++--------- 3 files changed, 129 insertions(+), 123 deletions(-) diff --git a/cpp/src/arrow/compute/api_scalar.h b/cpp/src/arrow/compute/api_scalar.h index ed8a55e0a0757..60a8028df8575 100644 --- a/cpp/src/arrow/compute/api_scalar.h +++ b/cpp/src/arrow/compute/api_scalar.h @@ -466,7 +466,7 @@ Result FillNull(const Datum& values, const Datum& fill_value, ExecContext* ctx = NULLPTR); /// \brief IfElse returns elements chosen from `left` or `right` -/// depending on `cond`. `Null` values would be promoted to the result +/// depending on `cond`. `null` values would be promoted to the result /// /// \param[in] cond `BooleanArray` condition array /// \param[in] left scalar/ Array @@ -475,7 +475,7 @@ Result FillNull(const Datum& values, const Datum& fill_value, /// /// \return the resulting datum /// -/// \since x.x.x +/// \since 5.0.0 /// \note API not yet finalized ARROW_EXPORT Result IfElse(const Datum& cond, const Datum& left, const Datum& right, diff --git a/cpp/src/arrow/compute/kernels/scalar_if_else.cc b/cpp/src/arrow/compute/kernels/scalar_if_else.cc index 375b1da2688a1..90041f8444b3c 100644 --- a/cpp/src/arrow/compute/kernels/scalar_if_else.cc +++ b/cpp/src/arrow/compute/kernels/scalar_if_else.cc @@ -79,7 +79,7 @@ Status PromoteNullsVisitor(KernelContext* ctx, const ArrayData& cond, const Scal case COND_ALL_VALID: // = 1 // out_valid = 0 --> nothing to do; but requires out_valid to be a all-zero buffer break; - case 0: // RLC = 000 + case 0: // out_valid = 0 --> nothing to do; but requires out_valid to be a all-zero buffer break; } @@ -117,40 +117,40 @@ Status PromoteNullsVisitor(KernelContext* ctx, const ArrayData& cond, // selected argument // ie. cond.valid & (cond.data & left.valid | ~cond.data & right.valid) switch (flag) { - case COND_ALL_VALID | LEFT_ALL_VALID | RIGHT_ALL_VALID: // RLC = 111 + case COND_ALL_VALID | LEFT_ALL_VALID | RIGHT_ALL_VALID: break; - case LEFT_ALL_VALID | RIGHT_ALL_VALID: // RLC = 110 + case LEFT_ALL_VALID | RIGHT_ALL_VALID: output->buffers[0] = SliceBuffer(cond.buffers[0], cond.offset, cond.length); break; - case COND_ALL_VALID | RIGHT_ALL_VALID: // RLC = 101 + case COND_ALL_VALID | RIGHT_ALL_VALID: // bitmaps[C_VALID] might be null; override to make it safe for Visit() bitmaps[C_VALID] = bitmaps[C_DATA]; Bitmap::VisitWords(bitmaps, [&](std::array words) { apply(UINT64_MAX, words[C_DATA], words[L_VALID], UINT64_MAX); }); break; - case RIGHT_ALL_VALID: // RLC = 100 + case RIGHT_ALL_VALID: Bitmap::VisitWords(bitmaps, [&](std::array words) { apply(words[C_VALID], words[C_DATA], words[L_VALID], UINT64_MAX); }); break; - case COND_ALL_VALID | LEFT_ALL_VALID: // RLC = 011 + case COND_ALL_VALID | LEFT_ALL_VALID: // only cond.data is passed output->buffers[0] = SliceBuffer(cond.buffers[1], cond.offset, cond.length); break; - case LEFT_ALL_VALID: // RLC = 010 + case LEFT_ALL_VALID: // out_valid = cond.valid & cond.data arrow::internal::BitmapAnd(cond.buffers[0]->data(), cond.offset, cond.buffers[1]->data(), cond.offset, cond.length, 0, output->buffers[0]->mutable_data()); break; - case COND_ALL_VALID: // RLC = 001 + case COND_ALL_VALID: // out_valid = cond.data & left.valid arrow::internal::BitmapAnd(cond.buffers[1]->data(), cond.offset, left.buffers[0]->data(), left.offset, cond.length, 0, output->buffers[0]->mutable_data()); break; - case 0: // RLC = 000 + case 0: Bitmap::VisitWords(bitmaps, [&](std::array words) { apply(words[C_VALID], words[C_DATA], words[L_VALID], 0); }); @@ -192,41 +192,41 @@ Status PromoteNullsVisitor(KernelContext* ctx, const ArrayData& cond, const Scal // selected argument // ie. cond.valid & (cond.data & left.valid | ~cond.data & right.valid) switch (flag) { - case COND_ALL_VALID | LEFT_ALL_VALID | RIGHT_ALL_VALID: // RLC = 111 + case COND_ALL_VALID | LEFT_ALL_VALID | RIGHT_ALL_VALID: break; - case LEFT_ALL_VALID | RIGHT_ALL_VALID: // RLC = 110 + case LEFT_ALL_VALID | RIGHT_ALL_VALID: output->buffers[0] = SliceBuffer(cond.buffers[0], cond.offset, cond.length); break; - case COND_ALL_VALID | RIGHT_ALL_VALID: // RLC = 101 + case COND_ALL_VALID | RIGHT_ALL_VALID: // out_valid = ~cond.data arrow::internal::InvertBitmap(cond.buffers[1]->data(), cond.offset, cond.length, output->buffers[0]->mutable_data(), 0); break; - case RIGHT_ALL_VALID: // RLC = 100 + case RIGHT_ALL_VALID: // out_valid = c_valid & ~cond.data arrow::internal::BitmapAndNot(cond.buffers[0]->data(), cond.offset, cond.buffers[1]->data(), cond.offset, cond.length, 0, output->buffers[0]->mutable_data()); break; - case COND_ALL_VALID | LEFT_ALL_VALID: // RLC = 011 + case COND_ALL_VALID | LEFT_ALL_VALID: // bitmaps[C_VALID] might be null; override to make it safe for Visit() bitmaps[C_VALID] = bitmaps[C_DATA]; Bitmap::VisitWords(bitmaps, [&](std::array words) { apply(UINT64_MAX, words[C_DATA], UINT64_MAX, words[R_VALID]); }); break; - case LEFT_ALL_VALID: // RLC = 010 + case LEFT_ALL_VALID: Bitmap::VisitWords(bitmaps, [&](std::array words) { apply(words[C_VALID], words[C_DATA], UINT64_MAX, words[R_VALID]); }); break; - case COND_ALL_VALID: // RLC = 001 + case COND_ALL_VALID: // out_valid = ~cond.data & right.valid arrow::internal::BitmapAndNot(right.buffers[0]->data(), right.offset, cond.buffers[1]->data(), cond.offset, cond.length, 0, output->buffers[0]->mutable_data()); break; - case 0: // RLC = 000 + case 0: Bitmap::VisitWords(bitmaps, [&](std::array words) { apply(words[C_VALID], words[C_DATA], 0, words[R_VALID]); }); @@ -271,12 +271,12 @@ Status PromoteNullsVisitor(KernelContext* ctx, const ArrayData& cond, // selected argument // ie. cond.valid & (cond.data & left.valid | ~cond.data & right.valid) switch (flag) { - case COND_ALL_VALID | LEFT_ALL_VALID | RIGHT_ALL_VALID: // RLC = 111 + case COND_ALL_VALID | LEFT_ALL_VALID | RIGHT_ALL_VALID: break; - case LEFT_ALL_VALID | RIGHT_ALL_VALID: // RLC = 110 + case LEFT_ALL_VALID | RIGHT_ALL_VALID: output->buffers[0] = SliceBuffer(cond.buffers[0], cond.offset, cond.length); break; - case COND_ALL_VALID | RIGHT_ALL_VALID: // RLC = 101 + case COND_ALL_VALID | RIGHT_ALL_VALID: // bitmaps[C_VALID], bitmaps[R_VALID] might be null; override to make it safe for // Visit() bitmaps[C_VALID] = bitmaps[C_DATA]; @@ -285,14 +285,14 @@ Status PromoteNullsVisitor(KernelContext* ctx, const ArrayData& cond, apply(UINT64_MAX, words[C_DATA], words[L_VALID], UINT64_MAX); }); break; - case RIGHT_ALL_VALID: // RLC = 100 + case RIGHT_ALL_VALID: // bitmaps[R_VALID] might be null; override to make it safe for Visit() bitmaps[R_VALID] = bitmaps[C_DATA]; Bitmap::VisitWords(bitmaps, [&](std::array words) { apply(words[C_VALID], words[C_DATA], words[L_VALID], UINT64_MAX); }); break; - case COND_ALL_VALID | LEFT_ALL_VALID: // RLC = 011 + case COND_ALL_VALID | LEFT_ALL_VALID: // bitmaps[C_VALID], bitmaps[L_VALID] might be null; override to make it safe for // Visit() bitmaps[C_VALID] = bitmaps[C_DATA]; @@ -301,21 +301,21 @@ Status PromoteNullsVisitor(KernelContext* ctx, const ArrayData& cond, apply(UINT64_MAX, words[C_DATA], UINT64_MAX, words[R_VALID]); }); break; - case LEFT_ALL_VALID: // RLC = 010 + case LEFT_ALL_VALID: // bitmaps[L_VALID] might be null; override to make it safe for Visit() bitmaps[L_VALID] = bitmaps[C_DATA]; Bitmap::VisitWords(bitmaps, [&](std::array words) { apply(words[C_VALID], words[C_DATA], UINT64_MAX, words[R_VALID]); }); break; - case COND_ALL_VALID: // RLC = 001 + case COND_ALL_VALID: // bitmaps[C_VALID] might be null; override to make it safe for Visit() bitmaps[C_VALID] = bitmaps[C_DATA]; Bitmap::VisitWords(bitmaps, [&](std::array words) { apply(UINT64_MAX, words[C_DATA], words[L_VALID], words[R_VALID]); }); break; - case 0: // RLC = 000 + case 0: Bitmap::VisitWords(bitmaps, [&](std::array words) { apply(words[C_VALID], words[C_DATA], words[L_VALID], words[R_VALID]); }); diff --git a/cpp/src/arrow/compute/kernels/scalar_if_else_test.cc b/cpp/src/arrow/compute/kernels/scalar_if_else_test.cc index 2b816406915a4..0eb87d9de264a 100644 --- a/cpp/src/arrow/compute/kernels/scalar_if_else_test.cc +++ b/cpp/src/arrow/compute/kernels/scalar_if_else_test.cc @@ -25,52 +25,51 @@ namespace arrow { namespace compute { void CheckIfElseOutputArray(const Datum& cond, const Datum& left, const Datum& right, - const Datum& expected, bool all_valid = true) { + const Datum& expected) { ASSERT_OK_AND_ASSIGN(Datum datum_out, IfElse(cond, left, right)); std::shared_ptr result = datum_out.make_array(); ASSERT_OK(result->ValidateFull()); + std::shared_ptr expected_ = expected.make_array(); AssertArraysEqual(*expected.make_array(), *result, /*verbose=*/true); - if (all_valid) { - // Check null count of ArrayData is set, not the computed Array.null_count - ASSERT_EQ(result->data()->null_count, 0); - } + + ASSERT_EQ(result->data()->null_count, expected_->data()->null_count); } void CheckIfElseOutputAAA(const std::shared_ptr& type, const std::string& cond, const std::string& left, const std::string& right, - const std::string& expected, bool all_valid = true) { + const std::string& expected) { const std::shared_ptr& cond_ = ArrayFromJSON(boolean(), cond); const std::shared_ptr& left_ = ArrayFromJSON(type, left); const std::shared_ptr& right_ = ArrayFromJSON(type, right); const std::shared_ptr& expected_ = ArrayFromJSON(type, expected); - CheckIfElseOutputArray(cond_, left_, right_, expected_, all_valid); + CheckIfElseOutputArray(cond_, left_, right_, expected_); } void CheckIfElseOutputAAS(const std::shared_ptr& type, const std::string& cond, const std::string& left, const std::shared_ptr& right, - const std::string& expected, bool all_valid = true) { + const std::string& expected) { const std::shared_ptr& cond_ = ArrayFromJSON(boolean(), cond); const std::shared_ptr& left_ = ArrayFromJSON(type, left); const std::shared_ptr& expected_ = ArrayFromJSON(type, expected); - CheckIfElseOutputArray(cond_, left_, right, expected_, all_valid); + CheckIfElseOutputArray(cond_, left_, right, expected_); } void CheckIfElseOutputASA(const std::shared_ptr& type, const std::string& cond, const std::shared_ptr& left, const std::string& right, - const std::string& expected, bool all_valid = true) { + const std::string& expected) { const std::shared_ptr& cond_ = ArrayFromJSON(boolean(), cond); const std::shared_ptr& right_ = ArrayFromJSON(type, right); const std::shared_ptr& expected_ = ArrayFromJSON(type, expected); - CheckIfElseOutputArray(cond_, left, right_, expected_, all_valid); + CheckIfElseOutputArray(cond_, left, right_, expected_); } void CheckIfElseOutputASS(const std::shared_ptr& type, const std::string& cond, const std::shared_ptr& left, const std::shared_ptr& right, - const std::string& expected, bool all_valid = true) { + const std::string& expected) { const std::shared_ptr& cond_ = ArrayFromJSON(boolean(), cond); const std::shared_ptr& expected_ = ArrayFromJSON(type, expected); - CheckIfElseOutputArray(cond_, left, right, expected_, all_valid); + CheckIfElseOutputArray(cond_, left, right, expected_); } class TestIfElseKernel : public ::testing::Test {}; @@ -84,39 +83,10 @@ using PrimitiveTypes = ::testing::Types::ArrayType; auto type = TypeTraits::type_singleton(); - // No Nulls - CheckIfElseOutputAAA(type, "[]", "[]", "[]", "[]"); - - // -------- All arrays --------- - // RLC = 111 - CheckIfElseOutputAAA(type, "[true, true, true, false]", "[1, 2, 3, 4]", "[5, 6, 7, 8]", - "[1, 2, 3, 8]"); - // RLC = 110 - CheckIfElseOutputAAA(type, "[true, true, null, false]", "[1, 2, 3, 4]", "[5, 6, 7, 8]", - "[1, 2, null, 8]", false); - // RLC = 101 - CheckIfElseOutputAAA(type, "[true, true, true, false]", "[1, null, 3, 4]", - "[5, 6, 7, 8]", "[1, null, 3, 8]", false); - // RLC = 100 - CheckIfElseOutputAAA(type, "[true, true, null, false]", "[1, null, 3, 4]", - "[5, 6, 7, 8]", "[1, null, null, 8]", false); - // RLC = 011 - CheckIfElseOutputAAA(type, "[true, true, true, false]", "[1, 2, 3, 4]", - "[5, 6, 7, null]", "[1, 2, 3, null]", false); - // RLC = 010 - CheckIfElseOutputAAA(type, "[null, true, true, false]", "[1, 2, 3, 4]", - "[5, 6, 7, null]", "[null, 2, 3, null]", false); - // RLC = 001 - CheckIfElseOutputAAA(type, "[true, true, true, false]", "[1, 2, null, null]", - "[null, 6, 7, null]", "[1, 2, null, null]", false); - // RLC = 000 - CheckIfElseOutputAAA(type, "[null, true, true, false]", "[1, 2, null, null]", - "[null, 6, 7, null]", "[null, 2, null, null]", false); - - using ArrayType = typename TypeTraits::ArrayType; random::RandomArrayGenerator rand(/*seed=*/0); int64_t len = 1000; auto cond = std::static_pointer_cast( @@ -143,7 +113,44 @@ TYPED_TEST(TestIfElsePrimitive, IfElseFixedSize) { } ASSERT_OK_AND_ASSIGN(auto expected_data, builder.Finish()); - CheckIfElseOutputArray(cond, left, right, expected_data, false); + CheckIfElseOutputArray(cond, left, right, expected_data); +} + +/* + * Legend: + * C - Cond, L - Left, R - Right + * 1 - All valid (or valid scalar), 0 - Could have nulls (or invalid scalar) + */ +TYPED_TEST(TestIfElsePrimitive, IfElseFixedSize) { + auto type = TypeTraits::type_singleton(); + + // -------- All arrays --------- + // empty + CheckIfElseOutputAAA(type, "[]", "[]", "[]", "[]"); + // CLR = 111 + CheckIfElseOutputAAA(type, "[true, true, true, false]", "[1, 2, 3, 4]", "[5, 6, 7, 8]", + "[1, 2, 3, 8]"); + // CLR = 011 + CheckIfElseOutputAAA(type, "[true, true, null, false]", "[1, 2, 3, 4]", "[5, 6, 7, 8]", + "[1, 2, null, 8]"); + // CLR = 101 + CheckIfElseOutputAAA(type, "[true, true, true, false]", "[1, null, 3, 4]", + "[5, 6, 7, 8]", "[1, null, 3, 8]"); + // CLR = 001 + CheckIfElseOutputAAA(type, "[true, true, null, false]", "[1, null, 3, 4]", + "[5, 6, 7, 8]", "[1, null, null, 8]"); + // CLR = 110 + CheckIfElseOutputAAA(type, "[true, true, true, false]", "[1, 2, 3, 4]", + "[5, 6, 7, null]", "[1, 2, 3, null]"); + // CLR = 010 + CheckIfElseOutputAAA(type, "[null, true, true, false]", "[1, 2, 3, 4]", + "[5, 6, 7, null]", "[null, 2, 3, null]"); + // CLR = 100 + CheckIfElseOutputAAA(type, "[true, true, true, false]", "[1, 2, null, null]", + "[null, 6, 7, null]", "[1, 2, null, null]"); + // CLR = 000 + CheckIfElseOutputAAA(type, "[null, true, true, false]", "[1, 2, null, null]", + "[null, 6, 7, null]", "[null, 2, null, null]"); // -------- Cond - Array, Left- Array, Right - Scalar --------- @@ -153,59 +160,59 @@ TYPED_TEST(TestIfElsePrimitive, IfElseFixedSize) { // empty CheckIfElseOutputAAS(type, "[]", "[]", valid_scalar, "[]"); - // RLC = 111 + // CLR = 111 CheckIfElseOutputAAS(type, "[true, true, true, false]", "[1, 2, 3, 4]", valid_scalar, "[1, 2, 3, 100]"); - // RLC = 110 + // CLR = 011 CheckIfElseOutputAAS(type, "[true, true, null, false]", "[1, 2, 3, 4]", valid_scalar, - "[1, 2, null, 100]", false); - // RLC = 101 + "[1, 2, null, 100]"); + // CLR = 101 CheckIfElseOutputAAS(type, "[true, true, true, false]", "[1, null, 3, 4]", valid_scalar, - "[1, null, 3, 100]", false); - // RLC = 100 + "[1, null, 3, 100]"); + // CLR = 001 CheckIfElseOutputAAS(type, "[true, true, null, false]", "[1, null, 3, 4]", valid_scalar, - "[1, null, null, 100]", false); - // RLC = 011 + "[1, null, null, 100]"); + // CLR = 110 CheckIfElseOutputAAS(type, "[true, true, true, false]", "[1, 2, 3, 4]", null_scalar, - "[1, 2, 3, null]", false); - // RLC = 010 + "[1, 2, 3, null]"); + // CLR = 010 CheckIfElseOutputAAS(type, "[null, true, true, false]", "[1, 2, 3, 4]", null_scalar, - "[null, 2, 3, null]", false); - // RLC = 001 + "[null, 2, 3, null]"); + // CLR = 100 CheckIfElseOutputAAS(type, "[true, true, true, false]", "[1, 2, null, null]", - null_scalar, "[1, 2, null, null]", false); - // RLC = 000 + null_scalar, "[1, 2, null, null]"); + // CLR = 000 CheckIfElseOutputAAS(type, "[null, true, true, false]", "[1, 2, null, null]", - null_scalar, "[null, 2, null, null]", false); + null_scalar, "[null, 2, null, null]"); // -------- Cond - Array, Left- Scalar, Right - Array --------- // empty CheckIfElseOutputASA(type, "[]", valid_scalar, "[]", "[]"); - // LRC = 111 + // CLR = 111 CheckIfElseOutputASA(type, "[true, true, true, false]", valid_scalar, "[1, 2, 3, 4]", "[100, 100, 100, 4]"); - // LRC = 110 + // CLR = 011 CheckIfElseOutputASA(type, "[true, true, null, false]", valid_scalar, "[1, 2, 3, 4]", - "[100, 100, null, 4]", false); - // LRC = 101 + "[100, 100, null, 4]"); + // CLR = 110 CheckIfElseOutputASA(type, "[true, true, true, false]", valid_scalar, - "[1, null, 3, null]", "[100, 100, 100, null]", false); - // LRC = 100 + "[1, null, 3, null]", "[100, 100, 100, null]"); + // CLR = 010 CheckIfElseOutputASA(type, "[true, true, null, false]", valid_scalar, - "[1, null, 3, null]", "[100, 100, null, null]", false); - // LRC = 011 + "[1, null, 3, null]", "[100, 100, null, null]"); + // CLR = 101 CheckIfElseOutputASA(type, "[true, true, true, false]", null_scalar, "[1, 2, 3, 4]", - "[null, null, null, 4]", false); - // LRC = 010 + "[null, null, null, 4]"); + // CLR = 001 CheckIfElseOutputASA(type, "[null, true, true, false]", null_scalar, "[1, 2, 3, 4]", - "[null, null, null, 4]", false); - // LRC = 001 + "[null, null, null, 4]"); + // CLR = 100 CheckIfElseOutputASA(type, "[true, true, true, false]", null_scalar, "[1, 2, null, 4]", - "[null, null, null, 4]", false); - // LRC = 000 + "[null, null, null, 4]"); + // CLR = 000 CheckIfElseOutputASA(type, "[true, true, null, false]", null_scalar, "[1, 2, null, 4]", - "[null, null, null, 4]", false); + "[null, null, null, 4]"); // -------- Cond - Array, Left- Scalar, Right - Scalar --------- ASSERT_OK_AND_ASSIGN(std::shared_ptr valid_scalar1, MakeScalar(type, 111)); @@ -213,30 +220,30 @@ TYPED_TEST(TestIfElsePrimitive, IfElseFixedSize) { // empty CheckIfElseOutputASS(type, "[]", valid_scalar, valid_scalar1, "[]"); - // RLC = 111 + // CLR = 111 CheckIfElseOutputASS(type, "[true, true, true, false]", valid_scalar, valid_scalar1, "[100, 100, 100, 111]"); - // LRC = 110 + // CLR = 011 CheckIfElseOutputASS(type, "[true, true, null, false]", valid_scalar, valid_scalar1, - "[100, 100, null, 111]", false); - // LRC = 101 - CheckIfElseOutputASS(type, "[true, true, null, false]", valid_scalar, null_scalar, - "[100, 100, null, null]", false); - // LRC = 100 + "[100, 100, null, 111]"); + // CLR = 010 CheckIfElseOutputASS(type, "[true, true, null, false]", valid_scalar, null_scalar, - "[100, 100, null, null]", false); - // LRC = 011 + "[100, 100, null, null]"); + // CLR = 110 + CheckIfElseOutputASS(type, "[true, true, true, false]", valid_scalar, null_scalar, + "[100, 100, 100, null]"); + // CLR = 101 CheckIfElseOutputASS(type, "[true, true, true, false]", null_scalar, valid_scalar1, - "[null, null, null, 111]", false); - // LRC = 010 + "[null, null, null, 111]"); + // CLR = 001 CheckIfElseOutputASS(type, "[null, true, true, false]", null_scalar, valid_scalar1, - "[null, null, null, 111]", false); - // LRC = 001 + "[null, null, null, 111]"); + // CLR = 100 CheckIfElseOutputASS(type, "[true, true, true, false]", null_scalar, null_scalar, - "[null, null, null, null]", false); - // LRC = 000 + "[null, null, null, null]"); + // CLR = 000 CheckIfElseOutputASS(type, "[true, true, null, false]", null_scalar, null_scalar, - "[null, null, null, null]", false); + "[null, null, null, null]"); } TEST_F(TestIfElseKernel, IfElseBoolean) { @@ -248,10 +255,10 @@ TEST_F(TestIfElseKernel, IfElseBoolean) { "[true, true, true, true]", "[false, false, false, true]"); CheckIfElseOutputAAA(type, "[true, true, null, false]", "[false, false, false, false]", - "[true, true, true, true]", "[false, false, null, true]", false); + "[true, true, true, true]", "[false, false, null, true]"); CheckIfElseOutputAAA(type, "[true, true, true, false]", "[true, false, null, null]", - "[null, false, true, null]", "[true, false, null, null]", false); + "[null, false, true, null]", "[true, false, null, null]"); random::RandomArrayGenerator rand(/*seed=*/0); int64_t len = 1000; @@ -278,13 +285,12 @@ TEST_F(TestIfElseKernel, IfElseBoolean) { } ASSERT_OK_AND_ASSIGN(auto expected_data, builder.Finish()); - CheckIfElseOutputArray(cond, left, right, expected_data, false); + CheckIfElseOutputArray(cond, left, right, expected_data); } TEST_F(TestIfElseKernel, IfElseNull) { CheckIfElseOutputAAA(null(), "[null, null, null, null]", "[null, null, null, null]", - "[null, null, null, null]", "[null, null, null, null]", - /*all_valid=*/false); + "[null, null, null, null]", "[null, null, null, null]"); } } // namespace compute From 673acde6c8c9e8cb1340fb0490e14b035effc0eb Mon Sep 17 00:00:00 2001 From: niranda perera Date: Thu, 27 May 2021 10:34:49 -0400 Subject: [PATCH 17/39] adding a method for set memory --- cpp/src/arrow/compute/kernels/scalar_if_else.cc | 4 ++-- cpp/src/arrow/compute/util_internal.h | 5 +++++ 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/cpp/src/arrow/compute/kernels/scalar_if_else.cc b/cpp/src/arrow/compute/kernels/scalar_if_else.cc index 90041f8444b3c..c38ea06b16c3a 100644 --- a/cpp/src/arrow/compute/kernels/scalar_if_else.cc +++ b/cpp/src/arrow/compute/kernels/scalar_if_else.cc @@ -16,6 +16,7 @@ // under the License. #include +#include #include #include #include @@ -738,8 +739,7 @@ struct IfElseFunctor> { // out_buf = ones ARROW_ASSIGN_OR_RAISE(out_buf, ctx->AllocateBitmap(cond.length)); // filling with UINT8_MAX upto the buffer's size (in bytes) - std::fill(out_buf->mutable_data(), out_buf->mutable_data() + out_buf->size(), - UINT8_MAX); + arrow::compute::internal::SetMemory(out_buf.get()); } else { // out_buf = cond out_buf = SliceBuffer(cond.buffers[1], cond.offset, cond.length); diff --git a/cpp/src/arrow/compute/util_internal.h b/cpp/src/arrow/compute/util_internal.h index 396c2ca2a0b38..bff4214217614 100644 --- a/cpp/src/arrow/compute/util_internal.h +++ b/cpp/src/arrow/compute/util_internal.h @@ -27,6 +27,11 @@ static inline void ZeroMemory(Buffer* buffer) { std::memset(buffer->mutable_data(), 0, buffer->size()); } +template +static inline void SetMemory(Buffer* buffer) { + std::memset(buffer->mutable_data(), ch, buffer->size()); +} + } // namespace internal } // namespace compute } // namespace arrow From 4ab179872484b50236e7b7dd33c1381f9f149e54 Mon Sep 17 00:00:00 2001 From: niranda perera Date: Thu, 27 May 2021 10:48:02 -0400 Subject: [PATCH 18/39] adding BitmapOrNot op --- .../arrow/compute/kernels/scalar_if_else.cc | 11 +++-------- cpp/src/arrow/util/bitmap_ops.cc | 18 ++++++++++++++++++ cpp/src/arrow/util/bitmap_ops.h | 19 +++++++++++++++++++ 3 files changed, 40 insertions(+), 8 deletions(-) diff --git a/cpp/src/arrow/compute/kernels/scalar_if_else.cc b/cpp/src/arrow/compute/kernels/scalar_if_else.cc index c38ea06b16c3a..92856a40e36fa 100644 --- a/cpp/src/arrow/compute/kernels/scalar_if_else.cc +++ b/cpp/src/arrow/compute/kernels/scalar_if_else.cc @@ -16,13 +16,12 @@ // under the License. #include +#include #include #include #include #include -#include "arrow/compute/kernels/codegen_internal.h" - namespace arrow { using internal::BitBlockCount; using internal::BitBlockCounter; @@ -712,12 +711,8 @@ struct IfElseFunctor> { // out_buff = left & cond | right & ~cond if (right_data) { - ARROW_ASSIGN_OR_RAISE( - std::shared_ptr tmp_buf, - arrow::internal::InvertBitmap(ctx->memory_pool(), cond.buffers[1]->data(), - cond.offset, cond.length)); - arrow::internal::BitmapOr(tmp_buf->data(), 0, out_buf->data(), 0, cond.length, 0, - out_buf->mutable_data()); + arrow::internal::BitmapOrNot(out_buf->data(), 0, cond.buffers[1]->data(), + cond.offset, cond.length, 0, out_buf->mutable_data()); } out->buffers[1] = std::move(out_buf); diff --git a/cpp/src/arrow/util/bitmap_ops.cc b/cpp/src/arrow/util/bitmap_ops.cc index 32da60aafd9f6..a27a61cadf38a 100644 --- a/cpp/src/arrow/util/bitmap_ops.cc +++ b/cpp/src/arrow/util/bitmap_ops.cc @@ -583,5 +583,23 @@ void BitmapAndNot(const uint8_t* left, int64_t left_offset, const uint8_t* right BitmapOp(left, left_offset, right, right_offset, length, out_offset, out); } +template +struct OrNotOp { + constexpr T operator()(const T& l, const T& r) const { return l | ~r; } +}; + +Result> BitmapOrNot(MemoryPool* pool, const uint8_t* left, + int64_t left_offset, const uint8_t* right, + int64_t right_offset, int64_t length, + int64_t out_offset) { + return BitmapOp(pool, left, left_offset, right, right_offset, length, + out_offset); +} + +void BitmapOrNot(const uint8_t* left, int64_t left_offset, const uint8_t* right, + int64_t right_offset, int64_t length, int64_t out_offset, uint8_t* out) { + BitmapOp(left, left_offset, right, right_offset, length, out_offset, out); +} + } // namespace internal } // namespace arrow diff --git a/cpp/src/arrow/util/bitmap_ops.h b/cpp/src/arrow/util/bitmap_ops.h index 554e1d7468b98..40a7797a2398a 100644 --- a/cpp/src/arrow/util/bitmap_ops.h +++ b/cpp/src/arrow/util/bitmap_ops.h @@ -183,5 +183,24 @@ ARROW_EXPORT void BitmapAndNot(const uint8_t* left, int64_t left_offset, const uint8_t* right, int64_t right_offset, int64_t length, int64_t out_offset, uint8_t* out); +/// \brief Do a "bitmap or not" on right and left buffers starting at +/// their respective bit-offsets for the given bit-length and put +/// the results in out_buffer starting at the given bit-offset. +/// +/// out_buffer will be allocated and initialized to zeros using pool before +/// the operation. +ARROW_EXPORT +Result> BitmapOrNot(MemoryPool* pool, const uint8_t* left, + int64_t left_offset, const uint8_t* right, + int64_t right_offset, int64_t length, + int64_t out_offset); + +/// \brief Do a "bitmap or not" on right and left buffers starting at +/// their respective bit-offsets for the given bit-length and put +/// the results in out starting at the given bit-offset. +ARROW_EXPORT +void BitmapOrNot(const uint8_t* left, int64_t left_offset, const uint8_t* right, + int64_t right_offset, int64_t length, int64_t out_offset, uint8_t* out); + } // namespace internal } // namespace arrow From 0ddb75b8f3d69e0fd2ea8ce4650b44cb666a224a Mon Sep 17 00:00:00 2001 From: niranda perera Date: Thu, 27 May 2021 12:54:52 -0400 Subject: [PATCH 19/39] extending tests --- .../arrow/compute/kernels/scalar_if_else.cc | 150 +--------------- .../compute/kernels/scalar_if_else_test.cc | 168 ++++++++++++++++-- 2 files changed, 156 insertions(+), 162 deletions(-) diff --git a/cpp/src/arrow/compute/kernels/scalar_if_else.cc b/cpp/src/arrow/compute/kernels/scalar_if_else.cc index 92856a40e36fa..bcd9aea7bfa1d 100644 --- a/cpp/src/arrow/compute/kernels/scalar_if_else.cc +++ b/cpp/src/arrow/compute/kernels/scalar_if_else.cc @@ -324,152 +324,6 @@ Status PromoteNullsVisitor(KernelContext* ctx, const ArrayData& cond, return Status::OK(); } -// nulls will be promoted as follows: -// cond.valid && (cond.data && left.valid || ~cond.data && right.valid) -// Note: we have to work on ArrayData. Otherwise we won't be able to handle array -// offsets AAA -/*Status PromoteNulls(KernelContext* ctx, const ArrayData& cond, const ArrayData& left, - const ArrayData& right, ArrayData* output) { - if (!cond.MayHaveNulls() && !left.MayHaveNulls() && !right.MayHaveNulls()) { - return Status::OK(); // no nulls to handle - } - const int64_t len = cond.length; - - // out_validity = ~cond.data --> mask right values - ARROW_ASSIGN_OR_RAISE( - std::shared_ptr out_validity, - arrow::internal::InvertBitmap(ctx->memory_pool(), cond.buffers[1]->data(), - cond.offset, len)); - - if (right.MayHaveNulls()) { // out_validity = right.valid && ~cond.data - arrow::internal::BitmapAnd(right.buffers[0]->data(), right.offset, - out_validity->data(), 0, len, 0, - out_validity->mutable_data()); - } - - std::shared_ptr tmp_buf; - if (left.MayHaveNulls()) { - // tmp_buf = left.valid && cond.data - ARROW_ASSIGN_OR_RAISE( - tmp_buf, arrow::internal::BitmapAnd(ctx->memory_pool(), left.buffers[0]->data(), - left.offset, cond.buffers[1]->data(), - cond.offset, len, 0)); - } else { // if left all valid --> tmp_buf = cond.data (zero copy slice) - tmp_buf = SliceBuffer(cond.buffers[1], cond.offset, cond.length); - } - - // out_validity = cond.data && left.valid || ~cond.data && right.valid - arrow::internal::BitmapOr(out_validity->data(), 0, tmp_buf->data(), 0, len, 0, - out_validity->mutable_data()); - - if (cond.MayHaveNulls()) { - // out_validity = cond.valid && (cond.data && left.valid || ~cond.data && right.valid) - ::arrow::internal::BitmapAnd(out_validity->data(), 0, cond.buffers[0]->data(), - cond.offset, len, 0, out_validity->mutable_data()); - } - - output->buffers[0] = std::move(out_validity); - output->GetNullCount(); // update null count - return Status::OK(); -} - -// cond.valid && (cond.data && left.valid || ~cond.data && right.valid) -// ASA and AAS -Status PromoteNulls(KernelContext* ctx, const ArrayData& cond, const Scalar& left, - const ArrayData& right, ArrayData* output) { - if (!cond.MayHaveNulls() && left.is_valid && !right.MayHaveNulls()) { - return Status::OK(); // no nulls to handle - } - const int64_t len = cond.length; - - // out_validity = ~cond.data - ARROW_ASSIGN_OR_RAISE( - std::shared_ptr out_validity, - arrow::internal::InvertBitmap(ctx->memory_pool(), cond.buffers[1]->data(), - cond.offset, len)); - // out_validity = ~cond.data && right.valid - if (right.MayHaveNulls()) { // out_validity = right.valid && ~cond.data - arrow::internal::BitmapAnd(right.buffers[0]->data(), right.offset, - out_validity->data(), 0, len, 0, - out_validity->mutable_data()); - } - - // out_validity = cond.data && left.valid || ~cond.data && right.valid - if (left.is_valid) { - arrow::internal::BitmapOr(out_validity->data(), 0, cond.buffers[1]->data(), - cond.offset, len, 0, out_validity->mutable_data()); - } - - // out_validity = cond.valid && (cond.data && left.valid || ~cond.data && right.valid) - if (cond.MayHaveNulls()) { - ::arrow::internal::BitmapAnd(out_validity->data(), 0, cond.buffers[0]->data(), - cond.offset, len, 0, out_validity->mutable_data()); - } - - output->buffers[0] = std::move(out_validity); - output->GetNullCount(); // update null count - return Status::OK(); -} - -// cond.valid && (cond.data && left.valid || ~cond.data && right.valid) -// ASS -Status PromoteNulls(KernelContext* ctx, const ArrayData& cond, const Scalar& left, - const Scalar& right, ArrayData* output) { - if (!cond.MayHaveNulls() && left.is_valid && right.is_valid) { - return Status::OK(); // no nulls to handle - } - const int64_t len = cond.length; - - std::shared_ptr out_validity; - if (right.is_valid) { - // out_validity = ~cond.data - ARROW_ASSIGN_OR_RAISE( - out_validity, arrow::internal::InvertBitmap( - ctx->memory_pool(), cond.buffers[1]->data(), cond.offset, len)); - } else { - // out_validity = [0...] - ARROW_ASSIGN_OR_RAISE(out_validity, ctx->AllocateBitmap(len)); - } - - // out_validity = cond.data && left.valid || ~cond.data && right.valid - if (left.is_valid) { - arrow::internal::BitmapOr(out_validity->data(), 0, cond.buffers[1]->data(), - cond.offset, len, 0, out_validity->mutable_data()); - } - - // out_validity = cond.valid && (cond.data && left.valid || ~cond.data && right.valid) - if (cond.MayHaveNulls()) { - ::arrow::internal::BitmapAnd(out_validity->data(), 0, cond.buffers[0]->data(), - cond.offset, len, 0, out_validity->mutable_data()); - } - - output->buffers[0] = std::move(out_validity); - output->GetNullCount(); // update null count - return Status::OK(); -} - -// todo: this could be dangerous because the inverted arraydata buffer[1] may not be -// available outside Exec's scope -Status InvertBoolArrayData(KernelContext* ctx, const ArrayData& input, - ArrayData* output) { - // null buffer - if (input.MayHaveNulls()) { - output->buffers.emplace_back( - SliceBuffer(input.buffers[0], input.offset, input.length)); - } else { - output->buffers.push_back(NULLPTR); - } - - // data buffer - ARROW_ASSIGN_OR_RAISE( - std::shared_ptr inv_data, - arrow::internal::InvertBitmap(ctx->memory_pool(), input.buffers[1]->data(), - input.offset, input.length)); - output->buffers.emplace_back(std::move(inv_data)); - return Status::OK(); -} - */ - template struct IfElseFunctor {}; @@ -795,7 +649,6 @@ struct ResolveIfElseExec { // cond is scalar if (batch[0].is_scalar()) { const auto& cond = batch[0].scalar_as(); - if (batch[1].is_scalar() && batch[2].is_scalar()) { if (cond.is_valid) { *out = cond.value ? batch[1].scalar() : batch[2].scalar(); @@ -803,7 +656,8 @@ struct ResolveIfElseExec { *out = MakeNullScalar(batch[1].type()); } } else { // either left or right is an array. output is always an array - int64_t bcast_size = std::max(batch[1].length(), batch[2].length()); + // output size is the size of the array arg + int64_t bcast_size = batch[1].is_array() ? batch[1].length() : batch[2].length(); if (cond.is_valid) { const auto& valid_data = cond.value ? batch[1] : batch[2]; if (valid_data.is_array()) { diff --git a/cpp/src/arrow/compute/kernels/scalar_if_else_test.cc b/cpp/src/arrow/compute/kernels/scalar_if_else_test.cc index 0eb87d9de264a..def502e34402d 100644 --- a/cpp/src/arrow/compute/kernels/scalar_if_else_test.cc +++ b/cpp/src/arrow/compute/kernels/scalar_if_else_test.cc @@ -24,15 +24,19 @@ namespace arrow { namespace compute { -void CheckIfElseOutputArray(const Datum& cond, const Datum& left, const Datum& right, - const Datum& expected) { +void CheckIfElseOutput(const Datum& cond, const Datum& left, const Datum& right, + const Datum& expected) { ASSERT_OK_AND_ASSIGN(Datum datum_out, IfElse(cond, left, right)); - std::shared_ptr result = datum_out.make_array(); - ASSERT_OK(result->ValidateFull()); - std::shared_ptr expected_ = expected.make_array(); - AssertArraysEqual(*expected.make_array(), *result, /*verbose=*/true); - - ASSERT_EQ(result->data()->null_count, expected_->data()->null_count); + if (datum_out.is_array()) { + std::shared_ptr result = datum_out.make_array(); + ASSERT_OK(result->ValidateFull()); + std::shared_ptr expected_ = expected.make_array(); + AssertArraysEqual(*expected_, *result, /*verbose=*/true); + } else { // expecting scalar + const std::shared_ptr& result = datum_out.scalar(); + const std::shared_ptr& expected_ = expected.scalar(); + AssertScalarsEqual(*expected_, *result, /*verbose=*/true); + } } void CheckIfElseOutputAAA(const std::shared_ptr& type, const std::string& cond, @@ -42,7 +46,7 @@ void CheckIfElseOutputAAA(const std::shared_ptr& type, const std::stri const std::shared_ptr& left_ = ArrayFromJSON(type, left); const std::shared_ptr& right_ = ArrayFromJSON(type, right); const std::shared_ptr& expected_ = ArrayFromJSON(type, expected); - CheckIfElseOutputArray(cond_, left_, right_, expected_); + CheckIfElseOutput(cond_, left_, right_, expected_); } void CheckIfElseOutputAAS(const std::shared_ptr& type, const std::string& cond, @@ -51,7 +55,7 @@ void CheckIfElseOutputAAS(const std::shared_ptr& type, const std::stri const std::shared_ptr& cond_ = ArrayFromJSON(boolean(), cond); const std::shared_ptr& left_ = ArrayFromJSON(type, left); const std::shared_ptr& expected_ = ArrayFromJSON(type, expected); - CheckIfElseOutputArray(cond_, left_, right, expected_); + CheckIfElseOutput(cond_, left_, right, expected_); } void CheckIfElseOutputASA(const std::shared_ptr& type, const std::string& cond, @@ -60,7 +64,7 @@ void CheckIfElseOutputASA(const std::shared_ptr& type, const std::stri const std::shared_ptr& cond_ = ArrayFromJSON(boolean(), cond); const std::shared_ptr& right_ = ArrayFromJSON(type, right); const std::shared_ptr& expected_ = ArrayFromJSON(type, expected); - CheckIfElseOutputArray(cond_, left, right_, expected_); + CheckIfElseOutput(cond_, left, right_, expected_); } void CheckIfElseOutputASS(const std::shared_ptr& type, const std::string& cond, @@ -69,7 +73,34 @@ void CheckIfElseOutputASS(const std::shared_ptr& type, const std::stri const std::string& expected) { const std::shared_ptr& cond_ = ArrayFromJSON(boolean(), cond); const std::shared_ptr& expected_ = ArrayFromJSON(type, expected); - CheckIfElseOutputArray(cond_, left, right, expected_); + CheckIfElseOutput(cond_, left, right, expected_); +} + +void CheckIfElseOutputSAA(const std::shared_ptr& type, + const std::shared_ptr& cond, const std::string& left, + const std::string& right, const std::string& expected) { + const std::shared_ptr& left_ = ArrayFromJSON(type, left); + const std::shared_ptr& right_ = ArrayFromJSON(type, right); + const std::shared_ptr& expected_ = ArrayFromJSON(type, expected); + CheckIfElseOutput(cond, left_, right_, expected_); +} + +void CheckIfElseOutputSAS(const std::shared_ptr& type, + const std::shared_ptr& cond, const std::string& left, + const std::shared_ptr& right, + const std::string& expected) { + const std::shared_ptr& left_ = ArrayFromJSON(type, left); + const std::shared_ptr& expected_ = ArrayFromJSON(type, expected); + CheckIfElseOutput(cond, left_, right, expected_); +} + +void CheckIfElseOutputSSA(const std::shared_ptr& type, + const std::shared_ptr& cond, + const std::shared_ptr& left, const std::string& right, + const std::string& expected) { + const std::shared_ptr& right_ = ArrayFromJSON(type, right); + const std::shared_ptr& expected_ = ArrayFromJSON(type, expected); + CheckIfElseOutput(cond, left, right_, expected_); } class TestIfElseKernel : public ::testing::Test {}; @@ -113,7 +144,7 @@ TYPED_TEST(TestIfElsePrimitive, IfElseFixedSizeRand) { } ASSERT_OK_AND_ASSIGN(auto expected_data, builder.Finish()); - CheckIfElseOutputArray(cond, left, right, expected_data); + CheckIfElseOutput(cond, left, right, expected_data); } /* @@ -244,6 +275,112 @@ TYPED_TEST(TestIfElsePrimitive, IfElseFixedSize) { // CLR = 000 CheckIfElseOutputASS(type, "[true, true, null, false]", null_scalar, null_scalar, "[null, null, null, null]"); + + // -------- Cond - Scalar, Left- Array, Right - Array --------- + ASSERT_OK_AND_ASSIGN(std::shared_ptr bool_true, MakeScalar(boolean(), true)); + ASSERT_OK_AND_ASSIGN(std::shared_ptr bool_false, MakeScalar(boolean(), false)); + std::shared_ptr bool_null = MakeNullScalar(boolean()); + + // empty + CheckIfElseOutputSAA(type, bool_true, "[]", "[]", "[]"); + // CLR = 111 + CheckIfElseOutputSAA(type, bool_true, "[1, 2, 3, 4]", "[5, 6, 7, 8]", "[1, 2, 3, 4]"); + // CLR = 011 + CheckIfElseOutputSAA(type, bool_null, "[1, 2, 3, 4]", "[5, 6, 7, 8]", + "[null, null, null, null]"); + // CLR = 101 + CheckIfElseOutputSAA(type, bool_false, "[1, null, 3, 4]", "[5, 6, 7, 8]", + "[5, 6, 7, 8]"); + // CLR = 001 + CheckIfElseOutputSAA(type, bool_null, "[1, null, 3, 4]", "[5, 6, 7, 8]", + "[null, null, null, null]"); + // CLR = 110 + CheckIfElseOutputSAA(type, bool_false, "[1, 2, 3, 4]", "[5, 6, 7, null]", + "[5, 6, 7, null]"); + // CLR = 010 + CheckIfElseOutputSAA(type, bool_null, "[1, 2, 3, 4]", "[5, 6, 7, null]", + "[null, null, null, null]"); + // CLR = 100 + CheckIfElseOutputSAA(type, bool_true, "[1, 2, null, null]", "[null, 6, 7, null]", + "[1, 2, null, null]"); + // CLR = 000 + CheckIfElseOutputSAA(type, bool_null, "[1, 2, null, null]", "[null, 6, 7, null]", + "[null, null, null, null]"); + + // -------- Cond - Scalar, Left- Array, Right - Scalar --------- + // empty + CheckIfElseOutputSAS(type, bool_true, "[]", valid_scalar, "[]"); + + // CLR = 111 + CheckIfElseOutputSAS(type, bool_true, "[1, 2, 3, 4]", valid_scalar, "[1, 2, 3, 4]"); + // CLR = 011 + CheckIfElseOutputSAS(type, bool_null, "[1, 2, 3, 4]", valid_scalar, + "[null, null, null, null]"); + // CLR = 101 + CheckIfElseOutputSAS(type, bool_false, "[1, null, 3, 4]", valid_scalar, + "[100, 100, 100, 100]"); + // CLR = 001 + CheckIfElseOutputSAS(type, bool_null, "[1, null, 3, 4]", valid_scalar, + "[null, null, null, null]"); + // CLR = 110 + CheckIfElseOutputSAS(type, bool_true, "[1, 2, 3, 4]", null_scalar, "[1, 2, 3, 4]"); + // CLR = 010 + CheckIfElseOutputSAS(type, bool_null, "[1, 2, 3, 4]", null_scalar, + "[null, null, null, null]"); + // CLR = 100 + CheckIfElseOutputSAS(type, bool_false, "[1, 2, null, null]", null_scalar, + "[null, null, null, null]"); + // CLR = 000 + CheckIfElseOutputSAS(type, bool_null, "[1, 2, null, null]", null_scalar, + "[null, null, null, null]"); + + // -------- Cond - Scalar, Left- Scalar, Right - Array --------- + // empty + CheckIfElseOutputSSA(type, bool_true, valid_scalar, "[]", "[]"); + + // CLR = 111 + CheckIfElseOutputSSA(type, bool_true, valid_scalar, "[1, 2, 3, 4]", + "[100, 100, 100, 100]"); + // CLR = 011 + CheckIfElseOutputSSA(type, bool_null, valid_scalar, "[1, 2, 3, 4]", + "[null, null, null, null]"); + // CLR = 110 + CheckIfElseOutputSSA(type, bool_false, valid_scalar, "[1, null, 3, null]", + "[1, null, 3, null]"); + // CLR = 010 + CheckIfElseOutputSSA(type, bool_null, valid_scalar, "[1, null, 3, null]", + "[null, null, null, null]"); + // CLR = 101 + CheckIfElseOutputSSA(type, bool_true, null_scalar, "[1, 2, 3, 4]", + "[null, null, null, null]"); + // CLR = 001 + CheckIfElseOutputSSA(type, bool_null, null_scalar, "[1, 2, 3, 4]", + "[null, null, null, null]"); + // CLR = 100 + CheckIfElseOutputSSA(type, bool_false, null_scalar, "[1, 2, null, 4]", + "[1, 2, null, 4]"); + // CLR = 000 + CheckIfElseOutputSSA(type, bool_null, null_scalar, "[1, 2, null, 4]", + "[null, null, null, null]"); + + // -------- Cond - Scalar, Left- Scalar, Right - Scalar --------- + + // CLR = 111 + CheckIfElseOutput(bool_false, valid_scalar, valid_scalar1, valid_scalar1); + // CLR = 011 + CheckIfElseOutput(bool_null, valid_scalar, valid_scalar1, null_scalar); + // CLR = 110 + CheckIfElseOutput(bool_true, valid_scalar, null_scalar, valid_scalar); + // CLR = 010 + CheckIfElseOutput(bool_null, valid_scalar, null_scalar, null_scalar); + // CLR = 101 + CheckIfElseOutput(bool_false, null_scalar, valid_scalar1, valid_scalar1); + // CLR = 001 + CheckIfElseOutput(bool_null, null_scalar, valid_scalar1, null_scalar); + // CLR = 100 + CheckIfElseOutput(bool_true, null_scalar, null_scalar, null_scalar); + // CLR = 000 + CheckIfElseOutput(bool_null, null_scalar, null_scalar, null_scalar); } TEST_F(TestIfElseKernel, IfElseBoolean) { @@ -259,7 +396,10 @@ TEST_F(TestIfElseKernel, IfElseBoolean) { CheckIfElseOutputAAA(type, "[true, true, true, false]", "[true, false, null, null]", "[null, false, true, null]", "[true, false, null, null]"); +} +TYPED_TEST(TestIfElsePrimitive, IfElseBooleanRand) { + auto type = boolean(); random::RandomArrayGenerator rand(/*seed=*/0); int64_t len = 1000; auto cond = std::static_pointer_cast( @@ -285,7 +425,7 @@ TEST_F(TestIfElseKernel, IfElseBoolean) { } ASSERT_OK_AND_ASSIGN(auto expected_data, builder.Finish()); - CheckIfElseOutputArray(cond, left, right, expected_data); + CheckIfElseOutput(cond, left, right, expected_data); } TEST_F(TestIfElseKernel, IfElseNull) { From 7d62139d423fd0fc40fa6d8c3231b168987ec54c Mon Sep 17 00:00:00 2001 From: niranda perera Date: Thu, 27 May 2021 13:36:39 -0400 Subject: [PATCH 20/39] extending tests for boolean type --- .../compute/kernels/scalar_if_else_test.cc | 229 +++++++++++++++++- 1 file changed, 224 insertions(+), 5 deletions(-) diff --git a/cpp/src/arrow/compute/kernels/scalar_if_else_test.cc b/cpp/src/arrow/compute/kernels/scalar_if_else_test.cc index def502e34402d..a9eea1bc50bbc 100644 --- a/cpp/src/arrow/compute/kernels/scalar_if_else_test.cc +++ b/cpp/src/arrow/compute/kernels/scalar_if_else_test.cc @@ -385,17 +385,236 @@ TYPED_TEST(TestIfElsePrimitive, IfElseFixedSize) { TEST_F(TestIfElseKernel, IfElseBoolean) { auto type = boolean(); - // No Nulls - CheckIfElseOutputAAA(type, "[]", "[]", "[]", "[]"); + // -------- All arrays --------- + // empty + CheckIfElseOutputAAA(type, "[]", "[]", "[]", "[]"); + // CLR = 111 CheckIfElseOutputAAA(type, "[true, true, true, false]", "[false, false, false, false]", "[true, true, true, true]", "[false, false, false, true]"); - + // CLR = 011 CheckIfElseOutputAAA(type, "[true, true, null, false]", "[false, false, false, false]", "[true, true, true, true]", "[false, false, null, true]"); + // CLR = 101 + CheckIfElseOutputAAA(type, "[true, true, true, false]", "[false, null, false, false]", + "[true, true, true, true]", "[false, null, false, true]"); + // CLR = 001 + CheckIfElseOutputAAA(type, "[true, true, null, false]", "[false, null, false, false]", + "[true, true, true, true]", "[false, null, null, true]"); + // CLR = 110 + CheckIfElseOutputAAA(type, "[true, true, true, false]", "[false, false, false, false]", + "[true, true, true, null]", "[false, false, false, null]"); + // CLR = 010 + CheckIfElseOutputAAA(type, "[null, true, true, false]", "[false, false, false, false]", + "[true, true, true, null]", "[null, false, false, null]"); + // CLR = 100 + CheckIfElseOutputAAA(type, "[true, true, true, false]", "[false, false, null, null]", + "[null, true, true, null]", "[false, false, null, null]"); + // CLR = 000 + CheckIfElseOutputAAA(type, "[null, true, true, false]", "[false, false, null, null]", + "[null, true, true, null]", "[null, false, null, null]"); + + // -------- Cond - Array, Left- Array, Right - Scalar --------- + + ASSERT_OK_AND_ASSIGN(std::shared_ptr valid_scalar, MakeScalar(type, false)); + std::shared_ptr null_scalar = MakeNullScalar(type); + + // empty + CheckIfElseOutputAAS(type, "[]", "[]", valid_scalar, "[]"); + + // CLR = 111 + CheckIfElseOutputAAS(type, "[true, true, true, false]", "[false, false, false, false]", + valid_scalar, "[false, false, false, false]"); + // CLR = 011 + CheckIfElseOutputAAS(type, "[true, true, null, false]", "[false, false, false, false]", + valid_scalar, "[false, false, null, false]"); + // CLR = 101 + CheckIfElseOutputAAS(type, "[true, true, true, false]", "[false, null, false, false]", + valid_scalar, "[false, null, false, false]"); + // CLR = 001 + CheckIfElseOutputAAS(type, "[true, true, null, false]", "[false, null, false, false]", + valid_scalar, "[false, null, null, false]"); + // CLR = 110 + CheckIfElseOutputAAS(type, "[true, true, true, false]", "[false, false, false, false]", + null_scalar, "[false, false, false, null]"); + // CLR = 010 + CheckIfElseOutputAAS(type, "[null, true, true, false]", "[false, false, false, false]", + null_scalar, "[null, false, false, null]"); + // CLR = 100 + CheckIfElseOutputAAS(type, "[true, true, true, false]", "[false, false, null, null]", + null_scalar, "[false, false, null, null]"); + // CLR = 000 + CheckIfElseOutputAAS(type, "[null, true, true, false]", "[false, false, null, null]", + null_scalar, "[null, false, null, null]"); + + // -------- Cond - Array, Left- Scalar, Right - Array --------- + // empty + CheckIfElseOutputASA(type, "[]", valid_scalar, "[]", "[]"); + + // CLR = 111 + CheckIfElseOutputASA(type, "[true, true, true, false]", valid_scalar, + "[false, false, false, false]", "[false, false, false, false]"); + // CLR = 011 + CheckIfElseOutputASA(type, "[true, true, null, false]", valid_scalar, + "[false, false, false, false]", "[false, false, null, false]"); + // CLR = 110 + CheckIfElseOutputASA(type, "[true, true, true, false]", valid_scalar, + "[false, null, false, null]", "[false, false, false, null]"); + // CLR = 010 + CheckIfElseOutputASA(type, "[true, true, null, false]", valid_scalar, + "[false, null, false, null]", "[false, false, null, null]"); + // CLR = 101 + CheckIfElseOutputASA(type, "[true, true, true, false]", null_scalar, + "[false, false, false, false]", "[null, null, null, false]"); + // CLR = 001 + CheckIfElseOutputASA(type, "[null, true, true, false]", null_scalar, + "[false, false, false, false]", "[null, null, null, false]"); + // CLR = 100 + CheckIfElseOutputASA(type, "[true, true, true, false]", null_scalar, + "[false, false, null, false]", "[null, null, null, false]"); + // CLR = 000 + CheckIfElseOutputASA(type, "[true, true, null, false]", null_scalar, + "[false, false, null, false]", "[null, null, null, false]"); + + // -------- Cond - Array, Left- Scalar, Right - Scalar --------- + ASSERT_OK_AND_ASSIGN(std::shared_ptr valid_scalar1, MakeScalar(type, true)); + + // empty + CheckIfElseOutputASS(type, "[]", valid_scalar, valid_scalar1, "[]"); + + // CLR = 111 + CheckIfElseOutputASS(type, "[true, true, true, false]", valid_scalar, valid_scalar1, + "[false, false, false, true]"); + // CLR = 011 + CheckIfElseOutputASS(type, "[true, true, null, false]", valid_scalar, valid_scalar1, + "[false, false, null, true]"); + // CLR = 010 + CheckIfElseOutputASS(type, "[true, true, null, false]", valid_scalar, null_scalar, + "[false, false, null, null]"); + // CLR = 110 + CheckIfElseOutputASS(type, "[true, true, true, false]", valid_scalar, null_scalar, + "[false, false, false, null]"); + // CLR = 101 + CheckIfElseOutputASS(type, "[true, true, true, false]", null_scalar, valid_scalar1, + "[null, null, null, true]"); + // CLR = 001 + CheckIfElseOutputASS(type, "[null, true, true, false]", null_scalar, valid_scalar1, + "[null, null, null, true]"); + // CLR = 100 + CheckIfElseOutputASS(type, "[true, true, true, false]", null_scalar, null_scalar, + "[null, null, null, null]"); + // CLR = 000 + CheckIfElseOutputASS(type, "[true, true, null, false]", null_scalar, null_scalar, + "[null, null, null, null]"); + + // -------- Cond - Scalar, Left- Array, Right - Array --------- + ASSERT_OK_AND_ASSIGN(std::shared_ptr bool_true, MakeScalar(type, true)); + ASSERT_OK_AND_ASSIGN(std::shared_ptr bool_false, MakeScalar(type, false)); + std::shared_ptr bool_null = MakeNullScalar(type); + + // empty + CheckIfElseOutputSAA(type, bool_true, "[]", "[]", "[]"); + // CLR = 111 + CheckIfElseOutputSAA(type, bool_true, "[false, false, false, false]", + "[true, true, true, true]", "[false, false, false, false]"); + // CLR = 011 + CheckIfElseOutputSAA(type, bool_null, "[false, false, false, false]", + "[true, true, true, true]", "[null, null, null, null]"); + // CLR = 101 + CheckIfElseOutputSAA(type, bool_false, "[false, null, false, false]", + "[true, true, true, true]", "[true, true, true, true]"); + // CLR = 001 + CheckIfElseOutputSAA(type, bool_null, "[false, null, false, false]", + "[true, true, true, true]", "[null, null, null, null]"); + // CLR = 110 + CheckIfElseOutputSAA(type, bool_false, "[false, false, false, false]", + "[true, true, true, null]", "[true, true, true, null]"); + // CLR = 010 + CheckIfElseOutputSAA(type, bool_null, "[false, false, false, false]", + "[true, true, true, null]", "[null, null, null, null]"); + // CLR = 100 + CheckIfElseOutputSAA(type, bool_true, "[false, false, null, null]", + "[null, true, true, null]", "[false, false, null, null]"); + // CLR = 000 + CheckIfElseOutputSAA(type, bool_null, "[false, false, null, null]", + "[null, true, true, null]", "[null, null, null, null]"); - CheckIfElseOutputAAA(type, "[true, true, true, false]", "[true, false, null, null]", - "[null, false, true, null]", "[true, false, null, null]"); + // -------- Cond - Scalar, Left- Array, Right - Scalar --------- + // empty + CheckIfElseOutputSAS(type, bool_true, "[]", valid_scalar, "[]"); + + // CLR = 111 + CheckIfElseOutputSAS(type, bool_true, "[false, false, false, false]", valid_scalar, + "[false, false, false, false]"); + // CLR = 011 + CheckIfElseOutputSAS(type, bool_null, "[false, false, false, false]", valid_scalar, + "[null, null, null, null]"); + // CLR = 101 + CheckIfElseOutputSAS(type, bool_false, "[false, null, false, false]", valid_scalar, + "[false, false, false, false]"); + // CLR = 001 + CheckIfElseOutputSAS(type, bool_null, "[false, null, false, false]", valid_scalar, + "[null, null, null, null]"); + // CLR = 110 + CheckIfElseOutputSAS(type, bool_true, "[false, false, false, false]", null_scalar, + "[false, false, false, false]"); + // CLR = 010 + CheckIfElseOutputSAS(type, bool_null, "[false, false, false, false]", null_scalar, + "[null, null, null, null]"); + // CLR = 100 + CheckIfElseOutputSAS(type, bool_false, "[false, false, null, null]", null_scalar, + "[null, null, null, null]"); + // CLR = 000 + CheckIfElseOutputSAS(type, bool_null, "[false, false, null, null]", null_scalar, + "[null, null, null, null]"); + + // -------- Cond - Scalar, Left- Scalar, Right - Array --------- + // empty + CheckIfElseOutputSSA(type, bool_true, valid_scalar, "[]", "[]"); + + // CLR = 111 + CheckIfElseOutputSSA(type, bool_true, valid_scalar, "[false, false, false, false]", + "[false, false, false, false]"); + // CLR = 011 + CheckIfElseOutputSSA(type, bool_null, valid_scalar, "[false, false, false, false]", + "[null, null, null, null]"); + // CLR = 110 + CheckIfElseOutputSSA(type, bool_false, valid_scalar, "[false, null, false, null]", + "[false, null, false, null]"); + // CLR = 010 + CheckIfElseOutputSSA(type, bool_null, valid_scalar, "[false, null, false, null]", + "[null, null, null, null]"); + // CLR = 101 + CheckIfElseOutputSSA(type, bool_true, null_scalar, "[false, false, false, false]", + "[null, null, null, null]"); + // CLR = 001 + CheckIfElseOutputSSA(type, bool_null, null_scalar, "[false, false, false, false]", + "[null, null, null, null]"); + // CLR = 100 + CheckIfElseOutputSSA(type, bool_false, null_scalar, "[false, false, null, false]", + "[false, false, null, false]"); + // CLR = 000 + CheckIfElseOutputSSA(type, bool_null, null_scalar, "[false, false, null, false]", + "[null, null, null, null]"); + + // -------- Cond - Scalar, Left- Scalar, Right - Scalar --------- + + // CLR = 111 + CheckIfElseOutput(bool_false, valid_scalar, valid_scalar1, valid_scalar1); + // CLR = 011 + CheckIfElseOutput(bool_null, valid_scalar, valid_scalar1, null_scalar); + // CLR = 110 + CheckIfElseOutput(bool_true, valid_scalar, null_scalar, valid_scalar); + // CLR = 010 + CheckIfElseOutput(bool_null, valid_scalar, null_scalar, null_scalar); + // CLR = 101 + CheckIfElseOutput(bool_false, null_scalar, valid_scalar1, valid_scalar1); + // CLR = 001 + CheckIfElseOutput(bool_null, null_scalar, valid_scalar1, null_scalar); + // CLR = 100 + CheckIfElseOutput(bool_true, null_scalar, null_scalar, null_scalar); + // CLR = 000 + CheckIfElseOutput(bool_null, null_scalar, null_scalar, null_scalar); } TYPED_TEST(TestIfElsePrimitive, IfElseBooleanRand) { From 874b0380425715f29d18cfb2662d779c99e355a9 Mon Sep 17 00:00:00 2001 From: niranda perera Date: Thu, 27 May 2021 16:14:11 -0400 Subject: [PATCH 21/39] making tests extensible --- .../compute/kernels/scalar_if_else_test.cc | 728 +++++++----------- 1 file changed, 273 insertions(+), 455 deletions(-) diff --git a/cpp/src/arrow/compute/kernels/scalar_if_else_test.cc b/cpp/src/arrow/compute/kernels/scalar_if_else_test.cc index a9eea1bc50bbc..a7c985bf4169c 100644 --- a/cpp/src/arrow/compute/kernels/scalar_if_else_test.cc +++ b/cpp/src/arrow/compute/kernels/scalar_if_else_test.cc @@ -147,6 +147,277 @@ TYPED_TEST(TestIfElsePrimitive, IfElseFixedSizeRand) { CheckIfElseOutput(cond, left, right, expected_data); } +#define IF_ELSE_TEST_GEN(type, l0, l1, l2, l3, r0, r1, r2, r3, valid, valid1) \ + do { \ + /* -------- All arrays --------- */ \ + /* empty */ \ + CheckIfElseOutputAAA((type), "[]", "[]", "[]", "[]"); \ + /* CLR = 111 */ \ + CheckIfElseOutputAAA( \ + (type), "[true, true, true, false]", "[" #l0 ", " #l1 ", " #l2 ", " #l3 "]", \ + "[" #r0 ", " #r1 ", " #r2 ", " #r3 "]", "[" #l0 ", " #l1 ", " #l2 ", " #r3 "]"); \ + /* CLR = 011 */ \ + CheckIfElseOutputAAA( \ + (type), "[true, true, null, false]", "[" #l0 ", " #l1 ", " #l2 ", " #l3 "]", \ + "[" #r0 ", " #r1 ", " #r2 ", " #r3 "]", "[" #l0 ", " #l1 ", null, " #r3 "]"); \ + /* CLR = 101 */ \ + CheckIfElseOutputAAA( \ + (type), "[true, true, true, false]", "[" #l0 ", null, " #l2 ", " #l3 "]", \ + "[" #r0 ", " #r1 ", " #r2 ", " #r3 "]", "[" #l0 ", null, " #l2 ", " #r3 "]"); \ + /* CLR = 001 */ \ + CheckIfElseOutputAAA( \ + (type), "[true, true, null, false]", "[" #l0 ", null, " #l2 ", " #l3 "]", \ + "[" #r0 ", " #r1 ", " #r2 ", " #r3 "]", "[" #l0 ", null, null, " #r3 "]"); \ + /* CLR = 110 */ \ + CheckIfElseOutputAAA( \ + (type), "[true, true, true, false]", "[" #l0 ", " #l1 ", " #l2 ", " #l3 "]", \ + "[" #r0 ", " #r1 ", " #r2 ", null]", "[" #l0 ", " #l1 ", " #l2 ", null]"); \ + /* CLR = 010 */ \ + CheckIfElseOutputAAA( \ + (type), "[null, true, true, false]", "[" #l0 ", " #l1 ", " #l2 ", " #l3 "]", \ + "[" #r0 ", " #r1 ", " #r2 ", null]", "[null, " #l1 ", " #l2 ", null]"); \ + /* CLR = 100 */ \ + CheckIfElseOutputAAA( \ + (type), "[true, true, true, false]", "[" #l0 ", " #l1 ", null, null]", \ + "[null, " #r1 ", " #r2 ", null]", "[" #l0 ", " #l1 ", null, null]"); \ + /* CLR = 000 */ \ + CheckIfElseOutputAAA( \ + (type), "[null, true, true, false]", "[" #l0 ", " #l1 ", null, null]", \ + "[null, " #r1 ", " #r2 ", null]", "[null, " #l1 ", null, null]"); \ + \ + /* -------- Cond - Array, Left- Array, Right - Scalar --------- */ \ + auto valid_scalar = MakeScalar((type), (valid)).ValueOrDie(); \ + auto valid_scalar1 = MakeScalar((type), (valid1)).ValueOrDie(); \ + std::shared_ptr null_scalar = MakeNullScalar(type); \ + \ + /* empty */ \ + CheckIfElseOutputAAS((type), "[]", "[]", valid_scalar, "[]"); \ + \ + /* CLR = 111 */ \ + CheckIfElseOutputAAS((type), "[true, true, true, false]", \ + "[" #l0 ", " #l1 ", " #l2 ", " #l3 "]", valid_scalar, \ + "[" #l0 ", " #l1 ", " #l2 ", " #valid "]"); \ + /* CLR = 011 */ \ + CheckIfElseOutputAAS((type), "[true, true, null, false]", \ + "[" #l0 ", " #l1 ", " #l2 ", " #l3 "]", valid_scalar, \ + "[" #l0 ", " #l1 ", null, " #valid "]"); \ + /* CLR = 101 */ \ + CheckIfElseOutputAAS((type), "[true, true, true, false]", \ + "[" #l0 ", null, " #l2 ", " #l3 "]", valid_scalar, \ + "[" #l0 ", null, " #l2 ", " #valid "]"); \ + /* CLR = 001 */ \ + CheckIfElseOutputAAS((type), "[true, true, null, false]", \ + "[" #l0 ", null, " #l2 ", " #l3 "]", valid_scalar, \ + "[" #l0 ", null, null, " #valid "]"); \ + /* CLR = 110 */ \ + CheckIfElseOutputAAS((type), "[true, true, true, false]", \ + "[" #l0 ", " #l1 ", " #l2 ", " #l3 "]", null_scalar, \ + "[" #l0 ", " #l1 ", " #l2 ", null]"); \ + /* CLR = 010 */ \ + CheckIfElseOutputAAS((type), "[null, true, true, false]", \ + "[" #l0 ", " #l1 ", " #l2 ", " #l3 "]", null_scalar, \ + "[null, " #l1 ", " #l2 ", null]"); \ + /* CLR = 100 */ \ + CheckIfElseOutputAAS((type), "[true, true, true, false]", \ + "[" #l0 ", " #l1 ", null, null]", null_scalar, \ + "[" #l0 ", " #l1 ", null, null]"); \ + /* CLR = 000 */ \ + CheckIfElseOutputAAS((type), "[null, true, true, false]", \ + "[" #l0 ", " #l1 ", null, null]", null_scalar, \ + "[null, " #l1 ", null, null]"); \ + \ + /* -------- Cond - Array, Left- Scalar, Right - Array --------- */ \ + /* empty */ \ + CheckIfElseOutputASA((type), "[]", valid_scalar, "[]", "[]"); \ + \ + /* CLR = 111 */ \ + CheckIfElseOutputASA((type), "[true, true, true, false]", valid_scalar, \ + "[" #l0 ", " #l1 ", " #l2 ", " #l3 "]", \ + "[" #valid ", " #valid ", " #valid ", " #l3 "]"); \ + /* CLR = 011 */ \ + CheckIfElseOutputASA((type), "[true, true, null, false]", valid_scalar, \ + "[" #l0 ", " #l1 ", " #l2 ", " #l3 "]", \ + "[" #valid ", " #valid ", null, " #l3 "]"); \ + /* CLR = 110 */ \ + CheckIfElseOutputASA((type), "[true, true, true, false]", valid_scalar, \ + "[" #l0 ", null, " #l2 ", null]", \ + "[" #valid ", " #valid ", " #valid ", null]"); \ + /* CLR = 010 */ \ + CheckIfElseOutputASA((type), "[true, true, null, false]", valid_scalar, \ + "[" #l0 ", null, " #l2 ", null]", \ + "[" #valid ", " #valid ", null, null]"); \ + /* CLR = 101 */ \ + CheckIfElseOutputASA((type), "[true, true, true, false]", null_scalar, \ + "[" #l0 ", " #l1 ", " #l2 ", " #l3 "]", \ + "[null, null, null, " #l3 "]"); \ + /* CLR = 001 */ \ + CheckIfElseOutputASA((type), "[null, true, true, false]", null_scalar, \ + "[" #l0 ", " #l1 ", " #l2 ", " #l3 "]", \ + "[null, null, null, " #l3 "]"); \ + /* CLR = 100 */ \ + CheckIfElseOutputASA((type), "[true, true, true, false]", null_scalar, \ + "[" #l0 ", " #l1 ", null, " #l3 "]", \ + "[null, null, null, " #l3 "]"); \ + /* CLR = 000 */ \ + CheckIfElseOutputASA((type), "[true, true, null, false]", null_scalar, \ + "[" #l0 ", " #l1 ", null, " #l3 "]", \ + "[null, null, null, " #l3 "]"); \ + \ + /* -------- Cond - Array, Left- Scalar, Right - Scalar --------- */ \ + /* empty */ \ + CheckIfElseOutputASS((type), "[]", valid_scalar, valid_scalar1, "[]"); \ + \ + /* CLR = 111 */ \ + CheckIfElseOutputASS((type), "[true, true, true, false]", valid_scalar, \ + valid_scalar1, \ + "[" #valid ", " #valid ", " #valid ", " #valid1 "]"); \ + /* CLR = 011 */ \ + CheckIfElseOutputASS((type), "[true, true, null, false]", valid_scalar, \ + valid_scalar1, "[" #valid ", " #valid ", null, " #valid1 "]"); \ + /* CLR = 010 */ \ + CheckIfElseOutputASS((type), "[true, true, null, false]", valid_scalar, null_scalar, \ + "[" #valid ", " #valid ", null, null]"); \ + /* CLR = 110 */ \ + CheckIfElseOutputASS((type), "[true, true, true, false]", valid_scalar, null_scalar, \ + "[" #valid ", " #valid ", " #valid ", null]"); \ + /* CLR = 101 */ \ + CheckIfElseOutputASS((type), "[true, true, true, false]", null_scalar, \ + valid_scalar1, "[null, null, null, " #valid1 "]"); \ + /* CLR = 001 */ \ + CheckIfElseOutputASS((type), "[null, true, true, false]", null_scalar, \ + valid_scalar1, "[null, null, null, " #valid1 "]"); \ + /* CLR = 100 */ \ + CheckIfElseOutputASS((type), "[true, true, true, false]", null_scalar, null_scalar, \ + "[null, null, null, null]"); \ + /* CLR = 000 */ \ + CheckIfElseOutputASS((type), "[true, true, null, false]", null_scalar, null_scalar, \ + "[null, null, null, null]"); \ + \ + /* -------- Cond - Scalar, Left- Array, Right - Array --------- */ \ + auto bool_true = MakeScalar(boolean(), true).ValueOrDie(); \ + auto bool_false = MakeScalar(boolean(), false).ValueOrDie(); \ + std::shared_ptr bool_null = MakeNullScalar(boolean()); \ + \ + /* empty */ \ + CheckIfElseOutputSAA((type), bool_true, "[]", "[]", "[]"); \ + /* CLR = 111 */ \ + CheckIfElseOutputSAA((type), bool_true, "[" #l0 ", " #l1 ", " #l2 ", " #l3 "]", \ + "[" #r0 ", " #r1 ", " #r2 ", " #r3 "]", \ + "[" #l0 ", " #l1 ", " #l2 ", " #l3 "]"); \ + /* CLR = 011 */ \ + CheckIfElseOutputSAA((type), bool_null, "[" #l0 ", " #l1 ", " #l2 ", " #l3 "]", \ + "[" #r0 ", " #r1 ", " #r2 ", " #r3 "]", \ + "[null, null, null, null]"); \ + /* CLR = 101 */ \ + CheckIfElseOutputSAA((type), bool_false, "[" #l0 ", null, " #l2 ", " #l3 "]", \ + "[" #r0 ", " #r1 ", " #r2 ", " #r3 "]", \ + "[" #r0 ", " #r1 ", " #r2 ", " #r3 "]"); \ + /* CLR = 001 */ \ + CheckIfElseOutputSAA((type), bool_null, "[" #l0 ", null, " #l2 ", " #l3 "]", \ + "[" #r0 ", " #r1 ", " #r2 ", " #r3 "]", \ + "[null, null, null, null]"); \ + /* CLR = 110 */ \ + CheckIfElseOutputSAA((type), bool_false, "[" #l0 ", " #l1 ", " #l2 ", " #l3 "]", \ + "[" #r0 ", " #r1 ", " #r2 ", null]", \ + "[" #r0 ", " #r1 ", " #r2 ", null]"); \ + /* CLR = 010 */ \ + CheckIfElseOutputSAA((type), bool_null, "[" #l0 ", " #l1 ", " #l2 ", " #l3 "]", \ + "[" #r0 ", " #r1 ", " #r2 ", null]", \ + "[null, null, null, null]"); \ + /* CLR = 100 */ \ + CheckIfElseOutputSAA((type), bool_true, "[" #l0 ", " #l1 ", null, null]", \ + "[null, " #r1 ", " #r2 ", null]", \ + "[" #l0 ", " #l1 ", null, null]"); \ + /* CLR = 000 */ \ + CheckIfElseOutputSAA((type), bool_null, "[" #l0 ", " #l1 ", null, null]", \ + "[null, " #r1 ", " #r2 ", null]", "[null, null, null, null]"); \ + \ + /* -------- Cond - Scalar, Left- Array, Right - Scalar --------- */ \ + /* empty */ \ + CheckIfElseOutputSAS((type), bool_true, "[]", valid_scalar, "[]"); \ + \ + /* CLR = 111 */ \ + CheckIfElseOutputSAS((type), bool_true, "[" #l0 ", " #l1 ", " #l2 ", " #l3 "]", \ + valid_scalar, "[" #l0 ", " #l1 ", " #l2 ", " #l3 "]"); \ + /* CLR = 011 */ \ + CheckIfElseOutputSAS((type), bool_null, "[" #l0 ", " #l1 ", " #l2 ", " #l3 "]", \ + valid_scalar, "[null, null, null, null]"); \ + /* CLR = 101 */ \ + CheckIfElseOutputSAS((type), bool_false, "[" #l0 ", null, " #l2 ", " #l3 "]", \ + valid_scalar, \ + "[" #valid ", " #valid ", " #valid ", " #valid "]"); \ + /* CLR = 001 */ \ + CheckIfElseOutputSAS((type), bool_null, "[" #l0 ", null, " #l2 ", " #l3 "]", \ + valid_scalar, "[null, null, null, null]"); \ + /* CLR = 110 */ \ + CheckIfElseOutputSAS((type), bool_true, "[" #l0 ", " #l1 ", " #l2 ", " #l3 "]", \ + null_scalar, "[" #l0 ", " #l1 ", " #l2 ", " #l3 "]"); \ + /* CLR = 010 */ \ + CheckIfElseOutputSAS((type), bool_null, "[" #l0 ", " #l1 ", " #l2 ", " #l3 "]", \ + null_scalar, "[null, null, null, null]"); \ + /* CLR = 100 */ \ + CheckIfElseOutputSAS((type), bool_false, "[" #l0 ", " #l1 ", null, null]", \ + null_scalar, "[null, null, null, null]"); \ + /* CLR = 000 */ \ + CheckIfElseOutputSAS((type), bool_null, "[" #l0 ", " #l1 ", null, null]", \ + null_scalar, "[null, null, null, null]"); \ + \ + /* -------- Cond - Scalar, Left- Scalar, Right - Array --------- */ \ + /* empty */ \ + CheckIfElseOutputSSA((type), bool_true, valid_scalar, "[]", "[]"); \ + \ + /* CLR = 111 */ \ + CheckIfElseOutputSSA((type), bool_true, valid_scalar, \ + "[" #l0 ", " #l1 ", " #l2 ", " #l3 "]", \ + "[" #valid ", " #valid ", " #valid ", " #valid "]"); \ + /* CLR = 011 */ \ + CheckIfElseOutputSSA((type), bool_null, valid_scalar, \ + "[" #l0 ", " #l1 ", " #l2 ", " #l3 "]", \ + "[null, null, null, null]"); \ + /* CLR = 110 */ \ + CheckIfElseOutputSSA((type), bool_false, valid_scalar, \ + "[" #l0 ", null, " #l2 ", null]", \ + "[" #l0 ", null, " #l2 ", null]"); \ + /* CLR = 010 */ \ + CheckIfElseOutputSSA((type), bool_null, valid_scalar, \ + "[" #l0 ", null, " #l2 ", null]", "[null, null, null, null]"); \ + /* CLR = 101 */ \ + CheckIfElseOutputSSA((type), bool_true, null_scalar, \ + "[" #l0 ", " #l1 ", " #l2 ", " #l3 "]", \ + "[null, null, null, null]"); \ + /* CLR = 001 */ \ + CheckIfElseOutputSSA((type), bool_null, null_scalar, \ + "[" #l0 ", " #l1 ", " #l2 ", " #l3 "]", \ + "[null, null, null, null]"); \ + /* CLR = 100 */ \ + CheckIfElseOutputSSA((type), bool_false, null_scalar, \ + "[" #l0 ", " #l1 ", null, " #l3 "]", \ + "[" #l0 ", " #l1 ", null, " #l3 "]"); \ + /* CLR = 000 */ \ + CheckIfElseOutputSSA((type), bool_null, null_scalar, \ + "[" #l0 ", " #l1 ", null, " #l3 "]", \ + "[null, null, null, null]"); \ + \ + /* -------- Cond - Scalar, Left- Scalar, Right - Scalar --------- */ \ + \ + /* CLR = 111 */ \ + CheckIfElseOutput(bool_false, valid_scalar, valid_scalar1, valid_scalar1); \ + /* CLR = 011 */ \ + CheckIfElseOutput(bool_null, valid_scalar, valid_scalar1, null_scalar); \ + /* CLR = 110 */ \ + CheckIfElseOutput(bool_true, valid_scalar, null_scalar, valid_scalar); \ + /* CLR = 010 */ \ + CheckIfElseOutput(bool_null, valid_scalar, null_scalar, null_scalar); \ + /* CLR = 101 */ \ + CheckIfElseOutput(bool_false, null_scalar, valid_scalar1, valid_scalar1); \ + /* CLR = 001 */ \ + CheckIfElseOutput(bool_null, null_scalar, valid_scalar1, null_scalar); \ + /* CLR = 100 */ \ + CheckIfElseOutput(bool_true, null_scalar, null_scalar, null_scalar); \ + /* CLR = 000 */ \ + CheckIfElseOutput(bool_null, null_scalar, null_scalar, null_scalar); \ + } while (0) + /* * Legend: * C - Cond, L - Left, R - Right @@ -155,466 +426,13 @@ TYPED_TEST(TestIfElsePrimitive, IfElseFixedSizeRand) { TYPED_TEST(TestIfElsePrimitive, IfElseFixedSize) { auto type = TypeTraits::type_singleton(); - // -------- All arrays --------- - // empty - CheckIfElseOutputAAA(type, "[]", "[]", "[]", "[]"); - // CLR = 111 - CheckIfElseOutputAAA(type, "[true, true, true, false]", "[1, 2, 3, 4]", "[5, 6, 7, 8]", - "[1, 2, 3, 8]"); - // CLR = 011 - CheckIfElseOutputAAA(type, "[true, true, null, false]", "[1, 2, 3, 4]", "[5, 6, 7, 8]", - "[1, 2, null, 8]"); - // CLR = 101 - CheckIfElseOutputAAA(type, "[true, true, true, false]", "[1, null, 3, 4]", - "[5, 6, 7, 8]", "[1, null, 3, 8]"); - // CLR = 001 - CheckIfElseOutputAAA(type, "[true, true, null, false]", "[1, null, 3, 4]", - "[5, 6, 7, 8]", "[1, null, null, 8]"); - // CLR = 110 - CheckIfElseOutputAAA(type, "[true, true, true, false]", "[1, 2, 3, 4]", - "[5, 6, 7, null]", "[1, 2, 3, null]"); - // CLR = 010 - CheckIfElseOutputAAA(type, "[null, true, true, false]", "[1, 2, 3, 4]", - "[5, 6, 7, null]", "[null, 2, 3, null]"); - // CLR = 100 - CheckIfElseOutputAAA(type, "[true, true, true, false]", "[1, 2, null, null]", - "[null, 6, 7, null]", "[1, 2, null, null]"); - // CLR = 000 - CheckIfElseOutputAAA(type, "[null, true, true, false]", "[1, 2, null, null]", - "[null, 6, 7, null]", "[null, 2, null, null]"); - - // -------- Cond - Array, Left- Array, Right - Scalar --------- - - ASSERT_OK_AND_ASSIGN(std::shared_ptr valid_scalar, MakeScalar(type, 100)); - std::shared_ptr null_scalar = MakeNullScalar(type); - - // empty - CheckIfElseOutputAAS(type, "[]", "[]", valid_scalar, "[]"); - - // CLR = 111 - CheckIfElseOutputAAS(type, "[true, true, true, false]", "[1, 2, 3, 4]", valid_scalar, - "[1, 2, 3, 100]"); - // CLR = 011 - CheckIfElseOutputAAS(type, "[true, true, null, false]", "[1, 2, 3, 4]", valid_scalar, - "[1, 2, null, 100]"); - // CLR = 101 - CheckIfElseOutputAAS(type, "[true, true, true, false]", "[1, null, 3, 4]", valid_scalar, - "[1, null, 3, 100]"); - // CLR = 001 - CheckIfElseOutputAAS(type, "[true, true, null, false]", "[1, null, 3, 4]", valid_scalar, - "[1, null, null, 100]"); - // CLR = 110 - CheckIfElseOutputAAS(type, "[true, true, true, false]", "[1, 2, 3, 4]", null_scalar, - "[1, 2, 3, null]"); - // CLR = 010 - CheckIfElseOutputAAS(type, "[null, true, true, false]", "[1, 2, 3, 4]", null_scalar, - "[null, 2, 3, null]"); - // CLR = 100 - CheckIfElseOutputAAS(type, "[true, true, true, false]", "[1, 2, null, null]", - null_scalar, "[1, 2, null, null]"); - // CLR = 000 - CheckIfElseOutputAAS(type, "[null, true, true, false]", "[1, 2, null, null]", - null_scalar, "[null, 2, null, null]"); - - // -------- Cond - Array, Left- Scalar, Right - Array --------- - // empty - CheckIfElseOutputASA(type, "[]", valid_scalar, "[]", "[]"); - - // CLR = 111 - CheckIfElseOutputASA(type, "[true, true, true, false]", valid_scalar, "[1, 2, 3, 4]", - "[100, 100, 100, 4]"); - // CLR = 011 - CheckIfElseOutputASA(type, "[true, true, null, false]", valid_scalar, "[1, 2, 3, 4]", - "[100, 100, null, 4]"); - // CLR = 110 - CheckIfElseOutputASA(type, "[true, true, true, false]", valid_scalar, - "[1, null, 3, null]", "[100, 100, 100, null]"); - // CLR = 010 - CheckIfElseOutputASA(type, "[true, true, null, false]", valid_scalar, - "[1, null, 3, null]", "[100, 100, null, null]"); - // CLR = 101 - CheckIfElseOutputASA(type, "[true, true, true, false]", null_scalar, "[1, 2, 3, 4]", - "[null, null, null, 4]"); - // CLR = 001 - CheckIfElseOutputASA(type, "[null, true, true, false]", null_scalar, "[1, 2, 3, 4]", - "[null, null, null, 4]"); - // CLR = 100 - CheckIfElseOutputASA(type, "[true, true, true, false]", null_scalar, "[1, 2, null, 4]", - "[null, null, null, 4]"); - // CLR = 000 - CheckIfElseOutputASA(type, "[true, true, null, false]", null_scalar, "[1, 2, null, 4]", - "[null, null, null, 4]"); - - // -------- Cond - Array, Left- Scalar, Right - Scalar --------- - ASSERT_OK_AND_ASSIGN(std::shared_ptr valid_scalar1, MakeScalar(type, 111)); - - // empty - CheckIfElseOutputASS(type, "[]", valid_scalar, valid_scalar1, "[]"); - - // CLR = 111 - CheckIfElseOutputASS(type, "[true, true, true, false]", valid_scalar, valid_scalar1, - "[100, 100, 100, 111]"); - // CLR = 011 - CheckIfElseOutputASS(type, "[true, true, null, false]", valid_scalar, valid_scalar1, - "[100, 100, null, 111]"); - // CLR = 010 - CheckIfElseOutputASS(type, "[true, true, null, false]", valid_scalar, null_scalar, - "[100, 100, null, null]"); - // CLR = 110 - CheckIfElseOutputASS(type, "[true, true, true, false]", valid_scalar, null_scalar, - "[100, 100, 100, null]"); - // CLR = 101 - CheckIfElseOutputASS(type, "[true, true, true, false]", null_scalar, valid_scalar1, - "[null, null, null, 111]"); - // CLR = 001 - CheckIfElseOutputASS(type, "[null, true, true, false]", null_scalar, valid_scalar1, - "[null, null, null, 111]"); - // CLR = 100 - CheckIfElseOutputASS(type, "[true, true, true, false]", null_scalar, null_scalar, - "[null, null, null, null]"); - // CLR = 000 - CheckIfElseOutputASS(type, "[true, true, null, false]", null_scalar, null_scalar, - "[null, null, null, null]"); - - // -------- Cond - Scalar, Left- Array, Right - Array --------- - ASSERT_OK_AND_ASSIGN(std::shared_ptr bool_true, MakeScalar(boolean(), true)); - ASSERT_OK_AND_ASSIGN(std::shared_ptr bool_false, MakeScalar(boolean(), false)); - std::shared_ptr bool_null = MakeNullScalar(boolean()); - - // empty - CheckIfElseOutputSAA(type, bool_true, "[]", "[]", "[]"); - // CLR = 111 - CheckIfElseOutputSAA(type, bool_true, "[1, 2, 3, 4]", "[5, 6, 7, 8]", "[1, 2, 3, 4]"); - // CLR = 011 - CheckIfElseOutputSAA(type, bool_null, "[1, 2, 3, 4]", "[5, 6, 7, 8]", - "[null, null, null, null]"); - // CLR = 101 - CheckIfElseOutputSAA(type, bool_false, "[1, null, 3, 4]", "[5, 6, 7, 8]", - "[5, 6, 7, 8]"); - // CLR = 001 - CheckIfElseOutputSAA(type, bool_null, "[1, null, 3, 4]", "[5, 6, 7, 8]", - "[null, null, null, null]"); - // CLR = 110 - CheckIfElseOutputSAA(type, bool_false, "[1, 2, 3, 4]", "[5, 6, 7, null]", - "[5, 6, 7, null]"); - // CLR = 010 - CheckIfElseOutputSAA(type, bool_null, "[1, 2, 3, 4]", "[5, 6, 7, null]", - "[null, null, null, null]"); - // CLR = 100 - CheckIfElseOutputSAA(type, bool_true, "[1, 2, null, null]", "[null, 6, 7, null]", - "[1, 2, null, null]"); - // CLR = 000 - CheckIfElseOutputSAA(type, bool_null, "[1, 2, null, null]", "[null, 6, 7, null]", - "[null, null, null, null]"); - - // -------- Cond - Scalar, Left- Array, Right - Scalar --------- - // empty - CheckIfElseOutputSAS(type, bool_true, "[]", valid_scalar, "[]"); - - // CLR = 111 - CheckIfElseOutputSAS(type, bool_true, "[1, 2, 3, 4]", valid_scalar, "[1, 2, 3, 4]"); - // CLR = 011 - CheckIfElseOutputSAS(type, bool_null, "[1, 2, 3, 4]", valid_scalar, - "[null, null, null, null]"); - // CLR = 101 - CheckIfElseOutputSAS(type, bool_false, "[1, null, 3, 4]", valid_scalar, - "[100, 100, 100, 100]"); - // CLR = 001 - CheckIfElseOutputSAS(type, bool_null, "[1, null, 3, 4]", valid_scalar, - "[null, null, null, null]"); - // CLR = 110 - CheckIfElseOutputSAS(type, bool_true, "[1, 2, 3, 4]", null_scalar, "[1, 2, 3, 4]"); - // CLR = 010 - CheckIfElseOutputSAS(type, bool_null, "[1, 2, 3, 4]", null_scalar, - "[null, null, null, null]"); - // CLR = 100 - CheckIfElseOutputSAS(type, bool_false, "[1, 2, null, null]", null_scalar, - "[null, null, null, null]"); - // CLR = 000 - CheckIfElseOutputSAS(type, bool_null, "[1, 2, null, null]", null_scalar, - "[null, null, null, null]"); - - // -------- Cond - Scalar, Left- Scalar, Right - Array --------- - // empty - CheckIfElseOutputSSA(type, bool_true, valid_scalar, "[]", "[]"); - - // CLR = 111 - CheckIfElseOutputSSA(type, bool_true, valid_scalar, "[1, 2, 3, 4]", - "[100, 100, 100, 100]"); - // CLR = 011 - CheckIfElseOutputSSA(type, bool_null, valid_scalar, "[1, 2, 3, 4]", - "[null, null, null, null]"); - // CLR = 110 - CheckIfElseOutputSSA(type, bool_false, valid_scalar, "[1, null, 3, null]", - "[1, null, 3, null]"); - // CLR = 010 - CheckIfElseOutputSSA(type, bool_null, valid_scalar, "[1, null, 3, null]", - "[null, null, null, null]"); - // CLR = 101 - CheckIfElseOutputSSA(type, bool_true, null_scalar, "[1, 2, 3, 4]", - "[null, null, null, null]"); - // CLR = 001 - CheckIfElseOutputSSA(type, bool_null, null_scalar, "[1, 2, 3, 4]", - "[null, null, null, null]"); - // CLR = 100 - CheckIfElseOutputSSA(type, bool_false, null_scalar, "[1, 2, null, 4]", - "[1, 2, null, 4]"); - // CLR = 000 - CheckIfElseOutputSSA(type, bool_null, null_scalar, "[1, 2, null, 4]", - "[null, null, null, null]"); - - // -------- Cond - Scalar, Left- Scalar, Right - Scalar --------- - - // CLR = 111 - CheckIfElseOutput(bool_false, valid_scalar, valid_scalar1, valid_scalar1); - // CLR = 011 - CheckIfElseOutput(bool_null, valid_scalar, valid_scalar1, null_scalar); - // CLR = 110 - CheckIfElseOutput(bool_true, valid_scalar, null_scalar, valid_scalar); - // CLR = 010 - CheckIfElseOutput(bool_null, valid_scalar, null_scalar, null_scalar); - // CLR = 101 - CheckIfElseOutput(bool_false, null_scalar, valid_scalar1, valid_scalar1); - // CLR = 001 - CheckIfElseOutput(bool_null, null_scalar, valid_scalar1, null_scalar); - // CLR = 100 - CheckIfElseOutput(bool_true, null_scalar, null_scalar, null_scalar); - // CLR = 000 - CheckIfElseOutput(bool_null, null_scalar, null_scalar, null_scalar); + IF_ELSE_TEST_GEN(type, 1, 2, 3, 4, 5, 6, 7, 8, 100, 111); } TEST_F(TestIfElseKernel, IfElseBoolean) { auto type = boolean(); - // -------- All arrays --------- - // empty - CheckIfElseOutputAAA(type, "[]", "[]", "[]", "[]"); - // CLR = 111 - CheckIfElseOutputAAA(type, "[true, true, true, false]", "[false, false, false, false]", - "[true, true, true, true]", "[false, false, false, true]"); - // CLR = 011 - CheckIfElseOutputAAA(type, "[true, true, null, false]", "[false, false, false, false]", - "[true, true, true, true]", "[false, false, null, true]"); - // CLR = 101 - CheckIfElseOutputAAA(type, "[true, true, true, false]", "[false, null, false, false]", - "[true, true, true, true]", "[false, null, false, true]"); - // CLR = 001 - CheckIfElseOutputAAA(type, "[true, true, null, false]", "[false, null, false, false]", - "[true, true, true, true]", "[false, null, null, true]"); - // CLR = 110 - CheckIfElseOutputAAA(type, "[true, true, true, false]", "[false, false, false, false]", - "[true, true, true, null]", "[false, false, false, null]"); - // CLR = 010 - CheckIfElseOutputAAA(type, "[null, true, true, false]", "[false, false, false, false]", - "[true, true, true, null]", "[null, false, false, null]"); - // CLR = 100 - CheckIfElseOutputAAA(type, "[true, true, true, false]", "[false, false, null, null]", - "[null, true, true, null]", "[false, false, null, null]"); - // CLR = 000 - CheckIfElseOutputAAA(type, "[null, true, true, false]", "[false, false, null, null]", - "[null, true, true, null]", "[null, false, null, null]"); - - // -------- Cond - Array, Left- Array, Right - Scalar --------- - - ASSERT_OK_AND_ASSIGN(std::shared_ptr valid_scalar, MakeScalar(type, false)); - std::shared_ptr null_scalar = MakeNullScalar(type); - - // empty - CheckIfElseOutputAAS(type, "[]", "[]", valid_scalar, "[]"); - - // CLR = 111 - CheckIfElseOutputAAS(type, "[true, true, true, false]", "[false, false, false, false]", - valid_scalar, "[false, false, false, false]"); - // CLR = 011 - CheckIfElseOutputAAS(type, "[true, true, null, false]", "[false, false, false, false]", - valid_scalar, "[false, false, null, false]"); - // CLR = 101 - CheckIfElseOutputAAS(type, "[true, true, true, false]", "[false, null, false, false]", - valid_scalar, "[false, null, false, false]"); - // CLR = 001 - CheckIfElseOutputAAS(type, "[true, true, null, false]", "[false, null, false, false]", - valid_scalar, "[false, null, null, false]"); - // CLR = 110 - CheckIfElseOutputAAS(type, "[true, true, true, false]", "[false, false, false, false]", - null_scalar, "[false, false, false, null]"); - // CLR = 010 - CheckIfElseOutputAAS(type, "[null, true, true, false]", "[false, false, false, false]", - null_scalar, "[null, false, false, null]"); - // CLR = 100 - CheckIfElseOutputAAS(type, "[true, true, true, false]", "[false, false, null, null]", - null_scalar, "[false, false, null, null]"); - // CLR = 000 - CheckIfElseOutputAAS(type, "[null, true, true, false]", "[false, false, null, null]", - null_scalar, "[null, false, null, null]"); - - // -------- Cond - Array, Left- Scalar, Right - Array --------- - // empty - CheckIfElseOutputASA(type, "[]", valid_scalar, "[]", "[]"); - - // CLR = 111 - CheckIfElseOutputASA(type, "[true, true, true, false]", valid_scalar, - "[false, false, false, false]", "[false, false, false, false]"); - // CLR = 011 - CheckIfElseOutputASA(type, "[true, true, null, false]", valid_scalar, - "[false, false, false, false]", "[false, false, null, false]"); - // CLR = 110 - CheckIfElseOutputASA(type, "[true, true, true, false]", valid_scalar, - "[false, null, false, null]", "[false, false, false, null]"); - // CLR = 010 - CheckIfElseOutputASA(type, "[true, true, null, false]", valid_scalar, - "[false, null, false, null]", "[false, false, null, null]"); - // CLR = 101 - CheckIfElseOutputASA(type, "[true, true, true, false]", null_scalar, - "[false, false, false, false]", "[null, null, null, false]"); - // CLR = 001 - CheckIfElseOutputASA(type, "[null, true, true, false]", null_scalar, - "[false, false, false, false]", "[null, null, null, false]"); - // CLR = 100 - CheckIfElseOutputASA(type, "[true, true, true, false]", null_scalar, - "[false, false, null, false]", "[null, null, null, false]"); - // CLR = 000 - CheckIfElseOutputASA(type, "[true, true, null, false]", null_scalar, - "[false, false, null, false]", "[null, null, null, false]"); - - // -------- Cond - Array, Left- Scalar, Right - Scalar --------- - ASSERT_OK_AND_ASSIGN(std::shared_ptr valid_scalar1, MakeScalar(type, true)); - - // empty - CheckIfElseOutputASS(type, "[]", valid_scalar, valid_scalar1, "[]"); - - // CLR = 111 - CheckIfElseOutputASS(type, "[true, true, true, false]", valid_scalar, valid_scalar1, - "[false, false, false, true]"); - // CLR = 011 - CheckIfElseOutputASS(type, "[true, true, null, false]", valid_scalar, valid_scalar1, - "[false, false, null, true]"); - // CLR = 010 - CheckIfElseOutputASS(type, "[true, true, null, false]", valid_scalar, null_scalar, - "[false, false, null, null]"); - // CLR = 110 - CheckIfElseOutputASS(type, "[true, true, true, false]", valid_scalar, null_scalar, - "[false, false, false, null]"); - // CLR = 101 - CheckIfElseOutputASS(type, "[true, true, true, false]", null_scalar, valid_scalar1, - "[null, null, null, true]"); - // CLR = 001 - CheckIfElseOutputASS(type, "[null, true, true, false]", null_scalar, valid_scalar1, - "[null, null, null, true]"); - // CLR = 100 - CheckIfElseOutputASS(type, "[true, true, true, false]", null_scalar, null_scalar, - "[null, null, null, null]"); - // CLR = 000 - CheckIfElseOutputASS(type, "[true, true, null, false]", null_scalar, null_scalar, - "[null, null, null, null]"); - - // -------- Cond - Scalar, Left- Array, Right - Array --------- - ASSERT_OK_AND_ASSIGN(std::shared_ptr bool_true, MakeScalar(type, true)); - ASSERT_OK_AND_ASSIGN(std::shared_ptr bool_false, MakeScalar(type, false)); - std::shared_ptr bool_null = MakeNullScalar(type); - - // empty - CheckIfElseOutputSAA(type, bool_true, "[]", "[]", "[]"); - // CLR = 111 - CheckIfElseOutputSAA(type, bool_true, "[false, false, false, false]", - "[true, true, true, true]", "[false, false, false, false]"); - // CLR = 011 - CheckIfElseOutputSAA(type, bool_null, "[false, false, false, false]", - "[true, true, true, true]", "[null, null, null, null]"); - // CLR = 101 - CheckIfElseOutputSAA(type, bool_false, "[false, null, false, false]", - "[true, true, true, true]", "[true, true, true, true]"); - // CLR = 001 - CheckIfElseOutputSAA(type, bool_null, "[false, null, false, false]", - "[true, true, true, true]", "[null, null, null, null]"); - // CLR = 110 - CheckIfElseOutputSAA(type, bool_false, "[false, false, false, false]", - "[true, true, true, null]", "[true, true, true, null]"); - // CLR = 010 - CheckIfElseOutputSAA(type, bool_null, "[false, false, false, false]", - "[true, true, true, null]", "[null, null, null, null]"); - // CLR = 100 - CheckIfElseOutputSAA(type, bool_true, "[false, false, null, null]", - "[null, true, true, null]", "[false, false, null, null]"); - // CLR = 000 - CheckIfElseOutputSAA(type, bool_null, "[false, false, null, null]", - "[null, true, true, null]", "[null, null, null, null]"); - - // -------- Cond - Scalar, Left- Array, Right - Scalar --------- - // empty - CheckIfElseOutputSAS(type, bool_true, "[]", valid_scalar, "[]"); - - // CLR = 111 - CheckIfElseOutputSAS(type, bool_true, "[false, false, false, false]", valid_scalar, - "[false, false, false, false]"); - // CLR = 011 - CheckIfElseOutputSAS(type, bool_null, "[false, false, false, false]", valid_scalar, - "[null, null, null, null]"); - // CLR = 101 - CheckIfElseOutputSAS(type, bool_false, "[false, null, false, false]", valid_scalar, - "[false, false, false, false]"); - // CLR = 001 - CheckIfElseOutputSAS(type, bool_null, "[false, null, false, false]", valid_scalar, - "[null, null, null, null]"); - // CLR = 110 - CheckIfElseOutputSAS(type, bool_true, "[false, false, false, false]", null_scalar, - "[false, false, false, false]"); - // CLR = 010 - CheckIfElseOutputSAS(type, bool_null, "[false, false, false, false]", null_scalar, - "[null, null, null, null]"); - // CLR = 100 - CheckIfElseOutputSAS(type, bool_false, "[false, false, null, null]", null_scalar, - "[null, null, null, null]"); - // CLR = 000 - CheckIfElseOutputSAS(type, bool_null, "[false, false, null, null]", null_scalar, - "[null, null, null, null]"); - - // -------- Cond - Scalar, Left- Scalar, Right - Array --------- - // empty - CheckIfElseOutputSSA(type, bool_true, valid_scalar, "[]", "[]"); - - // CLR = 111 - CheckIfElseOutputSSA(type, bool_true, valid_scalar, "[false, false, false, false]", - "[false, false, false, false]"); - // CLR = 011 - CheckIfElseOutputSSA(type, bool_null, valid_scalar, "[false, false, false, false]", - "[null, null, null, null]"); - // CLR = 110 - CheckIfElseOutputSSA(type, bool_false, valid_scalar, "[false, null, false, null]", - "[false, null, false, null]"); - // CLR = 010 - CheckIfElseOutputSSA(type, bool_null, valid_scalar, "[false, null, false, null]", - "[null, null, null, null]"); - // CLR = 101 - CheckIfElseOutputSSA(type, bool_true, null_scalar, "[false, false, false, false]", - "[null, null, null, null]"); - // CLR = 001 - CheckIfElseOutputSSA(type, bool_null, null_scalar, "[false, false, false, false]", - "[null, null, null, null]"); - // CLR = 100 - CheckIfElseOutputSSA(type, bool_false, null_scalar, "[false, false, null, false]", - "[false, false, null, false]"); - // CLR = 000 - CheckIfElseOutputSSA(type, bool_null, null_scalar, "[false, false, null, false]", - "[null, null, null, null]"); - - // -------- Cond - Scalar, Left- Scalar, Right - Scalar --------- - - // CLR = 111 - CheckIfElseOutput(bool_false, valid_scalar, valid_scalar1, valid_scalar1); - // CLR = 011 - CheckIfElseOutput(bool_null, valid_scalar, valid_scalar1, null_scalar); - // CLR = 110 - CheckIfElseOutput(bool_true, valid_scalar, null_scalar, valid_scalar); - // CLR = 010 - CheckIfElseOutput(bool_null, valid_scalar, null_scalar, null_scalar); - // CLR = 101 - CheckIfElseOutput(bool_false, null_scalar, valid_scalar1, valid_scalar1); - // CLR = 001 - CheckIfElseOutput(bool_null, null_scalar, valid_scalar1, null_scalar); - // CLR = 100 - CheckIfElseOutput(bool_true, null_scalar, null_scalar, null_scalar); - // CLR = 000 - CheckIfElseOutput(bool_null, null_scalar, null_scalar, null_scalar); + IF_ELSE_TEST_GEN(type, false, false, false, false, true, true, true, true, false, true); } TYPED_TEST(TestIfElsePrimitive, IfElseBooleanRand) { From 68f065bd2f69992ab9771df8e1fe681fa18ecad9 Mon Sep 17 00:00:00 2001 From: niranda perera Date: Thu, 27 May 2021 16:21:55 -0400 Subject: [PATCH 22/39] cosmetic changes --- cpp/src/arrow/compute/api_scalar.h | 6 +++--- cpp/src/arrow/compute/kernels/scalar_if_else.cc | 8 ++++++-- 2 files changed, 9 insertions(+), 5 deletions(-) diff --git a/cpp/src/arrow/compute/api_scalar.h b/cpp/src/arrow/compute/api_scalar.h index 60a8028df8575..d5d1f82fb35a7 100644 --- a/cpp/src/arrow/compute/api_scalar.h +++ b/cpp/src/arrow/compute/api_scalar.h @@ -468,9 +468,9 @@ Result FillNull(const Datum& values, const Datum& fill_value, /// \brief IfElse returns elements chosen from `left` or `right` /// depending on `cond`. `null` values would be promoted to the result /// -/// \param[in] cond `BooleanArray` condition array -/// \param[in] left scalar/ Array -/// \param[in] right scalar/ Array +/// \param[in] cond `Boolean` condition Scalar/ Array +/// \param[in] left Scalar/ Array +/// \param[in] right Scalar/ Array /// \param[in] ctx the function execution context, optional /// /// \return the resulting datum diff --git a/cpp/src/arrow/compute/kernels/scalar_if_else.cc b/cpp/src/arrow/compute/kernels/scalar_if_else.cc index bcd9aea7bfa1d..a8adf10056062 100644 --- a/cpp/src/arrow/compute/kernels/scalar_if_else.cc +++ b/cpp/src/arrow/compute/kernels/scalar_if_else.cc @@ -711,8 +711,12 @@ void AddPrimitiveIfElseKernels(const std::shared_ptr& scalar_fun } // namespace -// todo fill this -const FunctionDoc if_else_doc{"", ("`"), {"cond", "left", "right"}}; +const FunctionDoc if_else_doc{"Choose values based on a condition", + ("`cond` must be a Boolean scalar/ array. \n`left` or " + "`right` must be of the same type scalar/ array.\n" + "`null` values in `cond` will be promoted to the" + " output."), + {"cond", "left", "right"}}; namespace internal { From 0c25c2a2b622c88eda85a2a7016dfc6566877068 Mon Sep 17 00:00:00 2001 From: niranda perera Date: Thu, 27 May 2021 17:15:38 -0400 Subject: [PATCH 23/39] Update cpp/src/arrow/compute/kernels/scalar_if_else.cc Co-authored-by: Benjamin Kietzman --- cpp/src/arrow/compute/kernels/scalar_if_else.cc | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/cpp/src/arrow/compute/kernels/scalar_if_else.cc b/cpp/src/arrow/compute/kernels/scalar_if_else.cc index a8adf10056062..aa40f852dc52c 100644 --- a/cpp/src/arrow/compute/kernels/scalar_if_else.cc +++ b/cpp/src/arrow/compute/kernels/scalar_if_else.cc @@ -38,7 +38,9 @@ enum { COND_ALL_VALID = 1, LEFT_ALL_VALID = 2, RIGHT_ALL_VALID = 4 }; // ie. cond.valid & (cond.data & left.valid | ~cond.data & right.valid) Status PromoteNullsVisitor(KernelContext* ctx, const ArrayData& cond, const Scalar& left, const Scalar& right, ArrayData* output) { - uint8_t flag = right.is_valid * 4 + left.is_valid * 2 + !cond.MayHaveNulls(); + uint8_t flag = right.is_valid * RIGHT_ALL_VALID | + left.is_valid * LEFT_ALL_VALID | + !cond.MayHaveNulls() * COND_ALL_VALID; if (flag < 6 && flag != 3) { // there will be a validity buffer in the output From f3ddd66be4ec07a2ce1195d3a80b524dd5eb5635 Mon Sep 17 00:00:00 2001 From: niranda perera Date: Fri, 28 May 2021 10:01:53 -0400 Subject: [PATCH 24/39] moving test cases to a method --- .../compute/kernels/scalar_if_else_test.cc | 574 +++++++++--------- 1 file changed, 302 insertions(+), 272 deletions(-) diff --git a/cpp/src/arrow/compute/kernels/scalar_if_else_test.cc b/cpp/src/arrow/compute/kernels/scalar_if_else_test.cc index a7c985bf4169c..a74342c843a1e 100644 --- a/cpp/src/arrow/compute/kernels/scalar_if_else_test.cc +++ b/cpp/src/arrow/compute/kernels/scalar_if_else_test.cc @@ -147,276 +147,304 @@ TYPED_TEST(TestIfElsePrimitive, IfElseFixedSizeRand) { CheckIfElseOutput(cond, left, right, expected_data); } -#define IF_ELSE_TEST_GEN(type, l0, l1, l2, l3, r0, r1, r2, r3, valid, valid1) \ - do { \ - /* -------- All arrays --------- */ \ - /* empty */ \ - CheckIfElseOutputAAA((type), "[]", "[]", "[]", "[]"); \ - /* CLR = 111 */ \ - CheckIfElseOutputAAA( \ - (type), "[true, true, true, false]", "[" #l0 ", " #l1 ", " #l2 ", " #l3 "]", \ - "[" #r0 ", " #r1 ", " #r2 ", " #r3 "]", "[" #l0 ", " #l1 ", " #l2 ", " #r3 "]"); \ - /* CLR = 011 */ \ - CheckIfElseOutputAAA( \ - (type), "[true, true, null, false]", "[" #l0 ", " #l1 ", " #l2 ", " #l3 "]", \ - "[" #r0 ", " #r1 ", " #r2 ", " #r3 "]", "[" #l0 ", " #l1 ", null, " #r3 "]"); \ - /* CLR = 101 */ \ - CheckIfElseOutputAAA( \ - (type), "[true, true, true, false]", "[" #l0 ", null, " #l2 ", " #l3 "]", \ - "[" #r0 ", " #r1 ", " #r2 ", " #r3 "]", "[" #l0 ", null, " #l2 ", " #r3 "]"); \ - /* CLR = 001 */ \ - CheckIfElseOutputAAA( \ - (type), "[true, true, null, false]", "[" #l0 ", null, " #l2 ", " #l3 "]", \ - "[" #r0 ", " #r1 ", " #r2 ", " #r3 "]", "[" #l0 ", null, null, " #r3 "]"); \ - /* CLR = 110 */ \ - CheckIfElseOutputAAA( \ - (type), "[true, true, true, false]", "[" #l0 ", " #l1 ", " #l2 ", " #l3 "]", \ - "[" #r0 ", " #r1 ", " #r2 ", null]", "[" #l0 ", " #l1 ", " #l2 ", null]"); \ - /* CLR = 010 */ \ - CheckIfElseOutputAAA( \ - (type), "[null, true, true, false]", "[" #l0 ", " #l1 ", " #l2 ", " #l3 "]", \ - "[" #r0 ", " #r1 ", " #r2 ", null]", "[null, " #l1 ", " #l2 ", null]"); \ - /* CLR = 100 */ \ - CheckIfElseOutputAAA( \ - (type), "[true, true, true, false]", "[" #l0 ", " #l1 ", null, null]", \ - "[null, " #r1 ", " #r2 ", null]", "[" #l0 ", " #l1 ", null, null]"); \ - /* CLR = 000 */ \ - CheckIfElseOutputAAA( \ - (type), "[null, true, true, false]", "[" #l0 ", " #l1 ", null, null]", \ - "[null, " #r1 ", " #r2 ", null]", "[null, " #l1 ", null, null]"); \ - \ - /* -------- Cond - Array, Left- Array, Right - Scalar --------- */ \ - auto valid_scalar = MakeScalar((type), (valid)).ValueOrDie(); \ - auto valid_scalar1 = MakeScalar((type), (valid1)).ValueOrDie(); \ - std::shared_ptr null_scalar = MakeNullScalar(type); \ - \ - /* empty */ \ - CheckIfElseOutputAAS((type), "[]", "[]", valid_scalar, "[]"); \ - \ - /* CLR = 111 */ \ - CheckIfElseOutputAAS((type), "[true, true, true, false]", \ - "[" #l0 ", " #l1 ", " #l2 ", " #l3 "]", valid_scalar, \ - "[" #l0 ", " #l1 ", " #l2 ", " #valid "]"); \ - /* CLR = 011 */ \ - CheckIfElseOutputAAS((type), "[true, true, null, false]", \ - "[" #l0 ", " #l1 ", " #l2 ", " #l3 "]", valid_scalar, \ - "[" #l0 ", " #l1 ", null, " #valid "]"); \ - /* CLR = 101 */ \ - CheckIfElseOutputAAS((type), "[true, true, true, false]", \ - "[" #l0 ", null, " #l2 ", " #l3 "]", valid_scalar, \ - "[" #l0 ", null, " #l2 ", " #valid "]"); \ - /* CLR = 001 */ \ - CheckIfElseOutputAAS((type), "[true, true, null, false]", \ - "[" #l0 ", null, " #l2 ", " #l3 "]", valid_scalar, \ - "[" #l0 ", null, null, " #valid "]"); \ - /* CLR = 110 */ \ - CheckIfElseOutputAAS((type), "[true, true, true, false]", \ - "[" #l0 ", " #l1 ", " #l2 ", " #l3 "]", null_scalar, \ - "[" #l0 ", " #l1 ", " #l2 ", null]"); \ - /* CLR = 010 */ \ - CheckIfElseOutputAAS((type), "[null, true, true, false]", \ - "[" #l0 ", " #l1 ", " #l2 ", " #l3 "]", null_scalar, \ - "[null, " #l1 ", " #l2 ", null]"); \ - /* CLR = 100 */ \ - CheckIfElseOutputAAS((type), "[true, true, true, false]", \ - "[" #l0 ", " #l1 ", null, null]", null_scalar, \ - "[" #l0 ", " #l1 ", null, null]"); \ - /* CLR = 000 */ \ - CheckIfElseOutputAAS((type), "[null, true, true, false]", \ - "[" #l0 ", " #l1 ", null, null]", null_scalar, \ - "[null, " #l1 ", null, null]"); \ - \ - /* -------- Cond - Array, Left- Scalar, Right - Array --------- */ \ - /* empty */ \ - CheckIfElseOutputASA((type), "[]", valid_scalar, "[]", "[]"); \ - \ - /* CLR = 111 */ \ - CheckIfElseOutputASA((type), "[true, true, true, false]", valid_scalar, \ - "[" #l0 ", " #l1 ", " #l2 ", " #l3 "]", \ - "[" #valid ", " #valid ", " #valid ", " #l3 "]"); \ - /* CLR = 011 */ \ - CheckIfElseOutputASA((type), "[true, true, null, false]", valid_scalar, \ - "[" #l0 ", " #l1 ", " #l2 ", " #l3 "]", \ - "[" #valid ", " #valid ", null, " #l3 "]"); \ - /* CLR = 110 */ \ - CheckIfElseOutputASA((type), "[true, true, true, false]", valid_scalar, \ - "[" #l0 ", null, " #l2 ", null]", \ - "[" #valid ", " #valid ", " #valid ", null]"); \ - /* CLR = 010 */ \ - CheckIfElseOutputASA((type), "[true, true, null, false]", valid_scalar, \ - "[" #l0 ", null, " #l2 ", null]", \ - "[" #valid ", " #valid ", null, null]"); \ - /* CLR = 101 */ \ - CheckIfElseOutputASA((type), "[true, true, true, false]", null_scalar, \ - "[" #l0 ", " #l1 ", " #l2 ", " #l3 "]", \ - "[null, null, null, " #l3 "]"); \ - /* CLR = 001 */ \ - CheckIfElseOutputASA((type), "[null, true, true, false]", null_scalar, \ - "[" #l0 ", " #l1 ", " #l2 ", " #l3 "]", \ - "[null, null, null, " #l3 "]"); \ - /* CLR = 100 */ \ - CheckIfElseOutputASA((type), "[true, true, true, false]", null_scalar, \ - "[" #l0 ", " #l1 ", null, " #l3 "]", \ - "[null, null, null, " #l3 "]"); \ - /* CLR = 000 */ \ - CheckIfElseOutputASA((type), "[true, true, null, false]", null_scalar, \ - "[" #l0 ", " #l1 ", null, " #l3 "]", \ - "[null, null, null, " #l3 "]"); \ - \ - /* -------- Cond - Array, Left- Scalar, Right - Scalar --------- */ \ - /* empty */ \ - CheckIfElseOutputASS((type), "[]", valid_scalar, valid_scalar1, "[]"); \ - \ - /* CLR = 111 */ \ - CheckIfElseOutputASS((type), "[true, true, true, false]", valid_scalar, \ - valid_scalar1, \ - "[" #valid ", " #valid ", " #valid ", " #valid1 "]"); \ - /* CLR = 011 */ \ - CheckIfElseOutputASS((type), "[true, true, null, false]", valid_scalar, \ - valid_scalar1, "[" #valid ", " #valid ", null, " #valid1 "]"); \ - /* CLR = 010 */ \ - CheckIfElseOutputASS((type), "[true, true, null, false]", valid_scalar, null_scalar, \ - "[" #valid ", " #valid ", null, null]"); \ - /* CLR = 110 */ \ - CheckIfElseOutputASS((type), "[true, true, true, false]", valid_scalar, null_scalar, \ - "[" #valid ", " #valid ", " #valid ", null]"); \ - /* CLR = 101 */ \ - CheckIfElseOutputASS((type), "[true, true, true, false]", null_scalar, \ - valid_scalar1, "[null, null, null, " #valid1 "]"); \ - /* CLR = 001 */ \ - CheckIfElseOutputASS((type), "[null, true, true, false]", null_scalar, \ - valid_scalar1, "[null, null, null, " #valid1 "]"); \ - /* CLR = 100 */ \ - CheckIfElseOutputASS((type), "[true, true, true, false]", null_scalar, null_scalar, \ - "[null, null, null, null]"); \ - /* CLR = 000 */ \ - CheckIfElseOutputASS((type), "[true, true, null, false]", null_scalar, null_scalar, \ - "[null, null, null, null]"); \ - \ - /* -------- Cond - Scalar, Left- Array, Right - Array --------- */ \ - auto bool_true = MakeScalar(boolean(), true).ValueOrDie(); \ - auto bool_false = MakeScalar(boolean(), false).ValueOrDie(); \ - std::shared_ptr bool_null = MakeNullScalar(boolean()); \ - \ - /* empty */ \ - CheckIfElseOutputSAA((type), bool_true, "[]", "[]", "[]"); \ - /* CLR = 111 */ \ - CheckIfElseOutputSAA((type), bool_true, "[" #l0 ", " #l1 ", " #l2 ", " #l3 "]", \ - "[" #r0 ", " #r1 ", " #r2 ", " #r3 "]", \ - "[" #l0 ", " #l1 ", " #l2 ", " #l3 "]"); \ - /* CLR = 011 */ \ - CheckIfElseOutputSAA((type), bool_null, "[" #l0 ", " #l1 ", " #l2 ", " #l3 "]", \ - "[" #r0 ", " #r1 ", " #r2 ", " #r3 "]", \ - "[null, null, null, null]"); \ - /* CLR = 101 */ \ - CheckIfElseOutputSAA((type), bool_false, "[" #l0 ", null, " #l2 ", " #l3 "]", \ - "[" #r0 ", " #r1 ", " #r2 ", " #r3 "]", \ - "[" #r0 ", " #r1 ", " #r2 ", " #r3 "]"); \ - /* CLR = 001 */ \ - CheckIfElseOutputSAA((type), bool_null, "[" #l0 ", null, " #l2 ", " #l3 "]", \ - "[" #r0 ", " #r1 ", " #r2 ", " #r3 "]", \ - "[null, null, null, null]"); \ - /* CLR = 110 */ \ - CheckIfElseOutputSAA((type), bool_false, "[" #l0 ", " #l1 ", " #l2 ", " #l3 "]", \ - "[" #r0 ", " #r1 ", " #r2 ", null]", \ - "[" #r0 ", " #r1 ", " #r2 ", null]"); \ - /* CLR = 010 */ \ - CheckIfElseOutputSAA((type), bool_null, "[" #l0 ", " #l1 ", " #l2 ", " #l3 "]", \ - "[" #r0 ", " #r1 ", " #r2 ", null]", \ - "[null, null, null, null]"); \ - /* CLR = 100 */ \ - CheckIfElseOutputSAA((type), bool_true, "[" #l0 ", " #l1 ", null, null]", \ - "[null, " #r1 ", " #r2 ", null]", \ - "[" #l0 ", " #l1 ", null, null]"); \ - /* CLR = 000 */ \ - CheckIfElseOutputSAA((type), bool_null, "[" #l0 ", " #l1 ", null, null]", \ - "[null, " #r1 ", " #r2 ", null]", "[null, null, null, null]"); \ - \ - /* -------- Cond - Scalar, Left- Array, Right - Scalar --------- */ \ - /* empty */ \ - CheckIfElseOutputSAS((type), bool_true, "[]", valid_scalar, "[]"); \ - \ - /* CLR = 111 */ \ - CheckIfElseOutputSAS((type), bool_true, "[" #l0 ", " #l1 ", " #l2 ", " #l3 "]", \ - valid_scalar, "[" #l0 ", " #l1 ", " #l2 ", " #l3 "]"); \ - /* CLR = 011 */ \ - CheckIfElseOutputSAS((type), bool_null, "[" #l0 ", " #l1 ", " #l2 ", " #l3 "]", \ - valid_scalar, "[null, null, null, null]"); \ - /* CLR = 101 */ \ - CheckIfElseOutputSAS((type), bool_false, "[" #l0 ", null, " #l2 ", " #l3 "]", \ - valid_scalar, \ - "[" #valid ", " #valid ", " #valid ", " #valid "]"); \ - /* CLR = 001 */ \ - CheckIfElseOutputSAS((type), bool_null, "[" #l0 ", null, " #l2 ", " #l3 "]", \ - valid_scalar, "[null, null, null, null]"); \ - /* CLR = 110 */ \ - CheckIfElseOutputSAS((type), bool_true, "[" #l0 ", " #l1 ", " #l2 ", " #l3 "]", \ - null_scalar, "[" #l0 ", " #l1 ", " #l2 ", " #l3 "]"); \ - /* CLR = 010 */ \ - CheckIfElseOutputSAS((type), bool_null, "[" #l0 ", " #l1 ", " #l2 ", " #l3 "]", \ - null_scalar, "[null, null, null, null]"); \ - /* CLR = 100 */ \ - CheckIfElseOutputSAS((type), bool_false, "[" #l0 ", " #l1 ", null, null]", \ - null_scalar, "[null, null, null, null]"); \ - /* CLR = 000 */ \ - CheckIfElseOutputSAS((type), bool_null, "[" #l0 ", " #l1 ", null, null]", \ - null_scalar, "[null, null, null, null]"); \ - \ - /* -------- Cond - Scalar, Left- Scalar, Right - Array --------- */ \ - /* empty */ \ - CheckIfElseOutputSSA((type), bool_true, valid_scalar, "[]", "[]"); \ - \ - /* CLR = 111 */ \ - CheckIfElseOutputSSA((type), bool_true, valid_scalar, \ - "[" #l0 ", " #l1 ", " #l2 ", " #l3 "]", \ - "[" #valid ", " #valid ", " #valid ", " #valid "]"); \ - /* CLR = 011 */ \ - CheckIfElseOutputSSA((type), bool_null, valid_scalar, \ - "[" #l0 ", " #l1 ", " #l2 ", " #l3 "]", \ - "[null, null, null, null]"); \ - /* CLR = 110 */ \ - CheckIfElseOutputSSA((type), bool_false, valid_scalar, \ - "[" #l0 ", null, " #l2 ", null]", \ - "[" #l0 ", null, " #l2 ", null]"); \ - /* CLR = 010 */ \ - CheckIfElseOutputSSA((type), bool_null, valid_scalar, \ - "[" #l0 ", null, " #l2 ", null]", "[null, null, null, null]"); \ - /* CLR = 101 */ \ - CheckIfElseOutputSSA((type), bool_true, null_scalar, \ - "[" #l0 ", " #l1 ", " #l2 ", " #l3 "]", \ - "[null, null, null, null]"); \ - /* CLR = 001 */ \ - CheckIfElseOutputSSA((type), bool_null, null_scalar, \ - "[" #l0 ", " #l1 ", " #l2 ", " #l3 "]", \ - "[null, null, null, null]"); \ - /* CLR = 100 */ \ - CheckIfElseOutputSSA((type), bool_false, null_scalar, \ - "[" #l0 ", " #l1 ", null, " #l3 "]", \ - "[" #l0 ", " #l1 ", null, " #l3 "]"); \ - /* CLR = 000 */ \ - CheckIfElseOutputSSA((type), bool_null, null_scalar, \ - "[" #l0 ", " #l1 ", null, " #l3 "]", \ - "[null, null, null, null]"); \ - \ - /* -------- Cond - Scalar, Left- Scalar, Right - Scalar --------- */ \ - \ - /* CLR = 111 */ \ - CheckIfElseOutput(bool_false, valid_scalar, valid_scalar1, valid_scalar1); \ - /* CLR = 011 */ \ - CheckIfElseOutput(bool_null, valid_scalar, valid_scalar1, null_scalar); \ - /* CLR = 110 */ \ - CheckIfElseOutput(bool_true, valid_scalar, null_scalar, valid_scalar); \ - /* CLR = 010 */ \ - CheckIfElseOutput(bool_null, valid_scalar, null_scalar, null_scalar); \ - /* CLR = 101 */ \ - CheckIfElseOutput(bool_false, null_scalar, valid_scalar1, valid_scalar1); \ - /* CLR = 001 */ \ - CheckIfElseOutput(bool_null, null_scalar, valid_scalar1, null_scalar); \ - /* CLR = 100 */ \ - CheckIfElseOutput(bool_true, null_scalar, null_scalar, null_scalar); \ - /* CLR = 000 */ \ - CheckIfElseOutput(bool_null, null_scalar, null_scalar, null_scalar); \ - } while (0) +template +void DoIfElseTest(const std::shared_ptr& type, const std::array& left, + const std::array& right, const std::array& valid_scalars) { + std::array l, r; + std::array v; + + auto to_string = [](const T& i) { return std::to_string(i); }; + std::transform(left.begin(), left.end(), l.begin(), to_string); + std::transform(right.begin(), right.end(), r.begin(), to_string); + std::transform(valid_scalars.begin(), valid_scalars.end(), v.begin(), to_string); + + /* -------- All arrays --------- */ + /* empty */ + CheckIfElseOutputAAA(type, "[]", "[]", "[]", "[]"); + /* CLR = 111 */ + CheckIfElseOutputAAA(type, "[true, true, true, false]", + "[" + l[0] + ", " + l[1] + ", " + l[2] + ", " + l[3] + "]", + "[" + r[0] + ", " + r[1] + ", " + r[2] + ", " + r[3] + "]", + "[" + l[0] + ", " + l[1] + ", " + l[2] + ", " + r[3] + "]"); + /* CLR = 011 */ + CheckIfElseOutputAAA(type, "[true, true, null, false]", + "[" + l[0] + ", " + l[1] + ", " + l[2] + ", " + l[3] + "]", + "[" + r[0] + ", " + r[1] + ", " + r[2] + ", " + r[3] + "]", + "[" + l[0] + ", " + l[1] + ", null, " + r[3] + "]"); + /* CLR = 101 */ + CheckIfElseOutputAAA(type, "[true, true, true, false]", + "[" + l[0] + ", null, " + l[2] + ", " + l[3] + "]", + "[" + r[0] + ", " + r[1] + ", " + r[2] + ", " + r[3] + "]", + "[" + l[0] + ", null, " + l[2] + ", " + r[3] + "]"); + /* CLR = 001 */ + CheckIfElseOutputAAA(type, "[true, true, null, false]", + "[" + l[0] + ", null, " + l[2] + ", " + l[3] + "]", + "[" + r[0] + ", " + r[1] + ", " + r[2] + ", " + r[3] + "]", + "[" + l[0] + ", null, null, " + r[3] + "]"); + /* CLR = 110 */ + CheckIfElseOutputAAA(type, "[true, true, true, false]", + "[" + l[0] + ", " + l[1] + ", " + l[2] + ", " + l[3] + "]", + "[" + r[0] + ", " + r[1] + ", " + r[2] + ", null]", + "[" + l[0] + ", " + l[1] + ", " + l[2] + ", null]"); + /* CLR = 010 */ + CheckIfElseOutputAAA(type, "[null, true, true, false]", + "[" + l[0] + ", " + l[1] + ", " + l[2] + ", " + l[3] + "]", + "[" + r[0] + ", " + r[1] + ", " + r[2] + ", null]", + "[null, " + l[1] + ", " + l[2] + ", null]"); + /* CLR = 100 */ + CheckIfElseOutputAAA(type, "[true, true, true, false]", + "[" + l[0] + ", " + l[1] + ", null, null]", + "[null, " + r[1] + ", " + r[2] + ", null]", + "[" + l[0] + ", " + l[1] + ", null, null]"); + /* CLR = 000 */ + CheckIfElseOutputAAA( + type, "[null, true, true, false]", "[" + l[0] + ", " + l[1] + ", null, null]", + "[null, " + r[1] + ", " + r[2] + ", null]", "[null, " + l[1] + ", null, null]"); + + /* -------- Cond - Array, Left- Array, Right - Scalar --------- */ + ASSERT_OK_AND_ASSIGN(auto valid_scalar, MakeScalar(type, valid_scalars[0])); + ASSERT_OK_AND_ASSIGN(auto valid_scalar1, MakeScalar(type, valid_scalars[1])); + auto null_scalar = MakeNullScalar(type); + + /* empty */ + // CheckIfElseOutputAAS(type, "[]", "[]", valid_scalar, "[]"); + + /* CLR = 111 */ + CheckIfElseOutputAAS(type, "[true, true, true, false]", + "[" + l[0] + ", " + l[1] + ", " + l[2] + ", " + l[3] + "]", + valid_scalar, + "[" + l[0] + ", " + l[1] + ", " + l[2] + ", " + v[0] + "]"); + /* CLR = 011 */ + CheckIfElseOutputAAS(type, "[true, true, null, false]", + "[" + l[0] + ", " + l[1] + ", " + l[2] + ", " + l[3] + "]", + valid_scalar, "[" + l[0] + ", " + l[1] + ", null, " + v[0] + "]"); + /* CLR = 101 */ + CheckIfElseOutputAAS(type, "[true, true, true, false]", + "[" + l[0] + ", null, " + l[2] + ", " + l[3] + "]", valid_scalar, + "[" + l[0] + ", null, " + l[2] + ", " + v[0] + "]"); + /* CLR = 001 */ + CheckIfElseOutputAAS(type, "[true, true, null, false]", + "[" + l[0] + ", null, " + l[2] + ", " + l[3] + "]", valid_scalar, + "[" + l[0] + ", null, null, " + v[0] + "]"); + /* CLR = 110 */ + CheckIfElseOutputAAS(type, "[true, true, true, false]", + "[" + l[0] + ", " + l[1] + ", " + l[2] + ", " + l[3] + "]", + null_scalar, "[" + l[0] + ", " + l[1] + ", " + l[2] + ", null]"); + /* CLR = 010 */ + CheckIfElseOutputAAS(type, "[null, true, true, false]", + "[" + l[0] + ", " + l[1] + ", " + l[2] + ", " + l[3] + "]", + null_scalar, "[null, " + l[1] + ", " + l[2] + ", null]"); + /* CLR = 100 */ + CheckIfElseOutputAAS(type, "[true, true, true, false]", + "[" + l[0] + ", " + l[1] + ", null, null]", null_scalar, + "[" + l[0] + ", " + l[1] + ", null, null]"); + /* CLR = 000 */ + CheckIfElseOutputAAS(type, "[null, true, true, false]", + "[" + l[0] + ", " + l[1] + ", null, null]", null_scalar, + "[null, " + l[1] + ", null, null]"); + + /* -------- Cond - Array, Left- Scalar, Right - Array --------- */ + /* empty */ + CheckIfElseOutputASA(type, "[]", valid_scalar, "[]", "[]"); + + /* CLR = 111 */ + CheckIfElseOutputASA(type, "[true, true, true, false]", valid_scalar, + "[" + l[0] + ", " + l[1] + ", " + l[2] + ", " + l[3] + "]", + "[" + v[0] + ", " + v[0] + ", " + v[0] + ", " + l[3] + "]"); + /* CLR = 011 */ + CheckIfElseOutputASA(type, "[true, true, null, false]", valid_scalar, + "[" + l[0] + ", " + l[1] + ", " + l[2] + ", " + l[3] + "]", + "[" + v[0] + ", " + v[0] + ", null, " + l[3] + "]"); + /* CLR = 110 */ + CheckIfElseOutputASA(type, "[true, true, true, false]", valid_scalar, + "[" + l[0] + ", null, " + l[2] + ", null]", + "[" + v[0] + ", " + v[0] + ", " + v[0] + ", null]"); + /* CLR = 010 */ + CheckIfElseOutputASA(type, "[true, true, null, false]", valid_scalar, + "[" + l[0] + ", null, " + l[2] + ", null]", + "[" + v[0] + ", " + v[0] + ", null, null]"); + /* CLR = 101 */ + CheckIfElseOutputASA(type, "[true, true, true, false]", null_scalar, + "[" + l[0] + ", " + l[1] + ", " + l[2] + ", " + l[3] + "]", + "[null, null, null, " + l[3] + "]"); + /* CLR = 001 */ + CheckIfElseOutputASA(type, "[null, true, true, false]", null_scalar, + "[" + l[0] + ", " + l[1] + ", " + l[2] + ", " + l[3] + "]", + "[null, null, null, " + l[3] + "]"); + /* CLR = 100 */ + CheckIfElseOutputASA(type, "[true, true, true, false]", null_scalar, + "[" + l[0] + ", " + l[1] + ", null, " + l[3] + "]", + "[null, null, null, " + l[3] + "]"); + /* CLR = 000 */ + CheckIfElseOutputASA(type, "[true, true, null, false]", null_scalar, + "[" + l[0] + ", " + l[1] + ", null, " + l[3] + "]", + "[null, null, null, " + l[3] + "]"); + + /* -------- Cond - Array, Left- Scalar, Right - Scalar --------- */ + /* empty */ + CheckIfElseOutputASS(type, "[]", valid_scalar, valid_scalar1, "[]"); + + /* CLR = 111 */ + CheckIfElseOutputASS(type, "[true, true, true, false]", valid_scalar, valid_scalar1, + "[" + v[0] + ", " + v[0] + ", " + v[0] + ", " + v[1] + "]"); + /* CLR = 011 */ + CheckIfElseOutputASS(type, "[true, true, null, false]", valid_scalar, valid_scalar1, + "[" + v[0] + ", " + v[0] + ", null, " + v[1] + "]"); + /* CLR = 010 */ + CheckIfElseOutputASS(type, "[true, true, null, false]", valid_scalar, null_scalar, + "[" + v[0] + ", " + v[0] + ", null, null]"); + /* CLR = 110 */ + CheckIfElseOutputASS(type, "[true, true, true, false]", valid_scalar, null_scalar, + "[" + v[0] + ", " + v[0] + ", " + v[0] + ", null]"); + /* CLR = 101 */ + CheckIfElseOutputASS(type, "[true, true, true, false]", null_scalar, valid_scalar1, + "[null, null, null, " + v[1] + "]"); + /* CLR = 001 */ + CheckIfElseOutputASS(type, "[null, true, true, false]", null_scalar, valid_scalar1, + "[null, null, null, " + v[1] + "]"); + /* CLR = 100 */ + CheckIfElseOutputASS(type, "[true, true, true, false]", null_scalar, null_scalar, + "[null, null, null, null]"); + /* CLR = 000 */ + CheckIfElseOutputASS(type, "[true, true, null, false]", null_scalar, null_scalar, + "[null, null, null, null]"); + + /* -------- Cond - Scalar, Left- Array, Right - Array --------- */ + ASSERT_OK_AND_ASSIGN(auto bool_true, MakeScalar(boolean(), true)); + ASSERT_OK_AND_ASSIGN(auto bool_false, MakeScalar(boolean(), false)); + auto bool_null = MakeNullScalar(boolean()); + + /* empty */ + CheckIfElseOutputSAA(type, bool_true, "[]", "[]", "[]"); + /* CLR = 111 */ + CheckIfElseOutputSAA(type, bool_true, + "[" + l[0] + ", " + l[1] + ", " + l[2] + ", " + l[3] + "]", + "[" + r[0] + ", " + r[1] + ", " + r[2] + ", " + r[3] + "]", + "[" + l[0] + ", " + l[1] + ", " + l[2] + ", " + l[3] + "]"); + /* CLR = 011 */ + CheckIfElseOutputSAA(type, bool_null, + "[" + l[0] + ", " + l[1] + ", " + l[2] + ", " + l[3] + "]", + "[" + r[0] + ", " + r[1] + ", " + r[2] + ", " + r[3] + "]", + "[null, null, null, null]"); + /* CLR = 101 */ + CheckIfElseOutputSAA(type, bool_false, + "[" + l[0] + ", null, " + l[2] + ", " + l[3] + "]", + "[" + r[0] + ", " + r[1] + ", " + r[2] + ", " + r[3] + "]", + "[" + r[0] + ", " + r[1] + ", " + r[2] + ", " + r[3] + "]"); + /* CLR = 001 */ + CheckIfElseOutputSAA(type, bool_null, + "[" + l[0] + ", null, " + l[2] + ", " + l[3] + "]", + "[" + r[0] + ", " + r[1] + ", " + r[2] + ", " + r[3] + "]", + "[null, null, null, null]"); + /* CLR = 110 */ + CheckIfElseOutputSAA(type, bool_false, + "[" + l[0] + ", " + l[1] + ", " + l[2] + ", " + l[3] + "]", + "[" + r[0] + ", " + r[1] + ", " + r[2] + ", null]", + "[" + r[0] + ", " + r[1] + ", " + r[2] + ", null]"); + /* CLR = 010 */ + CheckIfElseOutputSAA( + type, bool_null, "[" + l[0] + ", " + l[1] + ", " + l[2] + ", " + l[3] + "]", + "[" + r[0] + ", " + r[1] + ", " + r[2] + ", null]", "[null, null, null, null]"); + /* CLR = 100 */ + CheckIfElseOutputSAA(type, bool_true, "[" + l[0] + ", " + l[1] + ", null, null]", + "[null, " + r[1] + ", " + r[2] + ", null]", + "[" + l[0] + ", " + l[1] + ", null, null]"); + /* CLR = 000 */ + CheckIfElseOutputSAA(type, bool_null, "[" + l[0] + ", " + l[1] + ", null, null]", + "[null, " + r[1] + ", " + r[2] + ", null]", + "[null, null, null, null]"); + + /* -------- Cond - Scalar, Left- Array, Right - Scalar --------- */ + /* empty */ + CheckIfElseOutputSAS(type, bool_true, "[]", valid_scalar, "[]"); + + /* CLR = 111 */ + CheckIfElseOutputSAS( + type, bool_true, "[" + l[0] + ", " + l[1] + ", " + l[2] + ", " + l[3] + "]", + valid_scalar, "[" + l[0] + ", " + l[1] + ", " + l[2] + ", " + l[3] + "]"); + /* CLR = 011 */ + CheckIfElseOutputSAS(type, bool_null, + "[" + l[0] + ", " + l[1] + ", " + l[2] + ", " + l[3] + "]", + valid_scalar, "[null, null, null, null]"); + /* CLR = 101 */ + CheckIfElseOutputSAS(type, bool_false, + "[" + l[0] + ", null, " + l[2] + ", " + l[3] + "]", valid_scalar, + "[" + v[0] + ", " + v[0] + ", " + v[0] + ", " + v[0] + "]"); + /* CLR = 001 */ + CheckIfElseOutputSAS(type, bool_null, + "[" + l[0] + ", null, " + l[2] + ", " + l[3] + "]", valid_scalar, + "[null, null, null, null]"); + /* CLR = 110 */ + CheckIfElseOutputSAS( + type, bool_true, "[" + l[0] + ", " + l[1] + ", " + l[2] + ", " + l[3] + "]", + null_scalar, "[" + l[0] + ", " + l[1] + ", " + l[2] + ", " + l[3] + "]"); + /* CLR = 010 */ + CheckIfElseOutputSAS(type, bool_null, + "[" + l[0] + ", " + l[1] + ", " + l[2] + ", " + l[3] + "]", + null_scalar, "[null, null, null, null]"); + /* CLR = 100 */ + CheckIfElseOutputSAS(type, bool_false, "[" + l[0] + ", " + l[1] + ", null, null]", + null_scalar, "[null, null, null, null]"); + /* CLR = 000 */ + CheckIfElseOutputSAS(type, bool_null, "[" + l[0] + ", " + l[1] + ", null, null]", + null_scalar, "[null, null, null, null]"); + + /* -------- Cond - Scalar, Left- Scalar, Right - Array --------- */ + /* empty */ + CheckIfElseOutputSSA(type, bool_true, valid_scalar, "[]", "[]"); + + /* CLR = 111 */ + CheckIfElseOutputSSA(type, bool_true, valid_scalar, + "[" + l[0] + ", " + l[1] + ", " + l[2] + ", " + l[3] + "]", + "[" + v[0] + ", " + v[0] + ", " + v[0] + ", " + v[0] + "]"); + /* CLR = 011 */ + CheckIfElseOutputSSA(type, bool_null, valid_scalar, + "[" + l[0] + ", " + l[1] + ", " + l[2] + ", " + l[3] + "]", + "[null, null, null, null]"); + /* CLR = 110 */ + CheckIfElseOutputSSA(type, bool_false, valid_scalar, + "[" + l[0] + ", null, " + l[2] + ", null]", + "[" + l[0] + ", null, " + l[2] + ", null]"); + /* CLR = 010 */ + CheckIfElseOutputSSA(type, bool_null, valid_scalar, + "[" + l[0] + ", null, " + l[2] + ", null]", + "[null, null, null, null]"); + /* CLR = 101 */ + CheckIfElseOutputSSA(type, bool_true, null_scalar, + "[" + l[0] + ", " + l[1] + ", " + l[2] + ", " + l[3] + "]", + "[null, null, null, null]"); + /* CLR = 001 */ + CheckIfElseOutputSSA(type, bool_null, null_scalar, + "[" + l[0] + ", " + l[1] + ", " + l[2] + ", " + l[3] + "]", + "[null, null, null, null]"); + /* CLR = 100 */ + CheckIfElseOutputSSA(type, bool_false, null_scalar, + "[" + l[0] + ", " + l[1] + ", null, " + l[3] + "]", + "[" + l[0] + ", " + l[1] + ", null, " + l[3] + "]"); + /* CLR = 000 */ + CheckIfElseOutputSSA(type, bool_null, null_scalar, + "[" + l[0] + ", " + l[1] + ", null, " + l[3] + "]", + "[null, null, null, null]"); + + /* -------- Cond - Scalar, Left- Scalar, Right - Scalar --------- */ + + /* CLR = 111 */ + CheckIfElseOutput(bool_false, valid_scalar, valid_scalar1, valid_scalar1); + /* CLR = 011 */ + CheckIfElseOutput(bool_null, valid_scalar, valid_scalar1, null_scalar); + /* CLR = 110 */ + CheckIfElseOutput(bool_true, valid_scalar, null_scalar, valid_scalar); + /* CLR = 010 */ + CheckIfElseOutput(bool_null, valid_scalar, null_scalar, null_scalar); + /* CLR = 101 */ + CheckIfElseOutput(bool_false, null_scalar, valid_scalar1, valid_scalar1); + /* CLR = 001 */ + CheckIfElseOutput(bool_null, null_scalar, valid_scalar1, null_scalar); + /* CLR = 100 */ + CheckIfElseOutput(bool_true, null_scalar, null_scalar, null_scalar); + /* CLR = 000 */ + CheckIfElseOutput(bool_null, null_scalar, null_scalar, null_scalar); +} /* * Legend: @@ -425,14 +453,16 @@ TYPED_TEST(TestIfElsePrimitive, IfElseFixedSizeRand) { */ TYPED_TEST(TestIfElsePrimitive, IfElseFixedSize) { auto type = TypeTraits::type_singleton(); + using T = typename TypeTraits::CType; - IF_ELSE_TEST_GEN(type, 1, 2, 3, 4, 5, 6, 7, 8, 100, 111); + DoIfElseTest(type, {1, 2, 3, 4}, {5, 6, 7, 8}, {100, 111}); } TEST_F(TestIfElseKernel, IfElseBoolean) { auto type = boolean(); - IF_ELSE_TEST_GEN(type, false, false, false, false, true, true, true, true, false, true); + DoIfElseTest(type, {false, false, false, false}, {true, true, true, true}, + {false, true}); } TYPED_TEST(TestIfElsePrimitive, IfElseBooleanRand) { From 54575a9d4d42f3ff80130f7c91ba98b7f7a93274 Mon Sep 17 00:00:00 2001 From: niranda perera Date: Fri, 28 May 2021 10:09:34 -0400 Subject: [PATCH 25/39] adding datum level null promotion --- .../arrow/compute/kernels/scalar_if_else.cc | 358 +++++------------- 1 file changed, 94 insertions(+), 264 deletions(-) diff --git a/cpp/src/arrow/compute/kernels/scalar_if_else.cc b/cpp/src/arrow/compute/kernels/scalar_if_else.cc index aa40f852dc52c..fab11fcaf1735 100644 --- a/cpp/src/arrow/compute/kernels/scalar_if_else.cc +++ b/cpp/src/arrow/compute/kernels/scalar_if_else.cc @@ -31,158 +31,50 @@ namespace compute { namespace { -enum { COND_ALL_VALID = 1, LEFT_ALL_VALID = 2, RIGHT_ALL_VALID = 4 }; - -// if the condition is null then output is null otherwise we take validity from the -// selected argument -// ie. cond.valid & (cond.data & left.valid | ~cond.data & right.valid) -Status PromoteNullsVisitor(KernelContext* ctx, const ArrayData& cond, const Scalar& left, - const Scalar& right, ArrayData* output) { - uint8_t flag = right.is_valid * RIGHT_ALL_VALID | - left.is_valid * LEFT_ALL_VALID | - !cond.MayHaveNulls() * COND_ALL_VALID; - - if (flag < 6 && flag != 3) { - // there will be a validity buffer in the output - ARROW_ASSIGN_OR_RAISE(output->buffers[0], ctx->AllocateBitmap(cond.length)); - } - - // if the condition is null then output is null otherwise we take validity from the - // selected argument - // ie. cond.valid & (cond.data & left.valid | ~cond.data & right.valid) - switch (flag) { - case COND_ALL_VALID | LEFT_ALL_VALID | RIGHT_ALL_VALID: // = 7 - break; - case LEFT_ALL_VALID | RIGHT_ALL_VALID: // = 6 - // out_valid = c_valid - output->buffers[0] = SliceBuffer(cond.buffers[0], cond.offset, cond.length); - break; - case COND_ALL_VALID | RIGHT_ALL_VALID: // = 5 - // out_valid = ~cond.data - arrow::internal::InvertBitmap(cond.buffers[1]->data(), cond.offset, cond.length, - output->buffers[0]->mutable_data(), 0); - break; - case RIGHT_ALL_VALID: // = 4 - // out_valid = c_valid & ~cond.data - arrow::internal::BitmapAndNot(cond.buffers[0]->data(), cond.offset, - cond.buffers[1]->data(), cond.offset, cond.length, 0, - output->buffers[0]->mutable_data()); - break; - case COND_ALL_VALID | LEFT_ALL_VALID: // = 3 - // out_valid = cond.data - output->buffers[0] = SliceBuffer(cond.buffers[1], cond.offset, cond.length); - break; - case LEFT_ALL_VALID: // = 2 - // out_valid = cond.valid & cond.data - arrow::internal::BitmapAnd(cond.buffers[0]->data(), cond.offset, - cond.buffers[1]->data(), cond.offset, cond.length, 0, - output->buffers[0]->mutable_data()); - break; - case COND_ALL_VALID: // = 1 - // out_valid = 0 --> nothing to do; but requires out_valid to be a all-zero buffer - break; - case 0: - // out_valid = 0 --> nothing to do; but requires out_valid to be a all-zero buffer - break; +util::optional GetConstantValidityWord(const Datum& data) { + if (data.is_scalar()) { + return data.scalar()->is_valid ? ~uint64_t(0) : uint64_t(0); } - return Status::OK(); -} -Status PromoteNullsVisitor(KernelContext* ctx, const ArrayData& cond, - const ArrayData& left, const Scalar& right, - ArrayData* output) { - uint8_t flag = right.is_valid * 4 + !left.MayHaveNulls() * 2 + !cond.MayHaveNulls(); + if (data.array()->null_count == data.array()->length) return uint64_t(0); - enum { C_VALID, C_DATA, L_VALID }; + if (!data.array()->MayHaveNulls()) return ~uint64_t(0); - Bitmap bitmaps[3]; - bitmaps[C_VALID] = {cond.buffers[0], cond.offset, cond.length}; - bitmaps[C_DATA] = {cond.buffers[1], cond.offset, cond.length}; - bitmaps[L_VALID] = {left.buffers[0], left.offset, left.length}; - - uint64_t* out_validity = nullptr; - if (flag < 6 && flag != 3) { - // there will be a validity buffer in the output - ARROW_ASSIGN_OR_RAISE(output->buffers[0], ctx->AllocateBitmap(cond.length)); - out_validity = output->GetMutableValues(0); - } - - // lambda function that will be used inside the visitor - int64_t i = 0; - auto apply = [&](uint64_t c_valid, uint64_t c_data, uint64_t l_valid, - uint64_t r_valid) { - out_validity[i] = c_valid & ((c_data & l_valid) | (~c_data & r_valid)); - i++; - }; + // no constant validity word available + return {}; +} - // if the condition is null then output is null otherwise we take validity from the - // selected argument - // ie. cond.valid & (cond.data & left.valid | ~cond.data & right.valid) - switch (flag) { - case COND_ALL_VALID | LEFT_ALL_VALID | RIGHT_ALL_VALID: - break; - case LEFT_ALL_VALID | RIGHT_ALL_VALID: - output->buffers[0] = SliceBuffer(cond.buffers[0], cond.offset, cond.length); - break; - case COND_ALL_VALID | RIGHT_ALL_VALID: - // bitmaps[C_VALID] might be null; override to make it safe for Visit() - bitmaps[C_VALID] = bitmaps[C_DATA]; - Bitmap::VisitWords(bitmaps, [&](std::array words) { - apply(UINT64_MAX, words[C_DATA], words[L_VALID], UINT64_MAX); - }); - break; - case RIGHT_ALL_VALID: - Bitmap::VisitWords(bitmaps, [&](std::array words) { - apply(words[C_VALID], words[C_DATA], words[L_VALID], UINT64_MAX); - }); - break; - case COND_ALL_VALID | LEFT_ALL_VALID: - // only cond.data is passed - output->buffers[0] = SliceBuffer(cond.buffers[1], cond.offset, cond.length); - break; - case LEFT_ALL_VALID: - // out_valid = cond.valid & cond.data - arrow::internal::BitmapAnd(cond.buffers[0]->data(), cond.offset, - cond.buffers[1]->data(), cond.offset, cond.length, 0, - output->buffers[0]->mutable_data()); - break; - case COND_ALL_VALID: - // out_valid = cond.data & left.valid - arrow::internal::BitmapAnd(cond.buffers[1]->data(), cond.offset, - left.buffers[0]->data(), left.offset, cond.length, 0, - output->buffers[0]->mutable_data()); - break; - case 0: - Bitmap::VisitWords(bitmaps, [&](std::array words) { - apply(words[C_VALID], words[C_DATA], words[L_VALID], 0); - }); - break; - } - return Status::OK(); +inline Bitmap GetBitmap(const Datum& datum, int i) { + if (datum.is_scalar()) return {}; + const ArrayData& a = *datum.array(); + return Bitmap{a.buffers[i], a.offset, a.length}; } // if the condition is null then output is null otherwise we take validity from the // selected argument // ie. cond.valid & (cond.data & left.valid | ~cond.data & right.valid) -Status PromoteNullsVisitor(KernelContext* ctx, const ArrayData& cond, const Scalar& left, - const ArrayData& right, ArrayData* output) { - uint8_t flag = !right.MayHaveNulls() * 4 + left.is_valid * 2 + !cond.MayHaveNulls(); - - enum { C_VALID, C_DATA, R_VALID }; - - Bitmap bitmaps[3]; - bitmaps[C_VALID] = {cond.buffers[0], cond.offset, cond.length}; - bitmaps[C_DATA] = {cond.buffers[1], cond.offset, cond.length}; - bitmaps[R_VALID] = {right.buffers[0], right.offset, right.length}; - - uint64_t* out_validity = nullptr; - if (flag < 6) { - // there will be a validity buffer in the output - ARROW_ASSIGN_OR_RAISE(output->buffers[0], ctx->AllocateBitmap(cond.length)); - out_validity = output->GetMutableValues(0); - } +Status PromoteNullsVisitor(KernelContext* ctx, const Datum& cond_d, const Datum& left_d, + const Datum& right_d, ArrayData* output) { + auto cond_const = GetConstantValidityWord(cond_d); + auto left_const = GetConstantValidityWord(left_d); + auto right_const = GetConstantValidityWord(right_d); + + enum { COND_CONST = 1, LEFT_CONST = 2, RIGHT_CONST = 4 }; + auto flag = COND_CONST * cond_const.has_value() | LEFT_CONST * left_const.has_value() | + RIGHT_CONST * right_const.has_value(); + + const ArrayData& cond = *cond_d.array(); + // cond.data will always be available + Bitmap cond_data{cond.buffers[1], cond.offset, cond.length}; + Bitmap cond_valid{cond.buffers[0], cond.offset, cond.length}; + Bitmap left_valid = GetBitmap(left_d, 0); + Bitmap right_valid = GetBitmap(right_d, 0); + // sometimes Bitmaps will be ignored, in which case we replace access to them with + // duplicated (probably elided) access to cond_data + const Bitmap& _ = cond_data; // lambda function that will be used inside the visitor + uint64_t* out_validity = nullptr; int64_t i = 0; auto apply = [&](uint64_t c_valid, uint64_t c_data, uint64_t l_valid, uint64_t r_valid) { @@ -193,134 +85,84 @@ Status PromoteNullsVisitor(KernelContext* ctx, const ArrayData& cond, const Scal // if the condition is null then output is null otherwise we take validity from the // selected argument // ie. cond.valid & (cond.data & left.valid | ~cond.data & right.valid) + + // In the following cases, we dont need to allocate out_valid bitmap switch (flag) { - case COND_ALL_VALID | LEFT_ALL_VALID | RIGHT_ALL_VALID: - break; - case LEFT_ALL_VALID | RIGHT_ALL_VALID: - output->buffers[0] = SliceBuffer(cond.buffers[0], cond.offset, cond.length); - break; - case COND_ALL_VALID | RIGHT_ALL_VALID: - // out_valid = ~cond.data - arrow::internal::InvertBitmap(cond.buffers[1]->data(), cond.offset, cond.length, - output->buffers[0]->mutable_data(), 0); - break; - case RIGHT_ALL_VALID: - // out_valid = c_valid & ~cond.data - arrow::internal::BitmapAndNot(cond.buffers[0]->data(), cond.offset, - cond.buffers[1]->data(), cond.offset, cond.length, 0, - output->buffers[0]->mutable_data()); - break; - case COND_ALL_VALID | LEFT_ALL_VALID: - // bitmaps[C_VALID] might be null; override to make it safe for Visit() - bitmaps[C_VALID] = bitmaps[C_DATA]; - Bitmap::VisitWords(bitmaps, [&](std::array words) { - apply(UINT64_MAX, words[C_DATA], UINT64_MAX, words[R_VALID]); - }); - break; - case LEFT_ALL_VALID: - Bitmap::VisitWords(bitmaps, [&](std::array words) { - apply(words[C_VALID], words[C_DATA], UINT64_MAX, words[R_VALID]); - }); - break; - case COND_ALL_VALID: - // out_valid = ~cond.data & right.valid - arrow::internal::BitmapAndNot(right.buffers[0]->data(), right.offset, - cond.buffers[1]->data(), cond.offset, cond.length, 0, - output->buffers[0]->mutable_data()); - break; - case 0: - Bitmap::VisitWords(bitmaps, [&](std::array words) { - apply(words[C_VALID], words[C_DATA], 0, words[R_VALID]); - }); - break; + case COND_CONST | LEFT_CONST | RIGHT_CONST: + // if cond & left & right all ones, then output is all valid --> out_valid = nullptr + if ((*cond_const & *left_const & *right_const) == UINT64_MAX) { + return Status::OK(); + } + case LEFT_CONST | RIGHT_CONST: + // if both left and right are valid, no need to calculate out_valid bitmap. Pass + // cond validity buffer + if ((*left_const & *right_const) == UINT64_MAX) { + output->buffers[0] = SliceBuffer(cond.buffers[0], cond.offset, cond.length); + return Status::OK(); + } } - return Status::OK(); -} -// if the condition is null then output is null otherwise we take validity from the -// selected argument -// ie. cond.valid & (cond.data & left.valid | ~cond.data & right.valid) -Status PromoteNullsVisitor(KernelContext* ctx, const ArrayData& cond, - const ArrayData& left, const ArrayData& right, - ArrayData* output) { - uint8_t flag = - !right.MayHaveNulls() * 4 + !left.MayHaveNulls() * 2 + !cond.MayHaveNulls(); + // following cases requires a separate out_valid buffer + ARROW_ASSIGN_OR_RAISE(output->buffers[0], ctx->AllocateBitmap(cond.length)); + out_validity = output->GetMutableValues(0); enum { C_VALID, C_DATA, L_VALID, R_VALID }; - Bitmap bitmaps[4]; - bitmaps[C_VALID] = {cond.buffers[0], cond.offset, cond.length}; - bitmaps[C_DATA] = {cond.buffers[1], cond.offset, cond.length}; - bitmaps[L_VALID] = {left.buffers[0], left.offset, left.length}; - bitmaps[R_VALID] = {right.buffers[0], right.offset, right.length}; - - uint64_t* out_validity = nullptr; - if (flag < 6) { - // there will be a validity buffer in the output - ARROW_ASSIGN_OR_RAISE(output->buffers[0], ctx->AllocateBitmap(cond.length)); - out_validity = output->GetMutableValues(0); - } - - // lambda function that will be used inside the visitor - int64_t i = 0; - auto apply = [&](uint64_t c_valid, uint64_t c_data, uint64_t l_valid, - uint64_t r_valid) { - out_validity[i] = c_valid & ((c_data & l_valid) | (~c_data & r_valid)); - i++; - }; - - // if the condition is null then output is null otherwise we take validity from the - // selected argument - // ie. cond.valid & (cond.data & left.valid | ~cond.data & right.valid) switch (flag) { - case COND_ALL_VALID | LEFT_ALL_VALID | RIGHT_ALL_VALID: + case COND_CONST | LEFT_CONST | RIGHT_CONST: + Bitmap::VisitWords({_, cond_data, _, _}, [&](std::array words) { + apply(*cond_const, words[C_DATA], *left_const, *right_const); + }); break; - case LEFT_ALL_VALID | RIGHT_ALL_VALID: - output->buffers[0] = SliceBuffer(cond.buffers[0], cond.offset, cond.length); + case LEFT_CONST | RIGHT_CONST: + Bitmap::VisitWords( + {cond_valid, cond_data, _, _}, [&](std::array words) { + apply(words[C_VALID], words[C_DATA], *left_const, *right_const); + }); break; - case COND_ALL_VALID | RIGHT_ALL_VALID: + case COND_CONST | RIGHT_CONST: // bitmaps[C_VALID], bitmaps[R_VALID] might be null; override to make it safe for // Visit() - bitmaps[C_VALID] = bitmaps[C_DATA]; - bitmaps[R_VALID] = bitmaps[C_DATA]; - Bitmap::VisitWords(bitmaps, [&](std::array words) { - apply(UINT64_MAX, words[C_DATA], words[L_VALID], UINT64_MAX); - }); + Bitmap::VisitWords( + {_, cond_data, left_valid, _}, [&](std::array words) { + apply(*cond_const, words[C_DATA], words[L_VALID], *right_const); + }); break; - case RIGHT_ALL_VALID: + case RIGHT_CONST: // bitmaps[R_VALID] might be null; override to make it safe for Visit() - bitmaps[R_VALID] = bitmaps[C_DATA]; - Bitmap::VisitWords(bitmaps, [&](std::array words) { - apply(words[C_VALID], words[C_DATA], words[L_VALID], UINT64_MAX); - }); + Bitmap::VisitWords( + {cond_valid, cond_data, left_valid, _}, [&](std::array words) { + apply(words[C_VALID], words[C_DATA], words[L_VALID], *right_const); + }); break; - case COND_ALL_VALID | LEFT_ALL_VALID: + case COND_CONST | LEFT_CONST: // bitmaps[C_VALID], bitmaps[L_VALID] might be null; override to make it safe for // Visit() - bitmaps[C_VALID] = bitmaps[C_DATA]; - bitmaps[L_VALID] = bitmaps[C_DATA]; - Bitmap::VisitWords(bitmaps, [&](std::array words) { - apply(UINT64_MAX, words[C_DATA], UINT64_MAX, words[R_VALID]); - }); + Bitmap::VisitWords({_, cond_data, _, right_valid}, + [&](std::array words) { + apply(*cond_const, words[C_DATA], *left_const, words[R_VALID]); + }); break; - case LEFT_ALL_VALID: + case LEFT_CONST: // bitmaps[L_VALID] might be null; override to make it safe for Visit() - bitmaps[L_VALID] = bitmaps[C_DATA]; - Bitmap::VisitWords(bitmaps, [&](std::array words) { - apply(words[C_VALID], words[C_DATA], UINT64_MAX, words[R_VALID]); - }); + Bitmap::VisitWords( + {cond_valid, cond_data, _, right_valid}, [&](std::array words) { + apply(words[C_VALID], words[C_DATA], *left_const, words[R_VALID]); + }); break; - case COND_ALL_VALID: + case COND_CONST: // bitmaps[C_VALID] might be null; override to make it safe for Visit() - bitmaps[C_VALID] = bitmaps[C_DATA]; - Bitmap::VisitWords(bitmaps, [&](std::array words) { - apply(UINT64_MAX, words[C_DATA], words[L_VALID], words[R_VALID]); - }); + Bitmap::VisitWords( + {_, cond_data, left_valid, right_valid}, [&](std::array words) { + apply(*cond_const, words[C_DATA], words[L_VALID], words[R_VALID]); + }); break; case 0: - Bitmap::VisitWords(bitmaps, [&](std::array words) { - apply(words[C_VALID], words[C_DATA], words[L_VALID], words[R_VALID]); - }); + Bitmap::VisitWords({cond_valid, cond_data, left_valid, right_valid}, + [&](std::array words) { + apply(words[C_VALID], words[C_DATA], words[L_VALID], + words[R_VALID]); + }); break; } return Status::OK(); @@ -341,8 +183,6 @@ struct IfElseFunctor> { // AAA static Status Call(KernelContext* ctx, const ArrayData& cond, const ArrayData& left, const ArrayData& right, ArrayData* out) { - ARROW_RETURN_NOT_OK(PromoteNullsVisitor(ctx, cond, left, right, out)); - ARROW_ASSIGN_OR_RAISE(std::shared_ptr out_buf, ctx->Allocate(cond.length * sizeof(T))); T* out_values = reinterpret_cast(out_buf->mutable_data()); @@ -383,8 +223,6 @@ struct IfElseFunctor> { // ASA static Status Call(KernelContext* ctx, const ArrayData& cond, const Scalar& left, const ArrayData& right, ArrayData* out) { - ARROW_RETURN_NOT_OK(PromoteNullsVisitor(ctx, cond, left, right, out)); - ARROW_ASSIGN_OR_RAISE(std::shared_ptr out_buf, ctx->Allocate(cond.length * sizeof(T))); T* out_values = reinterpret_cast(out_buf->mutable_data()); @@ -424,8 +262,6 @@ struct IfElseFunctor> { // AAS static Status Call(KernelContext* ctx, const ArrayData& cond, const ArrayData& left, const Scalar& right, ArrayData* out) { - ARROW_RETURN_NOT_OK(PromoteNullsVisitor(ctx, cond, left, right, out)); - ARROW_ASSIGN_OR_RAISE(std::shared_ptr out_buf, ctx->Allocate(cond.length * sizeof(T))); T* out_values = reinterpret_cast(out_buf->mutable_data()); @@ -466,8 +302,6 @@ struct IfElseFunctor> { // ASS static Status Call(KernelContext* ctx, const ArrayData& cond, const Scalar& left, const Scalar& right, ArrayData* out) { - ARROW_RETURN_NOT_OK(PromoteNullsVisitor(ctx, cond, left, right, out)); - ARROW_ASSIGN_OR_RAISE(std::shared_ptr out_buf, ctx->Allocate(cond.length * sizeof(T))); T* out_values = reinterpret_cast(out_buf->mutable_data()); @@ -510,8 +344,6 @@ struct IfElseFunctor> { // AAA static Status Call(KernelContext* ctx, const ArrayData& cond, const ArrayData& left, const ArrayData& right, ArrayData* out) { - ARROW_RETURN_NOT_OK(PromoteNullsVisitor(ctx, cond, left, right, out)); - // out_buff = right & ~cond ARROW_ASSIGN_OR_RAISE(std::shared_ptr out_buf, arrow::internal::BitmapAndNot( @@ -533,8 +365,6 @@ struct IfElseFunctor> { // ASA static Status Call(KernelContext* ctx, const ArrayData& cond, const Scalar& left, const ArrayData& right, ArrayData* out) { - ARROW_RETURN_NOT_OK(PromoteNullsVisitor(ctx, cond, left, right, out)); - // out_buff = right & ~cond ARROW_ASSIGN_OR_RAISE(std::shared_ptr out_buf, arrow::internal::BitmapAndNot( @@ -555,8 +385,6 @@ struct IfElseFunctor> { // AAS static Status Call(KernelContext* ctx, const ArrayData& cond, const ArrayData& left, const Scalar& right, ArrayData* out) { - ARROW_RETURN_NOT_OK(PromoteNullsVisitor(ctx, cond, left, right, out)); - // out_buff = left & cond ARROW_ASSIGN_OR_RAISE(std::shared_ptr out_buf, arrow::internal::BitmapAnd( @@ -578,8 +406,6 @@ struct IfElseFunctor> { // ASS static Status Call(KernelContext* ctx, const ArrayData& cond, const Scalar& left, const Scalar& right, ArrayData* out) { - ARROW_RETURN_NOT_OK(PromoteNullsVisitor(ctx, cond, left, right, out)); - bool left_data = internal::UnboxScalar::Unbox(left); bool right_data = internal::UnboxScalar::Unbox(right); @@ -657,9 +483,10 @@ struct ResolveIfElseExec { } else { *out = MakeNullScalar(batch[1].type()); } + return Status::OK(); } else { // either left or right is an array. output is always an array // output size is the size of the array arg - int64_t bcast_size = batch[1].is_array() ? batch[1].length() : batch[2].length(); + auto bcast_size = batch.length; if (cond.is_valid) { const auto& valid_data = cond.value ? batch[1] : batch[2]; if (valid_data.is_array()) { @@ -673,11 +500,14 @@ struct ResolveIfElseExec { ARROW_ASSIGN_OR_RAISE( *out, MakeArrayOfNull(batch[1].type(), bcast_size, ctx->memory_pool())) } + return Status::OK(); } - return Status::OK(); } // cond is array. Use functors to sort things out + ARROW_RETURN_NOT_OK( + PromoteNullsVisitor(ctx, batch[0], batch[1], batch[2], out->mutable_array())); + if (batch[1].kind() == Datum::ARRAY) { if (batch[2].kind() == Datum::ARRAY) { // AAA return IfElseFunctor::Call(ctx, *batch[0].array(), *batch[1].array(), From e32d709204351be7eb5af09f9c4eea94e718f320 Mon Sep 17 00:00:00 2001 From: niranda perera Date: Mon, 31 May 2021 12:27:50 -0400 Subject: [PATCH 26/39] fixing slicing issue --- cpp/src/arrow/compute/kernels/scalar_if_else.cc | 15 ++++++++++----- .../arrow/compute/kernels/scalar_if_else_test.cc | 8 ++++++++ 2 files changed, 18 insertions(+), 5 deletions(-) diff --git a/cpp/src/arrow/compute/kernels/scalar_if_else.cc b/cpp/src/arrow/compute/kernels/scalar_if_else.cc index fab11fcaf1735..4bdf9d2261022 100644 --- a/cpp/src/arrow/compute/kernels/scalar_if_else.cc +++ b/cpp/src/arrow/compute/kernels/scalar_if_else.cc @@ -82,10 +82,7 @@ Status PromoteNullsVisitor(KernelContext* ctx, const Datum& cond_d, const Datum& i++; }; - // if the condition is null then output is null otherwise we take validity from the - // selected argument - // ie. cond.valid & (cond.data & left.valid | ~cond.data & right.valid) - + // cond.valid & (cond.data & left.valid | ~cond.data & right.valid) // In the following cases, we dont need to allocate out_valid bitmap switch (flag) { case COND_CONST | LEFT_CONST | RIGHT_CONST: @@ -97,7 +94,15 @@ Status PromoteNullsVisitor(KernelContext* ctx, const Datum& cond_d, const Datum& // if both left and right are valid, no need to calculate out_valid bitmap. Pass // cond validity buffer if ((*left_const & *right_const) == UINT64_MAX) { - output->buffers[0] = SliceBuffer(cond.buffers[0], cond.offset, cond.length); + // if there's an offset, copy bitmap (cannot slice a bitmap) + if (cond.offset) { + ARROW_ASSIGN_OR_RAISE( + output->buffers[0], + arrow::internal::CopyBitmap(ctx->memory_pool(), cond.buffers[0]->data(), + cond.offset, cond.length)); + } else { // just copy assign cond validity buffer + output->buffers[0] = cond.buffers[0]; + } return Status::OK(); } } diff --git a/cpp/src/arrow/compute/kernels/scalar_if_else_test.cc b/cpp/src/arrow/compute/kernels/scalar_if_else_test.cc index a74342c843a1e..1934d36e3aaee 100644 --- a/cpp/src/arrow/compute/kernels/scalar_if_else_test.cc +++ b/cpp/src/arrow/compute/kernels/scalar_if_else_test.cc @@ -500,5 +500,13 @@ TEST_F(TestIfElseKernel, IfElseNull) { "[null, null, null, null]", "[null, null, null, null]"); } +TEST_F(TestIfElseKernel, IfElseWithOffset) { + auto cond = ArrayFromJSON(boolean(), "[null, true, false]")->Slice(1, 2); + auto left = ArrayFromJSON(int64(), "[10, 11]"); + auto right = ArrayFromJSON(int64(), "[1, 2]"); + auto expected = ArrayFromJSON(int64(), "[10, 2]"); + CheckIfElseOutput(cond, left, right, expected); +} + } // namespace compute } // namespace arrow From e49842a5efc3c6af325231b63b807b219bb5beff Mon Sep 17 00:00:00 2001 From: niranda perera Date: Mon, 31 May 2021 18:03:08 -0400 Subject: [PATCH 27/39] rewriting the tests --- .../compute/kernels/scalar_if_else_test.cc | 503 +++++------------- 1 file changed, 129 insertions(+), 374 deletions(-) diff --git a/cpp/src/arrow/compute/kernels/scalar_if_else_test.cc b/cpp/src/arrow/compute/kernels/scalar_if_else_test.cc index 1934d36e3aaee..f4eb1a6e1bed3 100644 --- a/cpp/src/arrow/compute/kernels/scalar_if_else_test.cc +++ b/cpp/src/arrow/compute/kernels/scalar_if_else_test.cc @@ -39,70 +39,6 @@ void CheckIfElseOutput(const Datum& cond, const Datum& left, const Datum& right, } } -void CheckIfElseOutputAAA(const std::shared_ptr& type, const std::string& cond, - const std::string& left, const std::string& right, - const std::string& expected) { - const std::shared_ptr& cond_ = ArrayFromJSON(boolean(), cond); - const std::shared_ptr& left_ = ArrayFromJSON(type, left); - const std::shared_ptr& right_ = ArrayFromJSON(type, right); - const std::shared_ptr& expected_ = ArrayFromJSON(type, expected); - CheckIfElseOutput(cond_, left_, right_, expected_); -} - -void CheckIfElseOutputAAS(const std::shared_ptr& type, const std::string& cond, - const std::string& left, const std::shared_ptr& right, - const std::string& expected) { - const std::shared_ptr& cond_ = ArrayFromJSON(boolean(), cond); - const std::shared_ptr& left_ = ArrayFromJSON(type, left); - const std::shared_ptr& expected_ = ArrayFromJSON(type, expected); - CheckIfElseOutput(cond_, left_, right, expected_); -} - -void CheckIfElseOutputASA(const std::shared_ptr& type, const std::string& cond, - const std::shared_ptr& left, const std::string& right, - const std::string& expected) { - const std::shared_ptr& cond_ = ArrayFromJSON(boolean(), cond); - const std::shared_ptr& right_ = ArrayFromJSON(type, right); - const std::shared_ptr& expected_ = ArrayFromJSON(type, expected); - CheckIfElseOutput(cond_, left, right_, expected_); -} - -void CheckIfElseOutputASS(const std::shared_ptr& type, const std::string& cond, - const std::shared_ptr& left, - const std::shared_ptr& right, - const std::string& expected) { - const std::shared_ptr& cond_ = ArrayFromJSON(boolean(), cond); - const std::shared_ptr& expected_ = ArrayFromJSON(type, expected); - CheckIfElseOutput(cond_, left, right, expected_); -} - -void CheckIfElseOutputSAA(const std::shared_ptr& type, - const std::shared_ptr& cond, const std::string& left, - const std::string& right, const std::string& expected) { - const std::shared_ptr& left_ = ArrayFromJSON(type, left); - const std::shared_ptr& right_ = ArrayFromJSON(type, right); - const std::shared_ptr& expected_ = ArrayFromJSON(type, expected); - CheckIfElseOutput(cond, left_, right_, expected_); -} - -void CheckIfElseOutputSAS(const std::shared_ptr& type, - const std::shared_ptr& cond, const std::string& left, - const std::shared_ptr& right, - const std::string& expected) { - const std::shared_ptr& left_ = ArrayFromJSON(type, left); - const std::shared_ptr& expected_ = ArrayFromJSON(type, expected); - CheckIfElseOutput(cond, left_, right, expected_); -} - -void CheckIfElseOutputSSA(const std::shared_ptr& type, - const std::shared_ptr& cond, - const std::shared_ptr& left, const std::string& right, - const std::string& expected) { - const std::shared_ptr& right_ = ArrayFromJSON(type, right); - const std::shared_ptr& expected_ = ArrayFromJSON(type, expected); - CheckIfElseOutput(cond, left, right_, expected_); -} - class TestIfElseKernel : public ::testing::Test {}; template @@ -147,322 +83,139 @@ TYPED_TEST(TestIfElsePrimitive, IfElseFixedSizeRand) { CheckIfElseOutput(cond, left, right, expected_data); } -template -void DoIfElseTest(const std::shared_ptr& type, const std::array& left, - const std::array& right, const std::array& valid_scalars) { - std::array l, r; - std::array v; - - auto to_string = [](const T& i) { return std::to_string(i); }; - std::transform(left.begin(), left.end(), l.begin(), to_string); - std::transform(right.begin(), right.end(), r.begin(), to_string); - std::transform(valid_scalars.begin(), valid_scalars.end(), v.begin(), to_string); - - /* -------- All arrays --------- */ - /* empty */ - CheckIfElseOutputAAA(type, "[]", "[]", "[]", "[]"); - /* CLR = 111 */ - CheckIfElseOutputAAA(type, "[true, true, true, false]", - "[" + l[0] + ", " + l[1] + ", " + l[2] + ", " + l[3] + "]", - "[" + r[0] + ", " + r[1] + ", " + r[2] + ", " + r[3] + "]", - "[" + l[0] + ", " + l[1] + ", " + l[2] + ", " + r[3] + "]"); - /* CLR = 011 */ - CheckIfElseOutputAAA(type, "[true, true, null, false]", - "[" + l[0] + ", " + l[1] + ", " + l[2] + ", " + l[3] + "]", - "[" + r[0] + ", " + r[1] + ", " + r[2] + ", " + r[3] + "]", - "[" + l[0] + ", " + l[1] + ", null, " + r[3] + "]"); - /* CLR = 101 */ - CheckIfElseOutputAAA(type, "[true, true, true, false]", - "[" + l[0] + ", null, " + l[2] + ", " + l[3] + "]", - "[" + r[0] + ", " + r[1] + ", " + r[2] + ", " + r[3] + "]", - "[" + l[0] + ", null, " + l[2] + ", " + r[3] + "]"); - /* CLR = 001 */ - CheckIfElseOutputAAA(type, "[true, true, null, false]", - "[" + l[0] + ", null, " + l[2] + ", " + l[3] + "]", - "[" + r[0] + ", " + r[1] + ", " + r[2] + ", " + r[3] + "]", - "[" + l[0] + ", null, null, " + r[3] + "]"); - /* CLR = 110 */ - CheckIfElseOutputAAA(type, "[true, true, true, false]", - "[" + l[0] + ", " + l[1] + ", " + l[2] + ", " + l[3] + "]", - "[" + r[0] + ", " + r[1] + ", " + r[2] + ", null]", - "[" + l[0] + ", " + l[1] + ", " + l[2] + ", null]"); - /* CLR = 010 */ - CheckIfElseOutputAAA(type, "[null, true, true, false]", - "[" + l[0] + ", " + l[1] + ", " + l[2] + ", " + l[3] + "]", - "[" + r[0] + ", " + r[1] + ", " + r[2] + ", null]", - "[null, " + l[1] + ", " + l[2] + ", null]"); - /* CLR = 100 */ - CheckIfElseOutputAAA(type, "[true, true, true, false]", - "[" + l[0] + ", " + l[1] + ", null, null]", - "[null, " + r[1] + ", " + r[2] + ", null]", - "[" + l[0] + ", " + l[1] + ", null, null]"); - /* CLR = 000 */ - CheckIfElseOutputAAA( - type, "[null, true, true, false]", "[" + l[0] + ", " + l[1] + ", null, null]", - "[null, " + r[1] + ", " + r[2] + ", null]", "[null, " + l[1] + ", null, null]"); - - /* -------- Cond - Array, Left- Array, Right - Scalar --------- */ - ASSERT_OK_AND_ASSIGN(auto valid_scalar, MakeScalar(type, valid_scalars[0])); - ASSERT_OK_AND_ASSIGN(auto valid_scalar1, MakeScalar(type, valid_scalars[1])); - auto null_scalar = MakeNullScalar(type); - - /* empty */ - // CheckIfElseOutputAAS(type, "[]", "[]", valid_scalar, "[]"); - - /* CLR = 111 */ - CheckIfElseOutputAAS(type, "[true, true, true, false]", - "[" + l[0] + ", " + l[1] + ", " + l[2] + ", " + l[3] + "]", - valid_scalar, - "[" + l[0] + ", " + l[1] + ", " + l[2] + ", " + v[0] + "]"); - /* CLR = 011 */ - CheckIfElseOutputAAS(type, "[true, true, null, false]", - "[" + l[0] + ", " + l[1] + ", " + l[2] + ", " + l[3] + "]", - valid_scalar, "[" + l[0] + ", " + l[1] + ", null, " + v[0] + "]"); - /* CLR = 101 */ - CheckIfElseOutputAAS(type, "[true, true, true, false]", - "[" + l[0] + ", null, " + l[2] + ", " + l[3] + "]", valid_scalar, - "[" + l[0] + ", null, " + l[2] + ", " + v[0] + "]"); - /* CLR = 001 */ - CheckIfElseOutputAAS(type, "[true, true, null, false]", - "[" + l[0] + ", null, " + l[2] + ", " + l[3] + "]", valid_scalar, - "[" + l[0] + ", null, null, " + v[0] + "]"); - /* CLR = 110 */ - CheckIfElseOutputAAS(type, "[true, true, true, false]", - "[" + l[0] + ", " + l[1] + ", " + l[2] + ", " + l[3] + "]", - null_scalar, "[" + l[0] + ", " + l[1] + ", " + l[2] + ", null]"); - /* CLR = 010 */ - CheckIfElseOutputAAS(type, "[null, true, true, false]", - "[" + l[0] + ", " + l[1] + ", " + l[2] + ", " + l[3] + "]", - null_scalar, "[null, " + l[1] + ", " + l[2] + ", null]"); - /* CLR = 100 */ - CheckIfElseOutputAAS(type, "[true, true, true, false]", - "[" + l[0] + ", " + l[1] + ", null, null]", null_scalar, - "[" + l[0] + ", " + l[1] + ", null, null]"); - /* CLR = 000 */ - CheckIfElseOutputAAS(type, "[null, true, true, false]", - "[" + l[0] + ", " + l[1] + ", null, null]", null_scalar, - "[null, " + l[1] + ", null, null]"); - - /* -------- Cond - Array, Left- Scalar, Right - Array --------- */ - /* empty */ - CheckIfElseOutputASA(type, "[]", valid_scalar, "[]", "[]"); - - /* CLR = 111 */ - CheckIfElseOutputASA(type, "[true, true, true, false]", valid_scalar, - "[" + l[0] + ", " + l[1] + ", " + l[2] + ", " + l[3] + "]", - "[" + v[0] + ", " + v[0] + ", " + v[0] + ", " + l[3] + "]"); - /* CLR = 011 */ - CheckIfElseOutputASA(type, "[true, true, null, false]", valid_scalar, - "[" + l[0] + ", " + l[1] + ", " + l[2] + ", " + l[3] + "]", - "[" + v[0] + ", " + v[0] + ", null, " + l[3] + "]"); - /* CLR = 110 */ - CheckIfElseOutputASA(type, "[true, true, true, false]", valid_scalar, - "[" + l[0] + ", null, " + l[2] + ", null]", - "[" + v[0] + ", " + v[0] + ", " + v[0] + ", null]"); - /* CLR = 010 */ - CheckIfElseOutputASA(type, "[true, true, null, false]", valid_scalar, - "[" + l[0] + ", null, " + l[2] + ", null]", - "[" + v[0] + ", " + v[0] + ", null, null]"); - /* CLR = 101 */ - CheckIfElseOutputASA(type, "[true, true, true, false]", null_scalar, - "[" + l[0] + ", " + l[1] + ", " + l[2] + ", " + l[3] + "]", - "[null, null, null, " + l[3] + "]"); - /* CLR = 001 */ - CheckIfElseOutputASA(type, "[null, true, true, false]", null_scalar, - "[" + l[0] + ", " + l[1] + ", " + l[2] + ", " + l[3] + "]", - "[null, null, null, " + l[3] + "]"); - /* CLR = 100 */ - CheckIfElseOutputASA(type, "[true, true, true, false]", null_scalar, - "[" + l[0] + ", " + l[1] + ", null, " + l[3] + "]", - "[null, null, null, " + l[3] + "]"); - /* CLR = 000 */ - CheckIfElseOutputASA(type, "[true, true, null, false]", null_scalar, - "[" + l[0] + ", " + l[1] + ", null, " + l[3] + "]", - "[null, null, null, " + l[3] + "]"); - - /* -------- Cond - Array, Left- Scalar, Right - Scalar --------- */ - /* empty */ - CheckIfElseOutputASS(type, "[]", valid_scalar, valid_scalar1, "[]"); - - /* CLR = 111 */ - CheckIfElseOutputASS(type, "[true, true, true, false]", valid_scalar, valid_scalar1, - "[" + v[0] + ", " + v[0] + ", " + v[0] + ", " + v[1] + "]"); - /* CLR = 011 */ - CheckIfElseOutputASS(type, "[true, true, null, false]", valid_scalar, valid_scalar1, - "[" + v[0] + ", " + v[0] + ", null, " + v[1] + "]"); - /* CLR = 010 */ - CheckIfElseOutputASS(type, "[true, true, null, false]", valid_scalar, null_scalar, - "[" + v[0] + ", " + v[0] + ", null, null]"); - /* CLR = 110 */ - CheckIfElseOutputASS(type, "[true, true, true, false]", valid_scalar, null_scalar, - "[" + v[0] + ", " + v[0] + ", " + v[0] + ", null]"); - /* CLR = 101 */ - CheckIfElseOutputASS(type, "[true, true, true, false]", null_scalar, valid_scalar1, - "[null, null, null, " + v[1] + "]"); - /* CLR = 001 */ - CheckIfElseOutputASS(type, "[null, true, true, false]", null_scalar, valid_scalar1, - "[null, null, null, " + v[1] + "]"); - /* CLR = 100 */ - CheckIfElseOutputASS(type, "[true, true, true, false]", null_scalar, null_scalar, - "[null, null, null, null]"); - /* CLR = 000 */ - CheckIfElseOutputASS(type, "[true, true, null, false]", null_scalar, null_scalar, - "[null, null, null, null]"); - - /* -------- Cond - Scalar, Left- Array, Right - Array --------- */ - ASSERT_OK_AND_ASSIGN(auto bool_true, MakeScalar(boolean(), true)); - ASSERT_OK_AND_ASSIGN(auto bool_false, MakeScalar(boolean(), false)); - auto bool_null = MakeNullScalar(boolean()); - - /* empty */ - CheckIfElseOutputSAA(type, bool_true, "[]", "[]", "[]"); - /* CLR = 111 */ - CheckIfElseOutputSAA(type, bool_true, - "[" + l[0] + ", " + l[1] + ", " + l[2] + ", " + l[3] + "]", - "[" + r[0] + ", " + r[1] + ", " + r[2] + ", " + r[3] + "]", - "[" + l[0] + ", " + l[1] + ", " + l[2] + ", " + l[3] + "]"); - /* CLR = 011 */ - CheckIfElseOutputSAA(type, bool_null, - "[" + l[0] + ", " + l[1] + ", " + l[2] + ", " + l[3] + "]", - "[" + r[0] + ", " + r[1] + ", " + r[2] + ", " + r[3] + "]", - "[null, null, null, null]"); - /* CLR = 101 */ - CheckIfElseOutputSAA(type, bool_false, - "[" + l[0] + ", null, " + l[2] + ", " + l[3] + "]", - "[" + r[0] + ", " + r[1] + ", " + r[2] + ", " + r[3] + "]", - "[" + r[0] + ", " + r[1] + ", " + r[2] + ", " + r[3] + "]"); - /* CLR = 001 */ - CheckIfElseOutputSAA(type, bool_null, - "[" + l[0] + ", null, " + l[2] + ", " + l[3] + "]", - "[" + r[0] + ", " + r[1] + ", " + r[2] + ", " + r[3] + "]", - "[null, null, null, null]"); - /* CLR = 110 */ - CheckIfElseOutputSAA(type, bool_false, - "[" + l[0] + ", " + l[1] + ", " + l[2] + ", " + l[3] + "]", - "[" + r[0] + ", " + r[1] + ", " + r[2] + ", null]", - "[" + r[0] + ", " + r[1] + ", " + r[2] + ", null]"); - /* CLR = 010 */ - CheckIfElseOutputSAA( - type, bool_null, "[" + l[0] + ", " + l[1] + ", " + l[2] + ", " + l[3] + "]", - "[" + r[0] + ", " + r[1] + ", " + r[2] + ", null]", "[null, null, null, null]"); - /* CLR = 100 */ - CheckIfElseOutputSAA(type, bool_true, "[" + l[0] + ", " + l[1] + ", null, null]", - "[null, " + r[1] + ", " + r[2] + ", null]", - "[" + l[0] + ", " + l[1] + ", null, null]"); - /* CLR = 000 */ - CheckIfElseOutputSAA(type, bool_null, "[" + l[0] + ", " + l[1] + ", null, null]", - "[null, " + r[1] + ", " + r[2] + ", null]", - "[null, null, null, null]"); - - /* -------- Cond - Scalar, Left- Array, Right - Scalar --------- */ - /* empty */ - CheckIfElseOutputSAS(type, bool_true, "[]", valid_scalar, "[]"); - - /* CLR = 111 */ - CheckIfElseOutputSAS( - type, bool_true, "[" + l[0] + ", " + l[1] + ", " + l[2] + ", " + l[3] + "]", - valid_scalar, "[" + l[0] + ", " + l[1] + ", " + l[2] + ", " + l[3] + "]"); - /* CLR = 011 */ - CheckIfElseOutputSAS(type, bool_null, - "[" + l[0] + ", " + l[1] + ", " + l[2] + ", " + l[3] + "]", - valid_scalar, "[null, null, null, null]"); - /* CLR = 101 */ - CheckIfElseOutputSAS(type, bool_false, - "[" + l[0] + ", null, " + l[2] + ", " + l[3] + "]", valid_scalar, - "[" + v[0] + ", " + v[0] + ", " + v[0] + ", " + v[0] + "]"); - /* CLR = 001 */ - CheckIfElseOutputSAS(type, bool_null, - "[" + l[0] + ", null, " + l[2] + ", " + l[3] + "]", valid_scalar, - "[null, null, null, null]"); - /* CLR = 110 */ - CheckIfElseOutputSAS( - type, bool_true, "[" + l[0] + ", " + l[1] + ", " + l[2] + ", " + l[3] + "]", - null_scalar, "[" + l[0] + ", " + l[1] + ", " + l[2] + ", " + l[3] + "]"); - /* CLR = 010 */ - CheckIfElseOutputSAS(type, bool_null, - "[" + l[0] + ", " + l[1] + ", " + l[2] + ", " + l[3] + "]", - null_scalar, "[null, null, null, null]"); - /* CLR = 100 */ - CheckIfElseOutputSAS(type, bool_false, "[" + l[0] + ", " + l[1] + ", null, null]", - null_scalar, "[null, null, null, null]"); - /* CLR = 000 */ - CheckIfElseOutputSAS(type, bool_null, "[" + l[0] + ", " + l[1] + ", null, null]", - null_scalar, "[null, null, null, null]"); - - /* -------- Cond - Scalar, Left- Scalar, Right - Array --------- */ - /* empty */ - CheckIfElseOutputSSA(type, bool_true, valid_scalar, "[]", "[]"); - - /* CLR = 111 */ - CheckIfElseOutputSSA(type, bool_true, valid_scalar, - "[" + l[0] + ", " + l[1] + ", " + l[2] + ", " + l[3] + "]", - "[" + v[0] + ", " + v[0] + ", " + v[0] + ", " + v[0] + "]"); - /* CLR = 011 */ - CheckIfElseOutputSSA(type, bool_null, valid_scalar, - "[" + l[0] + ", " + l[1] + ", " + l[2] + ", " + l[3] + "]", - "[null, null, null, null]"); - /* CLR = 110 */ - CheckIfElseOutputSSA(type, bool_false, valid_scalar, - "[" + l[0] + ", null, " + l[2] + ", null]", - "[" + l[0] + ", null, " + l[2] + ", null]"); - /* CLR = 010 */ - CheckIfElseOutputSSA(type, bool_null, valid_scalar, - "[" + l[0] + ", null, " + l[2] + ", null]", - "[null, null, null, null]"); - /* CLR = 101 */ - CheckIfElseOutputSSA(type, bool_true, null_scalar, - "[" + l[0] + ", " + l[1] + ", " + l[2] + ", " + l[3] + "]", - "[null, null, null, null]"); - /* CLR = 001 */ - CheckIfElseOutputSSA(type, bool_null, null_scalar, - "[" + l[0] + ", " + l[1] + ", " + l[2] + ", " + l[3] + "]", - "[null, null, null, null]"); - /* CLR = 100 */ - CheckIfElseOutputSSA(type, bool_false, null_scalar, - "[" + l[0] + ", " + l[1] + ", null, " + l[3] + "]", - "[" + l[0] + ", " + l[1] + ", null, " + l[3] + "]"); - /* CLR = 000 */ - CheckIfElseOutputSSA(type, bool_null, null_scalar, - "[" + l[0] + ", " + l[1] + ", null, " + l[3] + "]", - "[null, null, null, null]"); - - /* -------- Cond - Scalar, Left- Scalar, Right - Scalar --------- */ - - /* CLR = 111 */ - CheckIfElseOutput(bool_false, valid_scalar, valid_scalar1, valid_scalar1); - /* CLR = 011 */ - CheckIfElseOutput(bool_null, valid_scalar, valid_scalar1, null_scalar); - /* CLR = 110 */ - CheckIfElseOutput(bool_true, valid_scalar, null_scalar, valid_scalar); - /* CLR = 010 */ - CheckIfElseOutput(bool_null, valid_scalar, null_scalar, null_scalar); - /* CLR = 101 */ - CheckIfElseOutput(bool_false, null_scalar, valid_scalar1, valid_scalar1); - /* CLR = 001 */ - CheckIfElseOutput(bool_null, null_scalar, valid_scalar1, null_scalar); - /* CLR = 100 */ - CheckIfElseOutput(bool_true, null_scalar, null_scalar, null_scalar); - /* CLR = 000 */ - CheckIfElseOutput(bool_null, null_scalar, null_scalar, null_scalar); +template +struct DatumWrapper { + using CType = typename TypeTraits::CType; + using ArrayType = typename TypeTraits::ArrayType; + using ScalarType = typename TypeTraits::ScalarType; + + util::Variant, std::shared_ptr> datum; + bool is_scalar; + + explicit DatumWrapper(const Datum& datum_) : is_scalar(datum_.is_scalar()) { + if (is_scalar) { + datum = std::move(std::static_pointer_cast(datum_.scalar())); + } else { + datum = std::move(std::static_pointer_cast(datum_.make_array())); + } + } + + bool IsValid(int64_t i) const { + return is_scalar ? util::get>(datum)->is_valid + : util::get>(datum)->IsValid(i); + } + + CType Value(int64_t i) const { + return is_scalar ? util::get>(datum)->value + : util::get>(datum)->Value(i); + } +}; + +template +void GenerateExpected(const Datum& cond, const Datum& left, const Datum& right, + Datum* out) { + int64_t len = cond.is_array() ? cond.length() + : left.is_array() ? left.length() + : right.is_array() ? right.length() + : 1; + + DatumWrapper cond_(cond); + DatumWrapper left_(left); + DatumWrapper right_(right); + + int64_t i = 0; + + // if all scalars + if (cond.is_scalar() && left.is_scalar() && right.is_scalar()) { + if (!cond_.IsValid(i) || (cond_.Value(i) && !left_.IsValid(i)) || + (!cond_.Value(i) && !right_.IsValid(i))) { + *out = MakeNullScalar(left.type()); + return; + } + + if (cond_.Value(i)) { + *out = left; + return; + } else { + *out = right; + return; + } + } + + typename TypeTraits::BuilderType builder; + + for (; i < len; ++i) { + if (!cond_.IsValid(i) || (cond_.Value(i) && !left_.IsValid(i)) || + (!cond_.Value(i) && !right_.IsValid(i))) { + ASSERT_OK(builder.AppendNull()); + continue; + } + + if (cond_.Value(i)) { + ASSERT_OK(builder.Append(left_.Value(i))); + } else { + ASSERT_OK(builder.Append(right_.Value(i))); + } + } + ASSERT_OK_AND_ASSIGN(auto expected_data, builder.Finish()); + + *out = expected_data; } -/* - * Legend: - * C - Cond, L - Left, R - Right - * 1 - All valid (or valid scalar), 0 - Could have nulls (or invalid scalar) - */ -TYPED_TEST(TestIfElsePrimitive, IfElseFixedSize) { +TYPED_TEST(TestIfElsePrimitive, IfElseFixedSizeGen) { auto type = TypeTraits::type_singleton(); - using T = typename TypeTraits::CType; - DoIfElseTest(type, {1, 2, 3, 4}, {5, 6, 7, 8}, {100, 111}); + std::vector cond_datums{ArrayFromJSON(boolean(), "[true, true, true, false]"), + ArrayFromJSON(boolean(), "[true, null, true, false]"), + MakeScalar(boolean(), true).ValueOrDie(), + MakeNullScalar(boolean())}; + + std::vector left_datums{ + ArrayFromJSON(type, "[1, 2, 3, 4]"), ArrayFromJSON(type, "[1, 2, null, 4]"), + MakeScalar(type, 100).ValueOrDie(), MakeNullScalar(type)}; + + std::vector right_datums{ + ArrayFromJSON(type, "[5, 6, 7, 8]"), ArrayFromJSON(type, "[5, 6, 7, null]"), + MakeScalar(type, 111).ValueOrDie(), MakeNullScalar(type)}; + + for (auto&& cond : cond_datums) { + for (auto&& left : left_datums) { + for (auto&& right : right_datums) { + Datum exp; + GenerateExpected(cond, left, right, &exp); + CheckIfElseOutput(cond, left, right, exp); + } + } + } } -TEST_F(TestIfElseKernel, IfElseBoolean) { +TEST_F(TestIfElseKernel, IfElseBooleanGen) { auto type = boolean(); - DoIfElseTest(type, {false, false, false, false}, {true, true, true, true}, - {false, true}); + std::vector cond_datums{ArrayFromJSON(boolean(), "[true, true, true, false]"), + ArrayFromJSON(boolean(), "[true, true, null, false]"), + MakeScalar(boolean(), true).ValueOrDie(), + MakeNullScalar(boolean())}; + + std::vector left_datums{ArrayFromJSON(type, "[false, false, false, false]"), + ArrayFromJSON(type, "[false, false, null, false]"), + MakeScalar(type, false).ValueOrDie(), + MakeNullScalar(type)}; + + std::vector right_datums{ArrayFromJSON(type, "[true, true, true, true]"), + ArrayFromJSON(type, "[true, true, true, null]"), + MakeScalar(type, true).ValueOrDie(), + MakeNullScalar(type)}; + + for (auto&& cond : cond_datums) { + for (auto&& left : left_datums) { + for (auto&& right : right_datums) { + Datum exp; + GenerateExpected(cond, left, right, &exp); + CheckIfElseOutput(cond, left, right, exp); + } + } + } } TYPED_TEST(TestIfElsePrimitive, IfElseBooleanRand) { @@ -496,8 +249,10 @@ TYPED_TEST(TestIfElsePrimitive, IfElseBooleanRand) { } TEST_F(TestIfElseKernel, IfElseNull) { - CheckIfElseOutputAAA(null(), "[null, null, null, null]", "[null, null, null, null]", - "[null, null, null, null]", "[null, null, null, null]"); + CheckIfElseOutput(ArrayFromJSON(boolean(), "[null, null, null, null]"), + ArrayFromJSON(null(), "[null, null, null, null]"), + ArrayFromJSON(null(), "[null, null, null, null]"), + ArrayFromJSON(null(), "[null, null, null, null]")); } TEST_F(TestIfElseKernel, IfElseWithOffset) { From c54e88e37ed8de65787378689ff043cebff6a603 Mon Sep 17 00:00:00 2001 From: niranda perera Date: Tue, 1 Jun 2021 01:32:45 +0000 Subject: [PATCH 28/39] Autoformat/render all the things [automated commit] --- cpp/src/arrow/compute/kernels/scalar_if_else_test.cc | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/cpp/src/arrow/compute/kernels/scalar_if_else_test.cc b/cpp/src/arrow/compute/kernels/scalar_if_else_test.cc index f4eb1a6e1bed3..326c129592327 100644 --- a/cpp/src/arrow/compute/kernels/scalar_if_else_test.cc +++ b/cpp/src/arrow/compute/kernels/scalar_if_else_test.cc @@ -114,10 +114,9 @@ struct DatumWrapper { template void GenerateExpected(const Datum& cond, const Datum& left, const Datum& right, Datum* out) { - int64_t len = cond.is_array() ? cond.length() - : left.is_array() ? left.length() - : right.is_array() ? right.length() - : 1; + int64_t len = cond.is_array() ? cond.length() + : left.is_array() ? left.length() + : right.is_array() ? right.length() : 1; DatumWrapper cond_(cond); DatumWrapper left_(left); From 4e5869715231edf414269d4da5288def61cecc72 Mon Sep 17 00:00:00 2001 From: niranda perera Date: Tue, 1 Jun 2021 10:26:50 -0400 Subject: [PATCH 29/39] adding comments --- .../arrow/compute/kernels/scalar_if_else.cc | 49 ++++++++++--------- 1 file changed, 25 insertions(+), 24 deletions(-) diff --git a/cpp/src/arrow/compute/kernels/scalar_if_else.cc b/cpp/src/arrow/compute/kernels/scalar_if_else.cc index 4bdf9d2261022..37318128defef 100644 --- a/cpp/src/arrow/compute/kernels/scalar_if_else.cc +++ b/cpp/src/arrow/compute/kernels/scalar_if_else.cc @@ -31,14 +31,17 @@ namespace compute { namespace { +constexpr uint64_t kAllNull = 0; +constexpr uint64_t kAllValid = ~kAllNull; + util::optional GetConstantValidityWord(const Datum& data) { if (data.is_scalar()) { - return data.scalar()->is_valid ? ~uint64_t(0) : uint64_t(0); + return data.scalar()->is_valid ? kAllValid : kAllNull; } - if (data.array()->null_count == data.array()->length) return uint64_t(0); + if (data.array()->null_count == data.array()->length) return kAllNull; - if (!data.array()->MayHaveNulls()) return ~uint64_t(0); + if (!data.array()->MayHaveNulls()) return kAllValid; // no constant validity word available return {}; @@ -84,27 +87,25 @@ Status PromoteNullsVisitor(KernelContext* ctx, const Datum& cond_d, const Datum& // cond.valid & (cond.data & left.valid | ~cond.data & right.valid) // In the following cases, we dont need to allocate out_valid bitmap - switch (flag) { - case COND_CONST | LEFT_CONST | RIGHT_CONST: - // if cond & left & right all ones, then output is all valid --> out_valid = nullptr - if ((*cond_const & *left_const & *right_const) == UINT64_MAX) { - return Status::OK(); - } - case LEFT_CONST | RIGHT_CONST: - // if both left and right are valid, no need to calculate out_valid bitmap. Pass - // cond validity buffer - if ((*left_const & *right_const) == UINT64_MAX) { - // if there's an offset, copy bitmap (cannot slice a bitmap) - if (cond.offset) { - ARROW_ASSIGN_OR_RAISE( - output->buffers[0], - arrow::internal::CopyBitmap(ctx->memory_pool(), cond.buffers[0]->data(), - cond.offset, cond.length)); - } else { // just copy assign cond validity buffer - output->buffers[0] = cond.buffers[0]; - } - return Status::OK(); - } + + // if cond & left & right all ones, then output is all valid --> out_valid = nullptr + if (cond_const == kAllValid && left_const == kAllValid && right_const == kAllValid) { + return Status::OK(); + } + + if (left_const == kAllValid && right_const == kAllValid) { + // if both left and right are valid, no need to calculate out_valid bitmap. Pass + // cond validity buffer + // if there's an offset, copy bitmap (cannot slice a bitmap) + if (cond.offset) { + ARROW_ASSIGN_OR_RAISE( + output->buffers[0], + arrow::internal::CopyBitmap(ctx->memory_pool(), cond.buffers[0]->data(), + cond.offset, cond.length)); + } else { // just copy assign cond validity buffer + output->buffers[0] = cond.buffers[0]; + } + return Status::OK(); } // following cases requires a separate out_valid buffer From 7e9dff5321d9cce820da13be02680e82624a317b Mon Sep 17 00:00:00 2001 From: niranda perera Date: Tue, 1 Jun 2021 10:26:31 -0400 Subject: [PATCH 30/39] Update cpp/src/arrow/compute/kernels/scalar_if_else.cc Co-authored-by: Benjamin Kietzman --- .../arrow/compute/kernels/scalar_if_else.cc | 33 ++++++++++--------- 1 file changed, 17 insertions(+), 16 deletions(-) diff --git a/cpp/src/arrow/compute/kernels/scalar_if_else.cc b/cpp/src/arrow/compute/kernels/scalar_if_else.cc index 37318128defef..cce861235588b 100644 --- a/cpp/src/arrow/compute/kernels/scalar_if_else.cc +++ b/cpp/src/arrow/compute/kernels/scalar_if_else.cc @@ -490,24 +490,25 @@ struct ResolveIfElseExec { *out = MakeNullScalar(batch[1].type()); } return Status::OK(); - } else { // either left or right is an array. output is always an array - // output size is the size of the array arg - auto bcast_size = batch.length; - if (cond.is_valid) { - const auto& valid_data = cond.value ? batch[1] : batch[2]; - if (valid_data.is_array()) { - *out = valid_data; - } else { // valid data is a scalar that needs to be broadcasted - ARROW_ASSIGN_OR_RAISE(*out, - MakeArrayFromScalar(*valid_data.scalar(), bcast_size, - ctx->memory_pool())); - } - } else { // cond is null. create null array - ARROW_ASSIGN_OR_RAISE( - *out, MakeArrayOfNull(batch[1].type(), bcast_size, ctx->memory_pool())) - } + } + // either left or right is an array. Output is always an array + if (!cond.is_valid) { + // cond is null; just create a null array + ARROW_ASSIGN_OR_RAISE( + *out, MakeArrayOfNull(batch[1].type(), bcast_size, ctx->memory_pool())) return Status::OK(); } + + const auto& valid_data = cond.value ? batch[1] : batch[2]; + if (valid_data.is_array()) { + *out = valid_data; + } else { + // valid data is a scalar that needs to be broadcasted + ARROW_ASSIGN_OR_RAISE(*out, + MakeArrayFromScalar(*valid_data.scalar(), batch.length, + ctx->memory_pool())); + } + return Status::OK(); } // cond is array. Use functors to sort things out From ee573a8ac45ff946f72a5807022a596546381595 Mon Sep 17 00:00:00 2001 From: niranda perera Date: Tue, 1 Jun 2021 10:33:58 -0400 Subject: [PATCH 31/39] minor bug fix --- cpp/src/arrow/compute/kernels/scalar_if_else.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cpp/src/arrow/compute/kernels/scalar_if_else.cc b/cpp/src/arrow/compute/kernels/scalar_if_else.cc index cce861235588b..c559e3cf938c6 100644 --- a/cpp/src/arrow/compute/kernels/scalar_if_else.cc +++ b/cpp/src/arrow/compute/kernels/scalar_if_else.cc @@ -495,7 +495,7 @@ struct ResolveIfElseExec { if (!cond.is_valid) { // cond is null; just create a null array ARROW_ASSIGN_OR_RAISE( - *out, MakeArrayOfNull(batch[1].type(), bcast_size, ctx->memory_pool())) + *out, MakeArrayOfNull(batch[1].type(), batch.length, ctx->memory_pool())) return Status::OK(); } From c6d0cdbda1b7791247594eea4bbbc0b0c0b8b9fa Mon Sep 17 00:00:00 2001 From: niranda perera Date: Tue, 1 Jun 2021 10:34:25 -0400 Subject: [PATCH 32/39] minor bug fix --- cpp/src/arrow/compute/kernels/scalar_if_else.cc | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/cpp/src/arrow/compute/kernels/scalar_if_else.cc b/cpp/src/arrow/compute/kernels/scalar_if_else.cc index c559e3cf938c6..98939ab2f5708 100644 --- a/cpp/src/arrow/compute/kernels/scalar_if_else.cc +++ b/cpp/src/arrow/compute/kernels/scalar_if_else.cc @@ -504,9 +504,9 @@ struct ResolveIfElseExec { *out = valid_data; } else { // valid data is a scalar that needs to be broadcasted - ARROW_ASSIGN_OR_RAISE(*out, - MakeArrayFromScalar(*valid_data.scalar(), batch.length, - ctx->memory_pool())); + ARROW_ASSIGN_OR_RAISE( + *out, + MakeArrayFromScalar(*valid_data.scalar(), batch.length, ctx->memory_pool())); } return Status::OK(); } From e15b9dfd8738796c115a435ad15417111ea7eea6 Mon Sep 17 00:00:00 2001 From: niranda perera Date: Tue, 1 Jun 2021 13:17:58 -0400 Subject: [PATCH 33/39] adding discussed changes to the tests --- .../compute/kernels/scalar_if_else_test.cc | 252 +++++++++--------- cpp/src/arrow/compute/kernels/test_util.cc | 4 +- cpp/src/arrow/compute/kernels/test_util.h | 15 +- 3 files changed, 142 insertions(+), 129 deletions(-) diff --git a/cpp/src/arrow/compute/kernels/scalar_if_else_test.cc b/cpp/src/arrow/compute/kernels/scalar_if_else_test.cc index 326c129592327..10c3d88a19b46 100644 --- a/cpp/src/arrow/compute/kernels/scalar_if_else_test.cc +++ b/cpp/src/arrow/compute/kernels/scalar_if_else_test.cc @@ -83,141 +83,147 @@ TYPED_TEST(TestIfElsePrimitive, IfElseFixedSizeRand) { CheckIfElseOutput(cond, left, right, expected_data); } -template -struct DatumWrapper { - using CType = typename TypeTraits::CType; - using ArrayType = typename TypeTraits::ArrayType; - using ScalarType = typename TypeTraits::ScalarType; - - util::Variant, std::shared_ptr> datum; - bool is_scalar; - - explicit DatumWrapper(const Datum& datum_) : is_scalar(datum_.is_scalar()) { - if (is_scalar) { - datum = std::move(std::static_pointer_cast(datum_.scalar())); - } else { - datum = std::move(std::static_pointer_cast(datum_.make_array())); - } - } - - bool IsValid(int64_t i) const { - return is_scalar ? util::get>(datum)->is_valid - : util::get>(datum)->IsValid(i); - } - - CType Value(int64_t i) const { - return is_scalar ? util::get>(datum)->value - : util::get>(datum)->Value(i); - } -}; - -template -void GenerateExpected(const Datum& cond, const Datum& left, const Datum& right, - Datum* out) { - int64_t len = cond.is_array() ? cond.length() - : left.is_array() ? left.length() - : right.is_array() ? right.length() : 1; - - DatumWrapper cond_(cond); - DatumWrapper left_(left); - DatumWrapper right_(right); - - int64_t i = 0; - - // if all scalars - if (cond.is_scalar() && left.is_scalar() && right.is_scalar()) { - if (!cond_.IsValid(i) || (cond_.Value(i) && !left_.IsValid(i)) || - (!cond_.Value(i) && !right_.IsValid(i))) { - *out = MakeNullScalar(left.type()); - return; - } - - if (cond_.Value(i)) { - *out = left; - return; - } else { - *out = right; - return; - } - } - - typename TypeTraits::BuilderType builder; - - for (; i < len; ++i) { - if (!cond_.IsValid(i) || (cond_.Value(i) && !left_.IsValid(i)) || - (!cond_.Value(i) && !right_.IsValid(i))) { - ASSERT_OK(builder.AppendNull()); - continue; - } +void CheckWithDifferentShapes(const std::shared_ptr& cond, + const std::shared_ptr& left, + const std::shared_ptr& right, + const std::shared_ptr& expected) { + // this will check for whole arrays, every scalar at i'th index and slicing (offset) + CheckScalar("if_else", {cond, left, right}, expected); + + auto len = left->length(); + + enum { COND_SCALAR = 1, LEFT_SCALAR = 2, RIGHT_SCALAR = 4 }; + for (int mask = 0; mask < (COND_SCALAR | LEFT_SCALAR | RIGHT_SCALAR); ++mask) { + for (int64_t cond_idx = 0; cond_idx < len; ++cond_idx) { + Datum cond_in, cond_bcast; + if (mask & COND_SCALAR) { + ASSERT_OK_AND_ASSIGN(cond_in, cond->GetScalar(cond_idx)); + ASSERT_OK_AND_ASSIGN(cond_bcast, MakeArrayFromScalar(*cond_in.scalar(), len)); + } else { + cond_in = cond_bcast = cond; + } - if (cond_.Value(i)) { - ASSERT_OK(builder.Append(left_.Value(i))); - } else { - ASSERT_OK(builder.Append(right_.Value(i))); + for (int64_t left_idx = 0; left_idx < len; ++left_idx) { + Datum left_in, left_bcast; + if (mask & LEFT_SCALAR) { + ASSERT_OK_AND_ASSIGN(left_in, left->GetScalar(left_idx).As()); + ASSERT_OK_AND_ASSIGN(left_bcast, MakeArrayFromScalar(*left_in.scalar(), len)); + } else { + left_in = left_bcast = left; + } + + for (int64_t right_idx = 0; right_idx < len; ++right_idx) { + Datum right_in, right_bcast; + if (mask & RIGHT_SCALAR) { + ASSERT_OK_AND_ASSIGN(right_in, right->GetScalar(right_idx)); + ASSERT_OK_AND_ASSIGN(right_bcast, + MakeArrayFromScalar(*right_in.scalar(), len)); + } else { + right_in = right_bcast = right; + } + + ASSERT_OK_AND_ASSIGN(auto exp, IfElse(cond_bcast, left_bcast, right_bcast)); + ASSERT_OK_AND_ASSIGN(auto actual, IfElse(cond_in, left_in, right_in)); + AssertDatumsEqual(exp, actual, /*verbose=*/true); + + if (right_in.is_array()) break; + } + if (left_in.is_array()) break; + } + if (cond_in.is_array()) break; } - } - ASSERT_OK_AND_ASSIGN(auto expected_data, builder.Finish()); - - *out = expected_data; + } // for (mask) } -TYPED_TEST(TestIfElsePrimitive, IfElseFixedSizeGen) { +TYPED_TEST(TestIfElsePrimitive, IfElseFixedSize) { auto type = TypeTraits::type_singleton(); - std::vector cond_datums{ArrayFromJSON(boolean(), "[true, true, true, false]"), - ArrayFromJSON(boolean(), "[true, null, true, false]"), - MakeScalar(boolean(), true).ValueOrDie(), - MakeNullScalar(boolean())}; - - std::vector left_datums{ - ArrayFromJSON(type, "[1, 2, 3, 4]"), ArrayFromJSON(type, "[1, 2, null, 4]"), - MakeScalar(type, 100).ValueOrDie(), MakeNullScalar(type)}; - - std::vector right_datums{ - ArrayFromJSON(type, "[5, 6, 7, 8]"), ArrayFromJSON(type, "[5, 6, 7, null]"), - MakeScalar(type, 111).ValueOrDie(), MakeNullScalar(type)}; - - for (auto&& cond : cond_datums) { - for (auto&& left : left_datums) { - for (auto&& right : right_datums) { - Datum exp; - GenerateExpected(cond, left, right, &exp); - CheckIfElseOutput(cond, left, right, exp); - } - } - } + CheckWithDifferentShapes(ArrayFromJSON(boolean(), "[true, true, true, false]"), + ArrayFromJSON(type, "[1, 2, 3, 4]"), + ArrayFromJSON(type, "[5, 6, 7, 8]"), + ArrayFromJSON(type, "[1, 2, 3, 8]")); + + CheckWithDifferentShapes(ArrayFromJSON(boolean(), "[true, true, true, false]"), + ArrayFromJSON(type, "[1, 2, 3, 4]"), + ArrayFromJSON(type, "[5, 6, 7, null]"), + ArrayFromJSON(type, "[1, 2, 3, null]")); + + CheckWithDifferentShapes(ArrayFromJSON(boolean(), "[true, true, true, false]"), + ArrayFromJSON(type, "[1, 2, null, 4]"), + ArrayFromJSON(type, "[5, 6, 7, null]"), + ArrayFromJSON(type, "[1, 2, null, null]")); + + CheckWithDifferentShapes(ArrayFromJSON(boolean(), "[true, true, true, false]"), + ArrayFromJSON(type, "[1, 2, null, 4]"), + ArrayFromJSON(type, "[5, 6, 7, 8]"), + ArrayFromJSON(type, "[1, 2, null, 8]")); + + CheckWithDifferentShapes(ArrayFromJSON(boolean(), "[null, true, true, false]"), + ArrayFromJSON(type, "[1, 2, null, 4]"), + ArrayFromJSON(type, "[5, 6, 7, 8]"), + ArrayFromJSON(type, "[null, 2, null, 8]")); + + CheckWithDifferentShapes(ArrayFromJSON(boolean(), "[null, true, true, false]"), + ArrayFromJSON(type, "[1, 2, null, 4]"), + ArrayFromJSON(type, "[5, 6, 7, null]"), + ArrayFromJSON(type, "[null, 2, null, null]")); + + CheckWithDifferentShapes(ArrayFromJSON(boolean(), "[null, true, true, false]"), + ArrayFromJSON(type, "[1, 2, 3, 4]"), + ArrayFromJSON(type, "[5, 6, 7, null]"), + ArrayFromJSON(type, "[null, 2, 3, null]")); + + CheckWithDifferentShapes(ArrayFromJSON(boolean(), "[null, true, true, false]"), + ArrayFromJSON(type, "[1, 2, 3, 4]"), + ArrayFromJSON(type, "[5, 6, 7, 8]"), + ArrayFromJSON(type, "[null, 2, 3, 8]")); } -TEST_F(TestIfElseKernel, IfElseBooleanGen) { +TEST_F(TestIfElseKernel, IfElseBoolean) { auto type = boolean(); - std::vector cond_datums{ArrayFromJSON(boolean(), "[true, true, true, false]"), - ArrayFromJSON(boolean(), "[true, true, null, false]"), - MakeScalar(boolean(), true).ValueOrDie(), - MakeNullScalar(boolean())}; - - std::vector left_datums{ArrayFromJSON(type, "[false, false, false, false]"), - ArrayFromJSON(type, "[false, false, null, false]"), - MakeScalar(type, false).ValueOrDie(), - MakeNullScalar(type)}; - - std::vector right_datums{ArrayFromJSON(type, "[true, true, true, true]"), - ArrayFromJSON(type, "[true, true, true, null]"), - MakeScalar(type, true).ValueOrDie(), - MakeNullScalar(type)}; - - for (auto&& cond : cond_datums) { - for (auto&& left : left_datums) { - for (auto&& right : right_datums) { - Datum exp; - GenerateExpected(cond, left, right, &exp); - CheckIfElseOutput(cond, left, right, exp); - } - } - } + CheckWithDifferentShapes(ArrayFromJSON(boolean(), "[true, true, true, false]"), + ArrayFromJSON(type, "[false, false, false, false]"), + ArrayFromJSON(type, "[true, true, true, true]"), + ArrayFromJSON(type, "[false, false, false, true]")); + + CheckWithDifferentShapes(ArrayFromJSON(boolean(), "[true, true, true, false]"), + ArrayFromJSON(type, "[false, false, false, false]"), + ArrayFromJSON(type, "[true, true, true, null]"), + ArrayFromJSON(type, "[false, false, false, null]")); + + CheckWithDifferentShapes(ArrayFromJSON(boolean(), "[true, true, true, false]"), + ArrayFromJSON(type, "[false, false, null, false]"), + ArrayFromJSON(type, "[true, true, true, null]"), + ArrayFromJSON(type, "[false, false, null, null]")); + + CheckWithDifferentShapes(ArrayFromJSON(boolean(), "[true, true, true, false]"), + ArrayFromJSON(type, "[false, false, null, false]"), + ArrayFromJSON(type, "[true, true, true, true]"), + ArrayFromJSON(type, "[false, false, null, true]")); + + CheckWithDifferentShapes(ArrayFromJSON(boolean(), "[null, true, true, false]"), + ArrayFromJSON(type, "[false, false, null, false]"), + ArrayFromJSON(type, "[true, true, true, true]"), + ArrayFromJSON(type, "[null, false, null, true]")); + + CheckWithDifferentShapes(ArrayFromJSON(boolean(), "[null, true, true, false]"), + ArrayFromJSON(type, "[false, false, null, false]"), + ArrayFromJSON(type, "[true, true, true, null]"), + ArrayFromJSON(type, "[null, false, null, null]")); + + CheckWithDifferentShapes(ArrayFromJSON(boolean(), "[null, true, true, false]"), + ArrayFromJSON(type, "[false, false, false, false]"), + ArrayFromJSON(type, "[true, true, true, null]"), + ArrayFromJSON(type, "[null, false, false, null]")); + + CheckWithDifferentShapes(ArrayFromJSON(boolean(), "[null, true, true, false]"), + ArrayFromJSON(type, "[false, false, false, false]"), + ArrayFromJSON(type, "[true, true, true, true]"), + ArrayFromJSON(type, "[null, false, false, true]")); } -TYPED_TEST(TestIfElsePrimitive, IfElseBooleanRand) { +TEST_F(TestIfElseKernel, IfElseBooleanRand) { auto type = boolean(); random::RandomArrayGenerator rand(/*seed=*/0); int64_t len = 1000; diff --git a/cpp/src/arrow/compute/kernels/test_util.cc b/cpp/src/arrow/compute/kernels/test_util.cc index 672308452cf13..c74ef3b76ddd2 100644 --- a/cpp/src/arrow/compute/kernels/test_util.cc +++ b/cpp/src/arrow/compute/kernels/test_util.cc @@ -70,6 +70,8 @@ ScalarVector GetScalars(const ArrayVector& inputs, int64_t index) { return scalars; } +} // namespace + void CheckScalar(std::string func_name, const ScalarVector& inputs, std::shared_ptr expected, const FunctionOptions* options) { ASSERT_OK_AND_ASSIGN(Datum out, CallFunction(func_name, GetDatums(inputs), options)); @@ -140,8 +142,6 @@ void CheckScalar(std::string func_name, const ArrayVector& inputs, } } -} // namespace - void CheckScalarUnary(std::string func_name, std::shared_ptr input, std::shared_ptr expected, const FunctionOptions* options) { CheckScalar(std::move(func_name), {input}, expected, options); diff --git a/cpp/src/arrow/compute/kernels/test_util.h b/cpp/src/arrow/compute/kernels/test_util.h index aea3d8360e68e..cadcc4fe35cf9 100644 --- a/cpp/src/arrow/compute/kernels/test_util.h +++ b/cpp/src/arrow/compute/kernels/test_util.h @@ -19,13 +19,14 @@ // IWYU pragma: begin_exports +#include + #include #include #include -#include - #include "arrow/array.h" +#include "arrow/compute/kernel.h" #include "arrow/datum.h" #include "arrow/memory_pool.h" #include "arrow/pretty_print.h" @@ -34,8 +35,6 @@ #include "arrow/testing/util.h" #include "arrow/type.h" -#include "arrow/compute/kernel.h" - // IWYU pragma: end_exports namespace arrow { @@ -90,6 +89,14 @@ struct DatumEqual> { } }; +void CheckScalar(std::string func_name, const ScalarVector& inputs, + std::shared_ptr expected, + const FunctionOptions* options = nullptr); + +void CheckScalar(std::string func_name, const ArrayVector& inputs, + std::shared_ptr expected, + const FunctionOptions* options = nullptr); + void CheckScalarUnary(std::string func_name, std::shared_ptr in_ty, std::string json_input, std::shared_ptr out_ty, std::string json_expected, From 4fa62c382840bea10b321c94da36d68ade02ddd8 Mon Sep 17 00:00:00 2001 From: niranda perera Date: Tue, 1 Jun 2021 14:39:22 -0400 Subject: [PATCH 34/39] removing offset test case --- cpp/src/arrow/compute/kernels/scalar_if_else_test.cc | 8 -------- 1 file changed, 8 deletions(-) diff --git a/cpp/src/arrow/compute/kernels/scalar_if_else_test.cc b/cpp/src/arrow/compute/kernels/scalar_if_else_test.cc index 10c3d88a19b46..26d938dc58f16 100644 --- a/cpp/src/arrow/compute/kernels/scalar_if_else_test.cc +++ b/cpp/src/arrow/compute/kernels/scalar_if_else_test.cc @@ -260,13 +260,5 @@ TEST_F(TestIfElseKernel, IfElseNull) { ArrayFromJSON(null(), "[null, null, null, null]")); } -TEST_F(TestIfElseKernel, IfElseWithOffset) { - auto cond = ArrayFromJSON(boolean(), "[null, true, false]")->Slice(1, 2); - auto left = ArrayFromJSON(int64(), "[10, 11]"); - auto right = ArrayFromJSON(int64(), "[1, 2]"); - auto expected = ArrayFromJSON(int64(), "[10, 2]"); - CheckIfElseOutput(cond, left, right, expected); -} - } // namespace compute } // namespace arrow From 6891211434c2c9b8bdf5e7a9fc4b6e60ee6f012d Mon Sep 17 00:00:00 2001 From: niranda perera Date: Tue, 1 Jun 2021 21:02:30 -0400 Subject: [PATCH 35/39] Update cpp/src/arrow/compute/kernels/scalar_if_else_test.cc Co-authored-by: Benjamin Kietzman --- cpp/src/arrow/compute/kernels/scalar_if_else_test.cc | 3 +++ 1 file changed, 3 insertions(+) diff --git a/cpp/src/arrow/compute/kernels/scalar_if_else_test.cc b/cpp/src/arrow/compute/kernels/scalar_if_else_test.cc index 26d938dc58f16..4f8f7991c48c7 100644 --- a/cpp/src/arrow/compute/kernels/scalar_if_else_test.cc +++ b/cpp/src/arrow/compute/kernels/scalar_if_else_test.cc @@ -96,12 +96,15 @@ void CheckWithDifferentShapes(const std::shared_ptr& cond, for (int mask = 0; mask < (COND_SCALAR | LEFT_SCALAR | RIGHT_SCALAR); ++mask) { for (int64_t cond_idx = 0; cond_idx < len; ++cond_idx) { Datum cond_in, cond_bcast; + std::string trace_msg = "Cond"; if (mask & COND_SCALAR) { ASSERT_OK_AND_ASSIGN(cond_in, cond->GetScalar(cond_idx)); ASSERT_OK_AND_ASSIGN(cond_bcast, MakeArrayFromScalar(*cond_in.scalar(), len)); + trace_msg += "@" + std::to_string(cond_idx) + "=" + cond_in.scalar()->ToString(); } else { cond_in = cond_bcast = cond; } + SCOPED_TRACE(trace_msg); for (int64_t left_idx = 0; left_idx < len; ++left_idx) { Datum left_in, left_bcast; From b8802bb91258fa9d056a327291806932aa8a6ae4 Mon Sep 17 00:00:00 2001 From: niranda perera Date: Tue, 1 Jun 2021 21:14:22 -0400 Subject: [PATCH 36/39] adding PR review comments --- cpp/src/arrow/compute/kernels/scalar_if_else.cc | 2 +- .../arrow/compute/kernels/scalar_if_else_test.cc | 14 +++++++++++--- cpp/src/arrow/compute/util_internal.h | 5 ----- 3 files changed, 12 insertions(+), 9 deletions(-) diff --git a/cpp/src/arrow/compute/kernels/scalar_if_else.cc b/cpp/src/arrow/compute/kernels/scalar_if_else.cc index 98939ab2f5708..9a0e34442e0d6 100644 --- a/cpp/src/arrow/compute/kernels/scalar_if_else.cc +++ b/cpp/src/arrow/compute/kernels/scalar_if_else.cc @@ -422,7 +422,7 @@ struct IfElseFunctor> { // out_buf = ones ARROW_ASSIGN_OR_RAISE(out_buf, ctx->AllocateBitmap(cond.length)); // filling with UINT8_MAX upto the buffer's size (in bytes) - arrow::compute::internal::SetMemory(out_buf.get()); + std::memset(out_buf->mutable_data(), UINT8_MAX, out_buf->size()); } else { // out_buf = cond out_buf = SliceBuffer(cond.buffers[1], cond.offset, cond.length); diff --git a/cpp/src/arrow/compute/kernels/scalar_if_else_test.cc b/cpp/src/arrow/compute/kernels/scalar_if_else_test.cc index 4f8f7991c48c7..5d3d22210d238 100644 --- a/cpp/src/arrow/compute/kernels/scalar_if_else_test.cc +++ b/cpp/src/arrow/compute/kernels/scalar_if_else_test.cc @@ -96,34 +96,42 @@ void CheckWithDifferentShapes(const std::shared_ptr& cond, for (int mask = 0; mask < (COND_SCALAR | LEFT_SCALAR | RIGHT_SCALAR); ++mask) { for (int64_t cond_idx = 0; cond_idx < len; ++cond_idx) { Datum cond_in, cond_bcast; - std::string trace_msg = "Cond"; + std::string trace_cond = "Cond"; if (mask & COND_SCALAR) { ASSERT_OK_AND_ASSIGN(cond_in, cond->GetScalar(cond_idx)); ASSERT_OK_AND_ASSIGN(cond_bcast, MakeArrayFromScalar(*cond_in.scalar(), len)); - trace_msg += "@" + std::to_string(cond_idx) + "=" + cond_in.scalar()->ToString(); + trace_cond += "@" + std::to_string(cond_idx) + "=" + cond_in.scalar()->ToString(); } else { cond_in = cond_bcast = cond; } - SCOPED_TRACE(trace_msg); + SCOPED_TRACE(trace_cond); for (int64_t left_idx = 0; left_idx < len; ++left_idx) { Datum left_in, left_bcast; + std::string trace_left = "Left"; if (mask & LEFT_SCALAR) { ASSERT_OK_AND_ASSIGN(left_in, left->GetScalar(left_idx).As()); ASSERT_OK_AND_ASSIGN(left_bcast, MakeArrayFromScalar(*left_in.scalar(), len)); + trace_cond += + "@" + std::to_string(left_idx) + "=" + left_in.scalar()->ToString(); } else { left_in = left_bcast = left; } + SCOPED_TRACE(trace_left); for (int64_t right_idx = 0; right_idx < len; ++right_idx) { Datum right_in, right_bcast; + std::string trace_right = "Right"; if (mask & RIGHT_SCALAR) { ASSERT_OK_AND_ASSIGN(right_in, right->GetScalar(right_idx)); ASSERT_OK_AND_ASSIGN(right_bcast, MakeArrayFromScalar(*right_in.scalar(), len)); + trace_right += + "@" + std::to_string(right_idx) + "=" + right_in.scalar()->ToString(); } else { right_in = right_bcast = right; } + SCOPED_TRACE(trace_right); ASSERT_OK_AND_ASSIGN(auto exp, IfElse(cond_bcast, left_bcast, right_bcast)); ASSERT_OK_AND_ASSIGN(auto actual, IfElse(cond_in, left_in, right_in)); diff --git a/cpp/src/arrow/compute/util_internal.h b/cpp/src/arrow/compute/util_internal.h index bff4214217614..396c2ca2a0b38 100644 --- a/cpp/src/arrow/compute/util_internal.h +++ b/cpp/src/arrow/compute/util_internal.h @@ -27,11 +27,6 @@ static inline void ZeroMemory(Buffer* buffer) { std::memset(buffer->mutable_data(), 0, buffer->size()); } -template -static inline void SetMemory(Buffer* buffer) { - std::memset(buffer->mutable_data(), ch, buffer->size()); -} - } // namespace internal } // namespace compute } // namespace arrow From cee34c2b4ae2708b5ba8a95733d566e581667d22 Mon Sep 17 00:00:00 2001 From: niranda perera Date: Wed, 2 Jun 2021 08:12:11 -0400 Subject: [PATCH 37/39] Update cpp/src/arrow/compute/api_scalar.h Co-authored-by: Joris Van den Bossche --- cpp/src/arrow/compute/api_scalar.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cpp/src/arrow/compute/api_scalar.h b/cpp/src/arrow/compute/api_scalar.h index d5d1f82fb35a7..0a05b123a442c 100644 --- a/cpp/src/arrow/compute/api_scalar.h +++ b/cpp/src/arrow/compute/api_scalar.h @@ -466,7 +466,7 @@ Result FillNull(const Datum& values, const Datum& fill_value, ExecContext* ctx = NULLPTR); /// \brief IfElse returns elements chosen from `left` or `right` -/// depending on `cond`. `null` values would be promoted to the result +/// depending on `cond`. `null` values in `cond` will be promoted to the result /// /// \param[in] cond `Boolean` condition Scalar/ Array /// \param[in] left Scalar/ Array From 4dbe076710367379d6e5be82271e54fea7dbde53 Mon Sep 17 00:00:00 2001 From: niranda perera Date: Wed, 2 Jun 2021 11:37:10 -0400 Subject: [PATCH 38/39] adding docs --- docs/source/cpp/compute.rst | 36 ++++++++++++++++++------------ docs/source/python/api/compute.rst | 1 + 2 files changed, 23 insertions(+), 14 deletions(-) diff --git a/docs/source/cpp/compute.rst b/docs/source/cpp/compute.rst index 3cf244ca5e83a..4e729b055cf04 100644 --- a/docs/source/cpp/compute.rst +++ b/docs/source/cpp/compute.rst @@ -648,40 +648,48 @@ Structural transforms +==========================+============+================================================+=====================+=========+ | fill_null | Binary | Boolean, Null, Numeric, Temporal, String-like | Input type | \(1) | +--------------------------+------------+------------------------------------------------+---------------------+---------+ -| is_finite | Unary | Float, Double | Boolean | \(2) | +| if_else | Ternary | Boolean, Null, Numeric, Temporal | Input type + \(2) | +--------------------------+------------+------------------------------------------------+---------------------+---------+ -| is_inf | Unary | Float, Double | Boolean | \(3) | +| is_finite | Unary | Float, Double | Boolean | \(3) | +--------------------------+------------+------------------------------------------------+---------------------+---------+ -| is_nan | Unary | Float, Double | Boolean | \(4) | +| is_inf | Unary | Float, Double | Boolean | \(4) | +--------------------------+------------+------------------------------------------------+---------------------+---------+ -| is_null | Unary | Any | Boolean | \(5) | +| is_nan | Unary | Float, Double | Boolean | \(5) | +--------------------------+------------+------------------------------------------------+---------------------+---------+ -| is_valid | Unary | Any | Boolean | \(6) | +| is_null | Unary | Any | Boolean | \(6) | +--------------------------+------------+------------------------------------------------+---------------------+---------+ -| list_value_length | Unary | List-like | Int32 or Int64 | \(7) | +| is_valid | Unary | Any | Boolean | \(7) | +--------------------------+------------+------------------------------------------------+---------------------+---------+ -| project | Varargs | Any | Struct | \(8) | +| list_value_length | Unary | List-like | Int32 or Int64 | \(8) | ++--------------------------+------------+------------------------------------------------+---------------------+---------+ +| project | Varargs | Any | Struct | \(9) | +--------------------------+------------+------------------------------------------------+---------------------+---------+ * \(1) First input must be an array, second input a scalar of the same type. Output is an array of the same type as the inputs, and with the same values as the first input, except for nulls replaced with the second input value. -* \(2) Output is true iff the corresponding input element is finite (not Infinity, +* \(2) First input must be a Boolean scalar or array. Second and third inputs + could be scalars or arrays and must be of the same type. Output is an array + (or scalar if all inputs are scalar) of the same type as the second/ third + input. If the nulls present on the first input, they will be promoted to the + output, otherwise nulls will be chosen based on the first input values. + +* \(3) Output is true iff the corresponding input element is finite (not Infinity, -Infinity, or NaN). -* \(3) Output is true iff the corresponding input element is Infinity/-Infinity. +* \(4) Output is true iff the corresponding input element is Infinity/-Infinity. -* \(4) Output is true iff the corresponding input element is NaN. +* \(5) Output is true iff the corresponding input element is NaN. -* \(5) Output is true iff the corresponding input element is null. +* \(6) Output is true iff the corresponding input element is null. -* \(6) Output is true iff the corresponding input element is non-null. +* \(7) Output is true iff the corresponding input element is non-null. -* \(7) Each output element is the length of the corresponding input element +* \(8) Each output element is the length of the corresponding input element (null if input is null). Output type is Int32 for List, Int64 for LargeList. -* \(8) The output struct's field types are the types of its arguments. The +* \(9) The output struct's field types are the types of its arguments. The field names are specified using an instance of :struct:`ProjectOptions`. The output shape will be scalar if all inputs are scalar, otherwise any scalars will be broadcast to arrays. diff --git a/docs/source/python/api/compute.rst b/docs/source/python/api/compute.rst index 91eeeedbeaabe..3010776930f30 100644 --- a/docs/source/python/api/compute.rst +++ b/docs/source/python/api/compute.rst @@ -222,6 +222,7 @@ Structural Transforms binary_length fill_null + if_else is_finite is_inf is_nan From 52b6a8301ab488e321c441ee26b1f57814364eda Mon Sep 17 00:00:00 2001 From: niranda perera Date: Thu, 3 Jun 2021 11:52:43 -0400 Subject: [PATCH 39/39] fixing gcc 4.8.1 compilation issue --- .../arrow/compute/kernels/scalar_if_else.cc | 84 ++++++++++--------- 1 file changed, 46 insertions(+), 38 deletions(-) diff --git a/cpp/src/arrow/compute/kernels/scalar_if_else.cc b/cpp/src/arrow/compute/kernels/scalar_if_else.cc index 9a0e34442e0d6..63086172c9711 100644 --- a/cpp/src/arrow/compute/kernels/scalar_if_else.cc +++ b/cpp/src/arrow/compute/kernels/scalar_if_else.cc @@ -115,61 +115,69 @@ Status PromoteNullsVisitor(KernelContext* ctx, const Datum& cond_d, const Datum& enum { C_VALID, C_DATA, L_VALID, R_VALID }; switch (flag) { - case COND_CONST | LEFT_CONST | RIGHT_CONST: - Bitmap::VisitWords({_, cond_data, _, _}, [&](std::array words) { + case COND_CONST | LEFT_CONST | RIGHT_CONST: { + Bitmap bitmaps[] = {_, cond_data, _, _}; + Bitmap::VisitWords(bitmaps, [&](std::array words) { apply(*cond_const, words[C_DATA], *left_const, *right_const); }); break; - case LEFT_CONST | RIGHT_CONST: - Bitmap::VisitWords( - {cond_valid, cond_data, _, _}, [&](std::array words) { - apply(words[C_VALID], words[C_DATA], *left_const, *right_const); - }); + } + case LEFT_CONST | RIGHT_CONST: { + Bitmap bitmaps[] = {cond_valid, cond_data, _, _}; + Bitmap::VisitWords(bitmaps, [&](std::array words) { + apply(words[C_VALID], words[C_DATA], *left_const, *right_const); + }); break; - case COND_CONST | RIGHT_CONST: + } + case COND_CONST | RIGHT_CONST: { // bitmaps[C_VALID], bitmaps[R_VALID] might be null; override to make it safe for // Visit() - Bitmap::VisitWords( - {_, cond_data, left_valid, _}, [&](std::array words) { - apply(*cond_const, words[C_DATA], words[L_VALID], *right_const); - }); + Bitmap bitmaps[] = {_, cond_data, left_valid, _}; + Bitmap::VisitWords(bitmaps, [&](std::array words) { + apply(*cond_const, words[C_DATA], words[L_VALID], *right_const); + }); break; - case RIGHT_CONST: + } + case RIGHT_CONST: { // bitmaps[R_VALID] might be null; override to make it safe for Visit() - Bitmap::VisitWords( - {cond_valid, cond_data, left_valid, _}, [&](std::array words) { - apply(words[C_VALID], words[C_DATA], words[L_VALID], *right_const); - }); + Bitmap bitmaps[] = {cond_valid, cond_data, left_valid, _}; + Bitmap::VisitWords(bitmaps, [&](std::array words) { + apply(words[C_VALID], words[C_DATA], words[L_VALID], *right_const); + }); break; - case COND_CONST | LEFT_CONST: + } + case COND_CONST | LEFT_CONST: { // bitmaps[C_VALID], bitmaps[L_VALID] might be null; override to make it safe for // Visit() - Bitmap::VisitWords({_, cond_data, _, right_valid}, - [&](std::array words) { - apply(*cond_const, words[C_DATA], *left_const, words[R_VALID]); - }); + Bitmap bitmaps[] = {_, cond_data, _, right_valid}; + Bitmap::VisitWords(bitmaps, [&](std::array words) { + apply(*cond_const, words[C_DATA], *left_const, words[R_VALID]); + }); break; - case LEFT_CONST: + } + case LEFT_CONST: { // bitmaps[L_VALID] might be null; override to make it safe for Visit() - Bitmap::VisitWords( - {cond_valid, cond_data, _, right_valid}, [&](std::array words) { - apply(words[C_VALID], words[C_DATA], *left_const, words[R_VALID]); - }); + Bitmap bitmaps[] = {cond_valid, cond_data, _, right_valid}; + Bitmap::VisitWords(bitmaps, [&](std::array words) { + apply(words[C_VALID], words[C_DATA], *left_const, words[R_VALID]); + }); break; - case COND_CONST: + } + case COND_CONST: { // bitmaps[C_VALID] might be null; override to make it safe for Visit() - Bitmap::VisitWords( - {_, cond_data, left_valid, right_valid}, [&](std::array words) { - apply(*cond_const, words[C_DATA], words[L_VALID], words[R_VALID]); - }); + Bitmap bitmaps[] = {_, cond_data, left_valid, right_valid}; + Bitmap::VisitWords(bitmaps, [&](std::array words) { + apply(*cond_const, words[C_DATA], words[L_VALID], words[R_VALID]); + }); break; - case 0: - Bitmap::VisitWords({cond_valid, cond_data, left_valid, right_valid}, - [&](std::array words) { - apply(words[C_VALID], words[C_DATA], words[L_VALID], - words[R_VALID]); - }); + } + case 0: { + Bitmap bitmaps[] = {cond_valid, cond_data, left_valid, right_valid}; + Bitmap::VisitWords(bitmaps, [&](std::array words) { + apply(words[C_VALID], words[C_DATA], words[L_VALID], words[R_VALID]); + }); break; + } } return Status::OK(); }