Skip to content

Commit

Permalink
ARROW-16823: [C++] Arrow Substrait enhancements for UDF (#13375)
Browse files Browse the repository at this point in the history
See https://issues.apache.org/jira/browse/ARROW-16823

Authored-by: Yaron Gvili <rtpsw@hotmail.com>
Signed-off-by: Weston Pace <weston.pace@gmail.com>
  • Loading branch information
rtpsw committed Jun 30, 2022
1 parent aadb4fc commit f3bdcce
Show file tree
Hide file tree
Showing 10 changed files with 527 additions and 127 deletions.
4 changes: 1 addition & 3 deletions cpp/src/arrow/compute/registry.h
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ class ARROW_EXPORT FunctionRegistry {

/// \brief Check whether a new function options type can be added to the registry.
///
/// \returns Status::KeyError if a function options type with the same name is already
/// \return Status::KeyError if a function options type with the same name is already
/// registered.
Status CanAddFunctionOptionsType(const FunctionOptionsType* options_type,
bool allow_overwrite = false);
Expand Down Expand Up @@ -115,8 +115,6 @@ class ARROW_EXPORT FunctionRegistry {
std::unique_ptr<FunctionRegistryImpl> impl_;

explicit FunctionRegistry(FunctionRegistryImpl* impl);

class NestedFunctionRegistryImpl;
};

/// \brief Return the process-global function registry.
Expand Down
5 changes: 5 additions & 0 deletions cpp/src/arrow/engine/substrait/extension_set.cc
Original file line number Diff line number Diff line change
Expand Up @@ -315,6 +315,11 @@ struct ExtensionIdRegistryImpl : ExtensionIdRegistry {
return Status::OK();
}

Status RegisterFunction(std::string uri, std::string name,
std::string arrow_function_name) override {
return RegisterFunction({uri, name}, arrow_function_name);
}

// owning storage of uris, names, (arrow::)function_names, types
// note that storing strings like this is safe since references into an
// unordered_set are not invalidated on insertion
Expand Down
4 changes: 4 additions & 0 deletions cpp/src/arrow/engine/substrait/extension_set.h
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,11 @@ class ARROW_ENGINE_EXPORT ExtensionIdRegistry {
util::string_view arrow_function_name) const = 0;
virtual Status CanRegisterFunction(Id,
const std::string& arrow_function_name) const = 0;
// registers a function without taking ownership of uri and name within Id
virtual Status RegisterFunction(Id, std::string arrow_function_name) = 0;
// registers a function while taking ownership of uri and name
virtual Status RegisterFunction(std::string uri, std::string name,
std::string arrow_function_name) = 0;
};

constexpr util::string_view kArrowExtTypesUri =
Expand Down
3 changes: 3 additions & 0 deletions cpp/src/arrow/engine/substrait/plan_internal.cc
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,9 @@ Status AddExtensionSetToPlan(const ExtensionSet& ext_set, substrait::Plan* plan)

Result<ExtensionSet> GetExtensionSetFromPlan(const substrait::Plan& plan,
const ExtensionIdRegistry* registry) {
if (registry == NULLPTR) {
registry = default_extension_id_registry();
}
std::unordered_map<uint32_t, util::string_view> uris;
uris.reserve(plan.extension_uris_size());
for (const auto& uri : plan.extension_uris()) {
Expand Down
3 changes: 3 additions & 0 deletions cpp/src/arrow/engine/substrait/relation_internal.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,15 @@

#include "arrow/compute/api_scalar.h"
#include "arrow/compute/exec/options.h"
#include "arrow/dataset/file_base.h"
#include "arrow/dataset/file_ipc.h"
#include "arrow/dataset/file_parquet.h"
#include "arrow/dataset/plan.h"
#include "arrow/dataset/scanner.h"
#include "arrow/engine/substrait/expression_internal.h"
#include "arrow/engine/substrait/type_internal.h"
#include "arrow/filesystem/localfs.h"
#include "arrow/filesystem/path_util.h"
#include "arrow/filesystem/util_internal.h"

namespace arrow {
Expand Down Expand Up @@ -66,6 +68,7 @@ Result<compute::Declaration> FromProto(const substrait::Rel& rel,
ARROW_ASSIGN_OR_RAISE(auto base_schema, FromProto(read.base_schema(), ext_set));

auto scan_options = std::make_shared<dataset::ScanOptions>();
scan_options->use_threads = true;

if (read.has_filter()) {
ARROW_ASSIGN_OR_RAISE(scan_options->filter, FromProto(read.filter(), ext_set));
Expand Down
127 changes: 113 additions & 14 deletions cpp/src/arrow/engine/substrait/serde.cc
Original file line number Diff line number Diff line change
Expand Up @@ -58,12 +58,65 @@ Result<compute::Declaration> DeserializeRelation(const Buffer& buf,
return FromProto(rel, ext_set);
}

using DeclarationFactory = std::function<Result<compute::Declaration>(
compute::Declaration, std::vector<std::string> names)>;

namespace {

DeclarationFactory MakeConsumingSinkDeclarationFactory(
const ConsumerFactory& consumer_factory) {
return [&consumer_factory](
compute::Declaration input,
std::vector<std::string> names) -> Result<compute::Declaration> {
std::shared_ptr<compute::SinkNodeConsumer> consumer = consumer_factory();
if (consumer == NULLPTR) {
return Status::Invalid("consumer factory is exhausted");
}
std::shared_ptr<compute::ExecNodeOptions> options =
std::make_shared<compute::ConsumingSinkNodeOptions>(
compute::ConsumingSinkNodeOptions{consumer_factory(), std::move(names)});
return compute::Declaration::Sequence(
{std::move(input), {"consuming_sink", options}});
};
}

compute::Declaration ProjectByNamesDeclaration(compute::Declaration input,
std::vector<std::string> names) {
int names_size = static_cast<int>(names.size());
if (names_size == 0) {
return input;
}
std::vector<compute::Expression> expressions;
for (int i = 0; i < names_size; i++) {
expressions.push_back(compute::field_ref(FieldRef(i)));
}
return compute::Declaration::Sequence(
{std::move(input),
{"project",
compute::ProjectNodeOptions{std::move(expressions), std::move(names)}}});
}

DeclarationFactory MakeWriteDeclarationFactory(
const WriteOptionsFactory& write_options_factory) {
return [&write_options_factory](
compute::Declaration input,
std::vector<std::string> names) -> Result<compute::Declaration> {
std::shared_ptr<dataset::WriteNodeOptions> options = write_options_factory();
if (options == NULLPTR) {
return Status::Invalid("write options factory is exhausted");
}
compute::Declaration projected = ProjectByNamesDeclaration(input, names);
return compute::Declaration::Sequence(
{std::move(projected), {"write", std::move(*options)}});
};
}

Result<std::vector<compute::Declaration>> DeserializePlans(
const Buffer& buf, const ConsumerFactory& consumer_factory,
ExtensionSet* ext_set_out) {
const Buffer& buf, DeclarationFactory declaration_factory,
const ExtensionIdRegistry* registry, ExtensionSet* ext_set_out) {
ARROW_ASSIGN_OR_RAISE(auto plan, ParseFromBuffer<substrait::Plan>(buf));

ARROW_ASSIGN_OR_RAISE(auto ext_set, GetExtensionSetFromPlan(plan));
ARROW_ASSIGN_OR_RAISE(auto ext_set, GetExtensionSetFromPlan(plan, registry));

std::vector<compute::Declaration> sink_decls;
for (const substrait::PlanRel& plan_rel : plan.relations()) {
Expand All @@ -76,12 +129,9 @@ Result<std::vector<compute::Declaration>> DeserializePlans(
names.assign(plan_rel.root().names().begin(), plan_rel.root().names().end());
}

// pipe each relation into a consuming_sink node
auto sink_decl = compute::Declaration::Sequence({
std::move(decl),
{"consuming_sink",
compute::ConsumingSinkNodeOptions{consumer_factory(), std::move(names)}},
});
// pipe each relation
ARROW_ASSIGN_OR_RAISE(auto sink_decl,
declaration_factory(std::move(decl), std::move(names)));
sink_decls.push_back(std::move(sink_decl));
}

Expand All @@ -91,11 +141,26 @@ Result<std::vector<compute::Declaration>> DeserializePlans(
return sink_decls;
}

Result<compute::ExecPlan> DeserializePlan(const Buffer& buf,
const ConsumerFactory& consumer_factory,
ExtensionSet* ext_set_out) {
ARROW_ASSIGN_OR_RAISE(auto declarations,
DeserializePlans(buf, consumer_factory, ext_set_out));
} // namespace

Result<std::vector<compute::Declaration>> DeserializePlans(
const Buffer& buf, const ConsumerFactory& consumer_factory,
const ExtensionIdRegistry* registry, ExtensionSet* ext_set_out) {
return DeserializePlans(buf, MakeConsumingSinkDeclarationFactory(consumer_factory),
registry, ext_set_out);
}

Result<std::vector<compute::Declaration>> DeserializePlans(
const Buffer& buf, const WriteOptionsFactory& write_options_factory,
const ExtensionIdRegistry* registry, ExtensionSet* ext_set_out) {
return DeserializePlans(buf, MakeWriteDeclarationFactory(write_options_factory),
registry, ext_set_out);
}

namespace {

Result<compute::ExecPlan> MakeSingleDeclarationPlan(
std::vector<compute::Declaration> declarations) {
if (declarations.size() > 1) {
return Status::Invalid("DeserializePlan does not support multiple root relations");
} else {
Expand All @@ -105,6 +170,40 @@ Result<compute::ExecPlan> DeserializePlan(const Buffer& buf,
}
}

} // namespace

Result<compute::ExecPlan> DeserializePlan(
const Buffer& buf, const std::shared_ptr<compute::SinkNodeConsumer>& consumer,
const ExtensionIdRegistry* registry, ExtensionSet* ext_set_out) {
bool factory_done = false;
auto single_consumer = [&factory_done, &consumer] {
if (factory_done) {
return std::shared_ptr<compute::SinkNodeConsumer>{};
}
factory_done = true;
return consumer;
};
ARROW_ASSIGN_OR_RAISE(auto declarations,
DeserializePlans(buf, single_consumer, registry, ext_set_out));
return MakeSingleDeclarationPlan(declarations);
}

Result<compute::ExecPlan> DeserializePlan(
const Buffer& buf, const std::shared_ptr<dataset::WriteNodeOptions>& write_options,
const ExtensionIdRegistry* registry, ExtensionSet* ext_set_out) {
bool factory_done = false;
auto single_write_options = [&factory_done, &write_options] {
if (factory_done) {
return std::shared_ptr<dataset::WriteNodeOptions>{};
}
factory_done = true;
return write_options;
};
ARROW_ASSIGN_OR_RAISE(auto declarations, DeserializePlans(buf, single_write_options,
registry, ext_set_out));
return MakeSingleDeclarationPlan(declarations);
}

Result<std::shared_ptr<Schema>> DeserializeSchema(const Buffer& buf,
const ExtensionSet& ext_set) {
ARROW_ASSIGN_OR_RAISE(auto named_struct, ParseFromBuffer<substrait::NamedStruct>(buf));
Expand Down
67 changes: 63 additions & 4 deletions cpp/src/arrow/engine/substrait/serde.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
#include "arrow/buffer.h"
#include "arrow/compute/exec/exec_plan.h"
#include "arrow/compute/exec/options.h"
#include "arrow/dataset/file_base.h"
#include "arrow/engine/substrait/extension_set.h"
#include "arrow/engine/substrait/visibility.h"
#include "arrow/result.h"
Expand All @@ -40,21 +41,79 @@ using ConsumerFactory = std::function<std::shared_ptr<compute::SinkNodeConsumer>

/// \brief Deserializes a Substrait Plan message to a list of ExecNode declarations
///
/// The output of each top-level Substrait relation will be sent to a caller supplied
/// consumer function provided by consumer_factory
///
/// \param[in] buf a buffer containing the protobuf serialization of a Substrait Plan
/// message
/// \param[in] consumer_factory factory function for generating the node that consumes
/// the batches produced by each toplevel Substrait relation
/// \param[in] registry an extension-id-registry to use, or null for the default one.
/// \param[out] ext_set_out if non-null, the extension mapping used by the Substrait
/// Plan is returned here.
/// \return a vector of ExecNode declarations, one for each toplevel relation in the
/// Substrait Plan
ARROW_ENGINE_EXPORT Result<std::vector<compute::Declaration>> DeserializePlans(
const Buffer& buf, const ConsumerFactory& consumer_factory,
ExtensionSet* ext_set_out = NULLPTR);
const ExtensionIdRegistry* registry = NULLPTR, ExtensionSet* ext_set_out = NULLPTR);

/// \brief Deserializes a single-relation Substrait Plan message to an execution plan
///
/// The output of each top-level Substrait relation will be sent to a caller supplied
/// consumer function provided by consumer_factory
///
/// \param[in] buf a buffer containing the protobuf serialization of a Substrait Plan
/// message
/// \param[in] consumer node that consumes the batches produced by each toplevel Substrait
/// relation
/// \param[in] registry an extension-id-registry to use, or null for the default one.
/// \param[out] ext_set_out if non-null, the extension mapping used by the Substrait
/// Plan is returned here.
/// \return an ExecNode corresponding to the single toplevel relation in the Substrait
/// Plan
Result<compute::ExecPlan> DeserializePlan(
const Buffer& buf, const std::shared_ptr<compute::SinkNodeConsumer>& consumer,
const ExtensionIdRegistry* registry = NULLPTR, ExtensionSet* ext_set_out = NULLPTR);

/// Factory function type for generating the write options of a node consuming the batches
/// produced by each toplevel Substrait relation when deserializing a Substrait Plan.
using WriteOptionsFactory = std::function<std::shared_ptr<dataset::WriteNodeOptions>()>;

/// \brief Deserializes a Substrait Plan message to a list of ExecNode declarations
///
/// The output of each top-level Substrait relation will be written to a filesystem.
/// `write_options_factory` can be used to control write behavior.
///
/// \param[in] buf a buffer containing the protobuf serialization of a Substrait Plan
/// message
/// \param[in] write_options_factory factory function for generating the write options of
/// a node consuming the batches produced by each toplevel Substrait relation
/// \param[in] registry an extension-id-registry to use, or null for the default one.
/// \param[out] ext_set_out if non-null, the extension mapping used by the Substrait
/// Plan is returned here.
/// \return a vector of ExecNode declarations, one for each toplevel relation in the
/// Substrait Plan
ARROW_ENGINE_EXPORT Result<std::vector<compute::Declaration>> DeserializePlans(
const Buffer& buf, const WriteOptionsFactory& write_options_factory,
const ExtensionIdRegistry* registry = NULLPTR, ExtensionSet* ext_set_out = NULLPTR);

Result<compute::ExecPlan> DeserializePlan(const Buffer& buf,
const ConsumerFactory& consumer_factory,
ExtensionSet* ext_set_out = NULLPTR);
/// \brief Deserializes a single-relation Substrait Plan message to an execution plan
///
/// The output of the single Substrait relation will be written to a filesystem.
/// `write_options_factory` can be used to control write behavior.
///
/// \param[in] buf a buffer containing the protobuf serialization of a Substrait Plan
/// message
/// \param[in] write_options write options of a node consuming the batches produced by
/// each toplevel Substrait relation
/// \param[in] registry an extension-id-registry to use, or null for the default one.
/// \param[out] ext_set_out if non-null, the extension mapping used by the Substrait
/// Plan is returned here.
/// \return a vector of ExecNode declarations, one for each toplevel relation in the
/// Substrait Plan
ARROW_ENGINE_EXPORT Result<compute::ExecPlan> DeserializePlan(
const Buffer& buf, const std::shared_ptr<dataset::WriteNodeOptions>& write_options,
const ExtensionIdRegistry* registry = NULLPTR, ExtensionSet* ext_set_out = NULLPTR);

/// \brief Deserializes a Substrait Type message to the corresponding Arrow type
///
Expand Down
Loading

0 comments on commit f3bdcce

Please sign in to comment.