# Part 3: Spark Production Issues - Batch Processing

**Objective**: Identify, diagnose, and fix the most common Spark performance issues in batch workloads.

**Duration**: 20 minutes

**What You'll Learn**:
1. How to recognize performance bottlenecks using Spark UI
2. Fixing shuffle explosion and skewed data
3. Optimizing joins with broadcast
4. Avoiding Python UDF pitfalls
5. Leveraging Adaptive Query Execution (AQE)


In [None]:
# Setup: Import required libraries
from pyspark.sql.functions import *
from pyspark.sql.types import *
import time

# Load TPC-DS datasets (built into Databricks) - Scale Factor 1 (~1GB)
# TPC-DS is a more complex benchmark with larger datasets perfect for performance testing
# These datasets simulate a retail environment with stores, customers, and sales
customers_df = spark.read.table("samples.tpcds_sf1.customer")
store_sales_df = spark.read.table("samples.tpcds_sf1.store_sales")
item_df = spark.read.table("samples.tpcds_sf1.item")
date_dim_df = spark.read.table("samples.tpcds_sf1.date_dim")

print(f"Customers: {customers_df.count():,} rows")
print(f"Store Sales: {store_sales_df.count():,} rows")
print(f"Items: {item_df.count():,} rows")
print(f"Date Dimension: {date_dim_df.count():,} rows")


## Issue #1: Shuffle Explosion & Column Pruning

**The Problem**: Reading all columns and shuffling unnecessary data wastes network, CPU, and memory.

**Symptoms**:
- Jobs spend most time in shuffle read/write
- High network usage
- Slow stages with wide transformations

**Root Cause**: Not selecting only needed columns before joins/aggregations.

In [None]:
### ‚ùå BAD: Join without pruning columns first
# This shuffles ALL columns from both tables (unnecessary data movement)

start_time = time.time()

# Aggregate store_sales first to get revenue per customer (similar to orders)
sales_per_customer = store_sales_df.groupBy("ss_customer_sk").agg(
    sum("ss_sales_price").alias("total_revenue"),
    count("*").alias("sales_count")
)

bad_join = customers_df.join(
    sales_per_customer,
    customers_df.c_customer_sk == sales_per_customer.ss_customer_sk,
    "inner"
).groupBy("c_birth_country").agg(
    sum("total_revenue").alias("total_revenue"),
    sum("sales_count").alias("sales_count")
)

# Force execution
result_bad = bad_join.collect()
bad_time = time.time() - start_time

print(f"‚è±Ô∏è Time taken (BAD): {bad_time:.2f}s")
print(f"üìä Columns shuffled: {len(customers_df.columns) + len(sales_per_customer.columns)}")
print("\nüîç Check Spark UI: Look at shuffle read/write sizes in the stages!")


In [None]:
### ‚úÖ GOOD: Prune columns BEFORE the join
# Select only what you need as early as possible

start_time = time.time()

# Prune columns first!
customers_pruned = customers_df.select("c_customer_sk", "c_birth_country")
sales_per_customer_pruned = store_sales_df.select("ss_customer_sk", "ss_sales_price") \
    .groupBy("ss_customer_sk").agg(
        sum("ss_sales_price").alias("total_revenue"),
        count("*").alias("sales_count")
    )

good_join = customers_pruned.join(
    sales_per_customer_pruned,
    customers_pruned.c_customer_sk == sales_per_customer_pruned.ss_customer_sk,
    "inner"
).groupBy("c_birth_country").agg(
    sum("total_revenue").alias("total_revenue"),
    sum("sales_count").alias("sales_count")
)

# Force execution
result_good = good_join.collect()
good_time = time.time() - start_time

print(f"‚è±Ô∏è Time taken (GOOD): {good_time:.2f}s")
print(f"üìä Columns shuffled: 4 (only what we need)")
print(f"üöÄ Speedup: {bad_time/good_time:.1f}x faster!")
print("\nüí° Golden Rule: .select() early, shuffle less!")


## Issue #2: Broadcast Joins for Small Dimensions

**The Problem**: Sort-merge joins shuffle BOTH sides of the join, even when one side is tiny.

