In [19]:
import numpy as np
from GridWorld import standard_grid, negative_grid,print_values, print_policy

SMALL_ENOUGH = 1e-3
GAMMA = 0.9
ALL_POSSIBLE_ACTIONS = ('U', 'D', 'L', 'R')
# this grid gives you a reward of -0.1
# to find a shorter path to the goal, use negative grid
grid = negative_grid()
print(grid.i, grid.j)
print("Recompenzas:")
print_values(grid.rewards, grid)

Recompenzas:
0 0
---------------------------
-1.00|-1.00|-1.00|-1.00|
---------------------------
-1.00|-1.00|-1.00|-1.00|
---------------------------
-1.00|-1.00|-1.00|-1.00|
---------------------------
-1.00|-1.00|-1.00|-1.00|


In [20]:
# state -> action
# choose an action and update randomly 
policy = {}
for s in grid.actions.keys():
    policy[s] = np.random.choice(ALL_POSSIBLE_ACTIONS)

In [21]:
# initial policy
print("Política inicial:")
print_policy(policy, grid)

Política inicial:
---------------------------
  R  |  L  |  R  |  U  |
---------------------------
  D  |  L  |  L  |  R  |
---------------------------
  L  |  D  |  D  |  D  |
---------------------------
  U  |  R  |  U  |     |


In [22]:
# initialize V(s) - value function
V = {}
states = grid.all_states()
for s in states:
    print('s',s)
    # V[s] = 0
    if s in grid.actions:
        V[s] = np.random.random()
    else:
        # terminal state
        V[s] = 0

# initial value for all states in grid
print_values(V, grid)

s (0, 1)
s (1, 2)
s (2, 1)
s (3, 1)
s (0, 2)
s (2, 2)
s (1, 0)
s (3, 2)
s (1, 3)
s (0, 0)
s (1, 1)
s (0, 3)
s (2, 0)
s (3, 0)
s (2, 3)
s (3, 3)
---------------------------
 0.07| 0.16| 0.05| 0.19|
---------------------------
 0.44| 0.59| 0.80| 0.46|
---------------------------
 0.60| 0.98| 0.16| 0.70|
---------------------------
 0.33| 1.00| 0.29| 0.00|


In [23]:

iteration=0
# repeat until convergence
# when policy does not change, it will finish
while True:
    iteration+=1
    print("Valores %d: " % iteration)
    print_values(V, grid)
    print("Política %d: " % iteration)
    print_policy(policy, grid)

    # policy evaluation step
    while True:
        biggest_change = 0
        for s in states:
              old_v = V[s]

              # V(s) only has value if it's not a terminal state
              if s in policy:
                a = policy[s]
                grid.set_state(s)
                r = grid.move(a) #reward
                V[s] = r + GAMMA * V[grid.current_state()]
                biggest_change = max(biggest_change, np.abs(old_v - V[s]))

        if biggest_change < SMALL_ENOUGH:
              break

    # policy improvement step
    is_policy_converged = True
    for s in states:
        if s in policy:
            old_a = policy[s]
            new_a = None
            best_value = float('-inf')
            # loop through all possible actions to find the best current action
            for a in ALL_POSSIBLE_ACTIONS:
                grid.set_state(s)
                r = grid.move(a)
                v = r + GAMMA * V[grid.current_state()]
                if v > best_value:
                    best_value = v
                    new_a = a
            policy[s] = new_a
            if new_a != old_a:
                is_policy_converged = False

    if is_policy_converged:
        break

Valores 1: 
---------------------------
 0.07| 0.16| 0.05| 0.19|
---------------------------
 0.44| 0.59| 0.80| 0.46|
---------------------------
 0.60| 0.98| 0.16| 0.70|
---------------------------
 0.33| 1.00| 0.29| 0.00|
Política 1: 
---------------------------
  R  |  L  |  R  |  U  |
---------------------------
  D  |  L  |  L  |  R  |
---------------------------
  L  |  D  |  D  |  D  |
---------------------------
  U  |  R  |  U  |     |
Valores 2: 
---------------------------
-10.00|-10.00|-9.99|-9.99|
---------------------------
-9.99|-9.99|-9.99|-9.99|
---------------------------
-9.99|-9.99|-9.99|-9.99|
---------------------------
-9.99|-9.99|-9.99| 0.00|
Política 2: 
---------------------------
  R  |  R  |  U  |  D  |
---------------------------
  U  |  D  |  R  |  D  |
---------------------------
  R  |  D  |  L  |  D  |
---------------------------
  U  |  U  |  U  |     |
Valores 3: 
---------------------------
-9.99|-9.99|-9.99|-9.99|
---------------------------
-9.99|-

In [24]:
print("Valores finales:")
print_values(V, grid)
print("Política final:")
print_policy(policy, grid)

Valores finales:
---------------------------
-10.00|-10.00|-10.00|-10.00|
---------------------------
-10.00|-10.00|-10.00|-10.00|
---------------------------
-10.00|-10.00|-10.00|-10.00|
---------------------------
-10.00|-10.00|-10.00| 0.00|
Política final:
---------------------------
  R  |  U  |  U  |  L  |
---------------------------
  U  |  D  |  U  |  L  |
---------------------------
  R  |  D  |  L  |  L  |
---------------------------
  R  |  U  |  U  |     |
