In [0]:
from pyspark.sql.functions import *
from pyspark.sql.window import Window
import time
from pyspark.sql import SparkSession
import os

spark = SparkSession.builder.appName("MusicListeningDataPipeline").getOrCreate()


print("✅ Spark session created successfully.")

# Initialize Spark Session with optimized configurations
spark = (
    SparkSession.builder.appName("COVID-19 Data Pipeline")
    .config("spark.sql.adaptive.enabled", "true")
    .config("spark.sql.adaptive.coalescePartitions.enabled", "true")
    .getOrCreate()
)

print("=" * 80)
print("COVID-19 DATA PROCESSING PIPELINE")
print("=" * 80)

✅ Spark session created successfully.
COVID-19 DATA PROCESSING PIPELINE


In [0]:
# ============================================================================
# LOAD DATA WITH DEFAULT PARTITIONING
# ============================================================================

# Start timing
start_time = time.time()

# Read the dataset
covid_df = spark.read.csv(
    "/databricks-datasets/COVID/covid-19-data/", header=True, inferSchema=True
)

# Trigger action to actually load the data (transformations are lazy!)
row_count = covid_df.count()

# Calculate elapsed time
load_time = time.time() - start_time

# Check the data
print(f"Total rows: {covid_df.count():,}")
print(f"⏱️  Time to load data: {load_time:.5f} seconds")
covid_df.printSchema()
covid_df.show(5)

Total rows: 1,227,256
⏱️  Time to load data: 1.78491 seconds
root
 |-- date: string (nullable = true)
 |-- county: string (nullable = true)
 |-- state: string (nullable = true)
 |-- fips: string (nullable = true)
 |-- cases: string (nullable = true)
 |-- deaths: string (nullable = true)

+----------+---------+----------+-----+-----+------+
|      date|   county|     state| fips|cases|deaths|
+----------+---------+----------+-----+-----+------+
|2020-01-21|Snohomish|Washington|53061|    1|     0|
|2020-01-22|Snohomish|Washington|53061|    1|     0|
|2020-01-23|Snohomish|Washington|53061|    1|     0|
|2020-01-24|     Cook|  Illinois|17031|    1|     0|
|2020-01-24|Snohomish|Washington|53061|    1|     0|
+----------+---------+----------+-----+-----+------+
only showing top 5 rows


In [0]:
# ============================================================================
# LOAD DATA WITH SPECIFIED NUMBER OF PARTITIONS
# ============================================================================
# Start timing
start_time = time.time()

# Read the dataset with specified number of partitions
covid_df = spark.read.csv(
    "/databricks-datasets/COVID/covid-19-data/", header=True, inferSchema=True
).repartition(
    8
)  # Define number of partitions (adjust based on cluster size)

# Trigger action to actually load the data (transformations are lazy!)
row_count = covid_df.count()

# Calculate elapsed time
load_time = time.time() - start_time

# Check the data
print(f"Total rows: {row_count:,}")
# print(f"Number of partitions: {covid_df.rdd.getNumPartitions()}")
print(f"⏱️  Time to load data: {load_time:.5f} seconds")
covid_df.printSchema()
covid_df.show(5)

Total rows: 1,227,256
⏱️  Time to load data: 2.20983 seconds
root
 |-- date: string (nullable = true)
 |-- county: string (nullable = true)
 |-- state: string (nullable = true)
 |-- fips: string (nullable = true)
 |-- cases: string (nullable = true)
 |-- deaths: string (nullable = true)

+----------+-----------+----------+-----+-----+------+
|      date|     county|     state| fips|cases|deaths|
+----------+-----------+----------+-----+-----+------+
|2020-01-25|     Orange|California|06059|    1|     0|
|2020-01-27|Los Angeles|California|06037|    1|     0|
|2020-01-28|  Snohomish|Washington|53061|    1|     0|
|2020-01-30|   Maricopa|   Arizona|04013|    1|     0|
|2020-01-31|  Snohomish|Washington|53061|    1|     0|
+----------+-----------+----------+-----+-----+------+
only showing top 5 rows


In [0]:
# Checkthe size of the dataset
def get_directory_size(path):
    total_size = 0
    try:
        files = dbutils.fs.ls(path)
        for file in files:
            if file.isDir():
                total_size += get_directory_size(file.path)
            else:
                total_size += file.size
    except Exception as e:
        print(f"Error: {e}")
    return total_size


path = "/databricks-datasets/COVID/covid-19-data/"
size_bytes = get_directory_size(path)
size_mb = size_bytes / (1024 * 1024)
size_gb = size_bytes / (1024 * 1024 * 1024)
print(f"Total size: {size_bytes:,} bytes")
print(f"Total size: {size_mb:.2f} MB")
print(f"Total size: {size_gb:.2f} GB")

Total size: 2,567,706,254 bytes
Total size: 2448.76 MB
Total size: 2.39 GB


In [0]:
# ============================================================================
# CREATE REFERENCE DATA FOR JOIN
# ============================================================================
print("\n" + "=" * 80)
print("CREATING REFERENCE DATA")
print("=" * 80)

# Create a second dataset for join operation (simulate external data)
# This represents state population data
state_population = spark.createDataFrame(
    [
        ("California", 39538223, "West"),
        ("New York", 20201249, "Northeast"),
        ("Texas", 29145505, "South"),
        ("Florida", 21538187, "South"),
        ("Washington", 7705281, "West"),
        ("Illinois", 12812508, "Midwest"),
    ],
    ["state", "population", "region"],
)

print(f"State population reference data loaded: {state_population.count()} states")
state_population.show()


CREATING REFERENCE DATA
State population reference data loaded: 6 states
+----------+----------+---------+
|     state|population|   region|
+----------+----------+---------+
|California|  39538223|     West|
|  New York|  20201249|Northeast|
|     Texas|  29145505|    South|
|   Florida|  21538187|    South|
|Washington|   7705281|     West|
|  Illinois|  12812508|  Midwest|
+----------+----------+---------+



In [0]:
# ============================================================================
# VERSION 1: EAGER/INEFFICIENT APPROACH ❌
# ============================================================================
print("\n" + "=" * 80)
print("VERSION 1: EAGER/INEFFICIENT APPROACH (Bad Practices)")
print("=" * 80)
print("Problems:")
print("  ❌ GroupBy BEFORE filtering (processes all 1.2M rows)")
print("  ❌ Join on full dataset BEFORE filtering")
print("  ❌ Column transformations on large dataset")
print("  ❌ Multiple shuffles due to poor ordering")
print("  ❌ No early data reduction")

eager_start = time.time()

# Step 1: GroupBy FIRST (BAD - processes all data)
print("\n[Step 1] Performing groupBy on ENTIRE dataset...")
step1_start = time.time()
eager_grouped = covid_df.groupBy("state", "county", "date").agg(
    sum("cases").alias("total_cases"), sum("deaths").alias("total_deaths")
)
step1_time = time.time() - step1_start
print(f"  Time: {step1_time:.2f}s | Rows: {eager_grouped.count():,}")

# Step 2: Add column transformations BEFORE filtering (BAD - transforms all rows)
print("\n[Step 2] Adding column transformations on full dataset...")
step2_start = time.time()
eager_transformed = (
    eager_grouped.withColumn(
        "mortality_rate",
        when(
            col("total_cases") > 0, (col("total_deaths") / col("total_cases")) * 100
        ).otherwise(0),
    )
    .withColumn("year", year(col("date")))
    .withColumn("month", month(col("date")))
    .withColumn(
        "case_category",
        when(col("total_cases") < 100, "Low")
        .when(col("total_cases") < 1000, "Medium")
        .otherwise("High"),
    )
)
step2_time = time.time() - step2_start
print(f"  Time: {step2_time:.2f}s")

# Step 3: Join BEFORE filtering (BAD - joins large dataset)
print("\n[Step 3] Joining with population data on full dataset...")
step3_start = time.time()
eager_joined = eager_transformed.join(state_population, on="state", how="left")
step3_time = time.time() - step3_start
print(f"  Time: {step3_time:.2f}s | Rows: {eager_joined.count():,}")

# Step 4: Filter LAST (BAD - after all expensive operations)
print("\n[Step 4] Finally filtering (after all expensive operations)...")
step4_start = time.time()
eager_filtered = eager_joined.filter(
    (col("year") == 2020)
    & (col("state").isin("California", "New York", "Texas", "Florida", "Washington"))
)
step4_time = time.time() - step4_start
print(f"  Time: {step4_time:.2f}s | Rows: {eager_filtered.count():,}")

