In [29]:
from pyspark.sql import SparkSession

In [30]:
spark = SparkSession.builder.appName('mytree').getOrCreate()

In [31]:
from pyspark.ml import Pipeline
from pyspark.ml.classification import DecisionTreeClassifier,RandomForestClassifier,GBTClassifier

In [32]:
data = spark.read.format('libsvm').load('sample_libsvm_data.txt')

In [33]:
data.show()

+-----+--------------------+
|label|            features|
+-----+--------------------+
|  0.0|(692,[127,128,129...|
|  1.0|(692,[158,159,160...|
|  1.0|(692,[124,125,126...|
|  1.0|(692,[152,153,154...|
|  1.0|(692,[151,152,153...|
|  0.0|(692,[129,130,131...|
|  1.0|(692,[158,159,160...|
|  1.0|(692,[99,100,101,...|
|  0.0|(692,[154,155,156...|
|  0.0|(692,[127,128,129...|
|  1.0|(692,[154,155,156...|
|  0.0|(692,[153,154,155...|
|  0.0|(692,[151,152,153...|
|  1.0|(692,[129,130,131...|
|  0.0|(692,[154,155,156...|
|  1.0|(692,[150,151,152...|
|  0.0|(692,[124,125,126...|
|  0.0|(692,[152,153,154...|
|  1.0|(692,[97,98,99,12...|
|  1.0|(692,[124,125,126...|
+-----+--------------------+
only showing top 20 rows



In [34]:
train_data,test_data = data.randomSplit([0.7,0.3],seed=4)

In [35]:
dtc = DecisionTreeClassifier()
rfc = RandomForestClassifier(numTrees=100)
gbt = GBTClassifier()

In [36]:
dtc_model = dtc.fit(train_data)
rfc_model = rfc.fit(train_data)
gbt_model = rfc.fit(train_data)

In [37]:
dtc_preds = dtc_model.transform(test_data)
rfc_preds = rfc_model.transform(test_data)
gbt_preds = gbt_model.transform(test_data)

In [38]:
dtc_preds.show()

+-----+--------------------+-------------+-----------+----------+
|label|            features|rawPrediction|probability|prediction|
+-----+--------------------+-------------+-----------+----------+
|  0.0|(692,[95,96,97,12...|   [28.0,0.0]|  [1.0,0.0]|       0.0|
|  0.0|(692,[123,124,125...|   [28.0,0.0]|  [1.0,0.0]|       0.0|
|  0.0|(692,[124,125,126...|   [28.0,0.0]|  [1.0,0.0]|       0.0|
|  0.0|(692,[124,125,126...|   [28.0,0.0]|  [1.0,0.0]|       0.0|
|  0.0|(692,[126,127,128...|   [28.0,0.0]|  [1.0,0.0]|       0.0|
|  0.0|(692,[126,127,128...|   [28.0,0.0]|  [1.0,0.0]|       0.0|
|  0.0|(692,[126,127,128...|   [28.0,0.0]|  [1.0,0.0]|       0.0|
|  0.0|(692,[127,128,129...|   [28.0,0.0]|  [1.0,0.0]|       0.0|
|  0.0|(692,[127,128,129...|   [28.0,0.0]|  [1.0,0.0]|       0.0|
|  0.0|(692,[151,152,153...|   [28.0,0.0]|  [1.0,0.0]|       0.0|
|  0.0|(692,[152,153,154...|   [28.0,0.0]|  [1.0,0.0]|       0.0|
|  0.0|(692,[152,153,154...|   [28.0,0.0]|  [1.0,0.0]|       0.0|
|  0.0|(69

In [39]:
rfc_preds.show()

+-----+--------------------+-------------+-----------+----------+
|label|            features|rawPrediction|probability|prediction|
+-----+--------------------+-------------+-----------+----------+
|  0.0|(692,[95,96,97,12...|   [96.0,4.0]|[0.96,0.04]|       0.0|
|  0.0|(692,[123,124,125...|   [99.0,1.0]|[0.99,0.01]|       0.0|
|  0.0|(692,[124,125,126...|  [87.0,13.0]|[0.87,0.13]|       0.0|
|  0.0|(692,[124,125,126...|   [98.0,2.0]|[0.98,0.02]|       0.0|
|  0.0|(692,[126,127,128...|  [89.0,11.0]|[0.89,0.11]|       0.0|
|  0.0|(692,[126,127,128...|   [97.0,3.0]|[0.97,0.03]|       0.0|
|  0.0|(692,[126,127,128...|  [100.0,0.0]|  [1.0,0.0]|       0.0|
|  0.0|(692,[127,128,129...|   [98.0,2.0]|[0.98,0.02]|       0.0|
|  0.0|(692,[127,128,129...|   [94.0,6.0]|[0.94,0.06]|       0.0|
|  0.0|(692,[151,152,153...|   [97.0,3.0]|[0.97,0.03]|       0.0|
|  0.0|(692,[152,153,154...|   [96.0,4.0]|[0.96,0.04]|       0.0|
|  0.0|(692,[152,153,154...|   [91.0,9.0]|[0.91,0.09]|       0.0|
|  0.0|(69

In [40]:
gbt_preds.show()

+-----+--------------------+-------------+-----------+----------+
|label|            features|rawPrediction|probability|prediction|
+-----+--------------------+-------------+-----------+----------+
|  0.0|(692,[95,96,97,12...|   [96.0,4.0]|[0.96,0.04]|       0.0|
|  0.0|(692,[123,124,125...|   [99.0,1.0]|[0.99,0.01]|       0.0|
|  0.0|(692,[124,125,126...|  [87.0,13.0]|[0.87,0.13]|       0.0|
|  0.0|(692,[124,125,126...|   [98.0,2.0]|[0.98,0.02]|       0.0|
|  0.0|(692,[126,127,128...|  [89.0,11.0]|[0.89,0.11]|       0.0|
|  0.0|(692,[126,127,128...|   [97.0,3.0]|[0.97,0.03]|       0.0|
|  0.0|(692,[126,127,128...|  [100.0,0.0]|  [1.0,0.0]|       0.0|
|  0.0|(692,[127,128,129...|   [98.0,2.0]|[0.98,0.02]|       0.0|
|  0.0|(692,[127,128,129...|   [94.0,6.0]|[0.94,0.06]|       0.0|
|  0.0|(692,[151,152,153...|   [97.0,3.0]|[0.97,0.03]|       0.0|
|  0.0|(692,[152,153,154...|   [96.0,4.0]|[0.96,0.04]|       0.0|
|  0.0|(692,[152,153,154...|   [91.0,9.0]|[0.91,0.09]|       0.0|
|  0.0|(69

In [41]:
from pyspark.ml.evaluation import MulticlassClassificationEvaluator

In [42]:
acc_eval = MulticlassClassificationEvaluator(metricName='accuracy')

In [43]:
print('DTC ACCURACY: ')
acc_eval.evaluate(dtc_preds)

DTC ACCURACY: 


0.96875

In [44]:
print('RFC ACCURACY: ')
acc_eval.evaluate(rfc_preds)

RFC ACCURACY: 


1.0

In [45]:
print('GBT ACCURACY: ')
acc_eval.evaluate(gbt_preds)

GBT ACCURACY: 


1.0

In [46]:
rfc_model.featureImportances

SparseVector(692, {100: 0.0011, 101: 0.0013, 120: 0.0005, 149: 0.0013, 155: 0.0006, 176: 0.0005, 181: 0.0006, 183: 0.001, 184: 0.0019, 208: 0.0132, 209: 0.002, 211: 0.0013, 213: 0.0014, 215: 0.0005, 235: 0.0058, 244: 0.0106, 257: 0.0006, 263: 0.0192, 265: 0.0015, 272: 0.0167, 273: 0.008, 286: 0.0013, 289: 0.0059, 295: 0.0079, 300: 0.0068, 303: 0.001, 304: 0.0005, 317: 0.0019, 318: 0.001, 323: 0.0134, 324: 0.0006, 325: 0.0004, 329: 0.0157, 345: 0.0074, 349: 0.0019, 350: 0.017, 351: 0.0126, 352: 0.0018, 356: 0.0075, 357: 0.0092, 358: 0.002, 369: 0.0009, 370: 0.0005, 373: 0.0007, 374: 0.0006, 376: 0.0012, 377: 0.0002, 378: 0.041, 379: 0.0301, 384: 0.0082, 385: 0.0067, 386: 0.0064, 397: 0.0044, 400: 0.0081, 403: 0.0023, 405: 0.0511, 406: 0.02, 407: 0.023, 411: 0.0051, 412: 0.0006, 414: 0.0015, 425: 0.0005, 426: 0.0017, 427: 0.0111, 428: 0.0032, 429: 0.0071, 430: 0.005, 431: 0.0004, 433: 0.0545, 434: 0.037, 435: 0.002, 439: 0.0051, 440: 0.0113, 441: 0.0075, 453: 0.0011, 455: 0.0094, 456: 0.