In [1]:
import numpy as np
import matplotlib.pyplot as plt

In [80]:
# Agent class for multi-armed bandit problem (bernoulli bandit)

class Bandit:

    def __init__(self, mean, var, id = None):
        self.mean = mean
        self.var = var
        self.id = None
    
    def get_internal_state(self):
        return self.mean, self.var

    def get_id(self):
        return self.id
    
class MultiArmBandits:

    def __init__(self, num_bandits, dist_mean = None, dist_var = None):
        self.num_bandits = num_bandits
        self.bandits = []
        self.dist_mean = dist_mean
        self.dist_var = dist_var
        self.sampler = np.random.default_rng()
        self.init_bandits()

    def __len__(self):
        return self.num_bandits
    
    def init_bandits(self):
        if self.dist_mean is None:
            self.dist_mean = np.random.rand(self.num_bandits)
            self.dist_var = np.random.rand(self.num_bandits)
        for i in range(self.num_bandits):
            self.bandits.append(Bandit(self.dist_mean[i],self.dist_var[i], i ))
    
    def playBandit(self, bandit_id):
        bandit = self.bandits[bandit_id]
        return self.sampler.normal(bandit.get_internal_state()[0],bandit.get_internal_state()[1])
    
class EpsGreedySolver:

    def __init__(self, bandits, epsilon = 0.1):
        self.bandits = bandits
        self.epsilon = epsilon
        self.num_bandits = len(bandits)
        self.means = np.zeros(self.num_bandits)
        self.vars = np.zeros(self.num_bandits)
        self.n = np.zeros(self.num_bandits)
        self.rewards = []
    
    def update(self, bandit_id, reward):
        self.n[bandit_id] += 1
        self.means[bandit_id] += 1/self.n[bandit_id]*(reward - self.means[bandit_id])
        if self.n[bandit_id] > 1:
            self.vars[bandit_id] = ((self.n[bandit_id] - 2)/(self.n[bandit_id] - 1))*self.vars[bandit_id] + (1/self.n[bandit_id])*(reward - self.means[bandit_id])**2
    
    def choose_bandit(self):
        if np.random.rand() < self.epsilon:
            return np.random.randint(self.num_bandits)
        else:
            return np.argmax(self.means)
    
    def learn(self, num_steps = 1000):
        for i in range(num_steps):
            bandit_id = self.choose_bandit()
            reward = self.bandits.playBandit(bandit_id)
            self.rewards.append(reward)
            self.update(bandit_id, reward)

class UCBSolver:

    def __init__(self, bandits, c = 2):
        self.bandits = bandits
        self.c = c
        self.num_bandits = len(bandits)
        self.means = np.zeros(self.num_bandits)
        self.vars = np.zeros(self.num_bandits)
        self.n = np.zeros(self.num_bandits)
        self.rewards = []
    
    def update(self, bandit_id, reward):
        self.n[bandit_id] += 1
        self.means[bandit_id] += 1/self.n[bandit_id]*(reward - self.means[bandit_id])
        if self.n[bandit_id] > 1:
            self.vars[bandit_id] = ((self.n[bandit_id] - 2)/(self.n[bandit_id] - 1))*self.vars[bandit_id] + (1/self.n[bandit_id])*(reward - self.means[bandit_id])**2
    
    def choose_bandit(self):
        return np.argmax(self.means + self.c*np.sqrt(np.log(np.sum(self.n))/self.n))
    
    def learn(self, num_steps = 1000):
        for i in range(num_steps):
            bandit_id = self.choose_bandit()
            reward = self.bandits.playBandit(bandit_id)
            self.rewards.append(reward)
            self.update(bandit_id, reward)

In [81]:
bandit = MultiArmBandits(10)
solver = EpsGreedySolver(bandit, 0.4)
ucbsolver = UCBSolver(bandit, 0.01)

In [84]:
solver.learn(1000000)

In [89]:
ucbsolver.c = 10

In [91]:
ucbsolver.learn(1000000)

In [85]:
solver.means, solver.vars

(array([0.32320956, 0.80414552, 0.94916589, 0.63893062, 0.39152017,
        0.09533283, 0.37054841, 0.35218853, 0.80888821, 0.05004385]),
 array([3.01266950e-01, 3.69713100e-01, 7.58962204e-02, 4.90765746e-01,
        1.10889008e-04, 3.34305506e-01, 6.33028401e-01, 3.31227720e-02,
        3.29819497e-01, 4.37021155e-01]))

In [92]:
ucbsolver.means, ucbsolver.vars

(array([0.32466313, 0.79673052, 0.94935308, 0.64454419, 0.39147963,
        0.08919555, 0.38311382, 0.34944024, 0.8138008 , 0.04861066]),
 array([2.96455658e-01, 3.71593674e-01, 7.58984052e-02, 4.95965293e-01,
        1.12064743e-04, 3.21949212e-01, 6.42066803e-01, 3.26319676e-02,
        3.29620488e-01, 4.33593819e-01]))

In [76]:
UCBSolver.bandits.bandits[0].var

0.8994300254860429

In [55]:
solver.bandits.bandits[0].get_internal_state()

(0.9569209378738996, 0.08662555733430755)