In [3]:
from pyspark.sql import SparkSession
spark = SparkSession.builder.master("local[*]").appName("abc").getOrCreate()

24/02/09 12:59:48 WARN Utils: Your hostname, Navneets-MacBook-Air.local resolves to a loopback address: 127.0.0.1; using 172.20.10.7 instead (on interface en0)
24/02/09 12:59:48 WARN Utils: Set SPARK_LOCAL_IP if you need to bind to another address
Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
24/02/09 12:59:48 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable
24/02/09 12:59:48 WARN Utils: Service 'SparkUI' could not bind on port 4040. Attempting port 4041.


In [4]:
# Set Spark parameters - We have to turn off AQL to demonstrate Salting
spark.conf.set("spark.sql.adaptive.enabled", False)
spark.conf.set("spark.sql.shuffle.partitions", 5)
# Check the parameters
print(spark.conf.get("spark.sql.adaptive.enabled"))
print(spark.conf.get("spark.sql.shuffle.partitions"))

false
5


In [5]:
# Lets create the example dataset of fact and dimension we would use for demonstration
# Python program to generate random Fact table data
# [1, ,"ORD1001", "D102", 56]
import random
def generate_fact_data(counter=100):
    fact_records = []
    dim_keys = ["D100", "D101", "D102", "D103", "D104"]
    order_ids = ["ORD" + str(i) for i in range(1001, 1010)]
    qty_range = [i for i in range(10, 120)]
    for i in range(counter):
        _record = [i, random.choice(order_ids), random.choice(dim_keys), random.choice(qty_range)]
        fact_records.append(_record)
    return fact_records

# We will generate 200 records with random data in Fact to create skewness
fact_records = generate_fact_data(200)
dim_records = [
    ["D100", "Product A"],
    ["D101", "Product B"],
    ["D102", "Product C"],
    ["D103", "Product D"],
    ["D104", "Product E"]
]
_fact_cols = ["id", "order_id", "prod_id", "qty"]
_dim_cols = ["prod_id", "prod_name"]

In [6]:
# Generate Fact Data Frame
fact_df = spark.createDataFrame(data = fact_records, schema=_fact_cols)
fact_df.printSchema()
fact_df.show(10, truncate = False)

root
 |-- id: long (nullable = true)
 |-- order_id: string (nullable = true)
 |-- prod_id: string (nullable = true)
 |-- qty: long (nullable = true)



                                                                                

+---+--------+-------+---+
|id |order_id|prod_id|qty|
+---+--------+-------+---+
|0  |ORD1006 |D100   |44 |
|1  |ORD1001 |D103   |40 |
|2  |ORD1007 |D100   |37 |
|3  |ORD1005 |D104   |96 |
|4  |ORD1009 |D103   |104|
|5  |ORD1003 |D101   |115|
|6  |ORD1003 |D100   |12 |
|7  |ORD1006 |D101   |55 |
|8  |ORD1001 |D101   |88 |
|9  |ORD1007 |D104   |60 |
+---+--------+-------+---+
only showing top 10 rows



In [18]:
fact_df.groupBy('prod_id').agg(count('id').alias('id_count')).show()

+-------+--------+
|prod_id|id_count|
+-------+--------+
|   D103|      32|
|   D104|      43|
|   D100|      40|
|   D101|      39|
|   D102|      46|
+-------+--------+



In [7]:
# Generate Prod Dim Data Frame
dim_df = spark.createDataFrame(data = dim_records, schema=_dim_cols)
dim_df.printSchema()
dim_df.show(10, False)

root
 |-- prod_id: string (nullable = true)
 |-- prod_name: string (nullable = true)

+-------+---------+
|prod_id|prod_name|
+-------+---------+
|D100   |Product A|
|D101   |Product B|
|D102   |Product C|
|D103   |Product D|
|D104   |Product E|
+-------+---------+



In [8]:
# Check the parameters
print(spark.conf.get("spark.sql.adaptive.enabled"))
print(spark.conf.get("spark.sql.shuffle.partitions"))

false
5


In [9]:
# Set Spark parameters - We have to turn off AQL to demonstrate Salting
spark.conf.set("spark.sql.adaptive.enabled", False)
spark.conf.set("spark.sql.shuffle.partitions", 5)
# Check the parameters
print(spark.conf.get("spark.sql.adaptive.enabled"))
print(spark.conf.get("spark.sql.shuffle.partitions"))

false
5


In [10]:
# Lets join the fact and dim without salting
joined_df = fact_df.join(dim_df, on="prod_id", how="leftouter")
joined_df.show(10, False)

+-------+---+--------+---+---------+
|prod_id|id |order_id|qty|prod_name|
+-------+---+--------+---+---------+
|D103   |1  |ORD1001 |40 |Product D|
|D103   |4  |ORD1009 |104|Product D|
|D103   |54 |ORD1008 |10 |Product D|
|D103   |76 |ORD1001 |86 |Product D|
|D103   |79 |ORD1008 |110|Product D|
|D103   |81 |ORD1006 |11 |Product D|
|D103   |84 |ORD1004 |106|Product D|
|D103   |104|ORD1007 |63 |Product D|
|D103   |110|ORD1008 |84 |Product D|
|D103   |135|ORD1001 |51 |Product D|
+-------+---+--------+---+---------+
only showing top 10 rows



In [11]:
# Check the partition details to understand distribution
from pyspark.sql.functions import spark_partition_id, count
partition_df = joined_df.withColumn("partition_num", spark_partition_id()).groupBy("partition_num").agg(count("id"))
partition_df.show()

+-------------+---------+
|partition_num|count(id)|
+-------------+---------+
|            4|      125|
|            2|       75|
+-------------+---------+



