In [3]:
# set up slow environment 

from pyspark.sql import SparkSession
from pyspark.sql.functions import *
from pyspark.sql.window import Window
import random
import time
import  builtins

# ANTI-PATTERN: broadcast disabled, too many shuffle partitions
spark = SparkSession.builder \
    .appName("StreamPulse-Revenue-SLOW") \
    .master("local[*]") \
    .config("spark.driver.memory", "2g") \
    .config("spark.sql.shuffle.partitions", "200") \
    .config("spark.sql.adaptive.enabled", "false") \
    .config("spark.sql.autoBroadcastJoinThreshold", "-1") \
    .getOrCreate()

print("‚úÖ SparkSession created (intentionally misconfigured)")


‚úÖ SparkSession created (intentionally misconfigured)


In [17]:
# Generate the revenue dataset:

random.seed(42)
N = 600000

# Events (large)
event_data = []
for i in range(N):
    event_data.append((
        f"EVT-{i+1:07d}",
        f"USR-{random.randint(1, 100000):06d}",
        f"ART-{random.randint(1, 5000):05d}",
        random.choice(["Pop", "Rock", "Hip-Hop", "Jazz", "Electronic", "R&B", "Country", "Classical"]),
        random.choice(["North America", "Europe", "Asia Pacific", "Latin America", "Africa"]),
        random.randint(15, 350),
        random.choice([True, False]),
        random.choice(["mobile", "desktop", "smart_speaker", "tablet", "car", "tv"]),
        f"2024-{random.randint(1,12):02d}-{random.randint(1,28):02d}",
    ))

events = spark.createDataFrame(event_data,
    ["event_id", "user_id", "artist_id", "genre", "region",
     "duration_sec", "completed", "device", "event_date"]) \
    .withColumn("event_date", col("event_date").cast("date")) \
    .withColumn("month", month(col("event_date")))

events.write.parquet("revenue_data/events", mode="overwrite")

# Subscriptions (medium - 100K users with subscription info)
sub_data = [(f"USR-{i+1:06d}",
             random.choice(["free", "individual", "family", "student"]),
             builtins.round(random.choice([0.0, 9.99, 14.99, 4.99]), 2),
             random.choice(["US", "UK", "DE", "JP", "BR", "IN", "KR", "FR"]))
            for i in range(100000)]
subscriptions = spark.createDataFrame(sub_data, ["user_id", "plan", "monthly_price", "country"])
subscriptions.write.parquet("revenue_data/subscriptions", mode="overwrite")

# Ad rates (tiny - 8 genres x 6 devices = 48 rows)
ad_data = []
for genre in ["Pop", "Rock", "Hip-Hop", "Jazz", "Electronic", "R&B", "Country", "Classical"]:
    for device in ["mobile", "desktop", "smart_speaker", "tablet", "car", "tv"]:
        cpm = builtins.round(random.uniform(1.5, 8.0), 2)
        ad_data.append((genre, device, cpm))
ad_rates = spark.createDataFrame(ad_data, ["ad_genre", "ad_device", "cpm"])
ad_rates.write.parquet("revenue_data/ad_rates", mode="overwrite")

# Artist payout rates (small - 5000 artists)
payout_data = [(f"ART-{i+1:05d}", builtins.round(random.uniform(0.003, 0.008), 4),
                random.choice(["major", "indie", "unsigned"]))
               for i in range(5000)]
payouts = spark.createDataFrame(payout_data, ["artist_id", "per_stream_rate", "label_type"])
payouts.write.parquet("revenue_data/payouts", mode="overwrite")

# Reload from disk
events = spark.read.parquet("revenue_data/events")
subscriptions = spark.read.parquet("revenue_data/subscriptions")
ad_rates = spark.read.parquet("revenue_data/ad_rates")
payouts = spark.read.parquet("revenue_data/payouts")

print(f"Events: {events.count()} | Subs: {subscriptions.count()} | "
      f"Ad rates: {ad_rates.count()} | Payouts: {payouts.count()}")


26/02/27 15:40:04 WARN TaskSetManager: Stage 188 contains a task of very large size (4713 KiB). The maximum recommended task size is 1000 KiB.
                                                                                

Events: 600000 | Subs: 100000 | Ad rates: 48 | Payouts: 5000


In [7]:
# Part 2: The Unoptimized Pipeline (Baseline)
# Run the slow pipeline and time it:

