diff --git a/cpp/src/arrow/CMakeLists.txt b/cpp/src/arrow/CMakeLists.txt index 1d832cc25a21c..f6d5a540c9891 100644 --- a/cpp/src/arrow/CMakeLists.txt +++ b/cpp/src/arrow/CMakeLists.txt @@ -399,6 +399,7 @@ if(ARROW_COMPUTE) compute/kernels/scalar_string.cc compute/kernels/scalar_validity.cc compute/kernels/scalar_fill_null.cc + compute/kernels/scalar_if_else.cc compute/kernels/util_internal.cc compute/kernels/vector_hash.cc compute/kernels/vector_nested.cc diff --git a/cpp/src/arrow/compute/api_scalar.cc b/cpp/src/arrow/compute/api_scalar.cc index 9f4ad42fecbd2..105ba7a0589af 100644 --- a/cpp/src/arrow/compute/api_scalar.cc +++ b/cpp/src/arrow/compute/api_scalar.cc @@ -157,5 +157,10 @@ Result FillNull(const Datum& values, const Datum& fill_value, ExecContext return CallFunction("fill_null", {values, fill_value}, ctx); } +Result IfElse(const Datum& cond, const Datum& if_true, const Datum& if_false, + ExecContext* ctx) { + return CallFunction("if_else", {cond, if_true, if_false}, ctx); +} + } // namespace compute } // namespace arrow diff --git a/cpp/src/arrow/compute/api_scalar.h b/cpp/src/arrow/compute/api_scalar.h index dce420b32b24d..0a05b123a442c 100644 --- a/cpp/src/arrow/compute/api_scalar.h +++ b/cpp/src/arrow/compute/api_scalar.h @@ -465,5 +465,21 @@ ARROW_EXPORT Result FillNull(const Datum& values, const Datum& fill_value, ExecContext* ctx = NULLPTR); +/// \brief IfElse returns elements chosen from `left` or `right` +/// depending on `cond`. `null` values in `cond` will be promoted to the result +/// +/// \param[in] cond `Boolean` condition Scalar/ Array +/// \param[in] left Scalar/ Array +/// \param[in] right Scalar/ Array +/// \param[in] ctx the function execution context, optional +/// +/// \return the resulting datum +/// +/// \since 5.0.0 +/// \note API not yet finalized +ARROW_EXPORT +Result IfElse(const Datum& cond, const Datum& left, const Datum& right, + ExecContext* ctx = NULLPTR); + } // namespace compute } // namespace arrow diff --git a/cpp/src/arrow/compute/kernels/CMakeLists.txt b/cpp/src/arrow/compute/kernels/CMakeLists.txt index 5e223a1f906f0..fc11d14410524 100644 --- a/cpp/src/arrow/compute/kernels/CMakeLists.txt +++ b/cpp/src/arrow/compute/kernels/CMakeLists.txt @@ -29,6 +29,7 @@ add_arrow_compute_test(scalar_test scalar_string_test.cc scalar_validity_test.cc scalar_fill_null_test.cc + scalar_if_else_test.cc test_util.cc) add_arrow_benchmark(scalar_arithmetic_benchmark PREFIX "arrow-compute") diff --git a/cpp/src/arrow/compute/kernels/scalar_if_else.cc b/cpp/src/arrow/compute/kernels/scalar_if_else.cc new file mode 100644 index 0000000000000..63086172c9711 --- /dev/null +++ b/cpp/src/arrow/compute/kernels/scalar_if_else.cc @@ -0,0 +1,587 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include +#include +#include +#include +#include +#include + +namespace arrow { +using internal::BitBlockCount; +using internal::BitBlockCounter; +using internal::Bitmap; + +namespace compute { + +namespace { + +constexpr uint64_t kAllNull = 0; +constexpr uint64_t kAllValid = ~kAllNull; + +util::optional GetConstantValidityWord(const Datum& data) { + if (data.is_scalar()) { + return data.scalar()->is_valid ? kAllValid : kAllNull; + } + + if (data.array()->null_count == data.array()->length) return kAllNull; + + if (!data.array()->MayHaveNulls()) return kAllValid; + + // no constant validity word available + return {}; +} + +inline Bitmap GetBitmap(const Datum& datum, int i) { + if (datum.is_scalar()) return {}; + const ArrayData& a = *datum.array(); + return Bitmap{a.buffers[i], a.offset, a.length}; +} + +// if the condition is null then output is null otherwise we take validity from the +// selected argument +// ie. cond.valid & (cond.data & left.valid | ~cond.data & right.valid) +Status PromoteNullsVisitor(KernelContext* ctx, const Datum& cond_d, const Datum& left_d, + const Datum& right_d, ArrayData* output) { + auto cond_const = GetConstantValidityWord(cond_d); + auto left_const = GetConstantValidityWord(left_d); + auto right_const = GetConstantValidityWord(right_d); + + enum { COND_CONST = 1, LEFT_CONST = 2, RIGHT_CONST = 4 }; + auto flag = COND_CONST * cond_const.has_value() | LEFT_CONST * left_const.has_value() | + RIGHT_CONST * right_const.has_value(); + + const ArrayData& cond = *cond_d.array(); + // cond.data will always be available + Bitmap cond_data{cond.buffers[1], cond.offset, cond.length}; + Bitmap cond_valid{cond.buffers[0], cond.offset, cond.length}; + Bitmap left_valid = GetBitmap(left_d, 0); + Bitmap right_valid = GetBitmap(right_d, 0); + // sometimes Bitmaps will be ignored, in which case we replace access to them with + // duplicated (probably elided) access to cond_data + const Bitmap& _ = cond_data; + + // lambda function that will be used inside the visitor + uint64_t* out_validity = nullptr; + int64_t i = 0; + auto apply = [&](uint64_t c_valid, uint64_t c_data, uint64_t l_valid, + uint64_t r_valid) { + out_validity[i] = c_valid & ((c_data & l_valid) | (~c_data & r_valid)); + i++; + }; + + // cond.valid & (cond.data & left.valid | ~cond.data & right.valid) + // In the following cases, we dont need to allocate out_valid bitmap + + // if cond & left & right all ones, then output is all valid --> out_valid = nullptr + if (cond_const == kAllValid && left_const == kAllValid && right_const == kAllValid) { + return Status::OK(); + } + + if (left_const == kAllValid && right_const == kAllValid) { + // if both left and right are valid, no need to calculate out_valid bitmap. Pass + // cond validity buffer + // if there's an offset, copy bitmap (cannot slice a bitmap) + if (cond.offset) { + ARROW_ASSIGN_OR_RAISE( + output->buffers[0], + arrow::internal::CopyBitmap(ctx->memory_pool(), cond.buffers[0]->data(), + cond.offset, cond.length)); + } else { // just copy assign cond validity buffer + output->buffers[0] = cond.buffers[0]; + } + return Status::OK(); + } + + // following cases requires a separate out_valid buffer + ARROW_ASSIGN_OR_RAISE(output->buffers[0], ctx->AllocateBitmap(cond.length)); + out_validity = output->GetMutableValues(0); + + enum { C_VALID, C_DATA, L_VALID, R_VALID }; + + switch (flag) { + case COND_CONST | LEFT_CONST | RIGHT_CONST: { + Bitmap bitmaps[] = {_, cond_data, _, _}; + Bitmap::VisitWords(bitmaps, [&](std::array words) { + apply(*cond_const, words[C_DATA], *left_const, *right_const); + }); + break; + } + case LEFT_CONST | RIGHT_CONST: { + Bitmap bitmaps[] = {cond_valid, cond_data, _, _}; + Bitmap::VisitWords(bitmaps, [&](std::array words) { + apply(words[C_VALID], words[C_DATA], *left_const, *right_const); + }); + break; + } + case COND_CONST | RIGHT_CONST: { + // bitmaps[C_VALID], bitmaps[R_VALID] might be null; override to make it safe for + // Visit() + Bitmap bitmaps[] = {_, cond_data, left_valid, _}; + Bitmap::VisitWords(bitmaps, [&](std::array words) { + apply(*cond_const, words[C_DATA], words[L_VALID], *right_const); + }); + break; + } + case RIGHT_CONST: { + // bitmaps[R_VALID] might be null; override to make it safe for Visit() + Bitmap bitmaps[] = {cond_valid, cond_data, left_valid, _}; + Bitmap::VisitWords(bitmaps, [&](std::array words) { + apply(words[C_VALID], words[C_DATA], words[L_VALID], *right_const); + }); + break; + } + case COND_CONST | LEFT_CONST: { + // bitmaps[C_VALID], bitmaps[L_VALID] might be null; override to make it safe for + // Visit() + Bitmap bitmaps[] = {_, cond_data, _, right_valid}; + Bitmap::VisitWords(bitmaps, [&](std::array words) { + apply(*cond_const, words[C_DATA], *left_const, words[R_VALID]); + }); + break; + } + case LEFT_CONST: { + // bitmaps[L_VALID] might be null; override to make it safe for Visit() + Bitmap bitmaps[] = {cond_valid, cond_data, _, right_valid}; + Bitmap::VisitWords(bitmaps, [&](std::array words) { + apply(words[C_VALID], words[C_DATA], *left_const, words[R_VALID]); + }); + break; + } + case COND_CONST: { + // bitmaps[C_VALID] might be null; override to make it safe for Visit() + Bitmap bitmaps[] = {_, cond_data, left_valid, right_valid}; + Bitmap::VisitWords(bitmaps, [&](std::array words) { + apply(*cond_const, words[C_DATA], words[L_VALID], words[R_VALID]); + }); + break; + } + case 0: { + Bitmap bitmaps[] = {cond_valid, cond_data, left_valid, right_valid}; + Bitmap::VisitWords(bitmaps, [&](std::array words) { + apply(words[C_VALID], words[C_DATA], words[L_VALID], words[R_VALID]); + }); + break; + } + } + return Status::OK(); +} + +template +struct IfElseFunctor {}; + +// only number types needs to be handled for Fixed sized primitive data types because, +// internal::GenerateTypeAgnosticPrimitive forwards types to the corresponding unsigned +// int type +template +struct IfElseFunctor> { + using T = typename TypeTraits::CType; + // A - Array + // S - Scalar + + // AAA + static Status Call(KernelContext* ctx, const ArrayData& cond, const ArrayData& left, + const ArrayData& right, ArrayData* out) { + ARROW_ASSIGN_OR_RAISE(std::shared_ptr out_buf, + ctx->Allocate(cond.length * sizeof(T))); + T* out_values = reinterpret_cast(out_buf->mutable_data()); + + // copy right data to out_buff + const T* right_data = right.GetValues(1); + std::memcpy(out_values, right_data, right.length * sizeof(T)); + + const auto* cond_data = cond.buffers[1]->data(); // this is a BoolArray + BitBlockCounter bit_counter(cond_data, cond.offset, cond.length); + + // selectively copy values from left data + const T* left_data = left.GetValues(1); + int64_t offset = cond.offset; + + // todo this can be improved by intrinsics. ex: _mm*_mask_store_e* (vmovdqa*) + while (offset < cond.offset + cond.length) { + const BitBlockCount& block = bit_counter.NextWord(); + if (block.AllSet()) { // all from left + std::memcpy(out_values, left_data, block.length * sizeof(T)); + } else if (block.popcount) { // selectively copy from left + for (int64_t i = 0; i < block.length; ++i) { + if (BitUtil::GetBit(cond_data, offset + i)) { + out_values[i] = left_data[i]; + } + } + } + + offset += block.length; + out_values += block.length; + left_data += block.length; + } + + out->buffers[1] = std::move(out_buf); + return Status::OK(); + } + + // ASA + static Status Call(KernelContext* ctx, const ArrayData& cond, const Scalar& left, + const ArrayData& right, ArrayData* out) { + ARROW_ASSIGN_OR_RAISE(std::shared_ptr out_buf, + ctx->Allocate(cond.length * sizeof(T))); + T* out_values = reinterpret_cast(out_buf->mutable_data()); + + // copy right data to out_buff + const T* right_data = right.GetValues(1); + std::memcpy(out_values, right_data, right.length * sizeof(T)); + + const auto* cond_data = cond.buffers[1]->data(); // this is a BoolArray + BitBlockCounter bit_counter(cond_data, cond.offset, cond.length); + + // selectively copy values from left data + T left_data = internal::UnboxScalar::Unbox(left); + int64_t offset = cond.offset; + + // todo this can be improved by intrinsics. ex: _mm*_mask_store_e* (vmovdqa*) + while (offset < cond.offset + cond.length) { + const BitBlockCount& block = bit_counter.NextWord(); + if (block.AllSet()) { // all from left + std::fill(out_values, out_values + block.length, left_data); + } else if (block.popcount) { // selectively copy from left + for (int64_t i = 0; i < block.length; ++i) { + if (BitUtil::GetBit(cond_data, offset + i)) { + out_values[i] = left_data; + } + } + } + + offset += block.length; + out_values += block.length; + } + + out->buffers[1] = std::move(out_buf); + return Status::OK(); + } + + // AAS + static Status Call(KernelContext* ctx, const ArrayData& cond, const ArrayData& left, + const Scalar& right, ArrayData* out) { + ARROW_ASSIGN_OR_RAISE(std::shared_ptr out_buf, + ctx->Allocate(cond.length * sizeof(T))); + T* out_values = reinterpret_cast(out_buf->mutable_data()); + + // copy left data to out_buff + const T* left_data = left.GetValues(1); + std::memcpy(out_values, left_data, left.length * sizeof(T)); + + const auto* cond_data = cond.buffers[1]->data(); // this is a BoolArray + BitBlockCounter bit_counter(cond_data, cond.offset, cond.length); + + // selectively copy values from left data + T right_data = internal::UnboxScalar::Unbox(right); + int64_t offset = cond.offset; + + // todo this can be improved by intrinsics. ex: _mm*_mask_store_e* (vmovdqa*) + // left data is already in the output buffer. Therefore, mask needs to be inverted + while (offset < cond.offset + cond.length) { + const BitBlockCount& block = bit_counter.NextWord(); + if (block.NoneSet()) { // all from right + std::fill(out_values, out_values + block.length, right_data); + } else if (block.popcount) { // selectively copy from right + for (int64_t i = 0; i < block.length; ++i) { + if (!BitUtil::GetBit(cond_data, offset + i)) { + out_values[i] = right_data; + } + } + } + + offset += block.length; + out_values += block.length; + } + + out->buffers[1] = std::move(out_buf); + return Status::OK(); + } + + // ASS + static Status Call(KernelContext* ctx, const ArrayData& cond, const Scalar& left, + const Scalar& right, ArrayData* out) { + ARROW_ASSIGN_OR_RAISE(std::shared_ptr out_buf, + ctx->Allocate(cond.length * sizeof(T))); + T* out_values = reinterpret_cast(out_buf->mutable_data()); + + // copy right data to out_buff + T right_data = internal::UnboxScalar::Unbox(right); + std::fill(out_values, out_values + cond.length, right_data); + + const auto* cond_data = cond.buffers[1]->data(); // this is a BoolArray + BitBlockCounter bit_counter(cond_data, cond.offset, cond.length); + + // selectively copy values from left data + T left_data = internal::UnboxScalar::Unbox(left); + int64_t offset = cond.offset; + + // todo this can be improved by intrinsics. ex: _mm*_mask_store_e* (vmovdqa*) + while (offset < cond.offset + cond.length) { + const BitBlockCount& block = bit_counter.NextWord(); + if (block.AllSet()) { // all from left + std::fill(out_values, out_values + block.length, left_data); + } else if (block.popcount) { // selectively copy from left + for (int64_t i = 0; i < block.length; ++i) { + if (BitUtil::GetBit(cond_data, offset + i)) { + out_values[i] = left_data; + } + } + } + + offset += block.length; + out_values += block.length; + } + + out->buffers[1] = std::move(out_buf); + return Status::OK(); + } +}; + +template +struct IfElseFunctor> { + // AAA + static Status Call(KernelContext* ctx, const ArrayData& cond, const ArrayData& left, + const ArrayData& right, ArrayData* out) { + // out_buff = right & ~cond + ARROW_ASSIGN_OR_RAISE(std::shared_ptr out_buf, + arrow::internal::BitmapAndNot( + ctx->memory_pool(), right.buffers[1]->data(), right.offset, + cond.buffers[1]->data(), cond.offset, cond.length, 0)); + + // out_buff = left & cond + ARROW_ASSIGN_OR_RAISE(std::shared_ptr temp_buf, + arrow::internal::BitmapAnd( + ctx->memory_pool(), left.buffers[1]->data(), left.offset, + cond.buffers[1]->data(), cond.offset, cond.length, 0)); + + arrow::internal::BitmapOr(out_buf->data(), 0, temp_buf->data(), 0, cond.length, 0, + out_buf->mutable_data()); + out->buffers[1] = std::move(out_buf); + return Status::OK(); + } + + // ASA + static Status Call(KernelContext* ctx, const ArrayData& cond, const Scalar& left, + const ArrayData& right, ArrayData* out) { + // out_buff = right & ~cond + ARROW_ASSIGN_OR_RAISE(std::shared_ptr out_buf, + arrow::internal::BitmapAndNot( + ctx->memory_pool(), right.buffers[1]->data(), right.offset, + cond.buffers[1]->data(), cond.offset, cond.length, 0)); + + // out_buff = left & cond + bool left_data = internal::UnboxScalar::Unbox(left); + if (left_data) { + arrow::internal::BitmapOr(out_buf->data(), 0, cond.buffers[1]->data(), cond.offset, + cond.length, 0, out_buf->mutable_data()); + } + + out->buffers[1] = std::move(out_buf); + return Status::OK(); + } + + // AAS + static Status Call(KernelContext* ctx, const ArrayData& cond, const ArrayData& left, + const Scalar& right, ArrayData* out) { + // out_buff = left & cond + ARROW_ASSIGN_OR_RAISE(std::shared_ptr out_buf, + arrow::internal::BitmapAnd( + ctx->memory_pool(), left.buffers[1]->data(), left.offset, + cond.buffers[1]->data(), cond.offset, cond.length, 0)); + + bool right_data = internal::UnboxScalar::Unbox(right); + + // out_buff = left & cond | right & ~cond + if (right_data) { + arrow::internal::BitmapOrNot(out_buf->data(), 0, cond.buffers[1]->data(), + cond.offset, cond.length, 0, out_buf->mutable_data()); + } + + out->buffers[1] = std::move(out_buf); + return Status::OK(); + } + + // ASS + static Status Call(KernelContext* ctx, const ArrayData& cond, const Scalar& left, + const Scalar& right, ArrayData* out) { + bool left_data = internal::UnboxScalar::Unbox(left); + bool right_data = internal::UnboxScalar::Unbox(right); + + // out_buf = left & cond | right & ~cond + std::shared_ptr out_buf = nullptr; + if (left_data) { + if (right_data) { + // out_buf = ones + ARROW_ASSIGN_OR_RAISE(out_buf, ctx->AllocateBitmap(cond.length)); + // filling with UINT8_MAX upto the buffer's size (in bytes) + std::memset(out_buf->mutable_data(), UINT8_MAX, out_buf->size()); + } else { + // out_buf = cond + out_buf = SliceBuffer(cond.buffers[1], cond.offset, cond.length); + } + } else { + if (right_data) { + // out_buf = ~cond + ARROW_ASSIGN_OR_RAISE(out_buf, arrow::internal::InvertBitmap( + ctx->memory_pool(), cond.buffers[1]->data(), + cond.offset, cond.length)) + } else { + // out_buf = zeros + ARROW_ASSIGN_OR_RAISE(out_buf, ctx->AllocateBitmap(cond.length)); + } + } + out->buffers[1] = std::move(out_buf); + return Status::OK(); + } +}; + +template +struct IfElseFunctor> { + template + static inline Status ReturnCopy(const T& in, T* out) { + // Nothing preallocated, so we assign in into the output + *out = in; + return Status::OK(); + } + + // AAA + static Status Call(KernelContext* ctx, const ArrayData& cond, const ArrayData& left, + const ArrayData& right, ArrayData* out) { + return ReturnCopy(left, out); + } + + // ASA + static Status Call(KernelContext* ctx, const ArrayData& cond, const Scalar& left, + const ArrayData& right, ArrayData* out) { + return ReturnCopy(right, out); + } + + // AAS + static Status Call(KernelContext* ctx, const ArrayData& cond, const ArrayData& left, + const Scalar& right, ArrayData* out) { + return ReturnCopy(left, out); + } + + // ASS + static Status Call(KernelContext* ctx, const ArrayData& cond, const Scalar& left, + const Scalar& right, ArrayData* out) { + return ReturnCopy(cond, out); + } +}; + +template +struct ResolveIfElseExec { + static Status Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) { + // cond is scalar + if (batch[0].is_scalar()) { + const auto& cond = batch[0].scalar_as(); + if (batch[1].is_scalar() && batch[2].is_scalar()) { + if (cond.is_valid) { + *out = cond.value ? batch[1].scalar() : batch[2].scalar(); + } else { + *out = MakeNullScalar(batch[1].type()); + } + return Status::OK(); + } + // either left or right is an array. Output is always an array + if (!cond.is_valid) { + // cond is null; just create a null array + ARROW_ASSIGN_OR_RAISE( + *out, MakeArrayOfNull(batch[1].type(), batch.length, ctx->memory_pool())) + return Status::OK(); + } + + const auto& valid_data = cond.value ? batch[1] : batch[2]; + if (valid_data.is_array()) { + *out = valid_data; + } else { + // valid data is a scalar that needs to be broadcasted + ARROW_ASSIGN_OR_RAISE( + *out, + MakeArrayFromScalar(*valid_data.scalar(), batch.length, ctx->memory_pool())); + } + return Status::OK(); + } + + // cond is array. Use functors to sort things out + ARROW_RETURN_NOT_OK( + PromoteNullsVisitor(ctx, batch[0], batch[1], batch[2], out->mutable_array())); + + if (batch[1].kind() == Datum::ARRAY) { + if (batch[2].kind() == Datum::ARRAY) { // AAA + return IfElseFunctor::Call(ctx, *batch[0].array(), *batch[1].array(), + *batch[2].array(), out->mutable_array()); + } else { // AAS + return IfElseFunctor::Call(ctx, *batch[0].array(), *batch[1].array(), + *batch[2].scalar(), out->mutable_array()); + } + } else { + if (batch[2].kind() == Datum::ARRAY) { // ASA + return IfElseFunctor::Call(ctx, *batch[0].array(), *batch[1].scalar(), + *batch[2].array(), out->mutable_array()); + } else { // ASS + return IfElseFunctor::Call(ctx, *batch[0].array(), *batch[1].scalar(), + *batch[2].scalar(), out->mutable_array()); + } + } + } +}; + +void AddPrimitiveIfElseKernels(const std::shared_ptr& scalar_function, + const std::vector>& types) { + for (auto&& type : types) { + auto exec = internal::GenerateTypeAgnosticPrimitive(*type); + // cond array needs to be boolean always + ScalarKernel kernel({boolean(), type, type}, type, exec); + kernel.null_handling = NullHandling::COMPUTED_NO_PREALLOCATE; + kernel.mem_allocation = MemAllocation::NO_PREALLOCATE; + + DCHECK_OK(scalar_function->AddKernel(std::move(kernel))); + } +} + +} // namespace + +const FunctionDoc if_else_doc{"Choose values based on a condition", + ("`cond` must be a Boolean scalar/ array. \n`left` or " + "`right` must be of the same type scalar/ array.\n" + "`null` values in `cond` will be promoted to the" + " output."), + {"cond", "left", "right"}}; + +namespace internal { + +void RegisterScalarIfElse(FunctionRegistry* registry) { + ScalarKernel scalar_kernel; + scalar_kernel.null_handling = NullHandling::COMPUTED_NO_PREALLOCATE; + scalar_kernel.mem_allocation = MemAllocation::NO_PREALLOCATE; + + auto func = std::make_shared("if_else", Arity::Ternary(), &if_else_doc); + + AddPrimitiveIfElseKernels(func, NumericTypes()); + AddPrimitiveIfElseKernels(func, TemporalTypes()); + AddPrimitiveIfElseKernels(func, {boolean(), null()}); + // todo add binary kernels + + DCHECK_OK(registry->AddFunction(std::move(func))); +} + +} // namespace internal +} // namespace compute +} // namespace arrow diff --git a/cpp/src/arrow/compute/kernels/scalar_if_else_test.cc b/cpp/src/arrow/compute/kernels/scalar_if_else_test.cc new file mode 100644 index 0000000000000..5d3d22210d238 --- /dev/null +++ b/cpp/src/arrow/compute/kernels/scalar_if_else_test.cc @@ -0,0 +1,275 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include +#include +#include +#include +#include + +namespace arrow { +namespace compute { + +void CheckIfElseOutput(const Datum& cond, const Datum& left, const Datum& right, + const Datum& expected) { + ASSERT_OK_AND_ASSIGN(Datum datum_out, IfElse(cond, left, right)); + if (datum_out.is_array()) { + std::shared_ptr result = datum_out.make_array(); + ASSERT_OK(result->ValidateFull()); + std::shared_ptr expected_ = expected.make_array(); + AssertArraysEqual(*expected_, *result, /*verbose=*/true); + } else { // expecting scalar + const std::shared_ptr& result = datum_out.scalar(); + const std::shared_ptr& expected_ = expected.scalar(); + AssertScalarsEqual(*expected_, *result, /*verbose=*/true); + } +} + +class TestIfElseKernel : public ::testing::Test {}; + +template +class TestIfElsePrimitive : public ::testing::Test {}; + +using PrimitiveTypes = ::testing::Types; + +TYPED_TEST_SUITE(TestIfElsePrimitive, PrimitiveTypes); + +TYPED_TEST(TestIfElsePrimitive, IfElseFixedSizeRand) { + using ArrayType = typename TypeTraits::ArrayType; + auto type = TypeTraits::type_singleton(); + + random::RandomArrayGenerator rand(/*seed=*/0); + int64_t len = 1000; + auto cond = std::static_pointer_cast( + rand.ArrayOf(boolean(), len, /*null_probability=*/0.01)); + auto left = std::static_pointer_cast( + rand.ArrayOf(type, len, /*null_probability=*/0.01)); + auto right = std::static_pointer_cast( + rand.ArrayOf(type, len, /*null_probability=*/0.01)); + + typename TypeTraits::BuilderType builder; + + for (int64_t i = 0; i < len; ++i) { + if (!cond->IsValid(i) || (cond->Value(i) && !left->IsValid(i)) || + (!cond->Value(i) && !right->IsValid(i))) { + ASSERT_OK(builder.AppendNull()); + continue; + } + + if (cond->Value(i)) { + ASSERT_OK(builder.Append(left->Value(i))); + } else { + ASSERT_OK(builder.Append(right->Value(i))); + } + } + ASSERT_OK_AND_ASSIGN(auto expected_data, builder.Finish()); + + CheckIfElseOutput(cond, left, right, expected_data); +} + +void CheckWithDifferentShapes(const std::shared_ptr& cond, + const std::shared_ptr& left, + const std::shared_ptr& right, + const std::shared_ptr& expected) { + // this will check for whole arrays, every scalar at i'th index and slicing (offset) + CheckScalar("if_else", {cond, left, right}, expected); + + auto len = left->length(); + + enum { COND_SCALAR = 1, LEFT_SCALAR = 2, RIGHT_SCALAR = 4 }; + for (int mask = 0; mask < (COND_SCALAR | LEFT_SCALAR | RIGHT_SCALAR); ++mask) { + for (int64_t cond_idx = 0; cond_idx < len; ++cond_idx) { + Datum cond_in, cond_bcast; + std::string trace_cond = "Cond"; + if (mask & COND_SCALAR) { + ASSERT_OK_AND_ASSIGN(cond_in, cond->GetScalar(cond_idx)); + ASSERT_OK_AND_ASSIGN(cond_bcast, MakeArrayFromScalar(*cond_in.scalar(), len)); + trace_cond += "@" + std::to_string(cond_idx) + "=" + cond_in.scalar()->ToString(); + } else { + cond_in = cond_bcast = cond; + } + SCOPED_TRACE(trace_cond); + + for (int64_t left_idx = 0; left_idx < len; ++left_idx) { + Datum left_in, left_bcast; + std::string trace_left = "Left"; + if (mask & LEFT_SCALAR) { + ASSERT_OK_AND_ASSIGN(left_in, left->GetScalar(left_idx).As()); + ASSERT_OK_AND_ASSIGN(left_bcast, MakeArrayFromScalar(*left_in.scalar(), len)); + trace_cond += + "@" + std::to_string(left_idx) + "=" + left_in.scalar()->ToString(); + } else { + left_in = left_bcast = left; + } + SCOPED_TRACE(trace_left); + + for (int64_t right_idx = 0; right_idx < len; ++right_idx) { + Datum right_in, right_bcast; + std::string trace_right = "Right"; + if (mask & RIGHT_SCALAR) { + ASSERT_OK_AND_ASSIGN(right_in, right->GetScalar(right_idx)); + ASSERT_OK_AND_ASSIGN(right_bcast, + MakeArrayFromScalar(*right_in.scalar(), len)); + trace_right += + "@" + std::to_string(right_idx) + "=" + right_in.scalar()->ToString(); + } else { + right_in = right_bcast = right; + } + SCOPED_TRACE(trace_right); + + ASSERT_OK_AND_ASSIGN(auto exp, IfElse(cond_bcast, left_bcast, right_bcast)); + ASSERT_OK_AND_ASSIGN(auto actual, IfElse(cond_in, left_in, right_in)); + AssertDatumsEqual(exp, actual, /*verbose=*/true); + + if (right_in.is_array()) break; + } + if (left_in.is_array()) break; + } + if (cond_in.is_array()) break; + } + } // for (mask) +} + +TYPED_TEST(TestIfElsePrimitive, IfElseFixedSize) { + auto type = TypeTraits::type_singleton(); + + CheckWithDifferentShapes(ArrayFromJSON(boolean(), "[true, true, true, false]"), + ArrayFromJSON(type, "[1, 2, 3, 4]"), + ArrayFromJSON(type, "[5, 6, 7, 8]"), + ArrayFromJSON(type, "[1, 2, 3, 8]")); + + CheckWithDifferentShapes(ArrayFromJSON(boolean(), "[true, true, true, false]"), + ArrayFromJSON(type, "[1, 2, 3, 4]"), + ArrayFromJSON(type, "[5, 6, 7, null]"), + ArrayFromJSON(type, "[1, 2, 3, null]")); + + CheckWithDifferentShapes(ArrayFromJSON(boolean(), "[true, true, true, false]"), + ArrayFromJSON(type, "[1, 2, null, 4]"), + ArrayFromJSON(type, "[5, 6, 7, null]"), + ArrayFromJSON(type, "[1, 2, null, null]")); + + CheckWithDifferentShapes(ArrayFromJSON(boolean(), "[true, true, true, false]"), + ArrayFromJSON(type, "[1, 2, null, 4]"), + ArrayFromJSON(type, "[5, 6, 7, 8]"), + ArrayFromJSON(type, "[1, 2, null, 8]")); + + CheckWithDifferentShapes(ArrayFromJSON(boolean(), "[null, true, true, false]"), + ArrayFromJSON(type, "[1, 2, null, 4]"), + ArrayFromJSON(type, "[5, 6, 7, 8]"), + ArrayFromJSON(type, "[null, 2, null, 8]")); + + CheckWithDifferentShapes(ArrayFromJSON(boolean(), "[null, true, true, false]"), + ArrayFromJSON(type, "[1, 2, null, 4]"), + ArrayFromJSON(type, "[5, 6, 7, null]"), + ArrayFromJSON(type, "[null, 2, null, null]")); + + CheckWithDifferentShapes(ArrayFromJSON(boolean(), "[null, true, true, false]"), + ArrayFromJSON(type, "[1, 2, 3, 4]"), + ArrayFromJSON(type, "[5, 6, 7, null]"), + ArrayFromJSON(type, "[null, 2, 3, null]")); + + CheckWithDifferentShapes(ArrayFromJSON(boolean(), "[null, true, true, false]"), + ArrayFromJSON(type, "[1, 2, 3, 4]"), + ArrayFromJSON(type, "[5, 6, 7, 8]"), + ArrayFromJSON(type, "[null, 2, 3, 8]")); +} + +TEST_F(TestIfElseKernel, IfElseBoolean) { + auto type = boolean(); + + CheckWithDifferentShapes(ArrayFromJSON(boolean(), "[true, true, true, false]"), + ArrayFromJSON(type, "[false, false, false, false]"), + ArrayFromJSON(type, "[true, true, true, true]"), + ArrayFromJSON(type, "[false, false, false, true]")); + + CheckWithDifferentShapes(ArrayFromJSON(boolean(), "[true, true, true, false]"), + ArrayFromJSON(type, "[false, false, false, false]"), + ArrayFromJSON(type, "[true, true, true, null]"), + ArrayFromJSON(type, "[false, false, false, null]")); + + CheckWithDifferentShapes(ArrayFromJSON(boolean(), "[true, true, true, false]"), + ArrayFromJSON(type, "[false, false, null, false]"), + ArrayFromJSON(type, "[true, true, true, null]"), + ArrayFromJSON(type, "[false, false, null, null]")); + + CheckWithDifferentShapes(ArrayFromJSON(boolean(), "[true, true, true, false]"), + ArrayFromJSON(type, "[false, false, null, false]"), + ArrayFromJSON(type, "[true, true, true, true]"), + ArrayFromJSON(type, "[false, false, null, true]")); + + CheckWithDifferentShapes(ArrayFromJSON(boolean(), "[null, true, true, false]"), + ArrayFromJSON(type, "[false, false, null, false]"), + ArrayFromJSON(type, "[true, true, true, true]"), + ArrayFromJSON(type, "[null, false, null, true]")); + + CheckWithDifferentShapes(ArrayFromJSON(boolean(), "[null, true, true, false]"), + ArrayFromJSON(type, "[false, false, null, false]"), + ArrayFromJSON(type, "[true, true, true, null]"), + ArrayFromJSON(type, "[null, false, null, null]")); + + CheckWithDifferentShapes(ArrayFromJSON(boolean(), "[null, true, true, false]"), + ArrayFromJSON(type, "[false, false, false, false]"), + ArrayFromJSON(type, "[true, true, true, null]"), + ArrayFromJSON(type, "[null, false, false, null]")); + + CheckWithDifferentShapes(ArrayFromJSON(boolean(), "[null, true, true, false]"), + ArrayFromJSON(type, "[false, false, false, false]"), + ArrayFromJSON(type, "[true, true, true, true]"), + ArrayFromJSON(type, "[null, false, false, true]")); +} + +TEST_F(TestIfElseKernel, IfElseBooleanRand) { + auto type = boolean(); + random::RandomArrayGenerator rand(/*seed=*/0); + int64_t len = 1000; + auto cond = std::static_pointer_cast( + rand.ArrayOf(boolean(), len, /*null_probability=*/0.01)); + auto left = std::static_pointer_cast( + rand.ArrayOf(type, len, /*null_probability=*/0.01)); + auto right = std::static_pointer_cast( + rand.ArrayOf(type, len, /*null_probability=*/0.01)); + + BooleanBuilder builder; + for (int64_t i = 0; i < len; ++i) { + if (!cond->IsValid(i) || (cond->Value(i) && !left->IsValid(i)) || + (!cond->Value(i) && !right->IsValid(i))) { + ASSERT_OK(builder.AppendNull()); + continue; + } + + if (cond->Value(i)) { + ASSERT_OK(builder.Append(left->Value(i))); + } else { + ASSERT_OK(builder.Append(right->Value(i))); + } + } + ASSERT_OK_AND_ASSIGN(auto expected_data, builder.Finish()); + + CheckIfElseOutput(cond, left, right, expected_data); +} + +TEST_F(TestIfElseKernel, IfElseNull) { + CheckIfElseOutput(ArrayFromJSON(boolean(), "[null, null, null, null]"), + ArrayFromJSON(null(), "[null, null, null, null]"), + ArrayFromJSON(null(), "[null, null, null, null]"), + ArrayFromJSON(null(), "[null, null, null, null]")); +} + +} // namespace compute +} // namespace arrow diff --git a/cpp/src/arrow/compute/kernels/test_util.cc b/cpp/src/arrow/compute/kernels/test_util.cc index 672308452cf13..c74ef3b76ddd2 100644 --- a/cpp/src/arrow/compute/kernels/test_util.cc +++ b/cpp/src/arrow/compute/kernels/test_util.cc @@ -70,6 +70,8 @@ ScalarVector GetScalars(const ArrayVector& inputs, int64_t index) { return scalars; } +} // namespace + void CheckScalar(std::string func_name, const ScalarVector& inputs, std::shared_ptr expected, const FunctionOptions* options) { ASSERT_OK_AND_ASSIGN(Datum out, CallFunction(func_name, GetDatums(inputs), options)); @@ -140,8 +142,6 @@ void CheckScalar(std::string func_name, const ArrayVector& inputs, } } -} // namespace - void CheckScalarUnary(std::string func_name, std::shared_ptr input, std::shared_ptr expected, const FunctionOptions* options) { CheckScalar(std::move(func_name), {input}, expected, options); diff --git a/cpp/src/arrow/compute/kernels/test_util.h b/cpp/src/arrow/compute/kernels/test_util.h index aea3d8360e68e..cadcc4fe35cf9 100644 --- a/cpp/src/arrow/compute/kernels/test_util.h +++ b/cpp/src/arrow/compute/kernels/test_util.h @@ -19,13 +19,14 @@ // IWYU pragma: begin_exports +#include + #include #include #include -#include - #include "arrow/array.h" +#include "arrow/compute/kernel.h" #include "arrow/datum.h" #include "arrow/memory_pool.h" #include "arrow/pretty_print.h" @@ -34,8 +35,6 @@ #include "arrow/testing/util.h" #include "arrow/type.h" -#include "arrow/compute/kernel.h" - // IWYU pragma: end_exports namespace arrow { @@ -90,6 +89,14 @@ struct DatumEqual> { } }; +void CheckScalar(std::string func_name, const ScalarVector& inputs, + std::shared_ptr expected, + const FunctionOptions* options = nullptr); + +void CheckScalar(std::string func_name, const ArrayVector& inputs, + std::shared_ptr expected, + const FunctionOptions* options = nullptr); + void CheckScalarUnary(std::string func_name, std::shared_ptr in_ty, std::string json_input, std::shared_ptr out_ty, std::string json_expected, diff --git a/cpp/src/arrow/compute/registry.cc b/cpp/src/arrow/compute/registry.cc index 3a8a3a0eb8530..1d713b96e1ead 100644 --- a/cpp/src/arrow/compute/registry.cc +++ b/cpp/src/arrow/compute/registry.cc @@ -125,6 +125,7 @@ static std::unique_ptr CreateBuiltInRegistry() { RegisterScalarStringAscii(registry.get()); RegisterScalarValidity(registry.get()); RegisterScalarFillNull(registry.get()); + RegisterScalarIfElse(registry.get()); // Vector functions RegisterVectorHash(registry.get()); diff --git a/cpp/src/arrow/compute/registry_internal.h b/cpp/src/arrow/compute/registry_internal.h index e4008cf3f270f..f97553af4b104 100644 --- a/cpp/src/arrow/compute/registry_internal.h +++ b/cpp/src/arrow/compute/registry_internal.h @@ -34,6 +34,7 @@ void RegisterScalarSetLookup(FunctionRegistry* registry); void RegisterScalarStringAscii(FunctionRegistry* registry); void RegisterScalarValidity(FunctionRegistry* registry); void RegisterScalarFillNull(FunctionRegistry* registry); +void RegisterScalarIfElse(FunctionRegistry* registry); // Vector functions void RegisterVectorHash(FunctionRegistry* registry); diff --git a/cpp/src/arrow/util/bitmap_ops.cc b/cpp/src/arrow/util/bitmap_ops.cc index 32da60aafd9f6..a27a61cadf38a 100644 --- a/cpp/src/arrow/util/bitmap_ops.cc +++ b/cpp/src/arrow/util/bitmap_ops.cc @@ -583,5 +583,23 @@ void BitmapAndNot(const uint8_t* left, int64_t left_offset, const uint8_t* right BitmapOp(left, left_offset, right, right_offset, length, out_offset, out); } +template +struct OrNotOp { + constexpr T operator()(const T& l, const T& r) const { return l | ~r; } +}; + +Result> BitmapOrNot(MemoryPool* pool, const uint8_t* left, + int64_t left_offset, const uint8_t* right, + int64_t right_offset, int64_t length, + int64_t out_offset) { + return BitmapOp(pool, left, left_offset, right, right_offset, length, + out_offset); +} + +void BitmapOrNot(const uint8_t* left, int64_t left_offset, const uint8_t* right, + int64_t right_offset, int64_t length, int64_t out_offset, uint8_t* out) { + BitmapOp(left, left_offset, right, right_offset, length, out_offset, out); +} + } // namespace internal } // namespace arrow diff --git a/cpp/src/arrow/util/bitmap_ops.h b/cpp/src/arrow/util/bitmap_ops.h index 554e1d7468b98..40a7797a2398a 100644 --- a/cpp/src/arrow/util/bitmap_ops.h +++ b/cpp/src/arrow/util/bitmap_ops.h @@ -183,5 +183,24 @@ ARROW_EXPORT void BitmapAndNot(const uint8_t* left, int64_t left_offset, const uint8_t* right, int64_t right_offset, int64_t length, int64_t out_offset, uint8_t* out); +/// \brief Do a "bitmap or not" on right and left buffers starting at +/// their respective bit-offsets for the given bit-length and put +/// the results in out_buffer starting at the given bit-offset. +/// +/// out_buffer will be allocated and initialized to zeros using pool before +/// the operation. +ARROW_EXPORT +Result> BitmapOrNot(MemoryPool* pool, const uint8_t* left, + int64_t left_offset, const uint8_t* right, + int64_t right_offset, int64_t length, + int64_t out_offset); + +/// \brief Do a "bitmap or not" on right and left buffers starting at +/// their respective bit-offsets for the given bit-length and put +/// the results in out starting at the given bit-offset. +ARROW_EXPORT +void BitmapOrNot(const uint8_t* left, int64_t left_offset, const uint8_t* right, + int64_t right_offset, int64_t length, int64_t out_offset, uint8_t* out); + } // namespace internal } // namespace arrow diff --git a/docs/source/cpp/compute.rst b/docs/source/cpp/compute.rst index 3cf244ca5e83a..4e729b055cf04 100644 --- a/docs/source/cpp/compute.rst +++ b/docs/source/cpp/compute.rst @@ -648,40 +648,48 @@ Structural transforms +==========================+============+================================================+=====================+=========+ | fill_null | Binary | Boolean, Null, Numeric, Temporal, String-like | Input type | \(1) | +--------------------------+------------+------------------------------------------------+---------------------+---------+ -| is_finite | Unary | Float, Double | Boolean | \(2) | +| if_else | Ternary | Boolean, Null, Numeric, Temporal | Input type + \(2) | +--------------------------+------------+------------------------------------------------+---------------------+---------+ -| is_inf | Unary | Float, Double | Boolean | \(3) | +| is_finite | Unary | Float, Double | Boolean | \(3) | +--------------------------+------------+------------------------------------------------+---------------------+---------+ -| is_nan | Unary | Float, Double | Boolean | \(4) | +| is_inf | Unary | Float, Double | Boolean | \(4) | +--------------------------+------------+------------------------------------------------+---------------------+---------+ -| is_null | Unary | Any | Boolean | \(5) | +| is_nan | Unary | Float, Double | Boolean | \(5) | +--------------------------+------------+------------------------------------------------+---------------------+---------+ -| is_valid | Unary | Any | Boolean | \(6) | +| is_null | Unary | Any | Boolean | \(6) | +--------------------------+------------+------------------------------------------------+---------------------+---------+ -| list_value_length | Unary | List-like | Int32 or Int64 | \(7) | +| is_valid | Unary | Any | Boolean | \(7) | +--------------------------+------------+------------------------------------------------+---------------------+---------+ -| project | Varargs | Any | Struct | \(8) | +| list_value_length | Unary | List-like | Int32 or Int64 | \(8) | ++--------------------------+------------+------------------------------------------------+---------------------+---------+ +| project | Varargs | Any | Struct | \(9) | +--------------------------+------------+------------------------------------------------+---------------------+---------+ * \(1) First input must be an array, second input a scalar of the same type. Output is an array of the same type as the inputs, and with the same values as the first input, except for nulls replaced with the second input value. -* \(2) Output is true iff the corresponding input element is finite (not Infinity, +* \(2) First input must be a Boolean scalar or array. Second and third inputs + could be scalars or arrays and must be of the same type. Output is an array + (or scalar if all inputs are scalar) of the same type as the second/ third + input. If the nulls present on the first input, they will be promoted to the + output, otherwise nulls will be chosen based on the first input values. + +* \(3) Output is true iff the corresponding input element is finite (not Infinity, -Infinity, or NaN). -* \(3) Output is true iff the corresponding input element is Infinity/-Infinity. +* \(4) Output is true iff the corresponding input element is Infinity/-Infinity. -* \(4) Output is true iff the corresponding input element is NaN. +* \(5) Output is true iff the corresponding input element is NaN. -* \(5) Output is true iff the corresponding input element is null. +* \(6) Output is true iff the corresponding input element is null. -* \(6) Output is true iff the corresponding input element is non-null. +* \(7) Output is true iff the corresponding input element is non-null. -* \(7) Each output element is the length of the corresponding input element +* \(8) Each output element is the length of the corresponding input element (null if input is null). Output type is Int32 for List, Int64 for LargeList. -* \(8) The output struct's field types are the types of its arguments. The +* \(9) The output struct's field types are the types of its arguments. The field names are specified using an instance of :struct:`ProjectOptions`. The output shape will be scalar if all inputs are scalar, otherwise any scalars will be broadcast to arrays. diff --git a/docs/source/python/api/compute.rst b/docs/source/python/api/compute.rst index 91eeeedbeaabe..3010776930f30 100644 --- a/docs/source/python/api/compute.rst +++ b/docs/source/python/api/compute.rst @@ -222,6 +222,7 @@ Structural Transforms binary_length fill_null + if_else is_finite is_inf is_nan