# Imports

In [1]:
import pandas as pd
from typing import Iterator, Tuple
from pyspark.sql import SparkSession, Window
import pyspark.sql.functions as F
from pyspark.sql.functions import pandas_udf
from pyspark.sql.types import *

In [2]:
spark = (SparkSession
         .builder
         .appName('Chapter 5')
         .enableHiveSupport()
         .getOrCreate())
spark

# User Defined Functions (UDFs)

It extends the capability of Spark by defining our own functions in Python. They do not persist across SparkSessions, so they are only avaialble in the same session they were defined (registered). Spark will store them in metastore.

In [3]:
x = spark.range(10)

In [4]:
x.createOrReplaceTempView('udf_test')

In [5]:
def cubed(s):
    return s * s * s

In [6]:
# This will make it available to use (stored in sql metastore)
spark.udf.register(name='cube', f=cubed, returnType=LongType())

<function __main__.cubed(s)>

In [7]:
spark.sql('SELECT id, cube(id) FROM udf_test').show(10)

+---+--------+
| id|cube(id)|
+---+--------+
|  0|       0|
|  1|       1|
|  2|       8|
|  3|      27|
|  4|      64|
|  5|     125|
|  6|     216|
|  7|     343|
|  8|     512|
|  9|     729|
+---+--------+



Spark SQL (SQL, DataFrame, and Dataset) does not guarantee the order of execution of subexpressions. For example,
```sql
spark.sql("SELECT s FROM test1 WHERE s IS NOT NULL AND strlen(s) > 1")
```
The above query does not guarantee that checking for NULL will be evaluated before the `strlen(s)`. Therefore, it is recommended that we do the following:
- Make the __UDF__ itself null-aware and do null checking inside the UDF.
- Use `IF or CASE WHEN` expressions to do the null check and invoke the UDF in a conditional branch.

We can now use pandas_udf to use vectorized operations instead of operating row by row which is very slow. With Pandas_UDFs, Pandas uses Apache Arrow to transfer data and Pandas will work on the data. Once the data is in Apache Arrow format, there is no need to serialize/pickle data. To make a Python UDF a Pandas UDF, we need to decorate the function using `@pandas_udf` and use __type hints__.

## Series to Series

Take `pd.Series` as input and returns `pd.Series` as output with the same length.

In [8]:
pandas_udf('long')
def squares(s: pd.Series) -> pd.Series:
    return s * 2

In [9]:
x = pd.Series(range(10))
x

0    0
1    1
2    2
3    3
4    4
5    5
6    6
7    7
8    8
9    9
dtype: int64

In [10]:
squares(x)

0     0
1     2
2     4
3     6
4     8
5    10
6    12
7    14
8    16
9    18
dtype: int64

In [11]:
df = spark.range(10)

In [12]:
df.select('id', squares(F.col('id'))).show()

+---+--------+
| id|(id * 2)|
+---+--------+
|  0|       0|
|  1|       2|
|  2|       4|
|  3|       6|
|  4|       8|
|  5|      10|
|  6|      12|
|  7|      14|
|  8|      16|
|  9|      18|
+---+--------+



We can also use the UDFs on pandas locally.

In [13]:
squares(pd.Series(range(10)))

0     0
1     2
2     4
3     6
4     8
5    10
6    12
7    14
8    16
9    18
dtype: int64

## Iterator of Series to Iterator of Series

- The Python function
    - Takes an iterator of batches instead of a single input batch as input.
    - Returns an iterator of output batches instead of a single output batch.
- The length of the entire output in the iterator should be the same as the length of the entire input.
- The wrapped pandas UDF takes a single Spark column as an input.
- Specify the Python type hint as Iterator[pandas.Series] -> Iterator[pandas.Series].

It is useful in the case of loading a machine learning model file to apply inference to every input batch.

In [15]:
pdf = pd.DataFrame([1, 2, 3], columns=["x"])
df = spark.createDataFrame(pdf)

# When the UDF is called with the column,
# the input to the underlying function is an iterator of pd.Series.
@pandas_udf("long")
def plus_one(batch_iter: Iterator[pd.Series]) -> Iterator[pd.Series]:
    for x in batch_iter:
        print(type(x))
        yield x + 1

df.select(plus_one(F.col("x"))).show()
# df.select(F.col("x")).show()

