In [1]:
from pyspark.sql import SparkSession
from pyspark.sql import functions as F
from pyspark.sql.types import ArrayType, FloatType, IntegerType
from pyspark.ml.linalg import Vectors, VectorUDT
from pyspark.sql.functions import udf, col

spark = SparkSession.builder.master("local").appName("minLSH").getOrCreate()

Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).


23/04/16 19:36:23 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable
23/04/16 19:36:24 WARN Utils: Service 'SparkUI' could not bind on port 4040. Attempting port 4041.
23/04/16 19:36:24 WARN Utils: Service 'SparkUI' could not bind on port 4041. Attempting port 4042.


In [2]:
import random
from pyspark.sql.types import ArrayType, FloatType, IntegerType,StructType
from pyspark.ml.feature import HashingTF

# MinHashLSH

In [3]:
class minLSH:
    def __init__(self, numHashTables=5, shingleSize=5, inputCol="features", outputCol="hashes"):
        self.numHashTables = numHashTables
        self.shingleSize = shingleSize
        self.inputCol = inputCol
        self.outputCol = outputCol
        self.seed = None
        if not self.seed:
            self.seed = random.randint(0, 10000)
        random.seed(self.seed)
        self.a_vals = [random.randint(1, self.numHashTables * 10) for _ in range(self.numHashTables)]
        self.b_vals = [random.randint(0, self.numHashTables * 10) for _ in range(self.numHashTables)]


    def _shingle(self, vec):
        shingles = set()
        for i in range(len(vec) - self.shingleSize + 1):
            shingles.add(Vectors.dense(vec[i:i + self.shingleSize]))
        return list(shingles)

    def _hash_func(self, vec, a, b, m):
        hash_val = ((a * vec + b) % m) % self.numHashTables
        return hash_val
    
    def _min_hash(self, vec):
        m = 1 << 32 - 1
        hash_values = [self._hash_func(vec, a, b, m) for a, b in zip(self.a_vals, self.b_vals)]
        print("hash_values:",hash_values)
        return hash_values
    
    def _hash_func2(self, vec, a, b, m):
        hash_val = ((a * vec.values.sum() + b) % m) % self.numHashTables
        return hash_val
    
    def _min_hash2(self, shingles):
        hash_values = []
        m = 1 << 32 - 1
        for a, b in zip(self.a_vals, self.b_vals):
            min_val = float('inf')
            for shingle in shingles:
                hash_val = self._hash_func2(shingle, a, b, m)
                min_val = min(min_val, hash_val)
            hash_values.append(min_val)
        print("hash_values=", hash_values)
        return Vectors.dense(hash_values)

    def fit(self, dataset):
        return self

    def transform(self, dataset):
        # shingle_udf = udf(self._shingle, ArrayType(VectorUDT()))
        # dataset = dataset.withColumn("shingles", shingle_udf(col(self.inputCol)))
        min_hash_udf = udf(self._min_hash, ArrayType(VectorUDT()))
        dataset = dataset.withColumn(
            self.outputCol, min_hash_udf(col(self.inputCol)))
        return dataset

    def approxSimilarityJoin(self, datasetA, datasetB, threshold):
        jaccard_similarity = udf(lambda x, y: len(
            set(x).intersection(set(y))) / len(set(x).union(set(y))), FloatType())

        for col_name in datasetA.columns:
            datasetA = datasetA.withColumnRenamed(col_name, col_name + "_1")
        for col_name in datasetB.columns:
            datasetB = datasetB.withColumnRenamed(col_name, col_name + "_2")

        joined = datasetA.crossJoin(datasetB)
        joined = joined.withColumn("jaccard_similarity", jaccard_similarity(
            joined[self.outputCol + "_1"], joined[self.outputCol + "_2"]))
        result = joined.filter(joined["jaccard_similarity"] >= threshold)
        return result



In [4]:
from pyspark.ml.feature import VectorAssembler
from pyspark.ml.linalg import Vectors

In [5]:
train = spark.read.csv('./mnist/mnist_train.csv', header=True, inferSchema=True)
columns = train.columns[1:]
# 合并为 vector
assembler = VectorAssembler(inputCols=columns, outputCol='features')
train = assembler.transform(train)
train = train.drop(*columns)
# 改为 DenseVector
toDense = lambda v: Vectors.dense(v.toArray())
toDenseUdf = udf(toDense, VectorUDT())
train = train.withColumn('features', toDenseUdf('features'))


                                                                                

In [6]:
train.select('features').show(5, truncate=False)

23/04/16 19:36:37 WARN package: Truncated the string representation of a plan since it was too large. This behavior can be adjusted by setting 'spark.sql.debug.maxToStringFields'.


[Stage 2:>                                                          (0 + 1) / 1]

+---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------

                                                                                

# Train

In [7]:
image_df = spark.createDataFrame(train.take(10), ["id", "features"])
min_lsh = minLSH(numHashTables=10, shingleSize=10, inputCol="features", outputCol="hashes")
min_lsh_model = min_lsh.fit(image_df)
transformed_df = min_lsh_model.transform(image_df)
transformed_df.show()

