In [1]:
import polars as pl
import numpy as np
from scipy.special import gammaln
from scipy.optimize import minimize

import matplotlib.pyplot as plt
import matplotlib_inline

matplotlib_inline.backend_inline.set_matplotlib_formats('svg')

In [2]:
CDNOW = (
    pl.scan_csv(source='data/CDNOW/CDNOW_sample.csv',
                has_header=False,
                separator=',',
                schema={'CustID': pl.Int32,
                        'ID': pl.Int32,
                        'Date': pl.String,
                        'Quant': pl.Int16,
                        'Spend': pl.Float64})
    .with_columns(pl.col('Date').str.to_date("%Y%m%d"))
    .with_columns((pl.col('Date') - pl.date(1996,12,31)).dt.total_days().cast(pl.UInt16).alias('PurchDay'))
    .with_columns((pl.col('Spend')*100).round(0).cast(pl.Int64).alias('Spend Scaled'))
    .group_by('ID', 'Date')
    .agg(pl.col('*').exclude('PurchDay').sum(), pl.col('PurchDay').max())
    .sort('ID', 'Date')
    .with_columns((pl.col("ID").cum_count().over("ID") - 1).cast(pl.UInt16).alias("DoR"))      
    .drop('CustID')
)

In [3]:
calwk = 273 # 39 week calibration period

# The number of repeat transactions made by each customer in each period
freq_x = (
    CDNOW
    .group_by('ID', maintain_order=True)
    .agg(
        pl.col('PurchDay')
        .filter((pl.col('PurchDay') <= calwk) & (pl.col('DoR') > 0))
        .count()
        .alias('P1X'), # Period 1: Calibration Period

        pl.col('PurchDay')
        .filter((pl.col('PurchDay') > calwk) & (pl.col('DoR') > 0))
        .count()
        .alias('P2X')  # Period 2: Longitudinal Holdout Period      
    )
)

# The number of CDs purchased and total spend across these repeat transactions
pSpendQuant = (
    CDNOW
    .join(freq_x, on='ID', how='left')
    .group_by('ID', maintain_order=True)
    .agg(
        
        pl.col('Spend Scaled')
        .filter((pl.col('DoR') > 0) & (pl.col('DoR') <= pl.col('P1X')) & (pl.col('P1X') != 0))
        .sum()
        .alias('P1X Spend'),
        
        pl.col('Quant')
        .filter((pl.col('DoR') > 0) & (pl.col('DoR') <= pl.col('P1X')) & (pl.col('P1X') != 0))
        .sum()
        .alias('P1X Quant'),        
        
        pl.col('Spend Scaled')
        .filter((pl.col('DoR') > 0) & (pl.col('DoR') > pl.col('P1X')))
        .sum()
        .alias('P2X Spend'),
        
        pl.col('Quant')
        .filter((pl.col('DoR') > 0) & (pl.col('DoR') > pl.col('P1X')))
        .sum()
        .alias('P2X Quant')                
    )
)

# The average spend per repeat transaction
m_x = (
    pSpendQuant
    .join(freq_x, on='ID', how='left')
    .with_columns(
        (pl.col('P1X Spend') / pl.col('P1X')).alias('m_x_calib'),
        (pl.col('P2X Spend') / pl.col('P2X')).alias('m_x_valid')
    ).fill_nan(0)
)

# time of last calibration period repeat purchase (in weeks) - Recency
ttlrp = (
    CDNOW
    .join(freq_x, on='ID', how='left')
    .with_columns(
        pl.col('PurchDay').filter(pl.col('DoR') == 0)
        .first()
        .over('ID')
        .alias('Trial Day')
    )
    .group_by('ID', maintain_order=True)
    .agg(
        pl.col('PurchDay', 'Trial Day')
        .filter(pl.col('DoR') <= pl.col('P1X'))
        .max()
        # .alias('LastPurch')
    )
    .with_columns(
        # effective calibration period (in weeks)
        ((pl.col('PurchDay') - pl.col('Trial Day')) / 7).alias('t_x'), # Time to Last Repeat Purchase - Recency
        ((calwk - pl.col('Trial Day'))/7).alias('T')
    )
    .drop('PurchDay', 'Trial Day')
)

