In [1]:
from pyspark.sql import SparkSession
import pyspark.sql.functions as F
from pyspark.sql.types import *

spark = SparkSession.builder.appName("salting").master("local[*]").getOrCreate()

In [2]:
spark.conf.set("spark.sql.adaptive.enabled",False)
spark.conf.set("spark.sql.shuffle.partitions",3)

# Simulating Skewed Join

In [4]:
# Create skewed dataframe
df0 = spark.createDataFrame([0] * 999990, IntegerType()).repartition(1)
df1 = spark.createDataFrame([1] * 15, IntegerType()).repartition(1)
df2 = spark.createDataFrame([2] * 10, IntegerType()).repartition(1)
df3 = spark.createDataFrame([3] * 5, IntegerType()).repartition(1)
df_skew = df0.union(df1).union(df2).union(df3)
df_skew.show(5, False)

+-----+
|value|
+-----+
|0    |
|0    |
|0    |
|0    |
|0    |
+-----+
only showing top 5 rows



In [5]:
# Lets check how data is distributed across partitions
(
    df_skew
    .withColumn("partition", F.spark_partition_id())
    .groupBy("partition")
    .count()
    .orderBy("partition")
    .show()
)

# we can see that partition 0 contains a lot more records compared to other partitions indicating data skew

+---------+------+
|partition| count|
+---------+------+
|        0|999990|
|        1|    15|
|        2|    10|
|        3|     5|
+---------+------+



In [6]:
# create a uniform dataframe to perform a join
df_uniform = spark.createDataFrame([i for i in range(1000000)], IntegerType())
df_uniform.show(5, False)

+-----+
|value|
+-----+
|0    |
|1    |
|2    |
|3    |
|4    |
+-----+
only showing top 5 rows



In [9]:
# Now to see how the skewed data affects join
df_joined_c1 = df_skew.join(df_uniform, "value","inner")

In [10]:
# Lets check how data is distributed across partitions after join
(
    df_joined_c1
    .withColumn("partition", F.spark_partition_id())
    .groupBy("partition")
    .count()
    .show(5, False)
)

# we can see that 

+---------+-------+
|partition|count  |
+---------+-------+
|0        |1000005|
|1        |15     |
+---------+-------+



# Simulating Uniform Distribution Through Salting

In [13]:
#Lets define a salt number
SALT_NUMBER = 3

In [14]:
# lets create new column with the salt value
df_skew = df_skew.withColumn("salt", (F.rand() * SALT_NUMBER).cast("int"))
df_skew.show(10, truncate=False)

+-----+----+
|value|salt|
+-----+----+
|0    |2   |
|0    |1   |
|0    |0   |
|0    |2   |
|0    |0   |
|0    |0   |
|0    |1   |
|0    |2   |
|0    |2   |
|0    |0   |
+-----+----+
only showing top 10 rows



In [15]:
# We have to explode the other joining dataframe with salt values
df_uniform = (
    df_uniform
    .withColumn("salt_values", F.array([F.lit(i) for i in range(SALT_NUMBER)]))
    .withColumn("salt", F.explode(F.col("salt_values")))
)
df_uniform.show(10, truncate=False)

+-----+-----------+----+
|value|salt_values|salt|
+-----+-----------+----+
|0    |[0, 1, 2]  |0   |
|0    |[0, 1, 2]  |1   |
|0    |[0, 1, 2]  |2   |
|1    |[0, 1, 2]  |0   |
|1    |[0, 1, 2]  |1   |
|1    |[0, 1, 2]  |2   |
|2    |[0, 1, 2]  |0   |
|2    |[0, 1, 2]  |1   |
|2    |[0, 1, 2]  |2   |
|3    |[0, 1, 2]  |0   |
+-----+-----------+----+
only showing top 10 rows



In [17]:
# now lets join the dataframes
# we have join including the salt column as well
df_joined = df_skew.join(df_uniform, ["value", "salt"], 'inner')

In [18]:
# Lets see the distribution across partitions after joining salted data
(
    df_joined
    .withColumn("partition", F.spark_partition_id())
    .groupBy("value", "partition")
    .count()
    .orderBy("value", "partition")
    .show()
)

# We see that the hot value is not distributed across 3 partitions (salt_number)

+-----+---------+------+
|value|partition| count|
+-----+---------+------+
|    0|        0|333447|
|    0|        1|333538|
|    0|        2|333005|
|    1|        0|     3|
|    1|        1|    12|
|    2|        0|     4|
|    2|        1|     4|
|    2|        2|     2|
|    3|        0|     1|
|    3|        1|     3|
|    3|        2|     1|
+-----+---------+------+



# Salting In Aggregations

In [19]:
# Lets try group by on skewed data without salting
df_skew.groupBy("value").count().show()

+-----+------+
|value| count|
+-----+------+
|    0|999990|
|    2|    10|
|    3|     5|
|    1|    15|
+-----+------+



In [23]:
# Lets try it on salted data.
# First we need to group by on both value and salt, then remove the salt and again group by on value to get final result
(
    df_skew
    .withColumn("salt", (F.rand() * SALT_NUMBER).cast("int"))
    .groupBy("value", "salt")
    .agg(F.count("value").alias("count"))
    .groupBy("value")
    .agg(F.sum("count").alias("count"))
    .show()
)

+-----+------+
|value| count|
+-----+------+
|    0|999990|
|    2|    10|
|    3|     5|
|    1|    15|
+-----+------+



In [24]:
spark.stop()