In [0]:
import pandas as pd
from pyspark.sql.window import Window
from pyspark.sql.functions import pandas_udf, col, lit, sum
from pyspark.sql.types import IntegerType, FloatType
from statistics import mode

In [0]:
# create a simple dataframe
data = {
  'strings': ['one', 'one', 'one', 'one', 'two', 'two', 'two', 'two', 'three', 'three', 'three', 'three', ],
  'integers': [1, 1, 2, 3, 4, 5, 5, 5, 6, 7, 8, 9, ],
}

# this workbook is designed to run on databricks so the spark context is already created as spark
df = spark.createDataFrame(pd.DataFrame(data = data))
display(df)

strings,integers
one,1
one,1
one,2
one,3
two,4
two,5
two,5
two,5
three,6
three,7


In [0]:
# a simple example of a pandas UDF

# define the return type as a parameter within the decorator
@pandas_udf(IntegerType())
# pass and return a pd.Series in the decorated function
def string_length(strings: pd.Series) -> pd.Series:
  # note use of pandas string functions
  return strings.str.len()

df = df.withColumn(
  'length', string_length('strings')
)
display(df)

strings,integers,length
one,1,3
one,1,3
one,2,3
one,3,3
two,4,3
two,5,3
two,5,3
two,5,3
three,6,5
three,7,5


In [0]:
# the return type can also be defined as a string though this is less clear
@pandas_udf('integer')
# pass and return a pd.Series in the decorated function
def string_length(strings: pd.Series) -> pd.Series:
  # note use of pandas string functions
  return strings.str.len()

df = df.withColumn(
  'length', string_length('strings')
)
display(df)

strings,integers,length
one,1,3
one,1,3
one,2,3
one,3,3
two,4,3
two,5,3
two,5,3
two,5,3
three,6,5
three,7,5


In [0]:
# alternatively we can create the Pandas UDF without using a decorator

# a simple example of a pandas UDF
def string_length(strings: pd.Series) -> pd.Series:
  # note use of pandas string functions
  return strings.str.len()

# construct our pandas udf by 
udf_string_length = pandas_udf(string_length, IntegerType())

df = df.withColumn(
  'length', udf_string_length('strings')
)
display(df)

strings,integers,length
one,1,3
one,1,3
one,2,3
one,3,3
two,4,3
two,5,3
two,5,3
two,5,3
three,6,5
three,7,5


In [0]:
# for functions to be used in group by functions note that the output
# changes from a series to a python scalar
# define the return type as a parameter within the decorator
@pandas_udf(FloatType())
# pass a pd.Series and return a python type in the decorated function
def pd_mean(values: pd.Series) -> float:
  return values.mean()

# here is a simple way to calculate medians for pyspark aggregations
@pandas_udf(IntegerType())
def pd_median(values: pd.Series) -> int:
  return values.median()

# # not all pandas functions will work smoothly as pandas_udf functions
# # this would give PythonException: 'ValueError: buffer source array is read-only'
# @pandas_udf(IntegerType())
# def pd_mode(values: pd.Series) -> int:
#   return values.mode()

# we can fix the value error by copying the series inside the function
@pandas_udf(IntegerType())
def pd_mode(values: pd.Series) -> int:
  v = values.copy()
  return v.mode()[0]

df_grouped = df.groupBy(
  'strings'
).agg(
  pd_mean('integers').alias('mean'),
  pd_median('integers').alias('median'),
  pd_mode('integers').alias('mode'),
  # note that we cannot mix in with standard pyspark functions
  # sum('integers').alias('sum')
)

display(df_grouped)


strings,mean,median,mode
one,1.75,1,1
three,7.5,7,6
two,4.75,5,5


In [0]:
# the same code also works with windows functions
window = Window.partitionBy('strings')

# another way to use a pandas UDF to calculate mode
# this time we convert the pandas series to a list
# then use statistics' mode function
@pandas_udf(IntegerType())
def pd_mode(values: pd.Series) -> int:
  # convert to list if you need to perform something pandas can't do
  new_values = values.to_list()
  return mode(new_values)

df_window = df.withColumn(
  'mean', pd_mean('integers').over(window)
).withColumn(
  'median', pd_median('integers').over(window)
).withColumn(
  'mode', pd_mode('integers').over(window)
)

display(df_window)

strings,integers,length,mean,median,mode
one,1,3,1.75,1,1
one,1,3,1.75,1,1
one,2,3,1.75,1,1
one,3,3,1.75,1,1
three,6,5,7.5,7,6
three,7,5,7.5,7,6
three,8,5,7.5,7,6
three,9,5,7.5,7,6
two,4,3,4.75,5,5
two,5,3,4.75,5,5
