In [38]:
from pyspark import SparkContext
from pyspark.sql import SparkSession
import numpy as np 

In [39]:
sc = SparkSession.builder.appName("dt").getOrCreate()

In [40]:
df = sc.read.csv("Files/iris.csv",header=True,inferSchema=True)
df.show()

+------------+-----------+------------+-----------+-------+
|sepal.length|sepal.width|petal.length|petal.width|variety|
+------------+-----------+------------+-----------+-------+
|         5.1|        3.5|         1.4|        0.2| Setosa|
|         4.9|        3.0|         1.4|        0.2| Setosa|
|         4.7|        3.2|         1.3|        0.2| Setosa|
|         4.6|        3.1|         1.5|        0.2| Setosa|
|         5.0|        3.6|         1.4|        0.2| Setosa|
|         5.4|        3.9|         1.7|        0.4| Setosa|
|         4.6|        3.4|         1.4|        0.3| Setosa|
|         5.0|        3.4|         1.5|        0.2| Setosa|
|         4.4|        2.9|         1.4|        0.2| Setosa|
|         4.9|        3.1|         1.5|        0.1| Setosa|
|         5.4|        3.7|         1.5|        0.2| Setosa|
|         4.8|        3.4|         1.6|        0.2| Setosa|
|         4.8|        3.0|         1.4|        0.1| Setosa|
|         4.3|        3.0|         1.1| 

In [44]:
df = df.withColumnRenamed('sepal.length','sepal_length') \
       .withColumnRenamed('sepal.width','sepal_width') \
       .withColumnRenamed('petal.length','petal_length') \
       .withColumnRenamed('petal.width','petal_width')
       
df.show()

+------------+-----------+------------+-----------+-------+
|sepal_length|sepal_width|petal_length|petal_width|variety|
+------------+-----------+------------+-----------+-------+
|         5.1|        3.5|         1.4|        0.2| Setosa|
|         4.9|        3.0|         1.4|        0.2| Setosa|
|         4.7|        3.2|         1.3|        0.2| Setosa|
|         4.6|        3.1|         1.5|        0.2| Setosa|
|         5.0|        3.6|         1.4|        0.2| Setosa|
|         5.4|        3.9|         1.7|        0.4| Setosa|
|         4.6|        3.4|         1.4|        0.3| Setosa|
|         5.0|        3.4|         1.5|        0.2| Setosa|
|         4.4|        2.9|         1.4|        0.2| Setosa|
|         4.9|        3.1|         1.5|        0.1| Setosa|
|         5.4|        3.7|         1.5|        0.2| Setosa|
|         4.8|        3.4|         1.6|        0.2| Setosa|
|         4.8|        3.0|         1.4|        0.1| Setosa|
|         4.3|        3.0|         1.1| 

In [45]:
df.groupBy("variety").count().show()

+----------+-----+
|   variety|count|
+----------+-----+
| Virginica|   50|
|    Setosa|   50|
|Versicolor|   50|
+----------+-----+



In [46]:
df.count()

150

In [47]:
df.printSchema()

root
 |-- sepal_length: double (nullable = true)
 |-- sepal_width: double (nullable = true)
 |-- petal_length: double (nullable = true)
 |-- petal_width: double (nullable = true)
 |-- variety: string (nullable = true)



In [49]:
df.describe().show()

24/11/01 15:50:15 WARN SparkStringUtils: Truncated the string representation of a plan since it was too large. This behavior can be adjusted by setting 'spark.sql.debug.maxToStringFields'.


+-------+------------------+-------------------+------------------+------------------+---------+
|summary|      sepal_length|        sepal_width|      petal_length|       petal_width|  variety|
+-------+------------------+-------------------+------------------+------------------+---------+
|  count|               150|                150|               150|               150|      150|
|   mean| 5.843333333333335|  3.057333333333334|3.7580000000000027| 1.199333333333334|     NULL|
| stddev|0.8280661279778637|0.43586628493669793|1.7652982332594662|0.7622376689603467|     NULL|
|    min|               4.3|                2.0|               1.0|               0.1|   Setosa|
|    max|               7.9|                4.4|               6.9|               2.5|Virginica|
+-------+------------------+-------------------+------------------+------------------+---------+



In [50]:
df.columns

['sepal_length', 'sepal_width', 'petal_length', 'petal_width', 'variety']