hash_values: [DenseVector([9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 2.0, 7.0, 7.0, 7.0, 5.0, 5.0, 4.0, 5.0, 5.0, 4.0, 6.0, 6.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 5.0, 3.0, 3.0, 9.0, 2.0, 2.0, 2.0, 2.0, 2.0, 4.0, 1.0, 2.0, 1.0, 4.0, 3.0, 9.0, 9.0, 9.0

+---+--------------------+--------------------+
| id|            features|              hashes|
+---+--------------------+--------------------+
|  5|[0.0,0.0,0.0,0.0,...|[[9.0,9.0,9.0,9.0...|
|  0|[0.0,0.0,0.0,0.0,...|[[9.0,9.0,9.0,9.0...|
|  4|[0.0,0.0,0.0,0.0,...|[[9.0,9.0,9.0,9.0...|
|  1|[0.0,0.0,0.0,0.0,...|[[9.0,9.0,9.0,9.0...|
|  9|[0.0,0.0,0.0,0.0,...|[[9.0,9.0,9.0,9.0...|
|  2|[0.0,0.0,0.0,0.0,...|[[9.0,9.0,9.0,9.0...|
|  1|[0.0,0.0,0.0,0.0,...|[[9.0,9.0,9.0,9.0...|
|  3|[0.0,0.0,0.0,0.0,...|[[9.0,9.0,9.0,9.0...|
|  1|[0.0,0.0,0.0,0.0,...|[[9.0,9.0,9.0,9.0...|
|  4|[0.0,0.0,0.0,0.0,...|[[9.0,9.0,9.0,9.0...|
+---+--------------------+--------------------+



                                                                                

In [8]:
result = min_lsh_model.approxSimilarityJoin(transformed_df, transformed_df, threshold=0)
result.where("id_1=1").show()

23/04/16 19:36:49 WARN ExtractPythonUDFFromJoinCondition: The join condition:(<lambda>(hashes_1#4774, hashes_2#4786)#4798 >= 0.0) of the join plan contains PythonUDF only, it will be moved out and the join plan will be turned to cross join.


hash_values: [DenseVector([9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 2.0, 7.0, 7.0, 7.0, 5.0, 5.0, 4.0, 5.0, 5.0, 4.0, 6.0, 6.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 9.0, 5.0, 3.0, 3.0, 9.0, 2.0, 2.0, 2.0, 2.0, 2.0, 4.0, 1.0, 2.0, 1.0, 4.0, 3.0, 9.0, 9.0, 9.0

+----+--------------------+--------------------+----+--------------------+--------------------+------------------+
|id_1|          features_1|            hashes_1|id_2|          features_2|            hashes_2|jaccard_similarity|
+----+--------------------+--------------------+----+--------------------+--------------------+------------------+
|   1|[0.0,0.0,0.0,0.0,...|[[9.0,9.0,9.0,9.0...|   5|[0.0,0.0,0.0,0.0,...|[[9.0,9.0,9.0,9.0...|               0.0|
|   1|[0.0,0.0,0.0,0.0,...|[[9.0,9.0,9.0,9.0...|   0|[0.0,0.0,0.0,0.0,...|[[9.0,9.0,9.0,9.0...|               0.0|
|   1|[0.0,0.0,0.0,0.0,...|[[9.0,9.0,9.0,9.0...|   4|[0.0,0.0,0.0,0.0,...|[[9.0,9.0,9.0,9.0...|               0.0|
|   1|[0.0,0.0,0.0,0.0,...|[[9.0,9.0,9.0,9.0...|   1|[0.0,0.0,0.0,0.0,...|[[9.0,9.0,9.0,9.0...|               1.0|
|   1|[0.0,0.0,0.0,0.0,...|[[9.0,9.0,9.0,9.0...|   9|[0.0,0.0,0.0,0.0,...|[[9.0,9.0,9.0,9.0...|               0.0|
|   1|[0.0,0.0,0.0,0.0,...|[[9.0,9.0,9.0,9.0...|   2|[0.0,0.0,0.0,0.0,...|[[9.0,

In [38]:
hashes2 = result.select("hashes_2").collect()
hashes2[1].asDict()['hashes_2']


23/04/16 16:35:58 WARN ExtractPythonUDFFromJoinCondition: The join condition:(<lambda>(hashes_1#5729, hashes_2#5741)#5753 >= 0.0) of the join plan contains PythonUDF only, it will be moved out and the join plan will be turned to cross join.


hash_values: [DenseVector([1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 2.0, 7.0, 7.0, 7.0, 3.0, 3.0, 6.0, 3.0, 3.0, 6.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 3.0, 9.0, 9.0, 1.0, 2.0, 2.0, 2.0, 2.0, 2.0, 6.0, 5.0, 2.0, 5.0, 6.0, 9.0, 1.0, 1.0, 1.0

[DenseVector([1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 8.0, 4.0, 2.0, 4.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 7.0, 7.0, 5.0, 5.0, 5.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 9.0, 0.0, 2.0, 5.0, 4.0, 2.0, 5.0, 0.0, 3.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1