
We import the necessary libraries for numerical computations, plotting, and progress tracking.

In [26]:
import numpy as np
import matplotlib
matplotlib.use('Agg')  # Use a non-interactive backend for plotting
import matplotlib.pyplot as plt
from tqdm import tqdm  # Progress bar for loops

We define some constants that will be used throughout the simulation, such as the possible actions and the true state values for comparison.

In [27]:
ACTION_LEFT = 0
ACTION_RIGHT = 1

# True state values for states 1 to 5 (states 'A' to 'E')
TRUE_VALUES = np.zeros(7)
TRUE_VALUES[1:6] = np.arange(1, 6) / 6.0
TRUE_VALUES[6] = 1  # The right terminal state has a value of 1

We create a RandomWalkEnvironment class to simulate the environment. The environment consists of five non-terminal states labeled 'A' to 'E' (states 1 to 5), and two terminal states (state 0 and state 6). The agent starts at state 3 ('C') and can move left or right until it reaches a terminal state.

In [28]:
class RandomWalkEnvironment:
    def __init__(self):
        # Initialize the starting state and terminal states
        self.start_state = 3  # Start at state 'C'
        self.left_terminal_state = 0
        self.right_terminal_state = 6
        self.actions = [ACTION_LEFT, ACTION_RIGHT]
        self.reset()
    
    def reset(self):
        # Reset the environment to the starting state
        self.current_state = self.start_state
    
    def step(self, action):
        # Take an action and move to the next state
        if action == ACTION_LEFT:
            next_state = self.current_state - 1
        elif action == ACTION_RIGHT:
            next_state = self.current_state + 1
        else:
            raise ValueError("Invalid action")
        
        # Check if the next state is terminal
        done = next_state in [self.left_terminal_state, self.right_terminal_state]
        reward = 0  # Rewards are zero until reaching a terminal state
        self.current_state = next_state
        return next_state, reward, done

The ValueFunction class stores and updates the estimated values of each state. The initial values are set to 0.5 for non-terminal states, with the left terminal state at 0 and the right terminal state at 1.

In [29]:
class ValueFunction:
    def __init__(self, initial_values=None):
        if initial_values is None:
            # Initialize state values to 0.5 for non-terminal states
            self.values = np.full(7, 0.5)
            self.values[0] = 0  # Left terminal state
            self.values[6] = 1  # Right terminal state
        else:
            # Use provided initial values
            self.values = initial_values.copy()
    
    def update(self, state, delta):
        # Update the value of a state by adding the delta
        self.values[state] += delta
    
    def get_value(self, state):
        # Get the current estimated value of a state
        return self.values[state]

The Agent class represents the agent interacting with the environment. The agent can use either TD or MC methods to update the value function based on the episodes it plays.

In [30]:
class Agent:
    def __init__(self, env, value_function, alpha=0.1):
        self.env = env
        self.value_function = value_function
        self.alpha = alpha  # Step size parameter
    
    def choose_action(self):
        # Randomly choose an action (left or right) with equal probability
        return np.random.choice(self.env.actions)
    
    def play_episode_td(self, batch=False):
        # Play an episode using the TD method
        return self.run_episode(batch, self.temporal_difference_update)
    
    def play_episode_mc(self, batch=False):
        # Play an episode using the MC method
        return self.run_episode(batch, self.monte_carlo_update)
    
    def run_episode(self, batch, update_func):
        # General method to run an episode and update the value function
        trajectory = []  # List to store the sequence of visited states
        rewards = []     # List to store the rewards received
        self.env.reset()
        state = self.env.current_state
        trajectory.append(state)
        
        while True:
            action = self.choose_action()
            next_state, reward, done = self.env.step(action)
            trajectory.append(next_state)
            rewards.append(reward)
            
            if not batch:
                # Update the value function if not in batch mode
                update_func(state, next_state, reward)
            
            if done:
                break  # Episode ends when a terminal state is reached
            state = next_state  # Move to the next state
        
        return trajectory, rewards
    
    def temporal_difference_update(self, state, next_state, reward):
        # Update the value function using the TD(0) update rule
        delta = self.alpha * (reward + self.value_function.get_value(next_state) - self.value_function.get_value(state))
        self.value_function.update(state, delta)
    
    def monte_carlo_update(self, state, _, reward):
        # Update the value function using the MC update rule
        # Since rewards are zero until the terminal state, returns are either 0 or 1
        returns = 1.0 if self.env.current_state == self.env.right_terminal_state else 0.0
        delta = self.alpha * (returns - self.value_function.get_value(state))
        self.value_function.update(state, delta)

