Skip to content

Commit

Permalink
Enable user-defined context structs for async UDFs
Browse files Browse the repository at this point in the history
Define an async trait 'UdfContext' that users can implement in their UDF
definition. The struct can store shared resources used by each
invocation of the UDF, like a database connection. An Arc pointer to the
context is passed as the first argument to the UDF.
  • Loading branch information
jbeisen committed Jan 16, 2024
1 parent 9b813bf commit 7f2e18f
Show file tree
Hide file tree
Showing 11 changed files with 188 additions and 67 deletions.
97 changes: 49 additions & 48 deletions Cargo.lock

Large diffs are not rendered by default.

5 changes: 4 additions & 1 deletion arroyo-controller/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -686,16 +686,19 @@ impl ControllerServer {
}

fn cargo_toml(name: &str, dependencies: &str) -> String {
let arroyo_types = "arroyo-types = { path = \"../../../arroyo-types\" }";

format!(
r#"
[package]
name = "{}"
version = "1.0.0"
edition = "2021"
{}
{}
"#,
name, dependencies
name, dependencies, arroyo_types
)
}

Expand Down
17 changes: 15 additions & 2 deletions arroyo-datastream/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -380,6 +380,7 @@ pub enum Operator {
ordered: bool,
function_def: String,
max_concurrency: u64,
has_context: bool,
},
}

