# Vectorized User Defined Functions

Let's compare the performance of UDFs, Vectorized UDFs, and built-in methods.

Start by generating some dummy data.

In [2]:
from pyspark.sql.types import *
from pyspark.sql.functions import *
from pyspark.sql import SparkSession
spark = (
    SparkSession
    .builder
    .appName("07_chap")
    .config("spark.sql.catalogImplementation", "hive")
    .getOrCreate()
    )
sc = spark.sparkContext

Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
25/05/14 23:24:02 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable


----------------------------------------
Exception occurred during processing of request from ('127.0.0.1', 61144)
Traceback (most recent call last):
  File "/opt/anaconda3/envs/spark_env/lib/python3.12/socketserver.py", line 318, in _handle_request_noblock
    self.process_request(request, client_address)
  File "/opt/anaconda3/envs/spark_env/lib/python3.12/socketserver.py", line 349, in process_request
    self.finish_request(request, client_address)
  File "/opt/anaconda3/envs/spark_env/lib/python3.12/socketserver.py", line 362, in finish_request
    self.RequestHandlerClass(request, client_address, self)
  File "/opt/anaconda3/envs/spark_env/lib/python3.12/socketserver.py", line 766, in __init__
    self.handle()
  File "/opt/anaconda3/envs/spark_env/lib/python3.12/site-packages/pyspark/accumulators.py", line 295, in handle
    poll(accum_updates)
  File "/opt/anaconda3/envs/spark_env/lib/python3.12/site-packages/pyspark/accumulators.py", line 267, in poll
    if self.rfile in r an

In [3]:
from pyspark.sql.types import *
from pyspark.sql.functions import col, count, rand, collect_list, explode, struct, count, pandas_udf

df = (spark
      .range(0, 10 * 1000 * 1000)
      .withColumn("id", (col("id") / 1000).cast("integer"))
      .withColumn("v", rand()))

df.cache()
df.count()

                                                                                

10000000

## Incrementing a column by one

Let's start off with a simple example of adding one to each value in our DataFrame.

### PySpark UDF

In [4]:
@udf("double")
def plus_one(v):
    return v + 1

%timeit -n1 -r1 df.withColumn("v", plus_one(df.v)).agg(count(col("v"))).show()

[Stage 4:>                                                        (0 + 12) / 12]

+--------+
|count(v)|
+--------+
|10000000|
+--------+

3.63 s ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)


                                                                                

Alternate Syntax (can also use in the SQL namespace now)

In [5]:
from pyspark.sql.types import DoubleType

def plus_one(v):
    return v + 1
  
spark.udf.register("plus_one_udf", plus_one, DoubleType())

%timeit -n1 -r1 df.selectExpr("id", "plus_one_udf(v) as v").agg(count(col("v"))).show()

[Stage 7:>                                                        (0 + 12) / 12]

+--------+
|count(v)|
+--------+
|10000000|
+--------+

1.8 s ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)


                                                                                

### Scala UDF

Yikes! That took awhile to add 1 to each value. Let's see how long this takes with a Scala UDF.

In [6]:
df.createOrReplaceTempView("df")

In [8]:
# %scala
# import org.apache.spark.sql.functions._

# val df = spark.table("df")

# def plusOne: (Double => Double) = { v => v+1 }
# val plus_one = udf(plusOne)

In [10]:
# %scala
# df.withColumn("v", plus_one($"v"))
#   .agg(count(col("v")))
#   .show()

Wow! That Scala UDF was a lot faster. However, as of Spark 2.3, there are Vectorized UDFs available in Python to help speed up the computation.

