# Part 3: Spark Advanced

**Objective**: Identify, diagnose, and fix the most common Spark performance issues in batch workloads.


**What You'll Learn**:
1. Optimizing joins with broadcast
2. Avoiding Python UDF pitfalls
3. Leveraging Adaptive Query Execution (AQE)


In [None]:
# Setup: Import required libraries
from pyspark.sql.functions import *
from pyspark.sql.types import *
import time

# Load TPC-DS datasets (built into Databricks) - Scale Factor 1 (~1GB)
# TPC-DS is a more complex benchmark with larger datasets perfect for performance testing
# These datasets simulate a retail environment with stores, customers, and sales
customers_df = spark.read.table("samples.tpcds_sf1.customer")
store_sales_df = spark.read.table("samples.tpcds_sf1.store_sales")
item_df = spark.read.table("samples.tpcds_sf1.item")
date_dim_df = spark.read.table("samples.tpcds_sf1.date_dim")

print(f"Customers: {customers_df.count():,} rows")
print(f"Store Sales: {store_sales_df.count():,} rows")
print(f"Items: {item_df.count():,} rows")
print(f"Date Dimension: {date_dim_df.count():,} rows")


## Issue #1: Broadcast Joins for Small Dimensions

**The Problem**: Sort-merge joins shuffle BOTH sides of the join, even when one side is tiny.

**Symptoms**:
- Unnecessary shuffles on small dimension tables
- Slow joins with reference/lookup tables

**Solution**: Use broadcast joins for small tables (< 100MB typically).


In [None]:
### ‚ùå BAD: Default sort-merge join (shuffles both sides)
# Even though item is small, Spark shuffles it

start_time = time.time()

# Item table is a small dimension table (perfect for broadcast)
# Join store_sales with item to get item details
# Note: We need to select ss_customer_sk for the second join
bad_broadcast = store_sales_df.select("ss_item_sk", "ss_customer_sk", "ss_sales_price", "ss_quantity") \
    .join(
        item_df.select("i_item_sk", "i_item_id", "i_category"),
        col("ss_item_sk") == col("i_item_sk")
    ).join(
        customers_df.select("c_customer_sk", "c_first_name", "c_last_name").limit(5000),
        col("ss_customer_sk") == col("c_customer_sk")
    ).groupBy("i_category", "c_first_name").agg(
        sum("ss_sales_price").alias("total_spent")
    )

result_bad = bad_broadcast.limit(10).collect()
bad_broadcast_time = time.time() - start_time

print(f"‚è±Ô∏è Time taken (NO BROADCAST): {bad_broadcast_time:.2f}s")
print("üîç Check Spark UI: See SortMergeJoin with shuffles on BOTH sides")


In [None]:
### ‚úÖ GOOD: Explicit broadcast join (no shuffle on small side)

from pyspark.sql.functions import broadcast

start_time = time.time()

small_customers = customers_df.select("c_customer_sk", "c_first_name", "c_last_name").limit(5000)

good_broadcast = store_sales_df.select("ss_item_sk", "ss_customer_sk", "ss_sales_price", "ss_quantity") \
    .join(
        broadcast(item_df.select("i_item_sk", "i_item_id", "i_category")),  # Broadcast item dimension
        col("ss_item_sk") == col("i_item_sk")
    ).join(
        broadcast(small_customers),  # Broadcast small customer subset
        col("ss_customer_sk") == col("c_customer_sk")
    ).groupBy("i_category", "c_first_name").agg(
        sum("ss_sales_price").alias("total_spent")
    )

result_good = good_broadcast.limit(10).collect()
good_broadcast_time = time.time() - start_time

print(f"‚è±Ô∏è Time taken (WITH BROADCAST): {good_broadcast_time:.2f}s")
print(f"üöÄ Speedup: {bad_broadcast_time/good_broadcast_time:.1f}x faster!")
print("\nüîç Check Spark UI: See BroadcastHashJoin (no shuffle on small side!)")
print("\nüí° Golden Rule: broadcast(dim_table) for small lookups!")


