Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
291 changes: 280 additions & 11 deletions arrow-flight/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ use futures::{
stream::{self, BoxStream},
};
use prost::Message;
use tonic::codegen::{Body, StdError};
use tonic::{metadata::MetadataMap, transport::Channel};

use crate::error::{FlightError, Result};
Expand Down Expand Up @@ -67,22 +68,28 @@ use crate::streams::{FallibleRequestStream, FallibleTonicResponseStream};
/// # }
/// ```
#[derive(Debug)]
pub struct FlightClient {
pub struct FlightClient<T = Channel> {
/// Optional grpc header metadata to include with each request
metadata: MetadataMap,

/// The inner client
inner: FlightServiceClient<Channel>,
inner: FlightServiceClient<T>,
}

impl FlightClient {
/// Creates a client client with the provided [`Channel`]
pub fn new(channel: Channel) -> Self {
Self::new_from_inner(FlightServiceClient::new(channel))
impl<T> FlightClient<T>
where
T: tonic::client::GrpcService<tonic::body::Body>,
T::Error: Into<StdError>,
T::ResponseBody: Body<Data = Bytes> + std::marker::Send + 'static,
<T::ResponseBody as Body>::Error: Into<StdError> + std::marker::Send,
{
/// Creates a client with the provided transport
pub fn new(inner: T) -> Self {
Self::new_from_inner(FlightServiceClient::new(inner))
}

/// Creates a new higher level client with the provided lower level client
pub fn new_from_inner(inner: FlightServiceClient<Channel>) -> Self {
pub fn new_from_inner(inner: FlightServiceClient<T>) -> Self {
Self {
metadata: MetadataMap::new(),
inner,
Expand Down Expand Up @@ -120,19 +127,19 @@ impl FlightClient {

/// Return a reference to the underlying tonic
/// [`FlightServiceClient`]
pub fn inner(&self) -> &FlightServiceClient<Channel> {
pub fn inner(&self) -> &FlightServiceClient<T> {
&self.inner
}

/// Return a mutable reference to the underlying tonic
/// [`FlightServiceClient`]
pub fn inner_mut(&mut self) -> &mut FlightServiceClient<Channel> {
pub fn inner_mut(&mut self) -> &mut FlightServiceClient<T> {
&mut self.inner
}

/// Consume this client and return the underlying tonic
/// [`FlightServiceClient`]
pub fn into_inner(self) -> FlightServiceClient<Channel> {
pub fn into_inner(self) -> FlightServiceClient<T> {
self.inner
}

Expand Down Expand Up @@ -664,10 +671,272 @@ impl FlightClient {
}

/// return a Request, adding any configured metadata
fn make_request<T>(&self, t: T) -> tonic::Request<T> {
fn make_request<R>(&self, t: R) -> tonic::Request<R> {
// Pass along metadata
let mut request = tonic::Request::new(t);
*request.metadata_mut() = self.metadata.clone();
request
}
}

#[cfg(test)]
mod tests {
use super::FlightClient;
use crate::encode::FlightDataEncoderBuilder;
use crate::flight_service_server::{FlightService, FlightServiceServer};
use crate::{
Action, ActionType, Criteria, Empty, FlightData, FlightDescriptor, FlightInfo,
HandshakeRequest, HandshakeResponse, PollInfo, PutResult, SchemaResult, Ticket,
};
use arrow_array::{RecordBatch, UInt64Array};
use bytes::Bytes;
use futures::{StreamExt, TryStreamExt, stream::BoxStream};
use std::net::SocketAddr;
use std::sync::{Arc, Mutex};
use std::time::Duration;
use tokio::net::TcpListener;
use tokio::task::JoinHandle;
use tonic::metadata::MetadataMap;
use tonic::service::interceptor::InterceptedService;
use tonic::transport::Channel;
use tonic::{Request, Response, Status, Streaming};
use uuid::Uuid;

/// Minimal `FlightService` that records request metadata and serves a
/// configured `do_get` response. Other RPCs return `Unimplemented`.
#[derive(Debug, Clone, Default)]
struct InterceptorTestServer {
state: Arc<Mutex<InterceptorTestState>>,
}

#[derive(Debug, Default)]
struct InterceptorTestState {
do_get_request: Option<Ticket>,
do_get_response: Option<Vec<Result<RecordBatch, Status>>>,
last_request_metadata: Option<MetadataMap>,
}

impl InterceptorTestServer {
fn save_metadata<T>(&self, request: &Request<T>) {
self.state.lock().unwrap().last_request_metadata = Some(request.metadata().clone());
}

fn set_do_get_response(&self, response: Vec<Result<RecordBatch, Status>>) {
self.state.lock().unwrap().do_get_response = Some(response);
}

fn take_do_get_request(&self) -> Option<Ticket> {
self.state.lock().unwrap().do_get_request.take()
}

fn take_last_request_metadata(&self) -> Option<MetadataMap> {
self.state.lock().unwrap().last_request_metadata.take()
}
}

#[tonic::async_trait]
impl FlightService for InterceptorTestServer {
type HandshakeStream = BoxStream<'static, Result<HandshakeResponse, Status>>;
type ListFlightsStream = BoxStream<'static, Result<FlightInfo, Status>>;
type DoGetStream = BoxStream<'static, Result<FlightData, Status>>;
type DoPutStream = BoxStream<'static, Result<PutResult, Status>>;
type DoActionStream = BoxStream<'static, Result<crate::Result, Status>>;
type ListActionsStream = BoxStream<'static, Result<ActionType, Status>>;
type DoExchangeStream = BoxStream<'static, Result<FlightData, Status>>;

async fn do_get(
&self,
request: Request<Ticket>,
) -> Result<Response<Self::DoGetStream>, Status> {
self.save_metadata(&request);
let mut state = self.state.lock().unwrap();
state.do_get_request = Some(request.into_inner());

let batches = state
.do_get_response
.take()
.ok_or_else(|| Status::internal("no do_get response configured"))?;
let batch_stream = futures::stream::iter(batches).map_err(Into::into);
let stream = FlightDataEncoderBuilder::new()
.build(batch_stream)
.map_err(Into::into);
Ok(Response::new(stream.boxed()))
}

async fn handshake(
&self,
_: Request<Streaming<HandshakeRequest>>,
) -> Result<Response<Self::HandshakeStream>, Status> {
Err(Status::unimplemented(""))
}
async fn list_flights(
&self,
_: Request<Criteria>,
) -> Result<Response<Self::ListFlightsStream>, Status> {
Err(Status::unimplemented(""))
}
async fn get_flight_info(
&self,
_: Request<FlightDescriptor>,
) -> Result<Response<FlightInfo>, Status> {
Err(Status::unimplemented(""))
}
async fn poll_flight_info(
&self,
_: Request<FlightDescriptor>,
) -> Result<Response<PollInfo>, Status> {
Err(Status::unimplemented(""))
}
async fn get_schema(
&self,
_: Request<FlightDescriptor>,
) -> Result<Response<SchemaResult>, Status> {
Err(Status::unimplemented(""))
}
async fn do_put(
&self,
_: Request<Streaming<FlightData>>,
) -> Result<Response<Self::DoPutStream>, Status> {
Err(Status::unimplemented(""))
}
async fn do_action(
&self,
_: Request<Action>,
) -> Result<Response<Self::DoActionStream>, Status> {
Err(Status::unimplemented(""))
}
async fn list_actions(
&self,
_: Request<Empty>,
) -> Result<Response<Self::ListActionsStream>, Status> {
Err(Status::unimplemented(""))
}
async fn do_exchange(
&self,
_: Request<Streaming<FlightData>>,
) -> Result<Response<Self::DoExchangeStream>, Status> {
Err(Status::unimplemented(""))
}
}

/// Spawns the test server on a background task and exposes a connected channel.
struct InterceptorTestFixture {
shutdown: Option<tokio::sync::oneshot::Sender<()>>,
addr: SocketAddr,
handle: Option<JoinHandle<Result<(), tonic::transport::Error>>>,
}

impl InterceptorTestFixture {
async fn new(server: FlightServiceServer<InterceptorTestServer>) -> Self {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
let (tx, rx) = tokio::sync::oneshot::channel();
let shutdown_future = async move {
rx.await.ok();
};
let serve = tonic::transport::Server::builder()
.timeout(Duration::from_secs(30))
.add_service(server)
.serve_with_incoming_shutdown(
tokio_stream::wrappers::TcpListenerStream::new(listener),
shutdown_future,
);
let handle = tokio::task::spawn(serve);
Self {
shutdown: Some(tx),
addr,
handle: Some(handle),
}
}

async fn channel(&self) -> Channel {
let url = format!("http://{}", self.addr);
tonic::transport::Endpoint::from_shared(url)
.expect("valid endpoint")
.timeout(Duration::from_secs(30))
.connect()
.await
.expect("error connecting to server")
}

async fn shutdown_and_wait(mut self) {
if let Some(tx) = self.shutdown.take() {
tx.send(()).expect("server quit early");
}
if let Some(handle) = self.handle.take() {
handle
.await
.expect("task join error (panic?)")
.expect("server error at shutdown");
}
}
}

/// Integration test: a tonic [`Channel`] wrapped in an [`InterceptedService`]
/// that injects a custom header is passed to [`FlightClient`], and the server
/// observes the header on the request.
#[tokio::test]
async fn test_flight_client_with_intercepted_channel_passes_custom_header() {
let test_server = InterceptorTestServer::default();
let fixture =
InterceptorTestFixture::new(FlightServiceServer::new(test_server.clone())).await;

let channel = fixture.channel().await;

let header_name = "x-random-header";
let header_value = format!("random-{}", Uuid::new_v4());
let header_value_for_interceptor = header_value.clone();

let interceptor = move |mut req: Request<()>| -> Result<Request<()>, Status> {
req.metadata_mut().insert(
header_name,
header_value_for_interceptor
.parse()
.expect("valid metadata value"),
);
Ok(req)
};

let intercepted = InterceptedService::new(channel, interceptor);
let mut client = FlightClient::new(intercepted);

let ticket = Ticket {
ticket: Bytes::from("dummy-ticket"),
};

let batch = RecordBatch::try_from_iter(vec![(
"col",
Arc::new(UInt64Array::from_iter([1, 2, 3, 4])) as _,
)])
.unwrap();

test_server.set_do_get_response(vec![Ok(batch.clone())]);

let response_stream = client
.do_get(ticket.clone())
.await
.expect("error making do_get request");

let response: Vec<RecordBatch> = response_stream
.try_collect()
.await
.expect("error streaming data");

assert_eq!(response, vec![batch]);
assert_eq!(test_server.take_do_get_request(), Some(ticket));

let metadata = test_server
.take_last_request_metadata()
.expect("server received headers")
.into_headers();

let received = metadata
.get(header_name)
.expect("interceptor header missing on server")
.to_str()
.expect("ascii header value");
assert_eq!(received, header_value);

fixture.shutdown_and_wait().await;
}
}
Loading