Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement fallible streams for FlightClient::do_put #3464

Merged
merged 8 commits into from
Feb 23, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 44 additions & 14 deletions arrow-flight/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
// specific language governing permissions and limitations
// under the License.

use std::task::Poll;

use crate::{
decode::FlightRecordBatchStream, flight_service_client::FlightServiceClient, Action,
ActionType, Criteria, Empty, FlightData, FlightDescriptor, FlightInfo,
Expand All @@ -24,8 +26,9 @@ use arrow_schema::Schema;
use bytes::Bytes;
use futures::{
future::ready,
ready,
stream::{self, BoxStream},
Stream, StreamExt, TryStreamExt,
FutureExt, Stream, StreamExt, TryStreamExt,
};
use tonic::{metadata::MetadataMap, transport::Channel};

Expand Down Expand Up @@ -262,6 +265,15 @@ impl FlightClient {
/// [`Stream`](futures::Stream) of [`FlightData`] and returning a
/// stream of [`PutResult`].
///
/// # Note
///
/// The input stream is [`Result`] so that this can be connected
/// to a streaming data source, such as [`FlightDataEncoder`](crate::encode::FlightDataEncoder),
/// without having to buffer. If the input stream returns an error
/// that error will not be sent to the server, instead it will be
/// placed into the result stream and the server connection
/// terminated.
///
/// # Example:
/// ```no_run
/// # async fn run() {
Expand All @@ -279,9 +291,7 @@ impl FlightClient {
///
/// // encode the batch as a stream of `FlightData`
/// let flight_data_stream = FlightDataEncoderBuilder::new()
/// .build(futures::stream::iter(vec![Ok(batch)]))
/// // data encoder return Results, but do_put requires FlightData
/// .map(|batch|batch.unwrap());
/// .build(futures::stream::iter(vec![Ok(batch)]));
///
/// // send the stream and get the results as `PutResult`
/// let response: Vec<PutResult>= client
Expand All @@ -293,20 +303,40 @@ impl FlightClient {
/// .expect("error calling do_put");
/// # }
/// ```
pub async fn do_put<S: Stream<Item = FlightData> + Send + 'static>(
pub async fn do_put<S: Stream<Item = Result<FlightData>> + Send + 'static>(
&mut self,
request: S,
) -> Result<BoxStream<'static, Result<PutResult>>> {
let request = self.make_request(request);

let response = self
.inner
.do_put(request)
.await?
.into_inner()
.map_err(FlightError::Tonic);
let (sender, mut receiver) = futures::channel::oneshot::channel();

// Intercepts client errors and sends them to the oneshot channel above
let mut request = Box::pin(request); // Pin to heap
let mut sender = Some(sender); // Wrap into Option so can be taken
let request_stream = futures::stream::poll_fn(move |cx| {
Poll::Ready(match ready!(request.poll_next_unpin(cx)) {
Some(Ok(data)) => Some(data),
Some(Err(e)) => {
let _ = sender.take().unwrap().send(e);
None
}
None => None,
})
});

let request = self.make_request(request_stream);
let mut response_stream = self.inner.do_put(request).await?.into_inner();

// Forwards errors from the error oneshot with priority over responses from server
let error_stream = futures::stream::poll_fn(move |cx| {
if let Poll::Ready(Ok(err)) = receiver.poll_unpin(cx) {
return Poll::Ready(Some(Err(err)));
}
let next = ready!(response_stream.poll_next_unpin(cx));
Poll::Ready(next.map(|x| x.map_err(FlightError::Tonic)))
});

Ok(response.boxed())
// combine the response from the server and any error from the client
Ok(error_stream.boxed())
}

/// Make a `DoExchange` call to the server with the provided
Expand Down
99 changes: 92 additions & 7 deletions arrow-flight/tests/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -248,8 +248,10 @@ async fn test_do_put() {
test_server
.set_do_put_response(expected_response.clone().into_iter().map(Ok).collect());

let input_stream = futures::stream::iter(input_flight_data.clone()).map(Ok);

let response_stream = client
.do_put(futures::stream::iter(input_flight_data.clone()))
.do_put(input_stream)
.await
.expect("error making request");

Expand All @@ -266,15 +268,15 @@ async fn test_do_put() {
}

#[tokio::test]
async fn test_do_put_error() {
async fn test_do_put_error_server() {
do_test(|test_server, mut client| async move {
client.add_header("foo-header", "bar-header-value").unwrap();

let input_flight_data = test_flight_data().await;

let response = client
.do_put(futures::stream::iter(input_flight_data.clone()))
.await;
let input_stream = futures::stream::iter(input_flight_data.clone()).map(Ok);

let response = client.do_put(input_stream).await;
let response = match response {
Ok(_) => panic!("unexpected success"),
Err(e) => e,
Expand All @@ -290,7 +292,7 @@ async fn test_do_put_error() {
}

#[tokio::test]
async fn test_do_put_error_stream() {
async fn test_do_put_error_stream_server() {
do_test(|test_server, mut client| async move {
client.add_header("foo-header", "bar-header-value").unwrap();

Expand All @@ -307,8 +309,10 @@ async fn test_do_put_error_stream() {

test_server.set_do_put_response(response);

let input_stream = futures::stream::iter(input_flight_data.clone()).map(Ok);

let response_stream = client
.do_put(futures::stream::iter(input_flight_data.clone()))
.do_put(input_stream)
.await
.expect("error making request");

Expand All @@ -326,6 +330,87 @@ async fn test_do_put_error_stream() {
.await;
}

#[tokio::test]
async fn test_do_put_error_client() {
do_test(|test_server, mut client| async move {
client.add_header("foo-header", "bar-header-value").unwrap();

let e = Status::invalid_argument("bad arg: client");

// input stream to client sends good FlightData followed by an error
let input_flight_data = test_flight_data().await;
let input_stream = futures::stream::iter(input_flight_data.clone())
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added these two tests to show what is going on -- aka that the error provided to the client do_put call need to get to the returned stream even though the error did not come from the server

.map(Ok)
.chain(futures::stream::iter(vec![Err(FlightError::from(
e.clone(),
))]));

// server responds with one good message
let response = vec![Ok(PutResult {
app_metadata: Bytes::from("foo-metadata"),
})];
test_server.set_do_put_response(response);

let response_stream = client
.do_put(input_stream)
.await
.expect("error making request");

let response: Result<Vec<_>, _> = response_stream.try_collect().await;
let response = match response {
Ok(_) => panic!("unexpected success"),
Err(e) => e,
};

// expect to the error made from the client
expect_status(response, e);
// server still got the request messages until the client sent the error
assert_eq!(test_server.take_do_put_request(), Some(input_flight_data));
ensure_metadata(&client, &test_server);
})
.await;
}

#[tokio::test]
async fn test_do_put_error_client_and_server() {
do_test(|test_server, mut client| async move {
client.add_header("foo-header", "bar-header-value").unwrap();

let e_client = Status::invalid_argument("bad arg: client");
let e_server = Status::invalid_argument("bad arg: server");

// input stream to client sends good FlightData followed by an error
let input_flight_data = test_flight_data().await;
let input_stream = futures::stream::iter(input_flight_data.clone())
.map(Ok)
.chain(futures::stream::iter(vec![Err(FlightError::from(
e_client.clone(),
))]));

// server responds with an error (e.g. because it got truncated data)
let response = vec![Err(e_server)];
test_server.set_do_put_response(response);

let response_stream = client
.do_put(input_stream)
.await
.expect("error making request");

let response: Result<Vec<_>, _> = response_stream.try_collect().await;
let response = match response {
Ok(_) => panic!("unexpected success"),
Err(e) => e,
};

// expect to the error made from the client (not the server)
expect_status(response, e_client);
// server still got the request messages until the client sent the error
assert_eq!(test_server.take_do_put_request(), Some(input_flight_data));
ensure_metadata(&client, &test_server);
})
.await;
}

#[tokio::test]
async fn test_do_exchange() {
do_test(|test_server, mut client| async move {
Expand Down