In [1]:
import findspark
findspark.init()

In [2]:
from pyspark.sql import SparkSession

In [3]:
spark = SparkSession.builder.appName('TreeMethods').getOrCreate()

In [4]:
from pyspark.ml import pipeline
from pyspark.ml.classification import (RandomForestClassifier,
                                      GBTClassifier,
                                      DecisionTreeClassifier)

In [5]:
data = spark.read.format('libsvm').load('../../SparkData/sample_libsvm_data.txt')

In [6]:
train_data, test_data = data.randomSplit([0.7, 0.3])

In [8]:
dtc = DecisionTreeClassifier()
rfc = RandomForestClassifier(numTrees=100)
gbt = GBTClassifier()

In [9]:
dtc_model = dtc.fit(train_data)
rfc_model = rfc.fit(train_data)
gbt_model = gbt.fit(train_data)

In [10]:
dtc_preds = dtc_model.transform(test_data)
rfc_preds = rfc_model.transform(test_data)
gbt_preds = gbt_model.transform(test_data)

In [12]:
dtc_preds.show()

+-----+--------------------+-------------+-----------+----------+
|label|            features|rawPrediction|probability|prediction|
+-----+--------------------+-------------+-----------+----------+
|  0.0|(692,[98,99,100,1...|    [0.0,1.0]|  [0.0,1.0]|       1.0|
|  0.0|(692,[100,101,102...|   [30.0,0.0]|  [1.0,0.0]|       0.0|
|  0.0|(692,[124,125,126...|   [30.0,0.0]|  [1.0,0.0]|       0.0|
|  0.0|(692,[124,125,126...|   [30.0,0.0]|  [1.0,0.0]|       0.0|
|  0.0|(692,[124,125,126...|   [30.0,0.0]|  [1.0,0.0]|       0.0|
|  0.0|(692,[126,127,128...|   [30.0,0.0]|  [1.0,0.0]|       0.0|
|  0.0|(692,[126,127,128...|   [30.0,0.0]|  [1.0,0.0]|       0.0|
|  0.0|(692,[126,127,128...|   [30.0,0.0]|  [1.0,0.0]|       0.0|
|  0.0|(692,[128,129,130...|   [30.0,0.0]|  [1.0,0.0]|       0.0|
|  0.0|(692,[129,130,131...|   [30.0,0.0]|  [1.0,0.0]|       0.0|
|  0.0|(692,[151,152,153...|   [30.0,0.0]|  [1.0,0.0]|       0.0|
|  0.0|(692,[154,155,156...|   [30.0,0.0]|  [1.0,0.0]|       0.0|
|  0.0|(69

In [13]:
# with multiclass evaluator you can access more metrics
# as opposed to binary evaluation
from pyspark.ml.evaluation import MulticlassClassificationEvaluator

In [14]:
acc_eval = MulticlassClassificationEvaluator(metricName='accuracy')

In [16]:
print('DTC Accuracy:')
acc_eval.evaluate(dtc_preds)

DTC Accuracy:


0.9615384615384616

In [17]:
print('RFC Accuracy:')
acc_eval.evaluate(rfc_preds)

RFC Accuracy:


1.0

In [18]:
print('GBT Accuracy:')
acc_eval.evaluate(gbt_preds)

GBT Accuracy:


0.9615384615384616

In [19]:
rfc_model.featureImportances

SparseVector(692, {99: 0.0006, 101: 0.0011, 123: 0.0003, 127: 0.0005, 131: 0.0005, 132: 0.0005, 154: 0.001, 182: 0.0016, 183: 0.0046, 208: 0.0063, 209: 0.0075, 213: 0.0004, 217: 0.0017, 234: 0.0009, 235: 0.0138, 239: 0.0006, 240: 0.0006, 243: 0.0011, 244: 0.0056, 245: 0.0015, 263: 0.01, 264: 0.0006, 265: 0.0002, 272: 0.0238, 273: 0.007, 291: 0.0052, 292: 0.0014, 293: 0.0005, 301: 0.0088, 302: 0.0016, 303: 0.0006, 315: 0.0006, 317: 0.0073, 318: 0.0028, 319: 0.0005, 320: 0.001, 322: 0.0013, 323: 0.0073, 327: 0.0012, 328: 0.0177, 330: 0.0079, 344: 0.0056, 345: 0.0016, 350: 0.0188, 351: 0.0372, 352: 0.0039, 356: 0.0147, 358: 0.0027, 359: 0.0013, 370: 0.0017, 371: 0.0019, 374: 0.0011, 377: 0.0136, 378: 0.0262, 379: 0.0285, 380: 0.0036, 383: 0.001, 384: 0.0148, 385: 0.0076, 386: 0.0045, 387: 0.0013, 397: 0.0019, 398: 0.0008, 399: 0.0009, 400: 0.0065, 405: 0.0032, 406: 0.025, 407: 0.0409, 412: 0.0021, 416: 0.0035, 427: 0.0066, 428: 0.0011, 429: 0.0063, 430: 0.0028, 431: 0.0011, 433: 0.003, 43