In [1]:
from pyspark.sql import SparkSession
spark = SparkSession.builder.appName('mytree').getOrCreate()

In [2]:
from pyspark.ml import Pipeline

In [3]:
from pyspark.ml.classification import RandomForestClassifier, GBTClassifier, DecisionTreeClassifier

In [4]:
data = spark.read.format('libsvm').load('sample_libsvm_data.txt')

In [8]:
print(data.count())
data.show(5)

100
+-----+--------------------+
|label|            features|
+-----+--------------------+
|  0.0|(692,[127,128,129...|
|  1.0|(692,[158,159,160...|
|  1.0|(692,[124,125,126...|
|  1.0|(692,[152,153,154...|
|  1.0|(692,[151,152,153...|
+-----+--------------------+
only showing top 5 rows



In [9]:
train_data, test_data = data.randomSplit([0.7, 0.3])

In [10]:
dtc = DecisionTreeClassifier()
rfc = RandomForestClassifier(numTrees=100)
gbt = GBTClassifier()

In [12]:
dtc_model = dtc.fit(train_data)
rfc_model = rfc.fit(train_data)
gbt_model = gbt.fit(train_data)

In [13]:
dtc_preds = dtc_model.transform(test_data)
rfc_preds = rfc_model.transform(test_data)
gbt_preds = gbt_model.transform(test_data)

In [16]:
dtc_preds.show()

+-----+--------------------+-------------+-----------+----------+
|label|            features|rawPrediction|probability|prediction|
+-----+--------------------+-------------+-----------+----------+
|  0.0|(692,[121,122,123...|   [26.0,0.0]|  [1.0,0.0]|       0.0|
|  0.0|(692,[123,124,125...|   [26.0,0.0]|  [1.0,0.0]|       0.0|
|  0.0|(692,[123,124,125...|   [26.0,0.0]|  [1.0,0.0]|       0.0|
|  0.0|(692,[124,125,126...|   [26.0,0.0]|  [1.0,0.0]|       0.0|
|  0.0|(692,[124,125,126...|   [26.0,0.0]|  [1.0,0.0]|       0.0|
|  0.0|(692,[124,125,126...|   [26.0,0.0]|  [1.0,0.0]|       0.0|
|  0.0|(692,[126,127,128...|   [26.0,0.0]|  [1.0,0.0]|       0.0|
|  0.0|(692,[126,127,128...|   [26.0,0.0]|  [1.0,0.0]|       0.0|
|  0.0|(692,[126,127,128...|   [26.0,0.0]|  [1.0,0.0]|       0.0|
|  0.0|(692,[129,130,131...|   [26.0,0.0]|  [1.0,0.0]|       0.0|
|  0.0|(692,[150,151,152...|   [26.0,0.0]|  [1.0,0.0]|       0.0|
|  0.0|(692,[151,152,153...|   [26.0,0.0]|  [1.0,0.0]|       0.0|
|  0.0|(69

In [17]:
rfc_preds.show()

+-----+--------------------+-------------+-----------+----------+
|label|            features|rawPrediction|probability|prediction|
+-----+--------------------+-------------+-----------+----------+
|  0.0|(692,[121,122,123...|   [98.0,2.0]|[0.98,0.02]|       0.0|
|  0.0|(692,[123,124,125...|  [100.0,0.0]|  [1.0,0.0]|       0.0|
|  0.0|(692,[123,124,125...|   [96.0,4.0]|[0.96,0.04]|       0.0|
|  0.0|(692,[124,125,126...|   [97.0,3.0]|[0.97,0.03]|       0.0|
|  0.0|(692,[124,125,126...|  [100.0,0.0]|  [1.0,0.0]|       0.0|
|  0.0|(692,[124,125,126...|  [83.0,17.0]|[0.83,0.17]|       0.0|
|  0.0|(692,[126,127,128...|   [95.0,5.0]|[0.95,0.05]|       0.0|
|  0.0|(692,[126,127,128...|   [97.0,3.0]|[0.97,0.03]|       0.0|
|  0.0|(692,[126,127,128...|   [91.0,9.0]|[0.91,0.09]|       0.0|
|  0.0|(692,[129,130,131...|  [83.0,17.0]|[0.83,0.17]|       0.0|
|  0.0|(692,[150,151,152...|  [85.0,15.0]|[0.85,0.15]|       0.0|
|  0.0|(692,[151,152,153...|   [99.0,1.0]|[0.99,0.01]|       0.0|
|  0.0|(69

In [18]:
gbt_preds.show()

+-----+--------------------+--------------------+--------------------+----------+
|label|            features|       rawPrediction|         probability|prediction|
+-----+--------------------+--------------------+--------------------+----------+
|  0.0|(692,[121,122,123...|[1.54350200272498...|[0.95635347857270...|       0.0|
|  0.0|(692,[123,124,125...|[1.54350200272498...|[0.95635347857270...|       0.0|
|  0.0|(692,[123,124,125...|[1.54350200272498...|[0.95635347857270...|       0.0|
|  0.0|(692,[124,125,126...|[1.54350200272498...|[0.95635347857270...|       0.0|
|  0.0|(692,[124,125,126...|[1.54350200272498...|[0.95635347857270...|       0.0|
|  0.0|(692,[124,125,126...|[1.54350200272498...|[0.95635347857270...|       0.0|
|  0.0|(692,[126,127,128...|[1.54350200272498...|[0.95635347857270...|       0.0|
|  0.0|(692,[126,127,128...|[1.54350200272498...|[0.95635347857270...|       0.0|
|  0.0|(692,[126,127,128...|[1.54350200272498...|[0.95635347857270...|       0.0|
|  0.0|(692,[129

In [19]:
from pyspark.ml.evaluation import MulticlassClassificationEvaluator

In [20]:
acc_eval = MulticlassClassificationEvaluator(metricName='accuracy')

In [22]:
print('DTC Accuracy:', acc_eval.evaluate(dtc_preds))
print('RFC Accuracy:', acc_eval.evaluate(rfc_preds))
print('GBT Accuracy:', acc_eval.evaluate(gbt_preds))

DTC Accuracy: 0.9459459459459459
RFC Accuracy: 1.0
GBT Accuracy: 0.9459459459459459


In [23]:
rfc_model.featureImportances

SparseVector(692, {119: 0.001, 155: 0.001, 158: 0.0009, 159: 0.0005, 181: 0.0015, 184: 0.0003, 190: 0.0009, 207: 0.0027, 209: 0.0018, 215: 0.0053, 216: 0.0004, 235: 0.0006, 243: 0.0021, 244: 0.0082, 258: 0.0005, 262: 0.006, 263: 0.0024, 268: 0.0013, 271: 0.0009, 273: 0.0076, 289: 0.0023, 291: 0.0021, 299: 0.004, 300: 0.0112, 301: 0.0153, 314: 0.0014, 315: 0.0012, 319: 0.0058, 322: 0.0007, 323: 0.0018, 327: 0.0053, 328: 0.0087, 329: 0.0063, 330: 0.0144, 341: 0.0007, 344: 0.0084, 345: 0.0093, 347: 0.0006, 350: 0.0282, 351: 0.0341, 353: 0.0001, 355: 0.0026, 359: 0.0027, 370: 0.0064, 373: 0.0182, 378: 0.0199, 379: 0.0258, 380: 0.0023, 381: 0.0021, 383: 0.0023, 386: 0.0154, 399: 0.0081, 400: 0.018, 401: 0.0034, 405: 0.0188, 406: 0.0271, 407: 0.018, 408: 0.0013, 409: 0.0012, 411: 0.0006, 412: 0.0071, 413: 0.0015, 414: 0.0088, 415: 0.0018, 427: 0.0075, 428: 0.0132, 429: 0.017, 430: 0.0024, 432: 0.0026, 433: 0.0268, 434: 0.0377, 435: 0.0074, 438: 0.0006, 439: 0.0022, 441: 0.017, 444: 0.0013, 4