Skip to content

Commit

Permalink
Update random_walk.py
Browse files Browse the repository at this point in the history
I reset the x axis and changed the 'alpha' notation to make the figures look better.
  • Loading branch information
VEXLife committed Jun 6, 2021
1 parent ac4fbce commit 1c32673
Showing 1 changed file with 19 additions and 15 deletions.
34 changes: 19 additions & 15 deletions chapter06/random_walk.py
Expand Up @@ -59,7 +59,7 @@ def temporal_difference(values, alpha=0.1, batch=False):
# @batch: whether to update @values
def monte_carlo(values, alpha=0.1, batch=False):
state = 3
trajectory = [3]
trajectory = [state]

# if end up with left terminal state, all returns are 0
# if end up with right terminal state, all returns are 1
Expand Down Expand Up @@ -89,11 +89,11 @@ def compute_state_value():
plt.figure(1)
for i in range(episodes[-1] + 1):
if i in episodes:
plt.plot(current_values, label=str(i) + ' episodes')
plt.plot(("A", "B", "C", "D", "E"), current_values[1:6], label=str(i) + ' episodes')
temporal_difference(current_values)
plt.plot(TRUE_VALUE, label='true values')
plt.xlabel('state')
plt.ylabel('estimated value')
plt.plot(("A", "B", "C", "D", "E"), TRUE_VALUE[1:6], label='true values')
plt.xlabel('State')
plt.ylabel('Estimated Value')
plt.legend()

# Example 6.2 right
Expand Down Expand Up @@ -122,9 +122,9 @@ def rms_error():
monte_carlo(current_values, alpha=alpha)
total_errors += np.asarray(errors)
total_errors /= runs
plt.plot(total_errors, linestyle=linestyle, label=method + ', alpha = %.02f' % (alpha))
plt.xlabel('episodes')
plt.ylabel('RMS')
plt.plot(total_errors, linestyle=linestyle, label=method + ', $\\alpha$ = %.02f' % (alpha))
plt.xlabel('Walks/Episodes')
plt.ylabel('Empirical RMS error, averaged over states')
plt.legend()

# Figure 6.2
Expand All @@ -135,6 +135,7 @@ def batch_updating(method, episodes, alpha=0.001):
total_errors = np.zeros(episodes)
for r in tqdm(range(0, runs)):
current_values = np.copy(VALUES)
current_values[1:6] = -1
errors = []
# track shown trajectories and reward/return sequences
trajectories = []
Expand Down Expand Up @@ -180,13 +181,16 @@ def example_6_2():

def figure_6_2():
episodes = 100 + 1
td_erros = batch_updating('TD', episodes)
mc_erros = batch_updating('MC', episodes)

plt.plot(td_erros, label='TD')
plt.plot(mc_erros, label='MC')
plt.xlabel('episodes')
plt.ylabel('RMS error')
td_errors = batch_updating('TD', episodes)
mc_errors = batch_updating('MC', episodes)

plt.plot(td_errors, label='TD')
plt.plot(mc_errors, label='MC')
plt.title("Batch Training")
plt.xlabel('Walks/Episodes')
plt.ylabel('RMS error, averaged over states')
plt.xlim(0, 100)
plt.ylim(0, 0.25)
plt.legend()

plt.savefig('../images/figure_6_2.png')
Expand Down

0 comments on commit 1c32673

Please sign in to comment.