# Pandas UDFs TESTING

When coding with Spark, you will generally want to try and use native Spark functions wherever possible (i.e. functions in `pyspark.sql.functions` for PySpark). However, there are instances where there may not be a Spark function available to do what you need. In cases such as this, the second best option is to use a Pandas UDF. 

### Why are Pandas UDFs better than Python UDFs?

Ordinary Python UDFs are not particularly efficient in Spark as they do not make full use of the Spark cluster and can only be processed in the Python runtime. As the code for Python UDFs cannot be executed in the Java Virtual Machine (JVM), the platform which runs Spark, each row of the dataframe is serialised and deserialised between the Python runtime and the JVM. As can be imagined, this causes massive (de)serialisation overheads, high data copy in memory and is very slow!

Pandas UDFs (also known as vectorised UDFs) can work on the Spark executors to process data in a distributed manner and allow for vectorised operations. This means that Pandas UDFs can work on the whole data partition at once instead of just one row at a time like Python UDFs. This vectorisation is achieved by using Apache Arrow to transfer data across the cluster between the JVM and the Python executors with very low data copy in memory and (de)serialisation overheads. Essentially, Apache Arrow acts as a middleman to store a single copy of the data which can be accessed by both Python and Java processes.

### Types of Pandas UDFs

Spark 2.4 is compatible with 3 types of Pandas UDFs:
- Scalar
- Grouped Map
- Grouped Aggregate

Spark 3.x is compatible with the additional types:
- Scalar Iterator
- Map
- Cogroup Mapped

