In [0]:
# imports
from pyspark.sql.functions import col, expr, concat_ws, year, month, sum as sql_sum, lag, concat, lit, when, min as sql_min, max, countDistinct, sum, datediff, expr, to_date  
from pyspark.sql.window import Window
from pyspark.sql.types import StringType

In [0]:
# Data import for retail.csv
# File location and type
file_location = "/FileStore/tables/jarvis/retail.csv"
file_type = "csv"

# CSV options
infer_schema = "true"
first_row_is_header = "true"
delimiter = ","

# The applied options are for CSV files. For other file types, these will be ignored.
retail = spark.read.format(file_type) \
                .option("inferSchema", infer_schema) \
                .option("header", first_row_is_header) \
                .option("sep", delimiter) \
                .load(file_location)

retail.printSchema

Out[1]: <bound method DataFrame.printSchema of DataFrame[invoice_no: string, stock_code: string, description: string, quantity: int, invoice_date: timestamp, unit_price: double, customer_id: int, country: string]>

In [0]:
# Filter out rows with null customer_id
retail = retail.filter(retail["customer_id"].isNotNull())
display(retail)

invoice_no,stock_code,description,quantity,invoice_date,unit_price,customer_id,country,total_amount,invoice_month
489434,85048,15CM CHRISTMAS GLASS BALL 20 LIGHTS,12,2009-12-01T07:45:00.000+0000,6.95,13085,United Kingdom,83.4,2009-12-01
489434,79323P,PINK CHERRY LIGHTS,12,2009-12-01T07:45:00.000+0000,6.75,13085,United Kingdom,81.0,2009-12-01
489434,79323W,WHITE CHERRY LIGHTS,12,2009-12-01T07:45:00.000+0000,6.75,13085,United Kingdom,81.0,2009-12-01
489434,22041,"""RECORD FRAME 7"""" SINGLE SIZE """,48,2009-12-01T07:45:00.000+0000,2.1,13085,United Kingdom,100.8,2009-12-01
489434,21232,STRAWBERRY CERAMIC TRINKET BOX,24,2009-12-01T07:45:00.000+0000,1.25,13085,United Kingdom,30.0,2009-12-01
489434,22064,PINK DOUGHNUT TRINKET POT,24,2009-12-01T07:45:00.000+0000,1.65,13085,United Kingdom,39.6,2009-12-01
489434,21871,SAVE THE PLANET MUG,24,2009-12-01T07:45:00.000+0000,1.25,13085,United Kingdom,30.0,2009-12-01
489434,21523,FANCY FONT HOME SWEET HOME DOORMAT,10,2009-12-01T07:45:00.000+0000,5.95,13085,United Kingdom,59.5,2009-12-01
489435,22350,CAT BOWL,12,2009-12-01T07:46:00.000+0000,2.55,13085,United Kingdom,30.6,2009-12-01
489435,22349,"DOG BOWL , CHASING BALL DESIGN",12,2009-12-01T07:46:00.000+0000,3.75,13085,United Kingdom,45.0,2009-12-01


In [0]:
retail.printSchema

Out[16]: <bound method DataFrame.printSchema of DataFrame[invoice_no: string, stock_code: string, description: string, quantity: int, invoice_date: timestamp, unit_price: double, customer_id: int, country: string, total_amount: double, invoice_month: date]>

# Total Invoice Amount Distribution
### Calculate the invoice amount. 
Note: an invoice consists of one or more items where each item is a row in the df.

In [0]:
# Calculate the total amount for each item
retail = retail.withColumn("total_amount", expr("quantity * unit_price"))

# Group by invoice number and calculate the total invoice amount
invoice_amounts = retail.groupBy("invoice_no").agg({"total_amount": "sum"})

# Rename the resulting column for clarity
invoice_amounts = invoice_amounts.withColumnRenamed("sum(total_amount)", "invoice_amount")

# Show the resulting DataFrame
display(invoice_amounts)

invoice_no,invoice_amount
489677,192.0
C491017,-4.95
491045,303.2
491658,155.05999999999997
C491705,-22.5
C492541,-99.0
C493168,-177.60000000000002
493542,118.75
493977,275.95
C493984,-10.43


