In [1]:
from pyspark.sql import SparkSession
from pyspark.ml.feature import NGram, BucketedRandomProjectionLSH
from pyspark.ml.feature import CountVectorizer
from pyspark.sql.types import StructType,StructField, StringType, ArrayType

In [2]:
# Create spark session with increased memory
spark = (SparkSession.builder.master("local[*]")
    .appName("lsh").config("spark.driver.memory", "8g")
    .config("spark.executor.memory", "8g")
    .getOrCreate())
sc = spark.sparkContext

22/03/17 20:02:07 WARN Utils: Your hostname, adneovrebo.local resolves to a loopback address: 127.0.0.1; using 192.168.68.108 instead (on interface en0)
22/03/17 20:02:07 WARN Utils: Set SPARK_LOCAL_IP if you need to bind to another address
Using Spark's default log4j profile: org/apache/spark/log4j-defaults.properties
Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
22/03/17 20:02:08 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable
22/03/17 20:02:09 WARN Utils: Service 'SparkUI' could not bind on port 4040. Attempting port 4041.
22/03/17 20:02:09 WARN Utils: Service 'SparkUI' could not bind on port 4041. Attempting port 4042.


# Bucketing

In [6]:
# Split on tab and create a new column with rdd and split article_text into array of words
rdd = (sc.textFile('cleaned.txt')
        .map(lambda line: line.split('\t'))
        .map(lambda r: (r[0], r[1].split(" "))))

schema = StructType([
        StructField('id', StringType()),
        StructField('words', ArrayType(elementType=StringType()))
])

df = spark.createDataFrame(rdd, schema)


# # Make ngrams of size n
ngram = NGram(n=2, inputCol="words", outputCol="ngrams")
ngram_df = ngram.transform(df)

# # # # Countvectorizer
cv = CountVectorizer(inputCol="ngrams", outputCol="features", vocabSize=100_000, minDF=2)
cv_model = cv.fit(ngram_df)
cv_df = cv_model.transform(ngram_df)

brp = BucketedRandomProjectionLSH(inputCol="features", outputCol="hashes", bucketLength=1_000_000,
                                  numHashTables=100)
model = brp.fit(cv_df)

                                                                                

In [10]:
# Load text from review.txt file
text = open('review.txt', 'r').read().split(" ")
# add text to dataframe 
text_df = spark.createDataFrame([(text, )], ['words'])
# Find ngrams of text
text_ngram = ngram.transform(text_df)
# Countvectorize text
text_cv = cv_model.transform(text_ngram)
# Get the key
key = text_cv.first()["features"]

# Find the nearest neighbors
res = model.approxNearestNeighbors(cv_df, key, 10)
res.show()

22/03/17 20:13:37 WARN DAGScheduler: Broadcasting large task binary with size 1411.6 KiB
22/03/17 20:13:37 WARN DAGScheduler: Broadcasting large task binary with size 1411.6 KiB
22/03/17 20:13:37 WARN DAGScheduler: Broadcasting large task binary with size 1411.6 KiB
22/03/17 20:13:38 WARN DAGScheduler: Broadcasting large task binary with size 77.7 MiB

+-----------------+--------------------+--------------------+--------------------+--------------------+------------------+
|               id|               words|              ngrams|            features|              hashes|           distCol|
+-----------------+--------------------+--------------------+--------------------+--------------------+------------------+
|  "hep-ph9602258"|["citex12, citeac...|["citex12 citeaci...|(100000,[0,1,2,3,...|[[-1.0], [-1.0], ...|29.866369046136157|
|      "0809.4539"|["the, discovery,...|["the discovery, ...|(100000,[0,1,2,3,...|[[-1.0], [0.0], [...| 57.60208329565867|
|  "hep-ph0207019"|["when, confronti...|["when confrontin...|(100000,[0,1,2,3,...|[[-1.0], [-1.0], ...| 57.86190456595773|
| "hep-lat9911007"|["a, rather, comp...|["a rather, rathe...|(100000,[0,1,2,5,...|[[-1.0], [0.0], [...| 58.98304841223451|
|     "1505.07378"|["neutrons, emitt...|["neutrons emitte...|(100000,[0,1,3,5,...|[[-1.0], [0.0], [...|              59.0|
|"astro-ph980518

22/03/17 20:14:01 WARN DAGScheduler: Broadcasting large task binary with size 77.7 MiB
                                                                                