# Advanced PySpark Workshop: Performance, Reliability, Streaming, and ML

Purpose: End-to-end, runnable notebook that demonstrates high-impact PySpark patterns for production data engineering.

What you will learn:
- Where shuffles occur; how to control partitions and use broadcast joins
- How to enable AQE, mitigate skew, and read plans
- Schema/pushdown/pruning hygiene and small-files mitigation
- Reliable streaming with watermarks, output modes, and checkpointing
- Leakage-safe ML pipelines with validation and proper feature engineering

Assumptions:
- Local Spark available (Spark 3.2+ recommended). Delta Lake features are optional and shown as notes.
- Paths used are relative to this workspace under `tmp/advanced_demo/`.




In [None]:
# Setup Spark session and environment
import os
from pyspark.sql import SparkSession
from pyspark.sql import functions as F
from pyspark.sql.window import Window

BASE = "tmp/advanced_demo"
os.makedirs(BASE, exist_ok=True)

spark = (
    SparkSession.builder
    .appName("AdvancedPySparkWorkshop")
    .master("local[*]")
    .config("spark.sql.shuffle.partitions", "16")  # small for local demos
    .config("spark.sql.adaptive.enabled", "true")
    .config("spark.sql.session.timeZone", "UTC")
    .config("spark.python.worker.reuse", "true")
    .getOrCreate()
)

spark.sparkContext.setLogLevel("WARN")

print("Spark version:", spark.version)
print("shuffle.partitions:", spark.conf.get("spark.sql.shuffle.partitions"))
print("AQE enabled:", spark.conf.get("spark.sql.adaptive.enabled"))


## Part 1: PySpark Speedrun — Narrow vs Wide, Shuffles, Partitions, Caching

Goals:
- Understand narrow vs wide transformations and where shuffles happen
- Read physical plans with `explain("formatted")`
- Tune `spark.sql.shuffle.partitions` and compare effects
- Cache correctly: cache → materialize → unpersist



In [None]:
# Create a small demo DataFrame
sales = spark.range(0, 10000).withColumn("user_id", (F.col("id") % 100).cast("int")) \
    .withColumn("amount", (F.rand() * 100).cast("double")) \
    .withColumn("dt", F.date_format(F.current_timestamp(), "yyyy-MM-dd"))

# Narrow: filter (no shuffle)
filtered = sales.filter(F.col("amount") > 50)
print("NARROW filter plan:")
filtered.explain("formatted")

# Wide: groupBy (shuffle)
by_user = sales.groupBy("user_id").agg(F.sum("amount").alias("total"))
print("\nWIDE groupBy plan:")
by_user.explain("formatted")


In [None]:
# Compare shuffle partitions impact (very rough timing on local)
import time

spark.conf.set("spark.sql.shuffle.partitions", 8)
start = time.time()
_ = sales.repartition(8, "user_id").groupBy("user_id").count().count()
print("time with 8 partitions:", round(time.time() - start, 3), "s")

spark.conf.set("spark.sql.shuffle.partitions", 64)
start = time.time()
_ = sales.repartition(64, "user_id").groupBy("user_id").count().count()
print("time with 64 partitions:", round(time.time() - start, 3), "s")

# reset to 16 for consistency
spark.conf.set("spark.sql.shuffle.partitions", 16)
print("current shuffle.partitions:", spark.conf.get("spark.sql.shuffle.partitions"))


In [None]:
# Cache lifecycle demo
expensive = sales.withColumn("bucket", (F.col("user_id") % 5)).groupBy("bucket").agg(F.sum("amount").alias("s"))

cached = expensive.cache()
_ = cached.count()  # materialize cache
print("Cached rows:", cached.count())

cached.unpersist()
print("Unpersisted cache.")


## Part 2: Data Engineer's Toolkit — Joins, AQE & Skew, Schema/Pushdown, Small Files

Goals:
- Choose join strategies (broadcast vs sort-merge) and validate in plans
- Enable AQE; see AdaptiveSparkPlan; mitigate skew with salting
- Enforce schemas; demonstrate pruning/pushdown; avoid function-wrapped partition filters
- Avoid small-files explosion; compact sensibly



In [None]:
# Join performance: broadcast vs default
fact = spark.range(0, 1_0000).withColumn("key", (F.col("id") % 1000)).withColumn("v", (F.rand()*10))
dim  = spark.range(0, 1000).withColumnRenamed("id", "key").withColumn("attr", F.expr("concat('A', key)")).select("key", "attr")

# Default join (likely SortMerge or ShuffleHash depending on size)
joined_default = fact.join(dim, "key")
print("Default join plan:")
joined_default.explain("formatted")

from pyspark.sql.functions import broadcast
joined_broadcast = fact.join(broadcast(dim), "key")
print("\nBroadcast join plan:")
joined_broadcast.explain("formatted")


In [None]:
# AQE: compare plans with AQE off vs on
spark.conf.set("spark.sql.adaptive.enabled", "false")
print("AQE:", spark.conf.get("spark.sql.adaptive.enabled"))
plan_off = joined_default.groupBy("attr").agg(F.count("v")).explain("formatted")