print("=" * 60)
print("RUNNING UNOPTIMIZED PIPELINE (BASELINE)")
print("=" * 60)

total_start = time.time()

# Build enriched revenue DataFrame (NOT cached, recomputed every time)
def build_revenue():
    return events \
        .join(subscriptions, "user_id") \
        .join(ad_rates,
              (events.genre == ad_rates.ad_genre) & (events.device == ad_rates.ad_device)) \
        .join(payouts, "artist_id") \
        .withColumn("ad_revenue", col("cpm") / 1000) \
        .withColumn("stream_payout", col("per_stream_rate")) \
        .withColumn("is_premium", when(col("plan") != "free", True).otherwise(False))

# Report 1: Genre Revenue
revenue = build_revenue()
r1_start = time.time()
report_1 = revenue.groupBy("genre") \
    .agg(sum("ad_revenue").alias("total_ad_rev"),
         countDistinct("user_id").alias("unique_listeners")) \
    .collect()
r1_time = time.time() - r1_start

# Report 2: Regional Breakdown
revenue = build_revenue()
r2_start = time.time()
report_2 = revenue.groupBy("region", "country") \
    .agg(count("*").alias("streams"),
         sum("ad_revenue").alias("ad_rev")) \
    .collect()
r2_time = time.time() - r2_start

# Report 3: Subscription Analysis
revenue = build_revenue()
r3_start = time.time()
report_3 = revenue.groupBy("plan") \
    .agg(countDistinct("user_id").alias("users"),
         count("*").alias("total_streams"),
         avg("duration_sec").alias("avg_duration")) \
    .collect()
r3_time = time.time() - r3_start

# Report 4: Ad Performance
revenue = build_revenue()
r4_start = time.time()
report_4 = revenue.groupBy("device", "genre") \
    .agg(sum("ad_revenue").alias("total_ad_rev"),
         count("*").alias("impressions")) \
    .collect()
r4_time = time.time() - r4_start

# Report 5: Artist Payouts
revenue = build_revenue()
r5_start = time.time()
report_5 = revenue.groupBy("artist_id", "label_type") \
    .agg(sum("stream_payout").alias("total_payout"),
         count("*").alias("total_streams")) \
    .orderBy(desc("total_payout")).limit(100) \
    .collect()
r5_time = time.time() - r5_start

# Report 6: Daily Summary
revenue = build_revenue()
r6_start = time.time()
report_6 = revenue.groupBy("event_date") \
    .agg(count("*").alias("streams"),
         sum("ad_revenue").alias("ad_rev"),
         countDistinct("user_id").alias("unique_users")) \
    .orderBy("event_date") \
    .collect()
r6_time = time.time() - r6_start

baseline_total = time.time() - total_start

print(f"\nReport 1 (genre):        {r1_time:.2f}s")
print(f"Report 2 (regional):     {r2_time:.2f}s")
print(f"Report 3 (subscription): {r3_time:.2f}s")
print(f"Report 4 (ad perf):      {r4_time:.2f}s")
print(f"Report 5 (payouts):      {r5_time:.2f}s")
print(f"Report 6 (daily):        {r6_time:.2f}s")
print(f"\n‚è±Ô∏è  BASELINE TOTAL: {baseline_total:.2f}s")


RUNNING UNOPTIMIZED PIPELINE (BASELINE)





Report 1 (genre):        10.83s
Report 2 (regional):     5.16s
Report 3 (subscription): 8.64s
Report 4 (ad perf):      5.50s
Report 5 (payouts):      3.48s
Report 6 (daily):        11.67s

‚è±Ô∏è  BASELINE TOTAL: 45.59s


                                                                                

In [8]:
# Analyze the baseline plan:
print("\nBASELINE PLAN:")
build_revenue().groupBy("genre").agg(sum("ad_revenue")).explain(mode="formatted")





BASELINE PLAN:
== Physical Plan ==
* HashAggregate (33)
+- Exchange (32)
   +- * HashAggregate (31)
      +- * Project (30)
         +- * SortMergeJoin Inner (29)
            :- * Sort (23)
            :  +- Exchange (22)
            :     +- * Project (21)
            :        +- * SortMergeJoin Inner (20)
            :           :- * Sort (14)
            :           :  +- Exchange (13)
            :           :     +- * Project (12)
            :           :        +- * SortMergeJoin Inner (11)
            :           :           :- * Sort (5)
            :           :           :  +- Exchange (4)
            :           :           :     +- * Filter (3)
            :           :           :        +- * ColumnarToRow (2)
            :           :           :           +- Scan parquet  (1)
            :           :           +- * Sort (10)
            :           :              +- Exchange (9)
            :           :                 +- * Filter (8)
            :           :       

