The `pyspark.sql.functions.udf()` function is a very important function. It allows us to transfer a **user defined function** to a **`pyspark.sql.functions`** function which can act on columns of a DataFrame. It makes data framsformation much more flexible.

Using `udf()` could be tricky. The key to succeed is to understand how to define the `returnType` parameter.

In [1]:
from pyspark import SparkConf, SparkContext
from pyspark.sql import SparkSession
from pyspark.sql.types import *
from pyspark.sql.functions import udf

In [2]:
sc = SparkContext(conf=SparkConf())
spark = SparkSession(sparkContext=sc)

In [3]:
mtcars = spark.read.csv('data/mtcars.csv', inferSchema=True, header=True)
mtcars = mtcars.withColumnRenamed('_c0', 'model')
mtcars.show(5)

+-----------------+----+---+-----+---+----+-----+-----+---+---+----+----+
|            model| mpg|cyl| disp| hp|drat|   wt| qsec| vs| am|gear|carb|
+-----------------+----+---+-----+---+----+-----+-----+---+---+----+----+
|        Mazda RX4|21.0|  6|160.0|110| 3.9| 2.62|16.46|  0|  1|   4|   4|
|    Mazda RX4 Wag|21.0|  6|160.0|110| 3.9|2.875|17.02|  0|  1|   4|   4|
|       Datsun 710|22.8|  4|108.0| 93|3.85| 2.32|18.61|  1|  1|   4|   1|
|   Hornet 4 Drive|21.4|  6|258.0|110|3.08|3.215|19.44|  1|  0|   3|   1|
|Hornet Sportabout|18.7|  8|360.0|175|3.15| 3.44|17.02|  0|  0|   3|   2|
+-----------------+----+---+-----+---+----+-----+-----+---+---+----+----+
only showing top 5 rows



**The structure of the schema passed to `returnType` has to match the data structure of the return from the user defined function**.

**The input values of the udf function are DataFrame column names, NOT `column` object**.

**Case 1**: divide **disp** by **hp** and put the result to a new column

The user defined function returns a float value.

In [4]:
def disp_by_hp(disp, hp):
    return(disp/hp)

In [5]:
disp_by_hp_udf = udf(disp_by_hp, returnType=FloatType())

In [6]:
mtcars.select([eval('mtcars.' + x) for x in mtcars.columns] + [disp_by_hp_udf('disp', 'hp').name('disp/hp')]).show()

+-------------------+----+---+-----+---+----+-----+-----+---+---+----+----+---------+
|              model| mpg|cyl| disp| hp|drat|   wt| qsec| vs| am|gear|carb|  disp/hp|
+-------------------+----+---+-----+---+----+-----+-----+---+---+----+----+---------+
|          Mazda RX4|21.0|  6|160.0|110| 3.9| 2.62|16.46|  0|  1|   4|   4|1.4545455|
|      Mazda RX4 Wag|21.0|  6|160.0|110| 3.9|2.875|17.02|  0|  1|   4|   4|1.4545455|
|         Datsun 710|22.8|  4|108.0| 93|3.85| 2.32|18.61|  1|  1|   4|   1|1.1612903|
|     Hornet 4 Drive|21.4|  6|258.0|110|3.08|3.215|19.44|  1|  0|   3|   1|2.3454545|
|  Hornet Sportabout|18.7|  8|360.0|175|3.15| 3.44|17.02|  0|  0|   3|   2| 2.057143|
|            Valiant|18.1|  6|225.0|105|2.76| 3.46|20.22|  1|  0|   3|   1| 2.142857|
|         Duster 360|14.3|  8|360.0|245|3.21| 3.57|15.84|  0|  0|   3|   4|1.4693878|
|          Merc 240D|24.4|  4|146.7| 62|3.69| 3.19| 20.0|  1|  0|   4|   2| 2.366129|
|           Merc 230|22.8|  4|140.8| 95|3.92| 3.15| 22

**Case 2**: merge values from two columns into an array.

In [7]:
def merge_two_columns(col1, col2):
    return([col1, col2])

