Skip to content

Commit

Permalink
[C++][Java] Test SqlInfo values as well
Browse files Browse the repository at this point in the history
  • Loading branch information
lidavidm committed Aug 29, 2022
1 parent 9ff2dc1 commit 265f37a
Show file tree
Hide file tree
Showing 7 changed files with 320 additions and 29 deletions.
177 changes: 160 additions & 17 deletions cpp/src/arrow/flight/integration_tests/test_integration.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,17 @@
// under the License.

#include "arrow/flight/integration_tests/test_integration.h"

#include <iostream>
#include <memory>
#include <string>
#include <unordered_map>
#include <utility>
#include <vector>

#include "arrow/array/array_binary.h"
#include "arrow/array/array_nested.h"
#include "arrow/array/array_primitive.h"
#include "arrow/flight/client_middleware.h"
#include "arrow/flight/server_middleware.h"
#include "arrow/flight/sql/client.h"
Expand All @@ -28,12 +39,6 @@
#include "arrow/status.h"
#include "arrow/testing/gtest_util.h"

#include <iostream>
#include <memory>
#include <string>
#include <utility>
#include <vector>

namespace arrow {
namespace flight {
namespace integration_tests {
Expand Down Expand Up @@ -326,11 +331,37 @@ arrow::Status AssertEq(const T& expected, const T& actual, const std::string& me
return Status::OK();
}

template <typename T>
arrow::Status AssertUnprintableEq(const T& expected, const T& actual,
const std::string& message) {
if (expected != actual) {
return Status::Invalid(message);
}
return Status::OK();
}

/// \brief The server used for testing Flight SQL, this implements a static Flight SQL
/// server which only asserts that commands called during integration tests are being
/// parsed correctly and returns the expected schemas to be validated on client.
class FlightSqlScenarioServer : public sql::FlightSqlServerBase {
public:
FlightSqlScenarioServer() : sql::FlightSqlServerBase() {
RegisterSqlInfo(sql::SqlInfoOptions::SqlInfo::FLIGHT_SQL_SERVER_SUBSTRAIT,
sql::SqlInfoResult(true));
RegisterSqlInfo(sql::SqlInfoOptions::SqlInfo::FLIGHT_SQL_SERVER_SUBSTRAIT_MIN_VERSION,
sql::SqlInfoResult("min_version"));
RegisterSqlInfo(sql::SqlInfoOptions::SqlInfo::FLIGHT_SQL_SERVER_SUBSTRAIT_MAX_VERSION,
sql::SqlInfoResult("max_version"));
RegisterSqlInfo(sql::SqlInfoOptions::SqlInfo::FLIGHT_SQL_SERVER_TRANSACTION,
sql::SqlInfoResult(sql::SqlInfoOptions::SqlSupportedTransaction::
SQL_SUPPORTED_TRANSACTION_SAVEPOINT));
RegisterSqlInfo(sql::SqlInfoOptions::SqlInfo::FLIGHT_SQL_SERVER_CANCEL,
sql::SqlInfoResult(true));
RegisterSqlInfo(sql::SqlInfoOptions::SqlInfo::FLIGHT_SQL_SERVER_STATEMENT_TIMEOUT,
sql::SqlInfoResult(int32_t(42)));
RegisterSqlInfo(sql::SqlInfoOptions::SqlInfo::FLIGHT_SQL_SERVER_TRANSACTION_TIMEOUT,
sql::SqlInfoResult(int32_t(7)));
}
arrow::Result<std::unique_ptr<FlightInfo>> GetFlightInfoStatement(
const ServerCallContext& context, const sql::StatementQuery& command,
const FlightDescriptor& descriptor) override {
Expand Down Expand Up @@ -487,21 +518,29 @@ class FlightSqlScenarioServer : public sql::FlightSqlServerBase {
arrow::Result<std::unique_ptr<FlightInfo>> GetFlightInfoSqlInfo(
const ServerCallContext& context, const sql::GetSqlInfo& command,
const FlightDescriptor& descriptor) override {
ARROW_RETURN_NOT_OK(AssertEq<int64_t>(2, command.info.size(),
"Wrong number of SqlInfo values passed"));
ARROW_RETURN_NOT_OK(
AssertEq<int32_t>(sql::SqlInfoOptions::SqlInfo::FLIGHT_SQL_SERVER_NAME,
command.info[0], "Unexpected SqlInfo passed"));
ARROW_RETURN_NOT_OK(
AssertEq<int32_t>(sql::SqlInfoOptions::SqlInfo::FLIGHT_SQL_SERVER_READ_ONLY,
command.info[1], "Unexpected SqlInfo passed"));

return GetFlightInfoForCommand(descriptor, sql::SqlSchema::GetSqlInfoSchema());
if (command.info.size() == 2) {
// Integration test for the protocol messages
ARROW_RETURN_NOT_OK(
AssertEq<int32_t>(sql::SqlInfoOptions::SqlInfo::FLIGHT_SQL_SERVER_NAME,
command.info[0], "Unexpected SqlInfo passed"));
ARROW_RETURN_NOT_OK(
AssertEq<int32_t>(sql::SqlInfoOptions::SqlInfo::FLIGHT_SQL_SERVER_READ_ONLY,
command.info[1], "Unexpected SqlInfo passed"));

return GetFlightInfoForCommand(descriptor, sql::SqlSchema::GetSqlInfoSchema());
}
// Integration test for the values themselves
return sql::FlightSqlServerBase::GetFlightInfoSqlInfo(context, command, descriptor);
}

arrow::Result<std::unique_ptr<FlightDataStream>> DoGetSqlInfo(
const ServerCallContext& context, const sql::GetSqlInfo& command) override {
return DoGetForTestCase(sql::SqlSchema::GetSqlInfoSchema());
if (command.info.size() == 2) {
// Integration test for the protocol messages
return DoGetForTestCase(sql::SqlSchema::GetSqlInfoSchema());
}
// Integration test for the values themselves
return sql::FlightSqlServerBase::DoGetSqlInfo(context, command);
}

arrow::Result<std::unique_ptr<FlightInfo>> GetFlightInfoSchemas(
Expand Down Expand Up @@ -1007,6 +1046,9 @@ class FlightSqlExtensionScenario : public FlightSqlScenario {
Status RunClient(std::unique_ptr<FlightClient> client) override {
sql::FlightSqlClient sql_client(std::move(client));
Status status;
if (!(status = ValidateMetadataRetrieval(&sql_client)).ok()) {
return status.WithMessage("MetadataRetrieval failed: ", status.message());
}
if (!(status = ValidateStatementExecution(&sql_client)).ok()) {
return status.WithMessage("StatementExecution failed: ", status.message());
}
Expand All @@ -1019,6 +1061,107 @@ class FlightSqlExtensionScenario : public FlightSqlScenario {
return Status::OK();
}

Status ValidateMetadataRetrieval(sql::FlightSqlClient* sql_client) {
std::unique_ptr<FlightInfo> info;
std::vector<int32_t> sql_info = {
sql::SqlInfoOptions::FLIGHT_SQL_SERVER_SUBSTRAIT,
sql::SqlInfoOptions::FLIGHT_SQL_SERVER_SUBSTRAIT_MIN_VERSION,
sql::SqlInfoOptions::FLIGHT_SQL_SERVER_SUBSTRAIT_MAX_VERSION,
sql::SqlInfoOptions::FLIGHT_SQL_SERVER_TRANSACTION,
sql::SqlInfoOptions::FLIGHT_SQL_SERVER_CANCEL,
sql::SqlInfoOptions::FLIGHT_SQL_SERVER_STATEMENT_TIMEOUT,
sql::SqlInfoOptions::FLIGHT_SQL_SERVER_TRANSACTION_TIMEOUT,
};
ARROW_ASSIGN_OR_RAISE(info, sql_client->GetSqlInfo({}, sql_info));
ARROW_ASSIGN_OR_RAISE(auto reader,
sql_client->DoGet({}, info->endpoints()[0].ticket));

ARROW_ASSIGN_OR_RAISE(auto actual_schema, reader->GetSchema());
if (!sql::SqlSchema::GetSqlInfoSchema()->Equals(*actual_schema,
/*check_metadata=*/true)) {
return Status::Invalid("Schemas did not match. Expected:\n",
*sql::SqlSchema::GetSqlInfoSchema(), "\nActual:\n",
*actual_schema);
}

sql::SqlInfoResultMap info_values;
while (true) {
ARROW_ASSIGN_OR_RAISE(auto chunk, reader->Next());
if (!chunk.data) break;

const UInt32Array& info_name =
static_cast<const UInt32Array&>(*chunk.data->column(0));
const DenseUnionArray& value =
static_cast<const DenseUnionArray&>(*chunk.data->column(1));

for (int64_t i = 0; i < chunk.data->num_rows(); i++) {
const uint32_t code = info_name.Value(i);
if (info_values.find(code) != info_values.end()) {
return Status::Invalid("Duplicate SqlInfo value ", code);
}
switch (value.type_code(i)) {
case 0: { // string
std::string slot = static_cast<const StringArray&>(*value.field(0))
.GetString(value.value_offset(i));
info_values[code] = sql::SqlInfoResult(std::move(slot));
break;
}
case 1: { // bool
bool slot = static_cast<const BooleanArray&>(*value.field(1))
.Value(value.value_offset(i));
info_values[code] = sql::SqlInfoResult(slot);
break;
}
case 2: { // int64_t
int64_t slot = static_cast<const Int64Array&>(*value.field(2))
.Value(value.value_offset(i));
info_values[code] = sql::SqlInfoResult(slot);
break;
}
case 3: { // int32_t
int32_t slot = static_cast<const Int32Array&>(*value.field(3))
.Value(value.value_offset(i));
info_values[code] = sql::SqlInfoResult(slot);
break;
}
default:
return Status::NotImplemented("Decoding SqlInfoResult of type code ",
value.type_code(i));
}
}
}

ARROW_RETURN_NOT_OK(AssertUnprintableEq(
info_values[sql::SqlInfoOptions::FLIGHT_SQL_SERVER_SUBSTRAIT],
sql::SqlInfoResult(true), "FLIGHT_SQL_SERVER_SUBSTRAIT did not match"));
ARROW_RETURN_NOT_OK(AssertUnprintableEq(
info_values[sql::SqlInfoOptions::FLIGHT_SQL_SERVER_SUBSTRAIT_MIN_VERSION],
sql::SqlInfoResult("min_version"),
"FLIGHT_SQL_SERVER_SUBSTRAIT_MIN_VERSION did not match"));
ARROW_RETURN_NOT_OK(AssertUnprintableEq(
info_values[sql::SqlInfoOptions::FLIGHT_SQL_SERVER_SUBSTRAIT_MAX_VERSION],
sql::SqlInfoResult("max_version"),
"FLIGHT_SQL_SERVER_SUBSTRAIT_MAX_VERSION did not match"));
ARROW_RETURN_NOT_OK(AssertUnprintableEq(
info_values[sql::SqlInfoOptions::FLIGHT_SQL_SERVER_TRANSACTION],
sql::SqlInfoResult(sql::SqlInfoOptions::SqlSupportedTransaction::
SQL_SUPPORTED_TRANSACTION_SAVEPOINT),
"FLIGHT_SQL_SERVER_TRANSACTION did not match"));
ARROW_RETURN_NOT_OK(AssertUnprintableEq(
info_values[sql::SqlInfoOptions::FLIGHT_SQL_SERVER_CANCEL],
sql::SqlInfoResult(true), "FLIGHT_SQL_SERVER_CANCEL did not match"));
ARROW_RETURN_NOT_OK(AssertUnprintableEq(
info_values[sql::SqlInfoOptions::FLIGHT_SQL_SERVER_STATEMENT_TIMEOUT],
sql::SqlInfoResult(int32_t(42)),
"FLIGHT_SQL_SERVER_STATEMENT_TIMEOUT did not match"));
ARROW_RETURN_NOT_OK(AssertUnprintableEq(
info_values[sql::SqlInfoOptions::FLIGHT_SQL_SERVER_TRANSACTION_TIMEOUT],
sql::SqlInfoResult(int32_t(7)),
"FLIGHT_SQL_SERVER_TRANSACTION_TIMEOUT did not match"));

return Status::OK();
}

Status ValidateStatementExecution(sql::FlightSqlClient* sql_client) {
ARROW_ASSIGN_OR_RAISE(std::unique_ptr<FlightInfo> info,
sql_client->ExecuteSubstrait({}, kSubstraitPlan));
Expand Down
6 changes: 6 additions & 0 deletions cpp/src/arrow/flight/sql/example/acero_server.cc
Original file line number Diff line number Diff line change
Expand Up @@ -131,10 +131,16 @@ class AceroFlightSqlServer : public FlightSqlServerBase {
AceroFlightSqlServer() {
RegisterSqlInfo(SqlInfoOptions::SqlInfo::FLIGHT_SQL_SERVER_SUBSTRAIT,
SqlInfoResult(true));
RegisterSqlInfo(SqlInfoOptions::SqlInfo::FLIGHT_SQL_SERVER_SUBSTRAIT_MIN_VERSION,
SqlInfoResult("0.6.0"));
RegisterSqlInfo(SqlInfoOptions::SqlInfo::FLIGHT_SQL_SERVER_SUBSTRAIT_MAX_VERSION,
SqlInfoResult("0.6.0"));
RegisterSqlInfo(
SqlInfoOptions::SqlInfo::FLIGHT_SQL_SERVER_TRANSACTION,
SqlInfoResult(
SqlInfoOptions::SqlSupportedTransaction::SQL_SUPPORTED_TRANSACTION_NONE));
RegisterSqlInfo(SqlInfoOptions::SqlInfo::FLIGHT_SQL_SERVER_CANCEL,
SqlInfoResult(false));
}

arrow::Result<std::unique_ptr<FlightInfo>> GetFlightInfoSubstraitPlan(
Expand Down
14 changes: 11 additions & 3 deletions cpp/src/arrow/flight/sql/types.h
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,14 @@ struct ARROW_FLIGHT_SQL_EXPORT SqlInfoOptions {
/// supports executing Substrait plans.
FLIGHT_SQL_SERVER_SUBSTRAIT = 4,

/// Retrieves a string value indicating the minimum supported
/// Substrait version, or null if Substrait is not supported.
FLIGHT_SQL_SERVER_SUBSTRAIT_MIN_VERSION = 5,

/// Retrieves a string value indicating the maximum supported
/// Substrait version, or null if Substrait is not supported.
FLIGHT_SQL_SERVER_SUBSTRAIT_MAX_VERSION = 6,

/// Retrieves an int32 indicating whether the Flight SQL Server
/// supports the BeginTransaction, EndTransaction, BeginSavepoint,
/// and EndSavepoint actions.
Expand All @@ -85,11 +93,11 @@ struct ARROW_FLIGHT_SQL_EXPORT SqlInfoOptions {
/// whether the server implements the Flight SQL API endpoints.
///
/// The possible values are listed in `SqlSupportedTransaction`.
FLIGHT_SQL_SERVER_TRANSACTION = 5,
FLIGHT_SQL_SERVER_TRANSACTION = 7,

/// Retrieves a boolean value indicating whether the Flight SQL Server
/// supports explicit query cancellation (the CancelQuery action).
FLIGHT_SQL_SERVER_CANCEL = 6,
FLIGHT_SQL_SERVER_CANCEL = 8,

/// Retrieves an int32 value indicating the timeout (in milliseconds) for
/// prepared statement handles.
Expand All @@ -101,7 +109,7 @@ struct ARROW_FLIGHT_SQL_EXPORT SqlInfoOptions {
/// transactions, since transactions are not tied to a connection.
///
/// If 0, there is no timeout.
FLIGHT_SQL_TRANSACTION_TIMEOUT = 101,
FLIGHT_SQL_SERVER_TRANSACTION_TIMEOUT = 101,

/// @}

Expand Down
16 changes: 14 additions & 2 deletions format/FlightSql.proto
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,18 @@ enum SqlInfo {
*/
FLIGHT_SQL_SERVER_SUBSTRAIT = 4;

/*
* Retrieves a string value indicating the minimum supported Substrait version, or null
* if Substrait is not supported.
*/
FLIGHT_SQL_SERVER_SUBSTRAIT_MIN_VERSION = 5;

/*
* Retrieves a string value indicating the maximum supported Substrait version, or null
* if Substrait is not supported.
*/
FLIGHT_SQL_SERVER_SUBSTRAIT_MAX_VERSION = 6;

/*
* Retrieves an int32 indicating whether the Flight SQL Server supports the
* BeginTransaction/EndTransaction/BeginSavepoint/EndSavepoint actions.
Expand All @@ -106,13 +118,13 @@ enum SqlInfo {
*
* The possible values are listed in `SqlSupportedTransaction`.
*/
FLIGHT_SQL_SERVER_TRANSACTION = 5;
FLIGHT_SQL_SERVER_TRANSACTION = 7;

/*
* Retrieves a boolean value indicating whether the Flight SQL Server supports explicit
* query cancellation (the CancelQuery action).
*/
FLIGHT_SQL_SERVER_CANCEL = 6;
FLIGHT_SQL_SERVER_CANCEL = 8;

/*
* Retrieves an int32 indicating the timeout (in milliseconds) for prepared statement handles.
Expand Down
Loading

0 comments on commit 265f37a

Please sign in to comment.