This demo will cover the three types supported in Spark 2.4. Additional information on the other types for Spark 3.x can be found [here](https://towardsdatascience.com/pyspark-or-pandas-why-not-both-95523946ec7c).

### Declaring Pandas UDFs in Spark

You can create a Pandas UDF in a number of ways.
- Using the `@pandas_udf(<type>, F.PandasUDFType.<type>)` decorator.
- Assigning the UDF using `<udf_name>_udf = F.pandas_udf(<udf_name>, returnType = <type>)`
- Using Python Type hints. Python hints can be used to make Pandas UDFs more descriptive. For more info see [here](https://www.databricks.com/blog/2020/05/20/new-pandas-udfs-and-python-type-hints-in-the-upcoming-release-of-apache-spark-3-0.html).

This demo will use the `pandas_udf` decorator and Python hints to annotate the functions with the input and output data types.


### Spark Setup for Pandas UDFs

For Spark 3, to use Pandas UDFs you will have to install Pandas and PyArrow using:
```
pip install pandas
pip install pyarrow
```

For Spark 2.4 you will need to install the older versions of Pandas and PyArrow using. We found the following versions to be the highest which will work with Pandas UDFs, although expect a deprecation warning:
```
pip install pandas==1.1.5
pip install pyarrow==0.14.0
```
The [offical PySpark Usage Guide for Pandas with Apache Arrow](https://spark.apache.org/docs/2.4.7/sql-pyspark-pandas-with-arrow.html#compatibiliy-setting-for-pyarrow--0150-and-spark-23x-24x) advises that compatibility issues may occur with Pandas >0.19.2 and PyArrow >0.8.0, so if you are finding issues, it may be best to switch to these lower versions.

You should also add configs to enable PyArrow optimisation and fallback if it is not installed. Even if you are not using Pandas UDFs the below configs can make the conversion between Pandas and Spark more efficient.

```
.config('spark.sql.execution.arrow.enabled', 'true')
.config('spark.sql.execution.arrow.fallback.enabled', 'true')
```

For Spark 2.4 you will need to add the below additional configs to your Spark session for compatibility:

```
.config('spark.excutorEnv.ARROW_PRE_0_15_IPC_FORMAT', 1)
.config('spark.workerEnv.ARROW_PRE_0_15_IPC_FORMAT', 1)
```

In our demo we are using a local Spark session and will be working with the animal rescue dataset used throughout in this book. We have used examples of Pandas UDFs applied to the rescue dataset only for the purpose of demonstration as there are Spark functions which in a real-life scenario would be used instead. For other simple examples of Pandas UDFs please see [Databricks Introduction to Pandas UDFs](https://www.databricks.com/blog/2017/10/30/introducing-vectorized-udfs-for-pyspark.html).

Firstly, we need to import packages and set up a Spark session for PyArrow optimisation and compatibility with Spark 2.4. Then read the data, select and rename the relevant columns and remove nulls:

In [21]:
import pandas as pd
from pyspark.sql import SparkSession
import pyspark.sql.functions as F
from pyspark.sql.functions import pandas_udf
import pyspark.sql.types as T
from pyspark.sql.functions import lit

spark = (SparkSession.builder.master("local[2]")
         .appName("pandas_udfs")
         .config('spark.sql.execution.arrow.enabled', 'true')
         .config('spark.sql.execution.arrow.fallback.enabled', 'true')
         .config('spark.excutorEnv.ARROW_PRE_0_15_IPC_FORMAT', 1)
         .config('spark.workerEnv.ARROW_PRE_0_15_IPC_FORMAT', 1)
         .getOrCreate())

import yaml
with open("../../config.yaml") as f:
     config = yaml.safe_load(f)

rescue = spark.read.csv("../data/animal_rescue.csv", header = True, inferSchema=True)
#rescue_path_csv = config["rescue_path_csv"]
#rescue = spark.read.csv(rescue_path_csv, header=True, inferSchema=True)

#rescue = rescue.withColumn(
#    "IncidentDuration", 
#    F.col("PumpHoursTotal") / F.col("PumpCount")
#)
rescue = rescue.select(F.col('IncidentNumber').alias('incident_number'),
                       F.col('AnimalGroupParent').alias('animal_group'), 
                       F.col('IncidentNotionalCost(£)').alias('incident_cost'), 
                       F.col('Easting_rounded').alias('easting_rounded'),
                       F.col('Northing_rounded').alias('northing_rounded'))

rescue = rescue.withColumn('lfb_easting', lit(532066)).withColumn('lfb_northing', lit(180010))

#rescue = rescue.filter(F.col('incident_number').isNotNull())

rescue.limit(10).toPandas()

Unnamed: 0,incident_number,animal_group,incident_cost,easting_rounded,northing_rounded,lfb_easting,lfb_northing
0,139091,Dog,510.0,532350,170050,532066,180010
1,275091,Fox,255.0,534750,167550,532066,180010
2,2075091,Dog,255.0,528050,164950,532066,180010
3,2872091,Horse,255.0,504650,190650,532066,180010
4,3553091,Rabbit,255.0,554650,192350,532066,180010
5,3742091,Unknown - Heavy Livestock Animal,255.0,549350,184950,532066,180010
6,4011091,Dog,255.0,539050,186150,532066,180010
7,4211091,Dog,255.0,541350,186650,532066,180010
8,4306091,Squirrel,255.0,538750,163350,532066,180010
9,4715091,Dog,255.0,535450,186750,532066,180010


### Scalar Pandas UDFs

Scalar Pandas UDFs take in a `pandas.Series` and output a `pandas.Series`. 

Spark executes scalar Pandas UDFs by serialising each partition column into a `pandas.Series` object (basically splitting the data into batches). The UDF is then called on each of of these Series objects as a subset of the data and then results are concactenated together and returned as a `pandas.Series`. 

Be aware that scalar UDFs may cause incorrect results (if your dataframe is over multiple partitions) when calculating means and standard deviations. For more info on this and how to overcome this issue see [this towards data science post](https://towardsdatascience.com/pyspark-or-pandas-why-not-both-95523946ec7c).

In the decorator of a scalar Pandas UDF the first argument is the data type of the output dataframe and the second argument is the UDF type. Here is a simple Pandas scalar UDF to calculate the distance between where an animal was rescued from (easting and northing_rounded) and the London Fire Brigade (lfb_easting and northing). The output is in km.  

In [22]:
@pandas_udf(T.DoubleType(), F.PandasUDFType.SCALAR)
def distx(easting_rounded:pd.Series, lfb_easting:pd.Series, northing_rounded:pd.Series, lfb_northing:pd.Series)->pd.Series:
    distance = ((easting_rounded - lfb_easting)**2 + (northing_rounded - lfb_northing)**2)**0.5
    distance_km = distance/1000
    return distance_km



In [26]:
lfb_distance = rescue.withColumn("distance (km)", distx('easting_rounded', 'lfb_easting', 'northing_rounded', 'lfb_northing'))
lfb_distance.limit(5).toPandas()

  An error occurred while calling o353.getResult.
: org.apache.spark.SparkException: Exception thrown in awaitResult: 
	at org.apache.spark.util.ThreadUtils$.awaitResult(ThreadUtils.scala:322)
	at org.apache.spark.security.SocketAuthServer.getResult(SocketAuthServer.scala:98)
	at org.apache.spark.security.SocketAuthServer.getResult(SocketAuthServer.scala:94)
	at sun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)
	at sun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)
	at sun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)
	at java.lang.reflect.Method.invoke(Method.java:498)
	at py4j.reflection.MethodInvoker.invoke(MethodInvoker.java:244)
	at py4j.reflection.ReflectionEngine.invoke(ReflectionEngine.java:374)
	at py4j.Gateway.invoke(Gateway.java:282)
	at py4j.commands.AbstractCommand.invokeMethod(AbstractCommand.java:132)
	at py4j.commands.CallCommand.execute(CallCommand.java:79)
	at py4j.ClientServerConnection.waitForC

Py4JJavaError: An error occurred while calling o353.getResult.
: org.apache.spark.SparkException: Exception thrown in awaitResult: 
	at org.apache.spark.util.ThreadUtils$.awaitResult(ThreadUtils.scala:322)
	at org.apache.spark.security.SocketAuthServer.getResult(SocketAuthServer.scala:98)
	at org.apache.spark.security.SocketAuthServer.getResult(SocketAuthServer.scala:94)
	at sun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)
	at sun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)
	at sun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)
	at java.lang.reflect.Method.invoke(Method.java:498)
	at py4j.reflection.MethodInvoker.invoke(MethodInvoker.java:244)
	at py4j.reflection.ReflectionEngine.invoke(ReflectionEngine.java:374)
	at py4j.Gateway.invoke(Gateway.java:282)
	at py4j.commands.AbstractCommand.invokeMethod(AbstractCommand.java:132)
	at py4j.commands.CallCommand.execute(CallCommand.java:79)
	at py4j.ClientServerConnection.waitForCommands(ClientServerConnection.java:182)
	at py4j.ClientServerConnection.run(ClientServerConnection.java:106)
	at java.lang.Thread.run(Thread.java:748)
