# Interview Challenge 10: Custom UDFs, Joins & Data Skew Handling

## Problem Statement

You are tasked with building a complex analytics pipeline that requires custom business logic, multiple data source joins, and handling data skew issues. This challenge focuses on advanced PySpark operations that require deep understanding of the framework.

## Dataset Description

**Users Table:**
- `user_id` (string) - Unique user identifier
- `registration_date` (date) - When user registered
- `country` (string) - User country
- `subscription_tier` (string) - Free, Basic, Premium, Enterprise
- `last_login` (timestamp) - Last login timestamp

**Events Table:**
- `event_id` (string) - Unique event identifier
- `user_id` (string) - User who triggered event
- `event_type` (string) - login, purchase, feature_use, support_ticket
- `event_timestamp` (timestamp) - When event occurred
- `metadata` (string) - JSON string with additional data
- `value` (double) - Numeric value associated with event

**Products Table:**
- `product_id` (string) - Unique product identifier
- `name` (string) - Product name
- `category` (string) - Product category
- `price` (double) - Product price
- `created_date` (date) - When product was created

## Tasks

1. **Custom UDF Implementation**
   - Create UDF to parse JSON metadata and extract specific fields
   - Implement business logic UDF for user segmentation
   - Build custom aggregation UDF for complex metrics
   - Handle null values and edge cases in UDFs

2. **Complex Join Operations**
   - Perform multi-table joins (users → events → products)
   - Implement different join strategies (broadcast, sort-merge, shuffle-hash)
   - Handle join conditions with custom logic
   - Optimize join performance

3. **Data Skew Resolution**
   - Identify skewed data distributions
   - Implement salting techniques for skewed joins
   - Use custom partitioning to balance data
   - Apply broadcast joins for small tables

4. **Advanced Analytics**
   - Calculate user lifetime value with complex business rules
   - Implement cohort analysis
   - Build product recommendation logic
   - Generate predictive features

5. **Performance Optimization**
   - Implement proper caching strategies
   - Use appropriate data structures
   - Optimize memory usage
   - Monitor and tune performance

## Technical Requirements
- Implement custom UDFs using both SQL and DataFrame API
- Demonstrate different join strategies and their use cases
- Handle data skew with practical solutions
- Optimize for large-scale data processing
- Include proper error handling and validation
- Document performance considerations

## Key Concepts to Demonstrate
- UDF registration and usage
- Join optimization techniques
- Data skew detection and resolution
- Memory management
- Custom business logic implementation
- Performance monitoring

In [None]:
# Import required libraries
from pyspark.sql import SparkSession
from pyspark.sql.functions import *
from pyspark.sql.types import *
from pyspark.sql import Window
import json
import hashlib
from datetime import datetime, timedelta

# Create Spark session with optimizations
spark = SparkSession.builder \
    .appName("AdvancedUDFsJoinsChallenge") \
    .config("spark.sql.adaptive.enabled", "true") \
    .config("spark.sql.adaptive.coalescePartitions.enabled", "true") \
    .config("spark.sql.adaptive.skewJoin.enabled", "true") \
    .config("spark.sql.autoBroadcastJoinThreshold", "50MB") \
    .getOrCreate()

# Sample data
users_data = [
    ('user001', '2020-01-15', 'US', 'Premium', '2023-12-15 10:30:00'),
    ('user002', '2020-03-20', 'UK', 'Basic', '2023-12-14 15:45:00'),
    ('user003', '2020-02-10', 'US', 'Enterprise', '2023-12-15 09:15:00'),
    ('user004', '2020-04-05', 'DE', 'Free', '2023-12-13 14:20:00'),
    ('user005', '2020-01-30', 'US', 'Premium', '2023-12-15 11:00:00')
]