| Anti-Pattern | Description                                              | Impact                                                    |
| ------------ | -------------------------------------------------------- | --------------------------------------------------------- |
| 1            | Rebuilding the same joined DataFrame for every report    | Repeats expensive joins and shuffles six times            |
| 2            | No caching or persistence of shared intermediate results | Prevents reuse; Spark recomputes lineage for every action |
| 3            | Multiple `.collect()` calls                              | Forces full execution and driver memory pressure          |
| 4            | Large fact table (`events`) joined repeatedly            | Dominates runtime and I/O cost                            |
| 5            | No broadcast joins for small dimension tables            | Causes unnecessary shuffles and SortMergeJoins            |
---

In [None]:
#This tells Spark:
#If a table is ‚â§ 10 MB, automatically broadcast it to all executors instead of shuffling.‚Äù

spark.conf.set("spark.sql.autoBroadcastJoinThreshold", "10485760")
# Your optimized code here ‚Äî use broadcast() for ad_rates and payouts

from pyspark.sql.functions import broadcast

def build_revenue_optimized():
    return events \
        .join(broadcast(subscriptions), "user_id") \
        .join(
            broadcast(ad_rates),
            (events.genre == ad_rates.ad_genre) &
            (events.device == ad_rates.ad_device)
        ) \
        .join(broadcast(payouts), "artist_id") \
        .withColumn("ad_revenue", col("cpm") / 1000) \
        .withColumn("stream_payout", col("per_stream_rate")) \
        .withColumn("is_premium", when(col("plan") != "free", True).otherwise(False))

In [None]:
print("=" * 60)
print("RUNNING OPTIMIZED BROADCAST PIPELINE ")
print("=" * 60)
 
total_start = time.time()

# Report 1 : Genre Revenue
revenue = build_revenue_optimized()
r1_start = time.time()
report_1 = revenue.groupBy("genre") \
    .agg(sum("ad_revenue").alias("total_ad_rev"),
         countDistinct("user_id").alias("unique_listeners")) \
    .collect()
r1_time = time.time() - r1_start

# Report 2: Regional Breakdown
revenue = build_revenue_optimized()
r2_start = time.time()
report_2 = revenue.groupBy("region", "country") \
    .agg(count("*").alias("streams"),
         sum("ad_revenue").alias("ad_rev")) \
    .collect()
r2_time = time.time() - r2_start
 
# Report 3: Subscription Analysis
revenue = build_revenue_optimized()
r3_start = time.time()
report_3 = revenue.groupBy("plan") \
    .agg(countDistinct("user_id").alias("users"),
         count("*").alias("total_streams"),
         avg("duration_sec").alias("avg_duration")) \
    .collect()
r3_time = time.time() - r3_start
 
# Report 4: Ad Performance
revenue = build_revenue_optimized()
r4_start = time.time()
report_4 = revenue.groupBy("device", "genre") \
    .agg(sum("ad_revenue").alias("total_ad_rev"),
         count("*").alias("impressions")) \
    .collect()
r4_time = time.time() - r4_start
 
# Report 5: Artist Payouts
revenue = build_revenue_optimized()
r5_start = time.time()
report_5 = revenue.groupBy("artist_id", "label_type") \
    .agg(sum("stream_payout").alias("total_payout"),
         count("*").alias("total_streams")) \
    .orderBy(desc("total_payout")).limit(100) \
    .collect()
r5_time = time.time() - r5_start
 
# Report 6: Daily Summary
revenue = build_revenue_optimized()
r6_start = time.time()
report_6 = revenue.groupBy("event_date") \
    .agg(count("*").alias("streams"),
         sum("ad_revenue").alias("ad_rev"),
         countDistinct("user_id").alias("unique_users")) \
    .orderBy("event_date") \
    .collect()
r6_time = time.time() - r6_start
 
optimized_total = time.time() - total_start
 
print(f"\nReport 1 (genre):        {r1_time:.2f}s")
print(f"Report 2 (regional):     {r2_time:.2f}s")
print(f"Report 3 (subscription): {r3_time:.2f}s")
print(f"Report 4 (ad perf):      {r4_time:.2f}s")
print(f"Report 5 (payouts):      {r5_time:.2f}s")
print(f"Report 6 (daily):        {r6_time:.2f}s")
print(f"\n‚è±Ô∏è  OPTIMIZED BROADCAST TOTAL: {optimized_total:.2f}s")



RUNNING OPTIMIZED PIPELINE 





Report 1 (genre):        2.44s
Report 2 (regional):     1.01s
Report 3 (subscription): 1.77s
Report 4 (ad perf):      0.80s
Report 5 (payouts):      0.88s
Report 6 (daily):        5.01s

‚è±Ô∏è  OPTIMIZED TOTAL: 12.26s


                                                                                

In [11]:
#   Verifying the optimization with explain()

build_revenue_optimized() \
    .groupBy("genre") \
    .agg(sum("ad_revenue")) \
    .explain(mode="formatted")

== Physical Plan ==
* HashAggregate (24)
+- Exchange (23)
   +- * HashAggregate (22)
      +- * Project (21)
         +- * BroadcastHashJoin Inner BuildRight (20)
            :- * Project (15)
            :  +- * BroadcastHashJoin Inner BuildRight (14)
            :     :- * Project (9)
            :     :  +- * BroadcastHashJoin Inner BuildRight (8)
            :     :     :- * Filter (3)
            :     :     :  +- * ColumnarToRow (2)
            :     :     :     +- Scan parquet  (1)
            :     :     +- BroadcastExchange (7)
            :     :        +- * Filter (6)
            :     :           +- * ColumnarToRow (5)
            :     :              +- Scan parquet  (4)
            :     +- BroadcastExchange (13)
            :        +- * Filter (12)
            :           +- * ColumnarToRow (11)
            :              +- Scan parquet  (10)
            +- BroadcastExchange (19)
               +- * Filter (18)
                  +- * ColumnarToRow (17)
                

In [None]:
# Optimization 2: Cache the enriched Data Frame
# Build once, cache, reuse for all 6 reports
# Measure and compare

print("=" * 60)
print("RUNNING BROADCAST + CACHED PIPELINE ")
print("=" * 60)
 
total_start = time.time()

def build_revenue_cached():
    return events \
        .join(broadcast(subscriptions), "user_id") \
        .join(
            broadcast(ad_rates),
            (events.genre == ad_rates.ad_genre) &
            (events.device == ad_rates.ad_device)
        ) \
        .join(broadcast(payouts), "artist_id") \
        .withColumn("ad_revenue", col("cpm") / 1000) \
        .withColumn("stream_payout", col("per_stream_rate")) \
        .withColumn("is_premium", when(col("plan") != "free", True).otherwise(False))
        
# build once 
revenue = build_revenue_cached()
# Cache it
revenue.cache()
# trigger/ materialize cache population
revenue.count()
print("‚úÖ Cached filtered+joined Revenue DataFrame")

# Step 3: Reuse cached DataFrame for all 6 reports
# Report 1: Genre Revenue
r1_start = time.time()
report_1 = revenue.groupBy("genre") \
    .agg(sum("ad_revenue").alias("total_ad_rev"),
         countDistinct("user_id").alias("unique_listeners")) \
    .collect()
r1_time = time.time() - r1_start
 
# Report 2: Regional Breakdown
r2_start = time.time()
report_2 = revenue.groupBy("region", "country") \
    .agg(count("*").alias("streams"),
         sum("ad_revenue").alias("ad_rev")) \
    .collect()
r2_time = time.time() - r2_start
 
# Report 3: Subscription Analysis
r3_start = time.time()
report_3 = revenue.groupBy("plan") \
    .agg(countDistinct("user_id").alias("users"),
         count("*").alias("total_streams"),
         avg("duration_sec").alias("avg_duration")) \
    .collect()
r3_time = time.time() - r3_start
 
# Report 4: Ad Performance
r4_start = time.time()
report_4 = revenue.groupBy("device", "genre") \
    .agg(sum("ad_revenue").alias("total_ad_rev"),
         count("*").alias("impressions")) \
    .collect()
r4_time = time.time() - r4_start
 
# Report 5: Artist Payouts
r5_start = time.time()
report_5 = revenue.groupBy("artist_id", "label_type") \
    .agg(sum("stream_payout").alias("total_payout"),
         count("*").alias("total_streams")) \
    .orderBy(desc("total_payout")).limit(100) \
    .collect()
r5_time = time.time() - r5_start
 
