# Frozen Lake
Implementation of SARSA and Q-Learning on the Frozen Lake game:

> The agent controls the movement of a character in a grid world. Some tiles of the grid are walkable, and others lead to the agent falling into the water. Additionally, the movement direction of the agent is uncertain and only partially depends on the chosen direction. The agent is rewarded for finding a walkable path to a goal tile. [openai.com](https://gym.openai.com/envs/FrozenLake-v0)

| FROZEN                      | L | A | K | E |
|-----------------------------|---|---|---|---|
| **S:** starting point, safe | S | F | F | F |
| **F:** frozen surface, safe | F | H | F | H |
|                 **H:** hole | F | F | F | H |
|                 **G:** goal | H | F | F | G |

- Every step there is a 33% chance to slip and perform a random action instead of the chosen one.
- **H** and **G** are terminal states.
- Reaching **G** provides a reward of $1$.
- Actions are `left`, `down`, `right` and `up`.
- Environment is considered solved when reaching an average reward of $0.78$ over 100 consecutive episodes.
- [Optimum](https://github.com/openai/gym/blob/37efc3e7d876b7f28c4d675c87137a293c33e2d1/gym/envs/__init__.py#L149-L163) policy would receive a reward of $0.8196$ over 100 consecutive episodes.

In [None]:
ACTIONS = ['left', 'down', 'right', 'up']

### Helper Functions

In [None]:
%matplotlib inline
import random
import sys
from collections import deque

import gym
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns


def evaluator():
    """Pretty print live training stats."""
    rewards = deque([], 100)
    vals = [0]

    def stats(episode, reward):
        rewards.append(reward)
        if len(rewards) == 100:
            mean = np.mean(rewards)
            if mean >= .78 and vals[0] < .78:
                sys.stdout.write(' '.join([''] * 50) + '\r')
                sys.stdout.write('Beat .78 at episode {}.\n'.format(episode))
                sys.stdout.flush()
            if mean > vals[0]:
                vals[0] = mean
                sys.stdout.write('Current 100 epsiode highscore at {}: {:.2f}\r'
                                 .format(episode, mean))
                sys.stdout.flush()
    return stats


def plot_policy(qtable):
    """Plot the greedy argmax policy values."""
    plt.figure()
    plt.title('Policy')
    values = qtable.max(axis=1).reshape(4, 4)
    policy = qtable.argmax(axis=1).reshape(4, 4)
    labels = np.asarray(ACTIONS)[policy]
    annot = np.core.defchararray.add(labels, np.char.mod('\n%.5f', values))
    sns.heatmap(values, annot=annot, mask=values == 0, fmt='s', cbar=False)


def plot_values(qtable):
    """Plot all q values in a single big heatmap."""
    left, down, right, up = np.split(qtable, 4, 1)
    qmap = np.zeros((4 * 3, 4 * 3))
    qmap[0::3, 1::3] = up.squeeze().reshape(4, 4)
    qmap[1::3, 0::3] = left.squeeze().reshape(4, 4)
    qmap[1::3, 2::3] = right.squeeze().reshape(4, 4)
    qmap[2::3, 1::3] = down.squeeze().reshape(4, 4)
    plt.figure(figsize=(10, 10))
    plt.title('Transition values')
    annotations = np.char.mod('%.5f', qmap)
    mask = qmap == 0
    ax = sns.heatmap(qmap, annot=annotations, mask=mask, fmt='s', cbar=False)
    plt.hlines([3, 6, 9], *ax.get_xlim())
    plt.vlines([3, 6, 9], *ax.get_ylim())
    plt.xticks(np.arange(1.5, 12, 3), range(0, 4))
    plt.yticks(np.arange(1.5, 12, 3), range(0, 4))


# def subplot_values(qtable, action_labels):
#     """Plot q values in individual subplots per action."""
#     fig = plt.figure(figsize=(10, 10))
#     for i, actvals in enumerate(np.split(qtable, 4, 1)):
#         actvals = actvals.squeeze().reshape(4, 4)
#         fig.add_subplot(221 + i, title='Values for: {}'.format(action_labels[i]))
#         sns.heatmap(actvals, annot=True, mask=actvals == 0, fmt='.5f', cbar=False)
#     return fig


def epsilon(episodes, final=.1, initial=1):
    """Provide a function for linear value annealing."""
    def anneal(episode):
        diff = initial - final
        return max([final, initial - (diff * episode / episodes)])

    return anneal

### Frozen Lake: Sample Run

In [None]:
env = gym.make('FrozenLake-v0')
state = env.reset()
env.render()
terminal = False
while not terminal:
    action = env.action_space.sample()
    _, _, terminal, _ = env.step(action)
    print('\nAction selected: {}\n'.format(ACTIONS[action]))
    env.render()

# Deterministic Q-Learning
Learning a deterministic environment is always easier. So for this first implementation we turn of the "*slippyness*" of the ice. This way each Q value is simply the current reward plus the discounted maximum future reward.

$$Q(s_t,a_t) \leftarrow r_t + \gamma * max_{a} Q(s_{t+1},a)$$
- Single table of Q values as data structure (states * actions)
- $\epsilon$-greedy policy with linearly decaying $\epsilon$


In [None]:
if 'FrozenLakeDeterministic-v0' not in gym.envs.registry.env_specs:
    gym.envs.registration.register(
        id='FrozenLakeDeterministic-v0',
        entry_point='gym.envs.toy_text:FrozenLakeEnv',
        kwargs={'map_name' : '4x4', 'is_slippery': False},
        max_episode_steps=100,
        reward_threshold=1, # optimum = 1
    )

In [None]:
env = gym.make('FrozenLakeDeterministic-v0')

episodes = 1000
eps = epsilon(episodes / 2, initial=1, final=0)
gamma = 0.9

qtable = np.zeros((4 * 4, 4))  # ROWS x COLUMNS x ACTIONS
stats = evaluator()

for episode in range(episodes):
    state = env.reset()
    terminal = False
    while not terminal:
        if random.random() < eps(episode):
            action = env.action_space.sample()
        else:
            action = qtable[state].argmax()
        state_, reward, terminal, _ = env.step(action)
        qtable[state, action] = reward + gamma * qtable[state_].max()
        state = state_
    stats(episode, reward)

env.close()
plot_policy(qtable)

# SARSA

**S**tate, **A**ction, **R**eward, **S**tate', **A**ction'

$$Q(s_t,a_t) \leftarrow  Q(s_t,a_t) + \alpha * [r_t + \gamma * Q(s_{t+1}, a_{t+1}) - Q(s_t,a_t)]$$

- Single table of Q values as data structure (states * actions)
- $\epsilon$-greedy policy with linearly decaying $\epsilon$


In [None]:
def sarsa(state, action, reward, state_, action_):
    """Compute SARSA Q value update."""
    future_reward = gamma * qtable[state_, action_]
    return alpha * (reward + future_reward - qtable[state, action])

In [None]:
env = gym.make('FrozenLake-v0')

episodes = 3000
eps = epsilon(episodes / 2, initial=1, final=0)
gamma = 0.9
alpha = 0.85

qtable = np.zeros((4 * 4, 4))  # ROWS x COLUMNS x ACTIONS
stats = evaluator()

for episode in range(episodes):
    state = env.reset()
    terminal = False
    while not terminal:
        action = np.random.normal(qtable[state], eps(episode)).argmax()
        state_, reward, terminal, _ = env.step(action)
        action_ = np.random.normal(qtable[state_], eps(episode)).argmax()
        qtable[state, action] += sarsa(state, action, reward, state_, action_)
        state = state_
    stats(episode, reward)

env.close()
plot_policy(qtable)
plot_values(qtable)

# Q-Learning
Very similar to SARSA, but estimates the Q value according to a exploitation policy instead of the current policy. SARSA learn on-policy: Future Q values are chosen according to the current policy. Q-Learning is greedy: Future Q values are chosen purely by their value.

$$Q(s_t,a_t) \leftarrow Q(s_t,a_t) + \alpha * [r_t + \gamma * max_a Q(s_{t+1}, a) - Q(s_t,a_t)]$$

The reward $r_t$, as in SARSA, is the reward received when taking action $a_t$ in state $s_t$. Sometimes (e.g. on Wikipedia) this reward is denoted as $r_{t+1}$.

In [None]:
def update(state, action, reward, state_):
    """Compute Q value update."""
    future_reward = gamma * qtable[state_].max()
    return alpha * (reward + future_reward - qtable[state, action])

In [None]:
env = gym.make('FrozenLake-v0')

episodes = 5000
eps = epsilon(episodes / 4, initial=1, final=0)
gamma = 0.9
alpha = 0.85

qtable = np.zeros((4 * 4, 4))  # ROWS x COLUMNS x ACTIONS
stats = evaluator()

for episode in range(episodes):
    state = env.reset()
    terminal = False
    while not terminal:
        action = np.random.normal(qtable[state], eps(episode)).argmax()
        state_, reward, terminal, _ = env.step(action)
        qtable[state, action] += update(state, action, reward, state_)
        state = state_
    stats(episode, reward)

env.close()
plot_policy(qtable)
plot_values(qtable)