# **Taming Data Skew in Spark with Salting — An End-to-End Guide**

If you’ve ever run a Spark job that was supposed to finish in minutes but ended up running for an eternity (or failed altogether), you might have met the sneaky culprit: **data skew**.

Data skew happens when certain keys in your dataset have way more records than others. In a distributed system like Spark, this means that **some executors do way more work than others**, creating performance bottlenecks.

In this post, we’ll explore **salting** — a simple but powerful trick to balance the load and speed up your Spark jobs. We’ll walk through a **complete end-to-end example** so you can see exactly how it works.

# **What is Data Skew?**

Imagine you’re grouping transactions by customer_id to calculate the total amount spent..

If **one customer** (say CUST001) has millions of transactions while others have just a few, Spark’s shuffle phase will send **all those millions of records to a single reducer.**

This leads to:

*   Slow stages
*   Straggler tasks
*   Out-of-memory errors











# **The Idea Behind Salting**

**Salting** means adding a **small, random or deterministic “salt” value** to your keys before expensive operations like **groupBy or join.**

Instead of all records for CUST001 ending up in one partition, we spread them across multiple reducers. After the initial operation, we can **remove the salt** and combine results.

# **End-to-End Salting Example in PySpark**

Let’s imagine we have:

*  A large dataset of transactions with a customer_id
*   We want to find total amount spent by each customer
*   Problem: one customer_id (“CUST001”) has millions of records → skew.









### **Step 1 — Sample Skewed Data**

In [24]:
from pyspark.sql import SparkSession
from pyspark.sql.functions import col, lit
import random

spark = SparkSession.builder \
    .appName("SaltingExample") \
    .master("local[*]") \
    .getOrCreate()

# Skewed dataset
data = []
for i in range(1000000):
    # Mostly customer_id = CUST001 (skewed)
    cust_id = "CUST001" if random.random() < 0.7 else f"CUST{random.randint(2, 50)}"
    amount = random.randint(10, 500)
    data.append((cust_id, amount))

df = spark.createDataFrame(data, ["customer_id", "amount"])
df.show(5)


+-----------+------+
|customer_id|amount|
+-----------+------+
|    CUST001|   224|
|    CUST001|   343|
|    CUST001|   179|
|    CUST001|   254|
|    CUST001|    53|
+-----------+------+
only showing top 5 rows



In [25]:
spark

In [26]:
df.count()

1000000

In [32]:
df.filter(col("customer_id") == "CUST001").count()

700638

**Here, More than 70% of transactions belong to CUST001, simulating a skew.**

### **Step 2 — Without Salting (Skew Problem)**

In [27]:
df.groupBy("customer_id").sum("amount").show()

+-----------+-----------+
|customer_id|sum(amount)|
+-----------+-----------+
|     CUST40|    1548054|
|     CUST49|    1628669|
|     CUST30|    1535821|
|     CUST36|    1571761|
|      CUST2|    1572690|
|     CUST43|    1581272|
|     CUST46|    1552256|
|     CUST42|    1576348|
|    CUST001|  178635654|
|      CUST4|    1574003|
|     CUST31|    1534098|
|      CUST9|    1561539|
|     CUST14|    1544118|
|      CUST6|    1555607|
|     CUST16|    1555347|
|      CUST5|    1557201|
|     CUST34|    1550529|
|     CUST47|    1536868|
|     CUST19|    1559625|
|     CUST25|    1548677|
+-----------+-----------+
only showing top 20 rows



If we monitor the Spark UI Stage DAG, we’ll notice:

*   The key CUST001 goes mostly into 1 reducer → unbalanced tasks.



### **Step 3 — Add Salt Key**

We add a small random number (salt) to distribute skewed keys.

In [29]:
from pyspark.sql.functions import monotonically_increasing_id, rand, floor, concat

# Number of salts (adjust based on skew severity)
num_salts = 5

# Add a salt column (random int from 0 to num_salts-1)
df_salted = df.withColumn("salt", floor(rand() * num_salts))

# Create a composite key: customer_id + salt
df_salted = df_salted.withColumn("customer_salt", concat(col("customer_id") , lit("_") , col("salt").cast("string")))

df_salted.show(5)


+-----------+------+----+-------------+
|customer_id|amount|salt|customer_salt|
+-----------+------+----+-------------+
|    CUST001|   224|   1|    CUST001_1|
|    CUST001|   343|   0|    CUST001_0|
|    CUST001|   179|   2|    CUST001_2|
|    CUST001|   254|   2|    CUST001_2|
|    CUST001|    53|   1|    CUST001_1|
+-----------+------+----+-------------+
only showing top 5 rows



### **Step 4 — Aggregate on Salted Key**

In [30]:
# Step 1: Aggregate with salted key
agg_salted = df_salted.groupBy("customer_salt").sum("amount")

agg_salted.show(5)

+-------------+-----------+
|customer_salt|sum(amount)|
+-------------+-----------+
|     CUST13_4|     305319|
|     CUST45_3|     321322|
|     CUST50_1|     313721|
|     CUST21_3|     318822|
|      CUST8_2|     321668|
+-------------+-----------+
only showing top 5 rows



In [31]:
# Step 2: Remove salt by grouping again
from pyspark.sql.functions import split

agg_final = agg_salted.withColumn("customer_id", split(col("customer_salt"), "_")[0]) \
                      .groupBy("customer_id") \
                      .sum("sum(amount)")

agg_final.show()

+-----------+----------------+
|customer_id|sum(sum(amount))|
+-----------+----------------+
|     CUST40|         1548054|
|     CUST49|         1628669|
|     CUST36|         1571761|
|     CUST30|         1535821|
|      CUST2|         1572690|
|     CUST43|         1581272|
|     CUST46|         1552256|
|     CUST42|         1576348|
|    CUST001|       178635654|
|      CUST4|         1574003|
|     CUST31|         1534098|
|      CUST9|         1561539|
|     CUST14|         1544118|
|      CUST6|         1555607|
|     CUST16|         1555347|
|      CUST5|         1557201|
|     CUST34|         1550529|
|     CUST47|         1536868|
|     CUST19|         1559625|
|     CUST25|         1548677|
+-----------+----------------+
only showing top 20 rows



## **When to Use Salting**
✅ Use it when:


*   You’ve identified skewed keys in joins or aggregations.
*   Spark UI shows uneven shuffle read/write sizes.



🚫 Avoid it when:



*   Your data is evenly distributed (it’ll just add overhead).
*   You have too many unique keys (salting might not be needed).






## **Final Thoughts**

Salting is a **low-cost, high-impact optimization** when dealing with skewed datasets in Spark. By distributing hot keys across multiple partitions, you prevent bottlenecks and keep your cluster running efficiently.

If you regularly work with big data pipelines, **this is a trick worth keeping in your toolbox.**