## Issue #2: Python UDF Performance Killer

**The Problem**: Python UDFs serialize data row-by-row between JVM and Python, killing performance.

**Symptoms**:
- Low CPU utilization
- Much slower than expected
- High overhead in stages with UDFs

**Solution**: Use built-in Spark SQL functions OR vectorized pandas UDFs.


In [None]:
### ‚ùå BAD: Python UDF (row-by-row serialization overhead)
 
from pyspark.sql.types import DoubleType
 
# Define a simple discount calculation UDF
@udf(returnType=DoubleType())
def calculate_discount_udf(price, quantity):
    if price is None or quantity is None:
        return 0.0
    price = float(price)
    if quantity >= 50:
        return price * 0.15
    elif quantity >= 20:
        return price * 0.10
    elif quantity >= 10:
        return price * 0.05
    else:
        return 0.0
 
start_time = time.time()
 
bad_udf = store_sales_df.select(
    "ss_sales_price",
    "ss_quantity"
).withColumn(
    "discount_amount",
    calculate_discount_udf(col("ss_sales_price"), col("ss_quantity"))
).agg(
    sum("discount_amount").alias("total_discounts")
)
 
result = bad_udf.collect()
bad_udf_time = time.time() - start_time
 

In [None]:
### ‚úÖ GOOD: Built-in Spark SQL functions (pure JVM, no serialization)

start_time = time.time()

good_builtin = store_sales_df.select(
    "ss_sales_price", 
    "ss_quantity"
).withColumn(
    "discount_amount",
    when(col("ss_quantity") >= 50, col("ss_sales_price") * 0.15)
    .when(col("ss_quantity") >= 20, col("ss_sales_price") * 0.10)
    .when(col("ss_quantity") >= 10, col("ss_sales_price") * 0.05)
    .otherwise(0.0)
).agg(
    sum("discount_amount").alias("total_discounts")
)

result = good_builtin.collect()
good_builtin_time = time.time() - start_time

print(f"‚è±Ô∏è Time taken (BUILT-IN): {good_builtin_time:.2f}s")
print(f"üöÄ Speedup: {bad_udf_time/good_builtin_time:.1f}x faster!")
print("\nüí° Golden Rule: Use when/case_when over Python UDFs!")


## Issue #3: Data Skew - The Silent Killer

**The Problem**: Uneven key distribution causes few tasks to process most data while others idle.

**Real-World Example**: Christmas shopping season! Dec 24-26 can have 50-100x more transactions than regular days, creating massive hot keys.

**Symptoms**:
- One or few tasks taking 10-100x longer than others
- Stage time dominated by stragglers
- Wasted cluster resources

**Solution**: Identify skewed keys and apply salting or repartitioning.


In [None]:
### Step 1: Detect Skew - Analyze key distribution

# Join with date dimension to get actual dates
sales_with_dates = store_sales_df.join(
    date_dim_df.select("d_date_sk", "d_date", "d_month_seq"),
    store_sales_df.ss_sold_date_sk == date_dim_df.d_date_sk
)

# Analyze sales distribution by date (Christmas season will show massive skew)
skew_analysis = sales_with_dates.groupBy("d_date").agg(
    count("*").alias("sales_count"),
    sum("ss_sales_price").alias("total_revenue")
).orderBy(desc("sales_count"))

print("üìä Top dates by sales count (potential hot keys - notice Dec 24-26!):")
skew_analysis.show(10)

# Get statistics
stats = skew_analysis.agg(
    min("sales_count").alias("min"),
    max("sales_count").alias("max"),
    avg("sales_count").alias("avg"),
    expr("percentile(sales_count, 0.95)").alias("p95")
).collect()[0]

