diff --git a/arrow-flight/src/client.rs b/arrow-flight/src/client.rs index dac086271cb7..b2059a81d0df 100644 --- a/arrow-flight/src/client.rs +++ b/arrow-flight/src/client.rs @@ -31,6 +31,7 @@ use futures::{ stream::{self, BoxStream}, }; use prost::Message; +use tonic::codegen::{Body, StdError}; use tonic::{metadata::MetadataMap, transport::Channel}; use crate::error::{FlightError, Result}; @@ -67,22 +68,28 @@ use crate::streams::{FallibleRequestStream, FallibleTonicResponseStream}; /// # } /// ``` #[derive(Debug)] -pub struct FlightClient { +pub struct FlightClient { /// Optional grpc header metadata to include with each request metadata: MetadataMap, /// The inner client - inner: FlightServiceClient, + inner: FlightServiceClient, } -impl FlightClient { - /// Creates a client client with the provided [`Channel`] - pub fn new(channel: Channel) -> Self { - Self::new_from_inner(FlightServiceClient::new(channel)) +impl FlightClient +where + T: tonic::client::GrpcService, + T::Error: Into, + T::ResponseBody: Body + std::marker::Send + 'static, + ::Error: Into + std::marker::Send, +{ + /// Creates a client with the provided transport + pub fn new(inner: T) -> Self { + Self::new_from_inner(FlightServiceClient::new(inner)) } /// Creates a new higher level client with the provided lower level client - pub fn new_from_inner(inner: FlightServiceClient) -> Self { + pub fn new_from_inner(inner: FlightServiceClient) -> Self { Self { metadata: MetadataMap::new(), inner, @@ -120,19 +127,19 @@ impl FlightClient { /// Return a reference to the underlying tonic /// [`FlightServiceClient`] - pub fn inner(&self) -> &FlightServiceClient { + pub fn inner(&self) -> &FlightServiceClient { &self.inner } /// Return a mutable reference to the underlying tonic /// [`FlightServiceClient`] - pub fn inner_mut(&mut self) -> &mut FlightServiceClient { + pub fn inner_mut(&mut self) -> &mut FlightServiceClient { &mut self.inner } /// Consume this client and return the underlying tonic /// [`FlightServiceClient`] - pub fn into_inner(self) -> FlightServiceClient { + pub fn into_inner(self) -> FlightServiceClient { self.inner } @@ -664,10 +671,272 @@ impl FlightClient { } /// return a Request, adding any configured metadata - fn make_request(&self, t: T) -> tonic::Request { + fn make_request(&self, t: R) -> tonic::Request { // Pass along metadata let mut request = tonic::Request::new(t); *request.metadata_mut() = self.metadata.clone(); request } } + +#[cfg(test)] +mod tests { + use super::FlightClient; + use crate::encode::FlightDataEncoderBuilder; + use crate::flight_service_server::{FlightService, FlightServiceServer}; + use crate::{ + Action, ActionType, Criteria, Empty, FlightData, FlightDescriptor, FlightInfo, + HandshakeRequest, HandshakeResponse, PollInfo, PutResult, SchemaResult, Ticket, + }; + use arrow_array::{RecordBatch, UInt64Array}; + use bytes::Bytes; + use futures::{StreamExt, TryStreamExt, stream::BoxStream}; + use std::net::SocketAddr; + use std::sync::{Arc, Mutex}; + use std::time::Duration; + use tokio::net::TcpListener; + use tokio::task::JoinHandle; + use tonic::metadata::MetadataMap; + use tonic::service::interceptor::InterceptedService; + use tonic::transport::Channel; + use tonic::{Request, Response, Status, Streaming}; + use uuid::Uuid; + + /// Minimal `FlightService` that records request metadata and serves a + /// configured `do_get` response. Other RPCs return `Unimplemented`. + #[derive(Debug, Clone, Default)] + struct InterceptorTestServer { + state: Arc>, + } + + #[derive(Debug, Default)] + struct InterceptorTestState { + do_get_request: Option, + do_get_response: Option>>, + last_request_metadata: Option, + } + + impl InterceptorTestServer { + fn save_metadata(&self, request: &Request) { + self.state.lock().unwrap().last_request_metadata = Some(request.metadata().clone()); + } + + fn set_do_get_response(&self, response: Vec>) { + self.state.lock().unwrap().do_get_response = Some(response); + } + + fn take_do_get_request(&self) -> Option { + self.state.lock().unwrap().do_get_request.take() + } + + fn take_last_request_metadata(&self) -> Option { + self.state.lock().unwrap().last_request_metadata.take() + } + } + + #[tonic::async_trait] + impl FlightService for InterceptorTestServer { + type HandshakeStream = BoxStream<'static, Result>; + type ListFlightsStream = BoxStream<'static, Result>; + type DoGetStream = BoxStream<'static, Result>; + type DoPutStream = BoxStream<'static, Result>; + type DoActionStream = BoxStream<'static, Result>; + type ListActionsStream = BoxStream<'static, Result>; + type DoExchangeStream = BoxStream<'static, Result>; + + async fn do_get( + &self, + request: Request, + ) -> Result, Status> { + self.save_metadata(&request); + let mut state = self.state.lock().unwrap(); + state.do_get_request = Some(request.into_inner()); + + let batches = state + .do_get_response + .take() + .ok_or_else(|| Status::internal("no do_get response configured"))?; + let batch_stream = futures::stream::iter(batches).map_err(Into::into); + let stream = FlightDataEncoderBuilder::new() + .build(batch_stream) + .map_err(Into::into); + Ok(Response::new(stream.boxed())) + } + + async fn handshake( + &self, + _: Request>, + ) -> Result, Status> { + Err(Status::unimplemented("")) + } + async fn list_flights( + &self, + _: Request, + ) -> Result, Status> { + Err(Status::unimplemented("")) + } + async fn get_flight_info( + &self, + _: Request, + ) -> Result, Status> { + Err(Status::unimplemented("")) + } + async fn poll_flight_info( + &self, + _: Request, + ) -> Result, Status> { + Err(Status::unimplemented("")) + } + async fn get_schema( + &self, + _: Request, + ) -> Result, Status> { + Err(Status::unimplemented("")) + } + async fn do_put( + &self, + _: Request>, + ) -> Result, Status> { + Err(Status::unimplemented("")) + } + async fn do_action( + &self, + _: Request, + ) -> Result, Status> { + Err(Status::unimplemented("")) + } + async fn list_actions( + &self, + _: Request, + ) -> Result, Status> { + Err(Status::unimplemented("")) + } + async fn do_exchange( + &self, + _: Request>, + ) -> Result, Status> { + Err(Status::unimplemented("")) + } + } + + /// Spawns the test server on a background task and exposes a connected channel. + struct InterceptorTestFixture { + shutdown: Option>, + addr: SocketAddr, + handle: Option>>, + } + + impl InterceptorTestFixture { + async fn new(server: FlightServiceServer) -> Self { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let addr = listener.local_addr().unwrap(); + let (tx, rx) = tokio::sync::oneshot::channel(); + let shutdown_future = async move { + rx.await.ok(); + }; + let serve = tonic::transport::Server::builder() + .timeout(Duration::from_secs(30)) + .add_service(server) + .serve_with_incoming_shutdown( + tokio_stream::wrappers::TcpListenerStream::new(listener), + shutdown_future, + ); + let handle = tokio::task::spawn(serve); + Self { + shutdown: Some(tx), + addr, + handle: Some(handle), + } + } + + async fn channel(&self) -> Channel { + let url = format!("http://{}", self.addr); + tonic::transport::Endpoint::from_shared(url) + .expect("valid endpoint") + .timeout(Duration::from_secs(30)) + .connect() + .await + .expect("error connecting to server") + } + + async fn shutdown_and_wait(mut self) { + if let Some(tx) = self.shutdown.take() { + tx.send(()).expect("server quit early"); + } + if let Some(handle) = self.handle.take() { + handle + .await + .expect("task join error (panic?)") + .expect("server error at shutdown"); + } + } + } + + /// Integration test: a tonic [`Channel`] wrapped in an [`InterceptedService`] + /// that injects a custom header is passed to [`FlightClient`], and the server + /// observes the header on the request. + #[tokio::test] + async fn test_flight_client_with_intercepted_channel_passes_custom_header() { + let test_server = InterceptorTestServer::default(); + let fixture = + InterceptorTestFixture::new(FlightServiceServer::new(test_server.clone())).await; + + let channel = fixture.channel().await; + + let header_name = "x-random-header"; + let header_value = format!("random-{}", Uuid::new_v4()); + let header_value_for_interceptor = header_value.clone(); + + let interceptor = move |mut req: Request<()>| -> Result, Status> { + req.metadata_mut().insert( + header_name, + header_value_for_interceptor + .parse() + .expect("valid metadata value"), + ); + Ok(req) + }; + + let intercepted = InterceptedService::new(channel, interceptor); + let mut client = FlightClient::new(intercepted); + + let ticket = Ticket { + ticket: Bytes::from("dummy-ticket"), + }; + + let batch = RecordBatch::try_from_iter(vec![( + "col", + Arc::new(UInt64Array::from_iter([1, 2, 3, 4])) as _, + )]) + .unwrap(); + + test_server.set_do_get_response(vec![Ok(batch.clone())]); + + let response_stream = client + .do_get(ticket.clone()) + .await + .expect("error making do_get request"); + + let response: Vec = response_stream + .try_collect() + .await + .expect("error streaming data"); + + assert_eq!(response, vec![batch]); + assert_eq!(test_server.take_do_get_request(), Some(ticket)); + + let metadata = test_server + .take_last_request_metadata() + .expect("server received headers") + .into_headers(); + + let received = metadata + .get(header_name) + .expect("interceptor header missing on server") + .to_str() + .expect("ascii header value"); + assert_eq!(received, header_value); + + fixture.shutdown_and_wait().await; + } +}