# Challenge 3: Optimizing Joins with Broadcast Variables

## Task Description
In this challenge, we need to:
1. Understand when to use broadcast joins
2. Implement explicit and automatic broadcast joins
3. Measure performance improvements from broadcasting
4. Consider implications for larger datasets

In [None]:
from pyspark.sql import SparkSession
from pyspark.sql.functions import *
from pyspark.sql.types import *
import matplotlib.pyplot as plt
import time
import random

# Create Spark session
spark = SparkSession.builder \
    .appName("Broadcast Join Optimization") \
    .master("local[*]") \
    .config("spark.sql.shuffle.partitions", 10) \
    .getOrCreate()

spark.sparkContext.setLogLevel("WARN")

## Generate Test Data

We'll create two datasets:
1. A large dataset of connection logs (many rows)
2. A small reference dataset of country information (few rows)

In [None]:
# Generate a larger dataset for more meaningful performance comparison
countries = ["US", "UK", "DE", "FR", "CN", "IN", "BR", "JP", "CA", "AU"]
    
# Create sample data
import random
from datetime import datetime, timedelta
    
data = []
    
# Generate connection records
num_records = 500000  # Use more records for better measurement
    
for i in range(num_records):
    # More even distribution for this test
    country = random.choice(countries)
    
    # Create record
    timestamp = datetime.now() - timedelta(days=random.randint(0, 30), 
                                          hours=random.randint(0, 23),
                                          minutes=random.randint(0, 59))
    
    data.append((
        f"user_{random.randint(1, 5000)}",  # user_id
        timestamp.isoformat(),              # timestamp
        country,                            # country
        f"10.{random.randint(0, 255)}.{random.randint(0, 255)}.{random.randint(0, 255)}",  # ip_address
        random.choice(["success", "failed"]), # status
        random.randint(1, 100)              # duration_seconds
    ))

# Create DataFrame
columns = ["user_id", "timestamp", "country", "ip_address", "status", "duration_seconds"]
connection_logs = spark.createDataFrame(data, columns)

# Create a small lookup table for countries (this is the table we'll broadcast)
country_info = spark.createDataFrame([
    ("US", "United States", "North America", "English"),
    ("UK", "United Kingdom", "Europe", "English"),
    ("DE", "Germany", "Europe", "German"),
    ("FR", "France", "Europe", "French"),
    ("CN", "China", "Asia", "Chinese"),
    ("IN", "India", "Asia", "Hindi/English"),
    ("BR", "Brazil", "South America", "Portuguese"),
    ("JP", "Japan", "Asia", "Japanese"),
    ("CA", "Canada", "North America", "English/French"),
    ("AU", "Australia", "Oceania", "English")
], ["country_code", "country_name", "region", "language"])

# Cache for better performance
connection_logs.cache()
connection_logs.count()  # Force caching

country_info.cache()
country_info.count()  # Force caching

print(f"Created dataset with {connection_logs.count()} connection logs and {country_info.count()} country records")
print("Connection logs schema:")
connection_logs.printSchema()
print("\nCountry info schema:")
country_info.printSchema()

# Show sample data
print("\nSample connection logs:")
connection_logs.show(5)
print("\nCountry reference data:")
country_info.show()

## Check Dataset Sizes

It's important to verify that one dataset is much smaller than the other, making it a good candidate for broadcasting.

In [None]:
# TODO: Compare dataset sizes to confirm broadcasting is appropriate

# Get approximate sizes
from pyspark.sql import Row

def estimate_size_mb(df, sample_ratio=0.1):
    """Estimate dataframe size in MB based on sampling"""
    sampled = df.sample(withReplacement=False, fraction=sample_ratio)
    serialized_rows = sampled.rdd.map(lambda x: Row(**x.asDict()))
    row_size_bytes = serialized_rows.map(lambda x: len(str(x))).mean()
    total_size_bytes = row_size_bytes * df.count() / sample_ratio
    return total_size_bytes / (1024 * 1024)  # Convert to MB

# Estimate sizes
logs_size_mb = estimate_size_mb(connection_logs)
country_size_mb = estimate_size_mb(country_info)
size_ratio = logs_size_mb / country_size_mb if country_size_mb > 0 else float('inf')

