From 01224e19bd79c7e397345aa8bfda956b3012ae88 Mon Sep 17 00:00:00 2001 From: alexandreyc Date: Wed, 2 Aug 2023 12:23:42 +0200 Subject: [PATCH 1/2] Refactor LIKE kernel --- arrow-string/src/like.rs | 376 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 376 insertions(+) diff --git a/arrow-string/src/like.rs b/arrow-string/src/like.rs index 9d3abea66fb..449994d702d 100644 --- a/arrow-string/src/like.rs +++ b/arrow-string/src/like.rs @@ -235,6 +235,103 @@ dict_function!("STARTSWITH(left, right)", starts_with_dict, starts_with); dict_function!("ENDSWITH(left, right)", ends_with_dict, ends_with); dict_function!("CONTAINS(left, right)", contains_dict, contains); +macro_rules! datum_function { + ($fn_name:ident, $fn_array:ident, $fn_scalar:ident) => { + pub fn $fn_name( + left: &dyn Datum, + right: &dyn Datum, + ) -> Result { + let (left_array, left_scalar) = left.get(); + let (right_array, right_scalar) = right.get(); + let left_type = left_array.data_type(); + let right_type = right_array.data_type(); + + // TODO(alexandreyc): check if PartialEq for DataType is deep or shallow + // i.e. does it check nested subtypes for equality? + if left_type != right_type { + return Err(ArrowError::ComputeError( + "Arrays must have the same data type".to_string(), + )); + } + + if left_scalar && !right_scalar { + return Err(ArrowError::ComputeError( + "Left cannot be scalar when right is not".to_string(), + )); + } + + match left_type { + DataType::Utf8 => { + let left_array = left_array.as_string::(); + let right_array = right_array.as_string::(); + if right_scalar { + $fn_scalar(left_array, right_array.value(0)) + } else { + $fn_array(left_array, right_array) + } + } + DataType::LargeUtf8 => { + let left_array = left_array.as_string::(); + let right_array = right_array.as_string::(); + if right_scalar { + $fn_scalar(left_array, right_array.value(0)) + } else { + $fn_array(left_array, right_array) + } + } + DataType::Dictionary(_, value_type) => match **value_type { + DataType::Utf8 => { + downcast_dictionary_array!( + left_array => { + let right_array = as_dictionary_array(right_array); + let right_array = right_array.downcast_dict::>().unwrap(); + let left_array = left_array.downcast_dict::>().unwrap(); + + if right_scalar { + $fn_scalar(left_array, right_array.value(0)) + } else { + $fn_array(left_array, right_array) + } + } + t => Err(ArrowError::ComputeError(format!( + "Should be DictionaryArray but got: {}", t + ))) + ) + } + DataType::LargeUtf8 => { + downcast_dictionary_array!( + left_array => { + let left_array = left_array.downcast_dict::>().unwrap(); + let right_array = as_dictionary_array(right_array); + let right_array = right_array.downcast_dict::>().unwrap(); + + if right_scalar { + $fn_scalar(left_array, right_array.value(0)) + } else { + $fn_array(left_array, right_array) + } + } + t => Err(ArrowError::ComputeError(format!( + "Should be DictionaryArray but got: {}", t + ))) + ) + } + _ => Err(ArrowError::ComputeError(format!( + "Unsupported dictionnary value type: {}", + value_type + ))), + }, + _ => Err(ArrowError::ComputeError(format!( + "Unsupported data type: {}", + left_type + ))), + } + } + }; +} + +datum_function!(like_datum, like, like_scalar); + /// Perform SQL `left LIKE right` operation on [`StringArray`] / [`LargeStringArray`]. /// /// There are two wildcards supported with the LIKE operator: @@ -810,6 +907,7 @@ mod tests { }; } + // OK test_utf8!( test_utf8_array_like, vec!["arrow", "arrow", "arrow", "arrow", "arrow", "arrows", "arrow", "arrow"], @@ -818,6 +916,7 @@ mod tests { vec![true, true, true, false, false, true, false, false] ); + // TODO test_dict_utf8!( test_utf8_array_like_dict, vec!["arrow", "arrow", "arrow", "arrow", "arrow", "arrows", "arrow", "arrow"], @@ -826,6 +925,7 @@ mod tests { vec![true, true, true, false, false, true, false, false] ); + // OK test_utf8_scalar!( test_utf8_array_like_scalar_escape_testing, test_utf8_array_like_scalar_dyn_escape_testing, @@ -836,6 +936,7 @@ mod tests { vec![true, true, false, false] ); + // OK test_utf8_scalar!( test_utf8_array_like_scalar_escape_regex, test_utf8_array_like_scalar_dyn_escape_regex, @@ -846,6 +947,7 @@ mod tests { vec![true, false, false] ); + // OK test_utf8_scalar!( test_utf8_array_like_scalar_escape_regex_dot, test_utf8_array_like_scalar_dyn_escape_regex_dot, @@ -856,6 +958,7 @@ mod tests { vec![true, false, false] ); + // OK test_utf8_scalar!( test_utf8_array_like_scalar, test_utf8_array_like_scalar_dyn, @@ -866,6 +969,7 @@ mod tests { vec![true, true, false, false] ); + // OK test_utf8_scalar!( test_utf8_array_like_scalar_start, test_utf8_array_like_scalar_dyn_start, @@ -888,6 +992,7 @@ mod tests { vec![true, false, true, false] ); + // OK test_utf8_scalar!( test_utf8_array_like_scalar_end, test_utf8_array_like_scalar_dyn_end, @@ -910,6 +1015,7 @@ mod tests { vec![true, true, false, false] ); + // OK test_utf8_scalar!( test_utf8_array_like_scalar_equals, test_utf8_array_like_scalar_dyn_equals, @@ -920,6 +1026,7 @@ mod tests { vec![true, false, false, false] ); + // OK test_utf8_scalar!( test_utf8_array_like_scalar_one, test_utf8_array_like_scalar_dyn_one, @@ -930,6 +1037,7 @@ mod tests { vec![false, true, false, false] ); + // OK test_utf8_scalar!( test_utf8_scalar_like_escape, test_utf8_scalar_like_dyn_escape, @@ -940,6 +1048,7 @@ mod tests { vec![true, false] ); + // OK test_utf8_scalar!( test_utf8_scalar_like_escape_contains, test_utf8_scalar_like_dyn_escape_contains, @@ -1942,3 +2051,270 @@ mod tests { ); } } + +#[cfg(test)] +mod tests_datum { + use super::*; + use arrow_array::types::Int8Type; + use std::sync::Arc; + + macro_rules! test_array_array { + ($test_name:ident, $op:ident, $left:expr, $right:expr, $expected:expr) => { + #[test] + fn $test_name() { + let expected = $expected; + + // StringArray + let left = StringArray::from($left); + let right = StringArray::from($right); + let res = $op(&left, &right).unwrap(); + assert_eq!(res.len(), expected.len()); + for i in 0..res.len() { + let v = res.value(i); + assert_eq!(v, expected[i]); + } + + // LargeStringArray + let left = LargeStringArray::from($left); + let right = LargeStringArray::from($right); + let res = $op(&left, &right).unwrap(); + assert_eq!(res.len(), expected.len()); + for i in 0..res.len() { + let v = res.value(i); + assert_eq!(v, expected[i]); + } + + // DictionnaryArray + let left: DictionaryArray = $left.into_iter().collect(); + let right: DictionaryArray = $right.into_iter().collect(); + let res = $op(&left, &right).unwrap(); + assert_eq!(res.len(), expected.len()); + for i in 0..res.len() { + let v = res.value(i); + assert_eq!(v, expected[i]); + } + } + }; + } + + macro_rules! test_array_scalar { + ($test_name:ident, $op:ident, $left:expr, $right:expr, $expected:expr) => { + #[test] + fn $test_name() { + let expected = $expected; + + // StringArray + let left = StringArray::from($left); + let right = StringArray::from(vec![$right]); + let right = Scalar::new(&right); + let res = $op(&left, &right).unwrap(); + assert_eq!(res.len(), expected.len()); + for i in 0..res.len() { + let v = res.value(i); + assert_eq!(v, expected[i]); + } + + // LargeStringArray + let left = LargeStringArray::from($left); + let right = LargeStringArray::from(vec![$right]); + let right = Scalar::new(&right); + let res = $op(&left, &right).unwrap(); + assert_eq!(res.len(), expected.len()); + for i in 0..res.len() { + let v = res.value(i); + assert_eq!(v, expected[i]); + } + + // DictionnaryArray + let left: DictionaryArray = $left.into_iter().collect(); + let right: DictionaryArray = [$right].into_iter().collect(); + let right = Scalar::new(&right); + let res = $op(&left, &right).unwrap(); + assert_eq!(res.len(), expected.len()); + for i in 0..res.len() { + let v = res.value(i); + assert_eq!(v, expected[i]); + } + } + }; + } + + macro_rules! test_scalar_scalar { + ($test_name:ident, $op:ident, $left:expr, $right:expr, $expected:expr) => { + #[test] + fn $test_name() { + let expected = $expected; + + // StringArray + let left = StringArray::from(vec![$left]); + let left = Scalar::new(&left); + let right = StringArray::from(vec![$right]); + let right = Scalar::new(&right); + let res = $op(&left, &right).unwrap(); + assert_eq!(res.len(), 1); + assert_eq!(res.value(0), expected); + + // LargeStringArray + let left = LargeStringArray::from(vec![$left]); + let right = LargeStringArray::from(vec![$right]); + let right = Scalar::new(&right); + let res = $op(&left, &right).unwrap(); + assert_eq!(res.len(), 1); + assert_eq!(res.value(0), expected); + + // DictionnaryArray + let left: DictionaryArray = [$left].into_iter().collect(); + let left = Scalar::new(&left); + let right: DictionaryArray = [$right].into_iter().collect(); + let right = Scalar::new(&right); + let res = $op(&left, &right).unwrap(); + assert_eq!(res.len(), 1); + assert_eq!(res.value(0), expected); + } + }; + } + + macro_rules! test_errors { + ($test_name:ident, $op:ident) => { + #[test] + fn $test_name() { + let left = StringArray::from(vec!["a"]); + let left = Scalar::new(&left); + let right = StringArray::from(vec!["a", "b", "c", "d"]); + let res = $op(&left, &right).unwrap_err(); + assert_eq!( + res.to_string(), + "Compute error: Left cannot be scalar when right is not" + ); + + let left = StringArray::from(vec!["a", "b", "c", "d"]); + let right = StringArray::from(vec!["a"]); + let res = $op(&left, &right).unwrap_err(); + assert_eq!( + res.to_string(), + "Compute error: Cannot perform comparison operation on arrays of different length" + ); + + let left = StringArray::from(vec!["a", "b", "c", "d"]); + let right = LargeStringArray::from(vec!["a", "b", "c", "d"]); + let res = $op(&left, &right).unwrap_err(); + assert_eq!( + res.to_string(), + "Compute error: Arrays must have the same data type" + ); + + let left = Int32Array::from(vec![1, 2, 3, 4]); + let right = Int32Array::from(vec![1, 2, 3, 4]); + let res = $op(&left, &right).unwrap_err(); + assert_eq!( + res.to_string(), + format!("Compute error: Unsupported data type: {}", left.data_type()) + ); + + let values = Arc::new(BinaryArray::from_iter_values(["a", "b", "c"])); + let keys = Int8Array::from_iter_values([0, 0, 1, 2]); + let left = DictionaryArray::::try_new(keys.clone(), values.clone()).unwrap(); + let right = DictionaryArray::::try_new(keys.clone(), values.clone()).unwrap(); + let res = $op(&left, &right).unwrap_err(); + assert_eq!( + res.to_string(), + format!("Compute error: Unsupported dictionnary value type: {}", left.value_type()) + ); + } + } + } + + // LIKE + + test_array_array!( + test_like_array_array, + like_datum, + vec!["arrow", "arrow", "arrow", "arrow", "arrow", "arrows", "arrow", "arrow"], + vec!["arrow", "ar%", "%ro%", "foo", "arr", "arrow_", "arrow_", ".*"], + vec![true, true, true, false, false, true, false, false] + ); + + test_array_scalar!( + test_like_array_scalar_start, + like_datum, + vec!["arrow", "parrow", "arrows", "arr"], + "arrow%", + vec![true, false, true, false] + ); + + test_array_scalar!( + test_like_array_scalar_end, + like_datum, + vec!["arrow", "parrow", "arrows", "arr"], + "%arrow", + vec![true, true, false, false] + ); + + test_array_scalar!( + test_like_array_scalar_start_end, + like_datum, + vec!["arrow", "parquet", "datafusion", "flight"], + "%ar%", + vec![true, true, false, false] + ); + + test_array_scalar!( + test_like_array_scalar_escape_testing, + like_datum, + vec!["varchar(255)", "int(255)", "varchar", "int"], + "%(%)%", + vec![true, true, false, false] + ); + + test_array_scalar!( + test_like_array_scalar_escape_regex, + like_datum, + vec![".*", "a", "*"], + ".*", + vec![true, false, false] + ); + + test_array_scalar!( + test_like_array_scalar_escape_regex_dot, + like_datum, + vec![".", "a", "*"], + ".", + vec![true, false, false] + ); + + test_array_scalar!( + test_like_array_scalar_equals, + like_datum, + vec!["arrow", "parrow", "arrows", "arr"], + "arrow", + vec![true, false, false, false] + ); + + test_array_scalar!( + test_like_array_scalar_one, + like_datum, + vec!["arrow", "arrows", "parrow", "arr"], + "arrow_", + vec![false, true, false, false] + ); + + test_array_scalar!( + test_like_array_scalar_escape, + like_datum, + vec!["a%", "a\\x"], + "a\\%", + vec![true, false] + ); + + test_array_scalar!( + test_like_array_scalar_escape_contains, + like_datum, + vec!["ba%", "ba\\x"], + "%a\\%", + vec![true, false] + ); + + test_scalar_scalar!(test_like_scalar_scalar, like_datum, "arrow", "%rr%", true); + + test_errors!(test_like_errors, like_datum); +} From 46fa925df77b8e68c20942ad0fce701f6736a8d1 Mon Sep 17 00:00:00 2001 From: alexandreyc Date: Wed, 2 Aug 2023 15:05:53 +0200 Subject: [PATCH 2/2] Remove old TODO --- arrow-string/src/like.rs | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/arrow-string/src/like.rs b/arrow-string/src/like.rs index 449994d702d..cdb689ea1cf 100644 --- a/arrow-string/src/like.rs +++ b/arrow-string/src/like.rs @@ -246,8 +246,6 @@ macro_rules! datum_function { let left_type = left_array.data_type(); let right_type = right_array.data_type(); - // TODO(alexandreyc): check if PartialEq for DataType is deep or shallow - // i.e. does it check nested subtypes for equality? if left_type != right_type { return Err(ArrowError::ComputeError( "Arrays must have the same data type".to_string(), @@ -916,7 +914,7 @@ mod tests { vec![true, true, true, false, false, true, false, false] ); - // TODO + // OK test_dict_utf8!( test_utf8_array_like_dict, vec!["arrow", "arrow", "arrow", "arrow", "arrow", "arrows", "arrow", "arrow"],