# Challenge 2: Resolving Data Skew with Salting Techniques

## Task Description
In this challenge, we need to:
1. Implement key salting to distribute skewed data more evenly
2. Compare performance before and after salting
3. Apply the correct salt factor based on skew severity
4. Maintain data integrity while resolving skew

In [None]:
from pyspark.sql import SparkSession
from pyspark.sql.functions import *
from pyspark.sql.types import *
import matplotlib.pyplot as plt
import time
import random

# Create Spark session
spark = SparkSession.builder \
    .appName("Data Skew Resolution") \
    .master("local[*]") \
    .config("spark.sql.shuffle.partitions", 10) \
    .getOrCreate()

spark.sparkContext.setLogLevel("WARN")

## Load Skewed Connection Log Data

Let's recreate or load our skewed connection logs from Challenge 1:

In [None]:
# Generate skewed data similar to Challenge 1
countries = ["US", "UK", "DE", "FR", "CN", "IN", "BR", "JP", "CA", "AU"]
    
# Create sample data with skew
import random
from datetime import datetime, timedelta
    
data = []
    
# Generate connection records with skew
num_records = 100000
    
skewed_country = "US"  # This country will have most of the records
skew_percentage = 0.7  # 70% of records will be for this country
    
for i in range(num_records):
    # Determine country with skew
    if random.random() < skew_percentage:
        country = skewed_country
    else:
        country = random.choice([c for c in countries if c != skewed_country])
    
    # Create record
    timestamp = datetime.now() - timedelta(days=random.randint(0, 30), 
                                          hours=random.randint(0, 23),
                                          minutes=random.randint(0, 59))
    
    data.append((
        f"user_{random.randint(1, 1000)}",  # user_id
        timestamp.isoformat(),              # timestamp
        country,                            # country
        f"10.{random.randint(0, 255)}.{random.randint(0, 255)}.{random.randint(0, 255)}",  # ip_address
        random.choice(["success", "failed"]), # status
        random.randint(1, 100)              # duration_seconds
    ))

# Create DataFrame
columns = ["user_id", "timestamp", "country", "ip_address", "status", "duration_seconds"]
connection_logs = spark.createDataFrame(data, columns)

# Also create the country info DataFrame
country_info = spark.createDataFrame([
    ("US", "United States", "North America", "English"),
    ("UK", "United Kingdom", "Europe", "English"),
    ("DE", "Germany", "Europe", "German"),
    ("FR", "France", "Europe", "French"),
    ("CN", "China", "Asia", "Chinese"),
    ("IN", "India", "Asia", "Hindi/English"),
    ("BR", "Brazil", "South America", "Portuguese"),
    ("JP", "Japan", "Asia", "Japanese"),
    ("CA", "Canada", "North America", "English/French"),
    ("AU", "Australia", "Oceania", "English")
], ["country_code", "country_name", "region", "language"])

# Cache for better performance
connection_logs.cache()
country_info.cache()

print(f"Created {connection_logs.count()} sample records with skewed distribution")
connection_logs.show(5)

## Baseline Performance (Without Salting)

Let's measure the performance without any skew mitigation:

In [None]:
# Define a benchmark operation that will be affected by skew
def benchmark_join_operation(logs_df, country_df, description):
    start_time = time.time()
    
    # Perform join and aggregation
    result = logs_df \
        .join(country_df, logs_df.country == country_df.country_code) \
        .groupBy("country_name", "region") \
        .agg(count("*").alias("connection_count"))
    
    # Force execution
    result.collect()
    
    execution_time = time.time() - start_time
    print(f"{description}: {execution_time:.2f} seconds")
    
    return result, execution_time

# Run baseline benchmark
baseline_result, baseline_time = benchmark_join_operation(
    connection_logs, 
    country_info,
    "Baseline join time (without salting)"
)

baseline_result.show()

## Implement Key Salting

In [None]:
# TODO: Implement a key salting strategy for the skewed country

# First identify the skewed keys
country_distribution = connection_logs \
    .groupBy("country") \
    .count() \
    .orderBy(desc("count"))

country_distribution.show()

# Get the most skewed country and count
skewed_countries = country_distribution.collect()
most_skewed_country = skewed_countries[0]["country"]
most_skewed_count = skewed_countries[0]["count"]
total_records = connection_logs.count()

print(f"Most skewed country: {most_skewed_country} with {most_skewed_count} records")
print(f"Percentage: {(most_skewed_count / total_records) * 100:.2f}%")

# Determine appropriate salt factor based on skew
skew_ratio = most_skewed_count / (total_records / 10)  # Assuming even distribution across 10 countries
salt_factor = max(2, min(10, int(skew_ratio)))  # Limit between 2 and 10

print(f"Using salt factor: {salt_factor} based on skew ratio: {skew_ratio:.2f}")

# TODO: Define a function to add salting to skewed keys
def salt_skewed_key(df, skewed_key_col, skewed_key_value, salt_factor):
    """
    Add salting to skewed keys by creating multiple virtual keys.
    
    Args:
        df: Input DataFrame
        skewed_key_col: Column name containing the skewed key
        skewed_key_value: Value of the key that is skewed
        salt_factor: Number of virtual keys to create
        
    Returns:
        DataFrame with salted keys
    """
    # Add a salt value (0 to salt_factor-1) for skewed keys
    return df.withColumn(
        "salted_key",
        when(
            col(skewed_key_col) == skewed_key_value,
            concat(col(skewed_key_col), lit("_"), (rand() * salt_factor).cast("int").cast("string"))
        ).otherwise(col(skewed_key_col))
    )

