# Module 4 - Data Transformations & Aggregations

## Introduction

This module covers advanced DataFrame transformations including grouping, aggregations, column manipulation, and null value handling. These operations are essential for data analysis and ETL tasks.

## What You'll Learn

- Grouping data and performing aggregations
- Adding and renaming columns
- Handling null values
- Combining multiple transformations
- Common aggregation functions
- Conditional logic with when/otherwise


In [1]:
from pyspark.sql import SparkSession
from pyspark.sql.types import StructType, StructField, StringType, IntegerType, DoubleType
from pyspark.sql.functions import col, when, sum, avg, count, max, min

# Create SparkSession
spark = SparkSession.builder \
    .appName("DataFrame Transformations & Aggregations") \
    .master("local[*]") \
    .getOrCreate()

# Create sample DataFrame
data = [
    ("Alice", 25, "Sales", 50000, "New York"),
    ("Bob", 30, "IT", 60000, "London"),
    ("Charlie", 35, "Sales", 70000, "Tokyo"),
    ("Diana", 28, "IT", 55000, "Paris"),
    ("Eve", 32, "HR", 65000, "Sydney"),
    ("Frank", 27, "Sales", 52000, "New York"),
    ("Grace", 29, None, 58000, "London")  # Department is null
]

schema = StructType([
    StructField("Name", StringType(), True),
    StructField("Age", IntegerType(), True),
    StructField("Department", StringType(), True),
    StructField("Salary", IntegerType(), True),
    StructField("City", StringType(), True)
])

df = spark.createDataFrame(data, schema)
print("Sample DataFrame:")
df.show()


25/12/28 21:35:33 WARN Utils: Your hostname, N-MacBookPro-37.local resolves to a loopback address: 127.0.0.1; using 192.168.1.2 instead (on interface en0)
25/12/28 21:35:33 WARN Utils: Set SPARK_LOCAL_IP if you need to bind to another address
Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
25/12/28 21:35:33 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable
25/12/28 21:35:33 WARN Utils: Service 'SparkUI' could not bind on port 4040. Attempting port 4041.
25/12/28 21:35:33 WARN Utils: Service 'SparkUI' could not bind on port 4041. Attempting port 4042.
25/12/28 21:35:33 WARN Utils: Service 'SparkUI' could not bind on port 4042. Attempting port 4043.
25/12/28 21:35:33 WARN Utils: Service 'SparkUI' could not bind on port 4043. Attempting port 4044.
25/12/28 21:35:33 WARN Utils: Service 'SparkUI' could not bind on port 4044. Attempting 

Sample DataFrame:


                                                                                

+-------+---+----------+------+--------+
|   Name|Age|Department|Salary|    City|
+-------+---+----------+------+--------+
|  Alice| 25|     Sales| 50000|New York|
|    Bob| 30|        IT| 60000|  London|
|Charlie| 35|     Sales| 70000|   Tokyo|
|  Diana| 28|        IT| 55000|   Paris|
|    Eve| 32|        HR| 65000|  Sydney|
|  Frank| 27|     Sales| 52000|New York|
|  Grace| 29|      NULL| 58000|  London|
+-------+---+----------+------+--------+



## Grouping and Aggregations

Group data by one or more columns and perform aggregations. Similar to SQL's GROUP BY clause.

### Aggregate Functions

**Aggregate functions** combine multiple input rows together to give a consolidated output. They are essential for summarizing data and performing calculations across groups of rows.

PySpark provides three different styles for writing aggregations:

1. **Programmatic Style**: Using PySpark functions and column objects
2. **Column Expression Style**: Using `expr()` or `selectExpr()` with SQL-like expressions
3. **Spark SQL Style**: Using SQL queries directly with `spark.sql()`

Each style has its advantages and use cases. Let's explore all three approaches.


In [2]:
# Group by Department and count
df_grouped = df.groupBy("Department").count()
print("Count by Department:")
df_grouped.show()


Count by Department:
+----------+-----+
|Department|count|
+----------+-----+
|     Sales|    3|
|        IT|    2|
|        HR|    1|
|      NULL|    1|
+----------+-----+



In [3]:
# Multiple aggregations using agg()
from pyspark.sql.functions import avg, sum, max, min

df_agg = df.groupBy("Department").agg(
    count("Name").alias("EmployeeCount"),
    avg("Salary").alias("AvgSalary"),
    sum("Salary").alias("TotalSalary"),
    max("Salary").alias("MaxSalary"),
    min("Salary").alias("MinSalary")
)

print("Multiple aggregations by Department:")
df_agg.show()


