diff --git a/python/pyspark/sql/connect/functions.py b/python/pyspark/sql/connect/functions.py index b7d7bc937cf8e..d3a05d6a1c608 100644 --- a/python/pyspark/sql/connect/functions.py +++ b/python/pyspark/sql/connect/functions.py @@ -1878,8 +1878,13 @@ def substring_index(str: "ColumnOrName", delim: str, count: int) -> Column: substring_index.__doc__ = pysparkfuncs.substring_index.__doc__ -def levenshtein(left: "ColumnOrName", right: "ColumnOrName") -> Column: - return _invoke_function_over_columns("levenshtein", left, right) +def levenshtein( + left: "ColumnOrName", right: "ColumnOrName", threshold: Optional[int] = None +) -> Column: + if threshold is None: + return _invoke_function_over_columns("levenshtein", left, right) + else: + return _invoke_function("levenshtein", _to_col(left), _to_col(right), lit(threshold)) levenshtein.__doc__ = pysparkfuncs.levenshtein.__doc__ diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index e9b71f7d617db..fe35f12c40215 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -6594,7 +6594,9 @@ def substring_index(str: "ColumnOrName", delim: str, count: int) -> Column: @try_remote_functions -def levenshtein(left: "ColumnOrName", right: "ColumnOrName") -> Column: +def levenshtein( + left: "ColumnOrName", right: "ColumnOrName", threshold: Optional[int] = None +) -> Column: """Computes the Levenshtein distance of the two given strings. .. versionadded:: 1.5.0 @@ -6608,6 +6610,12 @@ def levenshtein(left: "ColumnOrName", right: "ColumnOrName") -> Column: first column value. right : :class:`~pyspark.sql.Column` or str second column value. + threshold : int, optional + if set when the levenshtein distance of the two given strings + less than or equal to a given threshold then return result distance, or -1 + + .. versionchanged: 3.5.0 + Added ``threshold`` argument. Returns ------- @@ -6619,8 +6627,15 @@ def levenshtein(left: "ColumnOrName", right: "ColumnOrName") -> Column: >>> df0 = spark.createDataFrame([('kitten', 'sitting',)], ['l', 'r']) >>> df0.select(levenshtein('l', 'r').alias('d')).collect() [Row(d=3)] + >>> df0.select(levenshtein('l', 'r', 2).alias('d')).collect() + [Row(d=-1)] """ - return _invoke_function_over_columns("levenshtein", left, right) + if threshold is None: + return _invoke_function_over_columns("levenshtein", left, right) + else: + return _invoke_function( + "levenshtein", _to_java_column(left), _to_java_column(right), threshold + ) @try_remote_functions diff --git a/python/pyspark/sql/tests/connect/test_connect_function.py b/python/pyspark/sql/tests/connect/test_connect_function.py index e274635d3c627..3e3b4dd5b1654 100644 --- a/python/pyspark/sql/tests/connect/test_connect_function.py +++ b/python/pyspark/sql/tests/connect/test_connect_function.py @@ -1924,6 +1924,11 @@ def test_string_functions_multi_args(self): cdf.select(CF.levenshtein(cdf.b, cdf.c)).toPandas(), sdf.select(SF.levenshtein(sdf.b, sdf.c)).toPandas(), ) + self.assert_eq( + cdf.select(CF.levenshtein(cdf.b, cdf.c, 1)).toPandas(), + sdf.select(SF.levenshtein(sdf.b, sdf.c, 1)).toPandas(), + ) + self.assert_eq( cdf.select(CF.locate("e", cdf.b)).toPandas(), sdf.select(SF.locate("e", sdf.b)).toPandas(), diff --git a/python/pyspark/sql/tests/test_functions.py b/python/pyspark/sql/tests/test_functions.py index 9067de3463357..72c6c365b804b 100644 --- a/python/pyspark/sql/tests/test_functions.py +++ b/python/pyspark/sql/tests/test_functions.py @@ -377,6 +377,13 @@ def test_array_contains_function(self): actual = df.select(F.array_contains(df.data, "1").alias("b")).collect() self.assertEqual([Row(b=True), Row(b=False)], actual) + def test_levenshtein_function(self): + df = self.spark.createDataFrame([("kitten", "sitting")], ["l", "r"]) + actual_without_threshold = df.select(F.levenshtein(df.l, df.r).alias("b")).collect() + self.assertEqual([Row(b=3)], actual_without_threshold) + actual_with_threshold = df.select(F.levenshtein(df.l, df.r, 2).alias("b")).collect() + self.assertEqual([Row(b=-1)], actual_with_threshold) + def test_between_function(self): df = self.spark.createDataFrame( [Row(a=1, b=2, c=3), Row(a=2, b=1, c=3), Row(a=4, b=1, c=4)]