Similar to SQL GROUP BY clause, PySpark groupBy() function is used to collect the identical data into groups on DataFrame and perform count, sum, avg, min, max functions on the grouped data.

In [2]:
from pyspark.sql import SparkSession
spark=SparkSession.builder.appName('groupBy()').getOrCreate()
spark

In [3]:
data = [("James","Sales","NY",90000,34,10000),
    ("Michael","Sales","NY",86000,56,20000),
    ("Robert","Sales","CA",81000,30,23000),
    ("Maria","Finance","CA",90000,24,23000),
    ("Raman","Finance","CA",99000,40,24000),
    ("Scott","Finance","NY",83000,36,19000),
    ("Jen","Finance","NY",79000,53,15000),
    ("Jeff","Marketing","CA",80000,25,18000),
    ("Kumar","Marketing","NY",91000,50,21000)
  ]

schema = ["employee_name","department","state","salary","age","bonus"]
df=spark.createDataFrame(data,schema)
df.printSchema()


root
 |-- employee_name: string (nullable = true)
 |-- department: string (nullable = true)
 |-- state: string (nullable = true)
 |-- salary: long (nullable = true)
 |-- age: long (nullable = true)
 |-- bonus: long (nullable = true)



In [4]:
help(df.groupBy)

Help on method groupBy in module pyspark.sql.dataframe:

groupBy(*cols: 'ColumnOrName') -> 'GroupedData' method of pyspark.sql.dataframe.DataFrame instance
    Groups the :class:`DataFrame` using the specified columns,
    so we can run aggregation on them. See :class:`GroupedData`
    for all the available aggregate functions.
    
    :func:`groupby` is an alias for :func:`groupBy`.
    
    .. versionadded:: 1.3.0
    
    .. versionchanged:: 3.4.0
        Supports Spark Connect.
    
    Parameters
    ----------
    cols : list, str or :class:`Column`
        columns to group by.
        Each element should be a column name (string) or an expression (:class:`Column`)
        or list of them.
    
    Returns
    -------
    :class:`GroupedData`
        Grouped data by given columns.
    
    Examples
    --------
    >>> df = spark.createDataFrame([
    ...     (2, "Alice"), (2, "Bob"), (2, "Bob"), (5, "Bob")], schema=["age", "name"])
    
    Empty grouping columns triggers a glo

PySpark groupBy on DataFrame Columns

In [6]:
#do the groupBy() on department column of DataFrame and then find the sum of salary for each department using sum() function.
df2=df.groupBy('department').sum('salary')
df2.printSchema()
# df2.show()

root
 |-- department: string (nullable = true)
 |-- sum(salary): long (nullable = true)



In [7]:
#calculate the number of employees in each department using.
df3=df.groupBy('department').count()
df3.printSchema()
# df3.show()

root
 |-- department: string (nullable = true)
 |-- count: long (nullable = false)



In [8]:
#Calculate the minimum salary of each department using min()
df4=df.groupBy(df.department).min('salary')
df4.printSchema()
# df4.show()

root
 |-- department: string (nullable = true)
 |-- min(salary): long (nullable = true)



In [10]:
#maximin salary of each department using max()

df5=df.groupBy('department').max('salary')
df5.printSchema()
# df5.show()


root
 |-- department: string (nullable = true)
 |-- max(salary): long (nullable = true)



In [11]:
#Calculate the average salary of each department using avg()

df6=df.groupBy('department').avg('salary')
df6.printSchema()
# df6.show()


root
 |-- department: string (nullable = true)
 |-- avg(salary): double (nullable = true)



In [12]:
#Calculate the mean salary of each department using mean()
df7=df.groupBy('department').mean('salary')
df7.printSchema()
# df6.show()



root
 |-- department: string (nullable = true)
 |-- avg(salary): double (nullable = true)



we can also run groupBy and aggregate on two or more DataFrame columns

In [13]:
#group by on department,state and does sum() on salary and bonus columns.
df7=df.groupBy('department','state').sum('salary','bonus')
df7.printSchema()
# df7.show()

root
 |-- department: string (nullable = true)
 |-- state: string (nullable = true)
 |-- sum(salary): long (nullable = true)
 |-- sum(bonus): long (nullable = true)



Running more aggregates at a time

Using agg() aggregate function we can calculate many aggregations at a time on a single statement using SQL functions sum(), avg(), min(), max() mean() e.t.c. In order to use these, we should import "from pyspark.sql.functions import sum,avg,max,min,mean,count"

In [15]:
from pyspark.sql.functions import sum,avg,max,count
df9=df.groupBy("department") \
    .agg(count('*'),sum("salary"), \
         avg("salary"), \
         sum("bonus"), \
         max("bonus")\
     ) 
df9.printSchema()


root
 |-- department: string (nullable = true)
 |-- count(1): long (nullable = false)
 |-- sum(salary): long (nullable = true)
 |-- avg(salary): double (nullable = true)
 |-- sum(bonus): long (nullable = true)
 |-- max(bonus): long (nullable = true)



In [16]:
from pyspark.sql.functions import sum,avg,max,count
df9=df.groupBy("department") \
    .agg(count('*').alias('count all'),sum("salary").alias("sum_salary"), \
         avg("salary").alias("avg_salary"), \
         sum("bonus").alias("sum_bonus"), \
         max("bonus").alias("max_bonus") \
     ) 
df9.printSchema()


root
 |-- department: string (nullable = true)
 |-- count all: long (nullable = false)
 |-- sum_salary: long (nullable = true)
 |-- avg_salary: double (nullable = true)
 |-- sum_bonus: long (nullable = true)
 |-- max_bonus: long (nullable = true)



 Using filter on aggregate data  # where 

In [19]:
from pyspark.sql.functions import sum,avg,max,col
df9=df.groupBy("department") \
    .agg(sum("salary").alias("sum_salary"), \
      avg("salary").alias("avg_salary"), \
      sum("bonus").alias("sum_bonus"), \
      max("bonus").alias("max_bonus")) \
    .where(col("sum_bonus") >= 50000)
df9.printSchema() 
# df9.show()

root
 |-- department: string (nullable = true)
 |-- sum_salary: long (nullable = true)
 |-- avg_salary: double (nullable = true)
 |-- sum_bonus: long (nullable = true)
 |-- max_bonus: long (nullable = true)

