In [None]:
!apt-get install openjdk-8-jdk-headless -qq > /dev/null
!wget -q https://mirrors.sonic.net/apache/spark/spark-3.1.2/spark-3.1.2-bin-hadoop3.2.tgz
!tar xzf spark-3.1.2-bin-hadoop3.2.tgz
!pip install -q findspark


import os
os.environ["JAVA_HOME"] = "/usr/lib/jvm/java-8-openjdk-amd64"
os.environ["SPARK_HOME"] = "/content/spark-3.1.2-bin-hadoop3.2"


import findspark
findspark.init()
from pyspark.sql import SparkSession
spark = SparkSession.builder.master("local[*]").getOrCreate()

In [None]:
from pyspark.ml.classification import (RandomForestClassifier, GBTClassifier,
                                       DecisionTreeClassifier)

In [None]:
df = spark.read.format('libsvm').load('/content/drive/MyDrive/Colab Notebooks/Trees/sample_libsvm_data.txt')

In [None]:
df.printSchema()

root
 |-- label: double (nullable = true)
 |-- features: vector (nullable = true)



In [None]:
df.show()

+-----+--------------------+
|label|            features|
+-----+--------------------+
|  0.0|(692,[127,128,129...|
|  1.0|(692,[158,159,160...|
|  1.0|(692,[124,125,126...|
|  1.0|(692,[152,153,154...|
|  1.0|(692,[151,152,153...|
|  0.0|(692,[129,130,131...|
|  1.0|(692,[158,159,160...|
|  1.0|(692,[99,100,101,...|
|  0.0|(692,[154,155,156...|
|  0.0|(692,[127,128,129...|
|  1.0|(692,[154,155,156...|
|  0.0|(692,[153,154,155...|
|  0.0|(692,[151,152,153...|
|  1.0|(692,[129,130,131...|
|  0.0|(692,[154,155,156...|
|  1.0|(692,[150,151,152...|
|  0.0|(692,[124,125,126...|
|  0.0|(692,[152,153,154...|
|  1.0|(692,[97,98,99,12...|
|  1.0|(692,[124,125,126...|
+-----+--------------------+
only showing top 20 rows



In [None]:
train, test = df.randomSplit([0.7, 0.3])

In [None]:
dtc = DecisionTreeClassifier()
rfc = RandomForestClassifier(numTrees = 100)
gbt = GBTClassifier()

In [None]:
dtc_model = dtc.fit(train)
rfc_model = rfc.fit(train)
gbt_model = gbt.fit(train)

In [None]:
dtc_preds = dtc_model.transform(test)
rfc_preds = rfc_model.transform(test)
gbt_preds = gbt_model.transform(test)

In [None]:
dtc_preds.show()

+-----+--------------------+-------------+-----------+----------+
|label|            features|rawPrediction|probability|prediction|
+-----+--------------------+-------------+-----------+----------+
|  0.0|(692,[95,96,97,12...|   [30.0,0.0]|  [1.0,0.0]|       0.0|
|  0.0|(692,[121,122,123...|   [30.0,0.0]|  [1.0,0.0]|       0.0|
|  0.0|(692,[123,124,125...|   [30.0,0.0]|  [1.0,0.0]|       0.0|
|  0.0|(692,[124,125,126...|   [30.0,0.0]|  [1.0,0.0]|       0.0|
|  0.0|(692,[124,125,126...|   [30.0,0.0]|  [1.0,0.0]|       0.0|
|  0.0|(692,[124,125,126...|   [30.0,0.0]|  [1.0,0.0]|       0.0|
|  0.0|(692,[126,127,128...|   [30.0,0.0]|  [1.0,0.0]|       0.0|
|  0.0|(692,[126,127,128...|   [30.0,0.0]|  [1.0,0.0]|       0.0|
|  0.0|(692,[153,154,155...|   [30.0,0.0]|  [1.0,0.0]|       0.0|
|  0.0|(692,[154,155,156...|   [30.0,0.0]|  [1.0,0.0]|       0.0|
|  0.0|(692,[154,155,156...|   [30.0,0.0]|  [1.0,0.0]|       0.0|
|  0.0|(692,[155,156,180...|   [30.0,0.0]|  [1.0,0.0]|       0.0|
|  0.0|(69

In [None]:
from pyspark.ml.evaluation import MulticlassClassificationEvaluator
evaluator = MulticlassClassificationEvaluator(metricName='accuracy')

In [None]:
print(f'DTC: {evaluator.evaluate(dtc_preds)}') 
print(f'RFC: {evaluator.evaluate(rfc_preds)}')
print(f'GBT: {evaluator.evaluate(gbt_preds)}')

DTC: 0.9714285714285714
RFC: 0.9714285714285714
GBT: 0.9714285714285714
