# Bayesian Bandit A/B-Testing

In [1]:
# !which python
# !pip install nbformat
# !pip install kaleido
# !makedir images
# !makedir video

In [2]:
from typing import Dict, List, Any, Union

import numpy as np
import pandas as pd
import math

from tqdm import tqdm

from scipy import stats
from scipy.stats import beta, gamma

# import util functions
from bayesian_bandit_test import Environment, Agent, Bandit
from bayesian_test import Bayesian_AB_Test

from graph import visualisation # conda install -n python3 -c conda-forge colorlover
from graph import Video
import plotly
import plotly.graph_objects as go
# Init visualisation tool
plot = visualisation(renderer="vscode") # vscode | iframe for browsers

from IPython.core.interactiveshell import InteractiveShell
InteractiveShell.ast_node_interactivity = "all"
pd.set_option('display.max_columns', None)
pd.set_option('display.max_rows', 5000)
pd.set_option('display.width', 10000)

tqdm.pandas()

# Parameters

In [3]:
# A/B-Test
# # No cold start problem
# BANDIT_PARAMS = {'A': {'period':0, 'ctr':0.1, 'cpm':1},
#                  'B': {'period':0, 'ctr':0.3, 'cpm':2}}

# With cold start problem
BANDIT_PARAMS = {'A': {'period':0, 'ctr':0.1, 'cpm':1},
                 'B': {'period':0, 'ctr':0.3, 'cpm':2}, 
                 'C': {'period':100, 'ctr': 0.05, 'cpm': 1},
                 'D': {'period':200, 'ctr': 0.4, 'cpm': 3}}

# Plotting
WIDTH_SAVE, HEIGHT_SAVE = 1200, 400

In [4]:
# # Early case
# config = {'optimise_for': 'ctr',
#           'n_periods': 500,
#           'max_impr_before_update_param': 100,
#           'recency_param': 0.6, # decay parameter`per day`
#           'n_periods_per_day': 24, # number of periods per day
#           'video': 'video/bandit_abcd_ctr_slow.mp4'
#          }

# For video
config = {'optimise_for': 'ctr',
          'n_periods': 500,
          'max_impr_before_update_param': 5000,
          'recency_param': 0.6, # decay parameter`per day`
          'n_periods_per_day': 1, # number of periods per day
          'video': 'video/bandit_abcd_ctr_fast.mp4'
         }

bandit = Bandit(bandit_params=BANDIT_PARAMS, n_periods=config['n_periods']+1, config=config)
bandit.run()
bandit.agent.df_log['A'].tail()
bandit.agent.df_log['B'].tail()
bandit.agent.df_log['C'].tail()
bandit.agent.df_log['D'].tail()
bandit.df_metrics.tail()

Unnamed: 0,period,n_impr,n_impr_w_sum,n_clicks,n_clicks_w_sum,cost,cost_sum,ctr,cpc,alpha,beta,a,scale,cost_w_sum
497,497,0,50.272235,0,4.253745,0.0,,0.084614,0.254978,5,47,2.084612,0.235087,1.084612
498,498,0,30.563341,0,2.952247,0.0,,0.096594,0.355921,4,29,2.050767,0.338725,1.050767
499,499,0,18.738005,0,2.171348,0.0,,0.115879,0.474572,3,18,2.03046,0.460543,1.03046
500,500,11,11.642803,1,1.702809,0.021523,,0.146254,0.597998,3,11,2.018276,0.587265,1.018276
501,501,254,13.985682,22,2.021685,0.482105,,0.144554,0.506449,3,13,2.02388,0.494637,1.02388


Unnamed: 0,period,n_impr,n_impr_w_sum,n_clicks,n_clicks_w_sum,cost,cost_sum,ctr,cpc,alpha,beta,a,scale,cost_w_sum
497,497,40,148.638025,14,45.752351,0.04,,0.307811,0.026486,47,104,2.211778,0.021857,1.211778
498,498,44,113.582815,16,36.25141,0.101773,,0.319163,0.031752,37,78,2.151067,0.027585,1.151067
499,499,20,94.949689,3,31.750846,0.045692,,0.334397,0.036273,33,64,2.151704,0.031495,1.151704
500,500,481,69.369813,126,21.250508,1.348128,,0.306337,0.052631,22,49,2.118437,0.047058,1.118437
501,501,282,330.621888,94,88.750305,0.520497,,0.268434,0.021182,90,243,2.879939,0.011268,1.879939


