Skip to content

Commit

Permalink
Add experimental support for the BG/BB model
Browse files Browse the repository at this point in the history
Beta-Geometric/Beta-Bernoulli model is a special variant of the
beta-binomial model without the binomial coefficient. It is particularly
efficient for discrete-time analyses. It is considered "experimental" as
the associated "Fitter" in the lifetimes python package is buggy (see
CamDavidsonPilon/lifetimes#259). An override
is provided in lifetimes_ext which utilizes exponentials and natural
logarithms to reduce the observed convergence instability. See
custom_beta_geo_beta_binom_fitter.py for details.
BG/BB literature: https://brucehardie.com/papers/020/fader_et_al_mksc_10.pdf

Change-Id: Icc2701b42bc720635e5d96c9c809c48c0f93e945
GitOrigin-RevId: 9765a1b45087a8b0c05b7eb6070b32d64a81b017
  • Loading branch information
mohabfekry authored and Copybara-Service committed Aug 20, 2021
1 parent dd89587 commit f14cc0b
Show file tree
Hide file tree
Showing 13 changed files with 191 additions and 34 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -85,12 +85,12 @@
},
{
"regexes": [
"(BGNBD|MBGNBD|PNBD|)"
"(BGNBD|MBGNBD|BGBB|PNBD|)"
],
"name": "frequency_model_type",
"label": "Frequency Model Type",
"is_optional": true,
"helpText": "[Default MBGNBD] \"BGNBD\", \"MBGNBD\" or \"PNBD\". BG/NBD (or MBG/NBD) vs Pareto/NBD frequency models."
"helpText": "[Default MBGNBD] \"BGNBD\", \"MBGNBD\", \"BGBB\", or \"PNBD\". BG/NBD (or MBG/NBD), Pareto/NBD or the experimental BG/BB (a.k.a Beta Bernoulli) frequency model."
},
{
"regexes": [
Expand All @@ -108,7 +108,7 @@
"name": "model_time_granularity",
"label": "Model Time Granularity",
"is_optional": true,
"helpText": "[Default Weekly] One of \"Daily\", \"Weekly\", \"Monthly\". What time granularity to run the model with. \"Daily\" is most frequently used for apps customers who have a very short period of time between repeat transactions."
"helpText": "[Default Weekly] One of \"Daily\", \"Weekly\", \"Monthly\". What time granularity to run the model with. \"Daily\" is most frequently used for apps customers who have a very short period of time between repeat transactions, while \"Monthly\" is more suited for discrete-time transaction analysis using the experimental BG/BB model."
},
{
"regexes": [
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,12 +61,12 @@
},
{
"regexes": [
"(BGNBD|MBGNBD|PNBD|)"
"(BGNBD|MBGNBD|BGBB|PNBD|)"
],
"name": "frequency_model_type",
"label": "Frequency Model Type",
"is_optional": true,
"helpText": "[Default MBGNBD] \"BGNBD\", \"MBGNBD\" or \"PNBD\". BG/NBD (or MBG/NBD) vs Pareto/NBD frequency models."
"helpText": "[Default MBGNBD] \"BGNBD\", \"MBGNBD\", \"BGBB\", or \"PNBD\". BG/NBD (or MBG/NBD), Pareto/NBD or the experimental BG/BB (a.k.a. Beta Bernoulli) frequency model."
},
{
"regexes": [
Expand All @@ -84,7 +84,7 @@
"name": "model_time_granularity",
"label": "Model Time Granularity",
"is_optional": true,
"helpText": "[Default Weekly] One of \"Daily\", \"Weekly\", \"Monthly\". What time granularity to run the model with. \"Daily\" is most frequently used for apps customers who have a very short period of time between repeat transactions."
"helpText": "[Default Weekly] One of \"Daily\", \"Weekly\", \"Monthly\". What time granularity to run the model with. \"Daily\" is most frequently used for apps customers who have a very short period of time between repeat transactions, while \"Monthly\" is more suited for discrete-time transaction analysis using the experimental BG/BB model."
},
{
"regexes": [
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2019 Google LLC
# Copyright 2021 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2019 Google LLC
# Copyright 2021 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -24,6 +24,7 @@
import pandas as pd
from apache_beam.options import value_provider
import lifetimes
import lifetimes_ext

_SEGMENT_PREDICTION_THRESHOLD = 100000

Expand Down Expand Up @@ -59,6 +60,7 @@

_MODEL_TYPE_BGNBD = 'BGNBD'
_MODEL_TYPE_MBGNBD = 'MBGNBD'
_MODEL_TYPE_BGBB = 'BGBB'
_MODEL_TYPE_PNBD = 'PNBD'

date_formats = {
Expand Down Expand Up @@ -1095,6 +1097,35 @@ def extract_bgnbd_params(model):
return {'r': r, 'alpha': alpha, 'a': a, 'b': b}


def fit_bgbb_model(data, penalizer_coef=0.0):
"""Generates BG/BB model from the input data.
Args:
data: Pandas DataFrame containing the customers data.
penalizer_coef: The coefficient applied to an l2 norm on the parameters.
Returns:
The BG/BB model.
"""
bbf = lifetimes_ext.BetaGeoBetaBinomFitter(penalizer_coef=penalizer_coef)
bbf.fit(data['frequency'], data['recency'], data['total_time_observed'])
return bbf


def extract_bgbb_params(model):
"""Extracts params from the BG/BB model
Args:
model: the BG/BB model.
Returns:
The alpha, beta, gamma and delta params of the BG/BB model.
"""
alpha, beta, gamma, delta = model._unload_params(
'alpha', 'beta', 'gamma', 'delta')
return {'alpha': alpha, 'beta': beta, 'gamma': gamma, 'delta': delta}


def fit_pnbd_model(data, penalizer_coef=0.0):
"""Generates Pareto/NBD model from the input data.
Expand Down Expand Up @@ -1186,10 +1217,11 @@ def calc_full_fit_period(calibration_start_date, holdout_end_date,
# Has no negative side-effect


def expected_cumulative_transactions(frequency_model, t_cal, t_tot):
"""Calculates expected cumulative transaction for each interval.
def expected_cumulative_transactions(model_type, frequency_model, t_cal, t_tot):
"""Calculates expected cumulative transactions for each interval.
Args:
model_type: Type of the model in use.
frequency_model: Model fitted on customer's data.
t_cal: NumPy array of Total Time Observed values.
t_tot: Total number of periods to predict.
Expand All @@ -1209,19 +1241,25 @@ def expected_cumulative_transactions(frequency_model, t_cal, t_tot):
expected_cumulative_transactions_output.append(0)
continue

# Returns 1D array
exp_purchases = frequency_model.expected_number_of_purchases_up_to_time(
t=(interval - cust_birth[np.where(cust_birth <= interval)]))
if model_type == _MODEL_TYPE_BGBB:
exp_purchases = \
frequency_model.expected_number_of_transactions_in_first_n_periods(
n=interval)["model"].tolist()
# (frequency, model)
else:
exp_purchases = frequency_model.expected_number_of_purchases_up_to_time(
t=(interval - cust_birth[np.where(cust_birth <= interval)]))

expected_cumulative_transactions_output.append(np.sum(exp_purchases))

return np.around(np.array(expected_cumulative_transactions_output), 2)


def predict_txs(frequency_model, t_cal, intervals):
def predict_txs(model_type, frequency_model, t_cal, intervals):
"""Calculates transactions by time unit predictions.
Args:
model_type: Type of the model in use.
frequency_model: Model fitted on customer's data.
t_cal: NumPy array of Total Time Observed values.
intervals: Total number of periods to predict.
Expand All @@ -1230,7 +1268,7 @@ def predict_txs(frequency_model, t_cal, intervals):
Pandas DataFrame containing predicted future total purchases
"""
expected_cumulative = expected_cumulative_transactions(
frequency_model, t_cal, intervals)
model_type, frequency_model, t_cal, intervals)
expected_incremental = expected_cumulative - np.delete(
np.hstack(([0], expected_cumulative)), expected_cumulative.size - 1)

Expand Down Expand Up @@ -1436,7 +1474,7 @@ def frequency_model_validation_to_text(model_params):
Transactions observed for validation: {model_params['num_transactions_validation']} \
({model_params['perc_transactions_validation']} % of total transactions)
Validation Mean Absolute Percent Error (MAPE): {str(model_params['validation_mape'])}%"""
Validation Mean Absolute Percent Error (MAPE): {model_params['validation_mape']}%"""

return output_text

Expand All @@ -1452,8 +1490,7 @@ def frequency_model_validation(model_type, cbs, cal_start_date, cal_end_date,
Mean Absolute Percent Error (MAPE).
Args:
model_type: String defining the type of model to be used, it can be
either 'BGNBD', 'MBGNBD' or 'PNBD'.
model_type: String defining the type of model to be used.
cbs: Customer-by-sufficient-statistic (CBS) DataFrame.
cal_start_date: Calibration start date.
cal_end_date: Calibration end date.
Expand All @@ -1480,6 +1517,9 @@ def frequency_model_validation(model_type, cbs, cal_start_date, cal_end_date,
elif model_type == _MODEL_TYPE_MBGNBD:
frequency_model = fit_mbgnbd_model(cbs, penalizer_coef)
model_params['frequency_model'] = 'MBG/NBD'
elif model_type == _MODEL_TYPE_BGBB:
frequency_model = fit_bgbb_model(cbs, penalizer_coef)
model_params['frequency_model'] = 'BG/BB'
elif model_type == _MODEL_TYPE_PNBD:
frequency_model = fit_pnbd_model(cbs, penalizer_coef)
model_params['frequency_model'] = 'Pareto/NBD'
Expand All @@ -1489,8 +1529,8 @@ def frequency_model_validation(model_type, cbs, cal_start_date, cal_end_date,
# Transactions by time unit predictions
intervals = calc_full_fit_period(cal_start_date, hold_end_date,
time_divisor)
predicted = predict_txs(frequency_model, cbs['total_time_observed'].values,
intervals)
predicted = predict_txs(model_type, frequency_model,
cbs['total_time_observed'].values, intervals)

# Actual transactions per time unit
txs = repeat_tx
Expand Down Expand Up @@ -1527,7 +1567,8 @@ def frequency_model_validation(model_type, cbs, cal_start_date, cal_end_date,
) / txs.iloc[median_line:, :]['repeat_transactions_cumulative'] * 100
mape = error_by_time.abs().mean()

model_params['validation_mape'] = round(mape, 2)
model_params['validation_mape'] = (
'N/A' if model_type == _MODEL_TYPE_BGBB else str(round(mape, 2)))

# return tuple that includes the validation MAPE, which will be used for a
# threshold check
Expand Down Expand Up @@ -1654,7 +1695,10 @@ def calculate_model_fit_validation(_, options, dates, calcbs, repeat_tx,
# the allowed threshold. If so, continue the calculation. If not,
# fail with an error and stop all calculations.
error = None
if model_params['validation_mape'] > float(options[_OPTION_TRANSACTION_FREQUENCY_THRESHOLD]):
if (
options[_OPTION_FREQUENCY_MODEL_TYPE] != _MODEL_TYPE_BGBB and
float(model_params['validation_mape']) > float(
options[_OPTION_TRANSACTION_FREQUENCY_THRESHOLD])):
model_params['invalid_mape'] = True
error = (
f"Validation Mean Absolute Percent Error (MAPE) [{model_params['validation_mape']}%]"
Expand Down Expand Up @@ -1695,6 +1739,7 @@ def calculate_prediction_to_text(prediction_params, options):
Frequency Model: {prediction_params['frequency_model']}
Model Parameters
{model_params_to_string(prediction_params['bgnbd_model_params'] or
prediction_params['bgbb_model_params'] or
prediction_params['paretonbd_model_params'])}
Gamma-Gamma Parameters
{model_params_to_string(prediction_params['gamma_gamma_params'])}"""
Expand Down Expand Up @@ -1756,6 +1801,7 @@ def calculate_prediction(_, options, fullcbs, num_customers, num_txns):

prediction_params['frequency_model'] = 'BG/NBD'
prediction_params['bgnbd_model_params'] = bgnbd_params
prediction_params['bgbb_model_params'] = None
prediction_params['paretonbd_model_params'] = None

elif frequency_model_type == _MODEL_TYPE_MBGNBD:
Expand All @@ -1764,22 +1810,38 @@ def calculate_prediction(_, options, fullcbs, num_customers, num_txns):

prediction_params['frequency_model'] = 'MBG/NBD'
prediction_params['bgnbd_model_params'] = mbgnbd_params
prediction_params['bgbb_model_params'] = None
prediction_params['paretonbd_model_params'] = None

elif frequency_model_type == _MODEL_TYPE_BGBB:
frequency_model = fit_bgbb_model(data, options[_OPTION_PENALIZER_COEF])
bgbb_params = extract_bgbb_params(frequency_model)

prediction_params['frequency_model'] = 'BG/BB'
prediction_params['bgnbd_model_params'] = None
prediction_params['bgbb_model_params'] = bgbb_params
prediction_params['paretonbd_model_params'] = None

elif frequency_model_type == _MODEL_TYPE_PNBD:
frequency_model = fit_pnbd_model(data, options[_OPTION_PENALIZER_COEF])
pnbd_params = extract_pnbd_params(frequency_model)

prediction_params['frequency_model'] = 'Pareto/NBD'
prediction_params['paretonbd_model_params'] = pnbd_params
prediction_params['bgnbd_model_params'] = None
prediction_params['bgbb_model_params'] = None
prediction_params['paretonbd_model_params'] = pnbd_params

else:
raise ValueError('Model type %s is not valid' % frequency_model_type)

# Predict probability alive for customers
data['p_alive'] = frequency_model.conditional_probability_alive(
data['frequency'], data['recency'], data['total_time_observed'])
if frequency_model_type == _MODEL_TYPE_BGBB:
data['p_alive'] = frequency_model.conditional_probability_alive(
prediction_period, data['frequency'], data['recency'],
data['total_time_observed'])
else:
data['p_alive'] = frequency_model.conditional_probability_alive(
data['frequency'], data['recency'], data['total_time_observed'])

# Predict future purchases (X weeks/days/months)
if frequency_model_type == _MODEL_TYPE_PNBD:
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2019 Google LLC
# Copyright 2021 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -294,7 +294,7 @@ def run(argv=None):
{'name': 'perc_customers_cohort', 'type': 'FLOAT'},
{'name': 'num_transactions_validation', 'type': 'INTEGER'},
{'name': 'perc_transactions_validation', 'type': 'FLOAT'},
{'name': 'validation_mape', 'type': 'FLOAT'},
{'name': 'validation_mape', 'type': 'STRING'},
]}
]
},
Expand Down Expand Up @@ -402,6 +402,13 @@ def run(argv=None):
{'name': 'r', 'type': 'FLOAT'},
{'name': 'alpha', 'type': 'FLOAT'}
]},
{'name': 'bgbb_model_params', 'type': 'RECORD',
'fields': [
{'name': 'alpha', 'type': 'FLOAT'},
{'name': 'beta', 'type': 'FLOAT'},
{'name': 'gamma', 'type': 'FLOAT'},
{'name': 'delta', 'type': 'FLOAT'}
]},
{'name': 'paretonbd_model_params',
'type': 'RECORD',
'fields': [
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2019 Google LLC
# Copyright 2021 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#!/usr/bin/env bash
# Copyright 2019 Google LLC
# Copyright 2021 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
# Copyright 2021 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Simplify importing of extensions."""
from .custom_beta_geo_beta_binom_fitter import BetaGeoBetaBinomFitter

__all__ = (
"BetaGeoBetaBinomFitter",
)

0 comments on commit f14cc0b

Please sign in to comment.