In [1]:
from __future__ import print_function
import findspark
findspark.init()
findspark.find()
import pyspark
findspark.find()
from pyspark.sql import SparkSession
from pyspark.ml.feature import VectorAssembler
from pyspark.ml.classification import RandomForestClassifier
from pyspark.ml.feature import StandardScaler
from pyspark.ml.evaluation import MulticlassClassificationEvaluator
from pyspark.ml.feature import StringIndexer

In [2]:
if __name__ == "__main__":
    spark = SparkSession\
        .builder\
        .appName("randomForestClassifier")\
        .getOrCreate()

In [3]:
dataset = spark.read.csv("winequality_red.csv",header=True)

In [4]:
dataset.show()

+-------------+----------------+-----------+--------------+-------------------+-------------------+--------------------+-------+----+---------+-------+-------+
|fixed acidity|volatile acidity|citric acid|residual sugar|          chlorides|free sulfur dioxide|total sulfur dioxide|density|  pH|sulphates|alcohol|quality|
+-------------+----------------+-----------+--------------+-------------------+-------------------+--------------------+-------+----+---------+-------+-------+
|          7.4|             0.7|        0.0|           1.9|              0.076|               11.0|                34.0| 0.9978|3.51|     0.56|    9.4|      5|
|          7.8|            0.88|        0.0|           2.6|              0.098|               25.0|                67.0| 0.9968| 3.2|     0.68|    9.8|      5|
|          7.8|            0.76|       0.04|           2.3|              0.092|               15.0|                54.0|  0.997|3.26|     0.65|    9.8|      5|
|         11.2|            0.28|       0

In [5]:
dataset.printSchema()

root
 |-- fixed acidity: string (nullable = true)
 |-- volatile acidity: string (nullable = true)
 |-- citric acid: string (nullable = true)
 |-- residual sugar: string (nullable = true)
 |-- chlorides: string (nullable = true)
 |-- free sulfur dioxide: string (nullable = true)
 |-- total sulfur dioxide: string (nullable = true)
 |-- density: string (nullable = true)
 |-- pH: string (nullable = true)
 |-- sulphates: string (nullable = true)
 |-- alcohol: string (nullable = true)
 |-- quality: string (nullable = true)



In [6]:
from pyspark.sql.functions import col
new_data = dataset.select(*(col(c).cast("float").alias(c) for c in dataset.columns))

In [7]:
new_data.printSchema()

root
 |-- fixed acidity: float (nullable = true)
 |-- volatile acidity: float (nullable = true)
 |-- citric acid: float (nullable = true)
 |-- residual sugar: float (nullable = true)
 |-- chlorides: float (nullable = true)
 |-- free sulfur dioxide: float (nullable = true)
 |-- total sulfur dioxide: float (nullable = true)
 |-- density: float (nullable = true)
 |-- pH: float (nullable = true)
 |-- sulphates: float (nullable = true)
 |-- alcohol: float (nullable = true)
 |-- quality: float (nullable = true)



In [8]:
from pyspark.sql.functions import col, count, isnan, when
#checking for null ir nan type values in our columns
new_data.select([count(when(col(c).isNull(), c)).alias(c) for c in new_data.columns]).show()

+-------------+----------------+-----------+--------------+---------+-------------------+--------------------+-------+---+---------+-------+-------+
|fixed acidity|volatile acidity|citric acid|residual sugar|chlorides|free sulfur dioxide|total sulfur dioxide|density| pH|sulphates|alcohol|quality|
+-------------+----------------+-----------+--------------+---------+-------------------+--------------------+-------+---+---------+-------+-------+
|            0|               0|          0|             0|        0|                  0|                   0|      0|  0|        0|      0|      0|
+-------------+----------------+-----------+--------------+---------+-------------------+--------------------+-------+---+---------+-------+-------+



In [9]:
cols=new_data.columns
cols.remove("quality")
assembler = VectorAssembler(inputCols=cols,outputCol="features")
# Now let us use the transform method to transform our dataset
data=assembler.transform(new_data)
data = data.select("features",'quality')

In [10]:
data.show()