In [8]:
array_type = ArrayType(FloatType())

In [9]:
array_merge_two_columns_udf = udf(merge_two_columns, returnType=array_type)

In [10]:
mtcars.select([eval('mtcars.' + x) for x in mtcars.columns] + 
              [array_merge_two_columns_udf('mpg', 'disp').name('merged_col')]).show(5, truncate=False)

+-----------------+----+---+-----+---+----+-----+-----+---+---+----+----+-------------+
|model            |mpg |cyl|disp |hp |drat|wt   |qsec |vs |am |gear|carb|merged_col   |
+-----------------+----+---+-----+---+----+-----+-----+---+---+----+----+-------------+
|Mazda RX4        |21.0|6  |160.0|110|3.9 |2.62 |16.46|0  |1  |4   |4   |[21.0, 160.0]|
|Mazda RX4 Wag    |21.0|6  |160.0|110|3.9 |2.875|17.02|0  |1  |4   |4   |[21.0, 160.0]|
|Datsun 710       |22.8|4  |108.0|93 |3.85|2.32 |18.61|1  |1  |4   |1   |[22.8, 108.0]|
|Hornet 4 Drive   |21.4|6  |258.0|110|3.08|3.215|19.44|1  |0  |3   |1   |[21.4, 258.0]|
|Hornet Sportabout|18.7|8  |360.0|175|3.15|3.44 |17.02|0  |0  |3   |2   |[18.7, 360.0]|
+-----------------+----+---+-----+---+----+-----+-----+---+---+----+----+-------------+
only showing top 5 rows



## `ArrayType` vs. `StructType`

Both `ArrayType` and `StructType` can be used to build `returnType` for a list. The difference is: 

1. `ArrayType` requires all elements in the list have the same `elementType`, while `StructType` can have different `elementTypes`.
2. `StructType` represents a `Row` object.

**Define an `ArrayType` with elementType being `FloatType`.**

In [11]:
array_type = ArrayType(FloatType())
array_merge_two_columns_udf = udf(merge_two_columns, returnType=array_type)

**Define a `StructType` with one elementType being `StringType` and the other being `FloatType`.**

In [12]:
struct_type = StructType([
    StructField('f1', StringType()),
    StructField('f2', FloatType())
])
struct_merge_two_columns_udf = udf(merge_two_columns, returnType=struct_type)

**Transform data**

In [13]:
mtcars_array_type = mtcars.select([eval('mtcars.' + x) for x in mtcars.columns] + [array_merge_two_columns_udf('mpg', 'disp').name('merged_col')])
mtcars_struct_type = mtcars.select([eval('mtcars.' + x) for x in mtcars.columns] + [struct_merge_two_columns_udf('model', 'disp').name('merged_col')])

The **merged_col** in `mtcars_struct_type` is a Row object, but not in `mtcars_array_type`.

In [14]:
mtcars_array_type.rdd.take(2)

[Row(model='Mazda RX4', mpg=21.0, cyl=6, disp=160.0, hp=110, drat=3.9, wt=2.62, qsec=16.46, vs=0, am=1, gear=4, carb=4, merged_col=[21.0, 160.0]),
 Row(model='Mazda RX4 Wag', mpg=21.0, cyl=6, disp=160.0, hp=110, drat=3.9, wt=2.875, qsec=17.02, vs=0, am=1, gear=4, carb=4, merged_col=[21.0, 160.0])]

In [15]:
mtcars_struct_type.rdd.take(2)

[Row(model='Mazda RX4', mpg=21.0, cyl=6, disp=160.0, hp=110, drat=3.9, wt=2.62, qsec=16.46, vs=0, am=1, gear=4, carb=4, merged_col=Row(f1='Mazda RX4', f2=160.0)),
 Row(model='Mazda RX4 Wag', mpg=21.0, cyl=6, disp=160.0, hp=110, drat=3.9, wt=2.875, qsec=17.02, vs=0, am=1, gear=4, carb=4, merged_col=Row(f1='Mazda RX4 Wag', f2=160.0))]