In [1]:
import numpy as np
import matplotlib.pyplot as plt

In [2]:
ACTION_SPACE = ('U', 'D', 'L', 'R')
class Grid:  # Environment

    def __init__(self,rows,cols,start):
        self.rows = rows
        self.cols = cols
        self.i = start[0]
        self.j = start[1]

    def set(self, rewards, actions):

        self.rewards = rewards
        self.actions = actions

    def set_state(self, s):
        self.i = s[0]
        self.j = s[1]

    def current_state(self):
        return (self.i, self.j)

    def is_terminal(self, s):
        return s not in self.actions
    
    def move(self, action):
    # check if legal move first
        if action in self.actions[(self.i, self.j)]:
            if action == 'U':
                self.i -= 1
            elif action == 'D':
                self.i += 1
            elif action == 'R':
                self.j += 1
            elif action == 'L':
                self.j -= 1
        return self.rewards.get((self.i, self.j), 0)

    def game_over(self):
        # returns true if game is over, else false
        # true if we are in a state where no actions are possible
        return (self.i, self.j) not in self.actions

    def all_states(self):
        return set(self.actions.keys()) | set(self.rewards.keys())


In [3]:
def standard_grid():

  # .  .  .  1
  # .  x  . -1
  # s  .  .  .

    grid = Grid(3, 4, (2, 0))
    rewards = {(0, 3): 1, (1, 3): -1}
    actions = {
        (0, 0): ('D', 'R'),
        (0, 1): ('L', 'R'),
        (0, 2): ('L', 'D', 'R'),
        (1, 0): ('U', 'D'),
        (1, 2): ('U', 'D', 'R'),
        (2, 0): ('U', 'R'),
        (2, 1): ('L', 'R'),
        (2, 2): ('L', 'R', 'U'),
        (2, 3): ('L', 'U'),
        }
    grid.set(rewards, actions)
    return grid

In [4]:
def print_values(V, g):
    for i in range(g.rows):
        print("---------------------------")
        for j in range(g.cols):
            v = V.get((i, j), 0)
            if v >= 0:
                print(" %.2f|" % v, end="")
            else:
                print("%.2f|" % v, end="")  
        print("")


def print_policy(P, g):
    for i in range(g.rows):
        print("---------------------------")
        for j in range(g.cols):
            a = P.get((i, j), " ")
            print("  %s  |" % a, end="")
        print("")


In [5]:
SMALL_ENOUGH = 1e-3
GAMMA = 0.9
ALPHA = 0.1
ALL_POSSIBLE_ACTIONS = ('U', 'D', 'L', 'R')

In [6]:
def take_action(policy,state, eps=0.1):
    a = policy[state]
    if np.random.random() < (1 - eps):
        return a
    else:
        return np.random.choice(ALL_POSSIBLE_ACTIONS)

In [7]:
def play_game(grid, policy):
    state = (2, 0)
    grid.set_state(state)
    states_and_rewards = [(state, 0)] # list of tuples of (state, reward)
    while not grid.game_over():
        a = take_action(policy,state,0.1)
        reward = grid.move(a)
        next_state = grid.current_state()
        states_and_rewards.append((next_state, reward))
        
    return states_and_rewards


In [8]:
grid = standard_grid()

# print rewards
print("rewards:")
print_values(grid.rewards, grid)

# state -> action
policy = {
(2, 0): 'U',
(1, 0): 'U',
(0, 0): 'R',
(0, 1): 'R',
(0, 2): 'R',
(1, 2): 'R',
(2, 1): 'R',
(2, 2): 'R',
(2, 3): 'U',
}

# initialize V(s) and returns
V = {}
states = grid.all_states()
for s in states:
    V[s] = 0

# repeat until convergence
for index in range(20000):
    # generate an episode using pi
    states_and_rewards = play_game(grid, policy)
    for t in range(len(states_and_rewards) - 1):
        state, _ = states_and_rewards[t]
        next_state, reward = states_and_rewards[t+1]
        # we will update V(s) AS we experience the episode
        V[state] = V[state] + ALPHA*(reward + GAMMA*V[next_state] - V[state])

print("values:")
print_values(V, grid)
print("policy:")
print_policy(policy, grid)

rewards:
---------------------------
 0.00| 0.00| 0.00| 1.00|
---------------------------
 0.00| 0.00| 0.00|-1.00|
---------------------------
 0.00| 0.00| 0.00| 0.00|
values:
---------------------------
 0.01| 0.05| 0.32| 0.00|
---------------------------
 0.01| 0.00| 0.18| 0.00|
---------------------------
 0.01| 0.03| 0.09|-0.83|
policy:
---------------------------
  R  |  R  |  R  |     |
---------------------------
  U  |     |  R  |     |
---------------------------
  U  |  R  |  R  |  U  |
