# Week 02 - Spark Batch Processing with DataFrames

**Topics covered:**
1. Creating DataFrames (5 methods)
2. Writing DataFrames
3. Column transformations
4. Row transformations
5. Aggregations
6. Datetime functions
7. Custom schemas
8. Efficient Writes
   - Partitioned Writes
   - Write Modes and Idempotency
9. Complex types: Struct & Array
10. Complex types: Map
11. Pivot tables
12. Window functions
13. Joins
    - Reading the Query Plan
14. User-Defined Functions (UDFs)
15. Exercises


---
## 0. SparkSession Setup

`SparkContext` is the low-level entry point for RDD operations.  
`SparkSession` is the unified entry point for DataFrame and SQL operations, and wraps the SparkContext internally.  
Enabling Hive support lets Spark persist tables to a local Hive metastore (useful for `saveAsTable` and `spark.table()`).

In [None]:
from pyspark.context import SparkContext
from pyspark.sql import SparkSession
import pyspark.sql.functions as F

sc = SparkContext('local', 'week02_practice')

spark = (
    SparkSession.builder
    .appName('week02_practice')
    .enableHiveSupport()   # persist tables to local Hive metastore
    .getOrCreate()
)

print(spark.version)

---
## 1. Creating DataFrames

There are five common ways to create a DataFrame in PySpark.  
In practice, reading from a file source (CSV, JSON, Parquet) is most common for batch pipelines.  
Reading from tables or SQL is useful when working in a shared catalog environment.

In [None]:
# Method 1: from an RDD
first_rdd = sc.parallelize([
    (1, "Batman"),
    (2, "Superman"),
    (3, "Spiderman")
])

first_df = spark.createDataFrame(first_rdd, ["id", "name"])
first_df.show()

In [None]:
# Method 2: from a file source (CSV, JSON, Parquet)
# CSV - specify delimiter, header, and schema inference
crimes_df = (
    spark.read
    .option("sep", "\t")       # tab-separated
    .option("header", True)
    .option("inferSchema", True)
    .csv("data/input/Chicago-Crimes-2018.csv")
)

# JSON - Spark handles nested structures automatically
events_df = (
    spark.read
    .option("inferSchema", True)
    .json("data/input/events-500k.json")
)

# Parquet - schema is embedded in the file, no options needed
sales_df = spark.read.parquet("data/input/sales.parquet")
users_df = spark.read.parquet("data/input/users.parquet")

events_df.printSchema()

In [None]:
# Method 3: from a catalog table
# First we need to write a table (see Section 2), then read it back
events_df.write.mode("overwrite").saveAsTable("event_table")

events_table_df = spark.table("event_table")
events_table_df.show(3)

In [None]:
# Method 4: from a SQL statement
# Any table registered in the catalog can be queried directly
events_sql_df = spark.sql("""
    SELECT device, event_name, CURRENT_DATE() AS cdate
    FROM event_table
    LIMIT 50
""")

events_sql_df.show(5)

In [None]:
# Method 5: from Row objects
# Row is Spark's generic record type; columns are assigned by keyword argument
from pyspark.sql import Row

my_data = Row("id", "product", "cost")   # define a row template

rows_df = spark.createDataFrame([
    my_data(1, "mac",     1000),
    my_data(2, "windows",  500),
    my_data(3, "linux",    700)
])

rows_df.show()

---
## 2. Writing DataFrames

DataFrames can be persisted to files or to a catalog table.  
Parquet is the recommended columnar format for Spark batch pipelines - it supports efficient compression and predicate pushdown.  
Temporary views are session-scoped and are useful for SQL queries without touching disk.

In [None]:
# Write to Parquet with Snappy compression
# mode options: "overwrite", "append", "error" (default), "ignore"
(
    first_df.write
    .option("compression", "snappy")
    .mode("overwrite")
    .parquet("data/output/first.parquet")
)

# Verify by reading back
spark.read.parquet("data/output/first.parquet").show()