events_data = [
    ('evt001', 'user001', 'purchase', '2023-12-15 10:35:00', '{"product_id": "prod001", "quantity": 2}', 299.98),
    ('evt002', 'user001', 'login', '2023-12-15 10:30:00', '{}', 0.0),
    ('evt003', 'user002', 'feature_use', '2023-12-14 15:50:00', '{"feature": "analytics", "duration": 25}', 0.0),
    ('evt004', 'user003', 'purchase', '2023-12-15 09:20:00', '{"product_id": "prod002", "quantity": 1}', 149.99),
    ('evt005', 'user001', 'support_ticket', '2023-12-15 11:15:00', '{"category": "billing", "priority": "high"}', 0.0),
    ('evt006', 'user004', 'login', '2023-12-13 14:20:00', '{}', 0.0),
    ('evt007', 'user005', 'purchase', '2023-12-15 11:05:00', '{"product_id": "prod003", "quantity": 3}', 449.97)
]

products_data = [
    ('prod001', 'Laptop Pro', 'Electronics', 149.99, '2020-01-01'),
    ('prod002', 'Wireless Headphones', 'Electronics', 149.99, '2020-02-15'),
    ('prod003', 'Office Chair', 'Furniture', 149.99, '2020-03-10')
]

# Define schemas
users_schema = StructType([
    StructField('user_id', StringType(), True),
    StructField('registration_date', StringType(), True),
    StructField('country', StringType(), True),
    StructField('subscription_tier', StringType(), True),
    StructField('last_login', StringType(), True)
])

events_schema = StructType([
    StructField('event_id', StringType(), True),
    StructField('user_id', StringType(), True),
    StructField('event_type', StringType(), True),
    StructField('event_timestamp', StringType(), True),
    StructField('metadata', StringType(), True),
    StructField('value', DoubleType(), True)
])

products_schema = StructType([
    StructField('product_id', StringType(), True),
    StructField('name', StringType(), True),
    StructField('category', StringType(), True),
    StructField('price', DoubleType(), True),
    StructField('created_date', StringType(), True)
])

# Create DataFrames
users_df = spark.createDataFrame(users_data, users_schema)
events_df = spark.createDataFrame(events_data, events_schema)
products_df = spark.createDataFrame(products_data, products_schema)

# Convert date strings to proper types
users_df = users_df \
    .withColumn('registration_date', to_date('registration_date')) \
    .withColumn('last_login', to_timestamp('last_login'))

events_df = events_df.withColumn('event_timestamp', to_timestamp('event_timestamp'))
products_df = products_df.withColumn('created_date', to_date('created_date'))

print("Data loaded successfully!")

# === YOUR SOLUTION GOES HERE ===
# Implement advanced UDFs, joins, and skew handling

# Task 1: Custom UDFs
print("\n1. Implementing Custom UDFs:")

# UDF to parse JSON metadata and extract product_id
def extract_product_id(metadata):
    """Extract product_id from JSON metadata"""
    if not metadata or metadata == '{}':
        return None
    try:
        data = json.loads(metadata)
        return data.get('product_id')
    except (json.JSONDecodeError, TypeError):
        return None

# Register UDF
extract_product_udf = udf(extract_product_id, StringType())
spark.udf.register("extract_product_id", extract_product_id, StringType())

# UDF for user segmentation based on activity
def segment_user(activity_score, subscription_tier):
    """Segment users based on activity and subscription"""
    if subscription_tier == 'Enterprise':
        return 'Enterprise'
    elif activity_score > 100:
        return 'Power_User'
    elif activity_score > 50:
        return 'Regular'
    elif activity_score > 10:
        return 'Occasional'
    else:
        return 'Inactive'

segment_udf = udf(segment_user, StringType())

# Task 2: Enhanced events with extracted data
enhanced_events = events_df.withColumn(
    'product_id', extract_product_udf('metadata')
)

print("Enhanced events with extracted product IDs:")
enhanced_events.select('event_id', 'user_id', 'event_type', 'product_id', 'value').show()

# Task 3: Complex Joins
print("\n2. Performing Complex Joins:")

# Multi-table join: users -> events -> products
# First, check data sizes for join strategy
users_count = users_df.count()
events_count = enhanced_events.count()
products_count = products_df.count()

print(f"Data sizes - Users: {users_count}, Events: {events_count}, Products: {products_count}")

# Join users with events (likely large tables)
user_events = users_df.join(enhanced_events, 'user_id', 'left')
print(f"User-events join result: {user_events.count()} rows")

# Join with products (products table is small, use broadcast)
complete_data = user_events.join(
    broadcast(products_df),
    user_events.product_id == products_df.product_id,
    'left'
).drop(products_df.product_id)

