In [2]:
from __future__ import print_function
import findspark
findspark.init()
findspark.find()
import pyspark
findspark.find()
from pyspark.sql import SparkSession
from pyspark.ml.feature import VectorAssembler
from pyspark.ml.classification import DecisionTreeClassifier
from pyspark.ml.feature import StandardScaler
from pyspark.ml.evaluation import MulticlassClassificationEvaluator
from pyspark.ml.feature import StringIndexer

In [3]:
if __name__ == "__main__":
    spark = SparkSession\
        .builder\
        .appName("DecisionTreeWithSpark")\
        .getOrCreate()

In [4]:
dataset = spark.read.csv("winequality_red.csv",header=True)

In [5]:
type(dataset)

pyspark.sql.dataframe.DataFrame

In [6]:
dataset.show()

+-------------+----------------+-----------+--------------+-------------------+-------------------+--------------------+-------+----+---------+-------+-------+
|fixed acidity|volatile acidity|citric acid|residual sugar|          chlorides|free sulfur dioxide|total sulfur dioxide|density|  pH|sulphates|alcohol|quality|
+-------------+----------------+-----------+--------------+-------------------+-------------------+--------------------+-------+----+---------+-------+-------+
|          7.4|             0.7|        0.0|           1.9|              0.076|               11.0|                34.0| 0.9978|3.51|     0.56|    9.4|      5|
|          7.8|            0.88|        0.0|           2.6|              0.098|               25.0|                67.0| 0.9968| 3.2|     0.68|    9.8|      5|
|          7.8|            0.76|       0.04|           2.3|              0.092|               15.0|                54.0|  0.997|3.26|     0.65|    9.8|      5|
|         11.2|            0.28|       0

In [7]:
dataset.printSchema()

root
 |-- fixed acidity: string (nullable = true)
 |-- volatile acidity: string (nullable = true)
 |-- citric acid: string (nullable = true)
 |-- residual sugar: string (nullable = true)
 |-- chlorides: string (nullable = true)
 |-- free sulfur dioxide: string (nullable = true)
 |-- total sulfur dioxide: string (nullable = true)
 |-- density: string (nullable = true)
 |-- pH: string (nullable = true)
 |-- sulphates: string (nullable = true)
 |-- alcohol: string (nullable = true)
 |-- quality: string (nullable = true)



In [8]:
from pyspark.sql.functions import col
new_data = dataset.select(*(col(c).cast("float").alias(c) for c in dataset.columns))

In [9]:
new_data.printSchema()

root
 |-- fixed acidity: float (nullable = true)
 |-- volatile acidity: float (nullable = true)
 |-- citric acid: float (nullable = true)
 |-- residual sugar: float (nullable = true)
 |-- chlorides: float (nullable = true)
 |-- free sulfur dioxide: float (nullable = true)
 |-- total sulfur dioxide: float (nullable = true)
 |-- density: float (nullable = true)
 |-- pH: float (nullable = true)
 |-- sulphates: float (nullable = true)
 |-- alcohol: float (nullable = true)
 |-- quality: float (nullable = true)



In [10]:
from pyspark.sql.functions import col, count, isnan, when
#checking for null ir nan type values in our columns
new_data.select([count(when(col(c).isNull(), c)).alias(c) for c in new_data.columns]).show()

+-------------+----------------+-----------+--------------+---------+-------------------+--------------------+-------+---+---------+-------+-------+
|fixed acidity|volatile acidity|citric acid|residual sugar|chlorides|free sulfur dioxide|total sulfur dioxide|density| pH|sulphates|alcohol|quality|
+-------------+----------------+-----------+--------------+---------+-------------------+--------------------+-------+---+---------+-------+-------+
|            0|               0|          0|             0|        0|                  0|                   0|      0|  0|        0|      0|      0|
+-------------+----------------+-----------+--------------+---------+-------------------+--------------------+-------+---+---------+-------+-------+



In [11]:
cols=new_data.columns
cols.remove("quality")
assembler = VectorAssembler(inputCols=cols,outputCol="features")
# Now let us use the transform method to transform our dataset
data=assembler.transform(new_data)
data = data.select("features",'quality')

In [12]:
data.show()