spark.conf.set("spark.sql.adaptive.enabled", "true")
print("AQE:", spark.conf.get("spark.sql.adaptive.enabled"))
plan_on = joined_default.groupBy("attr").agg(F.count("v")).explain("formatted")

# Keep AQE on for rest
spark.conf.set("spark.sql.adaptive.enabled", "true")


In [None]:
# Skew mitigation via salting (illustration)
hot = spark.range(0, 100000).select(F.lit(1).alias("k"), (F.rand()*10).alias("m"))
small = spark.createDataFrame([(1, "A")], ["k", "attr"])  # tiny dim

k = 16
hot_salted = hot.withColumn("salt", (F.rand()*k).cast("int"))
small_salted = (small
    .withColumn("salt", F.sequence(F.lit(0), F.lit(k-1)))
    .selectExpr("k", "explode(salt) as salt", "attr")
)

out_no_salt = hot.join(small, "k").groupBy("k").agg(F.sum("m").alias("sum_m"))
out_salt = (hot_salted.join(small_salted, ["k", "salt"]).groupBy("k").agg(F.sum("m").alias("sum_m")))

print("No-salt result:")
out_no_salt.show(3, truncate=False)
print("Salted result (should match):")
out_salt.show(3, truncate=False)


In [None]:
# Schema & pushdown: demonstrate pruning and function-wrapping pitfall
import shutil
from glob import glob

data_path = os.path.join(BASE, "prune_demo")
shutil.rmtree(data_path, ignore_errors=True)

# Create 3 day partitions
rows = spark.range(0, 3000).withColumn("dt", (F.lit(0) + (F.col("id") % 3)).cast("int")) \
    .withColumn("dt", F.expr("case when dt=0 then '2025-11-08' when dt=1 then '2025-11-09' else '2025-11-10' end")) \
    .withColumn("val", F.rand())
rows.repartition(3, "dt").write.partitionBy("dt").mode("overwrite").parquet(data_path)

base_df = spark.read.parquet(data_path)

# Good: direct filter on partition column enables pruning
pruned = base_df.filter("dt = '2025-11-10'").withColumn("src", F.input_file_name())
print("Pruned distinct files:", pruned.select("src").distinct().count())

# Bad: wrapping partition column prevents pruning
wrapped = base_df.filter(F.year(F.to_date("dt")) == 2025).withColumn("src", F.input_file_name())
print("Wrapped distinct files (more scanned):", wrapped.select("src").distinct().count())


In [None]:
# Small files: write many small vs compacted files
small_path = os.path.join(BASE, "small_files")
compacted_path = os.path.join(BASE, "small_files_compacted")
shutil.rmtree(small_path, ignore_errors=True)
shutil.rmtree(compacted_path, ignore_errors=True)

many = spark.range(0, 20000).withColumn("dt", F.lit("2025-11-10"))
# Bad: too many partitions → many small files
many.repartition(40).write.mode("overwrite").parquet(small_path)
# Better: coalesce before write
many.coalesce(4).write.mode("overwrite").parquet(compacted_path)

num_small = len(glob(os.path.join(small_path, "*.parquet")))
num_comp = len(glob(os.path.join(compacted_path, "*.parquet")))
print("files (small):", num_small, "| files (compacted):", num_comp)


## Part 3: ML Capstone — Leakage-Safe Pipeline, Validation, Feature Engineering

Goals:
- Build a Pipeline with proper train-only fitting of transformers
- Use TrainValidationSplit (or CrossValidator) for robust evaluation
- Feature engineering with StringIndexer, OneHotEncoder, StandardScaler
- Persist/unpersist appropriately and avoid collecting large data



In [None]:
# Synthetic classification dataset
df_ml = (spark.range(0, 20000)
    .withColumn("age", (F.rand()*50 + 18).cast("int"))
    .withColumn("country", F.expr("case when id % 3 = 0 then 'US' when id % 3 = 1 then 'IN' else 'DE' end"))
    .withColumn("income", (F.rand()*90000 + 10000))
)
# target with some non-linear signal
from pyspark.sql.functions import when

df_ml = df_ml.withColumn(
    "label",
    when((F.col("age") > 40) & (F.col("income") > 60000), 1).otherwise(0)
)

train, test = df_ml.randomSplit([0.8, 0.2], seed=42)

# Cache train as it is reused
train_cache = train.cache()
_ = train_cache.count()



In [None]:
# Pipeline: StringIndexer -> OneHotEncoder -> Assembler -> StandardScaler -> LogisticRegression
from pyspark.ml import Pipeline
from pyspark.ml.feature import StringIndexer, OneHotEncoder, VectorAssembler, StandardScaler
from pyspark.ml.classification import LogisticRegression
from pyspark.ml.tuning import TrainValidationSplit, ParamGridBuilder
from pyspark.ml.evaluation import BinaryClassificationEvaluator

