<a href="https://colab.research.google.com/github/DaehanKim/reinforcement_learning_pytorch/blob/master/epsilon_greedy.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Epsilon Greedy Algorithm

This notebook solves MAB(Multi-armed bandit) problem with epsilon greedy algorithm.

In [0]:
import torch
import numpy as np
from collections import Counter

def pull_bandit(action_tensor):
    reward_mask = np.random.randn(action_tensor.size(0)) > bandit[action_tensor]
    reward = torch.empty(action_tensor.size(0), dtype=torch.int32)
    reward[reward_mask] = 1
    reward[~reward_mask] = 0
    return reward


In [23]:
# config
NUM_BANDIT = 8
T = 10000
LR = 1e-3
eps = .1
n_batch = 100
LOG_INT = 1000

# define bandit
bandit = np.random.randn(NUM_BANDIT)

# learnable parameter
empherical_prob = torch.FloatTensor(NUM_BANDIT).uniform_(0,1)


# sgd = torch.optim.SGD([empherical_prob],lr=LR)
trying_cnt = torch.zeros(NUM_BANDIT, dtype=torch.int32)
# running_rewards = torch.zeros(NUM_BANDIT)

for _iter in range(T):
    # building choices
    
    choice = torch.empty(n_batch, dtype=torch.int32)
    use_emp = torch.FloatTensor(n_batch).uniform_(0,1) > eps
    choice[use_emp] = torch.max(empherical_prob,0).indices
    choice[~use_emp] = torch.randint(NUM_BANDIT, ((~use_emp).sum(),))
    reward = pull_bandit(choice)
    
    for k,v in Counter(choice.data.tolist()).items():
        running_rewards[k] += reward[k]*v
        empherical_prob[k] = (empherical_prob[k]*trying_cnt[k] + reward[k]*v) / (trying_cnt[k] + v)
        trying_cnt[k] += v

    if (_iter+1) % LOG_INT == 0: print('[Iter {}] Running_reward = {} trying_cnt ={}'.format(_iter+1, np.array(running_rewards.data), np.array(trying_cnt.data)))


print("Estimated Optimal Strategy : {}".format(empherical_prob.max(0).indices))
print("Answer Strategy : {}".format(bandit.argmin()))


[Iter 1000] Running_reward = [ 479327.  120083.  390726.  357010.  119679. 1024317.  122613.  119086.] trying_cnt =[ 1326  1283  1349  1359  1410 90309  1267  1697]
[Iter 2000] Running_reward = [ 480329.  121019.  391849.  357997.  120672. 1096804.  123597.  120031.] trying_cnt =[  2598   2472   2699   2571   2681 181544   2531   2904]
[Iter 3000] Running_reward = [ 481313.  122028.  392828.  358956.  121623. 1168557.  124622.  121012.] trying_cnt =[  3825   3763   3915   3780   3907 272792   3835   4183]
[Iter 4000] Running_reward = [ 482262.  123070.  393812.  359932.  122592. 1241268.  125624.  122006.] trying_cnt =[  5035   5034   5129   5028   5170 364099   5086   5419]
[Iter 5000] Running_reward = [ 483299.  124068.  395146.  360850.  123648. 1311305.  126562.  122986.] trying_cnt =[  6318   6327   6910   6228   6477 454763   6311   6666]
[Iter 6000] Running_reward = [ 484304.  125078.  396130.  361879.  124596. 1383849.  127630.  123949.] trying_cnt =[  7551   7596   8163   7510