In [1]:
import numpy as np

import sys
sys.path.append('../utils')
from GridWorld import get_standard_grid, get_negative_grid
from iterative_policy_eval import printValues, printPolicy

GAMMA = 0.9

In [2]:
def playGame(grid, policy, max_steps=20):
    # returnss a lisst of states and corresponding rewards
    startStates = list(grid.actions.keys())
    startIdx = np.random.choice(len(startStates))
    grid.set_state(startStates[startIdx])

    s = grid.current_state()

    states = [s]
    rewards = [0] # if we don't do this rewards and states will be missaligned

    steps = 0
    while not grid.game_over():
        a = policy[s]
        r = grid.move(a)
        next_s = grid.current_state()

        states.append(next_s)
        rewards.append(r)

        steps += 1
        if steps >= max_steps:
            break

        s = next_s

    return states, rewards

In [3]:
grid = get_standard_grid()

print("rewards:")
printValues(grid.rewards, grid)

policy = {
    (2, 0): 'U',
    (1, 0): 'U',
    (0, 0): 'R',
    (0, 1): 'R',
    (0, 2): 'R',
    (1, 2): 'R',
    (2, 1): 'R',
    (2, 2): 'R',
    (2, 3): 'U'
}

# initialize V(s) and returns
V = {}
returns = {}
states = grid.all_states()
for s in states:
    if s in grid.actions:
        returns[s] = []
    else:
        V[s] = 0

for t in range(100):
    states, rewards = playGame(grid, policy)
    G = 0
    T = len(states)
    for t in range(T-2, -1, -1):
        s = states[t]
        r = rewards[t+1]
        G = r + GAMMA * G

        if s not in states[:t]:
            returns[s].append(G)
            V[s] = np.mean(returns[s])

print("values:")
printValues(V, grid)

print("policy:")
printPolicy(policy, grid)

rewards:
-------------------------
 0.00| 0.00| 0.00| 1.00|
-------------------------
 0.00| 0.00| 0.00|-1.00|
-------------------------
 0.00| 0.00| 0.00| 0.00|
values:
-------------------------
 0.81| 0.90| 1.00| 0.00|
-------------------------
 0.73| 0.00|-1.00| 0.00|
-------------------------
 0.66|-0.81|-0.90|-1.00|
policy:
-------------------------
 R | R | R |   |
-------------------------
 U |   | R |   |
-------------------------
 U | R | R | U |
