Skip to content
9 changes: 7 additions & 2 deletions python/pyspark/sql/connect/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -1878,8 +1878,13 @@ def substring_index(str: "ColumnOrName", delim: str, count: int) -> Column:
substring_index.__doc__ = pysparkfuncs.substring_index.__doc__


def levenshtein(left: "ColumnOrName", right: "ColumnOrName") -> Column:
return _invoke_function_over_columns("levenshtein", left, right)
def levenshtein(
left: "ColumnOrName", right: "ColumnOrName", threshold: Optional[int] = None
) -> Column:
if threshold is None:
return _invoke_function_over_columns("levenshtein", left, right)
else:
return _invoke_function("levenshtein", _to_col(left), _to_col(right), lit(threshold))


levenshtein.__doc__ = pysparkfuncs.levenshtein.__doc__
Expand Down
19 changes: 17 additions & 2 deletions python/pyspark/sql/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -6594,7 +6594,9 @@ def substring_index(str: "ColumnOrName", delim: str, count: int) -> Column:


@try_remote_functions
def levenshtein(left: "ColumnOrName", right: "ColumnOrName") -> Column:
def levenshtein(
left: "ColumnOrName", right: "ColumnOrName", threshold: Optional[int] = None
) -> Column:
"""Computes the Levenshtein distance of the two given strings.

.. versionadded:: 1.5.0
Expand All @@ -6608,6 +6610,12 @@ def levenshtein(left: "ColumnOrName", right: "ColumnOrName") -> Column:
first column value.
right : :class:`~pyspark.sql.Column` or str
second column value.
threshold : int, optional
if set when the levenshtein distance of the two given strings
less than or equal to a given threshold then return result distance, or -1

.. versionchanged: 3.5.0
Added ``threshold`` argument.

Returns
-------
Expand All @@ -6619,8 +6627,15 @@ def levenshtein(left: "ColumnOrName", right: "ColumnOrName") -> Column:
>>> df0 = spark.createDataFrame([('kitten', 'sitting',)], ['l', 'r'])
>>> df0.select(levenshtein('l', 'r').alias('d')).collect()
[Row(d=3)]
>>> df0.select(levenshtein('l', 'r', 2).alias('d')).collect()
[Row(d=-1)]
"""
return _invoke_function_over_columns("levenshtein", left, right)
if threshold is None:
return _invoke_function_over_columns("levenshtein", left, right)
else:
return _invoke_function(
"levenshtein", _to_java_column(left), _to_java_column(right), threshold
)


@try_remote_functions
Expand Down
5 changes: 5 additions & 0 deletions python/pyspark/sql/tests/connect/test_connect_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -1924,6 +1924,11 @@ def test_string_functions_multi_args(self):
cdf.select(CF.levenshtein(cdf.b, cdf.c)).toPandas(),
sdf.select(SF.levenshtein(sdf.b, sdf.c)).toPandas(),
)
self.assert_eq(
cdf.select(CF.levenshtein(cdf.b, cdf.c, 1)).toPandas(),
sdf.select(SF.levenshtein(sdf.b, sdf.c, 1)).toPandas(),
)

self.assert_eq(
cdf.select(CF.locate("e", cdf.b)).toPandas(),
sdf.select(SF.locate("e", sdf.b)).toPandas(),
Expand Down
7 changes: 7 additions & 0 deletions python/pyspark/sql/tests/test_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,6 +377,13 @@ def test_array_contains_function(self):
actual = df.select(F.array_contains(df.data, "1").alias("b")).collect()
self.assertEqual([Row(b=True), Row(b=False)], actual)

def test_levenshtein_function(self):
df = self.spark.createDataFrame([("kitten", "sitting")], ["l", "r"])
actual_without_threshold = df.select(F.levenshtein(df.l, df.r).alias("b")).collect()
self.assertEqual([Row(b=3)], actual_without_threshold)
actual_with_threshold = df.select(F.levenshtein(df.l, df.r, 2).alias("b")).collect()
self.assertEqual([Row(b=-1)], actual_with_threshold)

def test_between_function(self):
df = self.spark.createDataFrame(
[Row(a=1, b=2, c=3), Row(a=2, b=1, c=3), Row(a=4, b=1, c=4)]
Expand Down