In [109]:
import polars as pl
import numpy as np
from scipy.special import gammaln, hyp2f1
from scipy.optimize import minimize

from IPython.display import display_markdown
import matplotlib.pyplot as plt
import matplotlib_inline

matplotlib_inline.backend_inline.set_matplotlib_formats('svg')

In [110]:
CDNOW = (
    pl.scan_csv(source='data/CDNOW/CDNOW_sample.csv',
                has_header=False,
                separator=',',
                schema={'CustID': pl.Int32,
                        'ID': pl.Int32,
                        'Date': pl.String,
                        'Quant': pl.Int16,
                        'Spend': pl.Float64})
    .with_columns(pl.col('Date').str.to_date("%Y%m%d"))
    .with_columns((pl.col('Date') - pl.date(1996,12,31)).dt.total_days().cast(pl.UInt16).alias('PurchDay'))
    .with_columns((pl.col('Spend')*100).round(0).cast(pl.Int64).alias('Spend Scaled'))
    .group_by('ID', 'Date')
    .agg(pl.col('*').exclude('PurchDay').sum(), pl.col('PurchDay').max())
    .sort('ID', 'Date')
    .with_columns((pl.col("ID").cum_count().over("ID") - 1).cast(pl.UInt16).alias("DoR"))      
    .drop('CustID')
)

In [111]:
calwk = 273 # 39 week calibration period

# The number of repeat transactions made by each customer in each period
freq_x = (
    CDNOW
    .group_by('ID', maintain_order=True)
    .agg(
        pl.col('PurchDay')
        .filter((pl.col('PurchDay') <= calwk) & (pl.col('DoR') > 0))
        .count()
        .alias('P1X'), # Period 1: Calibration Period

        pl.col('PurchDay')
        .filter((pl.col('PurchDay') > calwk) & (pl.col('DoR') > 0))
        .count()
        .alias('P2X')  # Period 2: Longitudinal Holdout Period      
    )
)

# The number of CDs purchased and total spend across these repeat transactions
pSpendQuant = (
    CDNOW
    .join(freq_x, on='ID', how='left')
    .group_by('ID', maintain_order=True)
    .agg(
        
        pl.col('Spend Scaled')
        .filter((pl.col('DoR') > 0) & (pl.col('DoR') <= pl.col('P1X')) & (pl.col('P1X') != 0))
        .sum()
        .alias('P1X Spend'),
        
        pl.col('Quant')
        .filter((pl.col('DoR') > 0) & (pl.col('DoR') <= pl.col('P1X')) & (pl.col('P1X') != 0))
        .sum()
        .alias('P1X Quant'),        
        
        pl.col('Spend Scaled')
        .filter((pl.col('DoR') > 0) & (pl.col('DoR') > pl.col('P1X')))
        .sum()
        .alias('P2X Spend'),
        
        pl.col('Quant')
        .filter((pl.col('DoR') > 0) & (pl.col('DoR') > pl.col('P1X')))
        .sum()
        .alias('P2X Quant')                
    )
)

# The average spend per repeat transaction
m_x = (
    pSpendQuant
    .join(freq_x, on='ID', how='left')
    .with_columns(
        (pl.col('P1X Spend') / pl.col('P1X')).alias('m_x_calib'),
        (pl.col('P2X Spend') / pl.col('P2X')).alias('m_x_valid')
    ).fill_nan(0)
)

# time of last calibration period repeat purchase (in weeks) - Recency
ttlrp = (
    CDNOW
    .join(freq_x, on='ID', how='left')
    .with_columns(
        pl.col('PurchDay').filter(pl.col('DoR') == 0)
        .first()
        .over('ID')
        .alias('Trial Day')
    )
    .group_by('ID', maintain_order=True)
    .agg(
        pl.col('PurchDay', 'Trial Day')
        .filter(pl.col('DoR') <= pl.col('P1X'))
        .max()
        # .alias('LastPurch')
    )
    .with_columns(
        # effective calibration period (in weeks)
        ((pl.col('PurchDay') - pl.col('Trial Day')) / 7).alias('t_x'), # Time to Last Repeat Purchase - Recency
        ((calwk - pl.col('Trial Day'))/7).alias('T')
    )
    .drop('PurchDay', 'Trial Day')
)

