Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Apache Arrow IPC streaming format support #39

Merged
merged 4 commits into from
Mar 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
9 changes: 8 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ path = "src/lib.rs"
[dependencies]
axum = { version = "0.7" }
bytes = "1"
futures-util = { version = "0.3", default-features = false, features = ["alloc"] }
http = "1"
http-body = "1"
mime = "0.3"
Expand All @@ -35,12 +34,14 @@ tokio-util = { version = "0.7" }
futures = "0.3"
csv = { version = "1.3", optional = true }
prost = { version= "0.12", optional = true }
arrow = { version = "51", features = ["ipc"], optional = true }

[features]
default = []
json = ["dep:serde", "dep:serde_json"]
csv = ["dep:csv", "dep:serde"]
protobuf = ["dep:prost"]
arrow = ["dep:arrow"]
text = []

[dev-dependencies]
Expand All @@ -53,6 +54,7 @@ tower-layer = "0.3"
tower-service = "0.3"
tokio = { version = "1", features = ["full"] }
prost = { version= "0.12", features = ["prost-derive"] }
arrow = { version = "51", features = ["ipc"] }

[[example]]
name = "json-example"
Expand All @@ -79,5 +81,10 @@ name = "json-array-complex-structure"
path = "examples/json-array-complex-structure.rs"
required-features = ["json"]

[[example]]
name = "arrow-example"
path = "examples/arrow-example.rs"
required-features = ["arrow"]

[build-dependencies]
cargo-husky = { version = "1.5", default-features = false, features = ["run-for-all", "prepush-hook", "run-cargo-fmt"] }
51 changes: 51 additions & 0 deletions examples/arrow-example.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
use arrow::array::*;
use arrow::datatypes::*;
use axum::response::IntoResponse;
use axum::routing::*;
use axum::Router;
use std::sync::Arc;

use futures::prelude::*;
use tokio::net::TcpListener;
use tokio_stream::StreamExt;

use axum_streams::*;

fn source_test_stream(schema: Arc<Schema>) -> impl Stream<Item = RecordBatch> {
// Simulating a stream with a plain vector and throttling to show how it works
stream::iter((0i64..10i64).map(move |idx| {
RecordBatch::try_new(
schema.clone(),
vec![
Arc::new(Int64Array::from(vec![idx, idx * 2, idx * 3])),
Arc::new(StringArray::from(vec!["New York", "London", "Gothenburg"])),
Arc::new(Float64Array::from(vec![40.7128, 51.5074, 57.7089])),
Arc::new(Float64Array::from(vec![-74.0060, -0.1278, 11.9746])),
],
)
.unwrap()
}))
.throttle(std::time::Duration::from_millis(50))
}

async fn test_text_stream() -> impl IntoResponse {
let schema = Arc::new(Schema::new(vec![
Field::new("id", DataType::Int64, false),
Field::new("city", DataType::Utf8, false),
Field::new("lat", DataType::Float64, false),
Field::new("lng", DataType::Float64, false),
]));
StreamBodyAs::arrow_ipc(schema.clone(), source_test_stream(schema.clone()))
}

#[tokio::main]
async fn main() {
// build our application with a route
let app = Router::new()
// `GET /` goes to `root`
.route("/arrow-stream", get(test_text_stream));

let listener = TcpListener::bind("127.0.0.1:8080").await.unwrap();

axum::serve(listener, app).await.unwrap();
}
162 changes: 162 additions & 0 deletions src/arrow_format.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,162 @@
use crate::StreamingFormat;
use arrow::array::RecordBatch;
use arrow::datatypes::{Schema, SchemaRef};
use arrow::ipc::writer::{IpcWriteOptions, StreamWriter};
use bytes::{BufMut, BytesMut};
use futures::stream::BoxStream;
use futures::Stream;
use futures::StreamExt;
use http::HeaderMap;
use http_body::Frame;
use std::sync::Arc;

pub struct ArrowRecordBatchIpcStreamFormat {
schema: SchemaRef,
options: IpcWriteOptions,
}

