# Optimization Techniques for PySpark ETL

In [0]:
# Load our data
display(dbutils.fs.ls("abfss://etl1@dbstoragebbpbs73u57xmm.dfs.core.windows.net/Exercise2/"))
file_path = "abfss://etl1@dbstoragebbpbs73u57xmm.dfs.core.windows.net/Exercise2/"

df = spark.read.option("header", "true").option("inferSchema", "true").csv(file_path)

user_path = "abfss://etl1@dbstoragebbpbs73u57xmm.dfs.core.windows.net/Exercise2/users.csv"
users_df = spark.read.option("header", "true").option("inferSchema", "true").csv(user_path)

#### OPTIMIZATION 1: BROADCAST JOINS

In [0]:
# 1. Regular join (will use shuffle)
# First, let's disable automatic broadcasting to clearly show the difference
print("Disabling automatic broadcast join optimization...")
original_broadcast_threshold = spark.conf.get("spark.sql.autoBroadcastJoinThreshold")
spark.conf.set("spark.sql.autoBroadcastJoinThreshold", "-1")  # Disable automatic broadcasting

# Clear everything to ensure a clean demonstration
spark.catalog.clearCache()

# 1. Regular join (will use shuffle)
print("\n1. REGULAR JOIN WITH SHUFFLE:")
print("-----------------------------")
regular_join = df.join(users_df, on="user_id")
print("Execution plan for regular join:")
regular_join.explain(mode="formatted")

# Look for Exchange hashpartitioning in the plan - this indicates shuffling

In [0]:
# 2. Explicit broadcast join
from pyspark.sql.functions import broadcast

print("\n2. BROADCAST JOIN:")
print("-----------------")
broadcast_join = df.join(broadcast(users_df), on="user_id")
print("Execution plan for broadcast join:")
broadcast_join.explain(mode="formatted")

# Look for BroadcastExchange or BroadcastHashJoin in the plan - this indicates broadcasting


In [0]:
# Compare query times
import time

print("\nComparing performance:")
print("---------------------")

# Time the regular join
start_time = time.time()
regular_count = regular_join.count()
regular_time = time.time() - start_time
print(f"Regular join time: {regular_time:.2f} seconds for {regular_count} records")

# Time the broadcast join
start_time = time.time()
broadcast_count = broadcast_join.count()
broadcast_time = time.time() - start_time
print(f"Broadcast join time: {broadcast_time:.2f} seconds for {broadcast_count} records")

print(f"Speedup: {regular_time/broadcast_time:.2f}x faster with broadcast join")

# Restore original broadcast threshold
spark.conf.set("spark.sql.autoBroadcastJoinThreshold", original_broadcast_threshold)
print(f"\nRestored broadcast threshold to original value: {original_broadcast_threshold}")

#### OPTIMIZATION 2: CACHING STRATEGIES

In [0]:
# Cache intermediate results that are used multiple times
import time

print("Caching demonstration:")
# First, clear any existing cache
df.unpersist(blocking=True)

# Non-cached approach
start_time = time.time()
result1 = df.groupBy("device_type").count()
result1.show()

result2 = df.groupBy("device_type").agg({"duration_seconds": "mean"})
result2.show()
end_time = time.time()
print(f"Non-cached execution time: {end_time - start_time:.2f} seconds")

# Cached approach
df.cache() # Cache the dataframe
df.count() # Force cache evaluation
start_time = time.time()
result1 = df.groupBy("device_type").count()
result1.show()

result2 = df.groupBy("device_type").agg({"duration_seconds": "mean"})
result2.show()
end_time = time.time()
print(f"Cached execution time: {end_time - start_time:.2f} seconds")

# Remember to unpersist when done
df.unpersist()

#### OPTIMIZATION 3: PARTITIONED WRITES

In [0]:
# Partition by columns that are frequently used in filters

# Save to parquet with partitioning
print("Partitioned Writes Demonstration:")
print("--------------------------------")
output_path = "/pyspark/video-streaming-data/module3-transform/optimization/optimized_output"

# See the impact of partitioning
print("Writing data partitioned by device_type...")
df.write.partitionBy("device_type").mode("overwrite").parquet(output_path)

print(f"\nPartition structure created at {output_path}:")
display(dbutils.fs.ls(output_path))

In [0]:
print("\nBenefits of partitioned writes:")
print("1. Enables partition pruning - Spark can skip irrelevant partitions")
print("2. Enables parallel reads - Different partitions can be read simultaneously")
print("3. Supports partition-aware queries - Filters on partition columns are much faster")