In [1]:
from pyspark import SparkContext
from pyspark.sql.types import *
sc = SparkContext.getOrCreate()
sqlContext = SQLContext(sc)

In [2]:
#Load the data and create an RDD (16 pixels and label)
pen_raw = sc.textFile("../Data/penbased.dat", 4).map(lambda x:  x.split(", ")).map(lambda row: [float(x) for x in row])

In [8]:
pen_raw.take(1)

[[47.0,
  100.0,
  27.0,
  81.0,
  57.0,
  37.0,
  26.0,
  0.0,
  0.0,
  23.0,
  56.0,
  53.0,
  100.0,
  90.0,
  40.0,
  98.0,
  8.0]]

In [3]:
#Create a DataFrame
from pyspark.sql.types import *
from pyspark.sql import Row
penschema = StructType([
    StructField("pix1",DoubleType(),True),
    StructField("pix2",DoubleType(),True),
    StructField("pix3",DoubleType(),True),
    StructField("pix4",DoubleType(),True),
    StructField("pix5",DoubleType(),True),
    StructField("pix6",DoubleType(),True),
    StructField("pix7",DoubleType(),True),
    StructField("pix8",DoubleType(),True),
    StructField("pix9",DoubleType(),True),
    StructField("pix10",DoubleType(),True),
    StructField("pix11",DoubleType(),True),
    StructField("pix12",DoubleType(),True),
    StructField("pix13",DoubleType(),True),
    StructField("pix14",DoubleType(),True),
    StructField("pix15",DoubleType(),True),
    StructField("pix16",DoubleType(),True),
    StructField("label",DoubleType(),True)
])

dfpen = sqlContext.createDataFrame(pen_raw.map(lambda x : Row(x[0],x[1],x[2],x[3],x[4],x[5],x[6],x[7],x[8],x[9],x[10],x[11],x[12],x[13],x[14],x[15],x[16])), penschema)

In [9]:
dfpen.show()

+-----+-----+----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+
| pix1| pix2|pix3| pix4| pix5| pix6| pix7| pix8| pix9|pix10|pix11|pix12|pix13|pix14|pix15|pix16|label|
+-----+-----+----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+
| 47.0|100.0|27.0| 81.0| 57.0| 37.0| 26.0|  0.0|  0.0| 23.0| 56.0| 53.0|100.0| 90.0| 40.0| 98.0|  8.0|
|  0.0| 89.0|27.0|100.0| 42.0| 75.0| 29.0| 45.0| 15.0| 15.0| 37.0|  0.0| 69.0|  2.0|100.0|  6.0|  2.0|
|  0.0| 57.0|31.0| 68.0| 72.0| 90.0|100.0|100.0| 76.0| 75.0| 50.0| 51.0| 28.0| 25.0| 16.0|  0.0|  1.0|
|  0.0|100.0| 7.0| 92.0|  5.0| 68.0| 19.0| 45.0| 86.0| 34.0|100.0| 45.0| 74.0| 23.0| 67.0|  0.0|  4.0|
|  0.0| 67.0|49.0| 83.0|100.0|100.0| 81.0| 80.0| 60.0| 60.0| 40.0| 40.0| 33.0| 20.0| 47.0|  0.0|  1.0|
|100.0|100.0|88.0| 99.0| 49.0| 74.0| 17.0| 47.0|  0.0| 16.0| 37.0|  0.0| 73.0| 16.0| 20.0| 20.0|  6.0|
|  0.0|100.0| 3.0| 72.0| 26.0| 35.0| 85.0| 35.0|100.0| 71.0| 73.0| 97.0| 

In [4]:
# Merging the data with Vector Assembler.
from pyspark.ml.feature import VectorAssembler
va = VectorAssembler(outputCol="features", inputCols=dfpen.columns[0:-1]) #except the last col.
penlpoints = va.transform(dfpen).select("features", "label")

In [5]:
# Create Training and Test data.
pendtsets = penlpoints.randomSplit([0.8, 0.2])
pendttrain = pendtsets[0].cache()
pendtvalid = pendtsets[1].cache()

In [27]:
# Train the model.
from pyspark.ml.classification import RandomForestClassifier
rf = RandomForestClassifier(maxDepth=20)
rfmodel = rf.fit(pendttrain)
print rfmodel._call_java('toDebugString')

