Skip to content

Commit

Permalink
ARROW-7982: [C++] Add function VisitArrayDataInline() helper
Browse files Browse the repository at this point in the history
Avoids the hassle of defining a dedicated struct with ArrayDataVisitor.

Closes #6535 from pitrou/ARROW-7982-data-visitor and squashes the following commits:

351812e <Antoine Pitrou> ARROW-7982:  Add function VisitArrayDataInline() helper

Authored-by: Antoine Pitrou <antoine@python.org>
Signed-off-by: Antoine Pitrou <antoine@python.org>
  • Loading branch information
pitrou committed Mar 9, 2020
1 parent 116672f commit a64f590
Show file tree
Hide file tree
Showing 7 changed files with 335 additions and 196 deletions.
23 changes: 9 additions & 14 deletions cpp/src/arrow/compute/kernels/cast.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1029,25 +1029,20 @@ struct CastFunctor<O, I,
using BuilderType = typename TypeTraits<O>::BuilderType;
using FormatterType = typename internal::StringFormatter<I>;

struct Visitor {
Visitor(FunctionContext* ctx, const ArrayData& input)
: formatter_(input.type), builder_(input.type, ctx->memory_pool()) {}
FormatterType formatter(input.type);
BuilderType builder(input.type, ctx->memory_pool());

Status VisitNull() { return builder_.AppendNull(); }

Status VisitValue(value_type value) {
return formatter_(value,
[this](util::string_view v) { return builder_.Append(v); });
auto convert_value = [&](util::optional<value_type> v) {
if (v.has_value()) {
return formatter(*v, [&](util::string_view v) { return builder.Append(v); });
} else {
return builder.AppendNull();
}

FormatterType formatter_;
BuilderType builder_;
};
RETURN_NOT_OK(VisitArrayDataInline<I>(input, std::move(convert_value)));

Visitor visitor(ctx, input);
RETURN_NOT_OK(ArrayDataVisitor<I>::Visit(input, &visitor));
std::shared_ptr<Array> output_array;
RETURN_NOT_OK(visitor.builder_.Finish(&output_array));
RETURN_NOT_OK(builder.Finish(&output_array));
*output = std::move(*output_array->data());
return Status::OK();
}
Expand Down
105 changes: 56 additions & 49 deletions cpp/src/arrow/compute/kernels/hash.cc
Original file line number Diff line number Diff line change
Expand Up @@ -282,7 +282,7 @@ class RegularHashKernelImpl : public HashKernelImpl {

Status Append(const ArrayData& arr) override {
RETURN_NOT_OK(action_.Reserve(arr.length));
return ArrayDataVisitor<Type>::Visit(arr, this);
return DoAppend(arr);
}

Status Flush(Datum* out) override { return action_.Flush(out); }
Expand All @@ -295,59 +295,66 @@ class RegularHashKernelImpl : public HashKernelImpl {
}

template <bool HasError = with_error_status>
enable_if_t<!HasError, Status> VisitNull() {
auto on_found = [this](int32_t memo_index) { action_.ObserveNullFound(memo_index); };
auto on_not_found = [this](int32_t memo_index) {
action_.ObserveNullNotFound(memo_index);
};

if (with_memo_visit_null) {
memo_table_->GetOrInsertNull(on_found, on_not_found);
} else {
action_.ObserveNullNotFound(-1);
}
return Status::OK();
}

template <bool HasError = with_error_status>
enable_if_t<HasError, Status> VisitNull() {
Status s = Status::OK();
auto on_found = [this](int32_t memo_index) { action_.ObserveFound(memo_index); };
auto on_not_found = [this, &s](int32_t memo_index) {
action_.ObserveNotFound(memo_index, &s);
};

if (with_memo_visit_null) {
memo_table_->GetOrInsertNull(on_found, on_not_found);
} else {
action_.ObserveNullNotFound(-1);
}

return s;
}

template <bool HasError = with_error_status>
enable_if_t<!HasError, Status> VisitValue(const Scalar& value) {
auto on_found = [this](int32_t memo_index) { action_.ObserveFound(memo_index); };
auto on_not_found = [this](int32_t memo_index) {
action_.ObserveNotFound(memo_index);
enable_if_t<!HasError, Status> DoAppend(const ArrayData& arr) {
auto process_value = [this](util::optional<Scalar> v) {
if (v.has_value()) {
auto on_found = [this](int32_t memo_index) { action_.ObserveFound(memo_index); };
auto on_not_found = [this](int32_t memo_index) {
action_.ObserveNotFound(memo_index);
};

int32_t unused_memo_index;
return memo_table_->GetOrInsert(*v, std::move(on_found), std::move(on_not_found),
&unused_memo_index);
} else {
// Null
if (with_memo_visit_null) {
auto on_found = [this](int32_t memo_index) {
action_.ObserveNullFound(memo_index);
};
auto on_not_found = [this](int32_t memo_index) {
action_.ObserveNullNotFound(memo_index);
};
memo_table_->GetOrInsertNull(std::move(on_found), std::move(on_not_found));
} else {
action_.ObserveNullNotFound(-1);
}
return Status::OK();
}
};

int32_t unused_memo_index;
return memo_table_->GetOrInsert(value, on_found, on_not_found, &unused_memo_index);
return VisitArrayDataInline<Type>(arr, std::move(process_value));
}

template <bool HasError = with_error_status>
enable_if_t<HasError, Status> VisitValue(const Scalar& value) {
Status s = Status::OK();
auto on_found = [this](int32_t memo_index) { action_.ObserveFound(memo_index); };
auto on_not_found = [this, &s](int32_t memo_index) {
action_.ObserveNotFound(memo_index, &s);
enable_if_t<HasError, Status> DoAppend(const ArrayData& arr) {
auto process_value = [this](util::optional<Scalar> v) {
Status s = Status::OK();
if (v.has_value()) {
auto on_found = [this](int32_t memo_index) { action_.ObserveFound(memo_index); };
auto on_not_found = [this, &s](int32_t memo_index) {
action_.ObserveNotFound(memo_index, &s);
};

int32_t unused_memo_index;
RETURN_NOT_OK(memo_table_->GetOrInsert(
*v, std::move(on_found), std::move(on_not_found), &unused_memo_index));
} else {
// Null
if (with_memo_visit_null) {
auto on_found = [this](int32_t memo_index) {
action_.ObserveNullFound(memo_index);
};
auto on_not_found = [this, &s](int32_t memo_index) {
action_.ObserveNullNotFound(memo_index, &s);
};
memo_table_->GetOrInsertNull(std::move(on_found), std::move(on_not_found));
} else {
action_.ObserveNullNotFound(-1);
}
}
return s;
};
int32_t unused_memo_index;
RETURN_NOT_OK(
memo_table_->GetOrInsert(value, on_found, on_not_found, &unused_memo_index));
return s;
return VisitArrayDataInline<Type>(arr, std::move(process_value));
}

std::shared_ptr<DataType> out_type() const override { return action_.out_type(); }
Expand Down
50 changes: 20 additions & 30 deletions cpp/src/arrow/compute/kernels/isin.cc
Original file line number Diff line number Diff line change
Expand Up @@ -75,13 +75,6 @@ class IsInKernelImpl : public UnaryKernel {

template <typename T, typename Scalar>
struct MemoTableRight {
Status VisitNull() { return Status::OK(); }

Status VisitValue(const Scalar& value) {
int32_t unused_memo_index;
return memo_table_->GetOrInsert(value, &unused_memo_index);
}

Status Reset(MemoryPool* pool) {
memo_table_.reset(new MemoTable(pool, 0));
return Status::OK();
Expand All @@ -90,7 +83,16 @@ struct MemoTableRight {
Status Append(FunctionContext* ctx, const Datum& right) {
const ArrayData& right_data = *right.array();
right_null_count += right_data.GetNullCount();
return ArrayDataVisitor<T>::Visit(right_data, this);

auto insert_value = [&](util::optional<Scalar> v) {
if (v.has_value()) {
int32_t unused_memo_index;
return memo_table_->GetOrInsert(*v, &unused_memo_index);
} else {
return Status::OK();
}
};
return VisitArrayDataInline<T>(right_data, std::move(insert_value));
}

using MemoTable = typename HashTraits<T>::MemoTableType;
Expand All @@ -106,27 +108,6 @@ class IsInKernel : public IsInKernelImpl {
IsInKernel(const std::shared_ptr<DataType>& type, MemoryPool* pool)
: type_(type), pool_(pool) {}

// \brief if left array has a null return true
Status VisitNull() {
writer->Set();
writer->Next();
return Status::OK();
}

// \brief Iterate over the left array using another visitor.
// In VisitValue, use the memo_table_ (for right array) and check if value
// in left array is in the memo_table_. Return true if condition satisfied,
// else false.
Status VisitValue(const Scalar& value) {
if (memo_table_->Get(value) != -1) {
writer->Set();
} else {
writer->Clear();
}
writer->Next();
return Status::OK();
}

Status Compute(FunctionContext* ctx, const Datum& left, Datum* out) override {
const ArrayData& left_data = *left.array();

Expand All @@ -136,7 +117,16 @@ class IsInKernel : public IsInKernelImpl {
writer = std::make_shared<internal::FirstTimeBitmapWriter>(
output.get()->buffers[1]->mutable_data(), output.get()->offset, left_data.length);

RETURN_NOT_OK(ArrayDataVisitor<Type>::Visit(left_data, this));
auto lookup_value = [&](util::optional<Scalar> v) {
if (!v.has_value() || memo_table_->Get(*v) != -1) {
writer->Set();
} else {
writer->Clear();
}
writer->Next();
};
VisitArrayDataInline<Type>(left_data, std::move(lookup_value));

writer->Finish();

// if right null count is zero and left null count is not zero, propagate nulls
Expand Down
79 changes: 24 additions & 55 deletions cpp/src/arrow/compute/kernels/sort_to_indices.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#include <algorithm>
#include <limits>
#include <numeric>
#include <utility>
#include <vector>

#include "arrow/builder.h"
Expand Down Expand Up @@ -135,52 +136,28 @@ class CountSorter {
// first slot reserved for prefix sum, last slot for null value
std::vector<CounterType> counts(1 + value_range + 1);

struct UpdateCounts {
Status VisitNull() {
++counts[value_range];
return Status::OK();
auto update_counts = [&](util::optional<c_type> v) {
if (v.has_value()) {
++counts[*v - min_ + 1];
} else {
++counts[value_range + 1];
}

Status VisitValue(c_type v) {
++counts[v - min];
return Status::OK();
}

CounterType* counts;
const uint32_t value_range;
c_type min;
};
{
UpdateCounts update_counts{&counts[1], value_range, min_};
ARROW_CHECK_OK(ArrayDataVisitor<ArrowType>().Visit(*values.data(), &update_counts));
}
VisitArrayDataInline<ArrowType>(*values.data(), std::move(update_counts));

for (uint32_t i = 1; i <= value_range; ++i) {
counts[i] += counts[i - 1];
}

struct OutputIndices {
Status VisitNull() {
out_indices[counts[value_range]++] = index++;
return Status::OK();
}

Status VisitValue(c_type v) {
out_indices[counts[v - min]++] = index++;
return Status::OK();
int64_t index = 0;
auto write_index = [&](util::optional<c_type> v) {
if (v.has_value()) {
indices_begin[counts[*v - min_]++] = index++;
} else {
indices_begin[counts[value_range]++] = index++;
}

CounterType* counts;
const uint32_t value_range;
c_type min;
int64_t* out_indices;
int64_t index;
};
{
OutputIndices output_indices{&counts[0], value_range, min_, indices_begin, 0};
ARROW_CHECK_OK(
ArrayDataVisitor<ArrowType>().Visit(*values.data(), &output_indices));
}
VisitArrayDataInline<ArrowType>(*values.data(), std::move(write_index));
}
};

Expand All @@ -197,29 +174,21 @@ class CountOrCompareSorter {

void Sort(int64_t* indices_begin, int64_t* indices_end, const ArrayType& values) {
if (values.length() >= countsort_min_len_ && values.length() > values.null_count()) {
struct MinMaxScanner {
Status VisitNull() { return Status::OK(); }
c_type min{std::numeric_limits<c_type>::max()};
c_type max{std::numeric_limits<c_type>::min()};

Status VisitValue(c_type v) {
min = std::min(min, v);
max = std::max(max, v);
return Status::OK();
auto update_minmax = [&min, &max](util::optional<c_type> v) {
if (v.has_value()) {
min = std::min(min, *v);
max = std::max(max, *v);
}

c_type min{std::numeric_limits<c_type>::max()};
c_type max{std::numeric_limits<c_type>::min()};
};

MinMaxScanner minmax_scanner;
ARROW_CHECK_OK(
ArrayDataVisitor<ArrowType>().Visit(*values.data(), &minmax_scanner));

VisitArrayDataInline<ArrowType>(*values.data(), std::move(update_minmax));
// For signed int32/64, (max - min) may overflow and trigger UBSAN.
// Cast to largest unsigned type(uint64_t) before substraction.
const uint64_t min = static_cast<uint64_t>(minmax_scanner.min);
const uint64_t max = static_cast<uint64_t>(minmax_scanner.max);
if ((max - min) <= countsort_max_range_) {
count_sorter_.SetMinMax(minmax_scanner.min, minmax_scanner.max);
if (static_cast<uint64_t>(max) - static_cast<uint64_t>(min) <=
countsort_max_range_) {
count_sorter_.SetMinMax(min, max);
count_sorter_.Sort(indices_begin, indices_end, values);
return;
}
Expand Down
4 changes: 4 additions & 0 deletions cpp/src/arrow/util/functional.h
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,10 @@ struct call_traits {

template <typename F>
using return_type = decltype(return_type_impl(&std::decay<F>::type::operator()));

template <typename F, typename T, typename RT = T>
using enable_if_return =
typename std::enable_if<std::is_same<return_type<F>, T>::value, RT>;
};

} // namespace internal
Expand Down
2 changes: 1 addition & 1 deletion cpp/src/arrow/util/future.cc
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ class FutureWaiterImpl : public FutureWaiter {

// Is the ending condition satisfied?
bool ShouldSignal() {
bool do_signal;
bool do_signal = false;
switch (kind_) {
case ANY:
do_signal = (finished_futures_.size() > 0);
Expand Down
Loading

0 comments on commit a64f590

Please sign in to comment.