diff --git a/datafusion/functions/src/regex/regexpreplace.rs b/datafusion/functions/src/regex/regexpreplace.rs index d28c6cd36d65..f033b7a51948 100644 --- a/datafusion/functions/src/regex/regexpreplace.rs +++ b/datafusion/functions/src/regex/regexpreplace.rs @@ -16,15 +16,15 @@ // under the License. //! Regx expressions -use arrow::array::new_null_array; use arrow::array::ArrayAccessor; use arrow::array::ArrayDataBuilder; use arrow::array::BufferBuilder; use arrow::array::GenericStringArray; use arrow::array::StringViewBuilder; +use arrow::array::{new_null_array, ArrayIter, AsArray}; use arrow::array::{Array, ArrayRef, OffsetSizeTrait}; use arrow::datatypes::DataType; -use datafusion_common::cast::as_string_view_array; +use datafusion_common::cast::{as_string_array, as_string_view_array}; use datafusion_common::exec_err; use datafusion_common::plan_err; use datafusion_common::ScalarValue; @@ -187,27 +187,34 @@ fn regex_replace_posix_groups(replacement: &str) -> String { /// # Ok(()) /// # } /// ``` -pub fn regexp_replace(args: &[ArrayRef]) -> Result { +pub fn regexp_replace<'a, T: OffsetSizeTrait, V, B>( + string_array: V, + pattern_array: B, + replacement_array: B, + flags: Option<&ArrayRef>, +) -> Result +where + V: ArrayAccessor, + B: ArrayAccessor, +{ // Default implementation for regexp_replace, assumes all args are arrays // and args is a sequence of 3 or 4 elements. // creating Regex is expensive so create hashmap for memoization let mut patterns: HashMap = HashMap::new(); - match args.len() { - 3 => { - let string_array = as_generic_string_array::(&args[0])?; - let pattern_array = as_generic_string_array::(&args[1])?; - let replacement_array = as_generic_string_array::(&args[2])?; + let string_array_iter = ArrayIter::new(string_array); + let pattern_array_iter = ArrayIter::new(pattern_array); + let replacement_array_iter = ArrayIter::new(replacement_array); - let result = string_array - .iter() - .zip(pattern_array.iter()) - .zip(replacement_array.iter()) + match flags { + None => { + let result = string_array_iter + .zip(pattern_array_iter) + .zip(replacement_array_iter) .map(|((string, pattern), replacement)| match (string, pattern, replacement) { (Some(string), Some(pattern), Some(replacement)) => { let replacement = regex_replace_posix_groups(replacement); - // if patterns hashmap already has regexp then use else create and return let re = match patterns.get(pattern) { Some(re) => Ok(re), @@ -230,16 +237,12 @@ pub fn regexp_replace(args: &[ArrayRef]) -> Result Ok(Arc::new(result) as ArrayRef) } - 4 => { - let string_array = as_generic_string_array::(&args[0])?; - let pattern_array = as_generic_string_array::(&args[1])?; - let replacement_array = as_generic_string_array::(&args[2])?; - let flags_array = as_generic_string_array::(&args[3])?; + Some(flags) => { + let flags_array = as_generic_string_array::(flags)?; - let result = string_array - .iter() - .zip(pattern_array.iter()) - .zip(replacement_array.iter()) + let result = string_array_iter + .zip(pattern_array_iter) + .zip(replacement_array_iter) .zip(flags_array.iter()) .map(|(((string, pattern), replacement), flags)| match (string, pattern, replacement, flags) { (Some(string), Some(pattern), Some(replacement), Some(flags)) => { @@ -283,7 +286,7 @@ pub fn regexp_replace(args: &[ArrayRef]) -> Result Ok(Arc::new(result) as ArrayRef) } other => exec_err!( - "regexp_replace was called with {other} arguments. It requires at least 3 and at most 4." + "regexp_replace was called with {other:?} arguments. It requires at least 3 and at most 4." ), } } @@ -496,7 +499,69 @@ pub fn specialize_regexp_replace( .iter() .map(|arg| arg.clone().into_array(inferred_length)) .collect::>>()?; - regexp_replace::(&args) + + match args[0].data_type() { + DataType::Utf8View => { + let string_array = args[0].as_string_view(); + let pattern_array = args[1].as_string::(); + let replacement_array = args[2].as_string::(); + let regexp_replace_result = regexp_replace::( + string_array, + pattern_array, + replacement_array, + args.get(3), + )?; + + if regexp_replace_result.data_type() == &DataType::Utf8 { + let string_view_array = + as_string_array(®exp_replace_result)?.to_owned(); + + let mut builder = + StringViewBuilder::with_capacity(string_view_array.len()) + .with_block_size(1024 * 1024 * 2); + + for val in string_view_array.iter() { + if let Some(val) = val { + builder.append_value(val); + } else { + builder.append_null(); + } + } + + let result = builder.finish(); + Ok(Arc::new(result) as ArrayRef) + } else { + Ok(regexp_replace_result) + } + } + DataType::Utf8 => { + let string_array = args[0].as_string::(); + let pattern_array = args[1].as_string::(); + let replacement_array = args[2].as_string::(); + regexp_replace::( + string_array, + pattern_array, + replacement_array, + args.get(3), + ) + } + DataType::LargeUtf8 => { + let string_array = args[0].as_string::(); + let pattern_array = args[1].as_string::(); + let replacement_array = args[2].as_string::(); + regexp_replace::( + string_array, + pattern_array, + replacement_array, + args.get(3), + ) + } + other => { + exec_err!( + "Unsupported data type {other:?} for function regex_replace" + ) + } + } } } } diff --git a/datafusion/sqllogictest/test_files/string_view.slt b/datafusion/sqllogictest/test_files/string_view.slt index 83c75b8df38c..a50724b36468 100644 --- a/datafusion/sqllogictest/test_files/string_view.slt +++ b/datafusion/sqllogictest/test_files/string_view.slt @@ -460,6 +460,52 @@ Xiangpeng Raphael NULL +### Test REGEXP_REPLACE + +# Should run REGEXP_REPLACE with Scalar value for utf8view +query T +SELECT + REGEXP_REPLACE(column1_utf8view, 'e', 'f') AS k +FROM test; +---- +Andrfw +Xiangpfng +Raphafl +NULL + +# Should run REGEXP_REPLACE with Scalar value for utf8 +query T +SELECT + REGEXP_REPLACE(column1_utf8, 'e', 'f') AS k +FROM test; +---- +Andrfw +Xiangpfng +Raphafl +NULL + +# Should run REGEXP_REPLACE with ScalarArray value for utf8view +query T +SELECT + REGEXP_REPLACE(column1_utf8view, lower(column1_utf8view), 'bar') AS k +FROM test; +---- +Andrew +Xiangpeng +Raphael +NULL + +# Should run REGEXP_REPLACE with ScalarArray value for utf8 +query T +SELECT + REGEXP_REPLACE(column1_utf8, lower(column1_utf8), 'bar') AS k +FROM test; +---- +Andrew +Xiangpeng +Raphael +NULL + ### Initcap query TT