In [4]:
import matplotlib.pyplot as plt
import numpy as np

In [6]:
"""A Python class which implements the greedy algorithm on the multi armed bandit problem."""

class greedy:
    # initialization method or constructor
    def __init__(self,n_arm,reward_fn):
        # Number of arms in the multi arm bandit problem initialization
        self.n_arm = n_arm
        # The mean reward of arms, initializing
        self.arm_avg = np.zeros(n_arm)
        # The number of times an arm is chosen, initialization
        self.num_arm = np.ones(n_arm)
        # The probability distribution of the arms, initializing
        self.reward_fn = reward_fn
        # Total number of iterations, initializing
        self.its = 0
        
    # At first choose all the arms once
    def initialize(self):
        self.arm_avg = np.array([reward() for reward in self.reward_fn])
     # The function which approximates the reward probability distribution   
    def update(self,its):
        self.its += its
        for i in range(its):
            # Selecting the arm which has the best arm_avg
            greedy_arm = self.best_arm()
            # Pulling the lever and getting the reward from the reward distribution
            reward = self.reward_fn[greedy_arm]() 
            # Computing the total reward of the greedy arm
            num = (self.arm_avg[greedy_arm] * self.num_arm[greedy_arm] + reward)
            # The total number of times greedy arm was selected
            denom = (self.num_arm[greedy_arm] + 1.0)
            # Computing the average reward of the greedy arm
            self.arm_avg[greedy_arm] = num / denom
            # Increment the number of times the arm was chosen
            self.num_arm[greedy_arm] += 1
            
    # The function which selects the best arm(greedy arm)
    def best_arm(self):
        return np.argmax(self.arm_avg)
        
        

In [7]:
fns = [
    lambda: np.random.randn(),
    lambda: np.random.randn()+2,
    lambda: np.random.randn()+3,
    lambda: np.random.randn()+4,
    lambda: np.random.randn()+5
]

In [13]:
grd = greedy(5,fns)

In [14]:
grd.initialize()
grd.update(200)
print(grd.best_arm())

4


In [15]:
""" The arm with mean = 5 and std 1(the fifth arm) is the best arm"""

' The arm with mean = 5 and std 1(the fifth arm) is the best arm'