# PYSPARK PRACTICE

## Introduction

We'll be working with high-level operations, so we need to create a SparkSession (a unified entrypoint for Spark that contains SparkContext, StreamingContext and SQLContext). If we wanted to work directly with RDDs, using a SparkContext would be more appropiate.

// TODO I'd like to test SparkContext at some point

In [1]:
from pyspark.sql import SparkSession

spark = (
   SparkSession.builder.appName('Spark Demo') 
  .master('local[*]') #local[*] - Run Spark locally with as many worker threads as logical cores on your machine. I could choose any number really
                    # YARN when deploying to a cluster with YARN
  .config("spark.sql.execution.arrow.maxRecordsPerBatch", "100")  # For testing purposes
  .getOrCreate()
)
spark

25/04/19 15:21:40 WARN Utils: Your hostname, alejandro resolves to a loopback address: 127.0.1.1; using 192.168.3.37 instead (on interface wlp2s0)
25/04/19 15:21:40 WARN Utils: Set SPARK_LOCAL_IP if you need to bind to another address
Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
25/04/19 15:21:41 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable
25/04/19 15:21:42 WARN Utils: Service 'SparkUI' could not bind on port 4040. Attempting port 4041.


I won´t be using it here, but I should make an honorary mention to spark-submit, the script used to deploy jobs in a Spark cluster

```bash
./bin/spark-submit \
  --class <main-class> \
  --master <master-url> \
  --deploy-mode <deploy-mode> \
  --conf <key>=<value> \
  --py-files file1.py,file2.py,file3.zip,file4.egg \ # These will be passed to the worker nodes
  --archives conda_env.tag.gz \ # We could package the whole conda environment if we wanted to. Or a virtualEnv 
  ... # other options
  <application-jar> \
  [application-arguments]
```  

Before we continue, some imports:

In [37]:
import time

from pyspark.sql import functions as F
from pyspark.sql import types as T
from pyspark.sql.functions import lit
from typing import Iterator, Tuple

import numpy as np
import pandas as pd

## Explore the data - Introduction

Let's just load a parquet file.

In [3]:
data_path_green = './data/green_tripdata_2025-01.parquet'
data_path_yellow = './data/yellow_tripdata_2025-01.parquet'
df_green = spark.read.parquet(data_path_green, header=True, inferSchema=True) # It is reading the file from my local file system, not the master or worker nodes
df_yellow = spark.read.parquet(data_path_yellow, header=True, inferSchema=True) # It is reading the file from my local file system, not the master or worker nodes
display(df_green.printSchema()) # That display() is part of Jupyter to show the output
display(df_yellow.printSchema())

                                                                                

root
 |-- VendorID: integer (nullable = true)
 |-- lpep_pickup_datetime: timestamp_ntz (nullable = true)
 |-- lpep_dropoff_datetime: timestamp_ntz (nullable = true)
 |-- store_and_fwd_flag: string (nullable = true)
 |-- RatecodeID: long (nullable = true)
 |-- PULocationID: integer (nullable = true)
 |-- DOLocationID: integer (nullable = true)
 |-- passenger_count: long (nullable = true)
 |-- trip_distance: double (nullable = true)
 |-- fare_amount: double (nullable = true)
 |-- extra: double (nullable = true)
 |-- mta_tax: double (nullable = true)
 |-- tip_amount: double (nullable = true)
 |-- tolls_amount: double (nullable = true)
 |-- ehail_fee: double (nullable = true)
 |-- improvement_surcharge: double (nullable = true)
 |-- total_amount: double (nullable = true)
 |-- payment_type: long (nullable = true)
 |-- trip_type: long (nullable = true)
 |-- congestion_surcharge: double (nullable = true)



None

