In [1]:
from pyspark.sql import SparkSession
from pyspark.sql.functions import *
from pyspark.sql import Row
from pyspark.sql.types import IntegerType


In [2]:
spark = SparkSession.builder \
.master("local") \
.config("spark.sql.autoBroadcastJoinThreshold", -1) \
.config("spark.executor.memory", "500mb") \
.appName("Exercise1") \
.getOrCreate()

Using Spark's default log4j profile: org/apache/spark/log4j-defaults.properties
Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
22/04/23 14:05:13 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable
22/04/23 14:05:20 WARN Utils: Service 'SparkUI' could not bind on port 4040. Attempting port 4041.
22/04/23 14:05:20 WARN Utils: Service 'SparkUI' could not bind on port 4041. Attempting port 4042.


In [3]:
sales_table = spark.read.parquet('DatasetToCompleteTheSixSparkExercises/sales_parquet')
products_table = spark.read.parquet('DatasetToCompleteTheSixSparkExercises/products_parquet')
sellers_table = spark.read.parquet('DatasetToCompleteTheSixSparkExercises/sellers_parquet')

                                                                                

In [4]:
sales_table.show(1)
products_table.show(1)

                                                                                

+--------+----------+---------+----------+---------------+--------------------+
|order_id|product_id|seller_id|      date|num_pieces_sold|       bill_raw_text|
+--------+----------+---------+----------+---------------+--------------------+
|       1|         0|        0|2020-07-10|             26|kyeibuumwlyhuwksx...|
+--------+----------+---------+----------+---------------+--------------------+
only showing top 1 row

+----------+------------+-----+
|product_id|product_name|price|
+----------+------------+-----+
|         0|   product_0|   22|
+----------+------------+-----+
only showing top 1 row



In [5]:
sales_products = sales_table.join(products_table, on='product_id')

In [6]:
sales_products.show(1)

[Stage 9:>                                                          (0 + 1) / 1]

+----------+--------+---------+----------+---------------+--------------------+----------------+-----+
|product_id|order_id|seller_id|      date|num_pieces_sold|       bill_raw_text|    product_name|price|
+----------+--------+---------+----------+---------------+--------------------+----------------+-----+
|  10005243|12478308|        6|2020-07-04|             98|qfvpgiscflyjxphcq...|product_10005243|   44|
+----------+--------+---------+----------+---------------+--------------------+----------------+-----+
only showing top 1 row



                                                                                

In [7]:
sales_products_avg = sales_products.withColumn('revenue', (col('num_pieces_sold') * col('price')))

In [8]:
sales_products_avg.select(avg('revenue')).show()



+------------------+
|      avg(revenue)|
+------------------+
|1246.1338560822878|
+------------------+



                                                                                

In [9]:

# Step 1 - Check and select the skewed keys 
# In this case we are retrieving the top 100 keys: these will be the only salted keys.
results = sales_table.groupby(sales_table["product_id"]).count().sort(col("count").desc()).limit(100).collect()
# Step 2 - What we want to do is:
#  a. Duplicate the entries that we have in the dimension table for the most common products, e.g.
#       product_0 will become: product_0-1, product_0-2, product_0-3 and so on
#  b. On the sales table, we are going to replace "product_0" with a random duplicate (e.g. some of them 
#     will be replaced with product_0-1, others with product_0-2, etc.)
# Using the new "salted" key will unskew the join
# Let's create a dataset to do the trick
REPLICATION_FACTOR = 101
l = []
replicated_products = []
for _r in results:
    replicated_products.append(_r["product_id"])
    for _rep in range(0, REPLICATION_FACTOR):
        l.append((_r["product_id"], _rep))
rdd = spark.sparkContext.parallelize(l)
replicated_df = rdd.map(lambda x: Row(product_id=x[0], replication=int(x[1])))
replicated_df = spark.createDataFrame(replicated_df)

#   Step 3: Generate the salted key
products_table = products_table.join(broadcast(replicated_df),
                                     products_table["product_id"] == replicated_df["product_id"], "left"). \
    withColumn("salted_join_key", when(replicated_df["replication"].isNull(), products_table["product_id"]).otherwise(
    concat(replicated_df["product_id"], lit("-"), replicated_df["replication"])))

sales_table = sales_table.withColumn("salted_join_key", when(sales_table["product_id"].isin(replicated_products),
                                                             concat(sales_table["product_id"], lit("-"),
                                                                    round(rand() * (REPLICATION_FACTOR - 1), 0).cast(
                                                                        IntegerType()))).otherwise(
    sales_table["product_id"]))

#   Step 4: Finally let's do the join
print(sales_table.join(products_table, sales_table["salted_join_key"] == products_table["salted_join_key"],
                       "inner").
      agg(avg(products_table["price"] * sales_table["num_pieces_sold"])).show())

print("Ok")




+------------------------------+
|avg((price * num_pieces_sold))|
+------------------------------+
|            1246.1338560822878|
+------------------------------+

None
Ok


                                                                                