diff --git a/datafusion/functions/src/string/contains.rs b/datafusion/functions/src/string/contains.rs index 0f75731aa1c3..86f1eda03342 100644 --- a/datafusion/functions/src/string/contains.rs +++ b/datafusion/functions/src/string/contains.rs @@ -16,8 +16,8 @@ // under the License. use crate::utils::make_scalar_function; -use arrow::array::{Array, ArrayRef, AsArray, GenericStringArray, StringViewArray}; -use arrow::compute::regexp_is_match; +use arrow::array::{Array, ArrayRef, AsArray}; +use arrow::compute::contains as arrow_contains; use arrow::datatypes::DataType; use arrow::datatypes::DataType::{Boolean, LargeUtf8, Utf8, Utf8View}; use datafusion_common::exec_err; @@ -102,40 +102,25 @@ fn get_contains_doc() -> &'static Documentation { }) } -/// use regexp_is_match_utf8_scalar to do the calculation for contains +/// use `arrow::compute::contains` to do the calculation for contains pub fn contains(args: &[ArrayRef]) -> Result { match (args[0].data_type(), args[1].data_type()) { (Utf8View, Utf8View) => { let mod_str = args[0].as_string_view(); let match_str = args[1].as_string_view(); - let res = regexp_is_match::< - StringViewArray, - StringViewArray, - GenericStringArray, - >(mod_str, match_str, None)?; - + let res = arrow_contains(mod_str, match_str)?; Ok(Arc::new(res) as ArrayRef) } (Utf8, Utf8) => { let mod_str = args[0].as_string::(); let match_str = args[1].as_string::(); - let res = regexp_is_match::< - GenericStringArray, - GenericStringArray, - GenericStringArray, - >(mod_str, match_str, None)?; - + let res = arrow_contains(mod_str, match_str)?; Ok(Arc::new(res) as ArrayRef) } (LargeUtf8, LargeUtf8) => { let mod_str = args[0].as_string::(); let match_str = args[1].as_string::(); - let res = regexp_is_match::< - GenericStringArray, - GenericStringArray, - GenericStringArray, - >(mod_str, match_str, None)?; - + let res = arrow_contains(mod_str, match_str)?; Ok(Arc::new(res) as ArrayRef) } other => { @@ -143,3 +128,31 @@ pub fn contains(args: &[ArrayRef]) -> Result { } } } + +#[cfg(test)] +mod test { + use super::ContainsFunc; + use arrow::array::{BooleanArray, StringArray}; + use datafusion_common::ScalarValue; + use datafusion_expr::{ColumnarValue, ScalarUDFImpl}; + use std::sync::Arc; + + #[test] + fn test_contains_udf() { + let udf = ContainsFunc::new(); + let array = ColumnarValue::Array(Arc::new(StringArray::from(vec![ + Some("xxx?()"), + Some("yyy?()"), + ]))); + let scalar = ColumnarValue::Scalar(ScalarValue::Utf8(Some("x?(".to_string()))); + let actual = udf.invoke(&[array, scalar]).unwrap(); + let expect = ColumnarValue::Array(Arc::new(BooleanArray::from(vec![ + Some(true), + Some(false), + ]))); + assert_eq!( + *actual.into_array(2).unwrap(), + *expect.into_array(2).unwrap() + ); + } +}