In [12]:
# Let prepare the salt
import random
from pyspark.sql.functions import udf
# UDF to return a random number every time
def rand(): return random.randint(0, 4) #Since we are distributing the data in 5 partitions
rand_udf = udf(rand)
# Salt Data Frame to add to dimension
salt_df = spark.range(0, 5)
salt_df.show()


+---+
| id|
+---+
|  0|
|  1|
|  2|
|  3|
|  4|
+---+



In [13]:
# Salted Fact
from pyspark.sql.functions import lit, expr, concat
salted_fact_df = fact_df.withColumn("salted_prod_id", concat("prod_id",lit("_"), lit(rand_udf())))
salted_fact_df.show(10, False)

# adding salt id sampled from a uniform distribution to a bunch of product_ids creates a uniformly distributed product_ids

+---+--------+-------+---+--------------+
|id |order_id|prod_id|qty|salted_prod_id|
+---+--------+-------+---+--------------+
|0  |ORD1006 |D100   |44 |D100_2        |
|1  |ORD1001 |D103   |40 |D103_2        |
|2  |ORD1007 |D100   |37 |D100_3        |
|3  |ORD1005 |D104   |96 |D104_3        |
|4  |ORD1009 |D103   |104|D103_3        |
|5  |ORD1003 |D101   |115|D101_3        |
|6  |ORD1003 |D100   |12 |D100_1        |
|7  |ORD1006 |D101   |55 |D101_0        |
|8  |ORD1001 |D101   |88 |D101_0        |
|9  |ORD1007 |D104   |60 |D104_4        |
+---+--------+-------+---+--------------+
only showing top 10 rows



In [20]:
# Salted DIM
salted_dim_df = dim_df.join(salt_df, how="cross").withColumn("salted_prod_id", concat("prod_id", lit("_"), "id")).drop("id")
salted_dim_df.show()


+-------+---------+--------------+
|prod_id|prod_name|salted_prod_id|
+-------+---------+--------------+
|   D100|Product A|        D100_0|
|   D100|Product A|        D100_1|
|   D100|Product A|        D100_2|
|   D100|Product A|        D100_3|
|   D100|Product A|        D100_4|
|   D101|Product B|        D101_0|
|   D101|Product B|        D101_1|
|   D101|Product B|        D101_2|
|   D101|Product B|        D101_3|
|   D101|Product B|        D101_4|
|   D102|Product C|        D102_0|
|   D102|Product C|        D102_1|
|   D102|Product C|        D102_2|
|   D102|Product C|        D102_3|
|   D102|Product C|        D102_4|
|   D103|Product D|        D103_0|
|   D103|Product D|        D103_1|
|   D103|Product D|        D103_2|
|   D103|Product D|        D103_3|
|   D103|Product D|        D103_4|
+-------+---------+--------------+
only showing top 20 rows



In [21]:
# Lets make the salted join now
salted_joined_df = salted_fact_df.join(salted_dim_df, on="salted_prod_id", how="leftouter")
salted_joined_df.show(10, False)

+--------------+---+--------+-------+---+-------+---------+
|salted_prod_id|id |order_id|prod_id|qty|prod_id|prod_name|
+--------------+---+--------+-------+---+-------+---------+
|D100_0        |2  |ORD1007 |D100   |37 |D100   |Product A|
|D100_0        |27 |ORD1006 |D100   |103|D100   |Product A|
|D100_0        |80 |ORD1001 |D100   |29 |D100   |Product A|
|D100_1        |0  |ORD1006 |D100   |44 |D100   |Product A|
|D100_1        |52 |ORD1004 |D100   |92 |D100   |Product A|
|D100_1        |85 |ORD1009 |D100   |115|D100   |Product A|
|D100_1        |102|ORD1001 |D100   |10 |D100   |Product A|
|D100_1        |179|ORD1006 |D100   |73 |D100   |Product A|
|D100_1        |185|ORD1009 |D100   |105|D100   |Product A|
|D104_1        |3  |ORD1005 |D104   |96 |D104   |Product E|
+--------------+---+--------+-------+---+-------+---------+
only showing top 10 rows



In [22]:
salted_joined_df.groupBy("salted_prod_id").agg(count("id")).show()

+--------------+---------+
|salted_prod_id|count(id)|
+--------------+---------+
|        D100_0|       13|
|        D100_1|        5|
|        D104_1|       11|
|        D101_1|       11|
|        D103_1|        4|
|        D103_2|        3|
|        D103_3|       13|
|        D104_0|        9|
|        D104_2|        5|
|        D104_3|       10|
|        D100_2|        9|
|        D102_1|        8|
|        D104_4|        8|
|        D101_0|        6|
|        D101_2|        5|
|        D101_4|       14|
|        D102_0|       13|
|        D102_2|        9|
|        D102_3|        7|
|        D102_4|        9|
+--------------+---------+
only showing top 20 rows



In [44]:
# Check the partition details to understand distribution
from pyspark.sql.functions import spark_partition_id, count
partition_df = salted_joined_df.withColumn("partition_num", spark_partition_id()).groupBy("partition_num") \
    .agg(count(lit(1)).alias("count")).orderBy("partition_num")
partition_df.show()

+-------------+-----+
|partition_num|count|
+-------------+-----+
|            0|   24|
|            1|   55|
|            2|   26|
|            3|   66|
|            4|   29|
+-------------+-----+



In [23]:
dict(zip(('a','b','c','d','e'),(1,2,3,4,5)))

{'a': 1, 'b': 2, 'c': 3, 'd': 4, 'e': 5}

SyntaxError: EOL while scanning string literal (1088809615.py, line 1)