In [None]:
from pyspark.sql import SparkSession
from pyspark.sql.functions import *
import time
import random

try:
    spark = SparkSession.builder \
        .appName("Complex Spark Job") \
        .config("spark.sql.adaptive.enabled", "true") \
        .config("spark.sql.adaptive.coalescePartitions.enabled", "true") \
        .getOrCreate()

    print("Starting complex Spark job...")
    start_time = time.time()

    # Generate a larger dataset (1 million records)
    print("Step 1: Generating large dataset...")
    large_data = []
    for i in range(1000000):
        large_data.append((
            f"User_{i % 10000}",  # 10k unique users
            random.randint(18, 80),  # Age
            random.choice(['M', 'F']),  # Gender
            random.choice(['NY', 'CA', 'TX', 'FL', 'WA']),  # State
            random.uniform(1000, 10000),  # Salary
            random.choice(['Tech', 'Finance', 'Healthcare', 'Education', 'Retail'])  # Industry
        ))
    
    df = spark.createDataFrame(large_data, ["UserID", "Age", "Gender", "State", "Salary", "Industry"])
    df.cache()  # Cache for multiple operations
    
    print(f"Created dataset with {df.count():,} records")

    # Complex transformations and aggregations
    print("Step 2: Performing complex aggregations...")
    
    # Age group analysis
    df_with_age_groups = df.withColumn(
        "AgeGroup",
        when(col("Age") < 25, "18-24")
        .when(col("Age") < 35, "25-34")
        .when(col("Age") < 45, "35-44")
        .when(col("Age") < 55, "45-54")
        .otherwise("55+")
    )
    
    # Multiple complex aggregations
    age_stats = df_with_age_groups.groupBy("AgeGroup", "State", "Industry") \
        .agg(
            count("*").alias("Count"),
            avg("Salary").alias("AvgSalary"),
            stddev("Salary").alias("SalaryStdDev"),
            min("Salary").alias("MinSalary"),
            max("Salary").alias("MaxSalary")
        ) \
        .orderBy("AgeGroup", "State", "Industry")
    
    print("Age group statistics by state and industry:")
    age_stats.show(50)

    print("Step 3: Complex joins and window functions...")
    
    # Create a second dataset for joining
    salary_benchmarks = spark.createDataFrame([
        ("Tech", 8000), ("Finance", 7500), ("Healthcare", 6500),
        ("Education", 5000), ("Retail", 4500)
    ], ["Industry", "BenchmarkSalary"])
    
    # Join with salary benchmarks
    df_with_benchmark = df.join(salary_benchmarks, "Industry")
    
    # Window functions for ranking
    from pyspark.sql.window import Window
    
    window_spec = Window.partitionBy("State", "Industry").orderBy(desc("Salary"))
    
    df_ranked = df_with_benchmark.withColumn(
        "SalaryRank", 
        row_number().over(window_spec)
    ).withColumn(
        "SalaryPercentile",
        percent_rank().over(window_spec)
    ).withColumn(
        "SalaryVsBenchmark",
        round((col("Salary") / col("BenchmarkSalary") - 1) * 100, 2)
    )
    
    # Get top performers by state
    top_performers = df_ranked.filter(col("SalaryRank") <= 10)
    print("Top 10 salary performers by state and industry:")
    top_performers.select("State", "Industry", "UserID", "Salary", "SalaryVsBenchmark").show(100)

    print("Step 4: Complex statistical operations...")
    
    # Correlation analysis (computationally expensive)
    df_numeric = df.select("Age", "Salary")
    correlation = df_numeric.stat.corr("Age", "Salary")
    print(f"Age-Salary Correlation: {correlation:.4f}")
    
    # Cross-tabulation
    crosstab = df.stat.crosstab("Gender", "Industry")
    print("Gender vs Industry cross-tabulation:")
    crosstab.show()

    print("Step 5: Multiple data transformations...")
    
    # Create multiple derived columns
    df_enriched = df.withColumn("SalaryTier", 
        when(col("Salary") < 3000, "Low")
        .when(col("Salary") < 7000, "Medium")
        .otherwise("High")
    ).withColumn("IsHighEarner", col("Salary") > 8000) \
    .withColumn("NormalizedAge", (col("Age") - 18) / (80 - 18)) \
    .withColumn("SalaryPerAge", col("Salary") / col("Age"))
    
    # Final complex aggregation
    final_summary = df_enriched.groupBy("State") \
        .agg(
            count("*").alias("TotalUsers"),
            countDistinct("UserID").alias("UniqueUsers"),
            avg("Salary").alias("AvgSalary"),
            sum(when(col("IsHighEarner"), 1).otherwise(0)).alias("HighEarners"),
            collect_list("Industry").alias("Industries")
        )
    
    print("Final state summary:")
    final_summary.show(truncate=False)

    # Force evaluation with an action that processes all data
    print("Step 6: Final data processing...")
    total_records = df_enriched.count()
    high_earner_percentage = df_enriched.filter(col("IsHighEarner")).count() / total_records * 100
    
    end_time = time.time()
    processing_time = end_time - start_time
    
    print(f"\nJob completed!")
    print(f"Total records processed: {total_records:,}")
    print(f"High earner percentage: {high_earner_percentage:.2f}%")
    print(f"Processing time: {processing_time:.2f} seconds")
    
    # Clean up
    df.unpersist()

except Exception as e:
    print(f"An error occurred: {e}")
finally:
    time.sleep(5)  # Brief pause before cleanup
    # Stop the Spark session
    if 'spark' in locals():
        spark.stop()
        print("Spark session stopped")