Skip to content

Commit

Permalink
ARROW-13136: [C++] Add coalesce function
Browse files Browse the repository at this point in the history
Closes #10608 from lidavidm/arrow-13136

Authored-by: David Li <li.davidm96@gmail.com>
Signed-off-by: Benjamin Kietzman <bengilgit@gmail.com>
  • Loading branch information
lidavidm authored and bkietz committed Jul 19, 2021
1 parent c9b9fa4 commit c848f12
Show file tree
Hide file tree
Showing 7 changed files with 562 additions and 20 deletions.
4 changes: 3 additions & 1 deletion cpp/src/arrow/compute/kernels/codegen_internal.cc
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,9 @@ const std::vector<std::shared_ptr<DataType>>& ExampleParametricTypes() {
// work above

Result<ValueDescr> FirstType(KernelContext*, const std::vector<ValueDescr>& descrs) {
return descrs[0];
ValueDescr result = descrs.front();
result.shape = GetBroadcastShape(descrs);
return result;
}

void EnsureDictionaryDecoded(std::vector<ValueDescr>* descrs) {
Expand Down
273 changes: 273 additions & 0 deletions cpp/src/arrow/compute/kernels/scalar_if_else.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1166,6 +1166,22 @@ void CopyValues(const Datum& in_values, const int64_t in_offset, const int64_t l
}
}

// Specialized helper to copy a single value from a source array. Allows avoiding
// repeatedly calling MayHaveNulls and Buffer::data() which have internal checks that
// add up when called in a loop.
template <typename Type>
void CopyOneArrayValue(const DataType& type, const uint8_t* in_valid,
const uint8_t* in_values, const int64_t in_offset,
uint8_t* out_valid, uint8_t* out_values,
const int64_t out_offset) {
if (out_valid) {
BitUtil::SetBitTo(out_valid, out_offset,
!in_valid || BitUtil::GetBit(in_valid, in_offset));
}
CopyFixedWidth<Type>::CopyArray(type, in_values, in_offset, /*length=*/1, out_values,
out_offset);
}

struct CaseWhenFunction : ScalarFunction {
using ScalarFunction::ScalarFunction;

Expand Down Expand Up @@ -1375,6 +1391,221 @@ struct CaseWhenFunctor<NullType> {
}
};

struct CoalesceFunction : ScalarFunction {
using ScalarFunction::ScalarFunction;

Result<const Kernel*> DispatchBest(std::vector<ValueDescr>* values) const override {
RETURN_NOT_OK(CheckArity(*values));
using arrow::compute::detail::DispatchExactImpl;
if (auto kernel = DispatchExactImpl(this, *values)) return kernel;
EnsureDictionaryDecoded(values);
if (auto type = CommonNumeric(*values)) {
ReplaceTypes(type, values);
}
if (auto kernel = DispatchExactImpl(this, *values)) return kernel;
return arrow::compute::detail::NoMatchingKernel(this, *values);
}
};

// Implement a 'coalesce' (SQL) operator for any number of scalar inputs
Status ExecScalarCoalesce(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
for (const auto& datum : batch.values) {
if (datum.scalar()->is_valid) {
*out = datum;
break;
}
}
return Status::OK();
}

// Helper: copy from a source datum into all null slots of the output
template <typename Type>
void CopyValuesAllValid(Datum source, uint8_t* out_valid, uint8_t* out_values,
const int64_t out_offset, const int64_t length) {
BitBlockCounter counter(out_valid, out_offset, length);
int64_t offset = 0;
while (offset < length) {
const auto block = counter.NextWord();
if (block.NoneSet()) {
CopyValues<Type>(source, offset, block.length, out_valid, out_values,
out_offset + offset);
} else if (!block.AllSet()) {
for (int64_t j = 0; j < block.length; ++j) {
if (!BitUtil::GetBit(out_valid, out_offset + offset + j)) {
CopyValues<Type>(source, offset + j, 1, out_valid, out_values,
out_offset + offset + j);
}
}
}
offset += block.length;
}
}

// Helper: zero the values buffer of the output wherever the slot is null
void InitializeNullSlots(const DataType& type, uint8_t* out_valid, uint8_t* out_values,
const int64_t out_offset, const int64_t length) {
BitBlockCounter counter(out_valid, out_offset, length);
int64_t offset = 0;
auto bit_width = checked_cast<const FixedWidthType&>(type).bit_width();
auto byte_width = BitUtil::BytesForBits(bit_width);
while (offset < length) {
const auto block = counter.NextWord();
if (block.NoneSet()) {
if (bit_width == 1) {
BitUtil::SetBitsTo(out_values, out_offset + offset, block.length, false);
} else {
std::memset(out_values + (out_offset + offset) * byte_width, 0x00,
byte_width * block.length);
}
} else if (!block.AllSet()) {
for (int64_t j = 0; j < block.length; ++j) {
if (BitUtil::GetBit(out_valid, out_offset + offset + j)) continue;
if (bit_width == 1) {
BitUtil::ClearBit(out_values, out_offset + offset + j);
} else {
std::memset(out_values + (out_offset + offset + j) * byte_width, 0x00,
byte_width);
}
}
}
offset += block.length;
}
}

// Implement 'coalesce' for any mix of scalar/array arguments for any fixed-width type
template <typename Type>
Status ExecArrayCoalesce(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
ArrayData* output = out->mutable_array();
const int64_t out_offset = output->offset;
// Use output validity buffer as mask to decide what values to copy
uint8_t* out_valid = output->buffers[0]->mutable_data();
// Clear output buffer - no values are set initially
BitUtil::SetBitsTo(out_valid, out_offset, batch.length, false);
uint8_t* out_values = output->buffers[1]->mutable_data();

for (const auto& datum : batch.values) {
if ((datum.is_scalar() && datum.scalar()->is_valid) ||
(datum.is_array() && !datum.array()->MayHaveNulls())) {
// Valid scalar, or all-valid array
CopyValuesAllValid<Type>(datum, out_valid, out_values, out_offset, batch.length);
break;
} else if (datum.is_array()) {
// Array with nulls
const ArrayData& arr = *datum.array();
const DataType& type = *datum.type();
const uint8_t* in_valid = arr.buffers[0]->data();
const uint8_t* in_values = arr.buffers[1]->data();
BinaryBitBlockCounter counter(in_valid, arr.offset, out_valid, out_offset,
batch.length);
int64_t offset = 0;
while (offset < batch.length) {
const auto block = counter.NextAndNotWord();
if (block.AllSet()) {
CopyValues<Type>(datum, offset, block.length, out_valid, out_values,
out_offset + offset);
} else if (block.popcount) {
for (int64_t j = 0; j < block.length; ++j) {
if (!BitUtil::GetBit(out_valid, out_offset + offset + j) &&
BitUtil::GetBit(in_valid, arr.offset + offset + j)) {
// This version lets us avoid calling MayHaveNulls() on every iteration
// (which does an atomic load and can add up)
CopyOneArrayValue<Type>(type, in_valid, in_values, arr.offset + offset + j,
out_valid, out_values, out_offset + offset + j);
}
}
}
offset += block.length;
}
}
}

// Initialize any remaining null slots (uninitialized memory)
InitializeNullSlots(*out->type(), out_valid, out_values, out_offset, batch.length);
return Status::OK();
}

template <typename Type, typename Enable = void>
struct CoalesceFunctor {
static Status Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
for (const auto& datum : batch.values) {
if (datum.is_array()) {
return ExecArrayCoalesce<Type>(ctx, batch, out);
}
}
return ExecScalarCoalesce(ctx, batch, out);
}
};