In [28]:
# In the UDF, you can initialize some state before processing batches.
# Wrap your code with try/finally or use context managers to ensure
# the release of resources at the end.
y_bc = spark.sparkContext.broadcast(1)

@pandas_udf("long")
def plus_y(batch_iter: Iterator[pd.Series]) -> Iterator[pd.Series]:
    y = y_bc.value  # initialize states
    try:
        for x in batch_iter:
            yield x + y
    finally:
        pass  # release resources here, if any

df.select(plus_y(F.col("x"))).show()

## Iterator of multiple Series to Iterator of multiple Series

The differences with Iterator of Series to Iterator of Series are:
- The underlying Python function takes an iterator of a tuple of pandas Series.
- The wrapped pandas UDF takes multiple Spark columns as an input.

In [32]:
pdf = pd.DataFrame([1, 2, 3], columns=["x"])
df = spark.createDataFrame(pdf)

@pandas_udf("long")
def multiply_two_cols(
        iterator: Iterator[Tuple[pd.Series, pd.Series]]) -> Iterator[pd.Series]:
    for a, b in iterator:
        yield a * b

df.select(multiply_two_cols("x", "x")).show()

## Series to scalar

It is typically used in aggregation from one or more pandas Series to scalar such as __mean__. The type of the scalar can be Python primitive type or Numpy. It loads all the data into memory.

In [33]:
df = spark.createDataFrame(
    [(1, 1.0), (1, 2.0), (2, 3.0), (2, 5.0), (2, 10.0)],
    ("id", "v"))
df.show()

+---+----+
| id|   v|
+---+----+
|  1| 1.0|
|  1| 2.0|
|  2| 3.0|
|  2| 5.0|
|  2|10.0|
+---+----+



In [34]:
@pandas_udf('double')
def mean_udf(x: pd.Series) -> float:
    return x.mean()

In [36]:
# df.select(mean_udf('v')).show()

# Higher Order Functions in DataFrames and Spark SQL

There are typically two solutions for the manipulation of complex data types:
- Exploding the nested structure into individual rows, applying some function, and then re-creating the nested structure as noted in the code snippet below (see Option 1)
- Building a User Defined Function (UDF)

In [16]:
# Create an array dataset
arrayData = [[1, (1, 2, 3)], [2, (2, 3, 4)], [3, (3, 4, 5)]]

# Create schema
from pyspark.sql.types import *
arraySchema = (StructType([
      StructField("id", IntegerType(), True), 
      StructField("values", ArrayType(IntegerType()), True)
      ]))

# Create DataFrame
df = spark.createDataFrame(spark.sparkContext.parallelize(arrayData), arraySchema)
df.createOrReplaceTempView("table")
df.printSchema()
df.show()

root
 |-- id: integer (nullable = true)
 |-- values: array (nullable = true)
 |    |-- element: integer (containsNull = true)

+---+---------+
| id|   values|
+---+---------+
|  1|[1, 2, 3]|
|  2|[2, 3, 4]|
|  3|[3, 4, 5]|
+---+---------+



explode(values) creates a new row (with the id) for each element (value) within values.

In [17]:
spark.sql('''
SELECT id, explode(values) AS value
FROM table
''').show()

+---+-----+
| id|value|
+---+-----+
|  1|    1|
|  1|    2|
|  1|    3|
|  2|    2|
|  2|    3|
|  2|    4|
|  3|    3|
|  3|    4|
|  3|    5|
+---+-----+



## Option 1: Explode & Collect

In [18]:
spark.sql('''
SELECT id, collect_list(value + 1) AS new_values
FROM (SELECT id, explode(values) AS value
    FROM table) AS A
GROUP BY id
''').show()

+---+----------+
| id|new_values|
+---+----------+
|  1| [2, 3, 4]|
|  3| [4, 5, 6]|
|  2| [3, 4, 5]|
+---+----------+



The order of values may not be the same as their order in the original array due to shuffling. This option is very expensive because we use `GroupBy` and it requires shuffling and values can be very wide/long. `collect_list` may cause out of memory issues for larger datasets.

## Option 2: User Defined Function

In [19]:
def plus_one(values):
    return [v + 1 for v in values]

In [20]:
spark.udf.register('plus_one', plus_one, returnType=ArrayType(IntegerType()))

<function __main__.plus_one(values)>

In [21]:
spark.sql(
'''
SELECT id, plus_one(values) AS new_values
FROM table
''').show()

