In [None]:
if spark:
    spark.stop()

In [21]:
from pyspark.sql import SparkSession

# Creating a Spark session
spark = SparkSession.builder.master("local[*]").appName("data_skew").getOrCreate()

In [22]:
# Set Spark parameters - We have to turn off AQL to demonstrate Salting
spark.conf.set("spark.sql.adaptive.enabled", False)
spark.conf.set("spark.sql.shuffle.partitions", 5)
# Check the parameters
print(spark.conf.get("spark.sql.adaptive.enabled"))
print(spark.conf.get("spark.sql.shuffle.partitions"))

false
5


In [23]:
# Lets create the example dataset of fact and dimension we would use for demonstration
# Python program to generate random Fact table data
# [1, ,"ORD1001", "D102", 56]
import random
def generate_fact_data(counter=100):
    fact_records = []
    dim_keys = ["D100", "D101", "D102", "D103", "D104"]
    order_ids = ["ORD" + str(i) for i in range(1001, 1010)]
    qty_range = [i for i in range(10, 120)]
    for i in range(counter):
        _record = [i, random.choice(order_ids), random.choice(dim_keys), random.choice(qty_range)]
        fact_records.append(_record)
    return fact_records

# We will generate 200 records with random data in Fact to create skewness
fact_records = generate_fact_data(200)
dim_records = [
    ["D100", "Product A"],
    ["D101", "Product B"],
    ["D102", "Product C"],
    ["D103", "Product D"],
    ["D104", "Product E"]
]
_fact_cols = ["id", "order_id", "prod_id", "qty"]
_dim_cols = ["prod_id", "prod_name"]

In [24]:
# Generate Fact Data Frame
fact_df = spark.createDataFrame(data = fact_records, schema=_fact_cols)
fact_df.printSchema()
fact_df.show(10, truncate = False)

root
 |-- id: long (nullable = true)
 |-- order_id: string (nullable = true)
 |-- prod_id: string (nullable = true)
 |-- qty: long (nullable = true)

+---+--------+-------+---+
|id |order_id|prod_id|qty|
+---+--------+-------+---+
|0  |ORD1002 |D104   |23 |
|1  |ORD1006 |D104   |95 |
|2  |ORD1005 |D103   |14 |
|3  |ORD1008 |D103   |31 |
|4  |ORD1007 |D103   |16 |
|5  |ORD1008 |D102   |106|
|6  |ORD1007 |D101   |14 |
|7  |ORD1009 |D101   |24 |
|8  |ORD1003 |D104   |80 |
|9  |ORD1004 |D102   |17 |
+---+--------+-------+---+
only showing top 10 rows



In [35]:
joined_df.groupBy("prod_id").agg(count("id")).show()

+-------+---------+
|prod_id|count(id)|
+-------+---------+
|   D103|       40|
|   D104|       44|
|   D100|       41|
|   D101|       34|
|   D102|       41|
+-------+---------+



In [26]:
# Generate Prod Dim Data Frame
dim_df = spark.createDataFrame(data = dim_records, schema=_dim_cols)
dim_df.printSchema()
dim_df.show(10, False)

root
 |-- prod_id: string (nullable = true)
 |-- prod_name: string (nullable = true)

+-------+---------+
|prod_id|prod_name|
+-------+---------+
|D100   |Product A|
|D101   |Product B|
|D102   |Product C|
|D103   |Product D|
|D104   |Product E|
+-------+---------+



In [33]:
# Check the parameters
print(spark.conf.get("spark.sql.adaptive.enabled"))
print(spark.conf.get("spark.sql.shuffle.partitions"))

false
5


In [None]:
# Set Spark parameters - We have to turn off AQL to demonstrate Salting
spark.conf.set("spark.sql.adaptive.enabled", False)
spark.conf.set("spark.sql.shuffle.partitions", 5)
# Check the parameters
print(spark.conf.get("spark.sql.adaptive.enabled"))
print(spark.conf.get("spark.sql.shuffle.partitions"))

In [27]:
# Lets join the fact and dim without salting
joined_df = fact_df.join(dim_df, on="prod_id", how="leftouter")
joined_df.show(10, False)

+-------+---+--------+---+---------+
|prod_id|id |order_id|qty|prod_name|
+-------+---+--------+---+---------+
|D103   |2  |ORD1005 |14 |Product D|
|D103   |3  |ORD1008 |31 |Product D|
|D103   |4  |ORD1007 |16 |Product D|
|D103   |29 |ORD1006 |21 |Product D|
|D103   |33 |ORD1001 |27 |Product D|
|D103   |52 |ORD1008 |78 |Product D|
|D103   |53 |ORD1009 |94 |Product D|
|D103   |77 |ORD1006 |63 |Product D|
|D103   |80 |ORD1003 |88 |Product D|
|D103   |103|ORD1008 |50 |Product D|
+-------+---+--------+---+---------+
only showing top 10 rows



In [36]:
# Check the partition details to understand distribution
from pyspark.sql.functions import spark_partition_id, count
partition_df = joined_df.withColumn("partition_num", spark_partition_id()).groupBy("partition_num").agg(count("id"))
partition_df.show()

+-------------+---------+
|partition_num|count(id)|
+-------------+---------+
|            4|      116|
|            2|       84|
+-------------+---------+



