In [1]:
# Requires spark >= 2.3
# Requires PyArrow (pip install pyarrow)

In [2]:
from pyspark import SparkContext, SQLContext

import pyspark.sql.types as T
import pyspark.sql.functions as F

from pyspark.sql.functions import pandas_udf
from pyspark.sql.functions import PandasUDFType

In [3]:
sc = SparkContext()
sqlc = SQLContext(sc)

In [4]:
schema = T.StructType(
    [
        T.StructField('group_col', T.StringType()),
        T.StructField('column1', T.IntegerType()),
        T.StructField('column2', T.IntegerType())
    ])

df = sqlc.createDataFrame(
    [
        ['a', 1, 1], 
        ['c', 3, 4], 
        ['e', 5, 6],
        ['a', 2, 7],
        ['a', 3, 2],
        ['e', 3, 6]
    ],
    schema=schema)

df.toPandas()

Unnamed: 0,group_col,column1,column2
0,a,1,1
1,c,3,4
2,e,5,6
3,a,2,7
4,a,3,2
5,e,3,6


In [5]:
# make schema for the return dataframe of the udf
udf_schema = T.StructType(df.schema.fields + [T.StructField('size', T.IntegerType())])

# define a function that inputs a dataframe (corresponding to grouped data)
# and outputs a dataframe with the specified data
# In this case, the function appends the size of the grouped dataframe as a column.

@pandas_udf(udf_schema, PandasUDFType.GROUPED_MAP)
def add_count(grouped_df):
    return grouped_df.assign(size=len(grouped_df))

In [7]:
df.groupBy('group_col').apply(add_count).toPandas()

Unnamed: 0,group_col,column1,column2,size
0,e,5,6,2
1,e,3,6,2
2,c,3,4,1
3,a,1,1,3
4,a,2,7,3
5,a,3,2,3