print(f"Estimated connection logs size: {logs_size_mb:.2f} MB")
print(f"Estimated country info size: {country_size_mb:.2f} MB")
print(f"Size ratio (logs/country): {size_ratio:.2f}")

# Check against broadcast threshold
broadcast_threshold_mb = spark.conf.get("spark.sql.autoBroadcastJoinThreshold") 
broadcast_threshold_mb = int(broadcast_threshold_mb) / (1024 * 1024) if broadcast_threshold_mb != "-1" else 10  # Default 10MB

print(f"\nCurrent broadcast threshold: {broadcast_threshold_mb} MB")

if country_size_mb < broadcast_threshold_mb:
    print(f"The country_info DataFrame ({country_size_mb:.2f} MB) is below the broadcast threshold and can be automatically broadcast")
else:
    print(f"The country_info DataFrame ({country_size_mb:.2f} MB) is above the broadcast threshold and will require explicit broadcasting")

## Baseline Join (No Broadcasting)

First, let's perform a join without any broadcast hints to establish a baseline performance.

In [None]:
# TODO: Perform regular join and measure performance

# Disable automatic broadcasting for baseline test
spark.conf.set("spark.sql.autoBroadcastJoinThreshold", -1)

# Define the benchmark function
def benchmark_join(join_function, name):
    start_time = time.time()
    result = join_function()
    # Force execution
    count = result.count()
    end_time = time.time()
    execution_time = end_time - start_time
    
    print(f"{name} completed in {execution_time:.2f} seconds with {count} results")
    return result, execution_time

# Regular join function
def regular_join():
    return connection_logs \
        .join(country_info, connection_logs.country == country_info.country_code) \
        .groupBy("country_name", "region") \
        .agg(count("*").alias("connection_count"), 
             avg("duration_seconds").alias("avg_duration"))

# Run regular join benchmark
regular_result, regular_time = benchmark_join(regular_join, "Regular join (no broadcast)")
regular_result.show()

## Explicit Broadcast Join

Now let's use the `broadcast` hint to explicitly broadcast the smaller dataframe.

In [None]:
# TODO: Implement explicit broadcasting and measure performance
from pyspark.sql.functions import broadcast

# Explicit broadcast join function
def explicit_broadcast_join():
    return connection_logs \
        .join(broadcast(country_info), connection_logs.country == country_info.country_code) \
        .groupBy("country_name", "region") \
        .agg(count("*").alias("connection_count"), 
             avg("duration_seconds").alias("avg_duration"))

# Run explicit broadcast join benchmark
explicit_result, explicit_time = benchmark_join(explicit_broadcast_join, "Explicit broadcast join")
explicit_result.show()

## Automatic Broadcast Join

Finally, let's re-enable automatic broadcasting and see if Spark chooses to broadcast without explicit hints.

In [None]:
# TODO: Enable automatic broadcasting and measure performance

# Re-enable automatic broadcasting
spark.conf.set("spark.sql.autoBroadcastJoinThreshold", 10 * 1024 * 1024)  # 10 MB

# Automatic broadcast join function (same as regular but with auto-broadcast enabled)
def auto_broadcast_join():
    return connection_logs \
        .join(country_info, connection_logs.country == country_info.country_code) \
        .groupBy("country_name", "region") \
        .agg(count("*").alias("connection_count"), 
             avg("duration_seconds").alias("avg_duration"))

# Run automatic broadcast join benchmark
auto_result, auto_time = benchmark_join(auto_broadcast_join, "Automatic broadcast join")
auto_result.show()

## Examine Query Plans

Let's look at the query plans to confirm whether broadcasting is occurring.

In [None]:
# TODO: Compare execution plans

# Examine regular join plan
print("\nRegular Join Plan:")
regular_plan = connection_logs \
    .join(country_info, connection_logs.country == country_info.country_code) \
    .groupBy("country_name", "region") \
    .agg(count("*").alias("connection_count"))
    
regular_plan.explain()

# Examine explicit broadcast join plan
print("\nExplicit Broadcast Join Plan:")
explicit_plan = connection_logs \
    .join(broadcast(country_info), connection_logs.country == country_info.country_code) \
    .groupBy("country_name", "region") \
    .agg(count("*").alias("connection_count"))
    
explicit_plan.explain()