Caused by: org.apache.spark.SparkException: Job aborted due to stage failure: Task 0 in stage 39.0 failed 1 times, most recent failure: Lost task 0.0 in stage 39.0 (TID 34) (ONS35189.Ons.Statistics.gov.uk executor driver): org.apache.spark.SparkException: Python worker failed to connect back.
	at org.apache.spark.api.python.PythonWorkerFactory.createSimpleWorker(PythonWorkerFactory.scala:192)
	at org.apache.spark.api.python.PythonWorkerFactory.create(PythonWorkerFactory.scala:109)
	at org.apache.spark.SparkEnv.createPythonWorker(SparkEnv.scala:124)
	at org.apache.spark.api.python.BasePythonRunner.compute(PythonRunner.scala:166)
	at org.apache.spark.sql.execution.python.ArrowEvalPythonExec.evaluate(ArrowEvalPythonExec.scala:89)
	at org.apache.spark.sql.execution.python.EvalPythonExec.$anonfun$doExecute$2(EvalPythonExec.scala:131)
	at org.apache.spark.rdd.RDD.$anonfun$mapPartitions$2(RDD.scala:853)
	at org.apache.spark.rdd.RDD.$anonfun$mapPartitions$2$adapted(RDD.scala:853)
	at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:52)
	at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:364)
	at org.apache.spark.rdd.RDD.iterator(RDD.scala:328)
	at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:52)
	at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:364)
	at org.apache.spark.rdd.RDD.iterator(RDD.scala:328)
	at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:52)
	at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:364)
	at org.apache.spark.rdd.RDD.iterator(RDD.scala:328)
	at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:52)
	at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:364)
	at org.apache.spark.rdd.RDD.iterator(RDD.scala:328)
	at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:92)
	at org.apache.spark.TaskContext.runTaskWithListeners(TaskContext.scala:161)
	at org.apache.spark.scheduler.Task.run(Task.scala:139)
	at org.apache.spark.executor.Executor$TaskRunner.$anonfun$run$3(Executor.scala:554)
	at org.apache.spark.util.Utils$.tryWithSafeFinally(Utils.scala:1529)
	at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:557)
	at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149)
	at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)
	at java.lang.Thread.run(Thread.java:748)
Caused by: java.net.SocketTimeoutException: Accept timed out
	at java.net.DualStackPlainSocketImpl.waitForNewConnection(Native Method)
	at java.net.DualStackPlainSocketImpl.socketAccept(DualStackPlainSocketImpl.java:135)
	at java.net.AbstractPlainSocketImpl.accept(AbstractPlainSocketImpl.java:409)
	at java.net.PlainSocketImpl.accept(PlainSocketImpl.java:199)
	at java.net.ServerSocket.implAccept(ServerSocket.java:545)
	at java.net.ServerSocket.accept(ServerSocket.java:513)
	at org.apache.spark.api.python.PythonWorkerFactory.createSimpleWorker(PythonWorkerFactory.scala:179)
	... 28 more

