In [6]:
from pyspark.sql import SparkSession
from pyspark.sql import functions as F
from pyspark.sql.types import DoubleType, IntegerType, StringType

# start spark
spark = SparkSession.builder.appName("BONUS").getOrCreate()
spark.conf.set("spark.sql.shuffle.partitions", "8")


In [8]:
# read the retail data
df = spark.read.csv("online_retail.csv", header=True, inferSchema=True)
df.show(5)


                                                                                

+---------+---------+--------------------+--------+--------------+---------+----------+--------------+
|InvoiceNo|StockCode|         Description|Quantity|   InvoiceDate|UnitPrice|CustomerID|       Country|
+---------+---------+--------------------+--------+--------------+---------+----------+--------------+
|   536365|   85123A|WHITE HANGING HEA...|       6|12/1/2010 8:26|     2.55|     17850|United Kingdom|
|   536365|    71053| WHITE METAL LANTERN|       6|12/1/2010 8:26|     3.39|     17850|United Kingdom|
|   536365|   84406B|CREAM CUPID HEART...|       8|12/1/2010 8:26|     2.75|     17850|United Kingdom|
|   536365|   84029G|KNITTED UNION FLA...|       6|12/1/2010 8:26|     3.39|     17850|United Kingdom|
|   536365|   84029E|RED WOOLLY HOTTIE...|       6|12/1/2010 8:26|     3.39|     17850|United Kingdom|
+---------+---------+--------------------+--------+--------------+---------+----------+--------------+
only showing top 5 rows


In [9]:
# convert datatypes and remove cancelled orders
df2 = (
    df.withColumn("InvoiceDateTS", F.to_timestamp(F.trim(F.col("InvoiceDate")), "M/d/yyyy H:mm"))
      .withColumn("Quantity", F.col("Quantity").cast(IntegerType()))
      .withColumn("UnitPrice", F.col("UnitPrice").cast(DoubleType()))
      .withColumn("CustomerID", F.col("CustomerID").cast(IntegerType()))
)

# keep valid customers only
sales = (
    df2.filter(F.col("CustomerID").isNotNull() & ~F.col("InvoiceNo").startswith("C"))
       .withColumn("LineAmount", (F.col("Quantity") * F.col("UnitPrice")).cast(DoubleType()))
)

sales.show(5)


+---------+---------+--------------------+--------+--------------+---------+----------+--------------+-------------------+------------------+
|InvoiceNo|StockCode|         Description|Quantity|   InvoiceDate|UnitPrice|CustomerID|       Country|      InvoiceDateTS|        LineAmount|
+---------+---------+--------------------+--------+--------------+---------+----------+--------------+-------------------+------------------+
|   536365|   85123A|WHITE HANGING HEA...|       6|12/1/2010 8:26|     2.55|     17850|United Kingdom|2010-12-01 08:26:00|15.299999999999999|
|   536365|    71053| WHITE METAL LANTERN|       6|12/1/2010 8:26|     3.39|     17850|United Kingdom|2010-12-01 08:26:00|             20.34|
|   536365|   84406B|CREAM CUPID HEART...|       8|12/1/2010 8:26|     2.75|     17850|United Kingdom|2010-12-01 08:26:00|              22.0|
|   536365|   84029G|KNITTED UNION FLA...|       6|12/1/2010 8:26|     3.39|     17850|United Kingdom|2010-12-01 08:26:00|             20.34|
|   53

In [10]:
# convert datatypes and remove cancelled orders
df2 = (
    df.withColumn("InvoiceDateTS", F.to_timestamp(F.trim(F.col("InvoiceDate")), "M/d/yyyy H:mm"))
      .withColumn("Quantity", F.col("Quantity").cast(IntegerType()))
      .withColumn("UnitPrice", F.col("UnitPrice").cast(DoubleType()))
      .withColumn("CustomerID", F.col("CustomerID").cast(IntegerType()))
)

# keep valid customers only
sales = (
    df2.filter(F.col("CustomerID").isNotNull() & ~F.col("InvoiceNo").startswith("C"))
       .withColumn("LineAmount", (F.col("Quantity") * F.col("UnitPrice")).cast(DoubleType()))
)

sales.show(5)


+---------+---------+--------------------+--------+--------------+---------+----------+--------------+-------------------+------------------+
|InvoiceNo|StockCode|         Description|Quantity|   InvoiceDate|UnitPrice|CustomerID|       Country|      InvoiceDateTS|        LineAmount|
+---------+---------+--------------------+--------+--------------+---------+----------+--------------+-------------------+------------------+
|   536365|   85123A|WHITE HANGING HEA...|       6|12/1/2010 8:26|     2.55|     17850|United Kingdom|2010-12-01 08:26:00|15.299999999999999|
|   536365|    71053| WHITE METAL LANTERN|       6|12/1/2010 8:26|     3.39|     17850|United Kingdom|2010-12-01 08:26:00|             20.34|
|   536365|   84406B|CREAM CUPID HEART...|       8|12/1/2010 8:26|     2.75|     17850|United Kingdom|2010-12-01 08:26:00|              22.0|
|   536365|   84029G|KNITTED UNION FLA...|       6|12/1/2010 8:26|     3.39|     17850|United Kingdom|2010-12-01 08:26:00|             20.34|
|   53

