In [1]:
import os
from pyspark.sql.types import *
from pyspark.sql import SparkSession

spark = (SparkSession.builder
         .master('local[*]')
         .appName('Six_Spark_Excercises')
         .getOrCreate()
        )

In [6]:
spark

# GET DATA

In [None]:
%%bash
unzip /home/artur/Downloads/DatasetToCompleteTheSixSparkExercises.zip \
-d /home/artur/Downloads/DatasetToCompleteTheSixSparkExercises

In [17]:
!ls /home/artur/Downloads/DatasetToCompleteTheSixSparkExercises

products_parquet  sales_parquet  sellers_parquet


In [7]:
data_dir = "/home/artur/Downloads/DatasetToCompleteTheSixSparkExercises"

### Products

In [72]:
# product_schema = StructType([
#     StructField("product_id", LongType(), False),
#     StructField("product_name", StringType(), True),
#     StructField("price", FloatType(), True)
# ])
# products_init = spark.read.schema(product_schema).parquet(os.path.join(data_dir, "products_parquet"))

In [8]:
products_init = spark.read.parquet(os.path.join(data_dir, "products_parquet"))

In [9]:
products_init.printSchema()

root
 |-- product_id: string (nullable = true)
 |-- product_name: string (nullable = true)
 |-- price: string (nullable = true)



In [102]:
products_init.show(3)

+----------+------------+-----+
|product_id|product_name|price|
+----------+------------+-----+
|         0|   product_0|   22|
|         1|   product_1|   30|
|         2|   product_2|   91|
+----------+------------+-----+
only showing top 3 rows



In [10]:
products_casted = (products_init.withColumn("product_id", products_init.product_id.cast('long'))
                            .withColumn("price", products_init.price.cast('float')))

In [11]:
products_casted.printSchema()

root
 |-- product_id: long (nullable = true)
 |-- product_name: string (nullable = true)
 |-- price: float (nullable = true)



In [105]:
products_casted.show(3)

+----------+------------+-----+
|product_id|product_name|price|
+----------+------------+-----+
|         0|   product_0| 22.0|
|         1|   product_1| 30.0|
|         2|   product_2| 91.0|
+----------+------------+-----+
only showing top 3 rows



### Sales

In [12]:
sales_init = spark.read.parquet(os.path.join(data_dir, 'sales_parquet'))

In [13]:
sales_init.printSchema()

root
 |-- order_id: string (nullable = true)
 |-- product_id: string (nullable = true)
 |-- seller_id: string (nullable = true)
 |-- date: string (nullable = true)
 |-- num_pieces_sold: string (nullable = true)
 |-- bill_raw_text: string (nullable = true)



In [95]:
sales_init.show(3)

+--------+----------+---------+----------+---------------+--------------------+
|order_id|product_id|seller_id|      date|num_pieces_sold|       bill_raw_text|
+--------+----------+---------+----------+---------------+--------------------+
|       1|         0|        0|2020-07-10|             26|kyeibuumwlyhuwksx...|
|       2|         0|        0|2020-07-08|             13|jfyuoyfkeyqkckwbu...|
|       3|         0|        0|2020-07-05|             38|uyjihlzhzcswxcccx...|
+--------+----------+---------+----------+---------------+--------------------+
only showing top 3 rows



In [14]:
sales_casted = (sales_init.withColumn("order_id", sales_init.order_id.cast("long"))
              .withColumn("product_id", sales_init.product_id.cast("long"))
              .withColumn("seller_id", sales_init.seller_id.cast("long"))
              .withColumn("num_pieces_sold", sales_init.num_pieces_sold.cast("long"))
             )
sales_casted.printSchema()

root
 |-- order_id: long (nullable = true)
 |-- product_id: long (nullable = true)
 |-- seller_id: long (nullable = true)
 |-- date: string (nullable = true)
 |-- num_pieces_sold: long (nullable = true)
 |-- bill_raw_text: string (nullable = true)



### Sellers

In [15]:
sellers_init = spark.read.parquet(os.path.join(data_dir, 'sellers_parquet'))
sellers_init.printSchema()

root
 |-- seller_id: string (nullable = true)
 |-- seller_name: string (nullable = true)
 |-- daily_target: string (nullable = true)



In [107]:
sellers_init.show(3)

