<h2> Imports & Configuration </h2>

In [0]:
from IPython.core.interactiveshell import InteractiveShell
InteractiveShell.ast_node_interactivity = 'all'

In [0]:
from pyspark.sql.types import *
import pyspark.sql.functions as F
from pyspark.sql import SparkSession

In [0]:
spark = SparkSession.builder.master("local[*]").getOrCreate()

In [0]:
spark.conf.set("spark.sql.shuffle.partitions", "3")
spark.conf.get("spark.sql.shuffle.partitions")
spark.conf.set("spark.sql.adaptive.enabled", "false")

<h2> Simulating Skewed Join </h2>

In [0]:
df_uniform = spark.createDataFrame([i for i in range(1000000)], IntegerType())
df_uniform.show(5, False)

In [0]:
(
    df_uniform
    .withColumn("partition", F.spark_partition_id())
    .groupBy("partition")
    .count()
    .orderBy("partition")
    .show(15, False)
)

In [0]:
df0 = spark.createDataFrame([0] * 999990, IntegerType()).repartition(1)
df1 = spark.createDataFrame([1] * 15, IntegerType()).repartition(1)
df2 = spark.createDataFrame([2] * 10, IntegerType()).repartition(1)
df3 = spark.createDataFrame([3] * 5, IntegerType()).repartition(1)
df_skew = df0.union(df1).union(df2).union(df3)
df_skew.show(5, False)

In [0]:
(
    df_skew
    .withColumn("partition", F.spark_partition_id())
    .groupBy("partition")
    .count()
    .orderBy("partition")
    .show()
)

In [0]:
df_joined_c1 = df_skew.join(df_uniform, "value", 'inner')

In [0]:
df_joined_c1.rdd.getNumPartitions()

In [0]:
df_joined_c1.explain()

In [0]:
df_joined_c1\
    .withColumn("partition", F.spark_partition_id())\
    .groupBy("partition")\
    .count().display()

<h2> Simulating Uniform Distribution Through Salting </h2>

In [0]:
SALT_NUMBER = int(spark.conf.get("spark.sql.shuffle.partitions"))
SALT_NUMBER

In [0]:
df_skew = df_skew.withColumn("salt", (F.rand() * SALT_NUMBER).cast("int"))

In [0]:
df_skew.show(10, truncate=False)

In [0]:
df_uniform = (
    df_uniform
    .withColumn("salt_values", F.array([F.lit(i) for i in range(SALT_NUMBER)]))
    .withColumn("salt", F.explode(F.col("salt_values")))
)

In [0]:
df_uniform.show(10, truncate=False)

In [0]:
df_joined = df_skew.join(df_uniform, ["value", "salt"], 'inner')

In [0]:
(
    df_joined
    .withColumn("partition", F.spark_partition_id())
    .groupBy("value", "partition")
    .count()
    .orderBy("value", "partition")
    .show()
)

# Salting In Aggregations

In [0]:
df_skew.groupBy("value").count().show()

In [0]:
(
    df_skew
    .withColumn("salt", (F.rand() * SALT_NUMBER).cast("int"))
    .groupBy("value", "salt")
    .agg(F.count("value").alias("count"))
    .groupBy("value")
    .agg(F.sum("count").alias("count"))
    .show()
)

In [0]:
spark.stop()