* [Blog post](https://databricks.com/blog/2017/10/30/introducing-vectorized-udfs-for-pyspark.html)
* [Documentation](https://spark.apache.org/docs/latest/sql-programming-guide.html#pyspark-usage-guide-for-pandas-with-apache-arrow)

![Benchmark](https://databricks.com/wp-content/uploads/2017/10/image1-4.png)

Vectorized UDFs utilize Apache Arrow to speed up computation. Let's see how that helps improve our processing time.

[Apache Arrow](https://arrow.apache.org/), is an in-memory columnar data format that is used in Spark to efficiently transfer data between JVM and Python processes. See more [here](https://spark.apache.org/docs/latest/sql-pyspark-pandas-with-arrow.html).

Let's take a look at how long it takes to convert a Spark DataFrame to Pandas with and without Apache Arrow enabled.

In [14]:
spark.conf.set("spark.sql.execution.arrow.enabled", "true")

%timeit -n1 -r1 df.toPandas()

25/05/14 23:28:05 WARN SQLConf: The SQL config 'spark.sql.execution.arrow.enabled' has been deprecated in Spark v3.0 and may be removed in the future. Use 'spark.sql.execution.arrow.pyspark.enabled' instead of it.


776 ms ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)


In [12]:
spark.conf.set("spark.sql.execution.arrow.enabled", "false")

%timeit -n1 -r1 df.toPandas()

25/05/14 23:24:43 WARN SQLConf: The SQL config 'spark.sql.execution.arrow.enabled' has been deprecated in Spark v3.0 and may be removed in the future. Use 'spark.sql.execution.arrow.pyspark.enabled' instead of it.
Exception in thread "refresh progress"                                          java.lang.OutOfMemoryError: GC overhead limit exceeded
	at scala.collection.convert.DecorateAsScala.collectionAsScalaIterableConverter(DecorateAsScala.scala:49)
	at scala.collection.convert.DecorateAsScala.collectionAsScalaIterableConverter$(DecorateAsScala.scala:48)
	at scala.collection.JavaConverters$.collectionAsScalaIterableConverter(JavaConverters.scala:77)
	at org.apache.spark.status.AppStatusListener.activeStages(AppStatusListener.scala:1044)
	at org.apache.spark.status.AppStatusStore.$anonfun$activeStages$1(AppStatusStore.scala:170)
	at org.apache.spark.status.AppStatusStore$$Lambda$709/818899952.apply(Unknown Source)
	at scala.Option.map(Option.scala:230)
	at org.apache.spark.status.AppSt

Py4JJavaError: An error occurred while calling o34.collectToPython.
: java.lang.OutOfMemoryError: GC overhead limit exceeded
	at org.apache.spark.sql.execution.SparkPlan$$anon$1._next(SparkPlan.scala:415)
	at org.apache.spark.sql.execution.SparkPlan$$anon$1.getNext(SparkPlan.scala:426)
	at org.apache.spark.sql.execution.SparkPlan$$anon$1.getNext(SparkPlan.scala:412)
	at org.apache.spark.util.NextIterator.hasNext(NextIterator.scala:73)
	at scala.collection.Iterator.foreach(Iterator.scala:943)
	at scala.collection.Iterator.foreach$(Iterator.scala:943)
	at org.apache.spark.util.NextIterator.foreach(NextIterator.scala:21)
	at org.apache.spark.sql.execution.SparkPlan.$anonfun$executeCollect$1(SparkPlan.scala:449)
	at org.apache.spark.sql.execution.SparkPlan.$anonfun$executeCollect$1$adapted(SparkPlan.scala:448)
	at org.apache.spark.sql.execution.SparkPlan$$Lambda$3654/1012265041.apply(Unknown Source)
	at scala.collection.IndexedSeqOptimized.foreach(IndexedSeqOptimized.scala:36)
	at scala.collection.IndexedSeqOptimized.foreach$(IndexedSeqOptimized.scala:33)
	at scala.collection.mutable.ArrayOps$ofRef.foreach(ArrayOps.scala:198)
	at org.apache.spark.sql.execution.SparkPlan.executeCollect(SparkPlan.scala:448)
	at org.apache.spark.sql.execution.adaptive.QueryStageExec.executeCollect(QueryStageExec.scala:101)
	at org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanExec.$anonfun$executeCollect$1(AdaptiveSparkPlanExec.scala:392)
	at org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanExec$$Lambda$1919/698927375.apply(Unknown Source)
	at org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanExec.withFinalPlanUpdate(AdaptiveSparkPlanExec.scala:420)
	at org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanExec.executeCollect(AdaptiveSparkPlanExec.scala:392)
	at org.apache.spark.sql.Dataset.$anonfun$collectToPython$1(Dataset.scala:4149)
	at org.apache.spark.sql.Dataset$$Lambda$4460/939728266.apply(Unknown Source)
	at org.apache.spark.sql.Dataset.$anonfun$withAction$2(Dataset.scala:4323)
	at org.apache.spark.sql.Dataset$$Lambda$1915/797271895.apply(Unknown Source)
	at org.apache.spark.sql.execution.QueryExecution$.withInternalError(QueryExecution.scala:546)
	at org.apache.spark.sql.Dataset.$anonfun$withAction$1(Dataset.scala:4321)
	at org.apache.spark.sql.Dataset$$Lambda$1756/1679373803.apply(Unknown Source)
	at org.apache.spark.sql.execution.SQLExecution$.$anonfun$withNewExecutionId$6(SQLExecution.scala:125)
	at org.apache.spark.sql.execution.SQLExecution$$$Lambda$1767/556877223.apply(Unknown Source)
	at org.apache.spark.sql.execution.SQLExecution$.withSQLConfPropagated(SQLExecution.scala:201)
	at org.apache.spark.sql.execution.SQLExecution$.$anonfun$withNewExecutionId$1(SQLExecution.scala:108)
	at org.apache.spark.sql.execution.SQLExecution$$$Lambda$1757/222620503.apply(Unknown Source)
	at org.apache.spark.sql.SparkSession.withActive(SparkSession.scala:900)


25/05/14 23:27:11 ERROR Inbox: Ignoring error
org.apache.spark.SparkException: Exception thrown in awaitResult: 
	at org.apache.spark.util.SparkThreadUtils$.awaitResult(SparkThreadUtils.scala:56)
	at org.apache.spark.util.ThreadUtils$.awaitResult(ThreadUtils.scala:310)
	at org.apache.spark.rpc.RpcTimeout.awaitResult(RpcTimeout.scala:75)
	at org.apache.spark.rpc.RpcEnv.setupEndpointRefByURI(RpcEnv.scala:102)
	at org.apache.spark.rpc.RpcEnv.setupEndpointRef(RpcEnv.scala:110)
	at org.apache.spark.util.RpcUtils$.makeDriverRef(RpcUtils.scala:36)
	at org.apache.spark.storage.BlockManagerMasterEndpoint.driverEndpoint$lzycompute(BlockManagerMasterEndpoint.scala:124)
	at org.apache.spark.storage.BlockManagerMasterEndpoint.org$apache$spark$storage$BlockManagerMasterEndpoint$$driverEndpoint(BlockManagerMasterEndpoint.scala:123)
	at org.apache.spark.storage.BlockManagerMasterEndpoint.isExecutorAlive$lzycompute$1(BlockManagerMasterEndpoint.scala:688)
	at org.apache.spark.storage.BlockManagerMasterE

25/05/14 23:27:12 ERROR Inbox: Ignoring error
org.apache.spark.SparkException: Exception thrown in awaitResult: 
	at org.apache.spark.util.SparkThreadUtils$.awaitResult(SparkThreadUtils.scala:56)
	at org.apache.spark.util.ThreadUtils$.awaitResult(ThreadUtils.scala:310)
	at org.apache.spark.rpc.RpcTimeout.awaitResult(RpcTimeout.scala:75)
	at org.apache.spark.rpc.RpcEnv.setupEndpointRefByURI(RpcEnv.scala:102)
	at org.apache.spark.rpc.RpcEnv.setupEndpointRef(RpcEnv.scala:110)
	at org.apache.spark.util.RpcUtils$.makeDriverRef(RpcUtils.scala:36)
	at org.apache.spark.storage.BlockManagerMasterEndpoint.driverEndpoint$lzycompute(BlockManagerMasterEndpoint.scala:124)
	at org.apache.spark.storage.BlockManagerMasterEndpoint.org$apache$spark$storage$BlockManagerMasterEndpoint$$driverEndpoint(BlockManagerMasterEndpoint.scala:123)
	at org.apache.spark.storage.BlockManagerMasterEndpoint.isExecutorAlive$lzycompute$1(BlockManagerMasterEndpoint.scala:688)
	at org.apache.spark.storage.BlockManagerMasterE

### Vectorized UDF

In [15]:
@pandas_udf("double")
def vectorized_plus_one(v):
    return v + 1

%timeit -n1 -r1 df.withColumn("v", vectorized_plus_one(df.v)).agg(count(col("v"))).show()

+--------+
|count(v)|
+--------+
|10000000|
+--------+

1.23 s ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)


25/05/14 23:28:12 WARN Executor: Issue communicating with driver in heartbeater
org.apache.spark.SparkException: Exception thrown in awaitResult: 
	at org.apache.spark.util.SparkThreadUtils$.awaitResult(SparkThreadUtils.scala:56)
	at org.apache.spark.util.ThreadUtils$.awaitResult(ThreadUtils.scala:310)
	at org.apache.spark.rpc.RpcTimeout.awaitResult(RpcTimeout.scala:75)
	at org.apache.spark.rpc.RpcEndpointRef.askSync(RpcEndpointRef.scala:101)
	at org.apache.spark.rpc.RpcEndpointRef.askSync(RpcEndpointRef.scala:85)
	at org.apache.spark.storage.BlockManagerMaster.registerBlockManager(BlockManagerMaster.scala:80)
	at org.apache.spark.storage.BlockManager.reregister(BlockManager.scala:642)
	at org.apache.spark.executor.Executor.reportHeartBeat(Executor.scala:1223)
	at org.apache.spark.executor.Executor.$anonfun$heartbeater$1(Executor.scala:295)
	at scala.runtime.java8.JFunction0$mcV$sp.apply(JFunction0$mcV$sp.java:23)
	at org.apache.spark.util.Utils$.logUncaughtExceptions(Utils.scala:1928)

Alright! Still not as great as the Scala UDF, but at least it's better than the regular Python UDF!

Here's some alternate syntax for the Pandas UDF.

In [17]:
from pyspark.sql.functions import pandas_udf

def vectorized_plus_one(v):
    return v + 1

vectorized_plus_one_udf = pandas_udf(vectorized_plus_one, "double")

%timeit -n1 -r1 df.withColumn("v", vectorized_plus_one_udf(df.v)).agg(count(col("v"))).show()

+--------+
|count(v)|
+--------+
|10000000|
+--------+

1.19 s ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)


### Built-in Method

Let's compare the performance of the UDFs with using the built-in method.

In [19]:
from pyspark.sql.functions import lit

%timeit -n1 -r1 df.withColumn("v", df.v + lit(1)).agg(count(col("v"))).show()

+--------+
|count(v)|
+--------+
|10000000|
+--------+

444 ms ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)


