<a href="https://colab.research.google.com/github/DaehanKim/reinforcement_learning_pytorch/blob/master/epsilon_greedy.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Epsilon Greedy Algorithm

This notebook solves MAB(Multi-armed bandit) problem with epsilon greedy algorithm.

In [0]:
import torch
import numpy as np
from collections import Counter


In [6]:
def pull_bandit(action, bandit):
    if np.random.uniform(0,1,(1,)) > bandit[action]:
        return 0
    else: return 1

# config
NUM_BANDIT = 8
T = 1000
eps = .1
LOG_INT = 100

# define bandit
bandit = np.random.uniform(0,1,(NUM_BANDIT,))

# Initialize empherical prob ~ U(0,1)
empherical_prob = np.random.uniform(0,1,(NUM_BANDIT,))

trying_cnt = np.zeros(NUM_BANDIT)
running_rewards = np.zeros(NUM_BANDIT)

for _iter in range(T):
    # building choice 
    if np.random.uniform(0,1,(1,)) > eps:
        choice = empherical_prob.argmax()
    else: 
        choice = np.random.randint(NUM_BANDIT)
    reward = pull_bandit(choice, bandit)
    
    running_rewards[choice] += reward
    empherical_prob[choice] = (empherical_prob[choice]*trying_cnt[choice] + reward) / (trying_cnt[choice] + 1)
    trying_cnt[choice] += 1

    if (_iter+1) % LOG_INT == 0: print('[Iter {}] Running_reward = {} trying_cnt ={}'.format(_iter+1, running_rewards, trying_cnt))


print("Estimated Optimal Strategy : {}".format(empherical_prob.argmax()))
print("Answer Strategy : {}".format(bandit.argmax()))


[Iter 100] Running_reward = [ 1.  5.  8.  0.  2.  0.  0. 49.] trying_cnt =[ 2.  8. 13.  0.  4.  2.  0. 71.]
[Iter 200] Running_reward = [ 1.  5.  8. 58.  3.  0.  3. 68.] trying_cnt =[  2.  10.  14.  59.   5.   3.   5. 102.]
[Iter 300] Running_reward = [  1.   5.  12. 151.   4.   0.   4.  68.] trying_cnt =[  2.  11.  18. 152.   6.   3.   6. 102.]
[Iter 400] Running_reward = [  1.   5.  13. 240.   5.   0.   4.  70.] trying_cnt =[  2.  12.  19. 243.   8.   3.   8. 105.]
[Iter 500] Running_reward = [  2.   6.  14. 326.   6.   0.   6.  74.] trying_cnt =[  3.  14.  20. 330.   9.   3.  12. 109.]
[Iter 600] Running_reward = [  4.   8.  15. 415.   6.   0.   6.  74.] trying_cnt =[  6.  17.  21. 421.  10.   3.  13. 109.]
[Iter 700] Running_reward = [  4.   8.  17. 506.   6.   0.   8.  74.] trying_cnt =[  6.  17.  24. 512.  11.   5.  15. 110.]
[Iter 800] Running_reward = [  4.   8.  17. 598.   6.   0.   8.  75.] trying_cnt =[  6.  18.  26. 604.  12.   5.  18. 111.]
[Iter 900] Running_reward = [  5