Unnamed: 0,period,n_impr,n_impr_w_sum,n_clicks,n_clicks_w_sum,cost,cost_sum,ctr,cpc,alpha,beta,a,scale,cost_w_sum
397,497,9,90.563874,1,5.581822,0.031806,,0.061634,0.198743,7,86,2.109346,0.179153,1.109346
398,498,0,60.138324,0,4.349093,0.0,,0.072318,0.249406,5,57,2.084691,0.229933,1.084691
399,499,0,36.482995,0,3.009456,0.0,,0.082489,0.349171,4,34,2.050815,0.332286,1.050815
400,500,0,22.289797,0,2.205674,0.0,,0.098954,0.467199,3,21,2.030489,0.453376,1.030489
401,501,6,13.773878,0,1.723404,0.006,,0.125121,0.590862,3,13,2.018293,0.580247,1.018293


Unnamed: 0,period,n_impr,n_impr_w_sum,n_clicks,n_clicks_w_sum,cost,cost_sum,ctr,cpc,alpha,beta,a,scale,cost_w_sum
297,497,3876,2893.937371,1541,1176.112288,7.929798,,0.406406,0.012715,1177,1719,15.954057,0.00085,14.954057
298,498,4825,4062.362423,1895,1630.667373,29.535848,,0.401409,0.008665,1632,2433,15.130313,0.000613,14.130313
299,499,386,5332.817454,140,2115.800424,0.682396,,0.396751,0.012572,2117,3218,27.599697,0.000473,26.599697
300,500,3362,3431.690472,1357,1353.880254,11.11812,,0.394523,0.012386,1355,2079,17.769256,0.000739,16.769256
301,501,3748,4076.614283,1479,1626.928153,14.188919,,0.399088,0.010531,1628,2451,18.132425,0.000615,17.132425


Unnamed: 0,period,n_impr,regret,P_ab_ctr,P_ab_cpc,loss_ctr,loss_cpc,p_overlap_ctr,p_overlap_cpc,n_impr_acc,regret_acc,regret_avg
496,497,3925,49,"[0.0, 0.0089, 0.0, 0.9911]","[0.0002, 0.0002, 0.0002, 0.9994]","{('A', 'B'): (0.212960845434553, 7.20479120661...","{('A', 'B'): (0.09375782339783675, 0.105478756...","{'ks': {('A', 'B'): 0.9410000000000001, ('A', ...",{'ks': {}},1227996,74015,0.060273
497,498,4869,44,"[0.0002, 0.0392, 0.0, 0.9606]","[0.0004, 0.0005, 0.0005, 0.9986]","{('A', 'B'): (0.1987111648563341, 0.0001879722...","{('A', 'B'): (0.08616561304246013, 0.114169147...","{'ks': {('A', 'B'): 0.889, ('A', 'C'): 0.954, ...",{'ks': {}},1232865,74059,0.060071
498,499,406,20,"[0.0037, 0.1245, 0.0, 0.8718]","[0.0, 0.0, 0.0, 1.0]","{('A', 'B'): (0.19641975684645438, 0.001045208...","{('A', 'B'): (0.08021563718199475, 0.121000394...","{'ks': {('A', 'B'): 0.868, ('A', 'C'): 0.971, ...",{'ks': {}},1233271,74079,0.060067
499,500,3854,492,"[0.0602, 0.0635, 0.0014, 0.8749]","[0.0001, 0.0001, 0.0, 0.9998]","{('A', 'B'): (0.11086945251522967, 0.016251977...","{('A', 'B'): (0.07288942529218133, 0.131279583...","{'ks': {('A', 'B'): 0.8160000000000001, ('A', ...",{'ks': {}},1237125,74571,0.060278
500,501,4290,542,"[0.0262, 0.0, 0.027, 0.9468]","[0.0, 0.0002, 0.0, 0.9998]","{('A', 'B'): (0.09511148140999033, 0.013169294...","{('A', 'B'): (0.11564150780160094, 0.075862587...","{'ks': {('A', 'B'): 0.91, ('A', 'C'): 0.942, (...",{'ks': {}},1241415,75113,0.060506


# Plotting

In [5]:
def extract_period(df: pd.DataFrame, period: int) -> pd.DataFrame:
    """ Extract data for given period
    """
    return {variant: df[variant][df[variant].period==period] for variant in df.keys() if sum(df[variant].period==period)>0}

In [6]:
# Impressions / Clicks over time
df = bandit.agent.df_log.copy()

p_data = []
for i, variant in enumerate(bandit.agent.variants):
    p_data += [ plot.plot(x=df[variant].period, y=df[variant].n_impr_w_sum, color=i, opacity=0.4, name=f'impr. {variant}', showlegend=True),
                plot.plot(x=df[variant].period, y=df[variant].n_clicks_w_sum, color=i, opacity=0.7, name=f'clicks {variant}', showlegend=True)]
layout = plot.layout(title=f'Observations - impr. & clicks', x_label='time', y_label='#', theme='dark', width=1200, height=400)
fig = go.Figure(data=p_data, layout=layout).show()
# layout['width'], layout['height'] = WIDTH_SAVE, HEIGHT_SAVE
# go.Figure(data=p_data, layout=layout).write_image('images/impr_clicks.png')

In [7]:
PERIOD = 300

df_T = extract_period(df=bandit.agent.df_log, period=PERIOD)

# Click-Through-Rate - Beta distribution
for variant in df_T:
    print(variant)
    df_T[variant]

x = np.linspace(0, 0.5, 1000)
p_data = [plot.plot(x=x, y=beta.pdf(x, df_T[variant].alpha, df_T[variant].beta), color=i, opacity=0.7, name=variant, showlegend=True) for i, variant in enumerate(df_T)]
layout = plot.layout(title=f'Beta distributions at T:{PERIOD}', x_label='Click-Through-Rate', y_label='p', theme='dark', width=1200, height=400)
layout['xaxis']['range'] = [0, 0.5]
fig = go.Figure(data=p_data, layout=layout).show()
layout['width'], layout['height'] = WIDTH_SAVE, HEIGHT_SAVE
go.Figure(data=p_data, layout=layout).write_image('images/bandit_beta_ab.png')

# Cost-per-Click - gamma distribution
x = np.linspace(0, 50, 1000)
p_data = [plot.plot(x=x, y=gamma.pdf(x, a=df_T[variant].a, scale=df_T[variant].scale), color=i, opacity=0.7, name=variant, showlegend=True) for i, variant in enumerate(df_T)]
layout = plot.layout(title=f'Gamma distributions at T:{PERIOD}', x_label='Cost-per-Click', y_label='p', theme='dark', width=1200, height=400)
layout['xaxis']['range'] = [0, 50]
fig = go.Figure(data=p_data, layout=layout).show()

A


Unnamed: 0,period,n_impr,n_impr_w_sum,n_clicks,n_clicks_w_sum,cost,cost_sum,ctr,cpc,alpha,beta,a,scale,cost_w_sum
300,300,0,58.648193,0,6.988616,0.0,,0.119162,0.151759,8,53,2.060582,0.14309,1.060582


B


Unnamed: 0,period,n_impr,n_impr_w_sum,n_clicks,n_clicks_w_sum,cost,cost_sum,ctr,cpc,alpha,beta,a,scale,cost_w_sum
300,300,79,88.229468,26,27.626415,0.148221,,0.31312,0.044553,29,62,2.230842,0.036197,1.230842


C


Unnamed: 0,period,n_impr,n_impr_w_sum,n_clicks,n_clicks_w_sum,cost,cost_sum,ctr,cpc,alpha,beta,a,scale,cost_w_sum
200,300,0,30.561569,0,2.665648,0.0,,0.087222,0.391644,4,29,2.043984,0.375143,1.043984


D


Unnamed: 0,period,n_impr,n_impr_w_sum,n_clicks,n_clicks_w_sum,cost,cost_sum,ctr,cpc,alpha,beta,a,scale,cost_w_sum
100,300,4110,4473.855362,1609,1755.635351,4.11,,0.392421,0.013147,1757,2719,24.081071,0.00057,23.081071


In [8]:
# Regret over time
p_data = [ plot.plot(x=bandit.df_metrics.period, y=bandit.df_metrics.regret, color=0, opacity=0.9, name=f'regret', showlegend=True)]
layout = plot.layout(title=f'Regret', x_label='periods', y_label='#', theme='dark', width=1200, height=400)
fig = go.Figure(data=p_data, layout=layout).show()
layout['width'], layout['height'] = WIDTH_SAVE, HEIGHT_SAVE
go.Figure(data=p_data, layout=layout).write_image('images/bandit_ab_regret.png')

# Regret - CDF
hist, bins = np.histogram(bandit.df_metrics.regret, bins=100)
p_data = [ plot.plot(x=bins, y=hist, color=0, opacity=0.6, name=f'regret', showlegend=True)]
layout = plot.layout(title=f'Regret - Distribution', x_label='periods', y_label='#', theme='dark', width=1200, height=400)
fig = go.Figure(data=p_data, layout=layout).show()


p_data = [ plot.plot(x=bins, y=np.cumsum(hist)/sum(hist), color=0, opacity=0.6, name=f'regret', showlegend=True)]
layout = plot.layout(title=f'Regret - CDF', x_label='periods', y_label='#', theme='dark', width=1200, height=400)
fig = go.Figure(data=p_data, layout=layout).show()

#### CTR

In [9]:
# P(A>B)
# map to dataframe, where each row is a period and each column is a variant
df_p_ab = pd.DataFrame(bandit.df_metrics.P_ab_ctr.to_list(), columns=bandit.agent.variants)

