diff --git a/.gitmodules b/.gitmodules index 81fceeb..19f90c2 100644 --- a/.gitmodules +++ b/.gitmodules @@ -1,9 +1,11 @@ [submodule "duckdb"] path = duckdb - url = https://github.com/duckdb/duckdb + url = https://github.com/drin/duckdb + branch = coop-decomp [submodule "duckdb-r"] path = duckdb-r url = https://github.com/duckdb/duckdb-r [submodule "substrait"] path = substrait - url = https://github.com/substrait-io/substrait + url = https://github.com/drin/substrait + branch = mohair diff --git a/CMakeLists.txt b/CMakeLists.txt index f180c46..668d080 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -1,4 +1,4 @@ -cmake_minimum_required(VERSION 2.8.12) +cmake_minimum_required(VERSION 3.25.1) # Set extension name here set(TARGET_NAME substrait) @@ -94,12 +94,27 @@ set(SUBSTRAIT_SOURCES third_party/substrait/substrait/type_expressions.pb.cc third_party/substrait/substrait/extensions/extensions.pb.cc) -set(EXTENSION_SOURCES - src/to_substrait.cpp +# custom sources for mohair integration +set(MOHAIR_SOURCES + src/plans.cpp + src/engine_duckdb.cpp + src/translation/duckdb_expressions.cpp + src/translation/duckdb_operators.cpp) + #src/transpilation/duckdb_expressions.cpp + #src/transpilation/duckdb_operators.cpp) + +# official sources for substrait integration +set(SUBSTRAIT_EXT_SOURCES src/from_substrait.cpp + src/to_substrait.cpp) + +# primary sources first, then others +set(EXTENSION_SOURCES src/substrait_extension.cpp src/custom_extensions.cpp src/custom_extensions_generated.cpp + ${MOHAIR_SOURCES} + ${SUBSTRAIT_EXT_SOURCES} ${SUBSTRAIT_SOURCES} ${PROTOBUF_SOURCES}) diff --git a/duckdb b/duckdb index 0e78476..8e26647 160000 --- a/duckdb +++ b/duckdb @@ -1 +1 @@ -Subproject commit 0e784765f6f87bd1ce9034afcce1e7f89fcd8777 +Subproject commit 8e266478d19636d381a328fce8551ada4c70e993 diff --git a/src/engine_duckdb.cpp b/src/engine_duckdb.cpp new file mode 100644 index 0000000..2d312c9 --- /dev/null +++ b/src/engine_duckdb.cpp @@ -0,0 +1,190 @@ +// ------------------------------ +// License +// +// Copyright 2024 Aldrin Montana +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + + +// ------------------------------ +// Dependencies + +#include "engine_duckdb.hpp" + + +// ------------------------------ +// Functions + +// >> DuckDB-specific function renaming and validation + +namespace duckdb { + + // >> Static data and related functions for mapping functions from substrait -> duckdb + + static FunctionRenameMap engine_remapped_functions { + {"modulus" , "mod" } + ,{"std_dev" , "stddev" } + ,{"starts_with", "prefix" } + ,{"ends_with" , "suffix" } + ,{"substring" , "substr" } + ,{"char_length", "length" } + ,{"is_nan" , "isnan" } + ,{"is_finite" , "isfinite" } + ,{"is_infinite", "isinf" } + ,{"like" , "~~" } + ,{"extract" , "date_part"} + }; + + string RemoveExtension(string &function_name) { + string name; + + for (auto &c : function_name) { + if (c == ':') { break; } + name += c; + } + + return name; + } + + string RemapFunctionName(string &function_name) { + string name { RemoveExtension(function_name) }; + + auto it = engine_remapped_functions.find(name); + if (it != engine_remapped_functions.end()) { name = it->second; } + + return name; + } + + + // >> Static data and related functions for extraction of date subfields + + static case_insensitive_set_t engine_date_subfields { + "year" , "month" , "day" + ,"decade" , "century" , "millenium" + ,"quarter" + ,"microsecond", "milliseconds", "second" + ,"minute" , "hour" + }; + + void AssertValidDateSubfield(const string& subfield) { + D_ASSERT(engine_date_subfields.count(subfield)); + } + +} // namespace: duckdb + + +namespace duckdb { + + //! Constructor for DuckDBTranslator + DuckDBTranslator::DuckDBTranslator(ClientContext &ctxt): context(ctxt) { + t_conn = make_uniq(*ctxt.db); + functions_map = make_uniq(); + + // create an http state, but I don't know what this is for + auto http_state = HTTPState::TryGetState(*(t_conn->context)); + http_state->Reset(); + } + + // >> Entry points for substrait plan (json or binary) -> duckdb logical plan + shared_ptr + DuckDBTranslator::TranspilePlanMessage(shared_ptr sys_plan) { + shared_ptr plan_rel = sys_plan->engine; + + // Transform Relation to QueryNode and wrap in a SQLStatement + auto plan_wrapper = make_uniq(); + plan_wrapper->node = plan_rel->GetQueryNode(); + + // Create a planner to go from SQLStatement -> LogicalOperator + Planner planner { context }; + planner.CreatePlan(std::move(plan_wrapper)); + shared_ptr logical_plan { std::move(planner.plan) }; + + return make_shared(sys_plan->substrait, logical_plan); + } + + shared_ptr + DuckDBTranslator::TranslateLogicalPlan( shared_ptr engine_plan + ,bool optimize) { + // Make a copy that is a unique_ptr + auto logical_plan = engine_plan->engine->Copy(context); + + // optimization + if (optimize) { + shared_ptr binder { Binder::CreateBinder(context) }; + Optimizer optimizer { *binder, context }; + + logical_plan = optimizer.Optimize(std::move(logical_plan)); + } + + // transformation to physical plan + PhysicalPlanGenerator physical_planner { context }; + shared_ptr physical_plan { + physical_planner.CreatePlan(std::move(logical_plan)) + }; + + return make_shared(engine_plan->substrait, physical_plan); + } + + bool ShouldKeepExecuting(PendingExecutionResult& exec_result) { + switch (exec_result) { + case PendingExecutionResult::RESULT_NOT_READY: + case PendingExecutionResult::RESULT_READY: + break; + + case PendingExecutionResult::BLOCKED: + std::cout << "\t[Executor]: blocked" << std::endl; + break; + + case PendingExecutionResult::NO_TASKS_AVAILABLE: + std::cout << "\t[Executor]: waiting for tasks" << std::endl; + break; + + case PendingExecutionResult::EXECUTION_ERROR: + std::cerr << "\t[Executor]: execution error" << std::endl; + return false; + + default: + std::cerr << "\t[Executor]: unknown execution result type" << std::endl; + return false; + } + + return true; + } + + unique_ptr DuckDBExecutor::Execute() { + constexpr bool dry_run { false }; + Executor plan_executor { context }; + + plan_executor.Initialize( + PhysicalResultCollector::GetResultCollector(context, plan_data) + ); + + auto exec_result = plan_executor.ExecuteTask(dry_run); + while (exec_result != PendingExecutionResult::RESULT_READY) { + if (not ShouldKeepExecuting(exec_result)) { + std::cerr << "\t\t" << plan_executor.GetError().Message() << std::endl; + break; + } + + exec_result = plan_executor.ExecuteTask(dry_run); + } + + if ( exec_result == PendingExecutionResult::RESULT_READY + and plan_executor.HasResultCollector()) { + return std::move(plan_executor.GetResult()); + } + + return nullptr; + } + +} // namespace: duckdb diff --git a/src/include/custom_extensions/custom_extensions.hpp b/src/include/custom_extensions/custom_extensions.hpp index 3bd1062..802efa2 100644 --- a/src/include/custom_extensions/custom_extensions.hpp +++ b/src/include/custom_extensions/custom_extensions.hpp @@ -6,65 +6,103 @@ // //===----------------------------------------------------------------------===// + +// ------------------------------ +// Dependencies #pragma once -#include "duckdb/common/types/hash.hpp" -#include #include +#include + +#include "duckdb/common/types/hash.hpp" + + +// ------------------------------ +// Dependencies namespace duckdb { -struct SubstraitCustomFunction { -public: - SubstraitCustomFunction(string name_p, vector arg_types_p) - : name(std::move(name_p)), arg_types(std::move(arg_types_p)) {}; - - SubstraitCustomFunction() = default; - bool operator==(const SubstraitCustomFunction &other) const { - return name == other.name && arg_types == other.arg_types; - } - string GetName(); - string name; - vector arg_types; -}; -//! Here we define function extensions -class SubstraitFunctionExtensions { -public: - SubstraitFunctionExtensions(SubstraitCustomFunction function_p, string extension_path_p) - : function(std::move(function_p)), extension_path(std::move(extension_path_p)) {}; - SubstraitFunctionExtensions() = default; - - string GetExtensionURI(); - bool IsNative(); - - SubstraitCustomFunction function; - string extension_path; -}; - -struct HashSubstraitFunctions { - size_t operator()(SubstraitCustomFunction const &custom_function) const noexcept { - // Hash Name - auto hash_name = Hash(custom_function.name.c_str()); - // Hash Input Types - auto &i_types = custom_function.arg_types; - auto hash_type = Hash(i_types[0].c_str()); - for (idx_t i = 1; i < i_types.size(); i++) { - hash_type = CombineHash(hash_type, Hash(i_types[i].c_str())); - } - // Combine name and inputs - return CombineHash(hash_name, hash_type); - } -}; - -class SubstraitCustomFunctions { -public: - SubstraitCustomFunctions(); - SubstraitFunctionExtensions Get(const string &name, const vector<::substrait::Type> &types) const; - void Initialize(); - -private: - std::unordered_map custom_functions; - void InsertCustomFunction(string name_p, vector types_p, string file_path); -}; - -} // namespace duckdb \ No newline at end of file + //! Class to describe a custom, individual substrait function + struct SubstraitCustomFunction { + + // Constructors + SubstraitCustomFunction() = default; + SubstraitCustomFunction(string name_p, vector arg_types_p) + : name(std::move(name_p)), arg_types(std::move(arg_types_p)) {} + + // Functions + string GetName(); + bool operator==(const SubstraitCustomFunction &other) const { + return name == other.name && arg_types == other.arg_types; + } + + // Attributes + string name; + vector arg_types; + }; + + + //! Class to describe a substrait function extension + struct SubstraitFunctionExtensions { + + // Constructors + SubstraitFunctionExtensions() = default; + SubstraitFunctionExtensions( SubstraitCustomFunction function_p + ,string extension_path_p) + : function (std::move(function_p)) + ,extension_path(std::move(extension_path_p)) {} + + // Functions + bool IsNative(); + string GetExtensionURI(); + + // Attributes + SubstraitCustomFunction function; + string extension_path; + }; + + + //! Hash function for a custom substrait function based on its signature. + struct HashSubstraitFunctions { + size_t operator()(SubstraitCustomFunction const &custom_function) const noexcept { + auto& fn_argtypes = custom_function.arg_types; + + // Hash the type name for each function argument + auto hashed_argtypes = Hash(fn_argtypes[0].c_str()); + for (idx_t arg_ndx = 1; arg_ndx < fn_argtypes.size(); arg_ndx++) { + hashed_argtypes = CombineHash( + hashed_argtypes + ,Hash(fn_argtypes[arg_ndx].c_str()) + ); + } + + // Combine hashes for the function name and its argument types + auto hashed_name = Hash(custom_function.name.c_str()); + return CombineHash(hashed_name, hashed_argtypes); + } + }; + + + //! Class that contains custom substrait functions and function extensions + struct SubstraitCustomFunctions { + // type aliases for convenience + using SubstraitTypeVec = vector<::substrait::Type>; + using SubstraitFnMap = std::unordered_map< SubstraitCustomFunction + ,SubstraitFunctionExtensions + ,HashSubstraitFunctions >; + + // Constructors + SubstraitCustomFunctions(); + + // Functions + void Initialize(); + SubstraitFunctionExtensions Get( const string& name + ,const SubstraitTypeVec& types) const; + + private: + SubstraitFnMap custom_functions; + + void InsertCustomFunction(string name_p, vector types_p, string file_path); + }; + +} // namespace duckdb diff --git a/src/include/engine_duckdb.hpp b/src/include/engine_duckdb.hpp new file mode 100644 index 0000000..d2a1b87 --- /dev/null +++ b/src/include/engine_duckdb.hpp @@ -0,0 +1,206 @@ +// ------------------------------ +// License +// +// Copyright 2024 Aldrin Montana +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + + +// ------------------------------ +// Dependencies +#pragma once + +#include +#include +#include // for debugging + +#include "duckdb/common/exception.hpp" +#include "duckdb/common/helper.hpp" +#include "duckdb/common/http_state.hpp" +#include "duckdb/common/types.hpp" +#include "duckdb/common/types/value.hpp" +#include "duckdb/common/enums/set_operation_type.hpp" + +#include "duckdb/main/connection.hpp" +#include "duckdb/main/client_data.hpp" +#include "duckdb/main/prepared_statement_data.hpp" + +#include "duckdb/main/relation/join_relation.hpp" +#include "duckdb/main/relation/cross_product_relation.hpp" +#include "duckdb/main/relation/limit_relation.hpp" +#include "duckdb/main/relation/projection_relation.hpp" +#include "duckdb/main/relation/setop_relation.hpp" +#include "duckdb/main/relation/aggregate_relation.hpp" +#include "duckdb/main/relation/filter_relation.hpp" +#include "duckdb/main/relation/order_relation.hpp" + +#include "duckdb/parser/parser.hpp" +#include "duckdb/parser/expression/list.hpp" +#include "duckdb/parser/expression/comparison_expression.hpp" + +#include "duckdb/planner/planner.hpp" +#include "duckdb/planner/bound_result_modifier.hpp" +#include "duckdb/planner/expression.hpp" +#include "duckdb/planner/joinside.hpp" +#include "duckdb/planner/logical_operator.hpp" +#include "duckdb/planner/table_filter.hpp" + +#include "duckdb/optimizer/optimizer.hpp" + +#include "duckdb/function/table_function.hpp" +#include "duckdb/execution/operator/helper/physical_result_collector.hpp" + +#include "custom_extensions/custom_extensions.hpp" +#include "plans.hpp" + + +// ------------------------------ +// Convenience aliases + +using duckdb::unique_ptr; +using std::string; + +using FunctionRenameMap = std::unordered_map; + +using DuckSystemPlan = mohair::SystemPlan; +using DuckLogicalPlan = mohair::SystemPlan; +using DuckPhysicalPlan = mohair::SystemPlan; + + +// ------------------------------ +// Convenience functions + +namespace duckdb { + + //! Remove extension id from a function name + string RemoveExtension(string &function_name); + + //! Return the duckdb function name for the given substrait function name + string RemapFunctionName(string &function_name); + + //! Validate the subfield name is supported by duckdb date types + void AssertValidDateSubfield(const string& subfield); + + //! Convenience function to convert a substrait type to a logical duckdb type + LogicalType SubstraitToDuckType(const substrait::Type& s_type); + +} // namespace: duckdb + + +// ------------------------------ +// Data Classes for binding table functions + +namespace duckdb { + + // Forward declaration to keep TableFunctionData types visible + struct DuckDBTranslator; + + struct FnDataSubstraitTranslation : public TableFunctionData { + FnDataSubstraitTranslation() = default; + + unique_ptr translator; + shared_ptr sys_plan; + shared_ptr engine_plan; + shared_ptr exec_plan; + shared_ptr plan_data; + bool enable_optimizer; + bool finished { false }; + }; + + + struct FnDataSubstraitExecution : public TableFunctionData { + FnDataSubstraitExecution() = default; + + unique_ptr translator; + unique_ptr sys_plan; + unique_ptr result; + }; + +} // namespace: duckdb + + +// ------------------------------ +// Classes for engine-specific translations + +namespace duckdb { + + + //! Translator from substrait plan to DuckDB plans + struct DuckDBTranslator { + + //! Initializes a Translator instance + DuckDBTranslator(ClientContext& context); + + //! Transforms Substrait Plan to DuckDB Relation + unique_ptr TranslatePlanMessage(const string& serialized_msg); + unique_ptr TranslatePlanJson(const string& json_msg); + + //! Transforms DuckDB Relation to DuckDB Logical Operator + shared_ptr TranspilePlanMessage(shared_ptr sys_plan); + + //! Transforms DuckDB Relation to DuckDB Physical Operator + shared_ptr TranslateLogicalPlan( shared_ptr engine_plan + ,bool optimize); + + private: + ClientContext& context; + unique_ptr t_conn; + unique_ptr functions_map; + + // >> Internal functions + private: + shared_ptr TranslateRootOp(const substrait::RelRoot& sop); + shared_ptr TranslateOp (const substrait::Rel& sop); + + //! Translate a substrait expression to a duckdb expression + unique_ptr TranslateExpr(const substrait::Expression &sexpr); + + // >> Internal translation functions for operators + // NOTE: these member methods eventually use t_conn and functions_map + shared_ptr TranslateJoinOp (const substrait::JoinRel& sjoin); + shared_ptr TranslateCrossProductOp (const substrait::CrossRel& scross); + shared_ptr TranslateFetchOp (const substrait::FetchRel& slimit); + shared_ptr TranslateFilterOp (const substrait::FilterRel& sfilter); + shared_ptr TranslateProjectOp (const substrait::ProjectRel& sproj); + shared_ptr TranslateAggregateOp (const substrait::AggregateRel& saggr); + shared_ptr TranslateReadOp (const substrait::ReadRel& sget); + shared_ptr TranslateSortOp (const substrait::SortRel& ssort); + shared_ptr TranslateSetOp (const substrait::SetRel& sset); + + //! Translate Substrait Sort Order to DuckDB Order + OrderByNode TranslateOrder(const substrait::SortField& sordf); + + // >> Internal translation functions for expressions + unique_ptr TranslateSelectionExpr(const substrait::Expression& sexpr); + unique_ptr TranslateIfThenExpr (const substrait::Expression& sexpr); + unique_ptr TranslateCastExpr (const substrait::Expression& sexpr); + unique_ptr TranslateInExpr (const substrait::Expression& sexpr); + + unique_ptr + TranslateLiteralExpr(const substrait::Expression::Literal& slit); + + unique_ptr + TranslateScalarFunctionExpr(const substrait::Expression& sexpr); + }; + + struct DuckDBExecutor { + ClientContext& context; + PreparedStatementData& plan_data; + + DuckDBExecutor(ClientContext& context, PreparedStatementData& plan_data) + : context(context), plan_data(plan_data) {} + + unique_ptr Execute(); + }; + +} // namespace duckdb diff --git a/src/include/plans.hpp b/src/include/plans.hpp new file mode 100644 index 0000000..5e14775 --- /dev/null +++ b/src/include/plans.hpp @@ -0,0 +1,101 @@ +// ------------------------------ +// License +// +// Copyright 2024 Aldrin Montana +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + + +// ------------------------------ +// Dependencies +#pragma once + +#include +#include + +#include "duckdb.hpp" +#include "google/protobuf/util/json_util.h" + +#include "substrait/plan.pb.h" +#include "substrait/algebra.pb.h" + + +// ------------------------------ +// Macros and Type Aliases + +// Standard types +using std::string; +using std::unordered_map; + +// protobuf types and functions from duckdb namespace +using duckdb::unique_ptr; +using duckdb::shared_ptr; +using duckdb::google::protobuf::util::Status; +using duckdb::google::protobuf::util::JsonStringToMessage; + + +// ------------------------------ +// Classes and structs + +namespace mohair { + + //! Alias template for transpilation functions + template + using TranspileSysPlanFnType = std::function; + + template + using TranspileSubPlanFnType = std::function; + + //! Templated class to hold a substrait plan and a "system plan". + /** The system plan is flexibly represented by the templated type so that any particular + * system can translate to the preferred level of abstraction. + */ + template + struct SystemPlan { + //! A system-level query plan represented as substrait + shared_ptr substrait; + + //! A system-level query plan represented by a specific query engine + shared_ptr engine; + + + SystemPlan( shared_ptr s_plan + ,shared_ptr e_plan) + : substrait(s_plan), engine(e_plan) {} + + + // TODO: + // arrow::Table Execute(); + }; + + //! Builder function that constructs SystemPlan from a serialized substrait message + unique_ptr SubstraitPlanFromSubstraitMessage(const string& serialized_msg); + + //! Builder function that constructs SystemPlan from a JSON-formatted substrait message + unique_ptr SubstraitPlanFromSubstraitJson(const string& json_msg); + +} // namespace: mohair + + +// >> Code for managing substrait function extensions. Likely to move in the future. +namespace mohair { + + struct SubstraitFunctionMap { + //! Registry of substrait function extensions used in substrait plan + unordered_map fn_map; + + void RegisterExtensionFunctions(substrait::Plan& plan); + string FindExtensionFunction(uint64_t id); + }; + +} // namespace: mohair diff --git a/src/plans.cpp b/src/plans.cpp new file mode 100644 index 0000000..3c65285 --- /dev/null +++ b/src/plans.cpp @@ -0,0 +1,83 @@ +// ------------------------------ +// License +// +// Copyright 2024 Aldrin Montana +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + + +// ------------------------------ +// Dependencies + +#include "plans.hpp" + + +// ------------------------------ +// Macros and Type Aliases + + +// ------------------------------ +// Classes and structs + +namespace mohair { + + // ------------------------------ + // Methods for SubstraitFunctionMap + + //! Register extension functions from the substrait plan in the function map + void SubstraitFunctionMap::RegisterExtensionFunctions(substrait::Plan& plan) { + for (auto &sext : plan.extensions()) { + if (!sext.has_extension_function()) { continue; } + + const auto fn_anchor = sext.extension_function().function_anchor(); + fn_map[fn_anchor] = sext.extension_function().name(); + } + } + + string SubstraitFunctionMap::FindExtensionFunction(uint64_t id) { + if (fn_map.find(id) == fn_map.end()) { + throw duckdb::InternalException( + "Could not find aggregate function " + std::to_string(id) + ); + } + + return fn_map[id]; + } + + + // ------------------------------ + // Builder Functions for SystemPlan + + //! Builder function that constructs SystemPlan from a serialized substrait message + unique_ptr SubstraitPlanFromSubstraitMessage(const string& serialized_msg) { + auto plan = duckdb::make_uniq(); + if (not plan->ParseFromString(serialized_msg)) { + throw std::runtime_error("Error parsing serialized Substrait Plan"); + } + + return plan; + } + + //! Builder function that constructs SystemPlan from a JSON-formatted substrait message + unique_ptr SubstraitPlanFromSubstraitJson(const string& json_msg) { + auto plan = duckdb::make_uniq(); + + Status status = JsonStringToMessage(json_msg, plan.get()); + if (not status.ok()) { + throw std::runtime_error("Error parsing JSON Substrait Plan: " + status.ToString()); + } + + return plan; + } + +} // namespace: mohair diff --git a/src/substrait_extension.cpp b/src/substrait_extension.cpp index fae645c..2946093 100644 --- a/src/substrait_extension.cpp +++ b/src/substrait_extension.cpp @@ -3,6 +3,8 @@ #include "from_substrait.hpp" #include "substrait_extension.hpp" #include "to_substrait.hpp" +#include "plans.hpp" +#include "engine_duckdb.hpp" #ifndef DUCKDB_AMALGAMATION #include "duckdb/common/enums/optimizer_type.hpp" @@ -16,309 +18,487 @@ namespace duckdb { -struct ToSubstraitFunctionData : public TableFunctionData { - ToSubstraitFunctionData() { - } - string query; - bool enable_optimizer; - bool finished = false; -}; - -static void ToJsonFunctionInternal(ClientContext &context, ToSubstraitFunctionData &data, DataChunk &output, - Connection &new_conn, unique_ptr &query_plan, string &serialized); -static void ToSubFunctionInternal(ClientContext &context, ToSubstraitFunctionData &data, DataChunk &output, - Connection &new_conn, unique_ptr &query_plan, string &serialized); - -static void VerifyJSONRoundtrip(unique_ptr &query_plan, Connection &con, ToSubstraitFunctionData &data, - const string &serialized); -static void VerifyBlobRoundtrip(unique_ptr &query_plan, Connection &con, ToSubstraitFunctionData &data, - const string &serialized); - -static bool SetOptimizationOption(const ClientConfig &config, const duckdb::named_parameter_map_t &named_params) { - for (const auto ¶m : named_params) { - auto loption = StringUtil::Lower(param.first); - // If the user has explicitly requested to enable/disable the optimizer when - // generating Substrait, then that takes precedence. - if (loption == "enable_optimizer") { - return BooleanValue::Get(param.second); - } - } - - // If the user has not specified what they want, fall back to the settings - // on the connection (e.g. if the optimizer was disabled by the user at - // the connection level, it would be surprising to enable the optimizer - // when generating Substrait). - return config.enable_optimizer; -} - -static unique_ptr InitToSubstraitFunctionData(const ClientConfig &config, - TableFunctionBindInput &input) { - auto result = make_uniq(); - result->query = input.inputs[0].ToString(); - result->enable_optimizer = SetOptimizationOption(config, input.named_parameters); - return std::move(result); -} - -static unique_ptr ToSubstraitBind(ClientContext &context, TableFunctionBindInput &input, - vector &return_types, vector &names) { - return_types.emplace_back(LogicalType::BLOB); - names.emplace_back("Plan Blob"); - auto result = InitToSubstraitFunctionData(context.config, input); - return std::move(result); -} - -static unique_ptr ToJsonBind(ClientContext &context, TableFunctionBindInput &input, - vector &return_types, vector &names) { - return_types.emplace_back(LogicalType::VARCHAR); - names.emplace_back("Json"); - auto result = InitToSubstraitFunctionData(context.config, input); - return std::move(result); -} - -shared_ptr SubstraitPlanToDuckDBRel(Connection &conn, const string &serialized, bool json = false) { - SubstraitToDuckDB transformer_s2d(conn, serialized, json); - return transformer_s2d.TransformPlan(); -} - -static void VerifySubstraitRoundtrip(unique_ptr &query_plan, Connection &con, - ToSubstraitFunctionData &data, const string &serialized, bool is_json) { - // We round-trip the generated json and verify if the result is the same - auto actual_result = con.Query(data.query); - - auto sub_relation = SubstraitPlanToDuckDBRel(con, serialized, is_json); - auto substrait_result = sub_relation->Execute(); - substrait_result->names = actual_result->names; - unique_ptr substrait_materialized; - - if (substrait_result->type == QueryResultType::STREAM_RESULT) { - auto &stream_query = substrait_result->Cast(); - - substrait_materialized = stream_query.Materialize(); - } else if (substrait_result->type == QueryResultType::MATERIALIZED_RESULT) { - substrait_materialized = unique_ptr_cast(std::move(substrait_result)); - } - auto actual_col_coll = actual_result->Collection(); - auto subs_col_coll = substrait_materialized->Collection(); - string error_message; - if (!ColumnDataCollection::ResultEquals(actual_col_coll, subs_col_coll, error_message)) { - query_plan->Print(); - sub_relation->Print(); - throw InternalException("The query result of DuckDB's query plan does not match Substrait : " + error_message); - } -} - -static void VerifyBlobRoundtrip(unique_ptr &query_plan, Connection &con, ToSubstraitFunctionData &data, - const string &serialized) { - VerifySubstraitRoundtrip(query_plan, con, data, serialized, false); -} - -static void VerifyJSONRoundtrip(unique_ptr &query_plan, Connection &con, ToSubstraitFunctionData &data, - const string &serialized) { - VerifySubstraitRoundtrip(query_plan, con, data, serialized, true); -} - -static DuckDBToSubstrait InitPlanExtractor(ClientContext &context, ToSubstraitFunctionData &data, Connection &new_conn, - unique_ptr &query_plan) { - // The user might want to disable the optimizer of the new connection - new_conn.context->config.enable_optimizer = data.enable_optimizer; - new_conn.context->config.use_replacement_scans = false; - - // We want for sure to disable the internal compression optimizations. - // These are DuckDB specific, no other system implements these. Also, - // respect the user's settings if they chose to disable any specific optimizers. - // - // The InClauseRewriter optimization converts large `IN` clauses to a - // "mark join" against a `ColumnDataCollection`, which may not make - // sense in other systems and would complicate the conversion to Substrait. - set disabled_optimizers = DBConfig::GetConfig(context).options.disabled_optimizers; - disabled_optimizers.insert(OptimizerType::IN_CLAUSE); - disabled_optimizers.insert(OptimizerType::COMPRESSED_MATERIALIZATION); - DBConfig::GetConfig(*new_conn.context).options.disabled_optimizers = disabled_optimizers; - - query_plan = new_conn.context->ExtractPlan(data.query); - return DuckDBToSubstrait(context, *query_plan); -} - -static void ToSubFunctionInternal(ClientContext &context, ToSubstraitFunctionData &data, DataChunk &output, - Connection &new_conn, unique_ptr &query_plan, string &serialized) { - output.SetCardinality(1); - auto transformer_d2s = InitPlanExtractor(context, data, new_conn, query_plan); - serialized = transformer_d2s.SerializeToString(); - output.SetValue(0, 0, Value::BLOB_RAW(serialized)); -} - -static void ToSubFunction(ClientContext &context, TableFunctionInput &data_p, DataChunk &output) { - auto &data = (ToSubstraitFunctionData &)*data_p.bind_data; - if (data.finished) { - return; - } - auto new_conn = Connection(*context.db); - - unique_ptr query_plan; - string serialized; - ToSubFunctionInternal(context, data, output, new_conn, query_plan, serialized); - - data.finished = true; - - if (!context.config.query_verification_enabled) { - return; - } - VerifyBlobRoundtrip(query_plan, new_conn, data, serialized); - // Also run the ToJson path and verify round-trip for that - DataChunk other_output; - other_output.Initialize(context, {LogicalType::VARCHAR}); - ToJsonFunctionInternal(context, data, other_output, new_conn, query_plan, serialized); - VerifyJSONRoundtrip(query_plan, new_conn, data, serialized); -} - -static void ToJsonFunctionInternal(ClientContext &context, ToSubstraitFunctionData &data, DataChunk &output, - Connection &new_conn, unique_ptr &query_plan, string &serialized) { - output.SetCardinality(1); - auto transformer_d2s = InitPlanExtractor(context, data, new_conn, query_plan); - serialized = transformer_d2s.SerializeToJson(); - output.SetValue(0, 0, serialized); -} - -static void ToJsonFunction(ClientContext &context, TableFunctionInput &data_p, DataChunk &output) { - auto &data = (ToSubstraitFunctionData &)*data_p.bind_data; - if (data.finished) { - return; - } - auto new_conn = Connection(*context.db); - - unique_ptr query_plan; - string serialized; - ToJsonFunctionInternal(context, data, output, new_conn, query_plan, serialized); - - data.finished = true; - - if (!context.config.query_verification_enabled) { - return; - } - VerifyJSONRoundtrip(query_plan, new_conn, data, serialized); - // Also run the ToJson path and verify round-trip for that - DataChunk other_output; - other_output.Initialize(context, {LogicalType::BLOB}); - ToSubFunctionInternal(context, data, other_output, new_conn, query_plan, serialized); - VerifyBlobRoundtrip(query_plan, new_conn, data, serialized); -} - -struct FromSubstraitFunctionData : public TableFunctionData { - FromSubstraitFunctionData() = default; - shared_ptr plan; - unique_ptr res; - unique_ptr conn; -}; - -static unique_ptr SubstraitBind(ClientContext &context, TableFunctionBindInput &input, - vector &return_types, vector &names, bool is_json) { - auto result = make_uniq(); - result->conn = make_uniq(*context.db); - if (input.inputs[0].IsNull()) { - throw BinderException("from_substrait cannot be called with a NULL parameter"); - } - string serialized = input.inputs[0].GetValueUnsafe(); - result->plan = SubstraitPlanToDuckDBRel(*result->conn, serialized, is_json); - for (auto &column : result->plan->Columns()) { - return_types.emplace_back(column.Type()); - names.emplace_back(column.Name()); - } - return std::move(result); -} - -static unique_ptr FromSubstraitBind(ClientContext &context, TableFunctionBindInput &input, + // Containers for bound data received from a Table Function + struct ToSubstraitFunctionData : public TableFunctionData { + ToSubstraitFunctionData() {} + + string query; + bool enable_optimizer; + bool finished { false }; + }; + + // Helper functions for `to_substrait` table function + static void ToJsonFunctionInternal( + ClientContext& context + ,ToSubstraitFunctionData& data + ,DataChunk& output + ,Connection& new_conn + ,unique_ptr& query_plan + ,string& serialized + ); + + static void ToSubFunctionInternal( + ClientContext& context + ,ToSubstraitFunctionData& data + ,DataChunk& output + ,Connection& new_conn + ,unique_ptr& query_plan + ,string& serialized + ); + + static void VerifyJSONRoundtrip( + unique_ptr& query_plan + ,Connection& con + ,ToSubstraitFunctionData& data + ,const string& serialized + ); + + static void VerifyBlobRoundtrip( + unique_ptr& query_plan + ,Connection& con + ,ToSubstraitFunctionData& data + ,const string& serialized + ); + + static bool SetOptimizationOption( const ClientConfig &config + ,const duckdb::named_parameter_map_t &named_params) { + + // First, check if the user has explicitly requested to enable/disable the optimizer + for (const auto ¶m : named_params) { + auto loption = StringUtil::Lower(param.first); + if (loption == "enable_optimizer") { return BooleanValue::Get(param.second); } + } + + // Default to the connection-level setting + return config.enable_optimizer; + } + + static unique_ptr InitToSubstraitFunctionData(const ClientConfig &config, + TableFunctionBindInput &input) { + auto result = make_uniq(); + result->query = input.inputs[0].ToString(); + result->enable_optimizer = SetOptimizationOption(config, input.named_parameters); + return std::move(result); + } + + static unique_ptr ToSubstraitBind(ClientContext &context, TableFunctionBindInput &input, vector &return_types, vector &names) { - return SubstraitBind(context, input, return_types, names, false); -} - -static unique_ptr FromSubstraitBindJSON(ClientContext &context, TableFunctionBindInput &input, - vector &return_types, vector &names) { - return SubstraitBind(context, input, return_types, names, true); -} - -static void FromSubFunction(ClientContext &context, TableFunctionInput &data_p, DataChunk &output) { - auto &data = (FromSubstraitFunctionData &)*data_p.bind_data; - if (!data.res) { - data.res = data.plan->Execute(); - } - auto result_chunk = data.res->Fetch(); - if (!result_chunk) { - return; - } - output.Move(*result_chunk); -} - -void InitializeGetSubstrait(Connection &con) { - auto &catalog = Catalog::GetSystemCatalog(*con.context); - - // create the get_substrait table function that allows us to get a substrait - // binary from a valid SQL Query - TableFunction to_sub_func("get_substrait", {LogicalType::VARCHAR}, ToSubFunction, ToSubstraitBind); - to_sub_func.named_parameters["enable_optimizer"] = LogicalType::BOOLEAN; - CreateTableFunctionInfo to_sub_info(to_sub_func); - catalog.CreateTableFunction(*con.context, to_sub_info); -} - -void InitializeGetSubstraitJSON(Connection &con) { - auto &catalog = Catalog::GetSystemCatalog(*con.context); - - // create the get_substrait table function that allows us to get a substrait - // JSON from a valid SQL Query - TableFunction get_substrait_json("get_substrait_json", {LogicalType::VARCHAR}, ToJsonFunction, ToJsonBind); - - get_substrait_json.named_parameters["enable_optimizer"] = LogicalType::BOOLEAN; - CreateTableFunctionInfo get_substrait_json_info(get_substrait_json); - catalog.CreateTableFunction(*con.context, get_substrait_json_info); -} - -void InitializeFromSubstrait(Connection &con) { - auto &catalog = Catalog::GetSystemCatalog(*con.context); - - // create the from_substrait table function that allows us to get a query - // result from a substrait plan - TableFunction from_sub_func("from_substrait", {LogicalType::BLOB}, FromSubFunction, FromSubstraitBind); - CreateTableFunctionInfo from_sub_info(from_sub_func); - catalog.CreateTableFunction(*con.context, from_sub_info); -} - -void InitializeFromSubstraitJSON(Connection &con) { - auto &catalog = Catalog::GetSystemCatalog(*con.context); - - // create the from_substrait table function that allows us to get a query - // result from a substrait plan - TableFunction from_sub_func_json("from_substrait_json", {LogicalType::VARCHAR}, FromSubFunction, - FromSubstraitBindJSON); - CreateTableFunctionInfo from_sub_info_json(from_sub_func_json); - catalog.CreateTableFunction(*con.context, from_sub_info_json); -} - -void SubstraitExtension::Load(DuckDB &db) { - Connection con(db); - con.BeginTransaction(); - - InitializeGetSubstrait(con); - InitializeGetSubstraitJSON(con); - - InitializeFromSubstrait(con); - InitializeFromSubstraitJSON(con); - - con.Commit(); -} - -std::string SubstraitExtension::Name() { - return "substrait"; -} + return_types.emplace_back(LogicalType::BLOB); + names.emplace_back("Plan Blob"); + auto result = InitToSubstraitFunctionData(context.config, input); + return std::move(result); + } + + static unique_ptr ToJsonBind(ClientContext &context, TableFunctionBindInput &input, + vector &return_types, vector &names) { + return_types.emplace_back(LogicalType::VARCHAR); + names.emplace_back("Json"); + auto result = InitToSubstraitFunctionData(context.config, input); + return std::move(result); + } + + shared_ptr SubstraitPlanToDuckDBRel(Connection &conn, const string &serialized, bool json = false) { + SubstraitToDuckDB transformer_s2d(conn, serialized, json); + return transformer_s2d.TransformPlan(); + } + + static void VerifySubstraitRoundtrip(unique_ptr &query_plan, Connection &con, + ToSubstraitFunctionData &data, const string &serialized, bool is_json) { + // We round-trip the generated json and verify if the result is the same + auto actual_result = con.Query(data.query); + + auto sub_relation = SubstraitPlanToDuckDBRel(con, serialized, is_json); + auto substrait_result = sub_relation->Execute(); + substrait_result->names = actual_result->names; + unique_ptr substrait_materialized; + + if (substrait_result->type == QueryResultType::STREAM_RESULT) { + auto &stream_query = substrait_result->Cast(); + + substrait_materialized = stream_query.Materialize(); + } else if (substrait_result->type == QueryResultType::MATERIALIZED_RESULT) { + substrait_materialized = unique_ptr_cast(std::move(substrait_result)); + } + auto actual_col_coll = actual_result->Collection(); + auto subs_col_coll = substrait_materialized->Collection(); + string error_message; + if (!ColumnDataCollection::ResultEquals(actual_col_coll, subs_col_coll, error_message)) { + query_plan->Print(); + sub_relation->Print(); + throw InternalException("The query result of DuckDB's query plan does not match Substrait : " + error_message); + } + } + + static void VerifyBlobRoundtrip(unique_ptr &query_plan, Connection &con, ToSubstraitFunctionData &data, + const string &serialized) { + VerifySubstraitRoundtrip(query_plan, con, data, serialized, false); + } + + static void VerifyJSONRoundtrip(unique_ptr &query_plan, Connection &con, ToSubstraitFunctionData &data, + const string &serialized) { + VerifySubstraitRoundtrip(query_plan, con, data, serialized, true); + } + + static DuckDBToSubstrait InitPlanExtractor(ClientContext &context, ToSubstraitFunctionData &data, Connection &new_conn, + unique_ptr &query_plan) { + // The user might want to disable the optimizer of the new connection + new_conn.context->config.enable_optimizer = data.enable_optimizer; + new_conn.context->config.use_replacement_scans = false; + + // We want for sure to disable the internal compression optimizations. + // These are DuckDB specific, no other system implements these. Also, + // respect the user's settings if they chose to disable any specific optimizers. + // + // The InClauseRewriter optimization converts large `IN` clauses to a + // "mark join" against a `ColumnDataCollection`, which may not make + // sense in other systems and would complicate the conversion to Substrait. + set disabled_optimizers = DBConfig::GetConfig(context).options.disabled_optimizers; + disabled_optimizers.insert(OptimizerType::IN_CLAUSE); + disabled_optimizers.insert(OptimizerType::COMPRESSED_MATERIALIZATION); + DBConfig::GetConfig(*new_conn.context).options.disabled_optimizers = disabled_optimizers; + + query_plan = new_conn.context->ExtractPlan(data.query); + return DuckDBToSubstrait(context, *query_plan); + } + + static void ToSubFunctionInternal(ClientContext &context, ToSubstraitFunctionData &data, DataChunk &output, + Connection &new_conn, unique_ptr &query_plan, string &serialized) { + output.SetCardinality(1); + auto transformer_d2s = InitPlanExtractor(context, data, new_conn, query_plan); + serialized = transformer_d2s.SerializeToString(); + output.SetValue(0, 0, Value::BLOB_RAW(serialized)); + } + + static void ToSubFunction(ClientContext &context, TableFunctionInput &data_p, DataChunk &output) { + auto &data = (ToSubstraitFunctionData &)*data_p.bind_data; + if (data.finished) { + return; + } + auto new_conn = Connection(*context.db); + + unique_ptr query_plan; + string serialized; + ToSubFunctionInternal(context, data, output, new_conn, query_plan, serialized); + + data.finished = true; + + if (!context.config.query_verification_enabled) { + return; + } + VerifyBlobRoundtrip(query_plan, new_conn, data, serialized); + // Also run the ToJson path and verify round-trip for that + DataChunk other_output; + other_output.Initialize(context, {LogicalType::VARCHAR}); + ToJsonFunctionInternal(context, data, other_output, new_conn, query_plan, serialized); + VerifyJSONRoundtrip(query_plan, new_conn, data, serialized); + } + + static void ToJsonFunctionInternal(ClientContext &context, ToSubstraitFunctionData &data, DataChunk &output, + Connection &new_conn, unique_ptr &query_plan, string &serialized) { + output.SetCardinality(1); + auto transformer_d2s = InitPlanExtractor(context, data, new_conn, query_plan); + serialized = transformer_d2s.SerializeToJson(); + output.SetValue(0, 0, serialized); + } + + static void ToJsonFunction(ClientContext &context, TableFunctionInput &data_p, DataChunk &output) { + auto &data = (ToSubstraitFunctionData &)*data_p.bind_data; + if (data.finished) { + return; + } + auto new_conn = Connection(*context.db); + + unique_ptr query_plan; + string serialized; + ToJsonFunctionInternal(context, data, output, new_conn, query_plan, serialized); + + data.finished = true; + + if (!context.config.query_verification_enabled) { + return; + } + VerifyJSONRoundtrip(query_plan, new_conn, data, serialized); + // Also run the ToJson path and verify round-trip for that + DataChunk other_output; + other_output.Initialize(context, {LogicalType::BLOB}); + ToSubFunctionInternal(context, data, other_output, new_conn, query_plan, serialized); + VerifyBlobRoundtrip(query_plan, new_conn, data, serialized); + } + + struct FromSubstraitFunctionData : public TableFunctionData { + FromSubstraitFunctionData() = default; + shared_ptr plan; + unique_ptr result; + unique_ptr conn; + }; + + static unique_ptr SubstraitBind(ClientContext &context, TableFunctionBindInput &input, + vector &return_types, vector &names, bool is_json) { + auto result = make_uniq(); + result->conn = make_uniq(*context.db); + if (input.inputs[0].IsNull()) { + throw BinderException("from_substrait cannot be called with a NULL parameter"); + } + string serialized = input.inputs[0].GetValueUnsafe(); + result->plan = SubstraitPlanToDuckDBRel(*result->conn, serialized, is_json); + for (auto &column : result->plan->Columns()) { + return_types.emplace_back(column.Type()); + names.emplace_back(column.Name()); + } + return std::move(result); + } + + + + static unique_ptr + FromSubstraitBind( ClientContext& context + ,TableFunctionBindInput& input + ,vector& return_types + ,vector& names) { + return SubstraitBind(context, input, return_types, names, false); + } + + static unique_ptr FromSubstraitBindJSON(ClientContext &context, TableFunctionBindInput &input, + vector &return_types, vector &names) { + return SubstraitBind(context, input, return_types, names, true); + } + + static void FromSubFunction(ClientContext &context, TableFunctionInput &data_p, DataChunk &output) { + auto &data = (FromSubstraitFunctionData &)*data_p.bind_data; + if (!data.result) { + data.result = data.plan->Execute(); + } + auto result_chunk = data.result->Fetch(); + if (!result_chunk) { + return; + } + output.Move(*result_chunk); + } + + + // ------------------------------ + // Supporting functions for Table Function "translate_mohair" + + static unique_ptr + BindingFnTranslateMohair( ClientContext& context + ,TableFunctionBindInput& input + ,vector& return_types + ,vector& names) { + if (input.inputs[0].IsNull()) { + throw BinderException("from_substrait cannot be called with a NULL parameter"); + } + string plan_msg { input.inputs[0].GetValueUnsafe() }; + bool enable_optimizer { SetOptimizationOption(context.config, input.named_parameters) }; + + // Prepare a FunctionData instance to return + auto fn_data = make_uniq(); + fn_data->translator = make_uniq(context); + fn_data->sys_plan = fn_data->translator->TranslatePlanMessage(plan_msg); + fn_data->plan_data = std::make_shared(StatementType::SELECT_STATEMENT); + + // For us to further build PreparedStatementData + // (probably affects our ResultCollector) + for (auto &column : fn_data->sys_plan->engine->Columns()) { + fn_data->plan_data->types.emplace_back(column.Type()); + fn_data->plan_data->names.emplace_back(column.Name()); + } + + // Set result schema (binding) + return_types.emplace_back(LogicalType::VARCHAR); + names.emplace_back("Physical Plan"); + + return std::move(fn_data); + } + + static void + TableFnTranslateMohair( ClientContext& context + ,TableFunctionInput& data_p + ,DataChunk& output) { + auto &fn_data = (FnDataSubstraitTranslation&) *(data_p.bind_data); + if (fn_data.finished) { return; } + + if (not fn_data.exec_plan) { + // Convert plan to engine plan + fn_data.engine_plan = fn_data.translator->TranspilePlanMessage(fn_data.sys_plan); + + // Convert engine plan to execution plan + fn_data.exec_plan = fn_data.translator->TranslateLogicalPlan( + fn_data.engine_plan, fn_data.enable_optimizer + ); + + fn_data.finished = true; + } + + // output.Initialize(context, { LogicalType::VARCHAR }, 1); + output.SetCardinality(1); + output.SetValue(0, 0, fn_data.exec_plan->engine->ToString()); + } + + + // ------------------------------ + // Supporting functions for Table Function "execute_mohair" + + static unique_ptr + BindingFnExecuteMohair( ClientContext& context + ,TableFunctionBindInput& input + ,vector& return_types + ,vector& names) { + if (input.inputs[0].IsNull()) { + throw BinderException("from_substrait cannot be called with a NULL parameter"); + } + + string plan_msg { input.inputs[0].GetValueUnsafe() }; + + auto result = make_uniq(); + result->translator = make_uniq(context); + result->sys_plan = result->translator->TranslatePlanMessage(plan_msg); + + for (auto &column : result->sys_plan->engine->Columns()) { + return_types.emplace_back(column.Type()); + names.emplace_back(column.Name()); + } + + return result; + } + + static void + TableFnExecuteMohair( ClientContext& context + ,TableFunctionInput& data_p + ,DataChunk& output) { + auto& fn_data = (FnDataSubstraitExecution&) *(data_p.bind_data); + + if (!fn_data.result) { + fn_data.result = fn_data.sys_plan->engine->Execute(); + } + + auto result_chunk = fn_data.result->Fetch(); + if (result_chunk) { output.Move(*result_chunk); } + } + + // ------------------------------ + // Initializers for Table Functions that implement extension logic + + //! Create a TableFunction, "get_substrait", then register it with the catalog + void InitializeGetSubstrait(Connection &con) { + TableFunction to_sub_func("get_substrait" + ,{ LogicalType::VARCHAR } + ,ToSubFunction + ,ToSubstraitBind + ); + + to_sub_func.named_parameters["enable_optimizer"] = LogicalType::BOOLEAN; + CreateTableFunctionInfo to_sub_info(to_sub_func); + + auto &catalog = Catalog::GetSystemCatalog(*con.context); + catalog.CreateTableFunction(*con.context, to_sub_info); + } + + //! Create a TableFunction, "get_substrait_json", then register it with the catalog + void InitializeGetSubstraitJSON(Connection &con) { + TableFunction get_substrait_json("get_substrait_json" + ,{ LogicalType::VARCHAR } + ,ToJsonFunction + ,ToJsonBind + ); + + get_substrait_json.named_parameters["enable_optimizer"] = LogicalType::BOOLEAN; + CreateTableFunctionInfo get_substrait_json_info(get_substrait_json); + + auto &catalog = Catalog::GetSystemCatalog(*con.context); + catalog.CreateTableFunction(*con.context, get_substrait_json_info); + } + + //! Create a TableFunction, "translate_mohair", then register it with the catalog + void InitializeTranslateMohair(Connection &con) { + TableFunction tablefn_mohair( + "translate_mohair" + ,{ LogicalType::BLOB } + ,TableFnTranslateMohair + ,BindingFnTranslateMohair + ); + + CreateTableFunctionInfo fninfo_mohair(tablefn_mohair); + + auto &catalog = Catalog::GetSystemCatalog(*(con.context)); + catalog.CreateTableFunction(*(con.context), fninfo_mohair); + } + + //! Create a TableFunction, "execute_mohair", then register it with the catalog + void InitializeExecuteMohair(Connection &con) { + TableFunction tablefn_mohair( + "execute_mohair" + ,{ LogicalType::BLOB } + ,TableFnExecuteMohair + ,BindingFnExecuteMohair + ); + + CreateTableFunctionInfo fninfo_mohair(tablefn_mohair); + + auto &catalog = Catalog::GetSystemCatalog(*(con.context)); + catalog.CreateTableFunction(*(con.context), fninfo_mohair); + } + + //! Create a TableFunction, "from_substrait", then register it with the catalog + void InitializeFromSubstrait(Connection &con) { + TableFunction from_sub_func("from_substrait" + ,{ LogicalType::BLOB } + ,FromSubFunction + ,FromSubstraitBind + ); + + CreateTableFunctionInfo from_sub_info(from_sub_func); + + auto &catalog = Catalog::GetSystemCatalog(*con.context); + catalog.CreateTableFunction(*con.context, from_sub_info); + } + + //! Create a TableFunction, "from_substrait_json", then register it with the catalog + void InitializeFromSubstraitJSON(Connection &con) { + TableFunction from_sub_func_json("from_substrait_json" + ,{ LogicalType::VARCHAR } + ,FromSubFunction + ,FromSubstraitBindJSON + ); + + CreateTableFunctionInfo from_sub_info_json(from_sub_func_json); + + auto &catalog = Catalog::GetSystemCatalog(*con.context); + catalog.CreateTableFunction(*con.context, from_sub_info_json); + } + + //! Logic for loading this extension + void SubstraitExtension::Load(DuckDB &db) { + Connection con(db); + con.BeginTransaction(); + + InitializeGetSubstrait(con); + InitializeGetSubstraitJSON(con); + + InitializeFromSubstrait(con); + InitializeFromSubstraitJSON(con); + InitializeTranslateMohair(con); + InitializeExecuteMohair(con); + + con.Commit(); + } + + std::string SubstraitExtension::Name() { + return "substrait"; + } } // namespace duckdb + extern "C" { -DUCKDB_EXTENSION_API void substrait_init(duckdb::DatabaseInstance &db) { - duckdb::DuckDB db_wrapper(db); - db_wrapper.LoadExtension(); -} + DUCKDB_EXTENSION_API + void substrait_init(duckdb::DatabaseInstance& db) { + duckdb::DuckDB db_wrapper(db); + db_wrapper.LoadExtension(); + } -DUCKDB_EXTENSION_API const char *substrait_version() { - return duckdb::DuckDB::LibraryVersion(); -} + DUCKDB_EXTENSION_API + const char* substrait_version() { + return duckdb::DuckDB::LibraryVersion(); + } } diff --git a/src/translation/duckdb_expressions.cpp b/src/translation/duckdb_expressions.cpp new file mode 100644 index 0000000..2368c90 --- /dev/null +++ b/src/translation/duckdb_expressions.cpp @@ -0,0 +1,421 @@ +// ------------------------------ +// License +// +// Copyright 2024 Aldrin Montana +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + + +// ------------------------------ +// Dependencies + +#include "engine_duckdb.hpp" + + +// ------------------------------ +// Functions + +namespace duckdb { + + LogicalType SubstraitToDuckType(const substrait::Type& s_type) { + if (s_type.has_bool_()) { return LogicalType(LogicalTypeId::BOOLEAN ); } + else if (s_type.has_i16() ) { return LogicalType(LogicalTypeId::SMALLINT); } + else if (s_type.has_i32() ) { return LogicalType(LogicalTypeId::INTEGER ); } + else if (s_type.has_i64() ) { return LogicalType(LogicalTypeId::BIGINT ); } + else if (s_type.has_date() ) { return LogicalType(LogicalTypeId::DATE ); } + else if (s_type.has_fp64() ) { return LogicalType(LogicalTypeId::DOUBLE ); } + + else if (s_type.has_varchar() || s_type.has_string()) { + return LogicalType(LogicalTypeId::VARCHAR); + } + + else if (s_type.has_decimal()) { + auto &s_decimal_type = s_type.decimal(); + return LogicalType::DECIMAL(s_decimal_type.precision(), s_decimal_type.scale()); + } + + else { + throw InternalException("Substrait type not yet supported"); + } + } + + using SLiteralType = substrait::Expression::Literal::LiteralTypeCase; + + unique_ptr + DuckDBTranslator::TranslateLiteralExpr(const substrait::Expression::Literal& slit) { + Value dval; + + if (slit.has_null()) { + dval = Value(LogicalType::SQLNULL); + return make_uniq(dval); + } + + switch (slit.literal_type_case()) { + case SLiteralType::kFp64: + dval = Value::DOUBLE(slit.fp64()); + break; + + case SLiteralType::kFp32: + dval = Value::FLOAT(slit.fp32()); + break; + + case SLiteralType::kString: + dval = Value(slit.string()); + break; + + case SLiteralType::kDecimal: { + const auto& sdecimal = slit.decimal(); + auto decimal_type = LogicalType::DECIMAL( + sdecimal.precision() + ,sdecimal.scale() + ); + + hugeint_t substrait_value; + auto raw_value = (uint64_t *) sdecimal.value().c_str(); + substrait_value.lower = raw_value[0]; + substrait_value.upper = raw_value[1]; + Value val = Value::HUGEINT(substrait_value); + + // cast to correct value + switch (decimal_type.InternalType()) { + case PhysicalType::INT8: + dval = Value::DECIMAL( + val.GetValue() + ,sdecimal.precision() + ,sdecimal.scale() + ); + break; + + case PhysicalType::INT16: + dval = Value::DECIMAL( + val.GetValue() + ,sdecimal.precision() + ,sdecimal.scale() + ); + break; + + case PhysicalType::INT32: + dval = Value::DECIMAL( + val.GetValue() + ,sdecimal.precision() + ,sdecimal.scale() + ); + break; + + case PhysicalType::INT64: + dval = Value::DECIMAL( + val.GetValue() + ,sdecimal.precision() + ,sdecimal.scale() + ); + break; + + case PhysicalType::INT128: + dval = Value::DECIMAL(substrait_value, sdecimal.precision(), sdecimal.scale()); + break; + + default: + throw InternalException("Not accepted internal type for decimal"); + } + break; + } + + case SLiteralType::kBoolean: { + dval = Value(slit.boolean()); + break; + } + + case SLiteralType::kI8: + dval = Value::TINYINT(slit.i8()); + break; + + case SLiteralType::kI32: + dval = Value::INTEGER(slit.i32()); + break; + + case SLiteralType::kI64: + dval = Value::BIGINT(slit.i64()); + break; + + case SLiteralType::kDate: { + date_t date(slit.date()); + dval = Value::DATE(date); + break; + } + + case SLiteralType::kTime: { + dtime_t time(slit.time()); + dval = Value::TIME(time); + break; + } + + case SLiteralType::kIntervalYearToMonth: { + interval_t interval; + interval.months = slit.interval_year_to_month().months(); + interval.days = 0; + interval.micros = 0; + dval = Value::INTERVAL(interval); + break; + } + + case SLiteralType::kIntervalDayToSecond: { + interval_t interval; + interval.months = 0; + interval.days = slit.interval_day_to_second().days(); + interval.micros = slit.interval_day_to_second().microseconds(); + dval = Value::INTERVAL(interval); + break; + } + + default: + throw InternalException(to_string(slit.literal_type_case())); + } + + return make_uniq(dval); + } + + + unique_ptr + DuckDBTranslator::TranslateSelectionExpr(const substrait::Expression &sexpr) { + if ( !sexpr.selection().has_direct_reference() + || !sexpr.selection().direct_reference().has_struct_field()) { + throw InternalException("Can only have direct struct references in selections"); + } + + return make_uniq( + sexpr.selection().direct_reference().struct_field().field() + 1 + ); + } + + unique_ptr + DuckDBTranslator::TranslateScalarFunctionExpr(const substrait::Expression& sexpr) { + auto function_id = sexpr.scalar_function().function_reference(); + auto function_name = functions_map->FindExtensionFunction(function_id); + function_name = RemoveExtension(function_name); + + vector> children; + vector enum_expressions; + + auto& function_arguments = sexpr.scalar_function().arguments(); + for (auto& sarg : function_arguments) { + + // value expression + if (sarg.has_value()) { children.push_back(TranslateExpr(sarg.value())); } + + // type expression + else if (sarg.has_type()) { + throw NotImplementedException( + "Type arguments in Substrait expressions are not supported yet!" + ); + } + + // enum expression + else { + D_ASSERT(sarg.has_enum_()); + auto &enum_str = sarg.enum_(); + enum_expressions.push_back(enum_str); + } + } + + // string compare galore + // TODO simplify this + if (function_name == "and") { + return make_uniq( + ExpressionType::CONJUNCTION_AND, std::move(children) + ); + } + + else if (function_name == "or") { + return make_uniq( + ExpressionType::CONJUNCTION_OR, std::move(children) + ); + } + + else if (function_name == "lt") { + D_ASSERT(children.size() == 2); + return make_uniq( + ExpressionType::COMPARE_LESSTHAN + ,std::move(children[0]) + ,std::move(children[1]) + ); + } + + else if (function_name == "equal") { + D_ASSERT(children.size() == 2); + return make_uniq( + ExpressionType::COMPARE_EQUAL + ,std::move(children[0]) + ,std::move(children[1]) + ); + } + + else if (function_name == "not_equal") { + D_ASSERT(children.size() == 2); + return make_uniq( + ExpressionType::COMPARE_NOTEQUAL + ,std::move(children[0]) + ,std::move(children[1]) + ); + } + + else if (function_name == "lte") { + D_ASSERT(children.size() == 2); + return make_uniq( + ExpressionType::COMPARE_LESSTHANOREQUALTO + ,std::move(children[0]) + ,std::move(children[1]) + ); + } + + else if (function_name == "gte") { + D_ASSERT(children.size() == 2); + return make_uniq( + ExpressionType::COMPARE_GREATERTHANOREQUALTO + ,std::move(children[0]) + ,std::move(children[1]) + ); + } + + else if (function_name == "gt") { + D_ASSERT(children.size() == 2); + return make_uniq( + ExpressionType::COMPARE_GREATERTHAN + ,std::move(children[0]) + ,std::move(children[1]) + ); + } + + else if (function_name == "is_not_null") { + D_ASSERT(children.size() == 1); + return make_uniq( + ExpressionType::OPERATOR_IS_NOT_NULL + ,std::move(children[0]) + ); + } + + else if (function_name == "is_null") { + D_ASSERT(children.size() == 1); + return make_uniq( + ExpressionType::OPERATOR_IS_NULL + ,std::move(children[0]) + ); + } + + else if (function_name == "not") { + D_ASSERT(children.size() == 1); + return make_uniq( + ExpressionType::OPERATOR_NOT + ,std::move(children[0]) + ); + } + + else if (function_name == "is_not_distinct_from") { + D_ASSERT(children.size() == 2); + return make_uniq( + ExpressionType::COMPARE_NOT_DISTINCT_FROM + ,std::move(children[0]) + ,std::move(children[1]) + ); + } + + else if (function_name == "between") { + // FIXME: ADD between to substrait extension + D_ASSERT(children.size() == 3); + return make_uniq( + std::move(children[0]) + ,std::move(children[1]) + ,std::move(children[2]) + ); + } + + else if (function_name == "extract") { + D_ASSERT(enum_expressions.size() == 1); + + auto& subfield = enum_expressions[0]; + AssertValidDateSubfield(subfield); + + auto constant_expression = make_uniq(Value(subfield)); + children.insert(children.begin(), std::move(constant_expression)); + } + + return make_uniq( + RemapFunctionName(function_name), std::move(children) + ); + } + + + unique_ptr + DuckDBTranslator::TranslateIfThenExpr(const substrait::Expression &sexpr) { + const auto& scase = sexpr.if_then(); + auto dcase = make_uniq(); + + for (const auto &sif : scase.ifs()) { + CaseCheck dif; + dif.when_expr = TranslateExpr(sif.if_()); + dif.then_expr = TranslateExpr(sif.then()); + dcase->case_checks.push_back(std::move(dif)); + } + + dcase->else_expr = TranslateExpr(scase.else_()); + return std::move(dcase); + } + + + unique_ptr + DuckDBTranslator::TranslateCastExpr(const substrait::Expression &sexpr) { + const auto& scast = sexpr.cast(); + auto cast_type = SubstraitToDuckType(scast.type()); + auto cast_child = TranslateExpr(scast.input()); + + return make_uniq(cast_type, std::move(cast_child)); + } + + + unique_ptr + DuckDBTranslator::TranslateInExpr(const substrait::Expression& sexpr) { + const auto &substrait_in = sexpr.singular_or_list(); + + vector> values; + values.emplace_back(TranslateExpr(substrait_in.value())); + + for (idx_t i = 0; i < (idx_t)substrait_in.options_size(); i++) { + values.emplace_back(TranslateExpr(substrait_in.options(i))); + } + + return make_uniq( + ExpressionType::COMPARE_IN, std::move(values) + ); + } + + + // >> Top-level translation function + using SExprType = substrait::Expression::RexTypeCase; + + unique_ptr + DuckDBTranslator::TranslateExpr(const substrait::Expression& sexpr) { + switch (sexpr.rex_type_case()) { + case SExprType::kLiteral: return TranslateLiteralExpr (sexpr.literal()); + case SExprType::kSelection: return TranslateSelectionExpr (sexpr); + case SExprType::kScalarFunction: return TranslateScalarFunctionExpr(sexpr); + case SExprType::kIfThen: return TranslateIfThenExpr (sexpr); + case SExprType::kCast: return TranslateCastExpr (sexpr); + case SExprType::kSingularOrList: return TranslateInExpr (sexpr); + case SExprType::kSubquery: + default: + throw InternalException( + "Unsupported expression type " + to_string(sexpr.rex_type_case()) + ); + } + } + +} // namespace: duckdb diff --git a/src/translation/duckdb_operators.cpp b/src/translation/duckdb_operators.cpp new file mode 100644 index 0000000..dc9968e --- /dev/null +++ b/src/translation/duckdb_operators.cpp @@ -0,0 +1,374 @@ +// ------------------------------ +// License +// +// Copyright 2024 Aldrin Montana +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + + +// ------------------------------ +// Dependencies + +#include "engine_duckdb.hpp" + + +// ------------------------------ +// Macros and Type Aliases + + +// ------------------------------ +// Functions + +namespace duckdb { + + static duckdb::SetOperationType + TranslateSetOperationType(substrait::SetRel::SetOp setop) { + switch (setop) { + case substrait::SetRel::SET_OP_UNION_ALL: { + return duckdb::SetOperationType::UNION; + } + + case substrait::SetRel::SET_OP_MINUS_PRIMARY: { + return duckdb::SetOperationType::EXCEPT; + } + + case substrait::SetRel::SET_OP_INTERSECTION_PRIMARY: { + return duckdb::SetOperationType::INTERSECT; + } + + default: { + throw duckdb::NotImplementedException( + "SetOperationType transform not implemented for SetRel type %d" + ,setop + ); + } + } + } + + + static duckdb::JoinType + TranslateJoinType(const substrait::JoinRel& sjoin) { + switch (sjoin.type()) { + case substrait::JoinRel::JOIN_TYPE_INNER: return duckdb::JoinType::INNER; + case substrait::JoinRel::JOIN_TYPE_LEFT: return duckdb::JoinType::LEFT; + case substrait::JoinRel::JOIN_TYPE_RIGHT: return duckdb::JoinType::RIGHT; + case substrait::JoinRel::JOIN_TYPE_SINGLE: return duckdb::JoinType::SINGLE; + case substrait::JoinRel::JOIN_TYPE_SEMI: return duckdb::JoinType::SEMI; + + default: + throw InternalException("Unsupported join type"); + } + } + + + OrderByNode + DuckDBTranslator::TranslateOrder(const substrait::SortField& sordf) { + OrderType dordertype; + OrderByNullType dnullorder; + + switch (sordf.direction()) { + case substrait::SortField::SORT_DIRECTION_ASC_NULLS_FIRST: + dordertype = OrderType::ASCENDING; + dnullorder = OrderByNullType::NULLS_FIRST; + break; + + case substrait::SortField::SORT_DIRECTION_ASC_NULLS_LAST: + dordertype = OrderType::ASCENDING; + dnullorder = OrderByNullType::NULLS_LAST; + break; + + case substrait::SortField::SORT_DIRECTION_DESC_NULLS_FIRST: + dordertype = OrderType::DESCENDING; + dnullorder = OrderByNullType::NULLS_FIRST; + break; + + case substrait::SortField::SORT_DIRECTION_DESC_NULLS_LAST: + dordertype = OrderType::DESCENDING; + dnullorder = OrderByNullType::NULLS_LAST; + break; + + default: + throw InternalException("Unsupported ordering " + to_string(sordf.direction())); + } + + return { dordertype, dnullorder, TranslateExpr(sordf.expr()) }; + } + + shared_ptr + DuckDBTranslator::TranslateJoinOp(const substrait::JoinRel& sjoin) { + JoinType djointype = TranslateJoinType(sjoin); + unique_ptr join_condition = TranslateExpr(sjoin.expression()); + + return make_shared( + TranslateOp(sjoin.left())->Alias("left") + ,TranslateOp(sjoin.right())->Alias("right") + ,std::move(join_condition) + ,djointype + ); + } + + shared_ptr + DuckDBTranslator::TranslateCrossProductOp(const substrait::CrossRel& scross) { + return make_shared( + TranslateOp(scross.left())->Alias("left") + ,TranslateOp(scross.right())->Alias("right") + ); + } + + shared_ptr + DuckDBTranslator::TranslateFetchOp(const substrait::FetchRel& slimit) { + return make_shared( + TranslateOp(slimit.input()) + ,slimit.count() + ,slimit.offset() + ); + } + + shared_ptr + DuckDBTranslator::TranslateFilterOp(const substrait::FilterRel& sfilter) { + return make_shared( + TranslateOp(sfilter.input()) + ,TranslateExpr(sfilter.condition()) + ); + } + + shared_ptr + DuckDBTranslator::TranslateProjectOp(const substrait::ProjectRel& sproj) { + vector> expressions; + for (auto &sexpr : sproj.expressions()) { + expressions.push_back(TranslateExpr(sexpr)); + } + + vector mock_aliases; + for (size_t i = 0; i < expressions.size(); i++) { + mock_aliases.push_back("expr_" + to_string(i)); + } + + return make_shared( + TranslateOp(sproj.input()) + ,std::move(expressions) + ,std::move(mock_aliases) + ); + } + + + shared_ptr + DuckDBTranslator::TranslateAggregateOp(const substrait::AggregateRel& saggr) { + vector> groups, expressions; + + if (saggr.groupings_size() > 0) { + for (auto &sgrp : saggr.groupings()) { + for (auto &sgrpexpr : sgrp.grouping_expressions()) { + groups.push_back(TranslateExpr(sgrpexpr)); + expressions.push_back(TranslateExpr(sgrpexpr)); + } + } + } + + for (auto &smeas : saggr.measures()) { + vector> children; + for (auto &sarg : smeas.measure().arguments()) { + children.push_back(TranslateExpr(sarg.value())); + } + + auto function_id = smeas.measure().function_reference(); + auto function_name = functions_map->FindExtensionFunction(function_id); + if (function_name == "count" && children.empty()) { function_name = "count_star"; } + + expressions.push_back( + make_uniq( + RemapFunctionName(function_name), std::move(children) + ) + ); + } + + return make_shared( + TranslateOp(saggr.input()) + ,std::move(expressions) + ,std::move(groups) + ); + } + + + shared_ptr + DuckDBTranslator::TranslateReadOp(const substrait::ReadRel& sget) { + shared_ptr scan; + // Find a table or view with given name + if (sget.has_named_table()) { + try { scan = t_conn->Table(sget.named_table().names(0)); } + catch (...) { scan = t_conn->View (sget.named_table().names(0)); } + } + + // Otherwise, try scanning from list of parquet files + else if (sget.has_local_files()) { + vector parquet_files; + + auto local_file_items = sget.local_files().items(); + for (auto ¤t_file : local_file_items) { + if (current_file.has_parquet()) { + + if (current_file.has_uri_file()) { + parquet_files.emplace_back(current_file.uri_file()); + } + + else if (current_file.has_uri_path()) { + parquet_files.emplace_back(current_file.uri_path()); + } + + else { + throw NotImplementedException( + "Unsupported type for file path, Only uri_file and uri_path are " + "currently supported" + ); + } + } + + else { + throw NotImplementedException( + "Unsupported type of local file for read operator on substrait" + ); + } + } + + string name = "parquet_" + StringUtil::GenerateRandomName(); + named_parameter_map_t named_parameters({{"binary_as_string", Value::BOOLEAN(false)}}); + + scan = t_conn->TableFunction( "parquet_scan" + ,{Value::LIST(parquet_files)} + ,named_parameters + )->Alias(name); + } + + else { + throw NotImplementedException("Unsupported type of read operator for substrait"); + } + + // Filter predicate for scan operation + if (sget.has_filter()) { + scan = make_shared(std::move(scan), TranslateExpr(sget.filter())); + } + + // Projection predicate for scan operation + if (sget.has_projection()) { + vector aliases; + vector> expressions; + + idx_t expr_idx = 0; + for (auto &sproj : sget.projection().select().struct_items()) { + // FIXME how to get actually alias? + aliases.push_back("expr_" + to_string(expr_idx++)); + + // TODO make sure nothing else is in there + expressions.push_back( + make_uniq(sproj.field() + 1) + ); + } + + scan = make_shared( + std::move(scan), std::move(expressions), std::move(aliases) + ); + } + + return scan; + } + + + shared_ptr + DuckDBTranslator::TranslateSortOp(const substrait::SortRel &ssort) { + vector order_nodes; + for (auto &sordf : ssort.sorts()) { + order_nodes.push_back(TranslateOrder(sordf)); + } + + return make_shared(TranslateOp(ssort.input()), std::move(order_nodes)); + } + + + shared_ptr + DuckDBTranslator::TranslateSetOp(const substrait::SetRel &sset) { + // TODO: see if this is necessary for some cases + // D_ASSERT(sop.has_set()); + + auto type = TranslateSetOperationType(sset.op()); + auto& inputs = sset.inputs(); + if (sset.inputs_size() > 2) { + throw NotImplementedException( + "Too many inputs (%d) for this set operation", sset.inputs_size() + ); + } + + auto lhs = TranslateOp(inputs[0]); + auto rhs = TranslateOp(inputs[1]); + return make_shared(std::move(lhs), std::move(rhs), type); + } + + + //! Translate Substrait Operations to DuckDB Relations + using SRelType = substrait::Rel::RelTypeCase; + shared_ptr DuckDBTranslator::TranslateOp(const substrait::Rel& sop) { + switch (sop.rel_type_case()) { + case SRelType::kJoin: return TranslateJoinOp (sop.join()); + case SRelType::kCross: return TranslateCrossProductOp(sop.cross()); + case SRelType::kFetch: return TranslateFetchOp (sop.fetch()); + case SRelType::kFilter: return TranslateFilterOp (sop.filter()); + case SRelType::kProject: return TranslateProjectOp (sop.project()); + case SRelType::kAggregate: return TranslateAggregateOp (sop.aggregate()); + case SRelType::kRead: return TranslateReadOp (sop.read()); + case SRelType::kSort: return TranslateSortOp (sop.sort()); + case SRelType::kSet: return TranslateSetOp (sop.set()); + + default: + throw InternalException( + "Unsupported relation type " + to_string(sop.rel_type_case()) + ); + } + } + + + //! Translates Substrait Plan Root To a DuckDB Relation + shared_ptr DuckDBTranslator::TranslateRootOp(const substrait::RelRoot& sop) { + vector aliases; + vector> expressions; + + int id = 1; + for (auto &column_name : sop.names()) { + aliases.push_back(column_name); + expressions.push_back(make_uniq(id++)); + } + + return make_shared( + TranslateOp(sop.input()), std::move(expressions), aliases + ); + } + + + // >> Entry points into the functions implemented here + unique_ptr + DuckDBTranslator::TranslatePlanMessage(const string& serialized_msg) { + auto plan = mohair::SubstraitPlanFromSubstraitMessage(serialized_msg); + functions_map->RegisterExtensionFunctions(*plan); + + auto engine_plan = TranslateRootOp(plan->relations(0).root()); + return make_uniq(std::move(plan), engine_plan); + } + + unique_ptr + DuckDBTranslator::TranslatePlanJson(const string& json_msg) { + auto plan = mohair::SubstraitPlanFromSubstraitJson(json_msg); + functions_map->RegisterExtensionFunctions(*plan); + + auto engine_plan = TranslateRootOp(plan->relations(0).root()); + return make_uniq(std::move(plan), engine_plan); + } + +} // namespace: duckdb diff --git a/substrait b/substrait index d9b9672..d95bcf8 160000 --- a/substrait +++ b/substrait @@ -1 +1 @@ -Subproject commit d9b9672fd3c24285afdee9344fc2f4f7fcd70afb +Subproject commit d95bcf83a9763e9b317f3b5332995b01b8c6e392