Skip to content

Commit

Permalink
ARROW-8183: [C++][Python][FlightRPC] Expose transport error metadata
Browse files Browse the repository at this point in the history
This is the C++ and Python version of ARROW-8181

- [x] c++
- [x] python

Closes #6747 from rymurr/ARROW-8183

Authored-by: Ryan Murray <rymurr@dremio.com>
Signed-off-by: David Li <li.davidm96@gmail.com>
  • Loading branch information
Ryan Murray authored and lidavidm committed Mar 30, 2020
1 parent 6be085f commit 0facdc7
Show file tree
Hide file tree
Showing 8 changed files with 157 additions and 32 deletions.
57 changes: 52 additions & 5 deletions cpp/src/arrow/flight/flight_test.cc
Expand Up @@ -15,6 +15,9 @@
// specific language governing permissions and limitations
// under the License.

#include <gmock/gmock.h>
#include <gtest/gtest.h>

#include <atomic>
#include <chrono>
#include <cstdint>
Expand All @@ -27,17 +30,13 @@
#include <thread>
#include <vector>

#include <gmock/gmock.h>
#include <gtest/gtest.h>

#include "arrow/flight/api.h"
#include "arrow/ipc/test_common.h"
#include "arrow/status.h"
#include "arrow/testing/gtest_util.h"
#include "arrow/testing/util.h"
#include "arrow/util/make_unique.h"

#include "arrow/flight/api.h"

#ifdef GRPCPP_GRPCPP_H
#error "gRPC headers should not be in public API"
#endif
Expand Down Expand Up @@ -728,6 +727,22 @@ class ReportContextTestServer : public FlightServerBase {
}
};

class ErrorMiddlewareServer : public FlightServerBase {
Status DoAction(const ServerCallContext& context, const Action& action,
std::unique_ptr<ResultStream>* result) override {
std::shared_ptr<Buffer> buf;
std::string msg = "error_message";
Status s = Buffer::FromString("", &buf);

std::shared_ptr<FlightStatusDetail> flightStatusDetail(
new FlightStatusDetail(FlightStatusCode::Failed, msg));
*result = std::unique_ptr<ResultStream>(new SimpleResultStream({Result{buf}}));
Status s_err = Status(StatusCode::ExecutionError, "test failed", flightStatusDetail);
RETURN_NOT_OK(s_err);
return Status::OK();
}
};