root
 |-- VendorID: integer (nullable = true)
 |-- tpep_pickup_datetime: timestamp_ntz (nullable = true)
 |-- tpep_dropoff_datetime: timestamp_ntz (nullable = true)
 |-- passenger_count: long (nullable = true)
 |-- trip_distance: double (nullable = true)
 |-- RatecodeID: long (nullable = true)
 |-- store_and_fwd_flag: string (nullable = true)
 |-- PULocationID: integer (nullable = true)
 |-- DOLocationID: integer (nullable = true)
 |-- payment_type: long (nullable = true)
 |-- fare_amount: double (nullable = true)
 |-- extra: double (nullable = true)
 |-- mta_tax: double (nullable = true)
 |-- tip_amount: double (nullable = true)
 |-- tolls_amount: double (nullable = true)
 |-- improvement_surcharge: double (nullable = true)
 |-- total_amount: double (nullable = true)
 |-- congestion_surcharge: double (nullable = true)
 |-- Airport_fee: double (nullable = true)



None

In [4]:
df_green.show(1)

+--------+--------------------+---------------------+------------------+----------+------------+------------+---------------+-------------+-----------+-----+-------+----------+------------+---------+---------------------+------------+------------+---------+--------------------+
|VendorID|lpep_pickup_datetime|lpep_dropoff_datetime|store_and_fwd_flag|RatecodeID|PULocationID|DOLocationID|passenger_count|trip_distance|fare_amount|extra|mta_tax|tip_amount|tolls_amount|ehail_fee|improvement_surcharge|total_amount|payment_type|trip_type|congestion_surcharge|
+--------+--------------------+---------------------+------------------+----------+------------+------------+---------------+-------------+-----------+-----+-------+----------+------------+---------+---------------------+------------+------------+---------+--------------------+
|       2| 2025-01-01 00:03:01|  2025-01-01 00:17:12|                 N|         1|          75|         235|              1|         5.93|       24.7|  1.0|    0.

What if I want to manually create a DF? Let's do it here and use it later

In [5]:
data = [
    (1, "Luis Mendoza", "Mexico"),
    (2, "Amina Yusuf", "Nigeria"),
    (3, "Chen Wei", "China"),
    (4, "Elena Petrova", "Russia"),
    (5, "David O'Connor", "Ireland"),
    (6, "Raj Patel", "India"),
    (7, "Fatima Al-Sayed", "Egypt"),
    (8, "John Smith", "USA"),
    (9, "Mateo García", "Spain"),
    (10, "Anna Müller", "Germany")
]

# Manually define schema
schema = T.StructType([
    T.StructField("id", T.IntegerType(), False),
    T.StructField("name", T.StringType(), False),
    T.StructField("country", T.StringType(), False)
])

# Create DataFrame
df_drivers = spark.createDataFrame(data, schema)
df_drivers.show()


                                                                                

+---+---------------+-------+
| id|           name|country|
+---+---------------+-------+
|  1|   Luis Mendoza| Mexico|
|  2|    Amina Yusuf|Nigeria|
|  3|       Chen Wei|  China|
|  4|  Elena Petrova| Russia|
|  5| David O'Connor|Ireland|
|  6|      Raj Patel|  India|
|  7|Fatima Al-Sayed|  Egypt|
|  8|     John Smith|    USA|
|  9|   Mateo García|  Spain|
| 10|    Anna Müller|Germany|
+---+---------------+-------+



Three important terms to remember here:
* *StructType*: A built-in DataType from org.apache.spark.sql.types that implements scala.collection.Seq<StructField>. Basically, it is a `Seq` of `StructField`. 
* *StructField*: The name, type and default `null` value of a row. You could see it as metadata of whatever you'll find in a row.
* *Row*: The values of the "column"

Let's do something a little bit more complicated. Let us show columns in groups of 4

In [6]:
def show_split(df, split=-1, n_samples=10):
    n_cols = len(df.columns)
    if split <= 0:
        split = n_cols
    i = 0
    j = i + split
    while i < n_cols:
        df.select(*df.columns[i:j]).show(n_samples) # That * operator will unpack the columsn from [c1, c2, c3] to c1, c2, c3
        i = j
        j = i + split
        
show_split(df_green, 4, 2)

+--------+--------------------+---------------------+------------------+
|VendorID|lpep_pickup_datetime|lpep_dropoff_datetime|store_and_fwd_flag|
+--------+--------------------+---------------------+------------------+
|       2| 2025-01-01 00:03:01|  2025-01-01 00:17:12|                 N|
|       2| 2025-01-01 00:19:59|  2025-01-01 00:25:52|                 N|
+--------+--------------------+---------------------+------------------+
only showing top 2 rows

