<a href="https://colab.research.google.com/github/anjli01/PySpark-Notes/blob/main/10_Window_Functions.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

### 1. Advanced Analytics: Window Functions

Window functions perform calculations across a set of table rows that are somehow related to the current row. Unlike aggregate functions (e.g., `sum()`, `avg()`) which return a single value for an entire group, window functions return a value for *each row*.

#### 1.1 Defining a Window Specification (`Window.partitionBy()`, `Window.orderBy()`, `rowsBetween()`)

To use a window function, you first define a window specification using the `Window` object.

*   `Window.partitionBy(*cols)`:
    *   Divides rows into groups (partitions) based on specified columns.
    *   The window function is applied independently within each partition.
    *   Similar conceptually to `GROUP BY`. If omitted, the entire DataFrame is treated as a single partition.
*   `Window.orderBy(*cols)`:
    *   Orders the rows *within each partition*.
    *   Crucial for functions that depend on row order (e.g., `row_number()`, `rank()`, `lag()`, `lead()`, running totals).
*   `rowsBetween(start, end)` / `rangeBetween(start, end)`:
    *   Defines the "frame" of rows within a partition that the window function operates on.
    *   `Window.unboundedPreceding`: From the beginning of the partition.
    *   `Window.currentRow`: The current row.
    *   `Window.unboundedFollowing`: To the end of the partition.
    *   **Common Frames:**
        *   `Window.unboundedPreceding, Window.currentRow`: For running sums/averages (includes all rows from start of partition up to current row).
        *   `Window.currentRow, Window.unboundedFollowing`: For calculations from the current row to the end.

#### 1.2 Common Window Functions

| Function          | Description                                                                                             | Notes                                                                                                                                                                                            |
| :---------------- | :------------------------------------------------------------------------------------------------------ | :----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| `row_number()`    | Assigns a unique, sequential integer to each row within its partition, starting from 1.                 | No ties; each row gets a distinct number.                                                                                                                                                        |
| `rank()`          | Assigns a rank to each row within its partition. Ties get the same rank, and a gap is left in the sequence. | If two rows tie for rank 2, the next rank will be 4 (not 3).                                                                                                                                     |
| `dense_rank()`    | Similar to `rank()`, but no gaps are left in the ranking sequence when there are ties.                  | If two rows tie for rank 2, the next rank will be 3.                                                                                                                                             |
| `lag(col, offset, default_value)` | Returns the value of `col` from a row `offset` rows *before* the current row in the partition. | `offset` (default 1) specifies how many rows back. `default_value` (default None) is used if `offset` goes beyond the partition start (e.g., first row of partition). Useful for previous period comparisons. |
| `lead(col, offset, default_value)` | Returns the value of `col` from a row `offset` rows *after* the current row in the partition.  | `offset` (default 1) specifies how many rows forward. `default_value` (default None) is used if `offset` goes beyond the partition end (e.g., last row of partition). Useful for next period comparisons.  |
| `sum(col)`        | Calculates the sum of `col` within the defined window frame.                                            | Can be used with `rowsBetween` for running totals.                                                                                                                                               |
| `avg(col)`        | Calculates the average of `col` within the defined window frame.                                        | Can be used with `rowsBetween` for moving averages.                                                                                                                                              |


In [1]:
from pyspark.sql import SparkSession
from pyspark.sql.functions import col, row_number, rank, dense_rank, lag, lead, sum, avg
from pyspark.sql.window import Window

spark = SparkSession.builder.appName("WindowFunctions").getOrCreate()

data = [
    ("Sales", "Alice", 2023, 100),
    ("Sales", "Bob", 2023, 150),
    ("Sales", "Alice", 2024, 120),
    ("HR", "Charlie", 2023, 80),
    ("HR", "David", 2024, 90),
    ("Sales", "Bob", 2024, 160)
]
columns = ["Department", "Employee", "Year", "Sales"]
df = spark.createDataFrame(data, columns)
df.show()

# Define a window specification: Partition by Department, order by Year and Sales
window_spec = Window.partitionBy("Department").orderBy("Year", "Sales")

# Row Number Example
print("\nRow Number by Department and Year/Sales:")
df.withColumn("row_num", row_number().over(window_spec)).show()

# Define a window spec for sum of sales per department over all years up to current
window_sum = Window.partitionBy("Department").orderBy("Year").rowsBetween(Window.unboundedPreceding, Window.currentRow)

# Running Sum Example
print("\nRunning Sum of Sales per Department:")
df.withColumn("Running_Sum_Sales", sum("Sales").over(window_sum)).show()

