In [1]:
from pyspark import SparkConf, SparkContext
from pyspark.sql import SparkSession
from pyspark.sql.functions import *
from pyspark.ml.feature import VectorAssembler, StringIndexer
from pyspark.ml import Pipeline
from pyspark.ml.classification import DecisionTreeClassifier
from pyspark.ml.evaluation import MulticlassClassificationEvaluator

In [2]:
iris_df = spark.read.csv("G:\SparkProject\Iris data\Iris Data\iris.csv", header=True)

In [3]:
iris_df.take(1)

[Row(sepal.length='5.1', sepal.width='3.5', petal.length='1.4', petal.width='.2', variety='Setosa')]

In [4]:
print(type(iris_df))

<class 'pyspark.sql.dataframe.DataFrame'>


In [5]:
iris_df = iris_df.withColumnRenamed("sepal.length","sepal_length")
iris_df = iris_df.withColumnRenamed("sepal.width","sepal_width")
iris_df = iris_df.withColumnRenamed("petal.length","petal_length")
iris_df = iris_df.withColumnRenamed("petal.width","petal_width")
iris_df = iris_df.withColumnRenamed("variety","species")

In [6]:
iris_df = iris_df.withColumn('sepal_length', iris_df.sepal_length.cast("double"))
iris_df = iris_df.withColumn('sepal_width', iris_df.sepal_width.cast("double"))
iris_df = iris_df.withColumn('petal_length', iris_df.petal_length.cast("double"))
iris_df = iris_df.withColumn('petal_width', iris_df.petal_width.cast("double"))

In [7]:
#iris_df.take(1)

In [8]:
indexer = StringIndexer(inputCol='species', outputCol='label')

In [9]:
vectorAssembler = VectorAssembler(inputCols=['sepal_length','sepal_width',
                                           'petal_length','petal_width'], outputCol='features')

In [10]:
dclassifier = DecisionTreeClassifier(labelCol='label', featuresCol='features')

In [11]:
iris_pipe = Pipeline(stages=[indexer, vectorAssembler, dclassifier])

In [12]:
(train_df, test_df) = iris_df.randomSplit([0.6, 0.4],1)

In [13]:
dt_model = iris_pipe.fit(train_df)
dt_predictions = dt_model.transform(test_df)

In [14]:
#Evaluate
dt_evaluator = MulticlassClassificationEvaluator(labelCol='label', predictionCol='prediction', metricName='accuracy')

In [16]:
dt_accuracy = dt_evaluator.evaluate(dt_predictions)

In [17]:
print(dt_accuracy)

0.9310344827586207


In [18]:
dt_predictions.select('prediction','species','label').show()

+----------+----------+-----+
|prediction|   species|label|
+----------+----------+-----+
|       1.0|    Setosa|  1.0|
|       1.0|    Setosa|  1.0|
|       1.0|    Setosa|  1.0|
|       1.0|    Setosa|  1.0|
|       0.0|Versicolor|  0.0|
|       1.0|    Setosa|  1.0|
|       1.0|    Setosa|  1.0|
|       1.0|    Setosa|  1.0|
|       1.0|    Setosa|  1.0|
|       1.0|    Setosa|  1.0|
|       0.0|Versicolor|  0.0|
|       1.0|    Setosa|  1.0|
|       1.0|    Setosa|  1.0|
|       1.0|    Setosa|  1.0|
|       1.0|    Setosa|  1.0|
|       1.0|    Setosa|  1.0|
|       1.0|    Setosa|  1.0|
|       1.0|    Setosa|  1.0|
|       1.0|    Setosa|  1.0|
|       0.0|Versicolor|  0.0|
+----------+----------+-----+
only showing top 20 rows



In [19]:
dt_predictions.groupBy('label','prediction').count().show()

+-----+----------+-----+
|label|prediction|count|
+-----+----------+-----+
|  2.0|       0.0|    3|
|  1.0|       1.0|   18|
|  2.0|       2.0|   21|
|  0.0|       0.0|   15|
|  0.0|       2.0|    1|
+-----+----------+-----+

