# What are Window Aggregations?

Unlike standard aggregations (`groupBy().agg()`) which collapse rows into a single output row per group, window aggregations perform calculations across a set of rows (a "window") that are somehow related to the *current* row. The key benefit is that they **return a value for every input row** rather than collapsing them. This is useful for tasks like:

- Calculating running totals or moving averages.
- Ranking rows within groups (e.g., top N products per category).
- Calculating differences between the current row and preceding/succeeding rows.

## Core Components in PySpark

1.  **Window Functions:** These are the functions you apply over the window (e.g., `sum`, `avg`, `rank`, `lag`). Many standard aggregate functions can be used as window functions. Specific window functions are available in `pyspark.sql.functions`.
2.  **Window Specification (`WindowSpec`):** This defines the window (the set of rows) for the calculation. You create it using the `Window` class from `pyspark.sql.window`.

**Defining the Window (`WindowSpec`)**

A `WindowSpec` is defined primarily by three components:

1.  **`partitionBy(*cols)`:**
    * **Purpose:** Divides the rows of the DataFrame into independent partitions based on the specified column(s). The window function is applied separately within each partition.
    * **Analogy:** Similar to `GROUP BY`.
    * **Example:** `Window.partitionBy("department")` - calculations will restart for each department.

2.  **`orderBy(*cols)`:**
    * **Purpose:** Orders the rows *within* each partition based on the specified column(s) and direction (ascending default, use `.desc()` for descending).
    * **Importance:** Crucial for functions that depend on row order, like ranking (`rank`, `row_number`), offset (`lag`, `lead`), and cumulative calculations.
    * **Example:** `Window.partitionBy("department").orderBy("salary")` - rows within each department are ordered by salary.

3.  **Frame Definition (`rowsBetween(start, end)` / `rangeBetween(start, end)`)**
    * **Purpose:** Specifies the exact set of rows relative to the current row *within the ordered partition* to include in the window frame for calculation.
    * **Boundaries:** Defined relative to the `Window.currentRow`. Common boundaries include:
        * `Window.unboundedPreceding`: The first row of the partition.
        * `Window.unboundedFollowing`: The last row of the partition.
        * `Window.currentRow`: The current row.
        * Integer offsets (e.g., `-1`, `1`): Rows preceding or following the current row.
    * **`rowsBetween(start, end)`:** Defines the frame based on a fixed number of rows relative to the current row (physical offset). Example: `rowsBetween(-1, 1)` includes the previous row, current row, and next row.
    * **`rangeBetween(start, end)`:** Defines the frame based on a *value* range relative to the current row's value in the `orderBy` column. Requires ordering by only *one* column. All rows with values within the specified range are included. Example: `rangeBetween(Window.unboundedPreceding, Window.currentRow)` includes all rows from the start of the partition up to the current row's *value* (useful for cumulative sums where rows with the same value are treated together).
    * **Default Frame:**
        * If only `partitionBy` is used: The frame is the entire partition (`rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing)`).
        * If `orderBy` is used without specifying a frame: The default is usually `rangeBetween(Window.unboundedPreceding, Window.currentRow)`, suitable for cumulative calculations.

## Common Window Functions (`pyspark.sql.functions`)

* **Aggregate Functions:** `sum()`, `avg()`, `count()`, `min()`, `max()` applied over the window.
* **Ranking Functions:**
    * `rank()`: Assigns rank based on order; skips ranks after ties (e.g., 1, 1, 3).
    * `dense_rank()`: Assigns rank based on order; does *not* skip ranks after ties (e.g., 1, 1, 2).
    * `row_number()`: Assigns a unique, sequential number within the partition based on order, regardless of ties.
    * `percent_rank()`: Rank as a percentage within the partition.
    * `ntile(n)`: Divides rows into `n` ranked groups (buckets).
* **Analytic (Offset) Functions:**
    * `lag(col, offset=1, default=None)`: Gets the value of `col` from a previous row within the partition (defined by `offset`).
    * `lead(col, offset=1, default=None)`: Gets the value of `col` from a subsequent row within the partition.

**Example:**

Let's calculate the salary rank for each employee within their department and the difference from the average department salary.

```python
from pyspark.sql import SparkSession, Window
import pyspark.sql.functions as F

# Create SparkSession
spark = SparkSession.builder.appName("WindowFunctionsExample").getOrCreate()

# Sample Data
data = [
    ("Sales", "Alice", 5000), ("Sales", "Bob", 4500), ("Sales", "Charlie", 5500),
    ("HR", "David", 3500), ("HR", "Eve", 4000),
    ("IT", "Frank", 6000), ("IT", "Grace", 6500), ("IT", "Heidi", 6000)
]
columns = ["department", "employee", "salary"]
df = spark.createDataFrame(data, columns)

print("Original DataFrame:")
df.show()

# --- Define Windows ---

# Window partitioned by department, ordered by salary descending for ranking
dept_salary_rank_window = Window.partitionBy("department").orderBy(F.col("salary").desc())

# Window partitioned by department (no ordering needed for overall avg)
dept_avg_window = Window.partitionBy("department")

# --- Apply Window Functions ---

df_results = df.withColumn(
    "rank_in_dept",
    F.rank().over(dept_salary_rank_window) # Rank within department
).withColumn(
    "dense_rank_in_dept",
    F.dense_rank().over(dept_salary_rank_window) # Dense Rank
).withColumn(
    "avg_dept_salary",
    F.avg("salary").over(dept_avg_window) # Avg salary for the whole department partition
).withColumn(
    "diff_from_avg",
    F.col("salary") - F.col("avg_dept_salary") # Calculate difference from dept avg
)

# Example with Lag: Find previous employee's salary in the ranked list
df_results = df_results.withColumn(
    "prev_emp_salary",
    F.lag("salary", 1).over(dept_salary_rank_window) # Salary of the person ranked just above
)

# Example with Cumulative Sum
dept_cumulative_window = Window.partitionBy("department").orderBy("salary").rowsBetween(Window.unboundedPreceding, Window.currentRow)
df_results = df_results.withColumn(
    "cumulative_salary",
    F.sum("salary").over(dept_cumulative_window)
)


print("\nDataFrame with Window Function Results:")
df_results.show()

spark.stop()
```

