diff --git a/datafusion/functions/src/unicode/substr.rs b/datafusion/functions/src/unicode/substr.rs index 46b3cc63d0b6..6eee49d49093 100644 --- a/datafusion/functions/src/unicode/substr.rs +++ b/datafusion/functions/src/unicode/substr.rs @@ -27,9 +27,13 @@ use arrow::array::{ use arrow::buffer::ScalarBuffer; use arrow::datatypes::DataType; use datafusion_common::cast::as_int64_array; -use datafusion_common::{exec_err, plan_err, Result}; +use datafusion_common::types::{ + logical_int32, logical_int64, logical_string, NativeType, +}; +use datafusion_common::{exec_err, Result}; use datafusion_expr::{ - ColumnarValue, Documentation, ScalarUDFImpl, Signature, Volatility, + Coercion, ColumnarValue, Documentation, ScalarUDFImpl, Signature, TypeSignature, + TypeSignatureClass, Volatility, }; use datafusion_macros::user_doc; @@ -44,7 +48,7 @@ use datafusion_macros::user_doc; | substr(Utf8("datafusion"),Int64(5),Int64(3)) | +----------------------------------------------+ | fus | -+----------------------------------------------+ ++----------------------------------------------+ ```"#, standard_argument(name = "str", prefix = "String"), argument( @@ -70,14 +74,30 @@ impl Default for SubstrFunc { impl SubstrFunc { pub fn new() -> Self { + let string = Coercion::new_exact(TypeSignatureClass::Native(logical_string())); + let int64 = Coercion::new_implicit( + TypeSignatureClass::Native(logical_int64()), + vec![TypeSignatureClass::Native(logical_int32())], + NativeType::Int64, + ); Self { - signature: Signature::user_defined(Volatility::Immutable) - .with_parameter_names(vec![ - "str".to_string(), - "start_pos".to_string(), - "length".to_string(), - ]) - .expect("valid parameter names"), + signature: Signature::one_of( + vec![ + TypeSignature::Coercible(vec![string.clone(), int64.clone()]), + TypeSignature::Coercible(vec![ + string.clone(), + int64.clone(), + int64.clone(), + ]), + ], + Volatility::Immutable, + ) + .with_parameter_names(vec![ + "str".to_string(), + "start_pos".to_string(), + "length".to_string(), + ]) + .expect("valid parameter names"), aliases: vec![String::from("substring")], } } @@ -112,72 +132,6 @@ impl ScalarUDFImpl for SubstrFunc { &self.aliases } - fn coerce_types(&self, arg_types: &[DataType]) -> Result> { - if arg_types.len() < 2 || arg_types.len() > 3 { - return plan_err!( - "The {} function requires 2 or 3 arguments, but got {}.", - self.name(), - arg_types.len() - ); - } - let first_data_type = match &arg_types[0] { - DataType::Null => Ok(DataType::Utf8), - DataType::LargeUtf8 | DataType::Utf8View | DataType::Utf8 => Ok(arg_types[0].clone()), - DataType::Dictionary(key_type, value_type) => { - if key_type.is_integer() { - match value_type.as_ref() { - DataType::Null => Ok(DataType::Utf8), - DataType::LargeUtf8 | DataType::Utf8View | DataType::Utf8 => Ok(*value_type.clone()), - _ => plan_err!( - "The first argument of the {} function can only be a string, but got {:?}.", - self.name(), - arg_types[0] - ), - } - } else { - plan_err!( - "The first argument of the {} function can only be a string, but got {:?}.", - self.name(), - arg_types[0] - ) - } - } - _ => plan_err!( - "The first argument of the {} function can only be a string, but got {:?}.", - self.name(), - arg_types[0] - ) - }?; - - if ![DataType::Int64, DataType::Int32, DataType::Null].contains(&arg_types[1]) { - return plan_err!( - "The second argument of the {} function can only be an integer, but got {:?}.", - self.name(), - arg_types[1] - ); - } - - if arg_types.len() == 3 - && ![DataType::Int64, DataType::Int32, DataType::Null].contains(&arg_types[2]) - { - return plan_err!( - "The third argument of the {} function can only be an integer, but got {:?}.", - self.name(), - arg_types[2] - ); - } - - if arg_types.len() == 2 { - Ok(vec![first_data_type.to_owned(), DataType::Int64]) - } else { - Ok(vec![ - first_data_type.to_owned(), - DataType::Int64, - DataType::Int64, - ]) - } - } - fn documentation(&self) -> Option<&Documentation> { self.doc() } diff --git a/datafusion/sqllogictest/test_files/functions.slt b/datafusion/sqllogictest/test_files/functions.slt index 20f79622a62c..6c87d618c727 100644 --- a/datafusion/sqllogictest/test_files/functions.slt +++ b/datafusion/sqllogictest/test_files/functions.slt @@ -193,10 +193,25 @@ SELECT substr('alphabet', 3, CAST(NULL AS int)) ---- NULL -statement error The first argument of the substr function can only be a string, but got Int64 +query T +SELECT substr(NULL, 1, 2) +---- +NULL + +query T +SELECT substr('alphabet', 1, NULL) +---- +NULL + +query T +SELECT substr('alphabet', NULL, 2) +---- +NULL + +statement error Function 'substr' failed to match any signature SELECT substr(1, 3) -statement error The first argument of the substr function can only be a string, but got Int64 +statement error Function 'substr' failed to match any signature SELECT substr(1, 3, 4) query T diff --git a/datafusion/sqllogictest/test_files/named_arguments.slt b/datafusion/sqllogictest/test_files/named_arguments.slt index 4eab799fd261..a16de826a2d5 100644 --- a/datafusion/sqllogictest/test_files/named_arguments.slt +++ b/datafusion/sqllogictest/test_files/named_arguments.slt @@ -85,7 +85,7 @@ SELECT substr("STR" => 'hello world', "start_pos" => 7); # Error: wrong number of arguments # This query provides only 1 argument but substr requires 2 or 3 -query error DataFusion error: Error during planning: Execution error: Function 'substr' user-defined coercion failed with "Error during planning: The substr function requires 2 or 3 arguments, but got 1." +query error Function 'substr' failed to match any signature SELECT substr(str => 'hello world'); ############# diff --git a/datafusion/sqllogictest/test_files/string/string_literal.slt b/datafusion/sqllogictest/test_files/string/string_literal.slt index f602dbb54b08..f5138ab3f734 100644 --- a/datafusion/sqllogictest/test_files/string/string_literal.slt +++ b/datafusion/sqllogictest/test_files/string/string_literal.slt @@ -132,10 +132,10 @@ SELECT substr('Hello🌏世界', 5, 3) ---- o🌏世 -statement error The first argument of the substr function can only be a string, but got Int64 +statement error Function 'substr' failed to match any signature SELECT substr(1, 3) -statement error The first argument of the substr function can only be a string, but got Int64 +statement error Function 'substr' failed to match any signature SELECT substr(1, 3, 4) statement error Execution error: negative substring length not allowed