In [62]:
from pyspark.sql import SparkSession
from pyspark.sql.types import LongType, IntegerType

In [63]:
spark = SparkSession.builder.appName("Spark Chap5 Parctice").enableHiveSupport().getOrCreate()

#### Spark SQL UDFs

In [64]:
def cubed(s):
    return s * s * s

In [65]:
spark.udf.register("cubed",cubed, LongType())

<function __main__.cubed(s)>

In [66]:
spark.range(1,9).createOrReplaceTempView("udf_test")

In [67]:
spark.sql("SELECT id,cubed(id) as id_cubed FROM udf_test").show()

+---+--------+
| id|id_cubed|
+---+--------+
|  1|       1|
|  2|       8|
|  3|      27|
|  4|      64|
|  5|     125|
|  6|     216|
|  7|     343|
|  8|     512|
+---+--------+



##### Evaluation order and null checking in Spark SQL

In [68]:
def strlen(s):
    return len(s)

# spark.udf.register("strlen", lambda s: len(s), "int") #Alternate 
spark.udf.register("strlen", strlen, IntegerType())

<function __main__.strlen(s)>

In [69]:
spark.udf.register("strlen_nullsafe", lambda s: len(s) if not s is None else -1, "int")
# spark.sql("select s from test1 where s is not null and strlen_nullsafe(s) > 1") // ok
# spark.sql("select s from test1 where if(s is not null, strlen(s), null) > 1")   // ok

<function __main__.<lambda>(s)>

#### Speeding up and distributing PySpark UDFs with Pandas UDFs

In [70]:
# Pandas udfs over PySpark udfs because, 
# PySpark udfs required data movement btwn JVM and Python,
# which is quite expensive.

# To resolve this Pandas UDFs (also known as vectorized UDFs) were introduced
# Panda UDFs uses Apache Arrow to transfer data and Pandas to work with data.


In [71]:
# Pandas UDfs were split into two API categories
# 1. Pandas UDFs
# 2. Pandas Function APIs

In [72]:
import pandas as pd
from pyspark.sql.functions import col, pandas_udf
from pyspark.sql.types import LongType

In [73]:
def cubed(a: pd.Series) -> pd.Series:
    return a * a * a

In [74]:
cubed_udf = pandas_udf(cubed, returnType=LongType())

x = pd.Series([1,2,3])

print(cubed(x))

df = spark.range(1,4)

df.select("id", cubed_udf(col("id"))).show()

0     1
1     8
2    27
dtype: int64


Py4JJavaError: An error occurred while calling o185.showString.
: org.apache.spark.SparkException: Job aborted due to stage failure: Task 0 in stage 69.0 failed 1 times, most recent failure: Lost task 0.0 in stage 69.0 (TID 158, 10.0.2.93, executor driver): java.lang.UnsupportedOperationException: sun.misc.Unsafe or java.nio.DirectByteBuffer.<init>(long, int) not available
	at io.netty.util.internal.PlatformDependent.directBuffer(PlatformDependent.java:490)
	at io.netty.buffer.NettyArrowBuf.getDirectBuffer(NettyArrowBuf.java:243)
	at io.netty.buffer.NettyArrowBuf.nioBuffer(NettyArrowBuf.java:233)
	at io.netty.buffer.ArrowBuf.nioBuffer(ArrowBuf.java:245)
	at org.apache.arrow.vector.ipc.message.ArrowRecordBatch.computeBodyLength(ArrowRecordBatch.java:222)
	at org.apache.arrow.vector.ipc.message.MessageSerializer.serialize(MessageSerializer.java:240)
	at org.apache.arrow.vector.ipc.ArrowWriter.writeRecordBatch(ArrowWriter.java:132)
	at org.apache.arrow.vector.ipc.ArrowWriter.writeBatch(ArrowWriter.java:120)
	at org.apache.spark.sql.execution.python.ArrowPythonRunner$$anon$1.$anonfun$writeIteratorToStream$1(ArrowPythonRunner.scala:94)
	at scala.runtime.java8.JFunction0$mcV$sp.apply(JFunction0$mcV$sp.java:23)
	at org.apache.spark.util.Utils$.tryWithSafeFinally(Utils.scala:1377)
	at org.apache.spark.sql.execution.python.ArrowPythonRunner$$anon$1.writeIteratorToStream(ArrowPythonRunner.scala:101)
	at org.apache.spark.api.python.BasePythonRunner$WriterThread.$anonfun$run$1(PythonRunner.scala:383)
	at org.apache.spark.util.Utils$.logUncaughtExceptions(Utils.scala:1934)
	at org.apache.spark.api.python.BasePythonRunner$WriterThread.run(PythonRunner.scala:218)

