# 🚀 Module 5: PySpark Performance Optimization
*Comprehensive Guide to Optimizing PySpark Applications for Production*

## 📋 Learning Objectives
By the end of this module, you will master:

🎯 **Partitioning Strategies**
- Hash, range, and custom partitioning
- Bucketing for join optimization
- Partition pruning techniques

⚡ **Caching & Persistence**
- Storage levels and memory management
- Checkpoint operations
- When and how to cache effectively

🔧 **Query Optimization**
- Catalyst optimizer deep dive
- Adaptive Query Execution (AQE)
- Broadcast joins and predicate pushdown

💪 **Resource Management**
- Dynamic allocation
- Memory tuning and garbage collection
- Parallelism optimization

📊 **Performance Monitoring**
- Spark UI analysis
- Metrics and monitoring tools
- Bottleneck identification

---

## 🏗️ Module Structure
1. **Partitioning Strategies** - Data distribution optimization
2. **Caching & Persistence** - Memory management techniques  
3. **Query Optimization** - Catalyst and AQE optimization
4. **Resource Management** - Cluster resource tuning
5. **Performance Monitoring** - Real-time performance analysis
6. **Production Best Practices** - Enterprise-ready optimizations

In [None]:
# 🚀 Module 5: PySpark Performance Optimization Setup
print("🔧 Setting up PySpark Performance Optimization Environment...")

import os
import time
import random
from datetime import datetime, timedelta
from pyspark.sql import SparkSession
from pyspark.sql.functions import *
from pyspark.sql.types import *
from pyspark.sql.window import Window
from pyspark.storagelevel import StorageLevel
import pyspark.sql.functions as F

# Configure Spark for performance optimization demonstrations
spark = SparkSession.builder \
    .appName("PySpark-Performance-Optimization") \
    .config("spark.sql.adaptive.enabled", "true") \
    .config("spark.sql.adaptive.coalescePartitions.enabled", "true") \
    .config("spark.sql.adaptive.skewJoin.enabled", "true") \
    .config("spark.sql.adaptive.localShuffleReader.enabled", "true") \
    .config("spark.serializer", "org.apache.spark.serializer.KryoSerializer") \
    .config("spark.sql.execution.arrow.pyspark.enabled", "true") \
    .config("spark.default.parallelism", "8") \
    .config("spark.sql.shuffle.partitions", "8") \
    .getOrCreate()

# Set log level to reduce noise
spark.sparkContext.setLogLevel("WARN")

print("✅ Spark Session Created with Performance Optimizations")
print(f"🎯 Spark Version: {spark.version}")
print(f"⚡ Default Parallelism: {spark.sparkContext.defaultParallelism}")
print(f"🔄 Shuffle Partitions: {spark.conf.get('spark.sql.shuffle.partitions')}")
print(f"🧠 AQE Enabled: {spark.conf.get('spark.sql.adaptive.enabled')}")

# Display current Spark configuration
print("\n📊 Key Performance Configurations:")
perf_configs = [
    "spark.sql.adaptive.enabled",
    "spark.sql.adaptive.coalescePartitions.enabled", 
    "spark.sql.adaptive.skewJoin.enabled",
    "spark.serializer",
    "spark.sql.execution.arrow.pyspark.enabled"
]

for config in perf_configs:
    value = spark.conf.get(config, "Not Set")
    print(f"   {config}: {value}")

In [None]:
# 📊 Generate Test Dataset for Performance Demonstrations
print(" Creating Performance Test Dataset...")

# Create a medium-sized dataset for performance testing
from pyspark.sql.functions import rand, when, floor, date_add, lit
from datetime import date

# Generate synthetic sales data efficiently using Spark functions
print(" Generating synthetic sales data...")

# Create base DataFrame with sequential IDs
base_df = spark.range(1, 100001).withColumnRenamed("id", "transaction_id")

