In [1]:
from pyspark import SparkContext
from pyspark.sql.types import *
sc = SparkContext.getOrCreate()
sqlContext = SQLContext(sc)

In [2]:
#https://spark.apache.org/docs/latest/api/python/pyspark.ml.html?highlight=kmeans#pyspark.ml.clustering.KMeans

In [3]:
#Load the data and create an RDD (16 pixels and label)
pen_raw = sc.textFile("../Data/penbased.dat", 4).map(lambda x:  x.split(", ")).map(lambda row: [float(x) for x in row])

In [4]:
#Create a DataFrame
from pyspark.sql.types import *
from pyspark.sql import Row
penschema = StructType([
    StructField("pix1",DoubleType(),True),
    StructField("pix2",DoubleType(),True),
    StructField("pix3",DoubleType(),True),
    StructField("pix4",DoubleType(),True),
    StructField("pix5",DoubleType(),True),
    StructField("pix6",DoubleType(),True),
    StructField("pix7",DoubleType(),True),
    StructField("pix8",DoubleType(),True),
    StructField("pix9",DoubleType(),True),
    StructField("pix10",DoubleType(),True),
    StructField("pix11",DoubleType(),True),
    StructField("pix12",DoubleType(),True),
    StructField("pix13",DoubleType(),True),
    StructField("pix14",DoubleType(),True),
    StructField("pix15",DoubleType(),True),
    StructField("pix16",DoubleType(),True),
    StructField("label",DoubleType(),True)
])

dfpen = sqlContext.createDataFrame(pen_raw.map(lambda x : Row(x[0],x[1],x[2],x[3],x[4],x[5],x[6],x[7],x[8],x[9],x[10],x[11],x[12],x[13],x[14],x[15],x[16])), penschema)

In [27]:
# Merging the data with Vector Assembler.
from pyspark.ml.feature import VectorAssembler
va = VectorAssembler(outputCol="features", inputCols=dfpen.columns[0:-1]) #except the last col.
penlpoints = va.transform(dfpen)

In [28]:
from pyspark.ml.clustering import KMeans
kmeans =  KMeans(k = 10, maxIter = 200, tol = 0.1) # k = 10 as there are 10 different handwritten numbers.
model = kmeans.fit(penlpoints)

In [18]:
# Evaluate clustering by computing Within Set Sum of Squared Errors
wssse = model.computeCost(penlpoints) 
print("Within Set Sum of Squared Errors = " + str(wssse))

Within Set Sum of Squared Errors = 46160858.0701


In [23]:
# Average distance from the center (max = 100)
import math
print("Average distance from the center = " + str(math.sqrt(wssse/pen_raw.count())))

Average distance from the center = 68.2427139368


In [8]:
# Shows the result.
centers = model.clusterCenters()
print("Cluster Centers: ")
for center in centers:
    print(center)

Cluster Centers: 
[ 88.00580833  97.78993224  52.64762827  87.28848015  21.28944821
  59.95062924   7.01548887  28.31945789  32.33881897   4.47918683
  79.51016457  11.47821878  62.06582769  30.81219748  13.39303001
  24.91674734]
[ 27.44980443  83.71968709  63.03259452  94.55997392  85.55280313
  87.22946545  55.13233377  65.58148631  69.64602347  45.38787484
  87.32920469  22.85397653  52.22946545   7.26597132   4.30247718
   9.58083442]
[ 44.53996448  98.30195382  13.6660746   77.04795737   5.36234458
  49.47424512  69.21669627  47.98401421  96.60923623  65.72824156
  77.97513321  67.89698046  62.92717584  34.38543517  50.60035524
   0.34280639]
[ 87.46744186  87.90813953  58.16046512  92.40232558  35.72325581
  79.76046512  56.75813953  74.44883721  80.47674419  63.75232558
  81.59302326  32.18372093  48.70232558   7.74534884   4.32906977
   3.93604651]
[  3.26024096  60.78554217  30.41204819  72.74578313  72.10240964
  89.77951807  91.46746988  94.43493976  79.77228916  73.8590361

In [34]:
model.transform(penlpoints).select('label', 'prediction').rdd.map(lambda x : (x,1)).countByKey()  
# prediction is a group, not an actual label.

defaultdict(int,
            {Row(label=0.0, prediction=0): 31,
             Row(label=0.0, prediction=2): 12,
             Row(label=0.0, prediction=3): 1,
             Row(label=0.0, prediction=4): 2,
             Row(label=0.0, prediction=5): 352,
             Row(label=0.0, prediction=6): 630,
             Row(label=0.0, prediction=8): 6,
             Row(label=0.0, prediction=9): 3,
             Row(label=1.0, prediction=0): 7,
             Row(label=1.0, prediction=1): 70,
             Row(label=1.0, prediction=3): 66,
             Row(label=1.0, prediction=4): 573,
             Row(label=1.0, prediction=8): 304,
             Row(label=1.0, prediction=9): 2,
             Row(label=2.0, prediction=4): 16,
             Row(label=2.0, prediction=8): 1006,
             Row(label=3.0, prediction=1): 919,
             Row(label=3.0, prediction=3): 2,
             Row(label=3.0, prediction=4): 19,
             Row(label=3.0, prediction=8): 1,
             Row(label=3.0, prediction=9): 1