In [8]:
from pyspark.sql import SparkSession
from pyspark.sql.functions import countDistinct, count, approx_count_distinct, collect_list, collect_set, avg
from pyspark.sql.types import *

spark = SparkSession.Builder().appName('Aggregate Function in PySpark').getOrCreate()

In [3]:
empData = [
    (1,"Rohit",'M', 3000, "Data"),
    (2,"Ajay", 'M',2000, "Data"),
    (6,"Deepshika", 'F',2000, "Data"),
    (3,"Hemma", 'F',2000, "HR"),
    (4,"Arti", 'F',2000, "Marketing"),
    (5,"Kanchan", 'F',2000, "Marketing"),
]  

empDataSchema = ['empID', 'Name', 'Gender','Salary', 'dept']

df = spark.createDataFrame(empData, empDataSchema)
df.show()

                                                                                

+-----+---------+------+------+---------+
|empID|     Name|Gender|Salary|     dept|
+-----+---------+------+------+---------+
|    1|    Rohit|     M|  3000|     Data|
|    2|     Ajay|     M|  2000|     Data|
|    6|Deepshika|     F|  2000|     Data|
|    3|    Hemma|     F|  2000|       HR|
|    4|     Arti|     F|  2000|Marketing|
|    5|  Kanchan|     F|  2000|Marketing|
+-----+---------+------+------+---------+



#### approx_count_distinct() --> to count the distinct number of values in a particular column

In [22]:
# approx_count_distinct()
df.select(approx_count_distinct('salary').alias('dist_salary_count')).show()
# avg
df.select(avg('salary').alias('avg_salary')).show()

+-----------------+
|dist_salary_count|
+-----------------+
|                2|
+-----------------+

+------------------+
|        avg_salary|
+------------------+
|2166.6666666666665|
+------------------+



#### collect_list() --> returns the list of all values (with duplicates) in a group

In [16]:
# collect_list()
df.select(collect_list('salary')).show(truncate=False)
# collect_list() with groupBy()
df.groupBy('dept').agg(collect_list('salary')).show()

+------------------------------------+
|collect_list(salary)                |
+------------------------------------+
|[3000, 2000, 2000, 2000, 2000, 2000]|
+------------------------------------+

+---------+--------------------+
|     dept|collect_list(salary)|
+---------+--------------------+
|     Data|  [3000, 2000, 2000]|
|       HR|              [2000]|
|Marketing|        [2000, 2000]|
+---------+--------------------+



#### collect_set() --> returns the list of distinct values (without duplicates) within a group

In [20]:
# collect_set()
df.select(collect_set('salary')).show()
# collect_set() with groupBy()
df.groupBy('dept').agg(collect_set('salary')).show()

+-------------------+
|collect_set(salary)|
+-------------------+
|       [3000, 2000]|
+-------------------+

+---------+-------------------+
|     dept|collect_set(salary)|
+---------+-------------------+
|     Data|       [3000, 2000]|
|       HR|             [2000]|
|Marketing|             [2000]|
+---------+-------------------+



#### countDistinct() --> to count the distinct number of rows in dataframe or subset of dataframe

##### countDistinct() is different then approx_count_distinct() 
##### approx_count_distinct() will work on only 1 column but countDistinct() will check of all the distinct rows 

In [26]:
df.select(countDistinct('salary', 'dept').alias('countDistinctRows')).show()
df.select(countDistinct('salary').alias('countDistinctRows')).show()

+-----------------+
|countDistinctRows|
+-----------------+
|                4|
+-----------------+

+-----------------+
|countDistinctRows|
+-----------------+
|                2|
+-----------------+



#### count() --> to count the number of rows

In [34]:
df.count(), df.select('salary').count(), df.select('salary', 'dept').count()

(6, 6, 6)

In [35]:
spark.stop()