# Lesson 6 - Aggregations and Window Functions

Okay, let's craft the technical notes for Lesson 6: Aggregations and Window Functions in PySpark.

---

## PySpark Technical Notes: Lesson 6 - Aggregations and Window Functions

**Objective:** This section delves into powerful data manipulation techniques within PySpark: aggregating data based on groups and performing calculations across related rows using window functions. These are fundamental operations for data summarization, analysis, and feature engineering in large datasets.

### 1. Grouped Aggregations (`groupBy` and Aggregation Functions)

**Theory:**

Aggregations are operations that compute a single result from a set of input values. In distributed data processing frameworks like Spark, aggregations are often performed after grouping data based on one or more key columns.

The `groupBy()` transformation in PySpark groups rows of a DataFrame based on the specified column(s). It doesn't compute anything immediately; instead, it returns a `GroupedData` object. This object holds the grouping information and requires an aggregation function (`agg()`, `sum()`, `count()`, `avg()`, `min()`, `max()`, etc.) to compute a final result DataFrame.

**Execution Flow:**
1.  **Shuffle Phase:** When `groupBy()` is called, Spark typically triggers a shuffle operation. Data rows with the same grouping key(s) from different partitions across the cluster are moved to the same worker node and partition. This can be I/O and network intensive.
2.  **Aggregation Phase:** Once data is co-located, the specified aggregation function(s) are applied *within each group* on the respective worker nodes.
3.  **Result:** A new DataFrame is created where each row represents a unique group key combination, and the columns contain the aggregated results.

**Key Aggregation Functions (within `pyspark.sql.functions`):**
*   `count(col)`: Counts the number of rows in each group (counts nulls if `col` is `*` or a literal).
*   `countDistinct(col)`: Counts the number of distinct values in `col` for each group.
*   `sum(col)`: Computes the sum of numeric values in `col` for each group.
*   `avg(col)` or `mean(col)`: Computes the average of numeric values in `col` for each group.
*   `min(col)`: Finds the minimum value in `col` for each group.
*   `max(col)`: Finds the maximum value in `col` for each group.
*   `agg(*exprs)`: Allows applying multiple aggregation functions simultaneously using a dictionary or function calls.

**Code Example:**

Let's consider a dataset of product sales and calculate total sales amount, average sales amount, and the number of transactions per product category.

```python
from pyspark.sql import SparkSession
from pyspark.sql import functions as F
from pyspark.sql.types import StructType, StructField, StringType, DoubleType, IntegerType

# Boilerplate Spark Session creation (assuming a running Spark cluster/local mode)
spark = SparkSession.builder.appName("AggregationsExample").getOrCreate()

# Sample Sales Data Schema
schema = StructType([
    StructField("transaction_id", IntegerType(), True),
    StructField("product_category", StringType(), True),
    StructField("sales_amount", DoubleType(), True),
    StructField("region", StringType(), True)
])

# Sample Sales Data
data = [
    (1, "Electronics", 1200.00, "North"),
    (2, "Clothing", 55.50, "South"),
    (3, "Electronics", 800.50, "North"),
    (4, "Home Goods", 250.00, "West"),
    (5, "Clothing", 80.00, "North"),
    (6, "Electronics", 150.75, "South"),
    (7, "Home Goods", 499.99, "North"),
    (8, "Clothing", 120.25, "West"),
    (9, "Electronics", 2200.00, "West")
]

sales_df = spark.createDataFrame(data, schema=schema)
print("--- Original Sales Data ---")
sales_df.show()

# Perform aggregations: Group by 'product_category'
category_summary_df = sales_df.groupBy("product_category") \
    .agg(
        F.sum("sales_amount").alias("total_sales"),
        F.avg("sales_amount").alias("average_sales"),
        F.count("*").alias("transaction_count") # Count all rows within the group
    )

print("--- Aggregated Sales Summary by Category ---")
category_summary_df.show()

# Perform aggregations: Group by 'product_category' and 'region'
region_category_summary_df = sales_df.groupBy("product_category", "region") \
    .agg(
        F.sum("sales_amount").alias("total_sales"),
        F.count("transaction_id").alias("transaction_count") # Count non-null transaction IDs
    )

print("--- Aggregated Sales Summary by Category and Region ---")
region_category_summary_df.show()

spark.stop()
```

**Code Explanation:**