In [None]:
# Save as a managed Hive table - persists to disk across sessions
# Use spark.catalog.listTables() to see registered tables
(
    events_df.write
    .mode("overwrite")
    .saveAsTable("event_table")
)

print("Tables in catalog:", [t.name for t in spark.catalog.listTables()])

In [None]:
# Create a temporary SQL view - session-scoped, not written to disk
events_df.createOrReplaceTempView("events_view")

spark.sql("SELECT event_name, COUNT(*) as n FROM events_view GROUP BY event_name ORDER BY n DESC").show(5)

---
## 3. Column Transformations

Column transformations produce a new DataFrame - Spark DataFrames are immutable.  
There are multiple equivalent ways to reference a column: dot notation, `F.col()`, or string indexing.  
`.withColumn()` is the idiomatic way to add or replace a column without rewriting the whole select.

In [None]:
# Selecting columns - multiple equivalent styles
events_df.select("user_id", "device").show(3)

# F.col() allows chaining methods on the column object
events_df.select(F.col("user_id"), F.col("device")).show(3)

# Wildcard: select all top-level columns
events_df.select(F.col("*")).show(3)

# Nested struct fields via dot notation in the select string
events_df.select("user_id", "geo.city", "geo.state").show(3)

In [None]:
# selectExpr allows inline SQL expressions (CASE, IN, arithmetic, etc.)
events_df.selectExpr(
    "user_id",
    "device IN ('macOS', 'iOS') AS apple_user",
    "CASE WHEN device = 'Windows' THEN 'Microsoft' ELSE 'Other' END AS platform"
).show(5)

In [None]:
# Drop multiple columns at once
anonymous_df = events_df.drop("user_id", "geo", "device")
anonymous_df.printSchema()

In [None]:
# withColumn - add a new column or replace an existing one
mobile_df = events_df.withColumn("mobile", F.col("device").isin("iOS", "Android"))
mobile_df.select("device", "mobile").show(5)

In [None]:
# Rename a column
location_df = events_df.withColumnRenamed("geo", "location")
location_df.printSchema()

In [None]:
# Conditional column values with when / otherwise
# Equivalent to a CASE WHEN statement in SQL
warranty_df = events_df.select(
    "*",
    F.when(F.col("event_name") == "warranty", "issue")
     .when(F.col("event_name") == "cart", "sale")
     .otherwise("other")
     .alias("event_class")
)

warranty_df.select("event_name", "event_class").distinct().show()

---
## 4. Row Transformations

Row transformations reduce or reorder the rows in a DataFrame without changing its schema.  
Filtering pushes conditions down to the scan layer, making it one of the most impactful optimisations.  
`dropDuplicates(subset)` is more useful than `distinct()` when you want uniqueness on a subset of columns.

In [None]:
# Filter using a SQL string expression
purchases_df = events_df.filter("ecommerce.total_item_quantity > 0")
purchases_df.show(3)

# Filter using column objects - allows combining multiple conditions
revenue_df = events_df.filter(
    (F.col("ecommerce.purchase_revenue_in_usd").isNotNull()) &
    (F.col("ecommerce.total_item_quantity") > 1)
)
revenue_df.show(3)

In [None]:
# distinct() removes fully duplicate rows
# dropDuplicates(subset) keeps the first occurrence per unique value combination
distinct_event_names_df = events_df.dropDuplicates(["event_name"])
distinct_event_names_df.select("event_name").show()

In [None]:
# Limit to the first n rows
events_df.limit(5).show()

In [None]:
# Sort ascending (default)
events_df.sort("event_timestamp").show(3)

# Sort descending using .desc()
events_df.sort(F.col("event_timestamp").desc()).show(3)

# Multi-column sort - orderBy is an alias for sort
events_df.orderBy(["device", "event_timestamp"]).show(3)

---
## 5. Aggregations

`groupBy()` returns a `GroupedData` object; you chain an aggregation function to get a new DataFrame.  
The `.agg()` method lets you compute multiple aggregations in a single pass over the data.  
Use `F.approx_count_distinct()` instead of `F.countDistinct()` for large datasets - it is much faster with minimal error.

