In [82]:
%load_ext autoreload
%autoreload 2
import grid_world as gw
import numpy as np
import typing

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [72]:
def print_V(V:typing.Tuple, width:int, height:int):
    for i in range(width):
        line = ""
        for j in range(height):
            index = (i, j)
            if index in V:
                val = V[index]
                if val >= 0:
                    line += "| %1.2f" % val
                else:
                    line += "|%1.2f" % val

            else:
                line += '|' + " 0.00"
        print(line+'|')
        print('-------------------------')

In [167]:
def print_policy(grid:gw.Grid, p:typing.Callable[[typing.Tuple, any], typing.List[typing.Tuple[float,any]]], width:int, height:int):
    for i in range(width):
        line = ""
        for j in range(height):
            index = (i, j)
            if index in grid.actions:
                if type(p) == type({}):
                    actions = [(1., p[index])]
                else:
                    actions = p(index, grid)
                highest_prob = 0.
                action_to_take = None

                for a in actions:
                    prob, action = a
                    if prob > highest_prob:
                        highest_prob = prob
                        action_to_take = action

                line += '|  ' + str(action_to_take) + "  "
            else:
                line += '|  ' + "X" + "  "
        print(line+'|')
        print('-------------------------')

In [74]:
def random_policy(s:typing.Tuple, environment:any) -> typing.List[typing.Tuple[float, any]]:
    moves = environment.actions[s]
    actions = []
    for move in moves:
        actions.append((1/len(moves),move))

    return actions

In [113]:
def policy_evaluation(grid:gw.Grid, p:typing.Callable[[typing.Tuple, any],
                                                      typing.List[typing.Tuple[float,any]]],
                                                      e:float=10e-4,
                                                      gamma:float=.9):
    states = grid.all_states()
    V = {}
    
    for state in states:
        V[state] = 0
    
    i = 0
    while True:
        i += 1
        max_change = 0.

        for s in states:
            if s in grid.actions:
                v_old = V[s]

                actions = p(s, grid)
                new_v = 0.
                for a in actions:
                    if a is not None:
                        grid.set_state(s)
                        prob, action = a
                        reward = grid.move(action, force=True)
                        new_v += prob * (reward + gamma * V[grid.current_state()])
                V[s] = new_v

                max_change = max(max_change, abs(v_old - new_v))
        print("Round {}, change: {}".format(i, max_change))
        print_policy(grid=grid, width=grid.width, height=grid.height,p=p)
        print_V(V=V, width=grid.width, height=grid.height)

        if max_change < e:
            break

    return V

In [131]:
def policy_iteration(grid:gw.Grid, windy = False):
    class Policy():
        dict = {
            (2, 0):'U',
            (1, 0):'R',
            (0, 0):'L',
            (0, 1):'U',
            (0, 2):'D',
            (1, 2):'L',
            (2, 1):'R',
            (2, 2):'U',
            (2, 3):'D',
            (0, 3):None,
            (1, 1):None,
            (1, 3):None,
        }

        def __init__(self):
            pass

        def __call__(self, s:any, env:any) -> typing.List[typing.Tuple[float, any]]:
            action = self.dict[s]
            return [(1., action)]

    p = Policy()
    gamma = .9

    while True:
        policy_changed = False
        V = policy_evaluation(grid=grid, p=p, gamma=gamma)
        for s in grid.actions:
            grid.set_state(s)
            # This only works for exactly one action
            old_a = p(s, env=grid)[0][1]
            best_val = -float('inf')
            best_action = None
            for action in ['U','D','L','R']:
                grid.set_state(s)
                r = grid.move(action, force=False)
                s_prime = grid.current_state()

                # only check if we actually moved
                if s_prime != s:
                    val = 1. * r + gamma * V[s_prime]
                    if val > best_val:
                        best_val = val
                        best_action = action

            p.dict[s] = best_action if best_action is not None else p.dict[s]
            if best_action != old_a:
                policy_changed = True

        if not policy_changed:
            break


