In [5]:
from pyspark.sql import SparkSession
from pyspark.ml import Pipeline
from pyspark.ml.classification import (RandomForestClassifier, GBTClassifier,
                                          DecisionTreeClassifier)

In [2]:
spark = SparkSession.builder.appName("app").getOrCreate()

In [7]:
data = spark.read.format("libsvm").load("data/sample_libsvm_data.txt")
data.show()

+-----+--------------------+
|label|            features|
+-----+--------------------+
|  0.0|(692,[127,128,129...|
|  1.0|(692,[158,159,160...|
|  1.0|(692,[124,125,126...|
|  1.0|(692,[152,153,154...|
|  1.0|(692,[151,152,153...|
|  0.0|(692,[129,130,131...|
|  1.0|(692,[158,159,160...|
|  1.0|(692,[99,100,101,...|
|  0.0|(692,[154,155,156...|
|  0.0|(692,[127,128,129...|
|  1.0|(692,[154,155,156...|
|  0.0|(692,[153,154,155...|
|  0.0|(692,[151,152,153...|
|  1.0|(692,[129,130,131...|
|  0.0|(692,[154,155,156...|
|  1.0|(692,[150,151,152...|
|  0.0|(692,[124,125,126...|
|  0.0|(692,[152,153,154...|
|  1.0|(692,[97,98,99,12...|
|  1.0|(692,[124,125,126...|
+-----+--------------------+
only showing top 20 rows



In [8]:
train_data, test_data = data.randomSplit([0.7, 0.3])

In [28]:
dtc = DecisionTreeClassifier()
rfc = RandomForestClassifier(numTrees=100)
gbt = GBTClassifier()

In [29]:
dtc_model = dtc.fit(train_data)
rfc_model = rfc.fit(train_data)
gbt_model = gbt.fit(train_data)

In [30]:
dtc_pred = dtc_model.transform(test_data)
rfc_pred = rfc_model.transform(test_data)
gbt_pred = gbt_model.transform(test_data)

In [31]:
dtc_pred.show()

+-----+--------------------+-------------+-----------+----------+
|label|            features|rawPrediction|probability|prediction|
+-----+--------------------+-------------+-----------+----------+
|  0.0|(692,[122,123,148...|   [33.0,0.0]|  [1.0,0.0]|       0.0|
|  0.0|(692,[123,124,125...|   [33.0,0.0]|  [1.0,0.0]|       0.0|
|  0.0|(692,[124,125,126...|   [33.0,0.0]|  [1.0,0.0]|       0.0|
|  0.0|(692,[126,127,128...|   [33.0,0.0]|  [1.0,0.0]|       0.0|
|  0.0|(692,[126,127,128...|   [33.0,0.0]|  [1.0,0.0]|       0.0|
|  0.0|(692,[128,129,130...|   [33.0,0.0]|  [1.0,0.0]|       0.0|
|  0.0|(692,[129,130,131...|   [33.0,0.0]|  [1.0,0.0]|       0.0|
|  0.0|(692,[152,153,154...|   [33.0,0.0]|  [1.0,0.0]|       0.0|
|  0.0|(692,[153,154,155...|   [33.0,0.0]|  [1.0,0.0]|       0.0|
|  0.0|(692,[153,154,155...|   [33.0,0.0]|  [1.0,0.0]|       0.0|
|  1.0|(692,[99,100,101,...|   [33.0,0.0]|  [1.0,0.0]|       0.0|
|  1.0|(692,[123,124,125...|   [0.0,45.0]|  [0.0,1.0]|       1.0|
|  1.0|(69

In [32]:
dtc_pred.show()

+-----+--------------------+-------------+-----------+----------+
|label|            features|rawPrediction|probability|prediction|
+-----+--------------------+-------------+-----------+----------+
|  0.0|(692,[122,123,148...|   [33.0,0.0]|  [1.0,0.0]|       0.0|
|  0.0|(692,[123,124,125...|   [33.0,0.0]|  [1.0,0.0]|       0.0|
|  0.0|(692,[124,125,126...|   [33.0,0.0]|  [1.0,0.0]|       0.0|
|  0.0|(692,[126,127,128...|   [33.0,0.0]|  [1.0,0.0]|       0.0|
|  0.0|(692,[126,127,128...|   [33.0,0.0]|  [1.0,0.0]|       0.0|
|  0.0|(692,[128,129,130...|   [33.0,0.0]|  [1.0,0.0]|       0.0|
|  0.0|(692,[129,130,131...|   [33.0,0.0]|  [1.0,0.0]|       0.0|
|  0.0|(692,[152,153,154...|   [33.0,0.0]|  [1.0,0.0]|       0.0|
|  0.0|(692,[153,154,155...|   [33.0,0.0]|  [1.0,0.0]|       0.0|
|  0.0|(692,[153,154,155...|   [33.0,0.0]|  [1.0,0.0]|       0.0|
|  1.0|(692,[99,100,101,...|   [33.0,0.0]|  [1.0,0.0]|       0.0|
|  1.0|(692,[123,124,125...|   [0.0,45.0]|  [0.0,1.0]|       1.0|
|  1.0|(69

In [33]:
from pyspark.ml.evaluation import MulticlassClassificationEvaluator

In [34]:
acc_eval = MulticlassClassificationEvaluator(metricName='accuracy')

In [35]:
print("DTC Accuracy")
acc_eval.evaluate(dtc_pred)

DTC Accuracy


0.9545454545454546

In [36]:
print("RFC Accuracy")
acc_eval.evaluate(rfc_pred)

RFC Accuracy


0.9545454545454546

In [38]:
 rfc_model.featureImportances

SparseVector(692, {121: 0.0005, 131: 0.0016, 151: 0.0024, 154: 0.0011, 155: 0.0009, 159: 0.0001, 185: 0.0007, 187: 0.0017, 189: 0.0008, 190: 0.0011, 214: 0.0004, 217: 0.0032, 234: 0.007, 235: 0.0005, 236: 0.0018, 243: 0.0003, 244: 0.0006, 264: 0.0001, 271: 0.0138, 273: 0.0027, 289: 0.0002, 302: 0.002, 315: 0.0005, 318: 0.0068, 319: 0.0003, 322: 0.0035, 324: 0.0006, 326: 0.0001, 327: 0.0048, 328: 0.0072, 329: 0.0058, 330: 0.0076, 341: 0.0005, 344: 0.0022, 345: 0.0091, 347: 0.0064, 349: 0.0003, 350: 0.0287, 351: 0.0283, 353: 0.0008, 357: 0.0111, 370: 0.001, 372: 0.0253, 373: 0.0082, 375: 0.003, 377: 0.0016, 378: 0.0235, 379: 0.0104, 380: 0.0019, 382: 0.0032, 383: 0.0011, 384: 0.006, 400: 0.0131, 405: 0.038, 406: 0.04, 407: 0.0239, 408: 0.0051, 411: 0.0015, 412: 0.0058, 413: 0.0064, 414: 0.0023, 415: 0.0008, 426: 0.0005, 429: 0.0083, 430: 0.001, 432: 0.0009, 433: 0.0486, 434: 0.0323, 435: 0.01, 441: 0.0017, 455: 0.0078, 456: 0.0232, 457: 0.0078, 460: 0.0012, 461: 0.0175, 462: 0.0379, 463: