# Demonstration code for "Building a complete RL system" lecture

## Introduction
This code demonstrates our implementation of SARSA for a deterministic FrozenLake task and serves as additional information to go alongside the "Building a complete Rl system" lecture.
The lecture is delivered as part of the [Reinforcement Learning (2020) course](http://www.drps.ed.ac.uk/19-20/dpt/cxinfr11010.html) at the University of Edinburgh.

## Environment: FrozenLake ... but less frozen
Let's first define our deterministic FrozenLake environment. Usually, ice is pretty slippery ... but we prefer it stable. Hence, we remove any slipping and stochasticity and comfortably solve the newly created, deterministic task.

In [1]:
from gym.envs.registration import register
import numpy as np

# register non-slippery/ deterministic FrozenLake environment
register(
    id='FrozenLakeNotSlippery-v0',
    entry_point='gym.envs.toy_text:FrozenLakeEnv',
    kwargs={'map_name' : '4x4', 'is_slippery': False},
)

## Human Interface
In order to get familiar with the task, we first create an interface such that a human player can play the game. In the end, playing is fun! And it helps us to understand the task at hand. So let's go ...

In [2]:
# imports for human player
import contextlib
import termios

@contextlib.contextmanager
def raw_mode(file):
    old_attrs = termios.tcgetattr(file.fileno())
    new_attrs = old_attrs[:]
    new_attrs[3] = new_attrs[3] & ~(termios.ECHO | termios.ICANON)
    try:
        termios.tcsetattr(file.fileno(), termios.TCSADRAIN, new_attrs)
        yield
    finally:
        termios.tcsetattr(file.fileno(), termios.TCSADRAIN, old_attrs)

In [3]:
def str_to_action(act: str) -> int:
    """
    Input transferred to action id (for FrozenLake)

    :param act (int): received input
    :return (int): action id for FrozenLake (and -1 for ESC)
    """
    if act == "A":
        return 0
    elif act == "S":
        return 1
    elif act == "D":
        return 2
    elif act == "W":
        return 3
    elif act == "STOP":
        return -1
    else:
        raise ValueError(f"Unknown input {act}!")
        
def get_input():
    act = input("Choose action [WASD | STOP]")
    if act == "STOP":
        return -1
    else:
        return str_to_action(act)

In [6]:
import gym

def human_player(env: gym.Env):
    """
    Play FrozenLake as a human player with WASD keys
    """
    print("Use WASD to move in the environment and end game with ESC or keyboard interrupt (Ctrl-C)")
    env.reset()
    env.render()

    while True:
        act = get_input()
        
        if act == -1:
            return
        _, rew, done, _ = env.step(act)
        env.render()
        if done:
            if rew == 1:
                print("EPISODE FINISHED - SOLVED")
            else:
                print("EPISODE FINISHED - FAILED")
            env.reset()
            env.render()
    return

In [8]:
env = gym.make('FrozenLakeNotSlippery-v0')
human_player(env)
env.close()

Use WASD to move in the environment and end game with ESC or keyboard interrupt (Ctrl-C)

[41mS[0mFFF
FHFH
FFFH
HFFG
Choose action [WASD | STOP]STOP


## SARSA Agent
Now after understanding the task of FrozenLake, let's implement an agent for the on-policy TD control algorithm, also called SARSA. If you need a revision, have a look at [lecture 6  on Temporal Difference Learning](https://www.learn.ed.ac.uk/bbcswebdav/pid-4067850-dt-content-rid-11604420_1/xid-11604420_1) (Slides 15 - 17) or take a look at [section 6.4 in the RL book](http://www.incompleteideas.net/book/RLbook2018.pdf##page=153) on the SARSA method.

In [16]:
# visualisation tools for final Q-tables
def visualise_q_table(q_table):
    """
    Print q_table in human-readable format

    :param q_table (Dict): q_table in form of a dict mapping (observation, action) pairs to
        q-values
    """
    for key in sorted(q_table.keys()):
        obs, act = key
        act_name = act_to_str(act)
        q_value = q_table[key]
        print(f"Pos={obs}\tAct={act_name}\t->\t{q_value}")

def act_to_str(act: int):
    if act == 0:
        return "L"
    elif act == 1:
        return "D"
    elif act == 2:
        return "R"
    elif act == 3:
        return "U"
    else:
        raise ValueError("Invalid action value")
        
def visualise_policy(q_table):
    """
    Given q_table print greedy policy for each FrozenLake position

    :param q_table (Dict): q_table in form of a dict mapping (observation, action) pairs to
        q-values
    """
    # extract best acts
    act_table = np.zeros((4,4))
    str_table = []
    for row in range(4):
        str_table.append("")
        for col in range(4):
            pos = row * 4 + col
            max_q = None
            max_a = None
            for a in range(4):
                q = q_table[(pos, a)]
                if max_q is None or q > max_q:
                    max_q = q
                    max_a = a
            act_table[row, col] = max_a
            str_table[row] += act_to_str(max_a)
    
    # print best actions in human_readable format
    print("\nAction selection table:")
    for row_str in str_table:
        print(row_str)
    print()

In [13]:
from abc import ABC
from collections import defaultdict
import random
from typing import DefaultDict


class SARSA(ABC):
    """Base class for SARSA agent

    :attr n_acts (int): number of actions
    :attr gamma (float): discount factor gamma
    :attr epsilon (float): epsilon hyperparameter for epsilon-greedy policy
    :attr alpha (float): learning rate alpha for updates
    :attr q_table (DefaultDict): table for Q-values mapping (OBS, ACT) pairs of observations
        and actions to respective Q-values
    """

    def __init__(
            self,
            num_acts: int,
            gamma: float,
            epsilon: float = 0.9,
            alpha: float = 0.1
        ):
        """Constructor for SARSA agent

        Initializes basic variables of the agent namely the epsilon, learning rate and discount
        rate.

        :param num_acts (int): number of possible actions
        :param gamma (float): discount factor (gamma)
        :param epsilon (float): initial epsilon for epsilon-greedy action selection
        :param alpha (float): learning rate alpha
        """
        self.n_acts: int = num_acts
        self.gamma: float = gamma
        self.epsilon: float = epsilon
        self.alpha: float = alpha

        self.q_table: DefaultDict = defaultdict(lambda: 0)

    def act(self, obs: np.ndarray) -> int:
        """Epsilon-greedy action selection 

        :param obs (np.ndarray of float with dim (observation size)):
            received observation representing the current environmental state
        :return (int): index of selected action
        """
        act_vals = [self.q_table[(obs, act)] for act in range(self.n_acts)]
        max_val = max(act_vals)
        max_acts = [idx for idx, act_val in enumerate(act_vals) if act_val == max_val]

        if random.random() < self.epsilon:
            return random.randint(0, self.n_acts - 1)
        else:
            return random.choice(max_acts)

    def learn(
            self,
            obs: np.ndarray,
            action: int,
            reward: float,
            n_obs: np.ndarray,
            n_action: int,
            done: bool
        ) -> float:
        """Updates the Q-table based on agent experience

        :param obs (np.ndarray of float with dim (observation size)):
            received observation representing the current environmental state
        :param action (int): index of applied action
        :param reward (float): received reward
        :param n_obs (np.ndarray of float with dim (observation size)):
            received observation representing the next environmental state
        :param done (bool): flag indicating whether a terminal state has been reached
        :return (float): updated Q-value for current observation-action pair
        """
        target_value = reward + self.gamma * (1 - done) * self.q_table[(n_obs, n_action)]
        self.q_table[(obs, action)] += self.alpha * (
            target_value - self.q_table[(obs, action)]
        )
        return self.q_table[(obs, action)]

    def schedule_hyperparameters(self, timestep: int, max_timestep: int):
        """Updates the hyperparameters

        This function is called before every episode and allows you to schedule your
        hyperparameters.

        :param timestep (int): current timestep at the beginning of the episode
        :param max_timestep (int): maximum timesteps that the training loop will run for
        """
        self.epsilon = 1.0-(min(1.0, timestep/(0.07*max_timestep)))*0.95

In [14]:
from time import sleep

CONFIG = {
    "env": "FrozenLakeNotSlippery-v0",
    "target_reward": 1,
    "eval_solved_goal": 10,
    "total_eps": 1000,
    "eval_episodes": 1,
    "eval_freq": 10,
    "gamma": 0.99,
    "alpha": 0.1,
    "epsilon": 0.9,
}

RENDER = False

def evaluate(env, config, q_table, eval_episodes=10, render=False, output=True):
    """
    Evaluate configuration of SARSA on given environment initialised with given Q-table

    :param env (gym.Env): environment to execute evaluation on
    :param config (Dict[str, float]): configuration dictionary containing hyperparameters
    :param q_table (Dict[(Obs, Act), float]): Q-table mapping observation-action to Q-values
    :param eval_episodes (int): number of evaluation episodes
    :param render (bool): flag whether evaluation runs should be rendered
    :param output (bool): flag whether mean evaluation performance should be printed
    :return (float, float): mean and standard deviation of reward received over episodes
    """
    eval_agent = SARSA(
            num_acts=env.action_space.n,
            gamma=config["gamma"],
            epsilon=0.0, 
            alpha=config["alpha"],
    )
    eval_agent.q_table = q_table
    episodic_rewards = []
    for eps_num in range(eval_episodes):
        obs = env.reset()
        if render:
            env.render()
            sleep(1)
        episodic_reward = 0
        done = False

        while not done:
            act = eval_agent.act(obs)
            n_obs, reward, done, info = env.step(act)
            if render:
                env.render()
                sleep(1)

            episodic_reward += reward

            obs = n_obs

        episodic_rewards.append(episodic_reward)

    mean_reward = np.mean(episodic_rewards)
    std_reward = np.std(episodic_rewards)

    if output:
        print(f"EVALUATION: MEAN REWARD OF {mean_reward}")
        if mean_reward == 1.0:
            print(f"EVALUATION: SOLVED")
        else:
            print(f"EVALUATION: NOT SOLVED!")
    return mean_reward, std_reward


def train(env, config, output=True):
    """
    Train and evaluate SARSA on given environment with provided hyperparameters

    :param env (gym.Env): environment to execute evaluation on
    :param config (Dict[str, float]): configuration dictionary containing hyperparameters
    :param output (bool): flag if mean evaluation results should be printed
    :return (float, List[float], List[float], Dict[(Obs, Act), float]):
        total reward over all episodes, list of means and standard deviations of evaluation
        rewards, final Q-table
    """
    agent = SARSA(
            num_acts=env.action_space.n,
            gamma=config["gamma"],
            epsilon=config["epsilon"],
            alpha=config["alpha"],
    )

    step_counter = 0
    max_steps = config["total_eps"]
    
    total_reward = 0
    evaluation_reward_means = []
    evaluation_reward_stds = []
    eval_solved = 0

    for eps_num in range(config["total_eps"]):
        obs = env.reset()
        episodic_reward = 0
        done = False

        # take first action
        act = agent.act(obs)

        while not done:
            n_obs, reward, done, info = env.step(act)
            step_counter += 1
            episodic_reward += reward

            agent.schedule_hyperparameters(step_counter, max_steps)
            n_act = agent.act(n_obs)
            agent.learn(obs, act, reward, n_obs, n_act, done)

            obs = n_obs
            act = n_act

        total_reward += episodic_reward

        if eps_num > 0 and eps_num % config["eval_freq"] == 0:
            mean_reward, std_reward = evaluate(
                    env,
                    config,
                    agent.q_table,
                    eval_episodes=config["eval_episodes"],
                    render=RENDER,
                    output=output
            )
            evaluation_reward_means.append(mean_reward)
            evaluation_reward_stds.append(std_reward)

            if mean_reward >= config["target_reward"]:
                eval_solved += 1
                if output:
                    print(f"Reached reward {mean_reward} >= {config['target_reward']} (target reward)")
                if eval_solved == config["eval_solved_goal"]:
                    if output:
                        print(f"Solved evaluation {eval_solved} times in a row --> terminate training")
                    break
            else:
                eval_solved = 0

    return total_reward, evaluation_reward_means, evaluation_reward_stds, agent.q_table

In [17]:
env = gym.make(CONFIG["env"])
    
total_reward, _, _, q_table = train(env, CONFIG)
# print("Q-table:")
# visualise_q_table(q_table)
print()
visualise_policy(q_table)

EVALUATION: MEAN REWARD OF 0.0
EVALUATION: NOT SOLVED!
EVALUATION: MEAN REWARD OF 0.0
EVALUATION: NOT SOLVED!
EVALUATION: MEAN REWARD OF 0.0
EVALUATION: NOT SOLVED!
EVALUATION: MEAN REWARD OF 0.0
EVALUATION: NOT SOLVED!
EVALUATION: MEAN REWARD OF 1.0
EVALUATION: SOLVED
Reached reward 1.0 >= 1 (target reward)
EVALUATION: MEAN REWARD OF 1.0
EVALUATION: SOLVED
Reached reward 1.0 >= 1 (target reward)
EVALUATION: MEAN REWARD OF 1.0
EVALUATION: SOLVED
Reached reward 1.0 >= 1 (target reward)
EVALUATION: MEAN REWARD OF 1.0
EVALUATION: SOLVED
Reached reward 1.0 >= 1 (target reward)
EVALUATION: MEAN REWARD OF 1.0
EVALUATION: SOLVED
Reached reward 1.0 >= 1 (target reward)
EVALUATION: MEAN REWARD OF 1.0
EVALUATION: SOLVED
Reached reward 1.0 >= 1 (target reward)
EVALUATION: MEAN REWARD OF 1.0
EVALUATION: SOLVED
Reached reward 1.0 >= 1 (target reward)
EVALUATION: MEAN REWARD OF 1.0
EVALUATION: SOLVED
Reached reward 1.0 >= 1 (target reward)
EVALUATION: MEAN REWARD OF 1.0
EVALUATION: SOLVED
Reached re