In [1]:
import numpy as np
import random

In [2]:
class gridworld():
    MOVE_UP = 0
    MOVE_RIGHT = 1
    MOVE_DOWN = 2
    MOVE_LEFT = 3

    def __init__(self, blocked, rewards, size = (5,5), punishment = -0.04):
        super(gridworld, self).__init__()
        if size != blocked.shape or size != rewards.shape:
            raise ValueError(f'Size does not fit blocked and/or rewards')

        self.size = size
        self.blocked = blocked
        self.rewards = rewards
        self.punishment = punishment

        self.state = self.reset()
        self.q_values = np.zeros((4,self.size[0],self.size[1]),dtype=float)
        self.policy = np.array([[random.randint(0,3) for row in range(size[0])] for _ in range(size[1])])
        self.policy[blocked] = -1
        print(self.policy)

    def reset(self):
        self.state = (0,0)
        return self.state

    def step(self, action):
        action = self.apply_transition_uncertainty(action)

        new_state = self.make_transition(action)

        reward = self.punishment + self.rewards[new_state]

        terminal = self.is_terminal_state(new_state)

        return new_state, reward, terminal

    def visualize(self):
        for y in range(self.size[0]):
            echo = []
            for x in range(self.size[1]):
                if self.state[0] == y and self.state[1] == x:
                    echo.append("!")
                elif blocked[y][x]:
                    echo.append("X")
                else:
                    echo.append(str(rewards[y][x]))

            print(" ".join(echo))

        print(self.policy)
        print(self.q_values)

    def apply_transition_uncertainty(self, action):
        #there's a 10% chance to move left or right from the desired direction, respectively
        if random.uniform(0, 1) <= 0.1:
            action -= 1
        elif random.uniform(0, 1) >= 0.9:
            action += 1

        #correct not available actions
        if action < 0:
            action = self.MOVE_LEFT
        if action > 3:
            action = self.MOVE_UP

        return action

    def make_transition(self, action):
        new_state = list(self.state)

        #get new state coordinates depending on move direction
        if action == self.MOVE_UP:
            new_state[0] -= 1
        elif action == self.MOVE_RIGHT:
            new_state[1] += 1
        elif action == self.MOVE_DOWN:
            new_state[0] += 1
        elif action == self.MOVE_LEFT:
            new_state[1] -= 1

        #if we end up in a state not featured (possible because of uncertainty), the exception is raised and we return the initial state
        try:
            new_state = tuple(new_state)
            if not self.blocked[new_state] and new_state[0] >= 0 and new_state[1] >= 0:
                self.state = new_state
        except IndexError:
            return self.state


        return self.state

    def is_terminal_state(self, state):
        return self.rewards[state] == 1

In [3]:
def n_step_sarsa(world, steps, learning_rate, gamma, max_episodes):
    episode = 1
    while True:

        state = world.reset()

        states_count = 1
        while True:
            action = world.policy[state]

            next_action = action
            new_state, reward, terminal = world.step(next_action)

            td_error = 0
            for i in range(steps):
                td_error += (gamma ** i) * reward

                if terminal:
                    break

                next_action = world.policy[new_state]
                new_state, reward, terminal = world.step(next_action)

            world.q_values[action][state] = world.q_values[action][state] + learning_rate * (td_error + world.q_values[next_action][new_state] - world.q_values[action][state])

            #policy evaluation
            action_values = world.q_values[:,state[0],state[1]]
            greedy = np.argsort(action_values)[::-1][0]

            world.policy[state] = greedy

            state = new_state

            if terminal:
                break

            #for safety reasons
            states_count +=1
            if states_count > 10000:
                break

        episode += 1
        if episode > max_episodes:
            break

In [4]:
blocked = np.zeros((5,5),dtype=bool)

for y, x in ([(1,1),(1,2),(2,1),(3,1)]):
    blocked[y][x] = True

print(blocked)

[[False False False False False]
 [False  True  True False False]
 [False  True False False False]
 [False  True False False False]
 [False False False False False]]


In [5]:
rewards = np.zeros((5,5),dtype=int)

rewards[2][2] = 1
rewards[1][3] = -1
rewards[3][2] = -1

print(rewards)

[[ 0  0  0  0  0]
 [ 0  0  0 -1  0]
 [ 0  0  1  0  0]
 [ 0  0 -1  0  0]
 [ 0  0  0  0  0]]


In [6]:
world = gridworld(blocked, rewards, (5,5))

[[ 1  0  1  1  3]
 [ 0 -1 -1  0  3]
 [ 2 -1  1  2  2]
 [ 0 -1  2  3  2]
 [ 2  2  3  0  0]]


In [7]:
sarsa_steps = 4
learning_rate = 0.1
gamma = 0.9
max_episodes = 100000

n_step_sarsa(world, sarsa_steps, learning_rate, gamma, max_episodes)

In [8]:
world.visualize()

0 0 0 0 0
0 X X -1 0
0 X ! 0 0
0 X -1 0 0
0 0 0 0 0
[[ 1  1  1  1  2]
 [ 0 -1 -1  2  2]
 [ 2 -1  1  3  3]
 [ 2 -1  0  0  3]
 [ 1  1  1  0  3]]
[[[-0.78071006 -0.68139524 -0.63771043 -0.79005187 -0.55425271]
  [-0.47626265  0.          0.         -0.4738458  -0.50685141]
  [-0.71229394  0.          0.         -0.13982338 -0.28896361]
  [-0.68149464  0.          0.55719536  0.54308488 -0.18178869]
  [-0.56811521 -0.56219454 -0.30468983  0.13009025 -0.38998104]]

 [[-0.34848312 -0.32673055  0.12482777 -0.26532879 -0.544834  ]
  [-0.75754077  0.          0.         -0.32114744 -0.50346713]
  [-0.69489556  0.          0.         -0.08193906 -0.16479816]
  [-0.68008928  0.         -0.06498357 -0.1017536  -0.1521471 ]
  [-0.48617786 -0.39452752 -0.25114206 -0.33332346 -0.37127326]]

 [[-0.77688576 -0.68154873 -0.62850187 -0.83276451  0.25202026]
  [-0.75740817  0.          0.          0.35899635  0.38600861]
  [-0.69489005  0.          0.         -0.07102823 -0.29095304]
  [-0.63926276  0.   