print(f"\nüìà Skew Statistics:")
print(f"   Min sales per day: {stats['min']}")
print(f"   Max sales per day: {stats['max']}")
print(f"   Avg sales per day: {stats['avg']:.1f}")
print(f"   95th percentile: {stats['p95']:.1f}")
print(f"   üî• Skew factor: {stats['max']/stats['avg']:.1f}x above average!")
print(f"\nüéÑ This is the 'Christmas Effect' - peak shopping days dominate!")


### Step 2: ‚ùå BAD - Direct aggregation on skewed dates (Christmas hot keys)


In [None]:
# Create shopping period buckets with Christmas as the hot key
sales_with_skew = sales_with_dates.withColumn(
    "shopping_period",
    # Simulate extreme Christmas skew - Dec 24-26 are hot keys
    when(col("d_date").isin("2000-12-24", "2000-12-25", "2000-12-26"), "CHRISTMAS_PEAK")
    .when(col("d_date").between("2000-12-01", "2000-12-23"), "DECEMBER")
    .when(col("d_date").between("2000-11-01", "2000-11-30"), "NOVEMBER")
    .otherwise(concat(lit("MONTH_"), month(col("d_date")).cast("string")))
)

start_time = time.time()

# Direct aggregation on skewed key (Christmas days will bottleneck)
bad_skew = sales_with_skew.groupBy("shopping_period").agg(
    count("*").alias("sales_count"),
    sum("ss_sales_price").alias("total_revenue"),
    avg("ss_sales_price").alias("avg_sale_value")
).orderBy(desc("sales_count"))

result = bad_skew.collect()
bad_skew_time = time.time() - start_time

print(f"‚è±Ô∏è Time taken (SKEWED - Christmas bottleneck): {bad_skew_time:.2f}s")
print("üîç Check Spark UI: See task time distribution - Christmas task takes forever!")
print("\nüéÑ Shopping periods by volume:")
for row in result[:8]:
    print(f"  {row['shopping_period']}: {row['sales_count']:,} sales")


### Step 3: ‚úÖ GOOD - Salting technique to distribute Christmas hot keys

**Salting**: Add random suffix to Christmas peak days, aggregate in parallel across tasks, then remove salt.

**How it works**: Instead of one task handling ALL Christmas sales, we split it into 10 sub-tasks!


In [None]:
start_time = time.time()

# Apply salting: add random salt to distribute Christmas peak load
SALT_FACTOR = 10  # Split Christmas hot key into 10 sub-keys for parallel processing

sales_salted = sales_with_skew.withColumn(
    "salted_period",
    when(
        col("shopping_period") == "CHRISTMAS_PEAK",  # Only salt the hot key
        concat(col("shopping_period"), lit("_SALT_"), (rand() * SALT_FACTOR).cast("int").cast("string"))
    ).otherwise(col("shopping_period"))
)

# Aggregate on salted keys (distributes Christmas load across 10 tasks!)
good_skew = sales_salted.groupBy("salted_period").agg(
    count("*").alias("sales_count"),
    sum("ss_sales_price").alias("total_revenue"),
    avg("ss_sales_price").alias("avg_sale_value")
)

# Remove salt and re-aggregate to get final result
final_result = good_skew.withColumn(
    "shopping_period",
    when(
        col("salted_period").startswith("CHRISTMAS_PEAK_SALT_"),
        lit("CHRISTMAS_PEAK")
    ).otherwise(col("salted_period"))
).groupBy("shopping_period").agg(
    sum("sales_count").alias("sales_count"),
    sum("total_revenue").alias("total_revenue"),
    avg("avg_sale_value").alias("avg_sale_value")
).orderBy(desc("sales_count"))

result = final_result.collect()
good_skew_time = time.time() - start_time

print(f"‚è±Ô∏è Time taken (WITH SALTING): {good_skew_time:.2f}s")
print(f"üöÄ Speedup: {bad_skew_time/good_skew_time:.1f}x faster!")
print("\nüîç Check Spark UI: Christmas load now distributed across 10 parallel tasks!")
print("\nüéÑ Shopping periods (after de-salting):")
for row in result[:8]:
    print(f"  {row['shopping_period']}: {row['sales_count']:,} sales")