This function runs episodes and collects the estimated state values at specified episodes for plotting.

In [31]:
def compute_state_values_over_episodes(agent, episodes, method):
    values_over_episodes = []
    for episode in range(1, max(episodes) + 1):
        # Play an episode using the specified method
        if method == 'TD':
            agent.play_episode_td()
        elif method == 'MC':
            agent.play_episode_mc()
        else:
            raise ValueError("Method must be 'TD' or 'MC'")
        
        if episode in episodes:
            # Record the value estimates at the specified episodes
            values_over_episodes.append((episode, agent.value_function.values.copy()))
    return values_over_episodes

This function plots the estimated state values against the true state values.

In [32]:
def plot_state_values(values_over_episodes, true_values):
    for episode, values in values_over_episodes:
        plt.plot(["A", "B", "C", "D", "E"], values[1:6], label=f'{episode} episodes')
    plt.plot(["A", "B", "C", "D", "E"], true_values[1:6], label='True values')
    plt.xlabel('State')
    plt.ylabel('Estimated Value')
    plt.legend()

These functions compute the RMS errors over multiple runs for different alpha values.

In [33]:
def compute_rms_errors(agent_class, method, alphas, episodes, runs, true_values):
    # Compute RMS errors for different alpha values
    errors_list = []
    for alpha in alphas:
        errors = calculate_average_error(agent_class, method, alpha, episodes, runs, true_values)
        errors_list.append(errors)
    return np.array(errors_list)

def calculate_average_error(agent_class, method, alpha, episodes, runs, true_values):
    # Calculate the average RMS error over multiple runs
    total_errors = np.zeros(episodes)
    for _ in tqdm(range(runs), leave=False):
        env = RandomWalkEnvironment()
        value_function = ValueFunction()
        agent = agent_class(env, value_function, alpha=alpha)
        errors = run_episodes(agent, episodes, method, true_values)
        total_errors += errors
    average_errors = total_errors / runs
    return average_errors

def run_episodes(agent, episodes, method, true_values):
    # Run episodes and record the RMS error at each episode
    errors = []
    for _ in range(episodes):
        if method == 'TD':
            agent.play_episode_td()
        elif method == 'MC':
            agent.play_episode_mc()
        else:
            raise ValueError("Method must be 'TD' or 'MC'")
        
        # Compute RMS error for non-terminal states (states 1 to 5)
        error = np.sqrt(np.mean((agent.value_function.values[1:6] - true_values[1:6]) ** 2))
        errors.append(error)
    return np.array(errors)

These functions perform batch updating using all episodes collected so far until convergence.

In [34]:
def batch_updating(agent_class, method, episodes, alpha, runs, true_values):
    total_errors = np.zeros(episodes)
    for _ in tqdm(range(runs), leave=False):
        env = RandomWalkEnvironment()
        value_function = ValueFunction()
        # Initialize state values to zero for batch updating
        value_function.values[1:6] = 0
        agent = agent_class(env, value_function, alpha=alpha)
        errors = run_batch_updates(agent, episodes, method, true_values)
        total_errors += errors
    average_errors = total_errors / runs
    return average_errors

def run_batch_updates(agent, episodes, method, true_values):
    errors = []
    trajectories = []
    rewards_list = []
    for _ in range(episodes):
        # Collect trajectories and rewards without updating the value function
        if method == 'TD':
            trajectory, rewards = agent.play_episode_td(batch=True)
        elif method == 'MC':
            trajectory, rewards = agent.play_episode_mc(batch=True)
        else:
            raise ValueError("Method must be 'TD' or 'MC'")
        
        trajectories.append(trajectory)
        rewards_list.append(rewards)
        
        # Perform batch update using all collected data
        agent.value_function = batch_update(agent.value_function, trajectories, rewards_list, agent.alpha, method)
        
        # Compute RMS error after batch update
        error = np.sqrt(np.mean((agent.value_function.values[1:6] - true_values[1:6]) ** 2))
        errors.append(error)
    return np.array(errors)