Multiple aggregations by Department:
+----------+-------------+------------------+-----------+---------+---------+
|Department|EmployeeCount|         AvgSalary|TotalSalary|MaxSalary|MinSalary|
+----------+-------------+------------------+-----------+---------+---------+
|     Sales|            3|57333.333333333336|     172000|    70000|    50000|
|        IT|            2|           57500.0|     115000|    60000|    55000|
|        HR|            1|           65000.0|      65000|    65000|    65000|
|      NULL|            1|           58000.0|      58000|    58000|    58000|
+----------+-------------+------------------+-----------+---------+---------+



### Simple Aggregations - Three Different Styles

Consider you have an orders dataset and you are required to:
- Count the total number of records
- Count number of distinct invoice IDs
- Sum of Quantities
- Average unit price

Let's see 3 ways of solving the above:


In [4]:
# Example: Simple Aggregations - Three Styles
# Note: This example uses a sample orders DataFrame structure

# Create sample orders data for demonstration
orders_data = [
    ("INV001", 5, 10.5, "USA"),
    ("INV001", 3, 15.0, "USA"),
    ("INV002", 2, 20.0, "UK"),
    ("INV003", 4, 12.5, "USA"),
    ("INV002", 1, 25.0, "UK"),
]

orders_df = spark.createDataFrame(orders_data, ["invoiceno", "quantity", "unitprice", "country"])

print("="*70)
print("1. PROGRAMMATIC STYLE")
print("="*70)
print("Using PySpark functions and column objects:")
print()

from pyspark.sql.functions import *

result_programmatic = orders_df.select(
    count("*").alias("row_count"),
    countDistinct("invoiceno").alias("unique_invoice"),
    sum("quantity").alias("total_quantity"),
    avg("unitprice").alias("avg_price")
)

result_programmatic.show()


1. PROGRAMMATIC STYLE
Using PySpark functions and column objects:

+---------+--------------+--------------+---------+
|row_count|unique_invoice|total_quantity|avg_price|
+---------+--------------+--------------+---------+
|        5|             3|            15|     16.6|
+---------+--------------+--------------+---------+



In [5]:
print("="*70)
print("2. COLUMN EXPRESSION STYLE")
print("="*70)
print("Using selectExpr() with SQL-like expressions:")
print()

result_expr = orders_df.selectExpr(
    "count(*) as row_count",
    "count(distinct(invoiceno)) as unique_invoice",
    "sum(quantity) as total_quantity",
    "avg(unitprice) as avg_price"
)

result_expr.show()


2. COLUMN EXPRESSION STYLE
Using selectExpr() with SQL-like expressions:

+---------+--------------+--------------+---------+
|row_count|unique_invoice|total_quantity|avg_price|
+---------+--------------+--------------+---------+
|        5|             3|            15|     16.6|
+---------+--------------+--------------+---------+



In [6]:
print("="*70)
print("3. SPARK SQL STYLE")
print("="*70)
print("Using SQL queries directly:")
print()

# Create a temporary view
orders_df.createOrReplaceTempView("orders")

result_sql = spark.sql("""
    SELECT 
        count(*) as row_count,
        count(distinct(invoiceno)) as unique_invoice,
        sum(quantity) as total_quantity,
        avg(unitprice) as avg_price
    FROM orders
""")

result_sql.show()


3. SPARK SQL STYLE
Using SQL queries directly:

+---------+--------------+--------------+---------+
|row_count|unique_invoice|total_quantity|avg_price|
+---------+--------------+--------------+---------+
|        5|             3|            15|     16.6|
+---------+--------------+--------------+---------+



### Grouping Aggregations - Three Different Styles

Consider you have an orders dataset and you are required to group based on invoice number and country:
- Find the total quantity for each group
- Find the total invoice amount (Amount = Quantity * UnitPrice)

Let's see 3 ways of solving the above:


In [7]:
# Example: Grouping Aggregations - Three Styles

print("="*70)
print("1. PROGRAMMATIC STYLE")
print("="*70)
print("Using PySpark functions with groupBy and agg():")
print()

from pyspark.sql.functions import *

summary_df_programmatic = orders_df \
    .groupBy("country", "invoiceno") \
    .agg(
        sum("quantity").alias("total_quantity"),
        sum(expr("quantity * unitprice")).alias("invoice_value")
    ) \
    .sort("invoiceno")

summary_df_programmatic.show()


1. PROGRAMMATIC STYLE
Using PySpark functions with groupBy and agg():

+-------+---------+--------------+-------------+
|country|invoiceno|total_quantity|invoice_value|
+-------+---------+--------------+-------------+
|    USA|   INV001|             8|         97.5|
|     UK|   INV002|             3|         65.0|
|    USA|   INV003|             4|         50.0|
+-------+---------+--------------+-------------+



In [8]:
print("="*70)
print("2. COLUMN EXPRESSION STYLE")
print("="*70)
print("Using expr() with groupBy and agg():")
print()

