In [2]:
from pyspark.sql import SparkSession
spark = SparkSession.builder.appName('tree').getOrCreate()

In [3]:
from pyspark.ml import Pipeline

In [4]:
from pyspark.ml.classification import (RandomForestClassifier,
                                       GBTClassifier,
                                       DecisionTreeClassifier)

In [7]:
# Get Data
data = spark.read.format('libsvm').load('sample_libsvm_data.txt')

In [8]:
data.show()

+-----+--------------------+
|label|            features|
+-----+--------------------+
|  0.0|(692,[127,128,129...|
|  1.0|(692,[158,159,160...|
|  1.0|(692,[124,125,126...|
|  1.0|(692,[152,153,154...|
|  1.0|(692,[151,152,153...|
|  0.0|(692,[129,130,131...|
|  1.0|(692,[158,159,160...|
|  1.0|(692,[99,100,101,...|
|  0.0|(692,[154,155,156...|
|  0.0|(692,[127,128,129...|
|  1.0|(692,[154,155,156...|
|  0.0|(692,[153,154,155...|
|  0.0|(692,[151,152,153...|
|  1.0|(692,[129,130,131...|
|  0.0|(692,[154,155,156...|
|  1.0|(692,[150,151,152...|
|  0.0|(692,[124,125,126...|
|  0.0|(692,[152,153,154...|
|  1.0|(692,[97,98,99,12...|
|  1.0|(692,[124,125,126...|
+-----+--------------------+
only showing top 20 rows



In [9]:
train_data,test_data = data.randomSplit([0.7,0.3])

In [11]:
# Initiate models
# All default columns: 'features' and 'label'
dtc = DecisionTreeClassifier()
rfc = RandomForestClassifier(numTrees=100)
gbt = GBTClassifier()

In [12]:
# Fit the models
dtcModel = dtc.fit(train_data)
rfcModel = rfc.fit(train_data)
gbtModel = gbt.fit(train_data)

In [14]:
# Obtain predictions on test data
dtc_preds = dtcModel.transform(test_data)
rfc_preds = rfcModel.transform(test_data)
gbt_preds = gbtModel.transform(test_data)

In [20]:
#dtc_preds.show()
#rfc_preds.show()
#gbt_preds.show()

In [21]:
# EVALUATE YOUR MODELS
from pyspark.ml.evaluation import MulticlassClassificationEvaluator

In [22]:
# Initiate evaluator
acc_eval = MulticlassClassificationEvaluator(metricName='accuracy')

In [25]:
print('DTC ACCURACY')
acc_eval.evaluate(dtc_preds)

DTC ACCURACY


0.9655172413793104

In [26]:
print('RFC ACCURACY')
acc_eval.evaluate(rfc_preds)

RFC ACCURACY


1.0

In [27]:
print('GBT ACCURACY')
acc_eval.evaluate(gbt_preds)

GBT ACCURACY


0.9655172413793104

In [28]:
# Grab Feature Importance on a Fitted Model
rfcModel.featureImportances

SparseVector(692, {100: 0.0017, 121: 0.001, 154: 0.0013, 156: 0.0006, 157: 0.0005, 158: 0.0003, 180: 0.0006, 184: 0.0002, 185: 0.0006, 209: 0.0015, 212: 0.0009, 214: 0.0002, 217: 0.0007, 230: 0.0005, 234: 0.0015, 236: 0.0004, 239: 0.0006, 240: 0.0008, 242: 0.0005, 243: 0.0009, 244: 0.0004, 245: 0.0064, 262: 0.016, 263: 0.0198, 271: 0.0072, 272: 0.001, 273: 0.0005, 289: 0.0065, 291: 0.0106, 292: 0.001, 295: 0.0006, 298: 0.0005, 299: 0.0028, 300: 0.0077, 301: 0.0248, 302: 0.001, 303: 0.0018, 317: 0.0141, 318: 0.0017, 322: 0.0051, 323: 0.0009, 327: 0.0021, 329: 0.01, 330: 0.0072, 331: 0.0008, 345: 0.0003, 346: 0.0006, 347: 0.0009, 350: 0.0329, 351: 0.0054, 352: 0.0006, 355: 0.0046, 356: 0.0086, 357: 0.007, 358: 0.0006, 360: 0.0006, 372: 0.0053, 373: 0.0098, 375: 0.0005, 378: 0.0268, 379: 0.0091, 383: 0.0005, 385: 0.0183, 386: 0.0028, 387: 0.0009, 405: 0.0344, 406: 0.0619, 407: 0.0007, 408: 0.0025, 409: 0.0011, 411: 0.0097, 412: 0.0069, 413: 0.0005, 424: 0.0005, 426: 0.016, 428: 0.0027, 42