In [1]:
from pyspark.sql import SparkSession

spark = SparkSession.builder.appName("Data Skew Example").master("local[*]").getOrCreate();


In [2]:
spark

In [3]:
df = spark.read.csv("/opt/data/ncr_ride_bookings.csv", header=True, inferSchema=True)

In [4]:
# 🔎 1. What is Skew?

# In distributed systems like Spark, data is partitioned across many worker nodes.
# 👉 Skew happens when some keys have way more rows than others.


from functools import reduce
from pyspark.sql import functions as F

def skew_summary(df):
    results = []
    for col in df.columns:
        stats = (
            df.groupBy(col).count()
              .agg(
                  F.count("*").alias("distinct_values"),
                  F.min("count").alias("min_count"),
                  F.expr("percentile_approx(count, 0.5)").alias("median_count"),
                  F.mean("count").alias("mean_count"),
                  F.max("count").alias("max_count")
              )
              .withColumn("column", F.lit(col))
              .withColumn("skew_ratio", F.col("max_count") / F.col("mean_count"))
              .withColumn(
                  "skew_class",
                  F.when(F.col("skew_ratio") <= 2, "Balanced")
                   .when((F.col("skew_ratio") > 2) & (F.col("skew_ratio") <= 10), "Moderately Skewed")
                   .otherwise("Highly Skewed")
              )
        )
        results.append(stats)
    
    # Merge all per-column results into one DataFrame
    summary_df = reduce(lambda a, b: a.unionByName(b), results)
    return summary_df

# Run summary
summary_df = skew_summary(df)
summary_df.show(truncate=False)


+---------------+---------+------------+------------------+---------+---------------------------------+------------------+-----------------+
|distinct_values|min_count|median_count|mean_count        |max_count|column                           |skew_ratio        |skew_class       |
+---------------+---------+------------+------------------+---------+---------------------------------+------------------+-----------------+
|365            |355      |411         |410.958904109589  |462      |Date                             |1.1242            |Balanced         |
|62910          |1        |2           |2.384358607534573 |16       |Time                             |6.7104            |Moderately Skewed|
|148767         |1        |1           |1.008288128415576 |3        |Booking ID                       |2.97534           |Moderately Skewed|
|5              |9000     |10500       |30000.0           |93000    |Booking Status                   |3.1               |Moderately Skewed|
|148788      

In [5]:
df.head(20)

