In [1]:
from google.colab import drive         #掛載我的雲端硬碟上去
drive.mount('/content/drive')

Mounted at /content/drive


In [2]:
# Run below commands
!apt-get install openjdk-8-jdk-headless -qq > /dev/null
!wget -q https://archive.apache.org/dist/spark/spark-3.0.0/spark-3.0.0-bin-hadoop3.2.tgz
!tar xf spark-3.0.0-bin-hadoop3.2.tgz
!pip install -q findspark

In [3]:
import os
os.environ["JAVA_HOME"] = "/usr/lib/jvm/java-8-openjdk-amd64"
os.environ["SPARK_HOME"] = "/content/spark-3.0.0-bin-hadoop3.2"

In [4]:
import findspark
findspark.init()

import pandas as pd
import plotly.graph_objs as go
from plotly.offline import download_plotlyjs, init_notebook_mode, plot, iplot
init_notebook_mode(connected=True)

import pyspark

from pyspark.sql import SparkSession
from pyspark.sql.functions import *
from pyspark.sql.types import *

In [5]:
spark = SparkSession \
    .builder \
    .appName("RFM Analysis with PySpark") \
    .getOrCreate()

In [6]:
spark

In [8]:
data = spark.read.format("csv").option("header", "true").load("/content/drive/MyDrive/data.csv")

In [9]:
data

DataFrame[InvoiceNo: string, StockCode: string, Description: string, Quantity: string, InvoiceDate: string, UnitPrice: string, CustomerID: string, Country: string]

In [10]:
data.columns

['InvoiceNo',
 'StockCode',
 'Description',
 'Quantity',
 'InvoiceDate',
 'UnitPrice',
 'CustomerID',
 'Country']

In [11]:
data.printSchema()

root
 |-- InvoiceNo: string (nullable = true)
 |-- StockCode: string (nullable = true)
 |-- Description: string (nullable = true)
 |-- Quantity: string (nullable = true)
 |-- InvoiceDate: string (nullable = true)
 |-- UnitPrice: string (nullable = true)
 |-- CustomerID: string (nullable = true)
 |-- Country: string (nullable = true)



In [12]:
# cache data in memory
data.cache().count()

541909

In [13]:
data.show(5)

+---------+---------+--------------------+--------+--------------+---------+----------+--------------+
|InvoiceNo|StockCode|         Description|Quantity|   InvoiceDate|UnitPrice|CustomerID|       Country|
+---------+---------+--------------------+--------+--------------+---------+----------+--------------+
|   536365|   85123A|WHITE HANGING HEA...|       6|12/1/2010 8:26|     2.55|     17850|United Kingdom|
|   536365|    71053| WHITE METAL LANTERN|       6|12/1/2010 8:26|     3.39|     17850|United Kingdom|
|   536365|   84406B|CREAM CUPID HEART...|       8|12/1/2010 8:26|     2.75|     17850|United Kingdom|
|   536365|   84029G|KNITTED UNION FLA...|       6|12/1/2010 8:26|     3.39|     17850|United Kingdom|
|   536365|   84029E|RED WOOLLY HOTTIE...|       6|12/1/2010 8:26|     3.39|     17850|United Kingdom|
+---------+---------+--------------------+--------+--------------+---------+----------+--------------+
only showing top 5 rows



# 1 Data Pre-Processing

In [14]:
data = data.withColumn("Quantity", data["Quantity"].cast(IntegerType()))   #字串型態轉成數值
data = data.withColumn("UnitPrice", data["UnitPrice"].cast(DoubleType()))

In [15]:
# define Total column     #創立新欄位
data = data.withColumn("Total", round(data["UnitPrice"] * data["Quantity"], 2))

In [17]:
data.show(5)

+---------+---------+--------------------+--------+--------------+---------+----------+--------------+-----+
|InvoiceNo|StockCode|         Description|Quantity|   InvoiceDate|UnitPrice|CustomerID|       Country|Total|
+---------+---------+--------------------+--------+--------------+---------+----------+--------------+-----+
|   536365|   85123A|WHITE HANGING HEA...|       6|12/1/2010 8:26|     2.55|     17850|United Kingdom| 15.3|
|   536365|    71053| WHITE METAL LANTERN|       6|12/1/2010 8:26|     3.39|     17850|United Kingdom|20.34|
|   536365|   84406B|CREAM CUPID HEART...|       8|12/1/2010 8:26|     2.75|     17850|United Kingdom| 22.0|
|   536365|   84029G|KNITTED UNION FLA...|       6|12/1/2010 8:26|     3.39|     17850|United Kingdom|20.34|
|   536365|   84029E|RED WOOLLY HOTTIE...|       6|12/1/2010 8:26|     3.39|     17850|United Kingdom|20.34|
+---------+---------+--------------------+--------+--------------+---------+----------+--------------+-----+
only showing top 5 