# Monthly Placed and Canceled Orders

In [0]:
# Create a new column with YYYYMM format
retail = retail.withColumn("invoice_month", to_date(concat(year("invoice_date"), lit("-"), month("invoice_date"))))

# Identify canceled orders and calculate the number of canceled invoices per month
canceled_orders = retail.filter(retail["invoice_no"].rlike("C")).groupBy("invoice_month").count()
canceled_orders = canceled_orders.withColumnRenamed("count", "canceled_orders")

# Calculate the number of placed orders per month
total_orders = retail.groupBy("invoice_month").count()
total_orders = total_orders.withColumnRenamed("count", "total_orders")

# Calculate the difference between placed and canceled orders to get monthly placed orders
monthly_placed_orders = total_orders.join(canceled_orders, "invoice_month", "left_outer")

# Fill NaN values with 0
monthly_placed_orders = monthly_placed_orders.fillna(0)
monthly_placed_orders = monthly_placed_orders.withColumn("placed_orders", monthly_placed_orders["total_orders"] - (2 * monthly_placed_orders["canceled_orders"]))
monthly_placed_orders = monthly_placed_orders.orderBy("invoice_month")

# Show the resulting DataFrame
display(monthly_placed_orders)


invoice_month,total_orders,canceled_orders,placed_orders
2009-12-01,31760,999,29762
2010-01-01,22439,661,21117
2010-02-01,23906,537,22832
2010-03-01,33114,812,31490
2010-04-01,27833,595,26643
2010-05-01,29604,960,27684
2010-06-01,31950,759,30432
2010-07-01,27746,713,26320
2010-08-01,26942,549,25844
2010-09-01,35386,784,33818


Databricks visualization. Run in Databricks to view.

# Monthly Sales
### Calculate the monthly sales data
Plot a chart to show monthly sales (e.g. x-asix=year_month, y-axis=sales_amount)

In [0]:
# Filter for numeric invoice numbers
sales_df = retail.filter(retail["invoice_no"].rlike("^[0-9]+$"))

# Calculate invoice amount
sales_df = sales_df.withColumn("invoice_amount", col("quantity") * col("unit_price"))

# Calculate the monthly sales data
window_spec = Window.partitionBy("invoice_month")
monthly_sales = sales_df.groupBy("invoice_month").agg(sql_sum("invoice_amount").alias("sales_amount"))
monthly_sales = monthly_sales.orderBy("invoice_month")

display(monthly_sales)


invoice_month,sales_amount
2009-12-01,686654.1599999949
2010-01-01,557319.0620000134
2010-02-01,506371.06600001536
2010-03-01,699608.9910000064
2010-04-01,594609.1919999977
2010-05-01,599985.7900000075
2010-06-01,639066.5800000058
2010-07-01,591636.7400000112
2010-08-01,604242.6499999989
2010-09-01,831615.0009999905


Databricks visualization. Run in Databricks to view.

# Monthly Sales Growth

In [0]:
# Filter for numeric invoice numbers
sales_growth_df = retail.filter(retail["invoice_no"].rlike("^[0-9]+$"))

# Calculate invoice amount
sales_growth_df = sales_growth_df.withColumn("invoice_amount", col("quantity") * col("unit_price"))

# Calculate the monthly sales data
window_spec = Window.orderBy("invoice_month")
monthly_sales_growth = sales_growth_df.groupBy("invoice_month").agg(sql_sum("invoice_amount").alias("sales_amount"))
monthly_sales_growth = monthly_sales_growth.orderBy("invoice_month")

# Calculate the lagged sales data for the previous month
monthly_sales_growth = monthly_sales_growth.withColumn("previous_sales", lag("sales_amount").over(window_spec))

# Calculate the percentage change in sales
monthly_sales_growth = monthly_sales_growth.withColumn("sales_growth", ((col("sales_amount") - col("previous_sales")) / col("previous_sales")) * 100)

display(monthly_sales_growth)