Driver stacktrace:
	at org.apache.spark.scheduler.DAGScheduler.failJobAndIndependentStages(DAGScheduler.scala:2059)
	at org.apache.spark.scheduler.DAGScheduler.$anonfun$abortStage$2(DAGScheduler.scala:2008)
	at org.apache.spark.scheduler.DAGScheduler.$anonfun$abortStage$2$adapted(DAGScheduler.scala:2007)
	at scala.collection.mutable.ResizableArray.foreach(ResizableArray.scala:62)
	at scala.collection.mutable.ResizableArray.foreach$(ResizableArray.scala:55)
	at scala.collection.mutable.ArrayBuffer.foreach(ArrayBuffer.scala:49)
	at org.apache.spark.scheduler.DAGScheduler.abortStage(DAGScheduler.scala:2007)
	at org.apache.spark.scheduler.DAGScheduler.$anonfun$handleTaskSetFailed$1(DAGScheduler.scala:973)
	at org.apache.spark.scheduler.DAGScheduler.$anonfun$handleTaskSetFailed$1$adapted(DAGScheduler.scala:973)
	at scala.Option.foreach(Option.scala:407)
	at org.apache.spark.scheduler.DAGScheduler.handleTaskSetFailed(DAGScheduler.scala:973)
	at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.doOnReceive(DAGScheduler.scala:2239)
	at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.onReceive(DAGScheduler.scala:2188)
	at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.onReceive(DAGScheduler.scala:2177)
	at org.apache.spark.util.EventLoop$$anon$1.run(EventLoop.scala:49)
	at org.apache.spark.scheduler.DAGScheduler.runJob(DAGScheduler.scala:775)
	at org.apache.spark.SparkContext.runJob(SparkContext.scala:2114)
	at org.apache.spark.SparkContext.runJob(SparkContext.scala:2135)
	at org.apache.spark.SparkContext.runJob(SparkContext.scala:2154)
	at org.apache.spark.sql.execution.SparkPlan.executeTake(SparkPlan.scala:472)
	at org.apache.spark.sql.execution.SparkPlan.executeTake(SparkPlan.scala:425)
	at org.apache.spark.sql.execution.CollectLimitExec.executeCollect(limit.scala:47)
	at org.apache.spark.sql.Dataset.collectFromPlan(Dataset.scala:3627)
	at org.apache.spark.sql.Dataset.$anonfun$head$1(Dataset.scala:2697)
	at org.apache.spark.sql.Dataset.$anonfun$withAction$1(Dataset.scala:3618)
	at org.apache.spark.sql.execution.SQLExecution$.$anonfun$withNewExecutionId$5(SQLExecution.scala:100)
	at org.apache.spark.sql.execution.SQLExecution$.withSQLConfPropagated(SQLExecution.scala:160)
	at org.apache.spark.sql.execution.SQLExecution$.$anonfun$withNewExecutionId$1(SQLExecution.scala:87)
	at org.apache.spark.sql.SparkSession.withActive(SparkSession.scala:764)
	at org.apache.spark.sql.execution.SQLExecution$.withNewExecutionId(SQLExecution.scala:64)
	at org.apache.spark.sql.Dataset.withAction(Dataset.scala:3616)
	at org.apache.spark.sql.Dataset.head(Dataset.scala:2697)
	at org.apache.spark.sql.Dataset.take(Dataset.scala:2904)
	at org.apache.spark.sql.Dataset.getRows(Dataset.scala:300)
	at org.apache.spark.sql.Dataset.showString(Dataset.scala:337)
	at jdk.internal.reflect.GeneratedMethodAccessor113.invoke(Unknown Source)
	at java.base/jdk.internal.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)
	at java.base/java.lang.reflect.Method.invoke(Method.java:566)
	at py4j.reflection.MethodInvoker.invoke(MethodInvoker.java:244)
	at py4j.reflection.ReflectionEngine.invoke(ReflectionEngine.java:357)
	at py4j.Gateway.invoke(Gateway.java:282)
	at py4j.commands.AbstractCommand.invokeMethod(AbstractCommand.java:132)
	at py4j.commands.CallCommand.execute(CallCommand.java:79)
	at py4j.GatewayConnection.run(GatewayConnection.java:238)
	at java.base/java.lang.Thread.run(Thread.java:834)