# Report 6: Daily Summary
r6_start = time.time()
report_6 = revenue.groupBy("event_date") \
    .agg(count("*").alias("streams"),
         sum("ad_revenue").alias("ad_rev"),
         countDistinct("user_id").alias("unique_users")) \
    .orderBy("event_date") \
    .collect()
r6_time = time.time() - r6_start
 
cached_total = time.time() - total_start
 
print(f"\nReport 1 (genre):        {r1_time:.2f}s")
print(f"Report 2 (regional):     {r2_time:.2f}s")
print(f"Report 3 (subscription): {r3_time:.2f}s")
print(f"Report 4 (ad perf):      {r4_time:.2f}s")
print(f"Report 5 (payouts):      {r5_time:.2f}s")
print(f"Report 6 (daily):        {r6_time:.2f}s")
print(f"\n‚è±Ô∏è  OPTIMIZED BROADCAST + CACHHED TOTAL: {cached_total:.2f}s")






RUNNING BROADCAST + CACHED PIPELINE 


26/02/27 15:21:10 WARN CacheManager: Asked to cache already cached data.


‚úÖ Cached filtered+joined Revenue DataFrame





Report 1 (genre):        1.25s
Report 2 (regional):     0.25s
Report 3 (subscription): 1.04s
Report 4 (ad perf):      0.29s
Report 5 (payouts):      0.42s
Report 6 (daily):        4.14s

‚è±Ô∏è  OPTIMIZED BROADCAST + CACHHED TOTAL: 7.67s


                                                                                

In [None]:
# compare cached optimization worked 
revenue.groupBy("genre")\
    .agg(sum("ad_revenue"))\
        .explain(mode = "formatted")


In [15]:
#Optimization 3: Reduce shuffle partitions
spark.conf.set("spark.sql.shuffle.partitions", "8")

print("=" * 60)
print("RUNNING BROADCAST + CACHED + REDUCE SHUFFLE PIPELINE ")
print("=" * 60)
 
total_start = time.time()

def build_revenue_cached():
    return events \
        .join(broadcast(subscriptions), "user_id") \
        .join(
            broadcast(ad_rates),
            (events.genre == ad_rates.ad_genre) &
            (events.device == ad_rates.ad_device)
        ) \
        .join(broadcast(payouts), "artist_id") \
        .withColumn("ad_revenue", col("cpm") / 1000) \
        .withColumn("stream_payout", col("per_stream_rate")) \
        .withColumn("is_premium", when(col("plan") != "free", True).otherwise(False))
        
# build once 
revenue = build_revenue_cached()
# Cache it
revenue.cache()
# trigger/ materialize cache population
revenue.count()
print("‚úÖ Cached filtered+joined Revenue DataFrame")

# Step 3: Reuse cached DataFrame for all 6 reports
# Report 1: Genre Revenue
r1_start = time.time()
report_1 = revenue.groupBy("genre") \
    .agg(sum("ad_revenue").alias("total_ad_rev"),
         countDistinct("user_id").alias("unique_listeners")) \
    .collect()
r1_time = time.time() - r1_start
 
# Report 2: Regional Breakdown
r2_start = time.time()
report_2 = revenue.groupBy("region", "country") \
    .agg(count("*").alias("streams"),
         sum("ad_revenue").alias("ad_rev")) \
    .collect()
r2_time = time.time() - r2_start
 
# Report 3: Subscription Analysis
r3_start = time.time()
report_3 = revenue.groupBy("plan") \
    .agg(countDistinct("user_id").alias("users"),
         count("*").alias("total_streams"),
         avg("duration_sec").alias("avg_duration")) \
    .collect()
r3_time = time.time() - r3_start
 
# Report 4: Ad Performance
r4_start = time.time()
report_4 = revenue.groupBy("device", "genre") \
    .agg(sum("ad_revenue").alias("total_ad_rev"),
         count("*").alias("impressions")) \
    .collect()
r4_time = time.time() - r4_start
 
# Report 5: Artist Payouts
r5_start = time.time()
report_5 = revenue.groupBy("artist_id", "label_type") \
    .agg(sum("stream_payout").alias("total_payout"),
         count("*").alias("total_streams")) \
    .orderBy(desc("total_payout")).limit(100) \
    .collect()
r5_time = time.time() - r5_start
 
# Report 6: Daily Summary
r6_start = time.time()
report_6 = revenue.groupBy("event_date") \
    .agg(count("*").alias("streams"),
         sum("ad_revenue").alias("ad_rev"),
         countDistinct("user_id").alias("unique_users")) \
    .orderBy("event_date") \
    .collect()