template <>
struct CoalesceFunctor<NullType> {
static Status Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
return Status::OK();
}
};

template <typename Type>
struct CoalesceFunctor<Type, enable_if_base_binary<Type>> {
using offset_type = typename Type::offset_type;
using BuilderType = typename TypeTraits<Type>::BuilderType;
static Status Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
for (const auto& datum : batch.values) {
if (datum.is_array()) {
return ExecArray(ctx, batch, out);
}
}
return ExecScalarCoalesce(ctx, batch, out);
}

static Status ExecArray(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
// Special case: grab any leading non-null scalar or array arguments
for (const auto& datum : batch.values) {
if (datum.is_scalar()) {
if (!datum.scalar()->is_valid) continue;
ARROW_ASSIGN_OR_RAISE(
*out, MakeArrayFromScalar(*datum.scalar(), batch.length, ctx->memory_pool()));
return Status::OK();
} else if (datum.is_array() && !datum.array()->MayHaveNulls()) {
*out = datum;
return Status::OK();
}
break;
}
ArrayData* output = out->mutable_array();
BuilderType builder(batch[0].type(), ctx->memory_pool());
RETURN_NOT_OK(builder.Reserve(batch.length));
for (int64_t i = 0; i < batch.length; i++) {
bool set = false;
for (const auto& datum : batch.values) {
if (datum.is_scalar()) {
if (datum.scalar()->is_valid) {
RETURN_NOT_OK(builder.Append(UnboxScalar<Type>::Unbox(*datum.scalar())));
set = true;
break;
}
} else {
const ArrayData& source = *datum.array();
if (!source.MayHaveNulls() ||
BitUtil::GetBit(source.buffers[0]->data(), source.offset + i)) {
const uint8_t* data = source.buffers[2]->data();
const offset_type* offsets = source.GetValues<offset_type>(1);
const offset_type offset0 = offsets[i];
const offset_type offset1 = offsets[i + 1];
RETURN_NOT_OK(builder.Append(data + offset0, offset1 - offset0));
set = true;
break;
}
}
}
if (!set) RETURN_NOT_OK(builder.AppendNull());
}
ARROW_ASSIGN_OR_RAISE(auto temp_output, builder.Finish());
*output = *temp_output->data();
// Builder type != logical type due to GenerateTypeAgnosticVarBinaryBase
output->type = batch[0].type();
return Status::OK();
}
};