In [None]:
# count per group
events_df.groupBy("event_name").count().orderBy(F.desc("count")).show()

In [None]:
# average - column name uses the original column name by default
events_df.groupBy("geo.state").avg("ecommerce.purchase_revenue_in_usd").show(5)

In [None]:
# sum with multiple grouping keys
events_df.groupBy("geo.state", "geo.city").sum("ecommerce.total_item_quantity").show(5)

In [None]:
# agg() - multiple aggregations in one step, with aliases
state_agg_df = events_df.groupBy("geo.state").agg(
    F.sum("ecommerce.total_item_quantity").alias("total_items"),
    F.avg("ecommerce.purchase_revenue_in_usd").alias("avg_revenue"),
    F.approx_count_distinct("user_id").alias("distinct_users")
)

state_agg_df.show(5)

---
## 6. Datetime Functions

Spark timestamps are stored as `TimestampType` (microseconds since epoch internally).  
The events dataset stores `event_timestamp` as a long integer in microseconds - divide by 1e6 before casting.  
Always use Spark's built-in datetime functions (`F.date_format`, `F.year`, etc.) rather than Python's `datetime`, which would require a UDF and break the Catalyst optimizer.

In [None]:
# Cast epoch microseconds to a proper timestamp
timestamp_df = events_df.withColumn(
    "event_ts",
    (F.col("event_timestamp") / 1e6).cast("timestamp")
)

timestamp_df.select("event_timestamp", "event_ts").show(3)

In [None]:
# Format a timestamp to a human-readable string
# Pattern follows Java SimpleDateFormat: MMMM = full month name, HH = 24h hour
format_df = (
    timestamp_df
    .withColumn("date_str", F.date_format("event_ts", "MMMM dd, yyyy"))
    .withColumn("time_str", F.date_format("event_ts", "HH:mm:ss"))
)

format_df.select("event_ts", "date_str", "time_str").show(3, truncate=False)

In [None]:
# Extract individual date/time parts
datetime_df = (
    timestamp_df
    .withColumn("year",      F.year("event_ts"))
    .withColumn("month",     F.month("event_ts"))
    .withColumn("dayofweek", F.dayofweek("event_ts"))  # 1 = Sunday
    .withColumn("hour",      F.hour("event_ts"))
    .withColumn("minute",    F.minute("event_ts"))
)

datetime_df.select("event_ts", "year", "month", "dayofweek", "hour", "minute").show(3)

In [None]:
# to_date() truncates to just the date part
date_df = timestamp_df.withColumn("date", F.to_date("event_ts"))
date_df.select("event_ts", "date").show(3)

In [None]:
# date_add() / date_sub() for calendar arithmetic
# SQL interval syntax also works: selectExpr("event_ts + interval 2 days")
plus_df = timestamp_df.withColumn("plus_two_days", F.date_add("event_ts", 2))
plus_df.select("event_ts", "plus_two_days").show(3)

---
## 7. Custom Schemas

By default Spark *infers* the schema by scanning the data, which is expensive for large files.  
Providing an explicit schema makes reads faster and gives you full control over column names and types.  
Schema inference also maps JSON objects to `MapType`, which is harder to work with than `StructType` - as we'll see in the next section.

In [None]:
# Fetch a sample JSON dataset (airlines delay statistics)
import requests

r = requests.get("https://corgis-edu.github.io/corgis/datasets/json/airlines/airlines.json")
print(r.json()[:1])   # preview first record

In [None]:
# Inferred schema: Python dicts become MapType columns - hard to query
airlines_map_df = spark.createDataFrame(r.json())
airlines_map_df.printSchema()

In [None]:
# Method 1: StructType + StructField objects
# StructField(name, dataType, nullable)
from pyspark.sql.types import (
    StructType, StructField, StringType, IntegerType, LongType
)