In [51]:
from pyspark.ml.feature import OneHotEncoder, StringIndexer

In [52]:
indexer = StringIndexer(inputCol="variety",outputCol="indexed_variety")
indexmodel = indexer.fit(df)

In [53]:
indexdf = indexmodel.transform(df)

In [54]:
indexdf.show()

+------------+-----------+------------+-----------+-------+---------------+
|sepal_length|sepal_width|petal_length|petal_width|variety|indexed_variety|
+------------+-----------+------------+-----------+-------+---------------+
|         5.1|        3.5|         1.4|        0.2| Setosa|            0.0|
|         4.9|        3.0|         1.4|        0.2| Setosa|            0.0|
|         4.7|        3.2|         1.3|        0.2| Setosa|            0.0|
|         4.6|        3.1|         1.5|        0.2| Setosa|            0.0|
|         5.0|        3.6|         1.4|        0.2| Setosa|            0.0|
|         5.4|        3.9|         1.7|        0.4| Setosa|            0.0|
|         4.6|        3.4|         1.4|        0.3| Setosa|            0.0|
|         5.0|        3.4|         1.5|        0.2| Setosa|            0.0|
|         4.4|        2.9|         1.4|        0.2| Setosa|            0.0|
|         4.9|        3.1|         1.5|        0.1| Setosa|            0.0|
|         5.

In [55]:
encoder = OneHotEncoder(inputCol="indexed_variety",outputCol="onehot_variety")
encoded_model = encoder.fit(indexdf)
encoded_df = encoded_model.transform(indexdf)
encoded_df.show()

+------------+-----------+------------+-----------+-------+---------------+--------------+
|sepal_length|sepal_width|petal_length|petal_width|variety|indexed_variety|onehot_variety|
+------------+-----------+------------+-----------+-------+---------------+--------------+
|         5.1|        3.5|         1.4|        0.2| Setosa|            0.0| (2,[0],[1.0])|
|         4.9|        3.0|         1.4|        0.2| Setosa|            0.0| (2,[0],[1.0])|
|         4.7|        3.2|         1.3|        0.2| Setosa|            0.0| (2,[0],[1.0])|
|         4.6|        3.1|         1.5|        0.2| Setosa|            0.0| (2,[0],[1.0])|
|         5.0|        3.6|         1.4|        0.2| Setosa|            0.0| (2,[0],[1.0])|
|         5.4|        3.9|         1.7|        0.4| Setosa|            0.0| (2,[0],[1.0])|
|         4.6|        3.4|         1.4|        0.3| Setosa|            0.0| (2,[0],[1.0])|
|         5.0|        3.4|         1.5|        0.2| Setosa|            0.0| (2,[0],[1.0])|

In [56]:
from pyspark.ml.feature import VectorAssembler

In [57]:
encoded_df.columns

['sepal_length',
 'sepal_width',
 'petal_length',
 'petal_width',
 'variety',
 'indexed_variety',
 'onehot_variety']

In [60]:
assembler = VectorAssembler(inputCols=['sepal_length',
                                       'sepal_width',
                                       'petal_length',
                                       'petal_width'],outputCol="features")

In [61]:
assembled_df = assembler.transform(encoded_df)
assembled_df.show()

+------------+-----------+------------+-----------+-------+---------------+--------------+-----------------+
|sepal_length|sepal_width|petal_length|petal_width|variety|indexed_variety|onehot_variety|         features|
+------------+-----------+------------+-----------+-------+---------------+--------------+-----------------+
|         5.1|        3.5|         1.4|        0.2| Setosa|            0.0| (2,[0],[1.0])|[5.1,3.5,1.4,0.2]|
|         4.9|        3.0|         1.4|        0.2| Setosa|            0.0| (2,[0],[1.0])|[4.9,3.0,1.4,0.2]|
|         4.7|        3.2|         1.3|        0.2| Setosa|            0.0| (2,[0],[1.0])|[4.7,3.2,1.3,0.2]|
|         4.6|        3.1|         1.5|        0.2| Setosa|            0.0| (2,[0],[1.0])|[4.6,3.1,1.5,0.2]|
|         5.0|        3.6|         1.4|        0.2| Setosa|            0.0| (2,[0],[1.0])|[5.0,3.6,1.4,0.2]|
|         5.4|        3.9|         1.7|        0.4| Setosa|            0.0| (2,[0],[1.0])|[5.4,3.9,1.7,0.4]|
|         4.6|     

