In [1]:
from pyspark.sql import SparkSession

In [2]:
spark = SparkSession.builder.appName('tree').getOrCreate()

In [3]:
from pyspark.ml import Pipeline

In [4]:
from pyspark.ml.classification import RandomForestClassifier, GBTClassifier, DecisionTreeClassifier

In [5]:
data = spark.read.format("libsvm").load("../data/sample_libsvm_data.txt")

In [6]:
data.show()

+-----+--------------------+
|label|            features|
+-----+--------------------+
|  0.0|(692,[127,128,129...|
|  1.0|(692,[158,159,160...|
|  1.0|(692,[124,125,126...|
|  1.0|(692,[152,153,154...|
|  1.0|(692,[151,152,153...|
|  0.0|(692,[129,130,131...|
|  1.0|(692,[158,159,160...|
|  1.0|(692,[99,100,101,...|
|  0.0|(692,[154,155,156...|
|  0.0|(692,[127,128,129...|
|  1.0|(692,[154,155,156...|
|  0.0|(692,[153,154,155...|
|  0.0|(692,[151,152,153...|
|  1.0|(692,[129,130,131...|
|  0.0|(692,[154,155,156...|
|  1.0|(692,[150,151,152...|
|  0.0|(692,[124,125,126...|
|  0.0|(692,[152,153,154...|
|  1.0|(692,[97,98,99,12...|
|  1.0|(692,[124,125,126...|
+-----+--------------------+
only showing top 20 rows



In [7]:
(train_data, test_data) = data.randomSplit([0.7, 0.3])

In [8]:
dtc = DecisionTreeClassifier()
rfc = RandomForestClassifier(numTrees=100)
gbt = GBTClassifier()

In [9]:
dtc_model = dtc.fit(train_data)
rfc_model = rfc.fit(train_data)
gbt_model = gbt.fit(train_data)

In [10]:
dtc_preds = dtc_model.transform(test_data)
rfc_preds = rfc_model.transform(test_data)
gbt_preds = gbt_model.transform(test_data)

In [11]:
from pyspark.ml.evaluation import MulticlassClassificationEvaluator

In [12]:
acc_eval = MulticlassClassificationEvaluator(metricName='accuracy')

In [13]:
print('DTC ACCURACY:')
acc_eval.evaluate(dtc_preds)

DTC ACCURACY:


0.967741935483871

In [14]:
print('RFC ACCURACY:')
acc_eval.evaluate(rfc_preds)

RFC ACCURACY:


1.0

In [15]:
print('GBT ACCURACY:')
acc_eval.evaluate(gbt_preds)

GBT ACCURACY:


0.967741935483871

In [16]:
rfc_model.featureImportances

SparseVector(692, {100: 0.0006, 101: 0.0013, 126: 0.0007, 128: 0.0003, 147: 0.0018, 149: 0.0003, 156: 0.0001, 158: 0.0004, 175: 0.0024, 182: 0.0001, 204: 0.0009, 211: 0.0009, 212: 0.0001, 215: 0.0023, 235: 0.0066, 238: 0.0008, 240: 0.0005, 242: 0.0026, 243: 0.0006, 244: 0.008, 245: 0.0005, 263: 0.003, 266: 0.0025, 268: 0.0012, 271: 0.011, 272: 0.0092, 274: 0.0019, 290: 0.0045, 291: 0.005, 295: 0.0006, 299: 0.0077, 301: 0.0063, 302: 0.0048, 317: 0.004, 318: 0.0068, 319: 0.008, 322: 0.0032, 323: 0.0015, 324: 0.0006, 330: 0.0067, 331: 0.0006, 341: 0.001, 343: 0.0044, 344: 0.0016, 347: 0.0011, 348: 0.0004, 350: 0.0151, 351: 0.0106, 352: 0.001, 354: 0.0027, 356: 0.0027, 357: 0.015, 373: 0.001, 375: 0.0005, 378: 0.0461, 379: 0.0036, 380: 0.0006, 382: 0.0024, 384: 0.0063, 385: 0.0012, 388: 0.0019, 402: 0.0075, 405: 0.0317, 406: 0.0495, 407: 0.0239, 408: 0.0006, 411: 0.0031, 415: 0.0068, 426: 0.005, 429: 0.0067, 433: 0.0596, 434: 0.0267, 435: 0.0087, 436: 0.0005, 438: 0.0029, 440: 0.0064, 453: