# Creating and using UDFs in Pyspark

In [1]:
import numpy as np
import pandas as pd 
import pyspark.sql.types as T
import pyspark.sql.functions as F
from pyspark.sql import SparkSession
from pyspark.sql.window import Window

In [11]:
from pyspark.sql import SQLContext

In [2]:
# initialize spark session
spark = SparkSession.builder \
            .master("local[*]") \
            .appName("ShortNSimple") \
            .getOrCreate()
spark

In [12]:
sqlContext = SQLContext(spark.sparkContext)

In [3]:
data = spark.createDataFrame(pd.read_csv("./datasets/iris.csv"))
data.show(5, False)

+------------+-----------+------------+-----------+-----------+
|sepal_length|sepal_width|petal_length|petal_width|class      |
+------------+-----------+------------+-----------+-----------+
|5.1         |3.5        |1.4         |0.2        |Iris-setosa|
|4.9         |3.0        |1.4         |0.2        |Iris-setosa|
|4.7         |3.2        |1.3         |0.2        |Iris-setosa|
|4.6         |3.1        |1.5         |0.2        |Iris-setosa|
|5.0         |3.6        |1.4         |0.2        |Iris-setosa|
+------------+-----------+------------+-----------+-----------+
only showing top 5 rows



In [4]:
data.count()

150

In [5]:
data.select("class").distinct().show(10, False)

+---------------+
|class          |
+---------------+
|Iris-virginica |
|Iris-setosa    |
|Iris-versicolor|
+---------------+



In [None]:
Iris-setosa  Iris-virginica 
1               0              
0               1               
0               0              

In [None]:
N => N-1

# One-hot encoding using UDF

In [6]:
def one_hot_encode(class_label):
    temp_code = [0, 0]
    if class_label.lower() == 'iris-setosa':
        temp_code[0] = 1
    elif class_label.lower() == 'iris-virginica':
        temp_code[1] = 1
    return temp_code

print(one_hot_encode("Iris-setosa"))
print(one_hot_encode("Iris-virginica"))
print(one_hot_encode("Iris-versicolor"))

[1, 0]
[0, 1]
[0, 0]


In [7]:
type(one_hot_encode("Iris-setosa"))

list

In [8]:
type(one_hot_encode("Iris-setosa")[0])

int

In [9]:
# pyspark - udf function
one_hot_encode_udf = F.udf(one_hot_encode, T.ArrayType(T.IntegerType()))

In [10]:
data = data.withColumn(
    "one_hot_values",
    one_hot_encode_udf(F.col("class"))
)

data.show(110, False)

+------------+-----------+------------+-----------+---------------+--------------+
|sepal_length|sepal_width|petal_length|petal_width|class          |one_hot_values|
+------------+-----------+------------+-----------+---------------+--------------+
|5.1         |3.5        |1.4         |0.2        |Iris-setosa    |[1, 0]        |
|4.9         |3.0        |1.4         |0.2        |Iris-setosa    |[1, 0]        |
|4.7         |3.2        |1.3         |0.2        |Iris-setosa    |[1, 0]        |
|4.6         |3.1        |1.5         |0.2        |Iris-setosa    |[1, 0]        |
|5.0         |3.6        |1.4         |0.2        |Iris-setosa    |[1, 0]        |
|5.4         |3.9        |1.7         |0.4        |Iris-setosa    |[1, 0]        |
|4.6         |3.4        |1.4         |0.3        |Iris-setosa    |[1, 0]        |
|5.0         |3.4        |1.5         |0.2        |Iris-setosa    |[1, 0]        |
|4.4         |2.9        |1.4         |0.2        |Iris-setosa    |[1, 0]        |
|4.9

# UDF as SQL/HIVE Functions

In [13]:
sqlContext.registerFunction("one_hot_encode_udf", one_hot_encode, T.ArrayType(T.IntegerType()))

<function __main__.one_hot_encode(class_label)>

In [14]:
data.registerTempTable("data")

In [15]:
sql_output = sqlContext.sql("""
    SELECT
        *,
        one_hot_encode_udf(class) as one_hot_sql
    FROM
        data
""")

