Beta-geometric/Negative Binomial Distribution Model (BG/NBD) - Forecasting Individual-Level Repeat-Buying

In [14]:
import polars as pl
import numpy as np
import altair as alt
from scipy.optimize import minimize

In [23]:
class ChartTemp(alt.Chart):
    def __init__(self, data, **kwargs):
        super().__init__(data=data, **kwargs)
    
    def line_encode(self, 
                    y_col, 
                    y_scale=alt.Scale(), 
                    y_title = 'Cumulative Sales (# Transactions)', 
                    dash=[1,0], 
                    color=alt.Color(), 
                    x_col='Week', 
                    x_title='Week', 
                    x_range=52):
        
        line = self.mark_line(strokeWidth=2, strokeDash=dash).encode(
            x = alt.X(x_col, 
                      scale=alt.Scale(domain=[0, x_range]),
                      axis = alt.Axis(
                          values=np.arange(0, x_range+1, 4),
                          labelExpr="datum.value",
                          title=x_title)
            ),
            y = alt.Y(y_col, 
                      title=y_title,
                      scale=y_scale
            ),
            color=color
        )
        return line   
    
    def line_prop(self, title):
        line = self.properties(
            width=650,
            height=250,
            title=title
        ).configure_view(stroke=None).configure_axisY(grid=False).configure_axisX(grid=False)     
        
        return line

def layered_line_prop(chart, title):
    line = chart.properties(
        width=650,
        height=250,
        title=title
    ).configure_view(stroke=None).configure_axisY(grid=False).configure_axisX(grid=False)     
    
    return line

In [16]:
CDNOW_master = (
    pl.scan_csv(source = 'data/CDNOW/CDNOW_master.csv', 
                has_header=False, 
                separator=',', 
                schema={'CustID': pl.Int32,     # customer id
                        'Date': pl.String,      # transaction date
                        'Quant': pl.Int16,      # number of CDs purchased
                        'Spend': pl.Float64})   # dollar value (excl. S&H)
    .with_columns(pl.col('Date').str.to_date("%Y%m%d"))
    .with_columns((pl.col('Date') - pl.date(1996,12,31)).dt.total_days().cast(pl.UInt16).alias('PurchDay'))
    .with_columns((pl.col('Spend')*100).round(0).cast(pl.Int64).alias('Spend Scaled'))
    .group_by('CustID', 'Date')
    .agg(pl.col('*').exclude('PurchDay').sum(), pl.col('PurchDay').max()) # Multiple transactions by a customer on a single day are aggregated into one
    .sort('CustID', 'Date')
    .with_columns((pl.col("CustID").cum_count().over("CustID") - 1).cast(pl.UInt16).alias("DoR"))    
)

display(CDNOW_master.head().collect())

CustID,Date,Quant,Spend,Spend Scaled,PurchDay,DoR
i32,date,i64,f64,i64,u16,u16
1,1997-01-01,1,11.77,1177,1,0
2,1997-01-12,6,89.0,8900,12,0
3,1997-01-02,2,20.76,2076,2,0
3,1997-03-30,2,20.76,2076,89,1
3,1997-04-02,2,19.54,1954,92,2


In [17]:
RptSpend = (
    CDNOW_master
    .filter(pl.col('PurchDay') <= 273)
    .with_columns(pl.when(pl.col('DoR') > 0)
                  .then(pl.col('Spend Scaled'))
                  .otherwise(0)
                  .alias('Repeat Spend (Scaled)'))
    .with_columns(pl.col('PurchDay').filter(pl.col('DoR') == 0)
                  .first()
                  .over('CustID')
                  .alias('Trial Day'))
    .group_by('CustID')
    .agg(pl.col('Repeat Spend (Scaled)').sum(), pl.col('Trial Day').max())
    .sort('CustID')
)

RptSpend.head().collect()

CustID,Repeat Spend (Scaled),Trial Day
i32,i64,u16
1,0,1
2,0,12
3,4030,2
4,4469,1
5,23188,1


In [18]:
# Sampling technique - Python Method:
# id_df = (
#     RptSpend.collect()
#     .with_columns(((pl.col('Trial Day') - 1) // 7 + 1).alias('Trial Week'))
#     .sort(['Trial Week','Repeat Spend (Scaled)', 'CustID'], descending=[False, True, False], maintain_order=True)
# )

# sampledID = id_df[9::10].select('CustID')