+---------+-----------+------------+
|seller_id|seller_name|daily_target|
+---------+-----------+------------+
|        0|   seller_0|     2500000|
|        1|   seller_1|      257237|
|        2|   seller_2|      754188|
+---------+-----------+------------+
only showing top 3 rows



In [16]:
sellers_casted = (sellers_init.withColumn("seller_id", sellers_init.seller_id.cast("long"))
                  .withColumn("daily_target", sellers_init.daily_target.cast("long")))
sellers_casted.printSchema()

root
 |-- seller_id: long (nullable = true)
 |-- seller_name: string (nullable = true)
 |-- daily_target: long (nullable = true)



In [47]:
sellers_casted.show()

+---------+-----------+------------+
|seller_id|seller_name|daily_target|
+---------+-----------+------------+
|        0|   seller_0|     2500000|
|        1|   seller_1|      257237|
|        2|   seller_2|      754188|
|        3|   seller_3|      310462|
|        4|   seller_4|     1532808|
|        5|   seller_5|     1199693|
|        6|   seller_6|     1055915|
|        7|   seller_7|     1946998|
|        8|   seller_8|      547320|
|        9|   seller_9|     1318051|
+---------+-----------+------------+



# WARM-UP #1

Find out how many orders, how many products and how many sellers are in the data.

In [17]:
import pyspark.sql.functions as sqlf

In [118]:
# 1st method
sales_casted.select("order_id").count()

20000040

In [124]:
%%time
# check if there are duplicates
sales_casted.select("order_id").distinct().count()

CPU times: user 4.78 ms, sys: 0 ns, total: 4.78 ms
Wall time: 13.3 s


20000040

In [122]:
%%time
# 2nd method
sales_casted.select(sqlf.countDistinct(sales_casted.order_id)).collect()

CPU times: user 5.55 ms, sys: 3.29 ms, total: 8.84 ms
Wall time: 13.2 s


[Row(count(DISTINCT order_id)=20000040)]

In [126]:
products_casted.select("product_id").count()

75000000

In [129]:
sellers_casted.select("seller_id").distinct().count()

10

How many products have been sold at least once? Which is the product contained in more orders?

In [130]:
%%time
sales_casted.select("product_id").distinct().count()

CPU times: user 3.1 ms, sys: 366 µs, total: 3.47 ms
Wall time: 4.64 s


993429

In [153]:
nb_orders_by_product = sales_casted.groupby("product_id").agg(sqlf.count("order_id").alias("count"))\
                                    .orderBy("count", ascending=False)

In [154]:
nb_orders_by_product.show(3)

+----------+--------+
|product_id|   count|
+----------+--------+
|         0|19000000|
|  61540351|       3|
|  28592106|       3|
+----------+--------+
only showing top 3 rows



# WARM-UP #2

How many distinct products have been sold in each day?

In [206]:
%%time
sales_casted.select("date", "product_id").distinct().groupby("date")\
            .agg(sqlf.count("product_id").alias("product_count"))\
                                    .orderBy(col("product_count").desc()).show()

+----------+-------------+
|      date|product_count|
+----------+-------------+
|2020-07-06|       100765|
|2020-07-09|       100501|
|2020-07-01|       100337|
|2020-07-03|       100017|
|2020-07-02|        99807|
|2020-07-05|        99796|
|2020-07-04|        99791|
|2020-07-07|        99756|
|2020-07-08|        99662|
|2020-07-10|        98973|
+----------+-------------+

CPU times: user 12 ms, sys: 426 µs, total: 12.4 ms
Wall time: 5.64 s


In [207]:
%%time
# author's solution
from pyspark.sql.functions import *
sales_casted.groupby(col("date")).agg(countDistinct(col("product_id")).alias("distinct_products_sold"))\
                        .orderBy(col("distinct_products_sold").desc()).show()


+----------+----------------------+
|      date|distinct_products_sold|
+----------+----------------------+
|2020-07-06|                100765|
|2020-07-09|                100501|
|2020-07-01|                100337|
|2020-07-03|                100017|
|2020-07-02|                 99807|
|2020-07-05|                 99796|
|2020-07-04|                 99791|
|2020-07-07|                 99756|
|2020-07-08|                 99662|
|2020-07-10|                 98973|
+----------+----------------------+

