In [12]:
import numpy as np
import random

In [5]:
blocked = np.zeros((5,5),dtype=bool)

for y, x in ([(1,1),(1,2),(2,1),(3,1)]):
    blocked[y][x] = True

print(blocked)

[[False False False False False]
 [False  True  True False False]
 [False  True False False False]
 [False  True False False False]
 [False False False False False]]


In [6]:
rewards = np.zeros((5,5),dtype=int)

rewards[2][2] = 1
rewards[1][3] = -1
rewards[3][2] = -1

print(rewards)

[[ 0  0  0  0  0]
 [ 0  0  0 -1  0]
 [ 0  0  1  0  0]
 [ 0  0 -1  0  0]
 [ 0  0  0  0  0]]


In [160]:
class gridworld():
    MOVE_UP = 0
    MOVE_RIGHT = 1
    MOVE_DOWN = 2
    MOVE_LEFT = 3

    def __init__(self, blocked, rewards, size = (5,5), punishment = -0.04):
        super(gridworld, self).__init__()
        if size != blocked.shape or size != rewards.shape:
            raise ValueError(f'Size does not fit blocked and/or rewards')

        self.size = size
        self.blocked = blocked
        self.rewards = rewards
        self.punishment = punishment

        self.state = self.reset()
        self.q_values = np.zeros((4,self.size[0],self.size[1]),dtype=float)
        self.policy = np.array([[random.randint(0,3) for row in range(size[0])] for _ in range(size[1])])

    def reset(self):
        self.state = (0,0)
        return self.state

    def step(self, action):
        action = self.apply_transition_uncertainty(action)

        new_state = self.make_transition(action)

        reward = self.punishment + self.rewards[new_state]

        terminal = self.is_terminal_state(new_state)

        return new_state, reward, terminal

    def visualize(self):
        for y in range(self.size[0]):
            echo = []
            for x in range(self.size[1]):
                if self.state[0] == y and self.state[1] == x:
                    echo.append("!")
                elif blocked[y][x]:
                    echo.append("X")
                else:
                    echo.append(str(rewards[y][x]))

            print(" ".join(echo))

    def apply_transition_uncertainty(self, action):
        if random.uniform(0, 1) <= 0.1:
            action -= 1
        elif random.uniform(0, 1) >= 0.9:
            action += 1

        if action < 0:
            action = self.MOVE_LEFT
        if action > 3:
            action = self.MOVE_UP

        return action

    def make_transition(self, action):
        new_state = list(self.state)

        if action == self.MOVE_UP:
            new_state[0] -= 1
        elif action == self.MOVE_RIGHT:
            new_state[1] += 1
        elif action == self.MOVE_DOWN:
            new_state[0] += 1
        elif action == self.MOVE_LEFT:
            new_state[1] -= 1

        try:
            new_state = tuple(new_state)
            if not self.blocked[new_state] and new_state[0] >= 0 and new_state[1] >= 0:
                self.state = new_state
        except IndexError:
            return self.state


        return self.state

    def is_terminal_state(self, state):
        return self.rewards[state] == 1

In [161]:
world = gridworld(blocked, rewards, (5,5))
learning_rate = 0.1
gamma = 0.9
sarsa_steps = 4

In [168]:

episode = 1
while True:

    state = world.reset()

    states_count = 1
    while True:
        # print(state)
        action = world.policy[state]
        # print(action)

        new_state, reward, terminal = world.step(action)

        # print(new_state)
        # print(reward)
        # print(terminal)


        next_action = world.policy[new_state]

        world.q_values[action][state] = world.q_values[action][state] + learning_rate * (reward + world.q_values[next_action][new_state] - world.q_values[action][state])


        action_values = world.q_values[:,state[0],state[1]]
        greedy = np.argsort(action_values)[::-1][0]

        world.policy[state] = greedy

        state = new_state

        if terminal:
            print(world.q_values)
            print(world.policy)
            break

        states_count +=1
        if states_count > 500:
            break


    episode += 1
    if episode > 500:
        break

[[[-0.34569651 -0.32310993 -0.28399442 -0.30057268 -0.16055346]
  [ 0.07458355  0.          0.         -0.1084     -0.23025036]
  [-0.28777039  0.          0.          0.         -0.0127564 ]
  [-0.25806098  0.         -0.104       0.68428829 -0.014516  ]
  [-0.22025087 -0.18425519 -0.1976      0.39228444 -0.02277468]]

 [[ 0.11054318  0.20274473  0.30086642  0.37053561 -0.164     ]
  [-0.3203135   0.          0.         -0.01095265 -0.19452456]
  [-0.2829203   0.          0.          0.         -0.0084    ]
  [-0.25826847  0.          0.3438676  -0.0084     -0.016     ]
  [-0.0676148   0.0317282   0.1782084  -0.01944386 -0.02      ]]

 [[-0.34275483 -0.32589832 -0.28350619 -0.32498943  0.52444204]
  [-0.31938522  0.          0.          0.69427627  0.67479841]
  [-0.26606306  0.          0.         -0.004      -0.012324  ]
  [-0.17789829  0.         -0.005996   -0.008      -0.01524   ]
  [-0.21891798 -0.18158763 -0.19058121 -0.02       -0.02293124]]

 [[-0.3432358  -0.32503273 -0.2847