# Examine automatic broadcast join plan
print("\nAutomatic Broadcast Join Plan:")
auto_plan = connection_logs \
    .join(country_info, connection_logs.country == country_info.country_code) \
    .groupBy("country_name", "region") \
    .agg(count("*").alias("connection_count"))
    
auto_plan.explain()

## Performance Comparison

In [None]:
# TODO: Visualize and compare performance

# Plot performance comparison
join_types = ['Regular Join', 'Explicit Broadcast', 'Automatic Broadcast']
execution_times = [regular_time, explicit_time, auto_time]

# Calculate performance improvements
explicit_improvement = ((regular_time - explicit_time) / regular_time) * 100
auto_improvement = ((regular_time - auto_time) / regular_time) * 100

print(f"Explicit broadcast join improvement: {explicit_improvement:.2f}%")
print(f"Automatic broadcast join improvement: {auto_improvement:.2f}%")

# Create bar chart
plt.figure(figsize=(10, 6))
bars = plt.bar(join_types, execution_times, color=['red', 'green', 'blue'])

# Add labels on top of bars
for bar in bars:
    height = bar.get_height()
    plt.text(bar.get_x() + bar.get_width()/2., height + 0.02,
            f'{height:.2f}s', ha='center', va='bottom')

plt.ylabel('Execution Time (seconds)')
plt.title('Join Performance Comparison')
plt.ylim(0, max(execution_times) * 1.2)  # Add some headroom for labels
plt.grid(axis='y', linestyle='--', alpha=0.7)
plt.show()

## Discussion: When to Use Broadcasting

Broadcast joins are effective when:

1. One dataset is significantly smaller than the other
2. The small dataset fits in memory on each executor
3. The join is selective (not a cross join)

The default threshold in Spark for automatic broadcasting is 10MB, but this can be adjusted based on your cluster's available memory.

In [None]:
# TODO: Experiment with different broadcast thresholds

# Test with larger threshold
spark.conf.set("spark.sql.autoBroadcastJoinThreshold", 50 * 1024 * 1024)  # 50 MB
print(f"Set broadcast threshold to 50 MB")

# Run the join with larger threshold
large_threshold_result, large_threshold_time = benchmark_join(auto_broadcast_join, "Join with 50 MB threshold")

# Reset to default
spark.conf.set("spark.sql.autoBroadcastJoinThreshold", 10 * 1024 * 1024)  # 10 MB

# Check if the plan shows broadcast with the larger threshold
print("\nJoin Plan with 50 MB Threshold:")
connection_logs \
    .join(country_info, connection_logs.country == country_info.country_code) \
    .groupBy("country_name", "region") \
    .agg(count("*").alias("connection_count")) \
    .explain()

## Advanced: Broadcast Join Trade-offs

While broadcast joins can significantly improve performance, they come with limitations:

1. Memory limitations: The broadcast table must fit in memory
2. Broadcast overhead: The table must be distributed to all executors
3. Not suitable for frequent updates: Broadcasting is inefficient for tables that change frequently

In [None]:
# TODO: Discuss memory implications of broadcast joins

# Calculate memory needed for broadcasting with different dataset sizes
def calculate_broadcast_memory(row_count, row_size_bytes, executor_count):
    """Calculate memory needed to broadcast a dataset"""
    total_size_bytes = row_count * row_size_bytes
    # Each executor needs a copy plus overhead
    broadcast_overhead = 1.5  # 50% overhead for broadcast variables
    memory_needed_bytes = total_size_bytes * broadcast_overhead * executor_count
    return memory_needed_bytes / (1024 * 1024)  # Convert to MB

# Example for our country_info table
row_count = country_info.count()
avg_row_size = 200  # bytes, estimated
executor_counts = [2, 5, 10, 20, 50, 100]

print("Memory needed to broadcast country_info table:")
for executor_count in executor_counts:
    memory_mb = calculate_broadcast_memory(row_count, avg_row_size, executor_count)
    print(f"  With {executor_count} executors: {memory_mb:.2f} MB")

# Example for a larger dimension table
large_dim_rows = 1000000  # 1 million rows
large_row_size = 500  # bytes

print("\nMemory needed to broadcast a large dimension table (1M rows):")
for executor_count in executor_counts:
    memory_mb = calculate_broadcast_memory(large_dim_rows, large_row_size, executor_count)
    memory_gb = memory_mb / 1024
    print(f"  With {executor_count} executors: {memory_gb:.2f} GB")