Skip to content

Commit

Permalink
Implement Aliases for ScalarUDF (#8360)
Browse files Browse the repository at this point in the history
* Implement Aliases for ScalarUDF

Signed-off-by: veeupup <code@tanweime.com>

* fix comments

Signed-off-by: veeupup <code@tanweime.com>

---------

Signed-off-by: veeupup <code@tanweime.com>
  • Loading branch information
Veeupup committed Nov 30, 2023
1 parent a49740f commit d45cf00
Show file tree
Hide file tree
Showing 3 changed files with 64 additions and 2 deletions.
11 changes: 9 additions & 2 deletions datafusion/core/src/execution/context/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -810,9 +810,16 @@ impl SessionContext {
///
/// - `SELECT MY_FUNC(x)...` will look for a function named `"my_func"`
/// - `SELECT "my_FUNC"(x)` will look for a function named `"my_FUNC"`
/// Any functions registered with the udf name or its aliases will be overwritten with this new function
pub fn register_udf(&self, f: ScalarUDF) {
self.state
.write()
let mut state = self.state.write();
let aliases = f.aliases();
for alias in aliases {
state
.scalar_functions
.insert(alias.to_string(), Arc::new(f.clone()));
}
state
.scalar_functions
.insert(f.name().to_string(), Arc::new(f));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -341,6 +341,43 @@ async fn case_sensitive_identifiers_user_defined_functions() -> Result<()> {
Ok(())
}

#[tokio::test]
async fn test_user_defined_functions_with_alias() -> Result<()> {
let ctx = SessionContext::new();
let arr = Int32Array::from(vec![1]);
let batch = RecordBatch::try_from_iter(vec![("i", Arc::new(arr) as _)])?;
ctx.register_batch("t", batch).unwrap();

let myfunc = |args: &[ArrayRef]| Ok(Arc::clone(&args[0]));
let myfunc = make_scalar_function(myfunc);

let udf = create_udf(
"dummy",
vec![DataType::Int32],
Arc::new(DataType::Int32),
Volatility::Immutable,
myfunc,
)
.with_aliases(vec!["dummy_alias"]);

ctx.register_udf(udf);

let expected = [
"+------------+",
"| dummy(t.i) |",
"+------------+",
"| 1 |",
"+------------+",
];
let result = plan_and_collect(&ctx, "SELECT dummy(i) FROM t").await?;
assert_batches_eq!(expected, &result);

let alias_result = plan_and_collect(&ctx, "SELECT dummy_alias(i) FROM t").await?;
assert_batches_eq!(expected, &alias_result);

Ok(())
}

fn create_udf_context() -> SessionContext {
let ctx = SessionContext::new();
// register a custom UDF
Expand Down
18 changes: 18 additions & 0 deletions datafusion/expr/src/udf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@ pub struct ScalarUDF {
/// the batch's row count (so that the generative zero-argument function can know
/// the result array size).
fun: ScalarFunctionImplementation,
/// Optional aliases for the function. This list should NOT include the value of `name` as well
aliases: Vec<String>,
}

impl Debug for ScalarUDF {
Expand Down Expand Up @@ -89,9 +91,20 @@ impl ScalarUDF {
signature: signature.clone(),
return_type: return_type.clone(),
fun: fun.clone(),
aliases: vec![],
}
}

/// Adds additional names that can be used to invoke this function, in addition to `name`
pub fn with_aliases(
mut self,
aliases: impl IntoIterator<Item = &'static str>,
) -> Self {
self.aliases
.extend(aliases.into_iter().map(|s| s.to_string()));
self
}

/// creates a logical expression with a call of the UDF
/// This utility allows using the UDF without requiring access to the registry.
pub fn call(&self, args: Vec<Expr>) -> Expr {
Expand All @@ -106,6 +119,11 @@ impl ScalarUDF {
&self.name
}

/// Returns the aliases for this function. See [`ScalarUDF::with_aliases`] for more details
pub fn aliases(&self) -> &[String] {
&self.aliases
}

/// Returns this function's signature (what input types are accepted)
pub fn signature(&self) -> &Signature {
&self.signature
Expand Down

0 comments on commit d45cf00

Please sign in to comment.