Skip to content

Commit

Permalink
Native hamming distance (#4)
Browse files Browse the repository at this point in the history
* Add native hamming distance
  • Loading branch information
nvander1 authored and MrPowers committed May 31, 2019
1 parent 0d8e4ba commit d967ac9
Show file tree
Hide file tree
Showing 5 changed files with 185 additions and 9 deletions.
Original file line number Diff line number Diff line change
@@ -1,12 +1,22 @@
package com.github.mrpowers.spark.stringmetric

import com.github.mrpowers.spark.stringmetric.expressions.HammingDistance
import org.apache.spark.sql.Column
import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.functions._

import java.util.Locale

import org.apache.commons.text.similarity._
import org.apache.commons.text.similarity.{
CosineDistance,
JaccardSimilarity,
JaroWinklerDistance,
FuzzyScore
}


object SimilarityFunctions {
private def withExpr(expr: Expression): Column = new Column(expr)

val cosine_distance = udf[Option[Double], String, String](cosineDistanceFun)

Expand All @@ -26,13 +36,8 @@ object SimilarityFunctions {
Some(f.fuzzyScore(str1, str2))
}

val hamming = udf[Option[Int], String, String](hammingFun)

def hammingFun(s1: String, s2: String): Option[Int] = {
val str1 = Option(s1).getOrElse(return None)
val str2 = Option(s2).getOrElse(return None)
val h = new HammingDistance()
Some(h.apply(str1, str2))
def hamming(s1: Column, s2: Column): Column = withExpr {
HammingDistance(s1.expr, s2.expr)
}

val jaccard_similarity = udf[Option[Double], String, String](jaccardSimilarityFun)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
package com.github.mrpowers.spark.stringmetric.expressions

import com.github.mrpowers.spark.stringmetric.unsafe.UTF8StringFunctions
import org.apache.commons.text.similarity.CosineDistance
import org.apache.spark.unsafe.types.UTF8String
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.codegen.{
CodegenContext,
ExprCode
}
import org.apache.spark.sql.types.{ DataType, IntegerType, StringType }

/*
* Alleviates painfully long codegen strings.
*
* TODO: See if there is a less hacky way to inject the imports
* into generated code.
*/
trait UTF8StringFunctionsHelper {
val stringFuncs: String = "com.github.mrpowers.spark.stringmetric.unsafe.UTF8StringFunctions"
}

trait StringString2IntegerExpression
extends ImplicitCastInputTypes
with NullIntolerant
with UTF8StringFunctionsHelper { self: BinaryExpression =>
override def dataType: DataType = IntegerType
override def inputTypes: Seq[DataType] = Seq(StringType, StringType)

protected override def nullSafeEval(left: Any, right: Any): Any = -1
}

case class HammingDistance(left: Expression, right: Expression)
extends BinaryExpression with StringString2IntegerExpression {
override def prettyName: String = "hamming"

override def nullSafeEval(leftVal: Any, righValt: Any): Any = {
val leftStr = left.asInstanceOf[UTF8String]
val rightStr = right.asInstanceOf[UTF8String]
UTF8StringFunctions.hammingDistance(leftStr, rightStr)
}

override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
defineCodeGen(ctx, ev, (s1, s2) => s"$stringFuncs.hammingDistance($s1, $s2)")
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
package com.github.mrpowers.spark.stringmetric.unsafe;

import org.apache.spark.unsafe.types.UTF8String;

public class UTF8StringFunctions {
// Wish these methods weren't private in UTF8String:
// - bytesOfCodePointInUTF8
// - numBytesForFirstByte
// https://github.com/apache/spark/blob/master/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java
private static byte[] bytesOfCodePointInUTF8 = {
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, // 0x00..0x0F
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, // 0x10..0x1F
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, // 0x20..0x2F
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, // 0x30..0x3F
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, // 0x40..0x4F
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, // 0x50..0x5F
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, // 0x60..0x6F
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, // 0x70..0x7F
// Continuation bytes cannot appear as the first byte
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, // 0x80..0x8F
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, // 0x90..0x9F
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, // 0xA0..0xAF
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, // 0xB0..0xBF
0, 0, // 0xC0..0xC1 - disallowed in UTF-8
2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, // 0xC2..0xCF
2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, // 0xD0..0xDF
3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, // 0xE0..0xEF
4, 4, 4, 4, 4, // 0xF0..0xF4
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 // 0xF5..0xFF - disallowed in UTF-8
};

/**
* Returns the number of bytes for a code point with the first byte as `b`
* @param b The first byte of a code point
*/
private static int numBytesForFirstByte(final byte b) {
final int offset = b & 0xFF;
byte numBytes = bytesOfCodePointInUTF8[offset];
return (numBytes == 0) ? 1: numBytes; // Skip the first byte disallowed in UTF-8
}

// Adapted from org.apache.commons.text.similarity.HammingDistance
public static int hammingDistance(UTF8String left, UTF8String right) {
int n = left.numChars();
if (n != right.numChars()) {
throw new java.lang.IllegalArgumentException(
"Hamming distance is only defined for strings of same length!"
);
}

int distance = 0;
byte[] leftBytes = left.getBytes();
byte[] rightBytes = right.getBytes();

int leftIdx = 0;
int rightIdx = 0;
int c = 0;
while (c < n) {
int leftCharSize = numBytesForFirstByte(leftBytes[leftIdx]);
int rightCharSize = numBytesForFirstByte(rightBytes[rightIdx]);

if (leftCharSize != rightCharSize) {
distance++;
} else {
for (int i = 0; i < leftCharSize; i++) {
if (leftBytes[i + leftIdx] != rightBytes[i + rightIdx]) {
distance++;
break;
}
}
}

leftIdx += leftCharSize;
rightIdx += rightCharSize;
c++; // would rather be writing this than Java!

This comment has been minimized.

Copy link
@MrPowers

MrPowers May 31, 2019

Owner

@nvander1 - LOL!!!!!!

}
return distance;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ class SimilarityFunctionsSpec

describe("hamming") {

it("computes the hamming metric") {
it("computes the hamming metric for single byte chars") {

val sourceDF = spark.createDF(
List(
Expand Down Expand Up @@ -145,6 +145,50 @@ class SimilarityFunctionsSpec

}

it("computes the hamming metric for multi byte chars") {

val sourceDF = spark.createDF(
List(
("aab", "aaa"),
("aa¢", "aaa"), // ¢ is 2 bytes
("aaह", "aaa"), // ह is 3 bytes
("aa€", "aaa"), // € is 3 bytes
("aa𐍈", "aaa"), // 𐍈 is 4 bytes
("𐍈€¢", "aaa"),
("𐍈€a¢€b", "b€𐍈𐍈ab"),
("𐍈€a¢€b", "b€𐍈𐍈€b")
), List(
("word1", StringType, true),
("word2", StringType, true)
)
)

val actualDF = sourceDF.withColumn(
"w1_w2_hamming",
SimilarityFunctions.hamming(col("word1"), col("word2"))
)

val expectedDF = spark.createDF(
List(
("aab", "aaa", 1),
("aa¢", "aaa", 1), // ¢ is 2 bytes
("aaह", "aaa", 1), // ह is 3 bytes
("aa€", "aaa", 1), // € is 3 bytes
("aa𐍈", "aaa", 1), // 𐍈 is 4 bytes
("𐍈€¢", "aaa", 3),
("𐍈€a¢€b", "b€𐍈𐍈ab", 4),
("𐍈€a¢€b", "b€𐍈𐍈€b", 3)
), List(
("word1", StringType, true),
("word2", StringType, true),
("w1_w2_hamming", IntegerType, true)
)
)

assertSmallDataFrameEquality(actualDF, expectedDF)

}

}

describe("jaccard_similarity") {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
package com.github.mrpowers.spark.stringmetric

import org.apache.spark.sql.SparkSession
import org.apache.log4j.{Logger, Level}

trait SparkSessionTestWrapper {

lazy val spark: SparkSession = {
Logger.getLogger("org").setLevel(Level.OFF)
SparkSession
.builder()
.master("local")
Expand Down

0 comments on commit d967ac9

Please sign in to comment.