diff --git a/cpp/src/arrow/flight/integration_tests/test_integration.cc b/cpp/src/arrow/flight/integration_tests/test_integration.cc index 104f1e5a60f96..43c16e0b77a6d 100644 --- a/cpp/src/arrow/flight/integration_tests/test_integration.cc +++ b/cpp/src/arrow/flight/integration_tests/test_integration.cc @@ -38,12 +38,15 @@ #include "arrow/ipc/dictionary.h" #include "arrow/status.h" #include "arrow/testing/gtest_util.h" +#include "arrow/util/checked_cast.h" namespace arrow { namespace flight { namespace integration_tests { namespace { +using arrow::internal::checked_cast; + /// \brief The server for the basic auth integration test. class AuthBasicProtoServer : public FlightServerBase { Status DoAction(const ServerCallContext& context, const Action& action, @@ -1092,10 +1095,8 @@ class FlightSqlExtensionScenario : public FlightSqlScenario { ARROW_ASSIGN_OR_RAISE(auto chunk, reader->Next()); if (!chunk.data) break; - const UInt32Array& info_name = - static_cast(*chunk.data->column(0)); - const DenseUnionArray& value = - static_cast(*chunk.data->column(1)); + const auto& info_name = checked_cast(*chunk.data->column(0)); + const auto& value = checked_cast(*chunk.data->column(1)); for (int64_t i = 0; i < chunk.data->num_rows(); i++) { const uint32_t code = info_name.Value(i); @@ -1104,25 +1105,25 @@ class FlightSqlExtensionScenario : public FlightSqlScenario { } switch (value.type_code(i)) { case 0: { // string - std::string slot = static_cast(*value.field(0)) + std::string slot = checked_cast(*value.field(0)) .GetString(value.value_offset(i)); info_values[code] = sql::SqlInfoResult(std::move(slot)); break; } case 1: { // bool - bool slot = static_cast(*value.field(1)) + bool slot = checked_cast(*value.field(1)) .Value(value.value_offset(i)); info_values[code] = sql::SqlInfoResult(slot); break; } case 2: { // int64_t - int64_t slot = static_cast(*value.field(2)) + int64_t slot = checked_cast(*value.field(2)) .Value(value.value_offset(i)); info_values[code] = sql::SqlInfoResult(slot); break; } case 3: { // int32_t - int32_t slot = static_cast(*value.field(3)) + int32_t slot = checked_cast(*value.field(3)) .Value(value.value_offset(i)); info_values[code] = sql::SqlInfoResult(slot); break; diff --git a/cpp/src/arrow/flight/sql/example/acero_server.cc b/cpp/src/arrow/flight/sql/example/acero_server.cc index 03e7352d468ee..ce1483cb8c3f4 100644 --- a/cpp/src/arrow/flight/sql/example/acero_server.cc +++ b/cpp/src/arrow/flight/sql/example/acero_server.cc @@ -33,6 +33,9 @@ namespace sql { namespace acero_example { namespace { +/// \brief A SinkNodeConsumer that saves the schema as given to it by +/// the ExecPlan. Used to retrieve the schema of a Substrait plan to +/// fulfill the Flight SQL API contract. class GetSchemaSinkNodeConsumer : public compute::SinkNodeConsumer { public: Status Init(const std::shared_ptr& schema, @@ -49,6 +52,10 @@ class GetSchemaSinkNodeConsumer : public compute::SinkNodeConsumer { std::shared_ptr schema_; }; +/// \brief A SinkNodeConsumer that internally saves batches into a +/// queue, so that it can be read from a RecordBatchReader. In other +/// words, this bridges a push-based interface (ExecPlan) to a +/// pull-based interface (RecordBatchReader). class QueuingSinkNodeConsumer : public compute::SinkNodeConsumer { public: QueuingSinkNodeConsumer() : schema_(nullptr), finished_(false) {} @@ -105,6 +112,9 @@ class QueuingSinkNodeConsumer : public compute::SinkNodeConsumer { bool finished_; }; +/// \brief A RecordBatchReader that pulls from the +/// QueuingSinkNodeConsumer above, blocking until results are +/// available as necessary. class ConsumerBasedRecordBatchReader : public RecordBatchReader { public: explicit ConsumerBasedRecordBatchReader( @@ -126,6 +136,7 @@ class ConsumerBasedRecordBatchReader : public RecordBatchReader { std::shared_ptr consumer_; }; +/// \brief An implementation of a Flight SQL service backed by Acero. class AceroFlightSqlServer : public FlightSqlServerBase { public: AceroFlightSqlServer() { @@ -156,13 +167,7 @@ class AceroFlightSqlServer : public FlightSqlServerBase { ARROW_LOG(INFO) << "GetFlightInfoSubstraitPlan: preparing plan with output schema " << *output_schema; - ARROW_ASSIGN_OR_RAISE(auto ticket, CreateStatementQueryTicket(command.plan.plan)); - std::vector endpoints{ - FlightEndpoint{Ticket{std::move(ticket)}, /*locations=*/{}}}; - ARROW_ASSIGN_OR_RAISE( - auto info, FlightInfo::Make(*output_schema, descriptor, std::move(endpoints), - /*total_records=*/-1, /*total_bytes=*/-1)); - return std::unique_ptr(new FlightInfo(std::move(info))); + return MakeFlightInfo(command.plan.plan, descriptor, *output_schema); } arrow::Result> GetFlightInfoPreparedStatement( @@ -178,13 +183,7 @@ class AceroFlightSqlServer : public FlightSqlServerBase { plan = it->second; } - ARROW_ASSIGN_OR_RAISE(auto ticket, CreateStatementQueryTicket(plan->ToString())); - std::vector endpoints{ - FlightEndpoint{Ticket{std::move(ticket)}, /*locations=*/{}}}; - ARROW_ASSIGN_OR_RAISE(auto info, - FlightInfo::Make(Schema({}), descriptor, std::move(endpoints), - /*total_records=*/-1, /*total_bytes=*/-1)); - return std::unique_ptr(new FlightInfo(std::move(info))); + return MakeFlightInfo(plan->ToString(), descriptor, Schema({})); } arrow::Result> DoGetStatement( @@ -279,6 +278,17 @@ class AceroFlightSqlServer : public FlightSqlServerBase { return output_schema; } + arrow::Result> MakeFlightInfo( + const std::string& plan, const FlightDescriptor& descriptor, const Schema& schema) { + ARROW_ASSIGN_OR_RAISE(auto ticket, CreateStatementQueryTicket(plan)); + std::vector endpoints{ + FlightEndpoint{Ticket{std::move(ticket)}, /*locations=*/{}}}; + ARROW_ASSIGN_OR_RAISE(auto info, + FlightInfo::Make(schema, descriptor, std::move(endpoints), + /*total_records=*/-1, /*total_bytes=*/-1)); + return std::make_unique(std::move(info)); + } + std::mutex mutex_; std::unordered_map> prepared_; int64_t counter_; diff --git a/cpp/src/arrow/flight/sql/example/sqlite_server.cc b/cpp/src/arrow/flight/sql/example/sqlite_server.cc index e509dbe860107..0d0a7c1ea0e01 100644 --- a/cpp/src/arrow/flight/sql/example/sqlite_server.cc +++ b/cpp/src/arrow/flight/sql/example/sqlite_server.cc @@ -228,7 +228,7 @@ int32_t GetSqlTypeFromTypeName(const char* sqlite_type) { class SQLiteFlightSqlServer::Impl { private: sqlite3* db_; - std::string db_uri_; + const std::string db_uri_; std::mutex mutex_; std::unordered_map> prepared_statements_; std::unordered_map open_transactions_; @@ -236,6 +236,7 @@ class SQLiteFlightSqlServer::Impl { arrow::Result> GetStatementByHandle( const std::string& handle) { + std::lock_guard guard(mutex_); auto search = prepared_statements_.find(handle); if (search == prepared_statements_.end()) { return Status::KeyError("Prepared statement not found"); @@ -433,6 +434,7 @@ class SQLiteFlightSqlServer::Impl { arrow::Result CreatePreparedStatement( const ServerCallContext& context, const ActionCreatePreparedStatementRequest& request) { + std::lock_guard guard(mutex_); std::shared_ptr statement; ARROW_ASSIGN_OR_RAISE(statement, SqliteStatement::Create(db_, request.query)); std::string handle = GenerateRandomString(); @@ -468,6 +470,7 @@ class SQLiteFlightSqlServer::Impl { Status ClosePreparedStatement(const ServerCallContext& context, const ActionClosePreparedStatementRequest& request) { + std::lock_guard guard(mutex_); const std::string& prepared_statement_handle = request.prepared_statement_handle; auto search = prepared_statements_.find(prepared_statement_handle); @@ -483,6 +486,7 @@ class SQLiteFlightSqlServer::Impl { arrow::Result> GetFlightInfoPreparedStatement( const ServerCallContext& context, const PreparedStatementQuery& command, const FlightDescriptor& descriptor) { + std::lock_guard guard(mutex_); const std::string& prepared_statement_handle = command.prepared_statement_handle; auto search = prepared_statements_.find(prepared_statement_handle); @@ -499,6 +503,7 @@ class SQLiteFlightSqlServer::Impl { arrow::Result> DoGetPreparedStatement( const ServerCallContext& context, const PreparedStatementQuery& command) { + std::lock_guard guard(mutex_); const std::string& prepared_statement_handle = command.prepared_statement_handle; auto search = prepared_statements_.find(prepared_statement_handle); @@ -708,26 +713,31 @@ class SQLiteFlightSqlServer::Impl { ARROW_RETURN_NOT_OK(ExecuteSql(new_db, "BEGIN TRANSACTION")); + std::lock_guard guard(mutex_); open_transactions_[handle] = new_db; return ActionBeginTransactionResult{std::move(handle)}; } Status EndTransaction(const ServerCallContext& context, const ActionEndTransactionRequest& request) { - std::lock_guard guard(mutex_); - auto it = open_transactions_.find(request.transaction_id); - if (it == open_transactions_.end()) { - return Status::KeyError("Unknown transaction ID: ", request.transaction_id); - } - Status status; - if (request.action == ActionEndTransactionRequest::kCommit) { - status = ExecuteSql(it->second, "COMMIT"); - } else { - status = ExecuteSql(it->second, "ROLLBACK"); + sqlite3* transaction = nullptr; + { + std::lock_guard guard(mutex_); + auto it = open_transactions_.find(request.transaction_id); + if (it == open_transactions_.end()) { + return Status::KeyError("Unknown transaction ID: ", request.transaction_id); + } + + if (request.action == ActionEndTransactionRequest::kCommit) { + status = ExecuteSql(it->second, "COMMIT"); + } else { + status = ExecuteSql(it->second, "ROLLBACK"); + } + transaction = it->second; + open_transactions_.erase(it); } - sqlite3_close(it->second); - open_transactions_.erase(it); + sqlite3_close(transaction); return status; } };