In [186]:
def policy_iteration_2(grid:gw.Grid, windy = False):
    V = {}
    policy = {}
    
    for key in grid.actions.keys():
        policy[key] = np.random.choice(['U', 'D', 'L', 'R'])
        
    for s in grid.all_states():
        V[s] = np.random.rand()
        
    while True:
        
        # Policy evaluation
        while True:
            max_change = -float('inf')
            for s in grid.actions:
                grid.set_state(s)
                best_val = -float('inf')
                best_action = None
                for action in ['U','D','L','R']:
                    grid.set_state(s)
                    r = grid.move(action)
                    s_prime = grid.current_state()
                    if s_prime != s:
                        val = r + 0.9 * V[s_prime]
                        if val > best_val:
                            best_val = val
                            best_action = action
                max_change = max(max_change, abs(V[s] - best_val))
                V[s] = best_val
                
            print(max_change)
            if max_change < 0.0001:
                break
        print_V(V, width=grid.width, height=grid.height)      
        
        
        policy_changed = False
        for key in policy.keys():
            possible_actions = grid.actions[key]
            action_old = policy[key]
            
            max_val = float("-inf")
            best_action = action_old
            for action in possible_actions:
                grid.set_state(key)
                r = grid.move(action)
                s_prime = grid.current_state()
                val = r + 0.9 * V[s_prime] #abs(V[s_prime] - V[s])
                if val > max_val:
                    max_val = val
                    best_action = action
                    
            if action_old != best_action:
                policy_changed = True
                policy[key] = best_action
                
        print_policy(p=policy, grid=grid, width=grid.width, height=grid.height)
        
        if not policy_changed:
            break


In [77]:
def fixed_policy(s:typing.Tuple, environment:any) -> typing.List[typing.Tuple[float, any]]:
    state_action_pairs = {
        (2, 0):'U',
        (1, 0):'U',
        (0, 0):'R',
        (0, 1):'R',
        (0, 2):'R',
        (1, 2):'R',
        (2, 1):'U',
        (2, 2):'U',
        (2, 3):'U',

        (0, 3):None,
        (1, 1):None,
        (1, 3):None,

    }

    action = state_action_pairs[s]
    return [(1., action)]

In [181]:
policy_evaluation(grid=gw.standard_grid(), p=fixed_policy)

Round 1, change: 1.0
|  R  |  R  |  R  |  X  |
-------------------------
|  U  |  X  |  R  |  X  |
-------------------------
|  U  |  U  |  U  |  U  |
-------------------------
| 0.00| 0.00| 1.00| 0.00|
-------------------------
| 0.00| 0.00|-1.00| 0.00|
-------------------------
| 0.00| 0.00|-0.90|-1.00|
-------------------------
Round 2, change: 0.9
|  R  |  R  |  R  |  X  |
-------------------------
|  U  |  X  |  R  |  X  |
-------------------------
|  U  |  U  |  U  |  U  |
-------------------------
| 0.81| 0.90| 1.00| 0.00|
-------------------------
| 0.73| 0.00|-1.00| 0.00|
-------------------------
| 0.00| 0.00|-0.90|-1.00|
-------------------------
Round 3, change: 0.6561000000000001
|  R  |  R  |  R  |  X  |
-------------------------
|  U  |  X  |  R  |  X  |
-------------------------
|  U  |  U  |  U  |  U  |
-------------------------
| 0.81| 0.90| 1.00| 0.00|
-------------------------
| 0.73| 0.00|-1.00| 0.00|
-------------------------
| 0.66| 0.00|-0.90|-1.00|
------------

{(0, 0): 0.81,
 (0, 1): 0.9,
 (0, 2): 1.0,
 (0, 3): 0,
 (1, 0): 0.7290000000000001,
 (1, 2): -1.0,
 (1, 3): 0,
 (2, 0): 0.6561000000000001,
 (2, 1): 0.0,
 (2, 2): -0.9,
 (2, 3): -1.0}

In [78]:
policy_evaluation(grid=gw.standard_grid(), p=random_policy)

