# One Hot Encoding

### Refreshing the Concept

Just like with the StringIndex, One Hot Encoding also transforms strings into numbers. However, this last creates a new column for each category, as shown in the following image:

![ohe](../../notebook-images/ohe.png)

[Image source: Fernando Amaral's PySpark Couse.](https://www.udemy.com/course/machine-learning-com-spark-e-pyspark/)

A difference between Spark's OHE and other libraries is that Spark expects the input column to be already numeric.

So, in this case, before we apply the One Hot Encoding, if our column is String, we have to transform to index.

Spark's OHE outputs a dense matrix, similar to the `VectorAssembler` and `PCA` outputs.

However, the result can be a little confusing. So, let's try to understand it for a while.

The OHE output in Spark looks like the following:


![ohe](../../notebook-images/spark-ohe.png)

[Image source: Fernando Amaral's PySpark Couse.](https://www.udemy.com/course/machine-learning-com-spark-e-pyspark/)

The `onehot_c1` is the OHE for the Geography column.

So, we have:

- **France**: index 0
- **Germany**: index 1
- **Spain**: Index 2

To make more clear, the following image shows what each component represents:

![ohe](../../notebook-images/single-ohe.png)

The first item is the number of categories. Sparks counts it by doing `n - 1`, where `n` is the number of unique categories we have.

The second item refers to the index value of the category. In the example, 0 means it is france, and 1 means it is Germany.

The last items is the presence or absence of value.

In our example, we have 3 categories, France, Germany, and Spain.

An important thing to note is that ti discards the last value. So, the empty brackets means it's refering to the last value

## Importing

In [1]:
import pyspark, findspark
from pyspark.sql import SparkSession

findspark.init()

spark = SparkSession.builder.appName("ohe").getOrCreate()

In [2]:
from pyspark.ml.feature import StringIndexer, OneHotEncoder

## Loading Data

In [3]:
churn = spark.read.load(
    "../../data/Churn.csv",
    format="csv",
    sep=";",
    header = True, 
    inferSchema=True)

churn.show(2)

+-----------+---------+------+---+------+-------+-------------+---------+--------------+---------------+------+
|CreditScore|Geography|Gender|Age|Tenure|Balance|NumOfProducts|HasCrCard|IsActiveMember|EstimatedSalary|Exited|
+-----------+---------+------+---+------+-------+-------------+---------+--------------+---------------+------+
|        619|   France|Female| 42|     2|      0|            1|        1|             1|       10134888|     1|
|        608|    Spain|Female| 41|     1|8380786|            1|        0|             1|       11254258|     0|
+-----------+---------+------+---+------+-------+-------------+---------+--------------+---------------+------+
only showing top 2 rows



## Taking Strings into Numbers With StringIndexer

In [4]:
# Geography
indexer = StringIndexer(
    inputCol="Geography",
    outputCol="geo_index"
)
churn = indexer.fit(churn).transform(churn)

# Gender
indexer = StringIndexer(
    inputCol="Gender",
    outputCol="gender_index"
)


churn = indexer.fit(churn).transform(churn)

In [5]:
churn.show(2)

+-----------+---------+------+---+------+-------+-------------+---------+--------------+---------------+------+---------+------------+
|CreditScore|Geography|Gender|Age|Tenure|Balance|NumOfProducts|HasCrCard|IsActiveMember|EstimatedSalary|Exited|geo_index|gender_index|
+-----------+---------+------+---+------+-------+-------------+---------+--------------+---------------+------+---------+------------+
|        619|   France|Female| 42|     2|      0|            1|        1|             1|       10134888|     1|      0.0|         1.0|
|        608|    Spain|Female| 41|     1|8380786|            1|        0|             1|       11254258|     0|      2.0|         1.0|
+-----------+---------+------+---+------+-------+-------------+---------+--------------+---------------+------+---------+------------+
only showing top 2 rows



Once we have both columns in the numeric format, we are able to perform the One Hot Encoding.

## Using One Hot Encoder

In [7]:
ohe = OneHotEncoder(
    inputCols=["geo_index","gender_index"],
    outputCols=["geo_ohe", "gender_ohe"]
)

model = ohe.fit(churn)

churn = model.transform(churn)

In [8]:
churn.show(2)

+-----------+---------+------+---+------+-------+-------------+---------+--------------+---------------+------+---------+------------+-------------+----------+
|CreditScore|Geography|Gender|Age|Tenure|Balance|NumOfProducts|HasCrCard|IsActiveMember|EstimatedSalary|Exited|geo_index|gender_index|      geo_ohe|gender_ohe|
+-----------+---------+------+---+------+-------+-------------+---------+--------------+---------------+------+---------+------------+-------------+----------+
|        619|   France|Female| 42|     2|      0|            1|        1|             1|       10134888|     1|      0.0|         1.0|(2,[0],[1.0])| (1,[],[])|
|        608|    Spain|Female| 41|     1|8380786|            1|        0|             1|       11254258|     0|      2.0|         1.0|    (2,[],[])| (1,[],[])|
+-----------+---------+------+---+------+-------+-------------+---------+--------------+---------------+------+---------+------------+-------------+----------+
only showing top 2 rows



In [9]:
churn.select("Geography", "geo_index", "geo_ohe").show(5)

+---------+---------+-------------+
|Geography|geo_index|      geo_ohe|
+---------+---------+-------------+
|   France|      0.0|(2,[0],[1.0])|
|    Spain|      2.0|    (2,[],[])|
|   France|      0.0|(2,[0],[1.0])|
|   France|      0.0|(2,[0],[1.0])|
|    Spain|      2.0|    (2,[],[])|
+---------+---------+-------------+
only showing top 5 rows