print("\nüí° Golden Rule: Salt hot keys (like Christmas!) with random suffix, aggregate in parallel, then de-salt!")


## Issue #4: Adaptive Query Execution (AQE) - Let Spark Optimize

**The Problem**: Static planning can't adapt to actual data characteristics at runtime.

**Solution**: Enable AQE for dynamic optimizations:
- Coalesce shuffle partitions
- Convert sort-merge to broadcast join
- Optimize skewed joins automatically


In [None]:
### Compare: Without vs With AQE

# Disable AQE first
try:
    spark.conf.set("spark.sql.adaptive.enabled", False)
    print("üî¥ AQE Disabled\n")
    aqe_configurable = True
except Exception as e:
    print("‚ö†Ô∏è AQE configuration not available on Databricks Serverless")
    print("   ‚Üí AQE is always enabled and optimized automatically")
    print("   ‚Üí Skipping manual AQE comparison\n")
    aqe_configurable = False

if aqe_configurable:
    start_time = time.time()
    
    # Aggregate store_sales per customer first
    sales_per_customer = store_sales_df.groupBy("ss_customer_sk").agg(
        sum("ss_sales_price").alias("total_revenue"),
        count("*").alias("sales_count")
    )
    
    query_no_aqe = customers_df.select("c_customer_sk", "c_birth_country") \
        .join(
            sales_per_customer,
            customers_df.c_customer_sk == sales_per_customer.ss_customer_sk
        ).groupBy("c_birth_country").agg(
            sum("sales_count").alias("sales_count"),
            sum("total_revenue").alias("revenue")
        )
    
    result = query_no_aqe.collect()
    no_aqe_time = time.time() - start_time
    
    print(f"‚è±Ô∏è Time without AQE: {no_aqe_time:.2f}s")
    
    # Now enable AQE
    spark.conf.set("spark.sql.adaptive.enabled", True)
    spark.conf.set("spark.sql.adaptive.coalescePartitions.enabled", True)
    print("\nüü¢ AQE Enabled\n")
    
    start_time = time.time()
    
    # Aggregate store_sales per customer first
    sales_per_customer = store_sales_df.groupBy("ss_customer_sk").agg(
        sum("ss_sales_price").alias("total_revenue"),
        count("*").alias("sales_count")
    )
    
    query_with_aqe = customers_df.select("c_customer_sk", "c_birth_country") \
        .join(
            sales_per_customer,
            customers_df.c_customer_sk == sales_per_customer.ss_customer_sk
        ).groupBy("c_birth_country").agg(
            sum("sales_count").alias("sales_count"),
            sum("total_revenue").alias("revenue")
        )
    
    result = query_with_aqe.collect()
    aqe_time = time.time() - start_time
    
    print(f"‚è±Ô∏è Time with AQE: {aqe_time:.2f}s")
    print(f"üöÄ Improvement: {no_aqe_time/aqe_time:.1f}x")
    print("\nüîç Check Spark UI: AQE adjusts partitions dynamically!")
    print("üí° Golden Rule: ALWAYS enable AQE in production!")
else:
    print("üí° On Databricks Serverless:")
    print("   ‚úÖ AQE is always enabled - no manual configuration needed")
    print("   ‚úÖ Automatic partition coalescing")
    print("   ‚úÖ Automatic broadcast join conversion")
    print("   ‚úÖ Automatic skew join handling")
    print("   ‚Üí Focus on query optimization, let serverless handle the rest!")


In [None]:
### Enable AQE with all optimizations (PRODUCTION SETTINGS)

print("üìù AQE Configuration for Traditional Spark Clusters:")
print("=" * 60)