# --- Ranking functions with ties ---
ranking_data = [
    ("DeptA", "Alice", 100),
    ("DeptA", "Bob", 120),
    ("DeptA", "Charlie", 120), # Tie for Bob and Charlie
    ("DeptA", "David", 150),
    ("DeptB", "Eve", 80),
    ("DeptB", "Frank", 80), # Tie for Eve and Frank
    ("DeptB", "Grace", 90)
]
ranking_cols = ["Department", "Employee", "Score"]
ranking_df = spark.createDataFrame(ranking_data, ranking_cols)
ranking_df.show()

# Define a window specification: Partition by Department, order by Score (descending)
window_spec_rank = Window.partitionBy("Department").orderBy(col("Score").desc())

print("\nRanking functions with ties:")
ranking_df.withColumn("row_num", row_number().over(window_spec_rank)) \
    .withColumn("rank", rank().over(window_spec_rank)) \
    .withColumn("dense_rank", dense_rank().over(window_spec_rank)) \
    .show()

# --- Lag and Lead Example ---
sales_data = [
    ("Alice", 2022, 1000),
    ("Alice", 2023, 1200),
    ("Alice", 2024, 1100),
    ("Bob", 2022, 1500),
    ("Bob", 2023, 1600)
]
sales_cols = ["Employee", "Year", "Sales"]
sales_df = spark.createDataFrame(sales_data, sales_cols)
sales_df.show()

# Window for lag/lead: partition by Employee, order by Year
window_lag_lead = Window.partitionBy("Employee").orderBy("Year")

print("\nLag and Lead for Sales:")
sales_df.withColumn("Prev_Year_Sales", lag("Sales", 1).over(window_lag_lead)) \
    .withColumn("Next_Year_Sales", lead("Sales", 1).over(window_lag_lead)) \
    .show()

# spark.stop() # Uncomment if this is the end of your script

+----------+--------+----+-----+
|Department|Employee|Year|Sales|
+----------+--------+----+-----+
|     Sales|   Alice|2023|  100|
|     Sales|     Bob|2023|  150|
|     Sales|   Alice|2024|  120|
|        HR| Charlie|2023|   80|
|        HR|   David|2024|   90|
|     Sales|     Bob|2024|  160|
+----------+--------+----+-----+


Row Number by Department and Year/Sales:
+----------+--------+----+-----+-------+
|Department|Employee|Year|Sales|row_num|
+----------+--------+----+-----+-------+
|        HR| Charlie|2023|   80|      1|
|        HR|   David|2024|   90|      2|
|     Sales|   Alice|2023|  100|      1|
|     Sales|     Bob|2023|  150|      2|
|     Sales|   Alice|2024|  120|      3|
|     Sales|     Bob|2024|  160|      4|
+----------+--------+----+-----+-------+


Running Sum of Sales per Department:
+----------+--------+----+-----+-----------------+
|Department|Employee|Year|Sales|Running_Sum_Sales|
+----------+--------+----+-----+-----------------+
|        HR| Charlie|2023

#### 1.3 Practical Use Cases for Window Functions

Window functions are incredibly versatile for various data engineering tasks.

*   **De-duplication:**
    *   Find duplicate rows based on a subset of columns.
    *   Assign `row_number()` or `rank()` to each duplicate group within partitions.
    *   Filter to keep only the first (or last) occurrence.
    *   **Process:**
        1.  `Window.partitionBy("unique_key_columns")`: Group rows that are considered duplicates.
        2.  `Window.orderBy("criteria_for_keeping_one")`: Decide which duplicate to keep (e.g., latest timestamp, highest ID).
        3.  Apply `row_number().over(...)`.
        4.  `filter(col("row_num") == 1)` to keep only the first (or desired) record in each partition.

In [2]:
from pyspark.sql import SparkSession
from pyspark.sql.functions import col, row_number
from pyspark.sql.window import Window

spark = SparkSession.builder.appName("Deduplication").getOrCreate()

duplicate_data = [
    ("Alice", "alice@email.com", "NY", 1),
    ("Bob", "bob@email.com", "LD", 2),
    ("Alice", "alice@email.com", "LA", 3), # Duplicate Alice, different city
    ("Charlie", "charlie@email.com", "SF", 4),
    ("Bob", "bob@email.com", "LD", 5)  # Exact duplicate Bob, higher RecordID
]
dup_cols = ["Name", "Email", "City", "RecordID"]
dup_df = spark.createDataFrame(duplicate_data, dup_cols)
print("Original DataFrame with duplicates:")
dup_df.show()