In [62]:
assembled_df.select('indexed_variety','features').show()

+---------------+-----------------+
|indexed_variety|         features|
+---------------+-----------------+
|            0.0|[5.1,3.5,1.4,0.2]|
|            0.0|[4.9,3.0,1.4,0.2]|
|            0.0|[4.7,3.2,1.3,0.2]|
|            0.0|[4.6,3.1,1.5,0.2]|
|            0.0|[5.0,3.6,1.4,0.2]|
|            0.0|[5.4,3.9,1.7,0.4]|
|            0.0|[4.6,3.4,1.4,0.3]|
|            0.0|[5.0,3.4,1.5,0.2]|
|            0.0|[4.4,2.9,1.4,0.2]|
|            0.0|[4.9,3.1,1.5,0.1]|
|            0.0|[5.4,3.7,1.5,0.2]|
|            0.0|[4.8,3.4,1.6,0.2]|
|            0.0|[4.8,3.0,1.4,0.1]|
|            0.0|[4.3,3.0,1.1,0.1]|
|            0.0|[5.8,4.0,1.2,0.2]|
|            0.0|[5.7,4.4,1.5,0.4]|
|            0.0|[5.4,3.9,1.3,0.4]|
|            0.0|[5.1,3.5,1.4,0.3]|
|            0.0|[5.7,3.8,1.7,0.3]|
|            0.0|[5.1,3.8,1.5,0.3]|
+---------------+-----------------+
only showing top 20 rows



In [63]:
final_df = assembled_df.select('indexed_variety','features')

In [64]:
train_data, test_data = final_df.randomSplit([0.8,0.2])

In [65]:
train_data.count(), test_data.count()

(118, 32)

In [66]:
from pyspark.ml.classification import DecisionTreeClassifier
from pyspark.ml.evaluation import MulticlassClassificationEvaluator

In [67]:
df_classifier = DecisionTreeClassifier(labelCol="indexed_variety").fit(train_data)

In [69]:
df_pred = df_classifier.transform(test_data)
df_pred.show()

+---------------+-----------------+--------------+-------------+----------+
|indexed_variety|         features| rawPrediction|  probability|prediction|
+---------------+-----------------+--------------+-------------+----------+
|            0.0|[4.7,3.2,1.3,0.2]|[37.0,0.0,0.0]|[1.0,0.0,0.0]|       0.0|
|            0.0|[4.8,3.0,1.4,0.3]|[37.0,0.0,0.0]|[1.0,0.0,0.0]|       0.0|
|            0.0|[4.9,3.1,1.5,0.2]|[37.0,0.0,0.0]|[1.0,0.0,0.0]|       0.0|
|            0.0|[5.0,3.3,1.4,0.2]|[37.0,0.0,0.0]|[1.0,0.0,0.0]|       0.0|
|            0.0|[5.0,3.6,1.4,0.2]|[37.0,0.0,0.0]|[1.0,0.0,0.0]|       0.0|
|            0.0|[5.1,3.4,1.5,0.2]|[37.0,0.0,0.0]|[1.0,0.0,0.0]|       0.0|
|            0.0|[5.1,3.8,1.5,0.3]|[37.0,0.0,0.0]|[1.0,0.0,0.0]|       0.0|
|            0.0|[5.1,3.8,1.9,0.4]|[37.0,0.0,0.0]|[1.0,0.0,0.0]|       0.0|
|            0.0|[5.2,3.4,1.4,0.2]|[37.0,0.0,0.0]|[1.0,0.0,0.0]|       0.0|
|            0.0|[5.4,3.4,1.5,0.4]|[37.0,0.0,0.0]|[1.0,0.0,0.0]|       0.0|
|           

In [76]:
accruacy = MulticlassClassificationEvaluator(labelCol="indexed_variety",
                                             metricName="accuracy").evaluate(df_pred)

precision = MulticlassClassificationEvaluator(labelCol="indexed_variety",
                                             metricName="weightedPrecision").evaluate(df_pred)

In [77]:
accruacy,precision

(0.90625, 0.9090277777777778)

In [78]:
df_classifier.featureImportances

SparseVector(4, {2: 0.5367, 3: 0.4633})

In [79]:
auc = MulticlassClassificationEvaluator(labelCol="indexed_variety").evaluate(df_pred)

In [80]:
auc

0.90625