diff --git a/datafusion/core/tests/sql/udf.rs b/datafusion/core/tests/sql/udf.rs index a31028fd71cb..0ecd5d0fde86 100644 --- a/datafusion/core/tests/sql/udf.rs +++ b/datafusion/core/tests/sql/udf.rs @@ -179,6 +179,39 @@ async fn scalar_udf_zero_params() -> Result<()> { Ok(()) } +#[tokio::test] +async fn scalar_udf_override_built_in_scalar_function() -> Result<()> { + let schema = Schema::new(vec![Field::new("a", DataType::Int32, false)]); + + let batch = RecordBatch::try_new( + Arc::new(schema.clone()), + vec![Arc::new(Int32Array::from(vec![-100]))], + )?; + let ctx = SessionContext::new(); + + ctx.register_batch("t", batch)?; + // register a UDF that has the same name as a builtin function (abs) and just returns 1 regardless of input + ctx.register_udf(create_udf( + "abs", + vec![DataType::Int32], + Arc::new(DataType::Int32), + Volatility::Immutable, + Arc::new(move |_| Ok(ColumnarValue::Scalar(ScalarValue::Int32(Some(1))))), + )); + + // Make sure that the UDF is used instead of the built-in function + let result = plan_and_collect(&ctx, "select abs(a) a from t").await?; + let expected = vec![ + "+---+", // + "| a |", // + "+---+", // + "| 1 |", // + "+---+", // + ]; + assert_batches_eq!(expected, &result); + Ok(()) +} + /// tests the creation, registration and usage of a UDAF #[tokio::test] async fn simple_udaf() -> Result<()> { diff --git a/datafusion/sql/src/expr/function.rs b/datafusion/sql/src/expr/function.rs index 0fb6b7554776..0289e804110c 100644 --- a/datafusion/sql/src/expr/function.rs +++ b/datafusion/sql/src/expr/function.rs @@ -47,6 +47,13 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { crate::utils::normalize_ident(function.name.0[0].clone()) }; + // user-defined function (UDF) should have precedence in case it has the same name as a scalar built-in function + if let Some(fm) = self.schema_provider.get_function_meta(&name) { + let args = + self.function_args_to_expr(function.args, schema, planner_context)?; + return Ok(Expr::ScalarUDF(ScalarUDF::new(fm, args))); + } + // next, scalar built-in if let Ok(fun) = BuiltinScalarFunction::from_str(&name) { let args = @@ -139,14 +146,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { ))); }; - // finally, user-defined functions (UDF) and UDAF - if let Some(fm) = self.schema_provider.get_function_meta(&name) { - let args = - self.function_args_to_expr(function.args, schema, planner_context)?; - return Ok(Expr::ScalarUDF(ScalarUDF::new(fm, args))); - } - - // User defined aggregate functions + // User defined aggregate functions (UDAF) if let Some(fm) = self.schema_provider.get_aggregate_meta(&name) { let args = self.function_args_to_expr(function.args, schema, planner_context)?;