Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,6 @@ class AuronSparkTestSettings extends SparkTestSettings {
enableSuite[AuronMiscFunctionsSuite]

enableSuite[AuronStringFunctionsSuite]
// Native levenshtein has a Spark 3.5+ result or schema comparison mismatch.
.exclude("string Levenshtein distance")
// Native substr does not support BinaryType inputs.
// See https://github.com/apache/auron/issues/1724
.exclude("string / binary substring function")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,9 +59,6 @@ class AuronSparkTestSettings extends SparkTestSettings {
enableSuite[AuronMiscFunctionsSuite]

enableSuite[AuronStringFunctionsSuite]
// Spark 4 adds the threshold argument, but native levenshtein currently supports only
// two arguments.
.exclude("string Levenshtein distance")
// Native substr does not support BinaryType inputs.
.exclude("string / binary substring function")

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,9 +59,6 @@ class AuronSparkTestSettings extends SparkTestSettings {
enableSuite[AuronMiscFunctionsSuite]

enableSuite[AuronStringFunctionsSuite]
// Spark 4 adds the threshold argument, but native levenshtein currently supports only
// two arguments.
.exclude("string Levenshtein distance")
// Native substr does not support BinaryType inputs.
.exclude("string / binary substring function")

Expand Down
2 changes: 1 addition & 1 deletion native-engine/auron-planner/proto/auron.proto
Original file line number Diff line number Diff line change
Expand Up @@ -281,7 +281,7 @@ enum ScalarFunction {
Hex=66;
Power=67;
IsNaN=69;
Levenshtein=80;
// Levenshtein=80;
FindInSet=81;
Nvl=82;
Nvl2=83;
Expand Down
3 changes: 1 addition & 2 deletions native-engine/auron-planner/src/planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1292,8 +1292,7 @@ impl From<protobuf::ScalarFunction> for Arc<ScalarUDF> {
ScalarFunction::Rpad => f::unicode::rpad(),
ScalarFunction::SplitPart => f::string::split_part(),
ScalarFunction::StartsWith => f::string::starts_with(),
ScalarFunction::Levenshtein => f::string::levenshtein(),

// ScalarFunction::Levenshtein => f::string::levenshtein(),
ScalarFunction::FindInSet => f::unicode::find_in_set(),
ScalarFunction::Strpos => f::unicode::strpos(),
ScalarFunction::Substr => f::unicode::substr(),
Expand Down
1 change: 1 addition & 0 deletions native-engine/datafusion-ext-functions/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ pub fn create_auron_ext_function(
"Spark_StringConcatWs" => Arc::new(spark_strings::string_concat_ws),
"Spark_StringLower" => Arc::new(spark_strings::string_lower),
"Spark_StringUpper" => Arc::new(spark_strings::string_upper),
"Spark_Levenshtein" => Arc::new(spark_strings::spark_levenshtein),
"Spark_InitCap" => Arc::new(spark_initcap::string_initcap),
"Spark_Year" => Arc::new(spark_dates::spark_year),
"Spark_Month" => Arc::new(spark_dates::spark_month),
Expand Down
242 changes: 237 additions & 5 deletions native-engine/datafusion-ext-functions/src/spark_strings.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@
use std::sync::Arc;

use arrow::{
array::{Array, ArrayRef, AsArray, ListArray, ListBuilder, StringArray, StringBuilder},
array::{
Array, ArrayRef, AsArray, Int32Array, ListArray, ListBuilder, StringArray, StringBuilder,
},
datatypes::DataType,
};
use datafusion::{
Expand Down Expand Up @@ -114,6 +116,118 @@ pub fn string_split(args: &[ColumnarValue]) -> Result<ColumnarValue> {
Ok(ColumnarValue::Array(Arc::new(splitted_builder.finish())))
}

pub fn spark_levenshtein(args: &[ColumnarValue]) -> Result<ColumnarValue> {
Copy link
Copy Markdown
Contributor

@ShreyeshArangath ShreyeshArangath May 20, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Correct me if I'm wrong but, in spark_levenshtein, a null threshold is coerced to Some(0), which then returns -1 for any non-zero distance and 0 for equal strings:

  Some(ColumnarValue::Scalar(scalar)) if scalar.is_null() => Some(0),
  ...
  Some(array) if array.data_type() == &DataType::Null => Some(0),
  Some(_) => thresholds.map(|array| if array.is_valid(i) { array.value(i) } else { 0 }),

I think that doesn't match Spark. In Spark, Levenshtein.eval only null-checks left/right..a null threshold at runtime causes an NPE during v.asInstanceOf[Int]. The defensible options are (a) propagate null when threshold is null, or (b) mirror Spark and error

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for catching this! You're correct. I double-confirmed this and have fixed it by propagating null. Updated tests accordingly.

if args.len() != 2 && args.len() != 3 {
df_execution_err!(
"levenshtein was called with {} arguments. It requires 2 or 3.",
args.len(),
)?;
}

if args
.iter()
.all(|arg| matches!(arg, ColumnarValue::Scalar(_)))
{
let left = match &args[0] {
ColumnarValue::Scalar(ScalarValue::Utf8(value)) => value.as_deref(),
_ => df_execution_err!("levenshtein only supports utf8 string arguments")?,
};
let right = match &args[1] {
ColumnarValue::Scalar(ScalarValue::Utf8(value)) => value.as_deref(),
_ => df_execution_err!("levenshtein only supports utf8 string arguments")?,
};
let threshold = match args.get(2) {
Some(ColumnarValue::Scalar(ScalarValue::Int32(Some(value)))) => Some(*value),
Some(ColumnarValue::Scalar(scalar)) if scalar.is_null() => {
return Ok(ColumnarValue::Scalar(ScalarValue::Int32(None)));
}
Some(_) => df_execution_err!("levenshtein threshold only supports int32")?,
None => None,
};
return Ok(ColumnarValue::Scalar(ScalarValue::Int32(
compute_levenshtein(left, right, threshold),
)));
}

let array_len = args
.iter()
.find_map(|arg| match arg {
ColumnarValue::Array(array) => Some(array.len()),
ColumnarValue::Scalar(_) => None,
})
.expect("levenshtein arguments include an array");
let left_array = args[0].clone().into_array(array_len)?;
let right_array = args[1].clone().into_array(array_len)?;
let threshold_array = args
.get(2)
.map(|threshold| threshold.clone().into_array(array_len))
.transpose()?;

let left_strings = as_string_array(&left_array)?;
let right_strings = as_string_array(&right_array)?;
let thresholds = threshold_array
.as_ref()
.filter(|array| array.data_type() != &DataType::Null)
.map(|array| as_int32_array(array))
.transpose()?;

let result = Int32Array::from_iter((0..array_len).map(|i| {
let threshold = match &threshold_array {
Some(array) if array.data_type() == &DataType::Null => return None,
Some(_) => match thresholds {
Some(arr) if arr.is_valid(i) => Some(arr.value(i)),
_ => return None,
},
None => None,
};
compute_levenshtein(
left_strings.is_valid(i).then(|| left_strings.value(i)),
right_strings.is_valid(i).then(|| right_strings.value(i)),
threshold,
)
}));
Ok(ColumnarValue::Array(Arc::new(result)))
}

fn compute_levenshtein(
left: Option<&str>,
right: Option<&str>,
threshold: Option<i32>,
) -> Option<i32> {
let left = left?;
let right = right?;
let distance = if left == right {
0
} else {
let left_chars = left.chars().collect::<Vec<_>>();
let right_chars = right.chars().collect::<Vec<_>>();
if left_chars.is_empty() {
right_chars.len() as i32
} else if right_chars.is_empty() {
left_chars.len() as i32
} else {
let mut previous = (0..=right_chars.len()).collect::<Vec<_>>();
let mut current = vec![0; right_chars.len() + 1];

for (i, left_char) in left_chars.iter().enumerate() {
current[0] = i + 1;
for (j, right_char) in right_chars.iter().enumerate() {
let substitution_cost = usize::from(left_char != right_char);
current[j + 1] = (current[j] + 1)
.min(previous[j + 1] + 1)
.min(previous[j] + substitution_cost);
}
std::mem::swap(&mut previous, &mut current);
}
previous[right_chars.len()] as i32
}
};
Some(match threshold {
Some(threshold) if distance > threshold => -1,
_ => distance,
})
}

/// concat() function compatible with spark (returns null if any param is null)
/// concat('abcde', 2, 22) = 'abcde222
/// concat('abcde', 2, NULL, 22) = NULL
Expand Down Expand Up @@ -322,19 +436,19 @@ pub fn string_concat_ws(args: &[ColumnarValue]) -> Result<ColumnarValue> {
mod test {
use std::sync::Arc;

use arrow::array::{Int32Array, ListBuilder, StringArray, StringBuilder};
use arrow::array::{Int32Array, ListBuilder, NullArray, StringArray, StringBuilder};
use datafusion::{
common::{
Result, ScalarValue,
cast::{as_list_array, as_string_array},
cast::{as_int32_array, as_list_array, as_string_array},
},
physical_plan::ColumnarValue,
};
use datafusion_ext_commons::df_execution_err;

use crate::spark_strings::{
string_concat, string_concat_ws, string_lower, string_repeat, string_space, string_split,
string_upper,
spark_levenshtein, string_concat, string_concat_ws, string_lower, string_repeat,
string_space, string_split, string_upper,
};

#[test]
Expand Down Expand Up @@ -395,6 +509,124 @@ mod test {
}
}

#[test]
fn test_spark_levenshtein_array() -> Result<()> {
let r = spark_levenshtein(&vec![
ColumnarValue::Array(Arc::new(StringArray::from_iter(vec![
Some("kitten".to_string()),
Some("frog".to_string()),
Some("千世".to_string()),
None,
]))),
ColumnarValue::Array(Arc::new(StringArray::from_iter(vec![
Some("sitting".to_string()),
Some("fog".to_string()),
Some("世界千世".to_string()),
Some("abc".to_string()),
]))),
])?;
let s = r.into_array(4)?;
assert_eq!(
as_int32_array(&s)?.into_iter().collect::<Vec<_>>(),
vec![Some(3), Some(1), Some(2), None]
);
Ok(())
}

#[test]
fn test_spark_levenshtein_threshold() -> Result<()> {
let r = spark_levenshtein(&vec![
ColumnarValue::Array(Arc::new(StringArray::from_iter(vec![
Some("kitten".to_string()),
Some("kitten".to_string()),
Some("abc".to_string()),
Some("abc".to_string()),
Some("".to_string()),
Some("abc".to_string()),
]))),
ColumnarValue::Array(Arc::new(StringArray::from_iter(vec![
Some("sitting".to_string()),
Some("sitting".to_string()),
Some("abc".to_string()),
Some("xyz".to_string()),
Some("abc".to_string()),
Some("abc".to_string()),
]))),
ColumnarValue::Array(Arc::new(Int32Array::from_iter(vec![
Some(2),
Some(3),
Some(0),
None,
Some(3),
Some(-1),
]))),
])?;
let s = r.into_array(6)?;
assert_eq!(
as_int32_array(&s)?.into_iter().collect::<Vec<_>>(),
vec![Some(-1), Some(3), Some(0), None, Some(3), Some(-1)]
);

let r = spark_levenshtein(&vec![
ColumnarValue::Array(Arc::new(StringArray::from_iter(vec![Some(
"abc".to_string(),
)]))),
ColumnarValue::Array(Arc::new(StringArray::from_iter(vec![Some(
"xyz".to_string(),
)]))),
ColumnarValue::Array(Arc::new(NullArray::new(1))),
])?;
let s = r.into_array(1)?;
assert_eq!(
as_int32_array(&s)?.into_iter().collect::<Vec<_>>(),
vec![None]
);
Ok(())
}

#[test]
fn test_spark_levenshtein_scalar() -> Result<()> {
let r = spark_levenshtein(&vec![
ColumnarValue::Scalar(ScalarValue::Utf8(Some("kitten".to_string()))),
ColumnarValue::Scalar(ScalarValue::Utf8(Some("sitting".to_string()))),
ColumnarValue::Scalar(ScalarValue::Int32(Some(2))),
])?;
match r {
ColumnarValue::Scalar(ScalarValue::Int32(Some(-1))) => {}
other => df_execution_err!("Expected Int32(-1) scalar, got: {:?}", other)?,
}

let r = spark_levenshtein(&vec![
ColumnarValue::Scalar(ScalarValue::Utf8(Some("kitten".to_string()))),
ColumnarValue::Scalar(ScalarValue::Utf8(Some("sitting".to_string()))),
ColumnarValue::Scalar(ScalarValue::Int32(Some(3))),
])?;
match r {
ColumnarValue::Scalar(ScalarValue::Int32(Some(3))) => {}
other => df_execution_err!("Expected Int32(3) scalar, got: {:?}", other)?,
}

let r = spark_levenshtein(&vec![
ColumnarValue::Scalar(ScalarValue::Utf8(Some("kitten".to_string()))),
ColumnarValue::Scalar(ScalarValue::Utf8(Some("sitting".to_string()))),
ColumnarValue::Scalar(ScalarValue::Null),
])?;
match r {
ColumnarValue::Scalar(ScalarValue::Int32(None)) => {}
other => df_execution_err!("Expected null Int32 scalar, got: {:?}", other)?,
}

let r = spark_levenshtein(&vec![
ColumnarValue::Scalar(ScalarValue::Utf8(None)),
ColumnarValue::Scalar(ScalarValue::Utf8(Some("sitting".to_string()))),
])?;
match r {
ColumnarValue::Scalar(ScalarValue::Int32(None)) => {}
other => df_execution_err!("Expected null Int32 scalar, got: {:?}", other)?,
}
Ok(())
}

#[test]
fn test_string_repeat() -> Result<()> {
// positive case
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -975,7 +975,7 @@ object NativeConverters extends Logging {
buildTimePartExt("Spark_Quarter", child, isPruningExpr, fallback)

case e: Levenshtein =>
buildScalarFunction(pb.ScalarFunction.Levenshtein, e.children, e.dataType)
buildExtScalarFunction("Spark_Levenshtein", e.children, e.dataType)

case e: Hour if datetimeExtractEnabled =>
buildTimePartExt("Spark_Hour", e.children.head, isPruningExpr, fallback)
Expand Down
Loading