1.  **`from pyspark.sql import SparkSession, functions as F, types as T`**: Imports necessary classes and functions. `F` is a conventional alias for `pyspark.sql.functions`.
2.  **`spark = SparkSession.builder...getOrCreate()`**: Standard way to initialize the SparkSession, the entry point for Spark functionality.
3.  **`schema = StructType([...])`**: Defines the structure and data types for the DataFrame, ensuring data integrity.
4.  **`data = [...]`**: Python list of tuples representing the raw data.
5.  **`sales_df = spark.createDataFrame(...)`**: Creates the PySpark DataFrame from the sample data and schema.
6.  **`sales_df.show()`**: Displays the initial DataFrame content.
7.  **`sales_df.groupBy("product_category")`**: Groups the DataFrame rows based on the unique values in the `product_category` column. Returns a `GroupedData` object.
8.  **`.agg(...)`**: Applies aggregation functions to the grouped data.
9.  **`F.sum("sales_amount").alias("total_sales")`**: Calculates the sum of the `sales_amount` column for each group and renames the resulting column to `total_sales`. `alias()` is crucial for clarity.
10. **`F.avg("sales_amount").alias("average_sales")`**: Calculates the average sales amount per category.
11. **`F.count("*").alias("transaction_count")`**: Counts the total number of rows within each group. Using `*` ensures all rows are counted, regardless of nulls in specific columns.
12. **`category_summary_df.show()`**: Displays the result of the first aggregation.
13. **`sales_df.groupBy("product_category", "region")`**: Groups by a combination of two columns.
14. **`F.count("transaction_id").alias("transaction_count")`**: Counts only the rows where `transaction_id` is not null within each group. If `transaction_id` could be null, this might differ from `count("*")`.
15. **`region_category_summary_df.show()`**: Displays the result of the multi-column grouping.
16. **`spark.stop()`**: Releases the resources used by the SparkSession.

**Practical Use Cases:**
*   Calculating total revenue, average order value, or customer count per region/store/month.
*   Summarizing user activity metrics (e.g., average session duration, total clicks per user segment).
*   Aggregating sensor readings (e.g., min/max/average temperature per sensor location per hour).

**Performance Considerations:**
*   **Cardinality of Grouping Keys:** Grouping by columns with very high cardinality (many unique values) can lead to a large number of small groups, potentially causing shuffle overhead and data skew (some nodes processing disproportionately large groups).
*   **Pre-filtering:** Filter data *before* the `groupBy` operation whenever possible to reduce the amount of data being shuffled.
*   **Combiners:** Spark automatically uses combiners within each partition *before* the shuffle to reduce the amount of data transferred, where applicable (e.g., for `sum`, `count`, `min`, `max`).

---

### 2. Pivot and Unpivot

**Theory:**

**Pivot:** This operation transforms data from a "long" format to a "wide" format. It rotates unique values from a specific *pivot column* into multiple new columns. You need:
1.  Column(s) to group by (defines the rows in the output).
2.  A *pivot column* whose distinct values will become new column headers.
3.  An aggregation function to determine the value that goes into the cells formed by the intersection of the group-by rows and the new pivot columns.

**Unpivot:** This is the reverse operation of pivot, transforming data from a "wide" format back to a "long" format. It stacks multiple columns into a pair of columns: one for the original column name (or a category derived from it) and one for the value. PySpark doesn't have a direct `unpivot` function before Spark 3.4. The common approach is using the `stack` expression within `selectExpr` or `expr`. Spark 3.4 introduced a dedicated `unpivot` function.

**Code Example (Pivot):**

Let's pivot the sales data to show total sales for each category across different regions.

```python
from pyspark.sql import SparkSession
from pyspark.sql import functions as F
from pyspark.sql.types import StructType, StructField, StringType, DoubleType, IntegerType

spark = SparkSession.builder.appName("PivotExample").getOrCreate()

schema = StructType([
    StructField("transaction_id", IntegerType(), True),
    StructField("product_category", StringType(), True),
    StructField("sales_amount", DoubleType(), True),
    StructField("region", StringType(), True)
])

data = [
    (1, "Electronics", 1200.00, "North"), (2, "Clothing", 55.50, "South"),
    (3, "Electronics", 800.50, "North"), (4, "Home Goods", 250.00, "West"),
    (5, "Clothing", 80.00, "North"), (6, "Electronics", 150.75, "South"),
    (7, "Home Goods", 499.99, "North"), (8, "Clothing", 120.25, "West"),
    (9, "Electronics", 2200.00, "West")
]

sales_df = spark.createDataFrame(data, schema=schema)
print("--- Original Sales Data ---")
sales_df.show()

# Pivot: Show total sales per category for each region
# Explicitly list pivot values for better performance and predictable schema
distinct_regions = ["North", "South", "West"] # Best practice to provide values

pivoted_sales_df = sales_df.groupBy("product_category") \
    .pivot("region", distinct_regions) \
    .agg(F.sum("sales_amount"))

print("--- Pivoted Sales Data (Category vs Region Sales) ---")
pivoted_sales_df.show()

# Handle potential nulls resulting from pivot (e.g., if a category had no sales in a region)
pivoted_sales_filled_df = pivoted_sales_df.na.fill(0) # Fill nulls with 0

print("--- Pivoted Sales Data (Filled Nulls) ---")
pivoted_sales_filled_df.show()

spark.stop()
```

