Skip to content

Commit

Permalink
Allow multiple TLS certificates in Flight
Browse files Browse the repository at this point in the history
  • Loading branch information
lidavidm committed Jun 13, 2019
1 parent 25b4a46 commit 9d2efa2
Show file tree
Hide file tree
Showing 6 changed files with 39 additions and 15 deletions.
7 changes: 4 additions & 3 deletions cpp/src/arrow/flight/server.cc
Original file line number Diff line number Diff line change
Expand Up @@ -460,7 +460,7 @@ thread_local std::atomic<FlightServerBase::Impl*>
#endif

FlightServerOptions::FlightServerOptions(const Location& location_)
: location(location_), auth_handler(nullptr) {}
: location(location_), auth_handler(nullptr), tls_certificates() {}

FlightServerBase::FlightServerBase() { impl_.reset(new Impl); }

Expand All @@ -483,8 +483,9 @@ Status FlightServerBase::Init(FlightServerOptions& options) {
std::shared_ptr<grpc::ServerCredentials> creds;
if (scheme == kSchemeGrpcTls) {
grpc::SslServerCredentialsOptions ssl_options;
ssl_options.pem_key_cert_pairs.push_back(
{options.tls_private_key, options.tls_cert_chain});
for (const auto& pair : options.tls_certificates) {
ssl_options.pem_key_cert_pairs.push_back({pair.pem_key, pair.pem_cert});
}
creds = grpc::SslServerCredentials(ssl_options);
} else {
creds = grpc::InsecureServerCredentials();
Expand Down
3 changes: 1 addition & 2 deletions cpp/src/arrow/flight/server.h
Original file line number Diff line number Diff line change
Expand Up @@ -106,8 +106,7 @@ class ARROW_FLIGHT_EXPORT FlightServerOptions {

Location location;
std::unique_ptr<ServerAuthHandler> auth_handler;
std::string tls_cert_chain;
std::string tls_private_key;
std::vector<CertKeyPair> tls_certificates;
};

/// \brief Skeleton RPC server implementation which can be used to create
Expand Down
6 changes: 6 additions & 0 deletions cpp/src/arrow/flight/types.cc
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,12 @@ Status Location::ForGrpcTcp(const std::string& host, const int port, Location* l
return Location::Parse(uri_string.str(), location);
}

Status Location::ForGrpcTls(const std::string& host, const int port, Location* location) {
std::stringstream uri_string;
uri_string << "grpc+tls://" << host << ':' << port;
return Location::Parse(uri_string.str(), location);
}

Status Location::ForGrpcUnix(const std::string& path, Location* location) {
std::stringstream uri_string;
uri_string << "grpc+unix://" << path;
Expand Down
16 changes: 16 additions & 0 deletions cpp/src/arrow/flight/types.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,15 @@ class Uri;

namespace flight {

/// \brief A TLS certificate plus key.
struct ARROW_FLIGHT_EXPORT CertKeyPair {
/// \brief The certificate in PEM format.
std::string pem_cert;

/// \brief The key in PEM format.
std::string pem_key;
};

/// \brief A type of action that can be performed with the DoAction RPC
struct ARROW_FLIGHT_EXPORT ActionType {
/// Name of action
Expand Down Expand Up @@ -145,6 +154,13 @@ struct ARROW_FLIGHT_EXPORT Location {
/// \param[out] location The resulting location
static Status ForGrpcTcp(const std::string& host, const int port, Location* location);

/// \brief Initialize a location for a TLS-enabled, gRPC-based Flight
/// service from a host and port
/// \param[in] host The hostname to connect to
/// \param[in] port The port
/// \param[out] location The resulting location
static Status ForGrpcTls(const std::string& host, const int port, Location* location);

/// \brief Initialize a location for a domain socket-based Flight
/// service
/// \param[in] path The path to the domain socket
Expand Down
14 changes: 6 additions & 8 deletions python/pyarrow/_flight.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -1016,12 +1016,12 @@ cdef class FlightServerBase:
cdef:
unique_ptr[PyFlightServer] server

def run(self, location, auth_handler=None,
tls_cert_chain=None, tls_private_key=None):
def run(self, location, auth_handler=None, tls_certificates=()):
cdef:
PyFlightServerVtable vtable = PyFlightServerVtable()
PyFlightServer* c_server
unique_ptr[CFlightServerOptions] c_options
CCertKeyPair c_cert

c_options.reset(new CFlightServerOptions(Location.unwrap(location)))

Expand All @@ -1032,12 +1032,10 @@ cdef class FlightServerBase:
c_options.get().auth_handler.reset(
(<ServerAuthHandler> auth_handler).to_handler())

if tls_cert_chain:
if not tls_private_key:
raise ValueError(
"Must provide both cert chain and private key")
c_options.get().tls_cert_chain = tobytes(tls_cert_chain)
c_options.get().tls_private_key = tobytes(tls_private_key)
for cert, key in tls_certificates:
c_cert.pem_cert = tobytes(cert)
c_cert.pem_key = tobytes(key)
c_options.get().tls_certificates.push_back(c_cert)

vtable.list_flights = &_list_flights
vtable.get_flight_info = &_get_flight_info
Expand Down
8 changes: 6 additions & 2 deletions python/pyarrow/includes/libarrow_flight.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -154,12 +154,16 @@ cdef extern from "arrow/flight/api.h" namespace "arrow" nogil:
CFlightCallOptions()
CTimeoutDuration timeout

cdef cppclass CCertKeyPair" arrow::flight::CertKeyPair":
CCertKeyPair()
c_string pem_cert
c_string pem_key

cdef cppclass CFlightServerOptions" arrow::flight::FlightServerOptions":
CFlightServerOptions(const CLocation& location)
CLocation location
unique_ptr[CServerAuthHandler] auth_handler
c_string tls_cert_chain
c_string tls_private_key
vector[CCertKeyPair] tls_certificates

cdef cppclass CFlightClientOptions" arrow::flight::FlightClientOptions":
CFlightClientOptions()
Expand Down

0 comments on commit 9d2efa2

Please sign in to comment.