Driver stacktrace:
	at org.apache.spark.scheduler.DAGScheduler.failJobAndIndependentStages(DAGScheduler.scala:2785)
	at org.apache.spark.scheduler.DAGScheduler.$anonfun$abortStage$2(DAGScheduler.scala:2721)
	at org.apache.spark.scheduler.DAGScheduler.$anonfun$abortStage$2$adapted(DAGScheduler.scala:2720)
	at scala.collection.mutable.ResizableArray.foreach(ResizableArray.scala:62)
	at scala.collection.mutable.ResizableArray.foreach$(ResizableArray.scala:55)
	at scala.collection.mutable.ArrayBuffer.foreach(ArrayBuffer.scala:49)
	at org.apache.spark.scheduler.DAGScheduler.abortStage(DAGScheduler.scala:2720)
	at org.apache.spark.scheduler.DAGScheduler.$anonfun$handleTaskSetFailed$1(DAGScheduler.scala:1206)
	at org.apache.spark.scheduler.DAGScheduler.$anonfun$handleTaskSetFailed$1$adapted(DAGScheduler.scala:1206)
	at scala.Option.foreach(Option.scala:407)
	at org.apache.spark.scheduler.DAGScheduler.handleTaskSetFailed(DAGScheduler.scala:1206)
	at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.doOnReceive(DAGScheduler.scala:2984)
	at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.onReceive(DAGScheduler.scala:2923)
	at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.onReceive(DAGScheduler.scala:2912)
	at org.apache.spark.util.EventLoop$$anon$1.run(EventLoop.scala:49)
	at org.apache.spark.scheduler.DAGScheduler.runJob(DAGScheduler.scala:971)
	at org.apache.spark.SparkContext.runJob(SparkContext.scala:2263)
	at org.apache.spark.SparkContext.runJob(SparkContext.scala:2358)
	at org.apache.spark.sql.Dataset.$anonfun$collectAsArrowToPython$5(Dataset.scala:4107)
	at scala.runtime.java8.JFunction0$mcV$sp.apply(JFunction0$mcV$sp.java:23)
	at org.apache.spark.util.Utils$.tryWithSafeFinally(Utils.scala:1529)
	at org.apache.spark.sql.Dataset.$anonfun$collectAsArrowToPython$2(Dataset.scala:4111)
	at org.apache.spark.sql.Dataset.$anonfun$collectAsArrowToPython$2$adapted(Dataset.scala:4088)
	at org.apache.spark.sql.Dataset.$anonfun$withAction$2(Dataset.scala:4167)
	at org.apache.spark.sql.execution.QueryExecution$.withInternalError(QueryExecution.scala:526)
	at org.apache.spark.sql.Dataset.$anonfun$withAction$1(Dataset.scala:4165)
	at org.apache.spark.sql.execution.SQLExecution$.$anonfun$withNewExecutionId$6(SQLExecution.scala:118)
	at org.apache.spark.sql.execution.SQLExecution$.withSQLConfPropagated(SQLExecution.scala:195)
	at org.apache.spark.sql.execution.SQLExecution$.$anonfun$withNewExecutionId$1(SQLExecution.scala:103)
	at org.apache.spark.sql.SparkSession.withActive(SparkSession.scala:827)
	at org.apache.spark.sql.execution.SQLExecution$.withNewExecutionId(SQLExecution.scala:65)
	at org.apache.spark.sql.Dataset.withAction(Dataset.scala:4165)
	at org.apache.spark.sql.Dataset.$anonfun$collectAsArrowToPython$1(Dataset.scala:4088)
	at org.apache.spark.sql.Dataset.$anonfun$collectAsArrowToPython$1$adapted(Dataset.scala:4087)
	at org.apache.spark.security.SocketAuthServer$.$anonfun$serveToStream$2(SocketAuthServer.scala:140)
	at scala.runtime.java8.JFunction0$mcV$sp.apply(JFunction0$mcV$sp.java:23)
	at org.apache.spark.util.Utils$.tryWithSafeFinally(Utils.scala:1529)
	at org.apache.spark.security.SocketAuthServer$.$anonfun$serveToStream$1(SocketAuthServer.scala:142)
	at org.apache.spark.security.SocketAuthServer$.$anonfun$serveToStream$1$adapted(SocketAuthServer.scala:137)
	at org.apache.spark.security.SocketFuncServer.handleConnection(SocketAuthServer.scala:114)
	at org.apache.spark.security.SocketFuncServer.handleConnection(SocketAuthServer.scala:108)
	at org.apache.spark.security.SocketAuthServer$$anon$1.$anonfun$run$4(SocketAuthServer.scala:69)
	at scala.util.Try$.apply(Try.scala:213)
	at org.apache.spark.security.SocketAuthServer$$anon$1.run(SocketAuthServer.scala:69)