> This example demonstrates how to define different window specifications and apply various functions (`rank`, `dense_rank`, `avg`, `lag`, `sum`) over those windows to add insightful columns without collapsing the original data structure. Remember that the choice of `partitionBy`, `orderBy`, and the frame definition (`rowsBetween`/`rangeBetween`) is critical to getting the desired calculation.

# Exercise

- we solved the previous exercise on grouping aggregates and got the following dataframe:

Now, extend this summary to meet the following requirements:
- compute week by week running total for each country

In [1]:
from pyspark.sql import DataFrame, SparkSession

spark: SparkSession = (
    SparkSession.builder
    .master("local[3]")  # type: ignore
    .appName("WindowAggregate")
    .getOrCreate()
)

spark.version

25/04/08 14:07:57 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable
25/04/08 14:07:58 WARN Utils: Service 'SparkUI' could not bind on port 4040. Attempting port 4041.


'3.5.5'

In [2]:
summaryDF: DataFrame = spark.read.format("parquet").load("output/*.parquet")  # previous solution

summarySortedDF = summaryDF.sort("Country", "WeekNumber")

summarySortedDF.show()

                                                                                

+---------------+----------+-----------+-------------+------------+
|        Country|WeekNumber|NumInvoices|TotalQuantity|InvoiceValue|
+---------------+----------+-----------+-------------+------------+
|      Australia|        48|          1|          107|      358.25|
|      Australia|        49|          1|          214|       258.9|
|      Australia|        50|          2|          133|      387.95|
|        Austria|        50|          2|            3|      257.04|
|        Bahrain|        51|          1|           54|      205.74|
|        Belgium|        48|          1|          528|       346.1|
|        Belgium|        50|          2|          285|      625.16|
|        Belgium|        51|          2|          942|      838.65|
|Channel Islands|        49|          1|           80|      363.53|
|         Cyprus|        50|          1|          917|     1590.82|
|        Denmark|        49|          1|          454|      1281.5|
|           EIRE|        48|          7|        

1. **break dataframe by `Country`**
2. _order each partition by the week number, for week by week total._
3. compute total using sliding window of records.

# 3 step process

- Identify partitioning columns. (_here, `Country`_)
- Identify ordering requirement. (_here, `WeekNumber`_)
- Define Window start and end. (_here, first record is start, including everything till current record for a `country`_)

In [3]:
summarySortedDF.show()

+---------------+----------+-----------+-------------+------------+
|        Country|WeekNumber|NumInvoices|TotalQuantity|InvoiceValue|
+---------------+----------+-----------+-------------+------------+
|      Australia|        48|          1|          107|      358.25|
|      Australia|        49|          1|          214|       258.9|
|      Australia|        50|          2|          133|      387.95|
|        Austria|        50|          2|            3|      257.04|
|        Bahrain|        51|          1|           54|      205.74|
|        Belgium|        48|          1|          528|       346.1|
|        Belgium|        50|          2|          285|      625.16|
|        Belgium|        51|          2|          942|      838.65|
|Channel Islands|        49|          1|           80|      363.53|
|         Cyprus|        50|          1|          917|     1590.82|
|        Denmark|        49|          1|          454|      1281.5|
|           EIRE|        48|          7|        

In [4]:
from pyspark.sql import Window
from pyspark.sql import functions as f

# Window.unboundedPreceding: take all rows from beginning
# runningWindowTotal = Window.partitionBy("Country").orderBy("WeekNumber").rowsBetween(-3, Window.currentRow)  # for 3 weeks
runningWindowTotal = Window.partitionBy("Country").orderBy("WeekNumber").rowsBetween(Window.unboundedPreceding, Window.currentRow)

# summaryDF.withColumn("RunningTotal", f.sum("InvoiceValue").over(runningWindowTotal)).show()
summaryDF.withColumn("RunningTotal", f.round(f.sum("InvoiceValue").over(runningWindowTotal), 2)).show()

+---------------+----------+-----------+-------------+------------+------------+
|        Country|WeekNumber|NumInvoices|TotalQuantity|InvoiceValue|RunningTotal|
+---------------+----------+-----------+-------------+------------+------------+
|      Australia|        48|          1|          107|      358.25|      358.25|
|      Australia|        49|          1|          214|       258.9|      617.15|
|      Australia|        50|          2|          133|      387.95|      1005.1|
|        Austria|        50|          2|            3|      257.04|      257.04|
|        Bahrain|        51|          1|           54|      205.74|      205.74|
|        Belgium|        48|          1|          528|       346.1|       346.1|
|        Belgium|        50|          2|          285|      625.16|      971.26|
|        Belgium|        51|          2|          942|      838.65|     1809.91|
|Channel Islands|        49|          1|           80|      363.53|      363.53|
|         Cyprus|        50|