diff --git a/cpp/src/arrow/flight/client.cc b/cpp/src/arrow/flight/client.cc index 1c927da782d43..2b7c69919763e 100644 --- a/cpp/src/arrow/flight/client.cc +++ b/cpp/src/arrow/flight/client.cc @@ -60,9 +60,6 @@ struct ClientRpc { grpc::ClientContext context; explicit ClientRpc(const FlightCallOptions& options) { - /// XXX workaround until we have a handshake in Connect - context.set_wait_for_ready(true); - if (options.timeout.count() >= 0) { std::chrono::system_clock::time_point deadline = std::chrono::time_point_cast( diff --git a/cpp/src/arrow/flight/flight-test.cc b/cpp/src/arrow/flight/flight-test.cc index cb7e57c85584b..b295878641523 100644 --- a/cpp/src/arrow/flight/flight-test.cc +++ b/cpp/src/arrow/flight/flight-test.cc @@ -176,29 +176,22 @@ TEST(TestFlight, ConnectUri) { class TestFlightClient : public ::testing::Test { public: - // Uncomment these when you want to run the server separately for - // debugging/valgrind/gdb + void SetUp() { + Location location; + std::unique_ptr server = ExampleTestServer(); - // void SetUp() { - // port_ = 92358; - // ASSERT_OK(ConnectClient()); - // } - // void TearDown() {} + ASSERT_OK(Location::ForGrpcTcp("localhost", GetListenPort(), &location)); + FlightServerOptions options(location); + ASSERT_OK(server->Init(options)); - void SetUp() { - server_.reset(new TestServer("flight-test-server")); - server_->Start(); - port_ = server_->port(); + server_.reset(new InProcessTestServer(std::move(server), location)); + ASSERT_OK(server_->Start()); ASSERT_OK(ConnectClient()); } void TearDown() { server_->Stop(); } - Status ConnectClient() { - Location location; - RETURN_NOT_OK(Location::ForGrpcTcp("localhost", port_, &location)); - return FlightClient::Connect(location, &client_); - } + Status ConnectClient() { return FlightClient::Connect(server_->location(), &client_); } template void CheckDoGet(const FlightDescriptor& descr, const BatchVector& expected_batches, @@ -236,7 +229,7 @@ class TestFlightClient : public ::testing::Test { protected: int port_; std::unique_ptr client_; - std::unique_ptr server_; + std::unique_ptr server_; }; class AuthTestServer : public FlightServerBase { @@ -249,6 +242,16 @@ class AuthTestServer : public FlightServerBase { } }; +class TlsTestServer : public FlightServerBase { + Status DoAction(const ServerCallContext& context, const Action& action, + std::unique_ptr* result) override { + std::shared_ptr buf; + RETURN_NOT_OK(Buffer::FromString("Hello, world!", &buf)); + *result = std::unique_ptr(new SimpleResultStream({Result{buf}})); + return Status::OK(); + } +}; + class DoPutTestServer : public FlightServerBase { public: Status DoPut(const ServerCallContext& context, @@ -336,6 +339,42 @@ class TestDoPut : public ::testing::Test { DoPutTestServer* do_put_server_; }; +class TestTls : public ::testing::Test { + public: + void SetUp() { + Location location; + std::unique_ptr server(new TlsTestServer); + + ASSERT_OK(Location::ForGrpcTls("localhost", GetListenPort(), &location)); + FlightServerOptions options(location); + ASSERT_RAISES(UnknownError, server->Init(options)); + ASSERT_OK(ExampleTlsCertificates(&options.tls_certificates)); + ASSERT_OK(server->Init(options)); + + server_.reset(new InProcessTestServer(std::move(server), location)); + ASSERT_OK(server_->Start()); + ASSERT_OK(ConnectClient()); + } + + void TearDown() { + if (server_) { + server_->Stop(); + } + } + + Status ConnectClient() { + auto options = FlightClientOptions(); + CertKeyPair root_cert; + RETURN_NOT_OK(ExampleTlsCertificateRoot(&root_cert)); + options.tls_root_certs = root_cert.pem_cert; + return FlightClient::Connect(server_->location(), options, &client_); + } + + protected: + std::unique_ptr client_; + std::unique_ptr server_; +}; + TEST_F(TestFlightClient, ListFlights) { std::unique_ptr listing; ASSERT_OK(client_->ListFlights(&listing)); @@ -620,5 +659,21 @@ TEST_F(TestAuthHandler, CheckPeerIdentity) { ASSERT_EQ(result->body->ToString(), "user"); } +TEST_F(TestTls, DoAction) { + FlightCallOptions options; + options.timeout = TimeoutDuration{5.0}; + Action action; + action.type = "test"; + action.body = Buffer::FromString(""); + std::unique_ptr results; + ASSERT_OK(client_->DoAction(options, action, &results)); + ASSERT_NE(results, nullptr); + + std::unique_ptr result; + ASSERT_OK(results->Next(&result)); + ASSERT_NE(result, nullptr); + ASSERT_EQ(result->body->ToString(), "Hello, world!"); +} + } // namespace flight } // namespace arrow diff --git a/cpp/src/arrow/flight/server.cc b/cpp/src/arrow/flight/server.cc index 9b6bf6ca410c8..6f3c466c4adef 100644 --- a/cpp/src/arrow/flight/server.cc +++ b/cpp/src/arrow/flight/server.cc @@ -460,7 +460,7 @@ thread_local std::atomic #endif FlightServerOptions::FlightServerOptions(const Location& location_) - : location(location_), auth_handler(nullptr) {} + : location(location_), auth_handler(nullptr), tls_certificates() {} FlightServerBase::FlightServerBase() { impl_.reset(new Impl); } @@ -483,8 +483,9 @@ Status FlightServerBase::Init(FlightServerOptions& options) { std::shared_ptr creds; if (scheme == kSchemeGrpcTls) { grpc::SslServerCredentialsOptions ssl_options; - ssl_options.pem_key_cert_pairs.push_back( - {options.tls_private_key, options.tls_cert_chain}); + for (const auto& pair : options.tls_certificates) { + ssl_options.pem_key_cert_pairs.push_back({pair.pem_key, pair.pem_cert}); + } creds = grpc::SslServerCredentials(ssl_options); } else { creds = grpc::InsecureServerCredentials(); diff --git a/cpp/src/arrow/flight/server.h b/cpp/src/arrow/flight/server.h index 7164b64c4aba3..c1bcb5c0a3dd1 100644 --- a/cpp/src/arrow/flight/server.h +++ b/cpp/src/arrow/flight/server.h @@ -106,8 +106,7 @@ class ARROW_FLIGHT_EXPORT FlightServerOptions { Location location; std::unique_ptr auth_handler; - std::string tls_cert_chain; - std::string tls_private_key; + std::vector tls_certificates; }; /// \brief Skeleton RPC server implementation which can be used to create diff --git a/cpp/src/arrow/flight/test-server.cc b/cpp/src/arrow/flight/test-server.cc index f72fd3caeead1..87ef62f17d785 100644 --- a/cpp/src/arrow/flight/test-server.cc +++ b/cpp/src/arrow/flight/test-server.cc @@ -25,120 +25,19 @@ #include -#include "arrow/buffer.h" -#include "arrow/io/test-common.h" -#include "arrow/record_batch.h" -#include "arrow/util/logging.h" - #include "arrow/flight/server.h" -#include "arrow/flight/server_auth.h" #include "arrow/flight/test-util.h" +#include "arrow/flight/types.h" +#include "arrow/util/logging.h" DEFINE_int32(port, 31337, "Server port to listen on"); -namespace arrow { -namespace flight { - -Status GetBatchForFlight(const Ticket& ticket, std::shared_ptr* out) { - if (ticket.ticket == "ticket-ints-1") { - BatchVector batches; - RETURN_NOT_OK(ExampleIntBatches(&batches)); - *out = std::make_shared(batches[0]->schema(), batches); - return Status::OK(); - } else if (ticket.ticket == "ticket-dicts-1") { - BatchVector batches; - RETURN_NOT_OK(ExampleDictBatches(&batches)); - *out = std::make_shared(batches[0]->schema(), batches); - return Status::OK(); - } else { - return Status::NotImplemented("no stream implemented for this ticket"); - } -} - -class FlightTestServer : public FlightServerBase { - Status ListFlights(const ServerCallContext& context, const Criteria* criteria, - std::unique_ptr* listings) override { - std::vector flights = ExampleFlightInfo(); - *listings = std::unique_ptr(new SimpleFlightListing(flights)); - return Status::OK(); - } - - Status GetFlightInfo(const ServerCallContext& context, const FlightDescriptor& request, - std::unique_ptr* out) override { - std::vector flights = ExampleFlightInfo(); - - for (const auto& info : flights) { - if (info.descriptor().Equals(request)) { - *out = std::unique_ptr(new FlightInfo(info)); - return Status::OK(); - } - } - return Status::Invalid("Flight not found: ", request.ToString()); - } - - Status DoGet(const ServerCallContext& context, const Ticket& request, - std::unique_ptr* data_stream) override { - // Test for ARROW-5095 - if (request.ticket == "ARROW-5095-fail") { - return Status::UnknownError("Server-side error"); - } - if (request.ticket == "ARROW-5095-success") { - return Status::OK(); - } - - std::shared_ptr batch_reader; - RETURN_NOT_OK(GetBatchForFlight(request, &batch_reader)); - - *data_stream = std::unique_ptr(new RecordBatchStream(batch_reader)); - return Status::OK(); - } - - Status RunAction1(const Action& action, std::unique_ptr* out) { - std::vector results; - for (int i = 0; i < 3; ++i) { - Result result; - std::string value = action.body->ToString() + "-part" + std::to_string(i); - RETURN_NOT_OK(Buffer::FromString(value, &result.body)); - results.push_back(result); - } - *out = std::unique_ptr(new SimpleResultStream(std::move(results))); - return Status::OK(); - } - - Status RunAction2(std::unique_ptr* out) { - // Empty - *out = std::unique_ptr(new SimpleResultStream({})); - return Status::OK(); - } - - Status DoAction(const ServerCallContext& context, const Action& action, - std::unique_ptr* out) override { - if (action.type == "action1") { - return RunAction1(action, out); - } else if (action.type == "action2") { - return RunAction2(out); - } else { - return Status::NotImplemented(action.type); - } - } - - Status ListActions(const ServerCallContext& context, - std::vector* out) override { - std::vector actions = ExampleActionTypes(); - *out = std::move(actions); - return Status::OK(); - } -}; - -} // namespace flight -} // namespace arrow - -std::unique_ptr g_server; +std::unique_ptr g_server; int main(int argc, char** argv) { gflags::ParseCommandLineFlags(&argc, &argv, true); - g_server.reset(new arrow::flight::FlightTestServer); + g_server = arrow::flight::ExampleTestServer(); arrow::flight::Location location; ARROW_CHECK_OK(arrow::flight::Location::ForGrpcTcp("0.0.0.0", FLAGS_port, &location)); diff --git a/cpp/src/arrow/flight/test-util.cc b/cpp/src/arrow/flight/test-util.cc index b20a4cbf9dea1..7dd78fdd6eb2d 100644 --- a/cpp/src/arrow/flight/test-util.cc +++ b/cpp/src/arrow/flight/test-util.cc @@ -22,6 +22,7 @@ #include #endif +#include #include #include @@ -154,6 +155,101 @@ InProcessTestServer::~InProcessTestServer() { } } +Status GetBatchForFlight(const Ticket& ticket, std::shared_ptr* out) { + if (ticket.ticket == "ticket-ints-1") { + BatchVector batches; + RETURN_NOT_OK(ExampleIntBatches(&batches)); + *out = std::make_shared(batches[0]->schema(), batches); + return Status::OK(); + } else if (ticket.ticket == "ticket-dicts-1") { + BatchVector batches; + RETURN_NOT_OK(ExampleDictBatches(&batches)); + *out = std::make_shared(batches[0]->schema(), batches); + return Status::OK(); + } else { + return Status::NotImplemented("no stream implemented for this ticket"); + } +} + +class FlightTestServer : public FlightServerBase { + Status ListFlights(const ServerCallContext& context, const Criteria* criteria, + std::unique_ptr* listings) override { + std::vector flights = ExampleFlightInfo(); + *listings = std::unique_ptr(new SimpleFlightListing(flights)); + return Status::OK(); + } + + Status GetFlightInfo(const ServerCallContext& context, const FlightDescriptor& request, + std::unique_ptr* out) override { + std::vector flights = ExampleFlightInfo(); + + for (const auto& info : flights) { + if (info.descriptor().Equals(request)) { + *out = std::unique_ptr(new FlightInfo(info)); + return Status::OK(); + } + } + return Status::Invalid("Flight not found: ", request.ToString()); + } + + Status DoGet(const ServerCallContext& context, const Ticket& request, + std::unique_ptr* data_stream) override { + // Test for ARROW-5095 + if (request.ticket == "ARROW-5095-fail") { + return Status::UnknownError("Server-side error"); + } + if (request.ticket == "ARROW-5095-success") { + return Status::OK(); + } + + std::shared_ptr batch_reader; + RETURN_NOT_OK(GetBatchForFlight(request, &batch_reader)); + + *data_stream = std::unique_ptr(new RecordBatchStream(batch_reader)); + return Status::OK(); + } + + Status RunAction1(const Action& action, std::unique_ptr* out) { + std::vector results; + for (int i = 0; i < 3; ++i) { + Result result; + std::string value = action.body->ToString() + "-part" + std::to_string(i); + RETURN_NOT_OK(Buffer::FromString(value, &result.body)); + results.push_back(result); + } + *out = std::unique_ptr(new SimpleResultStream(std::move(results))); + return Status::OK(); + } + + Status RunAction2(std::unique_ptr* out) { + // Empty + *out = std::unique_ptr(new SimpleResultStream({})); + return Status::OK(); + } + + Status DoAction(const ServerCallContext& context, const Action& action, + std::unique_ptr* out) override { + if (action.type == "action1") { + return RunAction1(action, out); + } else if (action.type == "action2") { + return RunAction2(out); + } else { + return Status::NotImplemented(action.type); + } + } + + Status ListActions(const ServerCallContext& context, + std::vector* out) override { + std::vector actions = ExampleActionTypes(); + *out = std::move(actions); + return Status::OK(); + } +}; + +std::unique_ptr ExampleTestServer() { + return std::unique_ptr(new FlightTestServer); +} + Status MakeFlightInfo(const Schema& schema, const FlightDescriptor& descriptor, const std::vector& endpoints, int64_t total_records, int64_t total_bytes, FlightInfo::Data* out) { @@ -286,5 +382,70 @@ Status TestClientAuthHandler::GetToken(std::string* token) { return Status::OK(); } +Status GetTestResourceRoot(std::string* out) { + const char* c_root = std::getenv("ARROW_TEST_DATA"); + if (!c_root) { + return Status::IOError("Test resources not found, set ARROW_TEST_DATA"); + } + *out = std::string(c_root); + return Status::OK(); +} + +Status ExampleTlsCertificates(std::vector* out) { + std::string root; + RETURN_NOT_OK(GetTestResourceRoot(&root)); + + *out = std::vector(); + for (int i = 0; i < 2; i++) { + try { + std::stringstream cert_path; + cert_path << root << "/flight/cert" << i << ".pem"; + std::stringstream key_path; + key_path << root << "/flight/cert" << i << ".key"; + + std::ifstream cert_file(cert_path.str()); + if (!cert_file) { + return Status::IOError("Could not open certificate: " + cert_path.str()); + } + std::stringstream cert; + cert << cert_file.rdbuf(); + + std::ifstream key_file(key_path.str()); + if (!key_file) { + return Status::IOError("Could not open key: " + key_path.str()); + } + std::stringstream key; + key << key_file.rdbuf(); + + out->push_back(CertKeyPair{cert.str(), key.str()}); + } catch (const std::ifstream::failure& e) { + return Status::IOError(e.what()); + } + } + return Status::OK(); +} + +Status ExampleTlsCertificateRoot(CertKeyPair* out) { + std::string root; + RETURN_NOT_OK(GetTestResourceRoot(&root)); + + std::stringstream path; + path << root << "/flight/root-ca.pem"; + + try { + std::ifstream cert_file(path.str()); + if (!cert_file) { + return Status::IOError("Could not open certificate: " + path.str()); + } + std::stringstream cert; + cert << cert_file.rdbuf(); + out->pem_cert = cert.str(); + out->pem_key = ""; + return Status::OK(); + } catch (const std::ifstream::failure& e) { + return Status::IOError(e.what()); + } +} + } // namespace flight } // namespace arrow diff --git a/cpp/src/arrow/flight/test-util.h b/cpp/src/arrow/flight/test-util.h index 2e1f4b0ed15c9..5b02630b432b5 100644 --- a/cpp/src/arrow/flight/test-util.h +++ b/cpp/src/arrow/flight/test-util.h @@ -86,6 +86,10 @@ class ARROW_FLIGHT_EXPORT InProcessTestServer { std::thread thread_; }; +/// \brief Create a simple Flight server for testing +ARROW_FLIGHT_EXPORT +std::unique_ptr ExampleTestServer(); + // ---------------------------------------------------------------------- // A RecordBatchReader for serving a sequence of in-memory record batches @@ -184,5 +188,11 @@ class ARROW_FLIGHT_EXPORT TestClientAuthHandler : public ClientAuthHandler { std::string password_; }; +ARROW_FLIGHT_EXPORT +Status ExampleTlsCertificates(std::vector* out); + +ARROW_FLIGHT_EXPORT +Status ExampleTlsCertificateRoot(CertKeyPair* out); + } // namespace flight } // namespace arrow diff --git a/cpp/src/arrow/flight/types.cc b/cpp/src/arrow/flight/types.cc index dadb51066cf3f..d982efce5cea4 100644 --- a/cpp/src/arrow/flight/types.cc +++ b/cpp/src/arrow/flight/types.cc @@ -96,6 +96,12 @@ Status Location::ForGrpcTcp(const std::string& host, const int port, Location* l return Location::Parse(uri_string.str(), location); } +Status Location::ForGrpcTls(const std::string& host, const int port, Location* location) { + std::stringstream uri_string; + uri_string << "grpc+tls://" << host << ':' << port; + return Location::Parse(uri_string.str(), location); +} + Status Location::ForGrpcUnix(const std::string& path, Location* location) { std::stringstream uri_string; uri_string << "grpc+unix://" << path; diff --git a/cpp/src/arrow/flight/types.h b/cpp/src/arrow/flight/types.h index 8d37225263606..e5f7bcdd550a1 100644 --- a/cpp/src/arrow/flight/types.h +++ b/cpp/src/arrow/flight/types.h @@ -49,6 +49,15 @@ class Uri; namespace flight { +/// \brief A TLS certificate plus key. +struct ARROW_FLIGHT_EXPORT CertKeyPair { + /// \brief The certificate in PEM format. + std::string pem_cert; + + /// \brief The key in PEM format. + std::string pem_key; +}; + /// \brief A type of action that can be performed with the DoAction RPC struct ARROW_FLIGHT_EXPORT ActionType { /// Name of action @@ -145,6 +154,13 @@ struct ARROW_FLIGHT_EXPORT Location { /// \param[out] location The resulting location static Status ForGrpcTcp(const std::string& host, const int port, Location* location); + /// \brief Initialize a location for a TLS-enabled, gRPC-based Flight + /// service from a host and port + /// \param[in] host The hostname to connect to + /// \param[in] port The port + /// \param[out] location The resulting location + static Status ForGrpcTls(const std::string& host, const int port, Location* location); + /// \brief Initialize a location for a domain socket-based Flight /// service /// \param[in] path The path to the domain socket diff --git a/python/pyarrow/_flight.pyx b/python/pyarrow/_flight.pyx index c68263507d7d6..c916e6bcf56ca 100644 --- a/python/pyarrow/_flight.pyx +++ b/python/pyarrow/_flight.pyx @@ -57,6 +57,13 @@ cdef class FlightCallOptions: "'{}'".format(type(obj))) +_CertKeyPair = collections.namedtuple('_CertKeyPair', ['cert', 'key']) + + +class CertKeyPair(_CertKeyPair): + """A TLS certificate and key for use in Flight.""" + + cdef class Action: """An action executable on a Flight service.""" cdef: @@ -227,6 +234,16 @@ cdef class Location: check_status(CLocation.ForGrpcTcp(c_host, c_port, &result.location)) return result + @staticmethod + def for_grpc_tls(host, port): + """Create a Location for a TLS-based gRPC service.""" + cdef: + c_string c_host = tobytes(host) + int c_port = port + Location result = Location.__new__(Location) + check_status(CLocation.ForGrpcTls(c_host, c_port, &result.location)) + return result + @staticmethod def for_grpc_unix(path): """Create a Location for a domain socket-based gRPC service.""" @@ -1016,12 +1033,12 @@ cdef class FlightServerBase: cdef: unique_ptr[PyFlightServer] server - def run(self, location, auth_handler=None, - tls_cert_chain=None, tls_private_key=None): + def run(self, location, auth_handler=None, tls_certificates=None): cdef: PyFlightServerVtable vtable = PyFlightServerVtable() PyFlightServer* c_server unique_ptr[CFlightServerOptions] c_options + CCertKeyPair c_cert c_options.reset(new CFlightServerOptions(Location.unwrap(location))) @@ -1032,12 +1049,11 @@ cdef class FlightServerBase: c_options.get().auth_handler.reset( ( auth_handler).to_handler()) - if tls_cert_chain: - if not tls_private_key: - raise ValueError( - "Must provide both cert chain and private key") - c_options.get().tls_cert_chain = tobytes(tls_cert_chain) - c_options.get().tls_private_key = tobytes(tls_private_key) + if tls_certificates: + for cert, key in tls_certificates: + c_cert.pem_cert = tobytes(cert) + c_cert.pem_key = tobytes(key) + c_options.get().tls_certificates.push_back(c_cert) vtable.list_flights = &_list_flights vtable.get_flight_info = &_get_flight_info diff --git a/python/pyarrow/flight.py b/python/pyarrow/flight.py index 37a21e4318163..05198e4b856a4 100644 --- a/python/pyarrow/flight.py +++ b/python/pyarrow/flight.py @@ -25,6 +25,7 @@ from pyarrow._flight import ( # noqa Action, ActionType, + CertKeyPair, DescriptorType, FlightCallOptions, FlightClient, diff --git a/python/pyarrow/includes/libarrow_flight.pxd b/python/pyarrow/includes/libarrow_flight.pxd index 4b749903d3d8f..14d1ed163d186 100644 --- a/python/pyarrow/includes/libarrow_flight.pxd +++ b/python/pyarrow/includes/libarrow_flight.pxd @@ -88,6 +88,8 @@ cdef extern from "arrow/flight/api.h" namespace "arrow" nogil: @staticmethod CStatus ForGrpcTcp(c_string& host, int port, CLocation* location) @staticmethod + CStatus ForGrpcTls(c_string& host, int port, CLocation* location) + @staticmethod CStatus ForGrpcUnix(c_string& path, CLocation* location) cdef cppclass CFlightEndpoint" arrow::flight::FlightEndpoint": @@ -154,12 +156,16 @@ cdef extern from "arrow/flight/api.h" namespace "arrow" nogil: CFlightCallOptions() CTimeoutDuration timeout + cdef cppclass CCertKeyPair" arrow::flight::CertKeyPair": + CCertKeyPair() + c_string pem_cert + c_string pem_key + cdef cppclass CFlightServerOptions" arrow::flight::FlightServerOptions": CFlightServerOptions(const CLocation& location) CLocation location unique_ptr[CServerAuthHandler] auth_handler - c_string tls_cert_chain - c_string tls_private_key + vector[CCertKeyPair] tls_certificates cdef cppclass CFlightClientOptions" arrow::flight::FlightClientOptions": CFlightClientOptions() diff --git a/python/pyarrow/tests/test_flight.py b/python/pyarrow/tests/test_flight.py index 9ce2264ee31b1..a7e6e340c68dd 100644 --- a/python/pyarrow/tests/test_flight.py +++ b/python/pyarrow/tests/test_flight.py @@ -28,12 +28,52 @@ import pyarrow as pa +from pathlib import Path from pyarrow.compat import tobytes flight = pytest.importorskip("pyarrow.flight") +def resource_root(): + """Get the path to the test resources directory.""" + if not os.environ.get("ARROW_TEST_DATA"): + raise RuntimeError("Test resources not found; set " + "ARROW_TEST_DATA to /testing") + return Path(os.environ["ARROW_TEST_DATA"]) / "flight" + + +def read_flight_resource(path): + """Get the contents of a test resource file.""" + root = resource_root() + if not root: + return None + try: + with (root / path).open("rb") as f: + return f.read() + except FileNotFoundError as e: + raise RuntimeError( + "Test resource {} not found; did you initialize the " + "test resource submodule?".format(root / path)) from e + + +def example_tls_certs(): + """Get the paths to test TLS certificates.""" + return { + "root_cert": read_flight_resource("root-ca.pem"), + "certificates": [ + flight.CertKeyPair( + cert=read_flight_resource("cert0.pem"), + key=read_flight_resource("cert0.key"), + ), + flight.CertKeyPair( + cert=read_flight_resource("cert1.pem"), + key=read_flight_resource("cert1.key"), + ), + ] + } + + def simple_ints_table(): data = [ pa.array([-10, -5, 0, 5, 10]) @@ -245,6 +285,7 @@ def get_token(self): def flight_server(server_base, *args, **kwargs): """Spawn a Flight server on a free port, shutting it down when done.""" auth_handler = kwargs.pop('auth_handler', None) + tls_certificates = kwargs.pop('tls_certificates', None) location = kwargs.pop('location', None) if location is None: @@ -254,7 +295,10 @@ def flight_server(server_base, *args, **kwargs): sock.bind(('', 0)) sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) port = sock.getsockname()[1] - location = flight.Location.for_grpc_tcp("localhost", port) + ctor = flight.Location.for_grpc_tcp + if tls_certificates: + ctor = flight.Location.for_grpc_tls + location = ctor("localhost", port) else: port = None @@ -262,11 +306,26 @@ def flight_server(server_base, *args, **kwargs): server_instance = server_base(*args, **ctor_kwargs) def _server_thread(): - server_instance.run(location, auth_handler=auth_handler) + server_instance.run( + location, + auth_handler=auth_handler, + tls_certificates=tls_certificates, + ) thread = threading.Thread(target=_server_thread, daemon=True) thread.start() + # Wait for server to start + client = flight.FlightClient.connect(location) + while True: + try: + list(client.list_flights()) + except Exception as e: + if 'Connect Failed' in str(e): + time.sleep(0.025) + continue + break + yield location server_instance.shutdown() @@ -471,3 +530,32 @@ def test_location_invalid(): server = ConstantFlightServer() with pytest.raises(pa.ArrowInvalid, match=".*Cannot parse URI:.*"): server.run("%") + + +@pytest.mark.slow +def test_tls_fails(): + """Make sure clients cannot connect when cert verification fails.""" + certs = example_tls_certs() + + with flight_server( + ConstantFlightServer, tls_certificates=certs["certificates"] + ) as server_location: + # Ensure client doesn't connect when certificate verification + # fails (this is a slow test since gRPC does retry a few times) + client = flight.FlightClient.connect(server_location) + with pytest.raises(pa.ArrowIOError, match="Connect Failed"): + client.do_get(flight.Ticket(b'ints')) + + +def test_tls_do_get(): + """Try a simple do_get call over TLS.""" + table = simple_ints_table() + certs = example_tls_certs() + + with flight_server( + ConstantFlightServer, tls_certificates=certs["certificates"] + ) as server_location: + client = flight.FlightClient.connect( + server_location, tls_root_certs=certs["root_cert"]) + data = client.do_get(flight.Ticket(b'ints')).read_all() + assert data.equals(table) diff --git a/testing b/testing index bf0abe442bf7e..12f9dbd2a37ee 160000 --- a/testing +++ b/testing @@ -1 +1 @@ -Subproject commit bf0abe442bf7e313380452c8972692940f4e56b6 +Subproject commit 12f9dbd2a37eea6fa370e108a1d797ee1167724a