In [2]:
import numpy as np

In [16]:
# the grid gets windy, meaning actions are not deterministic but stochastic
class WindyGrid:
    def __init__(self, rows, cols, start):
        self.rows = rows
        self.cols = cols
        self.pos = start

    def initialize(self, rewards, actions, probs):
        # rewards are a dict of (row, col): reward
        # actions should be a dict of (row, col): list of possible actions
        # probs should be a dictionary within dictionary ((1, 2): 'U'): {(0, 2): 0.5, (1, 3): 0.5}
        self.rewards = rewards
        self.actions = actions
        self.probs = probs

    def set_position(self, p):
        self.pos = p

    def get_position(self):
        return self.pos

    def is_terminal(self, p):
        return p not in self.actions
    
    def move(self, action):
        pos = self.pos

        next_state_probs = self.probs[(pos, action)]
        next_states = list(next_state_probs.keys())
        next_probs = list(next_state_probs.values())
        s2 = np.random.choice(next_states, p=next_probs)

        self.pos = s2
        return self.rewards.get(s2, 0)

    def game_over(self):
        return self.pos not in self.actions

    def all_states(self):
        return set(self.actions.keys()) | set(self.rewards.keys())

def get_windy_grid():
    g = WindyGrid(3, 4, (2, 0))
    rewards = {(0, 3): 1.0, (1, 3): -1.0}
    actions = {
        (0, 0): ('D', 'R'),
        (0, 1): ('L', 'R'),
        (0, 2): ('L', 'D', 'R'),
        (1, 0): ('U', 'D'),
        (1, 2): ('U', 'D', 'R'),
        (2, 0): ('U', 'R'),
        (2, 1): ('L', 'R'),
        (2, 2): ('L', 'R', 'U'),
        (2, 3): ('L', 'U')
    }

    probs = {
        ((2, 0), 'U'): {(1, 0): 1.0},
        ((2, 0), 'D'): {(2, 0): 1.0},
        ((2, 0), 'L'): {(2, 0): 1.0},
        ((2, 0), 'R'): {(2, 1): 1.0},
        ((1, 0), 'U'): {(0, 0): 1.0},
        ((1, 0), 'D'): {(2, 0): 1.0},
        ((1, 0), 'L'): {(1, 0): 1.0},
        ((1, 0), 'R'): {(1, 0): 1.0},
        ((0, 0), 'U'): {(0, 0): 1.0},
        ((0, 0), 'D'): {(1, 0): 1.0},
        ((0, 0), 'L'): {(0, 0): 1.0},
        ((0, 0), 'R'): {(0, 1): 1.0},
        ((0, 1), 'U'): {(0, 1): 1.0},
        ((0, 1), 'D'): {(0, 1): 1.0},
        ((0, 1), 'L'): {(0, 0): 1.0},
        ((0, 1), 'R'): {(0, 2): 1.0},
        ((0, 2), 'U'): {(0, 2): 1.0},
        ((0, 2), 'D'): {(1, 2): 1.0},
        ((0, 2), 'L'): {(0, 1): 1.0},
        ((0, 2), 'R'): {(0, 3): 1.0},
        ((2, 1), 'U'): {(2, 1): 1.0},
        ((2, 1), 'D'): {(2, 1): 1.0},
        ((2, 1), 'L'): {(2, 0): 1.0},
        ((2, 1), 'R'): {(2, 2): 1.0},
        ((2, 2), 'U'): {(1, 2): 1.0},
        ((2, 2), 'D'): {(2, 2): 1.0},
        ((2, 2), 'L'): {(2, 1): 1.0},
        ((2, 2), 'R'): {(2, 3): 1.0},
        ((2, 3), 'U'): {(1, 3): 1.0},
        ((2, 3), 'D'): {(2, 3): 1.0},
        ((2, 3), 'L'): {(2, 2): 1.0},
        ((2, 3), 'R'): {(2, 3): 1.0},
        ((1, 2), 'U'): {(0, 2): 0.5, (1, 3): 0.5},
        ((1, 2), 'D'): {(2, 2): 1.0},
        ((1, 2), 'L'): {(1, 2): 1.0},
        ((1, 2), 'R'): {(1, 3): 1.0}
    }
    g.initialize(rewards, actions, probs)
    return g