# Add synthetic columns using Spark functions for better performance
sales_df = base_df \
    .withColumn("customer_id", floor(rand() * 50000).cast("int")) \
    .withColumn("product_id", floor(rand() * 10000).cast("int")) \
    .withColumn("category", 
                when(col("transaction_id") % 6 == 0, "Electronics")
                .when(col("transaction_id") % 6 == 1, "Clothing") 
                .when(col("transaction_id") % 6 == 2, "Books")
                .when(col("transaction_id") % 6 == 3, "Home")
                .when(col("transaction_id") % 6 == 4, "Sports")
                .otherwise("Automotive")) \
    .withColumn("region",
                when(col("transaction_id") % 5 == 0, "North")
                .when(col("transaction_id") % 5 == 1, "South")
                .when(col("transaction_id") % 5 == 2, "East") 
                .when(col("transaction_id") % 5 == 3, "West")
                .otherwise("Central")) \
    .withColumn("amount", (rand() * 1990 + 10).cast("decimal(10,2)")) \
    .withColumn("quantity", floor(rand() * 10 + 1).cast("int")) \
    .withColumn("transaction_date", 
                date_add(lit(date(2023, 1, 1)), floor(rand() * 365).cast("int"))) \
    .withColumn("discount_pct", (rand() * 30).cast("decimal(5,2)"))

# Cache the DataFrame for reuse
sales_df.cache()

# Trigger action to materialize the data
record_count = sales_df.count()

print(f"✅ Performance Test Dataset Created")
print(f"📊 Records: {record_count:,}")
print(f" Partitions: {sales_df.rdd.getNumPartitions()}")
print(f"💾 Cached: {sales_df.is_cached}")

# Show sample data
print("\n🔍 Sample Data:")
sales_df.show(5, truncate=False)

print("\n📋 Schema:")
sales_df.printSchema()

---

# 🎯 Section 1: Partitioning Strategies

## 📚 Core Concepts

**Partitioning** is fundamental to Spark performance. It determines:
- How data is distributed across the cluster
- Parallelism level for operations  
- Network shuffle requirements
- Join optimization opportunities

### 🔑 Key Partitioning Types

1. **Hash Partitioning** - Default for most operations
2. **Range Partitioning** - Ordered data distribution
3. **Custom Partitioning** - Application-specific logic
4. **Bucketing** - Pre-partitioned storage optimization

### ⚡ Performance Impact

- **Good partitioning**: Parallel processing, minimal shuffles
- **Poor partitioning**: Data skew, excessive network I/O, slow joins

---

In [None]:
# 🎯 Section 1.1: Analyzing Current Partitioning
print("🔍 Analyzing Current Data Partitioning...")

# First, let's create a simple dataset to work with
from pyspark.sql.functions import rand, when, floor, date_add, lit
from datetime import date

# Create sample dataset for partitioning demonstrations
sample_df = spark.range(1, 50001) \
    .withColumnRenamed("id", "record_id") \
    .withColumn("region", when(col("record_id") % 4 == 0, "North")
                .when(col("record_id") % 4 == 1, "South") 
                .when(col("record_id") % 4 == 2, "East")
                .otherwise("West")) \
    .withColumn("category", when(col("record_id") % 3 == 0, "A")
                .when(col("record_id") % 3 == 1, "B")
                .otherwise("C")) \
    .withColumn("value", (rand() * 1000).cast("decimal(10,2)"))

print(f"✅ Sample Dataset Created: {sample_df.count():,} records")
print(f"📊 Current Partitions: {sample_df.rdd.getNumPartitions()}")

# Analyze partition distribution
print("\n🔍 Partition Distribution Analysis:")
partition_counts = sample_df.rdd.glom().map(len).collect()
print(f"Records per partition: {partition_counts}")
print(f"Min records/partition: {min(partition_counts):,}")
print(f"Max records/partition: {max(partition_counts):,}")
print(f"Avg records/partition: {sum(partition_counts)/len(partition_counts):.1f}")

# Check for data skew by region
print("\n📈 Data Distribution by Region:")
sample_df.groupBy("region").count().orderBy("region").show()

print("\n📈 Data Distribution by Category:")
sample_df.groupBy("category").count().orderBy("category").show()

In [None]:
# 🎯 Section 1.2: Hash Partitioning Strategies
print("🔗 Demonstrating Hash Partitioning Techniques...")

# Create a dataset for partitioning demonstrations
data_df = spark.range(1, 100001) \
    .withColumnRenamed("id", "customer_id") \
    .withColumn("region", when(col("customer_id") % 5 == 0, "North")
                .when(col("customer_id") % 5 == 1, "South")
                .when(col("customer_id") % 5 == 2, "East") 
                .when(col("customer_id") % 5 == 3, "West")
                .otherwise("Central")) \
    .withColumn("amount", (rand() * 1000).cast("decimal(10,2)"))

