In [1]:
from pyspark.sql import functions as F
import plotly.express as px
from pyspark.ml.clustering import KMeans
from pyspark.sql.types import *

In [2]:
vectors = spark.read.json('/common/users/shared/cs543_fall22_group3/combined/combined_vectors')

In [17]:
vectors.count()

1001

In [14]:
# Remove punctuation, stop words, and lower case the letters
def stats(line):
    
    average = sum(line[1])/len(line[1])
    max_val = max(line[1])
    min_val = min(line[1])
    
    return [average, max_val, min_val]

stats_udf = F.udf(lambda z: stats(z), ArrayType(FloatType(), containsNull=False))
processed_df = vectors.withColumn("features", stats_udf(F.col("features")))

In [15]:
processed_df.printSchema()

root
 |-- features: array (nullable = true)
 |    |-- element: float (containsNull = false)
 |-- weighCol: double (nullable = true)



In [16]:
processed_df.show()

+--------------------+--------+
|            features|weighCol|
+--------------------+--------+
|[-0.5102026, 0.73...|     1.0|
|[-0.30525059, 0.5...|     1.0|
|[-0.1497361, 0.80...|     1.0|
|[-0.4475176, 0.26...|     1.0|
|[0.002356287, 1.0...|     1.0|
|[0.067357205, 0.6...|     1.0|
|[-0.12473157, 0.6...|     1.0|
|[-0.051773936, 0....|     1.0|
|[-0.3069521, 0.35...|     1.0|
|[-0.16905668, 0.4...|     1.0|
|[-0.22439532, 0.5...|     1.0|
|[0.07051405, 0.80...|     1.0|
|[-0.31240213, 0.6...|     1.0|
|[0.039188497, 0.6...|     1.0|
|[-0.047275685, 0....|     1.0|
|[0.003494953, 0.7...|     1.0|
|[0.35350364, 1.04...|     1.0|
|[-0.05593809, 0.8...|     1.0|
|[-0.24546224, 0.5...|     1.0|
|[0.12138077, 0.71...|     1.0|
+--------------------+--------+
only showing top 20 rows



In [17]:
kmeans = KMeans(k=5)
kmeans.setSeed(1)
kmeans.setWeightCol("weighCol")
kmeans.setMaxIter(10)
model = kmeans.fit(processed_df)

In [18]:
centers = model.clusterCenters()
print(centers)

[array([-0.4042186 ,  0.05086916, -0.92345153]), array([-0.67591357,  0.26159427, -6.74898863]), array([ 0.16861255,  0.76447216, -0.32973419]), array([-0.09461035,  0.28119146, -0.46025592]), array([-0.10190278,  0.58812955, -0.87756608])]
