# Snake
## DETH Project
---
<b>Kevin Heijboer &amp; Jurjen Verbruggen</b>

---
The goal of this project is to create a rational agent that can solve a simplified game of snake by using AI algorithms. These algorithms are also compared with each other.

# 1. Definition of the Environment

The snake environment is constructed in a grid. Each cell has x and y-coordinates. The environment consists of the following components:

- States: Each state consists of the position of the snake (head), the position of the apple and the direction of the snake. The direction of the snake can be any of the following:
    * Left
    * Right
    * Up
    * Down
$$\{(snake\_x, snake\_y),\;(apple\_x, apple\_y),\;snake\_direction\}$$
- Terminal states: When a terminal state is reached the environment is set to "done". A state is a terminal state if it meets any of the following conditions:
    * The snake is on the same position as the apple
    * The snake is on the same position as one of the walls
<br></br>
- Actions: In the game snake, the snake is always moving. This is the reason why each state has a "direction" of the snake. A snake can also not go backwards.
    * <b>Turn right:</b> the snake will go one cell to the right relative to where it was facing. The snake's direction changes.
    * <b>Turn left:</b> the snake will go one cell to the left relative to where it was facing. The snake's direction changes.
    * <b>Straight:</b> the snake will go one cell to the direction it was facing. The snake's direction stays the same
<br></br>
- Transitions: The snake has a 1.0 probabilty for every transition as long as the action results in a valid state. If the resulting state is invalid or game has already finished, the probability is 0.
<br></br>
- Rewards: The environment has three types of rewards:
    * The snake eats the apple: +1
    * The snake hits a wall: -1
    * The snake does neither: -0.04
    
    To stimulate the snake to move towards the apple, every state that is not a terminal state has a reward of -0.04.

The code below defines all the characteristics mentioned above.

Explanation of the methods:
- <b>reset():</b> resets the environment to the initial state
- <b>step(action):</b> calculates the new state based on the actio of the agent and returns the reward and whether the state is done.
- <b>calculate_transition(action, state):</b> helper method for calculating the new state. This method contains most of the logic for the snake's movement. 
- <b>render():</b> visualizes the current state of the environment in a simple grid.
<br></br>
- <b>get_possible_states():</b> returns all the possible states. The possible states are created based on the size of the grid and the amount of possible directions.
- <b>is_done(state):</b> determines whether a state is a terminal state or not.
- <b>get_reward(state):</b> calculates the reward based on the given state.
- <b>get_transition_prob(action, state):</b> returns the probabilty and new state of a state transitioning to another state based on the given action. This method uses the calculate_transition method to determine whether the new state is valid. Note: this method does not ignore terminal states, this is done in the get_actions(state) method.
- <b>get_actions(state):</b> returns a list of actions that can be performed in the given state. No actions are given for terminal states.

In [205]:
from enum import Enum
from random import randint, choice
from copy import copy

class Action(Enum):
    def __str__(self):
        return self.name
    TurnLeft = 1
    TurnRight = 2
    Straight = 3
    

class Direction(Enum):
    def __str__(self):
        return self.name
    Left = 1
    Right = 2
    Up = 3
    Down = 4

    