**Symptoms**:
- Unnecessary shuffles on small dimension tables
- Slow joins with reference/lookup tables

**Solution**: Use broadcast joins for small tables (< 100MB typically).


In [None]:
### ‚ùå BAD: Default sort-merge join (shuffles both sides)
# Even though item is small, Spark shuffles it

start_time = time.time()

# Item table is a small dimension table (perfect for broadcast)
# Join store_sales with item to get item details
# Note: We need to select ss_customer_sk for the second join
bad_broadcast = store_sales_df.select("ss_item_sk", "ss_customer_sk", "ss_sales_price", "ss_quantity") \
    .join(
        item_df.select("i_item_sk", "i_item_id", "i_category"),
        col("ss_item_sk") == col("i_item_sk")
    ).join(
        customers_df.select("c_customer_sk", "c_first_name", "c_last_name").limit(5000),
        col("ss_customer_sk") == col("c_customer_sk")
    ).groupBy("i_category", "c_first_name").agg(
        sum("ss_sales_price").alias("total_spent")
    )

result_bad = bad_broadcast.limit(10).collect()
bad_broadcast_time = time.time() - start_time

print(f"‚è±Ô∏è Time taken (NO BROADCAST): {bad_broadcast_time:.2f}s")
print("üîç Check Spark UI: See SortMergeJoin with shuffles on BOTH sides")


In [None]:
### ‚úÖ GOOD: Explicit broadcast join (no shuffle on small side)

from pyspark.sql.functions import broadcast

start_time = time.time()

small_customers = customers_df.select("c_customer_sk", "c_first_name", "c_last_name").limit(5000)

good_broadcast = store_sales_df.select("ss_item_sk", "ss_customer_sk", "ss_sales_price", "ss_quantity") \
    .join(
        broadcast(item_df.select("i_item_sk", "i_item_id", "i_category")),  # Broadcast item dimension
        col("ss_item_sk") == col("i_item_sk")
    ).join(
        broadcast(small_customers),  # Broadcast small customer subset
        col("ss_customer_sk") == col("c_customer_sk")
    ).groupBy("i_category", "c_first_name").agg(
        sum("ss_sales_price").alias("total_spent")
    )

result_good = good_broadcast.limit(10).collect()
good_broadcast_time = time.time() - start_time

print(f"‚è±Ô∏è Time taken (WITH BROADCAST): {good_broadcast_time:.2f}s")
print(f"üöÄ Speedup: {bad_broadcast_time/good_broadcast_time:.1f}x faster!")
print("\nüîç Check Spark UI: See BroadcastHashJoin (no shuffle on small side!)")
print("\nüí° Golden Rule: broadcast(dim_table) for small lookups!")


## Issue #3: Python UDF Performance Killer

**The Problem**: Python UDFs serialize data row-by-row between JVM and Python, killing performance.

**Symptoms**:
- Low CPU utilization
- Much slower than expected
- High overhead in stages with UDFs

**Solution**: Use built-in Spark SQL functions OR vectorized pandas UDFs.


In [None]:
### ‚ùå BAD: Python UDF (row-by-row serialization overhead)

from pyspark.sql.types import DoubleType

# Define a simple discount calculation UDF
@udf(returnType=DoubleType())
def calculate_discount_udf(price, quantity):
    """Apply tiered discount based on quantity"""
    if quantity >= 50:
        return float(price * 0.15)  # 15% discount
    elif quantity >= 20:
        return float(price * 0.10)  # 10% discount
    elif quantity >= 10:
        return float(price * 0.05)  # 5% discount
    else:
        return 0.0

start_time = time.time()

bad_udf = store_sales_df.select(
    "ss_sales_price", 
    "ss_quantity"
).withColumn(
    "discount_amount",
    calculate_discount_udf(col("ss_sales_price"), col("ss_quantity"))
).agg(
    sum("discount_amount").alias("total_discounts")
)

result = bad_udf.collect()
bad_udf_time = time.time() - start_time