airport_schema = StructType([
    StructField("Airport", StructType([
        StructField("Code", StringType(), True),
        StructField("Name", StringType(), True)
    ]), True),
    StructField("Statistics", StructType([
        StructField("Carriers", StructType([
            StructField("Names", StringType(), True),
            StructField("Total", IntegerType(), True)
        ]), True),
        StructField("Flights", StructType([
            StructField("Delayed",   LongType(), True),
            StructField("Cancelled", LongType(), True),
            StructField("On Time",   LongType(), True),
            StructField("Total",     LongType(), True)
        ]), True)
    ]), True),
    StructField("Time", StructType([
        StructField("Label",      StringType(),  True),
        StructField("Month",      IntegerType(), True),
        StructField("Year",       IntegerType(), True),
        StructField("Month Name", StringType(),  True)
    ]), True)
])

airlines_df = spark.createDataFrame(r.json(), schema=airport_schema)
airlines_df.printSchema()

In [None]:
# Method 2: StructType .add() chaining - more readable for large schemas
airport_add_schema = (
    StructType()
    .add("Airport", StructType()
         .add("Code", StringType())
         .add("Name", StringType()))
    .add("Time", StructType()
         .add("Month", IntegerType())
         .add("Year",  IntegerType()))
)

airlines_add_df = spark.createDataFrame(r.json(), schema=airport_add_schema)
airlines_add_df.show(3)

In [None]:
# Method 3: DDL string - concise, good for simple schemas
# Columns absent from the data will appear as null
airport_string_schema = "Airport STRUCT<Code: STRING, Name: STRING>, Time STRUCT<Month: INTEGER, Year: INTEGER>"

airlines_str_df = spark.createDataFrame(r.json(), schema=airport_string_schema)
airlines_str_df.show(3)

In [None]:
# With a StructType schema, nested fields are easy to access with dot notation
airlines_df.select(
    "Airport.Code",
    "Airport.Name",
    "Time.Year",
    "Time.Month Name",
    "Statistics.Flights.Delayed",
    "Statistics.Flights.Cancelled"
).show(5)

## 8. Efficient Writes

### Partitioned Writes
When you write with `.partitionBy()`, Spark creates a folder hierarchy on disk (e.g. `year=2003/month=6/`).  
This is how batch tables are organised so downstream queries can skip irrelevant folders - a technique called *partition pruning*.  
The partition key should match the most common filter column; writing is slightly slower (sort + separate files per partition), but reads become dramatically faster when filtered.


In [None]:
# Flatten the nested airlines_df into a wide table before partitioning
airlines_flat_df = airlines_df.select(
    F.col("Airport.Code").alias("airport_code"),
    F.col("Airport.Name").alias("airport_name"),
    F.col("Time.Year").alias("year"),
    F.col("Time.Month").alias("month"),
    F.col("Statistics.Flights.Delayed").alias("flights_delayed"),
    F.col("Statistics.Flights.Cancelled").alias("flights_cancelled"),
    F.col("Statistics.Flights.Total").alias("flights_total")
)

airlines_flat_df.write \
    .mode("overwrite") \
    .partitionBy("year", "month") \
    .parquet("data/output/airlines_partitioned.parquet")

airlines_flat_df.show(3)


In [None]:
# Each sub-folder is a partition - Spark reads only the folders matching your WHERE clause
!find data/output/airlines_partitioned.parquet -type d | head -20


In [None]:
# Read back with a filter - look for PartitionFilters in the scan node
# Spark skips all folders except year=2006/month=1
spark.read.parquet("data/output/airlines_partitioned.parquet") \
    .filter("year = 2006 AND month = 1") \
    .explain(True)


In [None]:
# Read back with a filter - look for PartitionFilters in the scan node
# Spark skips all folders except year=2006/month=1
airlines_flat_df \
    .filter("year = 2006 AND month = 1") \
    .explain(True)


### Write Modes and Idempotency

If your job fails and you rerun it, what happens?  
With `append` mode you get duplicate data; with `overwrite` mode the second run replaces the first - same input produces the same output.  
This property is called *idempotency*: safe to rerun without side-effects. In production batch pipelines, prefer `overwrite` on a specific partition over `append`.


In [None]:
# Danger: append mode creates duplicates on rerun
small_df = airlines_flat_df.filter("year = 2003 AND month = 6")

small_df.write.mode("append").parquet("data/output/append_demo")
small_df.write.mode("append").parquet("data/output/append_demo")
print("Append twice, row count:", spark.read.parquet("data/output/append_demo").count())
# Expected: 2x the original count

# Safe: overwrite mode is idempotent
small_df.write.mode("overwrite").parquet("data/output/overwrite_demo")
small_df.write.mode("overwrite").parquet("data/output/overwrite_demo")
print("Overwrite twice, row count:", spark.read.parquet("data/output/overwrite_demo").count())
# Expected: same as original count


---
## 9. Complex Types: Struct & Array

Spark natively handles nested data types: `StructType` (record with named fields), `ArrayType` (ordered list), and `MapType` (key-value pairs).  
`explode()` turns each array element into its own row, which is essential for flattening nested datasets.  
`explode_outer()` preserves rows where the array is null or empty (equivalent to a LEFT JOIN LATERAL).

In [None]:
# The sales dataset has an 'items' array column
sales_df.printSchema()

In [None]:
# explode() - one row per array element; rows with null/empty array are dropped
# explode_outer() - same but keeps rows where array is null
details_df = (
    sales_df
    .withColumn("items", F.explode("items"))           # flatten the items array
    .select("email", "items.item_name")
    .withColumn("details", F.split(F.col("item_name"), " "))  # split string → array
)

details_df.show(5)

In [None]:
# array_contains() - boolean check
# element_at()    - index access (1-based!)
mattress_df = (
    details_df
    .filter(F.array_contains(F.col("details"), "Mattress"))
    .withColumn("size",    F.element_at(F.col("details"), 2))
    .withColumn("quality", F.element_at(F.col("details"), 1))
)

pillow_df = (
    details_df
    .filter(F.array_contains(F.col("details"), "Pillow"))
    .withColumn("size",    F.element_at(F.col("details"), 1))
    .withColumn("quality", F.element_at(F.col("details"), 2))
)

mattress_df.show(3)
pillow_df.show(3)

In [None]:
# unionByName() aligns columns by name, not by position
# Safer than union() when schemas might have columns in different orders
union_df = mattress_df.unionByName(pillow_df).drop("details")
union_df.show(5)

In [None]:
# collect_set() - aggregation that returns an array of distinct values per group
options_df = (
    union_df.groupBy("email")
    .agg(
        F.collect_set("size").alias("size_options"),
        F.collect_set("quality").alias("quality_options")
    )
)

options_df.show(5, truncate=False)

---
## 10. Complex Types: Map

`MapType` columns store key-value pairs (like Python dicts).  
They appear when Spark infers the schema of JSON objects without an explicit schema provided - which is exactly what we saw in section 7.  
In production, prefer converting maps to structs via a custom schema for type safety and easier access.

In [None]:
# The inferred schema from the airlines JSON produces MapType columns
airlines_map_df.printSchema()

In [None]:
# Three ways to extract a value from a MapType column
airlines_map_df.select(
    F.col("Airport").getItem("Code").alias("code_getItem"),  # method 1: .getItem()
    F.col("Airport")["Code"].alias("code_bracket"),           # method 2: bracket notation
    F.col("Airport.Code").alias("code_dot"),                  # method 3: dot notation
).show(5)

In [None]:
# map_keys() and map_values() return all keys/values as arrays
airlines_map_df.select(
    F.map_keys("Airport").alias("keys"),
    F.map_values("Airport").alias("values")
).show(3, truncate=False)

---
## 11. Pivot Tables

Pivot reshapes data from a *long* format (one row per observation) to a *wide* format (one column per category value).  
The syntax is: `.groupBy(row_keys).pivot(column_key).agg(value)`.  
Pivot can be expensive on high-cardinality columns - pass a list of known values to the `pivot()` call to limit the number of output columns.

