In [1]:
import numpy as np
import os
os.chdir(os.path.join('..', 'dynamic_programming'))
import import_ipynb
from grid_world import standard_grid, ACTION_SPACE
from iterative_policy_evaluation_deterministic import print_policy, print_values

importing Jupyter notebook from grid_world.ipynb
importing Jupyter notebook from iterative_policy_evaluation_deterministic.ipynb


In [2]:
GAMMA = 0.9

In [3]:
def play_game(grid, policy, max_steps = 20) :
    allowable_states = grid.actions.keys()
    grid.set_state(list(allowable_states) [np.random.choice(len(allowable_states))])
    
    s = grid.current_state()
    states = [s]
    rewards = [0]
    count = 0
    while not grid.game_over() and count < max_steps :
        a = policy[s]
        r = grid.move(a)
        s = grid.current_state()
        
        #update
        rewards.append(r)
        states.append(s)
        count += 1
        
    return states, rewards

In [4]:
if __name__ == '__main__' :
    grid = standard_grid()
    
    print('Rewards')
    print_values(grid.rewards, grid)
    
    policy = {
        (0, 0) : 'R',
        (0, 1) : 'R',
        (0, 2) : 'R',
        (1, 0) : 'U',
        (1, 2) : 'R',
        (2, 0) : 'U',
        (2, 1) : 'R',
        (2, 2) : 'U',
        (2, 3) : 'L'
    }
    
    V = {}
    returns = {}
    for s in grid.all_states() :
        if grid.is_terminal(s) :
            V[s] = 0
        else :
            returns[s] = []
    
    #loop until convergence
    for i in range(100) :
        states, rewards = play_game(grid, policy, max_steps=20)
        G = 0 #holds the return for each step
        for t in range(len(states) - 2, -1, -1) :
            G = rewards[t + 1] + GAMMA * G
            s_t = states[t]
            
            #first visit Monte Carlo
            if s_t not in states[:t]:
                returns[s_t].append(G)
                V[s_t] = np.mean(returns[s_t])
    
    print('Values')
    print_values(V, grid)
    print('Policy')
    print_policy(policy, grid)

Rewards
+---+---+---+----+
| 0 | 0 | 0 |  1 |
+---+---+---+----+
| 0 | 0 | 0 | -1 |
+---+---+---+----+
| 0 | 0 | 0 |  0 |
+---+---+---+----+
Values
+-------+--------+--------+--------+
| 0.810 |  0.900 |  1.000 |  0.000 |
+-------+--------+--------+--------+
| 0.729 |  0.000 | -1.000 |  0.000 |
+-------+--------+--------+--------+
| 0.656 | -0.810 | -0.900 | -0.810 |
+-------+--------+--------+--------+
Policy
+---+---+---+---+
| R | R | R |   |
+---+---+---+---+
| U |   | R |   |
+---+---+---+---+
| U | R | U | L |
+---+---+---+---+