[Row(Date=datetime.date(2024, 3, 23), Time=datetime.datetime(2025, 9, 14, 12, 29, 38), Booking ID='"""CNR5884300"""', Booking Status='No Driver Found', Customer ID='"""CID1982111"""', Vehicle Type='eBike', Pickup Location='Palam Vihar', Drop Location='Jhilmil', Avg VTAT='null', Avg CTAT='null', Cancelled Rides by Customer='null', Reason for cancelling by Customer='null', Cancelled Rides by Driver='null', Driver Cancellation Reason='null', Incomplete Rides='null', Incomplete Rides Reason='null', Booking Value='null', Ride Distance='null', Driver Ratings='null', Customer Rating='null', Payment Method='null'),
 Row(Date=datetime.date(2024, 11, 29), Time=datetime.datetime(2025, 9, 14, 18, 1, 39), Booking ID='"""CNR1326809"""', Booking Status='Incomplete', Customer ID='"""CID4604802"""', Vehicle Type='Go Sedan', Pickup Location='Shastri Nagar', Drop Location='Gurgaon Sector 56', Avg VTAT='4.9', Avg CTAT='14.0', Cancelled Rides by Customer='null', Reason for cancelling by Customer='null', Ca

In [7]:
from pyspark.sql.functions import col, rand, concat_ws, count

# Before salting: check distribution
print("=== BEFORE SALTING ===")
df.groupBy("Booking Status").agg(count("*").alias("count")).orderBy(col("count").desc()).show(truncate=False)

# Add salt
salted_df = df.withColumn(
    "BookingStatus_salted",
    concat_ws("_", col("Booking Status"), (rand()*10).cast("int"))  # 10 salts
)

# After salting: check distribution
print("=== AFTER SALTING ===")
salted_df.groupBy("BookingStatus_salted").agg(count("*").alias("count")).orderBy(col("count").desc()).show(truncate=False)


=== BEFORE SALTING ===
+---------------------+-----+
|Booking Status       |count|
+---------------------+-----+
|Completed            |93000|
|Cancelled by Driver  |27000|
|No Driver Found      |10500|
|Cancelled by Customer|10500|
|Incomplete           |9000 |
+---------------------+-----+

=== AFTER SALTING ===
+---------------------+-----+
|BookingStatus_salted |count|
+---------------------+-----+
|Completed_1          |9447 |
|Completed_6          |9442 |
|Completed_4          |9325 |
|Completed_8          |9310 |
|Completed_9          |9301 |
|Completed_3          |9286 |
|Completed_7          |9253 |
|Completed_2          |9248 |
|Completed_5          |9228 |
|Completed_0          |9160 |
|Cancelled by Driver_7|2802 |
|Cancelled by Driver_0|2735 |
|Cancelled by Driver_5|2720 |
|Cancelled by Driver_3|2717 |
|Cancelled by Driver_2|2708 |
|Cancelled by Driver_9|2695 |
|Cancelled by Driver_1|2689 |
|Cancelled by Driver_4|2662 |
|Cancelled by Driver_6|2638 |
|Cancelled by Driver_8|2

In [9]:
partial = salted_df.groupBy("BookingStatus_salted").agg(count("*").alias("partial_count"))


In [10]:
from pyspark.sql.functions import split, sum as _sum

# Step 1: GroupBy on salted key (parallelism fixed)
partial = salted_df.groupBy("BookingStatus_salted").agg(count("*").alias("partial_count"))

# Step 2: Extract original key (remove "_0", "_1", …)
unsalted = partial.withColumn("Booking Status", split(col("BookingStatus_salted"), "_")[0])

# Step 3: Aggregate again to restore correct totals
final = unsalted.groupBy("Booking Status").agg(_sum("partial_count").alias("total_count"))
final.show()


+--------------------+-----------+
|      Booking Status|total_count|
+--------------------+-----------+
|           Completed|      93000|
|     No Driver Found|      10500|
| Cancelled by Driver|      27000|
|Cancelled by Cust...|      10500|
|          Incomplete|       9000|
+--------------------+-----------+



In [11]:
from pyspark.sql.functions import count

print("=== Without Salting ===")
df.groupBy("Booking Status") \
  .agg(count("*").alias("count")) \
  .orderBy("count", ascending=False) \
  .show(truncate=False)


=== Without Salting ===
+---------------------+-----+
|Booking Status       |count|
+---------------------+-----+
|Completed            |93000|
|Cancelled by Driver  |27000|
|No Driver Found      |10500|
|Cancelled by Customer|10500|
|Incomplete           |9000 |
+---------------------+-----+



In [12]:
from pyspark.sql.functions import col, rand, concat_ws, split, sum as _sum

# Add salt (10 buckets)
salted = df.withColumn(
    "BookingStatus_salted",
    concat_ws("_", col("Booking Status"), (rand()*10).cast("int"))
)

# Partial aggregation (parallelized!)
partial = salted.groupBy("BookingStatus_salted") \
    .agg(count("*").alias("partial_count"))

# Unsalt (merge back into original key)
final = partial.withColumn("Booking Status", split(col("BookingStatus_salted"), "_")[0]) \
    .groupBy("Booking Status") \
    .agg(_sum("partial_count").alias("count"))

print("=== With Salting ===")
final.orderBy("count", ascending=False).show(truncate=False)


=== With Salting ===
+---------------------+-----+
|Booking Status       |count|
+---------------------+-----+
|Completed            |93000|
|Cancelled by Driver  |27000|
|No Driver Found      |10500|
|Cancelled by Customer|10500|
|Incomplete           |9000 |
+---------------------+-----+



In [13]:
# Repartition / Coalesce

# Repartition based on skewed column
df.repartition(20, "Booking Status").groupBy("Booking Status").count().show()


+--------------------+-----+
|      Booking Status|count|
+--------------------+-----+
|Cancelled by Cust...|10500|
| Cancelled by Driver|27000|
|     No Driver Found|10500|
|          Incomplete| 9000|
|           Completed|93000|
+--------------------+-----+



In [14]:
from pyspark.sql.functions import col, rand, concat_ws, sum as _sum

# Add salt
salted = df.withColumn("BookingValue_salted",
                       concat_ws("_", col("Booking Value"), (rand()*50).cast("int")))

# Partial aggregation
partial = salted.groupBy("BookingValue_salted").agg(_sum("Ride Distance").alias("partial_sum"))

# Final aggregation (remove salt)
final = partial.withColumn("Booking Value", col("BookingValue_salted").substr(0, 10)) \
               .groupBy("Booking Value").agg(_sum("partial_sum").alias("total_sum"))

final.show()


+-------------+------------------+
|Booking Value|         total_sum|
+-------------+------------------+
|      null_44|              NULL|
|      null_27|              NULL|
|       187_17|             40.04|
|       417_14|            131.88|
|      1042_23|             25.72|
|       713_15|             12.72|
|        92_25|             99.03|
|       282_23| 96.45000000000002|
|      1106_20|             43.19|
|       495_17|108.65999999999998|
|       284_22| 56.53999999999999|
|       842_18|             78.93|
|        91_21|             36.39|
|       432_17|            113.68|
|       390_15|46.330000000000005|
|       730_35|             94.63|
|       363_11|              49.7|
|       202_19|             22.25|
|       428_17| 70.22999999999999|
|        569_6|             32.75|
+-------------+------------------+
only showing top 20 rows



In [15]:
from pyspark.sql.functions import split

# Correct unsalting
final = partial.withColumn(
    "Booking Value", split(col("BookingValue_salted"), "_")[0]
).groupBy("Booking Value").agg(_sum("partial_sum").alias("total_sum"))

final.show()


+-------------+------------------+
|Booking Value|         total_sum|
+-------------+------------------+
|          691|1575.2400000000002|
|          467|3219.9700000000003|
|          296| 2866.810000000001|
|         4032|             32.11|
|         1436|113.60999999999999|
|          675|1615.0099999999995|
|          829|1447.7599999999998|
|         1159|            308.44|
|         1090| 699.5000000000001|
|         1512|246.42999999999995|
|         1572|            248.12|
|         2136|56.739999999999995|
|         2162|             49.44|
|         2294|             21.22|
|         2088|             29.09|
|          125|           4429.68|
|          451|           3484.38|
|          944| 952.4700000000001|
|          800|           1631.54|
|          853|1350.1800000000003|
+-------------+------------------+
only showing top 20 rows



In [16]:
# 🔹 What is Adaptive Query Execution (AQE)?

# Normally, Spark’s query plan (how it decides to run joins, shuffles, etc.) is decided before execution.
# But sometimes Spark guesses wrong because it doesn’t know the real data distribution until it actually runs.

# 🔹 What AQE does

# AQE has three main features:
# Dynamically coalesce shuffle partitions
# If Spark creates 200 shuffle partitions, but only 5 of them have data → AQE merges them automatically.
# This avoids tiny tasks overhead.
# Handle data skew dynamicall
# If one partition is much larger than others, AQE can split it into smaller ones (similar to salting, but automatic).
# This avoids stragglers (slow tasks due to skew).
# Change join strategy at runtime
# Spark may think a Sort-Merge Join is needed, but after seeing actual data sizes it realizes a Broadcast Join is faster.
# AQE switches automatically.

# 🔹 When to use AQE?

# When you have skewed data (e.g., one category dominates).

# When shuffle partitions are too many/too few.

# When data size is hard to estimate before execution.

spark.conf.set("spark.sql.adaptive.enabled", True)
spark.conf.set("spark.sql.adaptive.coalescePartitions.enabled", True)
spark.conf.set("spark.sql.adaptive.skewJoin.enabled", True)

from pyspark.sql import Row
data = [Row(BookingStatus="Completed") for _ in range(95000)] + \
       [Row(BookingStatus="Cancelled") for _ in range(5000)]

df = spark.createDataFrame(data)
# Aggregation that triggers shuffle
agg = df.groupBy("BookingStatus").count()
agg.explain(True)


== Parsed Logical Plan ==
'Aggregate ['BookingStatus], ['BookingStatus, count(1) AS count#3565L]
+- LogicalRDD [BookingStatus#3561], false

== Analyzed Logical Plan ==
BookingStatus: string, count: bigint
Aggregate [BookingStatus#3561], [BookingStatus#3561, count(1) AS count#3565L]
+- LogicalRDD [BookingStatus#3561], false

== Optimized Logical Plan ==
Aggregate [BookingStatus#3561], [BookingStatus#3561, count(1) AS count#3565L]
+- LogicalRDD [BookingStatus#3561], false

== Physical Plan ==
AdaptiveSparkPlan isFinalPlan=false
+- HashAggregate(keys=[BookingStatus#3561], functions=[count(1)], output=[BookingStatus#3561, count#3565L])
   +- Exchange hashpartitioning(BookingStatus#3561, 200), ENSURE_REQUIREMENTS, [plan_id=8824]
      +- HashAggregate(keys=[BookingStatus#3561], functions=[partial_count(1)], output=[BookingStatus#3561, count#3569L])
         +- Scan ExistingRDD[BookingStatus#3561]



In [17]:
spark.conf.set("spark.sql.adaptive.enabled", False)
df.groupBy("BookingStatus").count().explain(True)


== Parsed Logical Plan ==
'Aggregate ['BookingStatus], ['BookingStatus, count(1) AS count#3572L]
+- LogicalRDD [BookingStatus#3561], false

== Analyzed Logical Plan ==
BookingStatus: string, count: bigint
Aggregate [BookingStatus#3561], [BookingStatus#3561, count(1) AS count#3572L]
+- LogicalRDD [BookingStatus#3561], false

== Optimized Logical Plan ==
Aggregate [BookingStatus#3561], [BookingStatus#3561, count(1) AS count#3572L]
+- LogicalRDD [BookingStatus#3561], false

== Physical Plan ==
*(2) HashAggregate(keys=[BookingStatus#3561], functions=[count(1)], output=[BookingStatus#3561, count#3572L])
+- Exchange hashpartitioning(BookingStatus#3561, 200), ENSURE_REQUIREMENTS, [plan_id=8839]
   +- *(1) HashAggregate(keys=[BookingStatus#3561], functions=[partial_count(1)], output=[BookingStatus#3561, count#3576L])
      +- *(1) Scan ExistingRDD[BookingStatus#3561]



In [18]:
# So the logical/initial plan looks the same — but the key difference is:

# ❌ Without AQE → fixed number of shuffle partitions.

# ✅ With AQE → partitions can shrink/merge at runtime → fewer, bigger tasks, faster execution
# Say you group 100k rows into just 2 categories (Completed, Cancelled):
# Without AQE → Spark still spawns 200 shuffle partitions. 198 tasks are empty/useless.
# With AQE → Spark will merge down to 2 partitions. Only 2 tasks run after shuffle.
# That’s why in your plan you see AdaptiveSparkPlan isFinalPlan=false. When the job finishes, Spark will mark it as isFinalPlan=true with coalesced partitions.

spark.conf.set("spark.sql.adaptive.enabled", True)
df.groupBy("BookingStatus").count().explain(True)


== Parsed Logical Plan ==
'Aggregate ['BookingStatus], ['BookingStatus, count(1) AS count#3579L]
+- LogicalRDD [BookingStatus#3561], false

== Analyzed Logical Plan ==
BookingStatus: string, count: bigint
Aggregate [BookingStatus#3561], [BookingStatus#3561, count(1) AS count#3579L]
+- LogicalRDD [BookingStatus#3561], false

== Optimized Logical Plan ==
Aggregate [BookingStatus#3561], [BookingStatus#3561, count(1) AS count#3579L]
+- LogicalRDD [BookingStatus#3561], false

== Physical Plan ==
AdaptiveSparkPlan isFinalPlan=false
+- HashAggregate(keys=[BookingStatus#3561], functions=[count(1)], output=[BookingStatus#3561, count#3579L])
   +- Exchange hashpartitioning(BookingStatus#3561, 200), ENSURE_REQUIREMENTS, [plan_id=8860]
      +- HashAggregate(keys=[BookingStatus#3561], functions=[partial_count(1)], output=[BookingStatus#3561, count#3583L])
         +- Scan ExistingRDD[BookingStatus#3561]



In [19]:
from pyspark.sql import Row

# Large dataset
big = spark.createDataFrame([Row(booking_id=i, status_id=i % 5) for i in range(100000)])

# Small dataset
small = spark.createDataFrame([
    Row(status_id=0, status="Completed"),
    Row(status_id=1, status="Cancelled by Driver"),
    Row(status_id=2, status="Cancelled by Customer"),
    Row(status_id=3, status="No Driver Found"),
    Row(status_id=4, status="Incomplete"),
])


In [20]:
joined = big.join(small, "status_id")
joined.explain(True)


== Parsed Logical Plan ==
'Join UsingJoin(Inner, [status_id])
:- LogicalRDD [booking_id#3584L, status_id#3585L], false
+- LogicalRDD [status_id#3588L, status#3589], false

== Analyzed Logical Plan ==
status_id: bigint, booking_id: bigint, status: string
Project [status_id#3585L, booking_id#3584L, status#3589]
+- Join Inner, (status_id#3585L = status_id#3588L)
   :- LogicalRDD [booking_id#3584L, status_id#3585L], false
   +- LogicalRDD [status_id#3588L, status#3589], false

== Optimized Logical Plan ==
Project [status_id#3585L, booking_id#3584L, status#3589]
+- Join Inner, (status_id#3585L = status_id#3588L)
   :- Filter isnotnull(status_id#3585L)
   :  +- LogicalRDD [booking_id#3584L, status_id#3585L], false
   +- Filter isnotnull(status_id#3588L)
      +- LogicalRDD [status_id#3588L, status#3589], false

== Physical Plan ==
AdaptiveSparkPlan isFinalPlan=false
+- Project [status_id#3585L, booking_id#3584L, status#3589]
   +- SortMergeJoin [status_id#3585L], [status_id#3588L], Inner
   

In [21]:
from pyspark.sql.functions import broadcast

joined_b = big.join(broadcast(small), "status_id")
joined_b.explain(True)


== Parsed Logical Plan ==
'Join UsingJoin(Inner, [status_id])
:- LogicalRDD [booking_id#3584L, status_id#3585L], false
+- ResolvedHint (strategy=broadcast)
   +- LogicalRDD [status_id#3588L, status#3589], false

== Analyzed Logical Plan ==
status_id: bigint, booking_id: bigint, status: string
Project [status_id#3585L, booking_id#3584L, status#3589]
+- Join Inner, (status_id#3585L = status_id#3588L)
   :- LogicalRDD [booking_id#3584L, status_id#3585L], false
   +- ResolvedHint (strategy=broadcast)
      +- LogicalRDD [status_id#3588L, status#3589], false

== Optimized Logical Plan ==
Project [status_id#3585L, booking_id#3584L, status#3589]
+- Join Inner, (status_id#3585L = status_id#3588L), rightHint=(strategy=broadcast)
   :- Filter isnotnull(status_id#3585L)
   :  +- LogicalRDD [booking_id#3584L, status_id#3585L], false
   +- Filter isnotnull(status_id#3588L)
      +- LogicalRDD [status_id#3588L, status#3589], false

== Physical Plan ==
AdaptiveSparkPlan isFinalPlan=false
+- Project [