In [53]:
import findspark; findspark.init()

from pyspark.ml.classification import LogisticRegression
from pyspark.ml.feature import StringIndexer, VectorAssembler
from pyspark.ml.evaluation import MulticlassClassificationEvaluator
from pyspark.sql import SparkSession
from pyspark.sql.types import StructField, StructType
from pyspark.sql.types import DoubleType, StringType


In [4]:
spark = SparkSession.builder \
                    .master('local[*]') \
                    .appName('Iris_multiclass') \
                    .getOrCreate()
sc = spark.sparkContext
sc.version

'3.0.0-preview2'

In [27]:
# Read Iris Data
path = r"C:\Users\se.vi.dmitriev\Downloads\DS Materials\scripts\iris.data"
schema = StructType([
    StructField('sepal_length', DoubleType(), True),
    StructField('sepal_width', DoubleType(), True),
    StructField('petal_length', DoubleType(), True),
    StructField('petal_width', DoubleType(), True),
    StructField('class', StringType(), True)
])
data = spark.read.csv(path, schema=schema)
data.show(5, truncate=False)


+------------+-----------+------------+-----------+-----------+
|sepal_length|sepal_width|petal_length|petal_width|class      |
+------------+-----------+------------+-----------+-----------+
|5.1         |3.5        |1.4         |0.2        |Iris-setosa|
|4.9         |3.0        |1.4         |0.2        |Iris-setosa|
|4.7         |3.2        |1.3         |0.2        |Iris-setosa|
|4.6         |3.1        |1.5         |0.2        |Iris-setosa|
|5.0         |3.6        |1.4         |0.2        |Iris-setosa|
+------------+-----------+------------+-----------+-----------+
only showing top 5 rows



In [28]:
data.printSchema()

root
 |-- sepal_length: double (nullable = true)
 |-- sepal_width: double (nullable = true)
 |-- petal_length: double (nullable = true)
 |-- petal_width: double (nullable = true)
 |-- class: string (nullable = true)



In [30]:
data.describe().show()

+-------+------------------+-------------------+------------------+------------------+--------------+
|summary|      sepal_length|        sepal_width|      petal_length|       petal_width|         class|
+-------+------------------+-------------------+------------------+------------------+--------------+
|  count|               150|                150|               150|               150|           150|
|   mean| 5.843333333333335| 3.0540000000000007|3.7586666666666693|1.1986666666666672|          null|
| stddev|0.8280661279778637|0.43359431136217375| 1.764420419952262|0.7631607417008414|          null|
|    min|               4.3|                2.0|               1.0|               0.1|   Iris-setosa|
|    max|               7.9|                4.4|               6.9|               2.5|Iris-virginica|
+-------+------------------+-------------------+------------------+------------------+--------------+



In [37]:
# Create StringIndexer
labelIndexer = StringIndexer(inputCol='class', outputCol='class_label')
data_labeled = labelIndexer.fit(data).transform(data)
data_labeled.show(10, truncate=False)

+------------+-----------+------------+-----------+-----------+-----------+
|sepal_length|sepal_width|petal_length|petal_width|class      |class_label|
+------------+-----------+------------+-----------+-----------+-----------+
|5.1         |3.5        |1.4         |0.2        |Iris-setosa|0.0        |
|4.9         |3.0        |1.4         |0.2        |Iris-setosa|0.0        |
|4.7         |3.2        |1.3         |0.2        |Iris-setosa|0.0        |
|4.6         |3.1        |1.5         |0.2        |Iris-setosa|0.0        |
|5.0         |3.6        |1.4         |0.2        |Iris-setosa|0.0        |
|5.4         |3.9        |1.7         |0.4        |Iris-setosa|0.0        |
|4.6         |3.4        |1.4         |0.3        |Iris-setosa|0.0        |
|5.0         |3.4        |1.5         |0.2        |Iris-setosa|0.0        |
|4.4         |2.9        |1.4         |0.2        |Iris-setosa|0.0        |
|4.9         |3.1        |1.5         |0.1        |Iris-setosa|0.0        |
+-----------

In [43]:
# Create VectorAssembler
data_featured = VectorAssembler(inputCols=['sepal_length', 'sepal_width', 'petal_length', 'petal_width'], outputCol='features').transform(data_labeled)
data_featured.show(10, truncate=False)