country_ix = StringIndexer(inputCol="country", outputCol="country_ix", handleInvalid="keep")
country_ohe = OneHotEncoder(inputCols=["country_ix"], outputCols=["country_ohe"], handleInvalid="keep")

assembler = VectorAssembler(
    inputCols=["age", "income", "country_ohe"],
    outputCol="features_raw"
)
scaler = StandardScaler(inputCol="features_raw", outputCol="features")
clf = LogisticRegression(featuresCol="features", labelCol="label", maxIter=30)

pipe = Pipeline(stages=[country_ix, country_ohe, assembler, scaler, clf])

evalr = BinaryClassificationEvaluator(labelCol="label", rawPredictionCol="rawPrediction", metricName="areaUnderROC")

paramGrid = (
    ParamGridBuilder()
    .addGrid(clf.regParam, [0.0, 0.01, 0.1])
    .addGrid(clf.elasticNetParam, [0.0, 0.5, 1.0])
    .build()
)

tvs = TrainValidationSplit(estimator=pipe, estimatorParamMaps=paramGrid, evaluator=evalr, trainRatio=0.8)
model = tvs.fit(train_cache)

pred_test = model.transform(test)
auc = evalr.evaluate(pred_test)
print("Test AUC:", round(auc, 4))

# Unpersist cached data
train_cache.unpersist()


## Part 4: Structured Streaming — Watermarks, Output Modes, Checkpointing, foreachBatch

Goals:
- Use watermarks to bound state on event-time aggregations
- Choose correct output mode for operation
- Set checkpointing for reliable restarts
- Demonstrate a simple foreachBatch pattern (idempotent sink concept)

Note: We use short-lived queries with `availableNow=True` (Spark 3.3+) or a brief timeout to keep this runnable.



In [None]:
# Streaming source: rate (synthetic)
chk = os.path.join(BASE, "chk_rate")
shutil.rmtree(chk, ignore_errors=True)

rate = (spark.readStream.format("rate").option("rowsPerSecond", 50).load()
         .select(F.col("value").alias("user_id"), F.col("timestamp").alias("ts")))

# Windowed aggregation with watermark
agg = (rate
       .withWatermark("ts", "30 seconds")
       .groupBy(F.window("ts", "20 seconds"))
       .agg(F.count("user_id").alias("cnt")))

# Memory sink for easy demo: try availableNow, else short timeout
try:
    q = (agg.writeStream
         .outputMode("update")
         .format("memory")
         .queryName("agg_demo")
         .option("checkpointLocation", chk)
         .trigger(availableNow=True)
         .start())
    q.awaitTermination()
except TypeError:
    q = (agg.writeStream
         .outputMode("update")
         .format("memory")
         .queryName("agg_demo")
         .option("checkpointLocation", chk)
         .start())
    q.awaitTermination(5)



In [None]:
# Inspect memory sink results
try:
    spark.sql("SELECT window.start, window.end, cnt FROM agg_demo ORDER BY window.start").show(truncate=False)
except Exception as e:
    print("Memory table not available yet:", e)


In [None]:
# foreachBatch skeleton: simulate idempotent sink semantics
sink_path = os.path.join(BASE, "sink_foreach")
shutil.rmtree(sink_path, ignore_errors=True)

from typing import Any

def upsert_like(batch_df, batch_id: int):
    # For demo, just drop duplicates within batch and append
    # In production, use Delta MERGE or JDBC upsert
    (batch_df
     .dropDuplicates(["user_id"])  # toy idempotency within batch
     .withColumn("batch_id", F.lit(batch_id))
     .write.mode("append").parquet(sink_path))

# Short stream to call foreachBatch once
chk2 = os.path.join(BASE, "chk_foreach")
shutil.rmtree(chk2, ignore_errors=True)

rate2 = spark.readStream.format("rate").option("rowsPerSecond", 20).load() \
    .select(F.col("value").alias("user_id"), F.col("timestamp").alias("ts"))

try:
    q2 = (rate2.writeStream
          .foreachBatch(upsert_like)
          .option("checkpointLocation", chk2)
          .trigger(availableNow=True)
          .start())
    q2.awaitTermination()
except TypeError:
    q2 = (rate2.writeStream
          .foreachBatch(upsert_like)
          .option("checkpointLocation", chk2)
          .start())
    q2.awaitTermination(5)

print("Parquet files written:", len(glob(os.path.join(sink_path, "*.parquet"))))


## Cleanup and Run Notes

- This notebook uses local `tmp/advanced_demo/` for outputs and checkpoints.
- On re-run, previous outputs are cleaned for determinism.
- Delta-specific operations (e.g., OPTIMIZE, VACUUM) are noted conceptually; if you have Delta Lake, you can adapt the foreachBatch upsert to Delta MERGE.
- If `availableNow=True` is unsupported in your Spark version, the streaming examples fall back to a short `awaitTermination(5)` timeout.

When finished, stop Spark:



In [None]:
spark.stop()
print("Spark stopped.")