## Computing subtract mean

Up above, we were working with Scalar return types. Now we can use grouped UDFs.

### PySpark UDF

In [20]:
from pyspark.sql import Row
import pandas as pd

@udf(ArrayType(df.schema))
def subtract_mean(rows):
  vs = pd.Series([r.v for r in rows])
  vs = vs - vs.mean()
  return [Row(id=rows[i]["id"], v=float(vs[i])) for i in range(len(rows))]
  
%timeit -n1 -r1 (df.groupby("id").agg(collect_list(struct(df["id"], df["v"])).alias("rows")).withColumn("new_rows", subtract_mean(col("rows"))).withColumn("new_row", explode(col("new_rows"))).withColumn("id", col("new_row.id")).withColumn("v", col("new_row.v")).agg(count(col("v"))).show())

25/05/14 23:29:02 WARN Executor: Issue communicating with driver in heartbeater
org.apache.spark.SparkException: Exception thrown in awaitResult: 
	at org.apache.spark.util.SparkThreadUtils$.awaitResult(SparkThreadUtils.scala:56)
	at org.apache.spark.util.ThreadUtils$.awaitResult(ThreadUtils.scala:310)
	at org.apache.spark.rpc.RpcTimeout.awaitResult(RpcTimeout.scala:75)
	at org.apache.spark.rpc.RpcEndpointRef.askSync(RpcEndpointRef.scala:101)
	at org.apache.spark.rpc.RpcEndpointRef.askSync(RpcEndpointRef.scala:85)
	at org.apache.spark.storage.BlockManagerMaster.registerBlockManager(BlockManagerMaster.scala:80)
	at org.apache.spark.storage.BlockManager.reregister(BlockManager.scala:642)
	at org.apache.spark.executor.Executor.reportHeartBeat(Executor.scala:1223)
	at org.apache.spark.executor.Executor.$anonfun$heartbeater$1(Executor.scala:295)
	at scala.runtime.java8.JFunction0$mcV$sp.apply(JFunction0$mcV$sp.java:23)
	at org.apache.spark.util.Utils$.logUncaughtExceptions(Utils.scala:1928)