CPU times: user 7.14 ms, sys: 395 µs, total: 7.53 ms
Wall time: 5.54 s


# EXERCISE #1

What is the average revenue of the orders?

In [241]:
%%time
sales_casted.join(products_casted, on="product_id", how="inner")\
                                  .agg(sqlf.avg(sqlf.col("num_pieces_sold")*sqlf.col("price")).alias("revenue")).show()

+------------------+
|           revenue|
+------------------+
|1246.1338560822878|
+------------------+

CPU times: user 1.58 ms, sys: 9.57 ms, total: 11.1 ms
Wall time: 1min


### Author's optimized solution (technique known as “key salting”)

In [248]:
products_table = products_init
sales_table = sales_init
sellers_table = sellers_init

In [243]:
# Step 1 - Check and select the skewed keys 
# In this case we are retrieving the top 100 keys: these will be the only salted keys.
results = sales_table.groupby(sales_table["product_id"]).count().sort(col("count").desc()).limit(1).collect()

In [244]:
# Step 2 - What we want to do is:
#  a. Duplicate the entries that we have in the dimension table for the most common products, e.g.
#       product_0 will become: product_0-1, product_0-2, product_0-3 and so on
#  b. On the sales table, we are going to replace "product_0" with a random duplicate (e.g. some of them 
#     will be replaced with product_0-1, others with product_0-2, etc.)
# Using the new "salted" key will unskew the join

# Let's create a dataset to do the trick
REPLICATION_FACTOR = 101
l = []
replicated_products = []
for _r in results:
    replicated_products.append(_r["product_id"])
    for _rep in range(0, REPLICATION_FACTOR):
        l.append((_r["product_id"], _rep))

In [245]:
from pyspark.sql import Row

rdd = spark.sparkContext.parallelize(l)

In [246]:
replicated_df = rdd.map(lambda x: Row(product_id=x[0], replication=int(x[1])))
replicated_df = spark.createDataFrame(replicated_df)

In [249]:
from pyspark.sql.functions import *
#   Step 3: Generate the salted key
products_table = products_table.join(broadcast(replicated_df),
                                     products_table["product_id"] == replicated_df["product_id"], "left"). \
    withColumn("salted_join_key", when(replicated_df["replication"].isNull(), products_table["product_id"]).otherwise(
    concat(replicated_df["product_id"], lit("-"), replicated_df["replication"])))

In [250]:
sales_table = sales_table.withColumn("salted_join_key", when(sales_table["product_id"].isin(replicated_products),
                                                             concat(sales_table["product_id"], lit("-"),
                                                                    sqlf.round(sqlf.rand() * (REPLICATION_FACTOR - 1), 0)
                                                                    .cast(IntegerType()))
                                                            ).otherwise(sales_table["product_id"]))

In [251]:
%%time
#   Step 4: Finally let's do the join
print(sales_table.join(products_table, sales_table["salted_join_key"] == products_table["salted_join_key"], "inner")
      .agg(avg(products_table["price"] * sales_table["num_pieces_sold"])).show())

+------------------------------+
|avg((price * num_pieces_sold))|
+------------------------------+
|            1246.1338560822878|
+------------------------------+

None
CPU times: user 30.3 ms, sys: 5.04 ms, total: 35.3 ms
Wall time: 1min 18s


Using this technique in a local environment could lead to an increase of the execution time; in the real world, though, this trick can make the difference between completing and not completing the join.

# EXCERCISE #2

For each seller, what is the average % contribution of an order to the seller's daily quota?

**Example**

If Seller_0 with `quota=250` has 3 orders:
* Order 1: 10 products sold
* Order 2: 8 products sold
* Order 3: 7 products sold

The average % contribution of orders to the seller's quota would be:
* Order 1: 10/105 = 0.04
* Order 2: 8/105 = 0.032
* Order 3: 7/105 = 0.028

Average % Contribution = (0.04+0.032+0.028)/3 = 0.03333

In [257]:
sales_all_data = sales_join_products.join(sqlf.broadcast(sellers_casted), on="seller_id", how="inner")

In [258]:
%%time
sales_all_data.groupby("seller_id", "daily_target")\
                             .agg((sqlf.avg(sqlf.col("num_pieces_sold")/sqlf.col("daily_target"))).alias('ratio'))\
                             .orderBy("seller_id")\
                             .show()

