From 6c48fddd1a7381f538057907bc7cdadbfc34c23f Mon Sep 17 00:00:00 2001 From: David Li Date: Fri, 1 Jul 2022 14:33:20 -0400 Subject: [PATCH] RFC: [FlightRPC][WIP] Substrait, transaction, cancellation for Flight SQL --- .../flight_integration_test.cc | 4 + .../integration_tests/test_integration.cc | 656 ++++++++++++++++-- cpp/src/arrow/flight/server.cc | 4 +- cpp/src/arrow/flight/sql/CMakeLists.txt | 27 +- cpp/src/arrow/flight/sql/acero_test.cc | 239 +++++++ cpp/src/arrow/flight/sql/client.cc | 355 +++++++++- cpp/src/arrow/flight/sql/client.h | 152 +++- cpp/src/arrow/flight/sql/client_test.cc | 1 + cpp/src/arrow/flight/sql/column_metadata.cc | 8 +- cpp/src/arrow/flight/sql/column_metadata.h | 8 +- .../arrow/flight/sql/example/acero_main.cc | 70 ++ .../arrow/flight/sql/example/acero_server.cc | 306 ++++++++ .../arrow/flight/sql/example/acero_server.h | 37 + .../arrow/flight/sql/example/sqlite_server.cc | 205 ++++-- .../arrow/flight/sql/example/sqlite_server.h | 6 + .../flight/sql/example/sqlite_sql_info.cc | 9 +- cpp/src/arrow/flight/sql/server.cc | 424 +++++++++-- cpp/src/arrow/flight/sql/server.h | 171 ++++- cpp/src/arrow/flight/sql/server_test.cc | 83 ++- cpp/src/arrow/flight/sql/types.h | 79 ++- dev/archery/archery/integration/runner.py | 5 + docs/source/status.rst | 14 + format/FlightSql.proto | 298 +++++++- .../org/apache/arrow/flight/FlightClient.java | 7 +- .../apache/arrow/flight/FlightService.java | 2 +- .../tests/FlightSqlExtensionScenario.java | 217 ++++++ .../integration/tests/FlightSqlScenario.java | 52 +- .../tests/FlightSqlScenarioProducer.java | 383 ++++++++-- .../tests/IntegrationAssertions.java | 11 + .../flight/integration/tests/Scenarios.java | 1 + .../integration/tests/IntegrationTest.java | 70 ++ .../arrow/flight/sql/CancelListener.java | 51 ++ .../apache/arrow/flight/sql/CancelResult.java | 45 ++ .../arrow/flight/sql/FlightSqlClient.java | 451 +++++++++++- .../arrow/flight/sql/FlightSqlProducer.java | 198 +++++- .../arrow/flight/sql/FlightSqlUtils.java | 35 + .../arrow/flight/sql/NoResultListener.java | 45 ++ .../arrow/flight/sql/ProtoListener.java | 52 ++ .../arrow/flight/sql/SqlInfoBuilder.java | 41 ++ .../flight/sql/example/FlightSqlExample.java | 3 + 40 files changed, 4445 insertions(+), 380 deletions(-) create mode 100644 cpp/src/arrow/flight/sql/acero_test.cc create mode 100644 cpp/src/arrow/flight/sql/example/acero_main.cc create mode 100644 cpp/src/arrow/flight/sql/example/acero_server.cc create mode 100644 cpp/src/arrow/flight/sql/example/acero_server.h create mode 100644 java/flight/flight-integration-tests/src/main/java/org/apache/arrow/flight/integration/tests/FlightSqlExtensionScenario.java create mode 100644 java/flight/flight-integration-tests/src/test/java/org/apache/arrow/flight/integration/tests/IntegrationTest.java create mode 100644 java/flight/flight-sql/src/main/java/org/apache/arrow/flight/sql/CancelListener.java create mode 100644 java/flight/flight-sql/src/main/java/org/apache/arrow/flight/sql/CancelResult.java create mode 100644 java/flight/flight-sql/src/main/java/org/apache/arrow/flight/sql/NoResultListener.java create mode 100644 java/flight/flight-sql/src/main/java/org/apache/arrow/flight/sql/ProtoListener.java diff --git a/cpp/src/arrow/flight/integration_tests/flight_integration_test.cc b/cpp/src/arrow/flight/integration_tests/flight_integration_test.cc index 706ac3b7d931b..e29a281f32721 100644 --- a/cpp/src/arrow/flight/integration_tests/flight_integration_test.cc +++ b/cpp/src/arrow/flight/integration_tests/flight_integration_test.cc @@ -55,6 +55,10 @@ TEST(FlightIntegration, Middleware) { ASSERT_OK(RunScenario("middleware")); } TEST(FlightIntegration, FlightSql) { ASSERT_OK(RunScenario("flight_sql")); } +TEST(FlightIntegration, FlightSqlExtension) { + ASSERT_OK(RunScenario("flight_sql:extension")); +} + } // namespace integration_tests } // namespace flight } // namespace arrow diff --git a/cpp/src/arrow/flight/integration_tests/test_integration.cc b/cpp/src/arrow/flight/integration_tests/test_integration.cc index b228f9cceba06..43c16e0b77a6d 100644 --- a/cpp/src/arrow/flight/integration_tests/test_integration.cc +++ b/cpp/src/arrow/flight/integration_tests/test_integration.cc @@ -16,25 +16,36 @@ // under the License. #include "arrow/flight/integration_tests/test_integration.h" + +#include +#include +#include +#include +#include +#include + +#include "arrow/array/array_binary.h" +#include "arrow/array/array_nested.h" +#include "arrow/array/array_primitive.h" #include "arrow/flight/client_middleware.h" #include "arrow/flight/server_middleware.h" #include "arrow/flight/sql/client.h" #include "arrow/flight/sql/column_metadata.h" #include "arrow/flight/sql/server.h" +#include "arrow/flight/sql/types.h" #include "arrow/flight/test_util.h" #include "arrow/flight/types.h" #include "arrow/ipc/dictionary.h" +#include "arrow/status.h" #include "arrow/testing/gtest_util.h" - -#include -#include -#include -#include -#include +#include "arrow/util/checked_cast.h" namespace arrow { namespace flight { namespace integration_tests { +namespace { + +using arrow::internal::checked_cast; /// \brief The server for the basic auth integration test. class AuthBasicProtoServer : public FlightServerBase { @@ -263,29 +274,56 @@ class MiddlewareScenario : public Scenario { }; /// \brief Schema to be returned for mocking the statement/prepared statement results. +/// /// Must be the same across all languages. -std::shared_ptr GetQuerySchema() { - std::string table_name = "test"; - std::string schema_name = "schema_test"; - std::string catalog_name = "catalog_test"; - std::string type_name = "type_test"; - return arrow::schema({arrow::field("id", int64(), true, - arrow::flight::sql::ColumnMetadata::Builder() - .TableName(table_name) - .IsAutoIncrement(true) - .IsCaseSensitive(false) - .TypeName(type_name) - .SchemaName(schema_name) - .IsSearchable(true) - .CatalogName(catalog_name) - .Precision(100) - .Build() - .metadata_map())}); +const std::shared_ptr& GetQuerySchema() { + static std::shared_ptr kSchema = + schema({field("id", int64(), /*nullable=*/true, + arrow::flight::sql::ColumnMetadata::Builder() + .TableName("test") + .IsAutoIncrement(true) + .IsCaseSensitive(false) + .TypeName("type_test") + .SchemaName("schema_test") + .IsSearchable(true) + .CatalogName("catalog_test") + .Precision(100) + .Build() + .metadata_map())}); + return kSchema; +} + +/// \brief Schema to be returned for queries with transactions. +/// +/// Must be the same across all languages. +std::shared_ptr GetQueryWithTransactionSchema() { + static std::shared_ptr kSchema = + schema({field("pkey", int32(), /*nullable=*/true, + arrow::flight::sql::ColumnMetadata::Builder() + .TableName("test") + .IsAutoIncrement(true) + .IsCaseSensitive(false) + .TypeName("type_test") + .SchemaName("schema_test") + .IsSearchable(true) + .CatalogName("catalog_test") + .Precision(100) + .Build() + .metadata_map())}); + return kSchema; } constexpr int64_t kUpdateStatementExpectedRows = 10000L; +constexpr int64_t kUpdateStatementWithTransactionExpectedRows = 15000L; constexpr int64_t kUpdatePreparedStatementExpectedRows = 20000L; +constexpr int64_t kUpdatePreparedStatementWithTransactionExpectedRows = 25000L; constexpr char kSelectStatement[] = "SELECT STATEMENT"; +constexpr char kSavepointId[] = "savepoint_id"; +constexpr char kSavepointName[] = "savepoint_name"; +constexpr char kSubstraitPlanText[] = "plan"; +constexpr char kSubstraitVersion[] = "version"; +static const sql::SubstraitPlan kSubstraitPlan{kSubstraitPlanText, kSubstraitVersion}; +constexpr char kTransactionId[] = "transaction_id"; template arrow::Status AssertEq(const T& expected, const T& actual, const std::string& message) { @@ -296,25 +334,83 @@ arrow::Status AssertEq(const T& expected, const T& actual, const std::string& me return Status::OK(); } +template +arrow::Status AssertUnprintableEq(const T& expected, const T& actual, + const std::string& message) { + if (expected != actual) { + return Status::Invalid(message); + } + return Status::OK(); +} + /// \brief The server used for testing Flight SQL, this implements a static Flight SQL /// server which only asserts that commands called during integration tests are being /// parsed correctly and returns the expected schemas to be validated on client. class FlightSqlScenarioServer : public sql::FlightSqlServerBase { public: + FlightSqlScenarioServer() : sql::FlightSqlServerBase() { + RegisterSqlInfo(sql::SqlInfoOptions::SqlInfo::FLIGHT_SQL_SERVER_SQL, + sql::SqlInfoResult(false)); + RegisterSqlInfo(sql::SqlInfoOptions::SqlInfo::FLIGHT_SQL_SERVER_SUBSTRAIT, + sql::SqlInfoResult(true)); + RegisterSqlInfo(sql::SqlInfoOptions::SqlInfo::FLIGHT_SQL_SERVER_SUBSTRAIT_MIN_VERSION, + sql::SqlInfoResult(std::string("min_version"))); + RegisterSqlInfo(sql::SqlInfoOptions::SqlInfo::FLIGHT_SQL_SERVER_SUBSTRAIT_MAX_VERSION, + sql::SqlInfoResult(std::string("max_version"))); + RegisterSqlInfo(sql::SqlInfoOptions::SqlInfo::FLIGHT_SQL_SERVER_TRANSACTION, + sql::SqlInfoResult(sql::SqlInfoOptions::SqlSupportedTransaction:: + SQL_SUPPORTED_TRANSACTION_SAVEPOINT)); + RegisterSqlInfo(sql::SqlInfoOptions::SqlInfo::FLIGHT_SQL_SERVER_CANCEL, + sql::SqlInfoResult(true)); + RegisterSqlInfo(sql::SqlInfoOptions::SqlInfo::FLIGHT_SQL_SERVER_STATEMENT_TIMEOUT, + sql::SqlInfoResult(int32_t(42))); + RegisterSqlInfo(sql::SqlInfoOptions::SqlInfo::FLIGHT_SQL_SERVER_TRANSACTION_TIMEOUT, + sql::SqlInfoResult(int32_t(7))); + } arrow::Result> GetFlightInfoStatement( const ServerCallContext& context, const sql::StatementQuery& command, const FlightDescriptor& descriptor) override { ARROW_RETURN_NOT_OK( AssertEq(kSelectStatement, command.query, "Unexpected statement in GetFlightInfoStatement")); - - ARROW_ASSIGN_OR_RAISE(auto handle, - sql::CreateStatementQueryTicket("SELECT STATEMENT HANDLE")); - + std::string ticket; + Schema* schema; + if (command.transaction_id.empty()) { + ticket = "SELECT STATEMENT HANDLE"; + schema = GetQuerySchema().get(); + } else { + ticket = "SELECT STATEMENT WITH TXN HANDLE"; + schema = GetQueryWithTransactionSchema().get(); + } + ARROW_ASSIGN_OR_RAISE(auto handle, sql::CreateStatementQueryTicket(ticket)); std::vector endpoints{FlightEndpoint{{handle}, {}}}; - ARROW_ASSIGN_OR_RAISE( - auto result, FlightInfo::Make(*GetQuerySchema(), descriptor, endpoints, -1, -1)) + ARROW_ASSIGN_OR_RAISE(auto result, + FlightInfo::Make(*schema, descriptor, endpoints, -1, -1)); + return std::unique_ptr(new FlightInfo(result)); + } + arrow::Result> GetFlightInfoSubstraitPlan( + const ServerCallContext& context, const sql::StatementSubstraitPlan& command, + const FlightDescriptor& descriptor) override { + ARROW_RETURN_NOT_OK( + AssertEq(kSubstraitPlanText, command.plan.plan, + "Unexpected plan in GetFlightInfoSubstraitPlan")); + ARROW_RETURN_NOT_OK( + AssertEq(kSubstraitVersion, command.plan.version, + "Unexpected version in GetFlightInfoSubstraitPlan")); + std::string ticket; + Schema* schema; + if (command.transaction_id.empty()) { + ticket = "PLAN HANDLE"; + schema = GetQuerySchema().get(); + } else { + ticket = "PLAN WITH TXN HANDLE"; + schema = GetQueryWithTransactionSchema().get(); + } + ARROW_ASSIGN_OR_RAISE(auto handle, sql::CreateStatementQueryTicket(ticket)); + std::vector endpoints{FlightEndpoint{{handle}, {}}}; + ARROW_ASSIGN_OR_RAISE(auto result, + FlightInfo::Make(*schema, descriptor, endpoints, -1, -1)); return std::unique_ptr(new FlightInfo(result)); } @@ -323,38 +419,84 @@ class FlightSqlScenarioServer : public sql::FlightSqlServerBase { const FlightDescriptor& descriptor) override { ARROW_RETURN_NOT_OK(AssertEq( kSelectStatement, command.query, "Unexpected statement in GetSchemaStatement")); - return SchemaResult::Make(*GetQuerySchema()); + if (command.transaction_id.empty()) { + return SchemaResult::Make(*GetQuerySchema()); + } else { + return SchemaResult::Make(*GetQueryWithTransactionSchema()); + } + } + + arrow::Result> GetSchemaSubstraitPlan( + const ServerCallContext& context, const sql::StatementSubstraitPlan& command, + const FlightDescriptor& descriptor) override { + ARROW_RETURN_NOT_OK( + AssertEq(kSubstraitPlanText, command.plan.plan, + "Unexpected statement in GetSchemaSubstraitPlan")); + ARROW_RETURN_NOT_OK( + AssertEq(kSubstraitVersion, command.plan.version, + "Unexpected version in GetFlightInfoSubstraitPlan")); + if (command.transaction_id.empty()) { + return SchemaResult::Make(*GetQuerySchema()); + } else { + return SchemaResult::Make(*GetQueryWithTransactionSchema()); + } } arrow::Result> DoGetStatement( const ServerCallContext& context, const sql::StatementQueryTicket& command) override { - return DoGetForTestCase(GetQuerySchema()); + if (command.statement_handle == "SELECT STATEMENT HANDLE" || + command.statement_handle == "PLAN HANDLE") { + return DoGetForTestCase(GetQuerySchema()); + } else if (command.statement_handle == "SELECT STATEMENT WITH TXN HANDLE" || + command.statement_handle == "PLAN WITH TXN HANDLE") { + return DoGetForTestCase(GetQueryWithTransactionSchema()); + } + return Status::Invalid("Unknown handle: ", command.statement_handle); } arrow::Result> GetFlightInfoPreparedStatement( const ServerCallContext& context, const sql::PreparedStatementQuery& command, const FlightDescriptor& descriptor) override { - ARROW_RETURN_NOT_OK(AssertEq("SELECT PREPARED STATEMENT HANDLE", - command.prepared_statement_handle, - "Unexpected prepared statement handle")); - - return GetFlightInfoForCommand(descriptor, GetQuerySchema()); + if (command.prepared_statement_handle == "SELECT PREPARED STATEMENT HANDLE" || + command.prepared_statement_handle == "PLAN HANDLE") { + return GetFlightInfoForCommand(descriptor, GetQuerySchema()); + } else if (command.prepared_statement_handle == + "SELECT PREPARED STATEMENT WITH TXN HANDLE" || + command.prepared_statement_handle == "PLAN WITH TXN HANDLE") { + return GetFlightInfoForCommand(descriptor, GetQueryWithTransactionSchema()); + } + return Status::Invalid("Invalid handle for GetFlightInfoForCommand: ", + command.prepared_statement_handle); } arrow::Result> GetSchemaPreparedStatement( const ServerCallContext& context, const sql::PreparedStatementQuery& command, const FlightDescriptor& descriptor) override { - ARROW_RETURN_NOT_OK(AssertEq("SELECT PREPARED STATEMENT HANDLE", - command.prepared_statement_handle, - "Unexpected prepared statement handle")); - return SchemaResult::Make(*GetQuerySchema()); + if (command.prepared_statement_handle == "SELECT PREPARED STATEMENT HANDLE" || + command.prepared_statement_handle == "PLAN HANDLE") { + return SchemaResult::Make(*GetQuerySchema()); + } else if (command.prepared_statement_handle == + "SELECT PREPARED STATEMENT WITH TXN HANDLE" || + command.prepared_statement_handle == "PLAN WITH TXN HANDLE") { + return SchemaResult::Make(*GetQueryWithTransactionSchema()); + } + return Status::Invalid("Invalid handle for GetSchemaPreparedStatement: ", + command.prepared_statement_handle); } arrow::Result> DoGetPreparedStatement( const ServerCallContext& context, const sql::PreparedStatementQuery& command) override { - return DoGetForTestCase(GetQuerySchema()); + if (command.prepared_statement_handle == "SELECT PREPARED STATEMENT HANDLE" || + command.prepared_statement_handle == "PLAN HANDLE") { + return DoGetForTestCase(GetQuerySchema()); + } else if (command.prepared_statement_handle == + "SELECT PREPARED STATEMENT WITH TXN HANDLE" || + command.prepared_statement_handle == "PLAN WITH TXN HANDLE") { + return DoGetForTestCase(GetQueryWithTransactionSchema()); + } + return Status::Invalid("Invalid handle: ", command.prepared_statement_handle); } arrow::Result> GetFlightInfoCatalogs( @@ -381,21 +523,29 @@ class FlightSqlScenarioServer : public sql::FlightSqlServerBase { arrow::Result> GetFlightInfoSqlInfo( const ServerCallContext& context, const sql::GetSqlInfo& command, const FlightDescriptor& descriptor) override { - ARROW_RETURN_NOT_OK(AssertEq(2, command.info.size(), - "Wrong number of SqlInfo values passed")); - ARROW_RETURN_NOT_OK( - AssertEq(sql::SqlInfoOptions::SqlInfo::FLIGHT_SQL_SERVER_NAME, - command.info[0], "Unexpected SqlInfo passed")); - ARROW_RETURN_NOT_OK( - AssertEq(sql::SqlInfoOptions::SqlInfo::FLIGHT_SQL_SERVER_READ_ONLY, - command.info[1], "Unexpected SqlInfo passed")); - - return GetFlightInfoForCommand(descriptor, sql::SqlSchema::GetSqlInfoSchema()); + if (command.info.size() == 2) { + // Integration test for the protocol messages + ARROW_RETURN_NOT_OK( + AssertEq(sql::SqlInfoOptions::SqlInfo::FLIGHT_SQL_SERVER_NAME, + command.info[0], "Unexpected SqlInfo passed")); + ARROW_RETURN_NOT_OK( + AssertEq(sql::SqlInfoOptions::SqlInfo::FLIGHT_SQL_SERVER_READ_ONLY, + command.info[1], "Unexpected SqlInfo passed")); + + return GetFlightInfoForCommand(descriptor, sql::SqlSchema::GetSqlInfoSchema()); + } + // Integration test for the values themselves + return sql::FlightSqlServerBase::GetFlightInfoSqlInfo(context, command, descriptor); } arrow::Result> DoGetSqlInfo( const ServerCallContext& context, const sql::GetSqlInfo& command) override { - return DoGetForTestCase(sql::SqlSchema::GetSqlInfoSchema()); + if (command.info.size() == 2) { + // Integration test for the protocol messages + return DoGetForTestCase(sql::SqlSchema::GetSqlInfoSchema()); + } + // Integration test for the values themselves + return sql::FlightSqlServerBase::DoGetSqlInfo(context, command); } arrow::Result> GetFlightInfoSchemas( @@ -539,8 +689,21 @@ class FlightSqlScenarioServer : public sql::FlightSqlServerBase { ARROW_RETURN_NOT_OK( AssertEq("UPDATE STATEMENT", command.query, "Wrong query for DoPutCommandStatementUpdate")); + return command.transaction_id.empty() ? kUpdateStatementExpectedRows + : kUpdateStatementWithTransactionExpectedRows; + } - return kUpdateStatementExpectedRows; + arrow::Result DoPutCommandSubstraitPlan( + const ServerCallContext& context, + const sql::StatementSubstraitPlan& command) override { + ARROW_RETURN_NOT_OK( + AssertEq(kSubstraitPlanText, command.plan.plan, + "Wrong plan for DoPutCommandSubstraitPlan")); + ARROW_RETURN_NOT_OK( + AssertEq(kSubstraitVersion, command.plan.version, + "Unexpected version in GetFlightInfoSubstraitPlan")); + return command.transaction_id.empty() ? kUpdateStatementExpectedRows + : kUpdateStatementWithTransactionExpectedRows; } arrow::Result CreatePreparedStatement( @@ -552,8 +715,26 @@ class FlightSqlScenarioServer : public sql::FlightSqlServerBase { } sql::ActionCreatePreparedStatementResult result; - result.prepared_statement_handle = request.query + " HANDLE"; + result.prepared_statement_handle = request.query; + if (!request.transaction_id.empty()) { + result.prepared_statement_handle += " WITH TXN"; + } + result.prepared_statement_handle += " HANDLE"; + return result; + } + arrow::Result CreatePreparedSubstraitPlan( + const ServerCallContext& context, + const sql::ActionCreatePreparedSubstraitPlanRequest& request) override { + ARROW_RETURN_NOT_OK( + AssertEq(kSubstraitPlanText, request.plan.plan, + "Wrong plan for CreatePreparedSubstraitPlan")); + ARROW_RETURN_NOT_OK( + AssertEq(kSubstraitVersion, request.plan.version, + "Unexpected version in GetFlightInfoSubstraitPlan")); + sql::ActionCreatePreparedStatementResult result; + result.prepared_statement_handle = + request.transaction_id.empty() ? "PLAN HANDLE" : "PLAN WITH TXN HANDLE"; return result; } @@ -561,7 +742,13 @@ class FlightSqlScenarioServer : public sql::FlightSqlServerBase { const ServerCallContext& context, const sql::ActionClosePreparedStatementRequest& request) override { if (request.prepared_statement_handle != "SELECT PREPARED STATEMENT HANDLE" && - request.prepared_statement_handle != "UPDATE PREPARED STATEMENT HANDLE") { + request.prepared_statement_handle != "UPDATE PREPARED STATEMENT HANDLE" && + request.prepared_statement_handle != "PLAN HANDLE" && + request.prepared_statement_handle != + "SELECT PREPARED STATEMENT WITH TXN HANDLE" && + request.prepared_statement_handle != + "UPDATE PREPARED STATEMENT WITH TXN HANDLE" && + request.prepared_statement_handle != "PLAN WITH TXN HANDLE") { return Status::Invalid("Invalid handle for ClosePreparedStatement: ", request.prepared_statement_handle); } @@ -572,28 +759,95 @@ class FlightSqlScenarioServer : public sql::FlightSqlServerBase { const sql::PreparedStatementQuery& command, FlightMessageReader* reader, FlightMetadataWriter* writer) override { - if (command.prepared_statement_handle != "SELECT PREPARED STATEMENT HANDLE") { + if (command.prepared_statement_handle != "SELECT PREPARED STATEMENT HANDLE" && + command.prepared_statement_handle != + "SELECT PREPARED STATEMENT WITH TXN HANDLE" && + command.prepared_statement_handle != "PLAN HANDLE" && + command.prepared_statement_handle != "PLAN WITH TXN HANDLE") { return Status::Invalid("Invalid handle for DoPutPreparedStatementQuery: ", command.prepared_statement_handle); } - ARROW_ASSIGN_OR_RAISE(auto actual_schema, reader->GetSchema()); ARROW_RETURN_NOT_OK(AssertEq(*GetQuerySchema(), *actual_schema, "Wrong schema for DoPutPreparedStatementQuery")); - return Status::OK(); } arrow::Result DoPutPreparedStatementUpdate( const ServerCallContext& context, const sql::PreparedStatementUpdate& command, FlightMessageReader* reader) override { - if (command.prepared_statement_handle == "UPDATE PREPARED STATEMENT HANDLE") { + if (command.prepared_statement_handle == "UPDATE PREPARED STATEMENT HANDLE" || + command.prepared_statement_handle == "PLAN HANDLE") { return kUpdatePreparedStatementExpectedRows; + } else if (command.prepared_statement_handle == + "UPDATE PREPARED STATEMENT WITH TXN HANDLE" || + command.prepared_statement_handle == "PLAN WITH TXN HANDLE") { + return kUpdatePreparedStatementWithTransactionExpectedRows; } return Status::Invalid("Invalid handle for DoPutPreparedStatementUpdate: ", command.prepared_statement_handle); } + arrow::Result BeginSavepoint( + const ServerCallContext& context, + const sql::ActionBeginSavepointRequest& request) override { + ARROW_RETURN_NOT_OK(AssertEq( + kSavepointName, request.name, "Unexpected savepoint name in BeginSavepoint")); + ARROW_RETURN_NOT_OK( + AssertEq(kTransactionId, request.transaction_id, + "Unexpected transaction ID in BeginSavepoint")); + return sql::ActionBeginSavepointResult{kSavepointId}; + } + + arrow::Result BeginTransaction( + const ServerCallContext& context, + const sql::ActionBeginTransactionRequest& request) override { + return sql::ActionBeginTransactionResult{kTransactionId}; + } + + arrow::Result CancelQuery( + const ServerCallContext& context, + const sql::ActionCancelQueryRequest& request) override { + ARROW_RETURN_NOT_OK(AssertEq(1, request.info->endpoints().size(), + "Expected 1 endpoint for CancelQuery")); + const FlightEndpoint& endpoint = request.info->endpoints()[0]; + ARROW_ASSIGN_OR_RAISE(auto ticket, + sql::StatementQueryTicket::Deserialize(endpoint.ticket.ticket)); + ARROW_RETURN_NOT_OK(AssertEq("PLAN HANDLE", ticket.statement_handle, + "Unexpected ticket in CancelQuery")); + return sql::CancelResult::kCancelled; + } + + Status EndSavepoint(const ServerCallContext& context, + const sql::ActionEndSavepointRequest& request) override { + switch (request.action) { + case sql::ActionEndSavepointRequest::kRelease: + case sql::ActionEndSavepointRequest::kRollback: + ARROW_RETURN_NOT_OK( + AssertEq(kSavepointId, request.savepoint_id, + "Unexpected savepoint ID in EndSavepoint")); + break; + default: + return Status::Invalid("Unknown action ", static_cast(request.action)); + } + return Status::OK(); + } + + Status EndTransaction(const ServerCallContext& context, + const sql::ActionEndTransactionRequest& request) override { + switch (request.action) { + case sql::ActionEndTransactionRequest::kCommit: + case sql::ActionEndTransactionRequest::kRollback: + ARROW_RETURN_NOT_OK( + AssertEq(kTransactionId, request.transaction_id, + "Unexpected transaction ID in EndTransaction")); + break; + default: + return Status::Invalid("Unknown action ", static_cast(request.action)); + } + return Status::OK(); + } + private: arrow::Result> GetFlightInfoForCommand( const FlightDescriptor& descriptor, const std::shared_ptr& schema) { @@ -615,6 +869,7 @@ class FlightSqlScenarioServer : public sql::FlightSqlServerBase { /// implementations. This should ensure that RPC objects are being built and parsed /// correctly for multiple languages and that the Arrow schemas are returned as expected. class FlightSqlScenario : public Scenario { + public: Status MakeServer(std::unique_ptr* server, FlightServerOptions* options) override { server->reset(new FlightSqlScenarioServer()); @@ -785,10 +1040,290 @@ class FlightSqlScenario : public Scenario { AssertEq(kUpdatePreparedStatementExpectedRows, updated_rows, "Wrong number of updated rows for prepared statement ExecuteUpdate")); ARROW_RETURN_NOT_OK(update_prepared_statement->Close()); + return Status::OK(); + } +}; + +/// \brief Integration test scenario for validating the Substrait and +/// transaction extensions to Flight SQL. +class FlightSqlExtensionScenario : public FlightSqlScenario { + public: + Status RunClient(std::unique_ptr client) override { + sql::FlightSqlClient sql_client(std::move(client)); + Status status; + if (!(status = ValidateMetadataRetrieval(&sql_client)).ok()) { + return status.WithMessage("MetadataRetrieval failed: ", status.message()); + } + if (!(status = ValidateStatementExecution(&sql_client)).ok()) { + return status.WithMessage("StatementExecution failed: ", status.message()); + } + if (!(status = ValidatePreparedStatementExecution(&sql_client)).ok()) { + return status.WithMessage("PreparedStatementExecution failed: ", status.message()); + } + if (!(status = ValidateTransactions(&sql_client)).ok()) { + return status.WithMessage("Transactions failed: ", status.message()); + } + return Status::OK(); + } + + Status ValidateMetadataRetrieval(sql::FlightSqlClient* sql_client) { + std::unique_ptr info; + std::vector sql_info = { + sql::SqlInfoOptions::FLIGHT_SQL_SERVER_SQL, + sql::SqlInfoOptions::FLIGHT_SQL_SERVER_SUBSTRAIT, + sql::SqlInfoOptions::FLIGHT_SQL_SERVER_SUBSTRAIT_MIN_VERSION, + sql::SqlInfoOptions::FLIGHT_SQL_SERVER_SUBSTRAIT_MAX_VERSION, + sql::SqlInfoOptions::FLIGHT_SQL_SERVER_TRANSACTION, + sql::SqlInfoOptions::FLIGHT_SQL_SERVER_CANCEL, + sql::SqlInfoOptions::FLIGHT_SQL_SERVER_STATEMENT_TIMEOUT, + sql::SqlInfoOptions::FLIGHT_SQL_SERVER_TRANSACTION_TIMEOUT, + }; + ARROW_ASSIGN_OR_RAISE(info, sql_client->GetSqlInfo({}, sql_info)); + ARROW_ASSIGN_OR_RAISE(auto reader, + sql_client->DoGet({}, info->endpoints()[0].ticket)); + + ARROW_ASSIGN_OR_RAISE(auto actual_schema, reader->GetSchema()); + if (!sql::SqlSchema::GetSqlInfoSchema()->Equals(*actual_schema, + /*check_metadata=*/true)) { + return Status::Invalid("Schemas did not match. Expected:\n", + *sql::SqlSchema::GetSqlInfoSchema(), "\nActual:\n", + *actual_schema); + } + + sql::SqlInfoResultMap info_values; + while (true) { + ARROW_ASSIGN_OR_RAISE(auto chunk, reader->Next()); + if (!chunk.data) break; + + const auto& info_name = checked_cast(*chunk.data->column(0)); + const auto& value = checked_cast(*chunk.data->column(1)); + + for (int64_t i = 0; i < chunk.data->num_rows(); i++) { + const uint32_t code = info_name.Value(i); + if (info_values.find(code) != info_values.end()) { + return Status::Invalid("Duplicate SqlInfo value ", code); + } + switch (value.type_code(i)) { + case 0: { // string + std::string slot = checked_cast(*value.field(0)) + .GetString(value.value_offset(i)); + info_values[code] = sql::SqlInfoResult(std::move(slot)); + break; + } + case 1: { // bool + bool slot = checked_cast(*value.field(1)) + .Value(value.value_offset(i)); + info_values[code] = sql::SqlInfoResult(slot); + break; + } + case 2: { // int64_t + int64_t slot = checked_cast(*value.field(2)) + .Value(value.value_offset(i)); + info_values[code] = sql::SqlInfoResult(slot); + break; + } + case 3: { // int32_t + int32_t slot = checked_cast(*value.field(3)) + .Value(value.value_offset(i)); + info_values[code] = sql::SqlInfoResult(slot); + break; + } + default: + return Status::NotImplemented("Decoding SqlInfoResult of type code ", + value.type_code(i)); + } + } + } + + ARROW_RETURN_NOT_OK(AssertUnprintableEq( + info_values[sql::SqlInfoOptions::FLIGHT_SQL_SERVER_SQL], + sql::SqlInfoResult(false), "FLIGHT_SQL_SERVER_SQL did not match")); + ARROW_RETURN_NOT_OK(AssertUnprintableEq( + info_values[sql::SqlInfoOptions::FLIGHT_SQL_SERVER_SUBSTRAIT], + sql::SqlInfoResult(true), "FLIGHT_SQL_SERVER_SUBSTRAIT did not match")); + ARROW_RETURN_NOT_OK(AssertUnprintableEq( + info_values[sql::SqlInfoOptions::FLIGHT_SQL_SERVER_SUBSTRAIT_MIN_VERSION], + sql::SqlInfoResult(std::string("min_version")), + "FLIGHT_SQL_SERVER_SUBSTRAIT_MIN_VERSION did not match")); + ARROW_RETURN_NOT_OK(AssertUnprintableEq( + info_values[sql::SqlInfoOptions::FLIGHT_SQL_SERVER_SUBSTRAIT_MAX_VERSION], + sql::SqlInfoResult(std::string("max_version")), + "FLIGHT_SQL_SERVER_SUBSTRAIT_MAX_VERSION did not match")); + ARROW_RETURN_NOT_OK(AssertUnprintableEq( + info_values[sql::SqlInfoOptions::FLIGHT_SQL_SERVER_TRANSACTION], + sql::SqlInfoResult(sql::SqlInfoOptions::SqlSupportedTransaction:: + SQL_SUPPORTED_TRANSACTION_SAVEPOINT), + "FLIGHT_SQL_SERVER_TRANSACTION did not match")); + ARROW_RETURN_NOT_OK(AssertUnprintableEq( + info_values[sql::SqlInfoOptions::FLIGHT_SQL_SERVER_CANCEL], + sql::SqlInfoResult(true), "FLIGHT_SQL_SERVER_CANCEL did not match")); + ARROW_RETURN_NOT_OK(AssertUnprintableEq( + info_values[sql::SqlInfoOptions::FLIGHT_SQL_SERVER_STATEMENT_TIMEOUT], + sql::SqlInfoResult(int32_t(42)), + "FLIGHT_SQL_SERVER_STATEMENT_TIMEOUT did not match")); + ARROW_RETURN_NOT_OK(AssertUnprintableEq( + info_values[sql::SqlInfoOptions::FLIGHT_SQL_SERVER_TRANSACTION_TIMEOUT], + sql::SqlInfoResult(int32_t(7)), + "FLIGHT_SQL_SERVER_TRANSACTION_TIMEOUT did not match")); + + return Status::OK(); + } + + Status ValidateStatementExecution(sql::FlightSqlClient* sql_client) { + ARROW_ASSIGN_OR_RAISE(std::unique_ptr info, + sql_client->ExecuteSubstrait({}, kSubstraitPlan)); + ARROW_RETURN_NOT_OK(Validate(GetQuerySchema(), *info, sql_client)); + + ARROW_ASSIGN_OR_RAISE(std::unique_ptr schema, + sql_client->GetExecuteSubstraitSchema({}, kSubstraitPlan)); + ARROW_RETURN_NOT_OK(ValidateSchema(GetQuerySchema(), *schema)); + + ARROW_ASSIGN_OR_RAISE(info, sql_client->ExecuteSubstrait({}, kSubstraitPlan)); + ARROW_ASSIGN_OR_RAISE(sql::CancelResult cancel_result, + sql_client->CancelQuery({}, *info)); + ARROW_RETURN_NOT_OK( + AssertEq(sql::CancelResult::kCancelled, cancel_result, "Wrong cancel result")); + + ARROW_ASSIGN_OR_RAISE(const int64_t updated_rows, + sql_client->ExecuteSubstraitUpdate({}, kSubstraitPlan)); + ARROW_RETURN_NOT_OK( + AssertEq(kUpdateStatementExpectedRows, updated_rows, + "Wrong number of updated rows for ExecuteSubstraitUpdate")); + + return Status::OK(); + } + + Status ValidatePreparedStatementExecution(sql::FlightSqlClient* sql_client) { + auto parameters = + RecordBatch::Make(GetQuerySchema(), 1, {ArrayFromJSON(int64(), "[1]")}); + + ARROW_ASSIGN_OR_RAISE( + std::shared_ptr substrait_prepared_statement, + sql_client->PrepareSubstrait({}, kSubstraitPlan)); + ARROW_RETURN_NOT_OK(substrait_prepared_statement->SetParameters(parameters)); + ARROW_ASSIGN_OR_RAISE(std::unique_ptr info, + substrait_prepared_statement->Execute()); + ARROW_RETURN_NOT_OK(Validate(GetQuerySchema(), *info, sql_client)); + ARROW_ASSIGN_OR_RAISE(std::unique_ptr schema, + substrait_prepared_statement->GetSchema({})); + ARROW_RETURN_NOT_OK(ValidateSchema(GetQuerySchema(), *schema)); + ARROW_RETURN_NOT_OK(substrait_prepared_statement->Close()); + + ARROW_ASSIGN_OR_RAISE( + std::shared_ptr update_substrait_prepared_statement, + sql_client->PrepareSubstrait({}, kSubstraitPlan)); + ARROW_ASSIGN_OR_RAISE(const int64_t updated_rows, + update_substrait_prepared_statement->ExecuteUpdate()); + ARROW_RETURN_NOT_OK( + AssertEq(kUpdatePreparedStatementExpectedRows, updated_rows, + "Wrong number of updated rows for prepared statement ExecuteUpdate")); + ARROW_RETURN_NOT_OK(update_substrait_prepared_statement->Close()); + + return Status::OK(); + } + + Status ValidateTransactions(sql::FlightSqlClient* sql_client) { + ARROW_ASSIGN_OR_RAISE(sql::Transaction transaction, sql_client->BeginTransaction({})); + ARROW_RETURN_NOT_OK(AssertEq( + kTransactionId, transaction.transaction_id(), "Wrong transaction ID")); + + ARROW_ASSIGN_OR_RAISE(sql::Savepoint savepoint, + sql_client->BeginSavepoint({}, transaction, kSavepointName)); + ARROW_RETURN_NOT_OK(AssertEq(kSavepointId, savepoint.savepoint_id(), + "Wrong savepoint ID")); + + ARROW_ASSIGN_OR_RAISE(std::unique_ptr info, + sql_client->Execute({}, kSelectStatement, transaction)); + ARROW_RETURN_NOT_OK(Validate(GetQueryWithTransactionSchema(), *info, sql_client)); + + ARROW_ASSIGN_OR_RAISE(info, + sql_client->ExecuteSubstrait({}, kSubstraitPlan, transaction)); + ARROW_RETURN_NOT_OK(Validate(GetQueryWithTransactionSchema(), *info, sql_client)); + + ARROW_ASSIGN_OR_RAISE( + std::unique_ptr schema, + sql_client->GetExecuteSchema({}, kSelectStatement, transaction)); + ARROW_RETURN_NOT_OK(ValidateSchema(GetQueryWithTransactionSchema(), *schema)); + + ARROW_ASSIGN_OR_RAISE( + schema, sql_client->GetExecuteSubstraitSchema({}, kSubstraitPlan, transaction)); + ARROW_RETURN_NOT_OK(ValidateSchema(GetQueryWithTransactionSchema(), *schema)); + + ARROW_ASSIGN_OR_RAISE(int64_t updated_rows, + sql_client->ExecuteUpdate({}, "UPDATE STATEMENT", transaction)); + ARROW_RETURN_NOT_OK( + AssertEq(kUpdateStatementWithTransactionExpectedRows, updated_rows, + "Wrong number of updated rows for ExecuteUpdate with transaction")); + ARROW_ASSIGN_OR_RAISE(updated_rows, sql_client->ExecuteSubstraitUpdate( + {}, kSubstraitPlan, transaction)); + ARROW_RETURN_NOT_OK(AssertEq( + kUpdateStatementWithTransactionExpectedRows, updated_rows, + "Wrong number of updated rows for ExecuteSubstraitUpdate with transaction")); + + auto parameters = + RecordBatch::Make(GetQuerySchema(), 1, {ArrayFromJSON(int64(), "[1]")}); + + ARROW_ASSIGN_OR_RAISE( + std::shared_ptr select_prepared_statement, + sql_client->Prepare({}, "SELECT PREPARED STATEMENT", transaction)); + ARROW_RETURN_NOT_OK(select_prepared_statement->SetParameters(parameters)); + ARROW_ASSIGN_OR_RAISE(info, select_prepared_statement->Execute()); + ARROW_RETURN_NOT_OK(Validate(GetQueryWithTransactionSchema(), *info, sql_client)); + ARROW_ASSIGN_OR_RAISE(schema, select_prepared_statement->GetSchema({})); + ARROW_RETURN_NOT_OK(ValidateSchema(GetQueryWithTransactionSchema(), *schema)); + ARROW_RETURN_NOT_OK(select_prepared_statement->Close()); + + ARROW_ASSIGN_OR_RAISE( + std::shared_ptr substrait_prepared_statement, + sql_client->PrepareSubstrait({}, kSubstraitPlan, transaction)); + ARROW_RETURN_NOT_OK(substrait_prepared_statement->SetParameters(parameters)); + ARROW_ASSIGN_OR_RAISE(info, substrait_prepared_statement->Execute()); + ARROW_RETURN_NOT_OK(Validate(GetQueryWithTransactionSchema(), *info, sql_client)); + ARROW_ASSIGN_OR_RAISE(schema, substrait_prepared_statement->GetSchema({})); + ARROW_RETURN_NOT_OK(ValidateSchema(GetQueryWithTransactionSchema(), *schema)); + ARROW_RETURN_NOT_OK(substrait_prepared_statement->Close()); + + ARROW_ASSIGN_OR_RAISE( + std::shared_ptr update_prepared_statement, + sql_client->Prepare({}, "UPDATE PREPARED STATEMENT", transaction)); + ARROW_ASSIGN_OR_RAISE(updated_rows, update_prepared_statement->ExecuteUpdate()); + ARROW_RETURN_NOT_OK(AssertEq(kUpdatePreparedStatementWithTransactionExpectedRows, + updated_rows, + "Wrong number of updated rows for prepared statement " + "ExecuteUpdate with transaction")); + ARROW_RETURN_NOT_OK(update_prepared_statement->Close()); + + ARROW_ASSIGN_OR_RAISE( + std::shared_ptr update_substrait_prepared_statement, + sql_client->PrepareSubstrait({}, kSubstraitPlan, transaction)); + ARROW_ASSIGN_OR_RAISE(updated_rows, + update_substrait_prepared_statement->ExecuteUpdate()); + ARROW_RETURN_NOT_OK(AssertEq(kUpdatePreparedStatementWithTransactionExpectedRows, + updated_rows, + "Wrong number of updated rows for prepared statement " + "ExecuteUpdate with transaction")); + ARROW_RETURN_NOT_OK(update_substrait_prepared_statement->Close()); + + ARROW_RETURN_NOT_OK(sql_client->Rollback({}, savepoint)); + + ARROW_ASSIGN_OR_RAISE(sql::Savepoint savepoint2, + sql_client->BeginSavepoint({}, transaction, kSavepointName)); + ARROW_RETURN_NOT_OK(AssertEq(kSavepointId, savepoint.savepoint_id(), + "Wrong savepoint ID")); + ARROW_RETURN_NOT_OK(sql_client->Release({}, savepoint)); + + ARROW_RETURN_NOT_OK(sql_client->Commit({}, transaction)); + + ARROW_ASSIGN_OR_RAISE(sql::Transaction transaction2, + sql_client->BeginTransaction({})); + ARROW_RETURN_NOT_OK(AssertEq( + kTransactionId, transaction.transaction_id(), "Wrong transaction ID")); + ARROW_RETURN_NOT_OK(sql_client->Rollback({}, transaction2)); return Status::OK(); } }; +} // namespace Status GetScenario(const std::string& scenario_name, std::shared_ptr* out) { if (scenario_name == "auth:basic_proto") { @@ -800,6 +1335,9 @@ Status GetScenario(const std::string& scenario_name, std::shared_ptr* } else if (scenario_name == "flight_sql") { *out = std::make_shared(); return Status::OK(); + } else if (scenario_name == "flight_sql:extension") { + *out = std::make_shared(); + return Status::OK(); } return Status::KeyError("Scenario not found: ", scenario_name); } diff --git a/cpp/src/arrow/flight/server.cc b/cpp/src/arrow/flight/server.cc index 1a3b52910c0ad..e9736b0615e44 100644 --- a/cpp/src/arrow/flight/server.cc +++ b/cpp/src/arrow/flight/server.cc @@ -353,7 +353,9 @@ RecordBatchStream::RecordBatchStream(const std::shared_ptr& r impl_.reset(new RecordBatchStreamImpl(reader, options)); } -RecordBatchStream::~RecordBatchStream() {} +RecordBatchStream::~RecordBatchStream() { + ARROW_WARN_NOT_OK(impl_->Close(), "Failed to close FlightDataStream"); +} Status RecordBatchStream::Close() { return impl_->Close(); } diff --git a/cpp/src/arrow/flight/sql/CMakeLists.txt b/cpp/src/arrow/flight/sql/CMakeLists.txt index f7312de23a972..14503069dd004 100644 --- a/cpp/src/arrow/flight/sql/CMakeLists.txt +++ b/cpp/src/arrow/flight/sql/CMakeLists.txt @@ -89,7 +89,11 @@ if(ARROW_BUILD_TESTS OR ARROW_BUILD_EXAMPLES) example/sqlite_statement_batch_reader.cc example/sqlite_server.cc example/sqlite_tables_schema_batch_reader.cc) + set(ARROW_FLIGHT_SQL_TEST_SRCS server_test.cc) + set(ARROW_FLIGHT_SQL_TEST_LIBS ${SQLite3_LIBRARIES}) + set(ARROW_FLIGHT_SQL_ACERO_SRCS example/acero_server.cc) + if(NOT MSVC AND NOT MINGW) # ARROW-16902: getting Protobuf generated code to have all the # proper dllexport/dllimport declarations is difficult, since @@ -98,13 +102,34 @@ if(ARROW_BUILD_TESTS OR ARROW_BUILD_EXAMPLES) list(APPEND ARROW_FLIGHT_SQL_TEST_SRCS client_test.cc) endif() + if(ARROW_COMPUTE + AND ARROW_PARQUET + AND ARROW_SUBSTRAIT) + list(APPEND ARROW_FLIGHT_SQL_TEST_SRCS ${ARROW_FLIGHT_SQL_ACERO_SRCS} acero_test.cc) + if(ARROW_FLIGHT_TEST_LINKAGE STREQUAL "static") + list(APPEND ARROW_FLIGHT_SQL_TEST_LIBS arrow_substrait_static) + else() + list(APPEND ARROW_FLIGHT_SQL_TEST_LIBS arrow_substrait_shared) + endif() + + if(ARROW_BUILD_EXAMPLES) + add_executable(acero-flight-sql-server ${ARROW_FLIGHT_SQL_ACERO_SRCS} + example/acero_main.cc) + target_link_libraries(acero-flight-sql-server + PRIVATE ${ARROW_FLIGHT_SQL_TEST_LINK_LIBS} + ${ARROW_FLIGHT_SQL_TEST_LIBS} ${GFLAGS_LIBRARIES}) + endif() + endif() + add_arrow_test(flight_sql_test SOURCES ${ARROW_FLIGHT_SQL_TEST_SRCS} ${ARROW_FLIGHT_SQL_TEST_SERVER_SRCS} STATIC_LINK_LIBS ${ARROW_FLIGHT_SQL_TEST_LINK_LIBS} - ${SQLite3_LIBRARIES} + ${ARROW_FLIGHT_SQL_TEST_LIBS} + EXTRA_INCLUDES + "${CMAKE_CURRENT_BINARY_DIR}/../" LABELS "arrow_flight_sql") diff --git a/cpp/src/arrow/flight/sql/acero_test.cc b/cpp/src/arrow/flight/sql/acero_test.cc new file mode 100644 index 0000000000000..fd3c52e74f39e --- /dev/null +++ b/cpp/src/arrow/flight/sql/acero_test.cc @@ -0,0 +1,239 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +/// Integration test using the Acero backend + +#include +#include + +#include +#include + +#include "arrow/array.h" +#include "arrow/engine/substrait/util.h" +#include "arrow/flight/server.h" +#include "arrow/flight/sql/client.h" +#include "arrow/flight/sql/example/acero_server.h" +#include "arrow/flight/sql/types.h" +#include "arrow/flight/types.h" +#include "arrow/stl_iterator.h" +#include "arrow/table.h" +#include "arrow/testing/gtest_util.h" +#include "arrow/type_fwd.h" +#include "arrow/util/checked_cast.h" + +namespace arrow { +namespace flight { +namespace sql { + +using arrow::internal::checked_cast; + +class TestAcero : public ::testing::Test { + public: + void SetUp() override { + ASSERT_OK_AND_ASSIGN(auto location, Location::ForGrpcTcp("localhost", 0)); + flight::FlightServerOptions options(location); + + ASSERT_OK_AND_ASSIGN(server_, acero_example::MakeAceroServer()); + ASSERT_OK(server_->Init(options)); + + ASSERT_OK_AND_ASSIGN(auto client, FlightClient::Connect(server_->location())); + client_.reset(new FlightSqlClient(std::move(client))); + } + + void TearDown() override { + ASSERT_OK(client_->Close()); + ASSERT_OK(server_->Shutdown()); + } + + protected: + std::unique_ptr client_; + std::unique_ptr server_; +}; + +arrow::Result> MakeSubstraitPlan() { + ARROW_ASSIGN_OR_RAISE(std::string dir_string, + arrow::internal::GetEnvVar("PARQUET_TEST_DATA")); + ARROW_ASSIGN_OR_RAISE(auto dir, + arrow::internal::PlatformFilename::FromString(dir_string)); + ARROW_ASSIGN_OR_RAISE(auto filename, dir.Join("binary.parquet")); + std::string uri = std::string("file://") + filename.ToString(); + + // TODO(ARROW-17229): we should use a RootRel here + std::string json_plan = R"({ + "relations": [ + { + "rel": { + "read": { + "base_schema": { + "struct": { + "types": [ + {"binary": {}} + ] + }, + "names": [ + "foo" + ] + }, + "local_files": { + "items": [ + { + "uri_file": "URI_PLACEHOLDER", + "parquet": {} + } + ] + } + } + } + } + ] +})"; + std::string uri_placeholder = "URI_PLACEHOLDER"; + json_plan.replace(json_plan.find(uri_placeholder), uri_placeholder.size(), uri); + return engine::SerializeJsonPlan(json_plan); +} + +TEST_F(TestAcero, GetSqlInfo) { + FlightCallOptions call_options; + std::vector sql_info_codes = { + SqlInfoOptions::SqlInfo::FLIGHT_SQL_SERVER_SUBSTRAIT, + SqlInfoOptions::SqlInfo::FLIGHT_SQL_SERVER_TRANSACTION, + }; + ASSERT_OK_AND_ASSIGN(auto flight_info, + client_->GetSqlInfo(call_options, sql_info_codes)); + ASSERT_OK_AND_ASSIGN(auto reader, + client_->DoGet(call_options, flight_info->endpoints()[0].ticket)); + ASSERT_OK_AND_ASSIGN(auto results, reader->ToTable()); + ASSERT_OK_AND_ASSIGN(auto batch, results->CombineChunksToBatch()); + ASSERT_EQ(2, results->num_rows()); + std::vector> info; + const auto& ids = checked_cast(*batch->column(0)); + const auto& values = checked_cast(*batch->column(1)); + for (int64_t i = 0; i < batch->num_rows(); i++) { + ASSERT_OK_AND_ASSIGN(auto scalar, values.GetScalar(i)); + if (ids.Value(i) == SqlInfoOptions::SqlInfo::FLIGHT_SQL_SERVER_SUBSTRAIT) { + ASSERT_EQ(*checked_cast(*scalar).value, + BooleanScalar(true)); + } else if (ids.Value(i) == SqlInfoOptions::SqlInfo::FLIGHT_SQL_SERVER_TRANSACTION) { + ASSERT_EQ( + *checked_cast(*scalar).value, + Int32Scalar( + SqlInfoOptions::SqlSupportedTransaction::SQL_SUPPORTED_TRANSACTION_NONE)); + } else { + FAIL() << "Unexpected info value: " << ids.Value(i); + } + } +} + +TEST_F(TestAcero, Scan) { +#ifdef _WIN32 + GTEST_SKIP() << "ARROW-16392: Substrait File URI not supported for Windows"; +#endif + + FlightCallOptions call_options; + ASSERT_OK_AND_ASSIGN(auto serialized_plan, MakeSubstraitPlan()); + + SubstraitPlan plan{serialized_plan->ToString(), /*version=*/"0.6.0"}; + ASSERT_OK_AND_ASSIGN(std::unique_ptr info, + client_->ExecuteSubstrait(call_options, plan)); + ipc::DictionaryMemo memo; + ASSERT_OK_AND_ASSIGN(auto schema, info->GetSchema(&memo)); + // TODO(ARROW-17229): the scanner "special" fields are still included, strip them + // manually + auto fixed_schema = arrow::schema({schema->fields()[0]}); + ASSERT_NO_FATAL_FAILURE( + AssertSchemaEqual(fixed_schema, arrow::schema({field("foo", binary())}))); + + ASSERT_EQ(1, info->endpoints().size()); + ASSERT_EQ(0, info->endpoints()[0].locations.size()); + ASSERT_OK_AND_ASSIGN(auto reader, + client_->DoGet(call_options, info->endpoints()[0].ticket)); + ASSERT_OK_AND_ASSIGN(auto reader_schema, reader->GetSchema()); + ASSERT_NO_FATAL_FAILURE(AssertSchemaEqual(schema, reader_schema)); + ASSERT_OK_AND_ASSIGN(auto table, reader->ToTable()); + ASSERT_GT(table->num_rows(), 0); +} + +TEST_F(TestAcero, Update) { +#ifdef _WIN32 + GTEST_SKIP() << "ARROW-16392: Substrait File URI not supported for Windows"; +#endif + + FlightCallOptions call_options; + ASSERT_OK_AND_ASSIGN(auto serialized_plan, MakeSubstraitPlan()); + SubstraitPlan plan{serialized_plan->ToString(), /*version=*/"0.6.0"}; + EXPECT_RAISES_WITH_MESSAGE_THAT(NotImplemented, + ::testing::HasSubstr("Updates are unsupported"), + client_->ExecuteSubstraitUpdate(call_options, plan)); +} + +TEST_F(TestAcero, Prepare) { +#ifdef _WIN32 + GTEST_SKIP() << "ARROW-16392: Substrait File URI not supported for Windows"; +#endif + + FlightCallOptions call_options; + ASSERT_OK_AND_ASSIGN(auto serialized_plan, MakeSubstraitPlan()); + SubstraitPlan plan{serialized_plan->ToString(), /*version=*/"0.6.0"}; + ASSERT_OK_AND_ASSIGN(auto prepared_statement, + client_->PrepareSubstrait(call_options, plan)); + ASSERT_NE(prepared_statement->dataset_schema(), nullptr); + ASSERT_EQ(prepared_statement->parameter_schema(), nullptr); + + auto fixed_schema = arrow::schema({prepared_statement->dataset_schema()->fields()[0]}); + ASSERT_NO_FATAL_FAILURE( + AssertSchemaEqual(fixed_schema, arrow::schema({field("foo", binary())}))); + + EXPECT_RAISES_WITH_MESSAGE_THAT(NotImplemented, + ::testing::HasSubstr("Updates are unsupported"), + prepared_statement->ExecuteUpdate()); + + ASSERT_OK_AND_ASSIGN(std::unique_ptr info, prepared_statement->Execute()); + ASSERT_EQ(1, info->endpoints().size()); + ASSERT_EQ(0, info->endpoints()[0].locations.size()); + ASSERT_OK_AND_ASSIGN(auto reader, + client_->DoGet(call_options, info->endpoints()[0].ticket)); + ASSERT_OK_AND_ASSIGN(auto reader_schema, reader->GetSchema()); + ASSERT_NO_FATAL_FAILURE( + AssertSchemaEqual(prepared_statement->dataset_schema(), reader_schema)); + ASSERT_OK_AND_ASSIGN(auto table, reader->ToTable()); + ASSERT_GT(table->num_rows(), 0); + + ASSERT_OK(prepared_statement->Close()); +} + +TEST_F(TestAcero, Transactions) { +#ifdef _WIN32 + GTEST_SKIP() << "ARROW-16392: Substrait File URI not supported for Windows"; +#endif + + FlightCallOptions call_options; + ASSERT_OK_AND_ASSIGN(auto serialized_plan, MakeSubstraitPlan()); + Transaction handle("fake-id"); + SubstraitPlan plan{serialized_plan->ToString(), /*version=*/"0.6.0"}; + + EXPECT_RAISES_WITH_MESSAGE_THAT(NotImplemented, + ::testing::HasSubstr("Transactions are unsupported"), + client_->ExecuteSubstrait(call_options, plan, handle)); + EXPECT_RAISES_WITH_MESSAGE_THAT(NotImplemented, + ::testing::HasSubstr("Transactions are unsupported"), + client_->PrepareSubstrait(call_options, plan, handle)); +} + +} // namespace sql +} // namespace flight +} // namespace arrow diff --git a/cpp/src/arrow/flight/sql/client.cc b/cpp/src/arrow/flight/sql/client.cc index e299b7ceb11d4..521cf9e8cd694 100644 --- a/cpp/src/arrow/flight/sql/client.cc +++ b/cpp/src/arrow/flight/sql/client.cc @@ -66,8 +66,65 @@ arrow::Result> GetSchemaForCommand( GetFlightDescriptorForCommand(command)); return client->GetSchema(options, descriptor); } + +::arrow::Result PackAction(const std::string& action_type, + const google::protobuf::Message& message) { + google::protobuf::Any any; + if (!any.PackFrom(message)) { + return Status::SerializationError("Could not pack ", message.GetTypeName(), + " into Any"); + } + + std::string buffer; + if (!any.SerializeToString(&buffer)) { + return Status::SerializationError("Could not serialize packed ", + message.GetTypeName()); + } + + Action action; + action.type = action_type; + action.body = Buffer::FromString(std::move(buffer)); + return action; +} + +void SetPlan(const SubstraitPlan& plan, flight_sql_pb::SubstraitPlan* pb_plan) { + pb_plan->set_plan(plan.plan); + pb_plan->set_version(plan.version); +} + +Status ReadResult(ResultStream* results, google::protobuf::Message* message) { + ARROW_ASSIGN_OR_RAISE(auto result, results->Next()); + if (!result) { + return Status::IOError("Server did not return a result for ", message->GetTypeName()); + } + + google::protobuf::Any container; + if (!container.ParseFromArray(result->body->data(), + static_cast(result->body->size()))) { + return Status::IOError("Unable to parse Any (expecting ", message->GetTypeName(), + ")"); + } + if (!container.UnpackTo(message)) { + return Status::IOError("Unable to unpack Any (expecting ", message->GetTypeName(), + ")"); + } + return Status::OK(); +} + +Status DrainResultStream(ResultStream* results) { + while (true) { + ARROW_ASSIGN_OR_RAISE(auto result, results->Next()); + if (!result) break; + } + return Status::OK(); +} } // namespace +const Transaction& no_transaction() { + static Transaction kInvalidTransaction(""); + return kInvalidTransaction; +} + FlightSqlClient::FlightSqlClient(std::shared_ptr client) : impl_(std::move(client)) {} @@ -90,25 +147,59 @@ PreparedStatement::~PreparedStatement() { } arrow::Result> FlightSqlClient::Execute( - const FlightCallOptions& options, const std::string& query) { + const FlightCallOptions& options, const std::string& query, + const Transaction& transaction) { flight_sql_pb::CommandStatementQuery command; command.set_query(query); + if (transaction.is_valid()) { + command.set_transaction_id(transaction.transaction_id()); + } return GetFlightInfoForCommand(this, options, command); } arrow::Result> FlightSqlClient::GetExecuteSchema( - const FlightCallOptions& options, const std::string& query) { + const FlightCallOptions& options, const std::string& query, + const Transaction& transaction) { flight_sql_pb::CommandStatementQuery command; command.set_query(query); + if (transaction.is_valid()) { + command.set_transaction_id(transaction.transaction_id()); + } + return GetSchemaForCommand(this, options, command); +} +arrow::Result> FlightSqlClient::ExecuteSubstrait( + const FlightCallOptions& options, const SubstraitPlan& plan, + const Transaction& transaction) { + flight_sql_pb::CommandStatementSubstraitPlan command; + SetPlan(plan, command.mutable_plan()); + if (transaction.is_valid()) { + command.set_transaction_id(transaction.transaction_id()); + } + + return GetFlightInfoForCommand(this, options, command); +} + +arrow::Result> FlightSqlClient::GetExecuteSubstraitSchema( + const FlightCallOptions& options, const SubstraitPlan& plan, + const Transaction& transaction) { + flight_sql_pb::CommandStatementSubstraitPlan command; + SetPlan(plan, command.mutable_plan()); + if (transaction.is_valid()) { + command.set_transaction_id(transaction.transaction_id()); + } return GetSchemaForCommand(this, options, command); } arrow::Result FlightSqlClient::ExecuteUpdate(const FlightCallOptions& options, - const std::string& query) { + const std::string& query, + const Transaction& transaction) { flight_sql_pb::CommandStatementUpdate command; command.set_query(query); + if (transaction.is_valid()) { + command.set_transaction_id(transaction.transaction_id()); + } ARROW_ASSIGN_OR_RAISE(FlightDescriptor descriptor, GetFlightDescriptorForCommand(command)); @@ -119,14 +210,41 @@ arrow::Result FlightSqlClient::ExecuteUpdate(const FlightCallOptions& o ARROW_RETURN_NOT_OK(DoPut(options, descriptor, arrow::schema({}), &writer, &reader)); std::shared_ptr metadata; - ARROW_RETURN_NOT_OK(reader->ReadMetadata(&metadata)); + ARROW_RETURN_NOT_OK(writer->Close()); + + flight_sql_pb::DoPutUpdateResult result; + if (!result.ParseFromArray(metadata->data(), static_cast(metadata->size()))) { + return Status::Invalid("Unable to parse DoPutUpdateResult"); + } + + return result.record_count(); +} + +arrow::Result FlightSqlClient::ExecuteSubstraitUpdate( + const FlightCallOptions& options, const SubstraitPlan& plan, + const Transaction& transaction) { + flight_sql_pb::CommandStatementSubstraitPlan command; + SetPlan(plan, command.mutable_plan()); + if (transaction.is_valid()) { + command.set_transaction_id(transaction.transaction_id()); + } + + ARROW_ASSIGN_OR_RAISE(FlightDescriptor descriptor, + GetFlightDescriptorForCommand(command)); + + std::unique_ptr writer; + std::unique_ptr reader; + + ARROW_RETURN_NOT_OK(DoPut(options, descriptor, arrow::schema({}), &writer, &reader)); - flight_sql_pb::DoPutUpdateResult doPutUpdateResult; + std::shared_ptr metadata; + ARROW_RETURN_NOT_OK(reader->ReadMetadata(&metadata)); + ARROW_RETURN_NOT_OK(writer->Close()); flight_sql_pb::DoPutUpdateResult result; if (!result.ParseFromArray(metadata->data(), static_cast(metadata->size()))) { - return Status::Invalid("Unable to parse DoPutUpdateResult object."); + return Status::Invalid("Unable to parse DoPutUpdateResult"); } return result.record_count(); @@ -357,35 +475,41 @@ arrow::Result> FlightSqlClient::DoGet( } arrow::Result> FlightSqlClient::Prepare( - const FlightCallOptions& options, const std::string& query) { - google::protobuf::Any command; + const FlightCallOptions& options, const std::string& query, + const Transaction& transaction) { flight_sql_pb::ActionCreatePreparedStatementRequest request; request.set_query(query); - command.PackFrom(request); - - Action action; - action.type = "CreatePreparedStatement"; - action.body = Buffer::FromString(command.SerializeAsString()); + if (transaction.is_valid()) { + request.set_transaction_id(transaction.transaction_id()); + } std::unique_ptr results; - + ARROW_ASSIGN_OR_RAISE(auto action, PackAction("CreatePreparedStatement", request)); ARROW_RETURN_NOT_OK(DoAction(options, action, &results)); - ARROW_ASSIGN_OR_RAISE(std::unique_ptr result, results->Next()); - - google::protobuf::Any prepared_result; + return PreparedStatement::ParseResponse(this, std::move(results)); +} - std::shared_ptr message = std::move(result->body); - if (!prepared_result.ParseFromArray(message->data(), - static_cast(message->size()))) { - return Status::Invalid("Unable to parse packed ActionCreatePreparedStatementResult"); +arrow::Result> FlightSqlClient::PrepareSubstrait( + const FlightCallOptions& options, const SubstraitPlan& plan, + const Transaction& transaction) { + flight_sql_pb::ActionCreatePreparedSubstraitPlanRequest request; + SetPlan(plan, request.mutable_plan()); + if (transaction.is_valid()) { + request.set_transaction_id(transaction.transaction_id()); } - flight_sql_pb::ActionCreatePreparedStatementResult prepared_statement_result; + std::unique_ptr results; + ARROW_ASSIGN_OR_RAISE(auto action, PackAction("CreatePreparedSubstraitPlan", request)); + ARROW_RETURN_NOT_OK(DoAction(options, action, &results)); - if (!prepared_result.UnpackTo(&prepared_statement_result)) { - return Status::Invalid("Unable to unpack ActionCreatePreparedStatementResult"); - } + return PreparedStatement::ParseResponse(this, std::move(results)); +} + +arrow::Result> PreparedStatement::ParseResponse( + FlightSqlClient* client, std::unique_ptr results) { + flight_sql_pb::ActionCreatePreparedStatementResult prepared_statement_result; + ARROW_RETURN_NOT_OK(ReadResult(results.get(), &prepared_statement_result)); const std::string& serialized_dataset_schema = prepared_statement_result.dataset_schema(); @@ -407,14 +531,14 @@ arrow::Result> FlightSqlClient::Prepare( } auto handle = prepared_statement_result.prepared_statement_handle(); - return std::make_shared(this, handle, dataset_schema, + return std::make_shared(client, handle, dataset_schema, parameter_schema); } arrow::Result> PreparedStatement::Execute( const FlightCallOptions& options) { if (is_closed_) { - return Status::Invalid("Statement already closed."); + return Status::Invalid("Statement with handle '", handle_, "' already closed"); } flight_sql_pb::CommandPreparedStatementQuery command; @@ -433,6 +557,7 @@ arrow::Result> PreparedStatement::Execute( // Wait for the server to ack the result std::shared_ptr buffer; ARROW_RETURN_NOT_OK(reader->ReadMetadata(&buffer)); + ARROW_RETURN_NOT_OK(writer->Close()); } ARROW_ASSIGN_OR_RAISE(auto flight_info, client_->GetFlightInfo(options, descriptor)); @@ -442,7 +567,7 @@ arrow::Result> PreparedStatement::Execute( arrow::Result PreparedStatement::ExecuteUpdate( const FlightCallOptions& options) { if (is_closed_) { - return Status::Invalid("Statement already closed."); + return Status::Invalid("Statement with handle '", handle_, "' already closed"); } flight_sql_pb::CommandPreparedStatementUpdate command; @@ -496,7 +621,7 @@ std::shared_ptr PreparedStatement::parameter_schema() const { arrow::Result> PreparedStatement::GetSchema( const FlightCallOptions& options) { if (is_closed_) { - return Status::Invalid("Statement already closed"); + return Status::Invalid("Statement with handle '", handle_, "' already closed"); } flight_sql_pb::CommandPreparedStatementQuery command; @@ -508,29 +633,185 @@ arrow::Result> PreparedStatement::GetSchema( Status PreparedStatement::Close(const FlightCallOptions& options) { if (is_closed_) { - return Status::Invalid("Statement already closed."); + return Status::Invalid("Statement with handle '", handle_, "' already closed"); } - google::protobuf::Any command; + flight_sql_pb::ActionClosePreparedStatementRequest request; request.set_prepared_statement_handle(handle_); - command.PackFrom(request); + std::unique_ptr results; + ARROW_ASSIGN_OR_RAISE(auto action, PackAction("ClosePreparedStatement", request)); + ARROW_RETURN_NOT_OK(client_->DoAction(options, action, &results)); + ARROW_RETURN_NOT_OK(DrainResultStream(results.get())); - Action action; - action.type = "ClosePreparedStatement"; - action.body = Buffer::FromString(command.SerializeAsString()); + is_closed_ = true; + return Status::OK(); +} + +::arrow::Result FlightSqlClient::BeginTransaction( + const FlightCallOptions& options) { + flight_sql_pb::ActionBeginTransactionRequest request; std::unique_ptr results; + ARROW_ASSIGN_OR_RAISE(auto action, PackAction("BeginTransaction", request)); + ARROW_RETURN_NOT_OK(DoAction(options, action, &results)); - ARROW_RETURN_NOT_OK(client_->DoAction(options, action, &results)); + flight_sql_pb::ActionBeginTransactionResult transaction; + ARROW_RETURN_NOT_OK(ReadResult(results.get(), &transaction)); + if (transaction.transaction_id().empty()) { + return Status::Invalid("Server returned an empty transaction ID"); + } - is_closed_ = true; + ARROW_RETURN_NOT_OK(DrainResultStream(results.get())); + return Transaction(transaction.transaction_id()); +} + +::arrow::Result FlightSqlClient::BeginSavepoint( + const FlightCallOptions& options, const Transaction& transaction, + const std::string& name) { + flight_sql_pb::ActionBeginSavepointRequest request; + + if (!transaction.is_valid()) { + return Status::Invalid("Must provide an active transaction"); + } + request.set_transaction_id(transaction.transaction_id()); + request.set_name(name); + + std::unique_ptr results; + ARROW_ASSIGN_OR_RAISE(auto action, PackAction("BeginSavepoint", request)); + ARROW_RETURN_NOT_OK(DoAction(options, action, &results)); + + flight_sql_pb::ActionBeginSavepointResult savepoint; + ARROW_RETURN_NOT_OK(ReadResult(results.get(), &savepoint)); + if (savepoint.savepoint_id().empty()) { + return Status::Invalid("Server returned an empty savepoint ID"); + } + + ARROW_RETURN_NOT_OK(DrainResultStream(results.get())); + return Savepoint(savepoint.savepoint_id()); +} + +Status FlightSqlClient::Commit(const FlightCallOptions& options, + const Transaction& transaction) { + flight_sql_pb::ActionEndTransactionRequest request; + + if (!transaction.is_valid()) { + return Status::Invalid("Must provide an active transaction"); + } + request.set_transaction_id(transaction.transaction_id()); + request.set_action(flight_sql_pb::ActionEndTransactionRequest::END_TRANSACTION_COMMIT); + + std::unique_ptr results; + ARROW_ASSIGN_OR_RAISE(auto action, PackAction("EndTransaction", request)); + ARROW_RETURN_NOT_OK(DoAction(options, action, &results)); + + ARROW_RETURN_NOT_OK(DrainResultStream(results.get())); + return Status::OK(); +} +Status FlightSqlClient::Release(const FlightCallOptions& options, + const Savepoint& savepoint) { + flight_sql_pb::ActionEndSavepointRequest request; + + if (!savepoint.is_valid()) { + return Status::Invalid("Must provide an active savepoint"); + } + request.set_savepoint_id(savepoint.savepoint_id()); + request.set_action(flight_sql_pb::ActionEndSavepointRequest::END_SAVEPOINT_RELEASE); + + std::unique_ptr results; + ARROW_ASSIGN_OR_RAISE(auto action, PackAction("EndSavepoint", request)); + ARROW_RETURN_NOT_OK(DoAction(options, action, &results)); + + ARROW_RETURN_NOT_OK(DrainResultStream(results.get())); + return Status::OK(); +} + +Status FlightSqlClient::Rollback(const FlightCallOptions& options, + const Transaction& transaction) { + flight_sql_pb::ActionEndTransactionRequest request; + + if (!transaction.is_valid()) { + return Status::Invalid("Must provide an active transaction"); + } + request.set_transaction_id(transaction.transaction_id()); + request.set_action( + flight_sql_pb::ActionEndTransactionRequest::END_TRANSACTION_ROLLBACK); + + std::unique_ptr results; + ARROW_ASSIGN_OR_RAISE(auto action, PackAction("EndTransaction", request)); + ARROW_RETURN_NOT_OK(DoAction(options, action, &results)); + + ARROW_RETURN_NOT_OK(DrainResultStream(results.get())); + return Status::OK(); +} + +Status FlightSqlClient::Rollback(const FlightCallOptions& options, + const Savepoint& savepoint) { + flight_sql_pb::ActionEndSavepointRequest request; + + if (!savepoint.is_valid()) { + return Status::Invalid("Must provide an active savepoint"); + } + request.set_savepoint_id(savepoint.savepoint_id()); + request.set_action(flight_sql_pb::ActionEndSavepointRequest::END_SAVEPOINT_ROLLBACK); + + std::unique_ptr results; + ARROW_ASSIGN_OR_RAISE(auto action, PackAction("EndSavepoint", request)); + ARROW_RETURN_NOT_OK(DoAction(options, action, &results)); + + ARROW_RETURN_NOT_OK(DrainResultStream(results.get())); return Status::OK(); } +::arrow::Result FlightSqlClient::CancelQuery( + const FlightCallOptions& options, const FlightInfo& info) { + flight_sql_pb::ActionCancelQueryRequest request; + ARROW_ASSIGN_OR_RAISE(auto serialized_info, info.SerializeToString()); + request.set_info(std::move(serialized_info)); + + std::unique_ptr results; + ARROW_ASSIGN_OR_RAISE(auto action, PackAction("CancelQuery", request)); + ARROW_RETURN_NOT_OK(DoAction(options, action, &results)); + + flight_sql_pb::ActionCancelQueryResult result; + ARROW_RETURN_NOT_OK(ReadResult(results.get(), &result)); + ARROW_RETURN_NOT_OK(DrainResultStream(results.get())); + switch (result.result()) { + case flight_sql_pb::ActionCancelQueryResult::CANCEL_RESULT_UNSPECIFIED: + return CancelResult::kUnspecified; + case flight_sql_pb::ActionCancelQueryResult::CANCEL_RESULT_CANCELLED: + return CancelResult::kCancelled; + case flight_sql_pb::ActionCancelQueryResult::CANCEL_RESULT_CANCELLING: + return CancelResult::kCancelling; + case flight_sql_pb::ActionCancelQueryResult::CANCEL_RESULT_NOT_CANCELLABLE: + return CancelResult::kNotCancellable; + default: + break; + } + return Status::IOError("Server returned unknown result ", result.result()); +} + Status FlightSqlClient::Close() { return impl_->Close(); } +std::ostream& operator<<(std::ostream& os, CancelResult result) { + switch (result) { + case CancelResult::kUnspecified: + os << "CancelResult::kUnspecified"; + break; + case CancelResult::kCancelled: + os << "CancelResult::kCancelled"; + break; + case CancelResult::kCancelling: + os << "CancelResult::kCancelling"; + break; + case CancelResult::kNotCancellable: + os << "CancelResult::kNotCancellable"; + break; + } + return os; +} + } // namespace sql } // namespace flight } // namespace arrow diff --git a/cpp/src/arrow/flight/sql/client.h b/cpp/src/arrow/flight/sql/client.h index 26315e0d234fe..db168847ed66b 100644 --- a/cpp/src/arrow/flight/sql/client.h +++ b/cpp/src/arrow/flight/sql/client.h @@ -17,6 +17,7 @@ #pragma once +#include #include #include @@ -32,6 +33,13 @@ namespace flight { namespace sql { class PreparedStatement; +class Transaction; +class Savepoint; + +/// \brief A default transaction to use when the default behavior +/// (auto-commit) is desired. +ARROW_FLIGHT_SQL_EXPORT +const Transaction& no_transaction(); /// \brief Flight client with Flight SQL semantics. /// @@ -47,23 +55,51 @@ class ARROW_FLIGHT_SQL_EXPORT FlightSqlClient { virtual ~FlightSqlClient() = default; - /// \brief Execute a query on the server. + /// \brief Execute a SQL query on the server. + /// \param[in] options RPC-layer hints for this call. + /// \param[in] query The UTF8-encoded SQL query to be executed. + /// \param[in] transaction A transaction to associate this query with. + /// \return The FlightInfo describing where to access the dataset. + arrow::Result> Execute( + const FlightCallOptions& options, const std::string& query, + const Transaction& transaction = no_transaction()); + + /// \brief Execute a Substrait plan that returns a result set on the server. /// \param[in] options RPC-layer hints for this call. - /// \param[in] query The query to be executed in the UTF-8 format. + /// \param[in] plan The plan to be executed. + /// \param[in] transaction A transaction to associate this query with. /// \return The FlightInfo describing where to access the dataset. - arrow::Result> Execute(const FlightCallOptions& options, - const std::string& query); + arrow::Result> ExecuteSubstrait( + const FlightCallOptions& options, const SubstraitPlan& plan, + const Transaction& transaction = no_transaction()); /// \brief Get the result set schema from the server. arrow::Result> GetExecuteSchema( - const FlightCallOptions& options, const std::string& query); + const FlightCallOptions& options, const std::string& query, + const Transaction& transaction = no_transaction()); + + /// \brief Get the result set schema from the server. + arrow::Result> GetExecuteSubstraitSchema( + const FlightCallOptions& options, const SubstraitPlan& plan, + const Transaction& transaction = no_transaction()); /// \brief Execute an update query on the server. /// \param[in] options RPC-layer hints for this call. - /// \param[in] query The query to be executed in the UTF-8 format. + /// \param[in] query The UTF8-encoded SQL query to be executed. + /// \param[in] transaction A transaction to associate this query with. /// \return The quantity of rows affected by the operation. arrow::Result ExecuteUpdate(const FlightCallOptions& options, - const std::string& query); + const std::string& query, + const Transaction& transaction = no_transaction()); + + /// \brief Execute a Substrait plan that does not return a result set on the server. + /// \param[in] options RPC-layer hints for this call. + /// \param[in] plan The plan to be executed. + /// \param[in] transaction A transaction to associate this query with. + /// \return The FlightInfo describing where to access the dataset. + arrow::Result ExecuteSubstraitUpdate( + const FlightCallOptions& options, const SubstraitPlan& plan, + const Transaction& transaction = no_transaction()); /// \brief Request a list of catalogs. /// \param[in] options RPC-layer hints for this call. @@ -215,9 +251,20 @@ class ARROW_FLIGHT_SQL_EXPORT FlightSqlClient { /// \brief Create a prepared statement object. /// \param[in] options RPC-layer hints for this call. /// \param[in] query The query that will be executed. + /// \param[in] transaction A transaction to associate this query with. /// \return The created prepared statement. arrow::Result> Prepare( - const FlightCallOptions& options, const std::string& query); + const FlightCallOptions& options, const std::string& query, + const Transaction& transaction = no_transaction()); + + /// \brief Create a prepared statement object. + /// \param[in] options RPC-layer hints for this call. + /// \param[in] plan The Substrait plan that will be executed. + /// \param[in] transaction A transaction to associate this query with. + /// \return The created prepared statement. + arrow::Result> PrepareSubstrait( + const FlightCallOptions& options, const SubstraitPlan& plan, + const Transaction& transaction = no_transaction()); /// \brief Call the underlying Flight client's GetFlightInfo. virtual arrow::Result> GetFlightInfo( @@ -231,6 +278,58 @@ class ARROW_FLIGHT_SQL_EXPORT FlightSqlClient { return impl_->GetSchema(options, descriptor); } + /// \brief Begin a new transaction. + ::arrow::Result BeginTransaction(const FlightCallOptions& options); + + /// \brief Create a new savepoint within a transaction. + /// \param[in] options RPC-layer hints for this call. + /// \param[in] transaction The parent transaction. + /// \param[in] name A friendly name for the savepoint. + ::arrow::Result BeginSavepoint(const FlightCallOptions& options, + const Transaction& transaction, + const std::string& name); + + /// \brief Commit a transaction. + /// + /// After this, the transaction and all associated savepoints will + /// be invalidated. + /// + /// \param[in] options RPC-layer hints for this call. + /// \param[in] transaction The transaction. + Status Commit(const FlightCallOptions& options, const Transaction& transaction); + + /// \brief Release a savepoint. + /// + /// After this, the savepoint (and all savepoints created after it) will be invalidated. + /// + /// \param[in] options RPC-layer hints for this call. + /// \param[in] savepoint The savepoint. + Status Release(const FlightCallOptions& options, const Savepoint& savepoint); + + /// \brief Rollback a transaction. + /// + /// After this, the transaction and all associated savepoints will be invalidated. + /// + /// \param[in] options RPC-layer hints for this call. + /// \param[in] transaction The transaction. + Status Rollback(const FlightCallOptions& options, const Transaction& transaction); + + /// \brief Rollback a savepoint. + /// + /// After this, the savepoint will still be valid, but all + /// savepoints created after it will be invalidated. + /// + /// \param[in] options RPC-layer hints for this call. + /// \param[in] savepoint The savepoint. + Status Rollback(const FlightCallOptions& options, const Savepoint& savepoint); + + /// \brief Explicitly cancel a query. + /// + /// \param[in] options RPC-layer hints for this call. + /// \param[in] info The FlightInfo of the query to cancel. + ::arrow::Result CancelQuery(const FlightCallOptions& options, + const FlightInfo& info); + /// \brief Explicitly shut down and clean up the client. Status Close(); @@ -278,6 +377,10 @@ class ARROW_FLIGHT_SQL_EXPORT PreparedStatement { /// errors can't be caught. ~PreparedStatement(); + /// \brief Create a PreparedStatement by parsing the server response. + static arrow::Result> ParseResponse( + FlightSqlClient* client, std::unique_ptr results); + /// \brief Executes the prepared statement query on the server. /// \return A FlightInfo object representing the stream(s) to fetch. arrow::Result> Execute( @@ -295,8 +398,8 @@ class ARROW_FLIGHT_SQL_EXPORT PreparedStatement { /// \return The ResultSet schema from the query. std::shared_ptr dataset_schema() const; - /// \brief Set a RecordBatch that contains the parameters that will be bind. - /// \param parameter_binding The parameters that will be bind. + /// \brief Set a RecordBatch that contains the parameters that will be bound. + /// \param parameter_binding The parameters that will be bound. /// \return Status. Status SetParameters(std::shared_ptr parameter_binding); @@ -305,9 +408,9 @@ class ARROW_FLIGHT_SQL_EXPORT PreparedStatement { arrow::Result> GetSchema( const FlightCallOptions& options = {}); - /// \brief Close the prepared statement, so that this PreparedStatement can not used - /// anymore and server can free up any resources. - /// \return Status. + /// \brief Close the prepared statement so the server can free up any resources. + /// + /// After this, the prepared statement may not be used anymore. Status Close(const FlightCallOptions& options = {}); /// \brief Check if the prepared statement is closed. @@ -323,6 +426,29 @@ class ARROW_FLIGHT_SQL_EXPORT PreparedStatement { bool is_closed_; }; +/// \brief A handle for a server-side savepoint. +class ARROW_FLIGHT_SQL_EXPORT Savepoint { + public: + explicit Savepoint(std::string savepoint_id) : savepoint_id_(std::move(savepoint_id)) {} + const std::string& savepoint_id() const { return savepoint_id_; } + bool is_valid() const { return !savepoint_id_.empty(); } + + private: + std::string savepoint_id_; +}; + +/// \brief A handle for a server-side transaction. +class ARROW_FLIGHT_SQL_EXPORT Transaction { + public: + explicit Transaction(std::string transaction_id) + : transaction_id_(std::move(transaction_id)) {} + const std::string& transaction_id() const { return transaction_id_; } + bool is_valid() const { return !transaction_id_.empty(); } + + private: + std::string transaction_id_; +}; + } // namespace sql } // namespace flight } // namespace arrow diff --git a/cpp/src/arrow/flight/sql/client_test.cc b/cpp/src/arrow/flight/sql/client_test.cc index acd078a847792..984bf45481682 100644 --- a/cpp/src/arrow/flight/sql/client_test.cc +++ b/cpp/src/arrow/flight/sql/client_test.cc @@ -410,6 +410,7 @@ TEST_F(TestFlightSqlClient, TestExecuteUpdate) { std::unique_ptr* writer, std::unique_ptr* reader) { reader->reset(new FlightMetadataReaderMock(&buffer_ptr)); + writer->reset(new FlightStreamWriterMock()); return Status::OK(); }); diff --git a/cpp/src/arrow/flight/sql/column_metadata.cc b/cpp/src/arrow/flight/sql/column_metadata.cc index 30ef240105c87..adfe81f17300b 100644 --- a/cpp/src/arrow/flight/sql/column_metadata.cc +++ b/cpp/src/arrow/flight/sql/column_metadata.cc @@ -118,25 +118,25 @@ const std::shared_ptr& ColumnMetadata::metadata_m } ColumnMetadata::ColumnMetadataBuilder& ColumnMetadata::ColumnMetadataBuilder::CatalogName( - std::string& catalog_name) { + const std::string& catalog_name) { metadata_map_->Append(ColumnMetadata::kCatalogName, catalog_name); return *this; } ColumnMetadata::ColumnMetadataBuilder& ColumnMetadata::ColumnMetadataBuilder::SchemaName( - std::string& schema_name) { + const std::string& schema_name) { metadata_map_->Append(ColumnMetadata::kSchemaName, schema_name); return *this; } ColumnMetadata::ColumnMetadataBuilder& ColumnMetadata::ColumnMetadataBuilder::TableName( - std::string& table_name) { + const std::string& table_name) { metadata_map_->Append(ColumnMetadata::kTableName, table_name); return *this; } ColumnMetadata::ColumnMetadataBuilder& ColumnMetadata::ColumnMetadataBuilder::TypeName( - std::string& type_name) { + const std::string& type_name) { metadata_map_->Append(ColumnMetadata::kTypeName, type_name); return *this; } diff --git a/cpp/src/arrow/flight/sql/column_metadata.h b/cpp/src/arrow/flight/sql/column_metadata.h index 15b139ec5806a..0eb53f3e0bbc6 100644 --- a/cpp/src/arrow/flight/sql/column_metadata.h +++ b/cpp/src/arrow/flight/sql/column_metadata.h @@ -122,22 +122,22 @@ class ARROW_FLIGHT_SQL_EXPORT ColumnMetadata { /// \brief Set the catalog name in the KeyValueMetadata object. /// \param[in] catalog_name The catalog name. /// \return A ColumnMetadataBuilder. - ColumnMetadataBuilder& CatalogName(std::string& catalog_name); + ColumnMetadataBuilder& CatalogName(const std::string& catalog_name); /// \brief Set the schema_name in the KeyValueMetadata object. /// \param[in] schema_name The schema_name. /// \return A ColumnMetadataBuilder. - ColumnMetadataBuilder& SchemaName(std::string& schema_name); + ColumnMetadataBuilder& SchemaName(const std::string& schema_name); /// \brief Set the table name in the KeyValueMetadata object. /// \param[in] table_name The table name. /// \return A ColumnMetadataBuilder. - ColumnMetadataBuilder& TableName(std::string& table_name); + ColumnMetadataBuilder& TableName(const std::string& table_name); /// \brief Set the type name in the KeyValueMetadata object. /// \param[in] type_name The type name. /// \return A ColumnMetadataBuilder. - ColumnMetadataBuilder& TypeName(std::string& type_name); + ColumnMetadataBuilder& TypeName(const std::string& type_name); /// \brief Set the precision in the KeyValueMetadata object. /// \param[in] precision The precision. diff --git a/cpp/src/arrow/flight/sql/example/acero_main.cc b/cpp/src/arrow/flight/sql/example/acero_main.cc new file mode 100644 index 0000000000000..111bebcbf0f03 --- /dev/null +++ b/cpp/src/arrow/flight/sql/example/acero_main.cc @@ -0,0 +1,70 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +// Example Flight SQL server backed by Acero. + +#include + +#include +#include +#include +#include + +#include + +#include "arrow/flight/sql/example/acero_server.h" +#include "arrow/status.h" +#include "arrow/util/logging.h" + +namespace flight = arrow::flight; +namespace sql = arrow::flight::sql; + +DEFINE_string(location, "grpc://localhost:12345", "Location to listen on"); + +arrow::Status RunMain(const std::string& location_str) { + ARROW_ASSIGN_OR_RAISE(flight::Location location, flight::Location::Parse(location_str)); + flight::FlightServerOptions options(location); + + std::unique_ptr server; + ARROW_ASSIGN_OR_RAISE(server, sql::acero_example::MakeAceroServer()); + ARROW_RETURN_NOT_OK(server->Init(options)); + + ARROW_RETURN_NOT_OK(server->SetShutdownOnSignals({SIGTERM})); + + ARROW_LOG(INFO) << "Listening on " << location.ToString(); + + ARROW_RETURN_NOT_OK(server->Serve()); + return arrow::Status::OK(); +} + +int main(int argc, char** argv) { + gflags::ParseCommandLineFlags(&argc, &argv, true); + + arrow::util::ArrowLog::StartArrowLog("acero-flight-sql-server", + arrow::util::ArrowLogLevel::ARROW_INFO); + arrow::util::ArrowLog::InstallFailureSignalHandler(); + + arrow::Status st = RunMain(FLAGS_location); + + arrow::util::ArrowLog::ShutDownArrowLog(); + + if (!st.ok()) { + std::cerr << st << std::endl; + return EXIT_FAILURE; + } + return EXIT_SUCCESS; +} diff --git a/cpp/src/arrow/flight/sql/example/acero_server.cc b/cpp/src/arrow/flight/sql/example/acero_server.cc new file mode 100644 index 0000000000000..ce1483cb8c3f4 --- /dev/null +++ b/cpp/src/arrow/flight/sql/example/acero_server.cc @@ -0,0 +1,306 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "arrow/flight/sql/example/acero_server.h" + +#include +#include +#include +#include + +#include "arrow/engine/substrait/serde.h" +#include "arrow/flight/sql/types.h" +#include "arrow/type.h" +#include "arrow/util/logging.h" + +namespace arrow { +namespace flight { +namespace sql { +namespace acero_example { + +namespace { +/// \brief A SinkNodeConsumer that saves the schema as given to it by +/// the ExecPlan. Used to retrieve the schema of a Substrait plan to +/// fulfill the Flight SQL API contract. +class GetSchemaSinkNodeConsumer : public compute::SinkNodeConsumer { + public: + Status Init(const std::shared_ptr& schema, + compute::BackpressureControl*) override { + schema_ = schema; + return Status::OK(); + } + Status Consume(compute::ExecBatch exec_batch) override { return Status::OK(); } + Future<> Finish() override { return Status::OK(); } + + const std::shared_ptr& schema() const { return schema_; } + + private: + std::shared_ptr schema_; +}; + +/// \brief A SinkNodeConsumer that internally saves batches into a +/// queue, so that it can be read from a RecordBatchReader. In other +/// words, this bridges a push-based interface (ExecPlan) to a +/// pull-based interface (RecordBatchReader). +class QueuingSinkNodeConsumer : public compute::SinkNodeConsumer { + public: + QueuingSinkNodeConsumer() : schema_(nullptr), finished_(false) {} + + Status Init(const std::shared_ptr& schema, + compute::BackpressureControl*) override { + schema_ = schema; + return Status::OK(); + } + + Status Consume(compute::ExecBatch exec_batch) override { + { + std::lock_guard guard(mutex_); + batches_.push_back(std::move(exec_batch)); + batches_added_.notify_all(); + } + + return Status::OK(); + } + + Future<> Finish() override { + { + std::lock_guard guard(mutex_); + finished_ = true; + batches_added_.notify_all(); + } + + return Status::OK(); + } + + const std::shared_ptr& schema() const { return schema_; } + + arrow::Result> Next() { + compute::ExecBatch batch; + { + std::unique_lock guard(mutex_); + batches_added_.wait(guard, [this] { return !batches_.empty() || finished_; }); + + if (finished_ && batches_.empty()) { + return nullptr; + } + batch = std::move(batches_.front()); + batches_.pop_front(); + } + + return batch.ToRecordBatch(schema_); + } + + private: + std::mutex mutex_; + std::condition_variable batches_added_; + std::deque batches_; + std::shared_ptr schema_; + bool finished_; +}; + +/// \brief A RecordBatchReader that pulls from the +/// QueuingSinkNodeConsumer above, blocking until results are +/// available as necessary. +class ConsumerBasedRecordBatchReader : public RecordBatchReader { + public: + explicit ConsumerBasedRecordBatchReader( + std::shared_ptr plan, + std::shared_ptr consumer) + : plan_(std::move(plan)), consumer_(std::move(consumer)) {} + + std::shared_ptr schema() const override { return consumer_->schema(); } + + Status ReadNext(std::shared_ptr* batch) override { + return consumer_->Next().Value(batch); + } + + // TODO(ARROW-17242): FlightDataStream needs to call Close() + Status Close() override { return plan_->finished().status(); } + + private: + std::shared_ptr plan_; + std::shared_ptr consumer_; +}; + +/// \brief An implementation of a Flight SQL service backed by Acero. +class AceroFlightSqlServer : public FlightSqlServerBase { + public: + AceroFlightSqlServer() { + RegisterSqlInfo(SqlInfoOptions::SqlInfo::FLIGHT_SQL_SERVER_SUBSTRAIT, + SqlInfoResult(true)); + RegisterSqlInfo(SqlInfoOptions::SqlInfo::FLIGHT_SQL_SERVER_SUBSTRAIT_MIN_VERSION, + SqlInfoResult(std::string("0.6.0"))); + RegisterSqlInfo(SqlInfoOptions::SqlInfo::FLIGHT_SQL_SERVER_SUBSTRAIT_MAX_VERSION, + SqlInfoResult(std::string("0.6.0"))); + RegisterSqlInfo( + SqlInfoOptions::SqlInfo::FLIGHT_SQL_SERVER_TRANSACTION, + SqlInfoResult( + SqlInfoOptions::SqlSupportedTransaction::SQL_SUPPORTED_TRANSACTION_NONE)); + RegisterSqlInfo(SqlInfoOptions::SqlInfo::FLIGHT_SQL_SERVER_CANCEL, + SqlInfoResult(false)); + } + + arrow::Result> GetFlightInfoSubstraitPlan( + const ServerCallContext& context, const StatementSubstraitPlan& command, + const FlightDescriptor& descriptor) override { + if (!command.transaction_id.empty()) { + return Status::NotImplemented("Transactions are unsupported"); + } + + ARROW_ASSIGN_OR_RAISE(std::shared_ptr output_schema, + GetPlanSchema(command.plan.plan)); + + ARROW_LOG(INFO) << "GetFlightInfoSubstraitPlan: preparing plan with output schema " + << *output_schema; + + return MakeFlightInfo(command.plan.plan, descriptor, *output_schema); + } + + arrow::Result> GetFlightInfoPreparedStatement( + const ServerCallContext& context, const PreparedStatementQuery& command, + const FlightDescriptor& descriptor) override { + std::shared_ptr plan; + { + std::lock_guard guard(mutex_); + auto it = prepared_.find(command.prepared_statement_handle); + if (it == prepared_.end()) { + return Status::KeyError("Prepared statement not found"); + } + plan = it->second; + } + + return MakeFlightInfo(plan->ToString(), descriptor, Schema({})); + } + + arrow::Result> DoGetStatement( + const ServerCallContext& context, const StatementQueryTicket& command) override { + // GetFlightInfoSubstraitPlan encodes the plan into the ticket + std::shared_ptr serialized_plan = + Buffer::FromString(command.statement_handle); + std::shared_ptr consumer = + std::make_shared(); + ARROW_ASSIGN_OR_RAISE(std::shared_ptr plan, + engine::DeserializePlan(*serialized_plan, consumer)); + + ARROW_LOG(INFO) << "DoGetStatement: executing plan " << plan->ToString(); + + ARROW_RETURN_NOT_OK(plan->StartProducing()); + + auto reader = std::make_shared(std::move(plan), + std::move(consumer)); + return std::unique_ptr(new RecordBatchStream(reader)); + } + + arrow::Result DoPutCommandSubstraitPlan( + const ServerCallContext& context, const StatementSubstraitPlan& command) override { + return Status::NotImplemented("Updates are unsupported"); + } + + Status DoPutPreparedStatementQuery(const ServerCallContext& context, + const PreparedStatementQuery& command, + FlightMessageReader* reader, + FlightMetadataWriter* writer) override { + return Status::NotImplemented("NYI"); + } + + arrow::Result DoPutPreparedStatementUpdate( + const ServerCallContext& context, const PreparedStatementUpdate& command, + FlightMessageReader* reader) override { + return Status::NotImplemented("Updates are unsupported"); + } + + arrow::Result CreatePreparedSubstraitPlan( + const ServerCallContext& context, + const ActionCreatePreparedSubstraitPlanRequest& request) override { + if (!request.transaction_id.empty()) { + return Status::NotImplemented("Transactions are unsupported"); + } + // There's not any real point to precompiling the plan, since the + // consumer has to be provided here. So this is effectively the + // same as a non-prepared plan. + ARROW_ASSIGN_OR_RAISE(std::shared_ptr schema, + GetPlanSchema(request.plan.plan)); + + std::string handle; + { + std::lock_guard guard(mutex_); + handle = std::to_string(counter_++); + prepared_[handle] = Buffer::FromString(request.plan.plan); + } + + return ActionCreatePreparedStatementResult{ + /*dataset_schema=*/std::move(schema), + /*parameter_schema=*/nullptr, + handle, + }; + } + + Status ClosePreparedStatement( + const ServerCallContext& context, + const ActionClosePreparedStatementRequest& request) override { + std::lock_guard guard(mutex_); + prepared_.erase(request.prepared_statement_handle); + return Status::OK(); + } + + private: + arrow::Result> GetPlanSchema( + const std::string& serialized_plan) { + std::shared_ptr plan_buf = Buffer::FromString(serialized_plan); + std::shared_ptr consumer = + std::make_shared(); + ARROW_ASSIGN_OR_RAISE(std::shared_ptr plan, + engine::DeserializePlan(*plan_buf, consumer)); + std::shared_ptr output_schema; + for (compute::ExecNode* sink : plan->sinks()) { + // Force SinkNodeConsumer::Init to be called + ARROW_RETURN_NOT_OK(sink->StartProducing()); + output_schema = consumer->schema(); + break; + } + if (!output_schema) { + return Status::Invalid("Could not infer output schema"); + } + return output_schema; + } + + arrow::Result> MakeFlightInfo( + const std::string& plan, const FlightDescriptor& descriptor, const Schema& schema) { + ARROW_ASSIGN_OR_RAISE(auto ticket, CreateStatementQueryTicket(plan)); + std::vector endpoints{ + FlightEndpoint{Ticket{std::move(ticket)}, /*locations=*/{}}}; + ARROW_ASSIGN_OR_RAISE(auto info, + FlightInfo::Make(schema, descriptor, std::move(endpoints), + /*total_records=*/-1, /*total_bytes=*/-1)); + return std::make_unique(std::move(info)); + } + + std::mutex mutex_; + std::unordered_map> prepared_; + int64_t counter_; +}; + +} // namespace + +arrow::Result> MakeAceroServer() { + return std::unique_ptr(new AceroFlightSqlServer()); +} + +} // namespace acero_example +} // namespace sql +} // namespace flight +} // namespace arrow diff --git a/cpp/src/arrow/flight/sql/example/acero_server.h b/cpp/src/arrow/flight/sql/example/acero_server.h new file mode 100644 index 0000000000000..2e82fd3d3b6a8 --- /dev/null +++ b/cpp/src/arrow/flight/sql/example/acero_server.h @@ -0,0 +1,37 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include + +#include "arrow/flight/sql/server.h" +#include "arrow/flight/sql/visibility.h" +#include "arrow/result.h" + +namespace arrow { +namespace flight { +namespace sql { +namespace acero_example { + +/// \brief Make a Flight SQL server backed by the Acero query engine. +arrow::Result> MakeAceroServer(); + +} // namespace acero_example +} // namespace sql +} // namespace flight +} // namespace arrow diff --git a/cpp/src/arrow/flight/sql/example/sqlite_server.cc b/cpp/src/arrow/flight/sql/example/sqlite_server.cc index 35fa05468ba27..0d0a7c1ea0e01 100644 --- a/cpp/src/arrow/flight/sql/example/sqlite_server.cc +++ b/cpp/src/arrow/flight/sql/example/sqlite_server.cc @@ -20,11 +20,12 @@ #include #include -#include +#include #include #include +#include +#include -#include "arrow/api.h" #include "arrow/flight/sql/example/sqlite_sql_info.h" #include "arrow/flight/sql/example/sqlite_statement.h" #include "arrow/flight/sql/example/sqlite_statement_batch_reader.h" @@ -39,18 +40,6 @@ namespace example { namespace { -/// \brief Gets a SqliteStatement by given handle -arrow::Result> GetStatementByHandle( - const std::map>& prepared_statements, - const std::string& handle) { - auto search = prepared_statements.find(handle); - if (search == prepared_statements.end()) { - return Status::Invalid("Prepared statement not found"); - } - - return search->second; -} - std::string PrepareQueryForGetTables(const GetTables& command) { std::stringstream table_query; @@ -237,23 +226,83 @@ int32_t GetSqlTypeFromTypeName(const char* sqlite_type) { } class SQLiteFlightSqlServer::Impl { + private: sqlite3* db_; - std::map> prepared_statements_; + const std::string db_uri_; + std::mutex mutex_; + std::unordered_map> prepared_statements_; + std::unordered_map open_transactions_; std::default_random_engine gen_; + arrow::Result> GetStatementByHandle( + const std::string& handle) { + std::lock_guard guard(mutex_); + auto search = prepared_statements_.find(handle); + if (search == prepared_statements_.end()) { + return Status::KeyError("Prepared statement not found"); + } + return search->second; + } + + arrow::Result GetConnection(const std::string& transaction_id) { + if (transaction_id.empty()) return db_; + + std::lock_guard guard(mutex_); + auto it = open_transactions_.find(transaction_id); + if (it == open_transactions_.end()) { + return Status::KeyError("Unknown transaction ID: ", transaction_id); + } + return it->second; + } + + // Create a Ticket that combines a query and a transaction ID. + arrow::Result EncodeTransactionQuery(const std::string& query, + const std::string& transaction_id) { + std::string transaction_query = transaction_id; + transaction_query += ':'; + transaction_query += query; + ARROW_ASSIGN_OR_RAISE(auto ticket_string, + CreateStatementQueryTicket(transaction_query)); + return Ticket{std::move(ticket_string)}; + } + + arrow::Result> DecodeTransactionQuery( + const std::string& ticket) { + auto divider = ticket.find(':'); + if (divider == std::string::npos) { + return Status::Invalid("Malformed ticket"); + } + std::string transaction_id = ticket.substr(0, divider); + std::string query = ticket.substr(divider + 1); + return std::make_pair(std::move(query), std::move(transaction_id)); + } + public: - explicit Impl(sqlite3* db) : db_(db) {} + explicit Impl(sqlite3* db, std::string uri) : db_(db), db_uri_(std::move(uri)) {} - ~Impl() { sqlite3_close(db_); } + ~Impl() { + sqlite3_close(db_); + for (const auto& pair : open_transactions_) { + sqlite3_close(pair.second); + } + } std::string GenerateRandomString() { uint32_t length = 16; // MSVC doesn't support char types here std::uniform_int_distribution dist(static_cast('0'), - static_cast('z')); + static_cast('Z')); std::string ret(length, 0); - auto get_random_char = [&]() { return static_cast(dist(gen_)); }; + // Don't generate symbols to simplify parsing in DecodeTransactionQuery + auto get_random_char = [&]() { + char res; + while (true) { + res = static_cast(dist(gen_)); + if (res <= '9' || res >= 'A') break; + } + return res; + }; std::generate_n(ret.begin(), length, get_random_char); return ret; } @@ -262,13 +311,12 @@ class SQLiteFlightSqlServer::Impl { const ServerCallContext& context, const StatementQuery& command, const FlightDescriptor& descriptor) { const std::string& query = command.query; - - ARROW_ASSIGN_OR_RAISE(auto statement, SqliteStatement::Create(db_, query)); - + ARROW_ASSIGN_OR_RAISE(auto db, GetConnection(command.transaction_id)); + ARROW_ASSIGN_OR_RAISE(auto statement, SqliteStatement::Create(db, query)); ARROW_ASSIGN_OR_RAISE(auto schema, statement->GetSchema()); - - ARROW_ASSIGN_OR_RAISE(auto ticket_string, CreateStatementQueryTicket(query)); - std::vector endpoints{FlightEndpoint{{ticket_string}, {}}}; + ARROW_ASSIGN_OR_RAISE(auto ticket, + EncodeTransactionQuery(query, command.transaction_id)); + std::vector endpoints{FlightEndpoint{std::move(ticket), {}}}; ARROW_ASSIGN_OR_RAISE(auto result, FlightInfo::Make(*schema, descriptor, endpoints, -1, -1)) @@ -277,10 +325,13 @@ class SQLiteFlightSqlServer::Impl { arrow::Result> DoGetStatement( const ServerCallContext& context, const StatementQueryTicket& command) { - const std::string& sql = command.statement_handle; + ARROW_ASSIGN_OR_RAISE(auto pair, DecodeTransactionQuery(command.statement_handle)); + const std::string& sql = pair.first; + const std::string transaction_id = pair.second; + ARROW_ASSIGN_OR_RAISE(auto db, GetConnection(transaction_id)); std::shared_ptr statement; - ARROW_ASSIGN_OR_RAISE(statement, SqliteStatement::Create(db_, sql)); + ARROW_ASSIGN_OR_RAISE(statement, SqliteStatement::Create(db, sql)); std::shared_ptr reader; ARROW_ASSIGN_OR_RAISE(reader, SqliteStatementBatchReader::Create(statement)); @@ -375,15 +426,15 @@ class SQLiteFlightSqlServer::Impl { arrow::Result DoPutCommandStatementUpdate(const ServerCallContext& context, const StatementUpdate& command) { const std::string& sql = command.query; - - ARROW_ASSIGN_OR_RAISE(auto statement, SqliteStatement::Create(db_, sql)); - + ARROW_ASSIGN_OR_RAISE(auto db, GetConnection(command.transaction_id)); + ARROW_ASSIGN_OR_RAISE(auto statement, SqliteStatement::Create(db, sql)); return statement->ExecuteUpdate(); } arrow::Result CreatePreparedStatement( const ServerCallContext& context, const ActionCreatePreparedStatementRequest& request) { + std::lock_guard guard(mutex_); std::shared_ptr statement; ARROW_ASSIGN_OR_RAISE(statement, SqliteStatement::Create(db_, request.query)); std::string handle = GenerateRandomString(); @@ -419,6 +470,7 @@ class SQLiteFlightSqlServer::Impl { Status ClosePreparedStatement(const ServerCallContext& context, const ActionClosePreparedStatementRequest& request) { + std::lock_guard guard(mutex_); const std::string& prepared_statement_handle = request.prepared_statement_handle; auto search = prepared_statements_.find(prepared_statement_handle); @@ -434,6 +486,7 @@ class SQLiteFlightSqlServer::Impl { arrow::Result> GetFlightInfoPreparedStatement( const ServerCallContext& context, const PreparedStatementQuery& command, const FlightDescriptor& descriptor) { + std::lock_guard guard(mutex_); const std::string& prepared_statement_handle = command.prepared_statement_handle; auto search = prepared_statements_.find(prepared_statement_handle); @@ -450,6 +503,7 @@ class SQLiteFlightSqlServer::Impl { arrow::Result> DoGetPreparedStatement( const ServerCallContext& context, const PreparedStatementQuery& command) { + std::lock_guard guard(mutex_); const std::string& prepared_statement_handle = command.prepared_statement_handle; auto search = prepared_statements_.find(prepared_statement_handle); @@ -470,9 +524,8 @@ class SQLiteFlightSqlServer::Impl { FlightMessageReader* reader, FlightMetadataWriter* writer) { const std::string& prepared_statement_handle = command.prepared_statement_handle; - ARROW_ASSIGN_OR_RAISE( - auto statement, - GetStatementByHandle(prepared_statements_, prepared_statement_handle)); + ARROW_ASSIGN_OR_RAISE(auto statement, + GetStatementByHandle(prepared_statement_handle)); sqlite3_stmt* stmt = statement->GetSqlite3Stmt(); ARROW_RETURN_NOT_OK(SetParametersOnSQLiteStatement(stmt, reader)); @@ -484,9 +537,8 @@ class SQLiteFlightSqlServer::Impl { const ServerCallContext& context, const PreparedStatementUpdate& command, FlightMessageReader* reader) { const std::string& prepared_statement_handle = command.prepared_statement_handle; - ARROW_ASSIGN_OR_RAISE( - auto statement, - GetStatementByHandle(prepared_statements_, prepared_statement_handle)); + ARROW_ASSIGN_OR_RAISE(auto statement, + GetStatementByHandle(prepared_statement_handle)); sqlite3_stmt* stmt = statement->GetSqlite3Stmt(); ARROW_RETURN_NOT_OK(SetParametersOnSQLiteStatement(stmt, reader)); @@ -627,28 +679,85 @@ class SQLiteFlightSqlServer::Impl { return DoGetSQLiteQuery(db_, query, SqlSchema::GetCrossReferenceSchema()); } - Status ExecuteSql(const std::string& sql) { + Status ExecuteSql(const std::string& sql) { return ExecuteSql(db_, sql); } + + Status ExecuteSql(sqlite3* db, const std::string& sql) { char* err_msg = nullptr; - int rc = sqlite3_exec(db_, sql.c_str(), nullptr, nullptr, &err_msg); + int rc = sqlite3_exec(db, sql.c_str(), nullptr, nullptr, &err_msg); if (rc != SQLITE_OK) { std::string error_msg; if (err_msg != nullptr) { error_msg = err_msg; + sqlite3_free(err_msg); } - sqlite3_free(err_msg); - return Status::ExecutionError(error_msg); + return Status::IOError(error_msg); } + if (err_msg) sqlite3_free(err_msg); return Status::OK(); } + + arrow::Result BeginTransaction( + const ServerCallContext& context, const ActionBeginTransactionRequest& request) { + std::string handle = GenerateRandomString(); + sqlite3* new_db = nullptr; + if (sqlite3_open_v2(db_uri_.c_str(), &new_db, + SQLITE_OPEN_READWRITE | SQLITE_OPEN_CREATE | SQLITE_OPEN_URI, + /*zVfs=*/nullptr) != SQLITE_OK) { + std::string error_message = "Can't open new connection: "; + if (new_db) { + error_message += sqlite3_errmsg(new_db); + sqlite3_close(new_db); + } + return Status::Invalid(error_message); + } + + ARROW_RETURN_NOT_OK(ExecuteSql(new_db, "BEGIN TRANSACTION")); + + std::lock_guard guard(mutex_); + open_transactions_[handle] = new_db; + return ActionBeginTransactionResult{std::move(handle)}; + } + + Status EndTransaction(const ServerCallContext& context, + const ActionEndTransactionRequest& request) { + Status status; + sqlite3* transaction = nullptr; + { + std::lock_guard guard(mutex_); + auto it = open_transactions_.find(request.transaction_id); + if (it == open_transactions_.end()) { + return Status::KeyError("Unknown transaction ID: ", request.transaction_id); + } + + if (request.action == ActionEndTransactionRequest::kCommit) { + status = ExecuteSql(it->second, "COMMIT"); + } else { + status = ExecuteSql(it->second, "ROLLBACK"); + } + transaction = it->second; + open_transactions_.erase(it); + } + sqlite3_close(transaction); + return status; + } }; +// Give each server instance its own in-memory DB +std::atomic kDbCounter(0); + SQLiteFlightSqlServer::SQLiteFlightSqlServer(std::shared_ptr impl) : impl_(std::move(impl)) {} arrow::Result> SQLiteFlightSqlServer::Create() { sqlite3* db = nullptr; - if (sqlite3_open(":memory:", &db)) { + // All sqlite3* instances created from this URI will share data + std::string uri = "file:memorydb"; + uri += std::to_string(kDbCounter++); + uri += "?mode=memory&cache=shared"; + if (sqlite3_open_v2(uri.c_str(), &db, + SQLITE_OPEN_READWRITE | SQLITE_OPEN_CREATE | SQLITE_OPEN_URI, + /*zVfs=*/nullptr)) { std::string err_msg = "Can't open database: "; if (db != nullptr) { err_msg += sqlite3_errmsg(db); @@ -660,9 +769,10 @@ arrow::Result> SQLiteFlightSqlServer::Cre return Status::Invalid(err_msg); } - std::shared_ptr impl = std::make_shared(db); + std::shared_ptr impl = std::make_shared(db, std::move(uri)); - std::shared_ptr result(new SQLiteFlightSqlServer(impl)); + std::shared_ptr result( + new SQLiteFlightSqlServer(std::move(impl))); for (const auto& id_to_result : GetSqlInfoResultMap()) { result->RegisterSqlInfo(id_to_result.first, id_to_result.second); } @@ -855,6 +965,15 @@ SQLiteFlightSqlServer::DoGetCrossReference(const ServerCallContext& context, return impl_->DoGetCrossReference(context, command); } +arrow::Result SQLiteFlightSqlServer::BeginTransaction( + const ServerCallContext& context, const ActionBeginTransactionRequest& request) { + return impl_->BeginTransaction(context, request); +} +Status SQLiteFlightSqlServer::EndTransaction(const ServerCallContext& context, + const ActionEndTransactionRequest& request) { + return impl_->EndTransaction(context, request); +} + } // namespace example } // namespace sql } // namespace flight diff --git a/cpp/src/arrow/flight/sql/example/sqlite_server.h b/cpp/src/arrow/flight/sql/example/sqlite_server.h index 744ed068d0b13..389a2d921bbbf 100644 --- a/cpp/src/arrow/flight/sql/example/sqlite_server.h +++ b/cpp/src/arrow/flight/sql/example/sqlite_server.h @@ -141,6 +141,12 @@ class SQLiteFlightSqlServer : public FlightSqlServerBase { arrow::Result> DoGetPrimaryKeys( const ServerCallContext& context, const GetPrimaryKeys& command) override; + arrow::Result BeginTransaction( + const ServerCallContext& context, + const ActionBeginTransactionRequest& request) override; + Status EndTransaction(const ServerCallContext& context, + const ActionEndTransactionRequest& request) override; + private: class Impl; std::shared_ptr impl_; diff --git a/cpp/src/arrow/flight/sql/example/sqlite_sql_info.cc b/cpp/src/arrow/flight/sql/example/sqlite_sql_info.cc index 94f25b390170f..9737b5a3090d1 100644 --- a/cpp/src/arrow/flight/sql/example/sqlite_sql_info.cc +++ b/cpp/src/arrow/flight/sql/example/sqlite_sql_info.cc @@ -18,6 +18,7 @@ #include "arrow/flight/sql/example/sqlite_sql_info.h" #include "arrow/flight/sql/types.h" +#include "arrow/util/config.h" namespace arrow { namespace flight { @@ -33,8 +34,14 @@ SqlInfoResultMap GetSqlInfoResultMap() { {SqlInfoOptions::SqlInfo::FLIGHT_SQL_SERVER_VERSION, SqlInfoResult(std::string("sqlite 3"))}, {SqlInfoOptions::SqlInfo::FLIGHT_SQL_SERVER_ARROW_VERSION, - SqlInfoResult(std::string("7.0.0-SNAPSHOT" /* Only an example */))}, + SqlInfoResult(std::string(ARROW_VERSION_STRING))}, {SqlInfoOptions::SqlInfo::FLIGHT_SQL_SERVER_READ_ONLY, SqlInfoResult(false)}, + {SqlInfoOptions::SqlInfo::FLIGHT_SQL_SERVER_SQL, SqlInfoResult(true)}, + {SqlInfoOptions::SqlInfo::FLIGHT_SQL_SERVER_SUBSTRAIT, SqlInfoResult(false)}, + {SqlInfoOptions::SqlInfo::FLIGHT_SQL_SERVER_TRANSACTION, + SqlInfoResult(SqlInfoOptions::SqlSupportedTransaction:: + SQL_SUPPORTED_TRANSACTION_TRANSACTION)}, + {SqlInfoOptions::SqlInfo::FLIGHT_SQL_SERVER_CANCEL, SqlInfoResult(false)}, {SqlInfoOptions::SqlInfo::SQL_DDL_CATALOG, SqlInfoResult(false /* SQLite 3 does not support catalogs */)}, {SqlInfoOptions::SqlInfo::SQL_DDL_SCHEMA, diff --git a/cpp/src/arrow/flight/sql/server.cc b/cpp/src/arrow/flight/sql/server.cc index a8f3ed8a80c24..1905b117d611f 100644 --- a/cpp/src/arrow/flight/sql/server.cc +++ b/cpp/src/arrow/flight/sql/server.cc @@ -149,11 +149,29 @@ arrow::Result ParseCommandStatementQuery( const google::protobuf::Any& any) { pb::sql::CommandStatementQuery command; if (!any.UnpackTo(&command)) { - return Status::Invalid("Unable to unpack CommandStatementQuery."); + return Status::Invalid("Unable to unpack CommandStatementQuery"); } StatementQuery result; result.query = command.query(); + result.transaction_id = command.transaction_id(); + return result; +} + +SubstraitPlan ParseStatementSubstraitPlan(const pb::sql::SubstraitPlan& pb_plan) { + return {pb_plan.plan(), pb_plan.version()}; +} + +arrow::Result ParseCommandStatementSubstraitPlan( + const google::protobuf::Any& any) { + pb::sql::CommandStatementSubstraitPlan command; + if (!any.UnpackTo(&command)) { + return Status::Invalid("Unable to unpack CommandStatementSubstraitPlan"); + } + + StatementSubstraitPlan result; + result.plan = ParseStatementSubstraitPlan(command.plan()); + result.transaction_id = command.transaction_id(); return result; } @@ -190,18 +208,6 @@ arrow::Result ParseCommandGetTables(const google::protobuf::Any& any) return result; } -arrow::Result ParseStatementQueryTicket( - const google::protobuf::Any& any) { - pb::sql::TicketStatementQuery command; - if (!any.UnpackTo(&command)) { - return Status::Invalid("Unable to unpack TicketStatementQuery."); - } - - StatementQueryTicket result; - result.statement_handle = command.statement_handle(); - return result; -} - arrow::Result ParseCommandStatementUpdate( const google::protobuf::Any& any) { pb::sql::CommandStatementUpdate command; @@ -211,6 +217,7 @@ arrow::Result ParseCommandStatementUpdate( StatementUpdate result; result.query = command.query(); + result.transaction_id = command.transaction_id(); return result; } @@ -226,15 +233,65 @@ arrow::Result ParseCommandPreparedStatementUpdate( return result; } +arrow::Result ParseActionBeginSavepointRequest( + const google::protobuf::Any& any) { + pb::sql::ActionBeginSavepointRequest command; + if (!any.UnpackTo(&command)) { + return Status::Invalid("Unable to unpack ActionBeginSavepointRequest"); + } + + ActionBeginSavepointRequest result; + result.transaction_id = command.transaction_id(); + result.name = command.name(); + return result; +} + +arrow::Result ParseActionBeginTransactionRequest( + const google::protobuf::Any& any) { + pb::sql::ActionBeginTransactionRequest command; + if (!any.UnpackTo(&command)) { + return Status::Invalid("Unable to unpack ActionBeginTransactionRequest"); + } + + ActionBeginTransactionRequest result; + return result; +} + +arrow::Result ParseActionCancelQueryRequest( + const google::protobuf::Any& any) { + pb::sql::ActionCancelQueryRequest command; + if (!any.UnpackTo(&command)) { + return Status::Invalid("Unable to unpack ActionCancelQueryRequest"); + } + + ActionCancelQueryRequest result; + ARROW_ASSIGN_OR_RAISE(result.info, FlightInfo::Deserialize(command.info())); + return result; +} + arrow::Result ParseActionCreatePreparedStatementRequest(const google::protobuf::Any& any) { pb::sql::ActionCreatePreparedStatementRequest command; if (!any.UnpackTo(&command)) { - return Status::Invalid("Unable to unpack ActionCreatePreparedStatementRequest."); + return Status::Invalid("Unable to unpack ActionCreatePreparedStatementRequest"); } ActionCreatePreparedStatementRequest result; result.query = command.query(); + result.transaction_id = command.transaction_id(); + return result; +} + +arrow::Result +ParseActionCreatePreparedSubstraitPlanRequest(const google::protobuf::Any& any) { + pb::sql::ActionCreatePreparedSubstraitPlanRequest command; + if (!any.UnpackTo(&command)) { + return Status::Invalid("Unable to unpack ActionCreatePreparedSubstraitPlanRequest"); + } + + ActionCreatePreparedSubstraitPlanRequest result; + result.plan = ParseStatementSubstraitPlan(command.plan()); + result.transaction_id = command.transaction_id(); return result; } @@ -242,7 +299,7 @@ arrow::Result ParseActionClosePreparedStatementRequest(const google::protobuf::Any& any) { pb::sql::ActionClosePreparedStatementRequest command; if (!any.UnpackTo(&command)) { - return Status::Invalid("Unable to unpack ActionClosePreparedStatementRequest."); + return Status::Invalid("Unable to unpack ActionClosePreparedStatementRequest"); } ActionClosePreparedStatementRequest result; @@ -250,8 +307,139 @@ ParseActionClosePreparedStatementRequest(const google::protobuf::Any& any) { return result; } +arrow::Result ParseActionEndSavepointRequest( + const google::protobuf::Any& any) { + pb::sql::ActionEndSavepointRequest command; + if (!any.UnpackTo(&command)) { + return Status::Invalid("Unable to unpack ActionEndSavepointRequest"); + } + + ActionEndSavepointRequest result; + result.savepoint_id = command.savepoint_id(); + switch (command.action()) { + case pb::sql::ActionEndSavepointRequest::END_SAVEPOINT_UNSPECIFIED: + return Status::Invalid( + "ActionEndSavepointRequest.action was END_SAVEPOINT_UNSPECIFIED"); + case pb::sql::ActionEndSavepointRequest::END_SAVEPOINT_RELEASE: + result.action = ActionEndSavepointRequest::kRelease; + break; + case pb::sql::ActionEndSavepointRequest::END_SAVEPOINT_ROLLBACK: + result.action = ActionEndSavepointRequest::kRollback; + break; + default: + return Status::Invalid("Unknown value for ActionEndSavepointRequest.action: ", + command.action()); + } + return result; +} + +arrow::Result ParseActionEndTransactionRequest( + const google::protobuf::Any& any) { + pb::sql::ActionEndTransactionRequest command; + if (!any.UnpackTo(&command)) { + return Status::Invalid("Unable to unpack ActionEndTransactionRequest"); + } + + ActionEndTransactionRequest result; + result.transaction_id = command.transaction_id(); + switch (command.action()) { + case pb::sql::ActionEndTransactionRequest::END_TRANSACTION_UNSPECIFIED: + return Status::Invalid( + "ActionEndTransactionRequest.action was END_TRANSACTION_UNSPECIFIED"); + case pb::sql::ActionEndTransactionRequest::END_TRANSACTION_COMMIT: + result.action = ActionEndTransactionRequest::kCommit; + break; + case pb::sql::ActionEndTransactionRequest::END_TRANSACTION_ROLLBACK: + result.action = ActionEndTransactionRequest::kRollback; + break; + default: + return Status::Invalid("Unknown value for ActionEndTransactionRequest.action: ", + command.action()); + } + return result; +} + +arrow::Result PackActionResult(const google::protobuf::Message& message) { + google::protobuf::Any any; + if (!any.PackFrom(message)) { + return Status::IOError("Failed to pack ", message.GetTypeName()); + } + + std::string buffer; + if (!any.SerializeToString(&buffer)) { + return Status::IOError("Failed to serialize packed ", message.GetTypeName()); + } + return Result{Buffer::FromString(std::move(buffer))}; +} + +arrow::Result PackActionResult(ActionBeginSavepointResult result) { + pb::sql::ActionBeginSavepointResult pb_result; + pb_result.set_savepoint_id(std::move(result.savepoint_id)); + return PackActionResult(pb_result); +} + +arrow::Result PackActionResult(ActionBeginTransactionResult result) { + pb::sql::ActionBeginTransactionResult pb_result; + pb_result.set_transaction_id(std::move(result.transaction_id)); + return PackActionResult(pb_result); +} + +arrow::Result PackActionResult(CancelResult result) { + pb::sql::ActionCancelQueryResult pb_result; + switch (result) { + case CancelResult::kUnspecified: + pb_result.set_result(pb::sql::ActionCancelQueryResult::CANCEL_RESULT_UNSPECIFIED); + break; + case CancelResult::kCancelled: + pb_result.set_result(pb::sql::ActionCancelQueryResult::CANCEL_RESULT_CANCELLED); + break; + case CancelResult::kCancelling: + pb_result.set_result(pb::sql::ActionCancelQueryResult::CANCEL_RESULT_CANCELLING); + break; + case CancelResult::kNotCancellable: + pb_result.set_result( + pb::sql::ActionCancelQueryResult::CANCEL_RESULT_NOT_CANCELLABLE); + break; + } + return PackActionResult(pb_result); +} + +arrow::Result PackActionResult(ActionCreatePreparedStatementResult result) { + pb::sql::ActionCreatePreparedStatementResult pb_result; + pb_result.set_prepared_statement_handle(std::move(result.prepared_statement_handle)); + if (result.dataset_schema != nullptr) { + ARROW_ASSIGN_OR_RAISE(std::shared_ptr serialized, + ipc::SerializeSchema(*result.dataset_schema)); + pb_result.set_dataset_schema(reinterpret_cast(serialized->data()), + serialized->size()); + } + if (result.parameter_schema != nullptr) { + ARROW_ASSIGN_OR_RAISE(std::shared_ptr serialized, + ipc::SerializeSchema(*result.parameter_schema)); + pb_result.set_parameter_schema(reinterpret_cast(serialized->data()), + serialized->size()); + } + + return PackActionResult(pb_result); +} + } // namespace +arrow::Result StatementQueryTicket::Deserialize( + std::string_view serialized) { + pb::sql::TicketStatementQuery command; + google::protobuf::Any any; + if (!any.ParseFromArray(serialized.data(), static_cast(serialized.size()))) { + return Status::Invalid("Unable to parse ticket"); + } + if (!any.UnpackTo(&command)) { + return Status::Invalid("Unable to unpack TicketStatementQuery"); + } + StatementQueryTicket result; + result.statement_handle = command.statement_handle(); + return result; +} + arrow::Result CreateStatementQueryTicket( const std::string& statement_handle) { protocol::sql::TicketStatementQuery ticket_statement_query; @@ -282,6 +470,12 @@ Status FlightSqlServerBase::GetFlightInfo(const ServerCallContext& context, ARROW_ASSIGN_OR_RAISE(*info, GetFlightInfoStatement(context, internal_command, request)); return Status::OK(); + } else if (any.Is()) { + ARROW_ASSIGN_OR_RAISE(StatementSubstraitPlan internal_command, + ParseCommandStatementSubstraitPlan(any)); + ARROW_ASSIGN_OR_RAISE(*info, + GetFlightInfoSubstraitPlan(context, internal_command, request)); + return Status::OK(); } else if (any.Is()) { ARROW_ASSIGN_OR_RAISE(PreparedStatementQuery internal_command, ParseCommandPreparedStatementQuery(any)); @@ -358,6 +552,12 @@ Status FlightSqlServerBase::GetSchema(const ServerCallContext& context, ARROW_ASSIGN_OR_RAISE(*schema, GetSchemaStatement(context, internal_command, request)); return Status::OK(); + } else if (any.Is()) { + ARROW_ASSIGN_OR_RAISE(StatementSubstraitPlan internal_command, + ParseCommandStatementSubstraitPlan(any)); + ARROW_ASSIGN_OR_RAISE(*schema, + GetSchemaSubstraitPlan(context, internal_command, request)); + return Status::OK(); } else if (any.Is()) { ARROW_ASSIGN_OR_RAISE(PreparedStatementQuery internal_command, ParseCommandPreparedStatementQuery(any)); @@ -413,15 +613,19 @@ Status FlightSqlServerBase::GetSchema(const ServerCallContext& context, Status FlightSqlServerBase::DoGet(const ServerCallContext& context, const Ticket& request, std::unique_ptr* stream) { google::protobuf::Any any; - if (!any.ParseFromArray(request.ticket.data(), static_cast(request.ticket.size()))) { - return Status::Invalid("Unable to parse ticket."); + return Status::Invalid("Unable to parse ticket"); } if (any.Is()) { - ARROW_ASSIGN_OR_RAISE(StatementQueryTicket command, ParseStatementQueryTicket(any)); - ARROW_ASSIGN_OR_RAISE(*stream, DoGetStatement(context, command)); + pb::sql::TicketStatementQuery command; + if (!any.UnpackTo(&command)) { + return Status::Invalid("Unable to unpack TicketStatementQuery"); + } + StatementQueryTicket result; + result.statement_handle = command.statement_handle(); + ARROW_ASSIGN_OR_RAISE(*stream, DoGetStatement(context, result)); return Status::OK(); } else if (any.Is()) { ARROW_ASSIGN_OR_RAISE(PreparedStatementQuery internal_command, @@ -483,7 +687,7 @@ Status FlightSqlServerBase::DoPut(const ServerCallContext& context, google::protobuf::Any any; if (!any.ParseFromArray(request.cmd.data(), static_cast(request.cmd.size()))) { - return Status::Invalid("Unable to parse command."); + return Status::Invalid("Unable to parse command"); } if (any.Is()) { @@ -498,6 +702,18 @@ Status FlightSqlServerBase::DoPut(const ServerCallContext& context, const auto buffer = Buffer::FromString(result.SerializeAsString()); ARROW_RETURN_NOT_OK(writer->WriteMetadata(*buffer)); + return Status::OK(); + } else if (any.Is()) { + ARROW_ASSIGN_OR_RAISE(StatementSubstraitPlan internal_command, + ParseCommandStatementSubstraitPlan(any)); + ARROW_ASSIGN_OR_RAISE(auto record_count, + DoPutCommandSubstraitPlan(context, internal_command)); + + pb::sql::DoPutUpdateResult result; + result.set_record_count(record_count); + + const auto buffer = Buffer::FromString(result.SerializeAsString()); + ARROW_RETURN_NOT_OK(writer->WriteMetadata(*buffer)); return Status::OK(); } else if (any.Is()) { ARROW_ASSIGN_OR_RAISE(PreparedStatementQuery internal_command, @@ -507,78 +723,104 @@ Status FlightSqlServerBase::DoPut(const ServerCallContext& context, } else if (any.Is()) { ARROW_ASSIGN_OR_RAISE(PreparedStatementUpdate internal_command, ParseCommandPreparedStatementUpdate(any)); - ARROW_ASSIGN_OR_RAISE(auto record_count, DoPutPreparedStatementUpdate( - context, internal_command, reader.get())) + ARROW_ASSIGN_OR_RAISE( + auto record_count, + DoPutPreparedStatementUpdate(context, internal_command, reader.get())); pb::sql::DoPutUpdateResult result; result.set_record_count(record_count); const auto buffer = Buffer::FromString(result.SerializeAsString()); ARROW_RETURN_NOT_OK(writer->WriteMetadata(*buffer)); - return Status::OK(); } - return Status::Invalid("The defined request is invalid."); + return Status::NotImplemented("Command not recognized: ", any.type_url()); } Status FlightSqlServerBase::ListActions(const ServerCallContext& context, std::vector* actions) { - *actions = {FlightSqlServerBase::kCreatePreparedStatementActionType, - FlightSqlServerBase::kClosePreparedStatementActionType}; + *actions = { + FlightSqlServerBase::kBeginSavepointActionType, + FlightSqlServerBase::kBeginTransactionActionType, + FlightSqlServerBase::kCancelQueryActionType, + FlightSqlServerBase::kCreatePreparedStatementActionType, + FlightSqlServerBase::kCreatePreparedSubstraitPlanActionType, + FlightSqlServerBase::kClosePreparedStatementActionType, + FlightSqlServerBase::kEndSavepointActionType, + FlightSqlServerBase::kEndTransactionActionType, + }; return Status::OK(); } Status FlightSqlServerBase::DoAction(const ServerCallContext& context, const Action& action, std::unique_ptr* result_stream) { - if (action.type == FlightSqlServerBase::kCreatePreparedStatementActionType.type) { - google::protobuf::Any any_command; - if (!any_command.ParseFromArray(action.body->data(), - static_cast(action.body->size()))) { - return Status::Invalid("Unable to parse action."); - } + google::protobuf::Any any; + if (!any.ParseFromArray(action.body->data(), static_cast(action.body->size()))) { + return Status::Invalid("Unable to parse action"); + } + std::vector results; + if (action.type == FlightSqlServerBase::kBeginSavepointActionType.type) { + ARROW_ASSIGN_OR_RAISE(ActionBeginSavepointRequest internal_command, + ParseActionBeginSavepointRequest(any)); + ARROW_ASSIGN_OR_RAISE(ActionBeginSavepointResult result, + BeginSavepoint(context, internal_command)); + ARROW_ASSIGN_OR_RAISE(Result packed_result, PackActionResult(std::move(result))); + + results.push_back(std::move(packed_result)); + } else if (action.type == FlightSqlServerBase::kBeginTransactionActionType.type) { + ARROW_ASSIGN_OR_RAISE(ActionBeginTransactionRequest internal_command, + ParseActionBeginTransactionRequest(any)); + ARROW_ASSIGN_OR_RAISE(ActionBeginTransactionResult result, + BeginTransaction(context, internal_command)); + ARROW_ASSIGN_OR_RAISE(Result packed_result, PackActionResult(std::move(result))); + + results.push_back(std::move(packed_result)); + } else if (action.type == FlightSqlServerBase::kCancelQueryActionType.type) { + ARROW_ASSIGN_OR_RAISE(ActionCancelQueryRequest internal_command, + ParseActionCancelQueryRequest(any)); + ARROW_ASSIGN_OR_RAISE(CancelResult result, CancelQuery(context, internal_command)); + ARROW_ASSIGN_OR_RAISE(Result packed_result, PackActionResult(result)); + + results.push_back(std::move(packed_result)); + } else if (action.type == + FlightSqlServerBase::kCreatePreparedStatementActionType.type) { ARROW_ASSIGN_OR_RAISE(ActionCreatePreparedStatementRequest internal_command, - ParseActionCreatePreparedStatementRequest(any_command)); - ARROW_ASSIGN_OR_RAISE(auto result, CreatePreparedStatement(context, internal_command)) - - pb::sql::ActionCreatePreparedStatementResult action_result; - action_result.set_prepared_statement_handle(result.prepared_statement_handle); - if (result.dataset_schema != nullptr) { - ARROW_ASSIGN_OR_RAISE(auto serialized_dataset_schema, - ipc::SerializeSchema(*result.dataset_schema)) - action_result.set_dataset_schema(serialized_dataset_schema->ToString()); - } - if (result.parameter_schema != nullptr) { - ARROW_ASSIGN_OR_RAISE(auto serialized_parameter_schema, - ipc::SerializeSchema(*result.parameter_schema)) - action_result.set_parameter_schema(serialized_parameter_schema->ToString()); - } - - google::protobuf::Any any; - any.PackFrom(action_result); - - auto buf = Buffer::FromString(any.SerializeAsString()); - *result_stream = std::unique_ptr(new SimpleResultStream({Result{buf}})); - - return Status::OK(); + ParseActionCreatePreparedStatementRequest(any)); + ARROW_ASSIGN_OR_RAISE(ActionCreatePreparedStatementResult result, + CreatePreparedStatement(context, internal_command)); + ARROW_ASSIGN_OR_RAISE(Result packed_result, PackActionResult(std::move(result))); + + results.push_back(std::move(packed_result)); + } else if (action.type == + FlightSqlServerBase::kCreatePreparedSubstraitPlanActionType.type) { + ARROW_ASSIGN_OR_RAISE(ActionCreatePreparedSubstraitPlanRequest internal_command, + ParseActionCreatePreparedSubstraitPlanRequest(any)); + ARROW_ASSIGN_OR_RAISE(ActionCreatePreparedStatementResult result, + CreatePreparedSubstraitPlan(context, internal_command)); + ARROW_ASSIGN_OR_RAISE(Result packed_result, PackActionResult(std::move(result))); + + results.push_back(std::move(packed_result)); } else if (action.type == FlightSqlServerBase::kClosePreparedStatementActionType.type) { - google::protobuf::Any any; - if (!any.ParseFromArray(action.body->data(), static_cast(action.body->size()))) { - return Status::Invalid("Unable to parse action."); - } - ARROW_ASSIGN_OR_RAISE(ActionClosePreparedStatementRequest internal_command, ParseActionClosePreparedStatementRequest(any)); - ARROW_RETURN_NOT_OK(ClosePreparedStatement(context, internal_command)); - - // Need to instantiate a ResultStream, otherwise clients can not wait for completion. - *result_stream = std::unique_ptr(new SimpleResultStream({})); - return Status::OK(); + } else if (action.type == FlightSqlServerBase::kEndSavepointActionType.type) { + ARROW_ASSIGN_OR_RAISE(ActionEndSavepointRequest internal_command, + ParseActionEndSavepointRequest(any)); + ARROW_RETURN_NOT_OK(EndSavepoint(context, internal_command)); + } else if (action.type == FlightSqlServerBase::kEndTransactionActionType.type) { + ARROW_ASSIGN_OR_RAISE(ActionEndTransactionRequest internal_command, + ParseActionEndTransactionRequest(any)); + ARROW_RETURN_NOT_OK(EndTransaction(context, internal_command)); + } else { + return Status::NotImplemented("Action not implemented: ", action.type); } - return Status::Invalid("The defined request is invalid."); + *result_stream = + std::unique_ptr(new SimpleResultStream(std::move(results))); + return Status::OK(); } arrow::Result> FlightSqlServerBase::GetFlightInfoCatalogs( @@ -603,6 +845,19 @@ arrow::Result> FlightSqlServerBase::GetSchemaState return Status::NotImplemented("GetSchemaStatement not implemented"); } +arrow::Result> +FlightSqlServerBase::GetFlightInfoSubstraitPlan(const ServerCallContext& context, + const StatementSubstraitPlan& command, + const FlightDescriptor& descriptor) { + return Status::NotImplemented("GetFlightInfoSubstraitPlan not implemented"); +} + +arrow::Result> FlightSqlServerBase::GetSchemaSubstraitPlan( + const ServerCallContext& context, const StatementSubstraitPlan& command, + const FlightDescriptor& descriptor) { + return Status::NotImplemented("GetSchemaSubstraitPlan not implemented"); +} + arrow::Result> FlightSqlServerBase::DoGetStatement( const ServerCallContext& context, const StatementQueryTicket& command) { return Status::NotImplemented("DoGetStatement not implemented"); @@ -773,6 +1028,21 @@ arrow::Result> FlightSqlServerBase::DoGetCross return Status::NotImplemented("DoGetCrossReference not implemented"); } +arrow::Result FlightSqlServerBase::BeginSavepoint( + const ServerCallContext& context, const ActionBeginSavepointRequest& request) { + return Status::NotImplemented("BeginSavepoint not implemented"); +} + +arrow::Result FlightSqlServerBase::BeginTransaction( + const ServerCallContext& context, const ActionBeginTransactionRequest& request) { + return Status::NotImplemented("BeginTransaction not implemented"); +} + +arrow::Result FlightSqlServerBase::CancelQuery( + const ServerCallContext& context, const ActionCancelQueryRequest& request) { + return Status::NotImplemented("CancelQuery not implemented"); +} + arrow::Result FlightSqlServerBase::CreatePreparedStatement( const ServerCallContext& context, @@ -780,12 +1050,29 @@ FlightSqlServerBase::CreatePreparedStatement( return Status::NotImplemented("CreatePreparedStatement not implemented"); } +arrow::Result +FlightSqlServerBase::CreatePreparedSubstraitPlan( + const ServerCallContext& context, + const ActionCreatePreparedSubstraitPlanRequest& request) { + return Status::NotImplemented("CreatePreparedSubstraitPlan not implemented"); +} + Status FlightSqlServerBase::ClosePreparedStatement( const ServerCallContext& context, const ActionClosePreparedStatementRequest& request) { return Status::NotImplemented("ClosePreparedStatement not implemented"); } +Status FlightSqlServerBase::EndSavepoint(const ServerCallContext& context, + const ActionEndSavepointRequest& request) { + return Status::NotImplemented("EndSavepoint not implemented"); +} + +Status FlightSqlServerBase::EndTransaction(const ServerCallContext& context, + const ActionEndTransactionRequest& request) { + return Status::NotImplemented("EndTransaction not implemented"); +} + Status FlightSqlServerBase::DoPutPreparedStatementQuery( const ServerCallContext& context, const PreparedStatementQuery& command, FlightMessageReader* reader, FlightMetadataWriter* writer) { @@ -803,6 +1090,11 @@ arrow::Result FlightSqlServerBase::DoPutCommandStatementUpdate( return Status::NotImplemented("DoPutCommandStatementUpdate not implemented"); } +arrow::Result FlightSqlServerBase::DoPutCommandSubstraitPlan( + const ServerCallContext& context, const StatementSubstraitPlan& command) { + return Status::NotImplemented("DoPutCommandSubstraitPlan not implemented"); +} + std::shared_ptr SqlSchema::GetCatalogsSchema() { return arrow::schema({field("catalog_name", utf8(), false)}); } diff --git a/cpp/src/arrow/flight/sql/server.h b/cpp/src/arrow/flight/sql/server.h index 91dad98843f52..0fc8b714865a8 100644 --- a/cpp/src/arrow/flight/sql/server.h +++ b/cpp/src/arrow/flight/sql/server.h @@ -23,6 +23,7 @@ #include #include #include +#include #include #include "arrow/flight/server.h" @@ -44,18 +45,32 @@ namespace sql { struct ARROW_FLIGHT_SQL_EXPORT StatementQuery { /// \brief The SQL query. std::string query; + /// \brief The transaction ID, if specified (else a blank string). + std::string transaction_id; +}; + +/// \brief A Substrait plan to execute. +struct ARROW_FLIGHT_SQL_EXPORT StatementSubstraitPlan { + /// \brief The Substrait plan. + SubstraitPlan plan; + /// \brief The transaction ID, if specified (else a blank string). + std::string transaction_id; }; /// \brief A SQL update query. struct ARROW_FLIGHT_SQL_EXPORT StatementUpdate { /// \brief The SQL query. std::string query; + /// \brief The transaction ID, if specified (else a blank string). + std::string transaction_id; }; /// \brief A request to execute a query. struct ARROW_FLIGHT_SQL_EXPORT StatementQueryTicket { /// \brief The server-generated opaque identifier for the query. std::string statement_handle; + + static arrow::Result Deserialize(std::string_view serialized); }; /// \brief A prepared query statement. @@ -132,10 +147,66 @@ struct ARROW_FLIGHT_SQL_EXPORT GetCrossReference { TableRef fk_table_ref; }; +/// \brief A request to start a new transaction. +struct ARROW_FLIGHT_SQL_EXPORT ActionBeginTransactionRequest {}; + +/// \brief A request to create a new savepoint. +struct ARROW_FLIGHT_SQL_EXPORT ActionBeginSavepointRequest { + std::string transaction_id; + std::string name; +}; + +/// \brief The result of starting a new savepoint. +struct ARROW_FLIGHT_SQL_EXPORT ActionBeginSavepointResult { + std::string savepoint_id; +}; + +/// \brief The result of starting a new transaction. +struct ARROW_FLIGHT_SQL_EXPORT ActionBeginTransactionResult { + std::string transaction_id; +}; + +/// \brief A request to end a savepoint. +struct ARROW_FLIGHT_SQL_EXPORT ActionEndSavepointRequest { + enum EndSavepoint { + kRelease, + kRollback, + }; + + std::string savepoint_id; + EndSavepoint action; +}; + +/// \brief A request to end a transaction. +struct ARROW_FLIGHT_SQL_EXPORT ActionEndTransactionRequest { + enum EndTransaction { + kCommit, + kRollback, + }; + + std::string transaction_id; + EndTransaction action; +}; + +/// \brief An explicit request to cancel a running query. +struct ARROW_FLIGHT_SQL_EXPORT ActionCancelQueryRequest { + std::unique_ptr info; +}; + /// \brief A request to create a new prepared statement. struct ARROW_FLIGHT_SQL_EXPORT ActionCreatePreparedStatementRequest { /// \brief The SQL query. std::string query; + /// \brief The transaction ID, if specified (else a blank string). + std::string transaction_id; +}; + +/// \brief A request to create a new prepared statement with a Substrait plan. +struct ARROW_FLIGHT_SQL_EXPORT ActionCreatePreparedSubstraitPlanRequest { + /// \brief The serialized Substrait plan. + SubstraitPlan plan; + /// \brief The transaction ID, if specified (else a blank string). + std::string transaction_id; }; /// \brief A request to close a prepared statement. @@ -189,6 +260,15 @@ class ARROW_FLIGHT_SQL_EXPORT FlightSqlServerBase : public FlightServerBase { const ServerCallContext& context, const StatementQuery& command, const FlightDescriptor& descriptor); + /// \brief Get a FlightInfo for executing a Substrait plan. + /// \param[in] context Per-call context. + /// \param[in] command The StatementSubstraitPlan object containing the plan. + /// \param[in] descriptor The descriptor identifying the data stream. + /// \return The FlightInfo describing where to access the dataset. + virtual arrow::Result> GetFlightInfoSubstraitPlan( + const ServerCallContext& context, const StatementSubstraitPlan& command, + const FlightDescriptor& descriptor); + /// \brief Get a FlightDataStream containing the query results. /// \param[in] context Per-call context. /// \param[in] command The StatementQueryTicket containing the statement handle. @@ -231,6 +311,15 @@ class ARROW_FLIGHT_SQL_EXPORT FlightSqlServerBase : public FlightServerBase { const ServerCallContext& context, const StatementQuery& command, const FlightDescriptor& descriptor); + /// \brief Get the schema of the result set of a Substrait plan. + /// \param[in] context Per-call context. + /// \param[in] command The StatementQuery containing the plan. + /// \param[in] descriptor The descriptor identifying the data stream. + /// \return The schema of the result set. + virtual arrow::Result> GetSchemaSubstraitPlan( + const ServerCallContext& context, const StatementSubstraitPlan& command, + const FlightDescriptor& descriptor); + /// \brief Get the schema of the result set of a prepared statement. /// \param[in] context Per-call context. /// \param[in] command The PreparedStatementQuery containing the @@ -423,7 +512,14 @@ class ARROW_FLIGHT_SQL_EXPORT FlightSqlServerBase : public FlightServerBase { virtual arrow::Result DoPutCommandStatementUpdate( const ServerCallContext& context, const StatementUpdate& command); - /// \brief Create a prepared statement from given SQL statement. + /// \brief Execute an update Substrait plan. + /// \param[in] context The call context. + /// \param[in] command The StatementSubstraitPlan object containing the plan. + /// \return The changed record count. + virtual arrow::Result DoPutCommandSubstraitPlan( + const ServerCallContext& context, const StatementSubstraitPlan& command); + + /// \brief Create a prepared statement from a given SQL statement. /// \param[in] context The call context. /// \param[in] request The ActionCreatePreparedStatementRequest object containing the /// SQL statement. @@ -433,6 +529,16 @@ class ARROW_FLIGHT_SQL_EXPORT FlightSqlServerBase : public FlightServerBase { const ServerCallContext& context, const ActionCreatePreparedStatementRequest& request); + /// \brief Create a prepared statement from a Substrait plan. + /// \param[in] context The call context. + /// \param[in] request The ActionCreatePreparedSubstraitPlanRequest object containing + /// the Substrait plan. + /// \return A ActionCreatePreparedStatementResult containing the dataset + /// and parameter schemas and a handle for created statement. + virtual arrow::Result CreatePreparedSubstraitPlan( + const ServerCallContext& context, + const ActionCreatePreparedSubstraitPlanRequest& request); + /// \brief Close a prepared statement. /// \param[in] context The call context. /// \param[in] request The ActionClosePreparedStatementRequest object containing the @@ -462,6 +568,39 @@ class ARROW_FLIGHT_SQL_EXPORT FlightSqlServerBase : public FlightServerBase { const ServerCallContext& context, const PreparedStatementUpdate& command, FlightMessageReader* reader); + /// \brief Begin a new transaction. + /// \param[in] context The call context. + /// \param[in] request Request parameters. + /// \return The transaction ID. + virtual arrow::Result BeginTransaction( + const ServerCallContext& context, const ActionBeginTransactionRequest& request); + + /// \brief Create a new savepoint. + /// \param[in] context The call context. + /// \param[in] request Request parameters. + /// \return The savepoint ID. + virtual arrow::Result BeginSavepoint( + const ServerCallContext& context, const ActionBeginSavepointRequest& request); + + /// \brief Release/rollback a savepoint. + /// \param[in] context The call context. + /// \param[in] request The savepoint. + virtual Status EndSavepoint(const ServerCallContext& context, + const ActionEndSavepointRequest& request); + + /// \brief Commit/rollback a transaction. + /// \param[in] context The call context. + /// \param[in] request The tranaction. + virtual Status EndTransaction(const ServerCallContext& context, + const ActionEndTransactionRequest& request); + + /// \brief Attempt to explicitly cancel a query. + /// \param[in] context The call context. + /// \param[in] request The query to cancel. + /// \return The cancellation result. + virtual arrow::Result CancelQuery( + const ServerCallContext& context, const ActionCancelQueryRequest& request); + /// @} /// \name Utility methods @@ -492,16 +631,46 @@ class ARROW_FLIGHT_SQL_EXPORT FlightSqlServerBase : public FlightServerBase { std::unique_ptr reader, std::unique_ptr writer) final; + const ActionType kBeginSavepointActionType = + ActionType{"BeginSavepoint", + "Create a new savepoint.\n" + "Request Message: ActionBeginSavepointRequest\n" + "Response Message: ActionBeginSavepointResult"}; + const ActionType kBeginTransactionActionType = + ActionType{"BeginTransaction", + "Start a new transaction.\n" + "Request Message: ActionBeginTransactionRequest\n" + "Response Message: ActionBeginTransactionResult"}; const ActionType kCreatePreparedStatementActionType = ActionType{"CreatePreparedStatement", "Creates a reusable prepared statement resource on the server.\n" "Request Message: ActionCreatePreparedStatementRequest\n" "Response Message: ActionCreatePreparedStatementResult"}; + const ActionType kCreatePreparedSubstraitPlanActionType = + ActionType{"CreatePreparedSubstraitPlan", + "Creates a reusable prepared statement resource on the server.\n" + "Request Message: ActionCreatePreparedSubstraitPlanRequest\n" + "Response Message: ActionCreatePreparedStatementResult"}; + const ActionType kCancelQueryActionType = + ActionType{"CancelQuery", + "Explicitly cancel a running query.\n" + "Request Message: ActionCancelQueryRequest\n" + "Response Message: ActionCancelQueryResult"}; const ActionType kClosePreparedStatementActionType = ActionType{"ClosePreparedStatement", "Closes a reusable prepared statement resource on the server.\n" "Request Message: ActionClosePreparedStatementRequest\n" "Response Message: N/A"}; + const ActionType kEndSavepointActionType = + ActionType{"EndSavepoint", + "End a savepoint.\n" + "Request Message: ActionEndSavepointRequest\n" + "Response Message: N/A"}; + const ActionType kEndTransactionActionType = + ActionType{"EndTransaction", + "End a savepoint.\n" + "Request Message: ActionEndTransactionRequest\n" + "Response Message: N/A"}; Status ListActions(const ServerCallContext& context, std::vector* actions) final; diff --git a/cpp/src/arrow/flight/sql/server_test.cc b/cpp/src/arrow/flight/sql/server_test.cc index 7ba3ca4a24364..785f45551fcc5 100644 --- a/cpp/src/arrow/flight/sql/server_test.cc +++ b/cpp/src/arrow/flight/sql/server_test.cc @@ -24,9 +24,6 @@ #include #include -#include -#include - #include "arrow/flight/api.h" #include "arrow/flight/sql/api.h" #include "arrow/flight/sql/column_metadata.h" @@ -153,17 +150,12 @@ class TestFlightSqlServer : public ::testing::Test { protected: void SetUp() override { - port = GetListenPort(); - server_thread.reset(new std::thread([&]() { RunServer(); })); - - std::unique_lock lk(server_ready_m); - server_ready_cv.wait(lk); - - std::stringstream ss; - ss << "grpc://localhost:" << port; - std::string uri = ss.str(); + ASSERT_OK_AND_ASSIGN(auto location, Location::ForGrpcTcp("0.0.0.0", 0)); + arrow::flight::FlightServerOptions options(location); + ASSERT_OK_AND_ASSIGN(server, example::SQLiteFlightSqlServer::Create()); + ASSERT_OK(server->Init(options)); - ASSERT_OK_AND_ASSIGN(auto location, Location::Parse(uri)); + ASSERT_OK_AND_ASSIGN(location, Location::ForGrpcTcp("localhost", server->port())); ASSERT_OK_AND_ASSIGN(auto client, FlightClient::Connect(location)); sql_client.reset(new FlightSqlClient(std::move(client))); @@ -174,30 +166,10 @@ class TestFlightSqlServer : public ::testing::Test { sql_client.reset(); ASSERT_OK(server->Shutdown()); - server_thread->join(); - server_thread.reset(); } private: - int port; std::shared_ptr server; - std::unique_ptr server_thread; - std::condition_variable server_ready_cv; - std::mutex server_ready_m; - - void RunServer() { - ASSERT_OK_AND_ASSIGN(auto location, Location::ForGrpcTcp("localhost", port)); - arrow::flight::FlightServerOptions options(location); - - ARROW_CHECK_OK(example::SQLiteFlightSqlServer::Create().Value(&server)); - - ARROW_CHECK_OK(server->Init(options)); - // Exit with a clean error code (0) on SIGTERM - ARROW_CHECK_OK(server->SetShutdownOnSignals({SIGTERM})); - - server_ready_cv.notify_all(); - ARROW_CHECK_OK(server->Serve()); - } }; TEST_F(TestFlightSqlServer, TestCommandStatementQuery) { @@ -802,6 +774,51 @@ TEST_F(TestFlightSqlServer, TestCommandGetSqlInfoNoInfo) { sql_client->DoGet(call_options, flight_info->endpoints()[0].ticket)); } +TEST_F(TestFlightSqlServer, CancelQuery) { + // Not supported + ASSERT_OK_AND_ASSIGN(auto flight_info, sql_client->GetSqlInfo({}, {})); + ASSERT_RAISES(NotImplemented, sql_client->CancelQuery({}, *flight_info)); +} + +TEST_F(TestFlightSqlServer, Transactions) { + ASSERT_OK_AND_ASSIGN(auto handle, sql_client->BeginTransaction({})); + ASSERT_TRUE(handle.is_valid()); + ASSERT_NE(handle.transaction_id(), ""); + ASSERT_RAISES(NotImplemented, sql_client->BeginSavepoint({}, handle, "savepoint")); + + ASSERT_OK_AND_ASSIGN(auto flight_info, + sql_client->Execute({}, "SELECT * FROM intTable", handle)); + ASSERT_OK_AND_ASSIGN(auto stream, + sql_client->DoGet({}, flight_info->endpoints()[0].ticket)); + ASSERT_OK_AND_ASSIGN(auto table, stream->ToTable()); + int64_t row_count = table->num_rows(); + + int64_t result; + ASSERT_OK_AND_ASSIGN(result, + sql_client->ExecuteUpdate( + {}, + "INSERT INTO intTable (keyName, value) VALUES " + "('KEYNAME1', 1001), ('KEYNAME2', 1002), ('KEYNAME3', 1003)", + handle)); + ASSERT_EQ(3, result); + + ASSERT_OK_AND_ASSIGN(flight_info, + sql_client->Execute({}, "SELECT * FROM intTable", handle)); + ASSERT_OK_AND_ASSIGN(stream, sql_client->DoGet({}, flight_info->endpoints()[0].ticket)); + ASSERT_OK_AND_ASSIGN(table, stream->ToTable()); + ASSERT_EQ(table->num_rows(), row_count + 3); + + ASSERT_OK(sql_client->Rollback({}, handle)); + // Commit/rollback invalidate the handle + ASSERT_RAISES(KeyError, sql_client->Rollback({}, handle)); + ASSERT_RAISES(KeyError, sql_client->Commit({}, handle)); + + ASSERT_OK_AND_ASSIGN(flight_info, sql_client->Execute({}, "SELECT * FROM intTable")); + ASSERT_OK_AND_ASSIGN(stream, sql_client->DoGet({}, flight_info->endpoints()[0].ticket)); + ASSERT_OK_AND_ASSIGN(table, stream->ToTable()); + ASSERT_EQ(table->num_rows(), row_count); +} + } // namespace sql } // namespace flight } // namespace arrow diff --git a/cpp/src/arrow/flight/sql/types.h b/cpp/src/arrow/flight/sql/types.h index 20c7952d8d710..8b28ed18bdd4a 100644 --- a/cpp/src/arrow/flight/sql/types.h +++ b/cpp/src/arrow/flight/sql/types.h @@ -18,7 +18,7 @@ #pragma once #include -#include +#include #include #include #include @@ -70,6 +70,54 @@ struct ARROW_FLIGHT_SQL_EXPORT SqlInfoOptions { /// - true: if read only FLIGHT_SQL_SERVER_READ_ONLY = 3, + /// Retrieves a boolean value indicating whether the Flight SQL Server + /// supports executing SQL queries. + /// + /// Note that the absence of this info (as opposed to a false + /// value) does not necessarily mean that SQL is not supported, as + /// this property was not originally defined. + FLIGHT_SQL_SERVER_SQL = 4, + + /// Retrieves a boolean value indicating whether the Flight SQL Server + /// supports executing Substrait plans. + FLIGHT_SQL_SERVER_SUBSTRAIT = 5, + + /// Retrieves a string value indicating the minimum supported + /// Substrait version, or null if Substrait is not supported. + FLIGHT_SQL_SERVER_SUBSTRAIT_MIN_VERSION = 6, + + /// Retrieves a string value indicating the maximum supported + /// Substrait version, or null if Substrait is not supported. + FLIGHT_SQL_SERVER_SUBSTRAIT_MAX_VERSION = 7, + + /// Retrieves an int32 indicating whether the Flight SQL Server + /// supports the BeginTransaction, EndTransaction, BeginSavepoint, + /// and EndSavepoint actions. + /// + /// Even if this is not supported, the database may still support + /// explicit "BEGIN TRANSACTION"/"COMMIT" SQL statements (see + /// SQL_TRANSACTIONS_SUPPORTED); this property is only about + /// whether the server implements the Flight SQL API endpoints. + /// + /// The possible values are listed in `SqlSupportedTransaction`. + FLIGHT_SQL_SERVER_TRANSACTION = 8, + + /// Retrieves a boolean value indicating whether the Flight SQL Server + /// supports explicit query cancellation (the CancelQuery action). + FLIGHT_SQL_SERVER_CANCEL = 9, + + /// Retrieves an int32 value indicating the timeout (in milliseconds) for + /// prepared statement handles. + /// + /// If 0, there is no timeout. + FLIGHT_SQL_SERVER_STATEMENT_TIMEOUT = 100, + + /// Retrieves an int32 value indicating the timeout (in milliseconds) for + /// transactions, since transactions are not tied to a connection. + /// + /// If 0, there is no timeout. + FLIGHT_SQL_SERVER_TRANSACTION_TIMEOUT = 101, + /// @} /// \name SQL Syntax Information @@ -795,6 +843,16 @@ struct ARROW_FLIGHT_SQL_EXPORT SqlInfoOptions { /// @} }; + /// The level of support for Flight SQL transaction RPCs. + enum SqlSupportedTransaction { + /// Unknown/not indicated/no supoprt + SQL_SUPPORTED_TRANSACTION_NONE = 0, + /// Transactions, but not savepoints. + SQL_SUPPORTED_TRANSACTION_TRANSACTION = 1, + /// Transactions and savepoints. + SQL_SUPPORTED_TRANSACTION_SAVEPOINT = 2, + }; + /// Indicate whether something (e.g. an identifier) is case-sensitive. enum SqlSupportedCaseSensitivity { SQL_CASE_SENSITIVITY_UNKNOWN = 0, @@ -845,6 +903,25 @@ struct ARROW_FLIGHT_SQL_EXPORT TableRef { std::string table; }; +/// \brief A Substrait plan to be executed, along with associated metadata. +struct ARROW_FLIGHT_SQL_EXPORT SubstraitPlan { + /// \brief The serialized plan. + std::string plan; + /// \brief The Substrait release, e.g. "0.12.0". + std::string version; +}; + +/// \brief The result of cancelling a query. +enum class CancelResult : int8_t { + kUnspecified, + kCancelled, + kCancelling, + kNotCancellable, +}; + +ARROW_FLIGHT_SQL_EXPORT +std::ostream& operator<<(std::ostream& os, CancelResult result); + /// @} } // namespace sql diff --git a/dev/archery/archery/integration/runner.py b/dev/archery/archery/integration/runner.py index 05f945cb82416..887cbf92fedc5 100644 --- a/dev/archery/archery/integration/runner.py +++ b/dev/archery/archery/integration/runner.py @@ -435,6 +435,11 @@ def run_all_tests(with_cpp=True, with_java=True, with_js=True, description="Ensure Flight SQL protocol is working as expected.", skip={"Rust"} ), + Scenario( + "flight_sql:extension", + description="Ensure Flight SQL extensions work as expected.", + skip={"Rust", "Go"} + ), ] runner = IntegrationRunner(json_files, flight_scenarios, testers, **kwargs) diff --git a/docs/source/status.rst b/docs/source/status.rst index a5dd47c0b8121..fc63787225596 100644 --- a/docs/source/status.rst +++ b/docs/source/status.rst @@ -232,10 +232,22 @@ support/not support individual features. +--------------------------------------------+-------+-------+-------+------------+-------+-------+-------+ | Feature | C++ | Java | Go | JavaScript | C# | Rust | Julia | +============================================+=======+=======+=======+============+=======+=======+=======+ +| BeginSavepoint | ✓ | ✓ | | | | | | ++--------------------------------------------+-------+-------+-------+------------+-------+-------+-------+ +| BeginTransaction | ✓ | ✓ | | | | | | ++--------------------------------------------+-------+-------+-------+------------+-------+-------+-------+ +| CancelQuery | ✓ | ✓ | | | | | | ++--------------------------------------------+-------+-------+-------+------------+-------+-------+-------+ | ClosePreparedStatement | ✓ | ✓ | ✓ | | | | | +--------------------------------------------+-------+-------+-------+------------+-------+-------+-------+ | CreatePreparedStatement | ✓ | ✓ | ✓ | | | | | +--------------------------------------------+-------+-------+-------+------------+-------+-------+-------+ +| CreatePreparedSubstraitPlan | ✓ | ✓ | | | | | | ++--------------------------------------------+-------+-------+-------+------------+-------+-------+-------+ +| EndSavepoint | ✓ | ✓ | | | | | | ++--------------------------------------------+-------+-------+-------+------------+-------+-------+-------+ +| EndTransaction | ✓ | ✓ | | | | | | ++--------------------------------------------+-------+-------+-------+------------+-------+-------+-------+ | GetCatalogs | ✓ | ✓ | ✓ | | | | | +--------------------------------------------+-------+-------+-------+------------+-------+-------+-------+ | GetCrossReference | ✓ | ✓ | ✓ | | | | | @@ -260,6 +272,8 @@ support/not support individual features. +--------------------------------------------+-------+-------+-------+------------+-------+-------+-------+ | PreparedStatementUpdate | ✓ | ✓ | ✓ | | | | | +--------------------------------------------+-------+-------+-------+------------+-------+-------+-------+ +| StatementSubstraitPlan | ✓ | ✓ | | | | | | ++--------------------------------------------+-------+-------+-------+------------+-------+-------+-------+ | StatementQuery | ✓ | ✓ | ✓ | | | | | +--------------------------------------------+-------+-------+-------+------------+-------+-------+-------+ | StatementUpdate | ✓ | ✓ | ✓ | | | | | diff --git a/format/FlightSql.proto b/format/FlightSql.proto index 859427b68804b..d8a6cb5bfdb07 100644 --- a/format/FlightSql.proto +++ b/format/FlightSql.proto @@ -90,6 +90,64 @@ enum SqlInfo { */ FLIGHT_SQL_SERVER_READ_ONLY = 3; + /* + * Retrieves a boolean value indicating whether the Flight SQL Server supports executing + * SQL queries. + * + * Note that the absence of this info (as opposed to a false value) does not necessarily + * mean that SQL is not supported, as this property was not originally defined. + */ + FLIGHT_SQL_SERVER_SQL = 4; + + /* + * Retrieves a boolean value indicating whether the Flight SQL Server supports executing + * Substrait plans. + */ + FLIGHT_SQL_SERVER_SUBSTRAIT = 5; + + /* + * Retrieves a string value indicating the minimum supported Substrait version, or null + * if Substrait is not supported. + */ + FLIGHT_SQL_SERVER_SUBSTRAIT_MIN_VERSION = 6; + + /* + * Retrieves a string value indicating the maximum supported Substrait version, or null + * if Substrait is not supported. + */ + FLIGHT_SQL_SERVER_SUBSTRAIT_MAX_VERSION = 7; + + /* + * Retrieves an int32 indicating whether the Flight SQL Server supports the + * BeginTransaction/EndTransaction/BeginSavepoint/EndSavepoint actions. + * + * Even if this is not supported, the database may still support explicit "BEGIN + * TRANSACTION"/"COMMIT" SQL statements (see SQL_TRANSACTIONS_SUPPORTED); this property + * is only about whether the server implements the Flight SQL API endpoints. + * + * The possible values are listed in `SqlSupportedTransaction`. + */ + FLIGHT_SQL_SERVER_TRANSACTION = 8; + + /* + * Retrieves a boolean value indicating whether the Flight SQL Server supports explicit + * query cancellation (the CancelQuery action). + */ + FLIGHT_SQL_SERVER_CANCEL = 9; + + /* + * Retrieves an int32 indicating the timeout (in milliseconds) for prepared statement handles. + * + * If 0, there is no timeout. Servers should reset the timeout when the handle is used in a command. + */ + FLIGHT_SQL_SERVER_STATEMENT_TIMEOUT = 100; + + /* + * Retrieves an int32 indicating the timeout (in milliseconds) for transactions, since transactions are not tied to a connection. + * + * If 0, there is no timeout. Servers should reset the timeout when the handle is used in a command. + */ + FLIGHT_SQL_SERVER_TRANSACTION_TIMEOUT = 101; // SQL Syntax Information [500-1000): provides information about SQL syntax supported by the Flight SQL Server. @@ -761,6 +819,18 @@ enum SqlInfo { SQL_STORED_FUNCTIONS_USING_CALL_SYNTAX_SUPPORTED = 576; } +// The level of support for Flight SQL transaction RPCs. +enum SqlSupportedTransaction { + // Unknown/not indicated/no support + SQL_SUPPORTED_TRANSACTION_NONE = 0; + // Transactions, but not savepoints. + // A savepoint is a mark within a transaction that can be individually + // rolled back to. Not all databases support savepoints. + SQL_SUPPORTED_TRANSACTION_TRANSACTION = 1; + // Transactions and savepoints + SQL_SUPPORTED_TRANSACTION_SAVEPOINT = 2; +} + enum SqlSupportedCaseSensitivity { SQL_CASE_SENSITIVITY_UNKNOWN = 0; SQL_CASE_SENSITIVITY_CASE_INSENSITIVE = 1; @@ -1406,7 +1476,7 @@ message CommandGetCrossReference { string fk_table = 6; } -// SQL Execution Action Messages +// Query Execution Action Messages /* * Request message for the "CreatePreparedStatement" action on a Flight SQL enabled backend. @@ -1416,14 +1486,49 @@ message ActionCreatePreparedStatementRequest { // The valid SQL string to create a prepared statement for. string query = 1; + // Create/execute the prepared statement as part of this transaction (if + // unset, executions of the prepared statement will be auto-committed). + optional bytes transaction_id = 2; } /* - * Wrap the result of a "GetPreparedStatement" action. + * An embedded message describing a Substrait plan to execute. + */ +message SubstraitPlan { + option (experimental) = true; + + // The serialized substrait.Plan to create a prepared statement for. + // XXX(ARROW-16902): this is bytes instead of an embedded message + // because Protobuf does not really support one DLL using Protobuf + // definitions from another DLL. + bytes plan = 1; + // The Substrait release, e.g. "0.12.0". This information is not + // tracked in the plan itself, so this is the only way for consumers + // to potentially know if they can handle the plan. + string version = 2; +} + +/* + * Request message for the "CreatePreparedSubstraitPlan" action on a Flight SQL enabled backend. + */ +message ActionCreatePreparedSubstraitPlanRequest { + option (experimental) = true; + + // The serialized substrait.Plan to create a prepared statement for. + SubstraitPlan plan = 1; + // Create/execute the prepared statement as part of this transaction (if + // unset, executions of the prepared statement will be auto-committed). + optional bytes transaction_id = 2; +} + +/* + * Wrap the result of a "CreatePreparedStatement" or "CreatePreparedSubstraitPlan" action. * * The resultant PreparedStatement can be closed either: * - Manually, through the "ClosePreparedStatement" action; * - Automatically, by a server timeout. + * + * The result should be wrapped in a google.protobuf.Any message. */ message ActionCreatePreparedStatementResult { option (experimental) = true; @@ -1451,8 +1556,113 @@ message ActionClosePreparedStatementRequest { bytes prepared_statement_handle = 1; } +/* + * Request message for the "BeginTransaction" action. + * Begins a transaction. + */ +message ActionBeginTransactionRequest { + option (experimental) = true; +} + +/* + * Request message for the "BeginSavepoint" action. + * Creates a savepoint within a transaction. + * + * Only supported if FLIGHT_SQL_TRANSACTION is + * FLIGHT_SQL_TRANSACTION_SUPPORT_SAVEPOINT. + */ +message ActionBeginSavepointRequest { + option (experimental) = true; + + // The transaction to which a savepoint belongs. + bytes transaction_id = 1; + // Name for the savepoint. + string name = 2; +} + +/* + * The result of a "BeginTransaction" action. + * + * The transaction can be manipulated with the "EndTransaction" action, or + * automatically via server timeout. If the transaction times out, then it is + * automatically rolled back. + * + * The result should be wrapped in a google.protobuf.Any message. + */ +message ActionBeginTransactionResult { + option (experimental) = true; + + // Opaque handle for the transaction on the server. + bytes transaction_id = 1; +} + +/* + * The result of a "BeginSavepoint" action. + * + * The transaction can be manipulated with the "EndSavepoint" action. + * If the associated transaction is committed, rolled back, or times + * out, then the savepoint is also invalidated. + * + * The result should be wrapped in a google.protobuf.Any message. + */ +message ActionBeginSavepointResult { + option (experimental) = true; + + // Opaque handle for the savepoint on the server. + bytes savepoint_id = 1; +} + +/* + * Request message for the "EndTransaction" action. + * + * Commit (COMMIT) or rollback (ROLLBACK) the transaction. + * + * If the action completes successfully, the transaction handle is + * invalidated, as are all associated savepoints. + */ +message ActionEndTransactionRequest { + option (experimental) = true; -// SQL Execution Messages. + enum EndTransaction { + END_TRANSACTION_UNSPECIFIED = 0; + // Commit the transaction. + END_TRANSACTION_COMMIT = 1; + // Roll back the transaction. + END_TRANSACTION_ROLLBACK = 2; + } + // Opaque handle for the transaction on the server. + bytes transaction_id = 1; + // Whether to commit/rollback the given transaction. + EndTransaction action = 2; +} + +/* + * Request message for the "EndSavepoint" action. + * + * Release (RELEASE) the savepoint or rollback (ROLLBACK) to the + * savepoint. + * + * Releasing a savepoint invalidates that savepoint. Rolling back to + * a savepoint does not invalidate the savepoint, but invalidates all + * savepoints created after the current savepoint. + */ +message ActionEndSavepointRequest { + option (experimental) = true; + + enum EndSavepoint { + END_SAVEPOINT_UNSPECIFIED = 0; + // Release the savepoint. + END_SAVEPOINT_RELEASE = 1; + // Roll back to a savepoint. + END_SAVEPOINT_ROLLBACK = 2; + } + // Opaque handle for the savepoint on the server. + bytes savepoint_id = 1; + // Whether to rollback/release the given savepoint. + EndSavepoint action = 2; +} + +// Query Execution Messages. /* * Represents a SQL query. Used in the command member of FlightDescriptor @@ -1476,6 +1686,35 @@ message CommandStatementQuery { // The SQL syntax. string query = 1; + // Include the query as part of this transaction (if unset, the query is auto-committed). + optional bytes transaction_id = 2; +} + +/* + * Represents a Substrait plan. Used in the command member of FlightDescriptor + * for the following RPC calls: + * - GetSchema: return the Arrow schema of the query. + * Fields on this schema may contain the following metadata: + * - ARROW:FLIGHT:SQL:CATALOG_NAME - Table's catalog name + * - ARROW:FLIGHT:SQL:DB_SCHEMA_NAME - Database schema name + * - ARROW:FLIGHT:SQL:TABLE_NAME - Table name + * - ARROW:FLIGHT:SQL:TYPE_NAME - The data source-specific name for the data type of the column. + * - ARROW:FLIGHT:SQL:PRECISION - Column precision/size + * - ARROW:FLIGHT:SQL:SCALE - Column scale/decimal digits if applicable + * - ARROW:FLIGHT:SQL:IS_AUTO_INCREMENT - "1" indicates if the column is auto incremented, "0" otherwise. + * - ARROW:FLIGHT:SQL:IS_CASE_SENSITIVE - "1" indicates if the column is case sensitive, "0" otherwise. + * - ARROW:FLIGHT:SQL:IS_READ_ONLY - "1" indicates if the column is read only, "0" otherwise. + * - ARROW:FLIGHT:SQL:IS_SEARCHABLE - "1" indicates if the column is searchable via WHERE clause, "0" otherwise. + * - GetFlightInfo: execute the query. + * - DoPut: execute the query. + */ +message CommandStatementSubstraitPlan { + option (experimental) = true; + + // A serialized substrait.Plan + SubstraitPlan plan = 1; + // Include the query as part of this transaction (if unset, the query is auto-committed). + optional bytes transaction_id = 2; } /** @@ -1523,6 +1762,8 @@ message CommandStatementUpdate { // The SQL syntax. string query = 1; + // Include the query as part of this transaction (if unset, the query is auto-committed). + optional bytes transaction_id = 2; } /* @@ -1550,6 +1791,57 @@ message DoPutUpdateResult { int64 record_count = 1; } +/* + * Request message for the "CancelQuery" action. + * + * Explicitly cancel a running query. + * + * This lets a single client explicitly cancel work, no matter how many clients + * are involved/whether the query is distributed or not, given server support. + * The transaction/statement is not rolled back; it is the application's job to + * commit or rollback as appropriate. This only indicates the client no longer + * wishes to read the remainder of the query results or continue submitting + * data. + * + * This command is idempotent. + */ +message ActionCancelQueryRequest { + option (experimental) = true; + + // The result of the GetFlightInfo RPC that initiated the query. + // XXX(ARROW-16902): this must be a serialized FlightInfo, but is + // rendered as bytes because Protobuf does not really support one + // DLL using Protobuf definitions from another DLL. + bytes info = 1; +} + +/* + * The result of cancelling a query. + * + * The result should be wrapped in a google.protobuf.Any message. + */ +message ActionCancelQueryResult { + option (experimental) = true; + + enum CancelResult { + // The cancellation status is unknown. Servers should avoid using + // this value (send a NOT_FOUND error if the requested query is + // not known). Clients can retry the request. + CANCEL_RESULT_UNSPECIFIED = 0; + // The cancellation request is complete. Subsequent requests with + // the same payload may return CANCELLED or a NOT_FOUND error. + CANCEL_RESULT_CANCELLED = 1; + // The cancellation request is in progress. The client may retry + // the cancellation request. + CANCEL_RESULT_CANCELLING = 2; + // The query is not cancellable. The client should not retry the + // cancellation request. + CANCEL_RESULT_NOT_CANCELLABLE = 3; + } + + CancelResult result = 1; +} + extend google.protobuf.MessageOptions { bool experimental = 1000; } diff --git a/java/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightClient.java b/java/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightClient.java index 762b37859b948..1f50f50a293f9 100644 --- a/java/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightClient.java +++ b/java/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightClient.java @@ -292,7 +292,12 @@ public FlightInfo getInfo(FlightDescriptor descriptor, CallOption... options) { * @param options RPC-layer hints for this call. */ public SchemaResult getSchema(FlightDescriptor descriptor, CallOption... options) { - return SchemaResult.fromProtocol(CallOptions.wrapStub(blockingStub, options).getSchema(descriptor.toProtocol())); + try { + return SchemaResult.fromProtocol(CallOptions.wrapStub(blockingStub, options) + .getSchema(descriptor.toProtocol())); + } catch (StatusRuntimeException sre) { + throw StatusUtils.fromGrpcRuntimeException(sre); + } } /** diff --git a/java/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightService.java b/java/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightService.java index 4fb0dea2cba26..29a4f2bbd19ea 100644 --- a/java/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightService.java +++ b/java/flight/flight-core/src/main/java/org/apache/arrow/flight/FlightService.java @@ -231,7 +231,7 @@ public StreamObserver doPutCustom(final StreamObserver { try { producer.acceptPut(makeContext(responseObserver), fs, ackStream).run(); - } catch (Exception ex) { + } catch (Throwable ex) { ackStream.onError(ex); } finally { // ARROW-6136: Close the stream if and only if acceptPut hasn't closed it itself diff --git a/java/flight/flight-integration-tests/src/main/java/org/apache/arrow/flight/integration/tests/FlightSqlExtensionScenario.java b/java/flight/flight-integration-tests/src/main/java/org/apache/arrow/flight/integration/tests/FlightSqlExtensionScenario.java new file mode 100644 index 0000000000000..cd20ae4f46f1f --- /dev/null +++ b/java/flight/flight-integration-tests/src/main/java/org/apache/arrow/flight/integration/tests/FlightSqlExtensionScenario.java @@ -0,0 +1,217 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.flight.integration.tests; + +import java.util.HashMap; +import java.util.Map; + +import org.apache.arrow.flight.FlightClient; +import org.apache.arrow.flight.FlightInfo; +import org.apache.arrow.flight.FlightStream; +import org.apache.arrow.flight.Location; +import org.apache.arrow.flight.SchemaResult; +import org.apache.arrow.flight.Ticket; +import org.apache.arrow.flight.sql.CancelResult; +import org.apache.arrow.flight.sql.FlightSqlClient; +import org.apache.arrow.flight.sql.FlightSqlProducer; +import org.apache.arrow.flight.sql.impl.FlightSql; +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.util.Preconditions; +import org.apache.arrow.vector.UInt4Vector; +import org.apache.arrow.vector.VectorSchemaRoot; +import org.apache.arrow.vector.complex.DenseUnionVector; +import org.apache.arrow.vector.types.pojo.Schema; + +/** + * Integration test scenario for validating Flight SQL specs across multiple implementations. + * This should ensure that RPC objects are being built and parsed correctly for multiple languages + * and that the Arrow schemas are returned as expected. + */ +public class FlightSqlExtensionScenario extends FlightSqlScenario { + @Override + public void client(BufferAllocator allocator, Location location, FlightClient client) + throws Exception { + try (final FlightSqlClient sqlClient = new FlightSqlClient(client)) { + validateMetadataRetrieval(sqlClient); + validateStatementExecution(sqlClient); + validatePreparedStatementExecution(allocator, sqlClient); + validateTransactions(allocator, sqlClient); + } + } + + private void validateMetadataRetrieval(FlightSqlClient sqlClient) throws Exception { + FlightInfo info = sqlClient.getSqlInfo(); + Ticket ticket = info.getEndpoints().get(0).getTicket(); + + Map infoValues = new HashMap<>(); + try (FlightStream stream = sqlClient.getStream(ticket)) { + Schema actualSchema = stream.getSchema(); + IntegrationAssertions.assertEquals(FlightSqlProducer.Schemas.GET_SQL_INFO_SCHEMA, actualSchema); + + while (stream.next()) { + UInt4Vector infoName = (UInt4Vector) stream.getRoot().getVector(0); + DenseUnionVector value = (DenseUnionVector) stream.getRoot().getVector(1); + + for (int i = 0; i < stream.getRoot().getRowCount(); i++) { + final int code = infoName.get(i); + if (infoValues.containsKey(code)) { + throw new AssertionError("Duplicate SqlInfo value: " + code); + } + Object object; + byte typeId = value.getTypeId(i); + switch (typeId) { + case 0: // string + object = Preconditions.checkNotNull(value.getVarCharVector(typeId) + .getObject(value.getOffset(i))) + .toString(); + break; + case 1: // bool + object = value.getBitVector(typeId).getObject(value.getOffset(i)); + break; + case 2: // int64 + object = value.getBigIntVector(typeId).getObject(value.getOffset(i)); + break; + case 3: // int32 + object = value.getIntVector(typeId).getObject(value.getOffset(i)); + break; + default: + throw new AssertionError("Decoding SqlInfo of type code " + typeId); + } + infoValues.put(code, object); + } + } + } + + IntegrationAssertions.assertEquals(Boolean.FALSE, + infoValues.get(FlightSql.SqlInfo.FLIGHT_SQL_SERVER_SQL_VALUE)); + IntegrationAssertions.assertEquals(Boolean.TRUE, + infoValues.get(FlightSql.SqlInfo.FLIGHT_SQL_SERVER_SUBSTRAIT_VALUE)); + IntegrationAssertions.assertEquals("min_version", + infoValues.get(FlightSql.SqlInfo.FLIGHT_SQL_SERVER_SUBSTRAIT_MIN_VERSION_VALUE)); + IntegrationAssertions.assertEquals("max_version", + infoValues.get(FlightSql.SqlInfo.FLIGHT_SQL_SERVER_SUBSTRAIT_MAX_VERSION_VALUE)); + IntegrationAssertions.assertEquals(FlightSql.SqlSupportedTransaction.SQL_SUPPORTED_TRANSACTION_SAVEPOINT_VALUE, + infoValues.get(FlightSql.SqlInfo.FLIGHT_SQL_SERVER_TRANSACTION_VALUE)); + IntegrationAssertions.assertEquals(Boolean.TRUE, + infoValues.get(FlightSql.SqlInfo.FLIGHT_SQL_SERVER_CANCEL_VALUE)); + IntegrationAssertions.assertEquals(42, + infoValues.get(FlightSql.SqlInfo.FLIGHT_SQL_SERVER_STATEMENT_TIMEOUT_VALUE)); + IntegrationAssertions.assertEquals(7, + infoValues.get(FlightSql.SqlInfo.FLIGHT_SQL_SERVER_TRANSACTION_TIMEOUT_VALUE)); + } + + private void validateStatementExecution(FlightSqlClient sqlClient) throws Exception { + FlightInfo info = sqlClient.executeSubstrait(SUBSTRAIT_PLAN); + validate(FlightSqlScenarioProducer.getQuerySchema(), info, sqlClient); + + SchemaResult result = sqlClient.getExecuteSubstraitSchema(SUBSTRAIT_PLAN); + validateSchema(FlightSqlScenarioProducer.getQuerySchema(), result); + + IntegrationAssertions.assertEquals(CancelResult.CANCELLED, sqlClient.cancelQuery(info)); + + IntegrationAssertions.assertEquals(sqlClient.executeSubstraitUpdate(SUBSTRAIT_PLAN), + UPDATE_STATEMENT_EXPECTED_ROWS); + } + + private void validatePreparedStatementExecution(BufferAllocator allocator, + FlightSqlClient sqlClient) throws Exception { + try (FlightSqlClient.PreparedStatement preparedStatement = sqlClient.prepare(SUBSTRAIT_PLAN); + VectorSchemaRoot parameters = VectorSchemaRoot.create( + FlightSqlScenarioProducer.getQuerySchema(), allocator)) { + parameters.setRowCount(1); + preparedStatement.setParameters(parameters); + validate(FlightSqlScenarioProducer.getQuerySchema(), preparedStatement.execute(), sqlClient); + validateSchema(FlightSqlScenarioProducer.getQuerySchema(), preparedStatement.fetchSchema()); + } + + try (FlightSqlClient.PreparedStatement preparedStatement = sqlClient.prepare(SUBSTRAIT_PLAN)) { + IntegrationAssertions.assertEquals(preparedStatement.executeUpdate(), + UPDATE_PREPARED_STATEMENT_EXPECTED_ROWS); + } + } + + private void validateTransactions(BufferAllocator allocator, FlightSqlClient sqlClient) throws Exception { + final FlightSqlClient.Transaction transaction = sqlClient.beginTransaction(); + IntegrationAssertions.assertEquals(TRANSACTION_ID, transaction.getTransactionId()); + + final FlightSqlClient.Savepoint savepoint = sqlClient.beginSavepoint(transaction, SAVEPOINT_NAME); + IntegrationAssertions.assertEquals(SAVEPOINT_ID, savepoint.getSavepointId()); + + FlightInfo info = sqlClient.execute("SELECT STATEMENT", transaction); + validate(FlightSqlScenarioProducer.getQueryWithTransactionSchema(), info, sqlClient); + + info = sqlClient.executeSubstrait(SUBSTRAIT_PLAN, transaction); + validate(FlightSqlScenarioProducer.getQueryWithTransactionSchema(), info, sqlClient); + + SchemaResult schema = sqlClient.getExecuteSchema("SELECT STATEMENT", transaction); + validateSchema(FlightSqlScenarioProducer.getQueryWithTransactionSchema(), schema); + + schema = sqlClient.getExecuteSubstraitSchema(SUBSTRAIT_PLAN, transaction); + validateSchema(FlightSqlScenarioProducer.getQueryWithTransactionSchema(), schema); + + IntegrationAssertions.assertEquals(sqlClient.executeUpdate("UPDATE STATEMENT", transaction), + UPDATE_STATEMENT_WITH_TRANSACTION_EXPECTED_ROWS); + IntegrationAssertions.assertEquals(sqlClient.executeSubstraitUpdate(SUBSTRAIT_PLAN, transaction), + UPDATE_STATEMENT_WITH_TRANSACTION_EXPECTED_ROWS); + + try (FlightSqlClient.PreparedStatement preparedStatement = sqlClient.prepare( + "SELECT PREPARED STATEMENT", transaction); + VectorSchemaRoot parameters = VectorSchemaRoot.create( + FlightSqlScenarioProducer.getQuerySchema(), allocator)) { + parameters.setRowCount(1); + preparedStatement.setParameters(parameters); + validate(FlightSqlScenarioProducer.getQueryWithTransactionSchema(), preparedStatement.execute(), sqlClient); + schema = preparedStatement.fetchSchema(); + validateSchema(FlightSqlScenarioProducer.getQueryWithTransactionSchema(), schema); + } + + try (FlightSqlClient.PreparedStatement preparedStatement = sqlClient.prepare(SUBSTRAIT_PLAN, transaction); + VectorSchemaRoot parameters = VectorSchemaRoot.create( + FlightSqlScenarioProducer.getQuerySchema(), allocator)) { + parameters.setRowCount(1); + preparedStatement.setParameters(parameters); + validate(FlightSqlScenarioProducer.getQueryWithTransactionSchema(), preparedStatement.execute(), sqlClient); + schema = preparedStatement.fetchSchema(); + validateSchema(FlightSqlScenarioProducer.getQueryWithTransactionSchema(), schema); + } + + try (FlightSqlClient.PreparedStatement preparedStatement = + sqlClient.prepare("UPDATE PREPARED STATEMENT", transaction)) { + IntegrationAssertions.assertEquals(preparedStatement.executeUpdate(), + UPDATE_PREPARED_STATEMENT_WITH_TRANSACTION_EXPECTED_ROWS); + } + + try (FlightSqlClient.PreparedStatement preparedStatement = + sqlClient.prepare(SUBSTRAIT_PLAN, transaction)) { + IntegrationAssertions.assertEquals(preparedStatement.executeUpdate(), + UPDATE_PREPARED_STATEMENT_WITH_TRANSACTION_EXPECTED_ROWS); + } + + sqlClient.rollback(savepoint); + + final FlightSqlClient.Savepoint savepoint2 = sqlClient.beginSavepoint(transaction, SAVEPOINT_NAME); + IntegrationAssertions.assertEquals(SAVEPOINT_ID, savepoint2.getSavepointId()); + sqlClient.release(savepoint); + + sqlClient.commit(transaction); + + final FlightSqlClient.Transaction transaction2 = sqlClient.beginTransaction(); + IntegrationAssertions.assertEquals(TRANSACTION_ID, transaction2.getTransactionId()); + sqlClient.rollback(transaction); + } +} diff --git a/java/flight/flight-integration-tests/src/main/java/org/apache/arrow/flight/integration/tests/FlightSqlScenario.java b/java/flight/flight-integration-tests/src/main/java/org/apache/arrow/flight/integration/tests/FlightSqlScenario.java index 19c1378cfe6c5..71f1f741d5871 100644 --- a/java/flight/flight-integration-tests/src/main/java/org/apache/arrow/flight/integration/tests/FlightSqlScenario.java +++ b/java/flight/flight-integration-tests/src/main/java/org/apache/arrow/flight/integration/tests/FlightSqlScenario.java @@ -17,6 +17,7 @@ package org.apache.arrow.flight.integration.tests; +import java.nio.charset.StandardCharsets; import java.util.Arrays; import org.apache.arrow.flight.CallOption; @@ -42,9 +43,17 @@ * and that the Arrow schemas are returned as expected. */ public class FlightSqlScenario implements Scenario { - public static final long UPDATE_STATEMENT_EXPECTED_ROWS = 10000L; + public static final long UPDATE_STATEMENT_WITH_TRANSACTION_EXPECTED_ROWS = 15000L; public static final long UPDATE_PREPARED_STATEMENT_EXPECTED_ROWS = 20000L; + public static final long UPDATE_PREPARED_STATEMENT_WITH_TRANSACTION_EXPECTED_ROWS = 25000L; + public static final byte[] SAVEPOINT_ID = "savepoint_id".getBytes(StandardCharsets.UTF_8); + public static final String SAVEPOINT_NAME = "savepoint_name"; + public static final byte[] SUBSTRAIT_PLAN_TEXT = "plan".getBytes(StandardCharsets.UTF_8); + public static final String SUBSTRAIT_VERSION = "version"; + public static final FlightSqlClient.SubstraitPlan SUBSTRAIT_PLAN = + new FlightSqlClient.SubstraitPlan(SUBSTRAIT_PLAN_TEXT, SUBSTRAIT_VERSION); + public static final byte[] TRANSACTION_ID = "transaction_id".getBytes(StandardCharsets.UTF_8); @Override public FlightProducer producer(BufferAllocator allocator, Location location) throws Exception { @@ -59,13 +68,11 @@ public void buildServer(FlightServer.Builder builder) throws Exception { @Override public void client(BufferAllocator allocator, Location location, FlightClient client) throws Exception { - final FlightSqlClient sqlClient = new FlightSqlClient(client); - - validateMetadataRetrieval(sqlClient); - - validateStatementExecution(sqlClient); - - validatePreparedStatementExecution(sqlClient, allocator); + try (final FlightSqlClient sqlClient = new FlightSqlClient(client)) { + validateMetadataRetrieval(sqlClient); + validateStatementExecution(sqlClient); + validatePreparedStatementExecution(allocator, sqlClient); + } } private void validateMetadataRetrieval(FlightSqlClient sqlClient) throws Exception { @@ -122,40 +129,35 @@ private void validateMetadataRetrieval(FlightSqlClient sqlClient) throws Excepti } private void validateStatementExecution(FlightSqlClient sqlClient) throws Exception { - final CallOption[] options = new CallOption[0]; - - validate(FlightSqlScenarioProducer.getQuerySchema(), - sqlClient.execute("SELECT STATEMENT", options), sqlClient); + FlightInfo info = sqlClient.execute("SELECT STATEMENT"); + validate(FlightSqlScenarioProducer.getQuerySchema(), info, sqlClient); validateSchema(FlightSqlScenarioProducer.getQuerySchema(), - sqlClient.getExecuteSchema("SELECT STATEMENT", options)); + sqlClient.getExecuteSchema("SELECT STATEMENT")); - IntegrationAssertions.assertEquals(sqlClient.executeUpdate("UPDATE STATEMENT", options), + IntegrationAssertions.assertEquals(sqlClient.executeUpdate("UPDATE STATEMENT"), UPDATE_STATEMENT_EXPECTED_ROWS); } - private void validatePreparedStatementExecution(FlightSqlClient sqlClient, - BufferAllocator allocator) throws Exception { - final CallOption[] options = new CallOption[0]; + private void validatePreparedStatementExecution(BufferAllocator allocator, + FlightSqlClient sqlClient) throws Exception { try (FlightSqlClient.PreparedStatement preparedStatement = sqlClient.prepare( "SELECT PREPARED STATEMENT"); VectorSchemaRoot parameters = VectorSchemaRoot.create( FlightSqlScenarioProducer.getQuerySchema(), allocator)) { parameters.setRowCount(1); preparedStatement.setParameters(parameters); - - validate(FlightSqlScenarioProducer.getQuerySchema(), preparedStatement.execute(options), - sqlClient); + validate(FlightSqlScenarioProducer.getQuerySchema(), preparedStatement.execute(), sqlClient); validateSchema(FlightSqlScenarioProducer.getQuerySchema(), preparedStatement.fetchSchema()); } - try (FlightSqlClient.PreparedStatement preparedStatement = sqlClient.prepare( - "UPDATE PREPARED STATEMENT")) { - IntegrationAssertions.assertEquals(preparedStatement.executeUpdate(options), + try (FlightSqlClient.PreparedStatement preparedStatement = + sqlClient.prepare("UPDATE PREPARED STATEMENT")) { + IntegrationAssertions.assertEquals(preparedStatement.executeUpdate(), UPDATE_PREPARED_STATEMENT_EXPECTED_ROWS); } } - private void validate(Schema expectedSchema, FlightInfo flightInfo, + protected void validate(Schema expectedSchema, FlightInfo flightInfo, FlightSqlClient sqlClient) throws Exception { Ticket ticket = flightInfo.getEndpoints().get(0).getTicket(); try (FlightStream stream = sqlClient.getStream(ticket)) { @@ -164,7 +166,7 @@ private void validate(Schema expectedSchema, FlightInfo flightInfo, } } - private void validateSchema(Schema expected, SchemaResult actual) { + protected void validateSchema(Schema expected, SchemaResult actual) { IntegrationAssertions.assertEquals(expected, actual.getSchema()); } } diff --git a/java/flight/flight-integration-tests/src/main/java/org/apache/arrow/flight/integration/tests/FlightSqlScenarioProducer.java b/java/flight/flight-integration-tests/src/main/java/org/apache/arrow/flight/integration/tests/FlightSqlScenarioProducer.java index 33d62b650e176..4ed9a3df0fc62 100644 --- a/java/flight/flight-integration-tests/src/main/java/org/apache/arrow/flight/integration/tests/FlightSqlScenarioProducer.java +++ b/java/flight/flight-integration-tests/src/main/java/org/apache/arrow/flight/integration/tests/FlightSqlScenarioProducer.java @@ -17,11 +17,11 @@ package org.apache.arrow.flight.integration.tests; -import static com.google.protobuf.Any.pack; -import static java.util.Collections.singletonList; - +import java.util.Arrays; +import java.util.Collections; import java.util.List; +import org.apache.arrow.flight.CallStatus; import org.apache.arrow.flight.Criteria; import org.apache.arrow.flight.FlightDescriptor; import org.apache.arrow.flight.FlightEndpoint; @@ -31,8 +31,10 @@ import org.apache.arrow.flight.Result; import org.apache.arrow.flight.SchemaResult; import org.apache.arrow.flight.Ticket; +import org.apache.arrow.flight.sql.CancelResult; import org.apache.arrow.flight.sql.FlightSqlColumnMetadata; import org.apache.arrow.flight.sql.FlightSqlProducer; +import org.apache.arrow.flight.sql.SqlInfoBuilder; import org.apache.arrow.flight.sql.impl.FlightSql; import org.apache.arrow.memory.ArrowBuf; import org.apache.arrow.memory.BufferAllocator; @@ -42,7 +44,9 @@ import org.apache.arrow.vector.types.pojo.FieldType; import org.apache.arrow.vector.types.pojo.Schema; +import com.google.protobuf.Any; import com.google.protobuf.ByteString; +import com.google.protobuf.InvalidProtocolBufferException; import com.google.protobuf.Message; /** @@ -61,7 +65,7 @@ public FlightSqlScenarioProducer(BufferAllocator allocator) { */ static Schema getQuerySchema() { return new Schema( - singletonList( + Collections.singletonList( new Field("id", new FieldType(true, new ArrowType.Int(64, true), null, new FlightSqlColumnMetadata.Builder() .tableName("test") @@ -77,6 +81,94 @@ static Schema getQuerySchema() { ); } + /** + * The expected schema for queries with transactions. + *