def batch_update(value_function, trajectories, rewards_list, alpha, method):
    # Update the value function until convergence
    while True:
        updates = compute_updates(value_function, trajectories, rewards_list, method)
        updates *= alpha
        if np.sum(np.abs(updates)) < 1e-3:
            break  # Convergence criterion met
        for state in range(1, 6):  # Update non-terminal states
            value_function.update(state, updates[state])
    return value_function

def compute_updates(value_function, trajectories, rewards_list, method):
    # Compute the total updates for each state over all episodes
    updates = np.zeros(7)
    for trajectory, rewards in zip(trajectories, rewards_list):
        for i in range(len(trajectory) - 1):
            state = trajectory[i]
            if method == 'TD':
                # TD target: reward + value of next state
                target = rewards[i] + value_function.get_value(trajectory[i + 1])
            elif method == 'MC':
                # MC target: total return (reward at terminal state)
                target = rewards[i]
            else:
                raise ValueError("Method must be 'TD' or 'MC'")
            
            # Accumulate updates
            updates[state] += target - value_function.get_value(state)
    return updates

This function plots the RMS errors for different alpha values and methods.

In [35]:
def plot_rms_errors(td_alphas, mc_alphas, errors_td, errors_mc):
    # Plot RMS errors for TD methods
    for idx, alpha in enumerate(td_alphas):
        plt.plot(errors_td[idx], linestyle='solid', label=f'TD, α = {alpha:.02f}')
    # Plot RMS errors for MC methods
    for idx, alpha in enumerate(mc_alphas):
        plt.plot(errors_mc[idx], linestyle='dashdot', label=f'MC, α = {alpha:.02f}')
    plt.xlabel('Walks/Episodes')
    plt.ylabel('Empirical RMS Error (averaged over states)')
    plt.legend()

This function plots the RMS errors during batch updating for TD and MC methods.

In [36]:
def plot_batch_errors(errors_td, errors_mc):
    plt.plot(errors_td, label='TD')
    plt.plot(errors_mc, label='MC')
    plt.xlabel('Walks/Episodes')
    plt.ylabel('Empirical RMS Error (averaged over states)')
    plt.legend()

Finally, we run the experiments and generate the plots to compare the performance of TD and MC methods.

In [37]:
def example_6_2():
    # Define true state values for comparison
    true_values = np.zeros(7)
    true_values[1:6] = np.arange(1, 6) / 6.0
    true_values[6] = 1  # Right terminal state
    
    plt.figure(figsize=(10, 20))
    
    # Plot estimated state values over episodes
    plt.subplot(3, 1, 1)
    env = RandomWalkEnvironment()
    value_function = ValueFunction()
    agent = Agent(env, value_function)
    episodes_to_plot = [0, 1, 10, 100]
    values_over_episodes = compute_state_values_over_episodes(agent, episodes_to_plot, method='TD')
    plot_state_values(values_over_episodes, true_values)
    
    # Plot RMS errors over episodes for different alpha values
    plt.subplot(3, 1, 2)
    td_alphas = [0.15, 0.1, 0.05]
    mc_alphas = [0.01, 0.02, 0.03, 0.04]
    episodes = 101  # Number of episodes to run
    runs = 100      # Number of runs to average over
    
    errors_td = compute_rms_errors(Agent, 'TD', td_alphas, episodes, runs, true_values)
    errors_mc = compute_rms_errors(Agent, 'MC', mc_alphas, episodes, runs, true_values)
    
    plot_rms_errors(td_alphas, mc_alphas, errors_td, errors_mc)
    
    # Plot RMS errors during batch updating
    plt.subplot(3, 1, 3)
    batch_episodes = 100
    batch_alpha = 0.001
    batch_runs = 100
    
    errors_td_batch = batch_updating(Agent, 'TD', batch_episodes, batch_alpha, batch_runs, true_values)
    errors_mc_batch = batch_updating(Agent, 'MC', batch_episodes, batch_alpha, batch_runs, true_values)
    
    plot_batch_errors(errors_td_batch, errors_mc_batch)
    
    plt.tight_layout()
    plt.savefig('example_6_2.png')
    plt.close()

In [38]:
if __name__ == '__main__':
    example_6_2()

                                                 