+--------+
|count(v)|
+--------+
|10000000|
+--------+

25 s ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)


25/05/14 23:29:32 WARN Executor: Issue communicating with driver in heartbeater
org.apache.spark.SparkException: Exception thrown in awaitResult: 
	at org.apache.spark.util.SparkThreadUtils$.awaitResult(SparkThreadUtils.scala:56)
	at org.apache.spark.util.ThreadUtils$.awaitResult(ThreadUtils.scala:310)
	at org.apache.spark.rpc.RpcTimeout.awaitResult(RpcTimeout.scala:75)
	at org.apache.spark.rpc.RpcEndpointRef.askSync(RpcEndpointRef.scala:101)
	at org.apache.spark.rpc.RpcEndpointRef.askSync(RpcEndpointRef.scala:85)
	at org.apache.spark.storage.BlockManagerMaster.registerBlockManager(BlockManagerMaster.scala:80)
	at org.apache.spark.storage.BlockManager.reregister(BlockManager.scala:642)
	at org.apache.spark.executor.Executor.reportHeartBeat(Executor.scala:1223)
	at org.apache.spark.executor.Executor.$anonfun$heartbeater$1(Executor.scala:295)
	at scala.runtime.java8.JFunction0$mcV$sp.apply(JFunction0$mcV$sp.java:23)
	at org.apache.spark.util.Utils$.logUncaughtExceptions(Utils.scala:1928)