rfm_data = (
    m_x
    .join(other=ttlrp, on="ID", how="left")
    .rename({'P1X': 'x'})
    .select('ID', 'x', 't_x', 'T')
)

In [27]:
def bgnbd_est(rfm_data, guess={'r': 0.1, 'alpha': 0.1, 'a': 0.1, 'b':0.1}):
    
    def log_likelihood(x):
        r, alpha, a, b = x
        return -(
            rfm_data
            .with_columns(
                (gammaln(pl.col('x') + r) - gammaln(r) + r * np.log(alpha)).alias('ln(A_1)'),
                (gammaln(a + b) + gammaln(b + pl.col('x')) - gammaln(b) - gammaln(a + b + pl.col('x'))).alias('ln(A_2)'),
                (-(r + pl.col('x')) * np.log(alpha + pl.col('T'))).alias('ln(A_3)'),
                (pl.when(pl.col('x') > 0)
                .then(np.log(a) - np.log(b + pl.col('x') - 1) - (r + pl.col('x')) * np.log(alpha + pl.col('t_x')))
                .otherwise(0)        
                ).alias('ln(A_4)')
            ).with_columns(
                (pl.col('ln(A_1)') + pl.col('ln(A_2)') + np.log(np.exp(pl.col('ln(A_3)')) + (pl.col('x') > 0) * np.exp(pl.col('ln(A_4)')))).alias('ln(.)')
            ).select(pl.col('ln(.)').sum())
            .collect().item(0,0)
        )
    bnds = [(0, np.inf) for _ in range(4)]
    return minimize(log_likelihood, x0=list(guess.values()), bounds=bnds)

In [28]:
bgnbd_est(rfm_data)

  message: CONVERGENCE: REL_REDUCTION_OF_F_<=_FACTR*EPSMCH
  success: True
   status: 0
      fun: 9582.44760081516
        x: [ 2.420e-01  4.394e+00  8.272e-01  2.559e+00]
      nit: 18
      jac: [-5.275e-02 -1.484e-01  1.306e-01  2.123e-01]
     nfev: 100
     njev: 20
 hess_inv: <4x4 LbfgsInvHessProduct with dtype=float64>

In [None]:
def bgnbd_est(rfm_data, guess={'r': 0.1, 'alpha': 0.1, 'a': 0.1, 'b':0.1}):
    
    def log_likelihood(x):
        r, alpha, a, b = x

        ln_A_1 = gammaln(rfm_data[:,0] + r) - gammaln(r) + r * np.log(alpha)
        ln_A_2 = gammaln(a + b) + gammaln(b + rfm_data[:,0]) - gammaln(b) - gammaln(a + b + rfm_data[:,0])
        ln_A_3 = -(r + rfm_data[:,0]) * np.log(alpha + rfm_data[:,2])
        ln_A_4 = np.where(rfm_data[:,0] > 0, 
                          np.log(a) - np.log(b + rfm_data[:,0] - 1) - (r + rfm_data[:,0]) * np.log(alpha + rfm_data[:,1]),
                          0)
        return -np.sum(ln_A_1 + ln_A_2 + np.log(np.exp(ln_A_3) + (rfm_data[:,0] > 0) * np.exp(ln_A_4)))
    
    bnds = [(0, np.inf) for _ in range(4)]
    return minimize(log_likelihood, x0=list(guess.values()), bounds=bnds)

bgnbd_est(rfm_data.select('x', 't_x', 'T').collect().to_numpy())

  np.log(a) - np.log(b + rfm_data[:,0] - 1) - (r + rfm_data[:,0]) * np.log(alpha + rfm_data[:,1]),


  message: CONVERGENCE: REL_REDUCTION_OF_F_<=_FACTR*EPSMCH
  success: True
   status: 0
      fun: 9582.447599720595
        x: [ 2.420e-01  4.394e+00  8.272e-01  2.559e+00]
      nit: 18
      jac: [-5.148e-02 -1.481e-01  1.308e-01  2.126e-01]
     nfev: 100
     njev: 20
 hess_inv: <4x4 LbfgsInvHessProduct with dtype=float64>