Result<ValueDescr> LastType(KernelContext*, const std::vector<ValueDescr>& descrs) {
ValueDescr result = descrs.back();
result.shape = GetBroadcastShape(descrs);
Expand Down Expand Up @@ -1402,6 +1633,25 @@ void AddPrimitiveCaseWhenKernels(const std::shared_ptr<CaseWhenFunction>& scalar
}
}

void AddCoalesceKernel(const std::shared_ptr<ScalarFunction>& scalar_function,
detail::GetTypeId get_id, ArrayKernelExec exec) {
ScalarKernel kernel(KernelSignature::Make({InputType(get_id.id)}, OutputType(FirstType),
/*is_varargs=*/true),
exec);
kernel.null_handling = NullHandling::COMPUTED_PREALLOCATE;
kernel.mem_allocation = MemAllocation::PREALLOCATE;
kernel.can_write_into_slices = is_fixed_width(get_id.id);
DCHECK_OK(scalar_function->AddKernel(std::move(kernel)));
}

void AddPrimitiveCoalesceKernels(const std::shared_ptr<ScalarFunction>& scalar_function,
const std::vector<std::shared_ptr<DataType>>& types) {
for (auto&& type : types) {
auto exec = GenerateTypeAgnosticPrimitive<CoalesceFunctor>(*type);
AddCoalesceKernel(scalar_function, type, std::move(exec));
}
}

const FunctionDoc if_else_doc{"Choose values based on a condition",
("`cond` must be a Boolean scalar/ array. \n`left` or "
"`right` must be of the same type scalar/ array.\n"
Expand All @@ -1422,6 +1672,13 @@ const FunctionDoc case_when_doc{
"Essentially, this implements a switch-case or if-else, if-else... "
"statement."),
{"cond", "*cases"}};

const FunctionDoc coalesce_doc{
"Select the first non-null value in each slot",
("Each row of the output will be the value from the first corresponding input "
"for which the value is not null. If all inputs are null in a row, the output "
"will be null."),
{"*values"}};
} // namespace

