Skip to content

Commit

Permalink
apacheGH-35375: [C++][FlightRPC] Add `arrow::flight::ServerCallContex…
Browse files Browse the repository at this point in the history
…t::incoming_headers()` (apache#35376)

### Rationale for this change

It returns headers sent by a client.

We can get them only in `arrow::flight::ServerMiddlewareCactory::StartCall()` for now. But they're useful for in each RPC call.

### What changes are included in this PR?

Add the method.

### Are these changes tested?

Yes.

### Are there any user-facing changes?

Yes.
* Closes: apache#35375

Authored-by: Sutou Kouhei <kou@clear-code.com>
Signed-off-by: Sutou Kouhei <kou@clear-code.com>
  • Loading branch information
kou authored and liujiacheng777 committed May 11, 2023
1 parent 7e251d8 commit d936fb1
Show file tree
Hide file tree
Showing 8 changed files with 66 additions and 16 deletions.
2 changes: 1 addition & 1 deletion cpp/src/arrow/flight/client.h
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,7 @@ class ARROW_FLIGHT_EXPORT FlightClient {
/// \param[in] username Username to use
/// \param[in] password Password to use
/// \return Arrow result with bearer token and status OK if client authenticated
/// sucessfully
/// successfully
arrow::Result<std::pair<std::string, std::string>> AuthenticateBasicToken(
const FlightCallOptions& options, const std::string& username,
const std::string& password);
Expand Down
24 changes: 24 additions & 0 deletions cpp/src/arrow/flight/flight_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,30 @@ TEST(TestFlight, DISABLED_IpV6Port) {
ASSERT_OK(client->ListFlights());
}

TEST(TestFlight, ServerCallContextIncomingHeaders) {
auto server = ExampleTestServer();
ASSERT_OK_AND_ASSIGN(auto location, Location::ForGrpcTcp("localhost", 0));
FlightServerOptions options(location);
ASSERT_OK(server->Init(options));

ASSERT_OK_AND_ASSIGN(auto client, FlightClient::Connect(server->location()));
Action action;
action.type = "list-incoming-headers";
action.body = Buffer::FromString("test-header");
FlightCallOptions call_options;
call_options.headers.emplace_back("test-header1", "value1");
call_options.headers.emplace_back("test-header2", "value2");
ASSERT_OK_AND_ASSIGN(auto stream, client->DoAction(call_options, action));
ASSERT_OK_AND_ASSIGN(auto result, stream->Next());
ASSERT_NE(result.get(), nullptr);
ASSERT_EQ(result->body->ToString(), "test-header1: value1");
ASSERT_OK_AND_ASSIGN(result, stream->Next());
ASSERT_NE(result.get(), nullptr);
ASSERT_EQ(result->body->ToString(), "test-header2: value2");
ASSERT_OK_AND_ASSIGN(result, stream->Next());
ASSERT_EQ(result.get(), nullptr);
}

// ----------------------------------------------------------------------
// Client tests

Expand Down
8 changes: 1 addition & 7 deletions cpp/src/arrow/flight/middleware.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,23 +20,17 @@

#pragma once

#include <map>
#include <memory>
#include <string>
#include <string_view>
#include <utility>

#include "arrow/flight/visibility.h" // IWYU pragma: keep
#include "arrow/flight/types.h"
#include "arrow/status.h"

namespace arrow {
namespace flight {

/// \brief Headers sent from the client or server.
///
/// Header values are ordered.
using CallHeaders = std::multimap<std::string_view, std::string_view>;

/// \brief A write-only wrapper around headers for an RPC call.
class ARROW_FLIGHT_EXPORT AddCallHeaders {
public:
Expand Down
2 changes: 2 additions & 0 deletions cpp/src/arrow/flight/server.h
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,8 @@ class ARROW_FLIGHT_EXPORT ServerCallContext {
/// \brief Check if the current RPC has been cancelled (by the client, by
/// a network error, etc.).
virtual bool is_cancelled() const = 0;
/// \brief The headers sent by the client for this call.
virtual const CallHeaders& incoming_headers() const = 0;
};

class ARROW_FLIGHT_EXPORT FlightServerOptions {
Expand Down
19 changes: 19 additions & 0 deletions cpp/src/arrow/flight/test_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -474,12 +474,31 @@ class FlightTestServer : public FlightServerBase {
return Status::OK();
}

Status ListIncomingHeaders(const ServerCallContext& context, const Action& action,
std::unique_ptr<ResultStream>* out) {
std::vector<Result> results;
std::string_view prefix(*action.body);
for (const auto& header : context.incoming_headers()) {
if (header.first.substr(0, prefix.size()) != prefix) {
continue;
}
Result result;
result.body = Buffer::FromString(std::string(header.first) + ": " +
std::string(header.second));
results.push_back(result);
}
*out = std::make_unique<SimpleResultStream>(std::move(results));
return Status::OK();
}

Status DoAction(const ServerCallContext& context, const Action& action,
std::unique_ptr<ResultStream>* out) override {
if (action.type == "action1") {
return RunAction1(action, out);
} else if (action.type == "action2") {
return RunAction2(out);
} else if (action.type == "list-incoming-headers") {
return ListIncomingHeaders(context, action, out);
} else {
return Status::NotImplemented(action.type);
}
Expand Down
19 changes: 11 additions & 8 deletions cpp/src/arrow/flight/transport/grpc/grpc_server.cc
Original file line number Diff line number Diff line change
Expand Up @@ -117,11 +117,18 @@ class GrpcServerAuthSender : public ServerAuthSender {

class GrpcServerCallContext : public ServerCallContext {
explicit GrpcServerCallContext(::grpc::ServerContext* context)
: context_(context), peer_(context_->peer()) {}
: context_(context), peer_(context_->peer()) {
for (const auto& entry : context->client_metadata()) {
incoming_headers_.insert(
{std::string_view(entry.first.data(), entry.first.length()),
std::string_view(entry.second.data(), entry.second.length())});
}
}

const std::string& peer_identity() const override { return peer_identity_; }
const std::string& peer() const override { return peer_; }
bool is_cancelled() const override { return context_->IsCancelled(); }
const CallHeaders& incoming_headers() const override { return incoming_headers_; }

// Helper method that runs interceptors given the result of an RPC,
// then returns the final gRPC status to send to the client
Expand Down Expand Up @@ -156,6 +163,7 @@ class GrpcServerCallContext : public ServerCallContext {
std::string peer_identity_;
std::vector<std::shared_ptr<ServerMiddleware>> middleware_;
std::unordered_map<std::string, std::shared_ptr<ServerMiddleware>> middleware_map_;
CallHeaders incoming_headers_;
};

class GrpcAddServerHeaders : public AddCallHeaders {
Expand Down Expand Up @@ -310,17 +318,12 @@ class GrpcServiceHandler final : public FlightService::Service {
GrpcServerCallContext& flight_context) {
// Run server middleware
const CallInfo info{method};
CallHeaders incoming_headers;
for (const auto& entry : context->client_metadata()) {
incoming_headers.insert(
{std::string_view(entry.first.data(), entry.first.length()),
std::string_view(entry.second.data(), entry.second.length())});
}

GrpcAddServerHeaders outgoing_headers(context);
for (const auto& factory : middleware_) {
std::shared_ptr<ServerMiddleware> instance;
Status result = factory.second->StartCall(info, incoming_headers, &instance);
Status result =
factory.second->StartCall(info, flight_context.incoming_headers(), &instance);
if (!result.ok()) {
// Interceptor rejected call, end the request on all existing
// interceptors
Expand Down
2 changes: 2 additions & 0 deletions cpp/src/arrow/flight/transport/ucx/ucx_server.cc
Original file line number Diff line number Diff line change
Expand Up @@ -76,9 +76,11 @@ class UcxServerCallContext : public flight::ServerCallContext {
return nullptr;
}
bool is_cancelled() const override { return false; }
const CallHeaders& incoming_headers() const override { return incoming_headers_; }

private:
std::string peer_;
CallHeaders incoming_headers_;
};

class UcxServerStream : public internal::ServerDataStream {
Expand Down
6 changes: 6 additions & 0 deletions cpp/src/arrow/flight/types.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

#include <cstddef>
#include <cstdint>
#include <map>
#include <memory>
#include <string>
#include <string_view>
Expand Down Expand Up @@ -123,6 +124,11 @@ ARROW_FLIGHT_EXPORT
Status MakeFlightError(FlightStatusCode code, std::string message,
std::string extra_info = {});

/// \brief Headers sent from the client or server.
///
/// Header values are ordered.
using CallHeaders = std::multimap<std::string_view, std::string_view>;

/// \brief A TLS certificate plus key.
struct ARROW_FLIGHT_EXPORT CertKeyPair {
/// \brief The certificate in PEM format.
Expand Down

0 comments on commit d936fb1

Please sign in to comment.