In [1]:
import findspark
findspark.init('C:/spark/spark-2.4.5-bin-hadoop2.7')

from pyspark.sql import SparkSession
import pyspark

In [33]:
spark = SparkSession.builder.appName('trees').getOrCreate()

In [11]:
from pyspark.ml.classification import (DecisionTreeClassifier,
                                       RandomForestClassifier, 
                                       GBTClassifier)

In [12]:
from pyspark.ml import Pipeline

In [13]:
data = spark.read.format('libsvm').load('sample_libsvm_data.txt')

In [20]:
data.show()

+-----+--------------------+
|label|            features|
+-----+--------------------+
|  0.0|(692,[127,128,129...|
|  1.0|(692,[158,159,160...|
|  1.0|(692,[124,125,126...|
|  1.0|(692,[152,153,154...|
|  1.0|(692,[151,152,153...|
|  0.0|(692,[129,130,131...|
|  1.0|(692,[158,159,160...|
|  1.0|(692,[99,100,101,...|
|  0.0|(692,[154,155,156...|
|  0.0|(692,[127,128,129...|
|  1.0|(692,[154,155,156...|
|  0.0|(692,[153,154,155...|
|  0.0|(692,[151,152,153...|
|  1.0|(692,[129,130,131...|
|  0.0|(692,[154,155,156...|
|  1.0|(692,[150,151,152...|
|  0.0|(692,[124,125,126...|
|  0.0|(692,[152,153,154...|
|  1.0|(692,[97,98,99,12...|
|  1.0|(692,[124,125,126...|
+-----+--------------------+
only showing top 20 rows



In [15]:
train_data, test_data = data.randomSplit([0.7, 0.3])

In [18]:
dtc = DecisionTreeClassifier()
rfc = RandomForestClassifier(numTrees=100)
gbt = GBTClassifier()

In [19]:
dtc_model = dtc.fit(train_data)
rfc_model = rfc.fit(train_data)
gbt_model = gbt.fit(train_data)

In [22]:
dtc_preds = dtc_model.transform(test_data)
rfc_preds = rfc_model.transform(test_data)
gbt_preds = gbt_model.transform(test_data)

In [27]:
# dtc_preds.show()
# rfc_preds.show()
# gbt_preds.show()

In [28]:
from pyspark.ml.evaluation import MulticlassClassificationEvaluator

In [29]:
acc_eval = MulticlassClassificationEvaluator(metricName='accuracy')

In [31]:
print('DTC Accuracy: ', acc_eval.evaluate(dtc_preds))
print('RFC Accuracy: ', acc_eval.evaluate(rfc_preds))
print('GBT Accuracy: ', acc_eval.evaluate(gbt_preds))

DTC Accuracy:  0.9230769230769231
RFC Accuracy:  0.9615384615384616
GBT Accuracy:  0.9230769230769231


In [32]:
rfc_model.featureImportances

SparseVector(692, {101: 0.0006, 119: 0.0005, 124: 0.0003, 148: 0.0006, 155: 0.0005, 177: 0.0002, 183: 0.0014, 184: 0.0017, 203: 0.0015, 214: 0.0004, 215: 0.0066, 241: 0.0007, 243: 0.0094, 244: 0.0115, 245: 0.0021, 260: 0.0026, 262: 0.0274, 271: 0.018, 272: 0.0193, 287: 0.0014, 289: 0.0163, 290: 0.0079, 292: 0.0004, 301: 0.0084, 317: 0.0008, 322: 0.0089, 323: 0.0176, 324: 0.001, 326: 0.0005, 327: 0.005, 328: 0.0021, 330: 0.0058, 331: 0.0006, 345: 0.0129, 347: 0.001, 350: 0.038, 351: 0.0103, 352: 0.0047, 354: 0.0011, 355: 0.0006, 371: 0.0055, 374: 0.0018, 377: 0.0095, 378: 0.054, 379: 0.0158, 380: 0.0006, 382: 0.004, 384: 0.0032, 385: 0.0145, 399: 0.005, 400: 0.0071, 401: 0.0144, 405: 0.0643, 406: 0.0082, 407: 0.0345, 408: 0.0025, 409: 0.0005, 413: 0.003, 416: 0.0023, 426: 0.0094, 427: 0.022, 428: 0.0086, 430: 0.003, 433: 0.0209, 434: 0.049, 435: 0.0072, 437: 0.0006, 441: 0.0079, 442: 0.001, 443: 0.0053, 454: 0.0022, 455: 0.0266, 461: 0.0324, 462: 0.0184, 463: 0.0103, 464: 0.0005, 466: 0