Skip to content

Commit

Permalink
ARROW-8112: [FlightRPC][C++] make sure status codes round-trip throug…
Browse files Browse the repository at this point in the history
…h gRPC

There are still unmapped status codes, but these are the ones that correspond closely to a gRPC one. (OutOfMemory, for instance, doesn't quite line up with RESOURCE_EXHAUSTED since the latter is intended for some application-level resource like a disk quota, not an internal server error.)

Closes #6615 from lidavidm/flight-missing-codes

Lead-authored-by: David Li <li.davidm96@gmail.com>
Co-authored-by: Antoine Pitrou <antoine@python.org>
Signed-off-by: Antoine Pitrou <antoine@python.org>
  • Loading branch information
lidavidm and pitrou committed Mar 17, 2020
1 parent 948c8f6 commit ec7fce5
Show file tree
Hide file tree
Showing 7 changed files with 189 additions and 26 deletions.
16 changes: 8 additions & 8 deletions cpp/src/arrow/flight/client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -306,7 +306,7 @@ class GrpcIpcMessageReader : public ipc::MessageReader {
protected:
Status OverrideWithServerError(Status&& st) {
// Get the gRPC status if not OK, to propagate any server error message
RETURN_NOT_OK(internal::FromGrpcStatus(stream_->Finish()));
RETURN_NOT_OK(internal::FromGrpcStatus(stream_->Finish(), &rpc_->context));
return std::move(st);
}

Expand Down Expand Up @@ -458,7 +458,7 @@ class DoPutPayloadWriter : public ipc::internal::IpcPayloadWriter {
pb::PutResult message;
while (writer_->Read(&message)) {
}
RETURN_NOT_OK(internal::FromGrpcStatus(writer_->Finish()));
RETURN_NOT_OK(internal::FromGrpcStatus(writer_->Finish(), &rpc_->context));
if (!finished_writes) {
return Status::UnknownError(
"Could not finish writing record batches before closing");
Expand Down Expand Up @@ -577,7 +577,7 @@ class FlightClient::FlightClientImpl {
RETURN_NOT_OK(auth_handler_->Authenticate(&outgoing, &incoming));
// Explicitly close our side of the connection
bool finished_writes = stream->WritesDone();
RETURN_NOT_OK(internal::FromGrpcStatus(stream->Finish()));
RETURN_NOT_OK(internal::FromGrpcStatus(stream->Finish(), &rpc.context));
if (!finished_writes) {
return Status::UnknownError("Could not finish writing before closing");
}
Expand All @@ -604,7 +604,7 @@ class FlightClient::FlightClientImpl {
}

listing->reset(new SimpleFlightListing(std::move(flights)));
return internal::FromGrpcStatus(stream->Finish());
return internal::FromGrpcStatus(stream->Finish(), &rpc.context);
}

Status DoAction(const FlightCallOptions& options, const Action& action,
Expand All @@ -628,7 +628,7 @@ class FlightClient::FlightClientImpl {

*results = std::unique_ptr<ResultStream>(
new SimpleResultStream(std::move(materialized_results)));
return internal::FromGrpcStatus(stream->Finish());
return internal::FromGrpcStatus(stream->Finish(), &rpc.context);
}

Status ListActions(const FlightCallOptions& options, std::vector<ActionType>* types) {
Expand All @@ -645,7 +645,7 @@ class FlightClient::FlightClientImpl {
RETURN_NOT_OK(internal::FromProto(pb_type, &type));
types->emplace_back(std::move(type));
}
return internal::FromGrpcStatus(stream->Finish());
return internal::FromGrpcStatus(stream->Finish(), &rpc.context);
}

Status GetFlightInfo(const FlightCallOptions& options,
Expand All @@ -659,7 +659,7 @@ class FlightClient::FlightClientImpl {
ClientRpc rpc(options);
RETURN_NOT_OK(rpc.SetToken(auth_handler_.get()));
Status s = internal::FromGrpcStatus(
stub_->GetFlightInfo(&rpc.context, pb_descriptor, &pb_response));
stub_->GetFlightInfo(&rpc.context, pb_descriptor, &pb_response), &rpc.context);
RETURN_NOT_OK(s);

FlightInfo::Data info_data;
Expand All @@ -678,7 +678,7 @@ class FlightClient::FlightClientImpl {
ClientRpc rpc(options);
RETURN_NOT_OK(rpc.SetToken(auth_handler_.get()));
Status s = internal::FromGrpcStatus(
stub_->GetSchema(&rpc.context, pb_descriptor, &pb_response));
stub_->GetSchema(&rpc.context, pb_descriptor, &pb_response), &rpc.context);
RETURN_NOT_OK(s);

std::string str;
Expand Down
25 changes: 25 additions & 0 deletions cpp/src/arrow/flight/flight_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -258,6 +258,24 @@ TEST(TestFlight, RoundtripStatus) {
MakeFlightError(FlightStatusCode::Unavailable, "Test message"));
ASSERT_NE(nullptr, detail);
ASSERT_EQ(FlightStatusCode::Unavailable, detail->code());

Status status = internal::FromGrpcStatus(
internal::ToGrpcStatus(Status::NotImplemented("Sentinel")));
ASSERT_TRUE(status.IsNotImplemented());
ASSERT_THAT(status.message(), ::testing::HasSubstr("Sentinel"));

status = internal::FromGrpcStatus(internal::ToGrpcStatus(Status::Invalid("Sentinel")));
ASSERT_TRUE(status.IsInvalid());
ASSERT_THAT(status.message(), ::testing::HasSubstr("Sentinel"));

status = internal::FromGrpcStatus(internal::ToGrpcStatus(Status::KeyError("Sentinel")));
ASSERT_TRUE(status.IsKeyError());
ASSERT_THAT(status.message(), ::testing::HasSubstr("Sentinel"));

status =
internal::FromGrpcStatus(internal::ToGrpcStatus(Status::AlreadyExists("Sentinel")));
ASSERT_TRUE(status.IsAlreadyExists());
ASSERT_THAT(status.message(), ::testing::HasSubstr("Sentinel"));
}

TEST(TestFlight, GetPort) {
Expand Down Expand Up @@ -965,6 +983,13 @@ TEST_F(TestFlightClient, DoAction) {
ASSERT_EQ(nullptr, result);
}

TEST_F(TestFlightClient, RoundTripStatus) {
const auto descr = FlightDescriptor::Command("status-outofmemory");
std::unique_ptr<FlightInfo> info;
const auto status = client_->GetFlightInfo(descr, &info);
ASSERT_RAISES(OutOfMemory, status);
}

TEST_F(TestFlightClient, Issue5095) {
// Make sure the server-side error message is reflected to the
// client
Expand Down
111 changes: 107 additions & 4 deletions cpp/src/arrow/flight/internal.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#include "arrow/flight/protocol_internal.h"

#include <cstddef>
#include <map>
#include <memory>
#include <sstream>
#include <string>
Expand All @@ -45,12 +46,72 @@ namespace flight {
namespace internal {

const char* kGrpcAuthHeader = "auth-token-bin";
const char* kGrpcStatusCodeHeader = "x-arrow-status";
const char* kGrpcStatusMessageHeader = "x-arrow-status-message-bin";
const char* kGrpcStatusDetailHeader = "x-arrow-status-detail-bin";

static Status StatusCodeFromString(const grpc::string_ref& code_ref, StatusCode* code) {
// Bounce through std::string to get a proper null-terminated C string
const auto code_int = std::atoi(std::string(code_ref.data(), code_ref.size()).c_str());
switch (code_int) {
case static_cast<int>(StatusCode::OutOfMemory):
case static_cast<int>(StatusCode::KeyError):
case static_cast<int>(StatusCode::TypeError):
case static_cast<int>(StatusCode::Invalid):
case static_cast<int>(StatusCode::IOError):
case static_cast<int>(StatusCode::CapacityError):
case static_cast<int>(StatusCode::IndexError):
case static_cast<int>(StatusCode::UnknownError):
case static_cast<int>(StatusCode::NotImplemented):
case static_cast<int>(StatusCode::SerializationError):
case static_cast<int>(StatusCode::RError):
case static_cast<int>(StatusCode::CodeGenError):
case static_cast<int>(StatusCode::ExpressionValidationError):
case static_cast<int>(StatusCode::ExecutionError):
case static_cast<int>(StatusCode::AlreadyExists): {
*code = static_cast<StatusCode>(code_int);
return Status::OK();
}
default:
// Code is invalid
return Status::UnknownError("Unknown Arrow status code", code_ref);
}
}

Status FromGrpcStatus(const grpc::Status& grpc_status) {
if (grpc_status.ok()) {
return Status::OK();
/// Try to extract a status from gRPC trailers.
/// Return Status::OK if found, an error otherwise.
static Status FromGrpcContext(const grpc::ClientContext& ctx, Status* status) {
const std::multimap<grpc::string_ref, grpc::string_ref>& trailers =
ctx.GetServerTrailingMetadata();
const auto code_val = trailers.find(kGrpcStatusCodeHeader);
if (code_val == trailers.end()) {
return Status::IOError("Status code header not found");
}

const grpc::string_ref code_ref = (*code_val).second;
StatusCode code;
RETURN_NOT_OK(StatusCodeFromString(code_ref, &code));

const auto message_val = trailers.find(kGrpcStatusMessageHeader);
if (message_val == trailers.end()) {
return Status::IOError("Status message header not found");
}

const grpc::string_ref message_ref = (*message_val).second;
std::string message = std::string(message_ref.data(), message_ref.size());
const auto detail_val = trailers.find(kGrpcStatusDetailHeader);
if (detail_val != trailers.end()) {
const grpc::string_ref detail_ref = (*detail_val).second;
message += ". Detail: ";
message += std::string(detail_ref.data(), detail_ref.size());
}
*status = Status(code, message);
return Status::OK();
}

/// Convert a gRPC status to an Arrow status, ignoring any
/// implementation-defined headers that encode further detail.
static Status FromGrpcCode(const grpc::Status& grpc_status) {
switch (grpc_status.error_code()) {
case grpc::StatusCode::OK:
return Status::OK();
Expand Down Expand Up @@ -123,7 +184,28 @@ Status FromGrpcStatus(const grpc::Status& grpc_status) {
}
}

grpc::Status ToGrpcStatus(const Status& arrow_status) {
Status FromGrpcStatus(const grpc::Status& grpc_status, grpc::ClientContext* ctx) {
const Status status = FromGrpcCode(grpc_status);

if (!status.ok() && ctx) {
Status arrow_status;

if (!FromGrpcContext(*ctx, &arrow_status).ok()) {
// If we fail to decode a more detailed status from the headers,
// proceed normally
return status;
}

if (status.detail()) {
return arrow_status.WithDetail(status.detail());
}
return arrow_status;
}
return status;
}

/// Convert an Arrow status to a gRPC status.
static grpc::Status ToRawGrpcStatus(const Status& arrow_status) {
if (arrow_status.ok()) {
return grpc::Status::OK;
}
Expand Down Expand Up @@ -164,10 +246,31 @@ grpc::Status ToGrpcStatus(const Status& arrow_status) {
grpc_code = grpc::StatusCode::UNIMPLEMENTED;
} else if (arrow_status.IsInvalid()) {
grpc_code = grpc::StatusCode::INVALID_ARGUMENT;
} else if (arrow_status.IsKeyError()) {
grpc_code = grpc::StatusCode::NOT_FOUND;
} else if (arrow_status.IsAlreadyExists()) {
grpc_code = grpc::StatusCode::ALREADY_EXISTS;
}
return grpc::Status(grpc_code, message);
}

/// Convert an Arrow status to a gRPC status, and add extra headers to
/// the response to encode the original Arrow status.
grpc::Status ToGrpcStatus(const Status& arrow_status, grpc::ServerContext* ctx) {
grpc::Status status = ToRawGrpcStatus(arrow_status);
if (!status.ok() && ctx) {
const std::string code = std::to_string(static_cast<int>(arrow_status.code()));
ctx->AddTrailingMetadata(internal::kGrpcStatusCodeHeader, code);
ctx->AddTrailingMetadata(internal::kGrpcStatusMessageHeader, arrow_status.message());
if (arrow_status.detail()) {
const std::string detail_string = arrow_status.detail()->ToString();
ctx->AddTrailingMetadata(internal::kGrpcStatusDetailHeader, detail_string);
}
}

return status;
}

// ActionType

Status FromProto(const pb::ActionType& pb_type, ActionType* type) {
Expand Down
20 changes: 18 additions & 2 deletions cpp/src/arrow/flight/internal.h
Original file line number Diff line number Diff line change
Expand Up @@ -67,14 +67,30 @@ namespace internal {
ARROW_FLIGHT_EXPORT
extern const char* kGrpcAuthHeader;

/// The name of the header used to pass the exact Arrow status code.
ARROW_FLIGHT_EXPORT
extern const char* kGrpcStatusCodeHeader;

/// The name of the header used to pass the exact Arrow status message.
ARROW_FLIGHT_EXPORT
extern const char* kGrpcStatusMessageHeader;

/// The name of the header used to pass the exact Arrow status detail.
ARROW_FLIGHT_EXPORT
extern const char* kGrpcStatusDetailHeader;

ARROW_FLIGHT_EXPORT
Status SchemaToString(const Schema& schema, std::string* out);

/// Convert a gRPC status to an Arrow status. Optionally, provide a
/// ClientContext to recover the exact Arrow status if it was passed
/// over the wire.
ARROW_FLIGHT_EXPORT
Status FromGrpcStatus(const grpc::Status& grpc_status);
Status FromGrpcStatus(const grpc::Status& grpc_status,
grpc::ClientContext* ctx = nullptr);

ARROW_FLIGHT_EXPORT
grpc::Status ToGrpcStatus(const Status& arrow_status);
grpc::Status ToGrpcStatus(const Status& arrow_status, grpc::ServerContext* ctx = nullptr);

// These functions depend on protobuf types which are not exported in the Flight DLL.

Expand Down
24 changes: 14 additions & 10 deletions cpp/src/arrow/flight/server.cc
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,8 @@ class GrpcServerAuthSender : public ServerAuthSender {

class FlightServiceImpl;
class GrpcServerCallContext : public ServerCallContext {
explicit GrpcServerCallContext(grpc::ServerContext* context) : context_(context) {}

const std::string& peer_identity() const override { return peer_identity_; }

// Helper method that runs interceptors given the result of an RPC,
Expand All @@ -248,7 +250,10 @@ class GrpcServerCallContext : public ServerCallContext {
for (const auto& instance : middleware_) {
instance->CallCompleted(status);
}
return internal::ToGrpcStatus(status);

// Set custom headers to map the exact Arrow status for clients
// who want it.
return internal::ToGrpcStatus(status, context_);
}

ServerMiddleware* GetMiddleware(const std::string& key) const override {
Expand Down Expand Up @@ -334,7 +339,6 @@ class FlightServiceImpl : public FlightService::Service {
// Authenticate the client (if applicable) and construct the call context
grpc::Status CheckAuth(const FlightMethod& method, ServerContext* context,
GrpcServerCallContext& flight_context) {
flight_context.context_ = context;
if (!auth_handler_) {
flight_context.peer_identity_ = "";
} else {
Expand Down Expand Up @@ -386,7 +390,7 @@ class FlightServiceImpl : public FlightService::Service {
grpc::Status Handshake(
ServerContext* context,
grpc::ServerReaderWriter<pb::HandshakeResponse, pb::HandshakeRequest>* stream) {
GrpcServerCallContext flight_context;
GrpcServerCallContext flight_context(context);
GRPC_RETURN_NOT_GRPC_OK(
MakeCallContext(FlightMethod::Handshake, context, flight_context));

Expand All @@ -405,7 +409,7 @@ class FlightServiceImpl : public FlightService::Service {

grpc::Status ListFlights(ServerContext* context, const pb::Criteria* request,
ServerWriter<pb::FlightInfo>* writer) {
GrpcServerCallContext flight_context;
GrpcServerCallContext flight_context(context);
GRPC_RETURN_NOT_GRPC_OK(
CheckAuth(FlightMethod::ListFlights, context, flight_context));

Expand All @@ -428,7 +432,7 @@ class FlightServiceImpl : public FlightService::Service {

grpc::Status GetFlightInfo(ServerContext* context, const pb::FlightDescriptor* request,
pb::FlightInfo* response) {
GrpcServerCallContext flight_context;
GrpcServerCallContext flight_context(context);
GRPC_RETURN_NOT_GRPC_OK(
CheckAuth(FlightMethod::GetFlightInfo, context, flight_context));

Expand All @@ -453,7 +457,7 @@ class FlightServiceImpl : public FlightService::Service {

grpc::Status GetSchema(ServerContext* context, const pb::FlightDescriptor* request,
pb::SchemaResult* response) {
GrpcServerCallContext flight_context;
GrpcServerCallContext flight_context(context);
GRPC_RETURN_NOT_GRPC_OK(CheckAuth(FlightMethod::GetSchema, context, flight_context));

CHECK_ARG_NOT_NULL(flight_context, request, "FlightDescriptor cannot be null");
Expand All @@ -477,7 +481,7 @@ class FlightServiceImpl : public FlightService::Service {

grpc::Status DoGet(ServerContext* context, const pb::Ticket* request,
ServerWriter<pb::FlightData>* writer) {
GrpcServerCallContext flight_context;
GrpcServerCallContext flight_context(context);
GRPC_RETURN_NOT_GRPC_OK(CheckAuth(FlightMethod::DoGet, context, flight_context));

CHECK_ARG_NOT_NULL(flight_context, request, "ticket cannot be null");
Expand Down Expand Up @@ -517,7 +521,7 @@ class FlightServiceImpl : public FlightService::Service {

grpc::Status DoPut(ServerContext* context,
grpc::ServerReaderWriter<pb::PutResult, pb::FlightData>* reader) {
GrpcServerCallContext flight_context;
GrpcServerCallContext flight_context(context);
GRPC_RETURN_NOT_GRPC_OK(CheckAuth(FlightMethod::DoPut, context, flight_context));

auto message_reader =
Expand All @@ -532,7 +536,7 @@ class FlightServiceImpl : public FlightService::Service {

grpc::Status ListActions(ServerContext* context, const pb::Empty* request,
ServerWriter<pb::ActionType>* writer) {
GrpcServerCallContext flight_context;
GrpcServerCallContext flight_context(context);
GRPC_RETURN_NOT_GRPC_OK(
CheckAuth(FlightMethod::ListActions, context, flight_context));
// Retrieve the listing from the implementation
Expand All @@ -543,7 +547,7 @@ class FlightServiceImpl : public FlightService::Service {

grpc::Status DoAction(ServerContext* context, const pb::Action* request,
ServerWriter<pb::Result>* writer) {
GrpcServerCallContext flight_context;
GrpcServerCallContext flight_context(context);
GRPC_RETURN_NOT_GRPC_OK(CheckAuth(FlightMethod::DoAction, context, flight_context));
CHECK_ARG_NOT_NULL(flight_context, request, "Action cannot be null");
Action action;
Expand Down
Loading

0 comments on commit ec7fce5

Please sign in to comment.