+--------------------+-------+
|            features|quality|
+--------------------+-------+
|[7.40000009536743...|    5.0|
|[7.80000019073486...|    5.0|
|[7.80000019073486...|    5.0|
|[11.1999998092651...|    6.0|
|[7.40000009536743...|    5.0|
|[7.40000009536743...|    5.0|
|[7.90000009536743...|    5.0|
|[7.30000019073486...|    7.0|
|[7.80000019073486...|    7.0|
|[7.5,0.5,0.360000...|    5.0|
|[6.69999980926513...|    5.0|
|[7.5,0.5,0.360000...|    5.0|
|[5.59999990463256...|    5.0|
|[7.80000019073486...|    5.0|
|[8.89999961853027...|    5.0|
|[8.89999961853027...|    5.0|
|[8.5,0.2800000011...|    7.0|
|[8.10000038146972...|    5.0|
|[7.40000009536743...|    4.0|
|[7.90000009536743...|    6.0|
+--------------------+-------+
only showing top 20 rows



In [11]:
stringIndexer = StringIndexer(inputCol="quality", outputCol="quality_index")
data_indexed = stringIndexer.fit(data).transform(data)

In [12]:
data_indexed.show()

+--------------------+-------+-------------+
|            features|quality|quality_index|
+--------------------+-------+-------------+
|[7.40000009536743...|    5.0|          0.0|
|[7.80000019073486...|    5.0|          0.0|
|[7.80000019073486...|    5.0|          0.0|
|[11.1999998092651...|    6.0|          1.0|
|[7.40000009536743...|    5.0|          0.0|
|[7.40000009536743...|    5.0|          0.0|
|[7.90000009536743...|    5.0|          0.0|
|[7.30000019073486...|    7.0|          2.0|
|[7.80000019073486...|    7.0|          2.0|
|[7.5,0.5,0.360000...|    5.0|          0.0|
|[6.69999980926513...|    5.0|          0.0|
|[7.5,0.5,0.360000...|    5.0|          0.0|
|[5.59999990463256...|    5.0|          0.0|
|[7.80000019073486...|    5.0|          0.0|
|[8.89999961853027...|    5.0|          0.0|
|[8.89999961853027...|    5.0|          0.0|
|[8.5,0.2800000011...|    7.0|          2.0|
|[8.10000038146972...|    5.0|          0.0|
|[7.40000009536743...|    4.0|          3.0|
|[7.900000

In [13]:
(train, test) = data_indexed.randomSplit([0.7, 0.3])

In [14]:
random_forest_classifier = RandomForestClassifier(labelCol="quality_index", featuresCol="features", numTrees=10)

In [15]:
model = random_forest_classifier.fit(train)

In [16]:
predictions = model.transform(test)

In [17]:
predictions.show()

+--------------------+-------+-------------+--------------------+--------------------+----------+
|            features|quality|quality_index|       rawPrediction|         probability|prediction|
+--------------------+-------+-------------+--------------------+--------------------+----------+
|[4.69999980926513...|    6.0|          1.0|[4.32319937367372...|[0.43231993736737...|       0.0|
|[4.90000009536743...|    7.0|          2.0|[1.44811678034167...|[0.14481167803416...|       1.0|
|[5.09999990463256...|    7.0|          2.0|[1.99423718167946...|[0.19942371816794...|       1.0|
|[5.09999990463256...|    7.0|          2.0|[2.08890398054961...|[0.20889039805496...|       1.0|
|[5.19999980926513...|    5.0|          0.0|[6.86608360822343...|[0.68660836082234...|       0.0|
|[5.19999980926513...|    6.0|          1.0|[0.68216204751131...|[0.06821620475113...|       1.0|
|[5.30000019073486...|    7.0|          2.0|[1.66470864569932...|[0.16647086456993...|       2.0|
|[5.40000009536743..

In [18]:
# Select (prediction, true label) and compute test error
evaluator = MulticlassClassificationEvaluator(
    labelCol="quality_index", predictionCol="prediction", metricName="accuracy")
accuracy = evaluator.evaluate(predictions)


In [19]:
print ("Accuracy",accuracy)

Accuracy 0.6042884990253411


In [20]:
spark.stop()