In [9]:
# import libraries
from pyspark.sql import SparkSession
from pyspark.ml.linalg import Vectors
from pyspark.sql.functions import col
from pyspark.ml.feature import BucketedRandomProjectionLSH

import pandas as pd
import numpy as np

# Spark session & context
spark = SparkSession.builder.master("local").getOrCreate()
sc = spark.sparkContext

data = [(0, Vectors.dense([-1.0, -1.0 ]),),
        (1, Vectors.dense([-1.0, 1.0 ]),),
        (2, Vectors.dense([1.0, -1.0 ]),),
        (3, Vectors.dense([1.0, 1.0]),)]

data2 = [(4, Vectors.dense([2.0, 2.0 ]),),
         (5, Vectors.dense([2.0, 3.0 ]),),
         (6, Vectors.dense([3.0, 2.0 ]),),
         (7, Vectors.dense([3.0, 3.0]),)]

df = spark.createDataFrame(data, ["id", "features"])
df2 = spark.createDataFrame(data2, ["id", "features"])
df.show(), df2.show()

+---+-----------+
| id|   features|
+---+-----------+
|  0|[-1.0,-1.0]|
|  1| [-1.0,1.0]|
|  2| [1.0,-1.0]|
|  3|  [1.0,1.0]|
+---+-----------+

+---+---------+
| id| features|
+---+---------+
|  4|[2.0,2.0]|
|  5|[2.0,3.0]|
|  6|[3.0,2.0]|
|  7|[3.0,3.0]|
+---+---------+



(None, None)

In [10]:
# setting up brp
# used to hash the features based on euclidean distance
brp = BucketedRandomProjectionLSH(inputCol='features', outputCol='hashes', seed=4526, bucketLength=1.0)
model = brp.fit(df)
model.transform(df).show()

+---+-----------+--------+
| id|   features|  hashes|
+---+-----------+--------+
|  0|[-1.0,-1.0]| [[0.0]]|
|  1| [-1.0,1.0]| [[1.0]]|
|  2| [1.0,-1.0]|[[-2.0]]|
|  3|  [1.0,1.0]|[[-1.0]]|
+---+-----------+--------+



In [17]:
model.approxNearestNeighbors(df2, Vectors.dense([1.0, 2.0]), 1).collect()

[Row(id=5, features=DenseVector([2.0, 3.0]), hashes=[DenseVector([0.0])], distCol=1.4142135623730951)]

In [13]:
# test cross join with tolerance of 3 units
joined_df = model.approxSimilarityJoin(df, df2, 3.0, distCol="EuclideanDistance")
joined_df.show(10)

+--------------------+--------------------+------------------+
|            datasetA|            datasetB| EuclideanDistance|
+--------------------+--------------------+------------------+
|{3, [1.0,1.0], [[...|{7, [3.0,3.0], [[...|2.8284271247461903|
|{3, [1.0,1.0], [[...|{4, [2.0,2.0], [[...|1.4142135623730951|
+--------------------+--------------------+------------------+