In [None]:
# Airbnb Amsterdam listings - average price by neighbourhood and number of guests
airbnb_df = spark.read.parquet("data/input/amsterdam-listings-2018-12-06.parquet")
airbnb_df.select("city", "neighbourhood", "accommodates", "price").show(3)

In [None]:
# Pivot: rows = neighbourhood, columns = accommodates count, values = mean price
(
    airbnb_df
    .select("city", "neighbourhood", "accommodates", "price")
    .groupBy("city", "neighbourhood")
    .pivot("accommodates")          # one column per distinct value
    .mean("price")
    .na.fill(0)
    .orderBy(F.desc("2"))           # order by the '2 guests' column
).show(10, truncate=False)

In [None]:
# Wikipedia pageviews - total requests by date and hour, split by site
wiki_df = spark.read.parquet("data/input/pageviews_by_second.parquet")
wiki_df.show(3)

In [None]:
(
    wiki_df
    .selectExpr(
        "cast(timestamp as date) AS date",
        "hour(timestamp) AS hour",
        "site",
        "requests"
    )
    .groupBy("date", "hour")
    .pivot("site")                  # one column per site (desktop / mobile)
    .sum("requests")
    .orderBy("date", "hour")
).show(10)

---
## 12. Window Functions

Window functions compute a result for each row based on a *window* of related rows - without collapsing the DataFrame like `groupBy` does.  
A `WindowSpec` defines the partition (equivalent to `PARTITION BY`) and the ordering within each partition.  
Common use cases: ranking, running totals, moving averages, and comparing each row to its predecessor.

In [None]:
# Load the healthcare dataset - daily measurements per patient
from pyspark.sql import Window

healthcare_df = spark.read.parquet("data/input/health_profile_data.snappy.parquet")
healthcare_df.show(5)

In [None]:
# Define a window: partition by patient, order by date
window_by_date = Window.partitionBy("_id").orderBy("dte")

# row_number() - unique sequential number per partition (no ties)
# rank()       - same rank for ties, then skips numbers
# dense_rank() - same rank for ties, no gaps
(
    healthcare_df
    .withColumn("row_num",    F.row_number().over(window_by_date))
    .withColumn("rank",       F.rank().over(window_by_date))
    .withColumn("dense_rank", F.dense_rank().over(window_by_date))
    .select("_id", "dte", "resting_heartrate", "row_num", "rank", "dense_rank")
).show(8)

In [None]:
# lag()  - access a previous row's value (offset rows back)
# lead() - access a future row's value  (offset rows ahead)
# Useful for computing deltas between consecutive measurements
(
    healthcare_df
    .withColumn("prev_hr",  F.lag("resting_heartrate", 1, 0).over(window_by_date))
    .withColumn("next_hr",  F.lead("resting_heartrate", 1, 0).over(window_by_date))
    .withColumn("delta_hr", F.expr("resting_heartrate - prev_hr"))
    .select("_id", "dte", "resting_heartrate", "prev_hr", "next_hr", "delta_hr")
).show(8)

In [None]:
# Rolling window frames using rowsBetween()
# Window.unboundedPreceding = from the start of the partition
# Window.currentRow         = up to and including the current row

window_cumulative  = Window.partitionBy("_id").orderBy("dte").rowsBetween(Window.unboundedPreceding, Window.currentRow)
window_last_7_days = Window.partitionBy("_id").orderBy("dte").rowsBetween(-6, Window.currentRow)

(
    healthcare_df
    .withColumn("cumulative_avg_hr",    F.avg("resting_heartrate").over(window_cumulative))
    .withColumn("rolling_7day_avg",     F.avg("resting_heartrate").over(window_last_7_days))
    .withColumn("rolling_7day_max_bmi", F.max("BMI").over(window_last_7_days))
    .select("_id", "dte", "resting_heartrate", "cumulative_avg_hr", "rolling_7day_avg", "rolling_7day_max_bmi")
).show(10)

---
## 13. Joins

