In [17]:
import numpy as np
from scipy.stats import bernoulli

In [15]:
class BernoulliBandit :
    def __init__ (self , means):
        for i in means:
            assert(i <= 1 and i >= 0)
        self.means = means
        self.k = len(means)
        self.best_mean = max(means)
        self.total_rounds = 0
        self.regret = 0
# Function should return the number of arms
    def get_K( self ):
         return this.k
# Accepts a parameter 0 <= a <= K -1 and returns the
# realisation of random variable X with P(X = 1) being
# the mean of the (a +1) th arm .
    def pull (self , a):
        assert(a<= self.k & a >= 0)
        self.total_rounds += 1
        if (self.means[a] != self.best_mean):
         self.regret += self.means[a]
        return bernoulli.rvs(self.means[a], size=1)

    def get_regret ( self ):
        return self.regret

In [18]:
def follow_the_leader(n, bandit):
    actions = np.zeros(shape=n)
    # for each arm we keep track of the average so far and number of times pulled
    averages = np.zeros((bandit.k, 2))
    for t in range(n):     #pull each arm once
        if t<bandit.k:
            actions[t] = t  #pull the arms in order
            averages[(t,0)] = bandit.pull(t)
            averages[(t, 1)] = 1
        else:
            max_avg_index = np.argmax(averages, axis = 0)[0]
            actions[t] = max_avg_index
            averages[max_avg_index][0] = (averages[max_avg_index][0] * averages[max_avg_index][1] + bandit.pull(max_avg_index)) / (averages[max_avg_index][1] +1)
            averages[max_avg_index][1] += 1
    return actions

bandit = BernoulliBandit([0.55, 0.5, 0.6])
result = follow_the_leader(n = 200, bandit = bandit)
print(result)

[0. 1. 2. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
 1. 1. 1. 1. 1. 1. 1. 1.]


In [19]:
bandit = BernoulliBandit([0.2, 0.5, 0.9])
result = follow_the_leader(n = 200, bandit = bandit)
print(result)

[0. 1. 2. 1. 2. 2. 2. 2. 2. 2. 2. 2. 2. 2. 2. 2. 2. 2. 2. 2. 2. 2. 2. 2.
 2. 2. 2. 2. 2. 2. 2. 2. 2. 2. 2. 2. 2. 2. 2. 2. 2. 2. 2. 2. 2. 2. 2. 2.
 2. 2. 2. 2. 2. 2. 2. 2. 2. 2. 2. 2. 2. 2. 2. 2. 2. 2. 2. 2. 2. 2. 2. 2.
 2. 2. 2. 2. 2. 2. 2. 2. 2. 2. 2. 2. 2. 2. 2. 2. 2. 2. 2. 2. 2. 2. 2. 2.
 2. 2. 2. 2. 2. 2. 2. 2. 2. 2. 2. 2. 2. 2. 2. 2. 2. 2. 2. 2. 2. 2. 2. 2.
 2. 2. 2. 2. 2. 2. 2. 2. 2. 2. 2. 2. 2. 2. 2. 2. 2. 2. 2. 2. 2. 2. 2. 2.
 2. 2. 2. 2. 2. 2. 2. 2. 2. 2. 2. 2. 2. 2. 2. 2. 2. 2. 2. 2. 2. 2. 2. 2.
 2. 2. 2. 2. 2. 2. 2. 2. 2. 2. 2. 2. 2. 2. 2. 2. 2. 2. 2. 2. 2. 2. 2. 2.
 2. 2. 2. 2. 2. 2. 2. 2.]
