<a href="https://colab.research.google.com/github/DaehanKim/reinforcement_learning_pytorch/blob/master/epsilon_greedy.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Epsilon Greedy Algorithm

This notebook solves MAB(Multi-armed bandit) problem with epsilon greedy algorithm.

In [0]:
import torch
import numpy as np
from collections import Counter

def pull_bandit(action_tensor):
    reward_mask = np.random.randn(action_tensor.size(0)) > bandit[action_tensor]
    reward = torch.empty(action_tensor.size(0), dtype=torch.int32)
    reward[reward_mask] = 1
    reward[~reward_mask] = 0
    return reward


In [64]:
# config
NUM_BANDIT = 8
T = 10000
LR = 1e-3
eps = .1
n_batch = 100
LOG_INT = 1000

# define bandit
bandit = np.random.randn(NUM_BANDIT)

# learnable parameter
empherical_prob = torch.nn.Parameter(torch.FloatTensor(NUM_BANDIT).uniform_(0,1))

sgd = torch.optim.SGD([empherical_prob],lr=LR)
trying_num = torch.zeros(NUM_BANDIT, dtype=torch.int32)
running_rewards = torch.zeros(NUM_BANDIT)

for _iter in range(T):
    # building choices
    loss = .0
    
    choice = torch.empty(n_batch, dtype=torch.int32)
    use_emp = torch.FloatTensor(n_batch).uniform_(0,1) > eps
    choice[use_emp] = torch.max(empherical_prob,0).indices
    choice[~use_emp] = torch.randint(NUM_BANDIT, ((~use_emp).sum(),))
    reward = pull_bandit(choice)
    sgd.zero_grad()
    
    for k,v in Counter(choice.data.tolist()).items():
        trying_num[k] += v
        running_rewards[k] += reward[k]*v
        loss += -torch.log(empherical_prob[k])*reward[k]*v
    
    loss.backward()
    sgd.step()

    if (_iter+1) % LOG_INT == 0: print('[Iter {}] Loss = {:.5f} Running_reward = {} trying_cnt ={}'.format(_iter+1, loss.item(), np.array(running_rewards.data), np.array(trying_num.data)))

print("Estimated Optimal Strategy : {}".format(empherical_prob.max(0).indices))
print("Answer Strategy : {}".format(bandit.argmin()))


[Iter 1000] Loss = -232.10327 Running_reward = [ 1151.  1058. 83248.  1122.  1119.  1141.  1044.  1215.] trying_cnt =[ 1272  1184 91338  1262  1235  1247  1173  1289]
[Iter 2000] Loss = -288.24347 Running_reward = [  2300.   2189. 166030.   2253.   2272.   2316.   2201.   2299.] trying_cnt =[  2552   2416 182568   2521   2486   2532   2444   2481]
[Iter 3000] Loss = -291.41095 Running_reward = [  3447.   3326. 249596.   3366.   3373.   3447.   3231.   3432.] trying_cnt =[  3825   3697 273906   3773   3684   3790   3601   3724]
[Iter 4000] Loss = -307.73984 Running_reward = [  4604.   4479. 332037.   4537.   4481.   4491.   4412.   4619.] trying_cnt =[  5103   4982 365098   5059   4890   4930   4898   5040]
[Iter 5000] Loss = -324.03775 Running_reward = [  5761.   5566. 416357.   5671.   5634.   5548.   5525.   5739.] trying_cnt =[  6385   6180 456398   6312   6165   6115   6145   6300]
[Iter 6000] Loss = -318.77130 Running_reward = [  6857.   6723. 500039.   6858.   6899.   6673.   656