In [10]:
# coding: utf-8
import pdb
pdb.set_trace()

import matplotlib
import matplotlib.pyplot as plt
import numpy as np
from tqdm import trange

matplotlib.use('Agg')


class Bandit:
    # @k_arm: # of arms
    # @epsilon: probability for exploration in epsilon-greedy algorithm
    # @initial: initial estimation for each action
    # @step_size: constant step size for updating estimations
    # @sample_averages: if True, use sample averages to update estimations instead of constant step size
    # @UCB_param: if not None, use UCB algorithm to select action
    # @gradient: if True, use gradient based bandit algorithm
    # @gradient_baseline: if True, use average reward as baseline for gradient based bandit algorithm
    def __init__(self, k_arm=10, epsilon=0., initial=0., step_size=0.1, sample_averages=False, 
                 true_reward=0.):
        self.k = k_arm
        self.step_size = step_size
        self.sample_averages = sample_averages
        self.indices = np.arange(self.k)
        self.time = 0
        self.true_reward = true_reward
        self.epsilon = epsilon
        self.initial = initial
        

    def reset(self):
        # real reward for each action
        self.q_true = np.random.randn(self.k) + self.true_reward # np.random.normal(0, 0.01, self.k)
        # mean, standard deviation, num

        # estimation for each action
        self.q_estimation = np.zeros(self.k) + self.initial

        # # of chosen times for each action
        self.action_count = np.zeros(self.k)

        self.best_action = np.argmax(self.q_true)

        self.time = 0

    # get an action for this bandit
    def act(self):
        if np.random.rand() < self.epsilon:
            return np.random.choice(self.indices)
        
        q_best = np.max(self.q_estimation)
        return np.random.choice(np.where(self.q_estimation == q_best)[0])
        # 若有两个以上具有相同q_estimation

    # take an action, update estimation for this action
    def step(self, action, nonstationary):
        # generate the reward under N(real reward, 1) with stationary
        mean = np.random.normal(0, nonstationary)
        reward = self.q_true[action] + mean # np.random.randn() + 
        self.time += 1
        self.action_count[action] += 1
        
        if self.sample_averages:
            # update estimation using sample averages
            self.q_estimation[action] += (reward - self.q_estimation[action]) / self.action_count[action]
        else:
            # update estimation with constant step size
            self.q_estimation[action] += self.step_size * (reward - self.q_estimation[action])
        return reward


def simulate(runs, time, bandits, nonstationary):
    rewards = np.zeros((len(bandits), runs, time))
    best_action_counts = np.zeros(rewards.shape)
    for i, bandit in enumerate(bandits):
        for r in trange(runs):
            bandit.reset()
            for t in range(time):
                action = bandit.act()
                reward = bandit.step(action, nonstationary)
                rewards[i, r, t] = reward
                if action == bandit.best_action:
                    best_action_counts[i, r, t] = 1
    mean_best_action_counts = best_action_counts.mean(axis=1)
    mean_rewards = rewards.mean(axis=1)
    return mean_best_action_counts, mean_rewards
  

def figure_2_1(runs=2000, time=10000):    
    bandits = []
    bandits.append(Bandit(epsilon=0.1, sample_averages=True))
    bandits.append(Bandit(epsilon=0.1, sample_averages=False))
    nonstationary = 3
    best_action_counts, average_rewards = simulate(runs, time, bandits, nonstationary)

    plt.plot(average_rewards[0], label='sample_average')
    plt.plot(average_rewards[1], label='exponential recency-weighted average')
    plt.xlabel('Steps')
    plt.ylabel('Average reward')
    plt.legend()
    plt.title('nonstationary=0')

    plt.savefig('../images/figure_2_3.png')
    plt.close()

    plt.plot(best_action_counts[0], label='sample_average')
    plt.plot(best_action_counts[1], label='exponential recency-weighted average')
    plt.xlabel('Steps')
    plt.ylabel('% optimal action')
    plt.legend()
    plt.title('nonstationary=0')

    plt.savefig('../images/figure_2_4.png')
    plt.close() 
    
    
def figure_2_2(runs=2000, time=10000):
    bandits = []
    deviation=[0, 0.5, 1, 1.5]
    best_action_counts = []
    average_rewards = []
    bandits.append(Bandit(epsilon=0.1, sample_averages=True))
    for i in range(4):
        factor=deviation[i]
        best_action_countsi, average_rewardsi = simulate(runs, time, bandits, factor)
        best_action_counts.append(best_action_countsi[0])
        average_rewards.append(average_rewardsi[0])
        
    plt.subplot(2,2,1)
    plt.plot(average_rewards[3], label='nonstationary=1.5',color='b')
    # plt.xlabel('Steps')
    plt.ylabel('Average reward')
    plt.legend()

    plt.subplot(2,2,2)
    # plt.close()
    plt.plot(average_rewards[2], label='nonstationary=1',color='r')
    # plt.xlabel('Steps')
    # plt.ylabel('Average reward')
    plt.legend()

    plt.subplot(2,2,3)
    # plt.close()
    plt.plot(average_rewards[1], label='nonstationary=0.5',color='g')
    plt.xlabel('Steps')
    plt.ylabel('Average reward')
    plt.legend()

    plt.subplot(2,2,4)
    # plt.close()
    plt.plot(average_rewards[0], label='nonstationary=0',color='y')
    plt.xlabel('Steps')
    # plt.ylabel('Average reward')
    plt.legend()
    plt.tight_layout()

    plt.savefig('../images/figure_2_5png')
    plt.close()

    plt.plot(best_action_counts[0], label='nonstationary=0')
    plt.plot(best_action_counts[1], label='nonstationary=0.5')
    plt.plot(best_action_counts[2], label='nonstationary=1')
    plt.plot(best_action_counts[3], label='nonstationary=1.5')
    plt.xlabel('Steps')
    plt.ylabel('% optimal action')
    plt.legend()

    plt.savefig('../images/figure_2_6png')
    plt.close()
    
    
if __name__ == '__main__':
    figure_2_1()
    figure_2_2()                                              

--Return--
> <ipython-input-10-132cfa94a693>(12)<module>()->None
-> pdb.set_trace()
(Pdb) c


100%|██████████| 2000/2000 [05:16<00:00,  6.34it/s]
100%|██████████| 2000/2000 [05:08<00:00,  6.56it/s]
100%|██████████| 2000/2000 [05:13<00:00,  5.77it/s]
100%|██████████| 2000/2000 [05:18<00:00,  6.42it/s]