sql_output.show(110, False)

+------------+-----------+------------+-----------+---------------+--------------+-----------+
|sepal_length|sepal_width|petal_length|petal_width|class          |one_hot_values|one_hot_sql|
+------------+-----------+------------+-----------+---------------+--------------+-----------+
|5.1         |3.5        |1.4         |0.2        |Iris-setosa    |[1, 0]        |[1, 0]     |
|4.9         |3.0        |1.4         |0.2        |Iris-setosa    |[1, 0]        |[1, 0]     |
|4.7         |3.2        |1.3         |0.2        |Iris-setosa    |[1, 0]        |[1, 0]     |
|4.6         |3.1        |1.5         |0.2        |Iris-setosa    |[1, 0]        |[1, 0]     |
|5.0         |3.6        |1.4         |0.2        |Iris-setosa    |[1, 0]        |[1, 0]     |
|5.4         |3.9        |1.7         |0.4        |Iris-setosa    |[1, 0]        |[1, 0]     |
|4.6         |3.4        |1.4         |0.3        |Iris-setosa    |[1, 0]        |[1, 0]     |
|5.0         |3.4        |1.5         |0.2        

# Explode!!!

In [22]:
from pyspark.sql.window import Window

In [23]:
sql_output = sql_output.withColumn(
    "id",
    F.row_number().over(Window.orderBy("class"))
)

sql_output.show(5, False)

+------------+-----------+------------+-----------+-----------+--------------+-----------+---+
|sepal_length|sepal_width|petal_length|petal_width|class      |one_hot_values|one_hot_sql|id |
+------------+-----------+------------+-----------+-----------+--------------+-----------+---+
|5.1         |3.5        |1.4         |0.2        |Iris-setosa|[1, 0]        |[1, 0]     |1  |
|4.9         |3.0        |1.4         |0.2        |Iris-setosa|[1, 0]        |[1, 0]     |2  |
|4.7         |3.2        |1.3         |0.2        |Iris-setosa|[1, 0]        |[1, 0]     |3  |
|4.6         |3.1        |1.5         |0.2        |Iris-setosa|[1, 0]        |[1, 0]     |4  |
|5.0         |3.6        |1.4         |0.2        |Iris-setosa|[1, 0]        |[1, 0]     |5  |
+------------+-----------+------------+-----------+-----------+--------------+-----------+---+
only showing top 5 rows



In [24]:
exploded_values = sql_output.select("id", F.posexplode("one_hot_values"))
exploded_values.show(5)

+---+---+---+
| id|pos|col|
+---+---+---+
|  1|  0|  1|
|  1|  1|  0|
|  2|  0|  1|
|  2|  1|  0|
|  3|  0|  1|
+---+---+---+
only showing top 5 rows



In [26]:
sql_output = sql_output.join(
    exploded_values,
    on="id",
    how="inner"
)

sql_output.show(10, False)

+---+------------+-----------+------------+-----------+-----------+--------------+-----------+---+---+
|id |sepal_length|sepal_width|petal_length|petal_width|class      |one_hot_values|one_hot_sql|pos|col|
+---+------------+-----------+------------+-----------+-----------+--------------+-----------+---+---+
|1  |5.1         |3.5        |1.4         |0.2        |Iris-setosa|[1, 0]        |[1, 0]     |0  |1  |
|1  |5.1         |3.5        |1.4         |0.2        |Iris-setosa|[1, 0]        |[1, 0]     |1  |0  |
|2  |4.9         |3.0        |1.4         |0.2        |Iris-setosa|[1, 0]        |[1, 0]     |0  |1  |
|2  |4.9         |3.0        |1.4         |0.2        |Iris-setosa|[1, 0]        |[1, 0]     |1  |0  |
|3  |4.7         |3.2        |1.3         |0.2        |Iris-setosa|[1, 0]        |[1, 0]     |0  |1  |
|3  |4.7         |3.2        |1.3         |0.2        |Iris-setosa|[1, 0]        |[1, 0]     |1  |0  |
|4  |4.6         |3.1        |1.5         |0.2        |Iris-setosa|[1, 0]