# Salting
- Salting refers to the technique in which you add a column to a skewed dataset such that it splits the data much more evenly.
- It's one of the ways to handle Data Skew
- Say for instance, you have 3 partitions and one of them has 1M records as compared to others having just 10-20 records. With salting you would be adding a column that will split the records evenly (~330K records per partitions)

## The Salting Process

1. First you choose a salt number (let's say 3)

2F**or the LARG **dataset (typically the orders/transactions table):
   - Add a RANDOM salt number between 0 and 2 to each row
   - Important: We don't replicate this data
   - Each row gets just ONE random salt ue
**
3. ForAe SMA**LLER dataset (typically the dimension/lookup table):
   - REPLICATE each row 3 times (one for each salt 0,1,2)
   - This is where we use explode or crossJoin

4. The distribution happens when we join:
   - Records only match when BOTH the join key AND salt values match
   - This naturally distributes the data like me to clarify any part of this further?

In [22]:
from pyspark.sql import functions as F
from pyspark.sql import SparkSession
from pyspark.sql.types import *

spark = SparkSession.builder.appName("salting").getOrCreate()

sc = spark.sparkContext
sc.setLogLevel("ERROR")

In [3]:
# Adaptive Query Execution is enabled by default in Spark
spark.conf.set("spark.sql.adaptive.enabled", False)

In [20]:
spark.conf.set("spark.sql.shuffle.partitions", "3")
spark.conf.get("spark.sql.shuffle.partitions")
spark.conf.set("spark.sql.adaptive.enabled", "false")

In [23]:
# creating a uniform dataset
df_uniform = spark.createDataFrame([i for i in range(1000000)], IntegerType())
df_uniform.show(5, False)

+-----+
|value|
+-----+
|0    |
|1    |
|2    |
|3    |
|4    |
+-----+
only showing top 5 rows



In [24]:
# F.spark_partition_id() returns the id of the partition for each row
(
    df_uniform
    .withColumn("partition", F.spark_partition_id())
    .groupBy("partition")
    .count()
    .orderBy("partition")
    .show(15, False)
)

+---------+------+
|partition|count |
+---------+------+
|0        |124928|
|1        |124928|
|2        |124928|
|3        |124928|
|4        |124928|
|5        |124928|
|6        |124928|
|7        |125504|
+---------+------+



In [25]:
# lets create a skewed data set
df0 = spark.createDataFrame([0] * 999990, IntegerType()).repartition(1)
df1 = spark.createDataFrame([1] * 15, IntegerType()).repartition(1)
df2 = spark.createDataFrame([2] * 10, IntegerType()).repartition(1)
df3 = spark.createDataFrame([3] * 5, IntegerType()).repartition(1)
df_skew = df0.union(df1).union(df2).union(df3)
df_skew.show(5, False)

+-----+
|value|
+-----+
|0    |
|0    |
|0    |
|0    |
|0    |
+-----+
only showing top 5 rows



In [26]:
(
    df_skew
    .withColumn("partition", F.spark_partition_id())
    .groupBy("partition")
    .count()
    .orderBy("partition")
    .show()
)

+---------+------+
|partition| count|
+---------+------+
|        0|999990|
|        1|    15|
|        2|    10|
|        3|     5|
+---------+------+



In [27]:
df_joined_c1 = df_skew.join(df_uniform, "value", 'inner')
df_joined_c1.show(3)
df_joined_c1.explain(mode='formatted')

+-----+
|value|
+-----+
|    0|
|    0|
|    0|
+-----+
only showing top 3 rows

== Physical Plan ==
* Project (21)
+- * SortMergeJoin Inner (20)
   :- * Sort (15)
   :  +- Exchange (14)
   :     +- Union (13)
   :        :- Exchange (3)
   :        :  +- * Filter (2)
   :        :     +- * Scan ExistingRDD (1)
   :        :- Exchange (6)
   :        :  +- * Filter (5)
   :        :     +- * Scan ExistingRDD (4)
   :        :- Exchange (9)
   :        :  +- * Filter (8)
   :        :     +- * Scan ExistingRDD (7)
   :        +- Exchange (12)
   :           +- * Filter (11)
   :              +- * Scan ExistingRDD (10)
   +- * Sort (19)
      +- Exchange (18)
         +- * Filter (17)
            +- * Scan ExistingRDD (16)


(1) Scan ExistingRDD [codegen id : 1]
Output [1]: [value#241]
Arguments: [value#241], MapPartitionsRDD[179] at applySchemaToPythonRDD at NativeMethodAccessorImpl.java:0, ExistingRDD, UnknownPartitioning(0)

(2) Filter [codegen id : 1]
Input [1]: [value#241]
Condition

In [28]:
(
    df_joined_c1
    .withColumn("partition", F.spark_partition_id())
    .groupBy("partition")
    .count()
    .show(5, False)
)

+---------+-------+
|partition|count  |
+---------+-------+
|0        |1000005|
|1        |15     |
+---------+-------+



### Joining with Salting

In [29]:
SALT_NUMBER = int(spark.conf.get("spark.sql.shuffle.partitions"))
SALT_NUMBER

3

In [30]:
df_skew = df_skew.withColumn("salt", (F.rand()*SALT_NUMBER).cast("int"))
df_skew.show(10)

+-----+----+
|value|salt|
+-----+----+
|    0|   1|
|    0|   2|
|    0|   1|
|    0|   0|
|    0|   2|
|    0|   2|
|    0|   0|
|    0|   0|
|    0|   0|
|    0|   2|
+-----+----+
only showing top 10 rows



In [33]:
df_uniform = df_uniform.withColumn("salt_values_array",
                       F.array([F.lit(i) for i in range(SALT_NUMBER)])
                       )\
                        .withColumn("salt", 
                                   F.explode(F.col("salt_values_array"))
                                   )

In [34]:
df_uniform.show(10)

+-----+-----------------+----+
|value|salt_values_array|salt|
+-----+-----------------+----+
|    0|        [0, 1, 2]|   0|
|    0|        [0, 1, 2]|   1|
|    0|        [0, 1, 2]|   2|
|    0|        [0, 1, 2]|   0|
|    0|        [0, 1, 2]|   1|
|    0|        [0, 1, 2]|   2|
|    0|        [0, 1, 2]|   0|
|    0|        [0, 1, 2]|   1|
|    0|        [0, 1, 2]|   2|
|    1|        [0, 1, 2]|   0|
+-----+-----------------+----+
only showing top 10 rows



In [35]:
df_joined_salted = df_skew.join(df_uniform, on = ["value", "salt"], how = "inner")

In [37]:
(
    df_joined_salted
    .withColumn("partition", F.spark_partition_id())
    .groupBy("value","partition")
    .count()
    .orderBy("value","partition")
    .show()
)

+-----+---------+-------+
|value|partition|  count|
+-----+---------+-------+
|    0|        0|1000293|
|    0|        1|1002156|
|    0|        2| 997521|
|    1|        0|     24|
|    1|        1|     21|
|    2|        0|     12|
|    2|        1|      9|
|    2|        2|      9|
|    3|        1|      9|
|    3|        2|      6|
+-----+---------+-------+



**You can see that value `0` has been almost evenly distributed in different partitions**

## Salting in Aggregation
n
1. Choose salt number (say N=3)

2. Assign random salt (0,1,2) to each row
   Example if aggregating customer purchases:
   ```
   (customer1, $100) -> (customer1, $100, salt=0)
   (customer1, $200) -> (customer1, $200, salt=2)
   (customer1, $300) -> (customer1, $300, salt=1)
   ```

3. First shuffle & groupBy(value, salt):
   - Data gets distributed using hash(customer, salt)
   - Instead of all customer1 records going to one partition:
     - (customer1, salt=0) records go to partition X
     - (customer1, salt=1) records go to partition Y
     - (customer1, salt=2) records go to partition Z

4. Partial aggregation in each partition:
   ```
   Partition X: (customer1, salt=0) -> $100
   Partition Y: (customer1, salt=1) -> $300
   Partition Z: (customer1, salt=2) -> $200
   ```

5. Final groupBy(value) to combine results:
   - Combines all partial aggregations for each customer
   - Final result: customer1 -> $600

The key advantage is that the expensive computation (initial aggregation) is distributed across partitions, making it much le of how this works in practice?

In [39]:
df_skew.groupBy("value","salt").agg(F.count("value").alias("count"))\
        .groupby("value").agg(F.sum("count").alias("count")).show()

+-----+------+
|value| count|
+-----+------+
|    0|999990|
|    2|    10|
|    3|     5|
|    1|    15|
+-----+------+