Round 1, change: 0.5
|  D  |  L  |  L  |  X  |
-------------------------
|  U  |  X  |  U  |  X  |
-------------------------
|  U  |  L  |  L  |  L  |
-------------------------
| 0.00| 0.00| 0.23| 0.00|
-------------------------
| 0.00| 0.00|-0.33| 0.00|
-------------------------
| 0.00| 0.00|-0.25|-0.50|
-------------------------
Round 2, change: 0.11250000000000004
|  D  |  L  |  L  |  X  |
-------------------------
|  U  |  X  |  U  |  X  |
-------------------------
|  U  |  L  |  L  |  L  |
-------------------------
| 0.05| 0.11| 0.26| 0.00|
-------------------------
|-0.00| 0.00|-0.34| 0.00|
-------------------------
|-0.05|-0.11|-0.32|-0.61|
-------------------------
Round 3, change: 0.05383125000000001
|  D  |  L  |  L  |  X  |
-------------------------
|  U  |  X  |  U  |  X  |
-------------------------
|  U  |  L  |  L  |  L  |
-------------------------
| 0.06| 0.14| 0.27| 0.00|
-------------------------
|-0.01| 0.00|-0.35| 0.00|
-------------------------
|-0.08|-0.17|-0.35|-0

{(0, 0): 0.055352482520719955,
 (0, 1): 0.1457866028842206,
 (0, 2): 0.2674431564711638,
 (0, 3): 0,
 (1, 0): -0.023585994548220707,
 (1, 2): -0.3654205257581189,
 (1, 3): 0,
 (2, 0): -0.1077658037389882,
 (2, 1): -0.21669847769290862,
 (2, 2): -0.3752241775722589,
 (2, 3): -0.6686282551231688}

In [110]:
policy_iteration(grid=gw.negative_grid())

Round 1, change: 0.19
|  L  |  U  |  D  |  X  |
-------------------------
|  R  |  X  |  L  |  X  |
-------------------------
|  U  |  R  |  U  |  D  |
-------------------------
|-0.10|-0.10|-0.19| 0.00|
-------------------------
|-0.10| 0.00|-0.10| 0.00|
-------------------------
|-0.10|-0.10|-0.19|-0.10|
-------------------------
Round 2, change: 0.171
|  L  |  U  |  D  |  X  |
-------------------------
|  R  |  X  |  L  |  X  |
-------------------------
|  U  |  R  |  U  |  D  |
-------------------------
|-0.19|-0.19|-0.27| 0.00|
-------------------------
|-0.19| 0.00|-0.19| 0.00|
-------------------------
|-0.19|-0.27|-0.27|-0.19|
-------------------------
Round 3, change: 0.08100000000000002
|  L  |  U  |  D  |  X  |
-------------------------
|  R  |  X  |  L  |  X  |
-------------------------
|  U  |  R  |  U  |  D  |
-------------------------
|-0.27|-0.27|-0.34| 0.00|
-------------------------
|-0.27| 0.00|-0.27| 0.00|
-------------------------
|-0.27|-0.34|-0.34|-0.27|
--------

In [126]:
policy_iteration(grid=gw.negative_grid(windy=True))

Round 1, change: 0.19
|  L  |  U  |  D  |  X  |
-------------------------
|  R  |  X  |  L  |  X  |
-------------------------
|  U  |  R  |  U  |  D  |
-------------------------
|-0.10|-0.10|-0.19| 0.00|
-------------------------
|-0.10| 0.00|-0.10| 0.00|
-------------------------
|-0.10|-0.10|-0.19|-0.10|
-------------------------
Round 2, change: 0.171
|  L  |  U  |  D  |  X  |
-------------------------
|  R  |  X  |  L  |  X  |
-------------------------
|  U  |  R  |  U  |  D  |
-------------------------
|-0.19|-0.19|-0.27| 0.00|
-------------------------
|-0.19| 0.00|-0.19| 0.00|
-------------------------
|-0.19|-0.27|-0.27|-0.19|
-------------------------
Round 3, change: 0.08100000000000002
|  L  |  U  |  D  |  X  |
-------------------------
|  R  |  X  |  L  |  X  |
-------------------------
|  U  |  R  |  U  |  D  |
-------------------------
|-0.27|-0.27|-0.34| 0.00|
-------------------------
|-0.27| 0.00|-0.27| 0.00|
-------------------------
|-0.27|-0.34|-0.34|-0.27|
--------