RandomForestClassificationModel (uid=RandomForestClassifier_4cdbad0e1dd83d79951b) with 20 trees
  Tree 0 (weight 1.0):
    If (feature 15 <= 21.0)
     If (feature 1 <= 99.0)
      If (feature 14 <= 97.0)
       If (feature 7 <= 78.0)
        If (feature 4 <= 56.0)
         If (feature 9 <= 62.0)
          If (feature 9 <= 27.0)
           If (feature 13 <= 5.0)
            Predict: 1.0
           Else (feature 13 > 5.0)
            If (feature 14 <= 46.0)
             Predict: 6.0
            Else (feature 14 > 46.0)
             Predict: 4.0
          Else (feature 9 > 27.0)
           If (feature 11 <= 44.0)
            If (feature 5 <= 76.0)
             If (feature 0 <= 19.0)
              Predict: 3.0
             Else (feature 0 > 19.0)
              If (feature 8 <= 79.0)
               If (feature 7 <= 56.0)
                Predict: 5.0
               Else (feature 7 > 56.0)
                Predict: 9.0
              Else (feature 8 > 79.0)
               If (feature 6 <= 53.0

In [28]:
from pyspark.ml.evaluation import MulticlassClassificationEvaluator
rfpredicts = rfmodel.transform(pendtvalid)
evaluator = MulticlassClassificationEvaluator(labelCol="label", predictionCol="prediction", metricName="accuracy")
accuracy = evaluator.evaluate(rfpredicts)
print("Test Error = %g" % (1.0 - accuracy))

Test Error = 0.0177215


In [29]:
#Confusion Matrix
rfpredicts.select('label','prediction').rdd.map(lambda x : (x,1)).countByKey()

defaultdict(int,
            {Row(label=0.0, prediction=0.0): 206,
             Row(label=0.0, prediction=4.0): 1,
             Row(label=0.0, prediction=8.0): 1,
             Row(label=0.0, prediction=9.0): 2,
             Row(label=1.0, prediction=1.0): 198,
             Row(label=1.0, prediction=2.0): 7,
             Row(label=1.0, prediction=3.0): 1,
             Row(label=1.0, prediction=5.0): 1,
             Row(label=2.0, prediction=1.0): 3,
             Row(label=2.0, prediction=2.0): 186,
             Row(label=2.0, prediction=3.0): 1,
             Row(label=2.0, prediction=7.0): 1,
             Row(label=3.0, prediction=1.0): 1,
             Row(label=3.0, prediction=3.0): 195,
             Row(label=3.0, prediction=7.0): 1,
             Row(label=4.0, prediction=4.0): 204,
             Row(label=4.0, prediction=9.0): 1,
             Row(label=5.0, prediction=3.0): 2,
             Row(label=5.0, prediction=5.0): 174,
             Row(label=5.0, prediction=8.0): 1,
           

In [18]:
# Alternative : Evaluate with the test data set. (Using MLlib)
from pyspark.mllib.evaluation import MulticlassMetrics
rfpredicts = rfmodel.transform(pendtvalid)
rfresrdd = rfpredicts.select("prediction", "label").rdd
rfmm = MulticlassMetrics(rfresrdd)
#rfmm.precision()
print(rfmm.confusionMatrix())

DenseMatrix([[ 206.,    0.,    0.,    0.,    1.,    0.,    0.,    0.,    1.,
                 2.],
             [   0.,  198.,    7.,    1.,    0.,    1.,    0.,    0.,    0.,
                 0.],
             [   0.,    3.,  186.,    1.,    0.,    0.,    0.,    1.,    0.,
                 0.],
             [   0.,    1.,    0.,  195.,    0.,    0.,    0.,    1.,    0.,
                 0.],
             [   0.,    0.,    0.,    0.,  204.,    0.,    0.,    0.,    0.,
                 1.],
             [   0.,    0.,    0.,    2.,    0.,  174.,    0.,    0.,    1.,
                 2.],
             [   1.,    0.,    0.,    0.,    0.,    1.,  196.,    0.,    0.,
                 0.],
             [   0.,    0.,    0.,    0.,    0.,    0.,    0.,  202.,    1.,
                 0.],
             [   1.,    0.,    0.,    0.,    0.,    1.,    0.,    2.,  196.,
                 0.],
             [   0.,    0.,    0.,    0.,    0.,    0.,    0.,    1.,    1.,
               183.]])