+ * Must be the same across all languages. + */ + static Schema getQueryWithTransactionSchema() { + return new Schema( + Collections.singletonList( + new Field("pkey", new FieldType(true, new ArrowType.Int(32, true), + null, new FlightSqlColumnMetadata.Builder() + .tableName("test") + .isAutoIncrement(true) + .isCaseSensitive(false) + .typeName("type_test") + .schemaName("schema_test") + .isSearchable(true) + .catalogName("catalog_test") + .precision(100) + .build().getMetadataMap()), null) + ) + ); + } + + @Override + public void beginSavepoint(FlightSql.ActionBeginSavepointRequest request, CallContext context, + StreamListener listener) { + if (!request.getName().equals(FlightSqlScenario.SAVEPOINT_NAME)) { + listener.onError(CallStatus.INVALID_ARGUMENT + .withDescription(String.format("Expected name '%s', not '%s'", + FlightSqlScenario.SAVEPOINT_NAME, request.getName())) + .toRuntimeException()); + return; + } + if (!Arrays.equals(request.getTransactionId().toByteArray(), FlightSqlScenario.TRANSACTION_ID)) { + listener.onError(CallStatus.INVALID_ARGUMENT + .withDescription(String.format("Expected transaction ID '%s', not '%s'", + Arrays.toString(FlightSqlScenario.TRANSACTION_ID), + Arrays.toString(request.getTransactionId().toByteArray()))) + .toRuntimeException()); + return; + } + listener.onNext(FlightSql.ActionBeginSavepointResult.newBuilder() + .setSavepointId(ByteString.copyFrom(FlightSqlScenario.SAVEPOINT_ID)) + .build()); + listener.onCompleted(); + } + + @Override + public void beginTransaction(FlightSql.ActionBeginTransactionRequest request, CallContext context, + StreamListener listener) { + listener.onNext(FlightSql.ActionBeginTransactionResult.newBuilder() + .setTransactionId(ByteString.copyFrom(FlightSqlScenario.TRANSACTION_ID)) + .build()); + listener.onCompleted(); + } + + @Override + public void cancelQuery(FlightInfo info, CallContext context, StreamListener listener) { + final String expectedTicket = "PLAN HANDLE"; + if (info.getEndpoints().size() != 1) { + listener.onError(CallStatus.INVALID_ARGUMENT + .withDescription(String.format("Expected 1 endpoint, got %d", info.getEndpoints().size())) + .toRuntimeException()); + } + final FlightEndpoint endpoint = info.getEndpoints().get(0); + try { + final Any any = Any.parseFrom(endpoint.getTicket().getBytes()); + if (!any.is(FlightSql.TicketStatementQuery.class)) { + listener.onError(CallStatus.INVALID_ARGUMENT + .withDescription(String.format("Expected TicketStatementQuery, found '%s'", any.getTypeUrl())) + .toRuntimeException()); + } + final FlightSql.TicketStatementQuery ticket = any.unpack(FlightSql.TicketStatementQuery.class); + if (!ticket.getStatementHandle().toStringUtf8().equals(expectedTicket)) { + listener.onError(CallStatus.INVALID_ARGUMENT + .withDescription(String.format("Expected ticket '%s'", expectedTicket)) + .toRuntimeException()); + } + listener.onNext(CancelResult.CANCELLED); + listener.onCompleted(); + } catch (InvalidProtocolBufferException e) { + listener.onError(CallStatus.INVALID_ARGUMENT + .withDescription("Invalid Protobuf:" + e) + .withCause(e) + .toRuntimeException()); + } + } + @Override public void createPreparedStatement(FlightSql.ActionCreatePreparedStatementRequest request, CallContext context, StreamListener listener) { @@ -84,21 +176,106 @@ public void createPreparedStatement(FlightSql.ActionCreatePreparedStatementReque request.getQuery().equals("SELECT PREPARED STATEMENT") || request.getQuery().equals("UPDATE PREPARED STATEMENT")); + String text = request.getQuery(); + if (!request.getTransactionId().isEmpty()) { + text += " WITH TXN"; + } + text += " HANDLE"; final FlightSql.ActionCreatePreparedStatementResult result = FlightSql.ActionCreatePreparedStatementResult.newBuilder() - .setPreparedStatementHandle(ByteString.copyFromUtf8(request.getQuery() + " HANDLE")) + .setPreparedStatementHandle(ByteString.copyFromUtf8(text)) .build(); - listener.onNext(new Result(pack(result).toByteArray())); + listener.onNext(new Result(Any.pack(result).toByteArray())); + listener.onCompleted(); + } + + @Override + public void createPreparedSubstraitPlan(FlightSql.ActionCreatePreparedSubstraitPlanRequest request, + CallContext context, + StreamListener listener) { + if (!Arrays.equals(request.getPlan().getPlan().toByteArray(), FlightSqlScenario.SUBSTRAIT_PLAN_TEXT)) { + listener.onError(CallStatus.INVALID_ARGUMENT + .withDescription(String.format("Expected plan '%s', not '%s'", + Arrays.toString(FlightSqlScenario.SUBSTRAIT_PLAN_TEXT), + Arrays.toString(request.getPlan().getPlan().toByteArray()))) + .toRuntimeException()); + return; + } + if (!FlightSqlScenario.SUBSTRAIT_VERSION.equals(request.getPlan().getVersion())) { + listener.onError(CallStatus.INVALID_ARGUMENT + .withDescription(String.format("Expected version '%s', not '%s'", + FlightSqlScenario.SUBSTRAIT_VERSION, + request.getPlan().getVersion())) + .toRuntimeException()); + return; + } + final String handle = request.getTransactionId().isEmpty() ? + "PREPARED PLAN HANDLE" : "PREPARED PLAN WITH TXN HANDLE"; + final FlightSql.ActionCreatePreparedStatementResult result = + FlightSql.ActionCreatePreparedStatementResult.newBuilder() + .setPreparedStatementHandle(ByteString.copyFromUtf8(handle)) + .build(); + listener.onNext(result); listener.onCompleted(); } @Override public void closePreparedStatement(FlightSql.ActionClosePreparedStatementRequest request, CallContext context, StreamListener listener) { - IntegrationAssertions.assertTrue("Expect to be one of the two queries used on tests", - request.getPreparedStatementHandle().toStringUtf8().equals("SELECT PREPARED STATEMENT HANDLE") || - request.getPreparedStatementHandle().toStringUtf8().equals("UPDATE PREPARED STATEMENT HANDLE")); + final String handle = request.getPreparedStatementHandle().toStringUtf8(); + IntegrationAssertions.assertTrue("Expect to be one of the queries used on tests", + handle.equals("SELECT PREPARED STATEMENT HANDLE") || + handle.equals("SELECT PREPARED STATEMENT WITH TXN HANDLE") || + handle.equals("UPDATE PREPARED STATEMENT HANDLE") || + handle.equals("UPDATE PREPARED STATEMENT WITH TXN HANDLE") || + handle.equals("PREPARED PLAN HANDLE") || + handle.equals("PREPARED PLAN WITH TXN HANDLE")); + listener.onCompleted(); + } + @Override + public void endSavepoint(FlightSql.ActionEndSavepointRequest request, CallContext context, + StreamListener listener) { + switch (request.getAction()) { + case END_SAVEPOINT_RELEASE: + case END_SAVEPOINT_ROLLBACK: + if (!Arrays.equals(request.getSavepointId().toByteArray(), FlightSqlScenario.SAVEPOINT_ID)) { + listener.onError(CallStatus.INVALID_ARGUMENT + .withDescription("Unexpected ID: " + Arrays.toString(request.getSavepointId().toByteArray())) + .toRuntimeException()); + } + break; + case UNRECOGNIZED: + default: { + listener.onError(CallStatus.INVALID_ARGUMENT + .withDescription("Unknown action: " + request.getAction()) + .toRuntimeException()); + return; + } + } + listener.onCompleted(); + } + + @Override + public void endTransaction(FlightSql.ActionEndTransactionRequest request, CallContext context, + StreamListener listener) { + switch (request.getAction()) { + case END_TRANSACTION_COMMIT: + case END_TRANSACTION_ROLLBACK: + if (!Arrays.equals(request.getTransactionId().toByteArray(), FlightSqlScenario.TRANSACTION_ID)) { + listener.onError(CallStatus.INVALID_ARGUMENT + .withDescription("Unexpected ID: " + Arrays.toString(request.getTransactionId().toByteArray())) + .toRuntimeException()); + } + break; + case UNRECOGNIZED: + default: { + listener.onError(CallStatus.INVALID_ARGUMENT + .withDescription("Unknown action: " + request.getAction()) + .toRuntimeException()); + return; + } + } listener.onCompleted(); } @@ -106,11 +283,31 @@ public void closePreparedStatement(FlightSql.ActionClosePreparedStatementRequest public FlightInfo getFlightInfoStatement(FlightSql.CommandStatementQuery command, CallContext context, FlightDescriptor descriptor) { IntegrationAssertions.assertEquals(command.getQuery(), "SELECT STATEMENT"); + if (command.getTransactionId().isEmpty()) { + String handle = "SELECT STATEMENT HANDLE"; + FlightSql.TicketStatementQuery ticket = FlightSql.TicketStatementQuery.newBuilder() + .setStatementHandle(ByteString.copyFromUtf8(handle)) + .build(); + return getFlightInfoForSchema(ticket, descriptor, getQuerySchema()); + } else { + String handle = "SELECT STATEMENT WITH TXN HANDLE"; + FlightSql.TicketStatementQuery ticket = FlightSql.TicketStatementQuery.newBuilder() + .setStatementHandle(ByteString.copyFromUtf8(handle)) + .build(); + return getFlightInfoForSchema(ticket, descriptor, getQueryWithTransactionSchema()); + } + } - ByteString handle = ByteString.copyFromUtf8("SELECT STATEMENT HANDLE"); - + @Override + public FlightInfo getFlightInfoSubstraitPlan(FlightSql.CommandStatementSubstraitPlan command, CallContext context, + FlightDescriptor descriptor) { + IntegrationAssertions.assertEquals(command.getPlan().getPlan().toByteArray(), + FlightSqlScenario.SUBSTRAIT_PLAN_TEXT); + IntegrationAssertions.assertEquals(command.getPlan().getVersion(), FlightSqlScenario.SUBSTRAIT_VERSION); + String handle = command.getTransactionId().isEmpty() ? + "PLAN HANDLE" : "PLAN WITH TXN HANDLE"; FlightSql.TicketStatementQuery ticket = FlightSql.TicketStatementQuery.newBuilder() - .setStatementHandle(handle) + .setStatementHandle(ByteString.copyFromUtf8(handle)) .build(); return getFlightInfoForSchema(ticket, descriptor, getQuerySchema()); } @@ -119,37 +316,91 @@ public FlightInfo getFlightInfoStatement(FlightSql.CommandStatementQuery command public FlightInfo getFlightInfoPreparedStatement(FlightSql.CommandPreparedStatementQuery command, CallContext context, FlightDescriptor descriptor) { - IntegrationAssertions.assertEquals(command.getPreparedStatementHandle().toStringUtf8(), - "SELECT PREPARED STATEMENT HANDLE"); + String handle = command.getPreparedStatementHandle().toStringUtf8(); + if (handle.equals("SELECT PREPARED STATEMENT HANDLE") || + handle.equals("PREPARED PLAN HANDLE")) { + return getFlightInfoForSchema(command, descriptor, getQuerySchema()); + } else if (handle.equals("SELECT PREPARED STATEMENT WITH TXN HANDLE") || + handle.equals("PREPARED PLAN WITH TXN HANDLE")) { + return getFlightInfoForSchema(command, descriptor, getQueryWithTransactionSchema()); + } + throw CallStatus.INVALID_ARGUMENT.withDescription("Unknown handle: " + handle).toRuntimeException(); + } - return getFlightInfoForSchema(command, descriptor, getQuerySchema()); + @Override + public SchemaResult getSchemaStatement(FlightSql.CommandStatementQuery command, + CallContext context, FlightDescriptor descriptor) { + IntegrationAssertions.assertEquals(command.getQuery(), "SELECT STATEMENT"); + if (command.getTransactionId().isEmpty()) { + return new SchemaResult(getQuerySchema()); + } + return new SchemaResult(getQueryWithTransactionSchema()); } @Override public SchemaResult getSchemaPreparedStatement(FlightSql.CommandPreparedStatementQuery command, CallContext context, FlightDescriptor descriptor) { - IntegrationAssertions.assertEquals(command.getPreparedStatementHandle().toStringUtf8(), - "SELECT PREPARED STATEMENT HANDLE"); - return new SchemaResult(getQuerySchema()); + String handle = command.getPreparedStatementHandle().toStringUtf8(); + if (handle.equals("SELECT PREPARED STATEMENT HANDLE") || + handle.equals("PREPARED PLAN HANDLE")) { + return new SchemaResult(getQuerySchema()); + } else if (handle.equals("SELECT PREPARED STATEMENT WITH TXN HANDLE") || + handle.equals("PREPARED PLAN WITH TXN HANDLE")) { + return new SchemaResult(getQueryWithTransactionSchema()); + } + throw CallStatus.INVALID_ARGUMENT.withDescription("Unknown handle: " + handle).toRuntimeException(); } @Override - public SchemaResult getSchemaStatement(FlightSql.CommandStatementQuery command, - CallContext context, FlightDescriptor descriptor) { - IntegrationAssertions.assertEquals(command.getQuery(), "SELECT STATEMENT"); - return new SchemaResult(getQuerySchema()); + public SchemaResult getSchemaSubstraitPlan(FlightSql.CommandStatementSubstraitPlan command, CallContext context, + FlightDescriptor descriptor) { + if (!Arrays.equals(command.getPlan().getPlan().toByteArray(), FlightSqlScenario.SUBSTRAIT_PLAN_TEXT)) { + throw CallStatus.INVALID_ARGUMENT + .withDescription(String.format("Expected plan '%s', not '%s'", + Arrays.toString(FlightSqlScenario.SUBSTRAIT_PLAN_TEXT), + Arrays.toString(command.getPlan().getPlan().toByteArray()))) + .toRuntimeException(); + } + if (!FlightSqlScenario.SUBSTRAIT_VERSION.equals(command.getPlan().getVersion())) { + throw CallStatus.INVALID_ARGUMENT + .withDescription(String.format("Expected version '%s', not '%s'", + FlightSqlScenario.SUBSTRAIT_VERSION, + command.getPlan().getVersion())) + .toRuntimeException(); + } + if (command.getTransactionId().isEmpty()) { + return new SchemaResult(getQuerySchema()); + } + return new SchemaResult(getQueryWithTransactionSchema()); } @Override public void getStreamStatement(FlightSql.TicketStatementQuery ticket, CallContext context, ServerStreamListener listener) { - putEmptyBatchToStreamListener(listener, getQuerySchema()); + final String handle = ticket.getStatementHandle().toStringUtf8(); + if (handle.equals("SELECT STATEMENT HANDLE") || handle.equals("PLAN HANDLE")) { + putEmptyBatchToStreamListener(listener, getQuerySchema()); + } else if (handle.equals("SELECT STATEMENT WITH TXN HANDLE") || handle.equals("PLAN WITH TXN HANDLE")) { + putEmptyBatchToStreamListener(listener, getQueryWithTransactionSchema()); + } else { + listener.error(CallStatus.INVALID_ARGUMENT.withDescription("Unknown handle: " + handle).toRuntimeException()); + } } @Override public void getStreamPreparedStatement(FlightSql.CommandPreparedStatementQuery command, CallContext context, ServerStreamListener listener) { - putEmptyBatchToStreamListener(listener, getQuerySchema()); + String handle = command.getPreparedStatementHandle().toStringUtf8(); + if (handle.equals("SELECT PREPARED STATEMENT HANDLE") || handle.equals("PREPARED PLAN HANDLE")) { + putEmptyBatchToStreamListener(listener, getQuerySchema()); + } else if (handle.equals("SELECT PREPARED STATEMENT WITH TXN HANDLE") || + handle.equals("PREPARED PLAN WITH TXN HANDLE")) { + putEmptyBatchToStreamListener(listener, getQueryWithTransactionSchema()); + } else { + listener.error(CallStatus.INVALID_ARGUMENT + .withDescription("Unknown handle: " + handle) + .toRuntimeException()); + } } private Runnable acceptPutReturnConstant(StreamListener ackStream, long value) { @@ -170,48 +421,92 @@ public Runnable acceptPutStatement(FlightSql.CommandStatementUpdate command, Cal FlightStream flightStream, StreamListener ackStream) { IntegrationAssertions.assertEquals(command.getQuery(), "UPDATE STATEMENT"); + return acceptPutReturnConstant(ackStream, + command.getTransactionId().isEmpty() ? FlightSqlScenario.UPDATE_STATEMENT_EXPECTED_ROWS : + FlightSqlScenario.UPDATE_STATEMENT_WITH_TRANSACTION_EXPECTED_ROWS); + } - return acceptPutReturnConstant(ackStream, FlightSqlScenario.UPDATE_STATEMENT_EXPECTED_ROWS); + @Override + public Runnable acceptPutSubstraitPlan(FlightSql.CommandStatementSubstraitPlan command, CallContext context, + FlightStream flightStream, StreamListener ackStream) { + IntegrationAssertions.assertEquals(command.getPlan().getPlan().toByteArray(), + FlightSqlScenario.SUBSTRAIT_PLAN_TEXT); + IntegrationAssertions.assertEquals(command.getPlan().getVersion(), FlightSqlScenario.SUBSTRAIT_VERSION); + return acceptPutReturnConstant(ackStream, + command.getTransactionId().isEmpty() ? FlightSqlScenario.UPDATE_STATEMENT_EXPECTED_ROWS : + FlightSqlScenario.UPDATE_STATEMENT_WITH_TRANSACTION_EXPECTED_ROWS); } @Override public Runnable acceptPutPreparedStatementUpdate(FlightSql.CommandPreparedStatementUpdate command, CallContext context, FlightStream flightStream, StreamListener ackStream) { - IntegrationAssertions.assertEquals(command.getPreparedStatementHandle().toStringUtf8(), - "UPDATE PREPARED STATEMENT HANDLE"); - - return acceptPutReturnConstant(ackStream, FlightSqlScenario.UPDATE_PREPARED_STATEMENT_EXPECTED_ROWS); + final String handle = command.getPreparedStatementHandle().toStringUtf8(); + if (handle.equals("UPDATE PREPARED STATEMENT HANDLE") || + handle.equals("PREPARED PLAN HANDLE")) { + return acceptPutReturnConstant(ackStream, FlightSqlScenario.UPDATE_PREPARED_STATEMENT_EXPECTED_ROWS); + } else if (handle.equals("UPDATE PREPARED STATEMENT WITH TXN HANDLE") || + handle.equals("PREPARED PLAN WITH TXN HANDLE")) { + return acceptPutReturnConstant( + ackStream, FlightSqlScenario.UPDATE_PREPARED_STATEMENT_WITH_TRANSACTION_EXPECTED_ROWS); + } + return () -> { + ackStream.onError(CallStatus.INVALID_ARGUMENT + .withDescription("Unknown handle: " + handle) + .toRuntimeException()); + }; } @Override public Runnable acceptPutPreparedStatementQuery(FlightSql.CommandPreparedStatementQuery command, CallContext context, FlightStream flightStream, StreamListener ackStream) { - IntegrationAssertions.assertEquals(command.getPreparedStatementHandle().toStringUtf8(), - "SELECT PREPARED STATEMENT HANDLE"); - - IntegrationAssertions.assertEquals(getQuerySchema(), flightStream.getSchema()); - - return ackStream::onCompleted; + final String handle = command.getPreparedStatementHandle().toStringUtf8(); + if (handle.equals("SELECT PREPARED STATEMENT HANDLE") || + handle.equals("SELECT PREPARED STATEMENT WITH TXN HANDLE") || + handle.equals("PREPARED PLAN HANDLE") || + handle.equals("PREPARED PLAN WITH TXN HANDLE")) { + IntegrationAssertions.assertEquals(getQuerySchema(), flightStream.getSchema()); + return ackStream::onCompleted; + } + return () -> { + ackStream.onError(CallStatus.INVALID_ARGUMENT + .withDescription("Unknown handle: " + handle) + .toRuntimeException()); + }; } @Override public FlightInfo getFlightInfoSqlInfo(FlightSql.CommandGetSqlInfo request, CallContext context, FlightDescriptor descriptor) { - IntegrationAssertions.assertEquals(request.getInfoCount(), 2); - IntegrationAssertions.assertEquals(request.getInfo(0), - FlightSql.SqlInfo.FLIGHT_SQL_SERVER_NAME_VALUE); - IntegrationAssertions.assertEquals(request.getInfo(1), - FlightSql.SqlInfo.FLIGHT_SQL_SERVER_READ_ONLY_VALUE); - + if (request.getInfoCount() == 2) { + // Integration test for the protocol messages + IntegrationAssertions.assertEquals(request.getInfo(0), + FlightSql.SqlInfo.FLIGHT_SQL_SERVER_NAME_VALUE); + IntegrationAssertions.assertEquals(request.getInfo(1), + FlightSql.SqlInfo.FLIGHT_SQL_SERVER_READ_ONLY_VALUE); + } return getFlightInfoForSchema(request, descriptor, Schemas.GET_SQL_INFO_SCHEMA); } @Override public void getStreamSqlInfo(FlightSql.CommandGetSqlInfo command, CallContext context, ServerStreamListener listener) { - putEmptyBatchToStreamListener(listener, Schemas.GET_SQL_INFO_SCHEMA); + if (command.getInfoCount() == 2) { + // Integration test for the protocol messages + putEmptyBatchToStreamListener(listener, Schemas.GET_SQL_INFO_SCHEMA); + return; + } + SqlInfoBuilder sqlInfoBuilder = new SqlInfoBuilder() + .withFlightSqlServerSql(false) + .withFlightSqlServerSubstrait(true) + .withFlightSqlServerSubstraitMinVersion("min_version") + .withFlightSqlServerSubstraitMaxVersion("max_version") + .withFlightSqlServerTransaction(FlightSql.SqlSupportedTransaction.SQL_SUPPORTED_TRANSACTION_SAVEPOINT) + .withFlightSqlServerCancel(true) + .withFlightSqlServerStatementTimeout(42) + .withFlightSqlServerTransactionTimeout(7); + sqlInfoBuilder.send(command.getInfoList(), listener); } @Override @@ -373,8 +668,8 @@ public void listFlights(CallContext context, Criteria criteria, private FlightInfo getFlightInfoForSchema(final T request, final FlightDescriptor descriptor, final Schema schema) { - final Ticket ticket = new Ticket(pack(request).toByteArray()); - final List endpoints = singletonList(new FlightEndpoint(ticket)); + final Ticket ticket = new Ticket(Any.pack(request).toByteArray()); + final List endpoints = Collections.singletonList(new FlightEndpoint(ticket)); return new FlightInfo(schema, descriptor, endpoints, -1, -1); } diff --git a/java/flight/flight-integration-tests/src/main/java/org/apache/arrow/flight/integration/tests/IntegrationAssertions.java b/java/flight/flight-integration-tests/src/main/java/org/apache/arrow/flight/integration/tests/IntegrationAssertions.java index 76f846a8b73d5..a60efcbb78dcb 100644 --- a/java/flight/flight-integration-tests/src/main/java/org/apache/arrow/flight/integration/tests/IntegrationAssertions.java +++ b/java/flight/flight-integration-tests/src/main/java/org/apache/arrow/flight/integration/tests/IntegrationAssertions.java @@ -19,6 +19,7 @@ import java.io.PrintWriter; import java.io.StringWriter; +import java.util.Arrays; import java.util.Objects; import org.apache.arrow.flight.CallStatus; @@ -59,6 +60,16 @@ static void assertEquals(Object expected, Object actual) { } } + /** + * Assert that the two arrays are equal. + */ + static void assertEquals(byte[] expected, byte[] actual) { + if (!Arrays.equals(expected, actual)) { + throw new AssertionError( + String.format("Expected:\n%s\nbut got:\n%s", Arrays.toString(expected), Arrays.toString(actual))); + } + } + /** * Assert that the value is false, using the given message as an error otherwise. */ diff --git a/java/flight/flight-integration-tests/src/main/java/org/apache/arrow/flight/integration/tests/Scenarios.java b/java/flight/flight-integration-tests/src/main/java/org/apache/arrow/flight/integration/tests/Scenarios.java index 16cc856daf567..77f7ab0006db4 100644 --- a/java/flight/flight-integration-tests/src/main/java/org/apache/arrow/flight/integration/tests/Scenarios.java +++ b/java/flight/flight-integration-tests/src/main/java/org/apache/arrow/flight/integration/tests/Scenarios.java @@ -42,6 +42,7 @@ private Scenarios() { scenarios.put("auth:basic_proto", AuthBasicProtoScenario::new); scenarios.put("middleware", MiddlewareScenario::new); scenarios.put("flight_sql", FlightSqlScenario::new); + scenarios.put("flight_sql:extension", FlightSqlExtensionScenario::new); } private static Scenarios getInstance() { diff --git a/java/flight/flight-integration-tests/src/test/java/org/apache/arrow/flight/integration/tests/IntegrationTest.java b/java/flight/flight-integration-tests/src/test/java/org/apache/arrow/flight/integration/tests/IntegrationTest.java new file mode 100644 index 0000000000000..0751e1d7a8907 --- /dev/null +++ b/java/flight/flight-integration-tests/src/test/java/org/apache/arrow/flight/integration/tests/IntegrationTest.java @@ -0,0 +1,70 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.flight.integration.tests; + +import org.apache.arrow.flight.FlightClient; +import org.apache.arrow.flight.FlightServer; +import org.apache.arrow.flight.Location; +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.memory.RootAllocator; +import org.junit.jupiter.api.Test; + +/** + * Run the integration test scenarios in-process. + */ +class IntegrationTest { + @Test + void authBasicProto() throws Exception { + testScenario("auth:basic_proto"); + } + + @Test + void middleware() throws Exception { + testScenario("middleware"); + } + + @Test + void flightSql() throws Exception { + testScenario("flight_sql"); + } + + @Test + void flightSqlExtension() throws Exception { + testScenario("flight_sql:extension"); + } + + void testScenario(String scenarioName) throws Exception { + try (final BufferAllocator allocator = new RootAllocator()) { + final FlightServer.Builder builder = FlightServer.builder() + .allocator(allocator) + .location(Location.forGrpcInsecure("0.0.0.0", 0)); + final Scenario scenario = Scenarios.getScenario(scenarioName); + scenario.buildServer(builder); + builder.producer(scenario.producer(allocator, Location.forGrpcInsecure("0.0.0.0", 0))); + + try (final FlightServer server = builder.build()) { + server.start(); + + final Location location = Location.forGrpcInsecure("localhost", server.getPort()); + try (final FlightClient client = FlightClient.builder(allocator, location).build()) { + scenario.client(allocator, location, client); + } + } + } + } +} diff --git a/java/flight/flight-sql/src/main/java/org/apache/arrow/flight/sql/CancelListener.java b/java/flight/flight-sql/src/main/java/org/apache/arrow/flight/sql/CancelListener.java new file mode 100644 index 0000000000000..3438f788dcf6e --- /dev/null +++ b/java/flight/flight-sql/src/main/java/org/apache/arrow/flight/sql/CancelListener.java @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.flight.sql; + +import org.apache.arrow.flight.FlightProducer; +import org.apache.arrow.flight.Result; +import org.apache.arrow.flight.sql.impl.FlightSql; + +import com.google.protobuf.Any; + +/** Typed StreamListener for cancelQuery. */ +class CancelListener implements FlightProducer.StreamListener { + private final FlightProducer.StreamListener listener; + + CancelListener(FlightProducer.StreamListener listener) { + this.listener = listener; + } + + @Override + public void onNext(CancelResult val) { + FlightSql.ActionCancelQueryResult result = FlightSql.ActionCancelQueryResult.newBuilder() + .setResult(val.toProtocol()) + .build(); + listener.onNext(new Result(Any.pack(result).toByteArray())); + } + + @Override + public void onError(Throwable t) { + listener.onError(t); + } + + @Override + public void onCompleted() { + listener.onCompleted(); + } +} diff --git a/java/flight/flight-sql/src/main/java/org/apache/arrow/flight/sql/CancelResult.java b/java/flight/flight-sql/src/main/java/org/apache/arrow/flight/sql/CancelResult.java new file mode 100644 index 0000000000000..d1ae417831035 --- /dev/null +++ b/java/flight/flight-sql/src/main/java/org/apache/arrow/flight/sql/CancelResult.java @@ -0,0 +1,45 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.flight.sql; + +import org.apache.arrow.flight.sql.impl.FlightSql; + +/** + * The result of cancelling a query. + */ +public enum CancelResult { + UNSPECIFIED, + CANCELLED, + CANCELLING, + NOT_CANCELLABLE, + ; + + FlightSql.ActionCancelQueryResult.CancelResult toProtocol() { + switch (this) { + default: + case UNSPECIFIED: + return FlightSql.ActionCancelQueryResult.CancelResult.CANCEL_RESULT_UNSPECIFIED; + case CANCELLED: + return FlightSql.ActionCancelQueryResult.CancelResult.CANCEL_RESULT_CANCELLED; + case CANCELLING: + return FlightSql.ActionCancelQueryResult.CancelResult.CANCEL_RESULT_CANCELLING; + case NOT_CANCELLABLE: + return FlightSql.ActionCancelQueryResult.CancelResult.CANCEL_RESULT_NOT_CANCELLABLE; + } + } +} diff --git a/java/flight/flight-sql/src/main/java/org/apache/arrow/flight/sql/FlightSqlClient.java b/java/flight/flight-sql/src/main/java/org/apache/arrow/flight/sql/FlightSqlClient.java index f1f07a1588f57..922495a18e0c9 100644 --- a/java/flight/flight-sql/src/main/java/org/apache/arrow/flight/sql/FlightSqlClient.java +++ b/java/flight/flight-sql/src/main/java/org/apache/arrow/flight/sql/FlightSqlClient.java @@ -17,8 +17,17 @@ package org.apache.arrow.flight.sql; +import static org.apache.arrow.flight.sql.impl.FlightSql.ActionBeginSavepointRequest; +import static org.apache.arrow.flight.sql.impl.FlightSql.ActionBeginSavepointResult; +import static org.apache.arrow.flight.sql.impl.FlightSql.ActionBeginTransactionRequest; +import static org.apache.arrow.flight.sql.impl.FlightSql.ActionBeginTransactionResult; +import static org.apache.arrow.flight.sql.impl.FlightSql.ActionCancelQueryRequest; +import static org.apache.arrow.flight.sql.impl.FlightSql.ActionCancelQueryResult; import static org.apache.arrow.flight.sql.impl.FlightSql.ActionClosePreparedStatementRequest; import static org.apache.arrow.flight.sql.impl.FlightSql.ActionCreatePreparedStatementRequest; +import static org.apache.arrow.flight.sql.impl.FlightSql.ActionCreatePreparedSubstraitPlanRequest; +import static org.apache.arrow.flight.sql.impl.FlightSql.ActionEndSavepointRequest; +import static org.apache.arrow.flight.sql.impl.FlightSql.ActionEndTransactionRequest; import static org.apache.arrow.flight.sql.impl.FlightSql.CommandGetCatalogs; import static org.apache.arrow.flight.sql.impl.FlightSql.CommandGetCrossReference; import static org.apache.arrow.flight.sql.impl.FlightSql.CommandGetDbSchemas; @@ -31,6 +40,7 @@ import static org.apache.arrow.flight.sql.impl.FlightSql.CommandGetXdbcTypeInfo; import static org.apache.arrow.flight.sql.impl.FlightSql.CommandPreparedStatementUpdate; import static org.apache.arrow.flight.sql.impl.FlightSql.CommandStatementQuery; +import static org.apache.arrow.flight.sql.impl.FlightSql.CommandStatementSubstraitPlan; import static org.apache.arrow.flight.sql.impl.FlightSql.CommandStatementUpdate; import static org.apache.arrow.flight.sql.impl.FlightSql.DoPutUpdateResult; import static org.apache.arrow.flight.sql.impl.FlightSql.SqlInfo; @@ -91,8 +101,51 @@ public FlightSqlClient(final FlightClient client) { * @return a FlightInfo object representing the stream(s) to fetch. */ public FlightInfo execute(final String query, final CallOption... options) { - final CommandStatementQuery.Builder builder = CommandStatementQuery.newBuilder(); - builder.setQuery(query); + return execute(query, /*transaction*/ null, options); + } + + /** + * Execute a query on the server. + * + * @param query The query to execute. + * @param transaction The transaction that this query is part of. + * @param options RPC-layer hints for this call. + * @return a FlightInfo object representing the stream(s) to fetch. + */ + public FlightInfo execute(final String query, Transaction transaction, final CallOption... options) { + final CommandStatementQuery.Builder builder = CommandStatementQuery.newBuilder().setQuery(query); + if (transaction != null) { + builder.setTransactionId(ByteString.copyFrom(transaction.getTransactionId())); + } + final FlightDescriptor descriptor = FlightDescriptor.command(Any.pack(builder.build()).toByteArray()); + return client.getInfo(descriptor, options); + } + + /** + * Execute a Substrait plan on the server. + * + * @param plan The Substrait plan to execute. + * @param options RPC-layer hints for this call. + * @return a FlightInfo object representing the stream(s) to fetch. + */ + public FlightInfo executeSubstrait(SubstraitPlan plan, CallOption... options) { + return executeSubstrait(plan, /*transaction*/ null, options); + } + + /** + * Execute a Substrait plan on the server. + * + * @param plan The Substrait plan to execute. + * @param transaction The transaction that this query is part of. + * @param options RPC-layer hints for this call. + * @return a FlightInfo object representing the stream(s) to fetch. + */ + public FlightInfo executeSubstrait(SubstraitPlan plan, Transaction transaction, CallOption... options) { + final CommandStatementSubstraitPlan.Builder builder = CommandStatementSubstraitPlan.newBuilder(); + builder.getPlanBuilder().setPlan(ByteString.copyFrom(plan.getPlan())).setVersion(plan.getVersion()); + if (transaction != null) { + builder.setTransactionId(ByteString.copyFrom(transaction.getTransactionId())); + } final FlightDescriptor descriptor = FlightDescriptor.command(Any.pack(builder.build()).toByteArray()); return client.getInfo(descriptor, options); } @@ -100,13 +153,44 @@ public FlightInfo execute(final String query, final CallOption... options) { /** * Get the schema of the result set of a query. */ - public SchemaResult getExecuteSchema(final String query, final CallOption... options) { + public SchemaResult getExecuteSchema(String query, Transaction transaction, CallOption... options) { final CommandStatementQuery.Builder builder = CommandStatementQuery.newBuilder(); builder.setQuery(query); + if (transaction != null) { + builder.setTransactionId(ByteString.copyFrom(transaction.getTransactionId())); + } final FlightDescriptor descriptor = FlightDescriptor.command(Any.pack(builder.build()).toByteArray()); return client.getSchema(descriptor, options); } + /** + * Get the schema of the result set of a query. + */ + public SchemaResult getExecuteSchema(String query, CallOption... options) { + return getExecuteSchema(query, /*transaction*/null, options); + } + + /** + * Get the schema of the result set of a Substrait plan. + */ + public SchemaResult getExecuteSubstraitSchema(SubstraitPlan plan, Transaction transaction, + final CallOption... options) { + final CommandStatementSubstraitPlan.Builder builder = CommandStatementSubstraitPlan.newBuilder(); + builder.getPlanBuilder().setPlan(ByteString.copyFrom(plan.getPlan())).setVersion(plan.getVersion()); + if (transaction != null) { + builder.setTransactionId(ByteString.copyFrom(transaction.getTransactionId())); + } + final FlightDescriptor descriptor = FlightDescriptor.command(Any.pack(builder.build()).toByteArray()); + return client.getSchema(descriptor, options); + } + + /** + * Get the schema of the result set of a Substrait plan. + */ + public SchemaResult getExecuteSubstraitSchema(SubstraitPlan substraitPlan, final CallOption... options) { + return getExecuteSubstraitSchema(substraitPlan, /*transaction*/null, options); + } + /** * Execute an update query on the server. * @@ -115,18 +199,77 @@ public SchemaResult getExecuteSchema(final String query, final CallOption... opt * @return the number of rows affected. */ public long executeUpdate(final String query, final CallOption... options) { - final CommandStatementUpdate.Builder builder = CommandStatementUpdate.newBuilder(); - builder.setQuery(query); + return executeUpdate(query, /*transaction*/ null, options); + } + + /** + * Execute an update query on the server. + * + * @param query The query to execute. + * @param transaction The transaction that this query is part of. + * @param options RPC-layer hints for this call. + * @return the number of rows affected. + */ + public long executeUpdate(final String query, Transaction transaction, final CallOption... options) { + final CommandStatementUpdate.Builder builder = CommandStatementUpdate.newBuilder().setQuery(query); + if (transaction != null) { + builder.setTransactionId(ByteString.copyFrom(transaction.getTransactionId())); + } final FlightDescriptor descriptor = FlightDescriptor.command(Any.pack(builder.build()).toByteArray()); - final SyncPutListener putListener = new SyncPutListener(); - client.startPut(descriptor, VectorSchemaRoot.of(), putListener, options); + try (final SyncPutListener putListener = new SyncPutListener()) { + final FlightClient.ClientStreamListener listener = + client.startPut(descriptor, VectorSchemaRoot.of(), putListener, options); + try (final PutResult result = putListener.read()) { + final DoPutUpdateResult doPutUpdateResult = DoPutUpdateResult.parseFrom( + result.getApplicationMetadata().nioBuffer()); + return doPutUpdateResult.getRecordCount(); + } finally { + listener.getResult(); + } + } catch (final InterruptedException | ExecutionException e) { + throw CallStatus.CANCELLED.withCause(e).toRuntimeException(); + } catch (final InvalidProtocolBufferException e) { + throw CallStatus.INTERNAL.withCause(e).toRuntimeException(); + } + } - try { - final PutResult read = putListener.read(); - try (final ArrowBuf metadata = read.getApplicationMetadata()) { - final DoPutUpdateResult doPutUpdateResult = DoPutUpdateResult.parseFrom(metadata.nioBuffer()); + /** + * Execute an update query on the server. + * + * @param plan The Substrait plan to execute. + * @param options RPC-layer hints for this call. + * @return the number of rows affected. + */ + public long executeSubstraitUpdate(SubstraitPlan plan, CallOption... options) { + return executeSubstraitUpdate(plan, /*transaction*/ null, options); + } + + /** + * Execute an update query on the server. + * + * @param plan The Substrait plan to execute. + * @param transaction The transaction that this query is part of. + * @param options RPC-layer hints for this call. + * @return the number of rows affected. + */ + public long executeSubstraitUpdate(SubstraitPlan plan, Transaction transaction, CallOption... options) { + final CommandStatementSubstraitPlan.Builder builder = CommandStatementSubstraitPlan.newBuilder(); + builder.getPlanBuilder().setPlan(ByteString.copyFrom(plan.getPlan())).setVersion(plan.getVersion()); + if (transaction != null) { + builder.setTransactionId(ByteString.copyFrom(transaction.getTransactionId())); + } + + final FlightDescriptor descriptor = FlightDescriptor.command(Any.pack(builder.build()).toByteArray()); + try (final SyncPutListener putListener = new SyncPutListener()) { + final FlightClient.ClientStreamListener listener = + client.startPut(descriptor, VectorSchemaRoot.of(), putListener, options); + try (final PutResult result = putListener.read()) { + final DoPutUpdateResult doPutUpdateResult = DoPutUpdateResult.parseFrom( + result.getApplicationMetadata().nioBuffer()); return doPutUpdateResult.getRecordCount(); + } finally { + listener.getResult(); } } catch (final InterruptedException | ExecutionException e) { throw CallStatus.CANCELLED.withCause(e).toRuntimeException(); @@ -551,14 +694,198 @@ public SchemaResult getTableTypesSchema(final CallOption... options) { } /** - * Create a prepared statement on the server. + * Create a prepared statement for a SQL query on the server. * * @param query The query to prepare. * @param options RPC-layer hints for this call. * @return The representation of the prepared statement which exists on the server. */ - public PreparedStatement prepare(final String query, final CallOption... options) { - return new PreparedStatement(client, query, options); + public PreparedStatement prepare(String query, CallOption... options) { + return prepare(query, /*transaction*/ null, options); + } + + /** + * Create a prepared statement for a SQL query on the server. + * + * @param query The query to prepare. + * @param transaction The transaction that this query is part of. + * @param options RPC-layer hints for this call. + * @return The representation of the prepared statement which exists on the server. + */ + public PreparedStatement prepare(String query, Transaction transaction, CallOption... options) { + ActionCreatePreparedStatementRequest.Builder builder = + ActionCreatePreparedStatementRequest.newBuilder().setQuery(query); + if (transaction != null) { + builder.setTransactionId(ByteString.copyFrom(transaction.getTransactionId())); + } + return new PreparedStatement(client, + new Action( + FlightSqlUtils.FLIGHT_SQL_CREATE_PREPARED_STATEMENT.getType(), + Any.pack(builder.build()).toByteArray()), + options); + } + + /** + * Create a prepared statement for a Substrait plan on the server. + * + * @param plan The query to prepare. + * @param options RPC-layer hints for this call. + * @return The representation of the prepared statement which exists on the server. + */ + public PreparedStatement prepare(SubstraitPlan plan, CallOption... options) { + return prepare(plan, /*transaction*/ null, options); + } + + /** + * Create a prepared statement for a Substrait plan on the server. + * + * @param plan The query to prepare. + * @param transaction The transaction that this query is part of. + * @param options RPC-layer hints for this call. + * @return The representation of the prepared statement which exists on the server. + */ + public PreparedStatement prepare(SubstraitPlan plan, Transaction transaction, CallOption... options) { + ActionCreatePreparedSubstraitPlanRequest.Builder builder = + ActionCreatePreparedSubstraitPlanRequest.newBuilder(); + builder.getPlanBuilder().setPlan(ByteString.copyFrom(plan.getPlan())).setVersion(plan.getVersion()); + if (transaction != null) { + builder.setTransactionId(ByteString.copyFrom(transaction.getTransactionId())); + } + return new PreparedStatement(client, + new Action( + FlightSqlUtils.FLIGHT_SQL_CREATE_PREPARED_SUBSTRAIT_PLAN.getType(), + Any.pack(builder.build()).toByteArray()), + options); + } + + /** Begin a transaction. */ + public Transaction beginTransaction(CallOption... options) { + final Action action = new Action( + FlightSqlUtils.FLIGHT_SQL_BEGIN_TRANSACTION.getType(), + Any.pack(ActionBeginTransactionRequest.getDefaultInstance()).toByteArray()); + final Iterator preparedStatementResults = client.doAction(action, options); + final ActionBeginTransactionResult result = FlightSqlUtils.unpackAndParseOrThrow( + preparedStatementResults.next().getBody(), + ActionBeginTransactionResult.class); + preparedStatementResults.forEachRemaining((ignored) -> { }); + if (result.getTransactionId().isEmpty()) { + throw CallStatus.INTERNAL.withDescription("Server returned an empty transaction ID").toRuntimeException(); + } + return new Transaction(result.getTransactionId().toByteArray()); + } + + /** Create a savepoint within a transaction. */ + public Savepoint beginSavepoint(Transaction transaction, String name, CallOption... options) { + Preconditions.checkArgument(transaction.getTransactionId().length != 0, "Transaction must be initialized"); + ActionBeginSavepointRequest request = ActionBeginSavepointRequest.newBuilder() + .setTransactionId(ByteString.copyFrom(transaction.getTransactionId())) + .setName(name) + .build(); + final Action action = new Action( + FlightSqlUtils.FLIGHT_SQL_BEGIN_SAVEPOINT.getType(), + Any.pack(request).toByteArray()); + final Iterator preparedStatementResults = client.doAction(action, options); + final ActionBeginSavepointResult result = FlightSqlUtils.unpackAndParseOrThrow( + preparedStatementResults.next().getBody(), + ActionBeginSavepointResult.class); + preparedStatementResults.forEachRemaining((ignored) -> { }); + if (result.getSavepointId().isEmpty()) { + throw CallStatus.INTERNAL.withDescription("Server returned an empty transaction ID").toRuntimeException(); + } + return new Savepoint(result.getSavepointId().toByteArray()); + } + + /** Commit a transaction. */ + public void commit(Transaction transaction, CallOption... options) { + Preconditions.checkArgument(transaction.getTransactionId().length != 0, "Transaction must be initialized"); + ActionEndTransactionRequest request = ActionEndTransactionRequest.newBuilder() + .setTransactionId(ByteString.copyFrom(transaction.getTransactionId())) + .setActionValue(ActionEndTransactionRequest.EndTransaction.END_TRANSACTION_COMMIT.getNumber()) + .build(); + final Action action = new Action( + FlightSqlUtils.FLIGHT_SQL_END_TRANSACTION.getType(), + Any.pack(request).toByteArray()); + final Iterator preparedStatementResults = client.doAction(action, options); + preparedStatementResults.forEachRemaining((ignored) -> { }); + } + + /** Release a savepoint. */ + public void release(Savepoint savepoint, CallOption... options) { + Preconditions.checkArgument(savepoint.getSavepointId().length != 0, "Savepoint must be initialized"); + ActionEndSavepointRequest request = ActionEndSavepointRequest.newBuilder() + .setSavepointId(ByteString.copyFrom(savepoint.getSavepointId())) + .setActionValue(ActionEndSavepointRequest.EndSavepoint.END_SAVEPOINT_RELEASE.getNumber()) + .build(); + final Action action = new Action( + FlightSqlUtils.FLIGHT_SQL_END_SAVEPOINT.getType(), + Any.pack(request).toByteArray()); + final Iterator preparedStatementResults = client.doAction(action, options); + preparedStatementResults.forEachRemaining((ignored) -> { }); + } + + /** Rollback a transaction. */ + public void rollback(Transaction transaction, CallOption... options) { + Preconditions.checkArgument(transaction.getTransactionId().length != 0, "Transaction must be initialized"); + ActionEndTransactionRequest request = ActionEndTransactionRequest.newBuilder() + .setTransactionId(ByteString.copyFrom(transaction.getTransactionId())) + .setActionValue(ActionEndTransactionRequest.EndTransaction.END_TRANSACTION_ROLLBACK.getNumber()) + .build(); + final Action action = new Action( + FlightSqlUtils.FLIGHT_SQL_END_TRANSACTION.getType(), + Any.pack(request).toByteArray()); + final Iterator preparedStatementResults = client.doAction(action, options); + preparedStatementResults.forEachRemaining((ignored) -> { }); + } + + /** Rollback to a savepoint. */ + public void rollback(Savepoint savepoint, CallOption... options) { + Preconditions.checkArgument(savepoint.getSavepointId().length != 0, "Savepoint must be initialized"); + ActionEndSavepointRequest request = ActionEndSavepointRequest.newBuilder() + .setSavepointId(ByteString.copyFrom(savepoint.getSavepointId())) + .setActionValue(ActionEndSavepointRequest.EndSavepoint.END_SAVEPOINT_RELEASE.getNumber()) + .build(); + final Action action = new Action( + FlightSqlUtils.FLIGHT_SQL_END_SAVEPOINT.getType(), + Any.pack(request).toByteArray()); + final Iterator preparedStatementResults = client.doAction(action, options); + preparedStatementResults.forEachRemaining((ignored) -> { }); + } + + /** + * Explicitly cancel a running query. + *

+ * This lets a single client explicitly cancel work, no matter how many clients + * are involved/whether the query is distributed or not, given server support. + * The transaction/statement is not rolled back; it is the application's job to + * commit or rollback as appropriate. This only indicates the client no longer + * wishes to read the remainder of the query results or continue submitting + * data. + */ + public CancelResult cancelQuery(FlightInfo info, CallOption... options) { + ActionCancelQueryRequest request = ActionCancelQueryRequest.newBuilder() + .setInfo(ByteString.copyFrom(info.serialize())) + .build(); + final Action action = new Action( + FlightSqlUtils.FLIGHT_SQL_CANCEL_QUERY.getType(), + Any.pack(request).toByteArray()); + final Iterator preparedStatementResults = client.doAction(action, options); + final ActionCancelQueryResult result = FlightSqlUtils.unpackAndParseOrThrow( + preparedStatementResults.next().getBody(), + ActionCancelQueryResult.class); + preparedStatementResults.forEachRemaining((ignored) -> { }); + switch (result.getResult()) { + case CANCEL_RESULT_UNSPECIFIED: + return CancelResult.UNSPECIFIED; + case CANCEL_RESULT_CANCELLED: + return CancelResult.CANCELLED; + case CANCEL_RESULT_CANCELLING: + return CancelResult.CANCELLING; + case CANCEL_RESULT_NOT_CANCELLABLE: + return CancelResult.NOT_CANCELLABLE; + case UNRECOGNIZED: + default: + throw CallStatus.INTERNAL.withDescription("Unknown result: " + result.getResult()).toRuntimeException(); + } } @Override @@ -577,28 +904,13 @@ public static class PreparedStatement implements AutoCloseable { private Schema resultSetSchema; private Schema parameterSchema; - /** - * Constructor. - * - * @param client The client. PreparedStatement does not maintain this resource. - * @param sql The query. - * @param options RPC-layer hints for this call. - */ - public PreparedStatement(final FlightClient client, final String sql, final CallOption... options) { + PreparedStatement(FlightClient client, Action action, CallOption... options) { this.client = client; - final Action action = new Action( - FlightSqlUtils.FLIGHT_SQL_CREATE_PREPARED_STATEMENT.getType(), - Any.pack(ActionCreatePreparedStatementRequest - .newBuilder() - .setQuery(sql) - .build()) - .toByteArray()); - final Iterator preparedStatementResults = client.doAction(action, options); + final Iterator preparedStatementResults = client.doAction(action, options); preparedStatementResult = FlightSqlUtils.unpackAndParseOrThrow( preparedStatementResults.next().getBody(), ActionCreatePreparedStatementResult.class); - isClosed = false; } @@ -790,4 +1102,81 @@ public boolean isClosed() { return isClosed; } } + + /** A handle for an active savepoint. */ + public static class Savepoint { + private final byte[] transactionId; + + public Savepoint(byte[] transactionId) { + this.transactionId = transactionId; + } + + public byte[] getSavepointId() { + return transactionId; + } + } + + /** A handle for an active transaction. */ + public static class Transaction { + private final byte[] transactionId; + + public Transaction(byte[] transactionId) { + this.transactionId = transactionId; + } + + public byte[] getTransactionId() { + return transactionId; + } + } + + /** A wrapper around a Substrait plan and a Substrait version. */ + public static final class SubstraitPlan { + private final byte[] plan; + private final String version; + + public SubstraitPlan(byte[] plan, String version) { + this.plan = Preconditions.checkNotNull(plan); + this.version = Preconditions.checkNotNull(version); + } + + public byte[] getPlan() { + return plan; + } + + public String getVersion() { + return version; + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + + SubstraitPlan that = (SubstraitPlan) o; + + if (!Arrays.equals(getPlan(), that.getPlan())) { + return false; + } + return getVersion().equals(that.getVersion()); + } + + @Override + public int hashCode() { + int result = Arrays.hashCode(getPlan()); + result = 31 * result + getVersion().hashCode(); + return result; + } + + @Override + public String toString() { + return "SubstraitPlan{" + + "plan=" + Arrays.toString(plan) + + ", version='" + version + '\'' + + '}'; + } + } } diff --git a/java/flight/flight-sql/src/main/java/org/apache/arrow/flight/sql/FlightSqlProducer.java b/java/flight/flight-sql/src/main/java/org/apache/arrow/flight/sql/FlightSqlProducer.java index 4226ec9e228cf..00a83667990c5 100644 --- a/java/flight/flight-sql/src/main/java/org/apache/arrow/flight/sql/FlightSqlProducer.java +++ b/java/flight/flight-sql/src/main/java/org/apache/arrow/flight/sql/FlightSqlProducer.java @@ -20,12 +20,21 @@ import static java.util.Arrays.asList; import static java.util.Collections.singletonList; import static java.util.stream.IntStream.range; +import static org.apache.arrow.flight.sql.impl.FlightSql.ActionBeginSavepointRequest; +import static org.apache.arrow.flight.sql.impl.FlightSql.ActionBeginSavepointResult; +import static org.apache.arrow.flight.sql.impl.FlightSql.ActionBeginTransactionRequest; +import static org.apache.arrow.flight.sql.impl.FlightSql.ActionBeginTransactionResult; +import static org.apache.arrow.flight.sql.impl.FlightSql.ActionCancelQueryRequest; import static org.apache.arrow.flight.sql.impl.FlightSql.ActionCreatePreparedStatementResult; +import static org.apache.arrow.flight.sql.impl.FlightSql.ActionCreatePreparedSubstraitPlanRequest; +import static org.apache.arrow.flight.sql.impl.FlightSql.ActionEndSavepointRequest; +import static org.apache.arrow.flight.sql.impl.FlightSql.ActionEndTransactionRequest; import static org.apache.arrow.flight.sql.impl.FlightSql.CommandGetCrossReference; import static org.apache.arrow.flight.sql.impl.FlightSql.CommandGetDbSchemas; import static org.apache.arrow.flight.sql.impl.FlightSql.CommandGetExportedKeys; import static org.apache.arrow.flight.sql.impl.FlightSql.CommandGetImportedKeys; import static org.apache.arrow.flight.sql.impl.FlightSql.CommandGetXdbcTypeInfo; +import static org.apache.arrow.flight.sql.impl.FlightSql.CommandStatementSubstraitPlan; import static org.apache.arrow.vector.complex.MapVector.DATA_VECTOR_NAME; import static org.apache.arrow.vector.complex.MapVector.KEY_NAME; import static org.apache.arrow.vector.complex.MapVector.VALUE_NAME; @@ -37,6 +46,8 @@ import static org.apache.arrow.vector.types.Types.MinorType.UINT4; import static org.apache.arrow.vector.types.Types.MinorType.VARCHAR; +import java.io.IOException; +import java.net.URISyntaxException; import java.util.List; import org.apache.arrow.flight.Action; @@ -95,6 +106,9 @@ default FlightInfo getFlightInfo(CallContext context, FlightDescriptor descripto if (command.is(CommandStatementQuery.class)) { return getFlightInfoStatement( FlightSqlUtils.unpackOrThrow(command, CommandStatementQuery.class), context, descriptor); + } else if (command.is(CommandStatementSubstraitPlan.class)) { + return getFlightInfoSubstraitPlan( + FlightSqlUtils.unpackOrThrow(command, CommandStatementSubstraitPlan.class), context, descriptor); } else if (command.is(CommandPreparedStatementQuery.class)) { return getFlightInfoPreparedStatement( FlightSqlUtils.unpackOrThrow(command, CommandPreparedStatementQuery.class), context, descriptor); @@ -130,7 +144,9 @@ default FlightInfo getFlightInfo(CallContext context, FlightDescriptor descripto FlightSqlUtils.unpackOrThrow(command, CommandGetXdbcTypeInfo.class), context, descriptor); } - throw CallStatus.INVALID_ARGUMENT.withDescription("The defined request is invalid.").toRuntimeException(); + throw CallStatus.INVALID_ARGUMENT + .withDescription("Unrecognized request: " + command.getTypeUrl()) + .toRuntimeException(); } /** @@ -150,6 +166,9 @@ default SchemaResult getSchema(CallContext context, FlightDescriptor descriptor) } else if (command.is(CommandPreparedStatementQuery.class)) { return getSchemaPreparedStatement( FlightSqlUtils.unpackOrThrow(command, CommandPreparedStatementQuery.class), context, descriptor); + } else if (command.is(CommandStatementSubstraitPlan.class)) { + return getSchemaSubstraitPlan( + FlightSqlUtils.unpackOrThrow(command, CommandStatementSubstraitPlan.class), context, descriptor); } else if (command.is(CommandGetCatalogs.class)) { return new SchemaResult(Schemas.GET_CATALOGS_SCHEMA); } else if (command.is(CommandGetCrossReference.class)) { @@ -175,7 +194,9 @@ default SchemaResult getSchema(CallContext context, FlightDescriptor descriptor) return new SchemaResult(Schemas.GET_TYPE_INFO_SCHEMA); } - throw CallStatus.INVALID_ARGUMENT.withDescription("Invalid command provided.").toRuntimeException(); + throw CallStatus.INVALID_ARGUMENT + .withDescription("Unrecognized request: " + command.getTypeUrl()) + .toRuntimeException(); } /** @@ -249,6 +270,10 @@ default Runnable acceptPut(CallContext context, FlightStream flightStream, Strea return acceptPutStatement( FlightSqlUtils.unpackOrThrow(command, CommandStatementUpdate.class), context, flightStream, ackStream); + } else if (command.is(CommandStatementSubstraitPlan.class)) { + return acceptPutSubstraitPlan( + FlightSqlUtils.unpackOrThrow(command, CommandStatementSubstraitPlan.class), + context, flightStream, ackStream); } else if (command.is(CommandPreparedStatementUpdate.class)) { return acceptPutPreparedStatementUpdate( FlightSqlUtils.unpackOrThrow(command, CommandPreparedStatementUpdate.class), @@ -284,19 +309,91 @@ default void listActions(CallContext context, StreamListener listene @Override default void doAction(CallContext context, Action action, StreamListener listener) { final String actionType = action.getType(); - if (actionType.equals(FlightSqlUtils.FLIGHT_SQL_CREATE_PREPARED_STATEMENT.getType())) { + + if (actionType.equals(FlightSqlUtils.FLIGHT_SQL_BEGIN_SAVEPOINT.getType())) { + final ActionBeginSavepointRequest request = + FlightSqlUtils.unpackAndParseOrThrow(action.getBody(), ActionBeginSavepointRequest.class); + beginSavepoint(request, context, new ProtoListener<>(listener)); + } else if (actionType.equals(FlightSqlUtils.FLIGHT_SQL_BEGIN_TRANSACTION.getType())) { + final ActionBeginTransactionRequest request = + FlightSqlUtils.unpackAndParseOrThrow(action.getBody(), ActionBeginTransactionRequest.class); + beginTransaction(request, context, new ProtoListener<>(listener)); + } else if (actionType.equals(FlightSqlUtils.FLIGHT_SQL_CANCEL_QUERY.getType())) { + final ActionCancelQueryRequest request = + FlightSqlUtils.unpackAndParseOrThrow(action.getBody(), ActionCancelQueryRequest.class); + final FlightInfo info; + try { + info = FlightInfo.deserialize(request.getInfo().asReadOnlyByteBuffer()); + } catch (IOException | URISyntaxException e) { + listener.onError(CallStatus.INTERNAL + .withDescription("Could not unpack FlightInfo: " + e) + .withCause(e) + .toRuntimeException()); + return; + } + cancelQuery(info, context, new CancelListener(listener)); + } else if (actionType.equals(FlightSqlUtils.FLIGHT_SQL_CREATE_PREPARED_STATEMENT.getType())) { final ActionCreatePreparedStatementRequest request = FlightSqlUtils.unpackAndParseOrThrow(action.getBody(), ActionCreatePreparedStatementRequest.class); createPreparedStatement(request, context, listener); + } else if (actionType.equals(FlightSqlUtils.FLIGHT_SQL_CREATE_PREPARED_SUBSTRAIT_PLAN.getType())) { + final ActionCreatePreparedSubstraitPlanRequest request = + FlightSqlUtils.unpackAndParseOrThrow(action.getBody(), ActionCreatePreparedSubstraitPlanRequest.class); + createPreparedSubstraitPlan(request, context, new ProtoListener<>(listener)); } else if (actionType.equals(FlightSqlUtils.FLIGHT_SQL_CLOSE_PREPARED_STATEMENT.getType())) { - final ActionClosePreparedStatementRequest request = FlightSqlUtils.unpackAndParseOrThrow(action.getBody(), - ActionClosePreparedStatementRequest.class); - closePreparedStatement(request, context, listener); + final ActionClosePreparedStatementRequest request = + FlightSqlUtils.unpackAndParseOrThrow(action.getBody(), ActionClosePreparedStatementRequest.class); + closePreparedStatement(request, context, new NoResultListener(listener)); + } else if (actionType.equals(FlightSqlUtils.FLIGHT_SQL_END_SAVEPOINT.getType())) { + ActionEndSavepointRequest request = + FlightSqlUtils.unpackAndParseOrThrow(action.getBody(), ActionEndSavepointRequest.class); + endSavepoint(request, context, new NoResultListener(listener)); + } else if (actionType.equals(FlightSqlUtils.FLIGHT_SQL_END_TRANSACTION.getType())) { + ActionEndTransactionRequest request = + FlightSqlUtils.unpackAndParseOrThrow(action.getBody(), ActionEndTransactionRequest.class); + endTransaction(request, context, new NoResultListener(listener)); } else { - throw CallStatus.INVALID_ARGUMENT.withDescription("Invalid action provided.").toRuntimeException(); + throw CallStatus.INVALID_ARGUMENT + .withDescription("Unrecognized request: " + action.getType()) + .toRuntimeException(); } } + /** + * Create a savepoint within a transaction. + * + * @param request The savepoint request. + * @param context Per-call context. + * @param listener The newly created savepoint ID. + */ + default void beginSavepoint(ActionBeginSavepointRequest request, CallContext context, + StreamListener listener) { + listener.onError(CallStatus.UNIMPLEMENTED.toRuntimeException()); + } + + /** + * Begin a transaction. + * + * @param request The transaction request. + * @param context Per-call context. + * @param listener The newly created transaction ID. + */ + default void beginTransaction(ActionBeginTransactionRequest request, CallContext context, + StreamListener listener) { + listener.onError(CallStatus.UNIMPLEMENTED.toRuntimeException()); + } + + /** + * Explicitly cancel a query. + * + * @param info The FlightInfo of the query to cancel. + * @param context Per-call context. + * @param listener Whether cancellation succeeded. + */ + default void cancelQuery(FlightInfo info, CallContext context, StreamListener listener) { + listener.onError(CallStatus.UNIMPLEMENTED.toRuntimeException()); + } + /** * Creates a prepared statement on the server and returns a handle and metadata for in a * {@link ActionCreatePreparedStatementResult} object in a {@link Result} @@ -309,6 +406,17 @@ default void doAction(CallContext context, Action action, StreamListener void createPreparedStatement(ActionCreatePreparedStatementRequest request, CallContext context, StreamListener listener); + /** + * Pre-compile a Substrait plan. + * @param request The plan. + * @param context Per-call context. + * @param listener The resulting prepared statement. + */ + default void createPreparedSubstraitPlan(ActionCreatePreparedSubstraitPlanRequest request, CallContext context, + StreamListener listener) { + listener.onError(CallStatus.UNIMPLEMENTED.toRuntimeException()); + } + /** * Closes a prepared statement on the server. No result is expected. * @@ -320,9 +428,35 @@ void closePreparedStatement(ActionClosePreparedStatementRequest request, CallCon StreamListener listener); /** - * Gets information about a particular SQL query based data stream. + * Release or roll back to a savepoint. * - * @param command The sql command to generate the data stream. + * @param request The savepoint, and whether to release/rollback. + * @param context Per-call context. + * @param listener Call {@link StreamListener#onCompleted()} or + * {@link StreamListener#onError(Throwable)} when done; do not send a result. + */ + default void endSavepoint(ActionEndSavepointRequest request, CallContext context, + StreamListener listener) { + listener.onError(CallStatus.UNIMPLEMENTED.toRuntimeException()); + } + + /** + * Commit or roll back to a transaction. + * + * @param request The transaction, and whether to release/rollback. + * @param context Per-call context. + * @param listener Call {@link StreamListener#onCompleted()} or + * {@link StreamListener#onError(Throwable)} when done; do not send a result. + */ + default void endTransaction(ActionEndTransactionRequest request, CallContext context, + StreamListener listener) { + listener.onError(CallStatus.UNIMPLEMENTED.toRuntimeException()); + } + + /** + * Evaluate a SQL query. + * + * @param command The SQL query. * @param context Per-call context. * @param descriptor The descriptor identifying the data stream. * @return Metadata about the stream. @@ -330,6 +464,19 @@ void closePreparedStatement(ActionClosePreparedStatementRequest request, CallCon FlightInfo getFlightInfoStatement(CommandStatementQuery command, CallContext context, FlightDescriptor descriptor); + /** + * Evaluate a Substrait plan. + * + * @param command The Substrait plan. + * @param context Per-call context. + * @param descriptor The descriptor identifying the data stream. + * @return Metadata about the stream. + */ + default FlightInfo getFlightInfoSubstraitPlan(CommandStatementSubstraitPlan command, CallContext context, + FlightDescriptor descriptor) { + throw CallStatus.UNIMPLEMENTED.toRuntimeException(); + } + /** * Gets information about a particular prepared statement data stream. * @@ -342,7 +489,7 @@ FlightInfo getFlightInfoPreparedStatement(CommandPreparedStatementQuery command, CallContext context, FlightDescriptor descriptor); /** - * Get the schema of the result set of a query. + * Get the result schema for a SQL query. * * @param command The SQL query. * @param context Per-call context. @@ -367,6 +514,19 @@ default SchemaResult getSchemaPreparedStatement(CommandPreparedStatementQuery co .toRuntimeException(); } + /** + * Get the result schema for a Substrait plan. + * + * @param command The Substrait plan. + * @param context Per-call context. + * @param descriptor The descriptor identifying the data stream. + * @return Schema for the stream. + */ + default SchemaResult getSchemaSubstraitPlan(CommandStatementSubstraitPlan command, CallContext context, + FlightDescriptor descriptor) { + throw CallStatus.UNIMPLEMENTED.toRuntimeException(); + } + /** * Returns data for a SQL query based data stream. * @param ticket Ticket message containing the statement handle. @@ -399,6 +559,22 @@ void getStreamPreparedStatement(CommandPreparedStatementQuery command, CallConte Runnable acceptPutStatement(CommandStatementUpdate command, CallContext context, FlightStream flightStream, StreamListener ackStream); + /** + * Handle a Substrait plan with uploaded data. + * + * @param command The Substrait plan to evaluate. + * @param context Per-call context. + * @param flightStream The data stream being uploaded. + * @param ackStream The result data stream. + * @return A runnable to process the stream. + */ + default Runnable acceptPutSubstraitPlan(CommandStatementSubstraitPlan command, CallContext context, + FlightStream flightStream, StreamListener ackStream) { + return () -> { + ackStream.onError(CallStatus.UNIMPLEMENTED.toRuntimeException()); + }; + } + /** * Accepts uploaded data for a particular prepared statement data stream. *

`PutResult`s must be in the form of a {@link DoPutUpdateResult}. @@ -450,7 +626,7 @@ FlightInfo getFlightInfoSqlInfo(CommandGetSqlInfo request, CallContext context, /** * Returns a description of all the data types supported by source. - * + * * @param request request filter parameters. * @param descriptor The descriptor identifying the data stream. * @return Metadata about the stream. diff --git a/java/flight/flight-sql/src/main/java/org/apache/arrow/flight/sql/FlightSqlUtils.java b/java/flight/flight-sql/src/main/java/org/apache/arrow/flight/sql/FlightSqlUtils.java index e461515c40ecd..532921a8ac6e7 100644 --- a/java/flight/flight-sql/src/main/java/org/apache/arrow/flight/sql/FlightSqlUtils.java +++ b/java/flight/flight-sql/src/main/java/org/apache/arrow/flight/sql/FlightSqlUtils.java @@ -31,6 +31,18 @@ * Utilities to work with Flight SQL semantics. */ public final class FlightSqlUtils { + + public static final ActionType FLIGHT_SQL_BEGIN_SAVEPOINT = + new ActionType("BeginSavepoint", + "Create a new savepoint.\n" + + "Request Message: ActionBeginSavepointRequest\n" + + "Response Message: ActionBeginSavepointResult"); + + public static final ActionType FLIGHT_SQL_BEGIN_TRANSACTION = + new ActionType("BeginTransaction", + "Start a new transaction.\n" + + "Request Message: ActionBeginTransactionRequest\n" + + "Response Message: ActionBeginTransactionResult"); public static final ActionType FLIGHT_SQL_CREATE_PREPARED_STATEMENT = new ActionType("CreatePreparedStatement", "Creates a reusable prepared statement resource on the server. \n" + "Request Message: ActionCreatePreparedStatementRequest\n" + @@ -41,6 +53,29 @@ public final class FlightSqlUtils { "Request Message: ActionClosePreparedStatementRequest\n" + "Response Message: N/A"); + public static final ActionType FLIGHT_SQL_CREATE_PREPARED_SUBSTRAIT_PLAN = + new ActionType("CreatePreparedSubstraitPlan", + "Creates a reusable prepared statement resource on the server.\n" + + "Request Message: ActionCreatePreparedSubstraitPlanRequest\n" + + "Response Message: ActionCreatePreparedStatementResult"); + + public static final ActionType FLIGHT_SQL_CANCEL_QUERY = + new ActionType("CancelQuery", + "Explicitly cancel a running query.\n" + + "Request Message: ActionCancelQueryRequest\n" + + "Response Message: ActionCancelQueryResult"); + + public static final ActionType FLIGHT_SQL_END_SAVEPOINT = + new ActionType("EndSavepoint", + "End a savepoint.\n" + + "Request Message: ActionEndSavepointRequest\n" + + "Response Message: N/A"); + public static final ActionType FLIGHT_SQL_END_TRANSACTION = + new ActionType("EndTransaction", + "End a transaction.\n" + + "Request Message: ActionEndTransactionRequest\n" + + "Response Message: N/A"); + public static final List FLIGHT_SQL_ACTIONS = ImmutableList.of( FLIGHT_SQL_CREATE_PREPARED_STATEMENT, FLIGHT_SQL_CLOSE_PREPARED_STATEMENT diff --git a/java/flight/flight-sql/src/main/java/org/apache/arrow/flight/sql/NoResultListener.java b/java/flight/flight-sql/src/main/java/org/apache/arrow/flight/sql/NoResultListener.java new file mode 100644 index 0000000000000..2c80076a8f57f --- /dev/null +++ b/java/flight/flight-sql/src/main/java/org/apache/arrow/flight/sql/NoResultListener.java @@ -0,0 +1,45 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.flight.sql; + +import org.apache.arrow.flight.FlightProducer; +import org.apache.arrow.flight.Result; + +/** A StreamListener for actions that do not return results. */ +class NoResultListener implements FlightProducer.StreamListener { + private final FlightProducer.StreamListener listener; + + NoResultListener(FlightProducer.StreamListener listener) { + this.listener = listener; + } + + @Override + public void onNext(Result val) { + throw new UnsupportedOperationException("Do not call onNext on this listener."); + } + + @Override + public void onError(Throwable t) { + listener.onError(t); + } + + @Override + public void onCompleted() { + listener.onCompleted(); + } +} diff --git a/java/flight/flight-sql/src/main/java/org/apache/arrow/flight/sql/ProtoListener.java b/java/flight/flight-sql/src/main/java/org/apache/arrow/flight/sql/ProtoListener.java new file mode 100644 index 0000000000000..fd5fd0489628d --- /dev/null +++ b/java/flight/flight-sql/src/main/java/org/apache/arrow/flight/sql/ProtoListener.java @@ -0,0 +1,52 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.flight.sql; + +import org.apache.arrow.flight.FlightProducer; +import org.apache.arrow.flight.Result; + +import com.google.protobuf.Any; +import com.google.protobuf.Message; + +/** + * A StreamListener that accepts a particular type. + * + * @param The message type to accept. + */ +class ProtoListener implements FlightProducer.StreamListener { + private final FlightProducer.StreamListener listener; + + ProtoListener(FlightProducer.StreamListener listener) { + this.listener = listener; + } + + @Override + public void onNext(T val) { + listener.onNext(new Result(Any.pack(val).toByteArray())); + } + + @Override + public void onError(Throwable t) { + listener.onError(t); + } + + @Override + public void onCompleted() { + listener.onCompleted(); + } +} diff --git a/java/flight/flight-sql/src/main/java/org/apache/arrow/flight/sql/SqlInfoBuilder.java b/java/flight/flight-sql/src/main/java/org/apache/arrow/flight/sql/SqlInfoBuilder.java index 3866cb89b1f21..18793f9b905fe 100644 --- a/java/flight/flight-sql/src/main/java/org/apache/arrow/flight/sql/SqlInfoBuilder.java +++ b/java/flight/flight-sql/src/main/java/org/apache/arrow/flight/sql/SqlInfoBuilder.java @@ -20,6 +20,7 @@ import static java.nio.charset.StandardCharsets.UTF_8; import static java.util.stream.IntStream.range; import static org.apache.arrow.flight.FlightProducer.ServerStreamListener; +import static org.apache.arrow.flight.sql.impl.FlightSql.SqlSupportedTransaction; import static org.apache.arrow.flight.sql.util.SqlInfoOptionsUtils.createBitmaskFromEnums; import java.nio.charset.StandardCharsets; @@ -118,6 +119,46 @@ public SqlInfoBuilder withFlightSqlServerArrowVersion(final String value) { return withStringProvider(SqlInfo.FLIGHT_SQL_SERVER_ARROW_VERSION_VALUE, value); } + /** Set a value for SQL support. */ + public SqlInfoBuilder withFlightSqlServerSql(boolean value) { + return withBooleanProvider(SqlInfo.FLIGHT_SQL_SERVER_SQL_VALUE, value); + } + + /** Set a value for Substrait support. */ + public SqlInfoBuilder withFlightSqlServerSubstrait(boolean value) { + return withBooleanProvider(SqlInfo.FLIGHT_SQL_SERVER_SUBSTRAIT_VALUE, value); + } + + /** Set a value for Substrait minimum version support. */ + public SqlInfoBuilder withFlightSqlServerSubstraitMinVersion(String value) { + return withStringProvider(SqlInfo.FLIGHT_SQL_SERVER_SUBSTRAIT_MIN_VERSION_VALUE, value); + } + + /** Set a value for Substrait maximum version support. */ + public SqlInfoBuilder withFlightSqlServerSubstraitMaxVersion(String value) { + return withStringProvider(SqlInfo.FLIGHT_SQL_SERVER_SUBSTRAIT_MAX_VERSION_VALUE, value); + } + + /** Set a value for transaction support. */ + public SqlInfoBuilder withFlightSqlServerTransaction(SqlSupportedTransaction value) { + return withIntProvider(SqlInfo.FLIGHT_SQL_SERVER_TRANSACTION_VALUE, value.getNumber()); + } + + /** Set a value for query cancellation support. */ + public SqlInfoBuilder withFlightSqlServerCancel(boolean value) { + return withBooleanProvider(SqlInfo.FLIGHT_SQL_SERVER_CANCEL_VALUE, value); + } + + /** Set a value for statement timeouts. */ + public SqlInfoBuilder withFlightSqlServerStatementTimeout(int value) { + return withIntProvider(SqlInfo.FLIGHT_SQL_SERVER_STATEMENT_TIMEOUT_VALUE, value); + } + + /** Set a value for transaction timeouts. */ + public SqlInfoBuilder withFlightSqlServerTransactionTimeout(int value) { + return withIntProvider(SqlInfo.FLIGHT_SQL_SERVER_TRANSACTION_TIMEOUT_VALUE, value); + } + /** * Sets a value for {@link SqlInfo#SQL_IDENTIFIER_QUOTE_CHAR} in the builder. * diff --git a/java/flight/flight-sql/src/test/java/org/apache/arrow/flight/sql/example/FlightSqlExample.java b/java/flight/flight-sql/src/test/java/org/apache/arrow/flight/sql/example/FlightSqlExample.java index d66b8df9283bf..fe1e1445afc6e 100644 --- a/java/flight/flight-sql/src/test/java/org/apache/arrow/flight/sql/example/FlightSqlExample.java +++ b/java/flight/flight-sql/src/test/java/org/apache/arrow/flight/sql/example/FlightSqlExample.java @@ -217,6 +217,9 @@ public FlightSqlExample(final Location location) { .withFlightSqlServerVersion(metaData.getDatabaseProductVersion()) .withFlightSqlServerArrowVersion(metaData.getDriverVersion()) .withFlightSqlServerReadOnly(metaData.isReadOnly()) + .withFlightSqlServerSql(true) + .withFlightSqlServerSubstrait(false) + .withFlightSqlServerTransaction(SqlSupportedTransaction.SQL_SUPPORTED_TRANSACTION_NONE) .withSqlIdentifierQuoteChar(metaData.getIdentifierQuoteString()) .withSqlDdlCatalog(metaData.supportsCatalogsInDataManipulation()) .withSqlDdlSchema( metaData.supportsSchemasInDataManipulation())