# Step 5: Another filter (BAD - separate operation causes another pass)
print("\n[Step 5] Second filter (separate operation)...")
step5_start = time.time()
eager_final = eager_filtered.filter(col("total_cases") >= 10)
step5_time = time.time() - step5_start
print(f"  Time: {step5_time:.2f}s | Rows: {eager_final.count():,}")

# Step 6: Final aggregation with another groupBy (causes another shuffle)
print("\n[Step 6] Final aggregation (another shuffle)...")
step6_start = time.time()
eager_result = (
    eager_final.groupBy("state", "region", "year", "month")
    .agg(
        sum("total_cases").alias("monthly_cases"),
        sum("total_deaths").alias("monthly_deaths"),
        avg("mortality_rate").alias("avg_mortality_rate"),
        count("*").alias("counties_reported"),
    )
    .orderBy("state", "year", "month")
)

# Trigger execution
eager_count = eager_result.count()
step6_time = time.time() - step6_start
print(f"  Time: {step6_time:.2f}s | Final rows: {eager_count:,}")

eager_total = time.time() - eager_start

print("\n" + "-" * 80)
print("EAGER APPROACH SUMMARY:")
print(f"  Total execution time: {eager_total:.2f} seconds")
print(f"  Number of shuffles: ~4-5 (groupBy, join, filters, final groupBy)")
print(f"  Data processed: Full dataset multiple times")
print("-" * 80)

# Show sample results
print("\nSample results (Eager approach):")
eager_result.show(10, truncate=False)

# Show execution plan
print("\nEXECUTION PLAN (Eager - Notice multiple stages and shuffles):")
eager_result.explain(mode="simple")


VERSION 1: EAGER/INEFFICIENT APPROACH (Bad Practices)
Problems:
  ❌ GroupBy BEFORE filtering (processes all 1.2M rows)
  ❌ Join on full dataset BEFORE filtering
  ❌ Column transformations on large dataset
  ❌ Multiple shuffles due to poor ordering
  ❌ No early data reduction

[Step 1] Performing groupBy on ENTIRE dataset...
  Time: 0.00s | Rows: 1,133,113

[Step 2] Adding column transformations on full dataset...
  Time: 0.00s

[Step 3] Joining with population data on full dataset...
  Time: 0.00s | Rows: 1,133,113

[Step 4] Finally filtering (after all expensive operations)...
  Time: 0.00s | Rows: 131,570

[Step 5] Second filter (separate operation)...
  Time: 0.00s | Rows: 115,697

[Step 6] Final aggregation (another shuffle)...
  Time: 0.81s | Final rows: 52

--------------------------------------------------------------------------------
EAGER APPROACH SUMMARY:
  Total execution time: 4.29 seconds
  Number of shuffles: ~4-5 (groupBy, join, filters, final groupBy)
  Data processed

In [0]:
# ============================================================================
# VERSION 2: LAZY/OPTIMIZED APPROACH ✅
# ============================================================================
print("\n" + "=" * 80)
print("VERSION 2: LAZY/OPTIMIZED APPROACH (Best Practices)")
print("=" * 80)
print("Optimizations:")
print("  ✅ Filter EARLY (reduce data volume immediately)")
print("  ✅ Combine filters together (single pass)")
print("  ✅ Join AFTER filtering (smaller dataset)")
print("  ✅ Minimize shuffles through intelligent ordering")
print("  ✅ Use appropriate partitioning")

lazy_start = time.time()

# Step 1: Filter FIRST (GOOD - reduce data early)
print("\n[Step 1] Filtering data FIRST (early reduction)...")
step1_start = time.time()
lazy_filtered = covid_df.filter(
    (year(col("date")) == 2020)
    & (col("state").isin("California", "New York", "Texas", "Florida", "Washington"))
    & (col("cases") >= 10)  # Combine multiple filters
)
# Repartition after filtering for optimal parallelism
lazy_filtered = lazy_filtered.repartition(4, "state")
step1_time = time.time() - step1_start
print(f"  Time: {step1_time:.2f}s | Rows after filter: {lazy_filtered.count():,}")

# Step 2: Column transformations on FILTERED data (GOOD - fewer rows)
print("\n[Step 2] Adding column transformations on filtered data...")
step2_start = time.time()
lazy_transformed = (
    lazy_filtered.withColumn("cases_int", col("cases").cast("integer"))
    .withColumn("deaths_int", col("deaths").cast("integer"))
    .withColumn(
        "mortality_rate",
        when(
            col("cases_int") > 0, (col("deaths_int") / col("cases_int")) * 100
        ).otherwise(0),
    )
    .withColumn("year", year(col("date")))
    .withColumn("month", month(col("date")))
    .withColumn(
        "case_category",
        when(col("cases_int") < 100, "Low")
        .when(col("cases_int") < 1000, "Medium")
        .otherwise("High"),
    )
)
step2_time = time.time() - step2_start
print(f"  Time: {step2_time:.2f}s")

# Step 3: GroupBy on FILTERED data (GOOD - much less data to shuffle)
print("\n[Step 3] GroupBy on filtered dataset...")
step3_start = time.time()
lazy_grouped = lazy_transformed.groupBy("state", "county", "date", "year", "month").agg(
    sum("cases_int").alias("total_cases"),
    sum("deaths_int").alias("total_deaths"),
    avg("mortality_rate").alias("avg_mortality_rate"),
    first("case_category").alias("case_category"),
)
step3_time = time.time() - step3_start
print(f"  Time: {step3_time:.2f}s | Rows: {lazy_grouped.count():,}")

# Step 4: Join on SMALL dataset (GOOD - broadcast join possible)
print("\n[Step 4] Joining with population data (small dataset)...")
step4_start = time.time()
# Broadcast the small dimension table
lazy_joined = lazy_grouped.join(broadcast(state_population), on="state", how="left")
step4_time = time.time() - step4_start
print(f"  Time: {step4_time:.2f}s")

# Step 5: Final aggregation (GOOD - single optimized operation)
print("\n[Step 5] Final aggregation...")
step5_start = time.time()
lazy_result = (
    lazy_joined.groupBy("state", "region", "year", "month", "population")
    .agg(
        sum("total_cases").alias("monthly_cases"),
        sum("total_deaths").alias("monthly_deaths"),
        avg("avg_mortality_rate").alias("avg_mortality_rate"),
        count("*").alias("counties_reported"),
    )
    .withColumn("cases_per_100k", (col("monthly_cases") / col("population")) * 100000)
    .orderBy("state", "year", "month")
)

# Trigger execution
lazy_count = lazy_result.count()
step5_time = time.time() - step5_start
print(f"  Time: {step5_time:.2f}s | Final rows: {lazy_count:,}")

lazy_total = time.time() - lazy_start

print("\n" + "-" * 80)
print("LAZY/OPTIMIZED APPROACH SUMMARY:")
print(f"  Total execution time: {lazy_total:.2f} seconds")
print(f"  Number of shuffles: ~2 (one groupBy, one final aggregation)")
print(f"  Data processed: Filtered dataset only")
print("-" * 80)

# Show sample results
print("\nSample results (Optimized approach):")
lazy_result.show(10, truncate=False)

# Show execution plan
print("\nEXECUTION PLAN (Optimized - Notice fewer stages):")
lazy_result.explain(mode="simple")


VERSION 2: LAZY/OPTIMIZED APPROACH (Best Practices)
Optimizations:
  ✅ Filter EARLY (reduce data volume immediately)
  ✅ Combine filters together (single pass)
  ✅ Join AFTER filtering (smaller dataset)
  ✅ Minimize shuffles through intelligent ordering
  ✅ Use appropriate partitioning

[Step 1] Filtering data FIRST (early reduction)...
  Time: 0.00s | Rows after filter: 115,697

[Step 2] Adding column transformations on filtered data...
  Time: 0.00s

[Step 3] GroupBy on filtered dataset...
  Time: 0.00s | Rows: 115,697

[Step 4] Joining with population data (small dataset)...
  Time: 0.00s

[Step 5] Final aggregation...
  Time: 0.76s | Final rows: 52

--------------------------------------------------------------------------------
LAZY/OPTIMIZED APPROACH SUMMARY:
  Total execution time: 2.26 seconds
  Number of shuffles: ~2 (one groupBy, one final aggregation)
  Data processed: Filtered dataset only
