diff --git a/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs b/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs index 3ca8f846aa5e..1361091a4cb5 100644 --- a/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs +++ b/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs @@ -47,6 +47,7 @@ use datafusion_expr::{ LogicalPlanBuilder, OperateFunctionArg, ReturnFieldArgs, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl, Signature, Volatility, }; +use datafusion_expr_common::signature::TypeSignature; use datafusion_functions_nested::range::range_udf; use parking_lot::Mutex; use regex::Regex; @@ -945,6 +946,7 @@ struct ScalarFunctionWrapper { expr: Expr, signature: Signature, return_type: DataType, + defaults: Vec>, } impl ScalarUDFImpl for ScalarFunctionWrapper { @@ -973,7 +975,7 @@ impl ScalarUDFImpl for ScalarFunctionWrapper { args: Vec, _info: &dyn SimplifyInfo, ) -> Result { - let replacement = Self::replacement(&self.expr, &args)?; + let replacement = Self::replacement(&self.expr, &args, &self.defaults)?; Ok(ExprSimplifyResult::Simplified(replacement)) } @@ -981,7 +983,11 @@ impl ScalarUDFImpl for ScalarFunctionWrapper { impl ScalarFunctionWrapper { // replaces placeholders with actual arguments - fn replacement(expr: &Expr, args: &[Expr]) -> Result { + fn replacement( + expr: &Expr, + args: &[Expr], + defaults: &[Option], + ) -> Result { let result = expr.clone().transform(|e| { let r = match e { Expr::Placeholder(placeholder) => { @@ -989,11 +995,19 @@ impl ScalarFunctionWrapper { Self::parse_placeholder_identifier(&placeholder.id)?; if placeholder_position < args.len() { Transformed::yes(args[placeholder_position].clone()) - } else { + } else if placeholder_position >= defaults.len() { exec_err!( - "Function argument {} not provided, argument missing!", + "Invalid placeholder, out of range: {}", placeholder.id )? + } else { + match defaults[placeholder_position] { + Some(ref default) => Transformed::yes(default.clone()), + None => exec_err!( + "Function argument {} not provided, argument missing!", + placeholder.id + )?, + } } } _ => Transformed::no(e), @@ -1021,6 +1035,32 @@ impl TryFrom for ScalarFunctionWrapper { type Error = DataFusionError; fn try_from(definition: CreateFunction) -> std::result::Result { + let args = definition.args.unwrap_or_default(); + let defaults: Vec> = + args.iter().map(|a| a.default_expr.clone()).collect(); + let signature: Signature = match defaults.iter().position(|v| v.is_some()) { + Some(pos) => { + let mut type_signatures: Vec = vec![]; + // Generate all valid signatures + for n in pos..defaults.len() + 1 { + if n == 0 { + type_signatures.push(TypeSignature::Nullary) + } else { + type_signatures.push(TypeSignature::Exact( + args.iter().take(n).map(|a| a.data_type.clone()).collect(), + )) + } + } + Signature::one_of( + type_signatures, + definition.params.behavior.unwrap_or(Volatility::Volatile), + ) + } + None => Signature::exact( + args.iter().map(|a| a.data_type.clone()).collect(), + definition.params.behavior.unwrap_or(Volatility::Volatile), + ), + }; Ok(Self { name: definition.name, expr: definition @@ -1030,15 +1070,8 @@ impl TryFrom for ScalarFunctionWrapper { return_type: definition .return_type .expect("Return type has to be defined!"), - signature: Signature::exact( - definition - .args - .unwrap_or_default() - .into_iter() - .map(|a| a.data_type) - .collect(), - definition.params.behavior.unwrap_or(Volatility::Volatile), - ), + signature, + defaults, }) } } @@ -1109,6 +1142,180 @@ async fn create_scalar_function_from_sql_statement() -> Result<()> { "#; assert!(ctx.sql(bad_definition_sql).await.is_err()); + // FIXME: Definitions with invalid placeholders are allowed, fail at runtime + let bad_expression_sql = r#" + CREATE FUNCTION better_add(DOUBLE, DOUBLE) + RETURNS DOUBLE + RETURN $1 + $3 + "#; + assert!(ctx.sql(bad_expression_sql).await.is_ok()); + + let err = ctx + .sql("select better_add(2.0, 2.0)") + .await? + .collect() + .await + .expect_err("unknown placeholder"); + let expected = "Optimizer rule 'simplify_expressions' failed\ncaused by\nExecution error: Invalid placeholder, out of range: $3"; + assert!(expected.starts_with(&err.strip_backtrace())); + + Ok(()) +} + +#[tokio::test] +async fn create_scalar_function_from_sql_statement_named_arguments() -> Result<()> { + let function_factory = Arc::new(CustomFunctionFactory::default()); + let ctx = SessionContext::new().with_function_factory(function_factory.clone()); + + let sql = r#" + CREATE FUNCTION better_add(a DOUBLE, b DOUBLE) + RETURNS DOUBLE + RETURN $a + $b + "#; + + assert!(ctx.sql(sql).await.is_ok()); + + let result = ctx + .sql("select better_add(2.0, 2.0)") + .await? + .collect() + .await?; + + assert_batches_eq!( + &[ + "+-----------------------------------+", + "| better_add(Float64(2),Float64(2)) |", + "+-----------------------------------+", + "| 4.0 |", + "+-----------------------------------+", + ], + &result + ); + + // cannot mix named and positional style + let bad_expression_sql = r#" + CREATE FUNCTION bad_expression_fun(DOUBLE, b DOUBLE) + RETURNS DOUBLE + RETURN $1 + $b + "#; + let err = ctx + .sql(bad_expression_sql) + .await + .expect_err("cannot mix named and positional style"); + let expected = "Error during planning: All function arguments must use either named or positional style."; + assert!(expected.starts_with(&err.strip_backtrace())); + + Ok(()) +} + +#[tokio::test] +async fn create_scalar_function_from_sql_statement_default_arguments() -> Result<()> { + let function_factory = Arc::new(CustomFunctionFactory::default()); + let ctx = SessionContext::new().with_function_factory(function_factory.clone()); + + let sql = r#" + CREATE FUNCTION better_add(a DOUBLE = 2.0, b DOUBLE = 2.0) + RETURNS DOUBLE + RETURN $a + $b + "#; + + assert!(ctx.sql(sql).await.is_ok()); + + // Check all function arity supported + let result = ctx.sql("select better_add()").await?.collect().await?; + + assert_batches_eq!( + &[ + "+--------------+", + "| better_add() |", + "+--------------+", + "| 4.0 |", + "+--------------+", + ], + &result + ); + + let result = ctx.sql("select better_add(2.0)").await?.collect().await?; + + assert_batches_eq!( + &[ + "+------------------------+", + "| better_add(Float64(2)) |", + "+------------------------+", + "| 4.0 |", + "+------------------------+", + ], + &result + ); + + let result = ctx + .sql("select better_add(2.0, 2.0)") + .await? + .collect() + .await?; + + assert_batches_eq!( + &[ + "+-----------------------------------+", + "| better_add(Float64(2),Float64(2)) |", + "+-----------------------------------+", + "| 4.0 |", + "+-----------------------------------+", + ], + &result + ); + + assert!(ctx.sql("select better_add(2.0, 2.0, 2.0)").await.is_err()); + assert!(ctx.sql("drop function better_add").await.is_ok()); + + // works with positional style + let sql = r#" + CREATE FUNCTION better_add(DOUBLE, DOUBLE = 2.0) + RETURNS DOUBLE + RETURN $1 + $2 + "#; + assert!(ctx.sql(sql).await.is_ok()); + + assert!(ctx.sql("select better_add()").await.is_err()); + let result = ctx.sql("select better_add(2.0)").await?.collect().await?; + assert_batches_eq!( + &[ + "+------------------------+", + "| better_add(Float64(2)) |", + "+------------------------+", + "| 4.0 |", + "+------------------------+", + ], + &result + ); + + // non-default argument cannot follow default argument + let bad_expression_sql = r#" + CREATE FUNCTION bad_expression_fun(a DOUBLE = 2.0, b DOUBLE) + RETURNS DOUBLE + RETURN $a + $b + "#; + let err = ctx + .sql(bad_expression_sql) + .await + .expect_err("non-default argument cannot follow default argument"); + let expected = + "Error during planning: Non-default arguments cannot follow default arguments."; + assert!(expected.starts_with(&err.strip_backtrace())); + + // FIXME: The `DEFAULT` syntax does not work with positional params + let bad_expression_sql = r#" + CREATE FUNCTION bad_expression_fun(DOUBLE, DOUBLE DEFAULT 2.0) + RETURNS DOUBLE + RETURN $1 + $2 + "#; + let err = ctx + .sql(bad_expression_sql) + .await + .expect_err("sqlparser error"); + let expected = + "SQL error: ParserError(\"Expected: ), found: 2.0 at Line: 2, Column: 63\")"; + assert!(expected.starts_with(&err.strip_backtrace())); Ok(()) } diff --git a/datafusion/sql/src/expr/value.rs b/datafusion/sql/src/expr/value.rs index 3abb2752988f..8ca059d08c16 100644 --- a/datafusion/sql/src/expr/value.rs +++ b/datafusion/sql/src/expr/value.rs @@ -104,13 +104,13 @@ impl SqlToRel<'_, S> { } /// Create a placeholder expression - /// This is the same as Postgres's prepare statement syntax in which a placeholder starts with `$` sign and then - /// number 1, 2, ... etc. For example, `$1` is the first placeholder; $2 is the second one and so on. + /// Both named (`$foo`) and positional (`$1`, `$2`, ...) placeholder styles are supported. fn create_placeholder_expr( param: String, param_data_types: &[FieldRef], ) -> Result { - // Parse the placeholder as a number because it is the only support from sqlparser and postgres + // Try to parse the placeholder as a number. If the placeholder does not have a valid + // positional value, assume we have a named placeholder. let index = param[1..].parse::(); let idx = match index { Ok(0) => { @@ -123,12 +123,24 @@ impl SqlToRel<'_, S> { return if param_data_types.is_empty() { Ok(Expr::Placeholder(Placeholder::new_with_field(param, None))) } else { - // when PREPARE Statement, param_data_types length is always 0 - plan_err!("Invalid placeholder, not a number: {param}") + // FIXME: This branch is shared by params from PREPARE and CREATE FUNCTION, but + // only CREATE FUNCTION currently supports named params. For now, we rewrite + // these to positional params. + let named_param_pos = param_data_types + .iter() + .position(|v| v.name() == ¶m[1..]); + match named_param_pos { + Some(pos) => Ok(Expr::Placeholder(Placeholder::new_with_field( + format!("${}", pos + 1), + param_data_types.get(pos).cloned(), + ))), + None => plan_err!("Unknown placeholder: {param}"), + } }; } }; // Check if the placeholder is in the parameter list + // FIXME: In the CREATE FUNCTION branch, param_type = None should raise an error let param_type = param_data_types.get(idx); // Data type of the parameter debug!("type of param {param} param_data_types[idx]: {param_type:?}"); diff --git a/datafusion/sql/src/statement.rs b/datafusion/sql/src/statement.rs index 81381bf49fc5..84ce6bc1673a 100644 --- a/datafusion/sql/src/statement.rs +++ b/datafusion/sql/src/statement.rs @@ -1222,6 +1222,28 @@ impl SqlToRel<'_, S> { } None => None, }; + // Validate default arguments + let first_default = match args.as_ref() { + Some(arg) => arg.iter().position(|t| t.default_expr.is_some()), + None => None, + }; + let last_non_default = match args.as_ref() { + Some(arg) => arg + .iter() + .rev() + .position(|t| t.default_expr.is_none()) + .map(|reverse_pos| arg.len() - reverse_pos - 1), + None => None, + }; + if let (Some(pos_default), Some(pos_non_default)) = + (first_default, last_non_default) + { + if pos_non_default > pos_default { + return plan_err!( + "Non-default arguments cannot follow default arguments." + ); + } + } // At the moment functions can't be qualified `schema.name` let name = match &name.0[..] { [] => exec_err!("Function should have name")?, @@ -1233,9 +1255,23 @@ impl SqlToRel<'_, S> { // let arg_types = args.as_ref().map(|arg| { arg.iter() - .map(|t| Arc::new(Field::new("", t.data_type.clone(), true))) + .map(|t| { + let name = match t.name.clone() { + Some(name) => name.value, + None => "".to_string(), + }; + Arc::new(Field::new(name, t.data_type.clone(), true)) + }) .collect::>() }); + // Validate parameter style + if let Some(ref fields) = arg_types { + let count_positional = + fields.iter().filter(|f| f.name() == "").count(); + if !(count_positional == 0 || count_positional == fields.len()) { + return plan_err!("All function arguments must use either named or positional style."); + } + } let mut planner_context = PlannerContext::new() .with_prepare_param_data_types(arg_types.unwrap_or_default()); diff --git a/datafusion/sql/tests/cases/params.rs b/datafusion/sql/tests/cases/params.rs index 147628656d8f..a697fa460bb6 100644 --- a/datafusion/sql/tests/cases/params.rs +++ b/datafusion/sql/tests/cases/params.rs @@ -105,7 +105,7 @@ fn test_prepare_statement_to_plan_panic_param_format() { assert_snapshot!( logical_plan(sql).unwrap_err().strip_backtrace(), @r###" - Error during planning: Invalid placeholder, not a number: $foo + Error during planning: Unknown placeholder: $foo "### ); } diff --git a/datafusion/sqllogictest/test_files/prepare.slt b/datafusion/sqllogictest/test_files/prepare.slt index d61603ae6558..da28b4fe7e3f 100644 --- a/datafusion/sqllogictest/test_files/prepare.slt +++ b/datafusion/sqllogictest/test_files/prepare.slt @@ -34,7 +34,7 @@ statement error DataFusion error: SQL error: ParserError PREPARE AS SELECT id, age FROM person WHERE age = $foo; # param following a non-number, $foo, not supported -statement error Invalid placeholder, not a number: \$foo +statement error Unknown placeholder: \$foo PREPARE my_plan(INT) AS SELECT id, age FROM person WHERE age = $foo; # not specify table hence cannot specify columns