# Lesson 18 - Filtering, Sorting, and Grouping

In [0]:
from pyspark.sql import SparkSession
from pyspark.sql.functions import col, expr


spark = SparkSession.builder.getOrCreate()
sc = spark.sparkContext

## Introduction

In this lesson, we will introduce DataFrame methods that can be used to perform tasks commonly required when working with tabular data. Specifically, we will discuss how to filter DataFrames, how to sort the rows of a DataFrame, and how to perform grouping and aggregation on the rows of a DataFrame. We will use the Gapminder dataset to illustrate these concepts.

In [0]:
gm_schema = (
    'country STRING, year INTEGER, continent STRING, population INTEGER, '
    'life_exp DOUBLE, gdp_per_cap DOUBLE, gini DOUBLE'
)

gm_df = (
  spark.read
  .option('delimiter', '\t')
  .option('header', True)
  .schema(gm_schema)
  .csv('/FileStore/tables/gapminder_data_2020.txt')
)
    
gm_df.printSchema()

### The `filter()` Transformation

We can use the `filter()` transformation to reduce the size of a DataFrame by selecting out rows that satisfy a certain condition. We can specify the condition by applying a Python comparison operator to a column object, or by specifying the desired comparison as an expression string provided to `expr()`.

In [0]:
#gm_20 = gm_df.filter(col('year') == 2020)
gm_20 = gm_df.filter(expr('year == 2020'))
gm_20.persist()
gm_20.show(10)

We apply more complicated filters by chaining calls to `filter()` or by using logical operators such as `AND` and `OR` inside our expression strings.

In [0]:
gm_df.filter(expr('year == 2020 AND continent == "asia"')).count()

### The `sort()` Transformation 

We can use the `sort()` transformation to sort a DataFrame by one or more columns. This transformation accepts one or more arguments, each of which is expected to be either a column object or a string containing the name of a column. In the cell below, we use `sort()` to identify the five countries with the lowest life expectency in 2020.

In [0]:
gm_20.sort('life_exp').show(5, truncate=False)

The `sort()` method has an optional `ascending` parameter. This is equal to `True` by default, but can be set `False` to sort the DataFrame rows in descending order.

In [0]:
gm_20.sort('life_exp', ascending=False).show(5)

We can sort by more than one column at a time by indicating each of the columns to be involved in the sort. The first column provided is the primary column according to which the results are sorted.

In [0]:
gm_20.sort('continent', 'life_exp').show(5)

We can also specify that a column is to be sorted in descending order by calling the `desc()` method of an associated column object. This can be useful if we want to sort by multiple columns, but want to sort some of the columns in ascending order and others in descending order.

In [0]:
gm_20.sort(
    col('continent').desc(), 
    col('life_exp')
).show(5)

We can also sort by functions of column values, rather than the column values themselves. In the cell below, we sort the `gm_20` DataFrame in decreasing order of length of the name of the countries.

In [0]:
gm_20.sort(expr('LENGTH(country)'), ascending=False).show(10, truncate=False)

### The `groupBy()` Transformation

Spark provides tools for performing grouping and aggregation. These tools allow us to group together rows of a DataFrame that share the same values in one or more columns, and then report summary statistics for each such group. This process is performed in two steps:

1. Call the `groupBy()` method of the DataFrame, passing it one or more column objects or names of columns. This specific the column or columns upon which to group.
2. Call the `agg()` method of the grouped DataFrame, passing it one or more expressions defining aggregations to be performed on columns of the DataFrame.

The resulting DataFrame will contain one row for each unique value within the column that we grouped on. It will contain on column indicating these unique grouped, and one column for each aggregation we defined.

In the cell below, we group the 2020 data according to continental region and then calculate the total population for each such region.

In [0]:
( 
    gm_20
    .groupBy('continent')
    .agg(expr('SUM(population) AS total_pop'))
    .sort('total_pop')
    .show()
)

We can perform as many aggregations as we would like within a single call to `agg()`. Each aggregation will define a new column in the grouped DataFrame.

In [0]:
( 
    gm_20
    .groupBy('continent')
    .agg(
        expr('COUNT(*) AS n_countries'),
        expr('SUM(population) AS total_pop'),
        expr('INT(MEAN(population)) AS avg_pop'),
        expr('MIN(population) AS min_pop'),
        expr('MAX(population) AS max_pop')     
    )
    .sort('continent')
    .show()
)

## Example: Life Expectency and Per Capita GDP

We will close this lesson with an example in which we calculate the total population, life expectency, and per capita GDP for each of the four continental regions in the Gapminder dataset.

In [0]:
( 
    gm_20
    .select(
        '*',
        expr('gdp_per_cap * population as total_gdp'),
        expr('life_exp * population as pop_wt_life_exp')
    )
    .groupBy('continent')
    .agg(
        expr('COUNT(*) AS n_countries'),
        expr('SUM(population) AS total_pop'),
        expr('SUM(total_gdp) AS total_gdp'),
        expr('SUM(pop_wt_life_exp) AS pop_wt_life_exp')
    )
    .select(
        'continent', 'n_countries', 'total_pop',
        expr('INT(total_gdp / total_pop) AS gdp_per_cap'),
        expr('ROUND(pop_wt_life_exp / total_pop,1) AS life_exp'),
    )
    .sort('continent')
    .show()
)