print(f"‚è±Ô∏è Time taken (PYTHON UDF): {bad_udf_time:.2f}s")
print("‚ö†Ô∏è Every row crosses Python-JVM boundary!")
print("üìä TPC-DS SF1 has millions of rows - this will be slow!")


In [None]:
### ‚úÖ GOOD: Built-in Spark SQL functions (pure JVM, no serialization)

start_time = time.time()

good_builtin = store_sales_df.select(
    "ss_sales_price", 
    "ss_quantity"
).withColumn(
    "discount_amount",
    when(col("ss_quantity") >= 50, col("ss_sales_price") * 0.15)
    .when(col("ss_quantity") >= 20, col("ss_sales_price") * 0.10)
    .when(col("ss_quantity") >= 10, col("ss_sales_price") * 0.05)
    .otherwise(0.0)
).agg(
    sum("discount_amount").alias("total_discounts")
)

result = good_builtin.collect()
good_builtin_time = time.time() - start_time

print(f"‚è±Ô∏è Time taken (BUILT-IN): {good_builtin_time:.2f}s")
print(f"üöÄ Speedup: {bad_udf_time/good_builtin_time:.1f}x faster!")
print("\nüí° Golden Rule: Use when/case_when over Python UDFs!")


## Issue #4: Data Skew - The Silent Killer

**The Problem**: Uneven key distribution causes few tasks to process most data while others idle.

**Symptoms**:
- One or few tasks taking 10-100x longer than others
- Stage time dominated by stragglers
- Wasted cluster resources

**Solution**: Identify skewed keys and apply salting or repartitioning.


In [None]:
### Step 1: Detect Skew - Analyze key distribution

# Check distribution of sales per customer (potential skew)
skew_analysis = store_sales_df.groupBy("ss_customer_sk").agg(
    count("*").alias("sales_count"),
    sum("ss_sales_price").alias("total_revenue")
).orderBy(desc("sales_count"))

print("üìä Top customers by sales count (potential hot keys):")
skew_analysis.show(10)

# Get statistics
stats = skew_analysis.agg(
    min("sales_count").alias("min"),
    max("sales_count").alias("max"),
    avg("sales_count").alias("avg"),
    expr("percentile(sales_count, 0.95)").alias("p95")
).collect()[0]

print(f"\nüìà Skew Statistics:")
print(f"   Min sales: {stats['min']}")
print(f"   Max sales: {stats['max']}")
print(f"   Avg sales: {stats['avg']:.1f}")
print(f"   95th percentile: {stats['p95']:.1f}")
print(f"   üî• Skew factor: {stats['max']/stats['avg']:.1f}x above average!")


In [None]:
### Step 2: ‚ùå BAD - Direct aggregation on skewed keys

# Simulate worse skew by creating artificial hot key based on customer
sales_with_skew = store_sales_df.join(
    customers_df.select("c_customer_sk", "c_birth_country"),
    store_sales_df.ss_customer_sk == customers_df.c_customer_sk
).withColumn(
    "customer_segment",
    when(col("c_customer_sk") % 100 == 0, "VIP")  # 1% are VIP (hot key)
    .otherwise(concat(lit("REGULAR_"), (col("c_customer_sk") % 50).cast("string")))
)

start_time = time.time()

# Aggregate on skewed key
bad_skew = sales_with_skew.groupBy("customer_segment").agg(
    count("*").alias("sales_count"),
    sum("ss_sales_price").alias("total_revenue"),
    avg("ss_sales_price").alias("avg_sale_value")
).orderBy(desc("total_revenue"))

result = bad_skew.collect()
bad_skew_time = time.time() - start_time

print(f"‚è±Ô∏è Time taken (SKEWED): {bad_skew_time:.2f}s")
print("üîç Check Spark UI: See task time distribution - one task much longer!")
print("\nTop segments:")
for row in result[:5]:
    print(f"  {row['customer_segment']}: {row['sales_count']:,} sales")


### Step 3: ‚úÖ GOOD - Salting technique to distribute hot keys

**Salting**: Add random suffix to hot keys, aggregate, then remove salt.


In [None]:
start_time = time.time()

# Apply salting: add random salt to distribute VIP load
SALT_FACTOR = 10  # Split hot key into 10 sub-keys

