In [1]:
from pyspark.sql import SparkSession
from pyspark.ml.feature import HashingTF, IDF
import pyspark.sql.functions as psf
from pyspark.sql.types import DoubleType
from pyspark.sql.types import StructType,StructField, StringType, ArrayType

In [2]:
spark = (SparkSession.builder.appName('TF-IDF')
    .config("spark.driver.memory", "8g")
    .config("spark.executor.memory", "8g")
    .config('spark.sql.analyzer.maxIterations', 100_000)
    .getOrCreate())
sc = spark.sparkContext

22/04/01 17:21:55 WARN Utils: Your hostname, adneovrebo.local resolves to a loopback address: 127.0.0.1; using 152.94.128.66 instead (on interface en0)
22/04/01 17:21:55 WARN Utils: Set SPARK_LOCAL_IP if you need to bind to another address
Using Spark's default log4j profile: org/apache/spark/log4j-defaults.properties
Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
22/04/01 17:21:56 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable


In [3]:
def split_file(x):
    value=x.split('\t')
    return (value[0].replace('"',""), value[1].split(' '))

In [4]:
# Read categories into dataframe
categories = spark.read.csv('categories.csv', header=True, inferSchema=True)

categories.show(5)

# Make a column for each unique category, split each category on spaces
unique_cats = categories.select('categories').distinct()

# Split unique categories on spaces
cats_split = unique_cats.select(psf.split(psf.col('categories'), ' ').alias('categories'))

# Make a list of all the categories
cats_list = cats_split.select(psf.explode(psf.col('categories')).alias('category')).distinct()

# Add to categories dataframe a column for each unique category
for row in cats_list.collect():
    category = row.category
    categories = categories.withColumn(category, psf.when(categories['categories'].contains(category), 1).otherwise(0))

                                                                                

+---------+---------------+
|       id|     categories|
+---------+---------------+
|0704.0001|         hep-ph|
|0704.0002|  math.CO cs.CG|
|0704.0003| physics.gen-ph|
|0704.0004|        math.CO|
|0704.0005|math.CA math.FA|
+---------+---------------+
only showing top 5 rows



                                                                                

In [4]:
schema = StructType([
        StructField('id', StringType()),
        StructField('words', ArrayType(elementType=StringType()))
])


data = (sc.textFile('cleaned-test.txt')
    .map(lambda x: split_file(x))
    .toDF(schema)
).limit(100)




In [6]:
# Join categories with data
joined_data = data#.join(categories, ["id"])

In [8]:
hashingTF = HashingTF(inputCol="words", outputCol='features', numFeatures=10)
tf = hashingTF.transform(joined_data)

In [9]:
idf = IDF(inputCol='features', outputCol='idf')
model = idf.fit(tf)
tf_idf = model.transform(tf)

                                                                                

In [10]:
cosine_similarity_udf = psf.udf(lambda x,y: round(float(x.dot(y)/(x.norm(2)*y.norm(2))), 2), DoubleType())

test = (tf_idf.alias("i").join(tf_idf.alias("j"), psf.col("i.id") < psf.col("j.id")))

In [11]:
test.show()

+-----------+--------------------+--------------------+--------------------+-----------+--------------------+--------------------+--------------------+
|         id|               words|            features|                 idf|         id|               words|            features|                 idf|
+-----------+--------------------+--------------------+--------------------+-----------+--------------------+--------------------+--------------------+
|  0807.5065|["one, of, the, m...|(10,[0,1,2,3,4,5,...|(10,[0,1,2,3,4,5,...| 1512.09024|["recently, it, w...|(10,[0,1,2,3,4,5,...|(10,[0,1,2,3,4,5,...|
|  0807.5065|["one, of, the, m...|(10,[0,1,2,3,4,5,...|(10,[0,1,2,3,4,5,...|  0908.1812|["this, review, f...|(10,[0,1,2,3,4,5,...|(10,[0,1,2,3,4,5,...|
|  0807.5065|["one, of, the, m...|(10,[0,1,2,3,4,5,...|(10,[0,1,2,3,4,5,...|1009.3123-1|["for, about, 20,...|(10,[0,1,2,3,4,5,...|(10,[0,1,2,3,4,5,...|
|  0807.5065|["one, of, the, m...|(10,[0,1,2,3,4,5,...|(10,[0,1,2,3,4,5,...|  1009.3123|

In [14]:
import pyspark.sql.functions as f

test.withColumn('dot',   f.expr('aggregate(arrays_zip(i.idf, j.idf), 0D, (acc, x) -> acc + (x.Col1 * x.Col2))')) \
  .withColumn('norm1', f.expr('sqrt(aggregate(Col1, 0D, (acc, x) -> acc + (x * x)))')) \
  .withColumn('norm2', f.expr('sqrt(aggregate(Col2, 0D, (acc, x) -> acc + (x * x)))')) \
  .withColumn('cosine', f.expr('dot / (norm1 * norm2)')) \
  .show(truncate=False)

AnalysisException: Can't extract value from idf#21: need struct type but got struct<type:tinyint,size:int,indices:array<int>,values:array<double>>

In [10]:
# res = test.select(
#     psf.col("i.id").alias("i"), 
#     psf.col("j.id").alias("j"), 
#     cosine_similarity_udf("i.idf", "j.idf").alias("cosine_similarity"))\
#     .sort("i", "j")

res = test.select(psf.col("i.id"))

In [11]:
res.take(1)

[Row(id='0807.5065')]

In [12]:
for row in cats_list.collect():
    category = row.category
    test = tf_idf.filter(tf_idf[f"`{category}`"] == 1)

    if test.count() > 1:
        print(test.count(), category)

        cosine_similarity_udf = psf.udf(lambda x,y: round(float(x.dot(y)/(x.norm(2)*y.norm(2))), 2), DoubleType())

        res = (test.alias("i").join(test.alias("j"), psf.col("i.id") < psf.col("j.id"))
            .select(
                psf.col("i.id").alias("i"), 
                psf.col("j.id").alias("j"), 
                cosine_similarity_udf("i.idf", "j.idf").alias("cosine_similarity"))
            .sort("i", "j"))
        print(res.take(5))

2 astro-ph.SR


ERROR:root:KeyboardInterrupt while sending command.             (24 + 12) / 144]
Traceback (most recent call last):
  File "/Library/Frameworks/Python.framework/Versions/3.9/lib/python3.9/site-packages/py4j/java_gateway.py", line 1038, in send_command
    response = connection.send_command(command)
  File "/Library/Frameworks/Python.framework/Versions/3.9/lib/python3.9/site-packages/py4j/clientserver.py", line 475, in send_command
    answer = smart_decode(self.stream.readline()[:-1])
  File "/Library/Frameworks/Python.framework/Versions/3.9/lib/python3.9/socket.py", line 704, in readinto
    return self._sock.recv_into(b)
KeyboardInterrupt

KeyboardInterrupt: 

Exception in thread "serve-DataFrame" java.net.SocketTimeoutException: Accept timed out
	at java.base/java.net.PlainSocketImpl.socketAccept(Native Method)
	at java.base/java.net.AbstractPlainSocketImpl.accept(AbstractPlainSocketImpl.java:458)
	at java.base/java.net.ServerSocket.implAccept(ServerSocket.java:565)
	at java.base/java.net.ServerSocket.accept(ServerSocket.java:533)
	at org.apache.spark.security.SocketAuthServer$$anon$1.run(SocketAuthServer.scala:64)