--------------------------------------------------------------------------------

Sa

In [0]:
# ============================================================================
# PERFORMANCE COMPARISON
# ============================================================================
print("\n" + "=" * 80)
print("PERFORMANCE COMPARISON")
print("=" * 80)

speedup = eager_total / lazy_total
time_saved = eager_total - lazy_total
percent_improvement = ((eager_total - lazy_total) / eager_total) * 100

print(f"Eager/Inefficient Approach: {eager_total:6.2f} seconds")
print(f"Lazy/Optimized Approach: {lazy_total:6.2f} seconds")
print(f"Time Saved: {time_saved:6.2f} seconds")
print(f"Speedup: {speedup:6.2f}x faster")
print(f"Performance Improvement: {percent_improvement:6.1f}%")


PERFORMANCE COMPARISON
Eager/Inefficient Approach:   4.29 seconds
Lazy/Optimized Approach:   2.26 seconds
Time Saved:   2.03 seconds
Speedup:   1.90x faster
Performance Improvement:   47.4%


In [0]:
# SQL QUERYING
# Make sure we have data in the correct format (extra check)
# First, cast the string columns to integers in the original dataframe
covid_df = covid_df.withColumn("cases", col("cases").cast("integer")).withColumn(
    "deaths", col("deaths").cast("integer")
)

# Cast columns and add transformations
enriched_df = (
    covid_df.withColumn(
        "mortality_rate",
        when(col("cases") > 0, (col("deaths") / col("cases")) * 100).otherwise(0),
    )
    .withColumn("month", month(col("date")))
    .withColumn("year", year(col("date")))
)

# Create state_stats view with properly typed columns
state_stats = enriched_df.groupBy("state", "county").agg(
    sum("cases").alias("total_cases"),
    sum("deaths").alias("total_deaths"),
    avg("mortality_rate").alias("avg_mortality_rate"),
    count("*").alias("days_reported"),
)

In [0]:
# ============================================================================
# SQL QUERY 1: TOP 10 COUNTIES WITH HIGHEST MORTALITY RATES
# ============================================================================

print("\n" + "=" * 80)
print("SQL QUERY 1: TOP 10 COUNTIES WITH HIGHEST MORTALITY RATES")
print("=" * 80)

# ----------------------------------------------------------------------------
# VERSION 1A: EAGER/INEFFICIENT SQL QUERY ❌
# ----------------------------------------------------------------------------
print("\n" + "-" * 80)
print("VERSION 1A: EAGER/INEFFICIENT APPROACH")
print("-" * 80)
print("Problems:")
print("  ❌ Computes statistics on ALL counties first")
print("  ❌ Filters AFTER expensive aggregation")
print("  ❌ No predicate pushdown")
print("  ❌ Processes unnecessary data")

eager_q1_start = time.time()

# Register view
state_stats.createOrReplaceTempView("state_stats_eager")

# INEFFICIENT: No filtering before aggregation, sorting all data
sql_eager_q1 = """
    SELECT 
        state,
        county,
        total_cases,
        total_deaths,
        ROUND(avg_mortality_rate, 2) as mortality_rate_pct,
        days_reported
    FROM (
        SELECT 
            state,
            county,
            total_cases,
            total_deaths,
            avg_mortality_rate,
            days_reported,
            ROW_NUMBER() OVER (ORDER BY avg_mortality_rate DESC) as rank
        FROM state_stats_eager
    ) ranked
    WHERE total_cases >= 100 AND rank <= 10
"""

print("\n[Executing inefficient query...]")
eager_result_q1 = spark.sql(sql_eager_q1)
eager_q1_time = time.time() - eager_q1_start

print(f"⏱️  Eager Query 1 Time: {eager_q1_time:.5f} seconds")

print("\nExecution Plan (Eager - notice full table scan and window function):")
eager_result_q1.explain(mode="simple")


SQL QUERY 1: TOP 10 COUNTIES WITH HIGHEST MORTALITY RATES

--------------------------------------------------------------------------------
VERSION 1A: EAGER/INEFFICIENT APPROACH
--------------------------------------------------------------------------------
Problems:
  ❌ Computes statistics on ALL counties first
  ❌ Filters AFTER expensive aggregation
  ❌ No predicate pushdown
  ❌ Processes unnecessary data

