In [2]:
import numpy as np
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
from tqdm import tqdm

In [1]:
class example():
    def __init__(self):
        self.values = np.zeros(7)
        self.values[1:6] = 0.5
        self.values[6] = 1
        self.true_value = np.zeros(7)
        self.true_value[1:6] = np.arange(1, 6) / 6.0
        self.true_value[6] = 1
        self.action_left = 0
        self.action_right = 1
        
    def temporal_difference(self, values, alpha=0.1, batch=False):
        state = 3
        moves = [state]
        rewards = [0]
        while True:
            old_state = state
            if np.random.binomial(1, 0.5) == self.action_left:
                state = state - 1
            else:
                state = state + 1
            reward = 0
            moves.append(state)
            # TD update
            if not batch:
                self.values[old_state] += alpha * (reward + self.values[state] - self.values[old_state])
            if state == 6 or state == 0:
                break
            rewards.append(reward)
        return moves, rewards

    def monte_carlo(self, values, alpha=0.1, batch=False):
        state = 3
        moves = [3]
        while True:
            if np.random.binomial(1, 0.5) == self.action_left:
                state = state - 1
            else:
                state = state + 1
            moves.append(state)
            if state == 6:
                returns = 1.0
                break
            elif state == 0:
                returns = 0.0
                break
        if not batch:
            for state_ in moves[:-1]:
                # MC update step
                self.values[state_] =  self.values[state_]  + alpha * (returns - self.values[state_])
        return moves, [returns] * (len(moves) - 1)

    # Example 6.2 left
    def compute_state_value():
        episodes = [0, 1, 10, 100]
        current_values = np.copy(self.values)
        plt.figure(1)
        for i in range(episodes[-1] + 1):
            if i in episodes:
                plt.plot(current_values, label=str(i) + ' episodes')
            self.temporal_difference(current_values)
        plt.plot(TRUE_VALUE, label='true values')
        plt.xlabel('state')
        plt.ylabel('estimated value')
        plt.legend()
        
    def rms_error(self):
        td_alphas = [0.15, 0.1, 0.05]
        mc_alphas = [0.01, 0.02, 0.03, 0.04]
        episodes = 100 + 1
        runs = 100
        for i, alpha in enumerate(td_alphas + mc_alphas):
            total_errors = np.zeros(episodes)
            if i < len(td_alphas):
                method = 'TD'
                linestyle = 'solid'
            else:
                method = 'MC'
                linestyle = 'dashdot'
            for r in tqdm(range(runs)):
                errors = []
                current_values = np.copy(self.values)
                for i in range(0, episodes):
                    errors.append(np.sqrt(np.sum(np.power(self.true_values - current_values, 2)) / 5.0))
                    if method == 'TD':
                        self.temporal_difference(current_values, alpha=alpha)
                    else:
                        self.monte_carlo(current_values, alpha=alpha)
                total_errors += np.asarray(errors)
            total_errors /= runs
            plt.plot(total_errors, linestyle=linestyle, label=method + ', alpha = %.02f' % (alpha))
        plt.xlabel('episodes')
        plt.ylabel('RMS')
        plt.legend()

    def batch_updating(self, method, episodes, alpha=0.001):
        runs = 100
        total_errors = np.zeros(episodes)
        for r in tqdm(range(0, runs)):
            current_values = np.copy(self.values)
            errors = []
            moves = []
            rewards = []
            for ep in range(episodes):
                if method == 'TD':
                    trajectory_, rewards_ = self.temporal_difference(current_values, batch=True)
                else:
                    trajectory_, rewards_ = self.monte_carlo(current_values, batch=True)
                moves.append(trajectory_)
                rewards.append(rewards_)
                while True:
                    updates = np.zeros(7)
                    for trajectory_, rewards_ in zip(trajectories, rewards):
                        for i in range(0, len(trajectory_) - 1):
                            if method == 'TD':
                                updates[trajectory_[i]] += rewards_[i] + current_values[trajectory_[i + 1]] - current_values[trajectory_[i]]
                            else:
                                updates[trajectory_[i]] += rewards_[i] - current_values[trajectory_[i]]
                    updates *= alpha
                    if np.sum(np.abs(updates)) < 1e-3:
                        break
                    current_values += updates
                # calculate rms error
                errors.append(np.sqrt(np.sum(np.power(current_values - self.true_values, 2)) / 5.0))
            total_errors += np.asarray(errors)
        total_errors /= runs
        return total_errors

    def example_6_2(self):
        plt.figure(figsize=(10, 20))
        plt.subplot(2, 1, 1)
        compute_state_value()

        plt.subplot(2, 1, 2)
        rms_error()
        plt.tight_layout()

        plt.savefig('example_6_2.png')
        plt.close()

    def figure_6_2(self):
        episodes = 100 + 1
        td_erros = batch_updating('TD', episodes)
        mc_erros = batch_updating('MC', episodes)

        plt.plot(td_erros, label='TD')
        plt.plot(mc_erros, label='MC')
        plt.xlabel('episodes')
        plt.ylabel('RMS error')
        plt.legend()

        plt.savefig('figure_6_2.png')
        plt.close()

if __name__ == '__main__':
    e = example()
    e.example_6_2()
    e.figure_6_2()

100%|██████████| 100/100 [00:00<00:00, 257.76it/s]
100%|██████████| 100/100 [00:00<00:00, 242.19it/s]
100%|██████████| 100/100 [00:00<00:00, 277.72it/s]
100%|██████████| 100/100 [00:00<00:00, 331.97it/s]
100%|██████████| 100/100 [00:00<00:00, 323.45it/s]
100%|██████████| 100/100 [00:00<00:00, 337.55it/s]
100%|██████████| 100/100 [00:00<00:00, 320.30it/s]
100%|██████████| 100/100 [00:45<00:00,  2.01it/s]
100%|██████████| 100/100 [00:43<00:00,  2.49it/s]
