In [5]:
from pyspark.sql import SparkSession
spark = SparkSession.builder.appName("my tree").getOrCreate()

In [6]:
from pyspark.ml import Pipeline
from pyspark.ml.classification import (RandomForestClassifier, GBTClassifier, DecisionTreeClassifier)

In [7]:
data = spark.read.format("libsvm").load("sample_libsvm_data.txt")

In [8]:
data.show()

+-----+--------------------+
|label|            features|
+-----+--------------------+
|  0.0|(692,[127,128,129...|
|  1.0|(692,[158,159,160...|
|  1.0|(692,[124,125,126...|
|  1.0|(692,[152,153,154...|
|  1.0|(692,[151,152,153...|
|  0.0|(692,[129,130,131...|
|  1.0|(692,[158,159,160...|
|  1.0|(692,[99,100,101,...|
|  0.0|(692,[154,155,156...|
|  0.0|(692,[127,128,129...|
|  1.0|(692,[154,155,156...|
|  0.0|(692,[153,154,155...|
|  0.0|(692,[151,152,153...|
|  1.0|(692,[129,130,131...|
|  0.0|(692,[154,155,156...|
|  1.0|(692,[150,151,152...|
|  0.0|(692,[124,125,126...|
|  0.0|(692,[152,153,154...|
|  1.0|(692,[97,98,99,12...|
|  1.0|(692,[124,125,126...|
+-----+--------------------+
only showing top 20 rows



In [9]:
train_data, test_data = data.randomSplit([0.7,0.3])

In [17]:
test_data.count(), train_data.count()

(27, 73)

In [18]:
rfc = RandomForestClassifier(numTrees=100)
gbtc = GBTClassifier()
dtc = DecisionTreeClassifier()

In [19]:
rfc_model = rfc.fit(train_data)
gbtc_model = gbtc.fit(train_data)
dtc_model = dtc.fit(train_data)

In [22]:
# Prediction on test data
rfc_preds = rfc_model.transform(test_data)
gbtc_preds = gbtc_model.transform(test_data)
dtc_preds = dtc_model.transform(test_data)

In [23]:
rfc_preds.show()

+-----+--------------------+-------------+-----------+----------+
|label|            features|rawPrediction|probability|prediction|
+-----+--------------------+-------------+-----------+----------+
|  0.0|(692,[122,123,148...|  [86.0,14.0]|[0.86,0.14]|       0.0|
|  0.0|(692,[123,124,125...|   [99.0,1.0]|[0.99,0.01]|       0.0|
|  0.0|(692,[123,124,125...|   [99.0,1.0]|[0.99,0.01]|       0.0|
|  0.0|(692,[124,125,126...|  [83.0,17.0]|[0.83,0.17]|       0.0|
|  0.0|(692,[124,125,126...|   [95.0,5.0]|[0.95,0.05]|       0.0|
|  0.0|(692,[124,125,126...|   [98.0,2.0]|[0.98,0.02]|       0.0|
|  0.0|(692,[125,126,127...|   [91.0,9.0]|[0.91,0.09]|       0.0|
|  0.0|(692,[126,127,128...|   [92.0,8.0]|[0.92,0.08]|       0.0|
|  0.0|(692,[126,127,128...|  [100.0,0.0]|  [1.0,0.0]|       0.0|
|  0.0|(692,[127,128,129...|  [100.0,0.0]|  [1.0,0.0]|       0.0|
|  0.0|(692,[127,128,129...|  [89.0,11.0]|[0.89,0.11]|       0.0|
|  0.0|(692,[127,128,129...|   [98.0,2.0]|[0.98,0.02]|       0.0|
|  0.0|(69

In [24]:
gbtc_preds.show()

+-----+--------------------+--------------------+--------------------+----------+
|label|            features|       rawPrediction|         probability|prediction|
+-----+--------------------+--------------------+--------------------+----------+
|  0.0|(692,[122,123,148...|[1.54350200272498...|[0.95635347857270...|       0.0|
|  0.0|(692,[123,124,125...|[1.54350200272498...|[0.95635347857270...|       0.0|
|  0.0|(692,[123,124,125...|[1.54350200272498...|[0.95635347857270...|       0.0|
|  0.0|(692,[124,125,126...|[1.54350200272498...|[0.95635347857270...|       0.0|
|  0.0|(692,[124,125,126...|[1.54350200272498...|[0.95635347857270...|       0.0|
|  0.0|(692,[124,125,126...|[1.54350200272498...|[0.95635347857270...|       0.0|
|  0.0|(692,[125,126,127...|[1.54350200272498...|[0.95635347857270...|       0.0|
|  0.0|(692,[126,127,128...|[1.39785221844172...|[0.94244325809208...|       0.0|
|  0.0|(692,[126,127,128...|[1.54350200272498...|[0.95635347857270...|       0.0|
|  0.0|(692,[127

In [25]:
dtc_preds.show()

+-----+--------------------+-------------+-----------+----------+
|label|            features|rawPrediction|probability|prediction|
+-----+--------------------+-------------+-----------+----------+
|  0.0|(692,[122,123,148...|   [25.0,0.0]|  [1.0,0.0]|       0.0|
|  0.0|(692,[123,124,125...|   [25.0,0.0]|  [1.0,0.0]|       0.0|
|  0.0|(692,[123,124,125...|   [25.0,0.0]|  [1.0,0.0]|       0.0|
|  0.0|(692,[124,125,126...|   [25.0,0.0]|  [1.0,0.0]|       0.0|
|  0.0|(692,[124,125,126...|   [25.0,0.0]|  [1.0,0.0]|       0.0|
|  0.0|(692,[124,125,126...|   [25.0,0.0]|  [1.0,0.0]|       0.0|
|  0.0|(692,[125,126,127...|   [25.0,0.0]|  [1.0,0.0]|       0.0|
|  0.0|(692,[126,127,128...|   [25.0,0.0]|  [1.0,0.0]|       0.0|
|  0.0|(692,[126,127,128...|   [25.0,0.0]|  [1.0,0.0]|       0.0|
|  0.0|(692,[127,128,129...|   [25.0,0.0]|  [1.0,0.0]|       0.0|
|  0.0|(692,[127,128,129...|   [25.0,0.0]|  [1.0,0.0]|       0.0|
|  0.0|(692,[127,128,129...|   [25.0,0.0]|  [1.0,0.0]|       0.0|
|  0.0|(69

In [26]:
from pyspark.ml.evaluation import MulticlassClassificationEvaluator

In [38]:
acc_aval = MulticlassClassificationEvaluator(metricName='accuracy')

In [40]:
print("Decision tree",acc_aval.evaluate(dtc_preds))
print("Random Forest tree",acc_aval.evaluate(rfc_preds))
print("Gredient Boost tree",acc_aval.evaluate(gbtc_preds))

Decision tree 0.9629629629629629
Random Forest tree 0.9629629629629629
Gredient Boost tree 0.9629629629629629


In [42]:
rfc_model.featureImportances # Input features weights

SparseVector(692, {101: 0.0003, 131: 0.0004, 156: 0.0033, 177: 0.0011, 185: 0.0005, 189: 0.0015, 208: 0.0002, 215: 0.0011, 234: 0.0045, 235: 0.0178, 243: 0.0017, 244: 0.001, 261: 0.0006, 262: 0.0214, 263: 0.0014, 271: 0.0007, 272: 0.0078, 273: 0.008, 274: 0.0036, 287: 0.0033, 290: 0.024, 291: 0.0053, 299: 0.0062, 300: 0.0166, 301: 0.011, 316: 0.0069, 317: 0.0238, 318: 0.0023, 323: 0.0007, 325: 0.0006, 328: 0.0007, 329: 0.0077, 330: 0.0038, 341: 0.0004, 343: 0.0018, 345: 0.0014, 349: 0.0003, 350: 0.0094, 351: 0.0072, 357: 0.0028, 358: 0.0164, 359: 0.0038, 372: 0.0327, 373: 0.0219, 377: 0.0055, 378: 0.0281, 383: 0.0006, 401: 0.0004, 404: 0.0008, 405: 0.0398, 406: 0.0445, 407: 0.0186, 408: 0.0006, 412: 0.0018, 425: 0.0004, 427: 0.0006, 429: 0.0011, 430: 0.0055, 432: 0.0008, 433: 0.02, 434: 0.0358, 435: 0.0105, 440: 0.0062, 442: 0.006, 454: 0.0066, 455: 0.0022, 456: 0.0007, 457: 0.036, 458: 0.0025, 460: 0.0003, 461: 0.0066, 462: 0.0408, 463: 0.0068, 469: 0.0076, 483: 0.0174, 484: 0.0101, 4