<img src="./uva_seal.png">  

## PySpark Hotspot Demo

### University of Virginia
### DS 5110: Big Data Systems
### Last Updated: January 19, 2026

---  


### BACKGROUND

In Spark, a data hotspot happens when one or a few keys have much more data than others — causing one task to do most of the work while others sit idle.

It is a data imbalance problem.

This notebook demonstrates two small examples of how it arises and how it can be remediated.

#### EXAMPLE 1: SALTING A SMALLER DATASET

In [None]:
# Import modules
from pyspark.sql import SparkSession
from pyspark.sql.functions import col, count

# Create Spark session
spark = SparkSession.builder.appName("HotspotDemo").getOrCreate()

# Create skewed data
data = [
    ("user1", "click"), ("user1", "click"), ("user1", "click"), ("user1", "click"), ("user1", "click"),
    ("user1", "click"), ("user1", "click"), ("user1", "click"), ("user1", "click"), ("user1", "click"), # Hot key
    ("user2", "click"), ("user3", "click"), ("user4", "click")
]

df = spark.createDataFrame(data, ["user_id", "event"])

# Aggregate by user_id which causes data skew on "user1"
counts = df.groupBy("user_id").agg(count("*").alias("total_events"))

counts.show()

#### THE PROBLEM

Most records belong to "user1" key.

During `groupBy`, that key becomes a hot partition.

Spark assigns a single task to handle all "user1" data, slowing the job.

#### A SOLUTION: SALTING

Have Spark add a small random “salt” (0–9) to each record’s key, spreading the hot key’s data across multiple reducers.

Then aggregate back to the original key after processing.

In [None]:
from pyspark.sql.functions import lit, concat, rand, expr

# Add a random salt to break up the skewed key
salted = df.withColumn("salted_key", concat(col("user_id"), lit("_"), (rand() * 10).cast("int")))
print('Salted data:')
salted.show()

# Aggregate by salted key
salted_counts = salted.groupBy("salted_key").agg(count("*").alias("partial_count"))
print('Salted counts by key:')
salted_counts.show()

# Aggregate back by original user_id
final_counts = salted_counts.groupBy(expr("split(salted_key, '_')[0]").alias("user_id")) \
                            .agg(expr("sum(partial_count)").alias("total_events"))

print('Counts by original key:')
final_counts.show()


---

#### EXAMPLE 2: LARGER EXAMPLE OF SALTING WITH RUNTIME COMPARE

First, we create a large, skewed dataset

In [None]:
from pyspark.sql.functions import split
import time

# Create Spark session
spark = SparkSession.builder.appName("HotspotTimingDemo").getOrCreate()

# 99% of rows belong to one key ("hot_user")
data = [("hot_user", i) for i in range(990000)] + \
       [(f"user_{i}", i) for i in range(1, 10000)]

df = spark.createDataFrame(data, ["user_id", "event_id"])

print(f"Total rows: {df.count():,}")

In [None]:
# inspect some rows
df.head(5)

**Next, we aggregate without fixing skew and compute runtime**

In [None]:
start_time = time.time()
counts_no_fix = df.groupBy("user_id").agg(count("*").alias("total_events"))
counts_no_fix.count()  # force computation
time_no_fix = time.time() - start_time
print(f"Runtime without fix: {time_no_fix:.2f} seconds")

**Salt the data to break up skew**

In [None]:
salted = df.withColumn("salted_key", concat(col("user_id"), lit("_"), (rand() * 10).cast("int")))

salted_counts = salted.groupBy("salted_key").agg(count("*").alias("partial_count"))

final_counts = salted_counts.groupBy(expr("split(salted_key, '_')[0]").alias("user_id")) \
                            .agg(expr("sum(partial_count)").alias("total_events"))

**Compute and compare runtimes**

In [None]:
start_time = time.time()
final_counts.count()  # force computation
time_with_fix = time.time() - start_time
print(f"Runtime with salting fix: {time_with_fix:.2f} seconds")

# compare runtimes
improvement = ((time_no_fix - time_with_fix) / time_no_fix) * 100
print(f"Improvement: {improvement:.1f}% faster after removing hotspot")

spark.stop()