+---+----------+
| id|new_values|
+---+----------+
|  1| [2, 3, 4]|
|  2| [3, 4, 5]|
|  3| [4, 5, 6]|
+---+----------+



The major drawback of this approach is that it requires serialization/deserialization which may be expensive.

# Higher Order Functions

In [22]:
schema = StructType([StructField("celsius", ArrayType(IntegerType()))])
t_list = [[35, 36, 32, 30, 40, 42, 38]], [[31, 32, 34, 55, 56]]
t_c = spark.createDataFrame(t_list, schema)
t_c.createOrReplaceTempView("tC")
t_c.show()

+--------------------+
|             celsius|
+--------------------+
|[35, 36, 32, 30, ...|
|[31, 32, 34, 55, 56]|
+--------------------+



## `transform`

It takes an array and a lambda function, applies the function to each element in the array and returns an array. It is the same as UDF but more efficient.

In [23]:
spark.sql(
'''
SELECT celsius, transform(celsius, v -> ((v * 9) div 5) + 32) AS fahrenheit
FROM tC
''').show()

+--------------------+--------------------+
|             celsius|          fahrenheit|
+--------------------+--------------------+
|[35, 36, 32, 30, ...|[95, 96, 89, 86, ...|
|[31, 32, 34, 55, 56]|[87, 89, 93, 131,...|
+--------------------+--------------------+



## `filter`

Returns only the elements where the predicate is true

In [24]:
spark.sql(
'''
SELECT celsius, filter(celsius, v -> v > 38) AS fahrenheit
FROM tC
''').show()

+--------------------+----------+
|             celsius|fahrenheit|
+--------------------+----------+
|[35, 36, 32, 30, ...|  [40, 42]|
|[31, 32, 34, 55, 56]|  [55, 56]|
+--------------------+----------+



## `exists`

Returns __True__ if the predicate is true for any element in the array.

In [25]:
spark.sql(
'''
SELECT celsius, exists(celsius, v -> v = 38) AS threshold
FROM tC
''').show()

+--------------------+---------+
|             celsius|threshold|
+--------------------+---------+
|[35, 36, 32, 30, ...|     true|
|[31, 32, 34, 55, 56]|    false|
+--------------------+---------+



## `reduce`

In [27]:
spark.sql(
'''
SELECT celsius, reduce(
                    celsius, 
                    0, 
                    (x, y) -> x + y,
                    y -> (y div size(celsius) * 9 div 5) + 32) AS avg_fahrenheit
FROM tC
''').show()

# DataFrames and Spark SQL Common Relational Operators 

## Data 

In [28]:
airports_na = (spark
               .read
               .format('csv')
               .option('header', 'true')
               .option('inferSchema', 'true')
               .option('sep', '\t')
               .load('data/airport-codes-na.txt'))
airports_na.show(5)

+----------+-----+-------+----+
|      City|State|Country|IATA|
+----------+-----+-------+----+
|Abbotsford|   BC| Canada| YXX|
|  Aberdeen|   SD|    USA| ABR|
|   Abilene|   TX|    USA| ABI|
|     Akron|   OH|    USA| CAK|
|   Alamosa|   CO|    USA| ALS|
+----------+-----+-------+----+
only showing top 5 rows



In [29]:
airports_na.createOrReplaceTempView('airports_na')

In [30]:
schema = '''
`date` STRING, `delay` INT, `distance` INT, `origin` STRING, `destination` STRING
'''

In [31]:
departure_delays = (spark
               .read
               .format('csv')
               .option('header', 'true')
               .option('schema', schema)
               .load('data/departuredelays.csv'))
departure_delays.show(5)

+--------+-----+--------+------+-----------+
|    date|delay|distance|origin|destination|
+--------+-----+--------+------+-----------+
|01011245|    6|     602|   ABE|        ATL|
|01020600|   -8|     369|   ABE|        DTW|
|01021245|   -2|     602|   ABE|        ATL|
|01020605|   -4|     602|   ABE|        ATL|
|01031245|   -4|     602|   ABE|        ATL|
+--------+-----+--------+------+-----------+
only showing top 5 rows



In [32]:
departure_delays.printSchema()

root
 |-- date: string (nullable = true)
 |-- delay: string (nullable = true)
 |-- distance: string (nullable = true)
 |-- origin: string (nullable = true)
 |-- destination: string (nullable = true)



In [33]:
departure_delays = (departure_delays
                    .withColumn('delay', F.expr('CAST(delay AS INT) AS delay'))
                    .withColumn('distance', F.expr('CAST(distance AS INT) AS distance')))

In [34]:
departure_delays.printSchema()

root
 |-- date: string (nullable = true)
 |-- delay: integer (nullable = true)
 |-- distance: integer (nullable = true)
 |-- origin: string (nullable = true)
 |-- destination: string (nullable = true)



In [35]:
departure_delays.createOrReplaceTempView('departure_delays')

In [36]:
foo = (departure_delays
       .filter(F.expr(
           """origin == 'SEA' and destination == 'SFO' and date like '01010%' and delay > 0""")
              ))
foo.show()

+--------+-----+--------+------+-----------+
|    date|delay|distance|origin|destination|
+--------+-----+--------+------+-----------+
|01010710|   31|     590|   SEA|        SFO|
|01010955|  104|     590|   SEA|        SFO|
|01010730|    5|     590|   SEA|        SFO|
+--------+-----+--------+------+-----------+



## Union

It is the same as `UNION ALL` in SQL. To get the SQL-like `UNION`, use `distinct` after `union`.

In [37]:
bar = departure_delays.union(foo)

In [38]:
(bar
 .filter(F.expr(
           """origin == 'SEA' and destination == 'SFO' and date like '01010%' and delay > 0""")
              )
 .show())

+--------+-----+--------+------+-----------+
|    date|delay|distance|origin|destination|
+--------+-----+--------+------+-----------+
|01010710|   31|     590|   SEA|        SFO|
|01010955|  104|     590|   SEA|        SFO|
|01010730|    5|     590|   SEA|        SFO|
|01010710|   31|     590|   SEA|        SFO|
|01010955|  104|     590|   SEA|        SFO|
|01010730|    5|     590|   SEA|        SFO|
+--------+-----+--------+------+-----------+



We can also use `UnionByName`

## Joins

The default join is __inner__.

In [39]:
foo.join(airports_na, foo.origin == airports_na.IATA).show()

+--------+-----+--------+------+-----------+-------+-----+-------+----+
|    date|delay|distance|origin|destination|   City|State|Country|IATA|
+--------+-----+--------+------+-----------+-------+-----+-------+----+
|01010710|   31|     590|   SEA|        SFO|Seattle|   WA|    USA| SEA|
|01010955|  104|     590|   SEA|        SFO|Seattle|   WA|    USA| SEA|
|01010730|    5|     590|   SEA|        SFO|Seattle|   WA|    USA| SEA|
+--------+-----+--------+------+-----------+-------+-----+-------+----+



## Windowing

In [40]:
spark.sql("DROP TABLE IF EXISTS departure_delays_window")
spark.sql("""
CREATE TABLE departure_delays_window AS
SELECT origin, destination, sum(delay) as total_delays 
  FROM departure_delays 
 WHERE origin IN ('SEA', 'SFO', 'JFK') 
   AND destination IN ('SEA', 'SFO', 'JFK', 'DEN', 'ORD', 'LAX', 'ATL') 
 GROUP BY origin, destination
""")
spark.sql("""SELECT * FROM departure_delays_window""").show()

+------+-----------+------------+
|origin|destination|total_delays|
+------+-----------+------------+
|   SFO|        LAX|       40798|
|   SEA|        ORD|       10041|
|   SEA|        JFK|        4667|
|   SEA|        SFO|       22293|
|   JFK|        DEN|        4315|
|   JFK|        LAX|       35755|
|   SFO|        JFK|       24100|
|   JFK|        SEA|        7856|
|   SEA|        LAX|        9359|
|   SFO|        ATL|        5091|
|   JFK|        SFO|       35619|
|   SEA|        ATL|        4535|
|   JFK|        ORD|        5608|
|   SEA|        DEN|       13645|
|   JFK|        ATL|       12141|
|   SFO|        ORD|       27412|
|   SFO|        DEN|       18688|
|   SFO|        SEA|       17080|
+------+-----------+------------+



In [41]:
spark.sql(
"""
SELECT origin, destination, total_delays, rank
FROM (
    SELECT origin, destination, total_delays, dense_rank() OVER (PARTITION BY origin ORDER BY total_delays DESC) AS rank
    FROM departure_delays_window
) AS A
WHERE rank <= 3
""").show()

+------+-----------+------------+----+
|origin|destination|total_delays|rank|
+------+-----------+------------+----+
|   SEA|        SFO|       22293|   1|
|   SEA|        DEN|       13645|   2|
|   SEA|        ORD|       10041|   3|
|   SFO|        LAX|       40798|   1|
|   SFO|        ORD|       27412|   2|
|   SFO|        JFK|       24100|   3|
|   JFK|        LAX|       35755|   1|
|   JFK|        SFO|       35619|   2|
|   JFK|        ATL|       12141|   3|
+------+-----------+------------+----+



## Adding new columns 

In [42]:
foo2 = (foo
        .withColumn('status', F.expr("CASE WHEN delay <= 10 THEN 'on_time' ELSE 'delayed' END")))
foo2.show()

+--------+-----+--------+------+-----------+-------+
|    date|delay|distance|origin|destination| status|
+--------+-----+--------+------+-----------+-------+
|01010710|   31|     590|   SEA|        SFO|delayed|
|01010955|  104|     590|   SEA|        SFO|delayed|
|01010730|    5|     590|   SEA|        SFO|on_time|
+--------+-----+--------+------+-----------+-------+



## Dropping columns

In [43]:
foo3 = foo2.drop('delay')
foo3.show()

+--------+--------+------+-----------+-------+
|    date|distance|origin|destination| status|
+--------+--------+------+-----------+-------+
|01010710|     590|   SEA|        SFO|delayed|
|01010955|     590|   SEA|        SFO|delayed|
|01010730|     590|   SEA|        SFO|on_time|
+--------+--------+------+-----------+-------+



## Renaming Columns

In [44]:
foo4 = (foo3
        .withColumnRenamed('status', 'flight_status'))
foo4.show()

+--------+--------+------+-----------+-------------+
|    date|distance|origin|destination|flight_status|
+--------+--------+------+-----------+-------------+
|01010710|     590|   SEA|        SFO|      delayed|
|01010955|     590|   SEA|        SFO|      delayed|
|01010730|     590|   SEA|        SFO|      on_time|
+--------+--------+------+-----------+-------------+



## Pivoting

In [45]:
spark.sql(
"""
SELECT destination, CAST(SUBSTRING(date, 0, 2) AS INT), delay
FROM departure_delays
WHERE origin = 'SEA'
"""
).show()

+-----------+----------------------------------+-----+
|destination|CAST(substring(date, 0, 2) AS INT)|delay|
+-----------+----------------------------------+-----+
|        ORD|                                 1|   92|
|        JFK|                                 1|   -7|
|        DFW|                                 1|   -5|
|        MIA|                                 1|   -3|
|        DFW|                                 1|   -3|
|        DFW|                                 1|    1|
|        ORD|                                 1|  -10|
|        DFW|                                 1|   -6|
|        DFW|                                 1|   -2|
|        ORD|                                 1|   -3|
|        ORD|                                 1|    0|
|        DFW|                                 1|   23|
|        DFW|                                 1|   36|
|        ORD|                                 1|  298|
|        JFK|                                 1|    4|
|        D

In [46]:
spark.sql(
"""
SELECT *
FROM (
SELECT destination, CAST(SUBSTRING(date, 0, 2) AS INT) AS month, delay
FROM departure_delays
WHERE origin = 'SEA'
)
PIVOT (
    CAST(AVG(delay) AS DECIMAL(4, 2)) AS avg_delay, MAX(delay) AS max_delay
    FOR month in (1 Jan, 2 Feb)
)
"""
).show()

+-----------+-------------+-------------+-------------+-------------+
|destination|Jan_avg_delay|Jan_max_delay|Feb_avg_delay|Feb_max_delay|
+-----------+-------------+-------------+-------------+-------------+
|        GEG|         2.28|           63|         2.87|           60|
|        BUR|        -2.03|           56|        -1.89|           78|
|        SNA|        -3.58|           82|        -0.41|           90|
|        OAK|        15.82|          385|         8.12|          150|
|        DCA|        -1.15|           50|         0.07|           34|
|        KTN|         2.40|           64|         5.00|          128|
|        LIH|        -1.64|           15|        -1.70|           21|
|        IAH|        17.91|          227|         7.79|          466|
|        HNL|         7.43|          294|         2.94|          132|
|        SJC|         4.46|          256|         2.74|           74|
|        CVG|        -0.50|            4|         null|         null|
|        AUS|       