In [44]:
import numpy as np
from scipy.stats import bernoulli

class Bernoulli_Bandit:
    
    # Accept a list of K >= 2 floats, each lying in [0,1]
    def __init__(self,means):
        self.bandit_means = means
        self.rand_regret = 0
        self.pseudo_rand_regret = 0
        self.best_mean = max(self.bandit_means)
    
    # Function that returns the number of bandit arms
    def K(self):
        return len(self.bandit_means)
    
    # Accepts a parameter 0 <= a <= K-1 and returns the
    # realization of random variable X with P(X=1) being the
    # mean of the (a+1)^th arm.
    def pull(self,a):
        selected_arm_mean = self.bandit_means[a]
        selected_arm_sample = bernoulli.rvs(selected_arm_mean,size=1)
        self.rand_regret = self.rand_regret + (self.best_mean - selected_arm_sample)
        self.pseudo_rand_regret = self.pseudo_rand_regret + (self.best_mean - selected_arm_sample)
        return selected_arm_sample
    
    # Returns the regret incurred so far.
    def rand_regret(self):
        return self.rand_regret
    
    def pseudo_rand_regret(self):
        return self.pseudo_rand_regret

In [45]:
def greedy_algorithm(bandit,n):
    
    # Implement a greedy algorithm for arm selection
    arm_values = np.zeros((bandit.K(),1))
    
    # First, test the value of each arm once
    for i in range(bandit.K()):
        arm_values[i] = bandit.pull(i)
    
    # Second, exploit the arm with the best value
    best_arm = np.argmax(arm_values)
    
    # NOTE: IN A REAL IMPLEMENTATION, WILL LIKELY 
    # HAVE AT LEAST ONE ARM WITH VALUE 1 AFTER A
    # SINGLE PULL -> SHOULD THEN RANDOMLY CHOOSE
    # BETWEEN THESE ARMS, I AM JUST PICKING THE
    # FIRST ONE THAT OBTAINED THE MAXIMUM VALUE
    
    for j in range(n-bandit.K()):
        bandit.pull(best_arm)
        

In [46]:
def main():
    
    arm_values = [0.5,0.65,0.1,0.25]
    num_trials = 100
    my_bandit = Bernoulli_Bandit(arm_values)
    greedy_algorithm(my_bandit,num_trials)

if __name__ == "__main__":
    main()

In [10]:
a = np.array([1,2,3,4,2,4,1])
print(np.argmax(a))

3