# MATLAB Sampling (due to numerical float precision handling differences, original sampling results cannot be replicated unless spend is scaled in MATLAB)
CDNOW_sample = (
    pl.scan_csv(source='data/CDNOW/CDNOW_sample.csv',
                has_header=False,
                separator=',',
                schema={'CustID': pl.Int32,
                        'NewID': pl.Int32,
                        'Date': pl.String,
                        'Quant': pl.Int16,
                        'Spend': pl.Float64})
    .with_columns(pl.col('Date').str.to_date("%Y%m%d"))
    .with_columns((pl.col('Date') - pl.date(1996,12,31)).dt.total_days().cast(pl.UInt16).alias('PurchDay'))
    .with_columns((pl.col('Spend')*100).round(0).cast(pl.Int64).alias('Spend Scaled'))
    .group_by('CustID', 'Date')
    .agg(pl.col('*').exclude('PurchDay').sum(), pl.col('PurchDay').max())
    .sort('CustID', 'Date')
    .with_columns((pl.col("CustID").cum_count().over("CustID") - 1).cast(pl.UInt16).alias("DoR"))      
)

Creating summaries of the 1/10th sample data given the xMAT data structure. We divide the 78 weeks in half: Period 1 is a 39-week calibration period while Period 2 is a 39-week longitudinal holdout used for model validation.

In [19]:
# The number of repeat transactions made by each customer in each period
calwk = 273 # 39 week calibration period

px = (
    CDNOW_sample
    .collect()
    .group_by('CustID', maintain_order=True)
    .agg(
        pl.col('PurchDay')
        .filter((pl.col('PurchDay') <= calwk) & (pl.col('DoR') > 0))
        .count()
        .alias('P1X'),

        pl.col('PurchDay')
        .filter((pl.col('PurchDay') > calwk) & (pl.col('DoR') > 0))
        .count()
        .alias('P2X')        
    )
)
# The number of CDs purchased and total spend across these repeat transactions
pSpendQuant = (
    CDNOW_sample.collect()
    .join(px, on='CustID', how='left')
    .group_by('CustID', maintain_order=True)
    .agg(
        
        pl.col('Spend Scaled')
        .filter((pl.col('DoR') > 0) & (pl.col('DoR') <= pl.col('P1X')) & (pl.col('P1X') != 0))
        .sum()
        .alias('P1X Spend'),
        
        pl.col('Quant')
        .filter((pl.col('DoR') > 0) & (pl.col('DoR') <= pl.col('P1X')) & (pl.col('P1X') != 0))
        .sum()
        .alias('P1X Quant'),        
        
        pl.col('Spend Scaled')
        .filter((pl.col('DoR') > 0) & (pl.col('DoR') > pl.col('P1X')))
        .sum()
        .alias('P2X Spend'),
        
        pl.col('Quant')
        .filter((pl.col('DoR') > 0) & (pl.col('DoR') > pl.col('P1X')))
        .sum()
        .alias('P2X Quant')                
    )
)

# The average spend per repeat transaction
mx = (
    pSpendQuant
    .join(px, on='CustID', how='left')
    .with_columns(
        (pl.col('P1X Spend') / pl.col('P1X')).alias('Avg Spend per Repeat')
    ).fill_nan(0)
)

When fitting models such as the Pareto/NBD and BG/NBD to these data, we also want to know the “recency” information for each customer, as well as their effective calibration period:

In [20]:
# time of last calibration period repeat purchase (in weeks)
# effective calibration period (in weeks)
ttlrp = (
    CDNOW_sample.collect()
    .join(px, on='CustID', how='left')
    .with_columns(
        pl.col('PurchDay').filter(pl.col('DoR') == 0)
        .first()
        .over('CustID')
        .alias('Trial Day')
    )
    .group_by('CustID', maintain_order=True)
    .agg(
        pl.col('PurchDay', 'Trial Day')
        .filter(pl.col('DoR') <= pl.col('P1X'))
        .max()
        # .alias('LastPurch')
    )
    .with_columns(
        ((pl.col('PurchDay') - pl.col('Trial Day')) / 7).alias('Time to Last Repeat'),
        ((calwk - pl.col('Trial Day'))/7).alias('Eff. Calibration Period')
    )
    .drop('PurchDay', 'Trial Day')
)

Creating Summaries

In [25]:
# What is the total number of CDs purchased each week?
weeklysales = (
    CDNOW_master
    .with_columns(((pl.col('PurchDay') - 1) // 7 + 1).alias('Week'))
    .group_by('Week')
    .agg(pl.col('Quant').sum())
    .sort('Week')
    .collect()
)

(
    ChartTemp(weeklysales)
    .line_encode(y_col='Quant', y_title='Units Purchased', x_range=78)
    .line_prop('Weekly Repeat Sales')
)