In [33]:
import numpy as np
from grid_world import standard_grid, ACTION_SPACE

In [34]:
SMALL_ENOUGH = 1e-3
GAMMA = 0.9

In [35]:
def print_values(V, g):
  for i in range(g.rows):
    print("---------------------------")
    for j in range(g.cols):
      v = V.get((i,j), 0)
      if v >= 0:
        print(" %.2f|" % v, end="")
      else:
        print("%.2f|" % v, end="") # -ve sign takes up an extra space
    print("")


def print_policy(P, g):
  for i in range(g.rows):
    print("---------------------------")
    for j in range(g.cols):
      a = P.get((i,j), ' ')
      print("  %s  |" % a, end="")
    print("")


In [36]:
def get_transition_probs_and_rewards(grid):
  ### define transition probabilities and grid ###
  # the key is (s, a, s'), the value is the probability
  # that is, transition_probs[(s, a, s')] = p(s' | s, a)
  # any key NOT present will considered to be impossible (i.e. probability 0)
  transition_probs = {}

  # to reduce the dimensionality of the dictionary, we'll use deterministic
  # rewards, r(s, a, s')
  # note: you could make it simpler by using r(s') since the reward doesn't
  # actually depend on (s, a)
  rewards = {}

  for i in range(grid.rows):
    for j in range(grid.cols):
      s = (i, j)
      if not grid.is_terminal(s):
        for a in ACTION_SPACE:
          s2 = grid.get_next_state(s, a)
          transition_probs[(s, a, s2)] = 1
          if s2 in grid.rewards:
            rewards[(s, a, s2)] = grid.rewards[s2]

  return transition_probs, rewards

In [37]:
def greedy_policy_from_values(V, grid, transition_probs, rewards, gamma=GAMMA):
  policy = {}
  for s in grid.actions.keys():
    best_a = None
    best_v = float('-inf')
    for a in ACTION_SPACE:
      v = 0.0
      for s2 in grid.all_states():
        r = rewards.get((s, a, s2), 0)
        p = transition_probs.get((s, a, s2), 0)
        v += p * (r + gamma * V[s2])
      if v > best_v:
        best_v = v
        best_a = a
    policy[s] = best_a
  return policy


In [38]:
grid = standard_grid()
transition_probs, rewards = get_transition_probs_and_rewards(grid)

# print rewards
print("rewards:")
print_values(grid.rewards, grid)

# initialize V(s)
V = {}
states = grid.all_states()
for s in states:
  V[s] = 0

it = 0

# print iteration 0 (initialization)
print(f"\n=== Iteration {it} ===")
print("values:")
print_values(V, grid)
print("policy (greedy w.r.t. current V):")
print_policy(greedy_policy_from_values(V, grid, transition_probs, rewards), grid)

rewards:
---------------------------
 0.00| 0.00| 0.00| 1.00|
---------------------------
 0.00| 0.00| 0.00|-1.00|
---------------------------
 0.00| 0.00| 0.00| 0.00|

=== Iteration 0 ===
values:
---------------------------
 0.00| 0.00| 0.00| 0.00|
---------------------------
 0.00| 0.00| 0.00| 0.00|
---------------------------
 0.00| 0.00| 0.00| 0.00|
policy (greedy w.r.t. current V):
---------------------------
  U  |  U  |  R  |     |
---------------------------
  U  |     |  U  |     |
---------------------------
  U  |  U  |  U  |  D  |


In [39]:
############################### Implementation 1 ################################
it = 0
policy = {}

while True:
    biggest_change = 0.0

    for s in grid.all_states():
        if grid.is_terminal(s):
            continue

        old_v = V[s]
        best_a = None
        best_val = float('-inf')

        # if actions are state-dependent, use grid.actions.get(s, [])
        for a in ACTION_SPACE:
            v = 0.0
            for s2 in grid.all_states():
                r = rewards.get((s, a, s2), 0.0)
                p = transition_probs.get((s, a, s2), 0.0)
                v += p * (r + GAMMA * V[s2])

            if v > best_val:
                best_val = v
                best_a = a

        V[s] = best_val
        policy[s] = best_a              # <-- store argmax while updating V
        biggest_change = max(biggest_change, abs(old_v - V[s]))

    it += 1
    if biggest_change < SMALL_ENOUGH:
        break

# our goal here is to verify that we get the same answer as with policy iteration
print("values:")
print_values(V, grid)
print("policy:")
print_policy(policy, grid)
# (Optional) one more pass to recompute policy from the final V for safety/clarity.


values:
---------------------------
 0.81| 0.90| 1.00| 0.00|
---------------------------
 0.73| 0.00| 0.90| 0.00|
---------------------------
 0.66| 0.73| 0.81| 0.73|
policy:
---------------------------
  R  |  R  |  R  |     |
---------------------------
  U  |     |  U  |     |
---------------------------
  U  |  R  |  U  |  L  |


In [40]:
############################### Implementation 2 ################################
# repeat until convergence
# V[s] = max[a]{ sum[s',r] { p(s',r|s,a)[r + gamma*V[s']] } }
it = 0
while True:
    biggest_change = 0
    for s in grid.all_states():
        if not grid.is_terminal(s):
            old_v = V[s]
            new_v = float('-inf')

            for a in ACTION_SPACE:
                v = 0
                for s2 in grid.all_states():
                    # reward is a function of (s, a, s'), 0 if not specified
                    r = rewards.get((s, a, s2), 0)
                    v += transition_probs.get((s, a, s2), 0) * (r + GAMMA * V[s2])

                # keep v if it's better
                if v > new_v:
                    new_v = v

            V[s] = new_v
            biggest_change = max(biggest_change, np.abs(old_v - V[s]))

    it += 1
    if biggest_change < SMALL_ENOUGH:
        break

# find a policy that leads to optimal value function
policy = {}
for s in grid.actions.keys():
    best_a = None
    best_value = float('-inf')
    # loop through all possible actions to find the best current action
    for a in ACTION_SPACE:
        v = 0
        for s2 in grid.all_states():
            # reward is a function of (s, a, s'), 0 if not specified
            r = rewards.get((s, a, s2), 0)
            v += transition_probs.get((s, a, s2), 0) * (r + GAMMA * V[s2])

        # best_a is the action associated with best_value
        if v > best_value:
            best_value = v
            best_a = a
    policy[s] = best_a

# our goal here is to verify that we get the same answer as with policy iteration
print("values:")
print_values(V, grid)
print("policy:")
print_policy(policy, grid)

values:
---------------------------
 0.81| 0.90| 1.00| 0.00|
---------------------------
 0.73| 0.00| 0.90| 0.00|
---------------------------
 0.66| 0.73| 0.81| 0.73|
policy:
---------------------------
  R  |  R  |  R  |     |
---------------------------
  U  |     |  U  |     |
---------------------------
  U  |  R  |  U  |  L  |
