Skip to content

Commit

Permalink
Updates and perf improvements
Browse files Browse the repository at this point in the history
  • Loading branch information
pitrou committed Jun 3, 2021
1 parent 08b6bc3 commit 9353e4e
Show file tree
Hide file tree
Showing 4 changed files with 240 additions and 111 deletions.
32 changes: 23 additions & 9 deletions cpp/src/arrow/array/builder_binary.h
Expand Up @@ -77,23 +77,21 @@ class BaseBinaryBuilder : public ArrayBuilder {
return Append(value.data(), static_cast<offset_type>(value.size()));
}

/// AppendCurrent does not add a new offset
Status AppendCurrent(const uint8_t* value, offset_type length) {
/// Append to the last appended value
///
/// Unlike Append, this does not create a new offset.
Status AppendToCurrent(const uint8_t* value, offset_type length) {
// Safety check for UBSAN.
if (ARROW_PREDICT_TRUE(length > 0)) {
ARROW_RETURN_NOT_OK(ValidateOverflow(length));
ARROW_RETURN_NOT_OK(value_data_builder_.Append(value, length));
}

return Status::OK();
}

Status AppendCurrent(const char* value, offset_type length) {
return AppendCurrent(reinterpret_cast<const uint8_t*>(value), length);
}

Status AppendCurrent(util::string_view value) {
return AppendCurrent(value.data(), static_cast<offset_type>(value.size()));
Status AppendToCurrent(util::string_view value) {
return AppendToCurrent(reinterpret_cast<const uint8_t*>(value.data()),
static_cast<offset_type>(value.size()));
}

Status AppendNulls(int64_t length) final {
Expand Down Expand Up @@ -152,12 +150,28 @@ class BaseBinaryBuilder : public ArrayBuilder {
UnsafeAppend(value.data(), static_cast<offset_type>(value.size()));
}

/// Like AppendToCurrent, but do not check capacity
void UnsafeAppendToCurrent(const uint8_t* value, offset_type length) {
value_data_builder_.UnsafeAppend(value, length);
}

void UnsafeAppendToCurrent(util::string_view value) {
UnsafeAppendToCurrent(reinterpret_cast<const uint8_t*>(value.data()),
static_cast<offset_type>(value.size()));
}

void UnsafeAppendNull() {
const int64_t num_bytes = value_data_builder_.length();
offsets_builder_.UnsafeAppend(static_cast<offset_type>(num_bytes));
UnsafeAppendToBitmap(false);
}

void UnsafeAppendEmptyValue() {
const int64_t num_bytes = value_data_builder_.length();
offsets_builder_.UnsafeAppend(static_cast<offset_type>(num_bytes));
UnsafeAppendToBitmap(true);
}

/// \brief Append a sequence of strings in one shot.
///
/// \param[in] values a vector of strings
Expand Down
272 changes: 171 additions & 101 deletions cpp/src/arrow/compute/kernels/scalar_string.cc
Expand Up @@ -2427,8 +2427,6 @@ void AddUtf8Length(FunctionRegistry* registry) {
DCHECK_OK(registry->AddFunction(std::move(func)));
}

// binary join

template <typename Type>
struct BinaryJoin {
using ArrayType = typename TypeTraits<Type>::ArrayType;
Expand All @@ -2438,119 +2436,191 @@ struct BinaryJoin {

static Status Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
if (batch[0].kind() == Datum::SCALAR) {
const ListScalar& list = checked_cast<const ListScalar&>(*batch[0].scalar());
if (!list.is_valid) {
return Status::OK();
}
if (batch[1].kind() == Datum::SCALAR) {
const BaseBinaryScalar& separator_scalar =
checked_cast<const BaseBinaryScalar&>(*batch[1].scalar());
if (!separator_scalar.is_valid) {
return Status::OK();
}
util::string_view separator(*separator_scalar.value);

TypedBufferBuilder<uint8_t> builder(ctx->memory_pool());
auto Append = [&](util::string_view value) {
return builder.Append(reinterpret_cast<const uint8_t*>(value.data()),
static_cast<offset_type>(value.size()));
};

const ArrayType* strings = static_cast<const ArrayType*>(list.value.get());
if (strings->null_count() > 0) {
// since the input list is not null, the out datum needs to be assigned to
*out = MakeNullScalar(list.value->type());
return Status::OK();
}
if (strings->length() > 0) {
RETURN_NOT_OK(Append(strings->GetView(0)));
for (int64_t j = 1; j < strings->length(); j++) {
RETURN_NOT_OK(Append(separator));
RETURN_NOT_OK(Append(strings->GetView(j)));
}
}
std::shared_ptr<Buffer> string_buffer;
RETURN_NOT_OK(builder.Finish(&string_buffer));
ARROW_ASSIGN_OR_RAISE(auto scalar_right_type,
MakeScalar<std::shared_ptr<Buffer>>(
list.value->type(), std::move(string_buffer)));
*out = scalar_right_type;
return ExecScalarScalar(ctx, *batch[0].scalar(), *batch[1].scalar(), out);
}
// XXX do we want to support scalar[list[str]] with array[str] ?
} else {
const ListArrayType list(batch[0].array());
ArrayData* output = out->mutable_array();
DCHECK_EQ(batch[0].kind(), Datum::ARRAY);
if (batch[1].kind() == Datum::SCALAR) {
return ExecArrayScalar(ctx, batch[0].array(), *batch[1].scalar(), out);
}
DCHECK_EQ(batch[1].kind(), Datum::ARRAY);
return ExecArrayArray(ctx, batch[0].array(), batch[1].array(), out);
}
return Status::OK();
}

BuilderType builder(ctx->memory_pool());
RETURN_NOT_OK(builder.Reserve(list.length()));
if (batch[1].kind() == Datum::ARRAY) {
ArrayType separator_array(batch[1].array());
for (int64_t i = 0; i < list.length(); ++i) {
const std::shared_ptr<Array> slice = list.value_slice(i);
const ArrayType* strings = static_cast<const ArrayType*>(slice.get());
if ((strings->null_count() > 0) || (list.IsNull(i)) ||
separator_array.IsNull(i)) {
RETURN_NOT_OK(builder.AppendNull());
} else {
const auto separator = separator_array.GetView(i);
if (strings->length() > 0) {
RETURN_NOT_OK(builder.Append(strings->GetView(0)));
for (int64_t j = 1; j < strings->length(); j++) {
RETURN_NOT_OK(builder.AppendCurrent(separator));
RETURN_NOT_OK(builder.AppendCurrent(strings->GetView(j)));
}
} else {
RETURN_NOT_OK(builder.AppendEmptyValue());
}
}
}
} else if (batch[1].kind() == Datum::SCALAR) {
const auto& separator_scalar =
checked_cast<const BaseBinaryScalar&>(*batch[1].scalar());
if (!separator_scalar.is_valid) {
ARROW_ASSIGN_OR_RAISE(
auto nulls,
MakeArrayOfNull(list.value_type(), list.length(), ctx->memory_pool()));
*output = *nulls->data();
output->type = list.value_type();
return Status::OK();
}
util::string_view separator(*separator_scalar.value);

for (int64_t i = 0; i < list.length(); ++i) {
const std::shared_ptr<Array> slice = list.value_slice(i);
const ArrayType* strings = static_cast<const ArrayType*>(slice.get());
if ((strings->null_count() > 0) || (list.IsNull(i))) {
RETURN_NOT_OK(builder.AppendNull());
} else {
if (strings->length() > 0) {
RETURN_NOT_OK(builder.Append(strings->GetView(0)));
for (int64_t j = 1; j < strings->length(); j++) {
RETURN_NOT_OK(builder.AppendCurrent(separator));
RETURN_NOT_OK(builder.AppendCurrent(strings->GetView(j)));
}
} else {
RETURN_NOT_OK(builder.AppendEmptyValue());
}
}
}
// Scalar, scalar -> scalar
static Status ExecScalarScalar(KernelContext* ctx, const Scalar& left,
const Scalar& right, Datum* out) {
const auto& list = checked_cast<const ListScalar&>(left);
const auto& separator_scalar = checked_cast<const BaseBinaryScalar&>(right);
if (!list.is_valid || !separator_scalar.is_valid) {
return Status::OK();
}
util::string_view separator(*separator_scalar.value);

TypedBufferBuilder<uint8_t> builder(ctx->memory_pool());
auto Append = [&](util::string_view value) {
return builder.Append(reinterpret_cast<const uint8_t*>(value.data()),
static_cast<offset_type>(value.size()));
};

const auto& strings = checked_cast<const ArrayType&>(*list.value);
if (strings.null_count() > 0) {
// Since the input list is not null, the out datum needs to be assigned to
*out = MakeNullScalar(list.value->type());
return Status::OK();
}
if (strings.length() > 0) {
auto data_length =
strings.total_values_length() + (strings.length() - 1) * separator.length();
RETURN_NOT_OK(builder.Reserve(data_length));
RETURN_NOT_OK(Append(strings.GetView(0)));
for (int64_t j = 1; j < strings.length(); j++) {
RETURN_NOT_OK(Append(separator));
RETURN_NOT_OK(Append(strings.GetView(j)));
}
std::shared_ptr<Array> string_array;
RETURN_NOT_OK(builder.Finish(&string_array));
*output = *string_array->data();
// correct the output type based on the input
output->type = list.value_type();
}
std::shared_ptr<Buffer> string_buffer;
RETURN_NOT_OK(builder.Finish(&string_buffer));
ARROW_ASSIGN_OR_RAISE(auto joined, MakeScalar<std::shared_ptr<Buffer>>(
list.value->type(), std::move(string_buffer)));
*out = std::move(joined);
return Status::OK();
}

// Array, scalar -> array
static Status ExecArrayScalar(KernelContext* ctx,
const std::shared_ptr<ArrayData>& left,
const Scalar& right, Datum* out) {
const ListArrayType list(left);
const auto& separator_scalar = checked_cast<const BaseBinaryScalar&>(right);

if (!separator_scalar.is_valid) {
ARROW_ASSIGN_OR_RAISE(auto nulls, MakeArrayOfNull(list.value_type(), list.length(),
ctx->memory_pool()));
*out = *nulls->data();
return Status::OK();
}

util::string_view separator(*separator_scalar.value);
const auto& strings = checked_cast<const ArrayType&>(*list.values());
const auto list_offsets = list.raw_value_offsets();

BuilderType builder(ctx->memory_pool());
RETURN_NOT_OK(builder.Reserve(list.length()));

// Presize data to avoid multiple reallocations when joining strings
int64_t total_data_length = strings.total_values_length();
for (int64_t i = 0; i < list.length(); ++i) {
const auto j_start = list_offsets[i], j_end = list_offsets[i + 1];
bool has_null_string = false;
for (int64_t j = j_start; !has_null_string && j < j_end; ++j) {
has_null_string = strings.IsNull(j);
}
if (!has_null_string && j_end > j_start) {
total_data_length += (j_end - j_start - 1) * separator.length();
}
}
RETURN_NOT_OK(builder.ReserveData(total_data_length));

struct SeparatorLookup {
const util::string_view separator;

bool IsNull(int64_t i) { return false; }
util::string_view GetView(int64_t i) { return separator; }
};
return JoinStrings(list, strings, SeparatorLookup{separator}, &builder, out);
}

// Array, array -> array
static Status ExecArrayArray(KernelContext* ctx, const std::shared_ptr<ArrayData>& left,
const std::shared_ptr<ArrayData>& right, Datum* out) {
const ListArrayType list(left);
const auto& strings = checked_cast<const ArrayType&>(*list.values());
const auto list_offsets = list.raw_value_offsets();
const auto string_offsets = strings.raw_value_offsets();
const ArrayType separators(right);

BuilderType builder(ctx->memory_pool());
RETURN_NOT_OK(builder.Reserve(list.length()));

// Presize data to avoid multiple reallocations when joining strings
int64_t total_data_length = 0;
for (int64_t i = 0; i < list.length(); ++i) {
if (separators.IsNull(i)) {
continue;
}
const auto j_start = list_offsets[i], j_end = list_offsets[i + 1];
bool has_null_string = false;
for (int64_t j = j_start; !has_null_string && j < j_end; ++j) {
has_null_string = strings.IsNull(j);
}
if (!has_null_string && j_end > j_start) {
total_data_length += string_offsets[j_end] - string_offsets[j_start];
total_data_length += (j_end - j_start - 1) * separators.value_length(i);
}
}
RETURN_NOT_OK(builder.ReserveData(total_data_length));

struct SeparatorLookup {
const ArrayType& separators;

bool IsNull(int64_t i) { return separators.IsNull(i); }
util::string_view GetView(int64_t i) { return separators.GetView(i); }
};
return JoinStrings(list, strings, SeparatorLookup{separators}, &builder, out);
}

template <typename SeparatorLookup>
static Status JoinStrings(const ListArrayType& list, const ArrayType& strings,
SeparatorLookup&& separators, BuilderType* builder,
Datum* out) {
const auto list_offsets = list.raw_value_offsets();

for (int64_t i = 0; i < list.length(); ++i) {
if (list.IsNull(i) || separators.IsNull(i)) {
builder->UnsafeAppendNull();
continue;
}
const auto j_start = list_offsets[i], j_end = list_offsets[i + 1];
if (j_start == j_end) {
builder->UnsafeAppendEmptyValue();
continue;
}
bool has_null_string = false;
for (int64_t j = j_start; !has_null_string && j < j_end; ++j) {
has_null_string = strings.IsNull(j);
}
if (has_null_string) {
builder->UnsafeAppendNull();
continue;
}
builder->UnsafeAppend(strings.GetView(j_start));
for (int64_t j = j_start + 1; j < j_end; ++j) {
builder->UnsafeAppendToCurrent(separators.GetView(i));
builder->UnsafeAppendToCurrent(strings.GetView(j));
}
}

std::shared_ptr<Array> string_array;
RETURN_NOT_OK(builder->Finish(&string_array));
*out = *string_array->data();
// Correct the output type based on the input
out->mutable_array()->type = list.value_type();
return Status::OK();
}
};

const FunctionDoc binary_join_doc(
"Join a list of strings together with a `separator` to form a single string",
("Insert `separator` between each list element, and concatenate them."),
("Insert `separator` between `list` elements, and concatenate them.\n"
"Any null input and any null `list` element emits a null output.\n"),
{"list", "separator"});

void AddJoin(FunctionRegistry* registry) {
void AddBinaryJoin(FunctionRegistry* registry) {
auto func =
std::make_shared<ScalarFunction>("binary_join", Arity::Binary(), &binary_join_doc);
for (const std::shared_ptr<DataType>& ty : BaseBinaryTypes()) {
Expand Down Expand Up @@ -2870,7 +2940,7 @@ void RegisterScalarStringAscii(FunctionRegistry* registry) {
AddExtractRegex(registry);
#endif
AddStrptime(registry);
AddJoin(registry);
AddBinaryJoin(registry);
}

} // namespace internal
Expand Down

0 comments on commit 9353e4e

Please sign in to comment.