class PropagatingTestServer : public FlightServerBase {
public:
explicit PropagatingTestServer(std::unique_ptr<FlightClient> client)
Expand Down Expand Up @@ -858,6 +873,38 @@ class TestPropagatingMiddleware : public ::testing::Test {
std::shared_ptr<PropagatingClientMiddlewareFactory> client_middleware_;
};

class TestErrorMiddleware : public ::testing::Test {
public:
void SetUp() {
ASSERT_OK(MakeServer<ErrorMiddlewareServer>(
&server_, &client_, [](FlightServerOptions* options) { return Status::OK(); },
[](FlightClientOptions* options) { return Status::OK(); }));
}

void TearDown() { ASSERT_OK(server_->Shutdown()); }

protected:
std::unique_ptr<FlightClient> client_;
std::unique_ptr<FlightServerBase> server_;
};

TEST_F(TestErrorMiddleware, TestMetadata) {
Action action;
std::unique_ptr<ResultStream> stream;

// Run action1
action.type = "action1";

const std::string action1_value = "action1-content";
ASSERT_OK(Buffer::FromString(action1_value, &action.body));
Status s = client_->DoAction(action, &stream);
ASSERT_FALSE(s.ok());
std::shared_ptr<FlightStatusDetail> flightStatusDetail =
FlightStatusDetail::UnwrapStatus(s);
ASSERT_TRUE(flightStatusDetail);
ASSERT_EQ(flightStatusDetail->extra_info(), "error_message");
}

TEST_F(TestFlightClient, ListFlights) {
std::unique_ptr<FlightListing> listing;
ASSERT_OK(client_->ListFlights(&listing));
Expand Down
38 changes: 27 additions & 11 deletions cpp/src/arrow/flight/internal.cc
Expand Up @@ -16,8 +16,6 @@
// under the License.

#include "arrow/flight/internal.h"
#include "arrow/flight/platform.h"
#include "arrow/flight/protocol_internal.h"

#include <cstddef>
#include <map>
Expand All @@ -26,6 +24,10 @@
#include <string>
#include <utility>

#include "arrow/flight/platform.h"
#include "arrow/flight/protocol_internal.h"
#include "arrow/flight/types.h"

#ifdef GRPCPP_PP_INCLUDE
#include <grpcpp/grpcpp.h>
#else
Expand All @@ -49,6 +51,7 @@ const char* kGrpcAuthHeader = "auth-token-bin";
const char* kGrpcStatusCodeHeader = "x-arrow-status";
const char* kGrpcStatusMessageHeader = "x-arrow-status-message-bin";
const char* kGrpcStatusDetailHeader = "x-arrow-status-detail-bin";
const char* kBinaryErrorDetailsKey = "grpc-status-details-bin";

static Status StatusCodeFromString(const grpc::string_ref& code_ref, StatusCode* code) {
// Bounce through std::string to get a proper null-terminated C string
Expand Down Expand Up @@ -80,15 +83,16 @@ static Status StatusCodeFromString(const grpc::string_ref& code_ref, StatusCode*

/// Try to extract a status from gRPC trailers.
/// Return Status::OK if found, an error otherwise.
static Status FromGrpcContext(const grpc::ClientContext& ctx, Status* status) {
static Status FromGrpcContext(const grpc::ClientContext& ctx, Status* status,
std::shared_ptr<FlightStatusDetail> flightStatusDetail) {
const std::multimap<grpc::string_ref, grpc::string_ref>& trailers =
ctx.GetServerTrailingMetadata();
const auto code_val = trailers.find(kGrpcStatusCodeHeader);
if (code_val == trailers.end()) {
return Status::IOError("Status code header not found");
}

const grpc::string_ref code_ref = (*code_val).second;
const grpc::string_ref code_ref = code_val->second;
StatusCode code = {};
RETURN_NOT_OK(StatusCodeFromString(code_ref, &code));

Expand All @@ -97,15 +101,25 @@ static Status FromGrpcContext(const grpc::ClientContext& ctx, Status* status) {
return Status::IOError("Status message header not found");
}

const grpc::string_ref message_ref = (*message_val).second;
const grpc::string_ref message_ref = message_val->second;
std::string message = std::string(message_ref.data(), message_ref.size());
const auto detail_val = trailers.find(kGrpcStatusDetailHeader);
if (detail_val != trailers.end()) {
const grpc::string_ref detail_ref = (*detail_val).second;
const grpc::string_ref detail_ref = detail_val->second;
message += ". Detail: ";
message += std::string(detail_ref.data(), detail_ref.size());
}
*status = Status(code, message);
const auto grpc_detail_val = trailers.find(kBinaryErrorDetailsKey);
if (grpc_detail_val != trailers.end()) {
const grpc::string_ref detail_ref = grpc_detail_val->second;
std::string bin_detail = std::string(detail_ref.data(), detail_ref.size());
if (!flightStatusDetail) {
flightStatusDetail =
std::make_shared<FlightStatusDetail>(FlightStatusCode::Internal);
}
flightStatusDetail->set_extra_info(bin_detail);
}
*status = Status(code, message, flightStatusDetail);
return Status::OK();
}

Expand Down Expand Up @@ -190,15 +204,13 @@ Status FromGrpcStatus(const grpc::Status& grpc_status, grpc::ClientContext* ctx)
if (!status.ok() && ctx) {
Status arrow_status;

if (!FromGrpcContext(*ctx, &arrow_status).ok()) {
if (!FromGrpcContext(*ctx, &arrow_status, FlightStatusDetail::UnwrapStatus(status))
.ok()) {
// If we fail to decode a more detailed status from the headers,
// proceed normally
return status;
}

if (status.detail()) {
return arrow_status.WithDetail(status.detail());
}
return arrow_status;
}
return status;
Expand Down Expand Up @@ -266,6 +278,10 @@ grpc::Status ToGrpcStatus(const Status& arrow_status, grpc::ServerContext* ctx)
const std::string detail_string = arrow_status.detail()->ToString();
ctx->AddTrailingMetadata(internal::kGrpcStatusDetailHeader, detail_string);
}
auto fsd = FlightStatusDetail::UnwrapStatus(arrow_status);
if (fsd && !fsd->extra_info().empty()) {
ctx->AddTrailingMetadata(internal::kBinaryErrorDetailsKey, fsd->extra_info());
}
}

return status;
Expand Down
3 changes: 3 additions & 0 deletions cpp/src/arrow/flight/internal.h
Expand Up @@ -79,6 +79,9 @@ extern const char* kGrpcStatusMessageHeader;
ARROW_FLIGHT_EXPORT
extern const char* kGrpcStatusDetailHeader;

ARROW_FLIGHT_EXPORT
extern const char* kBinaryErrorDetailsKey;

ARROW_FLIGHT_EXPORT
Status SchemaToString(const Schema& schema, std::string* out);

Expand Down
13 changes: 13 additions & 0 deletions cpp/src/arrow/flight/types.cc
Expand Up @@ -45,6 +45,12 @@ std::string FlightStatusDetail::ToString() const { return CodeAsString(); }

FlightStatusCode FlightStatusDetail::code() const { return code_; }

std::string FlightStatusDetail::extra_info() const { return extra_info_; }

void FlightStatusDetail::set_extra_info(std::string extra_info) {
extra_info_ = std::move(extra_info);
}

std::string FlightStatusDetail::CodeAsString() const {
switch (code()) {
case FlightStatusCode::Internal:
Expand Down Expand Up @@ -77,6 +83,13 @@ Status MakeFlightError(FlightStatusCode code, const std::string& message) {
return arrow::Status(arrow_code, message, std::make_shared<FlightStatusDetail>(code));
}

Status MakeFlightError(FlightStatusCode code, const std::string& message,
const std::string& extra_info) {
StatusCode arrow_code = arrow::StatusCode::IOError;
return arrow::Status(arrow_code, message,
std::make_shared<FlightStatusDetail>(code, extra_info));
}

bool FlightDescriptor::Equals(const FlightDescriptor& other) const {
if (type != other.type) {
return false;
Expand Down
17 changes: 17 additions & 0 deletions cpp/src/arrow/flight/types.h
Expand Up @@ -80,13 +80,19 @@ enum class FlightStatusCode : int8_t {
class ARROW_FLIGHT_EXPORT FlightStatusDetail : public arrow::StatusDetail {
public:
explicit FlightStatusDetail(FlightStatusCode code) : code_{code} {}
explicit FlightStatusDetail(FlightStatusCode code, std::string extra_info)
: code_{code}, extra_info_(std::move(extra_info)) {}
const char* type_id() const override;
std::string ToString() const override;

/// \brief Get the Flight status code.
FlightStatusCode code() const;
/// \brief Get the extra error info
std::string extra_info() const;
/// \brief Get the human-readable name of the status code.
std::string CodeAsString() const;
/// \brief Set the extra error info
void set_extra_info(std::string extra_info);

/// \brief Try to extract a \a FlightStatusDetail from any Arrow
/// status.
Expand All @@ -97,6 +103,7 @@ class ARROW_FLIGHT_EXPORT FlightStatusDetail : public arrow::StatusDetail {

private:
FlightStatusCode code_;
std::string extra_info_;
};

#ifdef _MSC_VER
Expand All @@ -111,6 +118,16 @@ class ARROW_FLIGHT_EXPORT FlightStatusDetail : public arrow::StatusDetail {
ARROW_FLIGHT_EXPORT
Status MakeFlightError(FlightStatusCode code, const std::string& message);

/// \brief Make an appropriate Arrow status for the given
/// Flight-specific status.
///
/// \param code The status code.
/// \param message The message for the error.
/// \param extra_info The extra binary info for the error (eg protobuf)
ARROW_FLIGHT_EXPORT
Status MakeFlightError(FlightStatusCode code, const std::string& message,
const std::string& extra_info);

/// \brief A TLS certificate plus key.
struct ARROW_FLIGHT_EXPORT CertKeyPair {
/// \brief The certificate in PEM format.
Expand Down
41 changes: 25 additions & 16 deletions python/pyarrow/_flight.pyx
Expand Up @@ -52,21 +52,22 @@ cdef int check_flight_status(const CStatus& status) nogil except -1:
if detail:
with gil:
message = frombytes(status.message())
detail_msg = detail.get().extra_info()
if detail.get().code() == CFlightStatusInternal:
raise FlightInternalError(message)
raise FlightInternalError(message, detail_msg)
elif detail.get().code() == CFlightStatusFailed:
message = _munge_grpc_python_error(message)
raise FlightServerError(message)
raise FlightServerError(message, detail_msg)
elif detail.get().code() == CFlightStatusTimedOut:
raise FlightTimedOutError(message)
raise FlightTimedOutError(message, detail_msg)
elif detail.get().code() == CFlightStatusCancelled:
raise FlightCancelledError(message)
raise FlightCancelledError(message, detail_msg)
elif detail.get().code() == CFlightStatusUnauthenticated:
raise FlightUnauthenticatedError(message)
raise FlightUnauthenticatedError(message, detail_msg)
elif detail.get().code() == CFlightStatusUnauthorized:
raise FlightUnauthorizedError(message)
raise FlightUnauthorizedError(message, detail_msg)
elif detail.get().code() == CFlightStatusUnavailable:
raise FlightUnavailableError(message)
raise FlightUnavailableError(message, detail_msg)

return check_status(status)

Expand Down Expand Up @@ -126,46 +127,54 @@ class CertKeyPair(_CertKeyPair):
cdef class FlightError(Exception):
cdef dict __dict__

def __init__(self, message='', extra_info=b''):
super().__init__(message)
self.extra_info = tobytes(extra_info)

cdef CStatus to_status(self):
message = tobytes("Flight error: {}".format(str(self)))
return CStatus_UnknownError(message)


cdef class FlightInternalError(FlightError, ArrowException):
cdef CStatus to_status(self):
return MakeFlightError(CFlightStatusInternal, tobytes(str(self)))
return MakeFlightError(CFlightStatusInternal,
tobytes(str(self)), self.extra_info)


cdef class FlightTimedOutError(FlightError, ArrowException):
cdef CStatus to_status(self):
return MakeFlightError(CFlightStatusTimedOut, tobytes(str(self)))
return MakeFlightError(CFlightStatusTimedOut,
tobytes(str(self)), self.extra_info)


cdef class FlightCancelledError(FlightError, ArrowException):
cdef CStatus to_status(self):
return MakeFlightError(CFlightStatusCancelled, tobytes(str(self)))
return MakeFlightError(CFlightStatusCancelled, tobytes(str(self)),
self.extra_info)


cdef class FlightServerError(FlightError, ArrowException):
cdef CStatus to_status(self):
return MakeFlightError(CFlightStatusFailed, tobytes(str(self)))
return MakeFlightError(CFlightStatusFailed, tobytes(str(self)),
self.extra_info)


cdef class FlightUnauthenticatedError(FlightError, ArrowException):
cdef CStatus to_status(self):
return MakeFlightError(
CFlightStatusUnauthenticated, tobytes(str(self)))
CFlightStatusUnauthenticated, tobytes(str(self)), self.extra_info)


cdef class FlightUnauthorizedError(FlightError, ArrowException):
cdef CStatus to_status(self):
return MakeFlightError(CFlightStatusUnauthorized, tobytes(str(self)))
return MakeFlightError(CFlightStatusUnauthorized, tobytes(str(self)),
self.extra_info)


cdef class FlightUnavailableError(FlightError, ArrowException):
cdef CStatus to_status(self):
return MakeFlightError(CFlightStatusUnavailable, tobytes(str(self)))

return MakeFlightError(CFlightStatusUnavailable, tobytes(str(self)),
self.extra_info)

cdef class Action:
"""An action executable on a Flight service."""
Expand Down
5 changes: 5 additions & 0 deletions python/pyarrow/includes/libarrow_flight.pxd
Expand Up @@ -323,12 +323,17 @@ cdef extern from "arrow/flight/api.h" namespace "arrow" nogil:

cdef cppclass FlightStatusDetail" arrow::flight::FlightStatusDetail":
CFlightStatusCode code()
c_string extra_info()
@staticmethod
shared_ptr[FlightStatusDetail] UnwrapStatus(const CStatus& status)

cdef CStatus MakeFlightError" arrow::flight::MakeFlightError" \
(CFlightStatusCode code, const c_string& message)

cdef CStatus MakeFlightError" arrow::flight::MakeFlightError" \
(CFlightStatusCode code,
const c_string& message,
const c_string& extra_info)

# Callbacks for implementing Flight servers
# Use typedef to emulate syntax for std::function<void(..)>
Expand Down

0 comments on commit 0facdc7

Please sign in to comment.