In [21]:
from pyspark.sql import SparkSession
from pyspark.ml.feature import (VectorAssembler, VectorIndexer, OneHotEncoder, StringIndexer)
from pyspark.ml.classification import RandomForestClassifier
from pyspark.ml.evaluation import MulticlassClassificationEvaluator

In [2]:
spark = SparkSession.builder.appName("dog_food_decision_tree").getOrCreate()

24/01/13 16:57:28 WARN Utils: Your hostname, DESKTOP-0HJMRJF resolves to a loopback address: 127.0.1.1; using 172.25.120.25 instead (on interface eth0)
24/01/13 16:57:28 WARN Utils: Set SPARK_LOCAL_IP if you need to bind to another address
Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
24/01/13 16:57:29 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable


In [3]:
df = spark.read.csv("./data/ML/tree/dog_food.csv", inferSchema=True, header=True)

In [46]:
df.show()

+---+---+----+---+-------+
|  A|  B|   C|  D|Spoiled|
+---+---+----+---+-------+
|  4|  2|12.0|  3|    1.0|
|  5|  6|12.0|  7|    1.0|
|  6|  2|13.0|  6|    1.0|
|  4|  2|12.0|  1|    1.0|
|  4|  2|12.0|  3|    1.0|
| 10|  3|13.0|  9|    1.0|
|  8|  5|14.0|  5|    1.0|
|  5|  8|12.0|  8|    1.0|
|  6|  5|12.0|  9|    1.0|
|  3|  3|12.0|  1|    1.0|
|  9|  8|11.0|  3|    1.0|
|  1| 10|12.0|  3|    1.0|
|  1|  5|13.0| 10|    1.0|
|  2| 10|12.0|  6|    1.0|
|  1| 10|11.0|  4|    1.0|
|  5|  3|12.0|  2|    1.0|
|  4|  9|11.0|  8|    1.0|
|  5|  1|11.0|  1|    1.0|
|  4|  9|12.0| 10|    1.0|
|  5|  8|10.0|  9|    1.0|
+---+---+----+---+-------+


In [47]:
assembler = VectorAssembler(inputCols=['A', 'B', 'C', 'D'], outputCol='features')
output = assembler.transform(df)
output.show()

+---+---+----+---+-------+-------------------+
|  A|  B|   C|  D|Spoiled|           features|
+---+---+----+---+-------+-------------------+
|  4|  2|12.0|  3|    1.0| [4.0,2.0,12.0,3.0]|
|  5|  6|12.0|  7|    1.0| [5.0,6.0,12.0,7.0]|
|  6|  2|13.0|  6|    1.0| [6.0,2.0,13.0,6.0]|
|  4|  2|12.0|  1|    1.0| [4.0,2.0,12.0,1.0]|
|  4|  2|12.0|  3|    1.0| [4.0,2.0,12.0,3.0]|
| 10|  3|13.0|  9|    1.0|[10.0,3.0,13.0,9.0]|
|  8|  5|14.0|  5|    1.0| [8.0,5.0,14.0,5.0]|
|  5|  8|12.0|  8|    1.0| [5.0,8.0,12.0,8.0]|
|  6|  5|12.0|  9|    1.0| [6.0,5.0,12.0,9.0]|
|  3|  3|12.0|  1|    1.0| [3.0,3.0,12.0,1.0]|
|  9|  8|11.0|  3|    1.0| [9.0,8.0,11.0,3.0]|
|  1| 10|12.0|  3|    1.0|[1.0,10.0,12.0,3.0]|
|  1|  5|13.0| 10|    1.0|[1.0,5.0,13.0,10.0]|
|  2| 10|12.0|  6|    1.0|[2.0,10.0,12.0,6.0]|
|  1| 10|11.0|  4|    1.0|[1.0,10.0,11.0,4.0]|
|  5|  3|12.0|  2|    1.0| [5.0,3.0,12.0,2.0]|
|  4|  9|11.0|  8|    1.0| [4.0,9.0,11.0,8.0]|
|  5|  1|11.0|  1|    1.0| [5.0,1.0,11.0,1.0]|
|  4|  9|12.0

In [48]:
rfc = RandomForestClassifier(labelCol="Spoiled", featuresCol='features')

In [49]:
rfc_model = rfc.fit(output)

In [57]:
rfc_model

RandomForestClassificationModel: uid=RandomForestClassifier_90e3898e010c, numTrees=20, numClasses=2, numFeatures=4

In [58]:
importances = rfc_model.featureImportances
importances

SparseVector(4, {0: 0.0211, 1: 0.0155, 2: 0.9432, 3: 0.0202})

In [59]:
max_importance = max(importances)
most_important_feature_index = importances.toArray().tolist().index(max_importance)

In [60]:
most_important_feature_index

2

In [63]:
output[most_important_feature_index]

Column<'C'>