# Developing an exploitative alternative to A/B testing

This is the code accompaning my [blog post](https://medium.com/@Lando-L/beyond-the-basics-reinforcement-learning-with-jax-part-ii-developing-an-exploitative-9423cb6b2fa5) on multi-arm bandits.

## Implementing the environment

In [None]:
import jax

# Numpy API with hardware acceleration and automatic differentiation
from jax import numpy as jnp

# Low level operators
from jax import lax

# API for working with pseudorandom number generators
from jax import random

In [None]:
# Random seed to make our experiment replicable 
SEED = 42

# Number of visitors we want to simulate
NUM_VISITS = 10000

# Expected click rates for the five variants with the
CLICK_RATES = [0.042, 0.03, 0.035, 0.038, 0.045]

In [None]:
def visit(state, timestep, click_rates, policy_fn, update_fn):
    """
    Simulates a user visit.
    """

    # Unpacking the experiment state into
    # the agent's parameters and the random number generator
    params, rng = state

    # Splitting the random number generator
    next_rng, policy_rng, user_rng = random.split(rng, num=3)

    # Selecting the variant to show the user, based on
    # the given policy, the agent's paramters, and the current timestep
    variant = policy_fn(params, timestep, policy_rng)

    # Randomly simulating the user click, based on
    # the variant's click rate
    clicked = random.uniform(user_rng) < click_rates[variant]

    # Calculating the agent's updated parameters, based on
    # the current parameters, the selected variant,
    # and whether or not the user clicked
    next_params = update_fn(params, variant, clicked)
    
    # Returning the updated experiment state (params and rng) and
    # whether or not the user clicked
    return (next_params, next_rng), clicked

## Implementing the policies

In [None]:
def action_value_init(num_variants):
    """
    Returns the initial action values
    """

    return {
        'n': jnp.ones(num_variants, dtype=jnp.int32),
        'q': jnp.ones(num_variants, dtype=jnp.float32)
    }

def action_value_update(params, variant, clicked):
    """
    Calculates the updated action values
    """

    # Reading n and q parameters of the selected variant
    n, q = params['n'][variant], params['q'][variant]

    # Converting the boolean clicked variable to a float value
    r = clicked.astype(jnp.float32)

    return {
        # Incrementing the counter of the taken action by one
        'n': params['n'].at[variant].add(1),

        # Incrementally updating the action-value estimate
        'q': params['q'].at[variant].add((r - q) / n)
    }

In [None]:
def epsilon_greedy_policy(params, timestep, rng, epsilon):
    """
    Randomly selects either the variant with highest action-value,
    or an arbitrary variant.
    """

    # Selecting a random variant
    def explore(q, rng):
        return random.choice(rng, jnp.arange(len(q)))
    
    # Selecting the variant with the highest action-value estimate
    def exploit(q, rng):
        return jnp.argmax(q)
    
    # Splitting the random number generator 
    uniform_rng, choice_rng = random.split(rng)
    
    # Deciding randomly whether to explore or to exploit
    return lax.cond(
        random.uniform(uniform_rng) < epsilon,
        explore,
        exploit,
        params['q'],
        choice_rng
    )

In [None]:
def boltzmann_policy(params, timestep, rng, tau):
    """
    Randomly selects a variant proportional to the current action-values
    """

    return random.choice(
        rng,
        jnp.arange(len(params['q'])),
        # Turning the action-value estimates into a probability distribution
        # by applying the softmax function controlled by tau
        p=jax.nn.softmax(params['q'] / tau)
    )

In [None]:
def upper_confidence_bound_policy(params, timestep, rng, confidence):
    """
    Selects the variant with highest action-value plus upper confidence bound
    """

    # Read n and q parameters
    n, q = params['n'], params['q']

    # Calculating each variant's upper confidence bound
    # and selecting the variant with the highest value
    return jnp.argmax(q + confidence * jnp.sqrt(jnp.log(timestep) / n))

In [None]:
def beta_init(num_variants):
    """
    Returns the initial hyperparameters of the beta distribution
    """

    return {
        'a': jnp.ones(num_variants, dtype=jnp.int32),
        'b': jnp.ones(num_variants, dtype=jnp.int32)
    }

def beta_update(params, variant, clicked):
    """
    Calculates the updated hyperparameters of the beta distribution
    """

    # Incrementing alpha by one
    def increment_alpha(a, b):
        return {'a': a.at[variant].add(1), 'b': b}
    
    # Incrementing beta by one
    def increment_beta(a, b):
        return {'b': b.at[variant].add(1), 'a': a}
    
    # Incrementing either alpha or beta
    # depending on whether or not the user clicked
    return lax.cond(
        clicked,
        increment_alpha,
        increment_beta,
        params['a'],
        params['b']
    )

In [None]:
def thompson_policy(params, timestep, rng):
    """
    Randomly sampling click rates for all variants
    and selecting the variant with the highest sample
    """

    return jnp.argmax(random.beta(rng, params['a'], params['b']))

## Implementing the evaluation

In [None]:
from functools import partial
from matplotlib import pyplot as plt

In [None]:
def evaluate(policy_fn, init_fn, update_fn):
    """
    Simulating the experiment for NUM_VISITS users
    while accumulating the click history
    """

    return lax.scan(
        # Compiling the visit function using just-in-time (JIT) compilation
        # for better performance
        jax.jit(
            # Partially applying the visit function by fixing
            # the click_rates, policy_fn, and update_fn parameters 
            partial(
                visit,
                click_rates=jnp.array(CLICK_RATES),
                policy_fn=jax.jit(policy_fn),
                update_fn=jax.jit(update_fn)
            )
        ),
        
        # Initialising the experiment state using
        # init_fn and a new PRNG key
        (init_fn(len(CLICK_RATES)), random.PRNGKey(SEED)),
        
        # Setting the number steps of the experiment
        jnp.arange(1, NUM_VISITS + 1)
    )

In [None]:
def regret(history):
    """
    Calculates the regret for every action in the experiment history
    """

    # Calculating regret with regard to picking the optimal (0.045) variant
    def fn(acc, reward):
        n, v = acc[0] + 1, acc[1] + reward
        return (n, v), 0.045 - (v / n)
    
    # Calculating regret values over entire history
    _, result = lax.scan(
        jax.jit(fn),
        (jnp.array(0), jnp.array(0)),
        history
    )
    
    return result

In [None]:
# Epsilon greedy policy
(epsilon_greedy_params, _), epsilon_greedy_history = evaluate(
    policy_fn=partial(epsilon_greedy_policy, epsilon=0.1),
    init_fn=action_value_init,
    update_fn=action_value_update
)

# Boltzmann policy
(boltzmann_params, _), boltzmann_history = evaluate(
    policy_fn=partial(boltzmann_policy, tau=1.0),
    init_fn=action_value_init,
    update_fn=action_value_update
)

# Upper confidence bound policy
(ucb_params, _), ucb_history = evaluate(
    policy_fn=partial(upper_confidence_bound_policy, confidence=2),
    init_fn=action_value_init,
    update_fn=action_value_update
)

# Thompson sampling policy
(ts_params, _), ts_history = evaluate(
    policy_fn=thompson_policy,
    init_fn=beta_init,
    update_fn=beta_update
)

In [None]:
# Visualisation
fig, ax = plt.subplots(figsize=(16, 8))

x = jnp.arange(1, NUM_VISITS + 1)

ax.set_xlabel('Number of visits')
ax.set_ylabel('Regret')

ax.plot(x, jnp.repeat(jnp.mean(jnp.array(CLICK_RATES)), NUM_VISITS), label='A/B Testing')
ax.plot(x, regret(epsilon_greedy_history), label='Espilon Greedy Policy')
ax.plot(x, regret(boltzmann_history), label='Boltzmann Policy')
ax.plot(x, regret(ucb_history), label='UCB Policy')
ax.plot(x, regret(ts_history), label='TS Policy')

plt.legend()
plt.show()