In [None]:
from pyspark.sql import SparkSession, Row, Window
from pyspark.sql.functions import *
from pyspark.sql.types import IntegerType

In [None]:
spark = SparkSession.builder \
    .master("local") \
    .config("spark.sql.autoBroadcastJoinThreshold", -1) \
    .config("spark.executor.memory", "3gb") \
    .appName("Exercise1") \
    .getOrCreate()

In [None]:
sellers = spark.read.csv("data/seller.csv", header=True, mode="DROPMALFORMED")
sales = spark.read.csv("data/sales.csv", header=True, mode="DROPMALFORMED")
products = spark.read.csv("data/products.csv", header=True, mode="DROPMALFORMED")

DataFrames fields

In [None]:
sellers, sales, products

## Warmup 1

How many orders, how many products and how many sellers are in the data?

In [None]:
{"sellers": sellers.count(), "sales": sales.count(), "products": products.count()}

How many products have been sold at least once?

In [None]:
sales.select('product_id').distinct().count()

Which is the product contained in more orders?

In [None]:
sales.groupBy('product_id').count().filter(col('count') >= 2).show()

## Warmup 2

In [None]:
sales.groupBy('date').agg(countDistinct('product_id')).show()

## Excercise 1

### Easier approach

In [None]:
sales.join(products, sales["product_id"] == products["product_id"], "inner").\
    agg(avg(products["price"] * sales["num_pieces_sold"])).show()

In [None]:
mid_res = sales.join(products, sales["product_id"] == products["product_id"], "inner")

In [None]:
mid_res.agg(avg(mid_res["price"] * mid_res["num_pieces_sold"])).show()

In [None]:
mid_res.select(avg(mid_res["price"] * mid_res["num_pieces_sold"])).show()

### Efficient approach

In [None]:
results = sales.groupby(sales["product_id"]).count().sort(col("count").desc()).limit(100).collect()

# Step 2 - What we want to do is:
#  a. Duplicate the entries that we have in the dimension table for the most common products, e.g.
#       product_0 will become: product_0-1, product_0-2, product_0-3 and so on
#  b. On the sales table, we are going to replace "product_0" with a random duplicate (e.g. some of them 
#     will be replaced with product_0-1, others with product_0-2, etc.)
# Using the new "salted" key will unskew the join

# Let's create a dataset to do the trick
REPLICATION_FACTOR = 101
l = []
replicated_products = []
for _r in results:
    replicated_products.append(_r["product_id"])
    for _rep in range(0, REPLICATION_FACTOR):
        l.append((_r["product_id"], _rep))
rdd = spark.sparkContext.parallelize(l)
replicated_df = rdd.map(lambda x: Row(product_id=x[0], replication=int(x[1])))
replicated_df = spark.createDataFrame(replicated_df)

#   Step 3: Generate the salted key
products = products.join(broadcast(replicated_df),
                                     products["product_id"] == replicated_df["product_id"], "left"). \
    withColumn("salted_join_key", when(replicated_df["replication"].isNull(), products["product_id"]).otherwise(
    concat(replicated_df["product_id"], lit("-"), replicated_df["replication"])))

sales = sales.withColumn("salted_join_key", when(sales["product_id"].isin(replicated_products),
                                                             concat(sales["product_id"], lit("-"),
                                                                    round(rand() * (REPLICATION_FACTOR - 1), 0).cast(
                                                                        IntegerType()))).otherwise(
    sales["product_id"]))

#   Step 4: Finally let's do the join
print(sales.join(products, sales["salted_join_key"] == products["salted_join_key"],
                       "inner").
      agg(avg(products["price"] * sales["num_pieces_sold"])).show())

print("Ok")

## Excercise 2

In [None]:
sales.join(broadcast(sellers), sales['seller_id'] == sellers['seller_id']).\
    groupBy(sales['seller_id']).agg(avg(sales["num_pieces_sold"]/sellers['daily_target'])).show()

In [None]:
sales.join(broadcast(sellers), sales["seller_id"] == sellers["seller_id"], "inner").withColumn(
    "ratio", sales["num_pieces_sold"]/sellers["daily_target"]
).groupBy(sales["seller_id"]).agg(avg("ratio")).show()

## Excercise 3

[Caching](https://sparkbyexamples.com/spark/spark-dataframe-cache-and-persist-explained/)  
[row_number vs rand vs dense_rank](https://sparkbyexamples.com/pyspark/pyspark-window-functions/)

In [None]:
win = Window.partitionBy("product_id").orderBy(col("n_sold").asc())
df1 = sales.groupBy(['seller_id', 'product_id']).agg(
    sum('num_pieces_sold').alias('n_sold')) \
    .withColumn('seller_prod_rank', dense_rank().over(win))

df2 = sales.groupBy('product_id').agg(countDistinct('seller_id').alias('n_sellers'))

df = df1.join(df2, 'product_id')
df.cache()

In [None]:
only_one = df.filter(col('n_sellers') == 1).select(['product_id', 'seller_id', 'seller_prod_rank'])

seconds = df.filter(col('seller_prod_rank') == 2).select(['product_id', 'seller_id', 'seller_prod_rank'])

maxs = df.filter(col('seller_prod_rank') > 2) \
    .groupBy('product_id').agg(max('seller_prod_rank')).select('product_id')
lasts = df.select(['product_id', 'seller_id', 'seller_prod_rank']) \
    .join(maxs, 'product_id')

In [None]:
only_one, seconds, lasts

In [None]:
result = only_one.union(seconds).union(lasts)
result.persist()
result.show()

In [None]:
result.filter(col('product_id') == 0).show()

In [None]:
result.withColumn(
    'n_sellers_group', when(col('seller_prod_rank') == 1, 1) \
        .when(col('seller_prod_rank') == 2, 2) \
        .when(col('seller_prod_rank') > 2, 'more')) \
    .groupBy('n_sellers_group').count().show()

In [None]:
df.unpersist()
result.unpersist()

## Excercise 4

In [None]:
import hashlib


def custom_hash(order_id: str, bill: str):
    if order_id % 2 == 0:
        for _ in range(bill.count('A')):
            bill = hashlib.md5(bill)
    else:
        bill = hashlib.sha256(bill)
    return bill


def even_hash(bill):
    res = bill.encode("utf-8")
    for _ in range(bill.count('A')):
        res = hashlib.md5(res).hexdigest().encode("utf-8")
    res = res.decode('utf-8')
    return res


even_hash_udf = spark.udf.register('even_hash', even_hash)


result = sales.withColumn(
    'hashed_bill',
    when(col('order_id') % 2 == 0, even_hash_udf(col("bill_raw_text"))) \
    .otherwise(sha2(col("bill_raw_text"), 256))
)

result.show()
result.groupBy('hashed_bill').count().filter(col('count') > 1).show()