In [1]:
from pyspark import SparkContext
from pyspark.sql import SparkSession
from pyspark.sql.types import *
sc = SparkContext.getOrCreate()
ss = SparkSession.builder.getOrCreate()

## Create dataframe

In [2]:
#Load the data and create an RDD (16 pixels and label)
pen_raw = sc.textFile("../Data/penbased.dat", 4).map(lambda x:  x.split(", ")).map(lambda row: [float(x) for x in row])

In [3]:
#Create a DataFrame
from pyspark.sql.types import *
from pyspark.sql import Row
penschema = StructType([
    StructField("pix1",DoubleType(),True),
    StructField("pix2",DoubleType(),True),
    StructField("pix3",DoubleType(),True),
    StructField("pix4",DoubleType(),True),
    StructField("pix5",DoubleType(),True),
    StructField("pix6",DoubleType(),True),
    StructField("pix7",DoubleType(),True),
    StructField("pix8",DoubleType(),True),
    StructField("pix9",DoubleType(),True),
    StructField("pix10",DoubleType(),True),
    StructField("pix11",DoubleType(),True),
    StructField("pix12",DoubleType(),True),
    StructField("pix13",DoubleType(),True),
    StructField("pix14",DoubleType(),True),
    StructField("pix15",DoubleType(),True),
    StructField("pix16",DoubleType(),True),
    StructField("label",DoubleType(),True)
])

dfpen = ss.createDataFrame(pen_raw.map(lambda x : Row(x[0],x[1],x[2],x[3],x[4],x[5],x[6],x[7],x[8],x[9],x[10],x[11],x[12],x[13],x[14],x[15],x[16])), penschema)

## Split dataframe into training and test sets

In [4]:
# Create Training and Test data.
pendtsets = dfpen.randomSplit([0.8, 0.2])
pendttrain = pendtsets[0].cache()
pendtvalid = pendtsets[1].cache()

## Define transformer and estimator and add to a pipeline.

In [5]:
# Transformer - Vector Assembler.
from pyspark.ml.feature import VectorAssembler
va = VectorAssembler(outputCol="features", inputCols=dfpen.columns[0:-1]) #except the last col.

In [6]:
# Estimator - DecisionTreeClassifier which creates a transformer (Decision Tree Classifier model)
from pyspark.ml.classification import DecisionTreeClassifier
dt = DecisionTreeClassifier(maxDepth=20, maxBins= 32, minInstancesPerNode=1, minInfoGain = 0)

In [7]:
# Fit the pipeline to training documents.
from pyspark.ml import Pipeline
pipeline = Pipeline(stages=[va,dt])

## Fit the training dataset to pipeline and create a model

In [8]:
dtmodel = pipeline.fit(pendttrain)

## Apply the model to the training data set

In [9]:
dtpredicts = dtmodel.transform(pendtvalid)

## Evaluate the model

In [10]:
from pyspark.mllib.evaluation import MulticlassMetrics

#prediction and label
prediction_label = dtpredicts.select("prediction", "label").rdd

metrics = MulticlassMetrics(prediction_label)

precision = metrics.precision()
recall = metrics.recall()
f1Score = metrics.fMeasure()
confusionMetrics = metrics.confusionMatrix()

print("Summary Stats")
print("Precision = %s" % precision)
print("Recall = %s" % recall)
print("F1 Score = %s" % f1Score)
print("Weighted recall = %s" % metrics.weightedRecall)
print("Weighted precision = %s" % metrics.weightedPrecision)
print("Weighted F(1) Score = %s" % metrics.weightedFMeasure())
print("Weighted F(0.5) Score = %s" % metrics.weightedFMeasure(beta=0.5))
print("Weighted false positive rate = %s" % metrics.weightedFalsePositiveRate)
print("Confusion Metrics = \n%s" % confusionMetrics)

Summary Stats
Precision = 0.9589384076114171
Recall = 0.9589384076114171
F1 Score = 0.9589384076114171
Weighted recall = 0.9589384076114172
Weighted precision = 0.958984465987408
Weighted F(1) Score = 0.958922953503431
Weighted F(0.5) Score = 0.958950599535092
Weighted false positive rate = 0.004532062852722623
Confusion Metrics = 
DenseMatrix([[209.,   1.,   0.,   0.,   2.,   0.,   0.,   0.,   0.,   0.],
             [  0., 188.,   3.,   5.,   2.,   0.,   0.,   0.,   2.,   2.],
             [  0.,  12., 176.,   1.,   0.,   0.,   0.,   2.,   0.,   0.],
             [  0.,   0.,   2., 189.,   0.,   1.,   0.,   2.,   0.,   1.],
             [  0.,   0.,   0.,   0., 211.,   0.,   2.,   0.,   0.,   3.],
             [  0.,   0.,   1.,   2.,   0., 184.,   1.,   0.,   1.,   2.],
             [  1.,   0.,   1.,   0.,   1.,   0., 190.,   0.,   0.,   0.],
             [  0.,   3.,   3.,   0.,   0.,   0.,   0., 198.,   1.,   1.],
             [  2.,   0.,   0.,   0.,   0.,   2.,   1.,   1., 187.