r6_time = time.time() - r6_start
 
reduce_shuffle_total = time.time() - total_start
 
print(f"\nReport 1 (genre):        {r1_time:.2f}s")
print(f"Report 2 (regional):     {r2_time:.2f}s")
print(f"Report 3 (subscription): {r3_time:.2f}s")
print(f"Report 4 (ad perf):      {r4_time:.2f}s")
print(f"Report 5 (payouts):      {r5_time:.2f}s")
print(f"Report 6 (daily):        {r6_time:.2f}s")
print(f"\n‚è±Ô∏è  OPTIMIZED BROADCAST + CACHHED + REDUCE SHUFFLE TOTAL: {reduce_shuffle_total:.2f}s")







RUNNING BROADCAST + CACHED + REDUCE SHUFFLE PIPELINE 
‚úÖ Cached filtered+joined Revenue DataFrame


26/02/27 15:29:03 WARN CacheManager: Asked to cache already cached data.



Report 1 (genre):        0.46s
Report 2 (regional):     0.18s
Report 3 (subscription): 0.20s
Report 4 (ad perf):      0.10s
Report 5 (payouts):      0.18s
Report 6 (daily):        0.30s

‚è±Ô∏è  OPTIMIZED BROADCAST + CACHHED + REDUCE SHUFFLE TOTAL: 1.56s


In [None]:
# compare reduce shuffle partition  worked 
revenue.groupBy("genre")\
    .agg(sum("ad_revenue"))\
        .explain(mode = "formatted")

In [20]:
# Optimization 4: Column pruning ‚Äî select only needed columns from each table

print("=" * 60)
print("RUNNING BROADCAST + CACHED + REDUCE SHUFFLE PIPELINE + PRUNED ")
print("=" * 60)

total_start = time.time()
# Prune colums at source
events_pruned = events.select(
    "event_id",
    "user_id",
    "artist_id",
    "genre",
    "device",
    "region",
    "event_date",
    "duration_sec",
    "completed",
    "month"
)

subscriptions_pruned = subscriptions.select(
    "user_id",
    "plan",
    "country"
)

ad_rates_pruned = ad_rates.select(
    col("ad_genre"),
    col("ad_device"),
    col("cpm")
)

payouts_pruned = payouts.select(
    "artist_id",
    "label_type",
    "per_stream_rate"
)

# Step 2: Build optimized, cached revenue DataFrame
def build_revenue_pruned():
    return events_pruned \
        .join(broadcast(subscriptions_pruned), "user_id") \
        .join(
            broadcast(ad_rates_pruned),
            (events_pruned.genre == ad_rates_pruned.ad_genre) &
            (events_pruned.device == ad_rates_pruned.ad_device)
        ) \
        .join(broadcast(payouts_pruned), "artist_id") \
        .withColumn("ad_revenue", col("cpm") / 1000) \
        .withColumn("stream_payout", col("per_stream_rate")) \
        .withColumn("is_premium", when(col("plan") != "free", True).otherwise(False))
        
# Step 3: Cache and materialize
revenue = build_revenue_pruned()
revenue.cache()
revenue.count()

# Step 4:  build reports 

# Report 1: Genre Revenue
r1_start = time.time()
report_1 = revenue.groupBy("genre") \
    .agg(sum("ad_revenue").alias("total_ad_rev"),
         countDistinct("user_id").alias("unique_listeners")) \
    .collect()
r1_time = time.time() - r1_start
 
# Report 2: Regional Breakdown
r2_start = time.time()
report_2 = revenue.groupBy("region", "country") \
    .agg(count("*").alias("streams"),
         sum("ad_revenue").alias("ad_rev")) \
    .collect()
r2_time = time.time() - r2_start
 
# Report 3: Subscription Analysis
r3_start = time.time()
report_3 = revenue.groupBy("plan") \
    .agg(countDistinct("user_id").alias("users"),
         count("*").alias("total_streams"),
         avg("duration_sec").alias("avg_duration")) \
    .collect()
r3_time = time.time() - r3_start
 
# Report 4: Ad Performance
r4_start = time.time()
report_4 = revenue.groupBy("device", "genre") \
    .agg(sum("ad_revenue").alias("total_ad_rev"),
         count("*").alias("impressions")) \
    .collect()
r4_time = time.time() - r4_start
 
