##Calculating the Probability of Future Customer Engagement

**NOTE:** Snowpark Implementation

In non-subscription retail models, customers come and go with no long-term commitments, making it very difficult to determine whether a customer will return in the future. Determining the probability that a customer will re-engage is critical to the design of effective marketing campaigns. Different messaging and promotions may be required to incentivize customers who have likely dropped out to return to our stores. Engaged customers may be more responsive to marketing that encourages them to expand the breadth and scale of purchases with us. Understanding where our customers land with regard to the probability of future engagement is critical to tailoring our marketing efforts to them.  

The *Buy 'til You Die* (BTYD) models popularized by Peter Fader and others leverage two basic customer metrics, *i.e.* the recency of a customer's last engagement and the frequency of repeat transactions over a customer's lifetime, to derive a probability of future re-engagement. This is done by fitting customer history to curves describing the distribution of purchase frequencies and engagement drop-off following a prior purchase. The math behind these models is fairly complex but thankfully it's been encapsulated in the [lifetimes](https://pypi.org/project/Lifetimes/) library, making it much easier for traditional enterprises to employ. The purpose of this notebook is to examine how these models may be applied to customer transaction history and how they may be deployed for integration in marketing processes.

###Step 1: Setup the Environment

To run this notebook, you need to attach to a **Databricks ML Runtime** cluster leveraging Databricks version 6.5+. This version of the Databricks runtime will provide access to many of the pre-configured libraries used here.  Still, there are additional Python libraries which you will need to install and attach to your cluster.  These are:</p>

* xlrd
* lifetimes==0.10.1
* nbconvert

To install these libraries in your Databricks workspace, please follow [these steps](https://docs.databricks.com/libraries.html#workspace-libraries) using the PyPI library source in combination with the bullet-pointed library names in the provided list.  Once installed, please be sure to [attach](https://docs.databricks.com/libraries.html#install-a-library-on-a-cluster) these libraries to the cluster with which you are running this notebook.

With the libraries installed, let's load a sample dataset with which we can examine the BTYD models. The dataset we will use is the [Online Retail Data Set](http://archive.ics.uci.edu/ml/datasets/Online+Retail) available from the UCI Machine Learning Repository.  This dataset is made available as a Microsoft Excel workbook (XLSX).  Having downloaded this XLSX file to our local system, we can load it into our Databricks environment by following the steps provided [here](https://docs.databricks.com/data/tables.html#create-table-ui). Please note when performing the file import, you don't need to select the *Create Table with UI* or the *Create Table in Notebook* options to complete the import process. Also, the name of the XLSX file will be modified upon import as it includes an unsupported space character.  As a result, we will need to programmatically locate the new name for the file assigned by the import process.

Assuming we've uploaded the XLSX to the */FileStore/tables/online_retail/*, we can access it as follows:

In [0]:
from IPython.display import display, HTML, Image , Markdown

# Maximize View
display(HTML("<style>.container { width:90% !important; }</style>"))

import logging
import sys
import os ,json
import dotenv 
import pandas as pd
from snowflake.snowpark import Session
import snowflake.snowpark.functions as F
import snowflake.snowpark.types as T

logging.basicConfig(stream=sys.stdout, level=logging.CRITICAL)

#Load the snowflake login information from env file
dotenv.load_dotenv('/dbfs/FileStore/tables/sflk.env')

#Create a snowpark session
connection_parameters = {
  "account": os.getenv('DEMO_ACCOUNT'),
  "user": os.getenv('DEMO_USER'),
  "password": os.getenv('DEMO_PWD'),
  "role": "sysadmin",
  "warehouse": os.getenv('DEMO_WH'),
  "database": 'stage_db',
  "schema": 'public'
}

session = Session.builder.configs(connection_parameters).create()
print(session.sql("select current_account() ,current_warehouse(), current_database(), current_schema()").collect())

In [0]:
import pandas as pd
import numpy as np

xlsx_filename = '/dbfs/FileStore/tables/Online_Retail.xlsx'

# schema of the excel spreadsheet data range
orders_schema = {
  'InvoiceNo':str,
  'StockCode':str,
  'Description':str,
  'Quantity':np.int64,
  'InvoiceDate':np.datetime64,
  'UnitPrice':np.float64,
  'CustomerID':str,
  'Country':str  
  }

# read spreadsheet to pandas dataframe
# the xlrd library must be installed for this step to work 
orders_pd = pd.read_excel(
  xlsx_filename, 
  sheet_name='Online Retail',
  header=0, # first row is header
  dtype=orders_schema
  ,engine='openpyxl'
  )

# display first few rows from the dataset
orders_pd.head(10)

Unnamed: 0,InvoiceNo,StockCode,Description,Quantity,InvoiceDate,UnitPrice,CustomerID,Country
0,536365,85123A,WHITE HANGING HEART T-LIGHT HOLDER,6,2010-12-01 08:26:00,2.55,17850,United Kingdom
1,536365,71053,WHITE METAL LANTERN,6,2010-12-01 08:26:00,3.39,17850,United Kingdom
2,536365,84406B,CREAM CUPID HEARTS COAT HANGER,8,2010-12-01 08:26:00,2.75,17850,United Kingdom
3,536365,84029G,KNITTED UNION FLAG HOT WATER BOTTLE,6,2010-12-01 08:26:00,3.39,17850,United Kingdom
4,536365,84029E,RED WOOLLY HOTTIE WHITE HEART.,6,2010-12-01 08:26:00,3.39,17850,United Kingdom
5,536365,22752,SET 7 BABUSHKA NESTING BOXES,2,2010-12-01 08:26:00,7.65,17850,United Kingdom
6,536365,21730,GLASS STAR FROSTED T-LIGHT HOLDER,6,2010-12-01 08:26:00,4.25,17850,United Kingdom
7,536366,22633,HAND WARMER UNION JACK,6,2010-12-01 08:28:00,1.85,17850,United Kingdom
8,536366,22632,HAND WARMER RED POLKA DOT,6,2010-12-01 08:28:00,1.85,17850,United Kingdom
9,536367,84879,ASSORTED COLOUR BIRD ORNAMENT,32,2010-12-01 08:34:00,1.69,13047,United Kingdom


The data in the workbook are organized as a range in the Online Retail spreadsheet.  Each record represents a line item in a sales transaction. The fields included in the dataset are:

| Field | Description |
|-------------:|-----:|
|InvoiceNo|A 6-digit integral number uniquely assigned to each transaction|
|StockCode|A 5-digit integral number uniquely assigned to each distinct product|
|Description|The product (item) name|
|Quantity|The quantities of each product (item) per transaction|
|InvoiceDate|The invoice date and a time in mm/dd/yy hh:mm format|
|UnitPrice|The per-unit product price in pound sterling (£)|
|CustomerID| A 5-digit integral number uniquely assigned to each customer|
|Country|The name of the country where each customer resides|

Of these fields, the ones of particular interest for our work are InvoiceNo which identifies the transaction, InvoiceDate which identifies the date of that transaction, and CustomerID which uniquely identifies the customer across multiple transactions. (In a separate notebook, we will examine the monetary value of the transactions through the UnitPrice and Quantity fields.)

###Step 2: Explore the Dataset

To enable the exploration of the data using SQL statements, let's flip the pandas DataFrame into a Spark DataFrame and persist it as a temporary view:

In [0]:

t = session.create_dataframe([(1, "one"), (2, "two")], schema=["col_a", "col_b"])
t.show()
type(session)
display(orders_pd)

In [0]:
import snowflake.snowpark.functions as F
import snowflake.snowpark.types as T


type(orders_pd)
t = session.create_dataframe(orders_pd)

orders = t.select(
    F.col('\"InvoiceNo\"').cast(T.StringType()).as_('InvoiceNo')
    ,F.col('\"StockCode\"').cast(T.StringType()).as_('StockCode')
    ,F.col('\"Description\"').cast(T.StringType()).as_('Description')
    ,F.col('\"Quantity\"').cast(T.IntegerType()).as_('Quantity')
    ,F.call_builtin("to_date", F.col('\"InvoiceDate\"').cast(T.StringType())).as_('InvoiceDate') 
    ,F.col('\"UnitPrice\"').cast(T.FloatType()).as_('UnitPrice')
    ,F.col('\"CustomerID\"').cast(T.StringType()).as_('CustomerID')
    ,F.col('\"Country\"').cast(T.StringType()).as_('Country')
)


orders.write.mode("overwrite").save_as_table("orders")

display(orders.limit(5).to_pandas())

Examining the transaction activity in our dataset, we can see the first transaction occurs December 1, 2010 and the last is on December 9, 2011 making this a dataset that's a little more than 1 year in duration. The daily transaction count shows there is quite a bit of volatility in daily activity for this online retailer:

In [0]:
display(Markdown(f'unique transactions by date '))

# SELECT 
#   TO_DATE(InvoiceDate) as InvoiceDate,
#   COUNT(DISTINCT InvoiceNo) as Transactions
# FROM orders
# GROUP BY TO_DATE(InvoiceDate)
# ORDER BY InvoiceDate;

#notice the re-use of the alias invoicedate in the groupby expression
txn_by_date_df = session.sql(f'''
    SELECT 
      TO_DATE(InvoiceDate) as InvoiceDate,
      COUNT(DISTINCT InvoiceNo) as Transactions
    FROM orders
    GROUP BY InvoiceDate
    ORDER BY InvoiceDate
''')
display(txn_by_date_df.limit(5).to_pandas())

We can smooth this out a bit by summarizing activity by month. It's important to keep in mind that December 2011 only consists of 9 days so the sales decline graphed for the last month should most likely be ignored:

NOTE We will hide the SQL behind each of the following result sets for ease of viewing.  To view this code, simply click the **Show code** item above each of the following charts.

In [0]:
display(Markdown(f'unique transactions by month '))

# SELECT 
#   TRUNC(InvoiceDate, 'month') as InvoiceMonth,
#   COUNT(DISTINCT InvoiceNo) as Transactions
# FROM orders
# GROUP BY TRUNC(InvoiceDate, 'month') 
# ORDER BY InvoiceMonth;

txn_by_mon_df = session.sql(f'''
    SELECT 
      TRUNC(TO_DATE(InvoiceDate), 'month') as InvoiceMonth,
      COUNT(DISTINCT InvoiceNo) as Transactions
    FROM orders
    GROUP BY InvoiceMonth
    ORDER BY InvoiceMonth
''')

display(txn_by_mon_df.limit(5).to_pandas())

For the little more than 1-year period for which we have data, we see over four-thousand unique customers.  These customers generated about twenty-two thousand unique transactions:

In [0]:
display(Markdown(f'unique customers and transactions '))

# SELECT
#  COUNT(DISTINCT CustomerID) as Customers,
#  COUNT(DISTINCT InvoiceNo) as Transactions
# FROM orders
# WHERE CustomerID IS NOT NULL;

cust_to_txn_df = session.sql(f'''
    SELECT
     COUNT(DISTINCT CustomerID) as Customers,
     COUNT(DISTINCT InvoiceNo) as Transactions
    FROM orders
    WHERE CustomerID IS NOT NULL
''')
display(cust_to_txn_df.limit(5).to_pandas())

A little quick math may lead us to estimate that, on average, each customer is responsible for about 5 transactions, but this would not provide an accurate representation of customer activity.

Instead, if we count the unique transactions by customer and then examine the frequency of these values, we see that many of the customers have engaged in a single transaction. The distribution of the count of repeat purchases declines from there in a manner that we may describe as negative binomial distribution (which is the basis of the NBD acronym included in the name of most BTYD models):

In [0]:
display(Markdown(f'the distribution of per-customer transaction counts '))

# SELECT
#   x.Transactions,
#   COUNT(x.*) as Occurrences
# FROM (
#   SELECT
#     CustomerID,
#     COUNT(DISTINCT InvoiceNo) as Transactions 
#   FROM orders
#   WHERE CustomerID IS NOT NULL
#   GROUP BY CustomerID
#   ) x
# GROUP BY 
#   x.Transactions
# ORDER BY
#   x.Transactions;
    
t_df = session.sql('''
SELECT
  x.Transactions,
  COUNT(x.*) as Occurrences
FROM (
  SELECT
    CustomerID,
    COUNT(DISTINCT InvoiceNo) as Transactions 
  FROM orders
  WHERE CustomerID IS NOT NULL
  GROUP BY CustomerID
  ) x
GROUP BY 
  x.Transactions
ORDER BY
  x.Transactions
''')
display(t.limit(5).to_pandas())

If we alter our last analysis to group a customer's transactions that occur on the same date into a single transaction - a pattern that aligns with metrics we will calculate later - we see that a few more customers are identified as non-repeat customers but the overall pattern remains the same:

In [0]:
display(Markdown(f'''
    the distribution of per-customer transaction counts with consideration of same-day transactions as a single transaction 
'''))

# SELECT
#   x.Transactions,
#   COUNT(x.*) as Occurances
# FROM (
#   SELECT
#     CustomerID,
#     COUNT(DISTINCT TO_DATE(InvoiceDate)) as Transactions
#   FROM orders
#   WHERE CustomerID IS NOT NULL
#   GROUP BY CustomerID
#   ) x
# GROUP BY 
#   x.Transactions
# ORDER BY
#   x.Transactions;

    
t_df = session.sql('''
SELECT
  x.Transactions,
  COUNT(x.*) as Occurances
FROM (
  SELECT
    CustomerID,
    COUNT(DISTINCT TO_DATE(InvoiceDate)) as Transactions
  FROM orders
  WHERE CustomerID IS NOT NULL
  GROUP BY CustomerID
  ) x
GROUP BY 
  x.Transactions
ORDER BY
  x.Transactions
''')
display(t.limit(5).to_pandas())

Focusing on customers with repeat purchases, we can examine the distribution of the days between purchase events. What's important to note here is that most customers return to the site within 2 to 3 months of a prior purchase.  Longer gaps do occur but significantly fewer customers have longer gaps between returns.  This is important to understand in the context of our BYTD models in that the time since we last saw a customer is a critical factor to determining whether they will ever come back with the probability of return dropping as more and more time passes since a customer's last purchase event:

In [0]:
display(Markdown(f'''
    distribution of per-customer average number of days between purchase events
'''))

# WITH CustomerPurchaseDates
#   AS (
#     SELECT DISTINCT
#       CustomerID,
#       TO_DATE(InvoiceDate) as InvoiceDate
#     FROM orders 
#     WHERE CustomerId IS NOT NULL
#     )
# SELECT -- Per-Customer Average Days Between Purchase Events
#   AVG(
#     DATEDIFF(a.NextInvoiceDate, a.InvoiceDate)
#     ) as AvgDaysBetween
# FROM ( -- Purchase Event and Next Purchase Event by Customer
#   SELECT 
#     x.CustomerID,
#     x.InvoiceDate,
#     MIN(y.InvoiceDate) as NextInvoiceDate
#   FROM CustomerPurchaseDates x
#   INNER JOIN CustomerPurchaseDates y
#     ON x.CustomerID=y.CustomerID AND x.InvoiceDate < y.InvoiceDate
#   GROUP BY 
#     x.CustomerID,
#     x.InvoiceDate
#     ) a
# GROUP BY CustomerID

t_df = session.sql('''
WITH CustomerPurchaseDates
  AS (
    SELECT DISTINCT
      CustomerID,
      TO_DATE(InvoiceDate) as InvoiceDate
    FROM orders 
    WHERE CustomerId IS NOT NULL
    )
SELECT -- Per-Customer Average Days Between Purchase Events
  AVG(
    DATEDIFF(a.NextInvoiceDate, a.InvoiceDate)
    ) as AvgDaysBetween
FROM ( -- Purchase Event and Next Purchase Event by Customer
  SELECT 
    x.CustomerID,
    x.InvoiceDate,
    MIN(y.InvoiceDate) as NextInvoiceDate
  FROM CustomerPurchaseDates x
  INNER JOIN CustomerPurchaseDates y
    ON x.CustomerID=y.CustomerID AND x.InvoiceDate < y.InvoiceDate
  GROUP BY 
    x.CustomerID,
    x.InvoiceDate
    ) a
GROUP BY CustomerID
''')
display(t.limit(5).to_pandas())

###Step 3: Calculate Customer Metrics

The dataset with which we are working consists of raw transactional history.  To apply the BTYD models, we need to derive several per-customer metrics:</p>

* **Frequency** - the number of dates on which a customer made a purchase subsequent to the date of the customer's first purchase
* **Age (T)** - the number of time units, *e.g.* days, since the date of a customer's first purchase to the current date (or last date in the dataset)
* **Recency** - the age of the customer (as previously defined) at the time of their last purchase

It's important to note that when calculating metrics such as customer age that we need to consider when our dataset terminates.  Calculating these metrics relative to today's date can lead to erroneous results.  Given this, we will identify the last date in the dataset and define that as *today's date* for all calculations.

To get started with these calculations, let's take a look at how they are performed using the built-in functionality of the lifetimes library:

In [0]:
import lifetimes

# set the last transaction date as the end point for this historical dataset
current_date = orders_pd['InvoiceDate'].max()

# calculate the required customer metrics
metrics_pd = (
  lifetimes.utils.summary_data_from_transaction_data(
    orders_pd,
    customer_id_col='CustomerID',
    datetime_col='InvoiceDate',
    observation_period_end = current_date, 
    freq='D'
    )
  )

# display first few rows
metrics_pd.head(10)

Unnamed: 0_level_0,frequency,recency,T
CustomerID,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1
12346,0.0,0.0,325.0
12347,6.0,365.0,367.0
12348,3.0,283.0,358.0
12349,0.0,0.0,18.0
12350,0.0,0.0,310.0
12352,6.0,260.0,296.0
12353,0.0,0.0,204.0
12354,0.0,0.0,232.0
12355,0.0,0.0,214.0
12356,2.0,303.0,325.0


The lifetimes library, like many Python libraries, is single-threaded.  Using this library to derive customer metrics on larger transactional datasets may overwhelm your system or simply take too long to complete. For this reason, let's examine how these metrics can be calculated using the distributed capabilities of Apache Spark.

As SQL is frequency employed for complex data manipulation, we'll start with a Spark SQL statement.  In this statement, we first assemble each customer's order history consisting of the customer's ID, the date of their first purchase (first_at), the date on which a purchase was observed (transaction_at) and the current date (using the last date in the dataset for this value).  From this history, we can count the number of repeat transaction dates (frequency), the days between the last and first transaction dates (recency), and the days between the current date and first transaction (T) on a per-customer basis:

In [0]:
# sql statement to derive summary customer stats

#display(Markdown(f''' '''))
# SELECT
#     a.customerid as CustomerID,
#     CAST(COUNT(DISTINCT a.transaction_at) - 1 as float) as frequency,
#     CAST(DATEDIFF(MAX(a.transaction_at), a.first_at) as float) as recency,
#     CAST(DATEDIFF(a.current_dt, a.first_at) as float) as T
#   FROM ( -- customer order history
#     SELECT DISTINCT
#       x.customerid,
#       z.first_at,
#       TO_DATE(x.invoicedate) as transaction_at,
#       y.current_dt
#     FROM orders x
#     CROSS JOIN (SELECT MAX(TO_DATE(invoicedate)) as current_dt FROM orders) y                                -- current date (according to dataset)
#     INNER JOIN (SELECT customerid, MIN(TO_DATE(invoicedate)) as first_at FROM orders GROUP BY customerid) z  -- first order per customer
#       ON x.customerid=z.customerid
#     WHERE x.customerid IS NOT NULL
#     ) a
#   GROUP BY a.customerid, a.current_dt, a.first_at
#   ORDER BY CustomerID

display(Markdown(f''' 
**NOTE:** Changes made:
 - DATEDIFF
'''))

sql = '''
  SELECT
    a.customerid as CustomerID,
    CAST(COUNT(DISTINCT a.transaction_at) - 1 as float) as frequency,
    CAST(DATEDIFF(day ,MAX(a.transaction_at), a.first_at) as float) as recency,
    CAST(DATEDIFF(day ,a.current_dt, a.first_at) as float) as T
  FROM ( -- customer order history
    SELECT DISTINCT
      x.customerid,
      z.first_at,
      TO_DATE(x.invoicedate) as transaction_at,
      y.current_dt
    FROM orders x
    CROSS JOIN (SELECT MAX(TO_DATE(invoicedate)) as current_dt FROM orders) y                                -- current date (according to dataset)
    INNER JOIN (SELECT customerid, MIN(TO_DATE(invoicedate)) as first_at FROM orders GROUP BY customerid) z  -- first order per customer
      ON x.customerid=z.customerid
    WHERE x.customerid IS NOT NULL
    ) a
  GROUP BY a.customerid, a.current_dt, a.first_at
  ORDER BY CustomerID
  '''
# capture stats in dataframe 
metrics_sql = session.sql(sql)
display(metrics_sql.limit(5).to_pandas())

Of course, Spark SQL does not require the DataFrame to be accessed exclusively using a SQL statement.  We may derive this same result using the Programmatic SQL API which may align better with some Data Scientist's preferences.  The code in the next cell is purposely assembled to mirror the structure in the previous SQL statement for the purposes of comparison:

In [0]:
# from snowflake.snowpark import Session
# import snowflake.snowpark.functions as F
# import snowflake.snowpark.types as T
from snowflake.snowpark.functions import to_date, datediff, max, min, countDistinct, count, sum, when
from snowflake.snowpark.types import FloatType

display(Markdown(f''' 
**NOTE:** Changes made:
 - JOIN method
 - datediff function
 - agg
 - orderBy => sort
'''))

# valid customer orders
x = orders.where(orders.CustomerID.isNotNull())

# calculate last date in dataset
y = (
  orders
    .groupBy()
    .agg(max(to_date(orders.InvoiceDate)).alias('current_dt'))
  )

# calculate first transaction date by customer
z = (
  orders
    .groupBy(orders.CustomerID)
    .agg(min(to_date(orders.InvoiceDate)).alias('first_at'))
  )


# combine customer history with date info 
a = (x
    .crossJoin(y)
    .join(z, x.CustomerID==z.CustomerID, join_type='inner')
    .select(
      x.CustomerID.alias('customerid'), 
      z.first_at, 
      F.to_date(x.InvoiceDate).alias('transaction_at'), 
      y.current_dt
      )
     .distinct()
    )

# calculate relevant metrics by customer
metrics_api = (a
           .groupBy(a.customerid, a.current_dt, a.first_at)
           .agg(
             [(countDistinct(a.transaction_at)-1).cast(FloatType()).alias('frequency'),
             datediff('day', max(a.transaction_at), a.first_at).cast(FloatType()).alias('recency'),
             datediff('day', a.current_dt, a.first_at).cast(FloatType()).alias('T')
             ]
             )
           .select('customerid','frequency','recency','T')
           .sort('customerid')
          )
display(metrics_api.limit(5).to_pandas())

Let's take a moment to compare the data in these different metrics datasets, just to confirm the results are identical.  Instead of doing this record by record, let's calculate summary statistics across each dataset to verify their consistency:

NOTE You may notice means and standard deviations vary slightly in the hundred-thousandths and millionths decimal places.  This is a result of slight differences in data types between the pandas and Spark DataFrames but do not affect our results in a meaningful way.

In [0]:
# summary data from lifetimes
metrics_pd.describe()

Unnamed: 0,frequency,recency,T
count,4372.0,4372.0,4372.0
mean,3.413541,133.72301,225.304209
std,6.674343,133.000474,118.384168
min,0.0,0.0,0.0
25%,0.0,0.0,115.0
50%,1.0,98.0,253.0
75%,4.0,256.0,331.0
max,145.0,373.0,373.0


In [0]:
# summary data from SQL statement
metrics_sql.toPandas().describe()

Unnamed: 0,FREQUENCY,RECENCY,T
count,4372.0,4372.0,4372.0
mean,3.413541,-133.72301,-225.304209
std,6.674343,133.000474,118.384168
min,0.0,-373.0,-373.0
25%,0.0,-256.0,-331.0
50%,1.0,-98.0,-253.0
75%,4.0,0.0,-115.0
max,145.0,0.0,0.0


In [0]:
# summary data from pyspark.sql API
metrics_api.toPandas().describe()

Unnamed: 0,frequency,recency,T
count,4372.0,4372.0,4372.0
mean,3.413541,133.723007,225.304214
std,6.674344,133.000473,118.384171
min,0.0,0.0,0.0
25%,0.0,0.0,115.0
50%,1.0,98.0,253.0
75%,4.0,256.0,331.0
max,145.0,373.0,373.0


The metrics we've calculated represent summaries of a time series of data.  To support model validation and avoid overfitting, a common pattern with time series data is to train models on an earlier portion of the time series (known as the *calibration* period) and validate against a later portion of the time series (known as the *holdout* period). In the lifetimes library, the derivation of per customer metrics using calibration and holdout periods is done through a simple method call.  Because our dataset consists of a limited range for data, we will instruct this library method to use the last 90-days of data as the holdout period.  A simple parameter called a widget on the Databricks platform has been implemented to make the configuration of this setting easily changeable:

NOTE To change the number of days in the holdout period, look for the textbox widget by scrolling to the top of your Databricks notebook after running this next cell

In [0]:
!pip install lifetimes==0.10.1

In [0]:
from datetime import timedelta
from snowflake.snowpark.functions import lit
import lifetimes

# set the last transaction date as the end point for this historical dataset
current_date = orders_pd['InvoiceDate'].max()

# define end of calibration period
# holdout_days = int(dbutils.widgets.get('holdout days'))
holdout_days=90
calibration_end_date = current_date - timedelta(days = holdout_days )

# calculate the required customer metrics
metrics_cal_pd = (
  lifetimes.utils.calibration_and_holdout_data(
    orders_pd,
    customer_id_col='CustomerID',
    datetime_col='InvoiceDate',
    observation_period_end = current_date,
    calibration_period_end=calibration_end_date,
    freq='D'    
    )
  )

# display first few rows
metrics_cal_pd.head(10)

Unnamed: 0_level_0,frequency_cal,recency_cal,T_cal,frequency_holdout,duration_holdout
CustomerID,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1
12346,0.0,0.0,235.0,0.0,<90 * Days>
12347,4.0,238.0,277.0,2.0,<90 * Days>
12348,2.0,110.0,268.0,1.0,<90 * Days>
12350,0.0,0.0,220.0,0.0,<90 * Days>
12352,3.0,34.0,206.0,3.0,<90 * Days>
12353,0.0,0.0,114.0,0.0,<90 * Days>
12354,0.0,0.0,142.0,0.0,<90 * Days>
12355,0.0,0.0,124.0,0.0,<90 * Days>
12356,1.0,80.0,235.0,1.0,<90 * Days>
12358,0.0,0.0,60.0,1.0,<90 * Days>


As before, we may leverage Spark SQL to derive this same information.  Again, we'll examine this through both a SQL statement and the programmatic SQL API.

To understand the SQL statement, first recognize that it's divided into two main parts.  In the first, we calculate the core metrics, *i.e.* recency, frequency and age (T), per customer for the calibration period, much like we did in the previous query example. In the second part of the query, we calculate the number of purchase dates in the holdout customer for each customer.  This value (frequency_holdout) represents the incremental value to be added to the frequency for the calibration period (frequency_cal) when we examine a customer's entire transaction history across both calibration and holdout periods.

To simplify our logic, a common table expression (CTE) named CustomerHistory is defined at the top of the query.  This query extracts the relevant dates that make up a customer's transaction history and closely mirrors the logic at the center of the last SQL statement we examined.  The only difference is that we include the number of days in the holdout period (duration_holdout):

In [0]:

display(Markdown(f''' 
**NOTE:** Changes made:
 - getArgument is replaced with formatted string
'''))

# vsql = f'''
# WITH CustomerHistory 
#   AS (
#     SELECT  -- nesting req'ed b/c can't SELECT DISTINCT on widget parameter
#       m.*,
#       cast(getArgument('holdout days') as int) as duration_holdout
#     FROM (
#       SELECT DISTINCT
#         x.customerid,
#         z.first_at,
#         TO_DATE(x.invoicedate) as transaction_at,
#         y.current_dt
#       FROM orders x
#       CROSS JOIN (SELECT MAX(TO_DATE(invoicedate)) as current_dt FROM orders) y                                -- current date (according to dataset)
#       INNER JOIN (SELECT customerid, MIN(TO_DATE(invoicedate)) as first_at FROM orders GROUP BY customerid) z  -- first order per customer
#         ON x.customerid=z.customerid
#       WHERE x.customerid IS NOT NULL
#     ) m
#   )
#        SELECT
#         *
#     FROM CustomerHistory p
# '''

vsql = f'''
WITH CustomerHistory 
  AS (
    SELECT  -- nesting req'ed b/c can't SELECT DISTINCT on widget parameter
      m.*,
      cast({holdout_days} as int) as duration_holdout
    FROM (
      SELECT DISTINCT
        x.customerid,
        z.first_at,
        TO_DATE(x.invoicedate) as transaction_at,
        y.current_dt
      FROM orders x
      CROSS JOIN (SELECT MAX(TO_DATE(invoicedate)) as current_dt FROM orders) y                                -- current date (according to dataset)
      INNER JOIN (SELECT customerid, MIN(TO_DATE(invoicedate)) as first_at FROM orders GROUP BY customerid) z  -- first order per customer
        ON x.customerid=z.customerid
      WHERE x.customerid IS NOT NULL
    ) m
  )
       SELECT
        *
    FROM CustomerHistory p
'''

# vsql_df = spark.sql(vsql)
# display(vsql_df)

vsql_df = session.sql(vsql)
display(vsql_df.limit(5).to_pandas())

In [0]:

# sql = '''
# WITH CustomerHistory 
#   AS (
#     SELECT  -- nesting req'ed b/c can't SELECT DISTINCT on widget parameter
#       m.*,
#       -- getArgument('holdout days') as duration_holdout -- <== updated by venkat
#       cast(getArgument('holdout days') as int) as duration_holdout -- <== updated by venkat
#     FROM (
#       SELECT DISTINCT
#         x.customerid,
#         z.first_at,
#         TO_DATE(x.invoicedate) as transaction_at,
#         y.current_dt
#       FROM orders x
#       CROSS JOIN (SELECT MAX(TO_DATE(invoicedate)) as current_dt FROM orders) y                                -- current date (according to dataset)
#       INNER JOIN (SELECT customerid, MIN(TO_DATE(invoicedate)) as first_at FROM orders GROUP BY customerid) z  -- first order per customer
#         ON x.customerid=z.customerid
#       WHERE x.customerid IS NOT NULL
#     ) m
#   )
# SELECT
#     a.customerid as CustomerID,
#     a.frequency as frequency_cal,
#     a.recency as recency_cal,
#     a.T as T_cal,
#     COALESCE(b.frequency_holdout, 0.0) as frequency_holdout,
#     a.duration_holdout
# FROM ( -- CALIBRATION PERIOD CALCULATIONS
#     SELECT
#         p.customerid,
#         CAST(p.duration_holdout as float) as duration_holdout,
#         CAST(DATEDIFF(MAX(p.transaction_at), p.first_at) as float) as recency,
#         CAST(COUNT(DISTINCT p.transaction_at) - 1 as float) as frequency,
#         CAST(DATEDIFF(DATE_SUB(p.current_dt, p.duration_holdout), p.first_at) as float) as T
#     FROM CustomerHistory p
#     WHERE p.transaction_at < DATE_SUB(p.current_dt, p.duration_holdout)  -- LIMIT THIS QUERY TO DATA IN THE CALIBRATION PERIOD
#     GROUP BY p.customerid, p.first_at, p.current_dt, p.duration_holdout
#   ) a
# LEFT OUTER JOIN ( -- HOLDOUT PERIOD CALCULATIONS
#   SELECT
#     p.customerid,
#     CAST(COUNT(DISTINCT p.transaction_at) as float) as frequency_holdout
#   FROM CustomerHistory p
#   WHERE 
#     p.transaction_at >= DATE_SUB(p.current_dt, p.duration_holdout) AND  -- LIMIT THIS QUERY TO DATA IN THE HOLDOUT PERIOD
#     p.transaction_at <= p.current_dt
#   GROUP BY p.customerid
#   ) b
#   ON a.customerid=b.customerid
# ORDER BY CustomerID
# '''

display(Markdown(f''' 
**NOTE:** Changes made:
 - DATE_SUB replaced with dateadd
'''))

sql = f'''
WITH CustomerHistory 
  AS (
    SELECT  -- nesting req'ed b/c can't SELECT DISTINCT on widget parameter
      m.*,
      cast({holdout_days} as int) as duration_holdout
    FROM (
      SELECT DISTINCT
        x.customerid,
        z.first_at,
        TO_DATE(x.invoicedate) as transaction_at,
        y.current_dt
      FROM orders x
      CROSS JOIN (SELECT MAX(TO_DATE(invoicedate)) as current_dt FROM orders) y                                -- current date (according to dataset)
      INNER JOIN (SELECT customerid, MIN(TO_DATE(invoicedate)) as first_at FROM orders GROUP BY customerid) z  -- first order per customer
        ON x.customerid=z.customerid
      WHERE x.customerid IS NOT NULL
    ) m
  )

 SELECT
     a.customerid as CustomerID,
     a.frequency as frequency_cal,
     a.recency as recency_cal,
     a.T as T_cal,
     COALESCE(b.frequency_holdout, 0.0) as frequency_holdout,
     a.duration_holdout
 FROM ( -- CALIBRATION PERIOD CALCULATIONS
     SELECT
         p.customerid,
         CAST(p.duration_holdout as float) as duration_holdout,
         CAST(DATEDIFF( 'day', MAX(p.transaction_at), p.first_at) as float) as recency,
         CAST(COUNT(DISTINCT p.transaction_at) - 1 as float) as frequency,
         CAST(DATEDIFF( 'day'
            ,DATEADD( 'day', -1 * p.duration_holdout, p.current_dt ) 
            ,p.first_at) as float) as T
     FROM CustomerHistory p
     WHERE p.transaction_at < DATEADD( 'day', -1 * p.duration_holdout, p.current_dt )   -- LIMIT THIS QUERY TO DATA IN THE CALIBRATION PERIOD
     GROUP BY p.customerid, p.first_at, p.current_dt, p.duration_holdout
   ) a

LEFT OUTER JOIN ( -- HOLDOUT PERIOD CALCULATIONS
  SELECT
    p.customerid,
    CAST(COUNT(DISTINCT p.transaction_at) as float) as frequency_holdout
  FROM CustomerHistory p
  WHERE 
    p.transaction_at >= DATEADD( 'day', -1 * p.duration_holdout, p.current_dt )  AND  -- LIMIT THIS QUERY TO DATA IN THE HOLDOUT PERIOD
    p.transaction_at <= p.current_dt
  GROUP BY p.customerid
  ) b
  ON a.customerid=b.customerid
ORDER BY CustomerID
  
'''

# metrics_cal_sql = spark.sql(sql)
# display(metrics_cal_sql)

metrics_cal_sql = session.sql(sql)
display(metrics_cal_sql.limit(5).to_pandas())

And here is the equivalent Programmatic SQL API logic:

In [0]:
from snowflake.snowpark.functions import avg, dateadd, coalesce, lit ,col #, expr
# from snowflake.snowpark.types import FloatType

display(Markdown(f''' 
**NOTE:** Changes made:
 - join method
'''))

# valid customer orders
x = orders.where(orders.CustomerID.isNotNull())

# calculate last date in dataset
y = (
  orders
    .groupBy()
    .agg(max(to_date(orders.InvoiceDate)).alias('current_dt'))
  )

# calculate first transaction date by customer
z = (
  orders
    .groupBy(orders.CustomerID)
    .agg(min(to_date(orders.InvoiceDate)).alias('first_at'))
  )

# combine customer history with date info (CUSTOMER HISTORY)
p = (x
    .crossJoin(y)
    .join(z, x.CustomerID==z.CustomerID, join_type='inner')
    .withColumn('duration_holdout', lit(int(holdout_days)))
    .select(
      x.CustomerID.alias('customerid'),
      z.first_at, 
      to_date(x.InvoiceDate).alias('transaction_at'), 
      y.current_dt, 
      'duration_holdout'
      )
     .distinct()
    )

# calculate relevant metrics by customer
# note: date_sub requires a single integer value unless employed within an expr() call
a = (p
       #.where(p.transaction_at < expr('date_sub(current_dt, duration_holdout)')) 
       .where(p.transaction_at < dateadd('day', -1 * col('duration_holdout') ,col('current_dt')  ))   
       .groupBy(p.customerid, p.current_dt, p.duration_holdout, p.first_at)
       .agg( [
         (countDistinct(p.transaction_at)-1).cast(FloatType()).alias('frequency_cal'),
         #datediff( max(p.transaction_at), p.first_at).cast(FloatType()).alias('recency_cal'),
         datediff( 'day', max(p.transaction_at), p.first_at).cast(FloatType()).alias('recency_cal'),
         #datediff( expr('date_sub(current_dt, duration_holdout)'), p.first_at).cast(FloatType()).alias('T_cal')
         datediff( 'day'
             #expr('date_sub(current_dt, duration_holdout)')
             ,dateadd('day',  -1 * col('duration_holdout') ,col('current_dt')  )
             ,p.first_at).cast(FloatType()).alias('T_cal')
       ])
    )

b = (p
      #.where((p.transaction_at >= expr('date_sub(current_dt, duration_holdout)')) & (p.transaction_at <= p.current_dt) )
      .where((p.transaction_at >=  dateadd('day',  -1 * col('duration_holdout') ,col('current_dt')  ) ) & (p.transaction_at <= p.current_dt) )
      .groupBy(p.customerid)
      .agg(
        countDistinct(p.transaction_at).cast(FloatType()).alias('frequency_holdout')
        )
   )

metrics_cal_api = (a
                 .join(b, a.customerid==b.customerid, join_type='left')
                 .select(
                   a.customerid.alias('CustomerID'),
                   a.frequency_cal,
                   a.recency_cal,
                   a.T_cal,
                   coalesce(b.frequency_holdout, lit(0.0)).alias('frequency_holdout'),
                   a.duration_holdout
                   )
                 .sort('CustomerID')
              )

#display(metrics_cal_api)
display(metrics_cal_api.limit(5).to_pandas())

Using summary stats, we can again verify these different units of logic are returning the same results:

In [0]:
# summary data from lifetimes
metrics_cal_pd.describe()

Unnamed: 0,frequency_cal,recency_cal,T_cal,frequency_holdout
count,3412.0,3412.0,3412.0,3412.0
mean,2.677608,90.587046,185.041618,1.502345
std,5.222838,96.077761,80.771943,2.495318
min,0.0,0.0,1.0,0.0
25%,0.0,0.0,125.0,0.0
50%,1.0,59.5,197.0,1.0
75%,3.0,175.0,268.0,2.0
max,93.0,282.0,283.0,52.0


In [0]:
# summary data from SQL statement
metrics_cal_sql.toPandas().describe()

Unnamed: 0,FREQUENCY_CAL,RECENCY_CAL,T_CAL,FREQUENCY_HOLDOUT,DURATION_HOLDOUT
count,3412.0,3412.0,3412.0,3412.0,3412.0
mean,2.677608,-90.587046,-185.041618,1.502345,90.0
std,5.222838,96.077761,80.771943,2.495318,0.0
min,0.0,-282.0,-283.0,0.0,90.0
25%,0.0,-175.0,-268.0,0.0,90.0
50%,1.0,-59.5,-197.0,1.0,90.0
75%,3.0,0.0,-125.0,2.0,90.0
max,93.0,0.0,-1.0,52.0,90.0


In [0]:
# summary data from pyspark.sql API
metrics_cal_api.toPandas().describe()

Unnamed: 0,FREQUENCY_CAL,RECENCY_CAL,T_CAL,FREQUENCY_HOLDOUT,DURATION_HOLDOUT
count,3412.0,3412.0,3412.0,3412.0,3412.0
mean,2.677608,-90.587046,-185.041618,1.502345,90.0
std,5.222838,96.077761,80.771943,2.495318,0.0
min,0.0,-282.0,-283.0,0.0,90.0
25%,0.0,-175.0,-268.0,0.0,90.0
50%,1.0,-59.5,-197.0,1.0,90.0
75%,3.0,0.0,-125.0,2.0,90.0
max,93.0,0.0,-1.0,52.0,90.0