Expand Down Expand Up @@ -1635,17 +1636,25 @@ impl Program {
ordered,
function_def,
max_concurrency,
has_context
} => {
let in_k = parse_type(&input.unwrap().weight().key);
let in_t = parse_type(&input.unwrap().weight().value);
let out_t = parse_type(&output.unwrap().weight().value);

let mut context_t = quote! { EmptyContext };
let mut context = quote! { EmptyContext {} };

if *has_context {
context_t = quote! { udfs::Context };
context = quote! { udfs::Context::new() };
}

let udf_wrapper : syn::Expr = parse_str(function_def).unwrap();

quote! {
Box::new(AsyncMapOperator::<#in_k, #in_t, #out_t, _, _>::
new(#name.to_string(), #udf_wrapper, #ordered, #max_concurrency)
Box::new(AsyncMapOperator::<#in_k, #in_t, #out_t, _, _, #context_t>::
new(#name.to_string(), #udf_wrapper, #context, #ordered, #max_concurrency)
)
}
}
Expand Down Expand Up @@ -2131,11 +2140,13 @@ impl From<Operator> for GrpcApi::operator::Operator {
ordered,
function_def,
max_concurrency,
has_context,
} => GrpcOperator::AsyncMapOperator(GrpcApi::AsyncMapOperator {
name,
ordered,
function_def,
max_concurrency,
has_context,
}),
Operator::ArrayMapOperator {
name,
Expand Down Expand Up @@ -2450,11 +2461,13 @@ impl TryFrom<arroyo_rpc::grpc::api::Operator> for Operator {
ordered,
function_def,
max_concurrency,
has_context,
}) => Operator::AsyncMapOperator {
name,
ordered,
function_def,
max_concurrency,
has_context,
},
GrpcOperator::FlattenExpressionOperator(flatten_expression) => {
let return_type = flatten_expression.return_type().into();
Expand Down
1 change: 1 addition & 0 deletions arroyo-rpc/proto/api.proto
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,7 @@ message AsyncMapOperator {
bool ordered = 2;
string function_def = 3;
uint64 max_concurrency = 4;
bool has_context = 5;
}

message SlidingWindowAggregator {
Expand Down
2 changes: 2 additions & 0 deletions arroyo-sql/src/expressions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1007,6 +1007,7 @@ impl<'a> ExpressionContext<'a> {
ret_type: def.ret.clone(),
async_fn: def.async_fn,
opts: def.opts.clone(),
has_context: def.has_context,
};

if def.async_fn {
Expand Down Expand Up @@ -3689,6 +3690,7 @@ pub struct RustUdfExpression {
pub args: Vec<(TypeDef, Expression)>,
pub ret_type: TypeDef,
async_fn: bool,
pub has_context: bool,
pub opts: UdfOpts,
}

Expand Down
54 changes: 51 additions & 3 deletions arroyo-sql/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ pub struct UdfDef {
dependencies: String,
opts: UdfOpts,
async_fn: bool,
has_context: bool,
}

#[derive(Clone, Debug)]
Expand Down Expand Up @@ -275,9 +276,30 @@ impl ArroyoSchemaProvider {
};

let name = function.sig.ident.to_string();
let async_fn = function.sig.asyncness.is_some();
let mut args: Vec<TypeDef> = vec![];
let mut vec_arguments = 0;
for (i, arg) in function.sig.inputs.iter().enumerate() {

let inputs = function.sig.inputs.iter();
let mut skip = 0;
let mut has_context = false;

if async_fn {
// skip the first argument if it is a context
if function.sig.inputs.len() >= 1 {
if let FnArg::Typed(t) = function.sig.inputs.first().unwrap() {
if let syn::Pat::Ident(i) = &*t.pat {
if i.ident == "context" {
// TODO: how to ensure type is Arc<Context>?
has_context = true;
skip = 1
}
}
}
}
}

for (i, arg) in inputs.skip(skip).enumerate() {
match arg {
FnArg::Receiver(_) => {
bail!(
Expand Down Expand Up @@ -369,10 +391,11 @@ impl ArroyoSchemaProvider {
UdfDef {
args,
ret,
async_fn: function.sig.asyncness.is_some(),
async_fn,
def: unparse(&file.clone()),
dependencies: parse_dependencies(&body)?,
opts: parse_udf_opts(&body)?,
has_context,
},
);

Expand Down Expand Up @@ -403,7 +426,7 @@ pub fn parse_dependencies(definition: &str) -> Result<String> {
get_toml_value(definition)?;

// get content of dependencies comment using regex
let re = Regex::new(r"\/\*\n[\s\S]*?(\[dependencies\]\n[\s\S]*?)\*\/").unwrap();
let re = Regex::new(r"(?m)\*\n[\s\S]*?(\[dependencies\]\n[\s\S]*?)(?:^$|\*/)").unwrap();

return if let Some(captures) = re.captures(&definition) {
if captures.len() != 2 {
Expand Down Expand Up @@ -824,6 +847,31 @@ mod tests {
serde = "1.0"
*/
pub fn my_udf() -> i64 {
1
}
"#;

assert_eq!(
parse_dependencies(definition).unwrap(),
r#"[dependencies]
serde = "1.0"
"#
);
}

#[test]
fn test_parse_dependencies_valid_with_udfs() {
let definition = r#"
/*
[dependencies]
serde = "1.0"
[udfs]
async_results_ordered = true
*/
pub fn my_udf() -> i64 {
1
}
Expand Down
19 changes: 16 additions & 3 deletions arroyo-sql/src/operators.rs
Original file line number Diff line number Diff line change
Expand Up @@ -207,8 +207,19 @@ impl CodeGenerator<ValuePointerContext, StructDef, syn::Expr> for AsyncUdfProjec
let (match_terms, ids): (Vec<_>, Vec<_>) = match_term_ids.into_iter().unzip();

let function_name = format_ident!("{}", self.async_udf.name);
let args_pattern = quote!((#(#ids),*));
let timeout_seconds = self.async_udf.opts.async_timeout_seconds;

let mut context_t = quote! { EmptyContext };
let mut context_arg = quote!();

if self.async_udf.has_context {
context_t = quote! { udfs::Context };
context_arg = quote! {context.clone(), };
}

let args_pattern = quote!((#(#ids),*));
let args = quote!((#context_arg #(#ids),*));

let invocation = if may_not_invoke {
// turn ids into a tuple
let match_terms = quote!((#(#match_terms),*));
Expand All @@ -220,22 +231,24 @@ impl CodeGenerator<ValuePointerContext, StructDef, syn::Expr> for AsyncUdfProjec
quote!(
match #args_pattern {
#match_terms => {
timeout(Duration::from_secs(#timeout_seconds), udfs:: #function_name #args_pattern).await #suffix
timeout(Duration::from_secs(#timeout_seconds), udfs:: #function_name #args).await #suffix
}
_ => {
Ok(None)
}
}
)
} else {
quote!(timeout(Duration::from_secs(#timeout_seconds), udfs:: #function_name #args_pattern).await)
quote!(timeout(Duration::from_secs(#timeout_seconds), udfs:: #function_name #args).await)
};
parse_quote! {{
use tokio::time::error::Elapsed;
use tokio::time::{timeout, Duration};
use std::sync::Arc;
async fn wrapper(
index: usize,
#input_name: #input_struct,
context: Arc<#context_t>
) -> (
usize,
Result<#output_type, Elapsed>,
Expand Down
3 changes: 3 additions & 0 deletions arroyo-sql/src/pipeline.rs
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,7 @@ impl RecordTransform {
a.async_udf.opts.async_results_ordered,
function_def.to_token_stream().to_string(),
a.async_udf.opts.async_max_concurrency,
a.async_udf.has_context,
)
}
}
Expand Down Expand Up @@ -1315,12 +1316,14 @@ impl MethodCompiler {
ordered: bool,
function_def: String,
max_concurrency: u64,
has_context: bool,
) -> Operator {
Operator::AsyncMapOperator {
name: name.to_string(),
ordered,
function_def,
max_concurrency,
has_context,
}
}
}
1 change: 1 addition & 0 deletions arroyo-types/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,4 @@ bincode = "2.0.0-rc.3"
serde = { version = "1.0", features = ["derive"] }
arrow = { workspace = true }
arrow-array = { workspace = true }
async-trait = "0.1.74"
7 changes: 7 additions & 0 deletions arroyo-types/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use arrow::datatypes::SchemaRef;
use arrow_array::RecordBatch;
use async_trait::async_trait;
use bincode::{config, Decode, Encode};
use serde::ser::SerializeStruct;
use serde::{Deserialize, Serialize};
Expand Down Expand Up @@ -932,3 +933,9 @@ mod tests {
);
}
}

#[async_trait]
pub trait UdfContext: Sync {
async fn init(&self) {}
async fn close(&self) {}
}
Loading

0 comments on commit 7f2e18f

Please sign in to comment.