In [6]:
import matplotlib.pyplot as plt
import numpy as np
from tqdm import tqdm

from multi_armed_bandit import MultiArmBandit


In [7]:
def simulate(bandit, runs, time):
    bandit.best_action_counts = np.zeros((runs, time))
    bandit.rewards = np.zeros((runs, time))
    for r in tqdm(range(runs)):
        bandit.reset()
        for t in range(time):
            bandit.time += 1
            action = bandit.act()
            if action == bandit.best_arm:
                bandit.best_action_counts[r, t] = 1
            reward = bandit.step(action)
            bandit.rewards[r, t] = reward
            bandit.q_true += 0.01 if bandit.q_true < 0.5 else -0.01
    bandit.rewards = bandit.rewards.mean(axis=0)  # taking average of all the runs
    bandit.best_action_counts = bandit.best_action_counts.mean(axis=0)  # taking average of all the runs
    return bandit.rewards, bandit.best_action_counts


In [8]:
runs = 2000
time = 10000
arms = 10

bandit0 = MultiArmBandit(arms, epsilon=0.1)

In [9]:
rewards0, best_action_counts0 = simulate(bandit0, runs, time)

  0%|          | 0/2000 [00:00<?, ?it/s]

  0%|          | 1/2000 [00:00<08:58,  3.71it/s]

  0%|          | 2/2000 [00:00<08:55,  3.73it/s]

  0%|          | 3/2000 [00:00<08:58,  3.71it/s]

  0%|          | 4/2000 [00:01<08:41,  3.82it/s]

  0%|          | 5/2000 [00:01<07:56,  4.19it/s]

  0%|          | 6/2000 [00:01<07:29,  4.44it/s]

  0%|          | 7/2000 [00:01<07:01,  4.73it/s]

  0%|          | 8/2000 [00:01<06:39,  4.99it/s]

  0%|          | 9/2000 [00:01<06:23,  5.19it/s]

  0%|          | 10/2000 [00:01<06:12,  5.34it/s]

  1%|          | 11/2000 [00:02<06:03,  5.47it/s]

  1%|          | 12/2000 [00:02<05:53,  5.63it/s]

  1%|          | 13/2000 [00:02<05:46,  5.74it/s]

  1%|          | 14/2000 [00:02<05:40,  5.83it/s]

  1%|          | 15/2000 [00:02<05:33,  5.95it/s]

  1%|          | 16/2000 [00:02<05:30,  6.01it/s]

  1%|          | 17/2000 [00:02<05:25,  6.08it/s]

  1%|          | 18/2000 [00:02<05:21,  6.17it/s]

  1%|          | 19/2000 [00:03<05:17,  6.24it/s]

  1%|          | 20/2000 [00:03<05:14,  6.29it/s]

  1%|          | 21/2000 [00:03<05:11,  6.36it/s]

  1%|          | 22/2000 [00:03<05:08,  6.40it/s]




KeyboardInterrupt: 

In [1]:
plt.figure(figsize=(20, 30))

# average reward vs steps
plt.subplot(3, 2, 1)
plt.xlabel('steps')
plt.ylabel('average reward')

plt.plot(rewards0, label='epsilon = 0.1')
plt.legend()

# optimal action vs steps
plt.subplot(3, 2, 2)
plt.xlabel('steps')
plt.ylabel('% optimal action')

plt.plot(best_action_counts0, label='epsilon = 0.1')
plt.legend()

plt.savefig('./q5.png')
plt.close()


NameError: name 'plt' is not defined