In [1]:
!pip install pyspark

In [2]:
from pyspark.sql import SparkSession
spark = SparkSession.builder.appName("Trees").getOrCreate()

In [3]:
data = spark.read.format("libsvm").load("/kaggle/input/pyspark-ml-trees/sample_libsvm_data.txt")

data.show()

In [4]:
data.printSchema()

In [5]:
train, test = data.randomSplit(weights = [0.7, 0.3], seed = 42)

## `DecisionTreeClassifier`

In [6]:
from pyspark.ml.classification import DecisionTreeClassifier

dtc = DecisionTreeClassifier(featuresCol = "features",
                             labelCol = "label",
                             predictionCol = "prediction", 
                             maxDepth = 10)

model = dtc.fit(train)

y_hat = model.transform(test)

y_hat.show(10)

In [7]:
from pyspark.ml.evaluation import MulticlassClassificationEvaluator

evaluator = MulticlassClassificationEvaluator(labelCol = "label",
                                              predictionCol = "prediction",
                                              metricName = "accuracy")

accuracy = evaluator.evaluate(y_hat)

print(accuracy)

## `RandomForestClassifier`

In [8]:
from pyspark.ml.classification import RandomForestClassifier

rfc = RandomForestClassifier(featuresCol = "features",
                             labelCol = "label",
                             predictionCol = "prediction",
                             numTrees = 100)

model = rfc.fit(train)

y_hat = model.transform(test)

y_hat.show(10)

In [9]:
evaluator = MulticlassClassificationEvaluator(labelCol = "label",
                                              predictionCol = "prediction",
                                              metricName = "accuracy")

accuracy = evaluator.evaluate(y_hat)

print(accuracy)

## `GradientBoostingClassifier`

In [10]:
from pyspark.ml.classification import GBTClassifier

gbt = GBTClassifier(featuresCol = "features",
                    labelCol = "label",
                    predictionCol = "prediction", 
                    maxIter = 100)

model = gbt.fit(train)

y_hat = model.transform(test)

y_hat.show(10)

In [11]:
evaluator = MulticlassClassificationEvaluator(labelCol = "label",
                                              predictionCol = "prediction",
                                              metricName = "accuracy")

accuracy = evaluator.evaluate(y_hat)

print(accuracy)

In [None]:
################################################################################################################################