+----------+------------+------------+---------------+
|RatecodeID|PULocationID|DOLocationID|passenger_count|
+----------+------------+------------+---------------+
|         1|          75|         235|              1|
|         1|         166|          75|              1|
+----------+------------+------------+---------------+
only showing top 2 rows

+-------------+-----------+-----+-------+
|trip_distance|fare_amount|extra|mta_tax|
+-------------+-----------+-----+-------+
|         5.93|       24.7|  1.0|    0.5|
|         1.32

We can gather quick information about certain columns by typing a simple line.

In [7]:
df_green.describe('tip_amount', 'tolls_amount').show()

+-------+------------------+-------------------+
|summary|        tip_amount|       tolls_amount|
+-------+------------------+-------------------+
|  count|             48326|              48326|
|   mean|2.4818590406821954|0.17746099408185748|
| stddev| 3.213612255803839| 1.1929839044404538|
|    min|              -0.9|                0.0|
|    max|            252.05|              48.76|
+-------+------------------+-------------------+



We can be a little more specific with the values we want

In [8]:
# summary take statistics as params
df_green.select('fare_amount', 'extra').summary('count', 'min', 'max', 'mean', '50%').show()

+-------+------------------+------------------+
|summary|       fare_amount|             extra|
+-------+------------------+------------------+
|  count|             48326|             48326|
|    min|            -113.0|              -5.0|
|    max|             336.2|               7.5|
|   mean|16.762465546496703|0.9323242147084385|
|    50%|              13.5|               0.0|
+-------+------------------+------------------+



As for today (04/2025), Python does not have access to the Dataset class; it can only use Dataframes. But, what is the difference between these two? And what makes them different to an RDD?
* At the core, an RDD is an immutable distributed collection of elements of your data, partitioned across nodes in your cluster that can be operated in parallel with a low-level API that offers transformations and actions.
  * They have no schema.
* Like an RDD, a DataFrame is an immutable distributed collection of data. Unlike an RDD, data is organized into named columns, allowing higher-level abstraction. We could see these structures as DataFrames[Row]
* Datasets are just strongly-typed DataFrames.

You cannot overlook the space efficiency and performance gains in using DataFrames and Dataset APIs for two reasons.
* First, because DataFrame and Dataset APIs are built on top of the Spark SQL engine, it uses Catalyst to generate an optimized logical and physical query plan. All relation type queries undergo the same code optimizer, providing the space and speed efficiency. Whereas the Dataset[T] typed API is optimized for data engineering tasks, the untyped Dataset[Row] (an alias of DataFrame) is even faster and suitable for interactive analysis.
* Second, since Spark as a compiler understands your Dataset type JVM object, it maps your type-specific JVM object to Tungsten's internal memory representation using Encoders. As a result, Tungsten Encoders can efficiently serialize/deserialize JVM objects as well as generate compact bytecode that can execute at superior speeds.
  * Tungsten is the codename for the umbrella project to make changes to Apache Spark’s execution engine that focuses on substantially improving the efficiency of memory and CPU for Spark applications

From: https://www.databricks.com/blog/2016/07/14/a-tale-of-three-apache-spark-apis-rdds-dataframes-and-datasets.html

Let's keep messing with our datasets. Now I want to add (and remove) a column.

In [9]:
df_green.withColumn("tmp_column", lit(10)) # That function lit() is very important, Spark will fail without it. It means "literal"
                                            # and it is a way to tell Spark that that is the literal value that should be in every record
df_green.drop('tmp_column')

DataFrame[VendorID: int, lpep_pickup_datetime: timestamp_ntz, lpep_dropoff_datetime: timestamp_ntz, store_and_fwd_flag: string, RatecodeID: bigint, PULocationID: int, DOLocationID: int, passenger_count: bigint, trip_distance: double, fare_amount: double, extra: double, mta_tax: double, tip_amount: double, tolls_amount: double, ehail_fee: double, improvement_surcharge: double, total_amount: double, payment_type: bigint, trip_type: bigint, congestion_surcharge: double]

