In [1]:
import numpy as np

In [19]:
"""Python class which implements UCB algorithm on multi armed bandit problem."""
class UCB:
    def __init__(self, n_arm, reward_fn,c):
        # The number of arms in the multi armed bandit problem, initializing
        self.n_arm = n_arm
        # The mean reward of the distinct arms, initializing
        self.arm_avg = np.zeros(n_arm)
        # The number of times the various arms are chosen, initializing
        self.num_arm = np.ones(n_arm)
        # The number which controls the degree of exploration
        self.c = c
        # The distribution of the arms, initializing
        self.reward_fn = reward_fn
         # Total number of iterations, initializing
        self.its = 0
    
    # Assigning values to arm_avg    
    def initialize(self):
        self.arm_avg = np.array([i() for i in self.reward_fn])
    
    
    # Choosing the best arm based on the ucb logic
    def best_arm(self):
        return np.argmax(self.arm_avg + self.c \
                         *(np.array([np.log(self.its)])/self.num_arm)**0.5)
    
    # Method which implements the UCB algorithm
    def update(self,its):
        self.its += its
        for i in range(its):
            arm = self.best_arm()
            # Pulling the lever and getting the reward from the reward distribution   
            reward = self.reward_fn[arm]()
            # Computing the total reward of the given arm
            num = (self.arm_avg[arm] * self.num_arm[arm] + reward)
            # The total number of times the given arm was selected
            denom = (self.num_arm[arm] + 1.0)
            # Computing the mean reward of the given arm
            self.arm_avg[arm] = num/denom
            # Increment the number of times the arm was chosen
            self.num_arm[arm] += 1
    
    # Method which returns the one d array of mean reward of the arms        
    def get_arm_avg(self):
        return self.arm_avg
                                                 
        

In [20]:
fns = [
    lambda: np.random.randn(),
    lambda: np.random.randn()+2,
    lambda: np.random.randn()+3,
    lambda: np.random.randn()+4,
    lambda: np.random.randn()+5
]

In [31]:
ucb_mab = UCB(5,fns,2)

In [32]:
ucb_mab.initialize()
ucb_mab.update(1000)

In [33]:
ucb_mab.best_arm()

4

In [34]:
ucb_mab.get_arm_avg()

array([0.39739061, 1.70003513, 2.64822582, 3.48043376, 4.9784943 ])