Skip to content

Commit

Permalink
ARROW-5397: [FlightRPC] Add TLS certificates for testing Flight
Browse files Browse the repository at this point in the history
This needs apache/arrow-testing#2.

Author: David Li <li.davidm96@gmail.com>

Closes #4510 from lihalite/flight-tls and squashes the following commits:

5eff724 <David Li> Don't set wait_for_ready in Flight
776b9d0 <David Li> Add tests for TLS in Flight (C++, Python)
9d2efa2 <David Li> Allow multiple TLS certificates in Flight
  • Loading branch information
lidavidm authored and pitrou committed Jun 13, 2019
1 parent b21999e commit b7e8ed7
Show file tree
Hide file tree
Showing 14 changed files with 398 additions and 143 deletions.
3 changes: 0 additions & 3 deletions cpp/src/arrow/flight/client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -60,9 +60,6 @@ struct ClientRpc {
grpc::ClientContext context;

explicit ClientRpc(const FlightCallOptions& options) {
/// XXX workaround until we have a handshake in Connect
context.set_wait_for_ready(true);

if (options.timeout.count() >= 0) {
std::chrono::system_clock::time_point deadline =
std::chrono::time_point_cast<std::chrono::system_clock::time_point::duration>(
Expand Down
89 changes: 72 additions & 17 deletions cpp/src/arrow/flight/flight-test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -176,29 +176,22 @@ TEST(TestFlight, ConnectUri) {

class TestFlightClient : public ::testing::Test {
public:
// Uncomment these when you want to run the server separately for
// debugging/valgrind/gdb
void SetUp() {
Location location;
std::unique_ptr<FlightServerBase> server = ExampleTestServer();

// void SetUp() {
// port_ = 92358;
// ASSERT_OK(ConnectClient());
// }
// void TearDown() {}
ASSERT_OK(Location::ForGrpcTcp("localhost", GetListenPort(), &location));
FlightServerOptions options(location);
ASSERT_OK(server->Init(options));

void SetUp() {
server_.reset(new TestServer("flight-test-server"));
server_->Start();
port_ = server_->port();
server_.reset(new InProcessTestServer(std::move(server), location));
ASSERT_OK(server_->Start());
ASSERT_OK(ConnectClient());
}

void TearDown() { server_->Stop(); }

Status ConnectClient() {
Location location;
RETURN_NOT_OK(Location::ForGrpcTcp("localhost", port_, &location));
return FlightClient::Connect(location, &client_);
}
Status ConnectClient() { return FlightClient::Connect(server_->location(), &client_); }

template <typename EndpointCheckFunc>
void CheckDoGet(const FlightDescriptor& descr, const BatchVector& expected_batches,
Expand Down Expand Up @@ -236,7 +229,7 @@ class TestFlightClient : public ::testing::Test {
protected:
int port_;
std::unique_ptr<FlightClient> client_;
std::unique_ptr<TestServer> server_;
std::unique_ptr<InProcessTestServer> server_;
};

class AuthTestServer : public FlightServerBase {
Expand All @@ -249,6 +242,16 @@ class AuthTestServer : public FlightServerBase {
}
};

class TlsTestServer : public FlightServerBase {
Status DoAction(const ServerCallContext& context, const Action& action,
std::unique_ptr<ResultStream>* result) override {
std::shared_ptr<Buffer> buf;
RETURN_NOT_OK(Buffer::FromString("Hello, world!", &buf));
*result = std::unique_ptr<ResultStream>(new SimpleResultStream({Result{buf}}));
return Status::OK();
}
};

class DoPutTestServer : public FlightServerBase {
public:
Status DoPut(const ServerCallContext& context,
Expand Down Expand Up @@ -336,6 +339,42 @@ class TestDoPut : public ::testing::Test {
DoPutTestServer* do_put_server_;
};

class TestTls : public ::testing::Test {
public:
void SetUp() {
Location location;
std::unique_ptr<FlightServerBase> server(new TlsTestServer);

ASSERT_OK(Location::ForGrpcTls("localhost", GetListenPort(), &location));
FlightServerOptions options(location);
ASSERT_RAISES(UnknownError, server->Init(options));
ASSERT_OK(ExampleTlsCertificates(&options.tls_certificates));
ASSERT_OK(server->Init(options));

server_.reset(new InProcessTestServer(std::move(server), location));
ASSERT_OK(server_->Start());
ASSERT_OK(ConnectClient());
}

void TearDown() {
if (server_) {
server_->Stop();
}
}

Status ConnectClient() {
auto options = FlightClientOptions();
CertKeyPair root_cert;
RETURN_NOT_OK(ExampleTlsCertificateRoot(&root_cert));
options.tls_root_certs = root_cert.pem_cert;
return FlightClient::Connect(server_->location(), options, &client_);
}

protected:
std::unique_ptr<FlightClient> client_;
std::unique_ptr<InProcessTestServer> server_;
};

TEST_F(TestFlightClient, ListFlights) {
std::unique_ptr<FlightListing> listing;
ASSERT_OK(client_->ListFlights(&listing));
Expand Down Expand Up @@ -620,5 +659,21 @@ TEST_F(TestAuthHandler, CheckPeerIdentity) {
ASSERT_EQ(result->body->ToString(), "user");
}

TEST_F(TestTls, DoAction) {
FlightCallOptions options;
options.timeout = TimeoutDuration{5.0};
Action action;
action.type = "test";
action.body = Buffer::FromString("");
std::unique_ptr<ResultStream> results;
ASSERT_OK(client_->DoAction(options, action, &results));
ASSERT_NE(results, nullptr);

std::unique_ptr<Result> result;
ASSERT_OK(results->Next(&result));
ASSERT_NE(result, nullptr);
ASSERT_EQ(result->body->ToString(), "Hello, world!");
}

} // namespace flight
} // namespace arrow
7 changes: 4 additions & 3 deletions cpp/src/arrow/flight/server.cc
Original file line number Diff line number Diff line change
Expand Up @@ -460,7 +460,7 @@ thread_local std::atomic<FlightServerBase::Impl*>
#endif

FlightServerOptions::FlightServerOptions(const Location& location_)
: location(location_), auth_handler(nullptr) {}
: location(location_), auth_handler(nullptr), tls_certificates() {}

FlightServerBase::FlightServerBase() { impl_.reset(new Impl); }

Expand All @@ -483,8 +483,9 @@ Status FlightServerBase::Init(FlightServerOptions& options) {
std::shared_ptr<grpc::ServerCredentials> creds;
if (scheme == kSchemeGrpcTls) {
grpc::SslServerCredentialsOptions ssl_options;
ssl_options.pem_key_cert_pairs.push_back(
{options.tls_private_key, options.tls_cert_chain});
for (const auto& pair : options.tls_certificates) {
ssl_options.pem_key_cert_pairs.push_back({pair.pem_key, pair.pem_cert});
}
creds = grpc::SslServerCredentials(ssl_options);
} else {
creds = grpc::InsecureServerCredentials();
Expand Down
3 changes: 1 addition & 2 deletions cpp/src/arrow/flight/server.h
Original file line number Diff line number Diff line change
Expand Up @@ -106,8 +106,7 @@ class ARROW_FLIGHT_EXPORT FlightServerOptions {

Location location;
std::unique_ptr<ServerAuthHandler> auth_handler;
std::string tls_cert_chain;
std::string tls_private_key;
std::vector<CertKeyPair> tls_certificates;
};

/// \brief Skeleton RPC server implementation which can be used to create
Expand Down
109 changes: 4 additions & 105 deletions cpp/src/arrow/flight/test-server.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,120 +25,19 @@

#include <gflags/gflags.h>

#include "arrow/buffer.h"
#include "arrow/io/test-common.h"
#include "arrow/record_batch.h"
#include "arrow/util/logging.h"

#include "arrow/flight/server.h"
#include "arrow/flight/server_auth.h"
#include "arrow/flight/test-util.h"
#include "arrow/flight/types.h"
#include "arrow/util/logging.h"

DEFINE_int32(port, 31337, "Server port to listen on");

namespace arrow {
namespace flight {

Status GetBatchForFlight(const Ticket& ticket, std::shared_ptr<RecordBatchReader>* out) {
if (ticket.ticket == "ticket-ints-1") {
BatchVector batches;
RETURN_NOT_OK(ExampleIntBatches(&batches));
*out = std::make_shared<BatchIterator>(batches[0]->schema(), batches);
return Status::OK();
} else if (ticket.ticket == "ticket-dicts-1") {
BatchVector batches;
RETURN_NOT_OK(ExampleDictBatches(&batches));
*out = std::make_shared<BatchIterator>(batches[0]->schema(), batches);
return Status::OK();
} else {
return Status::NotImplemented("no stream implemented for this ticket");
}
}

class FlightTestServer : public FlightServerBase {
Status ListFlights(const ServerCallContext& context, const Criteria* criteria,
std::unique_ptr<FlightListing>* listings) override {
std::vector<FlightInfo> flights = ExampleFlightInfo();
*listings = std::unique_ptr<FlightListing>(new SimpleFlightListing(flights));
return Status::OK();
}

Status GetFlightInfo(const ServerCallContext& context, const FlightDescriptor& request,
std::unique_ptr<FlightInfo>* out) override {
std::vector<FlightInfo> flights = ExampleFlightInfo();

for (const auto& info : flights) {
if (info.descriptor().Equals(request)) {
*out = std::unique_ptr<FlightInfo>(new FlightInfo(info));
return Status::OK();
}
}
return Status::Invalid("Flight not found: ", request.ToString());
}

Status DoGet(const ServerCallContext& context, const Ticket& request,
std::unique_ptr<FlightDataStream>* data_stream) override {
// Test for ARROW-5095
if (request.ticket == "ARROW-5095-fail") {
return Status::UnknownError("Server-side error");
}
if (request.ticket == "ARROW-5095-success") {
return Status::OK();
}

std::shared_ptr<RecordBatchReader> batch_reader;
RETURN_NOT_OK(GetBatchForFlight(request, &batch_reader));

*data_stream = std::unique_ptr<FlightDataStream>(new RecordBatchStream(batch_reader));
return Status::OK();
}

Status RunAction1(const Action& action, std::unique_ptr<ResultStream>* out) {
std::vector<Result> results;
for (int i = 0; i < 3; ++i) {
Result result;
std::string value = action.body->ToString() + "-part" + std::to_string(i);
RETURN_NOT_OK(Buffer::FromString(value, &result.body));
results.push_back(result);
}
*out = std::unique_ptr<ResultStream>(new SimpleResultStream(std::move(results)));
return Status::OK();
}

Status RunAction2(std::unique_ptr<ResultStream>* out) {
// Empty
*out = std::unique_ptr<ResultStream>(new SimpleResultStream({}));
return Status::OK();
}

Status DoAction(const ServerCallContext& context, const Action& action,
std::unique_ptr<ResultStream>* out) override {
if (action.type == "action1") {
return RunAction1(action, out);
} else if (action.type == "action2") {
return RunAction2(out);
} else {
return Status::NotImplemented(action.type);
}
}

Status ListActions(const ServerCallContext& context,
std::vector<ActionType>* out) override {
std::vector<ActionType> actions = ExampleActionTypes();
*out = std::move(actions);
return Status::OK();
}
};

} // namespace flight
} // namespace arrow

std::unique_ptr<arrow::flight::FlightTestServer> g_server;
std::unique_ptr<arrow::flight::FlightServerBase> g_server;

int main(int argc, char** argv) {
gflags::ParseCommandLineFlags(&argc, &argv, true);

g_server.reset(new arrow::flight::FlightTestServer);
g_server = arrow::flight::ExampleTestServer();

arrow::flight::Location location;
ARROW_CHECK_OK(arrow::flight::Location::ForGrpcTcp("0.0.0.0", FLAGS_port, &location));
Expand Down
Loading

0 comments on commit b7e8ed7

Please sign in to comment.