class SnakeEnvironment():
    def __init__(self, initial_state, reward_per_step=0.04, grid_size=5, gamma=0.9):
        self.__grid_size = grid_size
        self.__possible_states = []
        for snake_x in range(grid_size):
            for snake_y in range(grid_size):
                for apple_x in range(grid_size):
                    for apple_y in range(grid_size):
                        # apple position cannot be in the same position as a wall
                        if apple_x != 0 and apple_x != grid_size-1 and apple_y != 0 and apple_y != grid_size-1:
                            for direction in Direction:
                                s = {'snake_position': (snake_x, snake_y), 'apple_position': (apple_x, apple_y), 'snake_direction': direction}
                                self.__possible_states.append(s)
                                
        if initial_state in self.__possible_states:
            self.__initial_state = initial_state
        else:
            raise ValueError("Invalid initial state")
        self.__state = self.__initial_state
        self.__reward_per_step = reward_per_step
        self.__gamma = gamma

    def reset(self):
        self.__state = self.__initial_state
        return self.__state

    def __calculate_transition(self, action, old_state=None):
        if old_state is None:
            old_state = self.__state

        # determine position
        snake_x_old, snake_y_old = old_state['snake_position']
        apple_x, apple_y = old_state['apple_position']
        direction_old = old_state['snake_direction']

        if action == Action.Straight:
            if direction_old == Direction.Left:
                snake_x_new = snake_x_old - 1
                snake_y_new = snake_y_old
                direction_new = direction_old
            elif direction_old == Direction.Right:
                snake_x_new = snake_x_old + 1
                snake_y_new = snake_y_old
                direction_new = direction_old
            elif direction_old == Direction.Up:
                snake_x_new = snake_x_old
                snake_y_new = snake_y_old + 1
                direction_new = direction_old
            else: # direction_old == Direction.Down:
                snake_x_new = snake_x_old
                snake_y_new = snake_y_old - 1
                direction_new = direction_old
        elif action == Action.TurnLeft:
            if direction_old == Direction.Left:
                snake_x_new = snake_x_old
                snake_y_new = snake_y_old - 1
                direction_new = Direction.Down
            elif direction_old == Direction.Right:
                snake_x_new = snake_x_old
                snake_y_new = snake_y_old + 1
                direction_new = Direction.Up
            elif direction_old == Direction.Up:
                snake_x_new = snake_x_old - 1
                snake_y_new = snake_y_old
                direction_new = Direction.Left
            else: # direction_old == Direction.Down:
                snake_x_new = snake_x_old + 1
                snake_y_new = snake_y_old
                direction_new = Direction.Right
        else: # action == Action.TurnRight:
            if direction_old == Direction.Left:
                snake_x_new = snake_x_old
                snake_y_new = snake_y_old + 1
                direction_new = Direction.Up
            elif direction_old == Direction.Right:
                snake_x_new = snake_x_old
                snake_y_new = snake_y_old - 1
                direction_new = Direction.Down
            elif direction_old == Direction.Up:
                snake_x_new = snake_x_old + 1
                snake_y_new = snake_y_old
                direction_new = Direction.Right
            else: # direction_old == Direction.Down:
                snake_x_new = snake_x_old - 1
                snake_y_new = snake_y_old
                direction_new = Direction.Left


        new_state = {'snake_position': (snake_x_new, snake_y_new), 'apple_position': (apple_x, apple_y), 'snake_direction': direction_new}
        if new_state in self.__possible_states:
            return new_state
        else:
            return old_state  # state does not change
      
    def step(self, action):
        old_state = self.__state
        self.__state = self.__calculate_transition(action)  # state after action
        observation = self.__state  # environment is fully observable
        done = self.is_done()
        reward = self.get_reward(self.__state)
        info = {}  # optional debug info
        return observation, done, reward, info

    def render(self):
        for j in range(0, self.__grid_size):
            for i in range(0, self.__grid_size):
                print("+---", end="")
            print("+")
            for i in range(0, self.__grid_size):
                if (i, self.__grid_size-1-j) == self.__state['snake_position']:
                    print("| 🐍", end="")
                elif (i, self.__grid_size-1-j) == self.__state['apple_position']:
                    print("| 🍎", end="")
                # walls of grid
                elif i == 0 or i == self.__grid_size-1 or j == 0 or j == self.__grid_size-1:
                    print("| X ", end="")
                else:
                    print("|   ", end="")
            print("|")
        for i in range(0, self.__grid_size):
            print("+---", end="")
        print("+\n")

        if self.is_done():
            print("╔══════════╗\n"+
                  "║ Game done! ║\n"+
                  "╚══════════╝")

    
    def get_possible_states(self):
        return self.__possible_states
    
    def is_done(self, state=None):
        if state is None:
            state = self.__state

        snake_x, snake_y = state['snake_position']
        apple_x, apple_y = state['apple_position']
        
        if snake_x == apple_x and snake_y == apple_y:
            return True
        if snake_x == 0 or snake_x == self.__grid_size-1 or snake_y == 0 or snake_y == self.__grid_size-1:
            return True
        return False
    
    def get_reward(self, state):
        snake_x, snake_y = state['snake_position']
        apple_x, apple_y = state['apple_position']

        # Reward R(s) for every possible state
        if snake_x == apple_x and snake_y == apple_y:
            return 1.0
        # if state is one of the walls
        if snake_x == 0 or snake_x == self.__grid_size-1 or snake_y == 0 or snake_y == self.__grid_size-1:
            return -1.0
            
        return self.__reward_per_step
        
    def get_transition_prob(self, action, state, new_state=None):
        calculated_state = self.__calculate_transition(action, state)

        if new_state and calculated_state != new_state:
            return [(0.0, calculated_state)]

        if action: 
            return [(1.0, calculated_state)]
        else:
            return [(0.0, state)]

    def get_gamma(self):
        return self.__gamma

    def get_actions(self, state):
        # this returns the list of actions for a state except for the terminal states
        if self.is_done(state):
            return [None]
        else:
            return Action

    def get_grid_size(self):
        return self.__grid_size

    def current_state(self):
        return self.__state


## 1.1 Initializing the Environment

The SnakeEnvironment Class allows creation of an environment with the following parameters:
* Initial state: the initial positions of the snake and apple
* Reward per step: the value of the rewards used for stimulating the snake to move
* Grid size: determines the size of the grid where the snake can move
* Gamma: discounting factor. This makes sure that delayed rewards have less value compared to immediate ones. (this is set to 0.9 by default)

