In [1]:
from pyspark.sql import SparkSession

In [2]:
spark = SparkSession.Builder().appName('iris').getOrCreate()

In [3]:
from pyspark.sql.types import StructType, StructField, DecimalType, StringType

In [4]:
data_schema = StructType([StructField('sepal_l', DecimalType(scale=1), False),
                         StructField('sepal_w', DecimalType(scale=1), False),
                         StructField('petal_l', DecimalType(scale=1), False),
                         StructField('petal_w', DecimalType(scale=1), False),
                         StructField('class', StringType(), False)])

[dataset](http://archive.ics.uci.edu/ml/datasets/Iris)

In [5]:
df = spark.read.csv('iris.data', header=False, schema=data_schema)

In [6]:
df.show(3)

+-------+-------+-------+-------+-----------+
|sepal_l|sepal_w|petal_l|petal_w|      class|
+-------+-------+-------+-------+-----------+
|    5.1|    3.5|    1.4|    0.2|Iris-setosa|
|    4.9|    3.0|    1.4|    0.2|Iris-setosa|
|    4.7|    3.2|    1.3|    0.2|Iris-setosa|
+-------+-------+-------+-------+-----------+
only showing top 3 rows



## Prepare data

### Convert class to number

In [7]:
from pyspark.ml.feature import StringIndexer

In [8]:
indexer = StringIndexer(inputCol='class', outputCol='class_index')

### create features vector

In [9]:
from pyspark.ml.feature import VectorAssembler

In [10]:
features_assembler = VectorAssembler(inputCols=['sepal_l', 'sepal_w', 'petal_l', 'petal_w'], outputCol='features')

## declare ML algorithm

In [11]:
from pyspark.ml.classification import LogisticRegression

In [12]:
regressor =  LogisticRegression(featuresCol='features', labelCol='class_index')

## use pipeline

In [13]:
from pyspark.ml import Pipeline

In [14]:
pipeline = Pipeline(stages=[indexer, features_assembler])

In [15]:
trained_pipeline = pipeline.fit(df)

In [16]:
final_data = trained_pipeline.transform(df)

In [17]:
final_data.show(3)

+-------+-------+-------+-------+-----------+-----------+-----------------+
|sepal_l|sepal_w|petal_l|petal_w|      class|class_index|         features|
+-------+-------+-------+-------+-----------+-----------+-----------------+
|    5.1|    3.5|    1.4|    0.2|Iris-setosa|        0.0|[5.1,3.5,1.4,0.2]|
|    4.9|    3.0|    1.4|    0.2|Iris-setosa|        0.0|[4.9,3.0,1.4,0.2]|
|    4.7|    3.2|    1.3|    0.2|Iris-setosa|        0.0|[4.7,3.2,1.3,0.2]|
+-------+-------+-------+-------+-----------+-----------+-----------------+
only showing top 3 rows



In [18]:
train_data, test_data = final_data.randomSplit([0.7, 0.3])

In [19]:
regression_model = regressor.fit(train_data)

## check model

In [20]:
results = regression_model.transform(test_data)

In [21]:
results.printSchema()

root
 |-- sepal_l: decimal(10,1) (nullable = true)
 |-- sepal_w: decimal(10,1) (nullable = true)
 |-- petal_l: decimal(10,1) (nullable = true)
 |-- petal_w: decimal(10,1) (nullable = true)
 |-- class: string (nullable = true)
 |-- class_index: double (nullable = false)
 |-- features: vector (nullable = true)
 |-- rawPrediction: vector (nullable = true)
 |-- probability: vector (nullable = true)
 |-- prediction: double (nullable = false)



In [22]:
from pyspark.ml.evaluation import MulticlassClassificationEvaluator

In [23]:
evaluator = MulticlassClassificationEvaluator(predictionCol='prediction', labelCol='class_index', metricName='accuracy')

In [24]:
summary = evaluator.evaluate(results)

In [25]:
summary

0.9583333333333334

In [26]:
type(summary)

float

In [43]:
results.select(['class_index','rawPrediction','probability','prediction']).show(3, truncate=False)

+-----------+-----------------------------------------------------------+-------------+----------+
|class_index|rawPrediction                                              |probability  |prediction|
+-----------+-----------------------------------------------------------+-------------+----------+
|0.0        |[2180.627938136404,-425.1580266562824,-1755.4699114801222] |[1.0,0.0,0.0]|0.0       |
|0.0        |[2355.876744774526,-459.59034946521734,-1896.286395309308] |[1.0,0.0,0.0]|0.0       |
|0.0        |[2254.787869217417,-457.61298391348333,-1797.1748853039335]|[1.0,0.0,0.0]|0.0       |
+-----------+-----------------------------------------------------------+-------------+----------+
only showing top 3 rows

