Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 33 additions & 0 deletions datafusion/core/tests/sql/udf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,39 @@ async fn scalar_udf_zero_params() -> Result<()> {
Ok(())
}

#[tokio::test]
async fn scalar_udf_override_built_in_scalar_function() -> Result<()> {
let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]);

let batch = RecordBatch::try_new(
Arc::new(schema.clone()),
vec![Arc::new(Int32Array::from(vec![-100]))],
)?;
let ctx = SessionContext::new();

ctx.register_batch("t", batch)?;
// register a UDF that has the same name as a builtin function (abs) and just returns 1 regardless of input
ctx.register_udf(create_udf(
"abs",
vec![DataType::Int32],
Arc::new(DataType::Int32),
Volatility::Immutable,
Arc::new(move |_| Ok(ColumnarValue::Scalar(ScalarValue::Int32(Some(1))))),
));

// Make sure that the UDF is used instead of the built-in function
let result = plan_and_collect(&ctx, "select abs(a) a from t").await?;
let expected = vec![
"+---+", //
"| a |", //
"+---+", //
"| 1 |", //
"+---+", //
];
assert_batches_eq!(expected, &result);
Ok(())
}

/// tests the creation, registration and usage of a UDAF
#[tokio::test]
async fn simple_udaf() -> Result<()> {
Expand Down
16 changes: 8 additions & 8 deletions datafusion/sql/src/expr/function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,13 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
crate::utils::normalize_ident(function.name.0[0].clone())
};

// user-defined function (UDF) should have precedence in case it has the same name as a scalar built-in function
if let Some(fm) = self.schema_provider.get_function_meta(&name) {
let args =
self.function_args_to_expr(function.args, schema, planner_context)?;
return Ok(Expr::ScalarUDF(ScalarUDF::new(fm, args)));
}

// next, scalar built-in
if let Ok(fun) = BuiltinScalarFunction::from_str(&name) {
let args =
Expand Down Expand Up @@ -139,14 +146,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
)));
};

// finally, user-defined functions (UDF) and UDAF
if let Some(fm) = self.schema_provider.get_function_meta(&name) {
let args =
self.function_args_to_expr(function.args, schema, planner_context)?;
return Ok(Expr::ScalarUDF(ScalarUDF::new(fm, args)));
}

// User defined aggregate functions
// User defined aggregate functions (UDAF)
if let Some(fm) = self.schema_provider.get_aggregate_meta(&name) {
let args =
self.function_args_to_expr(function.args, schema, planner_context)?;
Expand Down