print(f"Complete joined data: {complete_data.count()} rows")
complete_data.select('user_id', 'country', 'event_type', 'name', 'category', 'value').show(10)

# Task 4: Data Skew Handling
print("\n3. Handling Data Skew:")

# Check for skew in user events
user_event_counts = enhanced_events.groupBy('user_id').count().orderBy(desc('count'))
print("Top users by event count (checking for skew):")
user_event_counts.show(5)

# Implement salting for skewed joins
def add_salt(column, num_buckets=10):
    """Add salt to a column for skew handling"""
    return concat(col(column), lit('_'), (hash(col(column)) % num_buckets).cast('string'))

# If we had severe skew, we would use:
# salted_users = users_df.withColumn('salted_user_id', add_salt('user_id'))
# salted_events = enhanced_events.withColumn('salted_user_id', add_salt('user_id'))
# balanced_join = salted_users.join(salted_events, 'salted_user_id', 'left')

# Task 5: Advanced Analytics with Custom Logic
print("\n4. Advanced Analytics:")

# Calculate user activity scores
user_activity = enhanced_events.groupBy('user_id').agg(
    count('event_id').alias('total_events'),
    sum(when(col('event_type') == 'purchase', col('value')).otherwise(0)).alias('total_spent'),
    countDistinct(when(col('event_type') == 'login', col('event_timestamp').cast('date'))).alias('active_days'),
    max('event_timestamp').alias('last_activity')
)

# Calculate activity score (custom business logic)
user_activity = user_activity.withColumn(
    'activity_score',
    col('total_events') * 10 + 
    col('total_spent') * 0.1 + 
    col('active_days') * 5
)

# Apply user segmentation
user_segments = user_activity.withColumn(
    'segment', segment_udf('activity_score', lit('Premium'))  # Using default tier for demo
)

print("User segmentation results:")
user_segments.select('user_id', 'total_events', 'total_spent', 'activity_score', 'segment').show()

# Task 6: Cohort Analysis
print("\n5. Cohort Analysis:")

# Join with user registration data
cohort_data = user_segments.join(
    users_df.select('user_id', 'registration_date', 'subscription_tier'),
    'user_id'
).withColumn(
    'cohort_month', date_format('registration_date', 'yyyy-MM')
).withColumn(
    'activity_month', date_format('last_activity', 'yyyy-MM')
)

# Calculate cohort retention
cohort_retention = cohort_data.groupBy('cohort_month', 'activity_month').agg(
    countDistinct('user_id').alias('active_users'),
    avg('activity_score').alias('avg_activity_score')
).orderBy('cohort_month', 'activity_month')

print("Cohort analysis results:")
cohort_retention.show(10)

# Task 7: Performance Optimization
print("\n6. Performance Optimizations:")

# Cache frequently used DataFrames
users_df.cache()
enhanced_events.cache()

# Force cache materialization
users_cached = users_df.count()
events_cached = enhanced_events.count()

print(f"Cached {users_cached} users and {events_cached} events")

# Use appropriate partitioning for large datasets
# In production, you would repartition based on expected query patterns:
# optimized_data = complete_data.repartition('country', 'event_date')

# Final results summary
print("\n=== FINAL RESULTS ===")
print(f"Total users processed: {users_df.count()}")
print(f"Total events processed: {enhanced_events.count()}")
print(f"Total products: {products_df.count()}")
print(f"Joined records: {complete_data.count()}")

# Show sample of final enriched data
print("\nSample enriched user events:")
complete_data.select(
    'user_id', 'country', 'subscription_tier', 
    'event_type', 'name', 'category', 'value', 'event_timestamp'
).orderBy('user_id', 'event_timestamp').show(10)

# Cleanup
spark.catalog.clearCache()
print("\nCache cleared, processing complete!")

print("\n✅ Advanced UDFs and Joins Challenge completed!")
print("Key Learnings:")
print("- UDFs enable custom business logic in distributed processing")
print("- Join strategies (broadcast vs sort-merge) depend on data sizes")
print("- Salting helps handle data skew in distributed joins")
print("- Caching optimizes repeated access to intermediate results")
print("- Custom partitioning improves query performance")