Caused by: java.lang.UnsupportedOperationException: sun.misc.Unsafe or java.nio.DirectByteBuffer.<init>(long, int) not available
	at io.netty.util.internal.PlatformDependent.directBuffer(PlatformDependent.java:490)
	at io.netty.buffer.NettyArrowBuf.getDirectBuffer(NettyArrowBuf.java:243)
	at io.netty.buffer.NettyArrowBuf.nioBuffer(NettyArrowBuf.java:233)
	at io.netty.buffer.ArrowBuf.nioBuffer(ArrowBuf.java:245)
	at org.apache.arrow.vector.ipc.message.ArrowRecordBatch.computeBodyLength(ArrowRecordBatch.java:222)
	at org.apache.arrow.vector.ipc.message.MessageSerializer.serialize(MessageSerializer.java:240)
	at org.apache.arrow.vector.ipc.ArrowWriter.writeRecordBatch(ArrowWriter.java:132)
	at org.apache.arrow.vector.ipc.ArrowWriter.writeBatch(ArrowWriter.java:120)
	at org.apache.spark.sql.execution.python.ArrowPythonRunner$$anon$1.$anonfun$writeIteratorToStream$1(ArrowPythonRunner.scala:94)
	at scala.runtime.java8.JFunction0$mcV$sp.apply(JFunction0$mcV$sp.java:23)
	at org.apache.spark.util.Utils$.tryWithSafeFinally(Utils.scala:1377)
	at org.apache.spark.sql.execution.python.ArrowPythonRunner$$anon$1.writeIteratorToStream(ArrowPythonRunner.scala:101)
	at org.apache.spark.api.python.BasePythonRunner$WriterThread.$anonfun$run$1(PythonRunner.scala:383)
	at org.apache.spark.util.Utils$.logUncaughtExceptions(Utils.scala:1934)
	at org.apache.spark.api.python.BasePythonRunner$WriterThread.run(PythonRunner.scala:218)


#### Higher-Order Functions

In [75]:
from pyspark.sql.types import *

In [76]:
schema = StructType([StructField("celsius", ArrayType(IntegerType()))])

In [77]:
t_list = [[35,36,32,30,40,42,38]],[[31,32,34,55,56]]
t_c = spark.createDataFrame(t_list,schema)
t_c.createOrReplaceTempView("tC")

In [78]:
t_c.show(truncate=False)

+----------------------------+
|celsius                     |
+----------------------------+
|[35, 36, 32, 30, 40, 42, 38]|
|[31, 32, 34, 55, 56]        |
+----------------------------+



###### transform()

In [79]:
spark.sql("""
SELECT celsius, transform(celsius, t -> ((t * 9) div 5) + 32) as farenheit from tc

""").show(truncate=False)

+----------------------------+-------------------------------+
|celsius                     |farenheit                      |
+----------------------------+-------------------------------+
|[35, 36, 32, 30, 40, 42, 38]|[95, 96, 89, 86, 104, 107, 100]|
|[31, 32, 34, 55, 56]        |[87, 89, 93, 131, 132]         |
+----------------------------+-------------------------------+



###### filter()

In [80]:
spark.sql("""
SELECT celsius, filter(celsius, t -> t > 38) as high from tC
""").show()

