<h2> Imports & Configuration </h2>

In [2]:
from IPython.core.interactiveshell import InteractiveShell
InteractiveShell.ast_node_interactivity = 'all'

In [3]:
from pyspark.sql.types import *
import pyspark.sql.functions as F
from pyspark.sql import SparkSession

In [4]:
spark = SparkSession.builder.master("local[*]").getOrCreate()

23/07/09 21:34:24 WARN Utils: Your hostname, Afaques-MacBook-Pro.local resolves to a loopback address: 127.0.0.1; using 10.0.0.4 instead (on interface en0)
23/07/09 21:34:24 WARN Utils: Set SPARK_LOCAL_IP if you need to bind to another address
Using Spark's default log4j profile: org/apache/spark/log4j-defaults.properties
Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
23/07/09 21:34:25 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable


In [5]:
spark.conf.set("spark.sql.shuffle.partitions", "3")
spark.conf.get("spark.sql.shuffle.partitions")
spark.conf.set("spark.sql.adaptive.enabled", "false")

'3'

<h2> Simulating Skewed Join </h2>

In [6]:
df_uniform = spark.createDataFrame([i for i in range(1000000)], IntegerType())
df_uniform.show(5, False)

[Stage 0:>                                                          (0 + 1) / 1]

+-----+
|value|
+-----+
|0    |
|1    |
|2    |
|3    |
|4    |
+-----+
only showing top 5 rows



                                                                                

In [7]:
(
    df_uniform
    .withColumn("partition", F.spark_partition_id())
    .groupBy("partition")
    .count()
    .orderBy("partition")
    .show(15, False)
)

[Stage 1:====>                                                    (1 + 11) / 12]

+---------+-----+
|partition|count|
+---------+-----+
|0        |82944|
|1        |82944|
|2        |83968|
|3        |82944|
|4        |83968|
|5        |82944|
|6        |82944|
|7        |83968|
|8        |82944|
|9        |83968|
|10       |82944|
|11       |83520|
+---------+-----+



                                                                                

In [8]:
df0 = spark.createDataFrame([0] * 999990, IntegerType()).repartition(1)
df1 = spark.createDataFrame([1] * 15, IntegerType()).repartition(1)
df2 = spark.createDataFrame([2] * 10, IntegerType()).repartition(1)
df3 = spark.createDataFrame([3] * 5, IntegerType()).repartition(1)
df_skew = df0.union(df1).union(df2).union(df3)
df_skew.show(5, False)

[Stage 5:>                (0 + 12) / 12][Stage 6:>                 (0 + 0) / 12][Stage 6:>                                                        (0 + 12) / 12]

+-----+
|value|
+-----+
|0    |
|0    |
|0    |
|0    |
|0    |
+-----+
only showing top 5 rows



                                                                                

In [9]:
(
    df_skew
    .withColumn("partition", F.spark_partition_id())
    .groupBy("partition")
    .count()
    .orderBy("partition")
    .show()
)

+---------+------+
|partition| count|
+---------+------+
|        0|999990|
|        1|    15|
|        2|    10|
|        3|     5|
+---------+------+



In [10]:
df_joined_c1 = df_skew.join(df_uniform, "value", 'inner')

In [11]:
(
    df_joined_c1
    .withColumn("partition", F.spark_partition_id())
    .groupBy("partition")
    .count()
    .show(5, False)
)

[Stage 15:>                                                       (0 + 12) / 12]

+---------+-------+
|partition|count  |
+---------+-------+
|0        |1000005|
|1        |15     |
+---------+-------+





<h2> Simulating Uniform Distribution Through Salting </h2>

In [12]:
SALT_NUMBER = int(spark.conf.get("spark.sql.shuffle.partitions"))
SALT_NUMBER

3

In [13]:
df_skew = df_skew.withColumn("salt", (F.rand() * SALT_NUMBER).cast("int"))

In [14]:
df_skew.show(10, truncate=False)

+-----+----+
|value|salt|
+-----+----+
|0    |2   |
|0    |0   |
|0    |0   |
|0    |1   |
|0    |2   |
|0    |2   |
|0    |0   |
|0    |2   |
|0    |2   |
|0    |1   |
+-----+----+
only showing top 10 rows



In [15]:
df_uniform = (
    df_uniform
    .withColumn("salt_values", F.array([F.lit(i) for i in range(SALT_NUMBER)]))
    .withColumn("salt", F.explode(F.col("salt_values")))
)

In [16]:
df_uniform.show(10, truncate=False)

+-----+-----------+----+
|value|salt_values|salt|
+-----+-----------+----+
|0    |[0, 1, 2]  |0   |
|0    |[0, 1, 2]  |1   |
|0    |[0, 1, 2]  |2   |
|1    |[0, 1, 2]  |0   |
|1    |[0, 1, 2]  |1   |
|1    |[0, 1, 2]  |2   |
|2    |[0, 1, 2]  |0   |
|2    |[0, 1, 2]  |1   |
|2    |[0, 1, 2]  |2   |
|3    |[0, 1, 2]  |0   |
+-----+-----------+----+
only showing top 10 rows



In [17]:
df_joined = df_skew.join(df_uniform, ["value", "salt"], 'inner')

In [18]:
(
    df_joined
    .withColumn("partition", F.spark_partition_id())
    .groupBy("value", "partition")
    .count()
    .orderBy("value", "partition")
    .show()
)

[Stage 42:>                                                         (0 + 3) / 3]

+-----+---------+------+
|value|partition| count|
+-----+---------+------+
|    0|        0|332774|
|    0|        1|333601|
|    0|        2|333615|
|    1|        0|     6|
|    1|        1|     9|
|    2|        0|     2|
|    2|        1|     2|
|    2|        2|     6|
|    3|        0|     3|
|    3|        1|     2|
+-----+---------+------+



                                                                                

# Salting In Aggregations

In [19]:
df_skew.groupBy("value").count().show()

+-----+------+
|value| count|
+-----+------+
|    0|999990|
|    2|    10|
|    3|     5|
|    1|    15|
+-----+------+



In [None]:
(
    df_skew
    .withColumn("salt", (F.rand() * SALT_NUMBER).cast("int"))
    .groupBy("value", "salt")
    .agg(F.count("value").alias("count"))
    .groupBy("value")
    .agg(F.sum("count").alias("count"))
    .show()
)

In [None]:
spark.stop()