Spark supports all standard SQL join types: `inner`, `left`, `right`, `outer`, `cross`.  
When two DataFrames share a column name after a join, Spark keeps both - use `.drop()` or alias one side before joining.  
A *broadcast join* avoids an expensive shuffle by sending the smaller DataFrame to every executor - use when one side is small (< 10 MB by default).

In [None]:
# Outer join: all users, flagging which ones converted (made a purchase)
converted_users_df = (
    sales_df
    .select("email")
    .distinct()
    .withColumn("converted", F.lit(True))
)

conversions_df = (
    users_df
    .join(converted_users_df, "email", "outer")
    .filter(F.col("email").isNotNull())
    .na.fill(False)                    # fill null booleans → False
)

conversions_df.show(5)

In [None]:
# Left join: attach cart history (may not exist for all users)
carts_df = (
    events_df
    .withColumn("items", F.explode("items"))
    .groupBy("user_id")
    .agg(F.collect_set("items.item_id").alias("cart"))
)

email_carts_df = conversions_df.join(carts_df, "user_id", "left")
email_carts_df.show(5)

In [None]:
# Broadcast join - forces the small DataFrame to be broadcast to each executor
# Avoids a full shuffle of the large DataFrame

# Check the auto-broadcast threshold (default: 10 MB)
print("autoBroadcastJoinThreshold:", spark.conf.get("spark.sql.autoBroadcastJoinThreshold"))

# Small lookup: event name → event type
event_type_df = (
    events_df.select("event_name").distinct()
    .withColumn(
        "event_type",
        F.when(F.col("event_name").isin("register", "login"), "initial")
         .when(F.col("event_name").isin("checkout", "cart", "finalize"), "purchase")
         .otherwise("other")
    )
)

# Explicit broadcast hint on the small DataFrame
events_enriched_df = events_df.join(
    event_type_df.hint("broadcast"),
    "event_name"
)

events_enriched_df.select("event_name", "event_type", "device").show(5)

### Reading the Query Plan

`explain(True)` shows four plan levels: Parsed, Analyzed, Optimized, and Physical.  
In practice you care most about the **Physical plan** - look for `BroadcastHashJoin` vs `SortMergeJoin`, pushed filters (`PushedFilters` in the scan node), and which columns are actually read (`ReadSchema`).  
This is how you verify the optimizer is doing what you expect before the job runs.


In [None]:
# Inspect the physical plan for the broadcast join we just built
events_enriched_df.explain(True)


In [None]:
# Compare: without the broadcast hint, Spark may choose SortMergeJoin
# - more expensive for small dimension tables
# On your laptop Spark may auto-broadcast small tables via AQE.
# On a cluster with large tables the difference matters.
events_unoptimized = (
    events_df.select("*")               # unnecessary: reads all columns even unused ones
    .join(event_type_df, "event_name")  # no broadcast hint on the small lookup table
    .select("event_name", "event_type", "device")
)

events_unoptimized.explain(True)


---
## 14. User-Defined Functions (UDFs)

UDFs let you apply arbitrary Python logic to DataFrame columns.  
**Caveat:** UDFs are executed row-by-row in the Python interpreter, bypassing Spark's Catalyst optimizer - they are significantly slower than built-in functions.  
Always prefer built-in `pyspark.sql.functions` over UDFs; use UDFs only when no built-in equivalent exists.

In [None]:
# Step 1: define a regular Python function
def first_letter(email):
    return email[0]

# Step 2: wrap it as a Spark UDF (default return type is StringType)
first_letter_udf = F.udf(first_letter)

# Step 3: apply it in a select or withColumn
sales_df.select(first_letter_udf(F.col("email")).alias("first_letter")).show(5)

In [None]:
# Haversine distance UDF - computes great-circle distance in km
# Return type must be declared explicitly for non-string types
from pyspark.sql.types import DoubleType
import math

def haversine_km(lat1, lon1, lat2, lon2):
    R = 6371.0                         # Earth radius in km
    dlat = math.radians(lat2 - lat1)
    dlon = math.radians(lon2 - lon1)
    a = (math.sin(dlat / 2) ** 2
         + math.cos(math.radians(lat1))
         * math.cos(math.radians(lat2))
         * math.sin(dlon / 2) ** 2)
    return R * 2 * math.atan2(math.sqrt(a), math.sqrt(1 - a))

