In [1]:
!pip install pyspark

In [2]:
from pyspark.sql import SparkSession
spark = SparkSession.builder.appName("Trees").getOrCreate()

In [3]:
data = spark.read.csv(path = "../input/decisiontree/dog_food.csv",
                      inferSchema = True, header = True)

data.printSchema()

In [4]:
data.show(10)

In [5]:
from pyspark.ml.linalg import Vectors
from pyspark.ml.feature import VectorAssembler

In [6]:
assembler = VectorAssembler(inputCols = ["A", "B", "C", "D"],
                            outputCol = "features")

output = assembler.transform(data)

In [7]:
train, test = output.randomSplit(weights = [0.7, 0.3], seed = 42)

In [8]:
from pyspark.ml.classification import DecisionTreeClassifier
from pyspark.ml.classification import RandomForestClassifier
from pyspark.ml.classification import GBTClassifier

In [9]:
dtc = DecisionTreeClassifier(labelCol = "Spoiled",
                             featuresCol = "features",
                             predictionCol = "prediction")

rfc = RandomForestClassifier(labelCol = "Spoiled",
                             featuresCol = "features",
                             predictionCol = "prediction")

gbt = GBTClassifier(labelCol = "Spoiled",
                    featuresCol = "features",
                    predictionCol = "prediction")

In [10]:
dtc_model = dtc.fit(train)
rfc_model = rfc.fit(train)
gbt_model = gbt.fit(train)

In [11]:
dtc_predictions = dtc_model.transform(test)
rfc_predictions = rfc_model.transform(test)
gbt_predictions = gbt_model.transform(test)

In [12]:
from pyspark.ml.evaluation import MulticlassClassificationEvaluator

evaluator = MulticlassClassificationEvaluator(labelCol = "Spoiled",
                                              predictionCol = "prediction",
                                              metricName = "accuracy")

In [13]:
dtc_accuracy = evaluator.evaluate(dtc_predictions)
rfc_accuracy = evaluator.evaluate(rfc_predictions)
gbt_accuracy = evaluator.evaluate(gbt_predictions)

In [14]:
print("DecisionTreeClassifier: {}".format(dtc_accuracy*100))
print("-"*50)
print("RandomForestClassifier: {}".format(rfc_accuracy*100))
print("-"*50)
print("GradientBoostingClassifier: {}".format(gbt_accuracy*100))

In [15]:
dtc_model.featureImportances