Caused by: org.apache.spark.SparkException: Python worker failed to connect back.
	at org.apache.spark.api.python.PythonWorkerFactory.createSimpleWorker(PythonWorkerFactory.scala:192)
	at org.apache.spark.api.python.PythonWorkerFactory.create(PythonWorkerFactory.scala:109)
	at org.apache.spark.SparkEnv.createPythonWorker(SparkEnv.scala:124)
	at org.apache.spark.api.python.BasePythonRunner.compute(PythonRunner.scala:166)
	at org.apache.spark.sql.execution.python.ArrowEvalPythonExec.evaluate(ArrowEvalPythonExec.scala:89)
	at org.apache.spark.sql.execution.python.EvalPythonExec.$anonfun$doExecute$2(EvalPythonExec.scala:131)
	at org.apache.spark.rdd.RDD.$anonfun$mapPartitions$2(RDD.scala:853)
	at org.apache.spark.rdd.RDD.$anonfun$mapPartitions$2$adapted(RDD.scala:853)
	at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:52)
	at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:364)
	at org.apache.spark.rdd.RDD.iterator(RDD.scala:328)
	at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:52)
	at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:364)
	at org.apache.spark.rdd.RDD.iterator(RDD.scala:328)
	at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:52)
	at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:364)
	at org.apache.spark.rdd.RDD.iterator(RDD.scala:328)
	at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:52)
	at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:364)
	at org.apache.spark.rdd.RDD.iterator(RDD.scala:328)
	at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:92)
	at org.apache.spark.TaskContext.runTaskWithListeners(TaskContext.scala:161)
	at org.apache.spark.scheduler.Task.run(Task.scala:139)
	at org.apache.spark.executor.Executor$TaskRunner.$anonfun$run$3(Executor.scala:554)
	at org.apache.spark.util.Utils$.tryWithSafeFinally(Utils.scala:1529)
	at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:557)
	at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149)
	at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)
	at java.lang.Thread.run(Thread.java:748)
Caused by: java.net.SocketTimeoutException: Accept timed out
	at java.net.DualStackPlainSocketImpl.waitForNewConnection(Native Method)
	at java.net.DualStackPlainSocketImpl.socketAccept(DualStackPlainSocketImpl.java:135)
	at java.net.AbstractPlainSocketImpl.accept(AbstractPlainSocketImpl.java:409)
	at java.net.PlainSocketImpl.accept(PlainSocketImpl.java:199)
	at java.net.ServerSocket.implAccept(ServerSocket.java:545)
	at java.net.ServerSocket.accept(ServerSocket.java:513)
	at org.apache.spark.api.python.PythonWorkerFactory.createSimpleWorker(PythonWorkerFactory.scala:179)
	... 28 more


### Grouped Map Pandas UDFs - edited up to here so far

Grouped map UDFs take in a pandas.DataFrame and output a pandas.DataFrame.

Grouped Map UDFs use the split-apply-combine format. This is where data is first split into groups based on the `.groupby()` function. The UDF is then mapped over each group to return multiple Pandas dataframes. The results of the Pandas dataframes are then  combined and a new Spark dataframe is returned.

In the decorator of a grouped map Pandas UDF the first arguement is the schema of the output dataframe and the second argument is the UDF type. Here is a grouped map UDF to find the percentage mean difference in `incident_cost`

In [27]:
@pandas_udf(rescue.schema, F.PandasUDFType.GROUPED_MAP)
def percentage_difference(pdf:pd.DataFrame) -> pd.DataFrame:
    incident_cost = pdf.incident_cost
    
    return pdf.assign(incident_cost=((incident_cost - incident_cost.mean()))/ incident_cost.mean() * 100)

cost_by_group = rescue.groupBy('animal_group').apply(percentage_difference)
cost_by_group.limit(10).toPandas()

  An error occurred while calling o394.getResult.
: org.apache.spark.SparkException: Exception thrown in awaitResult: 
	at org.apache.spark.util.ThreadUtils$.awaitResult(ThreadUtils.scala:322)
	at org.apache.spark.security.SocketAuthServer.getResult(SocketAuthServer.scala:98)
	at org.apache.spark.security.SocketAuthServer.getResult(SocketAuthServer.scala:94)
	at sun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)
	at sun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)
	at sun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)
	at java.lang.reflect.Method.invoke(Method.java:498)
	at py4j.reflection.MethodInvoker.invoke(MethodInvoker.java:244)
	at py4j.reflection.ReflectionEngine.invoke(ReflectionEngine.java:374)
	at py4j.Gateway.invoke(Gateway.java:282)
	at py4j.commands.AbstractCommand.invokeMethod(AbstractCommand.java:132)
	at py4j.commands.CallCommand.execute(CallCommand.java:79)
	at py4j.ClientServerConnection.waitForC

Py4JJavaError: An error occurred while calling o394.getResult.
: org.apache.spark.SparkException: Exception thrown in awaitResult: 
	at org.apache.spark.util.ThreadUtils$.awaitResult(ThreadUtils.scala:322)
	at org.apache.spark.security.SocketAuthServer.getResult(SocketAuthServer.scala:98)
	at org.apache.spark.security.SocketAuthServer.getResult(SocketAuthServer.scala:94)
	at sun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)
	at sun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)
	at sun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)
	at java.lang.reflect.Method.invoke(Method.java:498)
	at py4j.reflection.MethodInvoker.invoke(MethodInvoker.java:244)
	at py4j.reflection.ReflectionEngine.invoke(ReflectionEngine.java:374)
	at py4j.Gateway.invoke(Gateway.java:282)
	at py4j.commands.AbstractCommand.invokeMethod(AbstractCommand.java:132)
	at py4j.commands.CallCommand.execute(CallCommand.java:79)
	at py4j.ClientServerConnection.waitForCommands(ClientServerConnection.java:182)
	at py4j.ClientServerConnection.run(ClientServerConnection.java:106)
	at java.lang.Thread.run(Thread.java:748)
