In [None]:
!pip install pyspark

Collecting pyspark
  Downloading pyspark-3.2.1.tar.gz (281.4 MB)
[K     |████████████████████████████████| 281.4 MB 33 kB/s 
[?25hCollecting py4j==0.10.9.3
  Downloading py4j-0.10.9.3-py2.py3-none-any.whl (198 kB)
[K     |████████████████████████████████| 198 kB 33.0 MB/s 
[?25hBuilding wheels for collected packages: pyspark
  Building wheel for pyspark (setup.py) ... [?25l[?25hdone
  Created wheel for pyspark: filename=pyspark-3.2.1-py2.py3-none-any.whl size=281853642 sha256=48a5c48b81fdba8952f7186b8a01acd99f89624e4d98b9fbb48bb08cc38d685a
  Stored in directory: /root/.cache/pip/wheels/9f/f5/07/7cd8017084dce4e93e84e92efd1e1d5334db05f2e83bcef74f
Successfully built pyspark
Installing collected packages: py4j, pyspark
Successfully installed py4j-0.10.9.3 pyspark-3.2.1


In [None]:
from pyspark.sql import SparkSession
from pyspark.ml.feature import VectorAssembler
from pyspark.ml.classification import DecisionTreeClassifier
from pyspark.ml.feature import StandardScaler
from pyspark.ml.evaluation import MulticlassClassificationEvaluator

In [None]:
if __name__ == "__main__":
    spark = SparkSession\
        .builder\
        .appName("DecisionTreeWithSpark")\
        .getOrCreate()

In [None]:
dataset = spark.read.csv("winequality_red.csv",header=True)

In [None]:
type(dataset)

pyspark.sql.dataframe.DataFrame

In [None]:
dataset.show()

+-------------+----------------+-----------+--------------+-------------------+-------------------+--------------------+-------+----+---------+-------+-------+
|fixed acidity|volatile acidity|citric acid|residual sugar|          chlorides|free sulfur dioxide|total sulfur dioxide|density|  pH|sulphates|alcohol|quality|
+-------------+----------------+-----------+--------------+-------------------+-------------------+--------------------+-------+----+---------+-------+-------+
|          7.4|             0.7|        0.0|           1.9|              0.076|               11.0|                34.0| 0.9978|3.51|     0.56|    9.4|      5|
|          7.8|            0.88|        0.0|           2.6|              0.098|               25.0|                67.0| 0.9968| 3.2|     0.68|    9.8|      5|
|          7.8|            0.76|       0.04|           2.3|              0.092|               15.0|                54.0|  0.997|3.26|     0.65|    9.8|      5|
|         11.2|            0.28|       0

In [None]:
dataset.printSchema()

root
 |-- fixed acidity: string (nullable = true)
 |-- volatile acidity: string (nullable = true)
 |-- citric acid: string (nullable = true)
 |-- residual sugar: string (nullable = true)
 |-- chlorides: string (nullable = true)
 |-- free sulfur dioxide: string (nullable = true)
 |-- total sulfur dioxide: string (nullable = true)
 |-- density: string (nullable = true)
 |-- pH: string (nullable = true)
 |-- sulphates: string (nullable = true)
 |-- alcohol: string (nullable = true)
 |-- quality: string (nullable = true)



In [None]:
from pyspark.sql.functions import col
new_data = dataset.select(*(col(c).cast("float").alias(c) for c in dataset.columns))

new_data.printSchema()

root
 |-- fixed acidity: float (nullable = true)
 |-- volatile acidity: float (nullable = true)
 |-- citric acid: float (nullable = true)
 |-- residual sugar: float (nullable = true)
 |-- chlorides: float (nullable = true)
 |-- free sulfur dioxide: float (nullable = true)
 |-- total sulfur dioxide: float (nullable = true)
 |-- density: float (nullable = true)
 |-- pH: float (nullable = true)
 |-- sulphates: float (nullable = true)
 |-- alcohol: float (nullable = true)
 |-- quality: float (nullable = true)



In [None]:
from pyspark.sql.functions import col, count, isnan, when
#checking for null ir nan type values in our columns
new_data.select([count(when(col(c).isNull(), c)).alias(c) for c in new_data.columns]).show()

+-------------+----------------+-----------+--------------+---------+-------------------+--------------------+-------+---+---------+-------+-------+
|fixed acidity|volatile acidity|citric acid|residual sugar|chlorides|free sulfur dioxide|total sulfur dioxide|density| pH|sulphates|alcohol|quality|
+-------------+----------------+-----------+--------------+---------+-------------------+--------------------+-------+---+---------+-------+-------+
|            0|               0|          0|             0|        0|                  0|                   0|      0|  0|        0|      0|      0|
+-------------+----------------+-----------+--------------+---------+-------------------+--------------------+-------+---+---------+-------+-------+



In [None]:
cols=new_data.columns
cols.remove("quality")
assembler = VectorAssembler(inputCols=cols,outputCol="features")

In [None]:
# Now let us use the transform method to transform our dataset
data=assembler.transform(new_data)
data = data.select("features",'quality')

In [None]:
data.show()

+--------------------+-------+
|            features|quality|
+--------------------+-------+
|[7.40000009536743...|    5.0|
|[7.80000019073486...|    5.0|
|[7.80000019073486...|    5.0|
|[11.1999998092651...|    6.0|
|[7.40000009536743...|    5.0|
|[7.40000009536743...|    5.0|
|[7.90000009536743...|    5.0|
|[7.30000019073486...|    7.0|
|[7.80000019073486...|    7.0|
|[7.5,0.5,0.360000...|    5.0|
|[6.69999980926513...|    5.0|
|[7.5,0.5,0.360000...|    5.0|
|[5.59999990463256...|    5.0|
|[7.80000019073486...|    5.0|
|[8.89999961853027...|    5.0|
|[8.89999961853027...|    5.0|
|[8.5,0.2800000011...|    7.0|
|[8.10000038146972...|    5.0|
|[7.40000009536743...|    4.0|
|[7.90000009536743...|    6.0|
+--------------------+-------+
only showing top 20 rows



In [None]:
from pyspark.ml.feature import StringIndexer
stringIndexer = StringIndexer(inputCol="quality", outputCol="quality_index")
data_indexed = stringIndexer.fit(data).transform(data)

data_indexed.show()

+--------------------+-------+-------------+
|            features|quality|quality_index|
+--------------------+-------+-------------+
|[7.40000009536743...|    5.0|          0.0|
|[7.80000019073486...|    5.0|          0.0|
|[7.80000019073486...|    5.0|          0.0|
|[11.1999998092651...|    6.0|          1.0|
|[7.40000009536743...|    5.0|          0.0|
|[7.40000009536743...|    5.0|          0.0|
|[7.90000009536743...|    5.0|          0.0|
|[7.30000019073486...|    7.0|          2.0|
|[7.80000019073486...|    7.0|          2.0|
|[7.5,0.5,0.360000...|    5.0|          0.0|
|[6.69999980926513...|    5.0|          0.0|
|[7.5,0.5,0.360000...|    5.0|          0.0|
|[5.59999990463256...|    5.0|          0.0|
|[7.80000019073486...|    5.0|          0.0|
|[8.89999961853027...|    5.0|          0.0|
|[8.89999961853027...|    5.0|          0.0|
|[8.5,0.2800000011...|    7.0|          2.0|
|[8.10000038146972...|    5.0|          0.0|
|[7.40000009536743...|    4.0|          3.0|
|[7.900000

In [None]:
(train, test) = data_indexed.randomSplit([0.7, 0.3])

dt = DecisionTreeClassifier(labelCol="quality_index", featuresCol="features")

model = dt.fit(train)

In [None]:
predictions = model.transform(test)

predictions.show()

+--------------------+-------+-------------+--------------------+--------------------+----------+
|            features|quality|quality_index|       rawPrediction|         probability|prediction|
+--------------------+-------+-------------+--------------------+--------------------+----------+
|[4.59999990463256...|    4.0|          3.0|[10.0,43.0,7.0,0....|[0.16666666666666...|       1.0|
|[5.09999990463256...|    6.0|          1.0|[10.0,43.0,7.0,0....|[0.16666666666666...|       1.0|
|[5.19999980926513...|    6.0|          1.0|[10.0,43.0,7.0,0....|[0.16666666666666...|       1.0|
|[5.30000019073486...|    5.0|          0.0|[26.0,102.0,5.0,2...|[0.19259259259259...|       1.0|
|[5.40000009536743...|    6.0|          1.0|[10.0,43.0,7.0,0....|[0.16666666666666...|       1.0|
|[5.59999990463256...|    6.0|          1.0|[0.0,6.0,0.0,0.0,...|[0.0,1.0,0.0,0.0,...|       1.0|
|[5.59999990463256...|    5.0|          0.0|[10.0,43.0,7.0,0....|[0.16666666666666...|       1.0|
|[5.59999990463256..

In [None]:
predictions.select("prediction", "quality_index", "features").show(5)

# Select (prediction, true label) and compute test error
evaluator = MulticlassClassificationEvaluator(labelCol="quality_index", predictionCol="prediction", metricName="accuracy")
accuracy = evaluator.evaluate(predictions)


print ("Accuracy",accuracy)

+----------+-------------+--------------------+
|prediction|quality_index|            features|
+----------+-------------+--------------------+
|       1.0|          3.0|[4.59999990463256...|
|       1.0|          1.0|[5.09999990463256...|
|       1.0|          1.0|[5.19999980926513...|
|       1.0|          0.0|[5.30000019073486...|
|       1.0|          1.0|[5.40000009536743...|
+----------+-------------+--------------------+
only showing top 5 rows

Accuracy 0.5823045267489712


In [None]:
spark.stop()