And now I want to change the value of a column

In [10]:
# df_green.withColumn("passenger_count", lit(2)) # The name of the column already exists, so no need to add new column

# Explore the data in SQL-like fashion

We can use PySpark-like syntax. In the example below, I access columns in three different ways. All of them work, although the third one (`F.col`) does require an import.
What is the difference between them? https://stackoverflow.com/questions/55105363/pyspark-dataframe-column-reference-df-col-vs-dfcol-vs-f-colcol

In [11]:
df_green.where(df_green['total_amount']>=100).select(df_green.total_amount, F.col('trip_type')).show(10)

+------------+---------+
|total_amount|trip_type|
+------------+---------+
|       108.9|        2|
|       101.0|        2|
|       222.0|        2|
|      164.38|        2|
|       124.5|        2|
|       120.0|        2|
|       114.0|        1|
|      121.28|        2|
|       110.9|        1|
|       103.2|        2|
+------------+---------+
only showing top 10 rows



We can also submit an SQL query directly.

`createOrReplaceTempView()` creates (or replaces if that view name already exists) a lazily evaluated "view". We need it to run queries this way. It does not persist to memory unless you cache the dataset that underpins the view. 

In [12]:
df_green.createOrReplaceTempView('df') # Register the dataframe as a view. This operation is lazy, won't materialize until we need it.
                                        # It does not cache unless otherwise specified.
spark.sql('''
    SELECT total_amount, trip_type FROM df
    WHERE total_amount >= 100 
''').show(10)



+------------+---------+
|total_amount|trip_type|
+------------+---------+
|       108.9|        2|
|       101.0|        2|
|       222.0|        2|
|      164.38|        2|
|       124.5|        2|
|       120.0|        2|
|       114.0|        1|
|      121.28|        2|
|       110.9|        1|
|       103.2|        2|
+------------+---------+
only showing top 10 rows



If I modify the underlaying df, would hte view be affected by it? Let's have a look!

In [13]:
df_green.withColumn("tmp_column", lit(10))

#spark.sql('''
#    SELECT total_amount, trip_type, tmp_column FROM df
#    WHERE total_amount >= 100 
#''').show(10)

## It fails because it does not find the new column!

df_green.drop('tmp_column')

DataFrame[VendorID: int, lpep_pickup_datetime: timestamp_ntz, lpep_dropoff_datetime: timestamp_ntz, store_and_fwd_flag: string, RatecodeID: bigint, PULocationID: int, DOLocationID: int, passenger_count: bigint, trip_distance: double, fare_amount: double, extra: double, mta_tax: double, tip_amount: double, tolls_amount: double, ehail_fee: double, improvement_surcharge: double, total_amount: double, payment_type: bigint, trip_type: bigint, congestion_surcharge: double]

## Some more practice with more "complex" queries

Groups and filtering

In [14]:
# Out of the first 50 vendors, which ones do have a total amount bigger than a given number?
(  
  df_green.where(df_green['VendorID']<=50) 
  .groupBy('VendorID') 
  .agg(F.sum('total_amount').alias('sumAmount')) # This agg() defines how the values should merge, but it does not imply an aggregation of the 
                                            # values per se (Could be a count, an avg, min, max, variance, etc.).
  .where(F.col('sumAmount') > 300000)   # This second "where" acts like a HAVING (which in SQL acts as a where in aggregate functions anyway)
  .show(10)
)

+--------+-----------------+
|VendorID|        sumAmount|
+--------+-----------------+
|       2|954358.4599999995|
+--------+-----------------+



Table unions (They work as a UNION ALL, with duplicates, unless stated otherwise)

In [15]:
df_split = df_green.randomSplit([0.5, 0.5])[0]
display(df_green.count())
display(df_split.count())  
display(df_green.union(df_split).count()) # Duplicates are still there!
display(df_green.union(df_split).distinct().count())

48326

24147

72473

48326

Table Joins

In [16]:
# Maybe I need a better example for this
(
df_green.where(df_green["VendorID"]<=10)
.join(df_drivers, df_green.VendorID == df_drivers.id, "inner")
.select(F.col('VendorID'), F.col('total_amount'), F.col('name'), F.col('country'))
.show(10)
)