In [206]:
# example of creation of an environment in the default state
initial_state = {'snake_position': (1, 1), 'apple_position': (4, 4), 'snake_direction': Direction.Right}
env = SnakeEnvironment(initial_state=initial_state, grid_size=6, reward_per_step=-0.04)
env.render()

+---+---+---+---+---+---+
| X | X | X | X | X | X |
+---+---+---+---+---+---+
| X |   |   |   | 🍎| X |
+---+---+---+---+---+---+
| X |   |   |   |   | X |
+---+---+---+---+---+---+
| X |   |   |   |   | X |
+---+---+---+---+---+---+
| X | 🐍|   |   |   | X |
+---+---+---+---+---+---+
| X | X | X | X | X | X |
+---+---+---+---+---+---+



The snake has started with a direction to the right. So any "Straight" actions will move to the right.


Steps can now be executed on the environment like so:

In [207]:
print(env.step(Action.Straight))
print(env.step(Action.Straight))
print(env.step(Action.Straight))
env.render()

({'snake_position': (2, 1), 'apple_position': (4, 4), 'snake_direction': <Direction.Right: 2>}, False, -0.04, {})
({'snake_position': (3, 1), 'apple_position': (4, 4), 'snake_direction': <Direction.Right: 2>}, False, -0.04, {})
({'snake_position': (4, 1), 'apple_position': (4, 4), 'snake_direction': <Direction.Right: 2>}, False, -0.04, {})
+---+---+---+---+---+---+
| X | X | X | X | X | X |
+---+---+---+---+---+---+
| X |   |   |   | 🍎| X |
+---+---+---+---+---+---+
| X |   |   |   |   | X |
+---+---+---+---+---+---+
| X |   |   |   |   | X |
+---+---+---+---+---+---+
| X |   |   |   | 🐍| X |
+---+---+---+---+---+---+
| X | X | X | X | X | X |
+---+---+---+---+---+---+



The output shows the rewards per step and whether or not the game is done.

The following output shows when the snake eats the apple:

In [208]:
print(env.step(Action.TurnLeft))
print(env.step(Action.Straight))
print(env.step(Action.Straight))
env.render()

({'snake_position': (4, 2), 'apple_position': (4, 4), 'snake_direction': <Direction.Up: 3>}, False, -0.04, {})
({'snake_position': (4, 3), 'apple_position': (4, 4), 'snake_direction': <Direction.Up: 3>}, False, -0.04, {})
({'snake_position': (4, 4), 'apple_position': (4, 4), 'snake_direction': <Direction.Up: 3>}, True, 1.0, {})
+---+---+---+---+---+---+
| X | X | X | X | X | X |
+---+---+---+---+---+---+
| X |   |   |   | 🐍| X |
+---+---+---+---+---+---+
| X |   |   |   |   | X |
+---+---+---+---+---+---+
| X |   |   |   |   | X |
+---+---+---+---+---+---+
| X |   |   |   |   | X |
+---+---+---+---+---+---+
| X | X | X | X | X | X |
+---+---+---+---+---+---+

╔══════════╗
║ Game done! ║
╚══════════╝


## 1.2 Transitions

Transitions can be calculated by calling method get_transition_prob(action, state).

Below are some tests that are conducted in order to verify the working order of the environment.

The transition probability can still be calculated even if a state is a terminal state. However, the actions of that state are None.

In [209]:
initial_state = {'snake_position': (2, 2), 'apple_position': (4, 4), 'snake_direction': Direction.Right}
env = SnakeEnvironment(initial_state=initial_state, grid_size=6, reward_per_step=-0.04)

print("Probailities:")

# make a valid turn to the right
print("Turning right")
new_state = {'snake_position': (2, 1), 'apple_position': (4, 4), 'snake_direction': Direction.Down}
print(f"{env.current_state()} -> {Action.TurnRight} -> {new_state} has probability: {env.get_transition_prob(Action.TurnRight, env.current_state(), new_state)[0][0]}\n")

# make a valid turn to the left
env = SnakeEnvironment(initial_state=initial_state, grid_size=6, reward_per_step=-0.04)
print("Turning left")
new_state = {'snake_position': (2, 3), 'apple_position': (4, 4), 'snake_direction': Direction.Up}
print(f"{env.current_state()} -> {Action.TurnLeft} -> {new_state} has probability: {env.get_transition_prob(Action.TurnLeft, env.current_state(), new_state)[0][0]}\n")

# make a valid move straight
env = SnakeEnvironment(initial_state=initial_state, grid_size=6, reward_per_step=-0.04)
print("Moving straight")
new_state = {'snake_position': (3, 2), 'apple_position': (4, 4), 'snake_direction': Direction.Right}
print(f"{env.current_state()} -> {Action.Straight} -> {new_state} has probability: {env.get_transition_prob(Action.Straight, env.current_state(), new_state)[0][0]}\n")