Caused by: org.apache.spark.SparkException: Job aborted due to stage failure: Task 0 in stage 42.0 failed 1 times, most recent failure: Lost task 0.0 in stage 42.0 (TID 36) (ONS35189.Ons.Statistics.gov.uk executor driver): org.apache.spark.SparkException: Python worker failed to connect back.
	at org.apache.spark.api.python.PythonWorkerFactory.createSimpleWorker(PythonWorkerFactory.scala:192)
	at org.apache.spark.api.python.PythonWorkerFactory.create(PythonWorkerFactory.scala:109)
	at org.apache.spark.SparkEnv.createPythonWorker(SparkEnv.scala:124)
	at org.apache.spark.api.python.BasePythonRunner.compute(PythonRunner.scala:166)
	at org.apache.spark.sql.execution.python.PandasGroupUtils$.executePython(PandasGroupUtils.scala:44)
	at org.apache.spark.sql.execution.python.FlatMapGroupsInPandasExec.$anonfun$doExecute$1(FlatMapGroupsInPandasExec.scala:95)
	at org.apache.spark.rdd.RDD.$anonfun$mapPartitionsInternal$2(RDD.scala:888)
	at org.apache.spark.rdd.RDD.$anonfun$mapPartitionsInternal$2$adapted(RDD.scala:888)
	at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:52)
	at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:364)
	at org.apache.spark.rdd.RDD.iterator(RDD.scala:328)
	at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:52)
	at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:364)
	at org.apache.spark.rdd.RDD.iterator(RDD.scala:328)
	at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:52)
	at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:364)
	at org.apache.spark.rdd.RDD.iterator(RDD.scala:328)
	at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:52)
	at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:364)
	at org.apache.spark.rdd.RDD.iterator(RDD.scala:328)
	at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:92)
	at org.apache.spark.TaskContext.runTaskWithListeners(TaskContext.scala:161)
	at org.apache.spark.scheduler.Task.run(Task.scala:139)
	at org.apache.spark.executor.Executor$TaskRunner.$anonfun$run$3(Executor.scala:554)
	at org.apache.spark.util.Utils$.tryWithSafeFinally(Utils.scala:1529)
	at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:557)
	at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149)
	at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)
	at java.lang.Thread.run(Thread.java:748)
Caused by: java.net.SocketTimeoutException: Accept timed out
	at java.net.DualStackPlainSocketImpl.waitForNewConnection(Native Method)
	at java.net.DualStackPlainSocketImpl.socketAccept(DualStackPlainSocketImpl.java:135)
	at java.net.AbstractPlainSocketImpl.accept(AbstractPlainSocketImpl.java:409)
	at java.net.PlainSocketImpl.accept(PlainSocketImpl.java:199)
	at java.net.ServerSocket.implAccept(ServerSocket.java:545)
	at java.net.ServerSocket.accept(ServerSocket.java:513)
	at org.apache.spark.api.python.PythonWorkerFactory.createSimpleWorker(PythonWorkerFactory.scala:179)
	... 28 more