impl ArrowRecordBatchIpcStreamFormat {
pub fn new(schema: Arc<Schema>) -> Self {
Self::with_options(schema, IpcWriteOptions::default())
}

pub fn with_options(schema: Arc<Schema>, options: IpcWriteOptions) -> Self {
Self {
schema: schema.clone(),
options: options.clone(),
}
}
}

impl StreamingFormat<RecordBatch> for ArrowRecordBatchIpcStreamFormat {
fn to_bytes_stream<'a, 'b>(
&'a self,
stream: BoxStream<'b, RecordBatch>,
) -> BoxStream<'b, Result<Frame<axum::body::Bytes>, axum::Error>> {
let schema = self.schema.clone();
let options = self.options.clone();

let stream_bytes: BoxStream<Result<Frame<axum::body::Bytes>, axum::Error>> = Box::pin({
stream.map(move |batch| {
let buf = BytesMut::new().writer();
let mut writer = StreamWriter::try_new_with_options(buf, &schema, options.clone())
.map_err(axum::Error::new)?;
writer.write(&batch).map_err(axum::Error::new)?;
writer.finish().map_err(axum::Error::new)?;
writer
.into_inner()
.map_err(axum::Error::new)
.map(|buf| buf.into_inner().freeze())
.map(axum::body::Bytes::from)
.map(Frame::data)
})
});

Box::pin(stream_bytes)
}

fn http_response_trailers(&self) -> Option<HeaderMap> {
let mut header_map = HeaderMap::new();
header_map.insert(
http::header::CONTENT_TYPE,
http::header::HeaderValue::from_static("application/vnd.apache.arrow.stream"),
);
Some(header_map)
}
}

impl<'a> crate::StreamBodyAs<'a> {
pub fn arrow_ipc<S>(schema: SchemaRef, stream: S) -> Self
where
S: Stream<Item = RecordBatch> + 'a + Send,
{
Self::new(ArrowRecordBatchIpcStreamFormat::new(schema), stream)
}

pub fn arrow_ipc_with_options<S>(schema: SchemaRef, stream: S, options: IpcWriteOptions) -> Self
where
S: Stream<Item = RecordBatch> + 'a + Send,
{
Self::new(
ArrowRecordBatchIpcStreamFormat::with_options(schema, options),
stream,
)
}
}

