In [None]:
!pip install pyspark

In [None]:
from pyspark.sql import SparkSession
spark = SparkSession.builder.appName("Trees").getOrCreate()

In [None]:
data = spark.read.format("libsvm").load("../data/sample_libsvm_data.txt")

data.show()

In [None]:
data.printSchema()

In [None]:
train, test = data.randomSplit(weights = [0.7, 0.3], seed = 42)

## DecisionTreeClassifier

In [None]:
from pyspark.ml.classification import DecisionTreeClassifier

dtc = DecisionTreeClassifier(featuresCol = "features",
                             labelCol = "label",
                             predictionCol = "prediction", 
                             maxDepth = 10)

model = dtc.fit(train)

y_hat = model.transform(test)

y_hat.show(10)

In [None]:
from pyspark.ml.evaluation import MulticlassClassificationEvaluator

evaluator = MulticlassClassificationEvaluator(labelCol = "label",
                                              predictionCol = "prediction",
                                              metricName = "accuracy")

accuracy = evaluator.evaluate(y_hat)

print(accuracy)

## RandomForestClassifier

In [None]:
from pyspark.ml.classification import RandomForestClassifier

rfc = RandomForestClassifier(featuresCol = "features",
                             labelCol = "label",
                             predictionCol = "prediction",
                             numTrees = 100)

model = rfc.fit(train)

y_hat = model.transform(test)

y_hat.show(10)

In [None]:
evaluator = MulticlassClassificationEvaluator(labelCol = "label",
                                              predictionCol = "prediction",
                                              metricName = "accuracy")

accuracy = evaluator.evaluate(y_hat)

print(accuracy)

## GradientBoostingClassifier

In [None]:
from pyspark.ml.classification import GBTClassifier

gbt = GBTClassifier(featuresCol = "features",
                    labelCol = "label",
                    predictionCol = "prediction", 
                    maxIter = 100)

model = gbt.fit(train)

y_hat = model.transform(test)

y_hat.show(10)

In [None]:
evaluator = MulticlassClassificationEvaluator(labelCol = "label",
                                              predictionCol = "prediction",
                                              metricName = "accuracy")

accuracy = evaluator.evaluate(y_hat)

print(accuracy)

In [None]:
################################################################################################################################