In [0]:
from pyspark.sql import functions as fn, types as T
from pyspark.sql import Window
import pandas as pd
# import panda_udf difect so we can use it as a decorator
from pyspark.sql.functions import pandas_udf, PandasUDFType



In [0]:
df_array = spark.createDataFrame(
    [
      (1, [1,2,3]), 
      (2, [4,5,6]), 
      (3, [7,8,9]), 
      (1, [2,2,2]), 
      (2, [5,5,5]), 
      (3, [8,8,8])
    ], 
  ("group", "array")
)

# this display will only sum the group column, not the array column
display(df_array.groupBy().sum())

sum(group)
12


In [0]:
@pandas_udf(T.ArrayType(T.IntegerType()))
def sum_array(input: pd.Series) -> float:
  # now we can use pandas sum which does handle arrays
  return input.sum() 

@pandas_udf(T.ArrayType(T.FloatType()))
def avg_array(input: pd.Series) -> float:
  # now we can use pandas mean which does handle arrays
  return input.mean() 

window = Window.partitionBy(
  'group'
).rowsBetween(
  Window.unboundedPreceding,
  Window.unboundedFollowing
)

df_out = df_array.withColumn(
  'sum_array', sum_array('array').over(window)
).withColumn(
  'avg_array', avg_array('array').over(window)
)

display(df_out)

group,array,sum_array,avg_array
1,"List(1, 2, 3)","List(3, 4, 5)","List(1.5, 2.0, 2.5)"
1,"List(2, 2, 2)","List(3, 4, 5)","List(1.5, 2.0, 2.5)"
2,"List(4, 5, 6)","List(9, 10, 11)","List(4.5, 5.0, 5.5)"
2,"List(5, 5, 5)","List(9, 10, 11)","List(4.5, 5.0, 5.5)"
3,"List(7, 8, 9)","List(15, 16, 17)","List(7.5, 8.0, 8.5)"
3,"List(8, 8, 8)","List(15, 16, 17)","List(7.5, 8.0, 8.5)"


In [0]:
# in older versions of PySpark (eg 2.4) you can try this

# @pandas_udf(T.ArrayType(T.IntegerType()), PandasUDFType.GROUPED_AGG)
# def sum_array(input):
#   return input.sum() 

# @pandas_udf(T.ArrayType(T.FloatType()), PandasUDFType.GROUPED_AGG)
# def avg_array(input):
#   return input.mean() 

# window = (Window
#     .partitionBy('group')
#     .rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing))

# df_out = df_array.withColumn(
#   'sum_array', sum_array('array').over(window)
# ).withColumn(
#   'avg_array', avg_array('array').over(window)
# )
#
# display(df_out)