In [10]:
#######################################################################
# Copyright (C)                                                       #
# 2016-2018 Shangtong Zhang(zhangshangtong.cpp@gmail.com)             #
# 2016 Kenta Shimada(hyperkentakun@gmail.com)                         #
# Permission given to modify the code as long as you keep this        #
# declaration at the top                                              #
#######################################################################

import numpy as np
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
from tqdm import tqdm

# 0 is the left terminal state
# 6 is the right terminal state
# 1 ... 5 represents A ... E
VALUES = np.zeros(7)
VALUES[1:6] = 0.5
# For convenience, we assume all rewards are 0
# and the left terminal state has value 0, the right terminal state has value 1
# This trick has been used in Gambler's Problem
VALUES[6] = 1

# set up true state values
TRUE_VALUE = np.zeros(7)
TRUE_VALUE[1:6] = np.arange(1, 6) / 6.0
TRUE_VALUE[6] = 1

ACTION_LEFT = 0
ACTION_RIGHT = 1

# @values: current states value, will be updated if @batch is False
# @gamma: step size
# @batch: whether to update @values
def temporal_difference(values, gamma=0.1, batch=False):
    state = 3
    trajectory = [state]
    rewards = [0]
    while True:
        old_state = state
        if np.random.binomial(1, 0.5) == ACTION_LEFT:
            state -= 1
        else:
            state += 1
        # Assume all rewards are 0
        reward = 0
        trajectory.append(state)
        # TD update
        if not batch:
            values[old_state] += gamma * (reward + values[state] - values[old_state])
        if state == 6 or state == 0:
            break
        rewards.append(reward)
    return trajectory, rewards

# @values: current states value, will be updated if @batch is False
# @gamma: step size
# @batch: whether to update @values
def monte_carlo(values, gamma=0.1, batch=False):
    state = 3
    trajectory = [3]

    while True:
        if np.random.binomial(1, 0.5) == ACTION_LEFT:
            state -= 1
        else:
            state += 1
        trajectory.append(state)
        if state == 6:
            returns = 1.0
            break
        elif state == 0:
            returns = 0.0
            break

    if not batch:
        for state_ in trajectory[:-1]:
            # MC update
            values[state_] += gamma * (returns - values[state_])
    return trajectory, [returns] * (len(trajectory) - 1)

# Example 6.2 left
def compute_state_value():
    episodes = [0, 1, 10, 100]
    current_values = np.copy(VALUES)

    # Temporal Difference
    plt.figure(figsize=(10, 5))
    plt.subplot(1, 2, 1)
    for i in range(episodes[-1] + 1):
        if i in episodes:
            plt.plot(current_values, label=str(i) + ' Episodes')
        temporal_difference(current_values)
    plt.plot(TRUE_VALUE, label='True Values')
    plt.xlabel('State')
    plt.ylabel('Estimated Value')
    plt.title('TD(0)')
    plt.legend()

    # Monte Carlo
    current_values = np.copy(VALUES)
    plt.subplot(1, 2, 2)
    for i in range(episodes[-1] + 1):
        if i in episodes:
            plt.plot(current_values, label=str(i) + ' Episodes')
        monte_carlo(current_values)
    plt.plot(TRUE_VALUE, label='True Values')
    plt.xlabel('State')
    plt.ylabel('Estimated Value')
    plt.title('Gamma-Constant Monte Carlo')
    plt.legend()

    plt.tight_layout()
    plt.show()

# Example 6.2 right
def rms_error():
    # Same gamma value can appear in both arrays
    td_gammas = [0.1]
    mc_gammas = [0.1]
    episodes = 100 + 1
    runs = 100
    for i, gamma in enumerate(td_gammas + mc_gammas):
        total_errors = np.zeros(episodes)
        if i < len(td_gammas):
            method = 'TD(0)'
            linestyle = 'solid'
        else:
            method = 'Gamma-Const MC'
            linestyle = 'dashdot'
        for r in tqdm(range(runs)):
            errors = []
            current_values = np.copy(VALUES)
            for i in range(0, episodes):
                errors.append(np.sqrt(np.sum(np.power(TRUE_VALUE - current_values, 2)) / 5.0))
                if method == 'TD(0)':
                    temporal_difference(current_values, gamma=gamma)
                else:
                    monte_carlo(current_values, gamma=gamma)
            total_errors += np.asarray(errors)
        total_errors /= runs
        plt.plot(total_errors, linestyle=linestyle, label=method + ', gamma = %.02f' % (gamma))
    plt.xlabel('Episodes')
    plt.ylabel('RMS')
    plt.legend()


def quality_estimate():
    plt.figure(figsize=(10, 20))
    plt.subplot(1, 1, 1)
    compute_state_value()

    # plt.subplot(2, 1, 2)
    # rms_error()
    # plt.tight_layout()

    plt.savefig('./images/quality_estimate.png')
    plt.close()

def rms_plot():
    plt.figure(figsize=(10, 10))

    plt.subplot(1, 1, 1)
    rms_error()
    plt.tight_layout()

    plt.savefig('./images/rms_plot.png')
    plt.close()

if __name__ == '__main__':
    #quality_estimate()
    rms_plot()

100%|██████████| 100/100 [00:00<00:00, 269.49it/s]
100%|██████████| 100/100 [00:00<00:00, 193.39it/s]
