In [5]:
import numpy as np
import matplotlib.pyplot as plt

In [6]:
class GridWorld():
    def __init__(self):
        self.state_space = [(x, y) for x in range(1,5) 
                                for y in range(1,4)]
        self.action_space = ["N", "S", "E", "W"]
        self.action_probs = {"N": {"N": 0.8, "E": 0.1, "W": 0.1}, "S": {"S": 0.8, "E": 0.1, "W": 0.1}, 
                             "E": {"E": 0.8, "N": 0.1, "S": 0.1}, "W": {"W": 0.8, "N": 0.1, "S": 0.1}}
        
#   Get the next state and the probability of transitioning to that state based off of
#   a current state, an action, and the probability of that action
    def get_next_state(self, state, action_prob, next_s, prob_next_s, a):
        if (a == "N" and state[1] + 1 <= 3 and (state[0], state[1] + 1) != (2,2)):
                next_s.append((state[0], state[1] + 1))
                prob_next_s.append(action_prob)
        elif (a == "S" and state[1] - 1 >= 1 and (state[0], state[1] - 1) != (2,2)):
                next_s.append((state[0], state[1] - 1))
                prob_next_s.append(action_prob)
        elif (a == "E" and state[0] + 1 <= 4 and (state[0] + 1, state[1]) != (2,2)):
                next_s.append((state[0] + 1, state[1]))
                prob_next_s.append(action_prob)
        elif (a == "W" and state[0] - 1 >= 1 and (state[0] - 1, state[1]) != (2,2)):
                next_s.append((state[0] - 1, state[1])) 
                prob_next_s.append(action_prob)
        else:
            next_s.append(state)
            prob_next_s.append(action_prob)
        return next_s, prob_next_s

#   Construct and return the list of the next possible states and the list of their
#   probabilities based on current state and action. Also checks it current state
#   is either terminal or a wall, and returns different lists accordingly
    def get_next_states_probs(self, state, action):
        next_s = []
        prob_next_s = []
        if (state == (4,3) or state == (4,2)):
            next_s.append(state)
            prob_next_s.append(1)
        elif (state == (2,2)):
            return [], []
        else:
            action_probs = self.action_probs[action]
            for a in action_probs:
                next_s, prob_next_s = self.get_next_state(state, action_probs[a], next_s, prob_next_s, a)
        return next_s, prob_next_s
    
#   Gets and returns the transition reward. Reward depends on whether current
#   state is terminal or non-terminal and if user has specified a staying alive
#   reward greater than 0
    def get_transition_reward(self, reward, state):
        if (state == (4,3)):
            return 1
        elif (state == (4,2)):
            return -1
        else:
            return reward

In [77]:
# Finds the optimal policy for the grid world
def find_optimal_policy(env, gamma, alive_reward, num_iter):
    v = {s: 0 for s in env.state_space}
    action_values = {state: {"N": 0, "S": 0, "E": 0, "W": 0} for state in env.state_space}
    for i in range(num_iter):
        v_prime = {s: 0 for s in env.state_space}
        for s in env.state_space:
            if (v[s] != 1.0 and v[s] != -1.0):
                q_max = float('-inf')
                for a in env.action_space:
                    q_sa = 0
                    next_s, prob_next_s = env.get_next_states_probs(s, a)
                    for i in range(len(next_s)):
                        s_prime = next_s[i]
                        prob_next_s_r = prob_next_s[i]
                        r = env.get_transition_reward(alive_reward, s)
                        q_sa += prob_next_s_r * (r + gamma * v[s_prime])
                    action_values[s][a] = q_sa
                    q_max = max(q_max, q_sa)
                v_prime[s] = q_max
            else:
                v_prime[s] = v[s]
        v = v_prime
    print("State Values:")
    for k in v.keys():
        print(k,':', round(v[k], 2))
    print("")
    print("Action Values:")
    for sa in action_values:
        cell_values = str(sa)
        cell_values += " "
        for ssa in action_values[sa]:
            cell_values += str(ssa)
            cell_values += ":"
            cell_values += str(round(action_values[sa][ssa], 2))
            cell_values += " "
        print(cell_values)
        print("")

In [78]:
grid = GridWorld()
find_optimal_policy(grid, 0.9, 0, 100)

State Values:
(1, 1) : 0.49
(1, 2) : 0.57
(1, 3) : 0.64
(2, 1) : 0.43
(2, 2) : 0
(2, 3) : 0.74
(3, 1) : 0.48
(3, 2) : 0.57
(3, 3) : 0.85
(4, 1) : 0.28
(4, 2) : -1.0
(4, 3) : 1.0

Action Values:
(1, 1) N:0.49 S:0.44 E:0.41 W:0.45 

(1, 2) N:0.57 S:0.46 E:0.51 W:0.51 

(1, 3) N:0.59 S:0.53 E:0.64 W:0.57 

(2, 1) N:0.4 S:0.4 E:0.42 W:0.43 

(2, 2) N:0 S:0 E:0 W:0 

(2, 3) N:0.67 S:0.67 E:0.74 W:0.6 

(3, 1) N:0.48 S:0.41 E:0.29 W:0.4 

(3, 2) N:0.57 S:0.3 E:-0.6 W:0.53 

(3, 3) N:0.77 S:0.57 E:0.85 W:0.66 

(4, 1) N:-0.65 S:0.27 E:0.13 W:0.28 

(4, 2) N:-1.0 S:-1.0 E:-1.0 W:-1.0 

(4, 3) N:1.0 S:1.0 E:1.0 W:1.0 

