In [1]:
from pyspark import SparkContext
from pyspark.sql import SparkSession
from pyspark.sql.types import *
sc = SparkContext.getOrCreate()
ss = SparkSession.builder.getOrCreate()

## Create dataframe

In [2]:
#Load the data and create an RDD (16 pixels and label)
pen_raw = sc.textFile("../Data/penbased.dat", 4).map(lambda x:  x.split(", ")).map(lambda row: [float(x) for x in row])

In [3]:
pen_raw.take(1)

[[47.0,
  100.0,
  27.0,
  81.0,
  57.0,
  37.0,
  26.0,
  0.0,
  0.0,
  23.0,
  56.0,
  53.0,
  100.0,
  90.0,
  40.0,
  98.0,
  8.0]]

In [4]:
#Create a DataFrame
from pyspark.sql.types import *
from pyspark.sql import Row
penschema = StructType([
    StructField("pix1",DoubleType(),True),
    StructField("pix2",DoubleType(),True),
    StructField("pix3",DoubleType(),True),
    StructField("pix4",DoubleType(),True),
    StructField("pix5",DoubleType(),True),
    StructField("pix6",DoubleType(),True),
    StructField("pix7",DoubleType(),True),
    StructField("pix8",DoubleType(),True),
    StructField("pix9",DoubleType(),True),
    StructField("pix10",DoubleType(),True),
    StructField("pix11",DoubleType(),True),
    StructField("pix12",DoubleType(),True),
    StructField("pix13",DoubleType(),True),
    StructField("pix14",DoubleType(),True),
    StructField("pix15",DoubleType(),True),
    StructField("pix16",DoubleType(),True),
    StructField("label",DoubleType(),True)
])

dfpen = ss.createDataFrame(pen_raw.map(lambda x : Row(x[0],x[1],x[2],x[3],x[4],x[5],x[6],x[7],x[8],x[9],x[10],x[11],x[12],x[13],x[14],x[15],x[16])), penschema)

In [5]:
dfpen.show()

+-----+-----+----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+
| pix1| pix2|pix3| pix4| pix5| pix6| pix7| pix8| pix9|pix10|pix11|pix12|pix13|pix14|pix15|pix16|label|
+-----+-----+----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+-----+
| 47.0|100.0|27.0| 81.0| 57.0| 37.0| 26.0|  0.0|  0.0| 23.0| 56.0| 53.0|100.0| 90.0| 40.0| 98.0|  8.0|
|  0.0| 89.0|27.0|100.0| 42.0| 75.0| 29.0| 45.0| 15.0| 15.0| 37.0|  0.0| 69.0|  2.0|100.0|  6.0|  2.0|
|  0.0| 57.0|31.0| 68.0| 72.0| 90.0|100.0|100.0| 76.0| 75.0| 50.0| 51.0| 28.0| 25.0| 16.0|  0.0|  1.0|
|  0.0|100.0| 7.0| 92.0|  5.0| 68.0| 19.0| 45.0| 86.0| 34.0|100.0| 45.0| 74.0| 23.0| 67.0|  0.0|  4.0|
|  0.0| 67.0|49.0| 83.0|100.0|100.0| 81.0| 80.0| 60.0| 60.0| 40.0| 40.0| 33.0| 20.0| 47.0|  0.0|  1.0|
|100.0|100.0|88.0| 99.0| 49.0| 74.0| 17.0| 47.0|  0.0| 16.0| 37.0|  0.0| 73.0| 16.0| 20.0| 20.0|  6.0|
|  0.0|100.0| 3.0| 72.0| 26.0| 35.0| 85.0| 35.0|100.0| 71.0| 73.0| 97.0| 

## Create dataframe with a feature vector and label

In [6]:
# Merging the data with Vector Assembler.
from pyspark.ml.feature import VectorAssembler
va = VectorAssembler(outputCol="features", inputCols=dfpen.columns[0:-1]) #except the last col.
penlpoints = va.transform(dfpen).select("features", "label")

## Split dataframe into training and test sets