# Core AQE settings for production (traditional clusters)
try:
    spark.conf.set("spark.sql.adaptive.enabled", True)
    spark.conf.set("spark.sql.adaptive.coalescePartitions.enabled", True)
    spark.conf.set("spark.sql.adaptive.skewJoin.enabled", True)  # If supported
    spark.conf.set("spark.sql.adaptive.localShuffleReader.enabled", True)
    
    print("‚úÖ AQE Configuration Set!")
    print("\nüìù What AQE Does:")
    print("   ‚Ä¢ Coalesces small shuffle partitions")
    print("   ‚Ä¢ Converts to broadcast joins when beneficial")
    print("   ‚Ä¢ Handles skewed joins automatically")
    print("   ‚Ä¢ Optimizes based on runtime statistics")
except Exception as e:
    print("‚ö†Ô∏è Manual AQE configuration not available on Databricks Serverless")
    print("\n‚úÖ On Serverless: AQE is ALWAYS enabled with optimal settings!")
    print("\nüìù What Serverless AQE Does Automatically:")
    print("   ‚Ä¢ Coalesces small shuffle partitions")
    print("   ‚Ä¢ Converts to broadcast joins when beneficial")
    print("   ‚Ä¢ Handles skewed joins automatically")
    print("   ‚Ä¢ Optimizes based on runtime statistics")
    print("   ‚Ä¢ Adjusts resources dynamically")
    print("\nüí° No configuration needed - focus on query optimization!")


In [None]:
## üéØ Production Configuration Template

print("üíº Production Spark Configurations\n")
print("=" * 70)

print("\nüìã FOR TRADITIONAL SPARK CLUSTERS:")
print("-" * 70)

config_template = """
# Parallelism & Shuffles
spark.conf.set("spark.sql.shuffle.partitions", 200)  # Tune to cluster size

# Adaptive Query Execution (MUST HAVE)
spark.conf.set("spark.sql.adaptive.enabled", True)
spark.conf.set("spark.sql.adaptive.coalescePartitions.enabled", True)

# Broadcast Joins
spark.conf.set("spark.sql.autoBroadcastJoinThreshold", 48 * 1024 * 1024)  # 48MB

# Python Performance
spark.conf.set("spark.python.worker.reuse", True)

# Column Pruning (verify it's on)
spark.conf.set("spark.sql.optimizer.nestedSchemaPruning.enabled", True)
"""

print(config_template)

print("\nüìã FOR DATABRICKS SERVERLESS:")
print("-" * 70)
print("""
‚ö†Ô∏è Manual Spark configurations are NOT needed on Databricks Serverless!

‚úÖ What Serverless Handles Automatically:
   ‚Ä¢ Shuffle partitions (dynamically optimized)
   ‚Ä¢ Adaptive Query Execution (always enabled)
   ‚Ä¢ Broadcast join thresholds (auto-tuned)
   ‚Ä¢ Resource allocation (scales automatically)
   ‚Ä¢ Memory management (optimized for workload)

‚úÖ What YOU Should Focus On:
   ‚Ä¢ Query-level optimizations:
     - Column pruning: .select() early
     - Predicate pushdown: .filter() early
     - Broadcast hints: broadcast(small_df)
     - Efficient joins: join order matters
   ‚Ä¢ Data layout:
     - Partition data appropriately
     - Use Delta Lake optimization (OPTIMIZE, ZORDER)
   ‚Ä¢ Code quality:
     - Avoid Python UDFs (use built-in functions)
     - Minimize shuffles where possible

üí° Bottom Line: Write efficient queries, let serverless handle resources!
""")


##  Golden Rules



**Shuffle Explosion** | High network, slow stages | `.select()` columns early, filter early | 2-5x faster

**Missing Broadcast** | Unnecessary shuffles | `broadcast(small_df)` | 3-10x faster |

**Python UDFs** | Low CPU, high overhead | Use built-in functions | 5-20x faster |

**Data Skew** | Few tasks 10x slower | Salt hot keys, AQE | 2-5x faster |

**No AQE** | Static planning | Enable AQE | 1.5-3x faster |