+--------------------+--------+
|             celsius|    high|
+--------------------+--------+
|[35, 36, 32, 30, ...|[40, 42]|
|[31, 32, 34, 55, 56]|[55, 56]|
+--------------------+--------+



###### exists()

In [81]:
spark.sql("""
SELECT celsius, exists(celsius, t -> t = 38) as threshold FROM tC
""").show(truncate=False)

+----------------------------+---------+
|celsius                     |threshold|
+----------------------------+---------+
|[35, 36, 32, 30, 40, 42, 38]|true     |
|[31, 32, 34, 55, 56]        |false    |
+----------------------------+---------+



###### reduce()

In [82]:
# spark.sql("""
# SELECT celsius,
# reduce(
# celsius,
# 0,
# (t, acc) -> t + acc,
# acc -> (acc div size(celsius) * 9 div 5) + 32
# ) as avgFahrenheit
# FROM tC
# """).show()

### Common DataFrames and Spark SQL Operations

#### UNION

In [83]:
from pyspark.sql.functions import expr

In [84]:
tripdelaysFilePath = "../databricks-datasets/learning-spark-v2/flights/departuredelays.csv"

In [85]:
airportsnaFilePath = "../databricks-datasets/learning-spark-v2/flights/airport-codes-na.txt"

In [86]:
airportsna = (spark.read
.format("csv")
.options(header="true", inferSchema="true", sep="\t")
.load(airportsnaFilePath))

In [87]:
airportsna.createOrReplaceTempView("airports_na")

In [88]:
departureDelays = (spark.read
.format("csv")
.options(header="true")
.load(tripdelaysFilePath))

In [89]:
departureDelays = (departureDelays
.withColumn("delay", expr("CAST(delay as INT) as delay"))
.withColumn("distance", expr("CAST(distance as INT) as distance")))

In [90]:
departureDelays.createOrReplaceTempView("departureDelays")

In [91]:
foo = (
    departureDelays
.filter(expr("""origin == 'SEA' and destination == 'SFO' and
date like '01010%' and delay > 0""")))

In [92]:
foo.createOrReplaceTempView("foo")

In [93]:
departureDelays.count()

1391578

In [94]:
foo.count()

3

In [95]:
spark.sql("SELECT * FROM airports_na LIMIT 10").show()

+-----------+-----+-------+----+
|       City|State|Country|IATA|
+-----------+-----+-------+----+
| Abbotsford|   BC| Canada| YXX|
|   Aberdeen|   SD|    USA| ABR|
|    Abilene|   TX|    USA| ABI|
|      Akron|   OH|    USA| CAK|
|    Alamosa|   CO|    USA| ALS|
|     Albany|   GA|    USA| ABY|
|     Albany|   NY|    USA| ALB|
|Albuquerque|   NM|    USA| ABQ|
| Alexandria|   LA|    USA| AEX|
|  Allentown|   PA|    USA| ABE|
+-----------+-----+-------+----+



In [96]:
spark.sql("SELECT * FROM departureDelays LIMIT 10").show()

+--------+-----+--------+------+-----------+
|    date|delay|distance|origin|destination|
+--------+-----+--------+------+-----------+
|01011245|    6|     602|   ABE|        ATL|
|01020600|   -8|     369|   ABE|        DTW|
|01021245|   -2|     602|   ABE|        ATL|
|01020605|   -4|     602|   ABE|        ATL|
|01031245|   -4|     602|   ABE|        ATL|
|01030605|    0|     602|   ABE|        ATL|
|01041243|   10|     602|   ABE|        ATL|
|01040605|   28|     602|   ABE|        ATL|
|01051245|   88|     602|   ABE|        ATL|
|01050605|    9|     602|   ABE|        ATL|
+--------+-----+--------+------+-----------+



In [97]:
spark.sql("SELECT * FROM foo").show()

+--------+-----+--------+------+-----------+
|    date|delay|distance|origin|destination|
+--------+-----+--------+------+-----------+
|01010710|   31|     590|   SEA|        SFO|
|01010955|  104|     590|   SEA|        SFO|
|01010730|    5|     590|   SEA|        SFO|
+--------+-----+--------+------+-----------+



In [98]:
departureDelays.filter(
    expr("""origin == 'SEA' AND destination == 'SFO' AND date LIKE '01010%' AND delay > 0""")
).show()

+--------+-----+--------+------+-----------+
|    date|delay|distance|origin|destination|
+--------+-----+--------+------+-----------+
|01010710|   31|     590|   SEA|        SFO|
|01010955|  104|     590|   SEA|        SFO|
|01010730|    5|     590|   SEA|        SFO|
+--------+-----+--------+------+-----------+



In [99]:
bar = departureDelays.union(foo)
bar.createOrReplaceTempView("bar")

In [100]:
bar.filter(
    expr("""origin == 'SEA' AND destination == 'SFO' AND date LIKE '01010%' AND delay > 0""")
).show()

+--------+-----+--------+------+-----------+
|    date|delay|distance|origin|destination|
+--------+-----+--------+------+-----------+
|01010710|   31|     590|   SEA|        SFO|
|01010955|  104|     590|   SEA|        SFO|
|01010730|    5|     590|   SEA|        SFO|
|01010710|   31|     590|   SEA|        SFO|
|01010955|  104|     590|   SEA|        SFO|
|01010730|    5|     590|   SEA|        SFO|
+--------+-----+--------+------+-----------+



In [101]:
bar.count()

1391581

In [102]:
spark.sql("""
SELECT *
FROM bar
WHERE origin = 'SEA'
AND destination = 'SFO'
AND date LIKE '01010%'
AND delay > 0
""").show()

+--------+-----+--------+------+-----------+
|    date|delay|distance|origin|destination|
+--------+-----+--------+------+-----------+
|01010710|   31|     590|   SEA|        SFO|
|01010955|  104|     590|   SEA|        SFO|
|01010730|    5|     590|   SEA|        SFO|
|01010710|   31|     590|   SEA|        SFO|
|01010955|  104|     590|   SEA|        SFO|
|01010730|    5|     590|   SEA|        SFO|
+--------+-----+--------+------+-----------+



#### JOIN

In [103]:
foo.join(
airportsna,
airportsna.IATA == foo.origin
).select("City", "State", "date", "delay", "distance", "destination").show()

+-------+-----+--------+-----+--------+-----------+
|   City|State|    date|delay|distance|destination|
+-------+-----+--------+-----+--------+-----------+
|Seattle|   WA|01010710|   31|     590|        SFO|
|Seattle|   WA|01010955|  104|     590|        SFO|
|Seattle|   WA|01010730|    5|     590|        SFO|
+-------+-----+--------+-----+--------+-----------+



In [104]:
spark.sql("""
SELECT a.City, a.State, f.date, f.delay, f.distance, f.destination
FROM foo f
JOIN airports_na a
ON a.IATA = f.origin
""").show()

+-------+-----+--------+-----+--------+-----------+
|   City|State|    date|delay|distance|destination|
+-------+-----+--------+-----+--------+-----------+
|Seattle|   WA|01010710|   31|     590|        SFO|
|Seattle|   WA|01010955|  104|     590|        SFO|
|Seattle|   WA|01010730|    5|     590|        SFO|
+-------+-----+--------+-----+--------+-----------+



#### WINDOWING

In [105]:
spark.sql("DROP TABLE IF EXISTS departureDelaysWindow")

DataFrame[]

In [106]:
spark.sql("""
CREATE TABLE departureDelaysWindow AS
SELECT origin, destination, SUM(delay) AS TotalDelays
FROM departureDelays
WHERE origin IN ('SEA', 'SFO', 'JFK')
AND destination IN ('SEA', 'SFO', 'JFK', 'DEN', 'ORD', 'LAX', 'ATL')
GROUP BY origin, destination;
""")

DataFrame[]

In [107]:
spark.sql("""
SELECT * FROM departureDelaysWindow
""").show()

+------+-----------+-----------+
|origin|destination|TotalDelays|
+------+-----------+-----------+
|   SEA|        ORD|      10041|
|   SEA|        JFK|       4667|
|   SFO|        SEA|      17080|
|   SEA|        SFO|      22293|
|   JFK|        SFO|      35619|
|   JFK|        ORD|       5608|
|   JFK|        SEA|       7856|
|   SEA|        LAX|       9359|
|   SFO|        LAX|      40798|
|   JFK|        DEN|       4315|
|   SFO|        ATL|       5091|
|   SEA|        ATL|       4535|
|   JFK|        LAX|      35755|
|   SFO|        JFK|      24100|
|   SFO|        DEN|      18688|
|   JFK|        ATL|      12141|
|   SFO|        ORD|      27412|
|   SEA|        DEN|      13645|
+------+-----------+-----------+



In [108]:
#pg150