+---------+------------+--------------------+
|seller_id|daily_target|               ratio|
+---------+------------+--------------------+
|        0|     2500000|2.019885898974511...|
|        1|      257237|1.964233366461027...|
|        2|      754188|6.690408001060523E-5|
|        3|      310462|1.628885370565940...|
|        4|     1532808|3.296428039825806...|
|        5|     1199693|4.211073965904036...|
|        6|     1055915|4.782147194369131...|
|        7|     1946998|2.595228787788168...|
|        8|      547320|9.213030375408928E-5|
|        9|     1318051|3.837913136180195...|
+---------+------------+--------------------+

CPU times: user 3.14 ms, sys: 12.2 ms, total: 15.4 ms
Wall time: 54.2 s


# EXCERCISE #3

Who are the second most selling and the least selling persons (sellers) for each product? 

Who are those for product with `product_id = 0`

In [43]:
import pyspark.sql.window as sqlw

window_conf = sqlw.Window.partitionBy("product_id").orderBy(sqlf.desc("volume"))

In [44]:
product_seller_volume = sales_all_data.groupby("product_id", "seller_id")\
                                      .agg(sqlf.sum("num_pieces_sold").alias("volume"))\
                                      .withColumn("volume_rank", sqlf.dense_rank().over(window_conf))
product_seller_volume.cache()

DataFrame[product_id: bigint, seller_id: bigint, volume: bigint, volume_rank: int]

### Develop the solution on a subset of products wich are sold by at least 3 sellers

In [51]:
fltr_tbl = product_seller_volume.groupby("product_id")\
            .agg(sqlf.count('seller_id').alias("count")).where(sqlf.col('count')>=3)
fltr_tbl.cache()

DataFrame[product_id: bigint, count: bigint]

In [67]:
tmp = product_seller_volume.join(fltr_tbl, on="product_id", how="inner")\
                .orderBy("product_id", "volume", "seller_id")
tmp.show()

+----------+---------+------+-----------+-----+
|product_id|seller_id|volume|volume_rank|count|
+----------+---------+------+-----------+-----+
|   3534470|        3|    25|          3|    3|
|   3534470|        9|    73|          2|    3|
|   3534470|        5|    81|          1|    3|
|  10978356|        7|    27|          3|    3|
|  10978356|        9|    36|          2|    3|
|  10978356|        5|    40|          1|    3|
|  14542470|        5|     3|          3|    3|
|  14542470|        9|    35|          2|    3|
|  14542470|        2|    62|          1|    3|
|  17944574|        8|    15|          3|    3|
|  17944574|        5|    17|          2|    3|
|  17944574|        2|    32|          1|    3|
|  18182299|        7|    15|          3|    3|
|  18182299|        4|    31|          2|    3|
|  18182299|        6|    67|          1|    3|
|  19986717|        1|     9|          3|    3|
|  19986717|        9|    18|          2|    3|
|  19986717|        2|    53|          1

In [69]:
# the least selling persons (sellers) for each product
tmp.groupby("product_id").agg(sqlf.first("seller_id").alias("seller_id"),
                             sqlf.first("volume").alias("volume"),
                             sqlf.first("volume_rank").alias("volume_rank"))\
.orderBy("product_id").show()

+----------+---------+------+-----------+
|product_id|seller_id|volume|volume_rank|
+----------+---------+------+-----------+
|   3534470|        3|    25|          3|
|  10978356|        7|    27|          3|
|  14542470|        5|     3|          3|
|  17944574|        8|    15|          3|
|  18182299|        7|    15|          3|
|  19986717|        1|     9|          3|
|  20774718|        9|     5|          3|
|  28592106|        5|    42|          3|
|  31136332|        9|    57|          3|
|  32602520|        9|    12|          3|
|  34681047|        5|    19|          3|
|  35669461|        4|    39|          3|
|  36269838|        8|    16|          3|
|  40496308|        5|    52|          3|
|  52606213|        7|    28|          3|
|  56011040|        5|    13|          3|
|  57735075|        9|    55|          3|
|  61475460|        7|    41|          3|
|  67723231|        5|    26|          3|
|  69790381|        5|     8|          3|
+----------+---------+------+-----

