In [1]:
from pyspark.ml.clustering import KMeans
from pyspark.ml.functions import array_to_vector
from pyspark.sql import functions as F

In [None]:
ROOT = '../'

In [None]:
vectorized_df = spark.read.json(f'{ROOT}/dataset/combined/combined_vectors')
vectorized_df.show()

+--------------------+--------------------+
|        cleaned_text|      output_vectors|
+--------------------+--------------------+
|[one, dead, hurt,...|[0.03610290982760...|
|[shelf, silver, s...|[-0.4456068724393...|
|[year, satyam, fr...|[-0.4698768357435...|
|       [sonic, boom]|[-0.3521701395511...|
|          [cat, men]|[0.07913395762443...|
|[rebuild, holmes,...|[0.09574120491743...|
|[be, mahindras, s...|[-0.1619699504226...|
|   [form, substance]|[-0.2782560214400...|
|[medical, insuran...|[-0.4382889032363...|
|[shah, rukh, twee...|[-0.0087734460830...|
|[trouble, snow, a...|[-0.1241957495609...|
|[next, bill, gate...|[-0.3776799477636...|
|[indian, villages...|[-0.3177983223771...|
|[indian, youth, s...|[-0.1780846191104...|
|[temple, sao, jos...|[-0.0150217153131...|
|[govt, ease, educ...|[-0.6779461907488...|
|[india, beat, tur...|[-0.3868040855973...|
|[musharraf, meet,...|[-0.3752494673244...|
|[india, war, mong...|[-0.1327664434909...|
|[us-pakistan, bic...|[-0.157507

In [5]:
vectorized_df.count()

40954102

In [6]:
kmeans = KMeans(featuresCol='article_embedding', predictionCol='prediction', k=5, maxIter=10, seed=1)
model = kmeans.fit(vectorized_df)

In [8]:
model.clusterCenters()

[array([-0.62283286]),
 array([0.23639457]),
 array([-0.36741422]),
 array([-0.17766151]),
 array([0.01256126])]

In [9]:
result_df = model.transform(vectorized_df)

In [10]:
# 0 -> pop culture
# 1 -> finance
# 2 -> politics
# 3 -> breaking news
# 4 -> sports/entertainment

In [11]:
result_df.filter(result_df['prediction'] == 3).head(10)

[Row(cleaned_text=['be', 'mahindras', 'sell', 'lemon'], output_vectors=DenseVector([-0.162]), prediction=3),
 Row(cleaned_text=['trouble', 'snow', 'afghanistan'], output_vectors=DenseVector([-0.1242]), prediction=3),
 Row(cleaned_text=['indian', 'youth', 'set', 'australia'], output_vectors=DenseVector([-0.1781]), prediction=3),
 Row(cleaned_text=['india', 'war', 'monger', 'country', 'antony'], output_vectors=DenseVector([-0.1328]), prediction=3),
 Row(cleaned_text=['us-pakistan', 'bicker', 'get', 'ugly'], output_vectors=DenseVector([-0.1575]), prediction=3),
 Row(cleaned_text=['india', 'aim', 'upstage', 'rampage', 'lankans'], output_vectors=DenseVector([-0.1428]), prediction=3),
 Row(cleaned_text=['india', 'war', 'monger', 'country', 'antony'], output_vectors=DenseVector([-0.1328]), prediction=3),
 Row(cleaned_text=['ibf', 'float', 'tender', 'exclusive', 'market', 'agent'], output_vectors=DenseVector([-0.1654]), prediction=3),
 Row(cleaned_text=["'re", 'vigorously', 'follow', 'attack',

In [12]:
# 0 -> 
# 1 -> technology?
# 2 -> politics
# 3 -> finance
# 4 -> sports
# 8 -> health

In [None]:
model.write().overwrite().save(f'{ROOT}/models/k_means')