# Apply salting to the connection logs
salted_logs = salt_skewed_key(connection_logs, "country", most_skewed_country, salt_factor)

# Also salt the country info for the join
expanded_country_info = country_info

# For the skewed country, create multiple rows with salted keys
skewed_country_row = country_info.filter(col("country_code") == most_skewed_country).collect()[0]
salt_rows = []

for i in range(salt_factor):
    salt_rows.append((f"{most_skewed_country}_{i}", 
                      skewed_country_row["country_name"], 
                      skewed_country_row["region"], 
                      skewed_country_row["language"]))

# Create DataFrame with salt rows
salt_df = spark.createDataFrame(salt_rows, country_info.schema)

# Remove the original skewed country row and union with salted rows
expanded_country_info = country_info.filter(col("country_code") != most_skewed_country) \
    .union(salt_df)

# Show the salted data
print("Sample of salted connection logs:")
salted_logs.select("country", "salted_key").show(10)

print("Expanded country info:")
expanded_country_info.show()

## Check Distribution After Salting

In [None]:
# TODO: Verify that salting has improved the distribution
salted_distribution = salted_logs \
    .groupBy("salted_key") \
    .count() \
    .orderBy(desc("count"))

salted_distribution.show(20)

# Visualize the before and after distribution
country_data = country_distribution.collect()
salted_data = salted_distribution.collect()

# Extract data for plotting
original_keys = [row["country"] for row in country_data]
original_counts = [row["count"] for row in country_data]

salted_keys = [row["salted_key"] for row in salted_data]
salted_counts = [row["count"] for row in salted_data]

# Create a side-by-side bar chart
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 6))

ax1.bar(original_keys, original_counts)
ax1.set_title('Original Distribution')
ax1.set_xlabel('Country')
ax1.set_ylabel('Count')
ax1.tick_params(axis='x', rotation=45)

ax2.bar(salted_keys, salted_counts)
ax2.set_title('Salted Distribution')
ax2.set_xlabel('Salted Key')
ax2.tick_params(axis='x', rotation=90)

plt.tight_layout()
plt.show()

## Salted Join Performance

In [None]:
# TODO: Measure performance with salting

# Define a function for salted join
def salted_join_benchmark(salted_logs, expanded_country_info):
    start_time = time.time()
    
    # Join on the salted key
    result = salted_logs \
        .join(expanded_country_info, salted_logs.salted_key == expanded_country_info.country_code) \
        .groupBy("country_name", "region") \
        .agg(count("*").alias("connection_count"))
    
    # Force execution
    result.collect()
    
    execution_time = time.time() - start_time
    print(f"Salted join time: {execution_time:.2f} seconds")
    
    return result, execution_time

# Run the salted join benchmark
salted_result, salted_time = salted_join_benchmark(salted_logs, expanded_country_info)

# Compare results
print("\nResults verification (should match):")
print("Baseline result:")
baseline_result.show()

print("Salted result:")
salted_result.show()

# Calculate improvement
improvement = ((baseline_time - salted_time) / baseline_time) * 100
print(f"\nPerformance improvement: {improvement:.2f}%")

## Additional Salting Approaches

In [None]:
# TODO: Implement an alternative salting approach
# Instead of random distribution, use a deterministic salt based on a different field

def deterministic_salt(df, skewed_key_col, skewed_key_value, secondary_field, salt_factor):
    """
    Add salting to skewed keys based on a hash of another column.
    This ensures consistent salting when data is reprocessed.
    
    Args:
        df: Input DataFrame
        skewed_key_col: Column name containing the skewed key
        skewed_key_value: Value of the key that is skewed
        secondary_field: Another column to use for consistent salting
        salt_factor: Number of virtual keys to create
        
    Returns:
        DataFrame with deterministically salted keys
    """
    # Hash the secondary field to get consistent salting
    return df.withColumn(
        "salted_key",
        when(
            col(skewed_key_col) == skewed_key_value,
            concat(
                col(skewed_key_col), 
                lit("_"), 
                (abs(hash(col(secondary_field))) % salt_factor).cast("string")
            )
        ).otherwise(col(skewed_key_col))
    )

# Apply deterministic salting using user_id
det_salted_logs = deterministic_salt(
    connection_logs, 
    "country", 
    most_skewed_country, 
    "user_id", 
    salt_factor
)

# Check distribution
det_salted_distribution = det_salted_logs \
    .groupBy("salted_key") \
    .count() \
    .orderBy(desc("count"))

det_salted_distribution.show(20)

# Benchmark deterministic salting
det_salted_result, det_salted_time = salted_join_benchmark(det_salted_logs, expanded_country_info)

# Calculate improvement
det_improvement = ((baseline_time - det_salted_time) / baseline_time) * 100
print(f"\nDeterministic salting improvement: {det_improvement:.2f}%")

## Performance Comparison

In [None]:
# TODO: Visualize the performance differences

approaches = ['Baseline (No Salting)', 'Random Salting', 'Deterministic Salting']
times = [baseline_time, salted_time, det_salted_time]

plt.figure(figsize=(10, 6))
bars = plt.bar(approaches, times, color=['red', 'green', 'blue'])

# Add labels on top of bars
for bar in bars:
    height = bar.get_height()
    plt.text(bar.get_x() + bar.get_width()/2., height + 0.02,
            f'{height:.2f}s', ha='center', va='bottom')

plt.ylabel('Execution Time (seconds)')
plt.title('Performance Comparison of Skew Handling Approaches')
plt.ylim(0, max(times) * 1.2)  # Add some headroom for labels
plt.grid(axis='y', linestyle='--', alpha=0.7)
plt.show()