invoice_month,sales_amount,previous_sales,sales_growth
2009-12-01,686654.1599999949,,
2010-01-01,557319.0620000134,686654.1599999949,-18.835551509654067
2010-02-01,506371.06600001536,557319.0620000134,-9.141620926649166
2010-03-01,699608.9910000064,506371.06600001536,38.16132831728289
2010-04-01,594609.1919999977,699608.9910000064,-15.008354716814637
2010-05-01,599985.7900000075,594609.1919999977,0.9042238284149832
2010-06-01,639066.5800000058,599985.7900000075,6.513619264215875
2010-07-01,591636.7400000112,639066.5800000058,-7.421736871296599
2010-08-01,604242.6499999989,591636.7400000112,2.130684108628456
2010-09-01,831615.0009999905,604242.6499999989,37.62931183358077


Databricks visualization. Run in Databricks to view.

# Monthly Active Users
### Compute number of active users for each month

In [0]:
# Group by invoice_month and count distinct CustomerID for each month
monthly_active_users = retail.groupBy("invoice_month").agg(countDistinct("customer_id").alias("active_users"))

# Sort the result by invoice_month in ascending order
monthly_active_users = monthly_active_users.orderBy("invoice_month")

display(monthly_active_users)

invoice_month,active_users
2009-12-01,1045
2010-01-01,786
2010-02-01,807
2010-03-01,1111
2010-04-01,998
2010-05-01,1062
2010-06-01,1095
2010-07-01,988
2010-08-01,964
2010-09-01,1202


Databricks visualization. Run in Databricks to view.

# New and Existing Users

Plot a diagram to show new and exiting user for each month.
A user is identified as a new user when he/she makes the first purchase
A user is identified as an existing user when he/she made purchases in the past

In [0]:
# Calculate the first purchase year-month for each customer
window_spec = Window.partitionBy("customer_id").orderBy("invoice_month")
first_purchase_month = retail.select("customer_id", "invoice_month").distinct().withColumn("first_purchase_month", sql_min("invoice_month").over(window_spec))

# Join the first purchase data with the retail data to identify new/existing users
retail_with_first_purchase = retail.join(first_purchase_month, ["customer_id", "invoice_month"], "left")

# Create a new column to categorize users as "New" or "Existing"
retail_with_first_purchase = retail_with_first_purchase.withColumn("user_type", when(col("invoice_month") == col("first_purchase_month"), "New").otherwise("Existing"))

# Group by invoice_month and user_type to count new and existing users for each month
user_type_counts = retail_with_first_purchase.groupBy("invoice_month", "user_type").agg(countDistinct("customer_id").alias("user_count"))

# Sort the result by invoice_month in ascending order
user_type_counts = user_type_counts.orderBy("invoice_month")

display(user_type_counts)

invoice_month,user_type,user_count
2009-12-01,New,1045
2010-01-01,New,394
2010-01-01,Existing,392
2010-02-01,Existing,444
2010-02-01,New,363
2010-03-01,New,436
2010-03-01,Existing,675
2010-04-01,New,291
2010-04-01,Existing,707
2010-05-01,New,254


Databricks visualization. Run in Databricks to view.

# Finding RFM

Note: To simplify the problem, let's keep all placed and canceled orders.

In [0]:
# Calculate Recency, Frequency, and Monetary Value for each customer
rfm_data = retail.groupBy("customer_id").agg(
            max("invoice_date").alias("Recency"),
            countDistinct("invoice_no").alias("Frequency"),
            sql_sum(expr("unit_price")).alias("MonetaryValue")
)

# Calculate Recency in days
max_date = retail.selectExpr("max(invoice_date)").collect()[0][0]
rfm_data = rfm_data.withColumn("Recency", datediff(lit(max_date), rfm_data["Recency"]))

