In [12]:
from pyspark.sql import SparkSession
from pyspark.sql.types import *


In [13]:
spark = SparkSession.builder.appName("iris_clf").getOrCreate()

In [14]:
df = spark.read.csv('iris.csv', header=True, inferSchema=True)
df.show(5)

+------------+-----------+------------+-----------+-------+
|sepal_length|sepal_width|petal_length|petal_width|variety|
+------------+-----------+------------+-----------+-------+
|         5.1|        3.5|         1.4|        0.2| Setosa|
|         4.9|        3.0|         1.4|        0.2| Setosa|
|         4.7|        3.2|         1.3|        0.2| Setosa|
|         4.6|        3.1|         1.5|        0.2| Setosa|
|         5.0|        3.6|         1.4|        0.2| Setosa|
+------------+-----------+------------+-----------+-------+
only showing top 5 rows



In [15]:
df.printSchema()

root
 |-- sepal_length: double (nullable = true)
 |-- sepal_width: double (nullable = true)
 |-- petal_length: double (nullable = true)
 |-- petal_width: double (nullable = true)
 |-- variety: string (nullable = true)



In [16]:
schema = StructType([
    StructField("sepal_length", DoubleType()),
    StructField("sepal_width", DoubleType()),
    StructField("petal_length", DoubleType()),
    StructField("petal_width", DoubleType()),
    StructField("type", StringType(), True)
])

In [17]:
df2 = spark.read.csv('iris.csv', header=True, schema=schema)
df.show(5)

+------------+-----------+------------+-----------+-------+
|sepal_length|sepal_width|petal_length|petal_width|variety|
+------------+-----------+------------+-----------+-------+
|         5.1|        3.5|         1.4|        0.2| Setosa|
|         4.9|        3.0|         1.4|        0.2| Setosa|
|         4.7|        3.2|         1.3|        0.2| Setosa|
|         4.6|        3.1|         1.5|        0.2| Setosa|
|         5.0|        3.6|         1.4|        0.2| Setosa|
+------------+-----------+------------+-----------+-------+
only showing top 5 rows



In [18]:
from pyspark.ml.feature import VectorAssembler

In [19]:
input_col = ['sepal_length', 'sepal_width', 'petal_length', 'petal_width']
vectorizer = VectorAssembler(inputCols = input_col, outputCol='features')

df = vectorizer.transform(df)

df.show(5)

+------------+-----------+------------+-----------+-------+-----------------+
|sepal_length|sepal_width|petal_length|petal_width|variety|         features|
+------------+-----------+------------+-----------+-------+-----------------+
|         5.1|        3.5|         1.4|        0.2| Setosa|[5.1,3.5,1.4,0.2]|
|         4.9|        3.0|         1.4|        0.2| Setosa|[4.9,3.0,1.4,0.2]|
|         4.7|        3.2|         1.3|        0.2| Setosa|[4.7,3.2,1.3,0.2]|
|         4.6|        3.1|         1.5|        0.2| Setosa|[4.6,3.1,1.5,0.2]|
|         5.0|        3.6|         1.4|        0.2| Setosa|[5.0,3.6,1.4,0.2]|
+------------+-----------+------------+-----------+-------+-----------------+
only showing top 5 rows



In [20]:
from pyspark.ml.feature import StringIndexer

In [22]:
indexer = StringIndexer(inputCol='variety', outputCol='indexed_type')
df = indexer.fit(df).transform(df)
df.show(5)

+------------+-----------+------------+-----------+-------+-----------------+------------+
|sepal_length|sepal_width|petal_length|petal_width|variety|         features|indexed_type|
+------------+-----------+------------+-----------+-------+-----------------+------------+
|         5.1|        3.5|         1.4|        0.2| Setosa|[5.1,3.5,1.4,0.2]|         0.0|
|         4.9|        3.0|         1.4|        0.2| Setosa|[4.9,3.0,1.4,0.2]|         0.0|
|         4.7|        3.2|         1.3|        0.2| Setosa|[4.7,3.2,1.3,0.2]|         0.0|
|         4.6|        3.1|         1.5|        0.2| Setosa|[4.6,3.1,1.5,0.2]|         0.0|
|         5.0|        3.6|         1.4|        0.2| Setosa|[5.0,3.6,1.4,0.2]|         0.0|
+------------+-----------+------------+-----------+-------+-----------------+------------+
only showing top 5 rows