summary_df_expr = orders_df \
    .groupBy("country", "invoiceno") \
    .agg(
        expr("sum(quantity) as total_quantity"),
        expr("sum(quantity * unitprice) as invoice_value")
    ) \
    .sort("invoiceno")

summary_df_expr.show()


2. COLUMN EXPRESSION STYLE
Using expr() with groupBy and agg():

+-------+---------+--------------+-------------+
|country|invoiceno|total_quantity|invoice_value|
+-------+---------+--------------+-------------+
|    USA|   INV001|             8|         97.5|
|     UK|   INV002|             3|         65.0|
|    USA|   INV003|             4|         50.0|
+-------+---------+--------------+-------------+



In [9]:
print("="*70)
print("3. SPARK SQL STYLE")
print("="*70)
print("Using SQL queries with GROUP BY:")
print()

summary_df_sql = spark.sql("""
    SELECT 
        country, 
        invoiceno, 
        sum(quantity) as total_quantity, 
        sum(quantity * unitprice) as invoice_value
    FROM orders
    GROUP BY country, invoiceno
    ORDER BY invoiceno
""")

summary_df_sql.show()


3. SPARK SQL STYLE
Using SQL queries with GROUP BY:

+-------+---------+--------------+-------------+
|country|invoiceno|total_quantity|invoice_value|
+-------+---------+--------------+-------------+
|    USA|   INV001|             8|         97.5|
|     UK|   INV002|             3|         65.0|
|    USA|   INV003|             4|         50.0|
+-------+---------+--------------+-------------+



## Adding and Renaming Columns


In [10]:
# Add a new column
df_with_bonus = df.withColumn("Bonus", col("Salary") * 0.1)
print("DataFrame with Bonus column:")
df_with_bonus.show()

# Rename a column
df_renamed = df.withColumnRenamed("Department", "Dept")
print("\nRenamed Department to Dept:")
df_renamed.show()


DataFrame with Bonus column:
+-------+---+----------+------+--------+------+
|   Name|Age|Department|Salary|    City| Bonus|
+-------+---+----------+------+--------+------+
|  Alice| 25|     Sales| 50000|New York|5000.0|
|    Bob| 30|        IT| 60000|  London|6000.0|
|Charlie| 35|     Sales| 70000|   Tokyo|7000.0|
|  Diana| 28|        IT| 55000|   Paris|5500.0|
|    Eve| 32|        HR| 65000|  Sydney|6500.0|
|  Frank| 27|     Sales| 52000|New York|5200.0|
|  Grace| 29|      NULL| 58000|  London|5800.0|
+-------+---+----------+------+--------+------+


Renamed Department to Dept:
+-------+---+-----+------+--------+
|   Name|Age| Dept|Salary|    City|
+-------+---+-----+------+--------+
|  Alice| 25|Sales| 50000|New York|
|    Bob| 30|   IT| 60000|  London|
|Charlie| 35|Sales| 70000|   Tokyo|
|  Diana| 28|   IT| 55000|   Paris|
|    Eve| 32|   HR| 65000|  Sydney|
|  Frank| 27|Sales| 52000|New York|
|  Grace| 29| NULL| 58000|  London|
+-------+---+-----+------+--------+



In [11]:
# Add column with conditional logic
df_with_category = df.withColumn(
    "SalaryCategory",
    when(col("Salary") > 65000, "High")
    .when(col("Salary") > 55000, "Medium")
    .otherwise("Low")
)
print("DataFrame with SalaryCategory:")
df_with_category.show()


DataFrame with SalaryCategory:
+-------+---+----------+------+--------+--------------+
|   Name|Age|Department|Salary|    City|SalaryCategory|
+-------+---+----------+------+--------+--------------+
|  Alice| 25|     Sales| 50000|New York|           Low|
|    Bob| 30|        IT| 60000|  London|        Medium|
|Charlie| 35|     Sales| 70000|   Tokyo|          High|
|  Diana| 28|        IT| 55000|   Paris|           Low|
|    Eve| 32|        HR| 65000|  Sydney|        Medium|
|  Frank| 27|     Sales| 52000|New York|           Low|
|  Grace| 29|      NULL| 58000|  London|        Medium|
+-------+---+----------+------+--------+--------------+



## Handling Null Values

Dealing with null/missing values is crucial in data engineering.


In [12]:
# Filter out null values
df_no_null = df.filter(df.Department.isNotNull())
print("Rows with non-null Department:")
df_no_null.show()

# Fill null values
from pyspark.sql.functions import lit
df_filled = df.fillna({"Department": "Unknown"})
print("\nFilled null Department with 'Unknown':")
df_filled.show()


