Skip to content
Permalink
Browse files
[C++] Add TypeVisitor example (#166)
* Start visitor example

* Get a working visitor pattern

* Add random data generator example

* Revert makefile change

* Apply suggestions from code review

Co-authored-by: Antoine Pitrou <pitrou@free.fr>

* Merge

* Reorganize examples and add prose

* PR feedback

* Update cpp/code/creating_arrow_objects.cc

Co-authored-by: Antoine Pitrou <pitrou@free.fr>

* Adjust types to be more specific

* Update cpp/code/basic_arrow.cc

Co-authored-by: Antoine Pitrou <pitrou@free.fr>

Co-authored-by: Antoine Pitrou <pitrou@free.fr>
  • Loading branch information
wjones127 and pitrou committed Mar 28, 2022
1 parent 37ac343 commit 72f9c042029aeb28b252c26bb77ca7daf897a05d
Showing 4 changed files with 223 additions and 1 deletion.
@@ -16,6 +16,7 @@
// under the License.

#include <arrow/api.h>
#include <arrow/visit_array_inline.h>
#include <gtest/gtest.h>

#include "common.h"
@@ -63,3 +64,71 @@ arrow::Status ReturnNotOk() {
TEST(BasicArrow, ReturnNotOkNoMacro) { ASSERT_OK(ReturnNotOkMacro()); }

TEST(BasicArrow, ReturnNotOk) { ASSERT_OK(ReturnNotOk()); }

/// \brief Sum numeric values across columns
///
/// Only supports floating point and integral types. Does not support decimals.
class TableSummation {
double partial = 0.0;
public:

arrow::Result<double> Compute(std::shared_ptr<arrow::RecordBatch> batch) {
for (std::shared_ptr<arrow::Array> array : batch->columns()) {
ARROW_RETURN_NOT_OK(arrow::VisitArrayInline(*array, this));
}
return partial;
}

// Default implementation
arrow::Status Visit(const arrow::Array& array) {
return arrow::Status::NotImplemented("Can not compute sum for array of type ",
array.type()->ToString());
}

template <typename ArrayType, typename T = typename ArrayType::TypeClass>
arrow::enable_if_number<T, arrow::Status> Visit(const ArrayType& array) {
for (arrow::util::optional<typename T::c_type> value : array) {
if (value.has_value()) {
partial += static_cast<double>(value.value());
}
}
return arrow::Status::OK();
}
}; // TableSummation

arrow::Status VisitorSummationExample() {
StartRecipe("VisitorSummationExample");
std::shared_ptr<arrow::Schema> schema = arrow::schema({
arrow::field("a", arrow::int32()),
arrow::field("b", arrow::float64()),
});
int32_t num_rows = 3;
std::vector<std::shared_ptr<arrow::Array>> columns;

arrow::Int32Builder a_builder = arrow::Int32Builder();
std::vector<int32_t> a_vals = {1, 2, 3};
ARROW_RETURN_NOT_OK(a_builder.AppendValues(a_vals));
ARROW_ASSIGN_OR_RAISE(auto a_arr, a_builder.Finish());
columns.push_back(a_arr);

arrow::DoubleBuilder b_builder = arrow::DoubleBuilder();
std::vector<double> b_vals = {4.0, 5.0, 6.0};
ARROW_RETURN_NOT_OK(b_builder.AppendValues(b_vals));
ARROW_ASSIGN_OR_RAISE(auto b_arr, b_builder.Finish());
columns.push_back(b_arr);

auto batch = arrow::RecordBatch::Make(schema, num_rows, columns);

// Call
TableSummation summation;
ARROW_ASSIGN_OR_RAISE(auto total, summation.Compute(batch));

rout << "Total is " << total;

EndRecipe("VisitorSummationExample");

EXPECT_EQ(total, 21.0);
return arrow::Status::OK();
}

TEST(BasicArrow, VisitorSummationExample) { ASSERT_OK(VisitorSummationExample()); }
@@ -18,6 +18,8 @@
#include <arrow/api.h>
#include <gtest/gtest.h>

#include <random>

#include "common.h"

arrow::Status CreatingArrays() {
@@ -58,5 +60,92 @@ arrow::Status CreatingArraysPtr() {
return arrow::Status::OK();
}

/// \brief Generate random record batches for a given schema
///
/// For demonstration purposes, this only covers DoubleType and ListType
class RandomBatchGenerator {
public:
std::shared_ptr<arrow::Schema> schema;

RandomBatchGenerator(std::shared_ptr<arrow::Schema> schema) : schema(schema){};

arrow::Result<std::shared_ptr<arrow::RecordBatch>> Generate(int32_t num_rows) {
num_rows_ = num_rows;
for (std::shared_ptr<arrow::Field> field : schema->fields()) {
ARROW_RETURN_NOT_OK(arrow::VisitTypeInline(*field->type(), this));
}

return arrow::RecordBatch::Make(schema, num_rows, arrays_);
}

// Default implementation
arrow::Status Visit(const arrow::DataType& type) {
return arrow::Status::NotImplemented("Generating data for", type.ToString());
}

arrow::Status Visit(const arrow::DoubleType&) {
auto builder = arrow::DoubleBuilder();
std::normal_distribution<> d{/*mean=*/5.0, /*stddev=*/2.0};
for (int32_t i = 0; i < num_rows_; ++i) {
builder.Append(d(gen_));
}
ARROW_ASSIGN_OR_RAISE(auto array, builder.Finish());
arrays_.push_back(array);
return arrow::Status::OK();
}

arrow::Status Visit(const arrow::ListType& type) {
// Generate offsets first, which determines number of values in sub-array
std::poisson_distribution<> d{/*mean=*/4};
auto builder = arrow::Int32Builder();
builder.Append(0);
int32_t last_val = 0;
for (int32_t i = 0; i < num_rows_; ++i) {
last_val += d(gen_);
builder.Append(last_val);
}
ARROW_ASSIGN_OR_RAISE(auto offsets, builder.Finish());

// Since children of list has a new length, will use a new generator
RandomBatchGenerator value_gen(arrow::schema({arrow::field("x", type.value_type())}));
// Last index from the offsets array becomes the length of the sub-array
ARROW_ASSIGN_OR_RAISE(auto inner_batch, value_gen.Generate(last_val));
std::shared_ptr<arrow::Array> values = inner_batch->column(0);

ARROW_ASSIGN_OR_RAISE(auto array,
arrow::ListArray::FromArrays(*offsets.get(), *values.get()));
arrays_.push_back(array);

return arrow::Status::OK();
}

protected:
std::random_device rd_{};
std::mt19937 gen_{rd_()};
std::vector<std::shared_ptr<arrow::Array>> arrays_;
int32_t num_rows_;
}; // RandomBatchGenerator

arrow::Status GenerateRandomData() {
StartRecipe("GenerateRandomData");
std::shared_ptr<arrow::Schema> schema =
arrow::schema({arrow::field("x", arrow::float64()),
arrow::field("y", arrow::list(arrow::float64()))});

RandomBatchGenerator generator(schema);
ARROW_ASSIGN_OR_RAISE(std::shared_ptr<arrow::RecordBatch> batch, generator.Generate(5));

rout << "Created batch: \n" << batch->ToString();

// Consider using ValidateFull to check correctness
ARROW_RETURN_NOT_OK(batch->ValidateFull());

EndRecipe("GenerateRandomData");
EXPECT_EQ(batch->num_rows(), 5);

return arrow::Status::OK();
}

TEST(CreatingArrowObjects, CreatingArraysTest) { ASSERT_OK(CreatingArrays()); }
TEST(CreatingArrowObjects, CreatingArraysPtrTest) { ASSERT_OK(CreatingArraysPtr()); }
TEST(CreatingArrowObjects, GeneratingRandomData) { ASSERT_OK(GenerateRandomData()); }
@@ -48,3 +48,45 @@ boilerplate for you. It will run the contained expression and check the resulti
.. recipe:: ../code/basic_arrow.cc ReturnNotOk
:caption: Using ARROW_RETURN_NOT_OK to check the status
:dedent: 2


Using the Visitor Pattern
=========================

Arrow classes :cpp:class:`arrow::DataType`, :cpp:class:`arrow::Scalar`, and
:cpp:class:`arrow::Array` have specialized subclasses for each Arrow type. In
order to specialize logic for each subclass, you can use the visitor pattern.
Arrow provides inline template functions that allow you to call visitors
efficiently:

* :cpp:func:`arrow::VisitTypeInline`
* :cpp:func:`arrow::VisitScalarInline`
* :cpp:func:`arrow::VisitArrayInline`

Generate Random Data
--------------------

See example at :ref:`Generate Random Data Example`.


Generalize Computations Across Arrow Types
------------------------------------------

Array visitors can be useful when writing functions that can handle multiple
array types. However, implementing a visitor for each type individually can be
excessively verbose. Fortunately, Arrow provides type traits that allow you to
write templated functions to handle subsets of types. The example below
demonstrates a table sum function that can handle any integer or floating point
array with only a single visitor implementation by leveraging
:cpp:type:`arrow::enable_if_number`.

.. literalinclude:: ../code/basic_arrow.cc
:language: cpp
:linenos:
:start-at: class TableSummation
:end-at: }; // TableSummation
:caption: Using visitor pattern that can compute sum of table with any numeric type


.. recipe:: ../code/basic_arrow.cc VisitorSummationExample
:dedent: 2
@@ -47,4 +47,26 @@ Builders can also consume standard C++ containers:
.. note::

Builders will not take ownership of data in containers and will make a
copy of the underlying data.
copy of the underlying data.

.. _Generate Random Data Example:

Generate Random Data for a Given Schema
=======================================

To generate random data for a given schema, implementing a type visitor is a
good idea. The following example only implements double arrays and list arrays,
but could be easily extended to all types.


.. literalinclude:: ../code/creating_arrow_objects.cc
:language: cpp
:linenos:
:start-at: class RandomBatchGenerator
:end-at: }; // RandomBatchGenerator
:caption: Using visitor pattern to generate random record batches

Given such a generator, you can create random test data for any supported schema:

.. recipe:: ../code/creating_arrow_objects.cc GenerateRandomData
:dedent: 2

0 comments on commit 72f9c04

Please sign in to comment.