In [1]:
# SparkContext represents the connection to a Spark cluster
from pyspark.context import SparkContext
# Configuration for a Spark application
from pyspark.conf import SparkConf
# The entry point to programming Spark with the Dataset and DataFrame API
from pyspark.sql.session import SparkSession

spark = SparkSession.builder \
    .appName("P03_Clustering") \
    .config("spark.executor.memory", "4g") \
    .config("spark.driver.memory", "4g") \
    .config("spark.sql.repl.eagerEval.enabled", True) \
    .config("spark.sql.repl.eagerEval.truncate", 500) \
    .getOrCreate()


----------------------------------------
Exception occurred during processing of request from ('127.0.0.1', 35400)
Traceback (most recent call last):
  File "/opt/conda/lib/python3.11/socketserver.py", line 317, in _handle_request_noblock
    self.process_request(request, client_address)
  File "/opt/conda/lib/python3.11/socketserver.py", line 348, in process_request
    self.finish_request(request, client_address)
  File "/opt/conda/lib/python3.11/socketserver.py", line 361, in finish_request
    self.RequestHandlerClass(request, client_address, self)
  File "/opt/conda/lib/python3.11/socketserver.py", line 755, in __init__
    self.handle()
  File "/usr/local/spark/python/pyspark/accumulators.py", line 295, in handle
    poll(accum_updates)
  File "/usr/local/spark/python/pyspark/accumulators.py", line 267, in poll
    if self.rfile in r and func():
                           ^^^^^^
  File "/usr/local/spark/python/pyspark/accumulators.py", line 271, in accum_updates
    num_updates =

In [2]:
dblp_ref_file_path = "dblp-ref/dblp-ref-0.json"
papers_df = spark.read.json(dblp_ref_file_path)

papers_df.printSchema()

root
 |-- abstract: string (nullable = true)
 |-- authors: array (nullable = true)
 |    |-- element: string (containsNull = true)
 |-- id: string (nullable = true)
 |-- n_citation: long (nullable = true)
 |-- references: array (nullable = true)
 |    |-- element: string (containsNull = true)
 |-- title: string (nullable = true)
 |-- venue: string (nullable = true)
 |-- year: long (nullable = true)



In [3]:
papers_df.show()