### Vectorized UDF

In [21]:
def vectorized_subtract_mean(pdf: pd.Series) -> pd.Series:
	return pdf.assign(v=pdf.v - pdf.v.mean())

%timeit -n1 -r1 df.groupby("id").applyInPandas(vectorized_subtract_mean, df.schema).agg(count(col("v"))).show()

25/05/14 23:29:52 WARN Executor: Issue communicating with driver in heartbeater
org.apache.spark.SparkException: Exception thrown in awaitResult: 
	at org.apache.spark.util.SparkThreadUtils$.awaitResult(SparkThreadUtils.scala:56)
	at org.apache.spark.util.ThreadUtils$.awaitResult(ThreadUtils.scala:310)
	at org.apache.spark.rpc.RpcTimeout.awaitResult(RpcTimeout.scala:75)
	at org.apache.spark.rpc.RpcEndpointRef.askSync(RpcEndpointRef.scala:101)
	at org.apache.spark.rpc.RpcEndpointRef.askSync(RpcEndpointRef.scala:85)
	at org.apache.spark.storage.BlockManagerMaster.registerBlockManager(BlockManagerMaster.scala:80)
	at org.apache.spark.storage.BlockManager.reregister(BlockManager.scala:642)
	at org.apache.spark.executor.Executor.reportHeartBeat(Executor.scala:1223)
	at org.apache.spark.executor.Executor.$anonfun$heartbeater$1(Executor.scala:295)
	at scala.runtime.java8.JFunction0$mcV$sp.apply(JFunction0$mcV$sp.java:23)
	at org.apache.spark.util.Utils$.logUncaughtExceptions(Utils.scala:1928)

+--------+
|count(v)|
+--------+
|10000000|
+--------+

5.62 s ± 0 ns per loop (mean ± std. dev. of 1 run, 1 loop each)


25/05/14 23:30:02 ERROR Inbox: Ignoring error
org.apache.spark.SparkException: Exception thrown in awaitResult: 
	at org.apache.spark.util.SparkThreadUtils$.awaitResult(SparkThreadUtils.scala:56)
	at org.apache.spark.util.ThreadUtils$.awaitResult(ThreadUtils.scala:310)
	at org.apache.spark.rpc.RpcTimeout.awaitResult(RpcTimeout.scala:75)
	at org.apache.spark.rpc.RpcEnv.setupEndpointRefByURI(RpcEnv.scala:102)
	at org.apache.spark.rpc.RpcEnv.setupEndpointRef(RpcEnv.scala:110)
	at org.apache.spark.util.RpcUtils$.makeDriverRef(RpcUtils.scala:36)
	at org.apache.spark.storage.BlockManagerMasterEndpoint.driverEndpoint$lzycompute(BlockManagerMasterEndpoint.scala:124)
	at org.apache.spark.storage.BlockManagerMasterEndpoint.org$apache$spark$storage$BlockManagerMasterEndpoint$$driverEndpoint(BlockManagerMasterEndpoint.scala:123)
	at org.apache.spark.storage.BlockManagerMasterEndpoint.isExecutorAlive$lzycompute$1(BlockManagerMasterEndpoint.scala:688)
	at org.apache.spark.storage.BlockManagerMasterE