# k-armed Bandints

## Imports

In [1]:
import random
import numpy as np

## Simple Bandits Implementation

In [2]:
# for a full implementation of bandits in gym look for the github repo by JKCooper2
# https://github.com/JKCooper2/gym-bandits
# this implementation is a simplified version

class Bandit():
    def __init__(self, probs=[], rewards=[]):

        self.probs = probs
        self.rewards = rewards
        # k as in k-armed bandits, number of arms
        k = len(self.probs)

        self.action_space = k
        self.observation_space = 1

    def step(self, action):
        if random.random() < self.probs[action]:
            reward = self.rewards[action]
        else:
            reward = 0

        return reward


In [3]:
bandit = Bandit(probs=[0.5, 1], rewards=[10, 1])

---
Take action 0.

In [4]:
for i in range(10):
    print(f'Reward: {bandit.step(0)}')

Reward: 10
Reward: 0
Reward: 0
Reward: 0
Reward: 0
Reward: 0
Reward: 0
Reward: 0
Reward: 10
Reward: 0


---
Take action 1.

In [5]:
for i in range(10):
    print(f'Reward: {bandit.step(1)}')

Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1
Reward: 1


In [6]:
bandit = Bandit(probs=[0.01, 0.5,  1], rewards=[1000, 10, 1])
A = [x for x in range(bandit.action_space)]

In [9]:
def bandit_algorithm(bandit, A, num_episodes=1000000, alpha=0.00001, epsilon=0.5):
    num_actions = len(A)
    Q = np.zeros(num_actions)
    
    for episode in range(num_episodes):
        
        if np.random.rand() < epsilon:
            action = np.random.choice(num_actions)
        else:
            action = Q.argmax()
        
        reward = bandit.step(action)
        
        Q[action] = Q[action] + alpha * (reward - Q[action])
        
        if episode % 100000 == 0:
            print(Q)
        
    return Q

In [10]:
bandit_algorithm(bandit, A)

[0.e+00 0.e+00 1.e-05]
[1.62649711 2.41596374 0.15391956]
[3.00282233 3.66976933 0.28346033]
[4.12498715 4.32330869 0.3946706 ]
[6.62687958 4.48600634 0.48777093]
[8.18477955 4.56768371 0.56618373]
[9.17256206 4.632986   0.63263707]
[9.60565039 4.69325087 0.68863388]
[9.66814323 4.73439188 0.73619805]
[9.99095331 4.78215745 0.77715916]


array([9.93379602, 4.80734032, 0.81119081])