From 6c48fddd1a7381f538057907bc7cdadbfc34c23f Mon Sep 17 00:00:00 2001
From: David Li
Date: Fri, 1 Jul 2022 14:33:20 -0400
Subject: [PATCH] RFC: [FlightRPC][WIP] Substrait, transaction, cancellation
for Flight SQL
---
.../flight_integration_test.cc | 4 +
.../integration_tests/test_integration.cc | 656 ++++++++++++++++--
cpp/src/arrow/flight/server.cc | 4 +-
cpp/src/arrow/flight/sql/CMakeLists.txt | 27 +-
cpp/src/arrow/flight/sql/acero_test.cc | 239 +++++++
cpp/src/arrow/flight/sql/client.cc | 355 +++++++++-
cpp/src/arrow/flight/sql/client.h | 152 +++-
cpp/src/arrow/flight/sql/client_test.cc | 1 +
cpp/src/arrow/flight/sql/column_metadata.cc | 8 +-
cpp/src/arrow/flight/sql/column_metadata.h | 8 +-
.../arrow/flight/sql/example/acero_main.cc | 70 ++
.../arrow/flight/sql/example/acero_server.cc | 306 ++++++++
.../arrow/flight/sql/example/acero_server.h | 37 +
.../arrow/flight/sql/example/sqlite_server.cc | 205 ++++--
.../arrow/flight/sql/example/sqlite_server.h | 6 +
.../flight/sql/example/sqlite_sql_info.cc | 9 +-
cpp/src/arrow/flight/sql/server.cc | 424 +++++++++--
cpp/src/arrow/flight/sql/server.h | 171 ++++-
cpp/src/arrow/flight/sql/server_test.cc | 83 ++-
cpp/src/arrow/flight/sql/types.h | 79 ++-
dev/archery/archery/integration/runner.py | 5 +
docs/source/status.rst | 14 +
format/FlightSql.proto | 298 +++++++-
.../org/apache/arrow/flight/FlightClient.java | 7 +-
.../apache/arrow/flight/FlightService.java | 2 +-
.../tests/FlightSqlExtensionScenario.java | 217 ++++++
.../integration/tests/FlightSqlScenario.java | 52 +-
.../tests/FlightSqlScenarioProducer.java | 383 ++++++++--
.../tests/IntegrationAssertions.java | 11 +
.../flight/integration/tests/Scenarios.java | 1 +
.../integration/tests/IntegrationTest.java | 70 ++
.../arrow/flight/sql/CancelListener.java | 51 ++
.../apache/arrow/flight/sql/CancelResult.java | 45 ++
.../arrow/flight/sql/FlightSqlClient.java | 451 +++++++++++-
.../arrow/flight/sql/FlightSqlProducer.java | 198 +++++-
.../arrow/flight/sql/FlightSqlUtils.java | 35 +
.../arrow/flight/sql/NoResultListener.java | 45 ++
.../arrow/flight/sql/ProtoListener.java | 52 ++
.../arrow/flight/sql/SqlInfoBuilder.java | 41 ++
.../flight/sql/example/FlightSqlExample.java | 3 +
40 files changed, 4445 insertions(+), 380 deletions(-)
create mode 100644 cpp/src/arrow/flight/sql/acero_test.cc
create mode 100644 cpp/src/arrow/flight/sql/example/acero_main.cc
create mode 100644 cpp/src/arrow/flight/sql/example/acero_server.cc
create mode 100644 cpp/src/arrow/flight/sql/example/acero_server.h
create mode 100644 java/flight/flight-integration-tests/src/main/java/org/apache/arrow/flight/integration/tests/FlightSqlExtensionScenario.java
create mode 100644 java/flight/flight-integration-tests/src/test/java/org/apache/arrow/flight/integration/tests/IntegrationTest.java
create mode 100644 java/flight/flight-sql/src/main/java/org/apache/arrow/flight/sql/CancelListener.java
create mode 100644 java/flight/flight-sql/src/main/java/org/apache/arrow/flight/sql/CancelResult.java
create mode 100644 java/flight/flight-sql/src/main/java/org/apache/arrow/flight/sql/NoResultListener.java
create mode 100644 java/flight/flight-sql/src/main/java/org/apache/arrow/flight/sql/ProtoListener.java
diff --git a/cpp/src/arrow/flight/integration_tests/flight_integration_test.cc b/cpp/src/arrow/flight/integration_tests/flight_integration_test.cc
index 706ac3b7d931b..e29a281f32721 100644
--- a/cpp/src/arrow/flight/integration_tests/flight_integration_test.cc
+++ b/cpp/src/arrow/flight/integration_tests/flight_integration_test.cc
@@ -55,6 +55,10 @@ TEST(FlightIntegration, Middleware) { ASSERT_OK(RunScenario("middleware")); }
TEST(FlightIntegration, FlightSql) { ASSERT_OK(RunScenario("flight_sql")); }
+TEST(FlightIntegration, FlightSqlExtension) {
+ ASSERT_OK(RunScenario("flight_sql:extension"));
+}
+
} // namespace integration_tests
} // namespace flight
} // namespace arrow
diff --git a/cpp/src/arrow/flight/integration_tests/test_integration.cc b/cpp/src/arrow/flight/integration_tests/test_integration.cc
index b228f9cceba06..43c16e0b77a6d 100644
--- a/cpp/src/arrow/flight/integration_tests/test_integration.cc
+++ b/cpp/src/arrow/flight/integration_tests/test_integration.cc
@@ -16,25 +16,36 @@
// under the License.
#include "arrow/flight/integration_tests/test_integration.h"
+
+#include
+#include
+#include
+#include
+#include
+#include
+
+#include "arrow/array/array_binary.h"
+#include "arrow/array/array_nested.h"
+#include "arrow/array/array_primitive.h"
#include "arrow/flight/client_middleware.h"
#include "arrow/flight/server_middleware.h"
#include "arrow/flight/sql/client.h"
#include "arrow/flight/sql/column_metadata.h"
#include "arrow/flight/sql/server.h"
+#include "arrow/flight/sql/types.h"
#include "arrow/flight/test_util.h"
#include "arrow/flight/types.h"
#include "arrow/ipc/dictionary.h"
+#include "arrow/status.h"
#include "arrow/testing/gtest_util.h"
-
-#include
-#include
-#include
-#include
-#include
+#include "arrow/util/checked_cast.h"
namespace arrow {
namespace flight {
namespace integration_tests {
+namespace {
+
+using arrow::internal::checked_cast;
/// \brief The server for the basic auth integration test.
class AuthBasicProtoServer : public FlightServerBase {
@@ -263,29 +274,56 @@ class MiddlewareScenario : public Scenario {
};
/// \brief Schema to be returned for mocking the statement/prepared statement results.
+///
/// Must be the same across all languages.
-std::shared_ptr GetQuerySchema() {
- std::string table_name = "test";
- std::string schema_name = "schema_test";
- std::string catalog_name = "catalog_test";
- std::string type_name = "type_test";
- return arrow::schema({arrow::field("id", int64(), true,
- arrow::flight::sql::ColumnMetadata::Builder()
- .TableName(table_name)
- .IsAutoIncrement(true)
- .IsCaseSensitive(false)
- .TypeName(type_name)
- .SchemaName(schema_name)
- .IsSearchable(true)
- .CatalogName(catalog_name)
- .Precision(100)
- .Build()
- .metadata_map())});
+const std::shared_ptr& GetQuerySchema() {
+ static std::shared_ptr kSchema =
+ schema({field("id", int64(), /*nullable=*/true,
+ arrow::flight::sql::ColumnMetadata::Builder()
+ .TableName("test")
+ .IsAutoIncrement(true)
+ .IsCaseSensitive(false)
+ .TypeName("type_test")
+ .SchemaName("schema_test")
+ .IsSearchable(true)
+ .CatalogName("catalog_test")
+ .Precision(100)
+ .Build()
+ .metadata_map())});
+ return kSchema;
+}
+
+/// \brief Schema to be returned for queries with transactions.
+///
+/// Must be the same across all languages.
+std::shared_ptr GetQueryWithTransactionSchema() {
+ static std::shared_ptr kSchema =
+ schema({field("pkey", int32(), /*nullable=*/true,
+ arrow::flight::sql::ColumnMetadata::Builder()
+ .TableName("test")
+ .IsAutoIncrement(true)
+ .IsCaseSensitive(false)
+ .TypeName("type_test")
+ .SchemaName("schema_test")
+ .IsSearchable(true)
+ .CatalogName("catalog_test")
+ .Precision(100)
+ .Build()
+ .metadata_map())});
+ return kSchema;
}
constexpr int64_t kUpdateStatementExpectedRows = 10000L;
+constexpr int64_t kUpdateStatementWithTransactionExpectedRows = 15000L;
constexpr int64_t kUpdatePreparedStatementExpectedRows = 20000L;
+constexpr int64_t kUpdatePreparedStatementWithTransactionExpectedRows = 25000L;
constexpr char kSelectStatement[] = "SELECT STATEMENT";
+constexpr char kSavepointId[] = "savepoint_id";
+constexpr char kSavepointName[] = "savepoint_name";
+constexpr char kSubstraitPlanText[] = "plan";
+constexpr char kSubstraitVersion[] = "version";
+static const sql::SubstraitPlan kSubstraitPlan{kSubstraitPlanText, kSubstraitVersion};
+constexpr char kTransactionId[] = "transaction_id";
template
arrow::Status AssertEq(const T& expected, const T& actual, const std::string& message) {
@@ -296,25 +334,83 @@ arrow::Status AssertEq(const T& expected, const T& actual, const std::string& me
return Status::OK();
}
+template
+arrow::Status AssertUnprintableEq(const T& expected, const T& actual,
+ const std::string& message) {
+ if (expected != actual) {
+ return Status::Invalid(message);
+ }
+ return Status::OK();
+}
+
/// \brief The server used for testing Flight SQL, this implements a static Flight SQL
/// server which only asserts that commands called during integration tests are being
/// parsed correctly and returns the expected schemas to be validated on client.
class FlightSqlScenarioServer : public sql::FlightSqlServerBase {
public:
+ FlightSqlScenarioServer() : sql::FlightSqlServerBase() {
+ RegisterSqlInfo(sql::SqlInfoOptions::SqlInfo::FLIGHT_SQL_SERVER_SQL,
+ sql::SqlInfoResult(false));
+ RegisterSqlInfo(sql::SqlInfoOptions::SqlInfo::FLIGHT_SQL_SERVER_SUBSTRAIT,
+ sql::SqlInfoResult(true));
+ RegisterSqlInfo(sql::SqlInfoOptions::SqlInfo::FLIGHT_SQL_SERVER_SUBSTRAIT_MIN_VERSION,
+ sql::SqlInfoResult(std::string("min_version")));
+ RegisterSqlInfo(sql::SqlInfoOptions::SqlInfo::FLIGHT_SQL_SERVER_SUBSTRAIT_MAX_VERSION,
+ sql::SqlInfoResult(std::string("max_version")));
+ RegisterSqlInfo(sql::SqlInfoOptions::SqlInfo::FLIGHT_SQL_SERVER_TRANSACTION,
+ sql::SqlInfoResult(sql::SqlInfoOptions::SqlSupportedTransaction::
+ SQL_SUPPORTED_TRANSACTION_SAVEPOINT));
+ RegisterSqlInfo(sql::SqlInfoOptions::SqlInfo::FLIGHT_SQL_SERVER_CANCEL,
+ sql::SqlInfoResult(true));
+ RegisterSqlInfo(sql::SqlInfoOptions::SqlInfo::FLIGHT_SQL_SERVER_STATEMENT_TIMEOUT,
+ sql::SqlInfoResult(int32_t(42)));
+ RegisterSqlInfo(sql::SqlInfoOptions::SqlInfo::FLIGHT_SQL_SERVER_TRANSACTION_TIMEOUT,
+ sql::SqlInfoResult(int32_t(7)));
+ }
arrow::Result> GetFlightInfoStatement(
const ServerCallContext& context, const sql::StatementQuery& command,
const FlightDescriptor& descriptor) override {
ARROW_RETURN_NOT_OK(
AssertEq(kSelectStatement, command.query,
"Unexpected statement in GetFlightInfoStatement"));
-
- ARROW_ASSIGN_OR_RAISE(auto handle,
- sql::CreateStatementQueryTicket("SELECT STATEMENT HANDLE"));
-
+ std::string ticket;
+ Schema* schema;
+ if (command.transaction_id.empty()) {
+ ticket = "SELECT STATEMENT HANDLE";
+ schema = GetQuerySchema().get();
+ } else {
+ ticket = "SELECT STATEMENT WITH TXN HANDLE";
+ schema = GetQueryWithTransactionSchema().get();
+ }
+ ARROW_ASSIGN_OR_RAISE(auto handle, sql::CreateStatementQueryTicket(ticket));
std::vector endpoints{FlightEndpoint{{handle}, {}}};
- ARROW_ASSIGN_OR_RAISE(
- auto result, FlightInfo::Make(*GetQuerySchema(), descriptor, endpoints, -1, -1))
+ ARROW_ASSIGN_OR_RAISE(auto result,
+ FlightInfo::Make(*schema, descriptor, endpoints, -1, -1));
+ return std::unique_ptr(new FlightInfo(result));
+ }
+ arrow::Result> GetFlightInfoSubstraitPlan(
+ const ServerCallContext& context, const sql::StatementSubstraitPlan& command,
+ const FlightDescriptor& descriptor) override {
+ ARROW_RETURN_NOT_OK(
+ AssertEq(kSubstraitPlanText, command.plan.plan,
+ "Unexpected plan in GetFlightInfoSubstraitPlan"));
+ ARROW_RETURN_NOT_OK(
+ AssertEq(kSubstraitVersion, command.plan.version,
+ "Unexpected version in GetFlightInfoSubstraitPlan"));
+ std::string ticket;
+ Schema* schema;
+ if (command.transaction_id.empty()) {
+ ticket = "PLAN HANDLE";
+ schema = GetQuerySchema().get();
+ } else {
+ ticket = "PLAN WITH TXN HANDLE";
+ schema = GetQueryWithTransactionSchema().get();
+ }
+ ARROW_ASSIGN_OR_RAISE(auto handle, sql::CreateStatementQueryTicket(ticket));
+ std::vector endpoints{FlightEndpoint{{handle}, {}}};
+ ARROW_ASSIGN_OR_RAISE(auto result,
+ FlightInfo::Make(*schema, descriptor, endpoints, -1, -1));
return std::unique_ptr(new FlightInfo(result));
}
@@ -323,38 +419,84 @@ class FlightSqlScenarioServer : public sql::FlightSqlServerBase {
const FlightDescriptor& descriptor) override {
ARROW_RETURN_NOT_OK(AssertEq(
kSelectStatement, command.query, "Unexpected statement in GetSchemaStatement"));
- return SchemaResult::Make(*GetQuerySchema());
+ if (command.transaction_id.empty()) {
+ return SchemaResult::Make(*GetQuerySchema());
+ } else {
+ return SchemaResult::Make(*GetQueryWithTransactionSchema());
+ }
+ }
+
+ arrow::Result> GetSchemaSubstraitPlan(
+ const ServerCallContext& context, const sql::StatementSubstraitPlan& command,
+ const FlightDescriptor& descriptor) override {
+ ARROW_RETURN_NOT_OK(
+ AssertEq(kSubstraitPlanText, command.plan.plan,
+ "Unexpected statement in GetSchemaSubstraitPlan"));
+ ARROW_RETURN_NOT_OK(
+ AssertEq(kSubstraitVersion, command.plan.version,
+ "Unexpected version in GetFlightInfoSubstraitPlan"));
+ if (command.transaction_id.empty()) {
+ return SchemaResult::Make(*GetQuerySchema());
+ } else {
+ return SchemaResult::Make(*GetQueryWithTransactionSchema());
+ }
}
arrow::Result> DoGetStatement(
const ServerCallContext& context,
const sql::StatementQueryTicket& command) override {
- return DoGetForTestCase(GetQuerySchema());
+ if (command.statement_handle == "SELECT STATEMENT HANDLE" ||
+ command.statement_handle == "PLAN HANDLE") {
+ return DoGetForTestCase(GetQuerySchema());
+ } else if (command.statement_handle == "SELECT STATEMENT WITH TXN HANDLE" ||
+ command.statement_handle == "PLAN WITH TXN HANDLE") {
+ return DoGetForTestCase(GetQueryWithTransactionSchema());
+ }
+ return Status::Invalid("Unknown handle: ", command.statement_handle);
}
arrow::Result> GetFlightInfoPreparedStatement(
const ServerCallContext& context, const sql::PreparedStatementQuery& command,
const FlightDescriptor& descriptor) override {
- ARROW_RETURN_NOT_OK(AssertEq("SELECT PREPARED STATEMENT HANDLE",
- command.prepared_statement_handle,
- "Unexpected prepared statement handle"));
-
- return GetFlightInfoForCommand(descriptor, GetQuerySchema());
+ if (command.prepared_statement_handle == "SELECT PREPARED STATEMENT HANDLE" ||
+ command.prepared_statement_handle == "PLAN HANDLE") {
+ return GetFlightInfoForCommand(descriptor, GetQuerySchema());
+ } else if (command.prepared_statement_handle ==
+ "SELECT PREPARED STATEMENT WITH TXN HANDLE" ||
+ command.prepared_statement_handle == "PLAN WITH TXN HANDLE") {
+ return GetFlightInfoForCommand(descriptor, GetQueryWithTransactionSchema());
+ }
+ return Status::Invalid("Invalid handle for GetFlightInfoForCommand: ",
+ command.prepared_statement_handle);
}
arrow::Result> GetSchemaPreparedStatement(
const ServerCallContext& context, const sql::PreparedStatementQuery& command,
const FlightDescriptor& descriptor) override {
- ARROW_RETURN_NOT_OK(AssertEq("SELECT PREPARED STATEMENT HANDLE",
- command.prepared_statement_handle,
- "Unexpected prepared statement handle"));
- return SchemaResult::Make(*GetQuerySchema());
+ if (command.prepared_statement_handle == "SELECT PREPARED STATEMENT HANDLE" ||
+ command.prepared_statement_handle == "PLAN HANDLE") {
+ return SchemaResult::Make(*GetQuerySchema());
+ } else if (command.prepared_statement_handle ==
+ "SELECT PREPARED STATEMENT WITH TXN HANDLE" ||
+ command.prepared_statement_handle == "PLAN WITH TXN HANDLE") {
+ return SchemaResult::Make(*GetQueryWithTransactionSchema());
+ }
+ return Status::Invalid("Invalid handle for GetSchemaPreparedStatement: ",
+ command.prepared_statement_handle);
}
arrow::Result> DoGetPreparedStatement(
const ServerCallContext& context,
const sql::PreparedStatementQuery& command) override {
- return DoGetForTestCase(GetQuerySchema());
+ if (command.prepared_statement_handle == "SELECT PREPARED STATEMENT HANDLE" ||
+ command.prepared_statement_handle == "PLAN HANDLE") {
+ return DoGetForTestCase(GetQuerySchema());
+ } else if (command.prepared_statement_handle ==
+ "SELECT PREPARED STATEMENT WITH TXN HANDLE" ||
+ command.prepared_statement_handle == "PLAN WITH TXN HANDLE") {
+ return DoGetForTestCase(GetQueryWithTransactionSchema());
+ }
+ return Status::Invalid("Invalid handle: ", command.prepared_statement_handle);
}
arrow::Result> GetFlightInfoCatalogs(
@@ -381,21 +523,29 @@ class FlightSqlScenarioServer : public sql::FlightSqlServerBase {
arrow::Result> GetFlightInfoSqlInfo(
const ServerCallContext& context, const sql::GetSqlInfo& command,
const FlightDescriptor& descriptor) override {
- ARROW_RETURN_NOT_OK(AssertEq(2, command.info.size(),
- "Wrong number of SqlInfo values passed"));
- ARROW_RETURN_NOT_OK(
- AssertEq(sql::SqlInfoOptions::SqlInfo::FLIGHT_SQL_SERVER_NAME,
- command.info[0], "Unexpected SqlInfo passed"));
- ARROW_RETURN_NOT_OK(
- AssertEq(sql::SqlInfoOptions::SqlInfo::FLIGHT_SQL_SERVER_READ_ONLY,
- command.info[1], "Unexpected SqlInfo passed"));
-
- return GetFlightInfoForCommand(descriptor, sql::SqlSchema::GetSqlInfoSchema());
+ if (command.info.size() == 2) {
+ // Integration test for the protocol messages
+ ARROW_RETURN_NOT_OK(
+ AssertEq(sql::SqlInfoOptions::SqlInfo::FLIGHT_SQL_SERVER_NAME,
+ command.info[0], "Unexpected SqlInfo passed"));
+ ARROW_RETURN_NOT_OK(
+ AssertEq(sql::SqlInfoOptions::SqlInfo::FLIGHT_SQL_SERVER_READ_ONLY,
+ command.info[1], "Unexpected SqlInfo passed"));
+
+ return GetFlightInfoForCommand(descriptor, sql::SqlSchema::GetSqlInfoSchema());
+ }
+ // Integration test for the values themselves
+ return sql::FlightSqlServerBase::GetFlightInfoSqlInfo(context, command, descriptor);
}
arrow::Result> DoGetSqlInfo(
const ServerCallContext& context, const sql::GetSqlInfo& command) override {
- return DoGetForTestCase(sql::SqlSchema::GetSqlInfoSchema());
+ if (command.info.size() == 2) {
+ // Integration test for the protocol messages
+ return DoGetForTestCase(sql::SqlSchema::GetSqlInfoSchema());
+ }
+ // Integration test for the values themselves
+ return sql::FlightSqlServerBase::DoGetSqlInfo(context, command);
}
arrow::Result> GetFlightInfoSchemas(
@@ -539,8 +689,21 @@ class FlightSqlScenarioServer : public sql::FlightSqlServerBase {
ARROW_RETURN_NOT_OK(
AssertEq("UPDATE STATEMENT", command.query,
"Wrong query for DoPutCommandStatementUpdate"));
+ return command.transaction_id.empty() ? kUpdateStatementExpectedRows
+ : kUpdateStatementWithTransactionExpectedRows;
+ }
- return kUpdateStatementExpectedRows;
+ arrow::Result DoPutCommandSubstraitPlan(
+ const ServerCallContext& context,
+ const sql::StatementSubstraitPlan& command) override {
+ ARROW_RETURN_NOT_OK(
+ AssertEq(kSubstraitPlanText, command.plan.plan,
+ "Wrong plan for DoPutCommandSubstraitPlan"));
+ ARROW_RETURN_NOT_OK(
+ AssertEq(kSubstraitVersion, command.plan.version,
+ "Unexpected version in GetFlightInfoSubstraitPlan"));
+ return command.transaction_id.empty() ? kUpdateStatementExpectedRows
+ : kUpdateStatementWithTransactionExpectedRows;
}
arrow::Result CreatePreparedStatement(
@@ -552,8 +715,26 @@ class FlightSqlScenarioServer : public sql::FlightSqlServerBase {
}
sql::ActionCreatePreparedStatementResult result;
- result.prepared_statement_handle = request.query + " HANDLE";
+ result.prepared_statement_handle = request.query;
+ if (!request.transaction_id.empty()) {
+ result.prepared_statement_handle += " WITH TXN";
+ }
+ result.prepared_statement_handle += " HANDLE";
+ return result;
+ }
+ arrow::Result CreatePreparedSubstraitPlan(
+ const ServerCallContext& context,
+ const sql::ActionCreatePreparedSubstraitPlanRequest& request) override {
+ ARROW_RETURN_NOT_OK(
+ AssertEq(kSubstraitPlanText, request.plan.plan,
+ "Wrong plan for CreatePreparedSubstraitPlan"));
+ ARROW_RETURN_NOT_OK(
+ AssertEq(kSubstraitVersion, request.plan.version,
+ "Unexpected version in GetFlightInfoSubstraitPlan"));
+ sql::ActionCreatePreparedStatementResult result;
+ result.prepared_statement_handle =
+ request.transaction_id.empty() ? "PLAN HANDLE" : "PLAN WITH TXN HANDLE";
return result;
}
@@ -561,7 +742,13 @@ class FlightSqlScenarioServer : public sql::FlightSqlServerBase {
const ServerCallContext& context,
const sql::ActionClosePreparedStatementRequest& request) override {
if (request.prepared_statement_handle != "SELECT PREPARED STATEMENT HANDLE" &&
- request.prepared_statement_handle != "UPDATE PREPARED STATEMENT HANDLE") {
+ request.prepared_statement_handle != "UPDATE PREPARED STATEMENT HANDLE" &&
+ request.prepared_statement_handle != "PLAN HANDLE" &&
+ request.prepared_statement_handle !=
+ "SELECT PREPARED STATEMENT WITH TXN HANDLE" &&
+ request.prepared_statement_handle !=
+ "UPDATE PREPARED STATEMENT WITH TXN HANDLE" &&
+ request.prepared_statement_handle != "PLAN WITH TXN HANDLE") {
return Status::Invalid("Invalid handle for ClosePreparedStatement: ",
request.prepared_statement_handle);
}
@@ -572,28 +759,95 @@ class FlightSqlScenarioServer : public sql::FlightSqlServerBase {
const sql::PreparedStatementQuery& command,
FlightMessageReader* reader,
FlightMetadataWriter* writer) override {
- if (command.prepared_statement_handle != "SELECT PREPARED STATEMENT HANDLE") {
+ if (command.prepared_statement_handle != "SELECT PREPARED STATEMENT HANDLE" &&
+ command.prepared_statement_handle !=
+ "SELECT PREPARED STATEMENT WITH TXN HANDLE" &&
+ command.prepared_statement_handle != "PLAN HANDLE" &&
+ command.prepared_statement_handle != "PLAN WITH TXN HANDLE") {
return Status::Invalid("Invalid handle for DoPutPreparedStatementQuery: ",
command.prepared_statement_handle);
}
-
ARROW_ASSIGN_OR_RAISE(auto actual_schema, reader->GetSchema());
ARROW_RETURN_NOT_OK(AssertEq(*GetQuerySchema(), *actual_schema,
"Wrong schema for DoPutPreparedStatementQuery"));
-
return Status::OK();
}
arrow::Result DoPutPreparedStatementUpdate(
const ServerCallContext& context, const sql::PreparedStatementUpdate& command,
FlightMessageReader* reader) override {
- if (command.prepared_statement_handle == "UPDATE PREPARED STATEMENT HANDLE") {
+ if (command.prepared_statement_handle == "UPDATE PREPARED STATEMENT HANDLE" ||
+ command.prepared_statement_handle == "PLAN HANDLE") {
return kUpdatePreparedStatementExpectedRows;
+ } else if (command.prepared_statement_handle ==
+ "UPDATE PREPARED STATEMENT WITH TXN HANDLE" ||
+ command.prepared_statement_handle == "PLAN WITH TXN HANDLE") {
+ return kUpdatePreparedStatementWithTransactionExpectedRows;
}
return Status::Invalid("Invalid handle for DoPutPreparedStatementUpdate: ",
command.prepared_statement_handle);
}
+ arrow::Result BeginSavepoint(
+ const ServerCallContext& context,
+ const sql::ActionBeginSavepointRequest& request) override {
+ ARROW_RETURN_NOT_OK(AssertEq(
+ kSavepointName, request.name, "Unexpected savepoint name in BeginSavepoint"));
+ ARROW_RETURN_NOT_OK(
+ AssertEq(kTransactionId, request.transaction_id,
+ "Unexpected transaction ID in BeginSavepoint"));
+ return sql::ActionBeginSavepointResult{kSavepointId};
+ }
+
+ arrow::Result BeginTransaction(
+ const ServerCallContext& context,
+ const sql::ActionBeginTransactionRequest& request) override {
+ return sql::ActionBeginTransactionResult{kTransactionId};
+ }
+
+ arrow::Result CancelQuery(
+ const ServerCallContext& context,
+ const sql::ActionCancelQueryRequest& request) override {
+ ARROW_RETURN_NOT_OK(AssertEq(1, request.info->endpoints().size(),
+ "Expected 1 endpoint for CancelQuery"));
+ const FlightEndpoint& endpoint = request.info->endpoints()[0];
+ ARROW_ASSIGN_OR_RAISE(auto ticket,
+ sql::StatementQueryTicket::Deserialize(endpoint.ticket.ticket));
+ ARROW_RETURN_NOT_OK(AssertEq("PLAN HANDLE", ticket.statement_handle,
+ "Unexpected ticket in CancelQuery"));
+ return sql::CancelResult::kCancelled;
+ }
+
+ Status EndSavepoint(const ServerCallContext& context,
+ const sql::ActionEndSavepointRequest& request) override {
+ switch (request.action) {
+ case sql::ActionEndSavepointRequest::kRelease:
+ case sql::ActionEndSavepointRequest::kRollback:
+ ARROW_RETURN_NOT_OK(
+ AssertEq(kSavepointId, request.savepoint_id,
+ "Unexpected savepoint ID in EndSavepoint"));
+ break;
+ default:
+ return Status::Invalid("Unknown action ", static_cast(request.action));
+ }
+ return Status::OK();
+ }
+
+ Status EndTransaction(const ServerCallContext& context,
+ const sql::ActionEndTransactionRequest& request) override {
+ switch (request.action) {
+ case sql::ActionEndTransactionRequest::kCommit:
+ case sql::ActionEndTransactionRequest::kRollback:
+ ARROW_RETURN_NOT_OK(
+ AssertEq(kTransactionId, request.transaction_id,
+ "Unexpected transaction ID in EndTransaction"));
+ break;
+ default:
+ return Status::Invalid("Unknown action ", static_cast(request.action));
+ }
+ return Status::OK();
+ }
+
private:
arrow::Result> GetFlightInfoForCommand(
const FlightDescriptor& descriptor, const std::shared_ptr& schema) {
@@ -615,6 +869,7 @@ class FlightSqlScenarioServer : public sql::FlightSqlServerBase {
/// implementations. This should ensure that RPC objects are being built and parsed
/// correctly for multiple languages and that the Arrow schemas are returned as expected.
class FlightSqlScenario : public Scenario {
+ public:
Status MakeServer(std::unique_ptr* server,
FlightServerOptions* options) override {
server->reset(new FlightSqlScenarioServer());
@@ -785,10 +1040,290 @@ class FlightSqlScenario : public Scenario {
AssertEq(kUpdatePreparedStatementExpectedRows, updated_rows,
"Wrong number of updated rows for prepared statement ExecuteUpdate"));
ARROW_RETURN_NOT_OK(update_prepared_statement->Close());
+ return Status::OK();
+ }
+};
+
+/// \brief Integration test scenario for validating the Substrait and
+/// transaction extensions to Flight SQL.
+class FlightSqlExtensionScenario : public FlightSqlScenario {
+ public:
+ Status RunClient(std::unique_ptr client) override {
+ sql::FlightSqlClient sql_client(std::move(client));
+ Status status;
+ if (!(status = ValidateMetadataRetrieval(&sql_client)).ok()) {
+ return status.WithMessage("MetadataRetrieval failed: ", status.message());
+ }
+ if (!(status = ValidateStatementExecution(&sql_client)).ok()) {
+ return status.WithMessage("StatementExecution failed: ", status.message());
+ }
+ if (!(status = ValidatePreparedStatementExecution(&sql_client)).ok()) {
+ return status.WithMessage("PreparedStatementExecution failed: ", status.message());
+ }
+ if (!(status = ValidateTransactions(&sql_client)).ok()) {
+ return status.WithMessage("Transactions failed: ", status.message());
+ }
+ return Status::OK();
+ }
+
+ Status ValidateMetadataRetrieval(sql::FlightSqlClient* sql_client) {
+ std::unique_ptr info;
+ std::vector sql_info = {
+ sql::SqlInfoOptions::FLIGHT_SQL_SERVER_SQL,
+ sql::SqlInfoOptions::FLIGHT_SQL_SERVER_SUBSTRAIT,
+ sql::SqlInfoOptions::FLIGHT_SQL_SERVER_SUBSTRAIT_MIN_VERSION,
+ sql::SqlInfoOptions::FLIGHT_SQL_SERVER_SUBSTRAIT_MAX_VERSION,
+ sql::SqlInfoOptions::FLIGHT_SQL_SERVER_TRANSACTION,
+ sql::SqlInfoOptions::FLIGHT_SQL_SERVER_CANCEL,
+ sql::SqlInfoOptions::FLIGHT_SQL_SERVER_STATEMENT_TIMEOUT,
+ sql::SqlInfoOptions::FLIGHT_SQL_SERVER_TRANSACTION_TIMEOUT,
+ };
+ ARROW_ASSIGN_OR_RAISE(info, sql_client->GetSqlInfo({}, sql_info));
+ ARROW_ASSIGN_OR_RAISE(auto reader,
+ sql_client->DoGet({}, info->endpoints()[0].ticket));
+
+ ARROW_ASSIGN_OR_RAISE(auto actual_schema, reader->GetSchema());
+ if (!sql::SqlSchema::GetSqlInfoSchema()->Equals(*actual_schema,
+ /*check_metadata=*/true)) {
+ return Status::Invalid("Schemas did not match. Expected:\n",
+ *sql::SqlSchema::GetSqlInfoSchema(), "\nActual:\n",
+ *actual_schema);
+ }
+
+ sql::SqlInfoResultMap info_values;
+ while (true) {
+ ARROW_ASSIGN_OR_RAISE(auto chunk, reader->Next());
+ if (!chunk.data) break;
+
+ const auto& info_name = checked_cast(*chunk.data->column(0));
+ const auto& value = checked_cast(*chunk.data->column(1));
+
+ for (int64_t i = 0; i < chunk.data->num_rows(); i++) {
+ const uint32_t code = info_name.Value(i);
+ if (info_values.find(code) != info_values.end()) {
+ return Status::Invalid("Duplicate SqlInfo value ", code);
+ }
+ switch (value.type_code(i)) {
+ case 0: { // string
+ std::string slot = checked_cast(*value.field(0))
+ .GetString(value.value_offset(i));
+ info_values[code] = sql::SqlInfoResult(std::move(slot));
+ break;
+ }
+ case 1: { // bool
+ bool slot = checked_cast(*value.field(1))
+ .Value(value.value_offset(i));
+ info_values[code] = sql::SqlInfoResult(slot);
+ break;
+ }
+ case 2: { // int64_t
+ int64_t slot = checked_cast(*value.field(2))
+ .Value(value.value_offset(i));
+ info_values[code] = sql::SqlInfoResult(slot);
+ break;
+ }
+ case 3: { // int32_t
+ int32_t slot = checked_cast(*value.field(3))
+ .Value(value.value_offset(i));
+ info_values[code] = sql::SqlInfoResult(slot);
+ break;
+ }
+ default:
+ return Status::NotImplemented("Decoding SqlInfoResult of type code ",
+ value.type_code(i));
+ }
+ }
+ }
+
+ ARROW_RETURN_NOT_OK(AssertUnprintableEq(
+ info_values[sql::SqlInfoOptions::FLIGHT_SQL_SERVER_SQL],
+ sql::SqlInfoResult(false), "FLIGHT_SQL_SERVER_SQL did not match"));
+ ARROW_RETURN_NOT_OK(AssertUnprintableEq(
+ info_values[sql::SqlInfoOptions::FLIGHT_SQL_SERVER_SUBSTRAIT],
+ sql::SqlInfoResult(true), "FLIGHT_SQL_SERVER_SUBSTRAIT did not match"));
+ ARROW_RETURN_NOT_OK(AssertUnprintableEq(
+ info_values[sql::SqlInfoOptions::FLIGHT_SQL_SERVER_SUBSTRAIT_MIN_VERSION],
+ sql::SqlInfoResult(std::string("min_version")),
+ "FLIGHT_SQL_SERVER_SUBSTRAIT_MIN_VERSION did not match"));
+ ARROW_RETURN_NOT_OK(AssertUnprintableEq(
+ info_values[sql::SqlInfoOptions::FLIGHT_SQL_SERVER_SUBSTRAIT_MAX_VERSION],
+ sql::SqlInfoResult(std::string("max_version")),
+ "FLIGHT_SQL_SERVER_SUBSTRAIT_MAX_VERSION did not match"));
+ ARROW_RETURN_NOT_OK(AssertUnprintableEq(
+ info_values[sql::SqlInfoOptions::FLIGHT_SQL_SERVER_TRANSACTION],
+ sql::SqlInfoResult(sql::SqlInfoOptions::SqlSupportedTransaction::
+ SQL_SUPPORTED_TRANSACTION_SAVEPOINT),
+ "FLIGHT_SQL_SERVER_TRANSACTION did not match"));
+ ARROW_RETURN_NOT_OK(AssertUnprintableEq(
+ info_values[sql::SqlInfoOptions::FLIGHT_SQL_SERVER_CANCEL],
+ sql::SqlInfoResult(true), "FLIGHT_SQL_SERVER_CANCEL did not match"));
+ ARROW_RETURN_NOT_OK(AssertUnprintableEq(
+ info_values[sql::SqlInfoOptions::FLIGHT_SQL_SERVER_STATEMENT_TIMEOUT],
+ sql::SqlInfoResult(int32_t(42)),
+ "FLIGHT_SQL_SERVER_STATEMENT_TIMEOUT did not match"));
+ ARROW_RETURN_NOT_OK(AssertUnprintableEq(
+ info_values[sql::SqlInfoOptions::FLIGHT_SQL_SERVER_TRANSACTION_TIMEOUT],
+ sql::SqlInfoResult(int32_t(7)),
+ "FLIGHT_SQL_SERVER_TRANSACTION_TIMEOUT did not match"));
+
+ return Status::OK();
+ }
+
+ Status ValidateStatementExecution(sql::FlightSqlClient* sql_client) {
+ ARROW_ASSIGN_OR_RAISE(std::unique_ptr info,
+ sql_client->ExecuteSubstrait({}, kSubstraitPlan));
+ ARROW_RETURN_NOT_OK(Validate(GetQuerySchema(), *info, sql_client));
+
+ ARROW_ASSIGN_OR_RAISE(std::unique_ptr schema,
+ sql_client->GetExecuteSubstraitSchema({}, kSubstraitPlan));
+ ARROW_RETURN_NOT_OK(ValidateSchema(GetQuerySchema(), *schema));
+
+ ARROW_ASSIGN_OR_RAISE(info, sql_client->ExecuteSubstrait({}, kSubstraitPlan));
+ ARROW_ASSIGN_OR_RAISE(sql::CancelResult cancel_result,
+ sql_client->CancelQuery({}, *info));
+ ARROW_RETURN_NOT_OK(
+ AssertEq(sql::CancelResult::kCancelled, cancel_result, "Wrong cancel result"));
+
+ ARROW_ASSIGN_OR_RAISE(const int64_t updated_rows,
+ sql_client->ExecuteSubstraitUpdate({}, kSubstraitPlan));
+ ARROW_RETURN_NOT_OK(
+ AssertEq(kUpdateStatementExpectedRows, updated_rows,
+ "Wrong number of updated rows for ExecuteSubstraitUpdate"));
+
+ return Status::OK();
+ }
+
+ Status ValidatePreparedStatementExecution(sql::FlightSqlClient* sql_client) {
+ auto parameters =
+ RecordBatch::Make(GetQuerySchema(), 1, {ArrayFromJSON(int64(), "[1]")});
+
+ ARROW_ASSIGN_OR_RAISE(
+ std::shared_ptr substrait_prepared_statement,
+ sql_client->PrepareSubstrait({}, kSubstraitPlan));
+ ARROW_RETURN_NOT_OK(substrait_prepared_statement->SetParameters(parameters));
+ ARROW_ASSIGN_OR_RAISE(std::unique_ptr info,
+ substrait_prepared_statement->Execute());
+ ARROW_RETURN_NOT_OK(Validate(GetQuerySchema(), *info, sql_client));
+ ARROW_ASSIGN_OR_RAISE(std::unique_ptr schema,
+ substrait_prepared_statement->GetSchema({}));
+ ARROW_RETURN_NOT_OK(ValidateSchema(GetQuerySchema(), *schema));
+ ARROW_RETURN_NOT_OK(substrait_prepared_statement->Close());
+
+ ARROW_ASSIGN_OR_RAISE(
+ std::shared_ptr update_substrait_prepared_statement,
+ sql_client->PrepareSubstrait({}, kSubstraitPlan));
+ ARROW_ASSIGN_OR_RAISE(const int64_t updated_rows,
+ update_substrait_prepared_statement->ExecuteUpdate());
+ ARROW_RETURN_NOT_OK(
+ AssertEq(kUpdatePreparedStatementExpectedRows, updated_rows,
+ "Wrong number of updated rows for prepared statement ExecuteUpdate"));
+ ARROW_RETURN_NOT_OK(update_substrait_prepared_statement->Close());
+
+ return Status::OK();
+ }
+
+ Status ValidateTransactions(sql::FlightSqlClient* sql_client) {
+ ARROW_ASSIGN_OR_RAISE(sql::Transaction transaction, sql_client->BeginTransaction({}));
+ ARROW_RETURN_NOT_OK(AssertEq(
+ kTransactionId, transaction.transaction_id(), "Wrong transaction ID"));
+
+ ARROW_ASSIGN_OR_RAISE(sql::Savepoint savepoint,
+ sql_client->BeginSavepoint({}, transaction, kSavepointName));
+ ARROW_RETURN_NOT_OK(AssertEq(kSavepointId, savepoint.savepoint_id(),
+ "Wrong savepoint ID"));
+
+ ARROW_ASSIGN_OR_RAISE(std::unique_ptr info,
+ sql_client->Execute({}, kSelectStatement, transaction));
+ ARROW_RETURN_NOT_OK(Validate(GetQueryWithTransactionSchema(), *info, sql_client));
+
+ ARROW_ASSIGN_OR_RAISE(info,
+ sql_client->ExecuteSubstrait({}, kSubstraitPlan, transaction));
+ ARROW_RETURN_NOT_OK(Validate(GetQueryWithTransactionSchema(), *info, sql_client));
+
+ ARROW_ASSIGN_OR_RAISE(
+ std::unique_ptr schema,
+ sql_client->GetExecuteSchema({}, kSelectStatement, transaction));
+ ARROW_RETURN_NOT_OK(ValidateSchema(GetQueryWithTransactionSchema(), *schema));
+
+ ARROW_ASSIGN_OR_RAISE(
+ schema, sql_client->GetExecuteSubstraitSchema({}, kSubstraitPlan, transaction));
+ ARROW_RETURN_NOT_OK(ValidateSchema(GetQueryWithTransactionSchema(), *schema));
+
+ ARROW_ASSIGN_OR_RAISE(int64_t updated_rows,
+ sql_client->ExecuteUpdate({}, "UPDATE STATEMENT", transaction));
+ ARROW_RETURN_NOT_OK(
+ AssertEq(kUpdateStatementWithTransactionExpectedRows, updated_rows,
+ "Wrong number of updated rows for ExecuteUpdate with transaction"));
+ ARROW_ASSIGN_OR_RAISE(updated_rows, sql_client->ExecuteSubstraitUpdate(
+ {}, kSubstraitPlan, transaction));
+ ARROW_RETURN_NOT_OK(AssertEq(
+ kUpdateStatementWithTransactionExpectedRows, updated_rows,
+ "Wrong number of updated rows for ExecuteSubstraitUpdate with transaction"));
+
+ auto parameters =
+ RecordBatch::Make(GetQuerySchema(), 1, {ArrayFromJSON(int64(), "[1]")});
+
+ ARROW_ASSIGN_OR_RAISE(
+ std::shared_ptr select_prepared_statement,
+ sql_client->Prepare({}, "SELECT PREPARED STATEMENT", transaction));
+ ARROW_RETURN_NOT_OK(select_prepared_statement->SetParameters(parameters));
+ ARROW_ASSIGN_OR_RAISE(info, select_prepared_statement->Execute());
+ ARROW_RETURN_NOT_OK(Validate(GetQueryWithTransactionSchema(), *info, sql_client));
+ ARROW_ASSIGN_OR_RAISE(schema, select_prepared_statement->GetSchema({}));
+ ARROW_RETURN_NOT_OK(ValidateSchema(GetQueryWithTransactionSchema(), *schema));
+ ARROW_RETURN_NOT_OK(select_prepared_statement->Close());
+
+ ARROW_ASSIGN_OR_RAISE(
+ std::shared_ptr substrait_prepared_statement,
+ sql_client->PrepareSubstrait({}, kSubstraitPlan, transaction));
+ ARROW_RETURN_NOT_OK(substrait_prepared_statement->SetParameters(parameters));
+ ARROW_ASSIGN_OR_RAISE(info, substrait_prepared_statement->Execute());
+ ARROW_RETURN_NOT_OK(Validate(GetQueryWithTransactionSchema(), *info, sql_client));
+ ARROW_ASSIGN_OR_RAISE(schema, substrait_prepared_statement->GetSchema({}));
+ ARROW_RETURN_NOT_OK(ValidateSchema(GetQueryWithTransactionSchema(), *schema));
+ ARROW_RETURN_NOT_OK(substrait_prepared_statement->Close());
+
+ ARROW_ASSIGN_OR_RAISE(
+ std::shared_ptr update_prepared_statement,
+ sql_client->Prepare({}, "UPDATE PREPARED STATEMENT", transaction));
+ ARROW_ASSIGN_OR_RAISE(updated_rows, update_prepared_statement->ExecuteUpdate());
+ ARROW_RETURN_NOT_OK(AssertEq(kUpdatePreparedStatementWithTransactionExpectedRows,
+ updated_rows,
+ "Wrong number of updated rows for prepared statement "
+ "ExecuteUpdate with transaction"));
+ ARROW_RETURN_NOT_OK(update_prepared_statement->Close());
+
+ ARROW_ASSIGN_OR_RAISE(
+ std::shared_ptr update_substrait_prepared_statement,
+ sql_client->PrepareSubstrait({}, kSubstraitPlan, transaction));
+ ARROW_ASSIGN_OR_RAISE(updated_rows,
+ update_substrait_prepared_statement->ExecuteUpdate());
+ ARROW_RETURN_NOT_OK(AssertEq(kUpdatePreparedStatementWithTransactionExpectedRows,
+ updated_rows,
+ "Wrong number of updated rows for prepared statement "
+ "ExecuteUpdate with transaction"));
+ ARROW_RETURN_NOT_OK(update_substrait_prepared_statement->Close());
+
+ ARROW_RETURN_NOT_OK(sql_client->Rollback({}, savepoint));
+
+ ARROW_ASSIGN_OR_RAISE(sql::Savepoint savepoint2,
+ sql_client->BeginSavepoint({}, transaction, kSavepointName));
+ ARROW_RETURN_NOT_OK(AssertEq(kSavepointId, savepoint.savepoint_id(),
+ "Wrong savepoint ID"));
+ ARROW_RETURN_NOT_OK(sql_client->Release({}, savepoint));
+
+ ARROW_RETURN_NOT_OK(sql_client->Commit({}, transaction));
+
+ ARROW_ASSIGN_OR_RAISE(sql::Transaction transaction2,
+ sql_client->BeginTransaction({}));
+ ARROW_RETURN_NOT_OK(AssertEq(
+ kTransactionId, transaction.transaction_id(), "Wrong transaction ID"));
+ ARROW_RETURN_NOT_OK(sql_client->Rollback({}, transaction2));
return Status::OK();
}
};
+} // namespace
Status GetScenario(const std::string& scenario_name, std::shared_ptr* out) {
if (scenario_name == "auth:basic_proto") {
@@ -800,6 +1335,9 @@ Status GetScenario(const std::string& scenario_name, std::shared_ptr*
} else if (scenario_name == "flight_sql") {
*out = std::make_shared();
return Status::OK();
+ } else if (scenario_name == "flight_sql:extension") {
+ *out = std::make_shared();
+ return Status::OK();
}
return Status::KeyError("Scenario not found: ", scenario_name);
}
diff --git a/cpp/src/arrow/flight/server.cc b/cpp/src/arrow/flight/server.cc
index 1a3b52910c0ad..e9736b0615e44 100644
--- a/cpp/src/arrow/flight/server.cc
+++ b/cpp/src/arrow/flight/server.cc
@@ -353,7 +353,9 @@ RecordBatchStream::RecordBatchStream(const std::shared_ptr& r
impl_.reset(new RecordBatchStreamImpl(reader, options));
}
-RecordBatchStream::~RecordBatchStream() {}
+RecordBatchStream::~RecordBatchStream() {
+ ARROW_WARN_NOT_OK(impl_->Close(), "Failed to close FlightDataStream");
+}
Status RecordBatchStream::Close() { return impl_->Close(); }
diff --git a/cpp/src/arrow/flight/sql/CMakeLists.txt b/cpp/src/arrow/flight/sql/CMakeLists.txt
index f7312de23a972..14503069dd004 100644
--- a/cpp/src/arrow/flight/sql/CMakeLists.txt
+++ b/cpp/src/arrow/flight/sql/CMakeLists.txt
@@ -89,7 +89,11 @@ if(ARROW_BUILD_TESTS OR ARROW_BUILD_EXAMPLES)
example/sqlite_statement_batch_reader.cc
example/sqlite_server.cc
example/sqlite_tables_schema_batch_reader.cc)
+
set(ARROW_FLIGHT_SQL_TEST_SRCS server_test.cc)
+ set(ARROW_FLIGHT_SQL_TEST_LIBS ${SQLite3_LIBRARIES})
+ set(ARROW_FLIGHT_SQL_ACERO_SRCS example/acero_server.cc)
+
if(NOT MSVC AND NOT MINGW)
# ARROW-16902: getting Protobuf generated code to have all the
# proper dllexport/dllimport declarations is difficult, since
@@ -98,13 +102,34 @@ if(ARROW_BUILD_TESTS OR ARROW_BUILD_EXAMPLES)
list(APPEND ARROW_FLIGHT_SQL_TEST_SRCS client_test.cc)
endif()
+ if(ARROW_COMPUTE
+ AND ARROW_PARQUET
+ AND ARROW_SUBSTRAIT)
+ list(APPEND ARROW_FLIGHT_SQL_TEST_SRCS ${ARROW_FLIGHT_SQL_ACERO_SRCS} acero_test.cc)
+ if(ARROW_FLIGHT_TEST_LINKAGE STREQUAL "static")
+ list(APPEND ARROW_FLIGHT_SQL_TEST_LIBS arrow_substrait_static)
+ else()
+ list(APPEND ARROW_FLIGHT_SQL_TEST_LIBS arrow_substrait_shared)
+ endif()
+
+ if(ARROW_BUILD_EXAMPLES)
+ add_executable(acero-flight-sql-server ${ARROW_FLIGHT_SQL_ACERO_SRCS}
+ example/acero_main.cc)
+ target_link_libraries(acero-flight-sql-server
+ PRIVATE ${ARROW_FLIGHT_SQL_TEST_LINK_LIBS}
+ ${ARROW_FLIGHT_SQL_TEST_LIBS} ${GFLAGS_LIBRARIES})
+ endif()
+ endif()
+
add_arrow_test(flight_sql_test
SOURCES
${ARROW_FLIGHT_SQL_TEST_SRCS}
${ARROW_FLIGHT_SQL_TEST_SERVER_SRCS}
STATIC_LINK_LIBS
${ARROW_FLIGHT_SQL_TEST_LINK_LIBS}
- ${SQLite3_LIBRARIES}
+ ${ARROW_FLIGHT_SQL_TEST_LIBS}
+ EXTRA_INCLUDES
+ "${CMAKE_CURRENT_BINARY_DIR}/../"
LABELS
"arrow_flight_sql")
diff --git a/cpp/src/arrow/flight/sql/acero_test.cc b/cpp/src/arrow/flight/sql/acero_test.cc
new file mode 100644
index 0000000000000..fd3c52e74f39e
--- /dev/null
+++ b/cpp/src/arrow/flight/sql/acero_test.cc
@@ -0,0 +1,239 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+/// Integration test using the Acero backend
+
+#include
+#include
+
+#include
+#include
+
+#include "arrow/array.h"
+#include "arrow/engine/substrait/util.h"
+#include "arrow/flight/server.h"
+#include "arrow/flight/sql/client.h"
+#include "arrow/flight/sql/example/acero_server.h"
+#include "arrow/flight/sql/types.h"
+#include "arrow/flight/types.h"
+#include "arrow/stl_iterator.h"
+#include "arrow/table.h"
+#include "arrow/testing/gtest_util.h"
+#include "arrow/type_fwd.h"
+#include "arrow/util/checked_cast.h"
+
+namespace arrow {
+namespace flight {
+namespace sql {
+
+using arrow::internal::checked_cast;
+
+class TestAcero : public ::testing::Test {
+ public:
+ void SetUp() override {
+ ASSERT_OK_AND_ASSIGN(auto location, Location::ForGrpcTcp("localhost", 0));
+ flight::FlightServerOptions options(location);
+
+ ASSERT_OK_AND_ASSIGN(server_, acero_example::MakeAceroServer());
+ ASSERT_OK(server_->Init(options));
+
+ ASSERT_OK_AND_ASSIGN(auto client, FlightClient::Connect(server_->location()));
+ client_.reset(new FlightSqlClient(std::move(client)));
+ }
+
+ void TearDown() override {
+ ASSERT_OK(client_->Close());
+ ASSERT_OK(server_->Shutdown());
+ }
+
+ protected:
+ std::unique_ptr client_;
+ std::unique_ptr server_;
+};
+
+arrow::Result> MakeSubstraitPlan() {
+ ARROW_ASSIGN_OR_RAISE(std::string dir_string,
+ arrow::internal::GetEnvVar("PARQUET_TEST_DATA"));
+ ARROW_ASSIGN_OR_RAISE(auto dir,
+ arrow::internal::PlatformFilename::FromString(dir_string));
+ ARROW_ASSIGN_OR_RAISE(auto filename, dir.Join("binary.parquet"));
+ std::string uri = std::string("file://") + filename.ToString();
+
+ // TODO(ARROW-17229): we should use a RootRel here
+ std::string json_plan = R"({
+ "relations": [
+ {
+ "rel": {
+ "read": {
+ "base_schema": {
+ "struct": {
+ "types": [
+ {"binary": {}}
+ ]
+ },
+ "names": [
+ "foo"
+ ]
+ },
+ "local_files": {
+ "items": [
+ {
+ "uri_file": "URI_PLACEHOLDER",
+ "parquet": {}
+ }
+ ]
+ }
+ }
+ }
+ }
+ ]
+})";
+ std::string uri_placeholder = "URI_PLACEHOLDER";
+ json_plan.replace(json_plan.find(uri_placeholder), uri_placeholder.size(), uri);
+ return engine::SerializeJsonPlan(json_plan);
+}
+
+TEST_F(TestAcero, GetSqlInfo) {
+ FlightCallOptions call_options;
+ std::vector sql_info_codes = {
+ SqlInfoOptions::SqlInfo::FLIGHT_SQL_SERVER_SUBSTRAIT,
+ SqlInfoOptions::SqlInfo::FLIGHT_SQL_SERVER_TRANSACTION,
+ };
+ ASSERT_OK_AND_ASSIGN(auto flight_info,
+ client_->GetSqlInfo(call_options, sql_info_codes));
+ ASSERT_OK_AND_ASSIGN(auto reader,
+ client_->DoGet(call_options, flight_info->endpoints()[0].ticket));
+ ASSERT_OK_AND_ASSIGN(auto results, reader->ToTable());
+ ASSERT_OK_AND_ASSIGN(auto batch, results->CombineChunksToBatch());
+ ASSERT_EQ(2, results->num_rows());
+ std::vector> info;
+ const auto& ids = checked_cast(*batch->column(0));
+ const auto& values = checked_cast(*batch->column(1));
+ for (int64_t i = 0; i < batch->num_rows(); i++) {
+ ASSERT_OK_AND_ASSIGN(auto scalar, values.GetScalar(i));
+ if (ids.Value(i) == SqlInfoOptions::SqlInfo::FLIGHT_SQL_SERVER_SUBSTRAIT) {
+ ASSERT_EQ(*checked_cast(*scalar).value,
+ BooleanScalar(true));
+ } else if (ids.Value(i) == SqlInfoOptions::SqlInfo::FLIGHT_SQL_SERVER_TRANSACTION) {
+ ASSERT_EQ(
+ *checked_cast(*scalar).value,
+ Int32Scalar(
+ SqlInfoOptions::SqlSupportedTransaction::SQL_SUPPORTED_TRANSACTION_NONE));
+ } else {
+ FAIL() << "Unexpected info value: " << ids.Value(i);
+ }
+ }
+}
+
+TEST_F(TestAcero, Scan) {
+#ifdef _WIN32
+ GTEST_SKIP() << "ARROW-16392: Substrait File URI not supported for Windows";
+#endif
+
+ FlightCallOptions call_options;
+ ASSERT_OK_AND_ASSIGN(auto serialized_plan, MakeSubstraitPlan());
+
+ SubstraitPlan plan{serialized_plan->ToString(), /*version=*/"0.6.0"};
+ ASSERT_OK_AND_ASSIGN(std::unique_ptr info,
+ client_->ExecuteSubstrait(call_options, plan));
+ ipc::DictionaryMemo memo;
+ ASSERT_OK_AND_ASSIGN(auto schema, info->GetSchema(&memo));
+ // TODO(ARROW-17229): the scanner "special" fields are still included, strip them
+ // manually
+ auto fixed_schema = arrow::schema({schema->fields()[0]});
+ ASSERT_NO_FATAL_FAILURE(
+ AssertSchemaEqual(fixed_schema, arrow::schema({field("foo", binary())})));
+
+ ASSERT_EQ(1, info->endpoints().size());
+ ASSERT_EQ(0, info->endpoints()[0].locations.size());
+ ASSERT_OK_AND_ASSIGN(auto reader,
+ client_->DoGet(call_options, info->endpoints()[0].ticket));
+ ASSERT_OK_AND_ASSIGN(auto reader_schema, reader->GetSchema());
+ ASSERT_NO_FATAL_FAILURE(AssertSchemaEqual(schema, reader_schema));
+ ASSERT_OK_AND_ASSIGN(auto table, reader->ToTable());
+ ASSERT_GT(table->num_rows(), 0);
+}
+
+TEST_F(TestAcero, Update) {
+#ifdef _WIN32
+ GTEST_SKIP() << "ARROW-16392: Substrait File URI not supported for Windows";
+#endif
+
+ FlightCallOptions call_options;
+ ASSERT_OK_AND_ASSIGN(auto serialized_plan, MakeSubstraitPlan());
+ SubstraitPlan plan{serialized_plan->ToString(), /*version=*/"0.6.0"};
+ EXPECT_RAISES_WITH_MESSAGE_THAT(NotImplemented,
+ ::testing::HasSubstr("Updates are unsupported"),
+ client_->ExecuteSubstraitUpdate(call_options, plan));
+}
+
+TEST_F(TestAcero, Prepare) {
+#ifdef _WIN32
+ GTEST_SKIP() << "ARROW-16392: Substrait File URI not supported for Windows";
+#endif
+
+ FlightCallOptions call_options;
+ ASSERT_OK_AND_ASSIGN(auto serialized_plan, MakeSubstraitPlan());
+ SubstraitPlan plan{serialized_plan->ToString(), /*version=*/"0.6.0"};
+ ASSERT_OK_AND_ASSIGN(auto prepared_statement,
+ client_->PrepareSubstrait(call_options, plan));
+ ASSERT_NE(prepared_statement->dataset_schema(), nullptr);
+ ASSERT_EQ(prepared_statement->parameter_schema(), nullptr);
+
+ auto fixed_schema = arrow::schema({prepared_statement->dataset_schema()->fields()[0]});
+ ASSERT_NO_FATAL_FAILURE(
+ AssertSchemaEqual(fixed_schema, arrow::schema({field("foo", binary())})));
+
+ EXPECT_RAISES_WITH_MESSAGE_THAT(NotImplemented,
+ ::testing::HasSubstr("Updates are unsupported"),
+ prepared_statement->ExecuteUpdate());
+
+ ASSERT_OK_AND_ASSIGN(std::unique_ptr info, prepared_statement->Execute());
+ ASSERT_EQ(1, info->endpoints().size());
+ ASSERT_EQ(0, info->endpoints()[0].locations.size());
+ ASSERT_OK_AND_ASSIGN(auto reader,
+ client_->DoGet(call_options, info->endpoints()[0].ticket));
+ ASSERT_OK_AND_ASSIGN(auto reader_schema, reader->GetSchema());
+ ASSERT_NO_FATAL_FAILURE(
+ AssertSchemaEqual(prepared_statement->dataset_schema(), reader_schema));
+ ASSERT_OK_AND_ASSIGN(auto table, reader->ToTable());
+ ASSERT_GT(table->num_rows(), 0);
+
+ ASSERT_OK(prepared_statement->Close());
+}
+
+TEST_F(TestAcero, Transactions) {
+#ifdef _WIN32
+ GTEST_SKIP() << "ARROW-16392: Substrait File URI not supported for Windows";
+#endif
+
+ FlightCallOptions call_options;
+ ASSERT_OK_AND_ASSIGN(auto serialized_plan, MakeSubstraitPlan());
+ Transaction handle("fake-id");
+ SubstraitPlan plan{serialized_plan->ToString(), /*version=*/"0.6.0"};
+
+ EXPECT_RAISES_WITH_MESSAGE_THAT(NotImplemented,
+ ::testing::HasSubstr("Transactions are unsupported"),
+ client_->ExecuteSubstrait(call_options, plan, handle));
+ EXPECT_RAISES_WITH_MESSAGE_THAT(NotImplemented,
+ ::testing::HasSubstr("Transactions are unsupported"),
+ client_->PrepareSubstrait(call_options, plan, handle));
+}
+
+} // namespace sql
+} // namespace flight
+} // namespace arrow
diff --git a/cpp/src/arrow/flight/sql/client.cc b/cpp/src/arrow/flight/sql/client.cc
index e299b7ceb11d4..521cf9e8cd694 100644
--- a/cpp/src/arrow/flight/sql/client.cc
+++ b/cpp/src/arrow/flight/sql/client.cc
@@ -66,8 +66,65 @@ arrow::Result> GetSchemaForCommand(
GetFlightDescriptorForCommand(command));
return client->GetSchema(options, descriptor);
}
+
+::arrow::Result PackAction(const std::string& action_type,
+ const google::protobuf::Message& message) {
+ google::protobuf::Any any;
+ if (!any.PackFrom(message)) {
+ return Status::SerializationError("Could not pack ", message.GetTypeName(),
+ " into Any");
+ }
+
+ std::string buffer;
+ if (!any.SerializeToString(&buffer)) {
+ return Status::SerializationError("Could not serialize packed ",
+ message.GetTypeName());
+ }
+
+ Action action;
+ action.type = action_type;
+ action.body = Buffer::FromString(std::move(buffer));
+ return action;
+}
+
+void SetPlan(const SubstraitPlan& plan, flight_sql_pb::SubstraitPlan* pb_plan) {
+ pb_plan->set_plan(plan.plan);
+ pb_plan->set_version(plan.version);
+}
+
+Status ReadResult(ResultStream* results, google::protobuf::Message* message) {
+ ARROW_ASSIGN_OR_RAISE(auto result, results->Next());
+ if (!result) {
+ return Status::IOError("Server did not return a result for ", message->GetTypeName());
+ }
+
+ google::protobuf::Any container;
+ if (!container.ParseFromArray(result->body->data(),
+ static_cast(result->body->size()))) {
+ return Status::IOError("Unable to parse Any (expecting ", message->GetTypeName(),
+ ")");
+ }
+ if (!container.UnpackTo(message)) {
+ return Status::IOError("Unable to unpack Any (expecting ", message->GetTypeName(),
+ ")");
+ }
+ return Status::OK();
+}
+
+Status DrainResultStream(ResultStream* results) {
+ while (true) {
+ ARROW_ASSIGN_OR_RAISE(auto result, results->Next());
+ if (!result) break;
+ }
+ return Status::OK();
+}
} // namespace
+const Transaction& no_transaction() {
+ static Transaction kInvalidTransaction("");
+ return kInvalidTransaction;
+}
+
FlightSqlClient::FlightSqlClient(std::shared_ptr client)
: impl_(std::move(client)) {}
@@ -90,25 +147,59 @@ PreparedStatement::~PreparedStatement() {
}
arrow::Result> FlightSqlClient::Execute(
- const FlightCallOptions& options, const std::string& query) {
+ const FlightCallOptions& options, const std::string& query,
+ const Transaction& transaction) {
flight_sql_pb::CommandStatementQuery command;
command.set_query(query);
+ if (transaction.is_valid()) {
+ command.set_transaction_id(transaction.transaction_id());
+ }
return GetFlightInfoForCommand(this, options, command);
}
arrow::Result> FlightSqlClient::GetExecuteSchema(
- const FlightCallOptions& options, const std::string& query) {
+ const FlightCallOptions& options, const std::string& query,
+ const Transaction& transaction) {
flight_sql_pb::CommandStatementQuery command;
command.set_query(query);
+ if (transaction.is_valid()) {
+ command.set_transaction_id(transaction.transaction_id());
+ }
+ return GetSchemaForCommand(this, options, command);
+}
+arrow::Result> FlightSqlClient::ExecuteSubstrait(
+ const FlightCallOptions& options, const SubstraitPlan& plan,
+ const Transaction& transaction) {
+ flight_sql_pb::CommandStatementSubstraitPlan command;
+ SetPlan(plan, command.mutable_plan());
+ if (transaction.is_valid()) {
+ command.set_transaction_id(transaction.transaction_id());
+ }
+
+ return GetFlightInfoForCommand(this, options, command);
+}
+
+arrow::Result> FlightSqlClient::GetExecuteSubstraitSchema(
+ const FlightCallOptions& options, const SubstraitPlan& plan,
+ const Transaction& transaction) {
+ flight_sql_pb::CommandStatementSubstraitPlan command;
+ SetPlan(plan, command.mutable_plan());
+ if (transaction.is_valid()) {
+ command.set_transaction_id(transaction.transaction_id());
+ }
return GetSchemaForCommand(this, options, command);
}
arrow::Result FlightSqlClient::ExecuteUpdate(const FlightCallOptions& options,
- const std::string& query) {
+ const std::string& query,
+ const Transaction& transaction) {
flight_sql_pb::CommandStatementUpdate command;
command.set_query(query);
+ if (transaction.is_valid()) {
+ command.set_transaction_id(transaction.transaction_id());
+ }
ARROW_ASSIGN_OR_RAISE(FlightDescriptor descriptor,
GetFlightDescriptorForCommand(command));
@@ -119,14 +210,41 @@ arrow::Result FlightSqlClient::ExecuteUpdate(const FlightCallOptions& o
ARROW_RETURN_NOT_OK(DoPut(options, descriptor, arrow::schema({}), &writer, &reader));
std::shared_ptr metadata;
-
ARROW_RETURN_NOT_OK(reader->ReadMetadata(&metadata));
+ ARROW_RETURN_NOT_OK(writer->Close());
+
+ flight_sql_pb::DoPutUpdateResult result;
+ if (!result.ParseFromArray(metadata->data(), static_cast(metadata->size()))) {
+ return Status::Invalid("Unable to parse DoPutUpdateResult");
+ }
+
+ return result.record_count();
+}
+
+arrow::Result FlightSqlClient::ExecuteSubstraitUpdate(
+ const FlightCallOptions& options, const SubstraitPlan& plan,
+ const Transaction& transaction) {
+ flight_sql_pb::CommandStatementSubstraitPlan command;
+ SetPlan(plan, command.mutable_plan());
+ if (transaction.is_valid()) {
+ command.set_transaction_id(transaction.transaction_id());
+ }
+
+ ARROW_ASSIGN_OR_RAISE(FlightDescriptor descriptor,
+ GetFlightDescriptorForCommand(command));
+
+ std::unique_ptr writer;
+ std::unique_ptr reader;
+
+ ARROW_RETURN_NOT_OK(DoPut(options, descriptor, arrow::schema({}), &writer, &reader));
- flight_sql_pb::DoPutUpdateResult doPutUpdateResult;
+ std::shared_ptr metadata;
+ ARROW_RETURN_NOT_OK(reader->ReadMetadata(&metadata));
+ ARROW_RETURN_NOT_OK(writer->Close());
flight_sql_pb::DoPutUpdateResult result;
if (!result.ParseFromArray(metadata->data(), static_cast(metadata->size()))) {
- return Status::Invalid("Unable to parse DoPutUpdateResult object.");
+ return Status::Invalid("Unable to parse DoPutUpdateResult");
}
return result.record_count();
@@ -357,35 +475,41 @@ arrow::Result> FlightSqlClient::DoGet(
}
arrow::Result> FlightSqlClient::Prepare(
- const FlightCallOptions& options, const std::string& query) {
- google::protobuf::Any command;
+ const FlightCallOptions& options, const std::string& query,
+ const Transaction& transaction) {
flight_sql_pb::ActionCreatePreparedStatementRequest request;
request.set_query(query);
- command.PackFrom(request);
-
- Action action;
- action.type = "CreatePreparedStatement";
- action.body = Buffer::FromString(command.SerializeAsString());
+ if (transaction.is_valid()) {
+ request.set_transaction_id(transaction.transaction_id());
+ }
std::unique_ptr results;
-
+ ARROW_ASSIGN_OR_RAISE(auto action, PackAction("CreatePreparedStatement", request));
ARROW_RETURN_NOT_OK(DoAction(options, action, &results));
- ARROW_ASSIGN_OR_RAISE(std::unique_ptr result, results->Next());
-
- google::protobuf::Any prepared_result;
+ return PreparedStatement::ParseResponse(this, std::move(results));
+}
- std::shared_ptr message = std::move(result->body);
- if (!prepared_result.ParseFromArray(message->data(),
- static_cast(message->size()))) {
- return Status::Invalid("Unable to parse packed ActionCreatePreparedStatementResult");
+arrow::Result> FlightSqlClient::PrepareSubstrait(
+ const FlightCallOptions& options, const SubstraitPlan& plan,
+ const Transaction& transaction) {
+ flight_sql_pb::ActionCreatePreparedSubstraitPlanRequest request;
+ SetPlan(plan, request.mutable_plan());
+ if (transaction.is_valid()) {
+ request.set_transaction_id(transaction.transaction_id());
}
- flight_sql_pb::ActionCreatePreparedStatementResult prepared_statement_result;
+ std::unique_ptr results;
+ ARROW_ASSIGN_OR_RAISE(auto action, PackAction("CreatePreparedSubstraitPlan", request));
+ ARROW_RETURN_NOT_OK(DoAction(options, action, &results));
- if (!prepared_result.UnpackTo(&prepared_statement_result)) {
- return Status::Invalid("Unable to unpack ActionCreatePreparedStatementResult");
- }
+ return PreparedStatement::ParseResponse(this, std::move(results));
+}
+
+arrow::Result> PreparedStatement::ParseResponse(
+ FlightSqlClient* client, std::unique_ptr results) {
+ flight_sql_pb::ActionCreatePreparedStatementResult prepared_statement_result;
+ ARROW_RETURN_NOT_OK(ReadResult(results.get(), &prepared_statement_result));
const std::string& serialized_dataset_schema =
prepared_statement_result.dataset_schema();
@@ -407,14 +531,14 @@ arrow::Result> FlightSqlClient::Prepare(
}
auto handle = prepared_statement_result.prepared_statement_handle();
- return std::make_shared(this, handle, dataset_schema,
+ return std::make_shared(client, handle, dataset_schema,
parameter_schema);
}
arrow::Result> PreparedStatement::Execute(
const FlightCallOptions& options) {
if (is_closed_) {
- return Status::Invalid("Statement already closed.");
+ return Status::Invalid("Statement with handle '", handle_, "' already closed");
}
flight_sql_pb::CommandPreparedStatementQuery command;
@@ -433,6 +557,7 @@ arrow::Result> PreparedStatement::Execute(
// Wait for the server to ack the result
std::shared_ptr buffer;
ARROW_RETURN_NOT_OK(reader->ReadMetadata(&buffer));
+ ARROW_RETURN_NOT_OK(writer->Close());
}
ARROW_ASSIGN_OR_RAISE(auto flight_info, client_->GetFlightInfo(options, descriptor));
@@ -442,7 +567,7 @@ arrow::Result> PreparedStatement::Execute(
arrow::Result PreparedStatement::ExecuteUpdate(
const FlightCallOptions& options) {
if (is_closed_) {
- return Status::Invalid("Statement already closed.");
+ return Status::Invalid("Statement with handle '", handle_, "' already closed");
}
flight_sql_pb::CommandPreparedStatementUpdate command;
@@ -496,7 +621,7 @@ std::shared_ptr PreparedStatement::parameter_schema() const {
arrow::Result> PreparedStatement::GetSchema(
const FlightCallOptions& options) {
if (is_closed_) {
- return Status::Invalid("Statement already closed");
+ return Status::Invalid("Statement with handle '", handle_, "' already closed");
}
flight_sql_pb::CommandPreparedStatementQuery command;
@@ -508,29 +633,185 @@ arrow::Result> PreparedStatement::GetSchema(
Status PreparedStatement::Close(const FlightCallOptions& options) {
if (is_closed_) {
- return Status::Invalid("Statement already closed.");
+ return Status::Invalid("Statement with handle '", handle_, "' already closed");
}
- google::protobuf::Any command;
+
flight_sql_pb::ActionClosePreparedStatementRequest request;
request.set_prepared_statement_handle(handle_);
- command.PackFrom(request);
+ std::unique_ptr results;
+ ARROW_ASSIGN_OR_RAISE(auto action, PackAction("ClosePreparedStatement", request));
+ ARROW_RETURN_NOT_OK(client_->DoAction(options, action, &results));
+ ARROW_RETURN_NOT_OK(DrainResultStream(results.get()));
- Action action;
- action.type = "ClosePreparedStatement";
- action.body = Buffer::FromString(command.SerializeAsString());
+ is_closed_ = true;
+ return Status::OK();
+}
+
+::arrow::Result FlightSqlClient::BeginTransaction(
+ const FlightCallOptions& options) {
+ flight_sql_pb::ActionBeginTransactionRequest request;
std::unique_ptr results;
+ ARROW_ASSIGN_OR_RAISE(auto action, PackAction("BeginTransaction", request));
+ ARROW_RETURN_NOT_OK(DoAction(options, action, &results));
- ARROW_RETURN_NOT_OK(client_->DoAction(options, action, &results));
+ flight_sql_pb::ActionBeginTransactionResult transaction;
+ ARROW_RETURN_NOT_OK(ReadResult(results.get(), &transaction));
+ if (transaction.transaction_id().empty()) {
+ return Status::Invalid("Server returned an empty transaction ID");
+ }
- is_closed_ = true;
+ ARROW_RETURN_NOT_OK(DrainResultStream(results.get()));
+ return Transaction(transaction.transaction_id());
+}
+
+::arrow::Result FlightSqlClient::BeginSavepoint(
+ const FlightCallOptions& options, const Transaction& transaction,
+ const std::string& name) {
+ flight_sql_pb::ActionBeginSavepointRequest request;
+
+ if (!transaction.is_valid()) {
+ return Status::Invalid("Must provide an active transaction");
+ }
+ request.set_transaction_id(transaction.transaction_id());
+ request.set_name(name);
+
+ std::unique_ptr results;
+ ARROW_ASSIGN_OR_RAISE(auto action, PackAction("BeginSavepoint", request));
+ ARROW_RETURN_NOT_OK(DoAction(options, action, &results));
+
+ flight_sql_pb::ActionBeginSavepointResult savepoint;
+ ARROW_RETURN_NOT_OK(ReadResult(results.get(), &savepoint));
+ if (savepoint.savepoint_id().empty()) {
+ return Status::Invalid("Server returned an empty savepoint ID");
+ }
+
+ ARROW_RETURN_NOT_OK(DrainResultStream(results.get()));
+ return Savepoint(savepoint.savepoint_id());
+}
+
+Status FlightSqlClient::Commit(const FlightCallOptions& options,
+ const Transaction& transaction) {
+ flight_sql_pb::ActionEndTransactionRequest request;
+
+ if (!transaction.is_valid()) {
+ return Status::Invalid("Must provide an active transaction");
+ }
+ request.set_transaction_id(transaction.transaction_id());
+ request.set_action(flight_sql_pb::ActionEndTransactionRequest::END_TRANSACTION_COMMIT);
+
+ std::unique_ptr results;
+ ARROW_ASSIGN_OR_RAISE(auto action, PackAction("EndTransaction", request));
+ ARROW_RETURN_NOT_OK(DoAction(options, action, &results));
+
+ ARROW_RETURN_NOT_OK(DrainResultStream(results.get()));
+ return Status::OK();
+}
+Status FlightSqlClient::Release(const FlightCallOptions& options,
+ const Savepoint& savepoint) {
+ flight_sql_pb::ActionEndSavepointRequest request;
+
+ if (!savepoint.is_valid()) {
+ return Status::Invalid("Must provide an active savepoint");
+ }
+ request.set_savepoint_id(savepoint.savepoint_id());
+ request.set_action(flight_sql_pb::ActionEndSavepointRequest::END_SAVEPOINT_RELEASE);
+
+ std::unique_ptr results;
+ ARROW_ASSIGN_OR_RAISE(auto action, PackAction("EndSavepoint", request));
+ ARROW_RETURN_NOT_OK(DoAction(options, action, &results));
+
+ ARROW_RETURN_NOT_OK(DrainResultStream(results.get()));
+ return Status::OK();
+}
+
+Status FlightSqlClient::Rollback(const FlightCallOptions& options,
+ const Transaction& transaction) {
+ flight_sql_pb::ActionEndTransactionRequest request;
+
+ if (!transaction.is_valid()) {
+ return Status::Invalid("Must provide an active transaction");
+ }
+ request.set_transaction_id(transaction.transaction_id());
+ request.set_action(
+ flight_sql_pb::ActionEndTransactionRequest::END_TRANSACTION_ROLLBACK);
+
+ std::unique_ptr results;
+ ARROW_ASSIGN_OR_RAISE(auto action, PackAction("EndTransaction", request));
+ ARROW_RETURN_NOT_OK(DoAction(options, action, &results));
+
+ ARROW_RETURN_NOT_OK(DrainResultStream(results.get()));
+ return Status::OK();
+}
+
+Status FlightSqlClient::Rollback(const FlightCallOptions& options,
+ const Savepoint& savepoint) {
+ flight_sql_pb::ActionEndSavepointRequest request;
+
+ if (!savepoint.is_valid()) {
+ return Status::Invalid("Must provide an active savepoint");
+ }
+ request.set_savepoint_id(savepoint.savepoint_id());
+ request.set_action(flight_sql_pb::ActionEndSavepointRequest::END_SAVEPOINT_ROLLBACK);
+
+ std::unique_ptr results;
+ ARROW_ASSIGN_OR_RAISE(auto action, PackAction("EndSavepoint", request));
+ ARROW_RETURN_NOT_OK(DoAction(options, action, &results));
+
+ ARROW_RETURN_NOT_OK(DrainResultStream(results.get()));
return Status::OK();
}
+::arrow::Result FlightSqlClient::CancelQuery(
+ const FlightCallOptions& options, const FlightInfo& info) {
+ flight_sql_pb::ActionCancelQueryRequest request;
+ ARROW_ASSIGN_OR_RAISE(auto serialized_info, info.SerializeToString());
+ request.set_info(std::move(serialized_info));
+
+ std::unique_ptr results;
+ ARROW_ASSIGN_OR_RAISE(auto action, PackAction("CancelQuery", request));
+ ARROW_RETURN_NOT_OK(DoAction(options, action, &results));
+
+ flight_sql_pb::ActionCancelQueryResult result;
+ ARROW_RETURN_NOT_OK(ReadResult(results.get(), &result));
+ ARROW_RETURN_NOT_OK(DrainResultStream(results.get()));
+ switch (result.result()) {
+ case flight_sql_pb::ActionCancelQueryResult::CANCEL_RESULT_UNSPECIFIED:
+ return CancelResult::kUnspecified;
+ case flight_sql_pb::ActionCancelQueryResult::CANCEL_RESULT_CANCELLED:
+ return CancelResult::kCancelled;
+ case flight_sql_pb::ActionCancelQueryResult::CANCEL_RESULT_CANCELLING:
+ return CancelResult::kCancelling;
+ case flight_sql_pb::ActionCancelQueryResult::CANCEL_RESULT_NOT_CANCELLABLE:
+ return CancelResult::kNotCancellable;
+ default:
+ break;
+ }
+ return Status::IOError("Server returned unknown result ", result.result());
+}
+
Status FlightSqlClient::Close() { return impl_->Close(); }
+std::ostream& operator<<(std::ostream& os, CancelResult result) {
+ switch (result) {
+ case CancelResult::kUnspecified:
+ os << "CancelResult::kUnspecified";
+ break;
+ case CancelResult::kCancelled:
+ os << "CancelResult::kCancelled";
+ break;
+ case CancelResult::kCancelling:
+ os << "CancelResult::kCancelling";
+ break;
+ case CancelResult::kNotCancellable:
+ os << "CancelResult::kNotCancellable";
+ break;
+ }
+ return os;
+}
+
} // namespace sql
} // namespace flight
} // namespace arrow
diff --git a/cpp/src/arrow/flight/sql/client.h b/cpp/src/arrow/flight/sql/client.h
index 26315e0d234fe..db168847ed66b 100644
--- a/cpp/src/arrow/flight/sql/client.h
+++ b/cpp/src/arrow/flight/sql/client.h
@@ -17,6 +17,7 @@
#pragma once
+#include
#include
#include
@@ -32,6 +33,13 @@ namespace flight {
namespace sql {
class PreparedStatement;
+class Transaction;
+class Savepoint;
+
+/// \brief A default transaction to use when the default behavior
+/// (auto-commit) is desired.
+ARROW_FLIGHT_SQL_EXPORT
+const Transaction& no_transaction();
/// \brief Flight client with Flight SQL semantics.
///
@@ -47,23 +55,51 @@ class ARROW_FLIGHT_SQL_EXPORT FlightSqlClient {
virtual ~FlightSqlClient() = default;
- /// \brief Execute a query on the server.
+ /// \brief Execute a SQL query on the server.
+ /// \param[in] options RPC-layer hints for this call.
+ /// \param[in] query The UTF8-encoded SQL query to be executed.
+ /// \param[in] transaction A transaction to associate this query with.
+ /// \return The FlightInfo describing where to access the dataset.
+ arrow::Result> Execute(
+ const FlightCallOptions& options, const std::string& query,
+ const Transaction& transaction = no_transaction());
+
+ /// \brief Execute a Substrait plan that returns a result set on the server.
/// \param[in] options RPC-layer hints for this call.
- /// \param[in] query The query to be executed in the UTF-8 format.
+ /// \param[in] plan The plan to be executed.
+ /// \param[in] transaction A transaction to associate this query with.
/// \return The FlightInfo describing where to access the dataset.
- arrow::Result> Execute(const FlightCallOptions& options,
- const std::string& query);
+ arrow::Result> ExecuteSubstrait(
+ const FlightCallOptions& options, const SubstraitPlan& plan,
+ const Transaction& transaction = no_transaction());
/// \brief Get the result set schema from the server.
arrow::Result> GetExecuteSchema(
- const FlightCallOptions& options, const std::string& query);
+ const FlightCallOptions& options, const std::string& query,
+ const Transaction& transaction = no_transaction());
+
+ /// \brief Get the result set schema from the server.
+ arrow::Result> GetExecuteSubstraitSchema(
+ const FlightCallOptions& options, const SubstraitPlan& plan,
+ const Transaction& transaction = no_transaction());
/// \brief Execute an update query on the server.
/// \param[in] options RPC-layer hints for this call.
- /// \param[in] query The query to be executed in the UTF-8 format.
+ /// \param[in] query The UTF8-encoded SQL query to be executed.
+ /// \param[in] transaction A transaction to associate this query with.
/// \return The quantity of rows affected by the operation.
arrow::Result ExecuteUpdate(const FlightCallOptions& options,
- const std::string& query);
+ const std::string& query,
+ const Transaction& transaction = no_transaction());
+
+ /// \brief Execute a Substrait plan that does not return a result set on the server.
+ /// \param[in] options RPC-layer hints for this call.
+ /// \param[in] plan The plan to be executed.
+ /// \param[in] transaction A transaction to associate this query with.
+ /// \return The FlightInfo describing where to access the dataset.
+ arrow::Result ExecuteSubstraitUpdate(
+ const FlightCallOptions& options, const SubstraitPlan& plan,
+ const Transaction& transaction = no_transaction());
/// \brief Request a list of catalogs.
/// \param[in] options RPC-layer hints for this call.
@@ -215,9 +251,20 @@ class ARROW_FLIGHT_SQL_EXPORT FlightSqlClient {
/// \brief Create a prepared statement object.
/// \param[in] options RPC-layer hints for this call.
/// \param[in] query The query that will be executed.
+ /// \param[in] transaction A transaction to associate this query with.
/// \return The created prepared statement.
arrow::Result> Prepare(
- const FlightCallOptions& options, const std::string& query);
+ const FlightCallOptions& options, const std::string& query,
+ const Transaction& transaction = no_transaction());
+
+ /// \brief Create a prepared statement object.
+ /// \param[in] options RPC-layer hints for this call.
+ /// \param[in] plan The Substrait plan that will be executed.
+ /// \param[in] transaction A transaction to associate this query with.
+ /// \return The created prepared statement.
+ arrow::Result> PrepareSubstrait(
+ const FlightCallOptions& options, const SubstraitPlan& plan,
+ const Transaction& transaction = no_transaction());
/// \brief Call the underlying Flight client's GetFlightInfo.
virtual arrow::Result> GetFlightInfo(
@@ -231,6 +278,58 @@ class ARROW_FLIGHT_SQL_EXPORT FlightSqlClient {
return impl_->GetSchema(options, descriptor);
}
+ /// \brief Begin a new transaction.
+ ::arrow::Result BeginTransaction(const FlightCallOptions& options);
+
+ /// \brief Create a new savepoint within a transaction.
+ /// \param[in] options RPC-layer hints for this call.
+ /// \param[in] transaction The parent transaction.
+ /// \param[in] name A friendly name for the savepoint.
+ ::arrow::Result BeginSavepoint(const FlightCallOptions& options,
+ const Transaction& transaction,
+ const std::string& name);
+
+ /// \brief Commit a transaction.
+ ///
+ /// After this, the transaction and all associated savepoints will
+ /// be invalidated.
+ ///
+ /// \param[in] options RPC-layer hints for this call.
+ /// \param[in] transaction The transaction.
+ Status Commit(const FlightCallOptions& options, const Transaction& transaction);
+
+ /// \brief Release a savepoint.
+ ///
+ /// After this, the savepoint (and all savepoints created after it) will be invalidated.
+ ///
+ /// \param[in] options RPC-layer hints for this call.
+ /// \param[in] savepoint The savepoint.
+ Status Release(const FlightCallOptions& options, const Savepoint& savepoint);
+
+ /// \brief Rollback a transaction.
+ ///
+ /// After this, the transaction and all associated savepoints will be invalidated.
+ ///
+ /// \param[in] options RPC-layer hints for this call.
+ /// \param[in] transaction The transaction.
+ Status Rollback(const FlightCallOptions& options, const Transaction& transaction);
+
+ /// \brief Rollback a savepoint.
+ ///
+ /// After this, the savepoint will still be valid, but all
+ /// savepoints created after it will be invalidated.
+ ///
+ /// \param[in] options RPC-layer hints for this call.
+ /// \param[in] savepoint The savepoint.
+ Status Rollback(const FlightCallOptions& options, const Savepoint& savepoint);
+
+ /// \brief Explicitly cancel a query.
+ ///
+ /// \param[in] options RPC-layer hints for this call.
+ /// \param[in] info The FlightInfo of the query to cancel.
+ ::arrow::Result CancelQuery(const FlightCallOptions& options,
+ const FlightInfo& info);
+
/// \brief Explicitly shut down and clean up the client.
Status Close();
@@ -278,6 +377,10 @@ class ARROW_FLIGHT_SQL_EXPORT PreparedStatement {
/// errors can't be caught.
~PreparedStatement();
+ /// \brief Create a PreparedStatement by parsing the server response.
+ static arrow::Result> ParseResponse(
+ FlightSqlClient* client, std::unique_ptr results);
+
/// \brief Executes the prepared statement query on the server.
/// \return A FlightInfo object representing the stream(s) to fetch.
arrow::Result> Execute(
@@ -295,8 +398,8 @@ class ARROW_FLIGHT_SQL_EXPORT PreparedStatement {
/// \return The ResultSet schema from the query.
std::shared_ptr dataset_schema() const;
- /// \brief Set a RecordBatch that contains the parameters that will be bind.
- /// \param parameter_binding The parameters that will be bind.
+ /// \brief Set a RecordBatch that contains the parameters that will be bound.
+ /// \param parameter_binding The parameters that will be bound.
/// \return Status.
Status SetParameters(std::shared_ptr parameter_binding);
@@ -305,9 +408,9 @@ class ARROW_FLIGHT_SQL_EXPORT PreparedStatement {
arrow::Result> GetSchema(
const FlightCallOptions& options = {});
- /// \brief Close the prepared statement, so that this PreparedStatement can not used
- /// anymore and server can free up any resources.
- /// \return Status.
+ /// \brief Close the prepared statement so the server can free up any resources.
+ ///
+ /// After this, the prepared statement may not be used anymore.
Status Close(const FlightCallOptions& options = {});
/// \brief Check if the prepared statement is closed.
@@ -323,6 +426,29 @@ class ARROW_FLIGHT_SQL_EXPORT PreparedStatement {
bool is_closed_;
};
+/// \brief A handle for a server-side savepoint.
+class ARROW_FLIGHT_SQL_EXPORT Savepoint {
+ public:
+ explicit Savepoint(std::string savepoint_id) : savepoint_id_(std::move(savepoint_id)) {}
+ const std::string& savepoint_id() const { return savepoint_id_; }
+ bool is_valid() const { return !savepoint_id_.empty(); }
+
+ private:
+ std::string savepoint_id_;
+};
+
+/// \brief A handle for a server-side transaction.
+class ARROW_FLIGHT_SQL_EXPORT Transaction {
+ public:
+ explicit Transaction(std::string transaction_id)
+ : transaction_id_(std::move(transaction_id)) {}
+ const std::string& transaction_id() const { return transaction_id_; }
+ bool is_valid() const { return !transaction_id_.empty(); }
+
+ private:
+ std::string transaction_id_;
+};
+
} // namespace sql
} // namespace flight
} // namespace arrow
diff --git a/cpp/src/arrow/flight/sql/client_test.cc b/cpp/src/arrow/flight/sql/client_test.cc
index acd078a847792..984bf45481682 100644
--- a/cpp/src/arrow/flight/sql/client_test.cc
+++ b/cpp/src/arrow/flight/sql/client_test.cc
@@ -410,6 +410,7 @@ TEST_F(TestFlightSqlClient, TestExecuteUpdate) {
std::unique_ptr* writer,
std::unique_ptr* reader) {
reader->reset(new FlightMetadataReaderMock(&buffer_ptr));
+ writer->reset(new FlightStreamWriterMock());
return Status::OK();
});
diff --git a/cpp/src/arrow/flight/sql/column_metadata.cc b/cpp/src/arrow/flight/sql/column_metadata.cc
index 30ef240105c87..adfe81f17300b 100644
--- a/cpp/src/arrow/flight/sql/column_metadata.cc
+++ b/cpp/src/arrow/flight/sql/column_metadata.cc
@@ -118,25 +118,25 @@ const std::shared_ptr& ColumnMetadata::metadata_m
}
ColumnMetadata::ColumnMetadataBuilder& ColumnMetadata::ColumnMetadataBuilder::CatalogName(
- std::string& catalog_name) {
+ const std::string& catalog_name) {
metadata_map_->Append(ColumnMetadata::kCatalogName, catalog_name);
return *this;
}
ColumnMetadata::ColumnMetadataBuilder& ColumnMetadata::ColumnMetadataBuilder::SchemaName(
- std::string& schema_name) {
+ const std::string& schema_name) {
metadata_map_->Append(ColumnMetadata::kSchemaName, schema_name);
return *this;
}
ColumnMetadata::ColumnMetadataBuilder& ColumnMetadata::ColumnMetadataBuilder::TableName(
- std::string& table_name) {
+ const std::string& table_name) {
metadata_map_->Append(ColumnMetadata::kTableName, table_name);
return *this;
}
ColumnMetadata::ColumnMetadataBuilder& ColumnMetadata::ColumnMetadataBuilder::TypeName(
- std::string& type_name) {
+ const std::string& type_name) {
metadata_map_->Append(ColumnMetadata::kTypeName, type_name);
return *this;
}
diff --git a/cpp/src/arrow/flight/sql/column_metadata.h b/cpp/src/arrow/flight/sql/column_metadata.h
index 15b139ec5806a..0eb53f3e0bbc6 100644
--- a/cpp/src/arrow/flight/sql/column_metadata.h
+++ b/cpp/src/arrow/flight/sql/column_metadata.h
@@ -122,22 +122,22 @@ class ARROW_FLIGHT_SQL_EXPORT ColumnMetadata {
/// \brief Set the catalog name in the KeyValueMetadata object.
/// \param[in] catalog_name The catalog name.
/// \return A ColumnMetadataBuilder.
- ColumnMetadataBuilder& CatalogName(std::string& catalog_name);
+ ColumnMetadataBuilder& CatalogName(const std::string& catalog_name);
/// \brief Set the schema_name in the KeyValueMetadata object.
/// \param[in] schema_name The schema_name.
/// \return A ColumnMetadataBuilder.
- ColumnMetadataBuilder& SchemaName(std::string& schema_name);
+ ColumnMetadataBuilder& SchemaName(const std::string& schema_name);
/// \brief Set the table name in the KeyValueMetadata object.
/// \param[in] table_name The table name.
/// \return A ColumnMetadataBuilder.
- ColumnMetadataBuilder& TableName(std::string& table_name);
+ ColumnMetadataBuilder& TableName(const std::string& table_name);
/// \brief Set the type name in the KeyValueMetadata object.
/// \param[in] type_name The type name.
/// \return A ColumnMetadataBuilder.
- ColumnMetadataBuilder& TypeName(std::string& type_name);
+ ColumnMetadataBuilder& TypeName(const std::string& type_name);
/// \brief Set the precision in the KeyValueMetadata object.
/// \param[in] precision The precision.
diff --git a/cpp/src/arrow/flight/sql/example/acero_main.cc b/cpp/src/arrow/flight/sql/example/acero_main.cc
new file mode 100644
index 0000000000000..111bebcbf0f03
--- /dev/null
+++ b/cpp/src/arrow/flight/sql/example/acero_main.cc
@@ -0,0 +1,70 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+// Example Flight SQL server backed by Acero.
+
+#include
+
+#include
+#include
+#include
+#include
+
+#include
+
+#include "arrow/flight/sql/example/acero_server.h"
+#include "arrow/status.h"
+#include "arrow/util/logging.h"
+
+namespace flight = arrow::flight;
+namespace sql = arrow::flight::sql;
+
+DEFINE_string(location, "grpc://localhost:12345", "Location to listen on");
+
+arrow::Status RunMain(const std::string& location_str) {
+ ARROW_ASSIGN_OR_RAISE(flight::Location location, flight::Location::Parse(location_str));
+ flight::FlightServerOptions options(location);
+
+ std::unique_ptr server;
+ ARROW_ASSIGN_OR_RAISE(server, sql::acero_example::MakeAceroServer());
+ ARROW_RETURN_NOT_OK(server->Init(options));
+
+ ARROW_RETURN_NOT_OK(server->SetShutdownOnSignals({SIGTERM}));
+
+ ARROW_LOG(INFO) << "Listening on " << location.ToString();
+
+ ARROW_RETURN_NOT_OK(server->Serve());
+ return arrow::Status::OK();
+}
+
+int main(int argc, char** argv) {
+ gflags::ParseCommandLineFlags(&argc, &argv, true);
+
+ arrow::util::ArrowLog::StartArrowLog("acero-flight-sql-server",
+ arrow::util::ArrowLogLevel::ARROW_INFO);
+ arrow::util::ArrowLog::InstallFailureSignalHandler();
+
+ arrow::Status st = RunMain(FLAGS_location);
+
+ arrow::util::ArrowLog::ShutDownArrowLog();
+
+ if (!st.ok()) {
+ std::cerr << st << std::endl;
+ return EXIT_FAILURE;
+ }
+ return EXIT_SUCCESS;
+}
diff --git a/cpp/src/arrow/flight/sql/example/acero_server.cc b/cpp/src/arrow/flight/sql/example/acero_server.cc
new file mode 100644
index 0000000000000..ce1483cb8c3f4
--- /dev/null
+++ b/cpp/src/arrow/flight/sql/example/acero_server.cc
@@ -0,0 +1,306 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "arrow/flight/sql/example/acero_server.h"
+
+#include
+#include
+#include
+#include
+
+#include "arrow/engine/substrait/serde.h"
+#include "arrow/flight/sql/types.h"
+#include "arrow/type.h"
+#include "arrow/util/logging.h"
+
+namespace arrow {
+namespace flight {
+namespace sql {
+namespace acero_example {
+
+namespace {
+/// \brief A SinkNodeConsumer that saves the schema as given to it by
+/// the ExecPlan. Used to retrieve the schema of a Substrait plan to
+/// fulfill the Flight SQL API contract.
+class GetSchemaSinkNodeConsumer : public compute::SinkNodeConsumer {
+ public:
+ Status Init(const std::shared_ptr& schema,
+ compute::BackpressureControl*) override {
+ schema_ = schema;
+ return Status::OK();
+ }
+ Status Consume(compute::ExecBatch exec_batch) override { return Status::OK(); }
+ Future<> Finish() override { return Status::OK(); }
+
+ const std::shared_ptr& schema() const { return schema_; }
+
+ private:
+ std::shared_ptr schema_;
+};
+
+/// \brief A SinkNodeConsumer that internally saves batches into a
+/// queue, so that it can be read from a RecordBatchReader. In other
+/// words, this bridges a push-based interface (ExecPlan) to a
+/// pull-based interface (RecordBatchReader).
+class QueuingSinkNodeConsumer : public compute::SinkNodeConsumer {
+ public:
+ QueuingSinkNodeConsumer() : schema_(nullptr), finished_(false) {}
+
+ Status Init(const std::shared_ptr& schema,
+ compute::BackpressureControl*) override {
+ schema_ = schema;
+ return Status::OK();
+ }
+
+ Status Consume(compute::ExecBatch exec_batch) override {
+ {
+ std::lock_guard guard(mutex_);
+ batches_.push_back(std::move(exec_batch));
+ batches_added_.notify_all();
+ }
+
+ return Status::OK();
+ }
+
+ Future<> Finish() override {
+ {
+ std::lock_guard guard(mutex_);
+ finished_ = true;
+ batches_added_.notify_all();
+ }
+
+ return Status::OK();
+ }
+
+ const std::shared_ptr& schema() const { return schema_; }
+
+ arrow::Result> Next() {
+ compute::ExecBatch batch;
+ {
+ std::unique_lock guard(mutex_);
+ batches_added_.wait(guard, [this] { return !batches_.empty() || finished_; });
+
+ if (finished_ && batches_.empty()) {
+ return nullptr;
+ }
+ batch = std::move(batches_.front());
+ batches_.pop_front();
+ }
+
+ return batch.ToRecordBatch(schema_);
+ }
+
+ private:
+ std::mutex mutex_;
+ std::condition_variable batches_added_;
+ std::deque batches_;
+ std::shared_ptr schema_;
+ bool finished_;
+};
+
+/// \brief A RecordBatchReader that pulls from the
+/// QueuingSinkNodeConsumer above, blocking until results are
+/// available as necessary.
+class ConsumerBasedRecordBatchReader : public RecordBatchReader {
+ public:
+ explicit ConsumerBasedRecordBatchReader(
+ std::shared_ptr plan,
+ std::shared_ptr consumer)
+ : plan_(std::move(plan)), consumer_(std::move(consumer)) {}
+
+ std::shared_ptr schema() const override { return consumer_->schema(); }
+
+ Status ReadNext(std::shared_ptr* batch) override {
+ return consumer_->Next().Value(batch);
+ }
+
+ // TODO(ARROW-17242): FlightDataStream needs to call Close()
+ Status Close() override { return plan_->finished().status(); }
+
+ private:
+ std::shared_ptr plan_;
+ std::shared_ptr consumer_;
+};
+
+/// \brief An implementation of a Flight SQL service backed by Acero.
+class AceroFlightSqlServer : public FlightSqlServerBase {
+ public:
+ AceroFlightSqlServer() {
+ RegisterSqlInfo(SqlInfoOptions::SqlInfo::FLIGHT_SQL_SERVER_SUBSTRAIT,
+ SqlInfoResult(true));
+ RegisterSqlInfo(SqlInfoOptions::SqlInfo::FLIGHT_SQL_SERVER_SUBSTRAIT_MIN_VERSION,
+ SqlInfoResult(std::string("0.6.0")));
+ RegisterSqlInfo(SqlInfoOptions::SqlInfo::FLIGHT_SQL_SERVER_SUBSTRAIT_MAX_VERSION,
+ SqlInfoResult(std::string("0.6.0")));
+ RegisterSqlInfo(
+ SqlInfoOptions::SqlInfo::FLIGHT_SQL_SERVER_TRANSACTION,
+ SqlInfoResult(
+ SqlInfoOptions::SqlSupportedTransaction::SQL_SUPPORTED_TRANSACTION_NONE));
+ RegisterSqlInfo(SqlInfoOptions::SqlInfo::FLIGHT_SQL_SERVER_CANCEL,
+ SqlInfoResult(false));
+ }
+
+ arrow::Result> GetFlightInfoSubstraitPlan(
+ const ServerCallContext& context, const StatementSubstraitPlan& command,
+ const FlightDescriptor& descriptor) override {
+ if (!command.transaction_id.empty()) {
+ return Status::NotImplemented("Transactions are unsupported");
+ }
+
+ ARROW_ASSIGN_OR_RAISE(std::shared_ptr output_schema,
+ GetPlanSchema(command.plan.plan));
+
+ ARROW_LOG(INFO) << "GetFlightInfoSubstraitPlan: preparing plan with output schema "
+ << *output_schema;
+
+ return MakeFlightInfo(command.plan.plan, descriptor, *output_schema);
+ }
+
+ arrow::Result> GetFlightInfoPreparedStatement(
+ const ServerCallContext& context, const PreparedStatementQuery& command,
+ const FlightDescriptor& descriptor) override {
+ std::shared_ptr plan;
+ {
+ std::lock_guard guard(mutex_);
+ auto it = prepared_.find(command.prepared_statement_handle);
+ if (it == prepared_.end()) {
+ return Status::KeyError("Prepared statement not found");
+ }
+ plan = it->second;
+ }
+
+ return MakeFlightInfo(plan->ToString(), descriptor, Schema({}));
+ }
+
+ arrow::Result> DoGetStatement(
+ const ServerCallContext& context, const StatementQueryTicket& command) override {
+ // GetFlightInfoSubstraitPlan encodes the plan into the ticket
+ std::shared_ptr serialized_plan =
+ Buffer::FromString(command.statement_handle);
+ std::shared_ptr consumer =
+ std::make_shared();
+ ARROW_ASSIGN_OR_RAISE(std::shared_ptr plan,
+ engine::DeserializePlan(*serialized_plan, consumer));
+
+ ARROW_LOG(INFO) << "DoGetStatement: executing plan " << plan->ToString();
+
+ ARROW_RETURN_NOT_OK(plan->StartProducing());
+
+ auto reader = std::make_shared(std::move(plan),
+ std::move(consumer));
+ return std::unique_ptr(new RecordBatchStream(reader));
+ }
+
+ arrow::Result DoPutCommandSubstraitPlan(
+ const ServerCallContext& context, const StatementSubstraitPlan& command) override {
+ return Status::NotImplemented("Updates are unsupported");
+ }
+
+ Status DoPutPreparedStatementQuery(const ServerCallContext& context,
+ const PreparedStatementQuery& command,
+ FlightMessageReader* reader,
+ FlightMetadataWriter* writer) override {
+ return Status::NotImplemented("NYI");
+ }
+
+ arrow::Result DoPutPreparedStatementUpdate(
+ const ServerCallContext& context, const PreparedStatementUpdate& command,
+ FlightMessageReader* reader) override {
+ return Status::NotImplemented("Updates are unsupported");
+ }
+
+ arrow::Result CreatePreparedSubstraitPlan(
+ const ServerCallContext& context,
+ const ActionCreatePreparedSubstraitPlanRequest& request) override {
+ if (!request.transaction_id.empty()) {
+ return Status::NotImplemented("Transactions are unsupported");
+ }
+ // There's not any real point to precompiling the plan, since the
+ // consumer has to be provided here. So this is effectively the
+ // same as a non-prepared plan.
+ ARROW_ASSIGN_OR_RAISE(std::shared_ptr schema,
+ GetPlanSchema(request.plan.plan));
+
+ std::string handle;
+ {
+ std::lock_guard guard(mutex_);
+ handle = std::to_string(counter_++);
+ prepared_[handle] = Buffer::FromString(request.plan.plan);
+ }
+
+ return ActionCreatePreparedStatementResult{
+ /*dataset_schema=*/std::move(schema),
+ /*parameter_schema=*/nullptr,
+ handle,
+ };
+ }
+
+ Status ClosePreparedStatement(
+ const ServerCallContext& context,
+ const ActionClosePreparedStatementRequest& request) override {
+ std::lock_guard guard(mutex_);
+ prepared_.erase(request.prepared_statement_handle);
+ return Status::OK();
+ }
+
+ private:
+ arrow::Result> GetPlanSchema(
+ const std::string& serialized_plan) {
+ std::shared_ptr plan_buf = Buffer::FromString(serialized_plan);
+ std::shared_ptr consumer =
+ std::make_shared();
+ ARROW_ASSIGN_OR_RAISE(std::shared_ptr plan,
+ engine::DeserializePlan(*plan_buf, consumer));
+ std::shared_ptr output_schema;
+ for (compute::ExecNode* sink : plan->sinks()) {
+ // Force SinkNodeConsumer::Init to be called
+ ARROW_RETURN_NOT_OK(sink->StartProducing());
+ output_schema = consumer->schema();
+ break;
+ }
+ if (!output_schema) {
+ return Status::Invalid("Could not infer output schema");
+ }
+ return output_schema;
+ }
+
+ arrow::Result> MakeFlightInfo(
+ const std::string& plan, const FlightDescriptor& descriptor, const Schema& schema) {
+ ARROW_ASSIGN_OR_RAISE(auto ticket, CreateStatementQueryTicket(plan));
+ std::vector endpoints{
+ FlightEndpoint{Ticket{std::move(ticket)}, /*locations=*/{}}};
+ ARROW_ASSIGN_OR_RAISE(auto info,
+ FlightInfo::Make(schema, descriptor, std::move(endpoints),
+ /*total_records=*/-1, /*total_bytes=*/-1));
+ return std::make_unique(std::move(info));
+ }
+
+ std::mutex mutex_;
+ std::unordered_map> prepared_;
+ int64_t counter_;
+};
+
+} // namespace
+
+arrow::Result> MakeAceroServer() {
+ return std::unique_ptr(new AceroFlightSqlServer());
+}
+
+} // namespace acero_example
+} // namespace sql
+} // namespace flight
+} // namespace arrow
diff --git a/cpp/src/arrow/flight/sql/example/acero_server.h b/cpp/src/arrow/flight/sql/example/acero_server.h
new file mode 100644
index 0000000000000..2e82fd3d3b6a8
--- /dev/null
+++ b/cpp/src/arrow/flight/sql/example/acero_server.h
@@ -0,0 +1,37 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#pragma once
+
+#include
+
+#include "arrow/flight/sql/server.h"
+#include "arrow/flight/sql/visibility.h"
+#include "arrow/result.h"
+
+namespace arrow {
+namespace flight {
+namespace sql {
+namespace acero_example {
+
+/// \brief Make a Flight SQL server backed by the Acero query engine.
+arrow::Result> MakeAceroServer();
+
+} // namespace acero_example
+} // namespace sql
+} // namespace flight
+} // namespace arrow
diff --git a/cpp/src/arrow/flight/sql/example/sqlite_server.cc b/cpp/src/arrow/flight/sql/example/sqlite_server.cc
index 35fa05468ba27..0d0a7c1ea0e01 100644
--- a/cpp/src/arrow/flight/sql/example/sqlite_server.cc
+++ b/cpp/src/arrow/flight/sql/example/sqlite_server.cc
@@ -20,11 +20,12 @@
#include
#include
-#include