Driver stacktrace:
	at org.apache.spark.scheduler.DAGScheduler.failJobAndIndependentStages(DAGScheduler.scala:2785)
	at org.apache.spark.scheduler.DAGScheduler.$anonfun$abortStage$2(DAGScheduler.scala:2721)
	at org.apache.spark.scheduler.DAGScheduler.$anonfun$abortStage$2$adapted(DAGScheduler.scala:2720)
	at scala.collection.mutable.ResizableArray.foreach(ResizableArray.scala:62)
	at scala.collection.mutable.ResizableArray.foreach$(ResizableArray.scala:55)
	at scala.collection.mutable.ArrayBuffer.foreach(ArrayBuffer.scala:49)
	at org.apache.spark.scheduler.DAGScheduler.abortStage(DAGScheduler.scala:2720)
	at org.apache.spark.scheduler.DAGScheduler.$anonfun$handleTaskSetFailed$1(DAGScheduler.scala:1206)
	at org.apache.spark.scheduler.DAGScheduler.$anonfun$handleTaskSetFailed$1$adapted(DAGScheduler.scala:1206)
	at scala.Option.foreach(Option.scala:407)
	at org.apache.spark.scheduler.DAGScheduler.handleTaskSetFailed(DAGScheduler.scala:1206)
	at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.doOnReceive(DAGScheduler.scala:2984)
	at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.onReceive(DAGScheduler.scala:2923)
	at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.onReceive(DAGScheduler.scala:2912)
	at org.apache.spark.util.EventLoop$$anon$1.run(EventLoop.scala:49)
	at org.apache.spark.scheduler.DAGScheduler.runJob(DAGScheduler.scala:971)
	at org.apache.spark.SparkContext.runJob(SparkContext.scala:2263)
	at org.apache.spark.SparkContext.runJob(SparkContext.scala:2358)
	at org.apache.spark.sql.Dataset.$anonfun$collectAsArrowToPython$5(Dataset.scala:4107)
	at scala.runtime.java8.JFunction0$mcV$sp.apply(JFunction0$mcV$sp.java:23)
	at org.apache.spark.util.Utils$.tryWithSafeFinally(Utils.scala:1529)
	at org.apache.spark.sql.Dataset.$anonfun$collectAsArrowToPython$2(Dataset.scala:4111)
	at org.apache.spark.sql.Dataset.$anonfun$collectAsArrowToPython$2$adapted(Dataset.scala:4088)
	at org.apache.spark.sql.Dataset.$anonfun$withAction$2(Dataset.scala:4167)
	at org.apache.spark.sql.execution.QueryExecution$.withInternalError(QueryExecution.scala:526)
	at org.apache.spark.sql.Dataset.$anonfun$withAction$1(Dataset.scala:4165)
	at org.apache.spark.sql.execution.SQLExecution$.$anonfun$withNewExecutionId$6(SQLExecution.scala:118)
	at org.apache.spark.sql.execution.SQLExecution$.withSQLConfPropagated(SQLExecution.scala:195)
	at org.apache.spark.sql.execution.SQLExecution$.$anonfun$withNewExecutionId$1(SQLExecution.scala:103)
	at org.apache.spark.sql.SparkSession.withActive(SparkSession.scala:827)
	at org.apache.spark.sql.execution.SQLExecution$.withNewExecutionId(SQLExecution.scala:65)
	at org.apache.spark.sql.Dataset.withAction(Dataset.scala:4165)
	at org.apache.spark.sql.Dataset.$anonfun$collectAsArrowToPython$1(Dataset.scala:4088)
	at org.apache.spark.sql.Dataset.$anonfun$collectAsArrowToPython$1$adapted(Dataset.scala:4087)
	at org.apache.spark.security.SocketAuthServer$.$anonfun$serveToStream$2(SocketAuthServer.scala:140)
	at scala.runtime.java8.JFunction0$mcV$sp.apply(JFunction0$mcV$sp.java:23)
	at org.apache.spark.util.Utils$.tryWithSafeFinally(Utils.scala:1529)
	at org.apache.spark.security.SocketAuthServer$.$anonfun$serveToStream$1(SocketAuthServer.scala:142)
	at org.apache.spark.security.SocketAuthServer$.$anonfun$serveToStream$1$adapted(SocketAuthServer.scala:137)
	at org.apache.spark.security.SocketFuncServer.handleConnection(SocketAuthServer.scala:114)
	at org.apache.spark.security.SocketFuncServer.handleConnection(SocketAuthServer.scala:108)
	at org.apache.spark.security.SocketAuthServer$$anon$1.$anonfun$run$4(SocketAuthServer.scala:69)
	at scala.util.Try$.apply(Try.scala:213)
	at org.apache.spark.security.SocketAuthServer$$anon$1.run(SocketAuthServer.scala:69)
Caused by: org.apache.spark.SparkException: Python worker failed to connect back.
	at org.apache.spark.api.python.PythonWorkerFactory.createSimpleWorker(PythonWorkerFactory.scala:192)
	at org.apache.spark.api.python.PythonWorkerFactory.create(PythonWorkerFactory.scala:109)
	at org.apache.spark.SparkEnv.createPythonWorker(SparkEnv.scala:124)
	at org.apache.spark.api.python.BasePythonRunner.compute(PythonRunner.scala:166)
	at org.apache.spark.sql.execution.python.PandasGroupUtils$.executePython(PandasGroupUtils.scala:44)
	at org.apache.spark.sql.execution.python.FlatMapGroupsInPandasExec.$anonfun$doExecute$1(FlatMapGroupsInPandasExec.scala:95)
	at org.apache.spark.rdd.RDD.$anonfun$mapPartitionsInternal$2(RDD.scala:888)
	at org.apache.spark.rdd.RDD.$anonfun$mapPartitionsInternal$2$adapted(RDD.scala:888)
	at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:52)
	at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:364)
	at org.apache.spark.rdd.RDD.iterator(RDD.scala:328)
	at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:52)
	at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:364)
	at org.apache.spark.rdd.RDD.iterator(RDD.scala:328)
	at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:52)
	at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:364)
	at org.apache.spark.rdd.RDD.iterator(RDD.scala:328)
	at org.apache.spark.rdd.MapPartitionsRDD.compute(MapPartitionsRDD.scala:52)
	at org.apache.spark.rdd.RDD.computeOrReadCheckpoint(RDD.scala:364)
	at org.apache.spark.rdd.RDD.iterator(RDD.scala:328)
	at org.apache.spark.scheduler.ResultTask.runTask(ResultTask.scala:92)
	at org.apache.spark.TaskContext.runTaskWithListeners(TaskContext.scala:161)
	at org.apache.spark.scheduler.Task.run(Task.scala:139)
	at org.apache.spark.executor.Executor$TaskRunner.$anonfun$run$3(Executor.scala:554)
	at org.apache.spark.util.Utils$.tryWithSafeFinally(Utils.scala:1529)
	at org.apache.spark.executor.Executor$TaskRunner.run(Executor.scala:557)
	at java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149)
	at java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)
	at java.lang.Thread.run(Thread.java:748)