+--------------------+-------+
|            features|quality|
+--------------------+-------+
|[7.40000009536743...|    5.0|
|[7.80000019073486...|    5.0|
|[7.80000019073486...|    5.0|
|[11.1999998092651...|    6.0|
|[7.40000009536743...|    5.0|
|[7.40000009536743...|    5.0|
|[7.90000009536743...|    5.0|
|[7.30000019073486...|    7.0|
|[7.80000019073486...|    7.0|
|[7.5,0.5,0.360000...|    5.0|
|[6.69999980926513...|    5.0|
|[7.5,0.5,0.360000...|    5.0|
|[5.59999990463256...|    5.0|
|[7.80000019073486...|    5.0|
|[8.89999961853027...|    5.0|
|[8.89999961853027...|    5.0|
|[8.5,0.2800000011...|    7.0|
|[8.10000038146972...|    5.0|
|[7.40000009536743...|    4.0|
|[7.90000009536743...|    6.0|
+--------------------+-------+
only showing top 20 rows



In [13]:
stringIndexer = StringIndexer(inputCol="quality", outputCol="quality_index")
data_indexed = stringIndexer.fit(data).transform(data)

In [14]:
data_indexed.show()

+--------------------+-------+-------------+
|            features|quality|quality_index|
+--------------------+-------+-------------+
|[7.40000009536743...|    5.0|          0.0|
|[7.80000019073486...|    5.0|          0.0|
|[7.80000019073486...|    5.0|          0.0|
|[11.1999998092651...|    6.0|          1.0|
|[7.40000009536743...|    5.0|          0.0|
|[7.40000009536743...|    5.0|          0.0|
|[7.90000009536743...|    5.0|          0.0|
|[7.30000019073486...|    7.0|          2.0|
|[7.80000019073486...|    7.0|          2.0|
|[7.5,0.5,0.360000...|    5.0|          0.0|
|[6.69999980926513...|    5.0|          0.0|
|[7.5,0.5,0.360000...|    5.0|          0.0|
|[5.59999990463256...|    5.0|          0.0|
|[7.80000019073486...|    5.0|          0.0|
|[8.89999961853027...|    5.0|          0.0|
|[8.89999961853027...|    5.0|          0.0|
|[8.5,0.2800000011...|    7.0|          2.0|
|[8.10000038146972...|    5.0|          0.0|
|[7.40000009536743...|    4.0|          3.0|
|[7.900000

In [15]:
(train, test) = data_indexed.randomSplit([0.7, 0.3])

In [16]:
dt = DecisionTreeClassifier(labelCol="quality_index", featuresCol="features")

In [17]:
model = dt.fit(train)

In [18]:
predictions = model.transform(test)

In [19]:
predictions.show()

+--------------------+-------+-------------+--------------------+--------------------+----------+
|            features|quality|quality_index|       rawPrediction|         probability|prediction|
+--------------------+-------+-------------+--------------------+--------------------+----------+
|[4.69999980926513...|    6.0|          1.0|[4.0,40.0,16.0,1....|[0.06557377049180...|       1.0|
|[5.0,1.0199999809...|    4.0|          3.0|[0.0,7.0,0.0,1.0,...|[0.0,0.7777777777...|       1.0|
|[5.0,1.0399999618...|    5.0|          0.0|[50.0,163.0,25.0,...|[0.19841269841269...|       1.0|
|[5.09999990463256...|    7.0|          2.0|[2.0,14.0,24.0,0....|[0.04166666666666...|       2.0|
|[5.09999990463256...|    7.0|          2.0|[2.0,14.0,24.0,0....|[0.04166666666666...|       2.0|
|[5.09999990463256...|    7.0|          2.0|[2.0,14.0,24.0,0....|[0.04166666666666...|       2.0|
|[5.19999980926513...|    6.0|          1.0|[1.0,11.0,7.0,0.0...|[0.05263157894736...|       1.0|
|[5.19999980926513..

In [20]:
predictions.select("prediction", "quality_index", "features").show(5)

+----------+-------------+--------------------+
|prediction|quality_index|            features|
+----------+-------------+--------------------+
|       1.0|          1.0|[4.69999980926513...|
|       1.0|          3.0|[5.0,1.0199999809...|
|       1.0|          0.0|[5.0,1.0399999618...|
|       2.0|          2.0|[5.09999990463256...|
|       2.0|          2.0|[5.09999990463256...|
+----------+-------------+--------------------+
only showing top 5 rows



In [21]:
# Select (prediction, true label) and compute test error
evaluator = MulticlassClassificationEvaluator(
    labelCol="quality_index", predictionCol="prediction", metricName="accuracy")
accuracy = evaluator.evaluate(predictions)


In [22]:
print ("Accuracy",accuracy)

Accuracy 0.5155925155925156


In [23]:
spark.stop()