# Stochastic Bernoulli Bandit

In [None]:
import numpy as np
from rich import print

from pybandits.model import Beta
from pybandits.smab import SmabBernoulli

In [None]:
# print 2 decimal places in the notebook
%precision %.2f

## 1. Initialization
The following two options are available to initialize the bandit.

### 1.1 Initialize via class constructor

You can initialize the bandit via the class constructor `SmabBernoulli()`. This is useful to impute prior knowledge on the Beta distributions.

In [None]:
mab = SmabBernoulli(
    actions={
        "a1": Beta(n_successes=1, n_failures=1),
        "a2": Beta(n_successes=1, n_failures=1),
        "a3": Beta(n_successes=1, n_failures=1),
    }
)

In [None]:
print(mab)

### 1.2 Initialize via utility function (for cold start)

You can initialize the bandit via the utility function `SmabBernoulliMOCC.cold_start()`. This is particulary useful in a cold start setting when there is no prior knowledge on the Beta distruibutions. In this case for all Betas `n_successes` and `n_failures` are set to `1`.

In [None]:
# generate a smab bernoulli in cold start settings
mab = SmabBernoulli.cold_start(action_ids=["a1", "a2", "a3"])

In [None]:
print(mab)

## 2. Function `predict()`

In [None]:
help(mab.predict)

In [None]:
# predict for 5 samples
actions, probs = mab.predict(n_samples=5)

In [None]:
actions

In [None]:
probs

In [None]:
# predict for 5 samples with forbidden actions, in this case `a1` will never be predicted.
actions, probs = mab.predict(n_samples=5, forbidden_actions=["a1"])

In [None]:
actions

In [None]:
probs

## 3. Function `update()`

In [None]:
help(mab.update)

In [None]:
# simulate rewards from the environment
rewards = [1, 0, 1, 1, 0]

In [None]:
# update
mab.update(actions=actions, rewards=rewards)
print(mab)

## 4. Example of usage

Simulate 10 updates, for each updates we predict actions for a batch of 1000 samples and then we update the bandit given the rewards.

In [None]:
n_updates = 10
batch_size = 1000

for _ in range(n_updates):
    # predict
    actions, _ = mab.predict(n_samples=batch_size)

    # simulate rewards from the environment
    rewards = np.random.choice([0, 1], size=batch_size).tolist()

    # update
    mab.update(actions=actions, rewards=rewards)

In [None]:
print(mab)