print(f"📊 Original dataset: {data_df.rdd.getNumPartitions()} partitions")

# 1. Hash partitioning by region (optimal for region-based analytics)
print("\n🔗 Hash Partitioning by Region:")
hash_partitioned = data_df.repartition(8, "region")
print(f"   Partitions after repartition: {hash_partitioned.rdd.getNumPartitions()}")

# Check distribution across partitions
print("   Records per partition:")
partition_sizes = hash_partitioned.rdd.glom().map(len).collect()
for i, size in enumerate(partition_sizes):
    print(f"     Partition {i}: {size:,} records")

# 2. Multiple column hash partitioning
print("\n🔗 Multi-Column Hash Partitioning:")
multi_hash = data_df.repartition(8, "region", "customer_id")
print(f"   Partitions: {multi_hash.rdd.getNumPartitions()}")

# 3. Compare performance for region-based aggregation
print("\n⚡ Performance Comparison: Region Aggregation")

# Original partitioning
start_time = time.time()
result1 = data_df.groupBy("region").agg(
    count("*").alias("total_customers"),
    sum("amount").alias("total_amount"),
    avg("amount").alias("avg_amount")
).collect()
original_time = time.time() - start_time

# Hash partitioned by region
start_time = time.time()
result2 = hash_partitioned.groupBy("region").agg(
    count("*").alias("total_customers"), 
    sum("amount").alias("total_amount"),
    avg("amount").alias("avg_amount")
).collect()
hash_time = time.time() - start_time

print(f"   Original partitioning: {original_time:.3f}s")
print(f"   Hash partitioned: {hash_time:.3f}s")
print(f"   Improvement: {((original_time - hash_time) / original_time * 100):.1f}%")

# Display results
print("\n📈 Aggregation Results:")
result_df = spark.createDataFrame(result2)
result_df.show()

In [None]:
# 🎯 Section 1.3: Range Partitioning for Sorted Data
print("📊 Demonstrating Range Partitioning...")

# Range partitioning is optimal for:
# - Time series data
# - Ordered data access patterns  
# - Range queries

# Create time series dataset
from pyspark.sql.functions import date_add, lit, expr
from datetime import date

time_series_df = spark.range(1, 100001) \
    .withColumnRenamed("id", "event_id") \
    .withColumn("event_date", 
                date_add(lit(date(2023, 1, 1)), 
                        floor(col("event_id") / 274).cast("int"))) \
    .withColumn("sensor_id", floor(rand() * 100).cast("int")) \
    .withColumn("temperature", (rand() * 40 + 10).cast("decimal(5,2)")) \
    .withColumn("humidity", (rand() * 100).cast("decimal(5,2)"))

print(f"📅 Time Series Dataset: {time_series_df.count():,} records")
print(f"📊 Original partitions: {time_series_df.rdd.getNumPartitions()}")

# Show date range
date_range = time_series_df.agg(
    min("event_date").alias("start_date"),
    max("event_date").alias("end_date")
).collect()[0]
print(f"📅 Date range: {date_range['start_date']} to {date_range['end_date']}")

# 1. Range partition by date (optimal for time-based queries)
print("\n📊 Range Partitioning by Date:")
range_partitioned = time_series_df.repartitionByRange(8, "event_date")
print(f"   Partitions: {range_partitioned.rdd.getNumPartitions()}")

# Check how data is distributed across partitions
print("   Data distribution by partition:")
partition_dates = []
partitions = range_partitioned.rdd.glom().collect()
for i, partition_data in enumerate(partitions):
    if partition_data:
        dates = [row.event_date for row in partition_data]
        min_date = min(dates)
        max_date = max(dates)
        print(f"     Partition {i}: {len(partition_data):,} records ({min_date} to {max_date})")
    else:
        print(f"     Partition {i}: 0 records (empty)")

# 2. Performance comparison for date range queries
print("\n⚡ Performance Test: Date Range Query")

# Test query: Get data for a specific month
test_date_start = date(2023, 6, 1)
test_date_end = date(2023, 6, 30)

# Original partitioning
start_time = time.time()
result1 = time_series_df.filter(
    (col("event_date") >= lit(test_date_start)) & 
    (col("event_date") <= lit(test_date_end))
).agg(
    count("*").alias("records"),
    avg("temperature").alias("avg_temp"),
    avg("humidity").alias("avg_humidity")
).collect()[0]
original_time = time.time() - start_time

