In [4]:
import numpy as np

In [5]:
"""Python class which implements the epsilon greedy algorithm on the multi armed bandit problem"""
class epsilon_greedy:
    # Initialization method
    def __init__(self,n_arm,reward_fn,eps=0.1):
        # Number of arms in the multi armed bandit problem, initializing
        self.n_arm = n_arm
        # The mean reward of the distinct arms, initializing
        self.arm_avg = np.zeros(n_arm)
        # The number of times the various arms are chosen, initializing
        self.num_arm = np.zeros(n_arm)
        # The value which controls explore exploitation logic,initializing
        self.eps = eps
        # The distribution of the arms, initializing
        self.reward_fn = reward_fn
        # Total number of iterations, initializing
        self.its = 0
        
    # The function which approximates the reward distribution 
    def update(self,its):
        self.its += its
        for i in range(its):
            # Implementing the epsilon greedy logic
            rand_eps = np.random.uniform(0,1,1)[0]
            
            if rand_eps < self.eps:
                # Exploration
                arm = np.random.choice(self.n_arm,1)[0]
            else:
                # Exploitation
                arm = self.best_arm()
            # Pulling the lever and getting the reward from the reward distribution   
            reward = self.reward_fn[arm]()
            # Computing the total reward of the given arm
            num = (self.arm_avg[arm] * self.num_arm[arm] + reward)
            # The total number of times the given arm was selected
            denom = (self.num_arm[arm] + 1.0)
            # Computing the mean reward of the given arm
            self.arm_avg[arm] = num/denom
            # Increment the number of times the arm was chosen
            self.num_arm[arm] += 1
            
        # The function which implements exploitation
    def best_arm(self):
        return np.argmax(self.arm_avg)
        
    def get_arm_avg(self):
        return self.arm_avg
        

In [6]:
fns = [
    lambda: np.random.randn(),
    lambda: np.random.randn()+2,
    lambda: np.random.randn()+3,
    lambda: np.random.randn()+4,
    lambda: np.random.randn()+5
]

In [7]:
eps_grd = epsilon_greedy(5,fns,0.1)

In [8]:
eps_grd.update(10000)

In [9]:
eps_grd.best_arm()

4

In [10]:
eps_grd.get_arm_avg()

array([0.0871986 , 1.88698143, 2.92570556, 4.05134568, 5.00976063])