+--------------------+--------------------+--------------------+----------+--------------------+--------------------+--------------------+----+
|            abstract|             authors|                  id|n_citation|          references|               title|               venue|year|
+--------------------+--------------------+--------------------+----------+--------------------+--------------------+--------------------+----+
|The purpose of th...|[Makoto Satoh, Ry...|00127ee2-cb05-48c...|         0|[51c7e02e-f5ed-43...|Preliminary Desig...|international con...|2013|
|This paper descri...|[Gareth Beale, Gr...|001c58d3-26ad-46b...|        50|[10482dd3-4642-41...|A methodology for...|visual analytics ...|2011|
|This article appl...|[Altaf Hossain, F...|001c8744-73c4-4b0...|        50|[2d84c0f2-e656-4c...|Comparison of GAR...|pattern recogniti...|2009|
|                NULL|[Jea-Bum Park, By...|00338203-9eb3-40c...|         0|[8c78e4b0-632b-42...|Development of Re...|                   

## Preprocessing

In [4]:
!pip install langdetect



In [5]:
from pyspark.sql.functions import col, udf, lower, regexp_replace, split
from pyspark.sql.types import ArrayType, StringType, BooleanType
from pyspark.ml.feature import StopWordsRemover
from langdetect import detect, LangDetectException

custom_stop_words = ['doi', 'preprint', 'copyright', 'peer', 'reviewed', 'org', 'https', 'et', 
                     'al', 'author', 'figure','rights', 'reserved', 'permission', 'used', 'using', 
                     'biorxiv', 'medrxiv', 'license', 'fig', 'fig.', 'al.', 'Elsevier', 'PMC', 'CZI', 'www']


def remove_punctuation(text):
    return regexp_replace(text, r'[!()\[\]{};:"\,<>./?@#$%^&*_~]', '')

"""
def detect_language(text):
    try:
        return detect(text) == 'en'
    except LangDetectException:
        return False

detect_language_udf = udf(detect_language, BooleanType())
"""

papers_cleaned_df = papers_df.filter(col("abstract").isNotNull() & (col("abstract").rlike(r'\w')))
#papers_cleaned_df = papers_cleaned_df.filter(detect_language_udf(col("title")))
papers_cleaned_df = papers_cleaned_df.withColumn("abstract", remove_punctuation(col("abstract")))
papers_cleaned_df = papers_cleaned_df.withColumn("abstract", lower(col("abstract")))
papers_cleaned_df = papers_cleaned_df.withColumn("words", split(col("abstract"), " "))

remover = StopWordsRemover(inputCol="words", outputCol="filtered_words")
papers_cleaned_df = remover.transform(papers_cleaned_df)


def remove_custom_stop_words(words, custom_stop_words):
    if words is None:
        return None
    return [word for word in words if word not in custom_stop_words]

remove_custom_stop_words_udf = udf(lambda words: remove_custom_stop_words(words, custom_stop_words), ArrayType(StringType()))

papers_cleaned_df = papers_cleaned_df.withColumn("final_filtered", remove_custom_stop_words_udf(col("filtered_words")))




In [6]:
papers_cleaned_df.show()

+--------------------+--------------------+--------------------+----------+--------------------+--------------------+--------------------+----+--------------------+--------------------+--------------------+
|            abstract|             authors|                  id|n_citation|          references|               title|               venue|year|               words|      filtered_words|      final_filtered|
+--------------------+--------------------+--------------------+----------+--------------------+--------------------+--------------------+----+--------------------+--------------------+--------------------+
|the purpose of th...|[Makoto Satoh, Ry...|00127ee2-cb05-48c...|         0|[51c7e02e-f5ed-43...|Preliminary Desig...|international con...|2013|[the, purpose, of...|[purpose, study, ...|[purpose, study, ...|
|this paper descri...|[Gareth Beale, Gr...|001c58d3-26ad-46b...|        50|[10482dd3-4642-41...|A methodology for...|visual analytics ...|2011|[this, paper, des...|[paper, 

## Vectorization

In [7]:
from pyspark.ml.feature import Tokenizer, HashingTF, IDF
from pyspark.sql.functions import when
#NB! this takes a long time to calculate so don't run this often
hashing_tf = HashingTF(inputCol="final_filtered", outputCol="raw_features", numFeatures=10000)
tf_df = hashing_tf.transform(papers_cleaned_df)

idf = IDF(inputCol="raw_features", outputCol="features")
idf_model = idf.fit(tf_df)
tfidf_df = idf_model.transform(tf_df)
tfidf_df = tfidf_df.filter(col("features").isNotNull())


In [8]:
# Show the final DataFrame with TF-IDF features
tfidf_df.select("final_filtered", "features").show()

+--------------------+--------------------+
|      final_filtered|            features|
+--------------------+--------------------+
|[purpose, study, ...|(10000,[1072,1241...|
|[paper, describes...|(10000,[110,157,2...|
|[article, applied...|(10000,[45,188,79...|
|[recent, achievem...|(10000,[86,157,18...|
|[recently, bridge...|(10000,[107,310,4...|
|[applications, ab...|(10000,[592,764,8...|
|[three, speech, t...|(10000,[39,72,102...|
|[paper, focuses, ...|(10000,[157,749,1...|
|[embedded, system...|(10000,[7,364,573...|
|[xax, browser, pl...|(10000,[23,286,38...|
|[recent, years, m...|(10000,[157,274,3...|
|[previous, langua...|(10000,[134,221,4...|
|[spatial, encrypt...|(10000,[46,270,27...|
|[system, operatio...|(10000,[7,15,163,...|
|[business, strate...|(10000,[120,157,2...|
|[ftp, mirror, tra...|(10000,[42,78,193...|
|[number, alternat...|(10000,[1020,1137...|
|[breast, cancer, ...|(10000,[157,274,4...|
|[development, aut...|(10000,[23,32,45,...|
|[quality, specifi...|(10000,[27

## Clustering

In [11]:
from pyspark.ml.feature import PCA

# Select the first 500 rows
limited_tfidf_df = tfidf_df.limit(500)

#throws somekind of error
pca = PCA(k=10, inputCol="features", outputCol="pca_features")
pca_model = pca.fit(limited_tfidf_df)
pca_df = pca_model.transform(limited_tfidf_df)

ERROR:root:Exception while sending command.
Traceback (most recent call last):
  File "/usr/local/spark/python/lib/py4j-0.10.9.7-src.zip/py4j/clientserver.py", line 516, in send_command
    raise Py4JNetworkError("Answer from Java side is empty")
py4j.protocol.Py4JNetworkError: Answer from Java side is empty

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/usr/local/spark/python/lib/py4j-0.10.9.7-src.zip/py4j/java_gateway.py", line 1038, in send_command
    response = connection.send_command(command)
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/spark/python/lib/py4j-0.10.9.7-src.zip/py4j/clientserver.py", line 539, in send_command
    raise Py4JNetworkError(
py4j.protocol.Py4JNetworkError: Error while sending or receiving


Py4JError: An error occurred while calling o172.fit

ERROR:root:Exception while sending command.
Traceback (most recent call last):
  File "/usr/local/spark/python/lib/py4j-0.10.9.7-src.zip/py4j/clientserver.py", line 516, in send_command
    raise Py4JNetworkError("Answer from Java side is empty")
py4j.protocol.Py4JNetworkError: Answer from Java side is empty

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/usr/local/spark/python/lib/py4j-0.10.9.7-src.zip/py4j/java_gateway.py", line 1038, in send_command
    response = connection.send_command(command)
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/local/spark/python/lib/py4j-0.10.9.7-src.zip/py4j/clientserver.py", line 539, in send_command
    raise Py4JNetworkError(
py4j.protocol.Py4JNetworkError: Error while sending or receiving


In [10]:
from pyspark.ml.feature import PCA
from pyspark.ml.clustering import KMeans
from pyspark.ml.evaluation import ClusteringEvaluator
from pyspark.ml.linalg import Vectors

#without PCA, this "works", need to change the featuresCol name and df name

# Determine the optimal number of clusters using the Elbow Method
cost = []
evaluator = ClusteringEvaluator()

for k in range(2, 11):
    print(k)
    kmeans = KMeans(featuresCol='pca_features', k=k)
    model = kmeans.fit(pca_df)
    predictions = model.transform(pca_df)
    silhouette = evaluator.evaluate(predictions)
    cost.append((k, silhouette))
    print(f"With K={k}, the Silhouette score is {silhouette}")

# Choose the best K (you can automate this step)
best_k = max(cost, key=lambda item: item[1])[0]
print(f"Best K found: {best_k}")

# Fit the final K-means model with the best K
kmeans = KMeans(featuresCol='pca_features', k=best_k)
model = kmeans.fit(pca_df)
predictions = model.transform(pca_df)

# Show the resulting clusters
predictions.select("id", "prediction").show()

# If you want to see the cluster centers
centers = model.clusterCenters()
print("Cluster Centers: ")
for center in centers:
    print(center)


Py4JJavaError: An error occurred while calling o155.fit.
: java.lang.OutOfMemoryError: Java heap space
	at breeze.linalg.svd$.breeze$linalg$svd$$doSVD_Double(svd.scala:94)
	at breeze.linalg.svd$Svd_DM_Impl$.apply(svd.scala:36)
	at breeze.linalg.svd$Svd_DM_Impl$.apply(svd.scala:35)
	at breeze.generic.UFunc.apply(UFunc.scala:47)
	at breeze.generic.UFunc.apply$(UFunc.scala:46)
	at breeze.linalg.svd$.apply(svd.scala:21)
	at org.apache.spark.mllib.linalg.distributed.RowMatrix.computePrincipalComponentsAndExplainedVariance(RowMatrix.scala:501)
	at org.apache.spark.mllib.feature.PCA.fit(PCA.scala:65)
	at org.apache.spark.ml.feature.PCA.fit(PCA.scala:93)
	at java.base/jdk.internal.reflect.NativeMethodAccessorImpl.invoke0(Native Method)
	at java.base/jdk.internal.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:77)
	at java.base/jdk.internal.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)
	at java.base/java.lang.reflect.Method.invoke(Method.java:568)
	at py4j.reflection.MethodInvoker.invoke(MethodInvoker.java:244)
	at py4j.reflection.ReflectionEngine.invoke(ReflectionEngine.java:374)
	at py4j.Gateway.invoke(Gateway.java:282)
	at py4j.commands.AbstractCommand.invokeMethod(AbstractCommand.java:132)
	at py4j.commands.CallCommand.execute(CallCommand.java:79)
	at py4j.ClientServerConnection.waitForCommands(ClientServerConnection.java:182)
	at py4j.ClientServerConnection.run(ClientServerConnection.java:106)
	at java.base/java.lang.Thread.run(Thread.java:840)