Caused by: java.net.SocketTimeoutException: Accept timed out
	at java.net.DualStackPlainSocketImpl.waitForNewConnection(Native Method)
	at java.net.DualStackPlainSocketImpl.socketAccept(DualStackPlainSocketImpl.java:135)
	at java.net.AbstractPlainSocketImpl.accept(AbstractPlainSocketImpl.java:409)
	at java.net.PlainSocketImpl.accept(PlainSocketImpl.java:199)
	at java.net.ServerSocket.implAccept(ServerSocket.java:545)
	at java.net.ServerSocket.accept(ServerSocket.java:513)
	at org.apache.spark.api.python.PythonWorkerFactory.createSimpleWorker(PythonWorkerFactory.scala:179)
	... 28 more


### Grouped Aggregate Pandas UDFs

Grouped aggregate Pandas UDFs take in one or more pandas.Series and output a scalar. 

Grouped aggregate Pandas UDFs are used alongisde the `.groupby()` and `.agg()` functions and are similiar to the Spark `.agg()` function.

In the decorator of a grouped aggregate Pandas UDF the first argument is the data type of the output dataframe and the second argument is the UDF type. Below is an example of a grouped aggregate Pandas UDF to find the mean incident cost of each animal group.

In [30]:
@pandas_udf(T.DoubleType(), F.PandasUDFType.GROUPED_AGG)
def mean_incident_cost(incidentcost: pd.Series) -> T.DoubleType():
    return incidentcost.mean()

mean_cost_of_incident = (rescue.groupBy('animal_group')
      .agg(mean_incident_cost(rescue['incident_cost']).alias('mean_incident_cost'))
      .orderBy('mean_incident_cost', ascending = False))

mean_cost_of_incident.limit(10).toPandas()



Unnamed: 0,animal_group,mean_incident_cost
0,Goat,1180.0
1,Bull,780.0
2,Fish,780.0
3,Horse,747.435065
4,Unknown - Animal rescue from water - Farm animal,709.666667
5,Cow,624.166667
6,Hedgehog,520.0
7,Lamb,520.0
8,Deer,423.882979
9,Unknown - Wild Animal,390.036364


The grouped aggregate Pandas UDF can also be used with the PySpark window functions. Each `pandas.Series` in the input represents a group or window. In Spark 2.4, the grouped aggregate UDF does not support partial aggregations with only an unbounded window supported. Here is an example

In [None]:
from pyspark.sql.window import Window

w = (Window.partitionBy('animal_group')
           .rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing))

rescue = (rescue.withColumn('mean_incident_cost', mean_incident_cost(rescue['incident_cost']).over(w))
                .orderBy('mean_incident_cost', ascending=False))


rescue.limit(10).toPandas()



Unnamed: 0,incident_number,animal_group,incident_duration,incident_cost,mean_incident_cost
0,72214141,Goat,2.0,1180.0,1180.0
1,165500101,Bull,3.0,780.0,780.0
2,129048101,Fish,1.0,260.0,780.0
3,129261101,Fish,2.5,1300.0,780.0
4,102278091,Horse,1.5,780.0,747.435065
5,151984091,Horse,1.0,260.0,747.435065
6,123495091,Horse,1.0,260.0,747.435065
7,2872091,Horse,1.0,255.0,747.435065
8,137525091,Horse,5.0,1300.0,747.435065
9,87402091,Horse,2.0,520.0,747.435065


### Additional Resources

For a more in-depth explaination of what happens under the hood of different types of Pandas UDFs see:
- [Big Data is Just a Lot of Small Data: Using pandas UDF part 1](https://freecontent.manning.com/big-data-is-just-a-lot-of-small-data-using-pandas-udf/)
- [Big Data is Just a Lot of Small Data: Using pandas UDF part 2](https://freecontent.manning.com/big-data-is-just-a-lot-of-small-data-using-pandas-udf-part-2/)