Probailities:
Turning right
{'snake_position': (2, 2), 'apple_position': (4, 4), 'snake_direction': <Direction.Right: 2>} -> TurnRight -> {'snake_position': (2, 1), 'apple_position': (4, 4), 'snake_direction': <Direction.Down: 4>} has probability: 1.0

Turning left
{'snake_position': (2, 2), 'apple_position': (4, 4), 'snake_direction': <Direction.Right: 2>} -> TurnLeft -> {'snake_position': (2, 3), 'apple_position': (4, 4), 'snake_direction': <Direction.Up: 3>} has probability: 1.0

Moving straight
{'snake_position': (2, 2), 'apple_position': (4, 4), 'snake_direction': <Direction.Right: 2>} -> Straight -> {'snake_position': (3, 2), 'apple_position': (4, 4), 'snake_direction': <Direction.Right: 2>} has probability: 1.0



The output above is expected and verifies that the transitions work for the "good flow" of the snake. Below are some tests with invalid states that should have 0 probability.

In [210]:
# turn to the right but next state has wrong direction
print("Turning right but next state has wrong direction")
new_state = {'snake_position': (2, 1), 'apple_position': (4, 4), 'snake_direction': Direction.Left}
print(f"{env.current_state()} -> {Action.TurnRight} -> {new_state} has probability: {env.get_transition_prob(Action.TurnRight, env.current_state(), new_state)[0][0]}\n")

# turn to the left but end in a non-neighboring cell
env = SnakeEnvironment(initial_state=initial_state, grid_size=6, reward_per_step=-0.04)
print("Turning left")
new_state = {'snake_position': (2, 4), 'apple_position': (4, 4), 'snake_direction': Direction.Up}
print(f"{env.current_state()} -> {Action.TurnLeft} -> {new_state} has probability: {env.get_transition_prob(Action.TurnLeft, env.current_state(), new_state)[0][0]}\n")

# check the available actions for a teminal state
print("Checking available actions for a terminal state (hit a wall)")
terminal_state = {'snake_position': (0, 0), 'apple_position': (4, 4), 'snake_direction': Direction.Up}
print(f"State {terminal_state} has actions: {env.get_actions(terminal_state)}")

Turning right but next state has wrong direction
{'snake_position': (2, 2), 'apple_position': (4, 4), 'snake_direction': <Direction.Right: 2>} -> TurnRight -> {'snake_position': (2, 1), 'apple_position': (4, 4), 'snake_direction': <Direction.Left: 1>} has probability: 0.0

Turning left
{'snake_position': (2, 2), 'apple_position': (4, 4), 'snake_direction': <Direction.Right: 2>} -> TurnLeft -> {'snake_position': (2, 4), 'apple_position': (4, 4), 'snake_direction': <Direction.Up: 3>} has probability: 0.0

Checking available actions for a terminal state (hit a wall)
State {'snake_position': (0, 0), 'apple_position': (4, 4), 'snake_direction': <Direction.Up: 3>} has actions: [None]


## 2. Random agent

Now that the functionality of the agent is verified, it can be provided a concrete policy function  𝜋(𝑠)→𝑎  implementation that will determine the decisions of the snake based on its current state.

There snake game can be solved in different ways. One way is by using a random policy. This policy concludes the next best action of the agent based on pure randomness. Since there are many negative terminal states and only one positive terminal state, the chance of completing the game on pure randomness is low. The larger the size of the grid, the less likely it is that the snake reaches the apple.

The code below shows the agent with a random policy.

For consistency, the initial position of the snake is (2, 2) and the initial of position of the apple is always (4, 4). The size of the grid is also always 7 by 7.

In [211]:
import random

def policy_random(env, state):
    action = choice([a for a in env.get_actions(state)])
    return action

# create a random environment
initial_state = {'snake_position': (1, 1), 'apple_position': (4, 4), 'snake_direction': Direction.Right}
env = SnakeEnvironment(initial_state=initial_state, grid_size=7, reward_per_step=-0.04)
state = env.reset()
print('initial state: {}'.format(state))

total_reward = 0.0
done = False
nr_steps = 0
while not done:
    next_action = policy_random(env, state)
    state, done, reward, info = env.step(next_action)
    total_reward += reward
    nr_steps += 1
    print('action: {}\tstate: {}, reward: {:5.2f}'
          .format(next_action, state, reward))
print('Episode done after {} steps. total reward: {:6.2f}'.format(nr_steps, total_reward))

initial state: {'snake_position': (1, 1), 'apple_position': (4, 4), 'snake_direction': <Direction.Right: 2>}
action: TurnRight	state: {'snake_position': (1, 0), 'apple_position': (4, 4), 'snake_direction': <Direction.Down: 4>}, reward: -1.00
Episode done after 1 steps. total reward:  -1.00