**Code Explanation (Pivot):**

1.  **`sales_df.groupBy("product_category")`**: Groups the data by the product category. These categories will form the rows of the pivoted table.
2.  **`.pivot("region", distinct_regions)`**: Specifies the `region` column as the pivot column. Its distinct values (`North`, `South`, `West`) will become new column headers. Providing the list `distinct_regions` is optional but recommended; otherwise, Spark needs an extra pass over the data to find the distinct values, and the resulting schema might vary.
3.  **`.agg(F.sum("sales_amount"))`**: Defines the aggregation to perform for each cell (intersection of `product_category` and `region`). Here, we sum the sales amounts. Only one aggregation function is typically used with `pivot`.
4.  **`pivoted_sales_df.show()`**: Displays the pivoted DataFrame. Note that cells where no combination existed in the original data (e.g., potentially a category with zero sales in the 'South') will contain `null`.
5.  **`pivoted_sales_df.na.fill(0)`**: Uses the DataFrameNaFunctions (`na`) to replace any `null` values (resulting from the pivot) with `0`.

**Code Example (Unpivot using `stack` expression):**

Let's take the pivoted data (or similar wide-format data) and unpivot it.

```python
from pyspark.sql import SparkSession
from pyspark.sql import functions as F

spark = SparkSession.builder.appName("UnpivotExample").getOrCreate()

# Assume we have the pivoted data from the previous step (or create similar)
data_pivoted = [
    ("Clothing", 135.50, 55.50, 120.25),
    ("Electronics", 2000.50, 150.75, 2200.00),
    ("Home Goods", 499.99, None, 250.00) # Added None for South region
]
schema_pivoted = ["product_category", "North", "South", "West"]

pivoted_df = spark.createDataFrame(data_pivoted, schema=schema_pivoted)
print("--- Wide Format Data (Input for Unpivot) ---")
pivoted_df.show()

# Unpivot using stack expression
# stack(n, col1_name, col1_val, col2_name, col2_val, ...) -> produces n columns per row
num_regions = 3 # Number of columns to unpivot (North, South, West)
unpivot_expr = f"stack({num_regions}, 'North', North, 'South', South, 'West', West) as (region, total_sales)"

unpivoted_df = pivoted_df.select("product_category", F.expr(unpivot_expr)) \
    .where(F.col("total_sales").isNotNull()) # Optional: remove rows where sales were null

print("--- Unpivoted Data (Long Format) ---")
unpivoted_df.show()

spark.stop()
```

**Code Explanation (Unpivot):**

1.  **`data_pivoted`, `schema_pivoted`, `pivoted_df`**: Creates a sample DataFrame in a "wide" format, similar to the output of the pivot operation.
2.  **`num_regions = 3`**: Defines how many columns are being unpivoted.
3.  **`unpivot_expr = f"stack(...)" `**: Constructs the `stack` expression string.
    *   `stack()` is a Spark SQL function that takes a number `n` and multiple pairs of (literal name, column value).
    *   It creates `n` output columns (here specified as `(region, total_sales)`).
    *   For each input row, it generates multiple output rows – one for each (literal name, column value) pair.
    *   `'North', North`: Takes the literal string 'North' and the value from the `North` column.
    *   `'South', South`: Takes the literal string 'South' and the value from the `South` column.
    *   `'West', West`: Takes the literal string 'West' and the value from the `West` column.
    *   `as (region, total_sales)`: Names the two output columns generated by `stack`.
4.  **`pivoted_df.select("product_category", F.expr(unpivot_expr))`**: Selects the original grouping column (`product_category`) and applies the `stack` expression using `F.expr()`. This transforms the wide columns into rows.
5.  **`.where(F.col("total_sales").isNotNull())`**: Filters out rows that were created from `null` values in the original pivoted table (e.g., 'Home Goods' in 'South'). This step is optional depending on requirements.
6.  **`unpivoted_df.show()`**: Displays the final "long" format DataFrame.