# Deduplicate keeping the record with the higher RecordID (more recent)
window_spec_dedup = Window.partitionBy("Name", "Email").orderBy(col("RecordID").desc())

deduplicated_df = dup_df.withColumn("row_num", row_number().over(window_spec_dedup)) \
    .filter(col("row_num") == 1) \
    .drop("row_num")

print("\nDeduplicated DataFrame (keeping latest record):")
deduplicated_df.show()

# spark.stop() # Uncomment if this is the end of your script

Original DataFrame with duplicates:
+-------+-----------------+----+--------+
|   Name|            Email|City|RecordID|
+-------+-----------------+----+--------+
|  Alice|  alice@email.com|  NY|       1|
|    Bob|    bob@email.com|  LD|       2|
|  Alice|  alice@email.com|  LA|       3|
|Charlie|charlie@email.com|  SF|       4|
|    Bob|    bob@email.com|  LD|       5|
+-------+-----------------+----+--------+


Deduplicated DataFrame (keeping latest record):
+-------+-----------------+----+--------+
|   Name|            Email|City|RecordID|
+-------+-----------------+----+--------+
|  Alice|  alice@email.com|  LA|       3|
|    Bob|    bob@email.com|  LD|       5|
|Charlie|charlie@email.com|  SF|       4|
+-------+-----------------+----+--------+



*   **Ranking:**
    *   Assign ranks to items within categories (e.g., top 3 students per class, top 10 products per region).
    *   Use `rank()`, `dense_rank()`, or `row_number()` based on tie-breaking requirements.

*   **Time Series Logic:**
    *   **Moving Averages/Cumulative Sums:** Use `rowsBetween(Window.unboundedPreceding, Window.currentRow)` or specific N-row windows.
    *   **Period-over-period comparison:** Compare current values with previous or subsequent values using `lag()` and `lead()`.

In [3]:
from pyspark.sql import SparkSession
from pyspark.sql.functions import col, lag, sum
from pyspark.sql.window import Window

spark = SparkSession.builder.appName("TimeSeries").getOrCreate()

time_series_data = [
    ("ProductA", "2023-01-01", 100),
    ("ProductA", "2023-02-01", 120),
    ("ProductA", "2023-03-01", 110),
    ("ProductB", "2023-01-01", 200),
    ("ProductB", "2023-02-01", 230)
]
ts_cols = ["Product", "Date", "Sales"]
ts_df = spark.createDataFrame(time_series_data, ts_cols)
print("Original Time Series DataFrame:")
ts_df.show()

window_ts = Window.partitionBy("Product").orderBy("Date")

# Calculate previous month's sales and month-over-month growth
# Note: For MoM Growth, typically handle division by zero or first period specially.
ts_df_with_growth = ts_df.withColumn("PreviousMonthSales", lag("Sales", 1).over(window_ts)) \
    .withColumn("MoM_Growth",
                (col("Sales") - col("PreviousMonthSales")) / col("PreviousMonthSales"))

print("\nTime Series with Previous Month Sales and MoM Growth:")
ts_df_with_growth.show()

# Calculate cumulative sum of sales for each product
ts_df_cumulative = ts_df.withColumn("CumulativeSales",
    sum("Sales").over(window_ts.rowsBetween(Window.unboundedPreceding, Window.currentRow)))

print("\nTime Series with Cumulative Sales:")
ts_df_cumulative.show()

# spark.stop() # Uncomment if this is the end of your script

Original Time Series DataFrame:
+--------+----------+-----+
| Product|      Date|Sales|
+--------+----------+-----+
|ProductA|2023-01-01|  100|
|ProductA|2023-02-01|  120|
|ProductA|2023-03-01|  110|
|ProductB|2023-01-01|  200|
|ProductB|2023-02-01|  230|
+--------+----------+-----+


Time Series with Previous Month Sales and MoM Growth:
+--------+----------+-----+------------------+--------------------+
| Product|      Date|Sales|PreviousMonthSales|          MoM_Growth|
+--------+----------+-----+------------------+--------------------+
|ProductA|2023-01-01|  100|              NULL|                NULL|
|ProductA|2023-02-01|  120|               100|                 0.2|
|ProductA|2023-03-01|  110|               120|-0.08333333333333333|
|ProductB|2023-01-01|  200|              NULL|                NULL|
|ProductB|2023-02-01|  230|               200|                0.15|
+--------+----------+-----+------------------+--------------------+


Time Series with Cumulative Sales:
+--------+-