In [1]:
import matplotlib.pyplot as plt
from multi_armed_bandit import MultiArmBandit
import multi_armed_bandit
import numpy as np
from tqdm import tqdm

In [2]:
def simulate(bandit, runs, time):
    bandit.best_action_counts = np.zeros((runs, time))
    bandit.rewards = np.zeros((runs, time))
    for r in tqdm(range(runs)):
        bandit.reset()
        for t in range(time):
            bandit.time += 1
            action = bandit.act()
            if action == bandit.best_arm:
                bandit.best_action_counts[r, t] = 1
            reward = bandit.step(action)
            bandit.rewards[r, t] = reward
    bandit.rewards = bandit.rewards.mean(axis=0)  # taking average of all the runs
    bandit.best_action_counts = bandit.best_action_counts.mean(axis=0)  # taking average of all the runs
    return bandit.rewards, bandit.best_action_counts


In [3]:
runs = 2000
time = 1000
arms = 10

bandit0 = MultiArmBandit(arms, epsilon=0.1)
bandit1 = MultiArmBandit(arms, method=multi_armed_bandit.UPPER_BOUND_CONFIDENCE, confidence=2)
bandit2 = MultiArmBandit(arms, method=multi_armed_bandit.UPPER_BOUND_CONFIDENCE, confidence=1)
bandit3 = MultiArmBandit(arms, method=multi_armed_bandit.UPPER_BOUND_CONFIDENCE, confidence=4)


In [4]:
rewards0, _ = simulate(bandit0, runs, time)
rewards1, _ = simulate(bandit1, runs, time)
rewards2, _ = simulate(bandit2, runs, time)
rewards3, _ = simulate(bandit3, runs, time)


In [5]:
plt.figure(figsize=(10, 30))
# average reward vs steps for c = 2
plt.subplot(3, 1, 1)
plt.xlabel('steps')
plt.ylabel('average reward')

plt.plot(rewards0, label='epsilon=0.1 greedy')
plt.plot(rewards1, label='UCB with c=2')
plt.legend()

# average reward vs steps for c = 1
plt.subplot(3, 1, 2)
plt.xlabel('steps')
plt.ylabel('average reward')

plt.plot(rewards0, label='epsilon=0.1 greedy')
plt.plot(rewards2, label='UCB with c=1')
plt.legend()

# average reward vs steps for c = 4
plt.subplot(3, 1, 3)
plt.xlabel('steps')
plt.ylabel('average reward')

plt.plot(rewards0, label='epsilon=0.1 greedy')
plt.plot(rewards3, label='UCB with c=4')
plt.legend()

plt.savefig('./q6.png')