sales_salted = sales_with_skew.withColumn(
    "salted_segment",
    when(
        col("customer_segment") == "VIP",  # Only salt the hot key
        concat(col("customer_segment"), lit("_"), (rand() * SALT_FACTOR).cast("int").cast("string"))
    ).otherwise(col("customer_segment"))
)

# Aggregate on salted keys (distributes VIP across multiple tasks)
good_skew = sales_salted.groupBy("salted_segment").agg(
    count("*").alias("sales_count"),
    sum("ss_sales_price").alias("total_revenue"),
    avg("ss_sales_price").alias("avg_sale_value")
)

# Remove salt and re-aggregate to get final result
final_result = good_skew.withColumn(
    "customer_segment",
    when(
        col("salted_segment").startswith("VIP_"),
        lit("VIP")
    ).otherwise(col("salted_segment"))
).groupBy("customer_segment").agg(
    sum("sales_count").alias("sales_count"),
    sum("total_revenue").alias("total_revenue"),
    avg("avg_sale_value").alias("avg_sale_value")
).orderBy(desc("total_revenue"))

result = final_result.collect()
good_skew_time = time.time() - start_time

print(f"‚è±Ô∏è Time taken (WITH SALTING): {good_skew_time:.2f}s")
print(f"üöÄ Speedup: {bad_skew_time/good_skew_time:.1f}x faster!")
print("\nüîç Check Spark UI: Tasks are now balanced!")
print("\nüí° Golden Rule: Salt hot keys with random suffix, aggregate, then de-salt!")


## Issue #5: Adaptive Query Execution (AQE) - Let Spark Optimize

**The Problem**: Static planning can't adapt to actual data characteristics at runtime.

**Solution**: Enable AQE for dynamic optimizations:
- Coalesce shuffle partitions
- Convert sort-merge to broadcast join
- Optimize skewed joins automatically


In [None]:
### Compare: Without vs With AQE

# Disable AQE first
spark.conf.set("spark.sql.adaptive.enabled", False)
print("üî¥ AQE Disabled\n")

start_time = time.time()

# Aggregate store_sales per customer first
sales_per_customer = store_sales_df.groupBy("ss_customer_sk").agg(
    sum("ss_sales_price").alias("total_revenue"),
    count("*").alias("sales_count")
)

query_no_aqe = customers_df.select("c_customer_sk", "c_birth_country") \
    .join(
        sales_per_customer,
        customers_df.c_customer_sk == sales_per_customer.ss_customer_sk
    ).groupBy("c_birth_country").agg(
        sum("sales_count").alias("sales_count"),
        sum("total_revenue").alias("revenue")
    )

result = query_no_aqe.collect()
no_aqe_time = time.time() - start_time

print(f"‚è±Ô∏è Time without AQE: {no_aqe_time:.2f}s")

# Now enable AQE
spark.conf.set("spark.sql.adaptive.enabled", True)
spark.conf.set("spark.sql.adaptive.coalescePartitions.enabled", True)
print("\nüü¢ AQE Enabled\n")

start_time = time.time()

# Aggregate store_sales per customer first
sales_per_customer = store_sales_df.groupBy("ss_customer_sk").agg(
    sum("ss_sales_price").alias("total_revenue"),
    count("*").alias("sales_count")
)

query_with_aqe = customers_df.select("c_customer_sk", "c_birth_country") \
    .join(
        sales_per_customer,
        customers_df.c_customer_sk == sales_per_customer.ss_customer_sk
    ).groupBy("c_birth_country").agg(
        sum("sales_count").alias("sales_count"),
        sum("total_revenue").alias("revenue")
    )

result = query_with_aqe.collect()
aqe_time = time.time() - start_time

print(f"‚è±Ô∏è Time with AQE: {aqe_time:.2f}s")
print(f"üöÄ Improvement: {no_aqe_time/aqe_time:.1f}x")
print("\nüîç Check Spark UI: AQE adjusts partitions dynamically!")
print("üí° Golden Rule: ALWAYS enable AQE in production!")


In [None]:
### Enable AQE with all optimizations (PRODUCTION SETTINGS)

