In [1]:

from pyspark.sql import SparkSession
from pyspark.ml import Pipeline
from pyspark.ml.classification import DecisionTreeClassifier
from pyspark.ml.evaluation import MulticlassClassificationEvaluator
from pyspark.ml.tuning import ParamGridBuilder, TrainValidationSplit
from pyspark.ml.feature import StandardScaler




spark = SparkSession \
    .builder \
    .appName("Spark ML App") \
     .getOrCreate()
trainingData = spark.read.format("libsvm").load("resources/pendigits")
print(trainingData.describe().toPandas().transpose())


testingData=spark.read.format("libsvm").load("resources/pendigits.t")
trainingData.show(truncate=False)
standardizer = StandardScaler(withMean=True, withStd=True,
                              inputCol='features',
                              outputCol='std_features')



dt = DecisionTreeClassifier(labelCol="label", featuresCol="std_features")
pipeline = Pipeline(stages=[standardizer, dt])


dtModel=pipeline.fit(trainingData);
dtPredictions=dtModel.transform(testingData);
dtPredictions.select("prediction", "label", "std_features").show(5)
evaluator = MulticlassClassificationEvaluator(
    labelCol="label", predictionCol="prediction", metricName="accuracy")
accuracy = evaluator.evaluate(dtPredictions)

print("Accuracy on test data = %g" % accuracy)
paramGrid = ParamGridBuilder().\
    addGrid(dt.maxDepth, [ 14,18,24,30]).\
    build()

tvs = TrainValidationSplit(estimator=pipeline,
                           estimatorParamMaps=paramGrid,
                           evaluator=evaluator,
                           # 80% of the data will be used for training, 20% for validation.
                           trainRatio=0.8)

tvsModel = tvs.fit(trainingData)

print(tvsModel.validationMetrics)
for param in paramGrid:
    print (param)


prediction = tvsModel.transform(testingData)

prediction.show(truncate=False)


             0                  1                  2    3    4
summary  count               mean             stddev  min  max
label     7494  4.430878035761943  2.876980684619264  0.0  9.0
+-----+------------------------------------------------------------------------------------------------------------------+
|label|features                                                                                                          |
+-----+------------------------------------------------------------------------------------------------------------------+
|8.0  |(16,[0,1,2,3,4,5,6,9,10,11,12,13,14,15],[47.0,100.0,27.0,81.0,57.0,37.0,26.0,23.0,56.0,53.0,100.0,90.0,40.0,98.0])|
|2.0  |(16,[1,2,3,4,5,6,7,8,9,10,12,13,14,15],[89.0,27.0,100.0,42.0,75.0,29.0,45.0,15.0,15.0,37.0,69.0,2.0,100.0,6.0])   |
|1.0  |(16,[1,2,3,4,5,6,7,8,9,10,11,12,13,14],[57.0,31.0,68.0,72.0,90.0,100.0,100.0,76.0,75.0,50.0,51.0,28.0,25.0,16.0]) |
|4.0  |(16,[1,2,3,4,5,6,7,8,9,10,11,12,13,14],[100.0,7.0,92.0,5.0,68.0,19