+------------+-----------+------------+-----------+-----------+-----------+-----------------+
|sepal_length|sepal_width|petal_length|petal_width|class      |class_label|features         |
+------------+-----------+------------+-----------+-----------+-----------+-----------------+
|5.1         |3.5        |1.4         |0.2        |Iris-setosa|0.0        |[5.1,3.5,1.4,0.2]|
|4.9         |3.0        |1.4         |0.2        |Iris-setosa|0.0        |[4.9,3.0,1.4,0.2]|
|4.7         |3.2        |1.3         |0.2        |Iris-setosa|0.0        |[4.7,3.2,1.3,0.2]|
|4.6         |3.1        |1.5         |0.2        |Iris-setosa|0.0        |[4.6,3.1,1.5,0.2]|
|5.0         |3.6        |1.4         |0.2        |Iris-setosa|0.0        |[5.0,3.6,1.4,0.2]|
|5.4         |3.9        |1.7         |0.4        |Iris-setosa|0.0        |[5.4,3.9,1.7,0.4]|
|4.6         |3.4        |1.4         |0.3        |Iris-setosa|0.0        |[4.6,3.4,1.4,0.3]|
|5.0         |3.4        |1.5         |0.2        |Iris-seto

In [47]:
# Estimate correlation between features and class label 
print(data_featured.corr('sepal_length', 'class_label'))
print(data_featured.corr('sepal_width', 'class_label'))
print(data_featured.corr('petal_length', 'class_label'))
print(data_featured.corr('petal_width', 'class_label'))

0.7825612318100821
-0.41944620026002677
0.9490425448523336
0.9564638238016178


In [48]:
# Train, test split
data_featured_train, data_featured_test = data_featured.randomSplit([0.8, 0.2], seed=1234)

In [52]:
# Create logistic Classifier, train and make predictions
lg_clf = LogisticRegression(featuresCol='features', labelCol='class_label')
model = lg_clf.fit(data_featured_train)
predictions = model.transform(data_featured_test)
predictions.select('sepal_length', 'sepal_width', 'petal_length', 'petal_width', 'class_label', 'prediction').show(20, truncate=False)

+------------+-----------+------------+-----------+-----------+----------+
|sepal_length|sepal_width|petal_length|petal_width|class_label|prediction|
+------------+-----------+------------+-----------+-----------+----------+
|4.4         |2.9        |1.4         |0.2        |0.0        |0.0       |
|4.5         |2.3        |1.3         |0.3        |0.0        |1.0       |
|4.9         |3.1        |1.5         |0.1        |0.0        |0.0       |
|5.0         |3.0        |1.6         |0.2        |0.0        |0.0       |
|5.0         |3.2        |1.2         |0.2        |0.0        |0.0       |
|5.0         |3.3        |1.4         |0.2        |0.0        |0.0       |
|5.0         |3.4        |1.5         |0.2        |0.0        |0.0       |
|5.1         |3.5        |1.4         |0.3        |0.0        |0.0       |
|5.3         |3.7        |1.5         |0.2        |0.0        |0.0       |
|5.4         |3.4        |1.5         |0.4        |0.0        |0.0       |
|5.5         |2.3        

In [51]:
predictions.printSchema()

root
 |-- sepal_length: double (nullable = true)
 |-- sepal_width: double (nullable = true)
 |-- petal_length: double (nullable = true)
 |-- petal_width: double (nullable = true)
 |-- class: string (nullable = true)
 |-- class_label: double (nullable = false)
 |-- features: vector (nullable = true)
 |-- rawPrediction: vector (nullable = true)
 |-- probability: vector (nullable = true)
 |-- prediction: double (nullable = false)



In [55]:
# Evaluate results
evaluator = MulticlassClassificationEvaluator(predictionCol='prediction', labelCol='class_label', metricName='accuracy')
evaluator.evaluate(predictions)

0.972972972972973

In [57]:
# Get confusion matrix
predictions.groupby('class_label', 'prediction').count().show()

+-----------+----------+-----+
|class_label|prediction|count|
+-----------+----------+-----+
|        0.0|       0.0|   13|
|        0.0|       1.0|    1|
|        1.0|       1.0|   12|
|        2.0|       2.0|   11|
+-----------+----------+-----+