# Report 5: Artist Payouts
r5_start = time.time()
report_5 = revenue.groupBy("artist_id", "label_type") \
    .agg(sum("stream_payout").alias("total_payout"),
         count("*").alias("total_streams")) \
    .orderBy(desc("total_payout")).limit(100) \
    .collect()
r5_time = time.time() - r5_start
 
# Report 6: Daily Summary
r6_start = time.time()
report_6 = revenue.groupBy("event_date") \
    .agg(count("*").alias("streams"),
         sum("ad_revenue").alias("ad_rev"),
         countDistinct("user_id").alias("unique_users")) \
    .orderBy("event_date") \
    .collect()
r6_time = time.time() - r6_start
 
pruned_total = time.time() - total_start
 
print(f"\nReport 1 (genre):        {r1_time:.2f}s")
print(f"Report 2 (regional):     {r2_time:.2f}s")
print(f"Report 3 (subscription): {r3_time:.2f}s")
print(f"Report 4 (ad perf):      {r4_time:.2f}s")
print(f"Report 5 (payouts):      {r5_time:.2f}s")
print(f"Report 6 (daily):        {r6_time:.2f}s")
print(f"\n‚è±Ô∏è  OPTIMIZED BROADCAST + CACHHED + REDUCE SHUFFLE + PRUNED TOTAL: {pruned_total:.2f}s")







RUNNING BROADCAST + CACHED + REDUCE SHUFFLE PIPELINE + PRUNED 


                                                                                


Report 1 (genre):        0.25s
Report 2 (regional):     0.10s
Report 3 (subscription): 0.16s
Report 4 (ad perf):      0.06s
Report 5 (payouts):      0.10s
Report 6 (daily):        0.22s

‚è±Ô∏è  OPTIMIZED BROADCAST + CACHHED + REDUCE SHUFFLE + PRUNED TOTAL: 2.63s


In [None]:
# compare column pruning  worked 
revenue.groupBy("genre")\
    .agg(sum("ad_revenue"))\
        .explain(mode = "formatted")

In [21]:
# Build the Fully Optimized Pipeline
# Combine ALL optimizations into a production-ready pipeline

print("=" * 60)
print("RUNNING FULLY OPTIMIZED PIPELINE")
print("=" * 60)

# Reset config
spark.conf.set("spark.sql.autoBroadcastJoinThreshold", "10485760")
spark.conf.set("spark.sql.shuffle.partitions", "8")

total_start = time.time()

# Build enriched DataFrame ONCE with all optimizations
revenue_opt = events \
    .select("event_id", "user_id", "artist_id", "genre", "region",
            "duration_sec", "completed", "device", "event_date", "month") \
    .join(subscriptions.select("user_id", "plan", "country"), "user_id") \
    .join(broadcast(ad_rates),
          (col("genre") == col("ad_genre")) & (col("device") == col("ad_device"))) \
    .join(broadcast(payouts), "artist_id") \
    .withColumn("ad_revenue", col("cpm") / 1000) \
    .withColumn("stream_payout", col("per_stream_rate")) \
    .drop("ad_genre", "ad_device")

# Cache the shared DataFrame
revenue_opt.cache()
cache_start = time.time()
row_count = revenue_opt.count()
cache_time = time.time() - cache_start
print(f"‚úÖ Cached {row_count} rows in {cache_time:.2f}s")

# Run all 6 reports from cache
r1 = revenue_opt.groupBy("genre").agg(sum("ad_revenue"), countDistinct("user_id")).collect()
r2 = revenue_opt.groupBy("region", "country").agg(count("*"), sum("ad_revenue")).collect()
r3 = revenue_opt.groupBy("plan").agg(countDistinct("user_id"), count("*"), avg("duration_sec")).collect()
r4 = revenue_opt.groupBy("device", "genre").agg(sum("ad_revenue"), count("*")).collect()
r5 = revenue_opt.groupBy("artist_id", "label_type") \
    .agg(sum("stream_payout"), count("*")) \
    .orderBy(desc("sum(stream_payout)")).limit(100).collect()
r6 = revenue_opt.groupBy("event_date") \
    .agg(count("*"), sum("ad_revenue"), countDistinct("user_id")) \
    .orderBy("event_date").collect()

optimized_total = time.time() - total_start

