In [95]:
from pyspark.sql import SparkSession
from pyspark.ml.feature import HashingTF, IDF
import pyspark.sql.functions as psf
from pyspark.sql.types import DoubleType

In [96]:
spark = SparkSession.builder.appName('TF-IDF').getOrCreate()
sc = spark.sparkContext

In [97]:
def split_file(x):
    value=x.split('\t')
    return (value[0], value[1].split(' '))

In [124]:
# Read categories into dataframe
categories = spark.read.csv('article_categories.csv', header=True, inferSchema=True)

# Make a column for each unique category, split each category on spaces
unique_cats = categories.select('categories').distinct()

# Split unique categories on spaces
cats_split = unique_cats.select(psf.split(psf.col('categories'), ' ').alias('categories'))

# Make a list of all the categories
cats_list = cats_split.select(psf.explode(psf.col('categories')).alias('category')).distinct()

# Add to categories dataframe a column for each unique category
for row in cats_list.collect():
    category = row.category
    categories = categories.withColumn(category, psf.when(categories['categories'].contains(category), 1).otherwise(0))

                                                                                

In [125]:
data = (sc.textFile('cleaned-test.txt')
    .map(lambda x: split_file(x))
    .toDF(['id', 'words'])
)

data.show()

+---+--------------------+
| id|               words|
+---+--------------------+
|  3|["this, article, ...|
|  2|["lorem, ipsum, d...|
|  1|["lorem, ipsum, d...|
+---+--------------------+



In [126]:
# Join categories with data
joined_data = data.join(categories, ["id"])
joined_data.show()

22/03/18 15:30:48 WARN package: Truncated the string representation of a plan since it was too large. This behavior can be adjusted by setting 'spark.sql.debug.maxToStringFields'.

+---+-----+----------+-------+--------+-----+-----+-----------+-----+-----+---------------+--------+-------+-------+--------------+-------+-------+---------------+--------+-----+----------------+-----+-------+-------+-----+-------+---------------+-----------+------------------+-----+-------+-------+-------+--------+--------------+-----+----------------+--------+-----+--------+-----+-----------+-----+-------+-----+-------+--------------+-------+-------+--------------+---------------+-----+-------------+--------+---------------+--------+-------+-----+--------+-----+-----+-------+--------+--------+-------+--------+--------+-------+-----+--------+--------+-----+-------+--------+-----+-------+-------+---------------+--------------+----------------+-------+-------+-----+-------+-------+-----+-------+-------+-----------------+-----+-------+-------+---------------+-------+-----+-----+--------+-------+--------+-----+--------------+-----+-----------+-------+-----+--------+-----+------+-----+----

                                                                                

In [101]:
hashingTF = HashingTF(inputCol="words", outputCol='features', numFeatures=100000)
tf = hashingTF.transform(joined_data)

In [102]:
idf = IDF(inputCol='features', outputCol='idf')
model = idf.fit(tf)
tf_idf = model.transform(tf)

In [112]:
for row in cats_list.collect():
    category = row.category
    test = tf_idf.filter(tf_idf[f"`{category}`"] == 0)

    cosine_similarity_udf = psf.udf(lambda x,y: round(float(x.dot(y)/(x.norm(2)*y.norm(2))), 2), DoubleType())

    res = (test.alias("i").join(test.alias("j"), psf.col("i.id") < psf.col("j.id"))
        .select(
            psf.col("i.id").alias("i"), 
            psf.col("j.id").alias("j"), 
            cosine_similarity_udf("i.idf", "j.idf").alias("cosine_similarity"))
        .sort("i", "j"))
    res.show(5)

22/03/18 15:19:52 WARN DAGScheduler: Broadcasting large task binary with size 1639.9 KiB


+---+---+-----------------+
|  i|  j|cosine_similarity|
+---+---+-----------------+
|  1|  2|             0.34|
+---+---+-----------------+



22/03/18 15:19:53 WARN DAGScheduler: Broadcasting large task binary with size 1639.9 KiB


+---+---+-----------------+
|  i|  j|cosine_similarity|
+---+---+-----------------+
|  1|  3|              0.0|
+---+---+-----------------+



22/03/18 15:19:53 WARN DAGScheduler: Broadcasting large task binary with size 1639.9 KiB


+---+---+-----------------+
|  i|  j|cosine_similarity|
+---+---+-----------------+
|  1|  3|              0.0|
+---+---+-----------------+

+---+---+-----------------+
|  i|  j|cosine_similarity|
+---+---+-----------------+
|  2|  3|              0.0|
+---+---+-----------------+



22/03/18 15:19:54 WARN DAGScheduler: Broadcasting large task binary with size 1639.9 KiB
