In [None]:
import numpy as np
import plotly.express as px

In [None]:
class BernoulliBandit:
  def __init__(self, k: int):
    self.arms = []
    for i in range(k):
      self.arms.append(np.random.random())

  def pull(self, arm: int) -> bool:
    p = self.arms[arm]
    return np.random.random() < p

  def __str__(self):
    return "{}".format(self.arms)


In [None]:
bernoulli_bandit = BernoulliBandit(10)

In [None]:
class BetaDistribution:
  def __init__(self, a: float, b: float):
    self.a = a
    self.b = b


class Agent:
  def __init__(self, k: int):
    self.estimates = []
    self.last_action = -1

    for i in range(k):
      self.estimates.append(BetaDistribution(a=1, b=1))

  def choose(self):
    max_i = 0
    max_v = 0

    for i in range(len(self.estimates)):
      beta_distribution = self.estimates[i]
      v = np.random.beta(beta_distribution.a, beta_distribution.b)
      if max_v < v:
        max_v = v
        max_i = i

    self.last_action = max_i
    return max_i

  def observe(self, reward: bool):
    estimate = self.estimates[self.last_action]
    if reward:
      estimate.a += 1
    else:
      estimate.b += 1

In [None]:
print(bernoulli_bandit)

for n in [10, 100, 1000]:
  agent = Agent(len(bernoulli_bandit.arms))
  arms_chosen = []
  for i in range(n):
    arm = agent.choose()
    arms_chosen.append(arm)
    reward = bernoulli_bandit.pull(arm)
    agent.observe(reward)
  px.histogram(arms_chosen, labels={"value": "arm #"}, title="{} iterations".format(n)).show()


In [None]:
print(bernoulli_bandit)

best_arm = np.argmax(bernoulli_bandit.arms)

best_arm_percentage = {}
win_percentage = {}
for n in range(100, 10100, 100):
  agent = Agent(len(bernoulli_bandit.arms))
  best_arm_chosen_times = 0
  wins = 0
  for i in range(n):
    arm = agent.choose()
    if arm == best_arm:
      best_arm_chosen_times += 1
    arms_chosen.append(arm)
    reward = bernoulli_bandit.pull(arm)
    wins += 1 if reward else 0
    agent.observe(reward)

  best_arm_percentage[n] = (best_arm_chosen_times / float(n))
  win_percentage[n] = (wins / float(n))

px.line(x=best_arm_percentage.keys(), y=best_arm_percentage.values(), labels={"x": "Iterations", "y": "Best arm chosen (%)"}).show()
px.line(x=win_percentage.keys(), y=win_percentage.values(), labels={"x": "Iterations", "y": "Wins (%)"}).show()