print(f"\n‚è±Ô∏è  OPTIMIZED TOTAL: {optimized_total:.2f}s")
print(f"‚è±Ô∏è  BASELINE TOTAL:  {baseline_total:.2f}s")
print(f"üìà SPEEDUP:          {baseline_total/optimized_total:.1f}x")
print(f"üìâ TIME SAVED:       {baseline_total - optimized_total:.2f}s ({(1-optimized_total/baseline_total)*100:.0f}%)")

# Verify the plan
print("\nOPTIMIZED PLAN:")
revenue_opt.groupBy("genre").agg(sum("ad_revenue")).explain(mode="formatted")

revenue_opt.unpersist()



RUNNING FULLY OPTIMIZED PIPELINE


                                                                                

‚úÖ Cached 600000 rows in 1.32s

‚è±Ô∏è  OPTIMIZED TOTAL: 2.61s
‚è±Ô∏è  BASELINE TOTAL:  45.59s
üìà SPEEDUP:          17.5x
üìâ TIME SAVED:       42.98s (94%)

OPTIMIZED PLAN:
== Physical Plan ==
* HashAggregate (26)
+- Exchange (25)
   +- * HashAggregate (24)
      +- InMemoryTableScan (1)
            +- InMemoryRelation (2)
                  +- * Project (23)
                     +- * BroadcastHashJoin Inner BuildRight (22)
                        :- * Project (17)
                        :  +- * BroadcastHashJoin Inner BuildRight (16)
                        :     :- * Project (11)
                        :     :  +- * BroadcastHashJoin Inner BuildRight (10)
                        :     :     :- * Filter (5)
                        :     :     :  +- * ColumnarToRow (4)
                        :     :     :     +- Scan parquet  (3)
                        :     :     +- BroadcastExchange (9)
                        :     :        +- * Filter (8)
                        :     :    

DataFrame[artist_id: string, user_id: string, event_id: string, genre: string, region: string, duration_sec: bigint, completed: boolean, device: string, event_date: date, month: int, plan: string, country: string, cpm: double, per_stream_rate: double, label_type: string, ad_revenue: double, stream_payout: double]

In [23]:
print("=" * 65)
print("OPTIMIZATION REPORT ‚Äî StreamPulse Revenue Pipeline")
print("=" * 65)

print(f"""
Pipeline: Revenue Analytics (6 reports from joined data)

CONFIGURATION CHANGES:
  spark.sql.autoBroadcastJoinThreshold: -1 ‚Üí 10MB
  spark.sql.shuffle.partitions: 200 ‚Üí 8
  spark.sql.adaptive.enabled: false ‚Üí (unchanged for testing)

CODE CHANGES:
  1. broadcast() on ad_rates (48 rows) and payouts (5K rows)
  2. .cache() on enriched DataFrame (built once, used 6 times)
  3. Column pruning on all source tables
  4. Single build_revenue() call instead of 6 separate calls

RESULTS:
  Baseline:  {baseline_total:.2f}s
  Optimized: {optimized_total:.2f}s
  Speedup:   {baseline_total/optimized_total:.1f}x

PLAN IMPROVEMENTS:
  - SortMergeJoin ‚Üí BroadcastHashJoin (ad_rates, payouts)
  - 6 full recomputations ‚Üí 1 computation + 5 cache reads
  - 200 shuffle partitions ‚Üí 8 (matched to local cores)
  - ReadSchema reduced (column pruning)
""")


OPTIMIZATION REPORT ‚Äî StreamPulse Revenue Pipeline

Pipeline: Revenue Analytics (6 reports from joined data)

CONFIGURATION CHANGES:
  spark.sql.autoBroadcastJoinThreshold: -1 ‚Üí 10MB
  spark.sql.shuffle.partitions: 200 ‚Üí 8
  spark.sql.adaptive.enabled: false ‚Üí (unchanged for testing)

CODE CHANGES:
  1. broadcast() on ad_rates (48 rows) and payouts (5K rows)
  2. .cache() on enriched DataFrame (built once, used 6 times)
  3. Column pruning on all source tables
  4. Single build_revenue() call instead of 6 separate calls

RESULTS:
  Baseline:  45.59s
  Optimized: 2.61s
  Speedup:   17.5x

PLAN IMPROVEMENTS:
  - SortMergeJoin ‚Üí BroadcastHashJoin (ad_rates, payouts)
  - 6 full recomputations ‚Üí 1 computation + 5 cache reads
  - 200 shuffle partitions ‚Üí 8 (matched to local cores)
  - ReadSchema reduced (column pruning)