# Create bins or segments for Recency, Frequency, and Monetary Value
quantiles = rfm_data.approxQuantile(["Recency", "Frequency", "MonetaryValue"], [0.25, 0.5, 0.75], 0.05)
rfm_segments = rfm_data.withColumn("RecencyScore", 
    expr(
        "CASE " +
        "WHEN Recency <= {} THEN 4 ".format(quantiles[0][0]) +
        "WHEN Recency <= {} THEN 3 ".format(quantiles[0][1]) +
        "WHEN Recency <= {} THEN 2 ".format(quantiles[0][2]) +
        "ELSE 1 END"
    )
).withColumn("FrequencyScore", 
    expr(
        "CASE " +
        "WHEN Frequency <= {} THEN 1 ".format(quantiles[1][0]) +
        "WHEN Frequency <= {} THEN 2 ".format(quantiles[1][1]) +
        "WHEN Frequency <= {} THEN 3 ".format(quantiles[1][2]) +
        "ELSE 4 END"
    )
).withColumn("MonetaryScore", 
    expr(
        "CASE " +
        "WHEN MonetaryValue <= {} THEN 4 ".format(quantiles[2][0]) +
        "WHEN MonetaryValue <= {} THEN 3 ".format(quantiles[2][1]) +
        "WHEN MonetaryValue <= {} THEN 2 ".format(quantiles[2][2]) +
        "ELSE 1 END"
    )
)

# Calculate RFM Score
rfm_segments = rfm_segments.withColumn(
    "RFM_Score",
    expr("concat(RecencyScore, FrequencyScore, MonetaryScore)")
)

# Show the results
display(rfm_segments)

customer_id,Recency,Frequency,MonetaryValue,RecencyScore,FrequencyScore,MonetaryScore,RFM_Score
18051,634,8,113.36,1,4,3,143
13623,30,15,1082.4999999999998,3,4,1,341
14832,630,3,920.01,1,2,1,121
17389,0,77,2447.380000000001,4,4,1,441
15447,330,6,121.92,2,3,3,233
15727,16,15,2474.36,4,4,1,441
17753,464,5,205.84,1,3,2,132
17679,52,11,291.41,3,4,2,342
13285,23,6,539.3000000000001,4,3,1,431
13289,723,1,70.25,1,1,3,113


# RFM Segmentation

In [0]:
# Define segmentation mapping
seg_map = {
    r'[1-2][1-2]': 'Hibernating',
    r'[1-2][3-4]': 'At Risk',
    r'[1-2]5': 'Can\'t Lose',
    r'3[1-2]': 'About to Sleep',
    r'33': 'Need Attention',
    r'[3-4][4-5]': 'Loyal Customers',
    r'41': 'Promising',
    r'51': 'New Customers',
    r'[4-5][2-3]': 'Potential Loyalists',
    r'5[4-5]': 'Champions'
}

# Create a new 'Segment' column by combining RecencyScore and FrequencyScore
rfm_segments = rfm_segments.withColumn("Segment", concat_ws("", col("RecencyScore"), col("FrequencyScore")))

# Replace the Segment values based on the seg_map
for pattern, segment_name in seg_map.items():
    rfm_segments = rfm_segments.withColumn("Segment", when(col("Segment").rlike(pattern), segment_name).otherwise(col("Segment")))

# Group by Segment and calculate mean and count
segmented_df = rfm_segments.groupBy("Segment").agg(
    {"Recency": "mean", "Frequency": "mean", "MonetaryValue": "mean", "Segment": "count"}
)

# Filter out rows with "nan" in the Segment column
segmented_df = segmented_df.filter(col("Segment").isNotNull())

# Rename columns for better understanding
segmented_df = segmented_df.withColumnRenamed("avg(Recency)", "RecencyMean")
segmented_df = segmented_df.withColumnRenamed("avg(Frequency)", "FrequencyMean")
segmented_df = segmented_df.withColumnRenamed("avg(MonetaryValue)", "MonetaryValueMean")
segmented_df = segmented_df.withColumnRenamed("count(Segment)", "CustomerCount")


display(segmented_df)

Segment,RecencyMean,MonetaryValueMean,FrequencyMean,CustomerCount
Promising,13.180616740088103,107.21700440528632,1.5726872246696035,227
At Risk,276.13467656415696,610.1436097560971,7.837751855779428,943
About to Sleep,53.98181818181818,124.72738363636358,1.8709090909090909,550
Hibernating,413.7672497570457,146.07566763848394,1.6141885325558796,2058
Potential Loyalists,12.02186878727634,324.47602385685883,4.7673956262425445,503
Loyal Customers,24.517241379310345,1390.2217068965506,22.22100313479624,1276
Need Attention,54.6987012987013,324.56987272727264,5.259740259740259,385
