In [None]:
from __future__ import print_function
import findspark
findspark.init()
findspark.find()
import pyspark
findspark.find()
from pyspark.sql import SparkSession
from pyspark.ml.feature import OneHotEncoder,StringIndexer

In [None]:
if __name__ == "__main__":
    spark = SparkSession\
        .builder\
        .appName("OneHotEncoder")\
        .getOrCreate()

In [None]:
data = spark.createDataFrame([
        (0,"Good"),
        (1,"Bad"),
        (2,"Good"),
        (3,"Good"),
        (4,"Bad"),
        (5,"Good")
    ], ["id","category1"])

In [None]:
data.show()

+---+---------+
| id|category1|
+---+---------+
|  0|     Good|
|  1|      Bad|
|  2|     Good|
|  3|     Good|
|  4|      Bad|
|  5|     Good|
+---+---------+



In [None]:
data.printSchema()

root
 |-- id: long (nullable = true)
 |-- category1: string (nullable = true)



In [None]:
# one hot encoder require array type abject so if we pass our columns right now it will give an error.

# we will use stringIndexer first then will apply oneHotEncoder

## Label Encoder or String Indexer

In [None]:
indexer = StringIndexer(inputCol="category1", outputCol="category_index")

In [None]:
indexed = indexer.fit(data).transform(data)
indexed.show()

+---+---------+--------------+
| id|category1|category_index|
+---+---------+--------------+
|  0|     Good|           0.0|
|  1|      Bad|           1.0|
|  2|     Good|           0.0|
|  3|     Good|           0.0|
|  4|      Bad|           1.0|
|  5|     Good|           0.0|
+---+---------+--------------+



## One Hot Encoder

In [None]:
encoder = OneHotEncoder(inputCol="category_index", outputCol="category_encode")

In [None]:
encoded = encoder.transform(indexed)

In [None]:
encoded.show()

+---+---------+--------------+---------------+
| id|category1|category_index|category_encode|
+---+---------+--------------+---------------+
|  0|     Good|           0.0|  (1,[0],[1.0])|
|  1|      Bad|           1.0|      (1,[],[])|
|  2|     Good|           0.0|  (1,[0],[1.0])|
|  3|     Good|           0.0|  (1,[0],[1.0])|
|  4|      Bad|           1.0|      (1,[],[])|
|  5|     Good|           0.0|  (1,[0],[1.0])|
+---+---------+--------------+---------------+



In [None]:
# let's take an example with categories

In [None]:
data1 = spark.createDataFrame([
        (0,"Good"),
        (1,"Bad"),
        (2,"Good"),
        (3,"Average"),
        (4,"Bad"),
        (5,"Good"),
        (6,"Average"),
        (7,"Average"),
        (8,"Average")
    ], ["id","category1"])

In [None]:
data1.show()

+---+---------+
| id|category1|
+---+---------+
|  0|     Good|
|  1|      Bad|
|  2|     Good|
|  3|  Average|
|  4|      Bad|
|  5|     Good|
|  6|  Average|
|  7|  Average|
|  8|  Average|
+---+---------+



## Label Encoder or String Indexer

In [None]:
indexer = StringIndexer(inputCol="category1", outputCol="category_index")

In [None]:
indexed = indexer.fit(data1).transform(data1)
indexed.show()

+---+---------+--------------+
| id|category1|category_index|
+---+---------+--------------+
|  0|     Good|           1.0|
|  1|      Bad|           2.0|
|  2|     Good|           1.0|
|  3|  Average|           0.0|
|  4|      Bad|           2.0|
|  5|     Good|           1.0|
|  6|  Average|           0.0|
|  7|  Average|           0.0|
|  8|  Average|           0.0|
+---+---------+--------------+



## One Hot Encoder

In [None]:
encoder = OneHotEncoder(inputCol="category_index", outputCol="category_encode")
    

In [None]:
encoded = encoder.transform(indexed)

In [None]:
encoded.show()

+---+---------+--------------+---------------+
| id|category1|category_index|category_encode|
+---+---------+--------------+---------------+
|  0|     Good|           1.0|  (2,[1],[1.0])|
|  1|      Bad|           2.0|      (2,[],[])|
|  2|     Good|           1.0|  (2,[1],[1.0])|
|  3|  Average|           0.0|  (2,[0],[1.0])|
|  4|      Bad|           2.0|      (2,[],[])|
|  5|     Good|           1.0|  (2,[1],[1.0])|
|  6|  Average|           0.0|  (2,[0],[1.0])|
|  7|  Average|           0.0|  (2,[0],[1.0])|
|  8|  Average|           0.0|  (2,[0],[1.0])|
+---+---------+--------------+---------------+



In [None]:
spark.stop()