As expected, the total reward of the agent in most cases is negative.

Below are some statistics over 200 episodes of the random agent:

In [221]:
from statistics import mean, stdev

def run_one_episode(policy):
    initial_state = {'snake_position': (1, 1), 'apple_position': (4, 4), 'snake_direction': Direction.Right}
    env = SnakeEnvironment(initial_state=initial_state, grid_size=7, reward_per_step=-0.04)
    state = env.reset()
    
    total_reward = 0.0
    done = False
    while not done:
        next_action = policy(env, state)
        state, done, reward, info = env.step(next_action)
        total_reward += reward
    return total_reward

def measure_performance(policy, nr_episodes=200):
    N = nr_episodes
    print('statistics over', N, 'episodes')
    all_rewards = []
    for _ in range(N):
        episode_reward = run_one_episode(policy)
        all_rewards.append(episode_reward)

    print()
    failed = 0
    finished = 0
    for n, episode_reward in enumerate(all_rewards):
        if episode_reward < 0:
            failed += 1
        else:
            finished += 1
    print('mean: {:6.2f}, sigma: {:6.2f}'.format(mean(all_rewards), stdev(all_rewards)))
    for n, episode_reward in enumerate(all_rewards[:5], 1):
        print('ep: {:2d}, total reward: {:5.2f}'.format(n, episode_reward))
    print('......')
    for n, episode_reward in enumerate(all_rewards[-5:], len(all_rewards)-5):
        print('ep: {:2d}, total reward: {:5.2f}'.format(n, episode_reward))
    print('\n')
    print('Number of episodes that reached the apple: {:6.2f}%'.format(finished/N*100))

    return mean(all_rewards), stdev(all_rewards), (finished/N*100)

results_random = measure_performance(policy_random) 

statistics over 200 episodes

mean:  -0.98, sigma:   0.48
ep:  1, total reward: -1.16
ep:  2, total reward: -1.04
ep:  3, total reward: -1.00
ep:  4, total reward: -1.16
ep:  5, total reward: -1.28
......
ep: 195, total reward: -1.04
ep: 196, total reward: -1.04
ep: 197, total reward: -1.16
ep: 198, total reward: -1.16
ep: 199, total reward: -1.04


Number of episodes that reached the apple:   7.00%


The statistics show that only about 15% of the episodes will reach the apple. These results will be compared to the results of the value iteration in the next sections.

# 3. Optimal decisions based on sums of rewards

A better way of solving snake is by obtaining an optimal policy. This can be obtained with value iteration.

First the utility is caclulated for every state. The utility of each state is the expected sum of discounted future rewards given we start in a state and follow a particular policy $\pi$. The utility of a state can be calculated by using the Bellman equation.

