In [1]:
!pip install pyspark
!pip install pyarrow
!pip install -q findspark



In [2]:
from pyspark.sql import SparkSession
spark=SparkSession.builder.appName("trees").getOrCreate()

In [3]:
from pyspark.ml.classification import RandomForestClassifier,GBTClassifier,DecisionTreeClassifier

In [4]:
data=spark.read.format("libsvm").load("/content/sample_libsvm_data.txt")

In [5]:
data.show()

+-----+--------------------+
|label|            features|
+-----+--------------------+
|  0.0|(692,[127,128,129...|
|  1.0|(692,[158,159,160...|
|  1.0|(692,[124,125,126...|
|  1.0|(692,[152,153,154...|
|  1.0|(692,[151,152,153...|
|  0.0|(692,[129,130,131...|
|  1.0|(692,[158,159,160...|
|  1.0|(692,[99,100,101,...|
|  0.0|(692,[154,155,156...|
|  0.0|(692,[127,128,129...|
|  1.0|(692,[154,155,156...|
|  0.0|(692,[153,154,155...|
|  0.0|(692,[151,152,153...|
|  1.0|(692,[129,130,131...|
|  0.0|(692,[154,155,156...|
|  1.0|(692,[150,151,152...|
|  0.0|(692,[124,125,126...|
|  0.0|(692,[152,153,154...|
|  1.0|(692,[97,98,99,12...|
|  1.0|(692,[124,125,126...|
+-----+--------------------+
only showing top 20 rows



In [6]:
train_data,test_data=data.randomSplit([0.7,0.3])

In [7]:
dtc=DecisionTreeClassifier()
rfc=RandomForestClassifier(numTrees=100)
gbt=GBTClassifier()

In [8]:
from pyspark.ml import Pipeline

In [9]:
dtc_model=dtc.fit(train_data)
rfc_model=rfc.fit(train_data)
gbt_model=gbt.fit(train_data)

In [10]:
dtc_preds=dtc_model.transform(test_data)
rfc_preds=rfc_model.transform(test_data)
gbt_preds=gbt_model.transform(test_data)

In [11]:
dtc_preds.show()

+-----+--------------------+-------------+-----------+----------+
|label|            features|rawPrediction|probability|prediction|
+-----+--------------------+-------------+-----------+----------+
|  0.0|(692,[98,99,100,1...|    [0.0,1.0]|  [0.0,1.0]|       1.0|
|  0.0|(692,[121,122,123...|   [32.0,0.0]|  [1.0,0.0]|       0.0|
|  0.0|(692,[122,123,124...|   [32.0,0.0]|  [1.0,0.0]|       0.0|
|  0.0|(692,[123,124,125...|   [32.0,0.0]|  [1.0,0.0]|       0.0|
|  0.0|(692,[126,127,128...|   [32.0,0.0]|  [1.0,0.0]|       0.0|
|  0.0|(692,[126,127,128...|   [32.0,0.0]|  [1.0,0.0]|       0.0|
|  0.0|(692,[126,127,128...|   [32.0,0.0]|  [1.0,0.0]|       0.0|
|  0.0|(692,[153,154,155...|   [32.0,0.0]|  [1.0,0.0]|       0.0|
|  0.0|(692,[153,154,155...|   [32.0,0.0]|  [1.0,0.0]|       0.0|
|  0.0|(692,[154,155,156...|   [32.0,0.0]|  [1.0,0.0]|       0.0|
|  0.0|(692,[234,235,237...|   [32.0,0.0]|  [1.0,0.0]|       0.0|
|  1.0|(692,[119,120,121...|   [0.0,32.0]|  [0.0,1.0]|       1.0|
|  1.0|(69

In [12]:
gbt_preds.show()

+-----+--------------------+--------------------+--------------------+----------+
|label|            features|       rawPrediction|         probability|prediction|
+-----+--------------------+--------------------+--------------------+----------+
|  0.0|(692,[98,99,100,1...|[-0.6472244504525...|[0.21510074015259...|       1.0|
|  0.0|(692,[121,122,123...|[1.54350200272498...|[0.95635347857270...|       0.0|
|  0.0|(692,[122,123,124...|[1.54350200272498...|[0.95635347857270...|       0.0|
|  0.0|(692,[123,124,125...|[1.54350200272498...|[0.95635347857270...|       0.0|
|  0.0|(692,[126,127,128...|[1.27174561349051...|[0.92713503073351...|       0.0|
|  0.0|(692,[126,127,128...|[1.54350200272498...|[0.95635347857270...|       0.0|
|  0.0|(692,[126,127,128...|[1.54350200272498...|[0.95635347857270...|       0.0|
|  0.0|(692,[153,154,155...|[1.54350200272498...|[0.95635347857270...|       0.0|
|  0.0|(692,[153,154,155...|[1.54350200272498...|[0.95635347857270...|       0.0|
|  0.0|(692,[154

In [13]:
from pyspark.ml.evaluation import MulticlassClassificationEvaluator 

In [14]:
evaluator=MulticlassClassificationEvaluator(metricName="accuracy")

In [15]:
print("DTC ACC: ",evaluator.evaluate(dtc_preds))
print("RFC ACC: ",evaluator.evaluate(rfc_preds))
print("GBT ACC: ",evaluator.evaluate(gbt_preds))

DTC ACC:  0.9714285714285714
RFC ACC:  1.0
GBT ACC:  0.9714285714285714


In [16]:
rfc_model.featureImportances

SparseVector(692, {100: 0.0007, 125: 0.0002, 154: 0.0005, 160: 0.0004, 181: 0.0008, 183: 0.0005, 184: 0.0013, 207: 0.0021, 208: 0.0013, 216: 0.0072, 217: 0.0004, 234: 0.0062, 235: 0.0205, 243: 0.0074, 244: 0.0154, 259: 0.0039, 262: 0.016, 263: 0.0156, 264: 0.0021, 267: 0.0006, 271: 0.0097, 272: 0.0128, 273: 0.0005, 290: 0.0082, 291: 0.0005, 293: 0.0007, 295: 0.0023, 296: 0.0082, 299: 0.0016, 300: 0.0217, 314: 0.0011, 316: 0.0007, 317: 0.0077, 321: 0.0005, 322: 0.0068, 323: 0.0183, 328: 0.0144, 344: 0.0032, 345: 0.0076, 347: 0.0007, 349: 0.003, 350: 0.0168, 351: 0.0188, 352: 0.0012, 355: 0.0016, 357: 0.0059, 359: 0.0025, 370: 0.0027, 371: 0.0004, 373: 0.0113, 374: 0.0028, 375: 0.0007, 377: 0.0163, 378: 0.0366, 379: 0.0095, 385: 0.0069, 398: 0.0015, 400: 0.0155, 402: 0.0007, 404: 0.0007, 405: 0.0084, 406: 0.0317, 407: 0.01, 410: 0.0019, 411: 0.001, 415: 0.0014, 425: 0.0016, 426: 0.0004, 427: 0.0158, 428: 0.0084, 429: 0.0063, 431: 0.0005, 432: 0.0034, 433: 0.0141, 434: 0.0269, 435: 0.0211