Performance & Optimization in PySpark. This lesson gives the practical patterns you’ll use most as a senior data scientist: partitioning, caching, shuffle-aware joins, file-format tips, and tools to inspect/measure performance.

I’ll explain concepts first, then show concise, runnable code (a new file you can drop into your repo) with detailed comments so you can run experiments locally and push to GitHub.

In [5]:
"""
Lesson 4: Performance & Optimization Examples (PySpark)
-------------------------------------------------------
Run:
    source venv/bin/activate
    python src/lesson4_performance.py

This script demonstrates:
 - inspecting / changing partitions
 - caching/persisting
 - explain() plans
 - broadcast joins
 - repartition vs coalesce
 - writing parquet with good partitioning
"""

from pathlib import Path
from pyspark.sql import SparkSession
from pyspark.sql import functions as F
from pyspark.sql.functions import broadcast
import urllib.request

# ---------------------------
# 1) Setup SparkSession
# ---------------------------
# ---------------------------
# 1) Setup SparkSession
# ---------------------------
spark = (
    SparkSession.builder
    .appName("lesson4-performance")
    .master("local[*]")
    .config("spark.sql.shuffle.partitions", "8")  # adjust for local dev
    .getOrCreate()
)

print("Spark version:", spark.version)
print("spark.sql.shuffle.partitions =", spark.conf.get("spark.sql.shuffle.partitions"))
print("-" * 60)

# ---------------------------
# 2) Download small dataset (Iris) and create a larger DF for experiments
# ---------------------------
iris_url = "https://raw.githubusercontent.com/uiuc-cse/data-fa14/gh-pages/data/iris.csv"
local_path = Path("/tmp/iris.csv")
urllib.request.urlretrieve(iris_url, str(local_path))

# read once
iris_df = spark.read.csv(str(local_path), header=True, inferSchema=True)

# Create a bigger dataset by repeating iris_df several times (to simulate larger data)
# Note: for real benchmarks use large public datasets (NYC taxi, etc.)
big_df = iris_df
REPEATS = 200  # increase to simulate larger dataset; keeps memory use moderate locally
for i in range(REPEATS - 1):
    big_df = big_df.union(iris_df)

print("Estimated rows (approx):", big_df.count())
print("Partitions (initial):", big_df.rdd.getNumPartitions())
print("-" * 60)

# ---------------------------
# 3) Repartition vs Coalesce
# ---------------------------
print("-> Repartitioning to 16 partitions (causes shuffle)")
big_repart = big_df.repartition(16)
print("Partitions after repartition:", big_repart.rdd.getNumPartitions())

print("-> Coalescing to 4 partitions (no full shuffle, may move data unevenly)")
big_coalesce = big_repart.coalesce(4)
print("Partitions after coalesce:", big_coalesce.rdd.getNumPartitions())
print("-" * 60)

# ---------------------------
# 4) Caching / Persisting
# ---------------------------
# Use caching when you plan to re-use the same DataFrame multiple times.
print("-> Caching big_coalesce")
big_coalesce.cache()  # default MEMORY_ONLY
# Trigger cache materialization via an action
print("Count (materialize cache):", big_coalesce.count())

# Now repeated actions are faster (avoid re-compute)
print("Count again (should be faster):", big_coalesce.count())
print("-" * 60)

# ---------------------------
# 5) Explain plan: check for shuffles, broadcasts
# ---------------------------
# Simple aggregation (should cause shuffle due to groupBy)
agg = big_coalesce.groupBy("species").agg(
    F.count("*").alias("cnt"),
    F.avg("sepal_length").alias("avg_sepal_len")
)
print("Explain plan for aggregation (logical and physical):")
agg.explain(extended=True)  # extended True prints analyzed, optimized, physical plan
print("-" * 60)

# ---------------------------
# 6) Broadcast join example
# ---------------------------
# Create a small lookup table and join. Broadcast the small DF to avoid shuffle on join.
species_lookup = spark.createDataFrame(
    [("setosa", "short petals"), ("versicolor", "medium petals"), ("virginica", "long petals")],
    ["species", "description"]
)

# Use broadcast() to force a broadcast join (avoids shuffle)
joined = big_coalesce.join(broadcast(species_lookup), on="species", how="inner")
print("Explain plan for broadcast join (should include BroadcastHashJoin):")
joined.select("species", "description").explain()
print("-" * 60)

# ---------------------------
# 7) Skew mitigation (pattern)
# ---------------------------
# If one key is very heavy (skew), you can salt the key:
# df.withColumn("salt", F.expr("floor(rand() * N)")) then join on (key, salt) after duplicating small DF etc.
print("Skew mitigation pattern: add 'salt' to keys to spread large-key loads across partitions.")
print("Example (not executed heavy):")
print("""
# pseudo:
from pyspark.sql.functions import floor, rand
N = 10  # number of salts
big_salted = big_coalesce.withColumn('salt', floor(rand() * N))
small_salted = small_df.withColumn('salt', F.lit(0))  # or replicate small_df for salt range
# join on both key and salt
""")
print("-" * 60)

# ---------------------------
# 8) Predicate pushdown & writing parquet
# ---------------------------
# Save to parquet partitioned by 'species' (good for queries filtering by species)
output_parquet = "output/iris_parquet_partitioned"
print(f"Writing partitioned parquet to {output_parquet} (this will create one folder per species)")
joined.write.mode("overwrite").partitionBy("species").parquet(output_parquet)