#[cfg(test)]
mod tests {
use super::*;
use crate::test_client::*;
use crate::StreamBodyAs;
use arrow::array::*;
use arrow::datatypes::*;
use axum::{routing::*, Router};
use futures::stream;
use std::sync::Arc;

#[tokio::test]
async fn serialize_arrow_stream_format() {
let schema = Arc::new(Schema::new(vec![
Field::new("id", DataType::Int64, false),
Field::new("city", DataType::Utf8, false),
Field::new("lat", DataType::Float64, false),
Field::new("lng", DataType::Float64, false),
]));

fn create_test_batch(schema_ref: SchemaRef) -> Vec<RecordBatch> {
let vec_schema = schema_ref.clone();
(0i64..10i64)
.map(move |idx| {
RecordBatch::try_new(
vec_schema.clone(),
vec![
Arc::new(Int64Array::from(vec![idx, idx * 2, idx * 3])),
Arc::new(StringArray::from(vec!["New York", "London", "Gothenburg"])),
Arc::new(Float64Array::from(vec![40.7128, 51.5074, 57.7089])),
Arc::new(Float64Array::from(vec![-74.0060, -0.1278, 11.9746])),
],
)
.unwrap()
})
.collect()
}

let test_stream = Box::pin(stream::iter(create_test_batch(schema.clone())));

let app_schema = schema.clone();

let app = Router::new().route(
"/",
get(|| async move {
StreamBodyAs::new(
ArrowRecordBatchIpcStreamFormat::new(app_schema.clone()),
test_stream,
)
}),
);

let client = TestClient::new(app).await;

let expected_buf: Vec<u8> = create_test_batch(schema.clone())
.iter()
.flat_map(|batch| {
let mut writer = StreamWriter::try_new(Vec::new(), &schema).expect("writer failed");
writer.write(&batch).expect("write failed");
writer.finish().expect("writer failed");
writer.into_inner().expect("writer failed")
})
.collect();

let res = client.get("/").send().await.unwrap();
assert_eq!(
res.headers()
.get("content-type")
.and_then(|h| h.to_str().ok()),
Some("application/vnd.apache.arrow.stream")
);
let body = res.bytes().await.unwrap().to_vec();

assert_eq!(body, expected_buf);
}
}
6 changes: 3 additions & 3 deletions src/csv_format.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use crate::stream_format::StreamingFormat;
use futures::stream::BoxStream;
use futures::Stream;
use futures_util::stream::BoxStream;
use futures_util::StreamExt;
use futures::StreamExt;
use http::HeaderMap;
use http_body::Frame;
use serde::Serialize;
Expand Down Expand Up @@ -159,7 +159,7 @@ mod tests {
use crate::test_client::*;
use crate::StreamBodyAs;
use axum::{routing::*, Router};
use futures_util::stream;
use futures::stream;
use std::ops::Add;

#[tokio::test]
Expand Down
10 changes: 5 additions & 5 deletions src/json_formats.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
use crate::stream_format::StreamingFormat;
use crate::StreamFormatEnvelope;
use bytes::{BufMut, BytesMut};
use futures::stream::BoxStream;
use futures::Stream;
use futures_util::stream::BoxStream;
use futures_util::StreamExt;
use futures::StreamExt;
use http::HeaderMap;
use http_body::Frame;
use serde::Serialize;
Expand Down Expand Up @@ -66,7 +66,7 @@ where
});

let prepend_stream: BoxStream<Result<Frame<axum::body::Bytes>, axum::Error>> =
Box::pin(futures_util::stream::once(futures_util::future::ready({
Box::pin(futures::stream::once(futures::future::ready({
if let Some(envelope) = &self.envelope {
match serde_json::to_vec(&envelope.object) {
Ok(envelope_bytes) if envelope_bytes.len() > 1 => {
Expand Down Expand Up @@ -108,7 +108,7 @@ where
})));

let append_stream: BoxStream<Result<Frame<axum::body::Bytes>, axum::Error>> =
Box::pin(futures_util::stream::once(futures_util::future::ready({
Box::pin(futures::stream::once(futures::future::ready({
if self.envelope.is_some() {
Ok::<_, axum::Error>(Frame::data(axum::body::Bytes::from(
JSON_ARRAY_ENVELOP_END_BYTES,
Expand Down Expand Up @@ -216,7 +216,7 @@ mod tests {
use crate::test_client::*;
use crate::StreamBodyAs;
use axum::{routing::*, Router};
use futures_util::stream;
use futures::stream;

#[tokio::test]
async fn serialize_json_array_stream_format() {
Expand Down
5 changes: 5 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -100,5 +100,10 @@ mod protobuf_format;
#[cfg(feature = "protobuf")]
pub use protobuf_format::ProtobufStreamFormat;

#[cfg(feature = "arrow")]
mod arrow_format;
#[cfg(feature = "arrow")]
pub use arrow_format::ArrowRecordBatchIpcStreamFormat;

#[cfg(test)]
mod test_client;
6 changes: 3 additions & 3 deletions src/protobuf_format.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use crate::stream_format::StreamingFormat;
use futures::stream::BoxStream;
use futures::Stream;
use futures_util::stream::BoxStream;
use futures_util::StreamExt;
use futures::StreamExt;
use http::HeaderMap;
use http_body::Frame;

Expand Down Expand Up @@ -72,7 +72,7 @@ mod tests {
use crate::test_client::*;
use crate::StreamBodyAs;
use axum::{routing::*, Router};
use futures_util::stream;
use futures::stream;
use prost::Message;

#[tokio::test]
Expand Down