# Range partitioned
start_time = time.time() 
result2 = range_partitioned.filter(
    (col("event_date") >= lit(test_date_start)) & 
    (col("event_date") <= lit(test_date_end))
).agg(
    count("*").alias("records"),
    avg("temperature").alias("avg_temp"), 
    avg("humidity").alias("avg_humidity")
).collect()[0]
range_time = time.time() - start_time

print(f"   Original partitioning: {original_time:.3f}s")
print(f"   Range partitioned: {range_time:.3f}s")
print(f"   Query result: {result2['records']:,} records, Temp: {result2['avg_temp']:.1f}°C")

# 3. Partition pruning demonstration
print("\n🎯 Partition Pruning Benefits:")
print("   Range partitioning enables partition pruning for date queries")
print("   - Only relevant partitions are read")
print("   - Significant I/O reduction for large datasets")
print("   - Better for time series analytics and reporting")

---

# 💾 Section 2: Caching & Persistence Strategies

## 🎯 Key Concepts

**Caching** stores DataFrames in memory/disk for faster subsequent access:
- **Memory-only**: Fastest access, limited by available memory
- **Memory + Disk**: Spills to disk when memory full
- **Disk-only**: Slower but handles large datasets
- **Serialized**: Compressed storage, slower access

### 🔑 Storage Levels

| Level | Memory | Disk | Serialized | Replication |
|-------|---------|------|------------|-------------|
| `MEMORY_ONLY` | ✅ | ❌ | ❌ | 1x |
| `MEMORY_AND_DISK` | ✅ | ✅ | ❌ | 1x |
| `MEMORY_ONLY_SER` | ✅ | ❌ | ✅ | 1x |
| `DISK_ONLY` | ❌ | ✅ | ❌ | 1x |
| `MEMORY_AND_DISK_2` | ✅ | ✅ | ❌ | 2x |

### ⚡ When to Cache

✅ **Good candidates:**
- DataFrames used multiple times
- Intermediate results in iterative algorithms
- Lookup tables and dimension data
- Expensive computations

❌ **Avoid caching:**
- Data used only once
- Very large datasets (memory pressure)
- Simple transformations (filtering, selecting)

---

In [None]:
# 💾 Section 2.1: Caching Performance Demonstration
print("🔄 Demonstrating Caching Performance Benefits...")

# Create a computational expensive DataFrame
expensive_df = spark.range(1, 200001) \
    .withColumnRenamed("id", "transaction_id") \
    .withColumn("customer_segment", 
                when(col("transaction_id") % 10 == 0, "Premium")
                .when(col("transaction_id") % 5 == 0, "Gold") 
                .otherwise("Standard")) \
    .withColumn("complex_calc", 
                # Simulate expensive computation
                sqrt(col("transaction_id")) * sin(col("transaction_id") / 1000) + 
                cos(col("transaction_id") / 500)) \
    .withColumn("amount", (rand() * 2000 + 100).cast("decimal(10,2)"))

print(f"📊 Expensive DataFrame created: {expensive_df.count():,} records")

# Test 1: Without caching - multiple operations
print("\n🚫 Performance WITHOUT Caching:")
start_time = time.time()

# First operation
result1 = expensive_df.groupBy("customer_segment").agg(
    count("*").alias("count"),
    avg("amount").alias("avg_amount")
).collect()

# Second operation  
result2 = expensive_df.filter(col("amount") > 1000).count()

# Third operation
result3 = expensive_df.agg(
    sum("complex_calc").alias("total_calc"),
    max("amount").alias("max_amount")
).collect()[0]

no_cache_time = time.time() - start_time
print(f"   Total time (3 operations): {no_cache_time:.3f}s")

# Test 2: With caching - same operations
print("\n✅ Performance WITH Caching:")
cached_df = expensive_df.cache()

# Trigger caching with first action
cache_start = time.time()
cached_count = cached_df.count()
cache_load_time = time.time() - cache_start

# Now run the same operations
start_time = time.time()

result1_cached = cached_df.groupBy("customer_segment").agg(
    count("*").alias("count"), 
    avg("amount").alias("avg_amount")
).collect()

result2_cached = cached_df.filter(col("amount") > 1000).count()

result3_cached = cached_df.agg(
    sum("complex_calc").alias("total_calc"),
    max("amount").alias("max_amount")
).collect()[0]

