In [1]:
from pyspark.sql import SparkSession
from pyspark.ml.feature import HashingTF, IDF
import pyspark.sql.functions as psf
from pyspark.sql.types import DoubleType
from pyspark.sql.types import StructType,StructField, StringType, ArrayType

In [2]:
spark = (SparkSession.builder.appName('TF-IDF')
    .config("spark.driver.memory", "8g")
    .config("spark.executor.memory", "8g")
    .config('spark.sql.analyzer.maxIterations', 100_000)
    .getOrCreate())
sc = spark.sparkContext

22/03/19 16:01:34 WARN Utils: Your hostname, adneovrebo.local resolves to a loopback address: 127.0.0.1; using 192.168.68.108 instead (on interface en0)
22/03/19 16:01:34 WARN Utils: Set SPARK_LOCAL_IP if you need to bind to another address
Using Spark's default log4j profile: org/apache/spark/log4j-defaults.properties
Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
22/03/19 16:01:35 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable


In [3]:
def split_file(x):
    value=x.split('\t')
    return (value[0].replace('"',""), value[1].split(' '))

In [4]:
# Read categories into dataframe
categories = spark.read.csv('categories.csv', header=True, inferSchema=True)

categories.show(5)

# Make a column for each unique category, split each category on spaces
unique_cats = categories.select('categories').distinct()

# Split unique categories on spaces
cats_split = unique_cats.select(psf.split(psf.col('categories'), ' ').alias('categories'))

# Make a list of all the categories
cats_list = cats_split.select(psf.explode(psf.col('categories')).alias('category')).distinct()

# Add to categories dataframe a column for each unique category
for row in cats_list.collect():
    category = row.category
    categories = categories.withColumn(category, psf.when(categories['categories'].contains(category), 1).otherwise(0))

                                                                                

+---------+---------------+
|       id|     categories|
+---------+---------------+
|0704.0001|         hep-ph|
|0704.0002|  math.CO cs.CG|
|0704.0003| physics.gen-ph|
|0704.0004|        math.CO|
|0704.0005|math.CA math.FA|
+---------+---------------+
only showing top 5 rows



                                                                                

In [6]:
schema = StructType([
        StructField('id', StringType()),
        StructField('words', ArrayType(elementType=StringType()))
])


data = (sc.textFile('cleaned.txt')
    .map(lambda x: split_file(x))
    .toDF(schema)
)


In [7]:
# Join categories with data
joined_data = data.join(categories, ["id"])

In [8]:
hashingTF = HashingTF(inputCol="words", outputCol='features', numFeatures=100000)
tf = hashingTF.transform(joined_data)

In [9]:
idf = IDF(inputCol='features', outputCol='idf')
model = idf.fit(tf)
tf_idf = model.transform(tf)

[Stage 9:>                                                          (0 + 1) / 1]

In [10]:
for row in cats_list.collect():
    category = row.category
    test = tf_idf.filter(tf_idf[f"`{category}`"] == 1)

    if test.count() > 1:
        print(test.count(), category)

        cosine_similarity_udf = psf.udf(lambda x,y: round(float(x.dot(y)/(x.norm(2)*y.norm(2))), 2), DoubleType())

        res = (test.alias("i").join(test.alias("j"), psf.col("i.id") < psf.col("j.id"))
            .select(
                psf.col("i.id").alias("i"), 
                psf.col("j.id").alias("j"), 
                cosine_similarity_udf("i.idf", "j.idf").alias("cosine_similarity"))
            .sort("i", "j"))
        print(res.take(5))

                                                                                

50 nlin.CD


22/03/19 16:05:21 WARN DAGScheduler: Broadcasting large task binary with size 1647.2 KiB
                                                                                

[Row(i='0704.2247', j='0709.0208', cosine_similarity=0.03), Row(i='0704.2247', j='0711.4125', cosine_similarity=0.1), Row(i='0704.2247', j='0801.2927', cosine_similarity=0.06), Row(i='0704.2247', j='0803.2252', cosine_similarity=0.05), Row(i='0704.2247', j='0805.1837', cosine_similarity=0.06)]


                                                                                

2 q-fin.CP


22/03/19 16:05:58 WARN DAGScheduler: Broadcasting large task binary with size 1647.2 KiB
                                                                                

[Row(i='1106.1395', j='1507.01901', cosine_similarity=0.08)]


                                                                                

39 cs.LG


ERROR:root:KeyboardInterrupt while sending command.                 (0 + 1) / 1]
Traceback (most recent call last):
  File "/Library/Frameworks/Python.framework/Versions/3.9/lib/python3.9/site-packages/py4j/java_gateway.py", line 1038, in send_command
    response = connection.send_command(command)
  File "/Library/Frameworks/Python.framework/Versions/3.9/lib/python3.9/site-packages/py4j/clientserver.py", line 475, in send_command
    answer = smart_decode(self.stream.readline()[:-1])
  File "/Library/Frameworks/Python.framework/Versions/3.9/lib/python3.9/socket.py", line 704, in readinto
    return self._sock.recv_into(b)
KeyboardInterrupt


KeyboardInterrupt: 