In [11]:
# total amount per invoice
invoice_amounts = (
    sales.groupBy("InvoiceNo", "CustomerID")
         .agg(F.round(F.sum("LineAmount"), 2).alias("order_amount"))
)

invoice_amounts.show(5)




+---------+----------+------------+
|InvoiceNo|CustomerID|order_amount|
+---------+----------+------------+
|   536376|     15291|       328.8|
|   536378|     14688|      444.98|
|   536393|     13747|        79.6|
|   536406|     17850|      353.14|
|   536464|     17968|      277.35|
+---------+----------+------------+
only showing top 5 rows


                                                                                

In [12]:
# bucket orders as small / medium / large
def bucket(x: float) -> str:
    if x is None:
        return "Unknown"
    if x < 50:
        return "Small"
    if x <= 200:
        return "Medium"
    return "Large"

bucket_udf = F.udf(bucket, StringType())


In [13]:
# add new column using udf
invoice_out = invoice_amounts.withColumn("order_size", bucket_udf(F.col("order_amount")))
invoice_out.show(5, truncate=False)

# save result
(invoice_out.coalesce(1)
    .write.mode("overwrite").option("header", True)
    .csv("midterm/output/bonus_udf_orders"))


                                                                                

+---------+----------+------------+----------+
|InvoiceNo|CustomerID|order_amount|order_size|
+---------+----------+------------+----------+
|536376   |15291     |328.8       |Large     |
|536378   |14688     |444.98      |Large     |
|536393   |13747     |79.6        |Medium    |
|536406   |17850     |353.14      |Large     |
|536464   |17968     |277.35      |Large     |
+---------+----------+------------+----------+
only showing top 5 rows


                                                                                

In [14]:
# total money spent per customer
customer_total = (
    sales.groupBy("CustomerID")
         .agg(F.round(F.sum("LineAmount"), 2).alias("total_spent"))
)

customer_total.show(5)




+----------+-----------+
|CustomerID|total_spent|
+----------+-----------+
|     17850|    5391.21|
|     17420|     598.83|
|     15862|     832.88|
|     14045|    1659.75|
|     13694|   65039.62|
+----------+-----------+
only showing top 5 rows


                                                                                

In [15]:
# customers with spending > threshold
THRESH = 2000.0

loyalty = (
    customer_total.filter(F.col("total_spent") > F.lit(THRESH))
                  .select(
                      "CustomerID",
                      F.lit("Gold").alias("loyalty_tier"),
                      F.lit(THRESH).alias("threshold")
                  )
)

# broadcast since it is small
loyalty_b = F.broadcast(loyalty)
loyalty_b.show(5)


25/11/10 23:11:38 WARN HintErrorLogger: A join hint (strategy=broadcast) is specified but it is not part of a join relation.
[Stage 26:>                                                         (0 + 2) / 2]

+----------+------------+---------+
|CustomerID|loyalty_tier|threshold|
+----------+------------+---------+
|     17850|        Gold|   2000.0|
|     13694|        Gold|   2000.0|
|     12921|        Gold|   2000.0|
|     13777|        Gold|   2000.0|
|     13090|        Gold|   2000.0|
+----------+------------+---------+
only showing top 5 rows


                                                                                

In [16]:
# join customers who qualify for loyalty
inner_join = (
    customer_total.alias("c").join(loyalty_b.alias("l"), on="CustomerID", how="inner")
    .select("CustomerID", F.col("c.total_spent").alias("total_spent"),
            "loyalty_tier", "threshold")
    .orderBy(F.col("total_spent").desc())
)

inner_join.show(10, truncate=False)

(inner_join.coalesce(1)
    .write.mode("overwrite").option("header", True)
    .csv("midterm/output/bonus_join_inner"))


                                                                                

+----------+-----------+------------+---------+
|CustomerID|total_spent|loyalty_tier|threshold|
+----------+-----------+------------+---------+
|14646     |280206.02  |Gold        |2000.0   |
|18102     |259657.3   |Gold        |2000.0   |
|17450     |194550.79  |Gold        |2000.0   |
|16446     |168472.5   |Gold        |2000.0   |
|14911     |143825.06  |Gold        |2000.0   |
|12415     |124914.53  |Gold        |2000.0   |
|14156     |117379.63  |Gold        |2000.0   |
|17511     |91062.38   |Gold        |2000.0   |
|16029     |81024.84   |Gold        |2000.0   |
|12346     |77183.6    |Gold        |2000.0   |
+----------+-----------+------------+---------+
only showing top 10 rows


                                                                                

In [17]:
# stop spark session
spark.stop()
