Skip to content

Commit

Permalink
Address feedback
Browse files Browse the repository at this point in the history
  • Loading branch information
lidavidm committed Sep 13, 2022
1 parent 1a8af54 commit 274a3d8
Show file tree
Hide file tree
Showing 3 changed files with 56 additions and 35 deletions.
17 changes: 9 additions & 8 deletions cpp/src/arrow/flight/integration_tests/test_integration.cc
Original file line number Diff line number Diff line change
Expand Up @@ -38,12 +38,15 @@
#include "arrow/ipc/dictionary.h"
#include "arrow/status.h"
#include "arrow/testing/gtest_util.h"
#include "arrow/util/checked_cast.h"

namespace arrow {
namespace flight {
namespace integration_tests {
namespace {

using arrow::internal::checked_cast;

/// \brief The server for the basic auth integration test.
class AuthBasicProtoServer : public FlightServerBase {
Status DoAction(const ServerCallContext& context, const Action& action,
Expand Down Expand Up @@ -1092,10 +1095,8 @@ class FlightSqlExtensionScenario : public FlightSqlScenario {
ARROW_ASSIGN_OR_RAISE(auto chunk, reader->Next());
if (!chunk.data) break;

const UInt32Array& info_name =
static_cast<const UInt32Array&>(*chunk.data->column(0));
const DenseUnionArray& value =
static_cast<const DenseUnionArray&>(*chunk.data->column(1));
const auto& info_name = checked_cast<const UInt32Array&>(*chunk.data->column(0));
const auto& value = checked_cast<const DenseUnionArray&>(*chunk.data->column(1));

for (int64_t i = 0; i < chunk.data->num_rows(); i++) {
const uint32_t code = info_name.Value(i);
Expand All @@ -1104,25 +1105,25 @@ class FlightSqlExtensionScenario : public FlightSqlScenario {
}
switch (value.type_code(i)) {
case 0: { // string
std::string slot = static_cast<const StringArray&>(*value.field(0))
std::string slot = checked_cast<const StringArray&>(*value.field(0))
.GetString(value.value_offset(i));
info_values[code] = sql::SqlInfoResult(std::move(slot));
break;
}
case 1: { // bool
bool slot = static_cast<const BooleanArray&>(*value.field(1))
bool slot = checked_cast<const BooleanArray&>(*value.field(1))
.Value(value.value_offset(i));
info_values[code] = sql::SqlInfoResult(slot);
break;
}
case 2: { // int64_t
int64_t slot = static_cast<const Int64Array&>(*value.field(2))
int64_t slot = checked_cast<const Int64Array&>(*value.field(2))
.Value(value.value_offset(i));
info_values[code] = sql::SqlInfoResult(slot);
break;
}
case 3: { // int32_t
int32_t slot = static_cast<const Int32Array&>(*value.field(3))
int32_t slot = checked_cast<const Int32Array&>(*value.field(3))
.Value(value.value_offset(i));
info_values[code] = sql::SqlInfoResult(slot);
break;
Expand Down
38 changes: 24 additions & 14 deletions cpp/src/arrow/flight/sql/example/acero_server.cc
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,9 @@ namespace sql {
namespace acero_example {

namespace {
/// \brief A SinkNodeConsumer that saves the schema as given to it by
/// the ExecPlan. Used to retrieve the schema of a Substrait plan to
/// fulfill the Flight SQL API contract.
class GetSchemaSinkNodeConsumer : public compute::SinkNodeConsumer {
public:
Status Init(const std::shared_ptr<Schema>& schema,
Expand All @@ -49,6 +52,10 @@ class GetSchemaSinkNodeConsumer : public compute::SinkNodeConsumer {
std::shared_ptr<Schema> schema_;
};

/// \brief A SinkNodeConsumer that internally saves batches into a
/// queue, so that it can be read from a RecordBatchReader. In other
/// words, this bridges a push-based interface (ExecPlan) to a
/// pull-based interface (RecordBatchReader).
class QueuingSinkNodeConsumer : public compute::SinkNodeConsumer {
public:
QueuingSinkNodeConsumer() : schema_(nullptr), finished_(false) {}
Expand Down Expand Up @@ -105,6 +112,9 @@ class QueuingSinkNodeConsumer : public compute::SinkNodeConsumer {
bool finished_;
};

/// \brief A RecordBatchReader that pulls from the
/// QueuingSinkNodeConsumer above, blocking until results are
/// available as necessary.
class ConsumerBasedRecordBatchReader : public RecordBatchReader {
public:
explicit ConsumerBasedRecordBatchReader(
Expand All @@ -126,6 +136,7 @@ class ConsumerBasedRecordBatchReader : public RecordBatchReader {
std::shared_ptr<QueuingSinkNodeConsumer> consumer_;
};

/// \brief An implementation of a Flight SQL service backed by Acero.
class AceroFlightSqlServer : public FlightSqlServerBase {
public:
AceroFlightSqlServer() {
Expand Down Expand Up @@ -156,13 +167,7 @@ class AceroFlightSqlServer : public FlightSqlServerBase {
ARROW_LOG(INFO) << "GetFlightInfoSubstraitPlan: preparing plan with output schema "
<< *output_schema;

ARROW_ASSIGN_OR_RAISE(auto ticket, CreateStatementQueryTicket(command.plan.plan));
std::vector<FlightEndpoint> endpoints{
FlightEndpoint{Ticket{std::move(ticket)}, /*locations=*/{}}};
ARROW_ASSIGN_OR_RAISE(
auto info, FlightInfo::Make(*output_schema, descriptor, std::move(endpoints),
/*total_records=*/-1, /*total_bytes=*/-1));
return std::unique_ptr<FlightInfo>(new FlightInfo(std::move(info)));
return MakeFlightInfo(command.plan.plan, descriptor, *output_schema);
}

arrow::Result<std::unique_ptr<FlightInfo>> GetFlightInfoPreparedStatement(
Expand All @@ -178,13 +183,7 @@ class AceroFlightSqlServer : public FlightSqlServerBase {
plan = it->second;
}

ARROW_ASSIGN_OR_RAISE(auto ticket, CreateStatementQueryTicket(plan->ToString()));
std::vector<FlightEndpoint> endpoints{
FlightEndpoint{Ticket{std::move(ticket)}, /*locations=*/{}}};
ARROW_ASSIGN_OR_RAISE(auto info,
FlightInfo::Make(Schema({}), descriptor, std::move(endpoints),
/*total_records=*/-1, /*total_bytes=*/-1));
return std::unique_ptr<FlightInfo>(new FlightInfo(std::move(info)));
return MakeFlightInfo(plan->ToString(), descriptor, Schema({}));
}

arrow::Result<std::unique_ptr<FlightDataStream>> DoGetStatement(
Expand Down Expand Up @@ -279,6 +278,17 @@ class AceroFlightSqlServer : public FlightSqlServerBase {
return output_schema;
}

arrow::Result<std::unique_ptr<FlightInfo>> MakeFlightInfo(
const std::string& plan, const FlightDescriptor& descriptor, const Schema& schema) {
ARROW_ASSIGN_OR_RAISE(auto ticket, CreateStatementQueryTicket(plan));
std::vector<FlightEndpoint> endpoints{
FlightEndpoint{Ticket{std::move(ticket)}, /*locations=*/{}}};
ARROW_ASSIGN_OR_RAISE(auto info,
FlightInfo::Make(schema, descriptor, std::move(endpoints),
/*total_records=*/-1, /*total_bytes=*/-1));
return std::make_unique<FlightInfo>(std::move(info));
}

std::mutex mutex_;
std::unordered_map<std::string, std::shared_ptr<arrow::Buffer>> prepared_;
int64_t counter_;
Expand Down
36 changes: 23 additions & 13 deletions cpp/src/arrow/flight/sql/example/sqlite_server.cc
Original file line number Diff line number Diff line change
Expand Up @@ -228,14 +228,15 @@ int32_t GetSqlTypeFromTypeName(const char* sqlite_type) {
class SQLiteFlightSqlServer::Impl {
private:
sqlite3* db_;
std::string db_uri_;
const std::string db_uri_;
std::mutex mutex_;
std::unordered_map<std::string, std::shared_ptr<SqliteStatement>> prepared_statements_;
std::unordered_map<std::string, sqlite3*> open_transactions_;
std::default_random_engine gen_;

arrow::Result<std::shared_ptr<SqliteStatement>> GetStatementByHandle(
const std::string& handle) {
std::lock_guard<std::mutex> guard(mutex_);
auto search = prepared_statements_.find(handle);
if (search == prepared_statements_.end()) {
return Status::KeyError("Prepared statement not found");
Expand Down Expand Up @@ -433,6 +434,7 @@ class SQLiteFlightSqlServer::Impl {
arrow::Result<ActionCreatePreparedStatementResult> CreatePreparedStatement(
const ServerCallContext& context,
const ActionCreatePreparedStatementRequest& request) {
std::lock_guard<std::mutex> guard(mutex_);
std::shared_ptr<SqliteStatement> statement;
ARROW_ASSIGN_OR_RAISE(statement, SqliteStatement::Create(db_, request.query));
std::string handle = GenerateRandomString();
Expand Down Expand Up @@ -468,6 +470,7 @@ class SQLiteFlightSqlServer::Impl {

Status ClosePreparedStatement(const ServerCallContext& context,
const ActionClosePreparedStatementRequest& request) {
std::lock_guard<std::mutex> guard(mutex_);
const std::string& prepared_statement_handle = request.prepared_statement_handle;

auto search = prepared_statements_.find(prepared_statement_handle);
Expand All @@ -483,6 +486,7 @@ class SQLiteFlightSqlServer::Impl {
arrow::Result<std::unique_ptr<FlightInfo>> GetFlightInfoPreparedStatement(
const ServerCallContext& context, const PreparedStatementQuery& command,
const FlightDescriptor& descriptor) {
std::lock_guard<std::mutex> guard(mutex_);
const std::string& prepared_statement_handle = command.prepared_statement_handle;

auto search = prepared_statements_.find(prepared_statement_handle);
Expand All @@ -499,6 +503,7 @@ class SQLiteFlightSqlServer::Impl {

arrow::Result<std::unique_ptr<FlightDataStream>> DoGetPreparedStatement(
const ServerCallContext& context, const PreparedStatementQuery& command) {
std::lock_guard<std::mutex> guard(mutex_);
const std::string& prepared_statement_handle = command.prepared_statement_handle;

auto search = prepared_statements_.find(prepared_statement_handle);
Expand Down Expand Up @@ -708,26 +713,31 @@ class SQLiteFlightSqlServer::Impl {

ARROW_RETURN_NOT_OK(ExecuteSql(new_db, "BEGIN TRANSACTION"));

std::lock_guard<std::mutex> guard(mutex_);
open_transactions_[handle] = new_db;
return ActionBeginTransactionResult{std::move(handle)};
}

Status EndTransaction(const ServerCallContext& context,
const ActionEndTransactionRequest& request) {
std::lock_guard<std::mutex> guard(mutex_);
auto it = open_transactions_.find(request.transaction_id);
if (it == open_transactions_.end()) {
return Status::KeyError("Unknown transaction ID: ", request.transaction_id);
}

Status status;
if (request.action == ActionEndTransactionRequest::kCommit) {
status = ExecuteSql(it->second, "COMMIT");
} else {
status = ExecuteSql(it->second, "ROLLBACK");
sqlite3* transaction = nullptr;
{
std::lock_guard<std::mutex> guard(mutex_);
auto it = open_transactions_.find(request.transaction_id);
if (it == open_transactions_.end()) {
return Status::KeyError("Unknown transaction ID: ", request.transaction_id);
}

if (request.action == ActionEndTransactionRequest::kCommit) {
status = ExecuteSql(it->second, "COMMIT");
} else {
status = ExecuteSql(it->second, "ROLLBACK");
}
transaction = it->second;
open_transactions_.erase(it);
}
sqlite3_close(it->second);
open_transactions_.erase(it);
sqlite3_close(transaction);
return status;
}
};
Expand Down

0 comments on commit 274a3d8

Please sign in to comment.