**Practical Use Cases:**
*   **Pivot:** Creating summary tables for reports, transforming time-series data where dates/times become columns, feature engineering where distinct categories become binary/numeric features.
*   **Unpivot:** Normalizing data that is already in a wide format (e.g., from spreadsheets or certain database exports), preparing data for tools or models that expect a long format.

**Performance Considerations:**
*   **Pivot:** Can be expensive, especially if the number of distinct values in the pivot column is large and not explicitly provided. It involves grouping and aggregation.
*   **Unpivot (`stack`):** Generally efficient as it's primarily a projection operation, but generating many rows from one input row can increase data size.

---

### 3. Window Functions

**Theory:**

Window functions perform calculations across a set of table rows that are somehow related to the current row. Unlike `groupBy` aggregations, window functions do not collapse rows; they return a value for *each* row based on the "window" of related rows.

A window is defined using the `Window` specification in PySpark (`pyspark.sql.window.Window`), which typically involves three components:

1.  **Partitioning (`partitionBy(cols...)`):** Divides the rows into partitions (groups). The window function is applied independently within each partition. Similar to `groupBy`, but doesn't collapse rows. If omitted, the entire DataFrame is treated as a single partition.
2.  **Ordering (`orderBy(cols...)`):** Specifies the order of rows *within each partition*. This is crucial for functions sensitive to order, like ranking (`rank`, `row_number`) or accessing preceding/succeeding rows (`lag`, `lead`).
3.  **Frame Clause (`rowsBetween(start, end)`, `rangeBetween(start, end)`):** (Optional, more advanced) Defines the exact set of rows within the partition relative to the current row (the "window frame"). Defaults often suffice for common functions (e.g., for ranking, the default frame is usually the whole partition; for `lag`/`lead`, it's relative positioning). `Window.unboundedPreceding`, `Window.currentRow`, `Window.unboundedFollowing` are common boundary markers.

**Common Window Functions (within `pyspark.sql.functions`):**

*   **Ranking Functions:**
    *   `rank()`: Assigns ranks based on ordering within a partition. Skips ranks after ties (e.g., 1, 1, 3).
    *   `dense_rank()`: Assigns ranks without gaps (e.g., 1, 1, 2).
    *   `row_number()`: Assigns a unique sequential number within the partition based on order, arbitrarily breaking ties.
    *   `percent_rank()`: Calculates the percentile rank within the partition.
    *   `ntile(n)`: Divides rows into `n` buckets (tiles) based on order.
*   **Analytic Functions:**
    *   `lag(col, offset=1, default=None)`: Accesses the value of `col` from a preceding row within the partition (defined by `orderBy`).
    *   `lead(col, offset=1, default=None)`: Accesses the value of `col` from a succeeding row within the partition.
*   **Aggregate Functions used as Window Functions:**
    *   `sum(col)`, `avg(col)`, `min(col)`, `max(col)`, `count(col)` can also be used over a window frame (e.g., to calculate running totals or moving averages).

**Code Example:**

Let's rank employees by salary within each department and find the salary of the next lower-paid employee in the same department.

```python
from pyspark.sql import SparkSession, Window
from pyspark.sql import functions as F
from pyspark.sql.types import StructType, StructField, StringType, IntegerType, DoubleType

spark = SparkSession.builder.appName("WindowFunctionsExample").getOrCreate()

# Sample Employee Data Schema
schema = StructType([
    StructField("emp_id", IntegerType(), True),
    StructField("emp_name", StringType(), True),
    StructField("department", StringType(), True),
    StructField("salary", DoubleType(), True)
])

# Sample Employee Data
data = [
    (1, "Alice", "Sales", 70000.0),
    (2, "Bob", "Sales", 65000.0),
    (3, "Charlie", "Sales", 70000.0),
    (4, "David", "IT", 80000.0),
    (5, "Eve", "IT", 95000.0),
    (6, "Frank", "IT", 80000.0),
    (7, "Grace", "HR", 55000.0),
    (8, "Heidi", "HR", 60000.0)
]

emp_df = spark.createDataFrame(data, schema=schema)
print("--- Original Employee Data ---")
emp_df.show()

# Define the window specification: Partition by department, order by salary descending
dept_window_spec = Window.partitionBy("department").orderBy(F.desc("salary"))

# Apply window functions
emp_ranked_df = emp_df.withColumn("rank", F.rank().over(dept_window_spec)) \
    .withColumn("dense_rank", F.dense_rank().over(dept_window_spec)) \
    .withColumn("row_num", F.row_number().over(dept_window_spec)) \
    .withColumn("salary_next_lower", F.lag("salary", offset=1).over(dept_window_spec)) \
    .withColumn("salary_prev_higher", F.lead("salary", offset=1).over(dept_window_spec)) # lead gets next row val based on ORDER BY

print("--- Employee Data with Window Function Results ---")
emp_ranked_df.orderBy("department", "rank", "emp_id").show()

# Example: Running total of salary within department (requires specific frame)
dept_window_spec_running_total = Window.partitionBy("department") \
                                        .orderBy("salary") \
                                        .rowsBetween(Window.unboundedPreceding, Window.currentRow)

emp_running_total_df = emp_df.withColumn("running_total_salary",
                                        F.sum("salary").over(dept_window_spec_running_total))

print("--- Employee Data with Running Total Salary ---")
emp_running_total_df.orderBy("department", "salary").show()


spark.stop()
```

**Code Explanation:**

1.  **`from pyspark.sql import Window`**: Imports the `Window` class needed to define window specifications.
2.  **`emp_df = spark.createDataFrame(...)`**: Creates the employee DataFrame.
3.  **`dept_window_spec = Window.partitionBy("department").orderBy(F.desc("salary"))`**: Defines the window.
    *   `partitionBy("department")`: Calculations will happen independently for 'Sales', 'IT', and 'HR'.
    *   `orderBy(F.desc("salary"))`: Within each department, rows are ordered from highest salary to lowest. This order is crucial for `rank`, `row_number`, `lag`, `lead`.
4.  **`.withColumn("rank", F.rank().over(dept_window_spec))`**: Adds a new column 'rank'.
    *   `F.rank()`: The window function being applied.
    *   `.over(dept_window_spec)`: Specifies that `rank()` should operate over the defined window. Alice and Charlie (Sales, 70k) get rank 1, Bob (Sales, 65k) gets rank 3.
5.  **`.withColumn("dense_rank", F.dense_rank().over(dept_window_spec))`**: Adds 'dense_rank'. Alice/Charlie get 1, Bob gets 2 (no gap).
6.  **`.withColumn("row_num", F.row_number().over(dept_window_spec))`**: Adds 'row_num'. Assigns unique numbers (1, 2, 3...) based on the order, breaking ties arbitrarily but consistently within an execution.
7.  **`.withColumn("salary_next_lower", F.lag("salary", offset=1).over(dept_window_spec))`**: Adds 'salary_next_lower'.
    *   `F.lag("salary", offset=1)`: Gets the value from the `salary` column of the *previous* row within the partition, based on the `orderBy` clause (descending salary). For the highest earner(s) in a dept, this will be `null`.
8.  **`.withColumn("salary_prev_higher", F.lead("salary", offset=1).over(dept_window_spec))`**: Adds 'salary_prev_higher'.
    *   `F.lead("salary", offset=1)`: Gets the value from the `salary` column of the *next* row within the partition based on the `orderBy`. For the lowest earner(s) in a dept, this will be `null`.
9.  **`emp_ranked_df.orderBy(...).show()`**: Displays the results, sorted for clarity.
10. **`dept_window_spec_running_total = ...`**: Defines a new window spec specifically for a running total. Note the `orderBy("salary")` (ascending) and the explicit frame clause.
11. **`rowsBetween(Window.unboundedPreceding, Window.currentRow)`**: Defines the window frame: include all rows from the beginning of the partition up to and including the current row, based on the order.
12. **`.withColumn("running_total_salary", F.sum("salary").over(dept_window_spec_running_total))`**: Calculates the cumulative sum of salaries within each department, ordered by salary.
13. **`emp_running_total_df.orderBy(...).show()`**: Displays the running total results.

**Practical Use Cases:**
*   Ranking items within categories (e.g., top N products per region).
*   Calculating period-over-period changes (e.g., sales difference from the previous month using `lag`).
*   Computing running totals or moving averages.
*   Sessionization: Identifying user sessions by finding time gaps between events using `lag`.
*   Filling missing values based on nearby rows within a partition.

**Performance Considerations:**
*   **Partitioning:** Window functions also involve data shuffling if `partitionBy` is used. Similar concerns about key cardinality and data skew apply as with `groupBy`.
*   **Ordering:** Sorting within partitions (`orderBy`) adds computational cost.
*   **Frame Clause:** Complex frame clauses (especially those involving large ranges or `unboundedPreceding`/`unboundedFollowing`) can impact memory usage on worker nodes, as more data per partition might need to be held for calculation.
*   **Resource Intensive:** Window functions can be more resource-intensive than simple aggregations as they don't reduce the number of rows. Ensure sufficient cluster memory.

---
**End of Lesson 6 Notes**