rfm_data = (
    m_x
    .join(other=ttlrp, on="ID", how="left")
    .rename({'P1X': 'x'})
    .select('ID', 'x', 't_x', 'T')
)

In [113]:
def bgnbd_est(rfm_data, guess={'r': 0.01, 'alpha': 0.01, 'a': 0.01, 'b':0.01}):
    
    def log_likelihood(x):
        r, alpha, a, b = x

        ln_A_1 = gammaln(rfm_data[:,0] + r) - gammaln(r) + r * np.log(alpha)
        ln_A_2 = gammaln(a + b) + gammaln(b + rfm_data[:,0]) - gammaln(b) - gammaln(a + b + rfm_data[:,0])
        ln_A_3 = -(r + rfm_data[:,0]) * np.log(alpha + rfm_data[:,2])
        ln_A_4 = np.where(rfm_data[:,0] > 0, 
                          np.log(a) - np.log(b + rfm_data[:,0] - 1) - (r + rfm_data[:,0]) * np.log(alpha + rfm_data[:,1]),
                          0)
        return -np.sum(ln_A_1 + ln_A_2 + np.log(np.exp(ln_A_3) + (rfm_data[:,0] > 0) * np.exp(ln_A_4)))
    
    bnds = [(0, np.inf) for _ in range(4)]
    return minimize(log_likelihood, x0=list(guess.values()), bounds=bnds)

result = bgnbd_est(rfm_data.select('x', 't_x', 'T').collect().to_numpy())
r, alpha, a, b = result.x
ll = result.fun

display_markdown(f'''$r$ = {r:0.4f}

$\\alpha$ = {alpha:0.4f}

$a$ = {a:0.4f}

$b$ = {b:0.4f}

Log-Likelihood = {-ll:0.4f}''', raw=True)

  np.log(a) - np.log(b + rfm_data[:,0] - 1) - (r + rfm_data[:,0]) * np.log(alpha + rfm_data[:,1]),


$r$ = 0.2426

$\alpha$ = 4.4136

$a$ = 0.7929

$b$ = 2.4259

Log-Likelihood = -9582.4292

In [114]:
forecast_horizon = (calwk * 2) // 7

t = np.arange(1/7, forecast_horizon, 1/7)
z = t / (alpha + t)
h2f1 = hyp2f1(r, b, (a + b - 1), z)
E_X_t = (a + b - 1) / (a - 1) * (1 - (alpha / (alpha + t))**r * h2f1)

In [136]:
tffr = (
    ttlrp
    .with_columns((39 - pl.col('T')).alias('Time of First Repeat'))
    .group_by('Time of First Repeat').agg(pl.len().alias('Count'))
    .sort('Time of First Repeat')
    .collect()
    .to_numpy()
)

num_triers = tffr[:, 1]
trial_week = tffr[:, 0]
time_trial_week = np.arange(1/7, np.max(trial_week), 1/7)

a, _ = np.meshgrid(time_trial_week,t)

test = ((t.reshape(-1,1) - a) * 7).astype(np.int16)

index = np.where(test < 0, 0, test) - 1

np.tril(E_X_t[np.clip(index, 0, E_X_t.shape[0] - 1)], k=-1)

array([[0.        , 0.        , 0.        , ..., 0.        , 0.        ,
        0.        ],
       [0.00781366, 0.        , 0.        , ..., 0.        , 0.        ,
        0.        ],
       [0.01555185, 0.00781366, 0.        , ..., 0.        , 0.        ,
        0.        ],
       ...,
       [1.85180813, 1.84770019, 1.84770019, ..., 1.67810192, 1.67360794,
        1.67135695],
       [1.8538589 , 1.84975523, 1.84770019, ..., 1.68034491, 1.67585626,
        1.67360794],
       [1.8538589 , 1.85180813, 1.84975523, ..., 1.68034491, 1.67810192,
        1.67585626]], shape=(546, 84))