p_data = [ plot.plot(x=bandit.df_metrics.period, y=df_p_ab[variant], color=i, opacity=0.7, name=f'P - {variant}', showlegend=True) for i, variant in enumerate(df_p_ab.columns)]
layout = plot.layout(title=f'p_ab', x_label='periods', y_label='#', theme='dark', width=1200, height=400)
fig = go.Figure(data=p_data, layout=layout).show()

# Loss
# map loss_ctr, where each row is a period and each column is a variant
df_loss = pd.DataFrame(bandit.df_metrics.loss_ctr.to_list())
df_loss = df_loss.applymap(lambda x: (0, 0) if pd.isna(x) else x)

for i, variant in enumerate(df_loss.columns):
    tmp1 = df_loss[variant].apply(lambda x: x[0])
    tmp2 = df_loss[variant].apply(lambda x: x[1])
    p_data = [ plot.plot(x=bandit.df_metrics.period, y=tmp1, color=0, opacity=0.7, name=f'P - {variant} - A', showlegend=True),
               plot.plot(x=bandit.df_metrics.period, y=tmp2, color=1, opacity=0.7, name=f'P - {variant} - B', showlegend=True) ]
    layout = plot.layout(title=f'loss', x_label='periods', y_label='#', theme='dark', width=1200, height=400)
    fig = go.Figure(data=p_data, layout=layout).show()

#### CpC

In [10]:
# P(A>B)
# map to dataframe, where each row is a period and each column is a variant
df_p_ab = pd.DataFrame(bandit.df_metrics.P_ab_cpc.to_list(), columns=bandit.agent.variants)

p_data = [ plot.plot(x=bandit.df_metrics.period, y=df_p_ab[variant], color=i, opacity=0.7, name=f'P - {variant}', showlegend=True) for i, variant in enumerate(bandit.agent.variants)]
layout = plot.layout(title=f'p_ab', x_label='periods', y_label='#', theme='dark', width=1200, height=400)
fig = go.Figure(data=p_data, layout=layout).show()

# Loss
# map loss_ctr, where each row is a period and each column is a variant
df_loss = pd.DataFrame(bandit.df_metrics.loss_cpc.to_list())
df_loss = df_loss.applymap(lambda x: (0, 0) if pd.isna(x) else x)

for i, variant in enumerate(df_loss.columns):
    tmp1 = df_loss[variant].apply(lambda x: x[0])
    tmp2 = df_loss[variant].apply(lambda x: x[1])
    p_data = [ plot.plot(x=bandit.df_metrics.period, y=tmp1, color=0, opacity=0.7, name=f'P - {variant} - A', showlegend=True),
               plot.plot(x=bandit.df_metrics.period, y=tmp2, color=1, opacity=0.7, name=f'P - {variant} - B', showlegend=True) ]
    layout = plot.layout(title=f'loss', x_label='periods', y_label='#', theme='dark', width=1200, height=400)
    fig = go.Figure(data=p_data, layout=layout).show()

<hr>

### Video

In [11]:
# Bandit - AA-etst - CTR
N_STEPS = bandit.df_metrics.shape[0]-1

colormap = ['#ff0000', '#ff00ff', '#ffff00', '#00ff00']
video = Video(xlabel='CTR', x_lim=0.5, y_lim=100, n_versions=4, colormap=colormap, txt_pos=0.5)
with video.writer.saving(video.fig, config['video'], 200):
    x = np.linspace(0, 1, 50000)
    for period in tqdm(range(N_STEPS+1)):
        df_T = extract_period(df=bandit.agent.df_log, period=period)
        txt = 'Period: {}\n\nClicks  |  Impressions  |  P(a>b)\n'.format(period)
        for i, variant in enumerate(df_T):
            if config['optimise_for'] == 'ctr':
                video.plts[i].set_data(x, beta.pdf(x, df_T[variant].alpha.values[0], df_T[variant].beta.values[0]))
            if config['optimise_for'] == 'cpc':
                video.plts[i].set_data(x, gamma.pdf(x, a=df_T[variant].a.values[0], scale=df_T[variant].scale.values[0]))
        
            txt += '{}: {: >8.1f}  |  {: >8.1f}  |  {: >8.1f}\n'.format(variant,
                                                    df_T[variant].n_clicks_w_sum.values[0],
                                                    df_T[variant].n_impr_w_sum.values[0],
                                                    100*bandit.df_metrics.P_ab_ctr[period][i])

        txt += 'regret: {:.3f} '.format(bandit.df_metrics.regret_avg[period])
        video.txt_time.set_text(txt)
        video.writer.grab_frame(facecolor=video.fig.get_facecolor(), edgecolor='none')
print('Completed movie: {}'.format(config['video']))

100%|██████████| 501/501 [00:41<00:00, 12.00it/s]


Completed movie: video/bandit_abcd_ctr_fast.mp4
