In [1]:
import findspark
findspark.init('C:\Spark\spark-3.0.1-bin-hadoop2.7')
from pyspark.sql import SparkSession

spark = SparkSession.builder.appName('mytree').getOrCreate()

In [2]:
from pyspark.ml import Pipeline

In [3]:
from pyspark.ml.classification import (RandomForestClassifier, GBTClassifier,
                                       DecisionTreeClassifier)

In [4]:
data = spark.read.format('libsvm').option("numFeatures", "692").load('sample_libsvm_data.txt')

In [5]:
train_data, test_data = data.randomSplit([0.7, 0.3])

In [6]:
dtc = DecisionTreeClassifier()
rfc = RandomForestClassifier(numTrees=100)
gbt = GBTClassifier()

In [7]:
dtc_model = dtc.fit(train_data)
rfc_model = rfc.fit(train_data)
gbt_model = gbt.fit(train_data)

In [8]:
dtc_preds = dtc_model.transform(test_data)
rfc_preds = rfc_model.transform(test_data)
gbt_preds = gbt_model.transform(test_data)

In [9]:
dtc_preds.show(5)

+-----+--------------------+-------------+-----------+----------+
|label|            features|rawPrediction|probability|prediction|
+-----+--------------------+-------------+-----------+----------+
|  0.0|(692,[100,101,102...|   [27.0,0.0]|  [1.0,0.0]|       0.0|
|  0.0|(692,[122,123,124...|   [27.0,0.0]|  [1.0,0.0]|       0.0|
|  0.0|(692,[123,124,125...|   [27.0,0.0]|  [1.0,0.0]|       0.0|
|  0.0|(692,[124,125,126...|   [27.0,0.0]|  [1.0,0.0]|       0.0|
|  0.0|(692,[124,125,126...|   [27.0,0.0]|  [1.0,0.0]|       0.0|
+-----+--------------------+-------------+-----------+----------+
only showing top 5 rows



In [10]:
rfc_preds.show(5)

+-----+--------------------+-------------+-----------+----------+
|label|            features|rawPrediction|probability|prediction|
+-----+--------------------+-------------+-----------+----------+
|  0.0|(692,[100,101,102...|  [64.0,36.0]|[0.64,0.36]|       0.0|
|  0.0|(692,[122,123,124...|  [100.0,0.0]|  [1.0,0.0]|       0.0|
|  0.0|(692,[123,124,125...|   [96.0,4.0]|[0.96,0.04]|       0.0|
|  0.0|(692,[124,125,126...|   [98.0,2.0]|[0.98,0.02]|       0.0|
|  0.0|(692,[124,125,126...|   [94.0,6.0]|[0.94,0.06]|       0.0|
+-----+--------------------+-------------+-----------+----------+
only showing top 5 rows



In [11]:
gbt_preds.show(5)

+-----+--------------------+--------------------+--------------------+----------+
|label|            features|       rawPrediction|         probability|prediction|
+-----+--------------------+--------------------+--------------------+----------+
|  0.0|(692,[100,101,102...|[1.49968587065345...|[0.95254573612173...|       0.0|
|  0.0|(692,[122,123,124...|[1.54350200272498...|[0.95635347857270...|       0.0|
|  0.0|(692,[123,124,125...|[1.54350200272498...|[0.95635347857270...|       0.0|
|  0.0|(692,[124,125,126...|[1.54350200272498...|[0.95635347857270...|       0.0|
|  0.0|(692,[124,125,126...|[1.54350200272498...|[0.95635347857270...|       0.0|
+-----+--------------------+--------------------+--------------------+----------+
only showing top 5 rows



In [12]:
from pyspark.ml.evaluation import MulticlassClassificationEvaluator

In [13]:
acc_eval = MulticlassClassificationEvaluator(metricName='accuracy')

In [14]:
print('DTC ACCURACY:')
acc_eval.evaluate(dtc_preds)

DTC ACCURACY:


0.9642857142857143

In [15]:
print('RFC ACCURACY:')
acc_eval.evaluate(rfc_preds)

RFC ACCURACY:


1.0

In [16]:
print('GBT ACCURACY:')
acc_eval.evaluate(gbt_preds)

GBT ACCURACY:


0.9642857142857143

In [17]:
rfc_model.featureImportances

SparseVector(692, {99: 0.0009, 100: 0.0005, 123: 0.0012, 128: 0.0007, 155: 0.0006, 182: 0.0002, 183: 0.0009, 214: 0.0005, 216: 0.0063, 230: 0.0007, 232: 0.0014, 235: 0.0004, 239: 0.0012, 244: 0.0311, 245: 0.0031, 262: 0.0012, 263: 0.0078, 268: 0.0016, 272: 0.0091, 273: 0.0165, 274: 0.0014, 287: 0.0019, 291: 0.0021, 295: 0.0012, 296: 0.0005, 301: 0.0276, 303: 0.0019, 320: 0.0011, 322: 0.0007, 323: 0.0103, 324: 0.0014, 328: 0.0068, 329: 0.008, 330: 0.0145, 346: 0.0005, 350: 0.0097, 351: 0.0212, 353: 0.0011, 356: 0.0087, 357: 0.0075, 358: 0.0123, 359: 0.0007, 373: 0.0135, 374: 0.002, 378: 0.0538, 379: 0.0361, 380: 0.0158, 382: 0.0013, 384: 0.001, 399: 0.0006, 402: 0.0048, 406: 0.02, 407: 0.0293, 408: 0.0005, 425: 0.0005, 427: 0.0012, 428: 0.0051, 430: 0.0026, 433: 0.0543, 434: 0.0482, 435: 0.0196, 436: 0.0015, 438: 0.0038, 440: 0.0069, 441: 0.0073, 442: 0.0086, 443: 0.0005, 454: 0.0067, 455: 0.0181, 456: 0.0104, 457: 0.0069, 458: 0.0006, 461: 0.0301, 462: 0.0291, 463: 0.0115, 465: 0.0006,