In [0]:
from pyspark.sql.functions import col, max as spark_max, sum as spark_sum, countDistinct, datediff, current_date

# Read Silver data
silver_df = spark.table("ecommerce.silver.silver_online_retail")

display(silver_df)


InvoiceNo,StockCode,Description,Quantity,InvoiceDate,UnitPrice,CustomerID,Country,total_amount
536365,85123A,WHITE HANGING HEART T-LIGHT HOLDER,6,2010-12-01T08:26:00.000Z,2.55,17850,United Kingdom,15.3
536365,71053,WHITE METAL LANTERN,6,2010-12-01T08:26:00.000Z,3.39,17850,United Kingdom,20.34
536365,84406B,CREAM CUPID HEARTS COAT HANGER,8,2010-12-01T08:26:00.000Z,2.75,17850,United Kingdom,22.0
536365,84029G,KNITTED UNION FLAG HOT WATER BOTTLE,6,2010-12-01T08:26:00.000Z,3.39,17850,United Kingdom,20.34
536365,84029E,RED WOOLLY HOTTIE WHITE HEART.,6,2010-12-01T08:26:00.000Z,3.39,17850,United Kingdom,20.34
536365,22752,SET 7 BABUSHKA NESTING BOXES,2,2010-12-01T08:26:00.000Z,7.65,17850,United Kingdom,15.3
536365,21730,GLASS STAR FROSTED T-LIGHT HOLDER,6,2010-12-01T08:26:00.000Z,4.25,17850,United Kingdom,25.5
536366,22633,HAND WARMER UNION JACK,6,2010-12-01T08:28:00.000Z,1.85,17850,United Kingdom,11.1
536366,22632,HAND WARMER RED POLKA DOT,6,2010-12-01T08:28:00.000Z,1.85,17850,United Kingdom,11.1
536367,84879,ASSORTED COLOUR BIRD ORNAMENT,32,2010-12-01T08:34:00.000Z,1.69,13047,United Kingdom,54.08


In [0]:
# Get dataset reference date
from pyspark.sql.functions import max as spark_max

# Reference date = last date in dataset
reference_date = silver_df.select(
    spark_max("InvoiceDate").alias("ref_date")
).collect()[0]["ref_date"]

reference_date


datetime.datetime(2011, 12, 9, 12, 50)

In [0]:
# CUSTOMER-LEVEL AGGREGATION
from pyspark.sql.functions import when, datediff, lit

gold_customer_df = (
    silver_df
    .groupBy("CustomerID")
    .agg(
        spark_sum("total_amount").alias("total_spend"),
        countDistinct("InvoiceNo").alias("total_orders"),
        spark_max("InvoiceDate").alias("last_purchase_date")
    )
    # Use reference date instead of current_date()
    .withColumn(
        "days_inactive",
        datediff(lit(reference_date), col("last_purchase_date"))
    )
    .withColumn(
        "churn",
        when(col("days_inactive") > 90, 1).otherwise(0)
    )
)

display(gold_customer_df)


CustomerID,total_spend,total_orders,last_purchase_date,days_inactive,churn
17850,5391.209999999999,34,2010-12-02T15:27:00.000Z,372,1
13047,3237.54,10,2011-11-08T12:06:00.000Z,31,0
12583,7281.38,15,2011-12-07T08:07:00.000Z,2,0
13748,948.25,5,2011-09-05T09:45:00.000Z,95,1
15100,876.0,3,2011-01-10T10:35:00.000Z,333,1
15291,4668.3,15,2011-11-14T11:02:00.000Z,25,0
14688,5630.869999999999,21,2011-12-02T12:26:00.000Z,7,0
17809,5411.910000000001,12,2011-11-23T12:59:00.000Z,16,0
15311,60767.9,91,2011-12-09T12:00:00.000Z,0,0
16098,2005.63,7,2011-09-13T09:59:00.000Z,87,0


In [0]:
%sql
CREATE SCHEMA IF NOT EXISTS ecommerce.gold;

In [0]:
# WRITE GOLD CUSTOMER TABLE
gold_customer_df.write \
    .format("delta") \
    .mode("overwrite") \
    .saveAsTable("ecommerce.gold.gold_customer_metrics")

In [0]:
%sql
SELECT churn, COUNT(*) 
FROM ecommerce.gold.gold_customer_metrics
GROUP BY churn;

churn,COUNT(*)
0,2890
1,1449
