Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-8270][SQL] levenshtein distance #7214

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions python/pyspark/sql/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,6 +323,20 @@ def explode(col):
return Column(jc)


@ignore_unicode_prefix
@since(1.5)
def levenshtein(left, right):
"""Computes the Levenshtein distance of the two given strings.

>>> df0 = sqlContext.createDataFrame([('kitten', 'sitting',)], ['l', 'r'])
>>> df0.select(levenshtein('l', 'r').alias('d')).collect()
[Row(d=3)]
"""
sc = SparkContext._active_spark_context
jc = sc._jvm.functions.levenshtein(_to_java_column(left), _to_java_column(right))
return Column(jc)


@ignore_unicode_prefix
@since(1.5)
def md5(col):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,7 @@ object FunctionRegistry {
expression[Lower]("lcase"),
expression[Lower]("lower"),
expression[StringLength]("length"),
expression[Levenshtein]("levenshtein"),
expression[Substring]("substr"),
expression[Substring]("substring"),
expression[Upper]("ucase"),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.expressions

import java.util.regex.Pattern

import org.apache.commons.lang3.StringUtils
import org.apache.spark.sql.catalyst.analysis.UnresolvedException
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.types._
Expand Down Expand Up @@ -298,3 +299,34 @@ case class StringLength(child: Expression) extends UnaryExpression with ExpectsI

override def prettyName: String = "length"
}

/**
* A function that return the Levenshtein distance between the two given strings.
*/
case class Levenshtein(left: Expression, right: Expression) extends BinaryExpression
with ExpectsInputTypes {

override def inputTypes: Seq[AbstractDataType] = Seq(StringType, StringType)

override def eval(input: InternalRow): Any = {
val leftValue = left.eval(input)
if (leftValue == null) {
null
} else {
val rightValue = right.eval(input)
if(rightValue == null) {
null
} else {
StringUtils.getLevenshteinDistance(leftValue.toString, rightValue.toString)
}
}
}

override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
val stringUtils = classOf[StringUtils].getName
nullSafeCodeGen(ctx, ev, (res, left, right) =>
s"$res = $stringUtils.getLevenshteinDistance($left.toString(), $right.toString());")
}

override def dataType: DataType = IntegerType
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you move this right after inputTypes?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sure

}
Original file line number Diff line number Diff line change
Expand Up @@ -224,4 +224,13 @@ class StringFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper {
checkEvaluation(StringLength(regEx), null, create_row(null))
checkEvaluation(StringLength(Literal.create(null, StringType)), null, create_row("abdef"))
}

test("Levenshtein distance") {
checkEvaluation(Levenshtein(Literal.create(null, StringType), Literal("")), null)
checkEvaluation(Levenshtein(Literal(""), Literal.create(null, StringType)), null)
checkEvaluation(Levenshtein(Literal(""), Literal("")), 0)
checkEvaluation(Levenshtein(Literal("abc"), Literal("abc")), 0)
checkEvaluation(Levenshtein(Literal("kitten"), Literal("sitting")), 3)
checkEvaluation(Levenshtein(Literal("frog"), Literal("fog")), 1)
}
}
19 changes: 17 additions & 2 deletions sql/core/src/main/scala/org/apache/spark/sql/functions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1542,19 +1542,34 @@ object functions {
//////////////////////////////////////////////////////////////////////////////////////////////

/**
* Computes the length of a given string value
* Computes the length of a given string value.
* @group string_funcs
* @since 1.5.0
*/
def strlen(e: Column): Column = StringLength(e.expr)

/**
* Computes the length of a given string column
* Computes the length of a given string column.
* @group string_funcs
* @since 1.5.0
*/
def strlen(columnName: String): Column = strlen(Column(columnName))

/**
* Computes the Levenshtein distance of the two given strings.
* @group string_funcs
* @since 1.5.0
*/
def levenshtein(l: Column, r: Column): Column = Levenshtein(l.expr, r.expr)

/**
* Computes the Levenshtein distance of the two given strings.
* @group string_funcs
* @since 1.5.0
*/
def levenshtein(leftColumnName: String, rightColumnName: String): Column =
levenshtein(Column(leftColumnName), Column(rightColumnName))

//////////////////////////////////////////////////////////////////////////////////////////////
//////////////////////////////////////////////////////////////////////////////////////////////

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -225,4 +225,10 @@ class DataFrameFunctionsSuite extends QueryTest {
Row(l)
})
}

test("Levenshtein distance") {
val df = Seq(("kitten", "sitting"), ("frog", "fog")).toDF("l", "r")
checkAnswer(df.select(levenshtein("l", "r")), Seq(Row(3), Row(1)))
checkAnswer(df.selectExpr("levenshtein(l, r)"), Seq(Row(3), Row(1)))
}
}