Rows with non-null Department:
+-------+---+----------+------+--------+
|   Name|Age|Department|Salary|    City|
+-------+---+----------+------+--------+
|  Alice| 25|     Sales| 50000|New York|
|    Bob| 30|        IT| 60000|  London|
|Charlie| 35|     Sales| 70000|   Tokyo|
|  Diana| 28|        IT| 55000|   Paris|
|    Eve| 32|        HR| 65000|  Sydney|
|  Frank| 27|     Sales| 52000|New York|
+-------+---+----------+------+--------+


Filled null Department with 'Unknown':
+-------+---+----------+------+--------+
|   Name|Age|Department|Salary|    City|
+-------+---+----------+------+--------+
|  Alice| 25|     Sales| 50000|New York|
|    Bob| 30|        IT| 60000|  London|
|Charlie| 35|     Sales| 70000|   Tokyo|
|  Diana| 28|        IT| 55000|   Paris|
|    Eve| 32|        HR| 65000|  Sydney|
|  Frank| 27|     Sales| 52000|New York|
|  Grace| 29|   Unknown| 58000|  London|
+-------+---+----------+------+--------+



In [13]:
# Drop rows with null values
df_dropped = df.dropna(subset=["Department"])
print("Dropped rows with null Department:")
df_dropped.show()


Dropped rows with null Department:
+-------+---+----------+------+--------+
|   Name|Age|Department|Salary|    City|
+-------+---+----------+------+--------+
|  Alice| 25|     Sales| 50000|New York|
|    Bob| 30|        IT| 60000|  London|
|Charlie| 35|     Sales| 70000|   Tokyo|
|  Diana| 28|        IT| 55000|   Paris|
|    Eve| 32|        HR| 65000|  Sydney|
|  Frank| 27|     Sales| 52000|New York|
+-------+---+----------+------+--------+



## Combining Transformations

You can chain multiple transformations together. This is a common pattern in PySpark.


In [14]:
# Chain multiple transformations
result = df \
    .filter(df.Age > 28) \
    .filter(df.Department.isNotNull()) \
    .select("Name", "Department", "Salary") \
    .withColumn("Bonus", col("Salary") * 0.1) \
    .orderBy(col("Salary").desc())

print("Chained transformations:")
result.show()


Chained transformations:
+-------+----------+------+------+
|   Name|Department|Salary| Bonus|
+-------+----------+------+------+
|Charlie|     Sales| 70000|7000.0|
|    Eve|        HR| 65000|6500.0|
|    Bob|        IT| 60000|6000.0|
+-------+----------+------+------+



## Common Aggregation Functions

Here are some commonly used aggregation functions:

- `count()`: Count number of rows
- `sum()`: Sum of values
- `avg()` / `mean()`: Average value
- `max()`: Maximum value
- `min()`: Minimum value
- `collect_list()`: Collect values into a list
- `collect_set()`: Collect unique values into a set
- `stddev()`: Standard deviation
- `variance()`: Variance


In [15]:
# Example with various aggregation functions
from pyspark.sql.functions import collect_list, collect_set, stddev

df_agg_examples = df.groupBy("Department").agg(
    count("*").alias("TotalCount"),
    avg("Salary").alias("AvgSalary"),
    stddev("Salary").alias("SalaryStdDev"),
    collect_list("Name").alias("EmployeeNames")
)

print("Various aggregation functions:")
df_agg_examples.show(truncate=False)


Various aggregation functions:
+----------+----------+------------------+------------------+-----------------------+
|Department|TotalCount|AvgSalary         |SalaryStdDev      |EmployeeNames          |
+----------+----------+------------------+------------------+-----------------------+
|Sales     |3         |57333.333333333336|11015.141094572204|[Alice, Charlie, Frank]|
|IT        |2         |57500.0           |3535.5339059327375|[Bob, Diana]           |
|HR        |1         |65000.0           |NULL              |[Eve]                  |
|NULL      |1         |58000.0           |NULL              |[Grace]                |
+----------+----------+------------------+------------------+-----------------------+



## Summary

In this module, you learned:

1. **Grouping and Aggregations**: Using `groupBy()` and `agg()` for aggregations (SQL GROUP BY equivalent)
2. **Adding Columns**: Using `withColumn()` to add new columns
3. **Renaming Columns**: Using `withColumnRenamed()`
4. **Handling Nulls**: Using `fillna()`, `dropna()`, and `isNotNull()`
5. **Combining Transformations**: Chaining multiple transformations for complex data processing
6. **Common Aggregation Functions**: `sum()`, `avg()`, `count()`, `max()`, `min()`, and more

**Key Takeaway**: Transformations are lazy - they create a plan but don't execute until an action is called. You can chain transformations for readable, efficient code.

**Next Steps**: In Module 5, we'll learn about Spark SQL - using SQL syntax to work with DataFrames.