# Read back with predicate pushdown (only species=setosa will be scanned)
print("Read back with predicate pushdown - only scans partition setosa")
subset = spark.read.parquet(output_parquet).filter(F.col("species") == "setosa")
print("Explain plan for filtered read (should show PartitionFilters):")
subset.explain()
print("Subset count (should be small):", subset.count())
print("-" * 60)

# ---------------------------
# 9) Checkpointing (for very long lineage)
# ---------------------------
# If your lineage graph becomes very long (many transformations), use checkpointing to truncate lineage.
# Requires setting a checkpoint dir on HDFS or local FS for local dev.
spark.sparkContext.setCheckpointDir("/tmp/spark-checkpoint")
short_df = big_coalesce.checkpoint(eager=True)  # truncates lineage immediately
print("Checkpoint done. Partitions:", short_df.rdd.getNumPartitions())
print("-" * 60)

# ---------------------------
# 10) Cleanup: unpersist and stop
# ---------------------------
big_coalesce.unpersist()
spark.stop()
print("Done. Spark stopped.")


Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
25/11/08 11:26:24 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable


Spark version: 3.5.7
spark.sql.shuffle.partitions = 8
------------------------------------------------------------


                                                                                

Estimated rows (approx): 30000
Partitions (initial): 200
------------------------------------------------------------
-> Repartitioning to 16 partitions (causes shuffle)




Partitions after repartition: 16
-> Coalescing to 4 partitions (no full shuffle, may move data unevenly)




Partitions after coalesce: 4
------------------------------------------------------------
-> Caching big_coalesce


                                                                                

Count (materialize cache): 30000
Count again (should be faster): 30000
------------------------------------------------------------
Explain plan for aggregation (logical and physical):
== Parsed Logical Plan ==
'Aggregate ['species], ['species, count(1) AS cnt#6328L, avg('sepal_length) AS avg_sepal_len#6330]
+- Repartition 4, false
   +- Repartition 16, true
      +- Union false, false
         :- Relation [sepal_length#17,sepal_width#18,petal_length#19,petal_width#20,species#21] csv
         :- Relation [sepal_length#27,sepal_width#28,petal_length#29,petal_width#30,species#31] csv
         :- Relation [sepal_length#37,sepal_width#38,petal_length#39,petal_width#40,species#41] csv
         :- Relation [sepal_length#47,sepal_width#48,petal_length#49,petal_width#50,species#51] csv
         :- Relation [sepal_length#57,sepal_width#58,petal_length#59,petal_width#60,species#61] csv
         :- Relation [sepal_length#67,sepal_width#68,petal_length#69,petal_width#70,species#71] csv
         :-

                                                                                

Read back with predicate pushdown - only scans partition setosa
Explain plan for filtered read (should show PartitionFilters):
== Physical Plan ==
*(1) ColumnarToRow
+- FileScan parquet [sepal_length#6609,sepal_width#6610,petal_length#6611,petal_width#6612,description#6613,species#6614] Batched: true, DataFilters: [], Format: Parquet, Location: InMemoryFileIndex(1 paths)[file:/Users/debabratapati/Documents/pyspark/my_pyspark_course/research..., PartitionFilters: [isnotnull(species#6614), (species#6614 = setosa)], PushedFilters: [], ReadSchema: struct<sepal_length:double,sepal_width:double,petal_length:double,petal_width:double,description:...


Subset count (should be small): 10000
------------------------------------------------------------
Checkpoint done. Partitions: 4
------------------------------------------------------------
Done. Spark stopped.


Core concepts (short & practical)

Partitions — Spark splits data into partitions. Operations run per-partition in parallel. Too few partitions → underutilized CPU; too many → overhead.

Check partitions: df.rdd.getNumPartitions() or df.rdd.getNumPartitions() for RDDs.

Change partitions: repartition(n) (full shuffle), coalesce(n) (no shuffle, only decrease).

Shuffle — expensive distributed data movement (caused by groupBy, join, repartition, aggregations). Minimize shuffle where possible.

Caching / Persisting — store intermediate DataFrames in memory/disk to avoid recomputation. Use df.cache() / df.persist(StorageLevel.MEMORY_AND_DISK) and df.unpersist() when done.

Broadcast Join — small table broadcast to all executors avoids shuffle for join. Use broadcast(small_df) from pyspark.sql.functions or spark.conf.set("spark.sql.autoBroadcastJoinThreshold", size).

File format & layout — Parquet (columnar, compressed) is preferred for analytics. Partition files by columns for predicate pushdown; beware small-files problem.

Predicate pushdown & projection pruning — Spark will read only necessary columns/partitions for Parquet/ORC; design queries to benefit.

Skew — uneven key distribution hurts groupBy/join. Use salting, broadcast, or split heavy keys.

Explain & UI — df.explain(True) for extended planner + physical plan. Use Spark UI (localhost:4040 in local mode) to inspect stages, tasks, and shuffles.

Memory & configs — tune spark.executor.memory, spark.driver.memory, spark.sql.shuffle.partitions (default often 200 — too high for local); set appropriate values for cluster.

Bucketing — for repeated joins on same keys, bucketing (native to Spark SQL) can reduce shuffle.