# Customer retail notebook

In [13]:
# 01_data_ingestion.py
from pyspark.sql import SparkSession

spark = SparkSession.builder.getOrCreate()

df = spark.read.csv("online_retail.csv", header=True, inferSchema=True)

df.printSchema()
df.show(5)
df.select("Country").distinct().show()

                                                                                

root
 |-- InvoiceNo: string (nullable = true)
 |-- StockCode: string (nullable = true)
 |-- Description: string (nullable = true)
 |-- Quantity: integer (nullable = true)
 |-- InvoiceDate: timestamp (nullable = true)
 |-- UnitPrice: double (nullable = true)
 |-- CustomerID: double (nullable = true)
 |-- Country: string (nullable = true)

+---------+---------+--------------------+--------+-------------------+---------+----------+--------------+
|InvoiceNo|StockCode|         Description|Quantity|        InvoiceDate|UnitPrice|CustomerID|       Country|
+---------+---------+--------------------+--------+-------------------+---------+----------+--------------+
|   536365|   85123A|WHITE HANGING HEA...|       6|2010-12-01 08:26:00|     2.55|   17850.0|United Kingdom|
|   536365|    71053| WHITE METAL LANTERN|       6|2010-12-01 08:26:00|     3.39|   17850.0|United Kingdom|
|   536365|   84406B|CREAM CUPID HEART...|       8|2010-12-01 08:26:00|     2.75|   17850.0|United Kingdom|
|   536365| 



+---------------+
|        Country|
+---------------+
|         Sweden|
|      Singapore|
|        Germany|
|         France|
|         Greece|
|        Belgium|
|        Finland|
|          Italy|
|           EIRE|
|      Lithuania|
|         Norway|
|          Spain|
|        Denmark|
|      Hong Kong|
|        Iceland|
|         Israel|
|Channel Islands|
|         Cyprus|
|    Switzerland|
|        Lebanon|
+---------------+
only showing top 20 rows


                                                                                

In [14]:
# 02_data_cleaning.py

from pyspark.sql.functions import col

# Drop nulls
df_clean = df.dropna(subset=["InvoiceNo", "CustomerID", "InvoiceDate", "UnitPrice", "Quantity"])

# Filter out refunds/negative quantities
df_clean = df_clean.filter(col("Quantity") > 0)

# Remove duplicates
df_clean = df_clean.dropDuplicates()

df_clean.cache().count()

                                                                                

392732

In [17]:
# 03_rfm_segmentation.py

from pyspark.sql.functions import max, datediff, sum as _sum, count as _count
from pyspark.sql.functions import lit

snapshot_date = df_clean.agg(max("InvoiceDate")).collect()[0][0]

rfm = df_clean.groupBy("CustomerID").agg(
    datediff(lit(snapshot_date), max("InvoiceDate")).alias("Recency"),
    _count("InvoiceNo").alias("Frequency"),
    _sum(col("UnitPrice") * col("Quantity")).alias("Monetary")
)

rfm.show(5)



+----------+-------+---------+-----------------+
|CustomerID|Recency|Frequency|         Monetary|
+----------+-------+---------+-----------------+
|   17884.0|      3|      113|695.0699999999999|
|   13121.0|    269|       50|283.7300000000001|
|   12550.0|     79|       57|964.8299999999996|
|   15898.0|      1|       84|          1383.83|
|   17392.0|    306|       56|           417.27|
+----------+-------+---------+-----------------+
only showing top 5 rows


                                                                                

In [18]:
# 04_churn_prediction.py

from pyspark.ml.feature import VectorAssembler
from pyspark.ml.classification import LogisticRegression
from pyspark.ml import Pipeline

# Create churn label (example: customers with no purchases in last N days)
rfm = rfm.withColumn("Churn", (rfm["Recency"] > 90).cast("integer"))

features = ["Recency", "Frequency", "Monetary"]
assembler = VectorAssembler(inputCols=features, outputCol="features")

lr = LogisticRegression(featuresCol="features", labelCol="Churn")

pipeline = Pipeline(stages=[assembler, lr])

train, test = rfm.randomSplit([0.8, 0.2])
model = pipeline.fit(train)

predictions = model.transform(test)
predictions.select("Churn", "prediction", "probability").show(10)

25/11/27 16:30:07 WARN InstanceBuilder: Failed to load implementation from:dev.ludovic.netlib.blas.JNIBLAS

+-----+----------+-----------+
|Churn|prediction|probability|
+-----+----------+-----------+
|    0|       0.0|  [1.0,0.0]|
|    0|       0.0|  [1.0,0.0]|
|    0|       0.0|  [1.0,0.0]|
|    1|       1.0|  [0.0,1.0]|
|    0|       0.0|  [1.0,0.0]|
|    1|       1.0|  [0.0,1.0]|
|    0|       0.0|  [1.0,0.0]|
|    0|       0.0|  [1.0,0.0]|
|    1|       1.0|  [0.0,1.0]|
|    0|       0.0|  [1.0,0.0]|
+-----+----------+-----------+
only showing top 10 rows


                                                                                