# Bandit Optimization

In this notebook we use `boax` to run a Multi-Arm Bandit experiment.

In [1]:
import random

from jax import config
from matplotlib import pyplot as plt

config.update("jax_enable_x64", True)
plt.style.use('bmh')

from boax.experiments import bandit

In [2]:
CLICK_RATES = [0.042, 0.03, 0.035, 0.038, 0.045]

In [3]:
def objective(variant):
    return float(random.random() < CLICK_RATES[variant])

In [4]:
experiment = bandit(
    parameters=[
        {
            'name': 'variant',
            'type': 'choice',
            'values': [0, 1, 2, 3, 4],
        },
    ],
    method='upper_confidence_bound',
)

In [5]:
step, results = None, []

In [6]:
for i in range(10_000):
    # Print progress
    if i % 1_000 == 0:
        print('.', end='')

    # Retrieve next parameterizations to evaluate
    step, parameterizations = experiment.next(step, results)

    # Evaluate parameterizations
    evaluations = [
        objective(parameterization['variant'])
        for parameterization in parameterizations
    ]
    
    results = list(
        zip(parameterizations, evaluations)
    )

..........

In [7]:
# Predicted best
experiment.best(step)

({'variant': 2}, 0.04068119078874588)

In [8]:
# Actual best
{'variant': 4}, CLICK_RATES[4]

({'variant': 4}, 0.045)