In [1]:
import findspark
findspark.init('/home/mysparkub/spark-3.0.0-bin-hadoop2.7')

In [10]:
from pyspark.sql import SparkSession
from pyspark.ml import Pipeline
from pyspark.ml.classification import (RandomForestClassifier, GBTClassifier,
                                      DecisionTreeClassifier)

In [3]:
spark = SparkSession.builder.appName('DecisionTree').getOrCreate()

In [5]:
data = spark.read.format('libsvm').load('files/sample_libsvm_data.txt')

In [6]:
train_data, test_data = data.randomSplit([0.7, 0.3])

In [13]:
dtc  = DecisionTreeClassifier()
rfc = RandomForestClassifier(numTrees=100)
gbt = GBTClassifier()

In [14]:
dtc_model = dtc.fit(train_data)
rfc_model = rfc.fit(train_data)
gbt_model = gbt.fit(train_data)

In [15]:
dtc_preds = dtc_model.transform(test_data)
rfc_preds = rfc_model.transform(test_data)
gbt_preds = gbt_model.transform(test_data)

In [17]:
gbt_preds.show()

+-----+--------------------+--------------------+--------------------+----------+
|label|            features|       rawPrediction|         probability|prediction|
+-----+--------------------+--------------------+--------------------+----------+
|  0.0|(692,[95,96,97,12...|[1.54350200272498...|[0.95635347857270...|       0.0|
|  0.0|(692,[100,101,102...|[-1.5435020027249...|[0.04364652142729...|       1.0|
|  0.0|(692,[121,122,123...|[1.54350200272498...|[0.95635347857270...|       0.0|
|  0.0|(692,[124,125,126...|[1.54350200272498...|[0.95635347857270...|       0.0|
|  0.0|(692,[125,126,127...|[1.54350200272498...|[0.95635347857270...|       0.0|
|  0.0|(692,[126,127,128...|[1.54350200272498...|[0.95635347857270...|       0.0|
|  0.0|(692,[126,127,128...|[1.54350200272498...|[0.95635347857270...|       0.0|
|  0.0|(692,[126,127,128...|[1.54350200272498...|[0.95635347857270...|       0.0|
|  0.0|(692,[126,127,128...|[1.54350200272498...|[0.95635347857270...|       0.0|
|  0.0|(692,[127

In [19]:
from pyspark.ml.evaluation import MulticlassClassificationEvaluator

In [20]:
acc_eval = MulticlassClassificationEvaluator(metricName='accuracy')

In [21]:
acc_eval.evaluate(dtc_preds)

0.9

In [23]:
rfc_model.featureImportances

SparseVector(692, {99: 0.001, 101: 0.0009, 122: 0.0005, 178: 0.0005, 182: 0.0005, 184: 0.0013, 205: 0.0022, 209: 0.0037, 214: 0.0003, 234: 0.0015, 239: 0.0011, 244: 0.0161, 245: 0.0053, 262: 0.0181, 271: 0.0025, 272: 0.0203, 273: 0.0021, 274: 0.0006, 288: 0.0014, 290: 0.01, 295: 0.0131, 299: 0.0016, 301: 0.0026, 314: 0.0016, 316: 0.0005, 319: 0.0006, 320: 0.0005, 322: 0.0124, 323: 0.0064, 324: 0.0063, 332: 0.0006, 345: 0.006, 350: 0.0264, 351: 0.0191, 352: 0.0021, 356: 0.0005, 357: 0.008, 370: 0.0038, 371: 0.0121, 373: 0.0017, 377: 0.0086, 378: 0.0182, 379: 0.0274, 380: 0.0048, 382: 0.0014, 383: 0.0012, 384: 0.0137, 385: 0.0018, 388: 0.0048, 401: 0.0033, 402: 0.0005, 405: 0.0574, 406: 0.03, 407: 0.0437, 408: 0.0004, 413: 0.0006, 415: 0.0043, 426: 0.0074, 428: 0.0085, 429: 0.0072, 432: 0.0012, 433: 0.0305, 434: 0.0394, 435: 0.0199, 436: 0.0017, 438: 0.0037, 440: 0.0057, 453: 0.0005, 455: 0.0228, 456: 0.0083, 457: 0.0069, 461: 0.0376, 462: 0.0098, 463: 0.0133, 464: 0.0011, 468: 0.0078, 4