[Executing inefficient query...]
⏱️  Eager Query 1 Time: 0.16685 seconds

Execution Plan (Eager - notice full table scan and window function):
== Physical Plan ==
AdaptiveSparkPlan isFinalPlan=false
+- == Initial Plan ==
   ColumnarToRow
   +- PhotonResultStage
      +- PhotonProject [state#14466, county#14465, total_cases#16491L, total_deaths#16492L, round(avg_mortality_rate#16493, 2) AS mortality_rate_pct#16894, days_reported#16494L]
         +- PhotonFilter ((isnotnull(total_cases#16491L) AND (rank#16893 <= 10)) AND (total_cases#16491L >= 100))
            +- PhotonWindow [st

In [0]:
# ----------------------------------------------------------------------------
# VERSION 1B: LAZY/OPTIMIZED SQL QUERY ✅
# ----------------------------------------------------------------------------
print("\n" + "-" * 80)
print("VERSION 1B: LAZY/OPTIMIZED APPROACH")
print("-" * 80)
print("Optimizations:")
print("  ✅ Filters applied early (total_cases >= 100)")
print("  ✅ Uses LIMIT instead of window function")
print("  ✅ Predicate pushdown reduces data scanned")
print("  ✅ Simplified query plan")

lazy_q1_start = time.time()

# Register view
state_stats.createOrReplaceTempView("state_stats_lazy")

# OPTIMIZED: Filter early, simple ORDER BY with LIMIT
sql_lazy_q1 = """
    SELECT 
        state,
        county,
        total_cases,
        total_deaths,
        ROUND(avg_mortality_rate, 2) as mortality_rate_pct,
        days_reported
    FROM state_stats_lazy
    WHERE total_cases >= 100
    ORDER BY avg_mortality_rate DESC
    LIMIT 10
"""

print("\n[Executing optimized query...]")
lazy_result_q1 = spark.sql(sql_lazy_q1)
lazy_q1_time = time.time() - lazy_q1_start

print(f"⏱️  Lazy Query 1 Time: {lazy_q1_time:.5f} seconds")

print("\nExecution Plan (Optimized - notice simpler plan):")
lazy_result_q1.explain(mode="simple")


--------------------------------------------------------------------------------
VERSION 1B: LAZY/OPTIMIZED APPROACH
--------------------------------------------------------------------------------
Optimizations:
  ✅ Filters applied early (total_cases >= 100)
  ✅ Uses LIMIT instead of window function
  ✅ Predicate pushdown reduces data scanned
  ✅ Simplified query plan

[Executing optimized query...]
⏱️  Lazy Query 1 Time: 0.15117 seconds

Execution Plan (Optimized - notice simpler plan):
== Physical Plan ==
AdaptiveSparkPlan isFinalPlan=false
+- == Initial Plan ==
   ColumnarToRow
   +- PhotonResultStage
      +- PhotonProject [state#14466, county#14465, total_cases#16491L, total_deaths#16492L, mortality_rate_pct#17136, days_reported#16494L]
         +- PhotonTopK(sortOrder=[avg_mortality_rate#16493 DESC NULLS LAST], partitionOrderCount=0)
            +- PhotonShuffleExchangeSource
               +- PhotonShuffleMapStage ENSURE_REQUIREMENTS, [id=#25223]
                  +- PhotonShu

In [0]:
# ============================================================================
# PERFORMANCE COMPARISON
# ============================================================================
print("\n" + "=" * 80)
print("PERFORMANCE COMPARISON")
print("=" * 80)

speedup_q1 = eager_q1_time / lazy_q1_time
time_saved_q1 = eager_q1_time - lazy_q1_time
percent_improvement_q1 = ((eager_q1_time - lazy_q1_time) / eager_q1_time) * 100

print(f"Eager/Inefficient Approach: {eager_q1_time:6.2f} seconds")
print(f"Lazy/Optimized Approach: {lazy_q1_time:6.2f} seconds")
print(f"Time Saved: {time_saved_q1:6.2f} seconds")
print(f"Speedup: {speedup_q1:6.2f}x faster")
print(f"Performance Improvement: {percent_improvement_q1:6.1f}%")


PERFORMANCE COMPARISON
Eager/Inefficient Approach:   0.17 seconds
Lazy/Optimized Approach:   0.15 seconds
Time Saved:   0.02 seconds
Speedup:   1.10x faster
Performance Improvement:    9.4%


In [0]:
# ============================================================================
# SQL QUERY 2: STATE-LEVEL COMPARISON
# ============================================================================

print("\n" + "=" * 80)
print("SQL QUERY 2: STATE-LEVEL COMPARISON WITH AGGREGATIONS")
print("=" * 80)

# ----------------------------------------------------------------------------
# VERSION 2A: EAGER/INEFFICIENT SQL QUERY ❌
# ----------------------------------------------------------------------------
print("\n" + "-" * 80)
print("VERSION 2A: EAGER/INEFFICIENT APPROACH")
print("-" * 80)
print("Problems:")
print("  ❌ Multiple subqueries with repeated aggregations")
print("  ❌ Computes aggregations separately (inefficient)")
print("  ❌ Multiple passes over the same data")
print("  ❌ Unnecessary intermediate results")

eager_q2_start = time.time()

# INEFFICIENT: Multiple subqueries, repeated computations
sql_eager_q2 = """
    SELECT 
        s1.state,
        s1.num_counties,
        s2.state_total_cases,
        s3.state_total_deaths,
        s4.avg_mortality_rate,
        s5.max_county_cases
    FROM (
        SELECT state, COUNT(DISTINCT county) as num_counties
        FROM state_stats_eager
        GROUP BY state
    ) s1
    LEFT JOIN (
        SELECT state, SUM(total_cases) as state_total_cases
        FROM state_stats_eager
        GROUP BY state
    ) s2 ON s1.state = s2.state
    LEFT JOIN (
        SELECT state, SUM(total_deaths) as state_total_deaths
        FROM state_stats_eager
        GROUP BY state
    ) s3 ON s1.state = s3.state
    LEFT JOIN (
        SELECT state, ROUND(AVG(avg_mortality_rate), 2) as avg_mortality_rate
        FROM state_stats_eager
        GROUP BY state
    ) s4 ON s1.state = s4.state
    LEFT JOIN (
        SELECT state, MAX(total_cases) as max_county_cases
        FROM state_stats_eager
        GROUP BY state
    ) s5 ON s1.state = s5.state
    ORDER BY s2.state_total_cases DESC
"""

print("\n[Executing inefficient query...]")
eager_result_q2 = spark.sql(sql_eager_q2)
eager_count_q2 = eager_result_q2.count()
eager_q2_time = time.time() - eager_q2_start

print(f"⏱️  Eager Query 2 Time: {eager_q2_time:.2f} seconds")

print("\nExecution Plan (Eager - notice multiple scans and joins):")
eager_result_q2.explain(mode="simple")


SQL QUERY 2: STATE-LEVEL COMPARISON WITH AGGREGATIONS

--------------------------------------------------------------------------------
VERSION 2A: EAGER/INEFFICIENT APPROACH
--------------------------------------------------------------------------------
Problems:
  ❌ Multiple subqueries with repeated aggregations
  ❌ Computes aggregations separately (inefficient)
  ❌ Multiple passes over the same data
  ❌ Unnecessary intermediate results

[Executing inefficient query...]
⏱️  Eager Query 2 Time: 0.80 seconds

Execution Plan (Eager - notice multiple scans and joins):
== Physical Plan ==
AdaptiveSparkPlan isFinalPlan=false
+- == Initial Plan ==
   ColumnarToRow
   +- PhotonResultStage
      +- PhotonSort [state_total_cases#17617L DESC NULLS LAST]
         +- PhotonShuffleExchangeSource
            +- PhotonShuffleMapStage ENSURE_REQUIREMENTS, [id=#27470]
               +- PhotonShuffleExchangeSink rangepartitioning(state_total_cases#17617L DESC NULLS LAST, 1024)
                  +- Ph

In [0]:
# ----------------------------------------------------------------------------
# VERSION 2B: LAZY/OPTIMIZED SQL QUERY ✅
# ----------------------------------------------------------------------------
print("\n" + "-" * 80)
print("VERSION 2B: LAZY/OPTIMIZED APPROACH")
print("-" * 80)
print("Optimizations:")
print("  ✅ Single GROUP BY with all aggregations together")
print("  ✅ One pass over data")
print("  ✅ No unnecessary joins")
print("  ✅ Efficient execution plan")

lazy_q2_start = time.time()

# OPTIMIZED: Single GROUP BY with all aggregations
sql_lazy_q2 = """
    SELECT 
        state,
        COUNT(DISTINCT county) as num_counties,
        SUM(total_cases) as state_total_cases,
        SUM(total_deaths) as state_total_deaths,
        ROUND(AVG(avg_mortality_rate), 2) as avg_mortality_rate,
        MAX(total_cases) as max_county_cases
    FROM state_stats_lazy
    GROUP BY state
    ORDER BY state_total_cases DESC
"""

print("\n[Executing optimized query...]")
lazy_result_q2 = spark.sql(sql_lazy_q2)
lazy_count_q2 = lazy_result_q2.count()
lazy_q2_time = time.time() - lazy_q2_start

print(f"⏱️  Lazy Query 2 Time: {lazy_q2_time:.2f} seconds")

print("\nExecution Plan (Optimized - notice single scan):")
lazy_result_q2.explain(mode="simple")


--------------------------------------------------------------------------------
VERSION 2B: LAZY/OPTIMIZED APPROACH
--------------------------------------------------------------------------------
Optimizations:
  ✅ Single GROUP BY with all aggregations together
  ✅ One pass over data
  ✅ No unnecessary joins
  ✅ Efficient execution plan

[Executing optimized query...]
⏱️  Lazy Query 2 Time: 0.73 seconds

Execution Plan (Optimized - notice single scan):
== Physical Plan ==
AdaptiveSparkPlan isFinalPlan=false
+- == Initial Plan ==
   ColumnarToRow
   +- PhotonResultStage
      +- PhotonSort [state_total_cases#17835L DESC NULLS LAST]
         +- PhotonShuffleExchangeSource
            +- PhotonShuffleMapStage ENSURE_REQUIREMENTS, [id=#28003]
               +- PhotonShuffleExchangeSink rangepartitioning(state_total_cases#17835L DESC NULLS LAST, 1024)
                  +- PhotonGroupingAgg(keys=[state#14466], functions=[finalmerge_sum(merge sum#17847L) AS sum(total_cases)#17840L, finalme

In [0]:
# ============================================================================
# PERFORMANCE COMPARISON
# ============================================================================
print("\n" + "=" * 80)
print("PERFORMANCE COMPARISON")
print("=" * 80)

speedup_q2 = eager_q2_time / lazy_q2_time
time_saved_q2 = eager_q2_time - lazy_q2_time
percent_improvement_q2 = ((eager_q2_time - lazy_q2_time) / eager_q2_time) * 100

print(f"Eager/Inefficient Approach: {eager_q2_time:6.2f} seconds")
print(f"Lazy/Optimized Approach: {lazy_q2_time:6.2f} seconds")
print(f"Time Saved: {time_saved_q2:6.2f} seconds")
print(f"Speedup: {speedup_q2:6.2f}x faster")
print(f"Performance Improvement: {percent_improvement_q2:6.1f}%")


PERFORMANCE COMPARISON
Eager/Inefficient Approach:   0.80 seconds
Lazy/Optimized Approach:   0.73 seconds
Time Saved:   0.07 seconds
Speedup:   1.10x faster
Performance Improvement:    8.8%


In [0]:
# Write results to a destination - my VOLUME on Databricks
output_base = "/Volumes/pyspark_assn_aj463/table1_output_schema/table1_volume_test1/"

output_path_state_pop = f"{output_base}/output_path_state_pop"
state_population.write.mode("overwrite").parquet(output_path_state_pop)

In [0]:
# writing all outputs
output_path_eager_result = f"{output_base}/eager_result"
eager_result.write.mode("overwrite").parquet(output_path_eager_result)

output_path_eager_result = f"{output_base}/eager_result"
state_population.write.mode("overwrite").parquet(output_path_eager_result)

output_path_lazy_result = f"{output_base}/lazy_result"
lazy_result.write.mode("overwrite").parquet(output_path_lazy_result)