In [18]:
from pyspark.sql import functions as F   #時間轉換
data = data.withColumn("date", F.to_date(F.to_timestamp(col("InvoiceDate"), "M/d/yyyy H:mm")))

In [19]:
data.show(5)

+---------+---------+--------------------+--------+--------------+---------+----------+--------------+-----+----------+
|InvoiceNo|StockCode|         Description|Quantity|   InvoiceDate|UnitPrice|CustomerID|       Country|Total|      date|
+---------+---------+--------------------+--------+--------------+---------+----------+--------------+-----+----------+
|   536365|   85123A|WHITE HANGING HEA...|       6|12/1/2010 8:26|     2.55|     17850|United Kingdom| 15.3|2010-12-01|
|   536365|    71053| WHITE METAL LANTERN|       6|12/1/2010 8:26|     3.39|     17850|United Kingdom|20.34|2010-12-01|
|   536365|   84406B|CREAM CUPID HEART...|       8|12/1/2010 8:26|     2.75|     17850|United Kingdom| 22.0|2010-12-01|
|   536365|   84029G|KNITTED UNION FLA...|       6|12/1/2010 8:26|     3.39|     17850|United Kingdom|20.34|2010-12-01|
|   536365|   84029E|RED WOOLLY HOTTIE...|       6|12/1/2010 8:26|     3.39|     17850|United Kingdom|20.34|2010-12-01|
+---------+---------+-------------------

In [20]:
# calculate difference in days between 2011-12-31 and the Invoice Date
data = data.withColumn("RecencyDays", expr("datediff('2011-12-31', Date)"))  #跟指定日期做比較

In [21]:
data.show(5)

+---------+---------+--------------------+--------+--------------+---------+----------+--------------+-----+----------+-----------+
|InvoiceNo|StockCode|         Description|Quantity|   InvoiceDate|UnitPrice|CustomerID|       Country|Total|      date|RecencyDays|
+---------+---------+--------------------+--------+--------------+---------+----------+--------------+-----+----------+-----------+
|   536365|   85123A|WHITE HANGING HEA...|       6|12/1/2010 8:26|     2.55|     17850|United Kingdom| 15.3|2010-12-01|        395|
|   536365|    71053| WHITE METAL LANTERN|       6|12/1/2010 8:26|     3.39|     17850|United Kingdom|20.34|2010-12-01|        395|
|   536365|   84406B|CREAM CUPID HEART...|       8|12/1/2010 8:26|     2.75|     17850|United Kingdom| 22.0|2010-12-01|        395|
|   536365|   84029G|KNITTED UNION FLA...|       6|12/1/2010 8:26|     3.39|     17850|United Kingdom|20.34|2010-12-01|        395|
|   536365|   84029E|RED WOOLLY HOTTIE...|       6|12/1/2010 8:26|     3.39|

# 2 Create RFM Table

In [22]:
# Creation of RFM table

rfm_table = data.groupBy("CustomerId").agg(min("RecencyDays").alias("Recency"),
                    count("InvoiceNo").alias("Frequency"),
                    sum("Total").alias("Monetary"))

In [23]:
rfm_table = rfm_table.withColumn("Monetary", round(rfm_table["Monetary"], 2))

In [24]:
rfm_table.printSchema()

root
 |-- CustomerId: string (nullable = true)
 |-- Recency: integer (nullable = true)
 |-- Frequency: long (nullable = false)
 |-- Monetary: double (nullable = true)



In [25]:
rfm_table.show(5)

+----------+-------+---------+--------+
|CustomerId|Recency|Frequency|Monetary|
+----------+-------+---------+--------+
|     16250|    283|       24|  389.44|
|     15574|    199|      168|  702.25|
|     15555|     34|      925|  4758.2|
|     15271|     29|      275| 2485.82|
|     17714|    342|       10|   153.0|
+----------+-------+---------+--------+
only showing top 5 rows