# Core AQE settings for production
spark.conf.set("spark.sql.adaptive.enabled", True)
spark.conf.set("spark.sql.adaptive.coalescePartitions.enabled", True)
spark.conf.set("spark.sql.adaptive.skewJoin.enabled", True)  # If supported
spark.conf.set("spark.sql.adaptive.localShuffleReader.enabled", True)

print("‚úÖ AQE Configuration Set!")
print("\nüìù What AQE Does:")
print("   ‚Ä¢ Coalesces small shuffle partitions")
print("   ‚Ä¢ Converts to broadcast joins when beneficial")
print("   ‚Ä¢ Handles skewed joins automatically")
print("   ‚Ä¢ Optimizes based on runtime statistics")


In [None]:
## üìä How to Use Spark UI for Diagnosis

print("üîç Spark UI Analysis Checklist:\n")

print("1Ô∏è‚É£ SQL/DataFrame Tab:")
print("   ‚Ä¢ Look at physical plan for shuffle boundaries")
print("   ‚Ä¢ Check for BroadcastHashJoin vs SortMergeJoin")
print("   ‚Ä¢ Verify column pruning worked\n")

print("2Ô∏è‚É£ Stages Tab:")
print("   ‚Ä¢ Check shuffle read/write sizes")
print("   ‚Ä¢ Look for task time skew (min vs max)")
print("   ‚Ä¢ Identify stages with most time\n")

print("3Ô∏è‚É£ Executors Tab:")
print("   ‚Ä¢ Monitor memory usage and GC time")
print("   ‚Ä¢ Check for failed tasks\n")

print("4Ô∏è‚É£ Key Metrics to Watch:")
print("   ‚Ä¢ Shuffle spill (memory/disk)")
print("   ‚Ä¢ Task skew (median vs max)")
print("   ‚Ä¢ Number of shuffle partitions")
print("   ‚Ä¢ Broadcast size")


In [None]:
## üéØ Production Configuration Template

print("üíº Copy-paste for production Spark configs:\n")

config_template = """
# Parallelism & Shuffles
spark.conf.set("spark.sql.shuffle.partitions", 200)  # Tune to cluster size

# Adaptive Query Execution (MUST HAVE)
spark.conf.set("spark.sql.adaptive.enabled", True)
spark.conf.set("spark.sql.adaptive.coalescePartitions.enabled", True)

# Broadcast Joins
spark.conf.set("spark.sql.autoBroadcastJoinThreshold", 48 * 1024 * 1024)  # 48MB

# Python Performance
spark.conf.set("spark.python.worker.reuse", True)

# Column Pruning (verify it's on)
spark.conf.set("spark.sql.optimizer.nestedSchemaPruning.enabled", True)
"""

print(config_template)


## üéØ Part 3 Summary: Batch Performance Golden Rules

### Top 5 Issues & Fixes

| Issue | Symptom | Fix | Impact |
|-------|---------|-----|--------|
| **Shuffle Explosion** | High network, slow stages | `.select()` columns early, filter early | 2-5x faster |
| **Missing Broadcast** | Unnecessary shuffles | `broadcast(small_df)` | 3-10x faster |
| **Python UDFs** | Low CPU, high overhead | Use built-in functions | 5-20x faster |
| **Data Skew** | Few tasks 10x slower | Salt hot keys, AQE | 2-5x faster |
| **No AQE** | Static planning | Enable AQE | 1.5-3x faster |

### Diagnosis Workflow

```
1. Check Spark UI SQL tab ‚Üí Identify shuffle-heavy stages
2. Look at physical plan ‚Üí See join types and column pruning
3. Check task distribution ‚Üí Spot skew
4. Apply fixes ‚Üí Re-run and compare
5. Use explain("formatted") ‚Üí Verify optimizations
```

### Before Moving to Production

‚úÖ Enable AQE  
‚úÖ Add broadcast hints for dimension tables  
‚úÖ Replace Python UDFs with built-ins  
‚úÖ Prune columns before joins  
‚úÖ Profile with Spark UI under realistic load  
‚úÖ Set appropriate shuffle partitions  
