In [1]:
import pandas as pd
import numpy as np

import findspark
findspark.init()
import pyspark
from pyspark.sql import SparkSession

from pyspark.ml.feature import VectorAssembler
from pyspark.ml.linalg import Vectors
from pyspark.ml.clustering import KMeans, BisectingKMeans

In [2]:
# setup spark
sc = pyspark.SparkContext()
ss = SparkSession(sc)

In [3]:
# get data
# get data
fil = './Ex_Files_Spark_ML_AI/Ch03/03_02/clustering_dataset.csv'
data = ss.read.csv(fil, header=True, inferSchema=True).cache()
# talk
print(data.columns)
data.printSchema()
print(data.take(1))
print(data.count())

['col1', 'col2', 'col3']
root
 |-- col1: integer (nullable = true)
 |-- col2: integer (nullable = true)
 |-- col3: integer (nullable = true)

[Row(col1=7, col2=4, col3=1)]
75


In [4]:
# make the features vector
va = VectorAssembler(inputCols=['col1', 'col2', 'col3'], outputCol='features')
data = va.transform(data).cache()
data.show()

+----+----+----+--------------+
|col1|col2|col3|      features|
+----+----+----+--------------+
|   7|   4|   1| [7.0,4.0,1.0]|
|   7|   7|   9| [7.0,7.0,9.0]|
|   7|   9|   6| [7.0,9.0,6.0]|
|   1|   6|   5| [1.0,6.0,5.0]|
|   6|   7|   7| [6.0,7.0,7.0]|
|   7|   9|   4| [7.0,9.0,4.0]|
|   7|  10|   6|[7.0,10.0,6.0]|
|   7|   8|   2| [7.0,8.0,2.0]|
|   8|   3|   8| [8.0,3.0,8.0]|
|   4|  10|   5|[4.0,10.0,5.0]|
|   7|   4|   5| [7.0,4.0,5.0]|
|   7|   8|   4| [7.0,8.0,4.0]|
|   2|   5|   1| [2.0,5.0,1.0]|
|   2|   6|   2| [2.0,6.0,2.0]|
|   2|   3|   8| [2.0,3.0,8.0]|
|   3|   9|   1| [3.0,9.0,1.0]|
|   4|   2|   9| [4.0,2.0,9.0]|
|   1|   7|   1| [1.0,7.0,1.0]|
|   6|   2|   3| [6.0,2.0,3.0]|
|   4|   1|   9| [4.0,1.0,9.0]|
+----+----+----+--------------+
only showing top 20 rows



In [5]:
''' now use kmeans for custering'''
# define the model
kmeans = KMeans(k=3, seed=42, predictionCol='kmeans')

# fit kmeans
kmodel = kmeans.fit(data)

# get cluster centroids
print(kmodel.clusterCenters())

# predict
data = kmodel.transform(data)

# talk
data.sample(withReplacement=False, fraction=0.2, seed=42).show()

[array([35.88461538, 31.46153846, 34.42307692]), array([5.12, 5.84, 4.84]), array([80.        , 79.20833333, 78.29166667])]
+----+----+----+-----------------+------+
|col1|col2|col3|         features|kmeans|
+----+----+----+-----------------+------+
|   7|   8|   2|    [7.0,8.0,2.0]|     1|
|   4|   2|   9|    [4.0,2.0,9.0]|     1|
|   6|   2|   3|    [6.0,2.0,3.0]|     1|
|  15|  23|  32| [15.0,23.0,32.0]|     0|
|  49|  29|  15| [49.0,29.0,15.0]|     0|
|  38|  27|  25| [38.0,27.0,25.0]|     0|
|  37|  39|  46| [37.0,39.0,46.0]|     0|
|  17|  29|  41| [17.0,29.0,41.0]|     0|
|  83|  72|  80| [83.0,72.0,80.0]|     2|
|  84|  90| 100|[84.0,90.0,100.0]|     2|
|  61|  82|  73| [61.0,82.0,73.0]|     2|
|  81|  60|  69| [81.0,60.0,69.0]|     2|
|  67|  80|  98| [67.0,80.0,98.0]|     2|
|  88|  68|  95| [88.0,68.0,95.0]|     2|
+----+----+----+-----------------+------+



In [6]:
''' now try bisecting kmeans for custering'''
# define the model
bkm = BisectingKMeans(k=3, seed=42, predictionCol='bikmeans')

# fit kmeans
kmodel = bkm.fit(data)

# get cluster centroids
print(kmodel.clusterCenters())

# predict
data = kmodel.transform(data)

# talk
data.sample(withReplacement=False, fraction=0.2, seed=42).show()

[array([5.12, 5.84, 4.84]), array([35.88461538, 31.46153846, 34.42307692]), array([80.        , 79.20833333, 78.29166667])]
+----+----+----+-----------------+------+--------+
|col1|col2|col3|         features|kmeans|bikmeans|
+----+----+----+-----------------+------+--------+
|   7|   8|   2|    [7.0,8.0,2.0]|     1|       0|
|   4|   2|   9|    [4.0,2.0,9.0]|     1|       0|
|   6|   2|   3|    [6.0,2.0,3.0]|     1|       0|
|  15|  23|  32| [15.0,23.0,32.0]|     0|       1|
|  49|  29|  15| [49.0,29.0,15.0]|     0|       1|
|  38|  27|  25| [38.0,27.0,25.0]|     0|       1|
|  37|  39|  46| [37.0,39.0,46.0]|     0|       1|
|  17|  29|  41| [17.0,29.0,41.0]|     0|       1|
|  83|  72|  80| [83.0,72.0,80.0]|     2|       2|
|  84|  90| 100|[84.0,90.0,100.0]|     2|       2|
|  61|  82|  73| [61.0,82.0,73.0]|     2|       2|
|  81|  60|  69| [81.0,60.0,69.0]|     2|       2|
|  67|  80|  98| [67.0,80.0,98.0]|     2|       2|
|  88|  68|  95| [88.0,68.0,95.0]|     2|       2|
+----+---

# don't understand how this is hierarchical?