+--------+------------+------------+-------+
|VendorID|total_amount|        name|country|
+--------+------------+------------+-------+
|       1|        57.2|Luis Mendoza| Mexico|
|       1|       13.49|Luis Mendoza| Mexico|
|       1|       16.95|Luis Mendoza| Mexico|
|       1|       59.57|Luis Mendoza| Mexico|
|       1|       17.65|Luis Mendoza| Mexico|
|       1|        18.8|Luis Mendoza| Mexico|
|       1|       90.21|Luis Mendoza| Mexico|
|       1|       17.94|Luis Mendoza| Mexico|
|       1|        29.7|Luis Mendoza| Mexico|
|       1|       16.12|Luis Mendoza| Mexico|
+--------+------------+------------+-------+
only showing top 10 rows



## UDFs

What are UDFs? User-Defined Functions (aka UDF) is a feature of Spark SQL to define new Column-based functions that extend the vocabulary of Spark SQL's DSL for transforming Datasets.

In [17]:
df_green.select(df_green['payment_type']).distinct().show()

+------------+
|payment_type|
+------------+
|           5|
|           1|
|           3|
|           2|
|           4|
|        NULL|
+------------+



In [18]:
def give_payment_name(payment_type:int) -> str:
    if payment_type == 1:
        return "Cash"
    if payment_type == 2:
        return "National_Card"
    if payment_type == 3:
        return "International_Card"
    if payment_type == 4:
        return "Bizum"
    if payment_type == 5:
        return "Begged_for_mercy"
    return "Null?"

# Let's create a UDF
givePaymentName = F.udf(give_payment_name, T.StringType()) # Important!! Apparently, DataTypes are not singletons, so we have to "create" one by
                                                    # calling the constructor i.e. adding those "()"

And what if we do not want to register the UDF manually? There is a tag for it as well

In [19]:
@F.udf(T.StringType())
def givePaymentName(payment_type:int) -> str:
    if payment_type == 1:
        return "Cash"
    if payment_type == 2:
        return "National_Card"
    if payment_type == 3:
        return "International_Card"
    if payment_type == 4:
        return "Bizum"
    if payment_type == 5:
        return "Begged_for_mercy"
    return "Null?"

In [20]:
df_green.select('VendorID', givePaymentName(df_green.payment_type)).show(10)

[Stage 42:>                                                         (0 + 1) / 1]

+--------+-----------------------------+
|VendorID|givePaymentName(payment_type)|
+--------+-----------------------------+
|       2|                         Cash|
|       2|                National_Card|
|       2|                National_Card|
|       2|                         Cash|
|       2|                         Cash|
|       2|                National_Card|
|       2|                         Cash|
|       2|                National_Card|
|       2|                         Cash|
|       2|                National_Card|
+--------+-----------------------------+
only showing top 10 rows



                                                                                

What if we want to use this UDF with the SQL syntax? 

In [21]:
spark.udf.register('givePaymentName', give_payment_name, T.StringType()) # We need to register the function with this one
                                                                            # The F.udf used above won't cut it 
spark.sql('''
    SELECT VendorID, givePaymentName(payment_type) PaymentTypeString
    FROM df
''').show(10)



+--------+-----------------+
|VendorID|PaymentTypeString|
+--------+-----------------+
|       2|             Cash|
|       2|    National_Card|
|       2|    National_Card|
|       2|             Cash|
|       2|             Cash|
|       2|    National_Card|
|       2|             Cash|
|       2|    National_Card|
|       2|             Cash|
|       2|    National_Card|
+--------+-----------------+
only showing top 10 rows



### Pandas UDF vs Python UDF


* In a Python UDF, when you pass column objects to your UDF, PySpark will unpack each value/row (convert the data between the python environment and the JVM, basically a serialization), perform the computation, and then return the value for each record.
* In a Scalar UDF, PySpark will serialize (through a library called PyArrow) each partitioned column into a pandas Series object. You then perform the operations on the Series object directly (avoiding the overhead of unpacking each row individually),returning a Series of the same dimension from your UDF.

