diff --git a/auron-spark-tests/spark35/src/test/scala/org/apache/auron/utils/AuronSparkTestSettings.scala b/auron-spark-tests/spark35/src/test/scala/org/apache/auron/utils/AuronSparkTestSettings.scala index 049ba09e4..a1ddd7204 100644 --- a/auron-spark-tests/spark35/src/test/scala/org/apache/auron/utils/AuronSparkTestSettings.scala +++ b/auron-spark-tests/spark35/src/test/scala/org/apache/auron/utils/AuronSparkTestSettings.scala @@ -59,8 +59,6 @@ class AuronSparkTestSettings extends SparkTestSettings { enableSuite[AuronMiscFunctionsSuite] enableSuite[AuronStringFunctionsSuite] - // Native levenshtein has a Spark 3.5+ result or schema comparison mismatch. - .exclude("string Levenshtein distance") // Native substr does not support BinaryType inputs. // See https://github.com/apache/auron/issues/1724 .exclude("string / binary substring function") diff --git a/auron-spark-tests/spark40/src/test/scala/org/apache/auron/utils/AuronSparkTestSettings.scala b/auron-spark-tests/spark40/src/test/scala/org/apache/auron/utils/AuronSparkTestSettings.scala index 2d96d470c..a24c6afe8 100644 --- a/auron-spark-tests/spark40/src/test/scala/org/apache/auron/utils/AuronSparkTestSettings.scala +++ b/auron-spark-tests/spark40/src/test/scala/org/apache/auron/utils/AuronSparkTestSettings.scala @@ -59,9 +59,6 @@ class AuronSparkTestSettings extends SparkTestSettings { enableSuite[AuronMiscFunctionsSuite] enableSuite[AuronStringFunctionsSuite] - // Spark 4 adds the threshold argument, but native levenshtein currently supports only - // two arguments. - .exclude("string Levenshtein distance") // Native substr does not support BinaryType inputs. .exclude("string / binary substring function") diff --git a/auron-spark-tests/spark41/src/test/scala/org/apache/auron/utils/AuronSparkTestSettings.scala b/auron-spark-tests/spark41/src/test/scala/org/apache/auron/utils/AuronSparkTestSettings.scala index 2d96d470c..a24c6afe8 100644 --- a/auron-spark-tests/spark41/src/test/scala/org/apache/auron/utils/AuronSparkTestSettings.scala +++ b/auron-spark-tests/spark41/src/test/scala/org/apache/auron/utils/AuronSparkTestSettings.scala @@ -59,9 +59,6 @@ class AuronSparkTestSettings extends SparkTestSettings { enableSuite[AuronMiscFunctionsSuite] enableSuite[AuronStringFunctionsSuite] - // Spark 4 adds the threshold argument, but native levenshtein currently supports only - // two arguments. - .exclude("string Levenshtein distance") // Native substr does not support BinaryType inputs. .exclude("string / binary substring function") diff --git a/native-engine/auron-planner/proto/auron.proto b/native-engine/auron-planner/proto/auron.proto index ab471fbb2..8540de5ea 100644 --- a/native-engine/auron-planner/proto/auron.proto +++ b/native-engine/auron-planner/proto/auron.proto @@ -281,7 +281,7 @@ enum ScalarFunction { Hex=66; Power=67; IsNaN=69; - Levenshtein=80; + // Levenshtein=80; FindInSet=81; Nvl=82; Nvl2=83; diff --git a/native-engine/auron-planner/src/planner.rs b/native-engine/auron-planner/src/planner.rs index e1e3df149..8c7da0fec 100644 --- a/native-engine/auron-planner/src/planner.rs +++ b/native-engine/auron-planner/src/planner.rs @@ -1292,8 +1292,7 @@ impl From for Arc { ScalarFunction::Rpad => f::unicode::rpad(), ScalarFunction::SplitPart => f::string::split_part(), ScalarFunction::StartsWith => f::string::starts_with(), - ScalarFunction::Levenshtein => f::string::levenshtein(), - + // ScalarFunction::Levenshtein => f::string::levenshtein(), ScalarFunction::FindInSet => f::unicode::find_in_set(), ScalarFunction::Strpos => f::unicode::strpos(), ScalarFunction::Substr => f::unicode::substr(), diff --git a/native-engine/datafusion-ext-functions/src/lib.rs b/native-engine/datafusion-ext-functions/src/lib.rs index 7eb1f63a6..eb97b5ee5 100644 --- a/native-engine/datafusion-ext-functions/src/lib.rs +++ b/native-engine/datafusion-ext-functions/src/lib.rs @@ -75,6 +75,7 @@ pub fn create_auron_ext_function( "Spark_StringConcatWs" => Arc::new(spark_strings::string_concat_ws), "Spark_StringLower" => Arc::new(spark_strings::string_lower), "Spark_StringUpper" => Arc::new(spark_strings::string_upper), + "Spark_Levenshtein" => Arc::new(spark_strings::spark_levenshtein), "Spark_InitCap" => Arc::new(spark_initcap::string_initcap), "Spark_Year" => Arc::new(spark_dates::spark_year), "Spark_Month" => Arc::new(spark_dates::spark_month), diff --git a/native-engine/datafusion-ext-functions/src/spark_strings.rs b/native-engine/datafusion-ext-functions/src/spark_strings.rs index 43b1f136f..58897dcde 100644 --- a/native-engine/datafusion-ext-functions/src/spark_strings.rs +++ b/native-engine/datafusion-ext-functions/src/spark_strings.rs @@ -16,7 +16,9 @@ use std::sync::Arc; use arrow::{ - array::{Array, ArrayRef, AsArray, ListArray, ListBuilder, StringArray, StringBuilder}, + array::{ + Array, ArrayRef, AsArray, Int32Array, ListArray, ListBuilder, StringArray, StringBuilder, + }, datatypes::DataType, }; use datafusion::{ @@ -114,6 +116,118 @@ pub fn string_split(args: &[ColumnarValue]) -> Result { Ok(ColumnarValue::Array(Arc::new(splitted_builder.finish()))) } +pub fn spark_levenshtein(args: &[ColumnarValue]) -> Result { + if args.len() != 2 && args.len() != 3 { + df_execution_err!( + "levenshtein was called with {} arguments. It requires 2 or 3.", + args.len(), + )?; + } + + if args + .iter() + .all(|arg| matches!(arg, ColumnarValue::Scalar(_))) + { + let left = match &args[0] { + ColumnarValue::Scalar(ScalarValue::Utf8(value)) => value.as_deref(), + _ => df_execution_err!("levenshtein only supports utf8 string arguments")?, + }; + let right = match &args[1] { + ColumnarValue::Scalar(ScalarValue::Utf8(value)) => value.as_deref(), + _ => df_execution_err!("levenshtein only supports utf8 string arguments")?, + }; + let threshold = match args.get(2) { + Some(ColumnarValue::Scalar(ScalarValue::Int32(Some(value)))) => Some(*value), + Some(ColumnarValue::Scalar(scalar)) if scalar.is_null() => { + return Ok(ColumnarValue::Scalar(ScalarValue::Int32(None))); + } + Some(_) => df_execution_err!("levenshtein threshold only supports int32")?, + None => None, + }; + return Ok(ColumnarValue::Scalar(ScalarValue::Int32( + compute_levenshtein(left, right, threshold), + ))); + } + + let array_len = args + .iter() + .find_map(|arg| match arg { + ColumnarValue::Array(array) => Some(array.len()), + ColumnarValue::Scalar(_) => None, + }) + .expect("levenshtein arguments include an array"); + let left_array = args[0].clone().into_array(array_len)?; + let right_array = args[1].clone().into_array(array_len)?; + let threshold_array = args + .get(2) + .map(|threshold| threshold.clone().into_array(array_len)) + .transpose()?; + + let left_strings = as_string_array(&left_array)?; + let right_strings = as_string_array(&right_array)?; + let thresholds = threshold_array + .as_ref() + .filter(|array| array.data_type() != &DataType::Null) + .map(|array| as_int32_array(array)) + .transpose()?; + + let result = Int32Array::from_iter((0..array_len).map(|i| { + let threshold = match &threshold_array { + Some(array) if array.data_type() == &DataType::Null => return None, + Some(_) => match thresholds { + Some(arr) if arr.is_valid(i) => Some(arr.value(i)), + _ => return None, + }, + None => None, + }; + compute_levenshtein( + left_strings.is_valid(i).then(|| left_strings.value(i)), + right_strings.is_valid(i).then(|| right_strings.value(i)), + threshold, + ) + })); + Ok(ColumnarValue::Array(Arc::new(result))) +} + +fn compute_levenshtein( + left: Option<&str>, + right: Option<&str>, + threshold: Option, +) -> Option { + let left = left?; + let right = right?; + let distance = if left == right { + 0 + } else { + let left_chars = left.chars().collect::>(); + let right_chars = right.chars().collect::>(); + if left_chars.is_empty() { + right_chars.len() as i32 + } else if right_chars.is_empty() { + left_chars.len() as i32 + } else { + let mut previous = (0..=right_chars.len()).collect::>(); + let mut current = vec![0; right_chars.len() + 1]; + + for (i, left_char) in left_chars.iter().enumerate() { + current[0] = i + 1; + for (j, right_char) in right_chars.iter().enumerate() { + let substitution_cost = usize::from(left_char != right_char); + current[j + 1] = (current[j] + 1) + .min(previous[j + 1] + 1) + .min(previous[j] + substitution_cost); + } + std::mem::swap(&mut previous, &mut current); + } + previous[right_chars.len()] as i32 + } + }; + Some(match threshold { + Some(threshold) if distance > threshold => -1, + _ => distance, + }) +} + /// concat() function compatible with spark (returns null if any param is null) /// concat('abcde', 2, 22) = 'abcde222 /// concat('abcde', 2, NULL, 22) = NULL @@ -322,19 +436,19 @@ pub fn string_concat_ws(args: &[ColumnarValue]) -> Result { mod test { use std::sync::Arc; - use arrow::array::{Int32Array, ListBuilder, StringArray, StringBuilder}; + use arrow::array::{Int32Array, ListBuilder, NullArray, StringArray, StringBuilder}; use datafusion::{ common::{ Result, ScalarValue, - cast::{as_list_array, as_string_array}, + cast::{as_int32_array, as_list_array, as_string_array}, }, physical_plan::ColumnarValue, }; use datafusion_ext_commons::df_execution_err; use crate::spark_strings::{ - string_concat, string_concat_ws, string_lower, string_repeat, string_space, string_split, - string_upper, + spark_levenshtein, string_concat, string_concat_ws, string_lower, string_repeat, + string_space, string_split, string_upper, }; #[test] @@ -395,6 +509,124 @@ mod test { } } + #[test] + fn test_spark_levenshtein_array() -> Result<()> { + let r = spark_levenshtein(&vec![ + ColumnarValue::Array(Arc::new(StringArray::from_iter(vec![ + Some("kitten".to_string()), + Some("frog".to_string()), + Some("千世".to_string()), + None, + ]))), + ColumnarValue::Array(Arc::new(StringArray::from_iter(vec![ + Some("sitting".to_string()), + Some("fog".to_string()), + Some("世界千世".to_string()), + Some("abc".to_string()), + ]))), + ])?; + let s = r.into_array(4)?; + assert_eq!( + as_int32_array(&s)?.into_iter().collect::>(), + vec![Some(3), Some(1), Some(2), None] + ); + Ok(()) + } + + #[test] + fn test_spark_levenshtein_threshold() -> Result<()> { + let r = spark_levenshtein(&vec![ + ColumnarValue::Array(Arc::new(StringArray::from_iter(vec![ + Some("kitten".to_string()), + Some("kitten".to_string()), + Some("abc".to_string()), + Some("abc".to_string()), + Some("".to_string()), + Some("abc".to_string()), + ]))), + ColumnarValue::Array(Arc::new(StringArray::from_iter(vec![ + Some("sitting".to_string()), + Some("sitting".to_string()), + Some("abc".to_string()), + Some("xyz".to_string()), + Some("abc".to_string()), + Some("abc".to_string()), + ]))), + ColumnarValue::Array(Arc::new(Int32Array::from_iter(vec![ + Some(2), + Some(3), + Some(0), + None, + Some(3), + Some(-1), + ]))), + ])?; + let s = r.into_array(6)?; + assert_eq!( + as_int32_array(&s)?.into_iter().collect::>(), + vec![Some(-1), Some(3), Some(0), None, Some(3), Some(-1)] + ); + + let r = spark_levenshtein(&vec![ + ColumnarValue::Array(Arc::new(StringArray::from_iter(vec![Some( + "abc".to_string(), + )]))), + ColumnarValue::Array(Arc::new(StringArray::from_iter(vec![Some( + "xyz".to_string(), + )]))), + ColumnarValue::Array(Arc::new(NullArray::new(1))), + ])?; + let s = r.into_array(1)?; + assert_eq!( + as_int32_array(&s)?.into_iter().collect::>(), + vec![None] + ); + Ok(()) + } + + #[test] + fn test_spark_levenshtein_scalar() -> Result<()> { + let r = spark_levenshtein(&vec![ + ColumnarValue::Scalar(ScalarValue::Utf8(Some("kitten".to_string()))), + ColumnarValue::Scalar(ScalarValue::Utf8(Some("sitting".to_string()))), + ColumnarValue::Scalar(ScalarValue::Int32(Some(2))), + ])?; + match r { + ColumnarValue::Scalar(ScalarValue::Int32(Some(-1))) => {} + other => df_execution_err!("Expected Int32(-1) scalar, got: {:?}", other)?, + } + + let r = spark_levenshtein(&vec![ + ColumnarValue::Scalar(ScalarValue::Utf8(Some("kitten".to_string()))), + ColumnarValue::Scalar(ScalarValue::Utf8(Some("sitting".to_string()))), + ColumnarValue::Scalar(ScalarValue::Int32(Some(3))), + ])?; + match r { + ColumnarValue::Scalar(ScalarValue::Int32(Some(3))) => {} + other => df_execution_err!("Expected Int32(3) scalar, got: {:?}", other)?, + } + + let r = spark_levenshtein(&vec![ + ColumnarValue::Scalar(ScalarValue::Utf8(Some("kitten".to_string()))), + ColumnarValue::Scalar(ScalarValue::Utf8(Some("sitting".to_string()))), + ColumnarValue::Scalar(ScalarValue::Null), + ])?; + match r { + ColumnarValue::Scalar(ScalarValue::Int32(None)) => {} + other => df_execution_err!("Expected null Int32 scalar, got: {:?}", other)?, + } + + let r = spark_levenshtein(&vec![ + ColumnarValue::Scalar(ScalarValue::Utf8(None)), + ColumnarValue::Scalar(ScalarValue::Utf8(Some("sitting".to_string()))), + ])?; + match r { + ColumnarValue::Scalar(ScalarValue::Int32(None)) => {} + other => df_execution_err!("Expected null Int32 scalar, got: {:?}", other)?, + } + Ok(()) + } + #[test] fn test_string_repeat() -> Result<()> { // positive case diff --git a/spark-extension/src/main/scala/org/apache/spark/sql/auron/NativeConverters.scala b/spark-extension/src/main/scala/org/apache/spark/sql/auron/NativeConverters.scala index dbe1781ee..5ddeb3a2c 100644 --- a/spark-extension/src/main/scala/org/apache/spark/sql/auron/NativeConverters.scala +++ b/spark-extension/src/main/scala/org/apache/spark/sql/auron/NativeConverters.scala @@ -975,7 +975,7 @@ object NativeConverters extends Logging { buildTimePartExt("Spark_Quarter", child, isPruningExpr, fallback) case e: Levenshtein => - buildScalarFunction(pb.ScalarFunction.Levenshtein, e.children, e.dataType) + buildExtScalarFunction("Spark_Levenshtein", e.children, e.dataType) case e: Hour if datetimeExtractEnabled => buildTimePartExt("Spark_Hour", e.children.head, isPruningExpr, fallback)