Skip to content

Commit

Permalink
Enable user-defined context structs for async UDFs
Browse files Browse the repository at this point in the history
Define an async trait 'UdfContext' that users can implement in their UDF
definition. The struct can store shared resources used by each
invocation of the UDF, like a database connection. An Arc pointer to the
context is passed as the first argument to the UDF.
  • Loading branch information
jbeisen committed Jan 13, 2024
1 parent 9085f50 commit 41cfe7a
Show file tree
Hide file tree
Showing 11 changed files with 176 additions and 65 deletions.
97 changes: 49 additions & 48 deletions Cargo.lock

Large diffs are not rendered by default.

52 changes: 47 additions & 5 deletions arroyo-datastream/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -387,6 +387,7 @@ pub enum Operator {
return_nullable: bool,
timeout_seconds: u64,
max_concurrency: u64,
has_context: bool,
},
}

Expand Down Expand Up @@ -1647,7 +1648,8 @@ impl Program {
null_handlers,
return_nullable,
timeout_seconds,
max_concurrency
max_concurrency,
has_context,
} => {
let in_k = parse_type(&input.unwrap().weight().key);
let in_t = parse_type(&input.unwrap().weight().value);
Expand All @@ -1673,17 +1675,38 @@ impl Program {
quote! ()
};

let context_type = if *has_context {
quote! {
udfs::Context
}
} else {
quote! {
EmptyContext
}
};

let udf_call = if *has_context {
quote! {
udfs::#fn_name(context.clone(), #args)
}
} else {
quote! {
udfs::#fn_name(#args)
}
};

let udf_wrapper = quote!({
use tokio::time::error::Elapsed;
use tokio::time::{timeout, Duration};
async fn wrapper(index: usize, in_data: #in_t) -> (usize, Result<#out_t, Elapsed>) {
use std::sync::Arc;
async fn wrapper(index: usize, in_data: #in_t, context: Arc<#context_type>) -> (usize, Result<#out_t, Elapsed>) {
#defs

#null_output

#null_handlers

let udf_result = timeout(Duration::from_secs(#timeout_seconds), udfs::#fn_name(#args)).await;
let udf_result = timeout(Duration::from_secs(#timeout_seconds), #udf_call).await;

let out = udf_result.map(
|udf_result| #out_t {
Expand All @@ -1696,9 +1719,24 @@ impl Program {
wrapper
});

let context = if *has_context {
quote! {
udfs::Context::new()
}
} else {
quote! {
EmptyContext {}
}
};

quote! {
Box::new(AsyncMapOperator::<#in_k, #in_t, #out_t, _, _>::
new(#name.to_string(), #udf_wrapper, #ordered, #max_concurrency)
Box::new(AsyncMapOperator::<#in_k, #in_t, #out_t, _, _, #context_type>::new(
#name.to_string(),
#udf_wrapper,
#context,
#ordered,
#max_concurrency
)
)
}
}
Expand Down Expand Up @@ -2191,6 +2229,7 @@ impl From<Operator> for GrpcApi::operator::Operator {
return_nullable,
timeout_seconds,
max_concurrency,
has_context,
} => GrpcOperator::AsyncMapOperator(GrpcApi::AsyncMapOperator {
name,
ordered,
Expand All @@ -2203,6 +2242,7 @@ impl From<Operator> for GrpcApi::operator::Operator {
return_nullable,
timeout_seconds,
max_concurrency,
has_context,
}),
Operator::ArrayMapOperator {
name,
Expand Down Expand Up @@ -2524,6 +2564,7 @@ impl TryFrom<arroyo_rpc::grpc::api::Operator> for Operator {
return_nullable,
timeout_seconds,
max_concurrency,
has_context,
}) => Operator::AsyncMapOperator {
name,
ordered,
Expand All @@ -2536,6 +2577,7 @@ impl TryFrom<arroyo_rpc::grpc::api::Operator> for Operator {
return_nullable,
timeout_seconds,
max_concurrency,
has_context,
},
GrpcOperator::FlattenExpressionOperator(flatten_expression) => {
let return_type = flatten_expression.return_type().into();
Expand Down
1 change: 1 addition & 0 deletions arroyo-rpc/proto/api.proto
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,7 @@ message AsyncMapOperator {
bool return_nullable = 9;
uint64 timeout_seconds = 10;
uint64 max_concurrency = 11;
bool has_context = 12;
}

message SlidingWindowAggregator {
Expand Down
1 change: 1 addition & 0 deletions arroyo-rpc/src/api_types/mod.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use async_trait::async_trait;
use checkpoints::*;
use connections::*;
use metrics::*;
Expand Down
1 change: 1 addition & 0 deletions arroyo-rpc/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ pub mod public_ids;
pub mod schema_resolver;
pub mod var_str;

use async_trait::async_trait;
use std::collections::HashMap;
use std::{fs, time::SystemTime};

Expand Down
2 changes: 2 additions & 0 deletions arroyo-sql/src/expressions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1007,6 +1007,7 @@ impl<'a> ExpressionContext<'a> {
ret_type: def.ret.clone(),
async_fn: def.async_fn,
opts: def.opts.clone(),
has_context: def.has_context,
};

if def.async_fn {
Expand Down Expand Up @@ -3689,6 +3690,7 @@ pub struct RustUdfExpression {
pub args: Vec<(TypeDef, Expression)>,
pub ret_type: TypeDef,
async_fn: bool,
pub has_context: bool,
pub opts: UdfOpts,
}

Expand Down
27 changes: 25 additions & 2 deletions arroyo-sql/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ pub struct UdfDef {
dependencies: String,
opts: UdfOpts,
async_fn: bool,
has_context: bool,
}

#[derive(Clone, Debug)]
Expand Down Expand Up @@ -275,9 +276,30 @@ impl ArroyoSchemaProvider {
};

let name = function.sig.ident.to_string();
let async_fn = function.sig.asyncness.is_some();
let mut args: Vec<TypeDef> = vec![];
let mut vec_arguments = 0;
for (i, arg) in function.sig.inputs.iter().enumerate() {

let inputs = function.sig.inputs.iter();
let mut skip = 0;
let mut has_context = false;

if async_fn {
// skip the first argument if it is a context
if function.sig.inputs.len() >= 1 {
if let FnArg::Typed(t) = function.sig.inputs.first().unwrap() {
if let syn::Pat::Ident(i) = &*t.pat {
if i.ident == "context" {
// TODO: how to ensure type is Arc<Context>?
has_context = true;
skip = 1
}
}
}
}
}

for (i, arg) in inputs.skip(skip).enumerate() {
match arg {
FnArg::Receiver(_) => {
bail!(
Expand Down Expand Up @@ -369,10 +391,11 @@ impl ArroyoSchemaProvider {
UdfDef {
args,
ret,
async_fn: function.sig.asyncness.is_some(),
async_fn,
def: unparse(&file.clone()),
dependencies: parse_dependencies(&body)?,
opts: parse_udf_opts(&body)?,
has_context,
},
);

Expand Down
3 changes: 3 additions & 0 deletions arroyo-sql/src/pipeline.rs
Original file line number Diff line number Diff line change
Expand Up @@ -259,6 +259,7 @@ impl RecordTransform {
a.async_udf.opts.async_results_ordered,
a.async_udf.opts.async_timeout_seconds,
a.async_udf.opts.async_max_concurrency,
a.async_udf.has_context,
)
}
}
Expand Down Expand Up @@ -1396,6 +1397,7 @@ impl MethodCompiler {
ordered: bool,
timeout_seconds: u64,
max_concurrency: u64,
has_context: bool,
) -> Operator {
Operator::AsyncMapOperator {
name: name.to_string(),
Expand All @@ -1409,6 +1411,7 @@ impl MethodCompiler {
return_nullable,
timeout_seconds,
max_concurrency,
has_context,
}
}
}
1 change: 1 addition & 0 deletions arroyo-types/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,4 @@ bincode = "2.0.0-rc.3"
serde = { version = "1.0", features = ["derive"] }
arrow = { workspace = true }
arrow-array = { workspace = true }
async-trait = "0.1.74"
7 changes: 7 additions & 0 deletions arroyo-types/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use arrow::datatypes::SchemaRef;
use arrow_array::RecordBatch;
use async_trait::async_trait;
use bincode::{config, Decode, Encode};
use serde::ser::SerializeStruct;
use serde::{Deserialize, Serialize};
Expand Down Expand Up @@ -932,3 +933,9 @@ mod tests {
);
}
}

#[async_trait]
pub trait UdfContext: Sync {
async fn init(&self) {}
async fn close(&self) {}
}
49 changes: 39 additions & 10 deletions arroyo-worker/src/operators/async_map.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,14 @@ use crate::engine::{Context, StreamNode};
use arroyo_macro::process_fn;
use arroyo_rpc::grpc::TableDescriptor;
use arroyo_types::Record;
use arroyo_types::{CheckpointBarrier, Data, Key};
use arroyo_types::{CheckpointBarrier, Data, Key, UdfContext};
use async_trait::async_trait;
use futures::stream::{FuturesOrdered, FuturesUnordered};
use futures::StreamExt;
use std::collections::VecDeque;
use std::future::Future;
use std::marker::PhantomData;
use std::sync::Arc;
use tokio::time::error::Elapsed;
use tracing::info;

Expand Down Expand Up @@ -58,18 +60,20 @@ pub struct AsyncMapOperator<
InT: Data,
OutT: Data,
FutureT: Future<Output = (usize, Result<OutT, Elapsed>)> + Send + 'static,
FnT: Fn(usize, InT) -> FutureT + Send + 'static,
FnT: Fn(usize, InT, Arc<ContextT>) -> FutureT + Send + 'static,
ContextT: UdfContext + Send + 'static,
> {
pub name: String,

pub udf: FnT,
pub futures: FuturesWrapper<FutureT>,
udf_context: Arc<ContextT>,
max_concurrency: u64,

next_id: usize, // i.e. inputs received so far, should start at 0
inputs: VecDeque<Option<Record<InKey, InT>>>,

_t: PhantomData<(InKey, InT, OutT)>,
_t: PhantomData<(InKey, InT, OutT, ContextT)>,
}

#[process_fn(in_k = InKey, in_t = InT, out_k = InKey, out_t = OutT, futures = "futures")]
Expand All @@ -78,10 +82,17 @@ impl<
InT: Data,
OutT: Data,
FutureT: Future<Output = (usize, Result<OutT, Elapsed>)> + Send + 'static,
FnT: Fn(usize, InT) -> FutureT + Send + 'static,
> AsyncMapOperator<InKey, InT, OutT, FutureT, FnT>
FnT: Fn(usize, InT, Arc<ContextT>) -> FutureT + Send + 'static,
ContextT: UdfContext + Send + 'static,
> AsyncMapOperator<InKey, InT, OutT, FutureT, FnT, ContextT>
{
pub fn new(name: String, udf: FnT, ordered: bool, max_concurrency: u64) -> Self {
pub fn new(
name: String,
udf: FnT,
context: ContextT,
ordered: bool,
max_concurrency: u64,
) -> Self {
let futures = if ordered {
info!("Using ordered futures");
FuturesWrapper {
Expand All @@ -98,6 +109,7 @@ impl<
name,
udf,
futures,
udf_context: Arc::new(context),
max_concurrency,
next_id: 0,
inputs: VecDeque::new(),
Expand All @@ -110,6 +122,8 @@ impl<
}

async fn on_start(&mut self, ctx: &mut Context<InKey, OutT>) {
self.udf_context.init().await;

let gs = ctx
.state
.get_global_keyed_state::<(usize, usize), Record<InKey, InT>>('a')
Expand All @@ -123,8 +137,11 @@ impl<
})
.for_each(|(_, v)| {
self.inputs.insert(self.next_id, Some(v.clone()));
self.futures
.push_back((self.udf)(self.next_id, v.value.clone()));
self.futures.push_back((self.udf)(
self.next_id,
v.value.clone(),
self.udf_context.clone(),
));
self.next_id += 1;
});
}
Expand All @@ -136,8 +153,11 @@ impl<
) {
self.inputs.push_back(Some(record.clone()));

self.futures
.push_back((self.udf)(self.next_id, record.value.clone()));
self.futures.push_back((self.udf)(
self.next_id,
record.value.clone(),
self.udf_context.clone(),
));
self.next_id += 1;
}

Expand Down Expand Up @@ -166,4 +186,13 @@ impl<
fn tables(&self) -> Vec<TableDescriptor> {
vec![arroyo_state::global_table("a", "AsyncMapOperator state")]
}

async fn on_close(&mut self, _ctx: &mut Context<InKey, OutT>) {
self.udf_context.close().await;
}
}

pub struct EmptyContext {}

#[async_trait]
impl UdfContext for EmptyContext {}

0 comments on commit 41cfe7a

Please sign in to comment.