In [25]:
# Define helper functions to print value and policy
def printValues(values, g):
    # values are a dictionary of tuples with the value being the probability
    # g is the gridWorld
    for i in range(g.rows):
        print("-------------------------")
        for j in range(g.cols):
            v = values.get((i, j), 0)
            if v >= 0:
                print(" %.2f|" % v, end="")
            else:
                print("%.2f|" % v, end="")
        print("")
    

def printPolicy(policy, g):
    for i in range(g.rows):
        print("-------------------------")
        for j in range(g.cols):
            p = policy.get((i, j), ' ')
            print(" %s |" % p, end="")
        print("")

transition_probs = {}
rewards = {}

ACTIONS = ['U', 'D', 'L', 'R']

grid = get_windy_grid()
for (s, a), v in grid.probs.items():
    for s2, p in v.items():
        transition_probs[(s, a, s2)] = p
        rewards[(s, a, s2)] = grid.rewards.get(s2, 0)

# small enough value for convergence
THRESHOLD = 1e-3

# probabilistic policy
policy = {
    (2, 0): {'U': 0.5, 'R': 0.5},
    (1, 0): {'U': 1.0},
    (0, 0): {'R': 1.0},
    (0, 1): {'R': 1.0},
    (0, 2): {'R': 1.0},
    (1, 2): {'U': 1.0},
    (2, 1): {'R': 1.0},
    (2, 2): {'U': 1.0},
    (2, 3): {'L': 1.0}
}
printPolicy(policy, grid)

V = {}
for s in grid.all_states():
    V[s] = 0
gamma = 0.9 # discount

it = 0
while True:
    biggest_change = 0
    for s in grid.all_states():
        if not grid.is_terminal(s):
            old_v = V[s]
            new_v = 0
            for a in ACTIONS:
                for s2 in grid.all_states():
                    action_prob = policy[s].get(a, 0)

                    r = rewards.get((s, a, s2), 0)
                    new_v += action_prob * transition_probs.get((s, a, s2), 0) * (r + gamma * V[s2])

            V[s] = new_v
            biggest_change = max(biggest_change, np.abs(old_v - V[s]))

    print(f"iteration: {it}, biggest change: {biggest_change}")
    printValues(V, grid)
    it += 1

    if biggest_change < THRESHOLD:
        break;
print("\n\n")

-------------------------
 {'R': 1.0} | {'R': 1.0} | {'R': 1.0} |   |
-------------------------
 {'U': 1.0} |   | {'U': 1.0} |   |
-------------------------
 {'U': 0.5, 'R': 0.5} | {'R': 1.0} | {'U': 1.0} | {'L': 1.0} |
iteration: 0, biggest change: 1.0
-------------------------
 0.00| 0.00| 1.00| 0.00|
-------------------------
 0.00| 0.00|-0.50| 0.00|
-------------------------
 0.00| 0.00|-0.45| 0.00|
iteration: 1, biggest change: 0.9
-------------------------
 0.81| 0.90| 1.00| 0.00|
-------------------------
 0.73| 0.00|-0.05| 0.00|
-------------------------
-0.18|-0.41|-0.04|-0.41|
iteration: 2, biggest change: 0.4920750000000001
-------------------------
 0.81| 0.90| 1.00| 0.00|
-------------------------
 0.73| 0.00|-0.05| 0.00|
-------------------------
 0.31|-0.04|-0.04|-0.04|
iteration: 3, biggest change: 0
-------------------------
 0.81| 0.90| 1.00| 0.00|
-------------------------
 0.73| 0.00|-0.05| 0.00|
-------------------------
 0.31|-0.04|-0.04|-0.04|



