From db0c9fa03968be84692256daed9f2b77e4e04a68 Mon Sep 17 00:00:00 2001 From: Rostislav Rumenov Date: Tue, 5 May 2026 15:01:07 +0000 Subject: [PATCH 1/3] generic channel --- arrow-flight/src/client.rs | 29 ++++++++++++++++++----------- 1 file changed, 18 insertions(+), 11 deletions(-) diff --git a/arrow-flight/src/client.rs b/arrow-flight/src/client.rs index dac086271cb7..660514acf14b 100644 --- a/arrow-flight/src/client.rs +++ b/arrow-flight/src/client.rs @@ -31,6 +31,7 @@ use futures::{ stream::{self, BoxStream}, }; use prost::Message; +use tonic::codegen::{Body, StdError}; use tonic::{metadata::MetadataMap, transport::Channel}; use crate::error::{FlightError, Result}; @@ -67,22 +68,28 @@ use crate::streams::{FallibleRequestStream, FallibleTonicResponseStream}; /// # } /// ``` #[derive(Debug)] -pub struct FlightClient { +pub struct FlightClient { /// Optional grpc header metadata to include with each request metadata: MetadataMap, /// The inner client - inner: FlightServiceClient, + inner: FlightServiceClient, } -impl FlightClient { - /// Creates a client client with the provided [`Channel`] - pub fn new(channel: Channel) -> Self { - Self::new_from_inner(FlightServiceClient::new(channel)) +impl FlightClient +where + T: tonic::client::GrpcService, + T::Error: Into, + T::ResponseBody: Body + std::marker::Send + 'static, + ::Error: Into + std::marker::Send, +{ + /// Creates a client with the provided transport + pub fn new(inner: T) -> Self { + Self::new_from_inner(FlightServiceClient::new(inner)) } /// Creates a new higher level client with the provided lower level client - pub fn new_from_inner(inner: FlightServiceClient) -> Self { + pub fn new_from_inner(inner: FlightServiceClient) -> Self { Self { metadata: MetadataMap::new(), inner, @@ -120,19 +127,19 @@ impl FlightClient { /// Return a reference to the underlying tonic /// [`FlightServiceClient`] - pub fn inner(&self) -> &FlightServiceClient { + pub fn inner(&self) -> &FlightServiceClient { &self.inner } /// Return a mutable reference to the underlying tonic /// [`FlightServiceClient`] - pub fn inner_mut(&mut self) -> &mut FlightServiceClient { + pub fn inner_mut(&mut self) -> &mut FlightServiceClient { &mut self.inner } /// Consume this client and return the underlying tonic /// [`FlightServiceClient`] - pub fn into_inner(self) -> FlightServiceClient { + pub fn into_inner(self) -> FlightServiceClient { self.inner } @@ -664,7 +671,7 @@ impl FlightClient { } /// return a Request, adding any configured metadata - fn make_request(&self, t: T) -> tonic::Request { + fn make_request(&self, t: R) -> tonic::Request { // Pass along metadata let mut request = tonic::Request::new(t); *request.metadata_mut() = self.metadata.clone(); From 718f1e05c4436df484608b13ecf42b138a3c6a00 Mon Sep 17 00:00:00 2001 From: Rostislav Rumenov Date: Wed, 6 May 2026 14:26:10 +0000 Subject: [PATCH 2/3] . --- arrow-flight/tests/client_interceptor.rs | 100 +++++++++++++++++++++++ 1 file changed, 100 insertions(+) create mode 100644 arrow-flight/tests/client_interceptor.rs diff --git a/arrow-flight/tests/client_interceptor.rs b/arrow-flight/tests/client_interceptor.rs new file mode 100644 index 000000000000..de39e9471808 --- /dev/null +++ b/arrow-flight/tests/client_interceptor.rs @@ -0,0 +1,100 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Integration test: a tonic [`Channel`] wrapped in an [`InterceptedService`] +//! that injects a custom header is passed to [`FlightClient`], and the server +//! observes the header on the request. + +mod common; + +use std::sync::Arc; + +use arrow_array::{RecordBatch, UInt64Array}; +use arrow_flight::{FlightClient, Ticket}; +use bytes::Bytes; +use futures::TryStreamExt; +use tonic::Request; +use tonic::Status; +use tonic::service::interceptor::InterceptedService; +use uuid::Uuid; + +use crate::common::fixture::TestFixture; +use crate::common::server::TestFlightServer; + +#[tokio::test] +async fn test_flight_client_with_intercepted_channel_passes_custom_header() { + let test_server = TestFlightServer::new(); + let fixture = TestFixture::new(test_server.service()).await; + + let channel = fixture.channel().await; + + let header_name = "x-random-header"; + let header_value = format!("random-{}", Uuid::new_v4()); + let header_value_for_interceptor = header_value.clone(); + + let interceptor = move |mut req: Request<()>| -> Result, Status> { + req.metadata_mut().insert( + header_name, + header_value_for_interceptor + .parse() + .expect("valid metadata value"), + ); + Ok(req) + }; + + let intercepted = InterceptedService::new(channel, interceptor); + let mut client = FlightClient::new(intercepted); + + let ticket = Ticket { + ticket: Bytes::from("dummy-ticket"), + }; + + let batch = RecordBatch::try_from_iter(vec![( + "col", + Arc::new(UInt64Array::from_iter([1, 2, 3, 4])) as _, + )]) + .unwrap(); + + test_server.set_do_get_response(vec![Ok(batch.clone())]); + + let response_stream = client + .do_get(ticket.clone()) + .await + .expect("error making do_get request"); + + let response: Vec = response_stream + .try_collect() + .await + .expect("error streaming data"); + + assert_eq!(response, vec![batch]); + assert_eq!(test_server.take_do_get_request(), Some(ticket)); + + let metadata = test_server + .take_last_request_metadata() + .expect("server received headers") + .into_headers(); + + let received = metadata + .get(header_name) + .expect("interceptor header missing on server") + .to_str() + .expect("ascii header value"); + assert_eq!(received, header_value); + + fixture.shutdown_and_wait().await +} From 8d066dcba81c4f5a4d27da68abbf335db199e74b Mon Sep 17 00:00:00 2001 From: Rostislav Rumenov Date: Thu, 7 May 2026 14:59:26 +0000 Subject: [PATCH 3/3] . --- arrow-flight/src/client.rs | 262 +++++++++++++++++++++++ arrow-flight/tests/client_interceptor.rs | 100 --------- 2 files changed, 262 insertions(+), 100 deletions(-) delete mode 100644 arrow-flight/tests/client_interceptor.rs diff --git a/arrow-flight/src/client.rs b/arrow-flight/src/client.rs index 660514acf14b..b2059a81d0df 100644 --- a/arrow-flight/src/client.rs +++ b/arrow-flight/src/client.rs @@ -678,3 +678,265 @@ where request } } + +#[cfg(test)] +mod tests { + use super::FlightClient; + use crate::encode::FlightDataEncoderBuilder; + use crate::flight_service_server::{FlightService, FlightServiceServer}; + use crate::{ + Action, ActionType, Criteria, Empty, FlightData, FlightDescriptor, FlightInfo, + HandshakeRequest, HandshakeResponse, PollInfo, PutResult, SchemaResult, Ticket, + }; + use arrow_array::{RecordBatch, UInt64Array}; + use bytes::Bytes; + use futures::{StreamExt, TryStreamExt, stream::BoxStream}; + use std::net::SocketAddr; + use std::sync::{Arc, Mutex}; + use std::time::Duration; + use tokio::net::TcpListener; + use tokio::task::JoinHandle; + use tonic::metadata::MetadataMap; + use tonic::service::interceptor::InterceptedService; + use tonic::transport::Channel; + use tonic::{Request, Response, Status, Streaming}; + use uuid::Uuid; + + /// Minimal `FlightService` that records request metadata and serves a + /// configured `do_get` response. Other RPCs return `Unimplemented`. + #[derive(Debug, Clone, Default)] + struct InterceptorTestServer { + state: Arc>, + } + + #[derive(Debug, Default)] + struct InterceptorTestState { + do_get_request: Option, + do_get_response: Option>>, + last_request_metadata: Option, + } + + impl InterceptorTestServer { + fn save_metadata(&self, request: &Request) { + self.state.lock().unwrap().last_request_metadata = Some(request.metadata().clone()); + } + + fn set_do_get_response(&self, response: Vec>) { + self.state.lock().unwrap().do_get_response = Some(response); + } + + fn take_do_get_request(&self) -> Option { + self.state.lock().unwrap().do_get_request.take() + } + + fn take_last_request_metadata(&self) -> Option { + self.state.lock().unwrap().last_request_metadata.take() + } + } + + #[tonic::async_trait] + impl FlightService for InterceptorTestServer { + type HandshakeStream = BoxStream<'static, Result>; + type ListFlightsStream = BoxStream<'static, Result>; + type DoGetStream = BoxStream<'static, Result>; + type DoPutStream = BoxStream<'static, Result>; + type DoActionStream = BoxStream<'static, Result>; + type ListActionsStream = BoxStream<'static, Result>; + type DoExchangeStream = BoxStream<'static, Result>; + + async fn do_get( + &self, + request: Request, + ) -> Result, Status> { + self.save_metadata(&request); + let mut state = self.state.lock().unwrap(); + state.do_get_request = Some(request.into_inner()); + + let batches = state + .do_get_response + .take() + .ok_or_else(|| Status::internal("no do_get response configured"))?; + let batch_stream = futures::stream::iter(batches).map_err(Into::into); + let stream = FlightDataEncoderBuilder::new() + .build(batch_stream) + .map_err(Into::into); + Ok(Response::new(stream.boxed())) + } + + async fn handshake( + &self, + _: Request>, + ) -> Result, Status> { + Err(Status::unimplemented("")) + } + async fn list_flights( + &self, + _: Request, + ) -> Result, Status> { + Err(Status::unimplemented("")) + } + async fn get_flight_info( + &self, + _: Request, + ) -> Result, Status> { + Err(Status::unimplemented("")) + } + async fn poll_flight_info( + &self, + _: Request, + ) -> Result, Status> { + Err(Status::unimplemented("")) + } + async fn get_schema( + &self, + _: Request, + ) -> Result, Status> { + Err(Status::unimplemented("")) + } + async fn do_put( + &self, + _: Request>, + ) -> Result, Status> { + Err(Status::unimplemented("")) + } + async fn do_action( + &self, + _: Request, + ) -> Result, Status> { + Err(Status::unimplemented("")) + } + async fn list_actions( + &self, + _: Request, + ) -> Result, Status> { + Err(Status::unimplemented("")) + } + async fn do_exchange( + &self, + _: Request>, + ) -> Result, Status> { + Err(Status::unimplemented("")) + } + } + + /// Spawns the test server on a background task and exposes a connected channel. + struct InterceptorTestFixture { + shutdown: Option>, + addr: SocketAddr, + handle: Option>>, + } + + impl InterceptorTestFixture { + async fn new(server: FlightServiceServer) -> Self { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let addr = listener.local_addr().unwrap(); + let (tx, rx) = tokio::sync::oneshot::channel(); + let shutdown_future = async move { + rx.await.ok(); + }; + let serve = tonic::transport::Server::builder() + .timeout(Duration::from_secs(30)) + .add_service(server) + .serve_with_incoming_shutdown( + tokio_stream::wrappers::TcpListenerStream::new(listener), + shutdown_future, + ); + let handle = tokio::task::spawn(serve); + Self { + shutdown: Some(tx), + addr, + handle: Some(handle), + } + } + + async fn channel(&self) -> Channel { + let url = format!("http://{}", self.addr); + tonic::transport::Endpoint::from_shared(url) + .expect("valid endpoint") + .timeout(Duration::from_secs(30)) + .connect() + .await + .expect("error connecting to server") + } + + async fn shutdown_and_wait(mut self) { + if let Some(tx) = self.shutdown.take() { + tx.send(()).expect("server quit early"); + } + if let Some(handle) = self.handle.take() { + handle + .await + .expect("task join error (panic?)") + .expect("server error at shutdown"); + } + } + } + + /// Integration test: a tonic [`Channel`] wrapped in an [`InterceptedService`] + /// that injects a custom header is passed to [`FlightClient`], and the server + /// observes the header on the request. + #[tokio::test] + async fn test_flight_client_with_intercepted_channel_passes_custom_header() { + let test_server = InterceptorTestServer::default(); + let fixture = + InterceptorTestFixture::new(FlightServiceServer::new(test_server.clone())).await; + + let channel = fixture.channel().await; + + let header_name = "x-random-header"; + let header_value = format!("random-{}", Uuid::new_v4()); + let header_value_for_interceptor = header_value.clone(); + + let interceptor = move |mut req: Request<()>| -> Result, Status> { + req.metadata_mut().insert( + header_name, + header_value_for_interceptor + .parse() + .expect("valid metadata value"), + ); + Ok(req) + }; + + let intercepted = InterceptedService::new(channel, interceptor); + let mut client = FlightClient::new(intercepted); + + let ticket = Ticket { + ticket: Bytes::from("dummy-ticket"), + }; + + let batch = RecordBatch::try_from_iter(vec![( + "col", + Arc::new(UInt64Array::from_iter([1, 2, 3, 4])) as _, + )]) + .unwrap(); + + test_server.set_do_get_response(vec![Ok(batch.clone())]); + + let response_stream = client + .do_get(ticket.clone()) + .await + .expect("error making do_get request"); + + let response: Vec = response_stream + .try_collect() + .await + .expect("error streaming data"); + + assert_eq!(response, vec![batch]); + assert_eq!(test_server.take_do_get_request(), Some(ticket)); + + let metadata = test_server + .take_last_request_metadata() + .expect("server received headers") + .into_headers(); + + let received = metadata + .get(header_name) + .expect("interceptor header missing on server") + .to_str() + .expect("ascii header value"); + assert_eq!(received, header_value); + + fixture.shutdown_and_wait().await; + } +} diff --git a/arrow-flight/tests/client_interceptor.rs b/arrow-flight/tests/client_interceptor.rs deleted file mode 100644 index de39e9471808..000000000000 --- a/arrow-flight/tests/client_interceptor.rs +++ /dev/null @@ -1,100 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -//! Integration test: a tonic [`Channel`] wrapped in an [`InterceptedService`] -//! that injects a custom header is passed to [`FlightClient`], and the server -//! observes the header on the request. - -mod common; - -use std::sync::Arc; - -use arrow_array::{RecordBatch, UInt64Array}; -use arrow_flight::{FlightClient, Ticket}; -use bytes::Bytes; -use futures::TryStreamExt; -use tonic::Request; -use tonic::Status; -use tonic::service::interceptor::InterceptedService; -use uuid::Uuid; - -use crate::common::fixture::TestFixture; -use crate::common::server::TestFlightServer; - -#[tokio::test] -async fn test_flight_client_with_intercepted_channel_passes_custom_header() { - let test_server = TestFlightServer::new(); - let fixture = TestFixture::new(test_server.service()).await; - - let channel = fixture.channel().await; - - let header_name = "x-random-header"; - let header_value = format!("random-{}", Uuid::new_v4()); - let header_value_for_interceptor = header_value.clone(); - - let interceptor = move |mut req: Request<()>| -> Result, Status> { - req.metadata_mut().insert( - header_name, - header_value_for_interceptor - .parse() - .expect("valid metadata value"), - ); - Ok(req) - }; - - let intercepted = InterceptedService::new(channel, interceptor); - let mut client = FlightClient::new(intercepted); - - let ticket = Ticket { - ticket: Bytes::from("dummy-ticket"), - }; - - let batch = RecordBatch::try_from_iter(vec![( - "col", - Arc::new(UInt64Array::from_iter([1, 2, 3, 4])) as _, - )]) - .unwrap(); - - test_server.set_do_get_response(vec![Ok(batch.clone())]); - - let response_stream = client - .do_get(ticket.clone()) - .await - .expect("error making do_get request"); - - let response: Vec = response_stream - .try_collect() - .await - .expect("error streaming data"); - - assert_eq!(response, vec![batch]); - assert_eq!(test_server.take_do_get_request(), Some(ticket)); - - let metadata = test_server - .take_last_request_metadata() - .expect("server received headers") - .into_headers(); - - let received = metadata - .get(header_name) - .expect("interceptor header missing on server") - .to_str() - .expect("ascii header value"); - assert_eq!(received, header_value); - - fixture.shutdown_and_wait().await -}