In [37]:
# Let prepare the salt
import random
from pyspark.sql.functions import udf
# UDF to return a random number every time
def rand(): return random.randint(0, 4) #Since we are distributing the data in 5 partitions
rand_udf = udf(rand)
# Salt Data Frame to add to dimension
salt_df = spark.range(0, 5)
salt_df.show()


+---+
| id|
+---+
|  0|
|  1|
|  2|
|  3|
|  4|
+---+



In [38]:
# Salted Fact
from pyspark.sql.functions import lit, expr, concat
salted_fact_df = fact_df.withColumn("salted_prod_id", concat("prod_id",lit("_"), lit(rand_udf())))
salted_fact_df.show(10, False)

# adding salt id sampled from a uniform distribution to a bunch of product_ids creates a uniformly distributed product_ids

+---+--------+-------+---+--------------+
|id |order_id|prod_id|qty|salted_prod_id|
+---+--------+-------+---+--------------+
|0  |ORD1002 |D104   |23 |D104_2        |
|1  |ORD1006 |D104   |95 |D104_1        |
|2  |ORD1005 |D103   |14 |D103_0        |
|3  |ORD1008 |D103   |31 |D103_0        |
|4  |ORD1007 |D103   |16 |D103_3        |
|5  |ORD1008 |D102   |106|D102_3        |
|6  |ORD1007 |D101   |14 |D101_3        |
|7  |ORD1009 |D101   |24 |D101_2        |
|8  |ORD1003 |D104   |80 |D104_4        |
|9  |ORD1004 |D102   |17 |D102_1        |
+---+--------+-------+---+--------------+
only showing top 10 rows



In [39]:
# Salted DIM
salted_dim_df = dim_df.join(salt_df, how="cross").withColumn("salted_prod_id", concat("prod_id", lit("_"), "id")).drop("id")
salted_dim_df.show()


+-------+---------+--------------+
|prod_id|prod_name|salted_prod_id|
+-------+---------+--------------+
|   D100|Product A|        D100_0|
|   D100|Product A|        D100_1|
|   D100|Product A|        D100_2|
|   D100|Product A|        D100_3|
|   D100|Product A|        D100_4|
|   D101|Product B|        D101_0|
|   D101|Product B|        D101_1|
|   D101|Product B|        D101_2|
|   D101|Product B|        D101_3|
|   D101|Product B|        D101_4|
|   D102|Product C|        D102_0|
|   D102|Product C|        D102_1|
|   D102|Product C|        D102_2|
|   D102|Product C|        D102_3|
|   D102|Product C|        D102_4|
|   D103|Product D|        D103_0|
|   D103|Product D|        D103_1|
|   D103|Product D|        D103_2|
|   D103|Product D|        D103_3|
|   D103|Product D|        D103_4|
+-------+---------+--------------+
only showing top 20 rows



In [40]:
# Lets make the salted join now
salted_joined_df = salted_fact_df.join(salted_dim_df, on="salted_prod_id", how="leftouter")
salted_joined_df.show(10, False)

+--------------+---+--------+-------+---+-------+---------+
|salted_prod_id|id |order_id|prod_id|qty|prod_id|prod_name|
+--------------+---+--------+-------+---+-------+---------+
|D100_0        |75 |ORD1002 |D100   |64 |D100   |Product A|
|D100_0        |79 |ORD1008 |D100   |82 |D100   |Product A|
|D100_0        |85 |ORD1006 |D100   |44 |D100   |Product A|
|D100_0        |100|ORD1002 |D100   |39 |D100   |Product A|
|D100_0        |125|ORD1006 |D100   |99 |D100   |Product A|
|D100_0        |129|ORD1005 |D100   |88 |D100   |Product A|
|D100_1        |126|ORD1007 |D100   |68 |D100   |Product A|
|D100_1        |182|ORD1004 |D100   |90 |D100   |Product A|
|D104_1        |8  |ORD1003 |D104   |80 |D104   |Product E|
|D104_1        |82 |ORD1003 |D104   |84 |D104   |Product E|
+--------------+---+--------+-------+---+-------+---------+
only showing top 10 rows



In [43]:
salted_joined_df.groupBy("salted_prod_id").agg(count("id")).show()

+--------------+---------+
|salted_prod_id|count(id)|
+--------------+---------+
|        D100_0|        7|
|        D100_1|       10|
|        D104_1|       13|
|        D101_1|        7|
|        D103_1|       11|
|        D103_2|        7|
|        D103_3|       11|
|        D104_0|       10|
|        D104_2|        6|
|        D104_3|        7|
|        D100_2|        8|
|        D102_1|        7|
|        D104_4|        8|
|        D101_0|        4|
|        D101_2|        6|
|        D101_4|        7|
|        D102_0|        6|
|        D102_2|        9|
|        D102_3|       11|
|        D102_4|        8|
+--------------+---------+
only showing top 20 rows



In [44]:
# Check the partition details to understand distribution
from pyspark.sql.functions import spark_partition_id, count
partition_df = salted_joined_df.withColumn("partition_num", spark_partition_id()).groupBy("partition_num") \
    .agg(count(lit(1)).alias("count")).orderBy("partition_num")
partition_df.show()

+-------------+-----+
|partition_num|count|
+-------------+-----+
|            0|   24|
|            1|   55|
|            2|   26|
|            3|   66|
|            4|   29|
+-------------+-----+