geo_distance_udf = F.udf(haversine_km, DoubleType())

In [None]:
# Countries dataset: ISO code, lat, lon, name
countries_df = (
    spark.read
    .option("header", "true")
    .option("inferSchema", "true")
    .csv("data/input/countries.csv")
)

# Cross join Estonia against all other countries and compute distance
distance_df = (
    countries_df.filter(F.col("country") == "EE")
    .join(
        countries_df
        .toDF("join_country", "join_latitude", "join_longitude", "join_name")
        .na.drop()
    )  # no key = cross join
    .withColumn(
        "distance_km",
        geo_distance_udf("latitude", "longitude", "join_latitude", "join_longitude")
    )
    .select("join_name", "distance_km")
    .orderBy("distance_km")
)

distance_df.show(10)

---
## 15. Exercises

Work through the exercises below. Each exercise builds on datasets and patterns from this notebook.


### Exercise 1 - Diagnose and Fix This Pipeline

The pipeline below works but has performance problems. Your task:

1. Run the pipeline and call `.explain(True)` on the final result
2. Identify at least two problems in the query plan
3. Fix each problem and verify with `.explain(True)` that the plan improved
4. Write a brief comment above each fix explaining what you changed and why

Hints: think about join strategy, filter placement, and column selection.


In [None]:
# --- Intentionally suboptimal pipeline - diagnose and fix ---

# Step 1: read events and select all columns
events_raw = spark.read.option("inferSchema", True).json("data/input/events-500k.json")

events_all_columns = events_raw.select("*")

# Step 2: join with users - users table is small (~200k rows)
users_all = spark.read.parquet("data/input/users.parquet")

joined = events_all_columns.join(users_all, events_all_columns.user_id == users_all.user_id)

# Step 3: filter AFTER the join
result = (
    joined
    .filter(F.col("event_name") == "finalize")
    .filter(F.col("ecommerce.total_item_quantity") > 0)
    .select("email", "event_name", "ecommerce.purchase_revenue_in_usd")
    .groupBy("email")
    .agg(F.sum("purchase_revenue_in_usd").alias("total_revenue"))
    .orderBy(F.desc("total_revenue"))
)

result.explain(True)
result.show(10)


### Exercise 2 - Choose a Partition Key

You have the airlines dataset as a flat Parquet file. Three teams need to query it:

- **Team A** runs a daily dashboard filtered by `airport_code` showing monthly trends
- **Team B** runs monthly reports aggregating all airports for a given `year` and `month`
- **Team C** occasionally looks up a specific `airport_code + year + month` combination

You can only choose **one** partition layout.

1. Write the table with your chosen `.partitionBy(...)` key(s)
2. For each team, write a read query with a filter and call `.explain(True)`
3. Check which queries get `PartitionFilters` (partition pruning) and which don’t
4. In a comment, explain your choice: which team benefits, which team doesn’t, and why you made that tradeoff

There is no single right answer - the point is to justify your choice.


In [None]:
# Use airlines_flat_df from Section 7 - write it with your chosen partition key(s)
# airlines_flat_df.write.mode("overwrite").partitionBy(???).parquet("data/output/airlines_exercise")

# Team A query: filtered by airport_code, all months

# Team B query: filtered by year and month, all airports

# Team C query: filtered by airport_code + year + month


---
## Further Reading

- [PySpark SQL Functions API](https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/functions.html)
- [Spark SQL Built-in Functions](https://spark.apache.org/docs/latest/api/sql/index.html)
- [Window Functions in PySpark](https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/window.html)
- [UDF performance deep-dive](https://medium.com/quantumblack/spark-udf-deep-insights-in-performance-f0a95a4d8c62)
- [Understanding Spark Query Plans](https://spark.apache.org/docs/latest/sql-performance-tuning.html)