|-0.97|-0.96|-0.97|-0.96|
-------------------------
Round 17, change: 0.007248888065061965
|  R  |  R  |  R  |  X  |
-------------------------
|  U  |  X  |  U  |  X  |
-------------------------
|  R  |  L  |  L  |  L  |
-------------------------
| 0.62| 0.80| 1.00| 0.00|
-------------------------
| 0.46| 0.00| 0.80| 0.00|
-------------------------
|-0.97|-0.97|-0.97|-0.97|
-------------------------
Round 18, change: 0.005871599332700206
|  R  |  R  |  R  |  X  |
-------------------------
|  U  |  X  |  U  |  X  |
-------------------------
|  R  |  L  |  L  |  L  |
-------------------------
| 0.62| 0.80| 1.00| 0.00|
-------------------------
| 0.46| 0.00| 0.80| 0.00|
-------------------------
|-0.98|-0.97|-0.98|-0.97|
-------------------------
Round 19, change: 0.004755995459487128
|  R  |  R  |  R  |  X  |
-------------------------
|  U  |  X  |  U  |  X  |
-------------------------
|  R  |  L  |  L  |  L  |
-------------------------
| 0.62| 0.80| 1.00| 0.00|
-------------------------

-------------------------
Round 11, change: 0.03486784400999998
|  D  |  D  |  R  |  X  |
-------------------------
|  D  |  X  |  U  |  X  |
-------------------------
|  R  |  L  |  U  |  L  |
-------------------------
|-0.90|-0.69| 1.00| 0.00|
-------------------------
|-0.91| 0.00| 0.80| 0.00|
-------------------------
|-0.90|-0.89| 0.62| 0.46|
-------------------------
Round 12, change: 0.03138105960899995
|  D  |  D  |  R  |  X  |
-------------------------
|  D  |  X  |  U  |  X  |
-------------------------
|  R  |  L  |  U  |  L  |
-------------------------
|-0.92|-0.72| 1.00| 0.00|
-------------------------
|-0.93| 0.00| 0.80| 0.00|
-------------------------
|-0.92|-0.91| 0.62| 0.46|
-------------------------
Round 13, change: 0.028242953648100033
|  D  |  D  |  R  |  X  |
-------------------------
|  D  |  X  |  U  |  X  |
-------------------------
|  R  |  L  |  U  |  L  |
-------------------------
|-0.94|-0.75| 1.00| 0.00|
-------------------------
|-0.94| 0.00| 0.80| 0.00|
-

|-1.00| 0.00| 0.80| 0.00|
-------------------------
|-1.00|-1.00|-1.00|-1.00|
-------------------------
Round 28, change: 0.005814973700304038
|  L  |  R  |  R  |  X  |
-------------------------
|  D  |  X  |  U  |  X  |
-------------------------
|  R  |  R  |  L  |  L  |
-------------------------
|-0.95| 0.80| 1.00| 0.00|
-------------------------
|-1.00| 0.00| 0.80| 0.00|
-------------------------
|-1.00|-1.00|-1.00|-1.00|
-------------------------
Round 29, change: 0.005233476330273601
|  L  |  R  |  R  |  X  |
-------------------------
|  D  |  X  |  U  |  X  |
-------------------------
|  R  |  R  |  L  |  L  |
-------------------------
|-0.95| 0.80| 1.00| 0.00|
-------------------------
|-1.00| 0.00| 0.80| 0.00|
-------------------------
|-1.00|-1.00|-1.00|-1.00|
-------------------------
Round 30, change: 0.004710128697246185
|  L  |  R  |  R  |  X  |
-------------------------
|  D  |  X  |  U  |  X  |
-------------------------
|  R  |  R  |  L  |  L  |
-------------------------