$$U(s)=R(s)+\gamma\max_{a\epsilon A(s)}\sum_{s'} P(s'\ |\ s,a)U(s')$$

Explanation of the code below:

The value_iteration method takes two parameters as input, the nevironment and epsilon. Epsilon is the maximum error allowed in the utility of any state. The method returns the utilities in the form of a dictionary.

Value iteration starts with any initial values for the utilities. The right-hand side of the Bellman equation is calculated and then plugged into the left-hand side. This updates the utility of each state from the utilities of its neighbors. This is called the Bellman update: 

$$ U_{i+1}(s) \leftarrow R(s) + \gamma \max_{a \epsilon A(s)} \sum_{s'} P(s'\ |\ s,a)U_{i}(s') $$

The iteration will continue until the delta value is less than the error value epsilon. Delta measures the difference in the utilities between the current episode and the previous episode. 

In [213]:
def q_value(mdp, s, a, U):
    if not a:
        return mdp.get_reward(s)
    res = 0
    for p, s_prime in mdp.get_transition_prob(a, s):
        res += p * (mdp.get_reward(s) + mdp.get_gamma() * U[str(s_prime)])
    return res


def value_iteration(mdp, epsilon=0.001):
    """Solving an MDP by value iteration. [Figure 16.6]"""

    U1 = {str(s): 0 for s in mdp.get_possible_states()}
    gamma = mdp.get_gamma()
    while True:
        U = U1.copy()
        delta = 0
        for s in mdp.get_possible_states():
            U1[str(s)] = max(q_value(mdp, s, a, U) for a in mdp.get_actions(s))
            delta = max(delta, abs(U1[str(s)] - U[str(s)]))
        if delta <= epsilon * (1 - gamma) / gamma:
            return U

The utilities for a 7 by 7 grid can now be calculated using value iteration.

In [214]:
mdp = SnakeEnvironment(initial_state={'snake_position': (1, 1), 'apple_position': (3, 3), 'snake_direction': Direction.Up}, reward_per_step=-0.04, grid_size=7, gamma=0.9)
utilities = value_iteration(mdp)
for state, utility in utilities.items():
    print("{0}:{1:7.4f}".format(state, utility))

{'snake_position': (0, 0), 'apple_position': (1, 1), 'snake_direction': <Direction.Left: 1>}:-1.0000
{'snake_position': (0, 0), 'apple_position': (1, 1), 'snake_direction': <Direction.Right: 2>}:-1.0000
{'snake_position': (0, 0), 'apple_position': (1, 1), 'snake_direction': <Direction.Up: 3>}:-1.0000
{'snake_position': (0, 0), 'apple_position': (1, 1), 'snake_direction': <Direction.Down: 4>}:-1.0000
{'snake_position': (0, 0), 'apple_position': (1, 2), 'snake_direction': <Direction.Left: 1>}:-1.0000
{'snake_position': (0, 0), 'apple_position': (1, 2), 'snake_direction': <Direction.Right: 2>}:-1.0000
{'snake_position': (0, 0), 'apple_position': (1, 2), 'snake_direction': <Direction.Up: 3>}:-1.0000
{'snake_position': (0, 0), 'apple_position': (1, 2), 'snake_direction': <Direction.Down: 4>}:-1.0000
{'snake_position': (0, 0), 'apple_position': (1, 3), 'snake_direction': <Direction.Left: 1>}:-1.0000
{'snake_position': (0, 0), 'apple_position': (1, 3), 'snake_direction': <Direction.Right: 2>}

This outputs a lot of utilities. This is because of the many possible states.

To visualize the output, we can filter the utilties on states where the apple is at position <b>(4, 3)</b> and the direction of the snake is <b>Up</b>

In [215]:
import ast

def string_to_dict(state):
    state = state.replace("<Direction.Left: 1>", "1")
    state = state.replace("<Direction.Right: 2>", "2")
    state = state.replace("<Direction.Up: 3>", "3")
    state = state.replace("<Direction.Down: 4>", "4")
    return ast.literal_eval(state)

def visualize_iteration(mdp, utilities, apple_position, direction):
    for j in range(0, mdp.get_grid_size()):
        print("+", end="")
        for i in range(0, mdp.get_grid_size()):
            print("-"*8, end="")
        print("+")
        for i in range(0, mdp.get_grid_size()):
            for state, utility in utilities.items():
                state = string_to_dict(state)
                if state['snake_position'] == (i, mdp.get_grid_size()-1-j) and state['apple_position'] == apple_position and state['snake_direction'] == direction:
                    if utility < 0:
                        print("| \x1b[31m{:5.2f}\x1b[0m ".format(utility), end="")
                    elif utility >= 1:
                        print("| \x1b[32m{:5.2f}\x1b[0m ".format(utility), end="")
                    elif utility >= 0.7:
                        print("| \x1b[33m{:5.2f}\x1b[0m ".format(utility), end="")
                    elif utility >= 0.4:
                        print("| \x1b[34m{:5.2f}\x1b[0m ".format(utility), end="")
                    else:
                        print("| \x1b[35m{:5.2f}\x1b[0m ".format(utility), end="")

            

        print("|")
    print("+", end="")
    for i in range(0, mdp.get_grid_size()):
        print("-"*8, end="")
    print("+\n")

visualize_iteration(mdp, utilities, (4, 3), Direction.Up.value)

+--------------------------------------------------------+
| [31m-1.00[0m | [31m-1.00[0m | [31m-1.00[0m | [31m-1.00[0m | [31m-1.00[0m | [31m-1.00[0m | [31m-1.00[0m |
+--------------------------------------------------------+
| [31m-1.00[0m | [34m 0.43[0m | [34m 0.52[0m | [34m 0.62[0m | [34m 0.52[0m | [34m 0.62[0m | [31m-1.00[0m |
+--------------------------------------------------------+
| [31m-1.00[0m | [34m 0.52[0m | [34m 0.62[0m | [33m 0.73[0m | [34m 0.62[0m | [33m 0.73[0m | [31m-1.00[0m |
+--------------------------------------------------------+
| [31m-1.00[0m | [34m 0.62[0m | [33m 0.73[0m | [33m 0.86[0m | [32m 1.00[0m | [33m 0.86[0m | [31m-1.00[0m |
+--------------------------------------------------------+
| [31m-1.00[0m | [34m 0.52[0m | [34m 0.62[0m | [33m 0.73[0m | [33m 0.86[0m | [33m 0.73[0m | [31m-1.00[0m |
+--------------------------------------------------------+
| [31m-1.00[0m | [34m 0.43[0m | [34m

The visualization shows that the values increase as they get closer to the position of the apple.

However, the two positions above the apple (4, 4) and (4, 5) do not have a high reward.

This is because the results are filtered on all the states where the snake has the direction "Up". Since the snake cannot go backwards, it would first have to make multiple turns.

The same thing will occur for other directions as shown in the output below, which is filtered on the "Right" direction.

In [216]:
visualize_iteration(mdp, utilities, (4, 3), Direction.Right.value)

+--------------------------------------------------------+
| [31m-1.00[0m | [31m-1.00[0m | [31m-1.00[0m | [31m-1.00[0m | [31m-1.00[0m | [31m-1.00[0m | [31m-1.00[0m |
+--------------------------------------------------------+
| [31m-1.00[0m | [34m 0.43[0m | [34m 0.52[0m | [34m 0.62[0m | [33m 0.73[0m | [34m 0.62[0m | [31m-1.00[0m |
+--------------------------------------------------------+
| [31m-1.00[0m | [34m 0.52[0m | [34m 0.62[0m | [33m 0.73[0m | [33m 0.86[0m | [33m 0.73[0m | [31m-1.00[0m |
+--------------------------------------------------------+
| [31m-1.00[0m | [34m 0.62[0m | [33m 0.73[0m | [33m 0.86[0m | [32m 1.00[0m | [34m 0.62[0m | [31m-1.00[0m |
+--------------------------------------------------------+
| [31m-1.00[0m | [34m 0.52[0m | [34m 0.62[0m | [33m 0.73[0m | [33m 0.86[0m | [33m 0.73[0m | [31m-1.00[0m |
+--------------------------------------------------------+
| [31m-1.00[0m | [34m 0.43[0m | [34m

Now that the utilities are calculated, the policies can be determined. This outputs which of the three actions the agent should make at every state.

In [217]:
pi_star = {}
for s in mdp.get_possible_states():
    if mdp.is_done(s):
        continue # policy is not needed in stop states
    max_a = float('-inf')
    argmax_a = None 
    for action in Action:
        q = q_value(mdp, s, action, utilities) 
        if q > max_a:
            max_a = q
            argmax_a = action
    pi_star[str(s)] = argmax_a

for state, policy in pi_star.items():
    print(f"{state} -> {policy}")

{'snake_position': (1, 1), 'apple_position': (1, 2), 'snake_direction': <Direction.Left: 1>} -> TurnRight
{'snake_position': (1, 1), 'apple_position': (1, 2), 'snake_direction': <Direction.Right: 2>} -> TurnLeft
{'snake_position': (1, 1), 'apple_position': (1, 2), 'snake_direction': <Direction.Up: 3>} -> Straight
{'snake_position': (1, 1), 'apple_position': (1, 2), 'snake_direction': <Direction.Down: 4>} -> TurnLeft
{'snake_position': (1, 1), 'apple_position': (1, 3), 'snake_direction': <Direction.Left: 1>} -> TurnRight
{'snake_position': (1, 1), 'apple_position': (1, 3), 'snake_direction': <Direction.Right: 2>} -> TurnLeft
{'snake_position': (1, 1), 'apple_position': (1, 3), 'snake_direction': <Direction.Up: 3>} -> Straight
{'snake_position': (1, 1), 'apple_position': (1, 3), 'snake_direction': <Direction.Down: 4>} -> TurnLeft
{'snake_position': (1, 1), 'apple_position': (1, 4), 'snake_direction': <Direction.Left: 1>} -> TurnRight
{'snake_position': (1, 1), 'apple_position': (1, 4), '

The output is filtered again on the same apple position and the direction "Up".

Keep in mind that every state below has calculated its policy when the snake is going in an upwards direction.

In [218]:
def visualize_policies(mdp, policies, apple_position, direction):
    for j in range(0, mdp.get_grid_size()):
        print("+--", end="")
        for i in range(0, mdp.get_grid_size()):
            print("----", end="")
        print("+")
        for i in range(0, mdp.get_grid_size()):
            if (i, mdp.get_grid_size()-1-j) == apple_position:
                print("|🍎 ", end="")
            # walls of grid
            elif i == 0 or i == mdp.get_grid_size()-1 or j == 0 or j == mdp.get_grid_size()-1:
                print("|🧱 ", end="")
            for state, policy in policies.items():
                state = string_to_dict(state)
                if state['snake_position'] == (i, mdp.get_grid_size()-1-j) and state['apple_position'] == apple_position and state['snake_direction'] == direction:
                    if direction == Direction.Up.value:
                        if str(policy) == str(Action.TurnRight):
                            print("| ⮣ ".format(utility), end="")
                        elif str(policy) == str(Action.TurnLeft):
                            print("| ⮢ ".format(utility), end="")
                        else: # str(policy) == str(Action.Left):
                            print("| 🠅 ".format(utility), end="")
                    elif direction == Direction.Down.value:
                        if str(policy) == str(Action.TurnRight):
                            print("| ⮠  ".format(utility), end="")
                        elif str(policy) == str(Action.TurnLeft):
                            print("| ⮡  ".format(utility), end="")
                        else:
                            print("| 🠇 ".format(utility), end="")
                    elif direction == Direction.Left.value:
                        if str(policy) == str(Action.TurnRight):
                            print("| ⮤ ".format(utility), end="")
                        elif str(policy) == str(Action.TurnLeft):
                            print("| ⮦ ".format(utility), end="")
                        else:
                            print("| 🠄 ".format(utility), end="")
                    else: # direction == Direction.Right.value:
                        if str(policy) == str(Action.TurnRight):
                            print("| ⮧ ".format(utility), end="")
                        elif str(policy) == str(Action.TurnLeft):
                            print("| ⮥ ".format(utility), end="")
                        else:
                            print("| 🠆 ".format(utility), end="")
        print("|")
    print("+--", end="")
    for i in range(0, mdp.get_grid_size()):
        print("----", end="")
    print("+\n")

visualize_policies(mdp, pi_star, (4, 3), Direction.Up.value)

+------------------------------+
|🧱 |🧱 |🧱 |🧱 |🧱 |🧱 |🧱 |
+------------------------------+
|🧱 | ⮣ | ⮣ | ⮣ | ⮢ | ⮢ |🧱 |
+------------------------------+
|🧱 | ⮣ | ⮣ | ⮣ | ⮢ | ⮢ |🧱 |
+------------------------------+
|🧱 | ⮣ | ⮣ | ⮣ |🍎 | ⮢ |🧱 |
+------------------------------+
|🧱 | ⮣ | ⮣ | ⮣ | 🠅 | ⮢ |🧱 |
+------------------------------+
|🧱 | ⮣ | ⮣ | ⮣ | 🠅 | ⮢ |🧱 |
+------------------------------+
|🧱 |🧱 |🧱 |🧱 |🧱 |🧱 |🧱 |
+------------------------------+



Again, this can also be filtered on other directions, such as going to the right.

In [219]:
visualize_policies(mdp, pi_star, (4, 3), Direction.Right.value)

+------------------------------+
|🧱 |🧱 |🧱 |🧱 |🧱 |🧱 |🧱 |
+------------------------------+
|🧱 | ⮧ | ⮧ | ⮧ | ⮧ | ⮧ |🧱 |
+------------------------------+
|🧱 | ⮧ | ⮧ | ⮧ | ⮧ | ⮧ |🧱 |
+------------------------------+
|🧱 | 🠆 | 🠆 | 🠆 |🍎 | ⮥ |🧱 |
+------------------------------+
|🧱 | ⮥ | ⮥ | ⮥ | ⮥ | ⮥ |🧱 |
+------------------------------+
|🧱 | ⮥ | ⮥ | ⮥ | ⮥ | ⮥ |🧱 |
+------------------------------+
|🧱 |🧱 |🧱 |🧱 |🧱 |🧱 |🧱 |
+------------------------------+



In [222]:
def optimal_policy(env, state):
    for state_p, policy in pi_star.items():
        state_p = string_to_dict(state_p)
        if state_p['snake_position'] == state['snake_position'] and state_p['apple_position'] == state['apple_position'] and state_p['snake_direction'] == state['snake_direction'].value:
            return policy
    return None

env.reset()
results_optimal = measure_performance(optimal_policy, nr_episodes = 100)

statistics over 100 episodes

mean:   0.80, sigma:   0.00
ep:  1, total reward:  0.80
ep:  2, total reward:  0.80
ep:  3, total reward:  0.80
ep:  4, total reward:  0.80
ep:  5, total reward:  0.80
......
ep: 95, total reward:  0.80
ep: 96, total reward:  0.80
ep: 97, total reward:  0.80
ep: 98, total reward:  0.80
ep: 99, total reward:  0.80


Number of episodes that reached the apple: 100.00%


As expected, all of the episodes reach the apple. This is because there is no randomness involved with the movements of the snake.

The standard deviation of 0 indicates that the mean is 0.80 for every episode. This also expected as it takes 5 steps to reach the apple at (4, 4) from the starting position (1, 1). This is a penalty of at at least $5 \cdot -0.04 = -0.20$ for the optimal solution.

Comparing the results of the random policies and the optimal policies shows a clear difference.

In [228]:
mean_r, std_r, percentage_r = results_random
mean_o, std_o, percentage_o = results_optimal

print(f"Random policy: mean={mean_r:.2f}, std={std_r:.2f}, reached apple={percentage_r:.2f}%")
print(f"Optimal policy: mean={mean_o:.2f}, std={std_o:.2f}, reached apple={percentage_o:.2f}%")

Random policy: mean=-0.98, std=0.48, reached apple=7.00%
Optimal policy: mean=0.80, std=0.00, reached apple=100.00%


The comparison shows that on a 7 by 7 grid, the random policy is 93% less likely to reach the apple. This number only increases as the side of the grid increases.

Of course, the computational time of value iteration will also increase significantly as the grid size increases.

Fortunately, there are more ways of solving the Bellman equation such as policy iteration. 