### 创建session

In [1]:
from pyspark.sql import SparkSession
from pyspark.conf import SparkConf
conf = SparkConf()
conf.set("spark.driver.extraJavaOptions", "-Dio.netty.tryReflectionSetAccessible=true")
conf.set("spark.executor.extraJavaOptions", "-Dio.netty.tryReflectionSetAccessible=true")
spark = SparkSession.builder.config(conf=conf).getOrCreate()

### 导入数据

In [4]:
df = spark.read.option("header", True).csv("color.csv") 
df.show()

+------+
| Color|
+------+
|   Red|
|   Red|
|Yellow|
| Green|
|Yellow|
+------+



### 创建唯一值

In [5]:
#   ##  import the required libraries
from pyspark.sql.functions import udf, col
from pyspark.sql.types import IntegerType

In [6]:
#   ##  gather the distinct values
distinct_values = list(df.select("Color")
                       .distinct()
                       .toPandas()["Color"])

In [7]:
distinct_values

['Green', 'Yellow', 'Red']

In [8]:
#   ##  gather the distinct values
distinct_values = df.select("Color")\
                    .distinct() \
                    .rdd\
                    .flatMap(lambda x: x).collect()

### 做数据转换，建立onehot encoder

In [9]:
#   ##  for each of the gathered values create a new column 
for distinct_value in distinct_values:
    function = udf(lambda item: 
                   1 if item == distinct_value else 0, 
                   IntegerType())
    new_column_name = "Color"+'_'+distinct_value
    df = df.withColumn(new_column_name, function(col("Color")))
df.show()

+------+-----------+------------+---------+
| Color|Color_Green|Color_Yellow|Color_Red|
+------+-----------+------------+---------+
|   Red|          0|           0|        1|
|   Red|          0|           0|        1|
|Yellow|          0|           1|        0|
| Green|          1|           0|        0|
|Yellow|          0|           1|        0|
+------+-----------+------------+---------+

