Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
233 changes: 220 additions & 13 deletions datafusion/core/tests/user_defined/user_defined_scalar_functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ use datafusion_expr::{
LogicalPlanBuilder, OperateFunctionArg, ReturnFieldArgs, ScalarFunctionArgs,
ScalarUDF, ScalarUDFImpl, Signature, Volatility,
};
use datafusion_expr_common::signature::TypeSignature;
use datafusion_functions_nested::range::range_udf;
use parking_lot::Mutex;
use regex::Regex;
Expand Down Expand Up @@ -945,6 +946,7 @@ struct ScalarFunctionWrapper {
expr: Expr,
signature: Signature,
return_type: DataType,
defaults: Vec<Option<Expr>>,
}

impl ScalarUDFImpl for ScalarFunctionWrapper {
Expand Down Expand Up @@ -973,27 +975,39 @@ impl ScalarUDFImpl for ScalarFunctionWrapper {
args: Vec<Expr>,
_info: &dyn SimplifyInfo,
) -> Result<ExprSimplifyResult> {
let replacement = Self::replacement(&self.expr, &args)?;
let replacement = Self::replacement(&self.expr, &args, &self.defaults)?;

Ok(ExprSimplifyResult::Simplified(replacement))
}
}

impl ScalarFunctionWrapper {
// replaces placeholders with actual arguments
fn replacement(expr: &Expr, args: &[Expr]) -> Result<Expr> {
fn replacement(
expr: &Expr,
args: &[Expr],
defaults: &[Option<Expr>],
) -> Result<Expr> {
let result = expr.clone().transform(|e| {
let r = match e {
Expr::Placeholder(placeholder) => {
let placeholder_position =
Self::parse_placeholder_identifier(&placeholder.id)?;
if placeholder_position < args.len() {
Transformed::yes(args[placeholder_position].clone())
} else {
} else if placeholder_position >= defaults.len() {
exec_err!(
"Function argument {} not provided, argument missing!",
"Invalid placeholder, out of range: {}",
placeholder.id
)?
} else {
match defaults[placeholder_position] {
Some(ref default) => Transformed::yes(default.clone()),
None => exec_err!(
"Function argument {} not provided, argument missing!",
placeholder.id
)?,
}
}
}
_ => Transformed::no(e),
Expand Down Expand Up @@ -1021,6 +1035,32 @@ impl TryFrom<CreateFunction> for ScalarFunctionWrapper {
type Error = DataFusionError;

fn try_from(definition: CreateFunction) -> std::result::Result<Self, Self::Error> {
let args = definition.args.unwrap_or_default();
let defaults: Vec<Option<Expr>> =
args.iter().map(|a| a.default_expr.clone()).collect();
let signature: Signature = match defaults.iter().position(|v| v.is_some()) {
Some(pos) => {
let mut type_signatures: Vec<TypeSignature> = vec![];
// Generate all valid signatures
for n in pos..defaults.len() + 1 {
if n == 0 {
type_signatures.push(TypeSignature::Nullary)
} else {
type_signatures.push(TypeSignature::Exact(
args.iter().take(n).map(|a| a.data_type.clone()).collect(),
))
}
}
Signature::one_of(
type_signatures,
definition.params.behavior.unwrap_or(Volatility::Volatile),
)
}
None => Signature::exact(
args.iter().map(|a| a.data_type.clone()).collect(),
definition.params.behavior.unwrap_or(Volatility::Volatile),
),
};
Ok(Self {
name: definition.name,
expr: definition
Expand All @@ -1030,15 +1070,8 @@ impl TryFrom<CreateFunction> for ScalarFunctionWrapper {
return_type: definition
.return_type
.expect("Return type has to be defined!"),
signature: Signature::exact(
definition
.args
.unwrap_or_default()
.into_iter()
.map(|a| a.data_type)
.collect(),
definition.params.behavior.unwrap_or(Volatility::Volatile),
),
signature,
defaults,
})
}
}
Expand Down Expand Up @@ -1109,6 +1142,180 @@ async fn create_scalar_function_from_sql_statement() -> Result<()> {
"#;
assert!(ctx.sql(bad_definition_sql).await.is_err());

// FIXME: Definitions with invalid placeholders are allowed, fail at runtime
let bad_expression_sql = r#"
CREATE FUNCTION better_add(DOUBLE, DOUBLE)
RETURNS DOUBLE
RETURN $1 + $3
"#;
assert!(ctx.sql(bad_expression_sql).await.is_ok());

let err = ctx
.sql("select better_add(2.0, 2.0)")
.await?
.collect()
.await
.expect_err("unknown placeholder");
let expected = "Optimizer rule 'simplify_expressions' failed\ncaused by\nExecution error: Invalid placeholder, out of range: $3";
assert!(expected.starts_with(&err.strip_backtrace()));

Ok(())
}

#[tokio::test]
async fn create_scalar_function_from_sql_statement_named_arguments() -> Result<()> {
let function_factory = Arc::new(CustomFunctionFactory::default());
let ctx = SessionContext::new().with_function_factory(function_factory.clone());

let sql = r#"
CREATE FUNCTION better_add(a DOUBLE, b DOUBLE)
RETURNS DOUBLE
RETURN $a + $b
"#;

assert!(ctx.sql(sql).await.is_ok());

let result = ctx
.sql("select better_add(2.0, 2.0)")
.await?
.collect()
.await?;

assert_batches_eq!(
&[
"+-----------------------------------+",
"| better_add(Float64(2),Float64(2)) |",
"+-----------------------------------+",
"| 4.0 |",
"+-----------------------------------+",
],
&result
);

// cannot mix named and positional style
let bad_expression_sql = r#"
CREATE FUNCTION bad_expression_fun(DOUBLE, b DOUBLE)
RETURNS DOUBLE
RETURN $1 + $b
"#;
let err = ctx
.sql(bad_expression_sql)
.await
.expect_err("cannot mix named and positional style");
let expected = "Error during planning: All function arguments must use either named or positional style.";
assert!(expected.starts_with(&err.strip_backtrace()));

Ok(())
}

#[tokio::test]
async fn create_scalar_function_from_sql_statement_default_arguments() -> Result<()> {
let function_factory = Arc::new(CustomFunctionFactory::default());
let ctx = SessionContext::new().with_function_factory(function_factory.clone());

let sql = r#"
CREATE FUNCTION better_add(a DOUBLE = 2.0, b DOUBLE = 2.0)
RETURNS DOUBLE
RETURN $a + $b
"#;

assert!(ctx.sql(sql).await.is_ok());

// Check all function arity supported
let result = ctx.sql("select better_add()").await?.collect().await?;

assert_batches_eq!(
&[
"+--------------+",
"| better_add() |",
"+--------------+",
"| 4.0 |",
"+--------------+",
],
&result
);

let result = ctx.sql("select better_add(2.0)").await?.collect().await?;

assert_batches_eq!(
&[
"+------------------------+",
"| better_add(Float64(2)) |",
"+------------------------+",
"| 4.0 |",
"+------------------------+",
],
&result
);

let result = ctx
.sql("select better_add(2.0, 2.0)")
.await?
.collect()
.await?;

assert_batches_eq!(
&[
"+-----------------------------------+",
"| better_add(Float64(2),Float64(2)) |",
"+-----------------------------------+",
"| 4.0 |",
"+-----------------------------------+",
],
&result
);

assert!(ctx.sql("select better_add(2.0, 2.0, 2.0)").await.is_err());
assert!(ctx.sql("drop function better_add").await.is_ok());

// works with positional style
let sql = r#"
CREATE FUNCTION better_add(DOUBLE, DOUBLE = 2.0)
RETURNS DOUBLE
RETURN $1 + $2
"#;
assert!(ctx.sql(sql).await.is_ok());

assert!(ctx.sql("select better_add()").await.is_err());
let result = ctx.sql("select better_add(2.0)").await?.collect().await?;
assert_batches_eq!(
&[
"+------------------------+",
"| better_add(Float64(2)) |",
"+------------------------+",
"| 4.0 |",
"+------------------------+",
],
&result
);

// non-default argument cannot follow default argument
let bad_expression_sql = r#"
CREATE FUNCTION bad_expression_fun(a DOUBLE = 2.0, b DOUBLE)
RETURNS DOUBLE
RETURN $a + $b
"#;
let err = ctx
.sql(bad_expression_sql)
.await
.expect_err("non-default argument cannot follow default argument");
let expected =
"Error during planning: Non-default arguments cannot follow default arguments.";
assert!(expected.starts_with(&err.strip_backtrace()));

// FIXME: The `DEFAULT` syntax does not work with positional params
let bad_expression_sql = r#"
CREATE FUNCTION bad_expression_fun(DOUBLE, DOUBLE DEFAULT 2.0)
RETURNS DOUBLE
RETURN $1 + $2
"#;
let err = ctx
.sql(bad_expression_sql)
.await
.expect_err("sqlparser error");
let expected =
"SQL error: ParserError(\"Expected: ), found: 2.0 at Line: 2, Column: 63\")";
assert!(expected.starts_with(&err.strip_backtrace()));
Ok(())
}

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add a test case with positional parameter that is not existing, e.g. $5 where there are less than 5 arguments.
I have the feeling that it will fail with index out of bounds error at https://github.com/apache/datafusion/pull/18450/files#diff-647d2e08b4d044bf63b35f9e23092ba9673b80b1568e8f3abffd7f909552ea1aR999

You need to add a check similar to if placeholder_position < defaults.len() {...} around it and return an error in the else clause

Copy link
Contributor Author

@r1b r1b Nov 6, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, addressed in 0330adb.

This case also revealed that the DEFAULT syntax is broken for positional params. I switched to = syntax and added a test that illustrates the problem in a97ddb5.

Ref: https://github.com/apache/datafusion-sqlparser-rs/blob/308a7231bcbc5c1c8ab71fe38f17b1a21632a6c6/src/parser/mod.rs#L5536

EDIT: It seems that = is the "canonical" syntax in any case

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Proposed a fix for the DEFAULT syntax bug upstream: apache/datafusion-sqlparser-rs#2091

Expand Down
22 changes: 17 additions & 5 deletions datafusion/sql/src/expr/value.rs
Original file line number Diff line number Diff line change
Expand Up @@ -104,13 +104,13 @@ impl<S: ContextProvider> SqlToRel<'_, S> {
}

/// Create a placeholder expression
/// This is the same as Postgres's prepare statement syntax in which a placeholder starts with `$` sign and then
/// number 1, 2, ... etc. For example, `$1` is the first placeholder; $2 is the second one and so on.
/// Both named (`$foo`) and positional (`$1`, `$2`, ...) placeholder styles are supported.
fn create_placeholder_expr(
param: String,
param_data_types: &[FieldRef],
) -> Result<Expr> {
// Parse the placeholder as a number because it is the only support from sqlparser and postgres
// Try to parse the placeholder as a number. If the placeholder does not have a valid
// positional value, assume we have a named placeholder.
let index = param[1..].parse::<usize>();
let idx = match index {
Ok(0) => {
Expand All @@ -123,12 +123,24 @@ impl<S: ContextProvider> SqlToRel<'_, S> {
return if param_data_types.is_empty() {
Ok(Expr::Placeholder(Placeholder::new_with_field(param, None)))
} else {
// when PREPARE Statement, param_data_types length is always 0
plan_err!("Invalid placeholder, not a number: {param}")
// FIXME: This branch is shared by params from PREPARE and CREATE FUNCTION, but
// only CREATE FUNCTION currently supports named params. For now, we rewrite
// these to positional params.
Comment on lines +126 to +128
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

note I explored doing this without rewriting to positional params, but I couldn't see a path forward without either:

  • Adding a way to distinguish between prepared statement vs SQL UDF context
  • Supporting named params in prepared statements (not currently supported in sqlparser AFAICT)

let named_param_pos = param_data_types
.iter()
.position(|v| v.name() == &param[1..]);
match named_param_pos {
Some(pos) => Ok(Expr::Placeholder(Placeholder::new_with_field(
format!("${}", pos + 1),
param_data_types.get(pos).cloned(),
))),
None => plan_err!("Unknown placeholder: {param}"),
}
};
}
};
// Check if the placeholder is in the parameter list
// FIXME: In the CREATE FUNCTION branch, param_type = None should raise an error
let param_type = param_data_types.get(idx);
// Data type of the parameter
debug!("type of param {param} param_data_types[idx]: {param_type:?}");
Expand Down
38 changes: 37 additions & 1 deletion datafusion/sql/src/statement.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1222,6 +1222,28 @@ impl<S: ContextProvider> SqlToRel<'_, S> {
}
None => None,
};
// Validate default arguments
let first_default = match args.as_ref() {
Some(arg) => arg.iter().position(|t| t.default_expr.is_some()),
None => None,
};
let last_non_default = match args.as_ref() {
Some(arg) => arg
.iter()
.rev()
.position(|t| t.default_expr.is_none())
.map(|reverse_pos| arg.len() - reverse_pos - 1),
None => None,
};
if let (Some(pos_default), Some(pos_non_default)) =
(first_default, last_non_default)
{
if pos_non_default > pos_default {
return plan_err!(
"Non-default arguments cannot follow default arguments."
);
}
}
// At the moment functions can't be qualified `schema.name`
let name = match &name.0[..] {
[] => exec_err!("Function should have name")?,
Expand All @@ -1233,9 +1255,23 @@ impl<S: ContextProvider> SqlToRel<'_, S> {
//
let arg_types = args.as_ref().map(|arg| {
arg.iter()
.map(|t| Arc::new(Field::new("", t.data_type.clone(), true)))
.map(|t| {
let name = match t.name.clone() {
Some(name) => name.value,
None => "".to_string(),
};
Arc::new(Field::new(name, t.data_type.clone(), true))
})
.collect::<Vec<_>>()
});
// Validate parameter style
if let Some(ref fields) = arg_types {
let count_positional =
fields.iter().filter(|f| f.name() == "").count();
if !(count_positional == 0 || count_positional == fields.len()) {
return plan_err!("All function arguments must use either named or positional style.");
}
}
Comment on lines +1267 to +1274
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

note I'm not sure this is actually necessary, but it made it easier to reason about the changes. If we think this is valuable I can look at relaxing this constraint.

let mut planner_context = PlannerContext::new()
.with_prepare_param_data_types(arg_types.unwrap_or_default());

Expand Down
2 changes: 1 addition & 1 deletion datafusion/sql/tests/cases/params.rs
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ fn test_prepare_statement_to_plan_panic_param_format() {
assert_snapshot!(
logical_plan(sql).unwrap_err().strip_backtrace(),
@r###"
Error during planning: Invalid placeholder, not a number: $foo
Error during planning: Unknown placeholder: $foo
"###
);
}
Expand Down
Loading