Skip to content

Commit

Permalink
Use custom Any instead of prost_types (#3360)
Browse files Browse the repository at this point in the history
* Use custom Any instead of prost_types

* Remove unnecesary path prefix
  • Loading branch information
tustvold committed Dec 19, 2022
1 parent 8cab7a2 commit 0f196b8
Show file tree
Hide file tree
Showing 5 changed files with 73 additions and 58 deletions.
5 changes: 2 additions & 3 deletions arrow-flight/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -35,14 +35,13 @@ base64 = { version = "0.20", default-features = false, features = ["std"] }
tonic = { version = "0.8", default-features = false, features = ["transport", "codegen", "prost"] }
bytes = { version = "1", default-features = false }
prost = { version = "0.11", default-features = false }
prost-types = { version = "0.11.0", default-features = false, optional = true }
prost-derive = { version = "0.11", default-features = false }
tokio = { version = "1.0", default-features = false, features = ["macros", "rt", "rt-multi-thread"] }
futures = { version = "0.3", default-features = false, features = ["alloc"]}
futures = { version = "0.3", default-features = false, features = ["alloc"] }

[features]
default = []
flight-sql-experimental = ["prost-types"]
flight-sql-experimental = []

[dev-dependencies]
arrow = { version = "29.0.0", path = "../arrow", features = ["prettyprint"] }
Expand Down
11 changes: 6 additions & 5 deletions arrow-flight/examples/flight_sql_server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,14 @@

use arrow_array::builder::StringBuilder;
use arrow_array::{ArrayRef, RecordBatch};
use arrow_flight::sql::{ActionCreatePreparedStatementResult, ProstMessageExt, SqlInfo};
use arrow_flight::sql::{
ActionCreatePreparedStatementResult, Any, ProstMessageExt, SqlInfo,
};
use arrow_flight::{
Action, FlightData, FlightEndpoint, HandshakeRequest, HandshakeResponse, IpcMessage,
Location, SchemaAsIpc, Ticket,
};
use futures::{stream, Stream};
use prost_types::Any;
use std::fs;
use std::pin::Pin;
use std::sync::Arc;
Expand Down Expand Up @@ -124,7 +125,7 @@ impl FlightSqlService for FlightSqlServiceImpl {
async fn do_get_fallback(
&self,
_request: Request<Ticket>,
_message: prost_types::Any,
_message: Any,
) -> Result<Response<<Self as FlightService>::DoGetStream>, Status> {
let batch =
Self::fake_result().map_err(|e| status!("Could not fake a result", e))?;
Expand Down Expand Up @@ -474,9 +475,9 @@ impl ProstMessageExt for FetchResults {
}

fn as_any(&self) -> Any {
prost_types::Any {
Any {
type_url: FetchResults::type_url().to_string(),
value: ::prost::Message::encode_to_vec(self),
value: ::prost::Message::encode_to_vec(self).into(),
}
}
}
Expand Down
22 changes: 11 additions & 11 deletions arrow-flight/src/sql/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,12 @@ use crate::flight_service_client::FlightServiceClient;
use crate::sql::server::{CLOSE_PREPARED_STATEMENT, CREATE_PREPARED_STATEMENT};
use crate::sql::{
ActionClosePreparedStatementRequest, ActionCreatePreparedStatementRequest,
ActionCreatePreparedStatementResult, CommandGetCatalogs, CommandGetCrossReference,
CommandGetDbSchemas, CommandGetExportedKeys, CommandGetImportedKeys,
CommandGetPrimaryKeys, CommandGetSqlInfo, CommandGetTableTypes, CommandGetTables,
CommandPreparedStatementQuery, CommandStatementQuery, CommandStatementUpdate,
DoPutUpdateResult, ProstAnyExt, ProstMessageExt, SqlInfo,
ActionCreatePreparedStatementResult, Any, CommandGetCatalogs,
CommandGetCrossReference, CommandGetDbSchemas, CommandGetExportedKeys,
CommandGetImportedKeys, CommandGetPrimaryKeys, CommandGetSqlInfo,
CommandGetTableTypes, CommandGetTables, CommandPreparedStatementQuery,
CommandStatementQuery, CommandStatementUpdate, DoPutUpdateResult, ProstMessageExt,
SqlInfo,
};
use crate::{
Action, FlightData, FlightDescriptor, FlightInfo, HandshakeRequest,
Expand Down Expand Up @@ -177,8 +178,8 @@ impl FlightSqlServiceClient {
.await
.map_err(status_to_arrow_error)?
.unwrap();
let any: prost_types::Any = prost::Message::decode(&*result.app_metadata)
.map_err(decode_error_to_arrow_error)?;
let any =
Any::decode(&*result.app_metadata).map_err(decode_error_to_arrow_error)?;
let result: DoPutUpdateResult = any.unpack()?.unwrap();
Ok(result.record_count)
}
Expand Down Expand Up @@ -298,8 +299,7 @@ impl FlightSqlServiceClient {
.await
.map_err(status_to_arrow_error)?
.unwrap();
let any: prost_types::Any =
prost::Message::decode(&*result.body).map_err(decode_error_to_arrow_error)?;
let any = Any::decode(&*result.body).map_err(decode_error_to_arrow_error)?;
let prepared_result: ActionCreatePreparedStatementResult = any.unpack()?.unwrap();
let dataset_schema = match prepared_result.dataset_schema.len() {
0 => Schema::empty(),
Expand Down Expand Up @@ -384,8 +384,8 @@ impl PreparedStatement<Channel> {
.await
.map_err(status_to_arrow_error)?
.unwrap();
let any: prost_types::Any = Message::decode(&*result.app_metadata)
.map_err(decode_error_to_arrow_error)?;
let any =
Any::decode(&*result.app_metadata).map_err(decode_error_to_arrow_error)?;
let result: DoPutUpdateResult = any.unpack()?.unwrap();
Ok(result.record_count)
}
Expand Down
67 changes: 41 additions & 26 deletions arrow-flight/src/sql/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
// under the License.

use arrow_schema::ArrowError;
use bytes::Bytes;
use prost::Message;

mod gen {
Expand Down Expand Up @@ -66,8 +67,8 @@ pub trait ProstMessageExt: prost::Message + Default {
/// type_url for this Message
fn type_url() -> &'static str;

/// Convert this Message to prost_types::Any
fn as_any(&self) -> prost_types::Any;
/// Convert this Message to [`Any`]
fn as_any(&self) -> Any;
}

macro_rules! prost_message_ext {
Expand All @@ -78,10 +79,10 @@ macro_rules! prost_message_ext {
concat!("type.googleapis.com/arrow.flight.protocol.sql.", stringify!($name))
}

fn as_any(&self) -> prost_types::Any {
prost_types::Any {
fn as_any(&self) -> Any {
Any {
type_url: <$name>::type_url().to_string(),
value: self.encode_to_vec(),
value: self.encode_to_vec().into(),
}
}
}
Expand Down Expand Up @@ -111,30 +112,44 @@ prost_message_ext!(
TicketStatementQuery,
);

/// ProstAnyExt are useful utility methods for prost_types::Any
/// The API design is inspired by [rust-protobuf](https://github.com/stepancheg/rust-protobuf/blob/master/protobuf/src/well_known_types_util/any.rs)
pub trait ProstAnyExt {
/// Check if `Any` contains a message of given type.
fn is<M: ProstMessageExt>(&self) -> bool;

/// Extract a message from this `Any`.
///
/// # Returns
///
/// * `Ok(None)` when message type mismatch
/// * `Err` when parse failed
fn unpack<M: ProstMessageExt>(&self) -> Result<Option<M>, ArrowError>;

/// Pack any message into `prost_types::Any` value.
fn pack<M: ProstMessageExt>(message: &M) -> Result<prost_types::Any, ArrowError>;
/// An implementation of the protobuf [`Any`] message type
///
/// Encoded protobuf messages are not self-describing, nor contain any information
/// on the schema of the encoded payload. Consequently to decode a protobuf a client
/// must know the exact schema of the message.
///
/// This presents a problem for loosely typed APIs, where the exact message payloads
/// are not enumerable, and therefore cannot be enumerated as variants in a [oneof].
///
/// One solution is [`Any`] where the encoded payload is paired with a `type_url`
/// identifying the type of encoded message, and the resulting combination encoded.
///
/// Clients can then decode the outer [`Any`], inspect the `type_url` and if it is
/// a type they recognise, proceed to decode the embedded message `value`
///
/// [`Any`]: https://developers.google.com/protocol-buffers/docs/proto3#any
/// [oneof]: https://developers.google.com/protocol-buffers/docs/proto3#oneof
#[derive(Clone, PartialEq, ::prost::Message)]
pub struct Any {
/// A URL/resource name that uniquely identifies the type of the serialized
/// protocol buffer message. This string must contain at least
/// one "/" character. The last segment of the URL's path must represent
/// the fully qualified name of the type (as in
/// `path/google.protobuf.Duration`). The name should be in a canonical form
/// (e.g., leading "." is not accepted).
#[prost(string, tag = "1")]
pub type_url: String,
/// Must be a valid serialized protocol buffer of the above specified type.
#[prost(bytes = "bytes", tag = "2")]
pub value: Bytes,
}

impl ProstAnyExt for prost_types::Any {
fn is<M: ProstMessageExt>(&self) -> bool {
impl Any {
pub fn is<M: ProstMessageExt>(&self) -> bool {
M::type_url() == self.type_url
}

fn unpack<M: ProstMessageExt>(&self) -> Result<Option<M>, ArrowError> {
pub fn unpack<M: ProstMessageExt>(&self) -> Result<Option<M>, ArrowError> {
if !self.is::<M>() {
return Ok(None);
}
Expand All @@ -144,7 +159,7 @@ impl ProstAnyExt for prost_types::Any {
Ok(Some(m))
}

fn pack<M: ProstMessageExt>(message: &M) -> Result<prost_types::Any, ArrowError> {
pub fn pack<M: ProstMessageExt>(message: &M) -> Result<Any, ArrowError> {
Ok(message.as_any())
}
}
Expand All @@ -170,7 +185,7 @@ mod tests {
let query = CommandStatementQuery {
query: "select 1".to_string(),
};
let any = prost_types::Any::pack(&query).unwrap();
let any = Any::pack(&query).unwrap();
assert!(any.is::<CommandStatementQuery>());
let unpack_query: CommandStatementQuery = any.unpack().unwrap().unwrap();
assert_eq!(query, unpack_query);
Expand Down
26 changes: 13 additions & 13 deletions arrow-flight/src/sql/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

use std::pin::Pin;

use crate::sql::Any;
use futures::Stream;
use prost::Message;
use tonic::{Request, Response, Status, Streaming};
Expand All @@ -32,7 +33,7 @@ use super::{
CommandGetDbSchemas, CommandGetExportedKeys, CommandGetImportedKeys,
CommandGetPrimaryKeys, CommandGetSqlInfo, CommandGetTableTypes, CommandGetTables,
CommandPreparedStatementQuery, CommandPreparedStatementUpdate, CommandStatementQuery,
CommandStatementUpdate, DoPutUpdateResult, ProstAnyExt, ProstMessageExt, SqlInfo,
CommandStatementUpdate, DoPutUpdateResult, ProstMessageExt, SqlInfo,
TicketStatementQuery,
};

Expand Down Expand Up @@ -63,7 +64,7 @@ pub trait FlightSqlService: Sync + Send + Sized + 'static {
async fn do_get_fallback(
&self,
_request: Request<Ticket>,
message: prost_types::Any,
message: Any,
) -> Result<Response<<Self as FlightService>::DoGetStream>, Status> {
Err(Status::unimplemented(format!(
"do_get: The defined request is invalid: {}",
Expand Down Expand Up @@ -311,8 +312,8 @@ where
&self,
request: Request<FlightDescriptor>,
) -> Result<Response<FlightInfo>, Status> {
let message: prost_types::Any =
Message::decode(&*request.get_ref().cmd).map_err(decode_error_to_status)?;
let message =
Any::decode(&*request.get_ref().cmd).map_err(decode_error_to_status)?;

if message.is::<CommandStatementQuery>() {
let token = message
Expand Down Expand Up @@ -411,10 +412,10 @@ where
&self,
request: Request<Ticket>,
) -> Result<Response<Self::DoGetStream>, Status> {
let msg: prost_types::Any = Message::decode(&*request.get_ref().ticket)
let msg: Any = Message::decode(&*request.get_ref().ticket)
.map_err(decode_error_to_status)?;

fn unpack<T: ProstMessageExt>(msg: prost_types::Any) -> Result<T, Status> {
fn unpack<T: ProstMessageExt>(msg: Any) -> Result<T, Status> {
msg.unpack()
.map_err(arrow_error_to_status)?
.ok_or_else(|| Status::internal("Expected a command, but found none."))
Expand Down Expand Up @@ -462,9 +463,8 @@ where
mut request: Request<Streaming<FlightData>>,
) -> Result<Response<Self::DoPutStream>, Status> {
let cmd = request.get_mut().message().await?.unwrap();
let message: prost_types::Any =
Message::decode(&*cmd.flight_descriptor.unwrap().cmd)
.map_err(decode_error_to_status)?;
let message = Any::decode(&*cmd.flight_descriptor.unwrap().cmd)
.map_err(decode_error_to_status)?;
if message.is::<CommandStatementUpdate>() {
let token = message
.unpack()
Expand Down Expand Up @@ -536,8 +536,8 @@ where
request: Request<Action>,
) -> Result<Response<Self::DoActionStream>, Status> {
if request.get_ref().r#type == CREATE_PREPARED_STATEMENT {
let any: prost_types::Any = Message::decode(&*request.get_ref().body)
.map_err(decode_error_to_status)?;
let any =
Any::decode(&*request.get_ref().body).map_err(decode_error_to_status)?;

let cmd: ActionCreatePreparedStatementRequest = any
.unpack()
Expand All @@ -556,8 +556,8 @@ where
return Ok(Response::new(Box::pin(output)));
}
if request.get_ref().r#type == CLOSE_PREPARED_STATEMENT {
let any: prost_types::Any = Message::decode(&*request.get_ref().body)
.map_err(decode_error_to_status)?;
let any =
Any::decode(&*request.get_ref().body).map_err(decode_error_to_status)?;

let cmd: ActionClosePreparedStatementRequest = any
.unpack()
Expand Down

0 comments on commit 0f196b8

Please sign in to comment.