In [26]:
rfm_table.cache().count()

4373

# 3 Computing Quartiles of RFM values

In [27]:
r_quartile = rfm_table.approxQuantile("Recency", [0.25, 0.5, 0.75], 0)
f_quartile = rfm_table.approxQuantile("Frequency", [0.25, 0.5, 0.75], 0)
m_quartile = rfm_table.approxQuantile("Monetary", [0.25, 0.5, 0.75], 0)

In [29]:
r_quartile

[38.0, 72.0, 165.0]

In [28]:
# calculate Recency based on quartile

rfm_table = rfm_table.withColumn("R_Quartile", \
                                 when(col("Recency") >= r_quartile[2] , 1).\
                                 when(col("Recency") >= r_quartile[1] , 2).\
                                 when(col("Recency") >= r_quartile[0] , 3).\
                                 otherwise(4))

In [30]:
# calculate Frequency based on quartile

rfm_table = rfm_table.withColumn("F_Quartile", \
                                 when(col("Frequency") > f_quartile[2] , 4).\
                                 when(col("Frequency") > f_quartile[1] , 3).\
                                 when(col("Frequency") > f_quartile[0] , 2).\
                                 otherwise(1))

In [31]:
# calculate Monetary based on quartile

rfm_table = rfm_table.withColumn("M_Quartile", \
                                 when(col("Monetary") >= m_quartile[2] , 4).\
                                 when(col("Monetary") >= m_quartile[1] , 3).\
                                 when(col("Monetary") >= m_quartile[0] , 2).\
                                 otherwise(1))

In [32]:
# combine the scores (R_Quartile, F_Quartile,M_Quartile) together.

rfm_table = rfm_table.withColumn("RFM_Score", concat(col("R_Quartile"), col("F_Quartile"), col("M_Quartile")))

In [33]:
rfm_table.show(10)

+----------+-------+---------+--------+----------+----------+----------+---------+
|CustomerId|Recency|Frequency|Monetary|R_Quartile|F_Quartile|M_Quartile|RFM_Score|
+----------+-------+---------+--------+----------+----------+----------+---------+
|     16250|    283|       24|  389.44|         1|         2|         2|      122|
|     15574|    199|      168|  702.25|         1|         4|         3|      143|
|     15555|     34|      925|  4758.2|         4|         4|         4|      444|
|     15271|     29|      275| 2485.82|         4|         4|         4|      444|
|     17714|    342|       10|   153.0|         1|         1|         1|      111|
|     17686|     29|      286| 5739.46|         4|         4|         4|      444|
|     13865|     80|       30|  501.56|         2|         2|         2|      222|
|     14157|     41|       49|  400.43|         3|         3|         2|      332|
|     13610|     34|      228| 1115.43|         4|         4|         3|      443|
|   

# 4 RFM Analysis

In [34]:
# Best customers

rfm_table.select("CustomerID").where("RFM_Score == 444").show(10)

+----------+
|CustomerID|
+----------+
|     15555|
|     15271|
|     17686|
|     17757|
|     16549|
|     13985|
|     14525|
|     18283|
|     12957|
|     17491|
+----------+
only showing top 10 rows



In [35]:
# group by RFM Score

grouped_by_rfmscore = rfm_table.groupBy("RFM_Score").count().orderBy("count", ascending=False)

In [36]:
grouped_by_rfmscore.show()

+---------+-----+
|RFM_Score|count|
+---------+-----+
|      444|  442|
|      111|  395|
|      344|  234|
|      122|  210|
|      333|  187|
|      211|  183|
|      222|  175|
|      233|  168|
|      433|  141|
|      322|  128|
|      311|  128|
|      121|  107|
|      244|  105|
|      112|  104|
|      223|   96|
|      343|   93|
|      212|   83|
|      443|   80|
|      332|   75|
|      422|   72|
+---------+-----+
only showing top 20 rows



In [37]:
# convert Spark dataframe to pandas in order to visualize data

grouped_by_rfmscore_pandas = grouped_by_rfmscore.toPandas()

In [38]:
grouped_by_rfmscore_pandas

Unnamed: 0,RFM_Score,count
0,444,442
1,111,395
2,344,234
3,122,210
4,333,187
...,...,...
57,124,6
58,142,4
59,414,3
60,441,1