From an end user perspective, they are the same functionally. Because pandas is optimized for rapid data manipulation, it is preferable to use a Series to Series UDF when you can instead of using a regular Python UDF, as it’ll be much faster.

Of course, this is only a problem for PySpark. If we were to use Spark and Scala/Java, that would be no problem.

From: https://gist.github.com/ThaiDat/81c3662801aa8410a65b94f3c993c377

In [22]:
@F.pandas_udf(T.StringType())
def giveTripTypeName(trip_type: pd.Series) -> pd.Series:
    return trip_type.map({
        1: "Normal",
        2: "Premium"
    }).fillna("Null?")

# We can also promote the function to pandas as GetUserType = F.pandas_udf(get_user_type, T.StringType())
def give_payment_name(payment_type:pd.Series) -> pd.Series:   # We are overwritting the method from above
    return payment_type.map({
        1: "Cash",
        2: "National_Card",
        3: "International_Card",
        4: "Bizum",
        5: "Begged_for_mercy"
    }).fillna("Null?")


givePaymentName = F.pandas_udf(give_payment_name, T.StringType())

We cannot use these UDFs with the SQL syntax.

In [23]:
df_green.select('VendorID', givePaymentName(df_green.payment_type).alias('PaymentTypeName'), giveTripTypeName(df_green.trip_type).alias("TripTypeClass")).show(10)

+--------+---------------+-------------+
|VendorID|PaymentTypeName|TripTypeClass|
+--------+---------------+-------------+
|       2|           Cash|       Normal|
|       2|  National_Card|       Normal|
|       2|  National_Card|      Premium|
|       2|           Cash|       Normal|
|       2|           Cash|       Normal|
|       2|  National_Card|       Normal|
|       2|           Cash|       Normal|
|       2|  National_Card|       Normal|
|       2|           Cash|       Normal|
|       2|  National_Card|       Normal|
+--------+---------------+-------------+
only showing top 10 rows



TODO: Is there an UDF that could not be transformed into a Pandas' UDF?

Now, what if ,instead of passing a column to our UDF, we want to pass it two or three? We could just pass several `pd.Series` as parameters

In [24]:
@F.pandas_udf(returnType=T.FloatType())
def tip_percent_of_fare_1(tip_amount: pd.Series, total_amount: pd.Series) -> pd.Series:
    return (tip_amount / total_amount) * 100

display(df_green.select(F.col('VendorId'), tip_percent_of_fare_1(F.col('tip_amount'), F.col('total_amount')).alias('TipPercentage')).show())

+--------+-------------+
|VendorId|TipPercentage|
+--------+-------------+
|       2|         20.0|
|       2|          0.0|
|       2|          0.0|
|       2|      16.6712|
|       2|    16.666666|
|       2|          0.0|
|       2|    61.281338|
|       2|          0.0|
|       2|    16.666666|
|       2|          0.0|
|       2|    19.985518|
|       2|     5.263158|
|       1|          0.0|
|       1|          0.0|
|       2|    5.9620595|
|       2|    16.666666|
|       2|    16.666666|
|       2|    16.666666|
|       2|    6.6666665|
|       2|          0.0|
+--------+-------------+
only showing top 20 rows



None

All previous UDFs are processed in batches. What if there is an operation we do not want to repeat over an over every batch? What if we only want to do it once and use the result of said operation every batch? (loading a ML model, for example). In this case, instead of defining the input and output of our UDF as a pd.Series, we could define it as an iterator of batches of pd.Series.

In [39]:
@F.pandas_udf(T.StringType())
def give_payment_name(batch_iter: Iterator[pd.Series]) -> Iterator[pd.Series]:   # We are overwritting the method from above
    # Some very expensive operation
    time.sleep(10)
    
    try:
        for x in batch_iter:
            yield x.map({
                1: "Cash",
                2: "National_Card",
                3: "International_Card",
                4: "Bizum",
                5: "Begged_for_mercy"
            }).fillna("Null?")
    finally:
        pass  # release resources here, if anytest_df.select(square_plus_y(col("id"))).show()