In [70]:
# Who are the second most selling persons (sellers) for each product? 
tmp.where(sqlf.col("volume_rank")==2).show()

+----------+---------+------+-----------+-----+
|product_id|seller_id|volume|volume_rank|count|
+----------+---------+------+-----------+-----+
|   3534470|        9|    73|          2|    3|
|  10978356|        9|    36|          2|    3|
|  14542470|        9|    35|          2|    3|
|  17944574|        5|    17|          2|    3|
|  18182299|        4|    31|          2|    3|
|  19986717|        9|    18|          2|    3|
|  20774718|        3|    16|          2|    3|
|  28592106|        4|    52|          2|    3|
|  31136332|        1|    85|          2|    3|
|  32602520|        2|    21|          2|    3|
|  34681047|        6|    46|          2|    3|
|  35669461|        3|    47|          2|    3|
|  36269838|        1|    38|          2|    3|
|  40496308|        9|    58|          2|    3|
|  52606213|        9|    61|          2|    3|
|  56011040|        1|    30|          2|    3|
|  57735075|        1|    57|          2|    3|
|  61475460|        6|    64|          2

### The final solution

In [259]:
%%time
# the least selling persons (sellers) for each product
product_seller_volume.orderBy("product_id", "volume", "seller_id").groupby("product_id")\
                     .agg(sqlf.first("seller_id").alias("seller_id"),
                          sqlf.first("volume").alias("volume"),
                          sqlf.first("volume_rank").alias("volume_rank"))\
                     .orderBy("product_id").show()

+----------+---------+---------+-----------+
|product_id|seller_id|   volume|volume_rank|
+----------+---------+---------+-----------+
|         0|        0|959445802|          1|
|       141|        7|       20|          1|
|       170|        8|       71|          1|
|       188|        5|       45|          1|
|       189|        2|       86|          1|
|       369|        6|       22|          1|
|       442|        3|       47|          1|
|       452|        3|       44|          1|
|       534|        8|       51|          1|
|       707|        2|       44|          1|
|       712|        8|      100|          1|
|       765|        7|       93|          1|
|       816|        5|       93|          1|
|       888|        6|       23|          1|
|       907|        8|       66|          1|
|      1037|        8|       59|          1|
|      1044|        5|       88|          1|
|      1216|        3|       46|          1|
|      1368|        8|       66|          1|
|      138

In [73]:
# Who are the second most selling persons (sellers) for each product? 
product_seller_volume.where(sqlf.col("volume_rank")==2).show()

+----------+---------+------+-----------+
|product_id|seller_id|volume|volume_rank|
+----------+---------+------+-----------+
|   1016707|        4|     9|          2|
|   3165246|        5|    19|          2|
|   3336320|        2|    17|          2|
|   8517688|        4|    60|          2|
|   8613392|        4|    39|          2|
|   9267327|        2|    68|          2|
|  10437166|        6|    59|          2|
|  11374026|        8|     8|          2|
|  21799676|        7|    37|          2|
|  24237246|        1|    59|          2|
|  25980426|        7|     2|          2|
|  30007038|        3|     2|          2|
|  33057106|        9|    74|          2|
|  33456955|        1|    36|          2|
|  34929513|        3|    10|          2|
|  36361547|        1|    59|          2|
|  39375459|        2|    18|          2|
|  45658314|        4|    60|          2|
|  46669674|        4|    15|          2|
|  52650256|        4|    33|          2|
+----------+---------+------+-----

In [74]:
product_seller_volume.where((sqlf.col("volume_rank")==2) & (sqlf.col("product_id")==0)).show()

+----------+---------+------+-----------+
|product_id|seller_id|volume|volume_rank|
+----------+---------+------+-----------+
+----------+---------+------+-----------+



### Author's solution

Let’s analyze the question: for each product, we need the second most selling and the least selling employees (sellers): we are probably going to need two rankings, one to get the second and the other one to get the last in the sales chart. We also need to handle some edge cases:

* If a product has been sold by only one seller, we’ll put it into a special category (category: Only seller or multiple sellers with the same quantity).
* If a product has been sold by more than one seller, but all of them sold the same quantity, we are going to put them in the same category as if they were only a single seller for that product (category: Only seller or multiple sellers with the same quantity).
* If the “least selling” is also the “second selling”, we will count it only as “second seller”

Let’s draft a strategy:

* We get the sum of sales for each product and seller pairs.
* We add two new ranking columns: one that ranks the products’ sales in descending order and another one that ranks in ascending order.
* We split the dataset obtained in three pieces: one for each case that we want to handle (second top selling, least selling, single selling).
* When calculating the “least selling”, we exclude those products that have a single seller and those where the least selling employee is also the second most selling
* We merge the pieces back together.

In [261]:
products_table = products_init
sales_table = sales_init
sellers_table = sellers_init

In [263]:
from pyspark.sql import Window
# Calcuate the number of pieces sold by each seller for each product
sales_table = sales_table.groupby(col("product_id"), col("seller_id")). \
    agg(sum("num_pieces_sold").alias("num_pieces_sold"))

# Create the window functions, one will sort ascending the other one descending. Partition by the product_id
# and sort by the pieces sold
window_desc = Window.partitionBy(col("product_id")).orderBy(col("num_pieces_sold").desc())
window_asc = Window.partitionBy(col("product_id")).orderBy(col("num_pieces_sold").asc())

# Create a Dense Rank (to avoid holes)
sales_table = sales_table.withColumn("rank_asc", dense_rank().over(window_asc)). \
    withColumn("rank_desc", dense_rank().over(window_desc))

In [264]:
# Get products that only have one row OR the products in which multiple sellers sold the same amount
# (i.e. all the employees that ever sold the product, sold the same exact amount)
single_seller = sales_table.where(col("rank_asc") == col("rank_desc")).select(
    col("product_id").alias("single_seller_product_id"), col("seller_id").alias("single_seller_seller_id"),
    lit("Only seller or multiple sellers with the same results").alias("type")
)

# Get the second top sellers
second_seller = sales_table.where(col("rank_desc") == 2).select(
    col("product_id").alias("second_seller_product_id"), col("seller_id").alias("second_seller_seller_id"),
    lit("Second top seller").alias("type")
)

In [265]:
# Get the least sellers and exclude those rows that are already included in the first piece
# We also exclude the "second top sellers" that are also "least sellers"
sales_table.where(col("rank_asc") == 1).select(
    col("product_id"), col("seller_id"),
    lit("Least Seller").alias("type")
).count()

993482

In [267]:
# Get the least sellers and exclude those rows that are already included in the first piece
# We also exclude the "second top sellers" that are also "least sellers"
least_seller = sales_table.where(col("rank_asc") == 1).select(
    col("product_id"), col("seller_id"),
    lit("Least Seller").alias("type")
).join(single_seller, (sales_table["seller_id"] == single_seller["single_seller_seller_id"]) & (
        sales_table["product_id"] == single_seller["single_seller_product_id"]), "left_anti"). \
    join(second_seller, (sales_table["seller_id"] == second_seller["second_seller_seller_id"]) & (
        sales_table["product_id"] == second_seller["second_seller_product_id"]), "left_anti")

In [268]:
least_seller.count()

21

In [269]:
%%time
# Union all the pieces
union_table = least_seller.select(
    col("product_id"),
    col("seller_id"),
    col("type")
).union(second_seller.select(
    col("second_seller_product_id").alias("product_id"),
    col("second_seller_seller_id").alias("seller_id"),
    col("type")
)).union(single_seller.select(
    col("single_seller_product_id").alias("product_id"),
    col("single_seller_seller_id").alias("seller_id"),
    col("type")
))
union_table.show()

# Which are the second top seller and least seller of product 0?
union_table.where(col("product_id") == 0).show()

+----------+---------+------------+
|product_id|seller_id|        type|
+----------+---------+------------+
|  19986717|        1|Least Seller|
|  40496308|        5|Least Seller|
|  52606213|        7|Least Seller|
|  14542470|        5|Least Seller|
|  28592106|        5|Least Seller|
|  17944574|        8|Least Seller|
|  61475460|        7|Least Seller|
|   3534470|        3|Least Seller|
|  35669461|        4|Least Seller|
|  32602520|        9|Least Seller|
|  72017876|        1|Least Seller|
|  67723231|        5|Least Seller|
|  56011040|        5|Least Seller|
|  34681047|        5|Least Seller|
|  57735075|        9|Least Seller|
|  18182299|        7|Least Seller|
|  69790381|        5|Least Seller|
|  31136332|        9|Least Seller|
|  10978356|        7|Least Seller|
|  20774718|        9|Least Seller|
+----------+---------+------------+
only showing top 20 rows