|  U  |  X  |  U  |  X  |
-------------------------
|  U  |  R  |  U  |  L  |
-------------------------
|-0.99|-0.99| 1.00| 0.00|
-------------------------
|-0.99| 0.00| 0.80| 0.00|
-------------------------
|-0.99| 0.46| 0.62| 0.46|
-------------------------
Round 45, change: 0.0009697737297875708
|  L  |  L  |  R  |  X  |
-------------------------
|  U  |  X  |  U  |  X  |
-------------------------
|  U  |  R  |  U  |  L  |
-------------------------
|-0.99|-0.99| 1.00| 0.00|
-------------------------
|-0.99| 0.00| 0.80| 0.00|
-------------------------
|-0.99| 0.46| 0.62| 0.46|
-------------------------
Due to wind I switched action U to R
Due to wind I switched action U to L
Due to wind I switched action L to U
Due to wind I switched action R to L
Round 1, change: 1.0
|  U  |  R  |  R  |  X  |
-------------------------
|  U  |  X  |  U  |  X  |
-------------------------
|  R  |  R  |  L  |  L  |
-------------------------
|-0.10|-0.10| 1.00| 0.00|
-------------------------
|-0.19| 0.0

|  D  |  R  |  R  |  X  |
-------------------------
|  U  |  X  |  U  |  X  |
-------------------------
|  R  |  L  |  U  |  L  |
-------------------------
|-0.99| 0.80| 1.00| 0.00|
-------------------------
|-0.99| 0.00| 0.80| 0.00|
-------------------------
|-0.99|-0.99| 0.62| 0.46|
-------------------------
Round 26, change: 0.0010880192104343323
|  D  |  R  |  R  |  X  |
-------------------------
|  U  |  X  |  U  |  X  |
-------------------------
|  R  |  L  |  U  |  L  |
-------------------------
|-1.00| 0.80| 1.00| 0.00|
-------------------------
|-1.00| 0.00| 0.80| 0.00|
-------------------------
|-1.00|-1.00| 0.62| 0.46|
-------------------------
Round 27, change: 0.0008812955604516892
|  D  |  R  |  R  |  X  |
-------------------------
|  U  |  X  |  U  |  X  |
-------------------------
|  R  |  L  |  U  |  L  |
-------------------------
|-1.00| 0.80| 1.00| 0.00|
-------------------------
|-1.00| 0.00| 0.80| 0.00|
-------------------------
|-1.00|-1.00| 0.62| 0.46|
----------

In [192]:
policy_iteration_2(grid=gw.standard_grid())

0.9718563444935845
0.3696621062511587
0.33269589562604285
0.0
| 0.94| 1.05| 1.17| 0.18|
-------------------------
| 0.85| 0.00| 1.05| 0.65|
-------------------------
| 0.76| 0.85| 0.94| 0.85|
-------------------------
|  R  |  R  |  R  |  X  |
-------------------------
|  U  |  X  |  U  |  X  |
-------------------------
|  U  |  R  |  U  |  L  |
-------------------------
0.0
| 0.94| 1.05| 1.17| 0.18|
-------------------------
| 0.85| 0.00| 1.05| 0.65|
-------------------------
| 0.76| 0.85| 0.94| 0.85|
-------------------------
|  R  |  R  |  R  |  X  |
-------------------------
|  U  |  X  |  U  |  X  |
-------------------------
|  U  |  R  |  U  |  L  |
-------------------------


In [193]:
policy_iteration_2(grid=gw.negative_grid())

0.9260149607542584
0.6192802318343924
0.5573522086509533
0.0
| 0.95| 1.16| 1.40| 0.45|
-------------------------
| 0.75| 0.00| 1.16| 0.36|
-------------------------
| 0.58| 0.75| 0.95| 0.75|
-------------------------
|  R  |  R  |  R  |  X  |
-------------------------
|  U  |  X  |  U  |  X  |
-------------------------
|  U  |  R  |  U  |  L  |
-------------------------
0.0
| 0.95| 1.16| 1.40| 0.45|
-------------------------
| 0.75| 0.00| 1.16| 0.36|
-------------------------
| 0.58| 0.75| 0.95| 0.75|
-------------------------
|  R  |  R  |  R  |  X  |
-------------------------
|  U  |  X  |  U  |  X  |
-------------------------
|  U  |  R  |  U  |  L  |
-------------------------
