Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

GH-34056: [C++] Add Utility function to simplify converting any row-based structure into an arrow::RecordBatchReader or an arrow::Table #34057

Merged
merged 20 commits into from
Feb 21, 2023
Merged
Show file tree
Hide file tree
Changes from 16 commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions cpp/src/arrow/util/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ add_arrow_test(utility-test
queue_test.cc
range_test.cc
reflection_test.cc
rows_to_batches_test.cc
small_vector_test.cc
stl_util_test.cc
string_test.cc
Expand Down
159 changes: 159 additions & 0 deletions cpp/src/arrow/util/rows_to_batches.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

#pragma once

#include "arrow/record_batch.h"
#include "arrow/result.h"
#include "arrow/status.h"
#include "arrow/table_builder.h"
#include "arrow/util/iterator.h"

#include <type_traits>

namespace arrow::util {

namespace detail {

// Default identity function row accessor. Used to for the common case where the value
// of each row iterated over is it's self also directly iterable.
[[nodiscard]] constexpr inline auto MakeDefaultRowAccessor() {
return [](auto& x) -> Result<decltype(std::ref(x))> { return std::ref(x); };
}

// Meta-funciton to check if a type `T` is a range (iterable using `std::begin()` /
gringasalpastor marked this conversation as resolved.
Show resolved Hide resolved
// `std::end()`). `is_range<T>::value` will be false if `T` is not a valid range.
template <typename T, typename = void>
struct is_range : std::false_type {};

template <typename T>
struct is_range<T, std::void_t<decltype(std::begin(std::declval<T>())),
decltype(std::end(std::declval<T>()))>> : std::true_type {
};

} // namespace detail

/// Delete overload for `const Range&& rows` because the data's lifetime must exceed
/// the lifetime of the function call. `data` will be read when client uses the
/// `RecordBatchReader`
template <class Range, class DataPointConvertor,
class RowAccessor = decltype(detail::MakeDefaultRowAccessor())>
[[nodiscard]] typename std::enable_if_t<detail::is_range<Range>::value,
Result<std::shared_ptr<RecordBatchReader>>>
/* Result<std::shared_ptr<RecordBatchReader>>> */ RowsToBatches(
const std::shared_ptr<Schema>& schema, const Range&& rows,
DataPointConvertor&& data_point_convertor,
RowAccessor&& row_accessor = detail::MakeDefaultRowAccessor(),
MemoryPool* pool = default_memory_pool(),
const std::size_t batch_size = 1024) = delete;

/// \brief Utility function for converting any row-based structure into an
/// `arrow::RecordBatchReader` (this can be easily converted to an `arrow::Table` using
/// `arrow::RecordBatchReader::ToTable()`).
///
/// Examples of supported types:
/// - `std::vector<std::vector<std::variant<int, bsl::string>>>`
/// - `std::vector<MyRowStruct>`

/// If `rows` (client’s row-based structure) is not a valid C++ range, the client will
/// need to either make it iterable, or make an adapter/wrapper that is a valid C++
/// range.

/// The client must provide a `DataPointConvertor` callable type that will convert the
/// structure’s data points into the corresponding arrow types.

/// Complex nested rows can be supported by providing a custom `row_accessor` instead
/// of the default.

/// Example usage:
/// \code{.cpp}
/// auto IntConvertor = [](ArrayBuilder& array_builder, int value) {
/// return static_cast<Int64Builder&>(array_builder).Append(value);
/// };
/// std::vector<std::vector<int>> data = {{1, 2, 4}, {5, 6, 7}};
/// auto batches = RowsToBatches(kTestSchema, data, IntConvertor);
/// \endcode

/// \param[in] schema - the schema to be used in the `RecordBatchReader`

/// \param[in] rows - iterable row-based structure that will be converted to arrow
/// batches

/// \param[in] data_point_convertor - client provided callable type that will convert
/// the structure’s data points into the corresponding arrow types. The convertor must
/// return an error `Status` if an error happens during conversion.

/// \param[in] row_accessor - In the common case where the value of each row iterated
/// over is it's self also directly iterable, the client can just use the default.
/// The provided callable must take the values of the `rows` range and return a
/// `std::reference_wrapper<Range>` to the data points in a given row. The data points
/// must be in order of their corresponding fields in the schema.
/// see: /ref `MakeDefaultRowAccessor`

/// \return `Result<std::shared_ptr<RecordBatchReader>>>` result will be a
/// `std::shared_ptr<RecordBatchReader>>` if not errors occurred, else an error status.
template <class Range, class DataPointConvertor,
class RowAccessor = decltype(detail::MakeDefaultRowAccessor())>
[[nodiscard]] typename std::enable_if_t<detail::is_range<Range>::value,
Result<std::shared_ptr<RecordBatchReader>>>
/* Result<std::shared_ptr<RecordBatchReader>>> */ RowsToBatches(
const std::shared_ptr<Schema>& schema, const Range& rows,
DataPointConvertor&& data_point_convertor,
RowAccessor&& row_accessor = detail::MakeDefaultRowAccessor(),
MemoryPool* pool = default_memory_pool(), const std::size_t batch_size = 1024) {
auto make_next_batch =
[pool = pool, batch_size = batch_size, rows_ittr = std::begin(rows),
rows_ittr_end = std::end(rows), schema = schema,
row_accessor = std::forward<RowAccessor>(row_accessor),
data_point_convertor = std::forward<DataPointConvertor>(
data_point_convertor)]() mutable -> Result<std::shared_ptr<RecordBatch>> {
if (rows_ittr == rows_ittr_end) return NULLPTR;

ARROW_ASSIGN_OR_RAISE(auto record_batch_builder,
RecordBatchBuilder::Make(schema, pool, batch_size));

for (size_t i = 0; i < batch_size && (rows_ittr != rows_ittr_end);
i++, std::advance(rows_ittr, 1)) {
int col_index = 0;
ARROW_ASSIGN_OR_RAISE(const auto row, row_accessor(*rows_ittr));

// If the accessor returns a `std::reference_wrapper` unwrap if
const auto& row_unwrapped = [&]() {
if constexpr (detail::is_range<decltype(row)>::value)
return row;
else
return row.get();
}();

for (auto& data_point : row_unwrapped) {
ArrayBuilder* array_builder = record_batch_builder->GetField(col_index);
ARROW_RETURN_IF(array_builder == NULLPTR,
Status::Invalid("array_builder == NULLPTR"));

ARROW_RETURN_NOT_OK(data_point_convertor(*array_builder, data_point));
col_index++;
}
}

ARROW_ASSIGN_OR_RAISE(auto result, record_batch_builder->Flush());
return result;
};
return RecordBatchReader::MakeFromIterator(MakeFunctionIterator(make_next_batch),
schema);
}

} // namespace arrow::util
120 changes: 120 additions & 0 deletions cpp/src/arrow/util/rows_to_batches_test.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

#include <vector>

#include <gtest/gtest.h>

#include "arrow/array/builder_binary.h"
#include "arrow/array/builder_primitive.h"
#include "arrow/scalar.h"
#include "arrow/table.h"
#include "arrow/testing/gtest_util.h"
#include "arrow/util/rows_to_batches.h"

namespace arrow::util {

const auto kTestSchema = schema(
{field("field_1", int64()), field("field_2", int64()), field("field_3", int64())});

auto IntConvertor = [](ArrayBuilder& array_builder, int value) {
return static_cast<Int64Builder&>(array_builder).Append(value);
};

bool CompareJson(const arrow::Table& arrow_table, const std::string& json,
const std::string& field_name) {
const auto col = arrow_table.GetColumnByName(field_name);
return arrow::ChunkedArrayFromJSON(col->type(), {json})->Equals(col);
}

TEST(RowsToBatches, BasicUsage) {
std::vector<std::vector<int>> data = {{1, 2, 4}, {5, 6, 7}};
auto batches = RowsToBatches(kTestSchema, data, IntConvertor).ValueOrDie();
auto table = batches->ToTable().ValueOrDie();

EXPECT_TRUE(CompareJson(*table, R"([1, 5])", "field_1"));
EXPECT_TRUE(CompareJson(*table, R"([2, 6])", "field_2"));
EXPECT_TRUE(CompareJson(*table, R"([4, 7])", "field_3"));
}

TEST(RowsToBatches, ConstRange) {
const std::vector<std::vector<int>> data = {{1, 2, 4}, {5, 6, 7}};
auto batches = RowsToBatches(kTestSchema, data, IntConvertor).ValueOrDie();
auto table = batches->ToTable().ValueOrDie();

EXPECT_TRUE(CompareJson(*table, R"([1, 5])", "field_1"));
EXPECT_TRUE(CompareJson(*table, R"([2, 6])", "field_2"));
EXPECT_TRUE(CompareJson(*table, R"([4, 7])", "field_3"));
}

TEST(RowsToBatches, StructAccessor) {
struct TestStruct {
std::vector<int> values;
};
std::vector<TestStruct> data = {TestStruct{{1, 2, 4}}, TestStruct{{5, 6, 7}}};

auto accessor =
[](const TestStruct& s) -> Result<std::reference_wrapper<const std::vector<int>>> {
return std::cref(s.values);
};

auto batches = RowsToBatches(kTestSchema, data, IntConvertor, accessor).ValueOrDie();

auto table = batches->ToTable().ValueOrDie();

EXPECT_TRUE(CompareJson(*table, R"([1, 5])", "field_1"));
EXPECT_TRUE(CompareJson(*table, R"([2, 6])", "field_2"));
EXPECT_TRUE(CompareJson(*table, R"([4, 7])", "field_3"));

// Test accessor that returns by value instead of using `std::reference_wrapper`
auto accessor_by_value = [](const TestStruct& s) -> Result<std::set<int>> {
return std::set(std::begin(s.values), std::end(s.values));
};
auto batches_by_value =
RowsToBatches(kTestSchema, data, IntConvertor, accessor_by_value).ValueOrDie();

auto table_by_value = batches_by_value->ToTable().ValueOrDie();

EXPECT_TRUE(CompareJson(*table_by_value, R"([1, 5])", "field_1"));
EXPECT_TRUE(CompareJson(*table_by_value, R"([2, 6])", "field_2"));
EXPECT_TRUE(CompareJson(*table_by_value, R"([4, 7])", "field_3"));
}

TEST(RowsToBatches, Variant) {
auto VariantConvertor = [](ArrayBuilder& array_builder,
const std::variant<int, std::string>& value) {
if (std::holds_alternative<int>(value))
return dynamic_cast<Int64Builder&>(array_builder).Append(std::get<int>(value));
else
return dynamic_cast<arrow::StringBuilder&>(array_builder)
.Append(std::get<std::string>(value).c_str(),
std::get<std::string>(value).length());
gringasalpastor marked this conversation as resolved.
Show resolved Hide resolved
};

const auto test_schema = schema({field("x", int64()), field("y", utf8())});
std::vector<std::vector<std::variant<int, std::string>>> data = {{1, std::string("2")},
{4, std::string("5")}};

auto batches = RowsToBatches(test_schema, data, VariantConvertor).ValueOrDie();

auto table = batches->ToTable().ValueOrDie();

EXPECT_TRUE(CompareJson(*table, R"([1, 4])", "x"));
EXPECT_TRUE(CompareJson(*table, R"(["2", "5"])", "y"));
}

} // namespace arrow::util