In [119]:
import sys
from pyspark.sql import SparkSession, DataFrame
from pyspark.ml.feature import HashingTF, IDF
from sklearn.metrics.pairwise import cosine_similarity
import numpy as np
import pandas as pd
from pyspark.sql.types import *
from functools import reduce

In [72]:
def split_file(x):
    value=x.split('\t')
    return (value[0].replace('"', ''), value[1].split(' '))

In [59]:
spark = SparkSession.builder.appName('SparkByExamples.com').getOrCreate()
df = spark.read.text('cleaned.txt')
rdd = df.rdd.map(lambda x: split_file(x))
df2 = rdd.toDF().withColumnRenamed('_2', 'content').withColumnRenamed('_1', 'id')

hashingTF = HashingTF(inputCol="content", outputCol='features')
hashingTF.setNumFeatures(1000)

tf = hashingTF.transform(df2)

idf = IDF()
idf.setInputCol('features')
idf.setOutputCol('idf')
model = idf.fit(tf)
tf_idf = model.transform(tf)

In [65]:
sc = spark.sparkContext

In [76]:
# Read categories into dataframe
categories = spark.read.csv('cat_test.csv', header=True, inferSchema=True)

# Make a column for each unique category, split each category on spaces
unique_cats = categories.select('categories').distinct()

In [77]:

# Split unique categories on spaces
cats_split = unique_cats.select(psf.split(psf.col('categories'), ' ').alias('categories'))

# Make a list of all the categories
cats_list = cats_split.select(psf.explode(psf.col('categories')).alias('category')).distinct()

In [78]:
# # Add to categories dataframe a column for each unique category
for row in cats_list.collect():
    category = row.category
    categories = categories.withColumn(category, psf.when(categories['categories'].contains(category), 1).otherwise(0))

In [79]:
categories.take(1)

[Row(id='0908.1812', categories='astro-ph.SR astro-ph.CO astro-ph.HE', gr-qc=0, astro-ph.SR=1, astro-ph.CO=1, astro-ph.HE=1, stat.AP=0, hep-ph=0)]

In [83]:
data = (sc.textFile('cleaned-test.txt')
    .map(lambda x: split_file(x))
    .toDF(['id', 'words'])
)


In [84]:
joined_data = data.join(categories, ["id"])


In [85]:
joined_data.show()

+-----------+--------------------+--------------------+-----+-----------+-----------+-----------+-------+------+
|         id|               words|          categories|gr-qc|astro-ph.SR|astro-ph.CO|astro-ph.HE|stat.AP|hep-ph|
+-----------+--------------------+--------------------+-----+-----------+-----------+-----------+-------+------+
|  0807.5065|["one, of, the, m...|               gr-qc|    1|          0|          0|          0|      0|     0|
| 1512.09024|["recently, it, w...|              hep-ph|    0|          0|          0|          0|      0|     1|
|  0908.1812|["this, review, f...|astro-ph.SR astro...|    0|          1|          1|          1|      0|     0|
|1009.3123-1|["for, about, 20,...| stat.AP astro-ph.SR|    0|          1|          0|          0|      1|     0|
|  1009.3123|["for, about, 20,...| stat.AP astro-ph.SR|    0|          1|          0|          0|      1|     0|
+-----------+--------------------+--------------------+-----+-----------+-----------+-----------

In [106]:
hashingTF = HashingTF(inputCol="words", outputCol='features')
hashingTF.setNumFeatures(1000)

tf = hashingTF.transform(joined_data)

idf = IDF()
idf.setInputCol('features')
idf.setOutputCol('idf')
model = idf.fit(tf)
tf_idf = model.transform(tf)

In [125]:
cosine_similarity_udf = psf.udf(lambda x,y: round(float(x.dot(y)/(x.norm(2) * y.norm(2))), 4), DoubleType())

result = spark.createDataFrame([], StructType([]))

result_append = []

for row in cats_list.collect(): # Uses collect here since it is a limited number of categories
    cat_df = tf_idf.filter(tf_idf[f'`{row.category}`'] == 1)

    res = cat_df.alias("i").join(cat_df.alias("j"), psf.col("i.id") < psf.col("j.id"))\
        .select(
            psf.col("i.id").alias("i"), 
            psf.col("j.id").alias("j"), 
            cosine_similarity_udf("i.idf", "j.idf").alias("similarity"))\
        .sort("i", "j")
    result_append.append(res)

df_series = reduce(DataFrame.unionAll, result_append).distinct()
df_series.show()


+---------+-----------+----------+
|        i|          j|similarity|
+---------+-----------+----------+
|0908.1812|  1009.3123|    0.0223|
|0908.1812|1009.3123-1|    0.0223|
|1009.3123|1009.3123-1|       1.0|
+---------+-----------+----------+



In [32]:
import pyspark.sql.functions as psf