In [7]:
# Create Training and Test data.
pendtsets = penlpoints.randomSplit([0.8, 0.2])
pendttrain = pendtsets[0].cache()
pendtvalid = pendtsets[1].cache()

## Create a RandomForestClassifer and build a model using training Dataset

In [8]:
# Train the model.
from pyspark.ml.classification import RandomForestClassifier
rf = RandomForestClassifier(maxDepth=20)
rfmodel = rf.fit(pendttrain)
print(rfmodel._call_java('toDebugString'))

RandomForestClassificationModel (uid=RandomForestClassifier_325d6fa69a8e) with 20 trees
  Tree 0 (weight 1.0):
    If (feature 13 <= 59.5)
     If (feature 13 <= 16.5)
      If (feature 0 <= 59.5)
       If (feature 6 <= 55.5)
        If (feature 12 <= 47.5)
         If (feature 9 <= 32.5)
          Predict: 1.0
         Else (feature 9 > 32.5)
          If (feature 14 <= 2.5)
           If (feature 8 <= 79.5)
            If (feature 2 <= 48.5)
             Predict: 3.0
            Else (feature 2 > 48.5)
             If (feature 1 <= 57.5)
              Predict: 8.0
             Else (feature 1 > 57.5)
              Predict: 5.0
           Else (feature 8 > 79.5)
            Predict: 3.0
          Else (feature 14 > 2.5)
           If (feature 10 <= 44.5)
            If (feature 12 <= 41.5)
             Predict: 1.0
            Else (feature 12 > 41.5)
             Predict: 2.0
           Else (feature 10 > 44.5)
            Predict: 8.0
        Else (feature 12 > 47.5)
         If (f

## Evaluate the model

In [9]:
rfpredicts = rfmodel.transform(pendtvalid)

In [10]:
from pyspark.mllib.evaluation import MulticlassMetrics

#prediction and label
prediction_label = rfpredicts.select("prediction", "label").rdd

metrics = MulticlassMetrics(prediction_label)

precision = metrics.precision()
recall = metrics.recall()
f1Score = metrics.fMeasure()
confusionMetrics = metrics.confusionMatrix()

print("Summary Stats")
print("Precision = %s" % precision)
print("Recall = %s" % recall)
print("F1 Score = %s" % f1Score)
print("Weighted recall = %s" % metrics.weightedRecall)
print("Weighted precision = %s" % metrics.weightedPrecision)
print("Weighted F(1) Score = %s" % metrics.weightedFMeasure())
print("Weighted F(0.5) Score = %s" % metrics.weightedFMeasure(beta=0.5))
print("Weighted false positive rate = %s" % metrics.weightedFalsePositiveRate)
print("Confusion Metrics = \n%s" % confusionMetrics)

Summary Stats
Precision = 0.9822695035460993
Recall = 0.9822695035460993
F1 Score = 0.9822695035460993
Weighted recall = 0.9822695035460993
Weighted precision = 0.982426914273226
Weighted F(1) Score = 0.9822526335301773
Weighted F(0.5) Score = 0.9823342492644922
Weighted false positive rate = 0.0020549266415841234
Confusion Metrics = 
DenseMatrix([[189.,   0.,   0.,   0.,   2.,   0.,   0.,   0.,   1.,   0.],
             [  0., 197.,  10.,   0.,   0.,   0.,   0.,   3.,   0.,   0.],
             [  0.,   3., 207.,   0.,   0.,   0.,   0.,   0.,   0.,   0.],
             [  0.,   0.,   1., 189.,   0.,   1.,   0.,   0.,   0.,   1.],
             [  0.,   0.,   0.,   0., 214.,   0.,   0.,   1.,   0.,   2.],
             [  0.,   0.,   0.,   2.,   0., 187.,   0.,   0.,   0.,   1.],
             [  0.,   0.,   0.,   0.,   0.,   1., 176.,   0.,   0.,   0.],
             [  0.,   1.,   0.,   0.,   0.,   0.,   0., 202.,   0.,   0.],
             [  0.,   0.,   0.,   0.,   0.,   0.,   0.,   0., 1