+----------+---------+--------------------+
|product_id|seller_id|                type|
+----------+---------+

# EXCERCISE 4

Create a new column called "hashed_bill" defined as follows:

- if the order_id is even: apply MD5 hashing iteratively to the bill_raw_text field, once for each 'A' (capital 'A') present in the text. E.g. if the bill text is 'nbAAnllA', you would apply hashing three times iteratively (only if the order number is even)
- if the order_id is odd: apply SHA256 hashing to the bill text

Finally, check if there are any duplicate on the new column

In [271]:
count_A_udf = sqlf.udf(lambda s: int(s.count('A')), IntegerType())

sales_ex4 = sales_casted.withColumn("nb_of_A", sqlf.when(sqlf.col("order_id")%2==0, 
                                          count_A_udf("bill_raw_text"))
                     .otherwise(-1))

In [272]:
import hashlib

@sqlf.udf(returnType=StringType())
def get_md5(arr):
    s, n = str(arr[0]), int(arr[1])
    if s is not None and n is not None:
        for i in range(0,n):
            s = hashlib.md5(s.encode()).hexdigest()
        return s

In [273]:
%%time
sales_ex4 = (sales_ex4.withColumn("hashed_bill", 
                      sqlf.when(sqlf.col("nb_of_A")>-1,
                                get_md5(sqlf.array("bill_raw_text", "nb_of_A")))
                      .otherwise(sqlf.sha2("bill_raw_text", 256))
                     ))
sales_ex4.show()

+--------+----------+---------+----------+---------------+--------------------+-------+--------------------+
|order_id|product_id|seller_id|      date|num_pieces_sold|       bill_raw_text|nb_of_A|         hashed_bill|
+--------+----------+---------+----------+---------------+--------------------+-------+--------------------+
|       1|         0|        0|2020-07-10|             26|kyeibuumwlyhuwksx...|     -1|f6fa2a8be04a4ead6...|
|       2|         0|        0|2020-07-08|             13|jfyuoyfkeyqkckwbu...|      0|jfyuoyfkeyqkckwbu...|
|       3|         0|        0|2020-07-05|             38|uyjihlzhzcswxcccx...|     -1|416376a64cd652e7b...|
|       4|         0|        0|2020-07-05|             56|umnxvoqbdzpbwjqmz...|      0|umnxvoqbdzpbwjqmz...|
|       5|         0|        0|2020-07-05|             11|zmqexmaawmvdpqhih...|     -1|787d361b162a6aa1a...|
|       6|         0|        0|2020-07-01|             82|lmuhhkpyuoyslwmvX...|      0|lmuhhkpyuoyslwmvX...|
|       7|         

In [202]:
%%time
sales_ex4.drop_duplicates(["hashed_bill"]).count()

CPU times: user 138 ms, sys: 84.7 ms, total: 223 ms
Wall time: 16min 41s


20000040

### Author's solution

In [274]:
products_table = products_init
sales_table = sales_init
sellers_table = sellers_init

In [275]:
%%time
#   Define the UDF function
def algo(order_id, bill_text):
    #   If number is even
    ret = bill_text.encode("utf-8")
    if int(order_id) % 2 == 0:
        #   Count number of 'A'
        cnt_A = bill_text.count("A")
        for _c in range(0, cnt_A):
            ret = hashlib.md5(ret).hexdigest().encode("utf-8")
        ret = ret.decode('utf-8')
    else:
        ret = hashlib.sha256(ret).hexdigest()
    return ret

#   Register the UDF function.
algo_udf = spark.udf.register("algo", algo)

#   Use the `algo_udf` to apply the aglorithm and then check if there is any duplicate hash in the table
sales_table.withColumn("hashed_bill", algo_udf(col("order_id"), col("bill_raw_text")))\
    .groupby(col("hashed_bill")).agg(count("*").alias("cnt")).where(col("cnt") > 1).show()

+-----------+---+
|hashed_bill|cnt|
+-----------+---+
+-----------+---+

CPU times: user 58.8 ms, sys: 67.5 ms, total: 126 ms
Wall time: 8min 43s


In [276]:
spark.stop()