In [39]:
import findspark
findspark.init('/home/anycaroliny/spark-3.3.2-bin-hadoop3')
from pyspark.sql import SparkSession
spark = SparkSession.builder.appName('mytree').getOrCreate()

In [40]:
from pyspark.ml import Pipeline

In [41]:
from pyspark.ml.classification import (RandomForestClassifier,GBTClassifier,
                                       DecisionTreeClassifier)

In [42]:
data = spark.read.format('libsvm').load('sample_libsvm_data.txt')

23/04/14 13:46:12 WARN LibSVMFileFormat: 'numFeatures' option not specified, determining the number of features by going though the input. If you know the number in advance, please specify it via 'numFeatures' option to avoid the extra scan.


In [43]:
data.show()

+-----+--------------------+
|label|            features|
+-----+--------------------+
|  0.0|(692,[127,128,129...|
|  1.0|(692,[158,159,160...|
|  1.0|(692,[124,125,126...|
|  1.0|(692,[152,153,154...|
|  1.0|(692,[151,152,153...|
|  0.0|(692,[129,130,131...|
|  1.0|(692,[158,159,160...|
|  1.0|(692,[99,100,101,...|
|  0.0|(692,[154,155,156...|
|  0.0|(692,[127,128,129...|
|  1.0|(692,[154,155,156...|
|  0.0|(692,[153,154,155...|
|  0.0|(692,[151,152,153...|
|  1.0|(692,[129,130,131...|
|  0.0|(692,[154,155,156...|
|  1.0|(692,[150,151,152...|
|  0.0|(692,[124,125,126...|
|  0.0|(692,[152,153,154...|
|  1.0|(692,[97,98,99,12...|
|  1.0|(692,[124,125,126...|
+-----+--------------------+
only showing top 20 rows



In [44]:
training_data, test_data = data.randomSplit([0.7,0.3])

In [45]:
dtc = DecisionTreeClassifier()
rfc = RandomForestClassifier(numTrees=100)
gbt = GBTClassifier()

In [46]:
dtc_model = dtc.fit(training_data)
rfc_model = rfc.fit(training_data)
gbt_model = gbt.fit(training_data)

In [47]:
dtc_preds = dtc_model.transform(test_data)
rfc_preds = rfc_model.transform(test_data)
gbt_preds = gbt_model.transform(test_data)

In [48]:
# dtc_preds.show()
# rfc_preds.show()
gbt_preds.show()

+-----+--------------------+--------------------+--------------------+----------+
|label|            features|       rawPrediction|         probability|prediction|
+-----+--------------------+--------------------+--------------------+----------+
|  0.0|(692,[100,101,102...|[1.54350200272498...|[0.95635347857270...|       0.0|
|  0.0|(692,[121,122,123...|[1.54350200272498...|[0.95635347857270...|       0.0|
|  0.0|(692,[123,124,125...|[1.54350200272498...|[0.95635347857270...|       0.0|
|  0.0|(692,[123,124,125...|[1.54350200272498...|[0.95635347857270...|       0.0|
|  0.0|(692,[123,124,125...|[1.54350200272498...|[0.95635347857270...|       0.0|
|  0.0|(692,[125,126,127...|[1.54350200272498...|[0.95635347857270...|       0.0|
|  0.0|(692,[126,127,128...|[1.54350200272498...|[0.95635347857270...|       0.0|
|  0.0|(692,[126,127,128...|[1.54350200272498...|[0.95635347857270...|       0.0|
|  0.0|(692,[129,130,131...|[1.54350200272498...|[0.95635347857270...|       0.0|
|  0.0|(692,[150

In [49]:
from pyspark.ml.evaluation import MulticlassClassificationEvaluator

In [50]:
acc_eval = MulticlassClassificationEvaluator(metricName='accuracy')

In [51]:
print('DTC ACCURACY:')
acc_eval.evaluate(dtc_preds)

DTC ACCURACY:


1.0

In [52]:
print('GBT ACCURACY:')
acc_eval.evaluate(gbt_preds)

GBT ACCURACY:


1.0

In [53]:
rfc_model.featureImportances

SparseVector(692, {99: 0.0005, 119: 0.0006, 127: 0.0005, 132: 0.001, 187: 0.0014, 202: 0.0049, 213: 0.0072, 215: 0.0022, 216: 0.0006, 217: 0.0006, 230: 0.0015, 231: 0.0016, 246: 0.0019, 262: 0.0012, 263: 0.0174, 266: 0.0001, 272: 0.0098, 286: 0.0018, 287: 0.0012, 290: 0.0143, 295: 0.0044, 298: 0.0007, 299: 0.0011, 300: 0.0089, 301: 0.0109, 303: 0.0025, 317: 0.0117, 320: 0.0006, 321: 0.0007, 323: 0.0095, 324: 0.0011, 327: 0.0007, 328: 0.0216, 329: 0.0079, 331: 0.002, 341: 0.0012, 342: 0.0069, 344: 0.0091, 345: 0.0348, 349: 0.0005, 350: 0.0176, 351: 0.0366, 352: 0.0014, 353: 0.0003, 357: 0.0135, 358: 0.0153, 359: 0.0026, 370: 0.0008, 373: 0.0096, 374: 0.0006, 377: 0.0142, 378: 0.0094, 379: 0.008, 383: 0.0005, 384: 0.0004, 385: 0.0106, 400: 0.0077, 405: 0.0195, 406: 0.0403, 407: 0.0192, 408: 0.0016, 409: 0.0006, 414: 0.0044, 415: 0.0052, 416: 0.0015, 426: 0.0033, 427: 0.0153, 428: 0.0165, 429: 0.0147, 433: 0.0262, 434: 0.0194, 437: 0.0005, 441: 0.0202, 442: 0.0011, 443: 0.005, 444: 0.0011