In [1]:
import findspark
findspark.init('/home/danial/spark-3.3.2-bin-hadoop3')

In [2]:
from pyspark.sql import SparkSession

In [3]:
from pyspark.ml import Pipeline
from pyspark.ml.classification import (RandomForestClassifier,
                                       GBTClassifier,
                                       DecisionTreeClassifier)

In [6]:
spark = SparkSession.builder.appName('trees').getOrCreate()

In [7]:
path = '/home/danial/Desktop/myspark/Apache-Spark/Python-and-Spark-for-Big-Data-master/Spark_for_Machine_Learning/Tree_Methods/'

In [8]:
data = spark.read.format('libsvm').load(path + 'sample_libsvm_data.txt')

23/04/10 14:33:41 WARN LibSVMFileFormat: 'numFeatures' option not specified, determining the number of features by going though the input. If you know the number in advance, please specify it via 'numFeatures' option to avoid the extra scan.


In [9]:
data.printSchema()

root
 |-- label: double (nullable = true)
 |-- features: vector (nullable = true)



In [10]:
data.show()

+-----+--------------------+
|label|            features|
+-----+--------------------+
|  0.0|(692,[127,128,129...|
|  1.0|(692,[158,159,160...|
|  1.0|(692,[124,125,126...|
|  1.0|(692,[152,153,154...|
|  1.0|(692,[151,152,153...|
|  0.0|(692,[129,130,131...|
|  1.0|(692,[158,159,160...|
|  1.0|(692,[99,100,101,...|
|  0.0|(692,[154,155,156...|
|  0.0|(692,[127,128,129...|
|  1.0|(692,[154,155,156...|
|  0.0|(692,[153,154,155...|
|  0.0|(692,[151,152,153...|
|  1.0|(692,[129,130,131...|
|  0.0|(692,[154,155,156...|
|  1.0|(692,[150,151,152...|
|  0.0|(692,[124,125,126...|
|  0.0|(692,[152,153,154...|
|  1.0|(692,[97,98,99,12...|
|  1.0|(692,[124,125,126...|
+-----+--------------------+
only showing top 20 rows



In [11]:
train_data, test_data = data.randomSplit([0.7, 0.3])

In [12]:
train_data.describe().show()

+-------+------------------+
|summary|             label|
+-------+------------------+
|  count|                77|
|   mean|0.5454545454545454|
| stddev|0.5011947448335864|
|    min|               0.0|
|    max|               1.0|
+-------+------------------+



In [15]:
test_data.describe().show()

+-------+-------------------+
|summary|              label|
+-------+-------------------+
|  count|                 23|
|   mean| 0.6521739130434783|
| stddev|0.48698475355767396|
|    min|                0.0|
|    max|                1.0|
+-------+-------------------+



In [16]:
data.count()

100

In [18]:
dtc = DecisionTreeClassifier()
rfc = RandomForestClassifier(numTrees=100)
gbt = GBTClassifier()

In [19]:
dtc_model = dtc.fit(train_data)
rfc_model = rfc.fit(train_data)
gbt_model = gbt.fit(train_data)

                                                                                

In [21]:
dtc_preds = dtc_model.transform(test_data)
rfc_preds = rfc_model.transform(test_data)
gbt_preds = gbt_model.transform(test_data)

In [22]:
dtc_preds.show()

+-----+--------------------+-------------+-----------+----------+
|label|            features|rawPrediction|probability|prediction|
+-----+--------------------+-------------+-----------+----------+
|  0.0|(692,[95,96,97,12...|   [35.0,0.0]|  [1.0,0.0]|       0.0|
|  0.0|(692,[122,123,148...|   [35.0,0.0]|  [1.0,0.0]|       0.0|
|  0.0|(692,[123,124,125...|   [35.0,0.0]|  [1.0,0.0]|       0.0|
|  0.0|(692,[124,125,126...|   [35.0,0.0]|  [1.0,0.0]|       0.0|
|  0.0|(692,[124,125,126...|   [35.0,0.0]|  [1.0,0.0]|       0.0|
|  0.0|(692,[124,125,126...|   [35.0,0.0]|  [1.0,0.0]|       0.0|
|  0.0|(692,[126,127,128...|   [35.0,0.0]|  [1.0,0.0]|       0.0|
|  0.0|(692,[127,128,129...|   [35.0,0.0]|  [1.0,0.0]|       0.0|
|  1.0|(692,[123,124,125...|   [0.0,41.0]|  [0.0,1.0]|       1.0|
|  1.0|(692,[123,124,125...|   [0.0,41.0]|  [0.0,1.0]|       1.0|
|  1.0|(692,[124,125,126...|   [0.0,41.0]|  [0.0,1.0]|       1.0|
|  1.0|(692,[125,126,127...|   [0.0,41.0]|  [0.0,1.0]|       1.0|
|  1.0|(69

In [23]:
rfc_preds.show()

+-----+--------------------+-------------+-----------+----------+
|label|            features|rawPrediction|probability|prediction|
+-----+--------------------+-------------+-----------+----------+
|  0.0|(692,[95,96,97,12...|   [98.0,2.0]|[0.98,0.02]|       0.0|
|  0.0|(692,[122,123,148...|  [82.0,18.0]|[0.82,0.18]|       0.0|
|  0.0|(692,[123,124,125...|  [100.0,0.0]|  [1.0,0.0]|       0.0|
|  0.0|(692,[124,125,126...|   [98.0,2.0]|[0.98,0.02]|       0.0|
|  0.0|(692,[124,125,126...|  [85.0,15.0]|[0.85,0.15]|       0.0|
|  0.0|(692,[124,125,126...|   [95.0,5.0]|[0.95,0.05]|       0.0|
|  0.0|(692,[126,127,128...|   [99.0,1.0]|[0.99,0.01]|       0.0|
|  0.0|(692,[127,128,129...|   [94.0,6.0]|[0.94,0.06]|       0.0|
|  1.0|(692,[123,124,125...|  [0.0,100.0]|  [0.0,1.0]|       1.0|
|  1.0|(692,[123,124,125...|  [0.0,100.0]|  [0.0,1.0]|       1.0|
|  1.0|(692,[124,125,126...|  [0.0,100.0]|  [0.0,1.0]|       1.0|
|  1.0|(692,[125,126,127...|  [0.0,100.0]|  [0.0,1.0]|       1.0|
|  1.0|(69

In [25]:
gbt_preds.show()

+-----+--------------------+--------------------+--------------------+----------+
|label|            features|       rawPrediction|         probability|prediction|
+-----+--------------------+--------------------+--------------------+----------+
|  0.0|(692,[95,96,97,12...|[1.54350200272498...|[0.95635347857270...|       0.0|
|  0.0|(692,[122,123,148...|[1.54350200272498...|[0.95635347857270...|       0.0|
|  0.0|(692,[123,124,125...|[1.54350200272498...|[0.95635347857270...|       0.0|
|  0.0|(692,[124,125,126...|[1.54350200272498...|[0.95635347857270...|       0.0|
|  0.0|(692,[124,125,126...|[0.95256924138510...|[0.87047199588301...|       0.0|
|  0.0|(692,[124,125,126...|[1.54350200272498...|[0.95635347857270...|       0.0|
|  0.0|(692,[126,127,128...|[1.54350200272498...|[0.95635347857270...|       0.0|
|  0.0|(692,[127,128,129...|[0.91510992413135...|[0.86178791492802...|       0.0|
|  1.0|(692,[123,124,125...|[-1.5435020027249...|[0.04364652142729...|       1.0|
|  1.0|(692,[123

In [26]:
from pyspark.ml.evaluation import MulticlassClassificationEvaluator

In [27]:
acc_eval = MulticlassClassificationEvaluator(metricName='accuracy')

In [28]:
print ('DTC Accuracy:')
acc_eval.evaluate(dtc_preds)

DTC Accuracy:


1.0

In [29]:
print ('RFC Accuracy:')
acc_eval.evaluate(rfc_preds)

RFC Accuracy:


1.0

In [30]:
print ('GBT Accuracy:')
acc_eval.evaluate(gbt_preds)

GBT Accuracy:


1.0

In [31]:
rfc_model.featureImportances

SparseVector(692, {98: 0.001, 99: 0.0005, 100: 0.0004, 102: 0.0009, 128: 0.0004, 156: 0.0005, 178: 0.0005, 187: 0.0003, 215: 0.0017, 216: 0.0064, 232: 0.0002, 234: 0.0084, 242: 0.0005, 243: 0.0005, 244: 0.0153, 245: 0.0008, 263: 0.009, 264: 0.0018, 268: 0.0017, 270: 0.0004, 271: 0.0061, 272: 0.0041, 275: 0.0017, 289: 0.0158, 290: 0.0022, 291: 0.0085, 296: 0.0008, 300: 0.0066, 301: 0.0169, 314: 0.0019, 316: 0.0031, 318: 0.0033, 321: 0.0004, 322: 0.0067, 323: 0.0005, 327: 0.0018, 328: 0.0152, 329: 0.0057, 345: 0.0165, 346: 0.0063, 347: 0.0005, 350: 0.0005, 351: 0.0369, 352: 0.0008, 355: 0.0005, 356: 0.0203, 357: 0.0014, 369: 0.0004, 370: 0.0003, 373: 0.0078, 377: 0.0241, 378: 0.0307, 379: 0.0094, 380: 0.0005, 381: 0.0014, 382: 0.0016, 383: 0.0003, 384: 0.0018, 386: 0.0011, 387: 0.0031, 400: 0.022, 401: 0.0004, 402: 0.0011, 403: 0.0016, 405: 0.0151, 406: 0.0209, 407: 0.0029, 408: 0.0038, 412: 0.0094, 413: 0.0073, 416: 0.001, 426: 0.002, 427: 0.0027, 428: 0.0007, 429: 0.0108, 431: 0.0006, 