void RegisterScalarIfElse(FunctionRegistry* registry) {
Expand Down Expand Up @@ -1450,6 +1707,22 @@ void RegisterScalarIfElse(FunctionRegistry* registry) {
AddCaseWhenKernel(func, Type::DECIMAL256, CaseWhenFunctor<Decimal256Type>::Exec);
DCHECK_OK(registry->AddFunction(std::move(func)));
}
{
auto func = std::make_shared<CoalesceFunction>(
"coalesce", Arity::VarArgs(/*min_args=*/1), &coalesce_doc);
AddPrimitiveCoalesceKernels(func, NumericTypes());
AddPrimitiveCoalesceKernels(func, TemporalTypes());
AddPrimitiveCoalesceKernels(
func, {boolean(), null(), day_time_interval(), month_interval()});
AddCoalesceKernel(func, Type::FIXED_SIZE_BINARY,
CoalesceFunctor<FixedSizeBinaryType>::Exec);
AddCoalesceKernel(func, Type::DECIMAL128, CoalesceFunctor<Decimal128Type>::Exec);
AddCoalesceKernel(func, Type::DECIMAL256, CoalesceFunctor<Decimal256Type>::Exec);
for (const auto& ty : BaseBinaryTypes()) {
AddCoalesceKernel(func, ty, GenerateTypeAgnosticVarBinaryBase<CoalesceFunctor>(ty));
}
DCHECK_OK(registry->AddFunction(std::move(func)));
}
}

} // namespace internal
Expand Down
61 changes: 61 additions & 0 deletions cpp/src/arrow/compute/kernels/scalar_if_else_benchmark.cc
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,61 @@ static void CaseWhenBench64Contiguous(benchmark::State& state) {
return CaseWhenBenchContiguous<UInt64Type>(state);
}

template <typename Type>
static void CoalesceBench(benchmark::State& state) {
using CType = typename Type::c_type;
auto type = TypeTraits<Type>::type_singleton();

int64_t len = state.range(0);
int64_t offset = state.range(1);

random::RandomArrayGenerator rand(/*seed=*/0);

std::vector<Datum> arguments;
for (int i = 0; i < 4; i++) {
arguments.emplace_back(
rand.ArrayOf(type, len, /*null_probability=*/0.25)->Slice(offset));
}

for (auto _ : state) {
ABORT_NOT_OK(CallFunction("coalesce", arguments));
}

state.SetBytesProcessed(state.iterations() * arguments.size() * (len - offset) *
sizeof(CType));
}

template <typename Type>
static void CoalesceNonNullBench(benchmark::State& state) {
using CType = typename Type::c_type;
auto type = TypeTraits<Type>::type_singleton();

int64_t len = state.range(0);
int64_t offset = state.range(1);

random::RandomArrayGenerator rand(/*seed=*/0);

std::vector<Datum> arguments;
arguments.emplace_back(
rand.ArrayOf(type, len, /*null_probability=*/0.25)->Slice(offset));
arguments.emplace_back(rand.ArrayOf(type, len, /*null_probability=*/0)->Slice(offset));

for (auto _ : state) {
ABORT_NOT_OK(CallFunction("coalesce", arguments));
}

state.SetBytesProcessed(state.iterations() * arguments.size() * (len - offset) *
sizeof(CType));
}

static void CoalesceBench64(benchmark::State& state) {
return CoalesceBench<Int64Type>(state);
}

static void CoalesceNonNullBench64(benchmark::State& state) {
return CoalesceBench<Int64Type>(state);
}

BENCHMARK(IfElseBench32)->Args({elems, 0});
BENCHMARK(IfElseBench64)->Args({elems, 0});

Expand All @@ -251,5 +306,11 @@ BENCHMARK(CaseWhenBench64)->Args({elems, 99});
BENCHMARK(CaseWhenBench64Contiguous)->Args({elems, 0});
BENCHMARK(CaseWhenBench64Contiguous)->Args({elems, 99});

BENCHMARK(CoalesceBench64)->Args({elems, 0});
BENCHMARK(CoalesceBench64)->Args({elems, 99});

BENCHMARK(CoalesceNonNullBench64)->Args({elems, 0});
BENCHMARK(CoalesceNonNullBench64)->Args({elems, 99});

} // namespace compute
} // namespace arrow
Loading

0 comments on commit c848f12

Please sign in to comment.