df_green.select('VendorID', give_payment_name(df_green.payment_type).alias('PaymentTypeName'), giveTripTypeName(df_green.trip_type).alias("TripTypeClass")).show(10)

[Stage 53:>                                                         (0 + 1) / 1]

+--------+---------------+-------------+
|VendorID|PaymentTypeName|TripTypeClass|
+--------+---------------+-------------+
|       2|           Cash|       Normal|
|       2|  National_Card|       Normal|
|       2|  National_Card|      Premium|
|       2|           Cash|       Normal|
|       2|           Cash|       Normal|
|       2|  National_Card|       Normal|
|       2|           Cash|       Normal|
|       2|  National_Card|       Normal|
|       2|           Cash|       Normal|
|       2|  National_Card|       Normal|
+--------+---------------+-------------+
only showing top 10 rows



                                                                                

And what if we need to use this solution but use several columns as input? No problem either:

In [40]:
@F.pandas_udf(returnType=T.FloatType())
def tip_percent_of_fare(values: Iterator[Tuple[pd.Series, pd.Series]]) -> Iterator[pd.Series]:
    # Some very expensive operation
    time.sleep(10)

    for tip, total in values:
        yield (tip / total) * 100

display(df_green.select(F.col('VendorId'), tip_percent_of_fare(F.col('tip_amount'), F.col('total_amount')).alias('TipPercentage')).show())


[Stage 54:>                                                         (0 + 1) / 1]

+--------+-------------+
|VendorId|TipPercentage|
+--------+-------------+
|       2|         20.0|
|       2|          0.0|
|       2|          0.0|
|       2|      16.6712|
|       2|    16.666666|
|       2|          0.0|
|       2|    61.281338|
|       2|          0.0|
|       2|    16.666666|
|       2|          0.0|
|       2|    19.985518|
|       2|     5.263158|
|       1|          0.0|
|       1|          0.0|
|       2|    5.9620595|
|       2|    16.666666|
|       2|    16.666666|
|       2|    16.666666|
|       2|    6.6666665|
|       2|          0.0|
+--------+-------------+
only showing top 20 rows



                                                                                

None

All the operations above return a series of the same lenght of the input. What if we want something different? We can also do that.

We can return only one value per batch

In [45]:
@F.pandas_udf(T.DoubleType())
def GetStdDeviation(series: pd.Series) -> float:
    return series.std()

(
    df_green.groupBy('VendorID')
    .agg(
        GetStdDeviation(df_green.total_amount).alias('var')
    )
    .orderBy('var', ascending=False)
    .show(10)
)


+--------+------------------+
|VendorID|               var|
+--------+------------------+
|       2|15.678956050473168|
|       1|13.682413066267353|
+--------+------------------+



Or we can return a completely new struck!

In [55]:
def getVar(data: pd.DataFrame) -> pd.DataFrame:
    return pd.DataFrame({
        'VendorID': [data['VendorID'].iloc[0]],  # Because of the previous groupBy(), each partition will only have one VendorID
        'var': [data['total_amount'].std()]
    })

# We can also use the SQL version of schema: 
# schema = 'type string, amount double, amountNorm double'
schema = T.StructType([
    T.StructField('VendorID', T.IntegerType()),
    T.StructField('var', T.FloatType())
])

(
    df_green.groupBy('VendorID')
    .applyInPandas(getVar, schema)   # When using this function, we do not need to define the method as an UDF.
                                                # IMPORTANT! Apparently, this is now recommended in Spark 3 and the older method deprecated! 
                                                # But it can only be used when dealing with pd.DataFrame as inputs and outputs
    .show(10)
)

+--------+---------+
|VendorID|      var|
+--------+---------+
|       1|13.682413|
|       2|15.678956|
+--------+---------+



Assuming this last case (the number of records of the output does not need to be the same as the number of records of the input), the function you need to use depends on the output:

| Input | Output | Function |
| --- | --- | --- |
| One row | One row | Pandas UDF |
| Many rows | One | `applyInPandas()` or `groupBy().agg()`|
| One row | Many rows | `mapInPandas()` or `explode`|
* In this table I understand "row" as "record"

In [None]:
## Windows