In [0]:
#Starting a spark session
from pyspark.sql import SparkSession
spark = SparkSession.builder.appName('trees').getOrCreate()

In [0]:
#Importing required libraries
from pyspark.ml import Pipeline
from pyspark.ml.classification import RandomForestClassifier, GBTClassifier, DecisionTreeClassifier
from pyspark.ml.evaluation import BinaryClassificationEvaluator, MulticlassClassificationEvaluator

In [0]:
#Loading the data
data = spark.read.format('libsvm').load('/FileStore/tables/sample_libsvm_data-1.txt')
data.show()

+-----+--------------------+
|label|            features|
+-----+--------------------+
|  0.0|(692,[127,128,129...|
|  1.0|(692,[158,159,160...|
|  1.0|(692,[124,125,126...|
|  1.0|(692,[152,153,154...|
|  1.0|(692,[151,152,153...|
|  0.0|(692,[129,130,131...|
|  1.0|(692,[158,159,160...|
|  1.0|(692,[99,100,101,...|
|  0.0|(692,[154,155,156...|
|  0.0|(692,[127,128,129...|
|  1.0|(692,[154,155,156...|
|  0.0|(692,[153,154,155...|
|  0.0|(692,[151,152,153...|
|  1.0|(692,[129,130,131...|
|  0.0|(692,[154,155,156...|
|  1.0|(692,[150,151,152...|
|  0.0|(692,[124,125,126...|
|  0.0|(692,[152,153,154...|
|  1.0|(692,[97,98,99,12...|
|  1.0|(692,[124,125,126...|
+-----+--------------------+
only showing top 20 rows



In [0]:
#Train-Test Split
train_data, test_data = data.randomSplit([0.7,0.3])

In [0]:
#Classifiers
dt = DecisionTreeClassifier()
rf = RandomForestClassifier(numTrees=100)
gb = GBTClassifier()

In [0]:
#Fitting the models
dt_model = dt.fit(train_data)
rf_model = rf.fit(train_data)
gb_model = gb.fit(train_data)

In [0]:
#Transform test data and generate predictions
dt_preds = dt_model.transform(test_data)
rf_preds = rf_model.transform(test_data)
gb_preds = gb_model.transform(test_data)

In [0]:
dt_preds.show(3)

+-----+--------------------+-------------+-----------+----------+
|label|            features|rawPrediction|probability|prediction|
+-----+--------------------+-------------+-----------+----------+
|  0.0|(692,[122,123,124...|   [28.0,0.0]|  [1.0,0.0]|       0.0|
|  0.0|(692,[122,123,148...|   [28.0,0.0]|  [1.0,0.0]|       0.0|
|  0.0|(692,[123,124,125...|   [28.0,0.0]|  [1.0,0.0]|       0.0|
+-----+--------------------+-------------+-----------+----------+
only showing top 3 rows



In [0]:
rf_preds.show(3)

+-----+--------------------+-------------+-----------+----------+
|label|            features|rawPrediction|probability|prediction|
+-----+--------------------+-------------+-----------+----------+
|  0.0|(692,[122,123,124...|   [99.0,1.0]|[0.99,0.01]|       0.0|
|  0.0|(692,[122,123,148...|  [90.0,10.0]|  [0.9,0.1]|       0.0|
|  0.0|(692,[123,124,125...|   [96.0,4.0]|[0.96,0.04]|       0.0|
+-----+--------------------+-------------+-----------+----------+
only showing top 3 rows



In [0]:
gb_preds.show(3)

+-----+--------------------+--------------------+--------------------+----------+
|label|            features|       rawPrediction|         probability|prediction|
+-----+--------------------+--------------------+--------------------+----------+
|  0.0|(692,[122,123,124...|[1.54350200272498...|[0.95635347857270...|       0.0|
|  0.0|(692,[122,123,148...|[1.54350200272498...|[0.95635347857270...|       0.0|
|  0.0|(692,[123,124,125...|[1.54350200272498...|[0.95635347857270...|       0.0|
+-----+--------------------+--------------------+--------------------+----------+
only showing top 3 rows



In [0]:
accuracy = MulticlassClassificationEvaluator(metricName='accuracy')
dt_acc = accuracy.evaluate(dt_preds)
rf_acc = accuracy.evaluate(rf_preds)
gb_acc = accuracy.evaluate(gb_preds)

print('Decision Tree Accuracy: {0}'.format(dt_acc))
print('Random Forest Accuracy: {0}'.format(rf_acc))
print('Gradient Boosting Accuracy: {0}'.format(gb_acc))

Decision Tree Accuracy: 0.8571428571428571
Random Forest Accuracy: 1.0
Gradient Boosting Accuracy: 0.8571428571428571


In [0]:
#Getting feature importance
rf_model.featureImportances

Out[16]: SparseVector(692, {99: 0.0006, 127: 0.0004, 128: 0.0003, 155: 0.0017, 156: 0.0006, 159: 0.0007, 178: 0.0019, 204: 0.0007, 206: 0.005, 215: 0.0005, 216: 0.0015, 234: 0.0016, 235: 0.0024, 236: 0.001, 237: 0.002, 238: 0.0006, 240: 0.0016, 242: 0.0007, 244: 0.009, 262: 0.0011, 266: 0.0002, 267: 0.002, 271: 0.0072, 272: 0.0004, 273: 0.0165, 274: 0.0043, 289: 0.0147, 290: 0.0152, 292: 0.0003, 295: 0.0024, 296: 0.0013, 301: 0.0079, 303: 0.0016, 317: 0.0155, 319: 0.0012, 322: 0.0081, 323: 0.015, 324: 0.0072, 326: 0.0004, 327: 0.0006, 328: 0.0065, 329: 0.0067, 330: 0.0063, 331: 0.0005, 342: 0.0017, 350: 0.0034, 351: 0.0311, 352: 0.0007, 358: 0.007, 372: 0.0092, 373: 0.0067, 377: 0.0141, 378: 0.0302, 379: 0.049, 382: 0.0015, 384: 0.0053, 387: 0.0026, 397: 0.0028, 399: 0.0067, 400: 0.0098, 401: 0.0006, 405: 0.0218, 407: 0.0203, 408: 0.0022, 411: 0.0026, 412: 0.0029, 413: 0.0089, 414: 0.0179, 425: 0.0006, 427: 0.0061, 430: 0.0029, 433: 0.0132, 434: 0.0599, 435: 0.0179, 436: 0.0004, 438: 0