cached_time = time.time() - start_time
total_cached_time = cache_load_time + cached_time

print(f"   Cache loading time: {cache_load_time:.3f}s")
print(f"   Operations time: {cached_time:.3f}s")
print(f"   Total time: {total_cached_time:.3f}s")

# Performance comparison
print(f"\n⚡ Performance Improvement:")
if total_cached_time < no_cache_time:
    improvement = ((no_cache_time - total_cached_time) / no_cache_time) * 100
    print(f"   Speedup: {improvement:.1f}% faster with caching")
else:
    overhead = ((total_cached_time - no_cache_time) / no_cache_time) * 100
    print(f"   Overhead: {overhead:.1f}% slower (cache loading cost)")

print(f"   Break-even: Cached approach faster after 2+ operations")

# Display cache statistics
print(f"\n📈 Cache Statistics:")
print(f"   Dataset cached: {cached_df.is_cached}")
print(f"   Storage level: {cached_df.storageLevel}")
print(f"   Records cached: {cached_count:,}")

# Show results verification
print(f"\n✅ Results Verification:")
print(f"   High-value transactions: {result2:,} (no cache) vs {result2_cached:,} (cached)")
print(f"   Results match: {result2 == result2_cached}")

In [None]:
# 💾 Section 2.2: Storage Levels Comparison
print("🔍 Comparing Different Storage Levels...")

# Create test dataset
test_df = spark.range(1, 50001) \
    .withColumnRenamed("id", "record_id") \
    .withColumn("data", concat(lit("DATA_"), col("record_id").cast("string"))) \
    .withColumn("value", (rand() * 1000).cast("decimal(10,2)")) \
    .withColumn("category", when(col("record_id") % 3 == 0, "A")
                .when(col("record_id") % 3 == 1, "B")
                .otherwise("C"))

print(f"📊 Test dataset: {test_df.count():,} records")

# Test different storage levels
storage_levels = {
    "MEMORY_ONLY": StorageLevel.MEMORY_ONLY,
    "MEMORY_AND_DISK": StorageLevel.MEMORY_AND_DISK, 
    "MEMORY_ONLY_SER": StorageLevel.MEMORY_ONLY_SER,
    "DISK_ONLY": StorageLevel.DISK_ONLY
}

results = {}

for level_name, storage_level in storage_levels.items():
    print(f"\n🔍 Testing {level_name}:")
    
    # Create DataFrame with specific storage level
    test_cached = test_df.persist(storage_level)
    
    # Time the caching operation
    start_time = time.time()
    count = test_cached.count()  # Trigger caching
    cache_time = time.time() - start_time
    
    # Time a simple operation  
    start_time = time.time()
    agg_result = test_cached.groupBy("category").count().collect()
    operation_time = time.time() - start_time
    
    results[level_name] = {
        "cache_time": cache_time,
        "operation_time": operation_time,
        "storage_level": storage_level
    }
    
    print(f"   Cache time: {cache_time:.3f}s")
    print(f"   Operation time: {operation_time:.3f}s")
    print(f"   Storage level: {storage_level}")
    
    # Unpersist to clean up
    test_cached.unpersist()

# Summary comparison
print(f"\n📊 Storage Level Performance Summary:")
print(f"{'Level':<20} {'Cache Time':<12} {'Op Time':<10} {'Total':<10}")
print("-" * 55)

for level_name, metrics in results.items():
    total_time = metrics["cache_time"] + metrics["operation_time"]
    print(f"{level_name:<20} {metrics['cache_time']:<12.3f} {metrics['operation_time']:<10.3f} {total_time:<10.3f}")

# Recommendations
print(f"\n💡 Storage Level Recommendations:")
print(f"   🚀 MEMORY_ONLY: Fastest for datasets that fit in memory")
print(f"   ⚖️  MEMORY_AND_DISK: Best balance for most use cases")
print(f"   💾 MEMORY_ONLY_SER: Memory-efficient for large datasets")
print(f"   🐌 DISK_ONLY: Slowest but handles very large datasets")

# Cache management best practices
print(f"\n🎯 Cache Management Best Practices:")
print(f"   1. Monitor memory usage with Spark UI")
print(f"   2. Unpersist DataFrames when no longer needed")
print(f"   3. Use broadcast for small lookup tables")
print(f"   4. Consider serialization for memory-constrained environments")
print(f"   5. Test different storage levels for your use case")