# Challenge 1: Detecting Data Skew in Connection Logs

## Task Description
In this challenge, we need to:
1. Analyze connection log data to identify skew
2. Detect imbalanced distributions across countries/regions
3. Visualize and quantify the skew
4. Understand the impact on processing performance

In [None]:
from pyspark.sql import SparkSession
from pyspark.sql.functions import *
from pyspark.sql.types import *
import matplotlib.pyplot as plt
import time

# Create Spark session
spark = SparkSession.builder \
    .appName("Data Skew Detection") \
    .master("local[*]") \
    .config("spark.sql.shuffle.partitions", 10) \
    .getOrCreate()

spark.sparkContext.setLogLevel("WARN")

## Load Connection Log Data

First, let's connect to our database to load VPN connection logs:

In [None]:
# PostgreSQL connection parameters
jdbc_url = "jdbc:postgresql://postgres:5432/datamart"
connection_properties = {
    "user": "spark",
    "password": "spark",
    "driver": "org.postgresql.Driver"
}

# TODO: Load connection logs from PostgreSQL
# If your table doesn't exist yet, create sample data instead

# Option 1: Load from database if table exists
try:
    connection_logs = spark.read \
        .format("jdbc") \
        .option("url", jdbc_url) \
        .option("dbtable", "raw.connection_logs") \
        .option("user", connection_properties["user"]) \
        .option("password", connection_properties["password"]) \
        .option("driver", connection_properties["driver"]) \
        .load()
    
    print(f"Loaded {connection_logs.count()} records from database")
    
except:
    print("Table not found, creating sample data instead")
    
    # Option 2: Create sample data with skew
    # We'll create data with heavy skew toward certain countries
    countries = ["US", "UK", "DE", "FR", "CN", "IN", "BR", "JP", "CA", "AU"]
    
    # Create sample data with skew
    import random
    from datetime import datetime, timedelta
    
    data = []
    
    # Generate connection records with skew
    num_records = 100000
    
    skewed_country = "US"  # This country will have most of the records
    skew_percentage = 0.7  # 70% of records will be for this country
    
    for i in range(num_records):
        # Determine country with skew
        if random.random() < skew_percentage:
            country = skewed_country
        else:
            country = random.choice([c for c in countries if c != skewed_country])
        
        # Create record
        timestamp = datetime.now() - timedelta(days=random.randint(0, 30), 
                                              hours=random.randint(0, 23),
                                              minutes=random.randint(0, 59))
        
        data.append((
            f"user_{random.randint(1, 1000)}",  # user_id
            timestamp.isoformat(),              # timestamp
            country,                            # country
            f"10.{random.randint(0, 255)}.{random.randint(0, 255)}.{random.randint(0, 255)}",  # ip_address
            random.choice(["success", "failed"]), # status
            random.randint(1, 100)              # duration_seconds
        ))
    
    # Create DataFrame
    columns = ["user_id", "timestamp", "country", "ip_address", "status", "duration_seconds"]
    connection_logs = spark.createDataFrame(data, columns)
    
    print(f"Created {connection_logs.count()} sample records with skewed distribution")

# Cache for better performance
connection_logs.cache()
connection_logs.show(5)

## Detect Skew by Country

In [None]:
# TODO: Analyze data distribution by country
country_distribution = connection_logs \
    .groupBy("country") \
    .count() \
    .orderBy(desc("count"))

# Calculate skew metrics
total_records = connection_logs.count()
country_distribution_with_pct = country_distribution \
    .withColumn("percentage", (col("count") / lit(total_records)) * 100) \
    .withColumn("percentage", round(col("percentage"), 2))

country_distribution_with_pct.show()

## Visualize the Skew

In [None]:
# TODO: Create a bar chart showing record count by country
# Collect data for plotting
plot_data = country_distribution.collect()
countries = [row["country"] for row in plot_data]
counts = [row["count"] for row in plot_data]

# Plot
plt.figure(figsize=(12, 6))
plt.bar(countries, counts)
plt.title('Connection Log Distribution by Country')
plt.xlabel('Country')
plt.ylabel('Number of Records')
plt.xticks(rotation=45)
plt.tight_layout()
plt.show()

## Detect Partition Skew

In [None]:
# TODO: Analyze how the skew affects partitioning
# Create a larger number of partitions to observe skew
num_partitions = 8

# Repartition based on skewed key (country)
skewed_partitions = connection_logs.repartition(num_partitions, "country")

# Add partition ID for analysis
partition_counts = skewed_partitions \
    .withColumn("partition_id", spark_partition_id()) \
    .groupBy("partition_id", "country") \
    .count() \
    .orderBy("partition_id", desc("count"))

partition_counts.show(num_partitions * 3)  # Show multiple rows per partition

## Measure Performance Impact of Skew

In [None]:
# TODO: Measure execution time for an operation on skewed data
from pyspark.sql.window import Window

# Create a second dataset to join with
country_info = spark.createDataFrame([
    ("US", "United States", "North America", "English"),
    ("UK", "United Kingdom", "Europe", "English"),
    ("DE", "Germany", "Europe", "German"),
    ("FR", "France", "Europe", "French"),
    ("CN", "China", "Asia", "Chinese"),
    ("IN", "India", "Asia", "Hindi/English"),
    ("BR", "Brazil", "South America", "Portuguese"),
    ("JP", "Japan", "Asia", "Japanese"),
    ("CA", "Canada", "North America", "English/French"),
    ("AU", "Australia", "Oceania", "English")
], ["country_code", "country_name", "region", "language"])

# Time a standard join (will be affected by skew)
start_time = time.time()

skewed_join = connection_logs \
    .join(country_info, connection_logs.country == country_info.country_code) \
    .groupBy("country_name", "region") \
    .agg(count("*").alias("connection_count"))

# Force execution
skewed_join.collect()

skewed_join_time = time.time() - start_time
print(f"Time to perform join on skewed data: {skewed_join_time:.2f} seconds")

# Display results
skewed_join.show()

## Quantify Skew Using Standard Deviation

In [None]:
# TODO: Calculate a skew factor using statistical methods
# Collect partition statistics
partition_stats = skewed_partitions \
    .withColumn("partition_id", spark_partition_id()) \
    .groupBy("partition_id") \
    .count() \
    .agg(
        avg("count").alias("avg_records_per_partition"),
        stddev("count").alias("stddev_records"),
        min("count").alias("min_records"),
        max("count").alias("max_records")
    )

# Calculate coefficient of variation as a measure of skew
skew_stats = partition_stats \
    .withColumn("coefficient_of_variation", col("stddev_records") / col("avg_records_per_partition")) \
    .withColumn("max_to_min_ratio", col("max_records") / col("min_records"))

skew_stats.show()

# Get values for analysis
stats = skew_stats.collect()[0]
cv = stats["coefficient_of_variation"]
ratio = stats["max_to_min_ratio"]

print(f"Coefficient of variation: {cv:.2f}")
print(f"Max/Min ratio: {ratio:.2f}")

# Interpret results
if cv > 1.0:
    print("High skew detected (CV > 1.0)")
elif cv > 0.5:
    print("Moderate skew detected (0.5 < CV < 1.0)")
else:
    print("Low skew detected (CV < 0.5)")
    
if ratio > 10:
    print(f"Severe imbalance: busiest partition has {ratio:.1f}x more records than emptiest")